| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 | package middlewareimport (	"crypto/subtle"	"errors"	"net/http"	"strings"	"time"	"github.com/labstack/echo"	"github.com/labstack/gommon/random")type (	// CSRFConfig defines the config for CSRF middleware.	CSRFConfig struct {		// Skipper defines a function to skip middleware.		Skipper Skipper		// TokenLength is the length of the generated token.		TokenLength uint8 `yaml:"token_length"`		// Optional. Default value 32.		// TokenLookup is a string in the form of "<source>:<key>" that is used		// to extract token from the request.		// Optional. Default value "header:X-CSRF-Token".		// Possible values:		// - "header:<name>"		// - "form:<name>"		// - "query:<name>"		TokenLookup string `yaml:"token_lookup"`		// Context key to store generated CSRF token into context.		// Optional. Default value "csrf".		ContextKey string `yaml:"context_key"`		// Name of the CSRF cookie. This cookie will store CSRF token.		// Optional. Default value "csrf".		CookieName string `yaml:"cookie_name"`		// Domain of the CSRF cookie.		// Optional. Default value none.		CookieDomain string `yaml:"cookie_domain"`		// Path of the CSRF cookie.		// Optional. Default value none.		CookiePath string `yaml:"cookie_path"`		// Max age (in seconds) of the CSRF cookie.		// Optional. Default value 86400 (24hr).		CookieMaxAge int `yaml:"cookie_max_age"`		// Indicates if CSRF cookie is secure.		// Optional. Default value false.		CookieSecure bool `yaml:"cookie_secure"`		// Indicates if CSRF cookie is HTTP only.		// Optional. Default value false.		CookieHTTPOnly bool `yaml:"cookie_http_only"`	}	// csrfTokenExtractor defines a function that takes `echo.Context` and returns	// either a token or an error.	csrfTokenExtractor func(echo.Context) (string, error))var (	// DefaultCSRFConfig is the default CSRF middleware config.	DefaultCSRFConfig = CSRFConfig{		Skipper:      DefaultSkipper,		TokenLength:  32,		TokenLookup:  "header:" + echo.HeaderXCSRFToken,		ContextKey:   "csrf",		CookieName:   "_csrf",		CookieMaxAge: 86400,	})// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.// See: https://en.wikipedia.org/wiki/Cross-site_request_forgeryfunc CSRF() echo.MiddlewareFunc {	c := DefaultCSRFConfig	return CSRFWithConfig(c)}// CSRFWithConfig returns a CSRF middleware with config.// See `CSRF()`.func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {	// Defaults	if config.Skipper == nil {		config.Skipper = DefaultCSRFConfig.Skipper	}	if config.TokenLength == 0 {		config.TokenLength = DefaultCSRFConfig.TokenLength	}	if config.TokenLookup == "" {		config.TokenLookup = DefaultCSRFConfig.TokenLookup	}	if config.ContextKey == "" {		config.ContextKey = DefaultCSRFConfig.ContextKey	}	if config.CookieName == "" {		config.CookieName = DefaultCSRFConfig.CookieName	}	if config.CookieMaxAge == 0 {		config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge	}	// Initialize	parts := strings.Split(config.TokenLookup, ":")	extractor := csrfTokenFromHeader(parts[1])	switch parts[0] {	case "form":		extractor = csrfTokenFromForm(parts[1])	case "query":		extractor = csrfTokenFromQuery(parts[1])	}	return func(next echo.HandlerFunc) echo.HandlerFunc {		return func(c echo.Context) error {			if config.Skipper(c) {				return next(c)			}			req := c.Request()			k, err := c.Cookie(config.CookieName)			token := ""			// Generate token			if err != nil {				token = random.String(config.TokenLength)			} else {				// Reuse token				token = k.Value			}			switch req.Method {			case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE:			default:				// Validate token only for requests which are not defined as 'safe' by RFC7231				clientToken, err := extractor(c)				if err != nil {					return echo.NewHTTPError(http.StatusBadRequest, err.Error())				}				if !validateCSRFToken(token, clientToken) {					return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")				}			}			// Set CSRF cookie			cookie := new(http.Cookie)			cookie.Name = config.CookieName			cookie.Value = token			if config.CookiePath != "" {				cookie.Path = config.CookiePath			}			if config.CookieDomain != "" {				cookie.Domain = config.CookieDomain			}			cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)			cookie.Secure = config.CookieSecure			cookie.HttpOnly = config.CookieHTTPOnly			c.SetCookie(cookie)			// Store token in the context			c.Set(config.ContextKey, token)			// Protect clients from caching the response			c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)			return next(c)		}	}}// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the// provided request header.func csrfTokenFromHeader(header string) csrfTokenExtractor {	return func(c echo.Context) (string, error) {		return c.Request().Header.Get(header), nil	}}// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the// provided form parameter.func csrfTokenFromForm(param string) csrfTokenExtractor {	return func(c echo.Context) (string, error) {		token := c.FormValue(param)		if token == "" {			return "", errors.New("missing csrf token in the form parameter")		}		return token, nil	}}// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the// provided query parameter.func csrfTokenFromQuery(param string) csrfTokenExtractor {	return func(c echo.Context) (string, error) {		token := c.QueryParam(param)		if token == "" {			return "", errors.New("missing csrf token in the query string")		}		return token, nil	}}func validateCSRFToken(token, clientToken string) bool {	return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1}
 |