Skip to content

Commit

Permalink
Merge pull request from GHSA-94w9-97p3-p368
Browse files Browse the repository at this point in the history
* feat: improved csrf with session support

* fix: double submit cookie

* feat: add warning cookie extractor without session

* feat: add warning CsrfFromCookie SameSite

* fix: use byes.Equal instead

* fix: Overriden CookieName KeyLookup cookie:<name>

* feat: Create helpers.go

* feat: use compareTokens (constant time compare)

* feat: validate cookie to prevent token injection

* refactor: clean up csrf.go

* docs: update comment about Double Submit Cookie

* docs: update docs for CSRF changes

* feat: add DeleteToken

* refactor: no else

* test: add more tests

* refactor: re-order tests

* docs: update safe methods RCF add note

* test: add CSRF_Cookie_Injection_Exploit

* feat: add SingleUseToken config

* test: check for new token

* docs: use warning

* fix: always register type Token

* feat: use UUIDv4

* test: swap in UUIDv4 here too
  • Loading branch information
sixcolors authored Oct 11, 2023
1 parent 9292a36 commit b50d91d
Show file tree
Hide file tree
Showing 9 changed files with 793 additions and 74 deletions.
117 changes: 105 additions & 12 deletions docs/api/middleware/csrf.md

Large diffs are not rendered by default.

54 changes: 47 additions & 7 deletions middleware/csrf/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
"github.com/gofiber/fiber/v2/middleware/session"
"github.com/gofiber/fiber/v2/utils"
)

Expand All @@ -33,6 +34,7 @@ type Config struct {

// Name of the session cookie. This cookie will store session key.
// Optional. Default value "csrf_".
// Overriden if KeyLookup == "cookie:<name>"
CookieName string

// Domain of the CSRF cookie.
Expand Down Expand Up @@ -64,11 +66,29 @@ type Config struct {
// Optional. Default: 1 * time.Hour
Expiration time.Duration

// SingleUseToken indicates if the CSRF token be destroyed
// and a new one generated on each use.
//
// Optional. Default: false
SingleUseToken bool

// Store is used to store the state of the middleware
//
// Optional. Default: memory.New()
// Ignored if Session is set.
Storage fiber.Storage

// Session is used to store the state of the middleware
//
// Optional. Default: nil
// If set, the middleware will use the session store instead of the storage
Session *session.Store

// SessionKey is the key used to store the token in the session
//
// Default: "fiber.csrf.token"
SessionKey string

// Context key to store generated CSRF token into context.
// If left empty, token will not be stored in context.
//
Expand Down Expand Up @@ -100,19 +120,26 @@ type Config struct {
//
// Optional. Default will create an Extractor based on KeyLookup.
Extractor func(c *fiber.Ctx) (string, error)

// HandlerContextKey is used to store the CSRF Handler into context
//
// Default: "fiber.csrf.handler"
HandlerContextKey string
}

const HeaderName = "X-Csrf-Token"

// ConfigDefault is the default config
var ConfigDefault = Config{
KeyLookup: "header:" + HeaderName,
CookieName: "csrf_",
CookieSameSite: "Lax",
Expiration: 1 * time.Hour,
KeyGenerator: utils.UUID,
ErrorHandler: defaultErrorHandler,
Extractor: CsrfFromHeader(HeaderName),
KeyLookup: "header:" + HeaderName,
CookieName: "csrf_",
CookieSameSite: "Lax",
Expiration: 1 * time.Hour,
KeyGenerator: utils.UUIDv4,
ErrorHandler: defaultErrorHandler,
Extractor: CsrfFromHeader(HeaderName),
SessionKey: "fiber.csrf.token",
HandlerContextKey: "fiber.csrf.handler",
}

// default ErrorHandler that process return error from fiber.Handler
Expand Down Expand Up @@ -174,6 +201,12 @@ func configDefault(config ...Config) Config {
if cfg.ErrorHandler == nil {
cfg.ErrorHandler = ConfigDefault.ErrorHandler
}
if cfg.SessionKey == "" {
cfg.SessionKey = ConfigDefault.SessionKey
}
if cfg.HandlerContextKey == "" {
cfg.HandlerContextKey = ConfigDefault.HandlerContextKey
}

// Generate the correct extractor to get the token from the correct location
selectors := strings.Split(cfg.KeyLookup, ":")
Expand All @@ -195,7 +228,14 @@ func configDefault(config ...Config) Config {
case "param":
cfg.Extractor = CsrfFromParam(selectors[1])
case "cookie":
if cfg.Session == nil {
log.Warn("[CSRF] Cookie extractor is not recommended without a session store")
}
if cfg.CookieSameSite == "None" || cfg.CookieSameSite != "Lax" && cfg.CookieSameSite != "Strict" {
log.Warn("[CSRF] Cookie extractor is only recommended for use with SameSite=Lax or SameSite=Strict")
}
cfg.Extractor = CsrfFromCookie(selectors[1])
cfg.CookieName = selectors[1] // Cookie name is the same as the key
}
}

Expand Down
214 changes: 173 additions & 41 deletions middleware/csrf/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,42 @@ package csrf

import (
"errors"
"reflect"
"time"

"github.com/gofiber/fiber/v2"
)

var errTokenNotFound = errors.New("csrf token not found")
var (
ErrTokenNotFound = errors.New("csrf token not found")
ErrTokenInvalid = errors.New("csrf token invalid")
ErrNoReferer = errors.New("referer not supplied")
ErrBadReferer = errors.New("referer invalid")
dummyValue = []byte{'+'}
)

type CSRFHandler struct {
config *Config
sessionManager *sessionManager
storageManager *storageManager
}

// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)

// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage)
// Create manager to simplify storage operations ( see *_manager.go )
var sessionManager *sessionManager
var storageManager *storageManager
if cfg.Session != nil {
// Register the Token struct in the session store
cfg.Session.RegisterType(Token{})

dummyValue := []byte{'+'}
sessionManager = newSessionManager(cfg.Session, cfg.SessionKey)
} else {
storageManager = newStorageManager(cfg.Storage)
}

// Return new handler
return func(c *fiber.Ctx) error {
Expand All @@ -26,36 +46,69 @@ func New(config ...Config) fiber.Handler {
return c.Next()
}

// Store the CSRF handler in the context if a context key is specified
if cfg.HandlerContextKey != "" {
c.Locals(cfg.HandlerContextKey, &CSRFHandler{
config: &cfg,
sessionManager: sessionManager,
storageManager: storageManager,
})
}

var token string

// Action depends on the HTTP method
switch c.Method() {
case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace:
// Declare empty token and try to get existing CSRF from cookie
token = c.Cookies(cfg.CookieName)
cookieToken := c.Cookies(cfg.CookieName)

if cookieToken != "" {
rawToken := getTokenFromStorage(c, cookieToken, cfg, sessionManager, storageManager)

if rawToken != nil {
token = string(rawToken)
}
}
default:
// Assume that anything not defined as 'safe' by RFC7231 needs protection

// Enforce an origin check for HTTPS connections.
if c.Protocol() == "https" {
if err := refererMatchesHost(c); err != nil {
return cfg.ErrorHandler(c, err)
}
}

// Extract token from client request i.e. header, query, param, form or cookie
token, err := cfg.Extractor(c)
extractedToken, err := cfg.Extractor(c)
if err != nil {
return cfg.ErrorHandler(c, err)
}

// if token does not exist in Storage
if manager.getRaw(token) == nil {
// Expire cookie
c.Cookie(&fiber.Cookie{
Name: cfg.CookieName,
Domain: cfg.CookieDomain,
Path: cfg.CookiePath,
Expires: time.Now().Add(-1 * time.Minute),
Secure: cfg.CookieSecure,
HTTPOnly: cfg.CookieHTTPOnly,
SameSite: cfg.CookieSameSite,
SessionOnly: cfg.CookieSessionOnly,
})
return cfg.ErrorHandler(c, errTokenNotFound)
if extractedToken == "" {
return cfg.ErrorHandler(c, ErrTokenNotFound)
}

// If not using CsrfFromCookie extractor, check that the token matches the cookie
// This is to prevent CSRF attacks by using a Double Submit Cookie method
// Useful when we do not have access to the users Session
if !isCsrfFromCookie(cfg.Extractor) && extractedToken != c.Cookies(cfg.CookieName) {
return cfg.ErrorHandler(c, ErrTokenInvalid)
}

rawToken := getTokenFromStorage(c, extractedToken, cfg, sessionManager, storageManager)

if rawToken == nil {
// If token is not in storage, expire the cookie
expireCSRFCookie(c, cfg)
// and return an error
return cfg.ErrorHandler(c, ErrTokenNotFound)
}
if cfg.SingleUseToken {
// If token is single use, delete it from storage
deleteTokenFromStorage(c, extractedToken, cfg, sessionManager, storageManager)
} else {
token = string(rawToken)
}
}

Expand All @@ -65,29 +118,16 @@ func New(config ...Config) fiber.Handler {
token = cfg.KeyGenerator()
}

// Add/update token to Storage
manager.setRaw(token, dummyValue, cfg.Expiration)

// Create cookie to pass token to client
cookie := &fiber.Cookie{
Name: cfg.CookieName,
Value: token,
Domain: cfg.CookieDomain,
Path: cfg.CookiePath,
Expires: time.Now().Add(cfg.Expiration),
Secure: cfg.CookieSecure,
HTTPOnly: cfg.CookieHTTPOnly,
SameSite: cfg.CookieSameSite,
SessionOnly: cfg.CookieSessionOnly,
}
// Set cookie to response
c.Cookie(cookie)
// Create or extend the token in the storage
createOrExtendTokenInStorage(c, token, cfg, sessionManager, storageManager)

// Protect clients from caching the response by telling the browser
// a new header value is generated
// Update the CSRF cookie
updateCSRFCookie(c, cfg, token)

// Tell the browser that a new header value is generated
c.Vary(fiber.HeaderCookie)

// Store token in context if set
// Store the token in the context if a context key is specified
if cfg.ContextKey != "" {
c.Locals(cfg.ContextKey, token)
}
Expand All @@ -96,3 +136,95 @@ func New(config ...Config) fiber.Handler {
return c.Next()
}
}

// getTokenFromStorage returns the raw token from the storage
// returns nil if the token does not exist, is expired or is invalid
func getTokenFromStorage(c *fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) []byte {
if cfg.Session != nil {
return sessionManager.getRaw(c, token, dummyValue)
}
return storageManager.getRaw(token)
}

// createOrExtendTokenInStorage creates or extends the token in the storage
func createOrExtendTokenInStorage(c *fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) {
if cfg.Session != nil {
sessionManager.setRaw(c, token, dummyValue, cfg.Expiration)
} else {
storageManager.setRaw(token, dummyValue, cfg.Expiration)
}
}

func deleteTokenFromStorage(c *fiber.Ctx, token string, cfg Config, sessionManager *sessionManager, storageManager *storageManager) {
if cfg.Session != nil {
sessionManager.delRaw(c)
} else {
storageManager.delRaw(token)
}
}

// Update CSRF cookie
// if expireCookie is true, the cookie will expire immediately
func updateCSRFCookie(c *fiber.Ctx, cfg Config, token string) {
setCSRFCookie(c, cfg, token, cfg.Expiration)
}

func expireCSRFCookie(c *fiber.Ctx, cfg Config) {
setCSRFCookie(c, cfg, "", -time.Hour)
}

func setCSRFCookie(c *fiber.Ctx, cfg Config, token string, expiry time.Duration) {
cookie := &fiber.Cookie{
Name: cfg.CookieName,
Value: token,
Domain: cfg.CookieDomain,
Path: cfg.CookiePath,
Secure: cfg.CookieSecure,
HTTPOnly: cfg.CookieHTTPOnly,
SameSite: cfg.CookieSameSite,
SessionOnly: cfg.CookieSessionOnly,
Expires: time.Now().Add(expiry),
}

// Set the CSRF cookie to the response
c.Cookie(cookie)
}

// DeleteToken removes the token found in the context from the storage
// and expires the CSRF cookie
func (handler *CSRFHandler) DeleteToken(c *fiber.Ctx) error {
// Get the config from the context
config := handler.config
if config == nil {
panic("CSRFHandler config not found in context")
}
// Extract token from the client request cookie
cookieToken := c.Cookies(config.CookieName)
if cookieToken == "" {
return config.ErrorHandler(c, ErrTokenNotFound)
}
// Remove the token from storage
deleteTokenFromStorage(c, cookieToken, *config, handler.sessionManager, handler.storageManager)
// Expire the cookie
expireCSRFCookie(c, *config)
return nil
}

// isCsrfFromCookie checks if the extractor is set to ExtractFromCookie
func isCsrfFromCookie(extractor interface{}) bool {
return reflect.ValueOf(extractor).Pointer() == reflect.ValueOf(CsrfFromCookie).Pointer()
}

// refererMatchesHost checks that the referer header matches the host header
// returns an error if the referer header is not present or is invalid
// returns nil if the referer header is valid
func refererMatchesHost(c *fiber.Ctx) error {
referer := c.Get(fiber.HeaderReferer)
if referer == "" {
return ErrNoReferer
}
if referer != c.Protocol()+"://"+c.Hostname() {
return ErrBadReferer
}
return nil
}
Loading

0 comments on commit b50d91d

Please sign in to comment.