diff --git a/README.md b/README.md index a25c19b..1a80c7c 100644 --- a/README.md +++ b/README.md @@ -37,12 +37,13 @@ Here are a few bullet point reasons you might like to try it out: ## Strategies * [kubernetes (Token Review)](https://pkg.go.dev/github.com/shaj13/go-guardian/auth/strategies/kubernetes?tab=doc) * [2FA](https://pkg.go.dev/github.com/shaj13/go-guardian/auth/strategies/twofactor?tab=doc) -* [Certificate-Based](https://pkg.go.dev/github.com/shaj13/go-guardian@v1.2.0/auth/strategies/x509?tab=doc) -* [Bearer-Token](https://pkg.go.dev/github.com/shaj13/go-guardian@v1.2.0/auth/strategies/bearer?tab=doc) -* [Static-Token](https://pkg.go.dev/github.com/shaj13/go-guardian@v1.2.0/auth/strategies/bearer?tab=doc) -* [LDAP](https://pkg.go.dev/github.com/shaj13/go-guardian@v1.2.0/auth/strategies/ldap?tab=doc) -* [Basic](https://pkg.go.dev/github.com/shaj13/go-guardian@v1.2.0/auth/strategies/basic?tab=doc) -* [Digest](https://pkg.go.dev/github.com/shaj13/go-guardian@v1.2.0/auth/strategies/digest?tab=doc) +* [Certificate-Based](https://pkg.go.dev/github.com/shaj13/go-guardian/auth/strategies/x509?tab=doc) +* [Bearer-Token](https://pkg.go.dev/github.com/shaj13/go-guardian/auth/strategies/token?tab=doc) +* [Static-Token](https://pkg.go.dev/github.com/shaj13/go-guardian/auth/strategies/token?tab=doc) +* [LDAP](https://pkg.go.dev/github.com/shaj13/go-guardian/auth/strategies/ldap?tab=doc) +* [Basic](https://pkg.go.dev/github.com/shaj13/go-guardian/auth/strategies/basic?tab=doc) +* [Digest](https://pkg.go.dev/github.com/shaj13/go-guardian/auth/strategies/digest?tab=doc) +* [Union](https://pkg.go.dev/github.com/shaj13/go-guardian/auth/strategies/union?tab=doc) # Examples Examples are available on [GoDoc](https://pkg.go.dev/github.com/shaj13/go-guardian) or [Examples Folder](./_examples). diff --git a/_examples/basic_bearer/main.go b/_examples/basic_bearer/main.go index 89b4b7d..6daad46 100644 --- a/_examples/basic_bearer/main.go +++ b/_examples/basic_bearer/main.go @@ -16,8 +16,10 @@ import ( "github.com/shaj13/go-guardian/auth" "github.com/shaj13/go-guardian/auth/strategies/basic" - "github.com/shaj13/go-guardian/auth/strategies/bearer" - "github.com/shaj13/go-guardian/store" + "github.com/shaj13/go-guardian/auth/strategies/token" + "github.com/shaj13/go-guardian/auth/strategies/union" + "github.com/shaj13/go-guardian/cache" + "github.com/shaj13/go-guardian/cache/container/fifo" ) // Usage: @@ -25,8 +27,9 @@ import ( // curl -k http://127.0.0.1:8080/v1/auth/token -u admin:admin // curl -k http://127.0.0.1:8080/v1/book/1449311601 -H "Authorization: Bearer " -var authenticator auth.Authenticator -var cache store.Cache +var strategy union.Union +var tokenStrategy auth.Strategy +var cacheObj cache.Cache func main() { setupGoGuardian() @@ -39,9 +42,8 @@ func main() { func createToken(w http.ResponseWriter, r *http.Request) { token := uuid.New().String() - user := auth.NewDefaultUser("medium", "1", nil, nil) - tokenStrategy := authenticator.Strategy(bearer.CachedStrategyKey) - auth.Append(tokenStrategy, token, user, r) + user := auth.NewDefaultUser("admin", "1", nil, nil) + auth.Append(tokenStrategy, token, user) body := fmt.Sprintf("token: %s \n", token) w.Write([]byte(body)) } @@ -59,20 +61,20 @@ func getBookAuthor(w http.ResponseWriter, r *http.Request) { } func setupGoGuardian() { - authenticator = auth.New() - cache = store.NewFIFO(context.Background(), time.Minute*10) - - basicStrategy := basic.New(validateUser, cache) - tokenStrategy := bearer.New(bearer.NoOpAuthenticate, cache) - - authenticator.EnableStrategy(basic.StrategyKey, basicStrategy) - authenticator.EnableStrategy(bearer.CachedStrategyKey, tokenStrategy) + ttl := fifo.TTL(time.Minute * 5) + exp := fifo.RegisterOnExpired(func(key interface{}) { + cacheObj.Peek(key) + }) + cacheObj = cache.FIFO.New(ttl, exp) + basicStrategy := basic.NewCached(validateUser, cacheObj) + tokenStrategy = token.New(token.NoOpAuthenticate, cacheObj) + strategy = union.New(tokenStrategy, basicStrategy) } func validateUser(ctx context.Context, r *http.Request, userName, password string) (auth.Info, error) { // here connect to db or any other service to fetch user and validate it. if userName == "admin" && password == "admin" { - return auth.NewDefaultUser("medium", "1", nil, nil), nil + return auth.NewDefaultUser("admin", "1", nil, nil), nil } return nil, fmt.Errorf("Invalid credentials") @@ -81,13 +83,14 @@ func validateUser(ctx context.Context, r *http.Request, userName, password strin func middleware(next http.Handler) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Println("Executing Auth Middleware") - user, err := authenticator.Authenticate(r) + _, user, err := strategy.AuthenticateRequest(r) if err != nil { + fmt.Println(err) code := http.StatusUnauthorized http.Error(w, http.StatusText(code), code) return } - log.Printf("User %s Authenticated\n", user.UserName()) + log.Printf("User %s Authenticated\n", user.GetUserName()) next.ServeHTTP(w, r) }) } diff --git a/_examples/digest/main.go b/_examples/digest/main.go index c570a60..4212288 100644 --- a/_examples/digest/main.go +++ b/_examples/digest/main.go @@ -1,25 +1,36 @@ package main import ( - "crypto/md5" + _ "crypto/md5" "fmt" - "hash" "log" "net/http" + "time" "github.com/gorilla/mux" "github.com/shaj13/go-guardian/auth" "github.com/shaj13/go-guardian/auth/strategies/digest" + "github.com/shaj13/go-guardian/cache" + "github.com/shaj13/go-guardian/cache/container/fifo" ) // Usage: // curl --digest --user admin:admin http://127.0.0.1:8080/v1/book/1449311601 -var authenticator auth.Authenticator +var strategy *digest.Digest + +func init() { + var c cache.Cache + ttl := fifo.TTL(time.Minute * 3) + exp := fifo.RegisterOnExpired(func(key interface{}) { + c.Delete(key) + }) + c = cache.FIFO.New(ttl, exp) + strategy = digest.New(validateUser, c) +} func main() { - setupGoGuardian() router := mux.NewRouter() router.HandleFunc("/v1/book/{id}", middleware(http.HandlerFunc(getBookAuthor))).Methods("GET") log.Println("server started and listening on http://127.0.0.1:8080") @@ -38,23 +49,10 @@ func getBookAuthor(w http.ResponseWriter, r *http.Request) { w.Write([]byte(body)) } -func setupGoGuardian() { - authenticator = auth.New() - - digestStrategy := &digest.Strategy{ - Algorithm: "md5", - Hash: func(algo string) hash.Hash { return md5.New() }, - Realm: "test", - FetchUser: validateUser, - } - - authenticator.EnableStrategy(digest.StrategyKey, digestStrategy) -} - func validateUser(userName string) (string, auth.Info, error) { // here connect to db or any other service to fetch user and validate it. if userName == "admin" { - return "admin", auth.NewDefaultUser("medium", "1", nil, nil), nil + return "admin", auth.NewDefaultUser("admin", "1", nil, nil), nil } return "", nil, fmt.Errorf("Invalid credentials") @@ -63,16 +61,15 @@ func validateUser(userName string) (string, auth.Info, error) { func middleware(next http.Handler) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Println("Executing Auth Middleware") - user, err := authenticator.Authenticate(r) + user, err := strategy.Authenticate(r.Context(), r) if err != nil { code := http.StatusUnauthorized - s := authenticator.Strategy(digest.StrategyKey) - s.(*digest.Strategy).WWWAuthenticate(w.Header()) + w.Header().Add("WWW-Authenticate", strategy.GetChallenge()) http.Error(w, http.StatusText(code), code) fmt.Println("send error", err) return } - log.Printf("User %s Authenticated\n", user.UserName()) + log.Printf("User %s Authenticated\n", user.GetUserName()) next.ServeHTTP(w, r) }) } diff --git a/_examples/kubernetes/main.go b/_examples/kubernetes/main.go index 17e9e11..ba2ef26 100644 --- a/_examples/kubernetes/main.go +++ b/_examples/kubernetes/main.go @@ -5,7 +5,6 @@ package main import ( - "context" "fmt" "log" "net/http" @@ -15,18 +14,18 @@ import ( "github.com/shaj13/go-guardian/auth" "github.com/shaj13/go-guardian/auth/strategies/kubernetes" - "github.com/shaj13/go-guardian/auth/strategies/token" - "github.com/shaj13/go-guardian/store" + "github.com/shaj13/go-guardian/cache" + "github.com/shaj13/go-guardian/cache/container/fifo" ) // Usage: // Run kubernetes mock api and get agent token // go run mock.go // Request server to verify token and get book author -// curl -k http://127.0.0.1:8080/v1/book/1449311601 -H "Authorization: Bearer " +// " -var authenticator auth.Authenticator -var cache store.Cache +var strategy auth.Strategy +var cacheObj cache.Cache func main() { setupGoGuardian() @@ -50,22 +49,24 @@ func getBookAuthor(w http.ResponseWriter, r *http.Request) { } func setupGoGuardian() { - authenticator = auth.New() - cache = store.NewFIFO(context.Background(), time.Minute*10) - kubeStrategy := kubernetes.New(cache) - authenticator.EnableStrategy(token.CachedStrategyKey, kubeStrategy) + ttl := fifo.TTL(time.Minute * 5) + exp := fifo.RegisterOnExpired(func(key interface{}) { + cacheObj.Peek(key) + }) + cacheObj = cache.FIFO.New(ttl, exp) + strategy = kubernetes.New(cacheObj) } func middleware(next http.Handler) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Println("Executing Auth Middleware") - user, err := authenticator.Authenticate(r) + user, err := strategy.Authenticate(r.Context(), r) if err != nil { code := http.StatusUnauthorized http.Error(w, http.StatusText(code), code) return } - log.Printf("User %s Authenticated\n", user.UserName()) + log.Printf("User %s Authenticated\n", user.GetUserName()) next.ServeHTTP(w, r) }) } diff --git a/_examples/ldap/main.go b/_examples/ldap/main.go index e8e150e..cbeac57 100644 --- a/_examples/ldap/main.go +++ b/_examples/ldap/main.go @@ -5,7 +5,6 @@ package main import ( - "context" "fmt" "log" "net/http" @@ -15,14 +14,15 @@ import ( "github.com/shaj13/go-guardian/auth" "github.com/shaj13/go-guardian/auth/strategies/ldap" - "github.com/shaj13/go-guardian/store" + "github.com/shaj13/go-guardian/cache" + "github.com/shaj13/go-guardian/cache/container/fifo" ) // Usage: // curl -k http://127.0.0.1:8080/v1/book/1449311601 -u tesla:password -var authenticator auth.Authenticator -var cache store.Cache +var strategy auth.Strategy +var cacheObj cache.Cache func main() { setupGoGuardian() @@ -53,22 +53,24 @@ func setupGoGuardian() { BindPassword: "password", Filter: "(uid=%s)", } - authenticator = auth.New() - cache = store.NewFIFO(context.Background(), time.Minute*10) - strategy := ldap.NewCached(cfg, cache) - authenticator.EnableStrategy(ldap.StrategyKey, strategy) + ttl := fifo.TTL(time.Minute * 5) + exp := fifo.RegisterOnExpired(func(key interface{}) { + cacheObj.Peek(key) + }) + cacheObj = cache.FIFO.New(ttl, exp) + strategy = ldap.NewCached(cfg, cacheObj) } func middleware(next http.Handler) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Println("Executing Auth Middleware") - user, err := authenticator.Authenticate(r) + user, err := strategy.Authenticate(r.Context(), r) if err != nil { code := http.StatusUnauthorized http.Error(w, http.StatusText(code), code) return } - log.Printf("User %s Authenticated\n", user.UserName()) + log.Printf("User %s Authenticated\n", user.GetUserName()) next.ServeHTTP(w, r) }) } diff --git a/_examples/twofactor/main.go b/_examples/twofactor/main.go index 23efc8e..19159d5 100644 --- a/_examples/twofactor/main.go +++ b/_examples/twofactor/main.go @@ -21,7 +21,7 @@ import ( // Usage: // curl -k http://127.0.0.1:8080/v1/book/1449311601 -u admin:admin -H "X-Example-OTP: 345515" -var authenticator auth.Authenticator +var strategy auth.Strategy func main() { setupGoGuardian() @@ -44,16 +44,12 @@ func getBookAuthor(w http.ResponseWriter, r *http.Request) { } func setupGoGuardian() { - authenticator = auth.New() - - basicStrategy := basic.AuthenticateFunc(validateUser) - tfaStrategy := twofactor.Strategy{ + basicStrategy := basic.New(validateUser) + strategy = twofactor.TwoFactor{ Parser: twofactor.XHeaderParser("X-Example-OTP"), Manager: OTPManager{}, Primary: basicStrategy, } - - authenticator.EnableStrategy(twofactor.StrategyKey, tfaStrategy) } func validateUser(ctx context.Context, r *http.Request, userName, password string) (auth.Info, error) { @@ -68,13 +64,13 @@ func validateUser(ctx context.Context, r *http.Request, userName, password strin func middleware(next http.Handler) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Println("Executing Auth Middleware") - user, err := authenticator.Authenticate(r) + user, err := strategy.Authenticate(r.Context(), r) if err != nil { code := http.StatusUnauthorized http.Error(w, http.StatusText(code), code) return } - log.Printf("User %s Authenticated\n", user.UserName()) + log.Printf("User %s Authenticated\n", user.GetUserName()) next.ServeHTTP(w, r) }) } @@ -83,14 +79,14 @@ type OTPManager struct{} func (OTPManager) Enabled(_ auth.Info) bool { return true } -func (OTPManager) Load(_ auth.Info) (twofactor.OTP, error) { +func (OTPManager) Load(_ auth.Info) (twofactor.Verifier, error) { // user otp configuration must be loaded from persistent storage key := otp.NewKey(otp.HOTP, "LABEL", "GXNRHI2MFRFWXQGJHWZJFOSYI6E7MEVA") ver := otp.New(key) return ver, nil } -func (OTPManager) Store(_ auth.Info, otp twofactor.OTP) error { +func (OTPManager) Store(_ auth.Info, otp twofactor.Verifier) error { // persist user otp after verification return nil } diff --git a/_examples/x509/main.go b/_examples/x509/main.go index 5d81717..885a6f7 100644 --- a/_examples/x509/main.go +++ b/_examples/x509/main.go @@ -22,7 +22,7 @@ import ( // Usage: // curl --cacert ./certs/ca --key ./certs/admin-key --cert ./certs/admin https://127.0.0.1:8080/v1/book/1449311601 -var authenticator auth.Authenticator +var strategy auth.Strategy func main() { setupGoGuardian() @@ -52,22 +52,20 @@ func getBookAuthor(w http.ResponseWriter, r *http.Request) { } func setupGoGuardian() { - authenticator = auth.New() opts := verifyOptions() - strategy := gx509.New(opts) - authenticator.EnableStrategy(gx509.StrategyKey, strategy) + strategy = gx509.New(opts) } func middleware(next http.Handler) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Println("Executing Auth Middleware") - user, err := authenticator.Authenticate(r) + user, err := strategy.Authenticate(r.Context(), r) if err != nil { code := http.StatusUnauthorized http.Error(w, http.StatusText(code), code) return } - log.Printf("User %s Authenticated\n", user.UserName()) + log.Printf("User %s Authenticated\n", user.GetUserName()) next.ServeHTTP(w, r) }) } diff --git a/auth/authenticator.go b/auth/authenticator.go deleted file mode 100644 index 21462f4..0000000 --- a/auth/authenticator.go +++ /dev/null @@ -1,100 +0,0 @@ -package auth - -import ( - "errors" - "net/http" - "strings" - - gerrors "github.com/shaj13/go-guardian/errors" -) - -var ( - // ErrNoMatch is returned by Authenticator when request not authenticated, - // and all registered Strategies returned errors. - ErrNoMatch = errors.New("authenticator: No authentication strategy matched") - - // ErrDisabledPath is a soft error similar to EOF. - // returned by Authenticator when a attempting to authenticate request have a disabled path. - // Authenticator return DisabledPath only to signal the caller. - // The caller should continue the request flow, and never return the error to the end users. - ErrDisabledPath = errors.New("authenticator: Disabled Path") - - // ErrNOOP is a soft error similar to EOF, - // returned by strategies that have NoOpAuthenticate function to indicate there no op, - // and signal authenticator to unauthenticate the request. - ErrNOOP = errors.New("NOOP") -) - -// Authenticator carry the registered authentication strategies, -// and represents the first API to authenticate received requests. -type Authenticator interface { - // Authenticate dispatch the request to the registered authentication strategies, - // and return user information from the first strategy that successfully authenticates the request. - // Otherwise, an aggregated error returned. - // if request attempt to visit a disabled path, ErrDisabledPath returned to signal the caller, - // Otherwise, start the authentication process. - // See ErrDisabledPath documentation for more info. - // - // NOTICE: Authenticate does not guarantee the order strategies run in. - Authenticate(r *http.Request) (Info, error) - // EnableStrategy register a new strategy to the authenticator. - EnableStrategy(key StrategyKey, strategy Strategy) - // DisableStrategy unregister a strategy from the authenticator. - DisableStrategy(key StrategyKey) - // Strategy return a registered strategy, Otherwise, nil. - Strategy(key StrategyKey) Strategy - // DisabledPaths return a map[string]struct{} represents a paths disabled from authentication. - // Typically the paths are given during authenticator initialization. - DisabledPaths() map[string]struct{} -} - -type authenticator struct { - strategies map[StrategyKey]Strategy - paths map[string]struct{} -} - -func (a *authenticator) Authenticate(r *http.Request) (Info, error) { - // check if request to a disabled path - if a.disabledPath(r.RequestURI) { - return nil, ErrDisabledPath - } - - errs := gerrors.MultiError{ErrNoMatch} - - for _, strategy := range a.strategies { - info, err := strategy.Authenticate(r.Context(), r) - if err == nil { - return info, nil - } - errs = append(errs, err) - } - - return nil, errs -} - -func (a *authenticator) disabledPath(path string) bool { - path = strings.TrimPrefix(path, "/") - _, ok := a.paths[path] - return ok -} - -func (a *authenticator) Strategy(key StrategyKey) Strategy { return a.strategies[key] } -func (a *authenticator) EnableStrategy(key StrategyKey, s Strategy) { a.strategies[key] = s } -func (a *authenticator) DisableStrategy(key StrategyKey) { delete(a.strategies, key) } -func (a *authenticator) DisabledPaths() map[string]struct{} { return a.paths } - -// New return new Authenticator and disables authentication process at a given paths. -// The returned authenticator not safe for concurrent access. -func New(paths ...string) Authenticator { - p := make(map[string]struct{}) - - for _, path := range paths { - path = strings.TrimPrefix(path, "/") - p[path] = struct{}{} - } - - return &authenticator{ - strategies: make(map[StrategyKey]Strategy), - paths: p, - } -} diff --git a/auth/authenticator_test.go b/auth/authenticator_test.go deleted file mode 100644 index 47951cc..0000000 --- a/auth/authenticator_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package auth - -import ( - "context" - "fmt" - "net/http" - "strconv" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestAuthenticator(t *testing.T) { - table := []struct { - name string - strategies []Strategy - paths []string - userID string - ExpetedErr bool - }{ - { - name: "it return error when all strategies return errors", - strategies: []Strategy{ - strategy{returnErr: true}, - strategy{returnErr: true}, - strategy{returnErr: true}, - }, - ExpetedErr: true, - }, - { - name: "it return error when no strategies enabled", - ExpetedErr: true, - }, - { - name: "it return user info when first strategy return info", - strategies: []Strategy{ - strategy{returnErr: true}, - strategy{id: "1"}, - strategy{returnErr: true}, - }, - ExpetedErr: false, - userID: "1", - }, - { - name: "it return DisabledPath when path disabled", - paths: []string{"/health", "health2", "/api/health"}, - }, - } - - for _, tt := range table { - t.Run(tt.name, func(t *testing.T) { - authenticator := New(tt.paths...) - r, _ := http.NewRequest("GET", "/", nil) - - for i, strategy := range tt.strategies { - key := strconv.Itoa(i) - authenticator.EnableStrategy(StrategyKey(key), strategy) - } - - if len(tt.paths) > 0 { - for _, p := range tt.paths { - r.RequestURI = p - _, err := authenticator.Authenticate(r) - if err != ErrDisabledPath { - t.Errorf("Expected %v, Got %v, For Path %s", ErrDisabledPath, err, p) - continue - } - } - return - } - - info, err := authenticator.Authenticate(r) - - if tt.ExpetedErr { - assert.True(t, err != nil, "Expected error got nil") - return - } - - assert.Equal(t, tt.userID, info.ID()) - }) - } -} - -type strategy struct { - id string - returnErr bool -} - -func (s strategy) Authenticate(ctx context.Context, r *http.Request) (Info, error) { - if s.returnErr { - return nil, fmt.Errorf("authenticator test strategy L25") - } - return NewDefaultUser("", s.id, nil, nil), nil -} diff --git a/auth/cache.go b/auth/cache.go new file mode 100644 index 0000000..6b839a3 --- /dev/null +++ b/auth/cache.go @@ -0,0 +1,12 @@ +package auth + +// Cache type describes the requirements for authentication strategies, +// that cache the authentication decisions. +type Cache interface { + // Load returns key's value. + Load(key interface{}) (interface{}, bool) + // Store sets the value for a key. + Store(key interface{}, value interface{}) + // Delete deletes the key value. + Delete(key interface{}) +} diff --git a/auth/error.go b/auth/error.go new file mode 100644 index 0000000..4ee5175 --- /dev/null +++ b/auth/error.go @@ -0,0 +1,33 @@ +package auth + +import ( + "reflect" +) + +// TypeError represent invalid type assertion error. +type TypeError struct { + prefix string + Want string + Got string +} + +// Error describe error as a string +func (i TypeError) Error() string { + return i.prefix + " Invalid type assertion: Want " + i.Want + " Got " + i.Got +} + +// NewTypeError returns InvalidType error +func NewTypeError(prefix string, want, got interface{}) error { + f := func(v interface{}) string { + if v == nil { + return "" + } + return reflect.TypeOf(v).String() + } + + return TypeError{ + Want: f(want), + Got: f(got), + prefix: prefix, + } +} diff --git a/errors/error_test.go b/auth/error_test.go similarity index 60% rename from errors/error_test.go rename to auth/error_test.go index a38c818..fe59f0b 100644 --- a/errors/error_test.go +++ b/auth/error_test.go @@ -1,7 +1,6 @@ -package errors +package auth import ( - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -40,8 +39,8 @@ func TestInvalidType(t *testing.T) { for _, tt := range table { t.Run(tt.name, func(t *testing.T) { - err := NewInvalidType(tt.want, tt.got) - it, ok := err.(InvalidType) + err := NewTypeError("", tt.want, tt.got) + it, ok := err.(TypeError) assert.True(t, ok) assert.Equal(t, tt.gotStr, it.Got) assert.Equal(t, tt.wantStr, it.Want) @@ -50,33 +49,3 @@ func TestInvalidType(t *testing.T) { }) } } - -func TestError(t *testing.T) { - table := []struct { - errs MultiError - errStr string - }{ - { - errs: MultiError{ - fmt.Errorf("1st error"), - fmt.Errorf("2nd error"), - fmt.Errorf("3rd error"), - }, - errStr: "1st error: [2nd error, 3rd error, ]", - }, - { - errs: MultiError{}, - errStr: "", - }, - { - errs: MultiError{ - fmt.Errorf("Test error"), - }, - errStr: "Test error", - }, - } - - for _, tt := range table { - assert.Equal(t, tt.errs.Error(), tt.errStr) - } -} diff --git a/auth/example_test.go b/auth/example_test.go deleted file mode 100644 index fb288c2..0000000 --- a/auth/example_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package auth - -import ( - "fmt" - "net/http" -) - -func ExampleAppend() { - strategy := &mockStrategy{} - info := NewDefaultUser("1", "2", nil, nil) - token := "90d64460d14870c08c81352a05dedd3465940a7" - r, _ := http.NewRequest("POST", "/login", nil) - // append new token to cached bearer strategy - err := Append(strategy, token, info, r) - fmt.Println(err) - // Output: - // -} - -func ExampleRevoke() { - strategy := &mockStrategy{} - r, _ := http.NewRequest("GET", "/logout", nil) - // assume token extracted from header - token := "90d64460d14870c08c81352a05dedd3465940a7" - err := Revoke(strategy, token, r) - fmt.Println(err) - // Output: - // -} diff --git a/auth/extensions.go b/auth/extensions.go new file mode 100644 index 0000000..dce10f8 --- /dev/null +++ b/auth/extensions.go @@ -0,0 +1,53 @@ +package auth + +// Extensions represents additional information to a user. +type Extensions map[string][]string + +// Add adds the key, value pair to the extensions. +// It appends to any existing values associated with key. +// The key is case sensitive. +func (exts Extensions) Add(key, value string) { + exts[key] = append(exts[key], value) +} + +// Set sets the extensions entries associated with key to +// the single element value. It replaces any existing +// values associated with key. +func (exts Extensions) Set(key, value string) { + exts[key] = []string{value} +} + +// Del deletes the values associated with key. +func (exts Extensions) Del(key string) { + delete(exts, key) +} + +// Get gets the first value associated with the given key. +// It is case sensitive; +// If there are no values associated with the key, Get returns "". +func (exts Extensions) Get(key string) string { + if exts == nil { + return "" + } + v := exts[key] + if len(v) == 0 { + return "" + } + return v[0] +} + +// Values returns all values associated with the given key. +// It is case sensitive; +// The returned slice is not a copy. +func (exts Extensions) Values(key string) []string { + if exts == nil { + return nil + } + return exts[key] +} + +// Has reports whether extensions has the provided key defined. +func (exts Extensions) Has(key string) bool { + _, ok := exts[key] + return ok +} diff --git a/auth/info.go b/auth/info.go index c75d467..4c718d1 100644 --- a/auth/info.go +++ b/auth/info.go @@ -1,130 +1,92 @@ package auth -import ( - "bytes" - "encoding/gob" - "encoding/json" -) - var ic InfoConstructor -func init() { - gob.Register(&DefaultUser{}) -} - // Info describes a user that has been authenticated to the system. type Info interface { - // UserName returns the name that uniquely identifies this user among all + // GetUserName returns the name that uniquely identifies this user among all // other active users. - UserName() string - // ID returns a unique value identify a particular user - ID() string - // Groups returns the names of the groups the user is a member of - Groups() []string - // Extensions can contain any additional information. - Extensions() map[string][]string + GetUserName() string + // SetUserName set the name that uniquely identifies this user among all + // other active users. + SetUserName(string) + // GetID returns a unique value identify a particular user. + GetID() string + // SetID set a unique value identify a particular user. + SetID(string) + // GetGroups returns the names of the groups the user is a member of + GetGroups() []string // SetGroups set the names of the groups the user is a member of. SetGroups(groups []string) + // Extensions can contain any additional information. + GetExtensions() Extensions // SetExtensions to contain additional information. - SetExtensions(exts map[string][]string) + SetExtensions(exts Extensions) } // InfoConstructor define function signature to create new Info object. -type InfoConstructor func(name, id string, groups []string, extensions map[string][]string) Info +type InfoConstructor func(name, id string, groups []string, extensions Extensions) Info // DefaultUser implement Info interface and provides a simple user information. type DefaultUser struct { - name string - id string - groups []string - extensions map[string][]string + Name string + ID string + Groups []string + Extensions Extensions } -// UserName returns the name that uniquely identifies this user among all +// GetUserName returns the name that uniquely identifies this user among all // other active users. -func (d *DefaultUser) UserName() string { - return d.name -} - -// ID returns a unique value identify a particular user -func (d *DefaultUser) ID() string { - return d.id +func (d *DefaultUser) GetUserName() string { + return d.Name } -// Groups returns the names of the groups the user is a member of -func (d *DefaultUser) Groups() []string { - return d.groups +// SetUserName set the name that uniquely identifies this user among all +// other active users. +func (d *DefaultUser) SetUserName(name string) { + d.Name = name } -// SetGroups set the names of the groups the user is a member of. -func (d *DefaultUser) SetGroups(groups []string) { - d.groups = groups +// GetID returns a unique value identify a particular user +func (d *DefaultUser) GetID() string { + return d.ID } -// Extensions return additional information. -func (d *DefaultUser) Extensions() map[string][]string { - return d.extensions +// SetID set a unique value identify a particular user. +func (d *DefaultUser) SetID(id string) { + d.ID = id } -// SetExtensions to contain additional information. -func (d *DefaultUser) SetExtensions(exts map[string][]string) { - d.extensions = exts +// GetGroups returns the names of the groups the user is a member of +func (d *DefaultUser) GetGroups() []string { + return d.Groups } -// MarshalBinary encodes the default user into a binary form and returns the result. -func (d *DefaultUser) MarshalBinary() ([]byte, error) { - b := new(bytes.Buffer) - i := new(internalUser) - i.from(d) - err := gob.NewEncoder(b).Encode(i) - return b.Bytes(), err +// SetGroups set the names of the groups the user is a member of. +func (d *DefaultUser) SetGroups(groups []string) { + d.Groups = groups } -// UnmarshalBinary decode the binary form generated by MarshalBinary. -func (d *DefaultUser) UnmarshalBinary(data []byte) error { - b := new(bytes.Buffer) - i := new(internalUser) - - _, err := b.Write(data) - if err != nil { - return err +// GetExtensions return additional information. +func (d *DefaultUser) GetExtensions() Extensions { + if d.Extensions == nil { + d.Extensions = Extensions{} } - - err = gob.NewDecoder(b).Decode(i) - if err != nil { - return err - } - - i.to(d) - return nil + return d.Extensions } -// MarshalJSON encodes the default user into a json and returns the result. -func (d *DefaultUser) MarshalJSON() ([]byte, error) { - i := new(internalUser) - i.from(d) - return json.Marshal(i) -} - -// UnmarshalJSON decode the json generated by MarshalJSON. -func (d *DefaultUser) UnmarshalJSON(data []byte) error { - i := new(internalUser) - err := json.Unmarshal(data, i) - if err != nil { - return err - } - - i.from(d) - return nil +// SetExtensions to contain additional information. +func (d *DefaultUser) SetExtensions(exts Extensions) { + d.Extensions = exts } // NewDefaultUser return new default user -func NewDefaultUser(name, id string, groups []string, extensions map[string][]string) *DefaultUser { +func NewDefaultUser(name, id string, groups []string, extensions Extensions) *DefaultUser { return &DefaultUser{ - name: name, - id: id, - groups: groups, - extensions: extensions, + Name: name, + ID: id, + Groups: groups, + Extensions: extensions, } } @@ -147,24 +109,3 @@ func NewUserInfo(name, id string, groups []string, extensions map[string][]strin func SetInfoConstructor(c InfoConstructor) { ic = c } - -type internalUser struct { - Name string `json:"name,omitempty"` - ID string `json:"id,omitempty"` - Groups []string `json:"groups,omitempty"` - Extensions map[string][]string `json:"extensions,omitempty"` -} - -func (i *internalUser) from(d *DefaultUser) { - i.Name = d.name - i.ID = d.id - i.Groups = d.groups - i.Extensions = d.extensions -} - -func (i *internalUser) to(d *DefaultUser) { - d.name = i.Name - d.id = i.ID - d.groups = i.Groups - d.extensions = i.Extensions -} diff --git a/internal/parse.go b/auth/internal/parse.go similarity index 100% rename from internal/parse.go rename to auth/internal/parse.go diff --git a/auth/strategies/basic/basic.go b/auth/strategies/basic/basic.go index a288e1c..d454530 100644 --- a/auth/strategies/basic/basic.go +++ b/auth/strategies/basic/basic.go @@ -4,31 +4,21 @@ package basic import ( "context" - "crypto" "errors" - "fmt" "net/http" "github.com/shaj13/go-guardian/auth" - gerrors "github.com/shaj13/go-guardian/errors" - "github.com/shaj13/go-guardian/store" ) -// ErrMissingPrams is returned by Authenticate Strategy method, -// when failed to retrieve user credentials from request. -var ErrMissingPrams = errors.New("basic: Request missing BasicAuth") +var ( + // ErrMissingPrams is returned by Authenticate Strategy method, + // when failed to retrieve user credentials from request. + ErrMissingPrams = errors.New("strategies/basic: Request missing BasicAuth") -// ErrInvalidCredentials is returned by Authenticate Strategy method, -// when user password is invalid. -var ErrInvalidCredentials = errors.New("basic: Invalid user credentials") - -// StrategyKey export identifier for the basic strategy, -// commonly used when enable/add strategy to go-guardian authenticator. -const StrategyKey = auth.StrategyKey("Basic.Strategy") - -// ExtensionKey represents a key for the password in info extensions. -// Typically used when basic strategy cache the authentication decisions. -const ExtensionKey = "x-go-guardian-basic-password" + // ErrInvalidCredentials is returned by Authenticate Strategy method, + // when user password is invalid. + ErrInvalidCredentials = errors.New("strategies/basic: Invalid user credentials") +) // AuthenticateFunc declare custom function to authenticate request using user credentials. // the authenticate function invoked by Authenticate Strategy method after extracting user credentials @@ -36,122 +26,26 @@ const ExtensionKey = "x-go-guardian-basic-password" // with ErrMissingPrams returned, Otherwise, return Authenticate invocation result. type AuthenticateFunc func(ctx context.Context, r *http.Request, userName, password string) (auth.Info, error) -// Authenticate implement Authenticate Strategy method, and return user info or an appropriate error. -func (auth AuthenticateFunc) Authenticate(ctx context.Context, r *http.Request) (auth.Info, error) { - user, pass, err := auth.credentials(r) - - if err != nil { - return nil, err - } - - return auth(ctx, r, user, pass) -} - -// Challenge returns string indicates the authentication scheme. -// Typically used to adds a HTTP WWW-Authenticate header. -func (auth AuthenticateFunc) Challenge(realm string) string { - return fmt.Sprintf(`Basic realm="%s", title="'Basic' HTTP Authentication Scheme"`, realm) -} - -func (auth AuthenticateFunc) credentials(r *http.Request) (string, string, error) { - user, pass, ok := r.BasicAuth() - - if !ok { - return "", "", ErrMissingPrams - } - - return user, pass, nil +type basic struct { + fn AuthenticateFunc + parser Parser } -type cachedBasic struct { - AuthenticateFunc - comparator Comparator - cache store.Cache -} - -func (c *cachedBasic) authenticate(ctx context.Context, r *http.Request, userName, pass string) (auth.Info, error) { // nolint:lll - v, ok, err := c.cache.Load(userName, r) - +func (b basic) Authenticate(ctx context.Context, r *http.Request) (auth.Info, error) { + user, pass, err := b.parser.Credentials(r) if err != nil { return nil, err } - - // if info not found invoke user authenticate function - if !ok { - return c.authenticatAndHash(ctx, r, userName, pass) - } - - if _, ok := v.(auth.Info); !ok { - return nil, gerrors.NewInvalidType((*auth.Info)(nil), v) - } - - info := v.(auth.Info) - ext := info.Extensions() - hashedPass, ok := ext[ExtensionKey] - - if !ok { - return c.authenticatAndHash(ctx, r, userName, pass) - } - - return info, c.comparator.Verify(hashedPass[0], pass) -} - -func (c *cachedBasic) authenticatAndHash(ctx context.Context, r *http.Request, userName, pass string) (auth.Info, error) { //nolint:lll - info, err := c.AuthenticateFunc(ctx, r, userName, pass) - if err != nil { - return nil, err - } - - ext := info.Extensions() - if ext == nil { - ext = make(map[string][]string) - } - - hashedPass, _ := c.comparator.Hash(pass) - ext[ExtensionKey] = []string{hashedPass} - info.SetExtensions(ext) - - // cache result - if err := c.cache.Store(userName, info, r); err != nil { - return nil, err - } - - return info, nil + return b.fn(ctx, r, user, pass) } // New return new auth.Strategy. -// The returned strategy, caches the invocation result of authenticate function. -func New(f AuthenticateFunc, cache store.Cache) auth.Strategy { - return NewWithOptions(f, cache) -} - -// NewWithOptions return new auth.Strategy. -// The returned strategy, caches the invocation result of authenticate function. -func NewWithOptions(f AuthenticateFunc, cache store.Cache, opts ...auth.Option) auth.Strategy { - cb := &cachedBasic{ - AuthenticateFunc: f, - cache: cache, - comparator: plainText{}, - } - +func New(fn AuthenticateFunc, opts ...auth.Option) auth.Strategy { + b := new(basic) + b.fn = fn + b.parser = AuthorizationParser() for _, opt := range opts { - opt.Apply(cb) + opt.Apply(b) } - - return AuthenticateFunc(cb.authenticate) -} - -// SetHash set the hashing algorithm to hash the user password. -func SetHash(h crypto.Hash) auth.Option { - b := basicHashing{h} - return SetComparator(b) -} - -// SetComparator set password comparator. -func SetComparator(c Comparator) auth.Option { - return auth.OptionFunc(func(v interface{}) { - if v, ok := v.(*cachedBasic); ok { - v.comparator = c - } - }) + return b } diff --git a/auth/strategies/basic/basic_test.go b/auth/strategies/basic/basic_test.go index 2dd70eb..55b6ea7 100644 --- a/auth/strategies/basic/basic_test.go +++ b/auth/strategies/basic/basic_test.go @@ -2,15 +2,10 @@ package basic import ( "context" - "crypto" - _ "crypto/sha256" "fmt" "net/http" - "sync" "testing" - "github.com/stretchr/testify/assert" - "github.com/shaj13/go-guardian/auth" ) @@ -50,7 +45,7 @@ func Test(t *testing.T) { t.Run(tt.name, func(t *testing.T) { r, _ := http.NewRequest("GET", "/", nil) tt.setCredentials(r) - info, err := AuthenticateFunc(auth).Authenticate(r.Context(), r) + info, err := New(auth).Authenticate(r.Context(), r) if tt.expectedErr && err == nil { t.Errorf("%s: Expected error, got none", tt.name) @@ -65,124 +60,11 @@ func Test(t *testing.T) { } } -func TestChallenge(t *testing.T) { - basic := AuthenticateFunc(func(_ context.Context, _ *http.Request, _, _ string) (auth.Info, error) { - return nil, nil - }) - - got := basic.Challenge("Test Realm") - expected := `Basic realm="Test Realm", title="'Basic' HTTP Authentication Scheme"` - - assert.Equal(t, expected, got) -} - -//nolint:goconst -func TestNewCached(t *testing.T) { - authFunc := func(ctx context.Context, r *http.Request, userName, password string) (auth.Info, error) { - if userName == "auth-error" { - return nil, fmt.Errorf("Invalid credentials") - } - - return auth.NewDefaultUser("test", "10", nil, nil), nil - } - - table := []struct { - name string - setCredentials func(r *http.Request) - expectedErr bool - }{ - { - name: "it return user from cache", - setCredentials: func(r *http.Request) { r.SetBasicAuth("predefined2", "test") }, - expectedErr: false, - }, - { - name: "it re-authenticate user when hash missing", - setCredentials: func(r *http.Request) { r.SetBasicAuth("predefined3", "test") }, - expectedErr: false, - }, - { - name: "it return error when cache hold invalid user info", - setCredentials: func(r *http.Request) { r.SetBasicAuth("predefined", "test") }, - expectedErr: true, - }, - { - name: "it return error when cache return error on load", - setCredentials: func(r *http.Request) { r.SetBasicAuth("error", "test") }, - expectedErr: true, - }, - { - name: "it return error when cache return error on store", - setCredentials: func(r *http.Request) { r.SetBasicAuth("store-error", "test") }, - expectedErr: true, - }, - { - name: "it return user info when request authenticated", - setCredentials: func(r *http.Request) { r.SetBasicAuth("test", "test") }, - expectedErr: false, - }, - { - name: "it return error when request has invalid credentials", - setCredentials: func(r *http.Request) { r.SetBasicAuth("auth-error", "unknown") }, - expectedErr: true, - }, - { - name: "it return error when request missing basic auth params", - setCredentials: func(r *http.Request) { /* no op */ }, - expectedErr: true, - }, - } - - for _, tt := range table { - t.Run(tt.name, func(t *testing.T) { - r, _ := http.NewRequest("GET", "/", nil) - tt.setCredentials(r) - - c := newMockCache() - c.cache["predefined"] = "invalid-type" - c.cache["predefined2"] = auth.NewDefaultUser( - "predefined2", - "10", - nil, - map[string][]string{ - ExtensionKey: {"9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08"}, - }, - ) - c.cache["predefined3"] = auth.NewDefaultUser("predefined3", "10", nil, nil) - - opt := SetHash(crypto.SHA256) - info, err := NewWithOptions(authFunc, c, opt).Authenticate(r.Context(), r) - - assert.Equal(t, tt.expectedErr, err != nil, "%s: Got Unexpected error %v", tt.name, err) - assert.Equal(t, !tt.expectedErr, info != nil, "%s: Expected info object, got nil", tt.name) - }) - } -} - -func BenchmarkCachedBasic(b *testing.B) { - r, _ := http.NewRequest("GET", "/", nil) - r.SetBasicAuth("test", "test") - - cache := newMockCache() - - strategy := New(exampleAuthFunc, cache) - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - _, err := strategy.Authenticate(r.Context(), r) - if err != nil { - b.Error(err) - } - } - }) -} - func BenchmarkBasic(b *testing.B) { r, _ := http.NewRequest("GET", "/", nil) r.SetBasicAuth("test", "test") - strategy := AuthenticateFunc(exampleAuthFunc) + strategy := New(exampleAuthFunc) b.ResetTimer() b.RunParallel(func(pb *testing.PB) { @@ -195,40 +77,11 @@ func BenchmarkBasic(b *testing.B) { }) } -type mockCache struct { - cache map[string]interface{} - mu *sync.Mutex -} - -func (m mockCache) Load(key string, _ *http.Request) (interface{}, bool, error) { - m.mu.Lock() - defer m.mu.Unlock() - if key == "error" { - return nil, false, fmt.Errorf("Load Error") - } - v, ok := m.cache[key] - return v, ok, nil -} - -func (m mockCache) Store(key string, value interface{}, _ *http.Request) error { - m.mu.Lock() - defer m.mu.Unlock() - if key == "store-error" { - return fmt.Errorf("Store Error") +func exampleAuthFunc(ctx context.Context, r *http.Request, userName, password string) (auth.Info, error) { + // here connect to db or any other service to fetch user and validate it. + if userName == "test" && password == "test" { + return auth.NewDefaultUser("test", "10", nil, nil), nil } - m.cache[key] = value - return nil -} - -func (m mockCache) Delete(key string, _ *http.Request) error { - return nil -} -func (m mockCache) Keys() []string { return nil } - -func newMockCache() mockCache { - return mockCache{ - cache: make(map[string]interface{}), - mu: new(sync.Mutex), - } + return nil, fmt.Errorf("Invalid credentials") } diff --git a/auth/strategies/basic/cached.go b/auth/strategies/basic/cached.go new file mode 100644 index 0000000..e8e626d --- /dev/null +++ b/auth/strategies/basic/cached.go @@ -0,0 +1,66 @@ +package basic + +import ( + "context" + "net/http" + + "github.com/shaj13/go-guardian/auth" +) + +// ExtensionKey represents a key for the password in info extensions. +// Typically used when basic strategy cache the authentication decisions. +const ExtensionKey = "x-go-guardian-basic-password" + +type cachedBasic struct { + fn AuthenticateFunc + comparator Comparator + cache auth.Cache +} + +func (c *cachedBasic) authenticate(ctx context.Context, r *http.Request, userName, pass string) (auth.Info, error) { // nolint:lll + v, ok := c.cache.Load(userName) + + // if info not found invoke user authenticate function + if !ok { + return c.authenticatAndHash(ctx, r, userName, pass) + } + + if _, ok := v.(auth.Info); !ok { + return nil, auth.NewTypeError("strategies/basic:", (*auth.Info)(nil), v) + } + + info := v.(auth.Info) + ext := info.GetExtensions() + + if !ext.Has(ExtensionKey) { + return c.authenticatAndHash(ctx, r, userName, pass) + } + + return info, c.comparator.Compare(ext.Get(ExtensionKey), pass) +} + +func (c *cachedBasic) authenticatAndHash(ctx context.Context, r *http.Request, userName, pass string) (auth.Info, error) { //nolint:lll + info, err := c.fn(ctx, r, userName, pass) + if err != nil { + return nil, err + } + + hashedPass, _ := c.comparator.Hash(pass) + info.GetExtensions().Set(ExtensionKey, hashedPass) + c.cache.Store(userName, info) + + return info, nil +} + +// NewCached return new auth.Strategy. +// The returned strategy, caches the invocation result of authenticate function. +func NewCached(f AuthenticateFunc, cache auth.Cache, opts ...auth.Option) auth.Strategy { + cb := new(cachedBasic) + cb.fn = f + cb.cache = cache + cb.comparator = plainText{} + for _, opt := range opts { + opt.Apply(cb) + } + return New(cb.authenticate, opts...) +} diff --git a/auth/strategies/basic/cached_test.go b/auth/strategies/basic/cached_test.go new file mode 100644 index 0000000..74cf26c --- /dev/null +++ b/auth/strategies/basic/cached_test.go @@ -0,0 +1,134 @@ +package basic + +import ( + "context" + "crypto" + _ "crypto/sha256" + "fmt" + "net/http" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/shaj13/go-guardian/auth" +) + +//nolint:goconst +func TestNewCached(t *testing.T) { + authFunc := func(ctx context.Context, r *http.Request, userName, password string) (auth.Info, error) { + if userName == "auth-error" { + return nil, fmt.Errorf("Invalid credentials") + } + + return auth.NewDefaultUser("test", "10", nil, nil), nil + } + + table := []struct { + name string + setCredentials func(r *http.Request) + expectedErr bool + }{ + { + name: "it return user from cache", + setCredentials: func(r *http.Request) { r.SetBasicAuth("predefined2", "test") }, + expectedErr: false, + }, + { + name: "it re-authenticate user when hash missing", + setCredentials: func(r *http.Request) { r.SetBasicAuth("predefined3", "test") }, + expectedErr: false, + }, + { + name: "it return error when cache hold invalid user info", + setCredentials: func(r *http.Request) { r.SetBasicAuth("predefined", "test") }, + expectedErr: true, + }, + { + name: "it return user info when request authenticated", + setCredentials: func(r *http.Request) { r.SetBasicAuth("test", "test") }, + expectedErr: false, + }, + { + name: "it return error when request has invalid credentials", + setCredentials: func(r *http.Request) { r.SetBasicAuth("auth-error", "unknown") }, + expectedErr: true, + }, + { + name: "it return error when request missing basic auth params", + setCredentials: func(r *http.Request) { /* no op */ }, + expectedErr: true, + }, + } + + for _, tt := range table { + t.Run(tt.name, func(t *testing.T) { + r, _ := http.NewRequest("GET", "/", nil) + tt.setCredentials(r) + + c := newMockCache() + c.cache["predefined"] = "invalid-type" + c.cache["predefined2"] = auth.NewDefaultUser( + "predefined2", + "10", + nil, + map[string][]string{ + ExtensionKey: {"9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08"}, + }, + ) + c.cache["predefined3"] = auth.NewDefaultUser("predefined3", "10", nil, nil) + + opt := SetHash(crypto.SHA256) + info, err := NewCached(authFunc, c, opt).Authenticate(r.Context(), r) + + assert.Equal(t, tt.expectedErr, err != nil, "%s: Got Unexpected error %v", tt.name, err) + assert.Equal(t, !tt.expectedErr, info != nil, "%s: Expected info object, got nil", tt.name) + }) + } +} + +func BenchmarkCachedBasic(b *testing.B) { + r, _ := http.NewRequest("GET", "/", nil) + r.SetBasicAuth("test", "test") + + cache := newMockCache() + + strategy := NewCached(exampleAuthFunc, cache) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := strategy.Authenticate(r.Context(), r) + if err != nil { + b.Error(err) + } + } + }) +} + +type mockCache struct { + cache map[interface{}]interface{} + mu *sync.Mutex +} + +func (m mockCache) Load(key interface{}) (interface{}, bool) { + m.mu.Lock() + defer m.mu.Unlock() + v, ok := m.cache[key] + return v, ok +} + +func (m mockCache) Store(key interface{}, value interface{}) { + m.mu.Lock() + defer m.mu.Unlock() + m.cache[key] = value +} + +func (m mockCache) Delete(interface{}) {} + +func newMockCache() mockCache { + return mockCache{ + cache: make(map[interface{}]interface{}), + mu: new(sync.Mutex), + } +} diff --git a/auth/strategies/basic/comparator.go b/auth/strategies/basic/comparator.go index e2f4461..8e24153 100644 --- a/auth/strategies/basic/comparator.go +++ b/auth/strategies/basic/comparator.go @@ -11,7 +11,7 @@ import ( // with its possible plaintext equivalent type Comparator interface { Hash(password string) (string, error) - Verify(hashedPassword, password string) error + Compare(hashedPassword, password string) error } type basicHashing struct { @@ -25,7 +25,7 @@ func (b basicHashing) Hash(password string) (string, error) { return hex.EncodeToString(sum), nil } -func (b basicHashing) Verify(hashedPassword, password string) error { +func (b basicHashing) Compare(hashedPassword, password string) error { hash, err := b.Hash(password) if err != nil { return err @@ -44,7 +44,7 @@ func (p plainText) Hash(password string) (string, error) { return password, nil } -func (p plainText) Verify(hashedPassword, password string) error { +func (p plainText) Compare(hashedPassword, password string) error { if subtle.ConstantTimeCompare([]byte(hashedPassword), []byte(password)) == 1 { return nil } diff --git a/auth/strategies/basic/comparator_test.go b/auth/strategies/basic/comparator_test.go index b69adfb..ac4fea7 100644 --- a/auth/strategies/basic/comparator_test.go +++ b/auth/strategies/basic/comparator_test.go @@ -11,8 +11,8 @@ func TestBasicHashing(t *testing.T) { b := basicHashing{h: crypto.SHA256} pass := "password" hash, err := b.Hash(pass) - match := b.Verify(hash, pass) - missmatch := b.Verify(hash, "pass") + match := b.Compare(hash, pass) + missmatch := b.Compare(hash, "pass") assert.NoError(t, err) assert.NotEqual(t, hash, pass) @@ -24,8 +24,8 @@ func TestPlainText(t *testing.T) { p := plainText{} pass := "password" hash, err := p.Hash(pass) - match := p.Verify(hash, pass) - missmatch := p.Verify(hash, "pass") + match := p.Compare(hash, pass) + missmatch := p.Compare(hash, "pass") assert.NoError(t, err) assert.Equal(t, hash, pass) diff --git a/auth/strategies/basic/example_test.go b/auth/strategies/basic/example_test.go index 6cc9ae4..0506dce 100644 --- a/auth/strategies/basic/example_test.go +++ b/auth/strategies/basic/example_test.go @@ -1,4 +1,4 @@ -package basic +package basic_test import ( "context" @@ -7,24 +7,23 @@ import ( "net/http" "github.com/shaj13/go-guardian/auth" - "github.com/shaj13/go-guardian/errors" - "github.com/shaj13/go-guardian/store" + "github.com/shaj13/go-guardian/auth/strategies/basic" + "github.com/shaj13/go-guardian/cache" + _ "github.com/shaj13/go-guardian/cache/container/lru" ) func Example() { - strategy := AuthenticateFunc(exampleAuthFunc) - authenticator := auth.New() - authenticator.EnableStrategy(StrategyKey, strategy) + strategy := basic.New(exampleAuthFunc) // user request req, _ := http.NewRequest("GET", "/", nil) req.SetBasicAuth("test", "test") - user, err := authenticator.Authenticate(req) - fmt.Println(user.ID(), err) + user, err := strategy.Authenticate(req.Context(), req) + fmt.Println(user.GetID(), err) req.SetBasicAuth("test", "1234") - _, err = authenticator.Authenticate(req) - fmt.Println(err.(errors.MultiError)[1]) + _, err = strategy.Authenticate(req.Context(), req) + fmt.Println(err) // Output: // 10 @@ -34,31 +33,28 @@ func Example() { func Example_second() { // This example show how to caches the result of basic auth. // With LRU cache - cache := store.New(2) - - strategy := New(exampleAuthFunc, cache) - authenticator := auth.New() - authenticator.EnableStrategy(StrategyKey, strategy) + cache := cache.LRU.New() + strategy := basic.NewCached(exampleAuthFunc, cache) // user request req, _ := http.NewRequest("GET", "/", nil) req.SetBasicAuth("test", "test") - user, err := authenticator.Authenticate(req) - fmt.Println(user.ID(), err) + user, err := strategy.Authenticate(req.Context(), req) + fmt.Println(user.GetID(), err) req.SetBasicAuth("test", "1234") - _, err = authenticator.Authenticate(req) - fmt.Println(err.(errors.MultiError)[1]) + _, err = strategy.Authenticate(req.Context(), req) + fmt.Println(err) // Output: // 10 - // basic: Invalid user credentials + // strategies/basic: Invalid user credentials } func ExampleSetHash() { - opt := SetHash(crypto.SHA256) // import _ crypto/sha256 - cache := store.New(2) - NewWithOptions(exampleAuthFunc, cache, opt) + opt := basic.SetHash(crypto.SHA256) // import _ crypto/sha256 + cache := cache.LRU.New() + basic.NewCached(exampleAuthFunc, cache, opt) } func exampleAuthFunc(ctx context.Context, r *http.Request, userName, password string) (auth.Info, error) { diff --git a/auth/strategies/basic/options.go b/auth/strategies/basic/options.go new file mode 100644 index 0000000..2bc9300 --- /dev/null +++ b/auth/strategies/basic/options.go @@ -0,0 +1,31 @@ +package basic + +import ( + "crypto" + + "github.com/shaj13/go-guardian/auth" +) + +// SetHash set the hashing algorithm to hash the user password. +func SetHash(h crypto.Hash) auth.Option { + b := basicHashing{h} + return SetComparator(b) +} + +// SetComparator set password comparator. +func SetComparator(c Comparator) auth.Option { + return auth.OptionFunc(func(v interface{}) { + if v, ok := v.(*cachedBasic); ok { + v.comparator = c + } + }) +} + +// SetParser set credentials parser. +func SetParser(p Parser) auth.Option { + return auth.OptionFunc(func(v interface{}) { + if v, ok := v.(*basic); ok { + v.parser = p + } + }) +} diff --git a/auth/strategies/basic/parser.go b/auth/strategies/basic/parser.go new file mode 100644 index 0000000..d73f97f --- /dev/null +++ b/auth/strategies/basic/parser.go @@ -0,0 +1,28 @@ +package basic + +import "net/http" + +// Parser parse and extract user credentials from incoming HTTP request. +type Parser interface { + Credentials(r *http.Request) (string, string, error) +} + +type credentialsFn func(r *http.Request) (string, string, error) + +func (fn credentialsFn) Credentials(r *http.Request) (string, string, error) { + return fn(r) +} + +// AuthorizationParser return a credentials parser, +// where credentials extracted form Authorization header. +func AuthorizationParser() Parser { + fn := func(r *http.Request) (string, string, error) { + user, pass, ok := r.BasicAuth() + if !ok { + return "", "", ErrMissingPrams + } + return user, pass, nil + } + + return credentialsFn(fn) +} diff --git a/auth/strategies/bearer/bearer.go b/auth/strategies/bearer/bearer.go deleted file mode 100644 index e6e7bdc..0000000 --- a/auth/strategies/bearer/bearer.go +++ /dev/null @@ -1,71 +0,0 @@ -// Package bearer provides authentication strategy, -// to authenticate HTTP requests based on the bearer token. -// -// Deprecated: Use token Strategy instead. -package bearer - -import ( - "context" - "net/http" - - "github.com/shaj13/go-guardian/auth" - "github.com/shaj13/go-guardian/auth/strategies/token" - "github.com/shaj13/go-guardian/store" -) - -var ( - // ErrInvalidToken indicate a hit of an invalid bearer token format. - // And it's returned by Token function. - ErrInvalidToken = token.ErrInvalidToken - // ErrTokenNotFound is returned by authenticating functions for bearer strategies, - // when token not found in their store. - ErrTokenNotFound = token.ErrTokenNotFound -) - -const ( - // CachedStrategyKey export identifier for the cached bearer strategy, - // commonly used when enable/add strategy to go-guardian authenticator. - CachedStrategyKey = token.CachedStrategyKey - // StatitcStrategyKey export identifier for the static bearer strategy, - // commonly used when enable/add strategy to go-guardian authenticator. - StatitcStrategyKey = token.StatitcStrategyKey -) - -// AuthenticateFunc declare custom function to authenticate request using token. -// The authenticate function invoked by Authenticate Strategy method when -// The token does not exist in the cahce and the invocation result will be cached, unless an error returned. -// Use NoOpAuthenticate instead to refresh/mangae token directly using cache or Append function. -type AuthenticateFunc = token.AuthenticateFunc - -// Static implements auth.Strategy and define a synchronized map honor all predefined bearer tokens. -type Static = token.Static - -// Token return bearer token from Authorization header, or ErrInvalidToken, -// The returned token will not contain "Bearer" keyword -func Token(r *http.Request) (string, error) { - return token.AuthorizationParser("Bearer").Token(r) -} - -// NewStaticFromFile returns static auth.Strategy, populated from a CSV file. -func NewStaticFromFile(path string) (auth.Strategy, error) { - return token.NewStaticFromFile(path) -} - -// NewStatic returns static auth.Strategy, populated from a map. -func NewStatic(tokens map[string]auth.Info) auth.Strategy { - return token.NewStatic(tokens) -} - -// New return new auth.Strategy. -// The returned strategy, caches the invocation result of authenticate function, See AuthenticateFunc. -// Use NoOpAuthenticate to refresh/mangae token directly using cache or Append function, See NoOpAuthenticate. -func New(auth AuthenticateFunc, c store.Cache) auth.Strategy { - return token.New(auth, c) -} - -// NoOpAuthenticate implements Authenticate function, it return nil, auth.ErrNOOP, -// commonly used when token refreshed/mangaed directly using cache or Append function, -// and there is no need to parse token and authenticate request. -func NoOpAuthenticate(ctx context.Context, r *http.Request, token string) (auth.Info, error) { - return nil, auth.ErrNOOP -} diff --git a/auth/strategies/digest/cached.go b/auth/strategies/digest/cached.go deleted file mode 100644 index b62da8a..0000000 --- a/auth/strategies/digest/cached.go +++ /dev/null @@ -1,72 +0,0 @@ -package digest - -import ( - "context" - "net/http" - - "github.com/shaj13/go-guardian/auth" - "github.com/shaj13/go-guardian/errors" - "github.com/shaj13/go-guardian/store" -) - -const extensionKey = "x-go-guardian-digest" - -// CachedStrategy caches digest strategy authentication response based on authorization nonce field. -type CachedStrategy struct { - *Strategy - Cache store.Cache -} - -// Authenticate user request and returns user info, Otherwise error. -func (c *CachedStrategy) Authenticate(ctx context.Context, r *http.Request) (auth.Info, error) { - authz := r.Header.Get("Authorization") - h := make(Header) - _ = h.Parse(authz) - - info, ok, err := c.Cache.Load(h.Nonce(), r) - - if err != nil { - return nil, err - } - - if !ok { - info, err := c.Strategy.Authenticate(ctx, r) - - if err != nil { - return nil, err - } - - ext := info.Extensions() - if ext == nil { - ext = make(map[string][]string) - } - - ext[extensionKey] = []string{h.String()} - info.SetExtensions(ext) - - // cache result - err = c.Cache.Store(h.Nonce(), info, r) - - return info, err - } - - v, ok := info.(auth.Info) - - if !ok { - return nil, errors.NewInvalidType((*auth.Info)(nil), info) - } - - sh := make(Header) - _ = sh.Parse(v.Extensions()[extensionKey][0]) - - h.SetNC("00000001") - h.SetURI(r.RequestURI) - sh.SetNC("00000001") - sh.SetURI(r.RequestURI) - - if err := sh.Compare(h); err != nil { - return nil, err - } - - return v, nil -} diff --git a/auth/strategies/digest/cached_test.go b/auth/strategies/digest/cached_test.go deleted file mode 100644 index 2938333..0000000 --- a/auth/strategies/digest/cached_test.go +++ /dev/null @@ -1,182 +0,0 @@ -//nolint: lll, goconst -package digest - -import ( - "crypto/md5" - "fmt" - "hash" - "net/http" - "strings" - "sync" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/shaj13/go-guardian/auth" -) - -func TestNewCahced(t *testing.T) { - table := []struct { - name string - expectedErr bool - info interface{} - key string - header string - }{ - { - name: "it return error when cache load return error", - expectedErr: true, - header: "Digest nonce=error", - info: nil, - }, - { - name: "it return error when user authenticate func return error", - expectedErr: true, - header: "Digest nonce=ignore", - info: nil, - }, - { - name: "it return error when cache store return error", - expectedErr: true, - header: `Digest username="a", realm="t", nonce="store-error", uri="/", cnonce="1=", nc=00000001, qop=auth, response="14979b5053904998faf57bc72a1d7a56", opaque="1"`, - info: nil, - }, - { - name: "it return error when cache return invalid type", - expectedErr: true, - info: "sample-data", - header: "Digest nonce=invalid", - key: "invalid", - }, - { - name: "it authinticate user", - expectedErr: false, - info: nil, - header: `Digest username="a", realm="t", nonce="150", uri="/", cnonce="1=", nc=00000001, qop=auth, response="22b46269226822f5edde5a52985679ad", opaque="1"`, - }, - { - name: "it authinticate user even if nc, uri changed and load it from cache", - expectedErr: false, - info: auth.NewDefaultUser( - "test", - "1", - nil, - map[string][]string{ - extensionKey: {`Digest username="a", realm="t", nonce="150", uri="/abc", cnonce="1=", nc=00000001, qop=auth, response="22b46269226822f5edde5a52985679ad", opaque="1"`}, - }, - ), - key: "150", - header: `Digest username="a", realm="t", nonce="150", uri="/", cnonce="1=", nc=00000002, qop=auth, response="22b46269226822f5edde5a52985679ad", opaque="1"`, - }, - } - - for _, tt := range table { - t.Run(tt.name, func(t *testing.T) { - cache := newMockCache() - - s := &CachedStrategy{ - Strategy: &Strategy{ - Algorithm: "md5", - Realm: "test", - FetchUser: func(userName string) (string, auth.Info, error) { - if userName == "error" { - return "", nil, fmt.Errorf("Error #L 51") - } - return "", auth.NewDefaultUser("test", "1", nil, nil), nil - }, - Hash: func(algo string) hash.Hash { - return md5.New() - }, - }, - Cache: cache, - } - - r, _ := http.NewRequest("GET", "/", nil) - r.Header.Set("Authorization", tt.header) - - _ = cache.Store(tt.key, tt.info, r) - - _, err := s.Authenticate(r.Context(), r) - - assert.Equal(t, tt.expectedErr, err != nil) - }) - } -} - -func BenchmarkCached(b *testing.B) { - count := 0 - - authz := `Digest username="a", realm="t", nonce="1", uri="/", cnonce="1=", nc=00000001, qop=auth, response="60bf91821e3499f7b04f3de9ee7e3aa6", opaque="1"` - r, _ := http.NewRequest("GET", "/", nil) - r.Header.Set("Authorization", authz) - - strategy := &CachedStrategy{ - Strategy: &Strategy{ - Algorithm: "md5", - Realm: "test", - FetchUser: func(userName string) (string, auth.Info, error) { - count++ - if count >= 10 { - b.Fatal("Function fetch user called more than 10 times this should not happens in cached strategy") - } - return "", auth.NewDefaultUser("benchmark", "1", nil, nil), nil - }, - Hash: func(algo string) hash.Hash { - return md5.New() - }, - }, - Cache: newMockCache(), - } - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - _, err := strategy.Authenticate(r.Context(), r) - if err != nil { - b.Error(err) - } - } - }) -} - -type mockCache struct { - cache map[string]interface{} - mu *sync.Mutex -} - -func (m mockCache) Load(key string, _ *http.Request) (interface{}, bool, error) { - m.mu.Lock() - defer m.mu.Unlock() - if key == "error" { - return nil, false, fmt.Errorf("Load Error") - } - v, ok := m.cache[key] - return v, ok, nil -} - -func (m mockCache) Store(key string, value interface{}, _ *http.Request) error { - m.mu.Lock() - defer m.mu.Unlock() - if key == "ignore" { - return nil - } - - if strings.Contains(key, "store-error") { - return fmt.Errorf("Store Error") - } - m.cache[key] = value - return nil -} - -func (m mockCache) Delete(key string, _ *http.Request) error { - return nil -} - -func (m mockCache) Keys() []string { return nil } - -func newMockCache() mockCache { - return mockCache{ - cache: make(map[string]interface{}), - mu: new(sync.Mutex), - } -} diff --git a/auth/strategies/digest/digest.go b/auth/strategies/digest/digest.go index 326ad1e..32c9109 100644 --- a/auth/strategies/digest/digest.go +++ b/auth/strategies/digest/digest.go @@ -4,37 +4,32 @@ package digest import ( "context" + "crypto" + _ "crypto/md5" //nolint:gosec + "crypto/subtle" "encoding/hex" - "hash" + "errors" "net/http" "github.com/shaj13/go-guardian/auth" - "github.com/shaj13/go-guardian/errors" ) -// StrategyKey export identifier for the digest strategy, -// commonly used when enable/add strategy to go-guardian authenticator. -const StrategyKey = auth.StrategyKey("Digest.Strategy") - // ErrInvalidResponse is returned by Strategy when client authz response does not match server hash. -var ErrInvalidResponse = errors.New("Digest: Invalid Response") - -// Strategy implements auth.Strategy and represents digest authentication as described in RFC 7616. -type Strategy struct { - // Hash a callback function to return the desired hash algorithm, - // the passed algo args is extracted from authorization header - // and if it missing the args will be same as algorithm field provided to the strategy. - Hash func(algo string) hash.Hash +var ErrInvalidResponse = errors.New("strategies/digest: Invalid Response") - // FetchUser a callback function to return the user password and user info or error in case of occurs. - FetchUser func(userName string) (string, auth.Info, error) +// FetchUser a callback function to return the user password and user info. +type FetchUser func(userName string) (string, auth.Info, error) - Realm string - Algorithm string +// Digest authentication strategy. +type Digest struct { + fn FetchUser + chash crypto.Hash + c auth.Cache + h Header } // Authenticate user request and returns user info, Otherwise error. -func (s *Strategy) Authenticate(ctx context.Context, r *http.Request) (auth.Info, error) { +func (d *Digest) Authenticate(ctx context.Context, r *http.Request) (auth.Info, error) { authz := r.Header.Get("Authorization") h := make(Header) @@ -42,67 +37,65 @@ func (s *Strategy) Authenticate(ctx context.Context, r *http.Request) (auth.Info return nil, err } - passwd, info, err := s.FetchUser(h.UserName()) - + passwd, info, err := d.fn(h.UserName()) if err != nil { return nil, err } - algo := h.Algorithm() - if len(algo) == 0 { - algo = s.Algorithm + HA1 := d.hash(h.UserName() + ":" + h.Realm() + ":" + passwd) + HA2 := d.hash(r.Method + ":" + r.RequestURI) + HKD := d.hash(HA1 + ":" + h.Nonce() + ":" + h.NC() + ":" + h.Cnonce() + ":" + h.QOP() + ":" + HA2) + if subtle.ConstantTimeCompare([]byte(HKD), []byte(h.Response())) != 1 { + return nil, ErrInvalidResponse } - HA1 := s.hash(algo, h.UserName()+":"+h.Realm()+":"+passwd) - HA2 := s.hash(algo, r.Method+":"+r.RequestURI) - HKD := s.hash(algo, HA1+":"+h.Nonce()+":"+h.NC()+":"+h.Cnonce()+":"+h.QOP()+":"+HA2) - - if HKD != h.Response() { + if _, ok := d.c.Load(h.Nonce()); !ok { return nil, ErrInvalidResponse } - return info, nil -} - -func (s *Strategy) hash(algo, str string) string { - h := s.Hash(algo) - _, _ = h.Write([]byte(str)) - p := h.Sum(nil) - return hex.EncodeToString(p) -} - -// WWWAuthenticate set HTTP WWW-Authenticate header field with Digest scheme. -func (s *Strategy) WWWAuthenticate(hh http.Header) error { - // TODO: Deprecated - h := make(Header) - h.SetRealm(s.Realm) - h.SetAlgorithm(s.Algorithm) + // validate the header values. + ch := d.h.Clone() + ch.SetNonce(h.Nonce()) - str, err := h.WWWAuthenticate() - if err != nil { - return err + if err := ch.Compare(h); err != nil { + return nil, err } - hh.Set("WWW-Authenticate", str) - return nil + return info, nil } -// Challenge returns string indicates the authentication scheme. +// GetChallenge returns string indicates the authentication scheme. // Typically used to adds a HTTP WWW-Authenticate header. -func (s *Strategy) Challenge(realm string) string { - if len(realm) == 0 { - realm = s.Realm - } - - h := make(Header) - h.SetRealm(realm) - h.SetAlgorithm(s.Algorithm) +func (d *Digest) GetChallenge() string { + h := d.h.Clone() + str := h.WWWAuthenticate() + d.c.Store(h.Nonce(), struct{}{}) + return str +} - str, err := h.WWWAuthenticate() +func (d *Digest) hash(str string) string { + h := d.chash.New() + _, _ = h.Write([]byte(str)) + p := h.Sum(nil) + return hex.EncodeToString(p) +} - if err != nil { - panic(err) +// New returns digest authentication strategy. +// Digest strategy use MD5 as default hash. +// Digest use cache to store nonce. +func New(f FetchUser, c auth.Cache, opts ...auth.Option) *Digest { + d := new(Digest) + d.fn = f + d.chash = crypto.MD5 + d.c = c + d.h = make(Header) + d.h.SetRealm("Users") + d.h.SetAlgorithm("md5") + d.h.SetOpaque(secretKey()) + + for _, opt := range opts { + opt.Apply(d) } - return str + return d } diff --git a/auth/strategies/digest/digest_test.go b/auth/strategies/digest/digest_test.go index 5ae4566..4c60373 100644 --- a/auth/strategies/digest/digest_test.go +++ b/auth/strategies/digest/digest_test.go @@ -2,9 +2,7 @@ package digest import ( - "crypto/md5" "fmt" - "hash" "net/http" "testing" @@ -13,40 +11,14 @@ import ( "github.com/shaj13/go-guardian/auth" ) -func TestWWWAuthenticate(t *testing.T) { - s := &Strategy{ - Algorithm: "md5", - Realm: "test", - } - - h := make(http.Header) - - s.WWWAuthenticate(h) - str := h.Get("WWW-Authenticate") - assert.Contains(t, str, `qop="auth"`) - assert.Contains(t, str, "Digest") - assert.Contains(t, str, "algorithm=md5") - assert.Contains(t, str, "opaque=") - assert.Contains(t, str, "nonce=") - assert.Contains(t, str, `realm="test"`) -} - -func TestChallenge(t *testing.T) { - s := &Strategy{ - Algorithm: "md5", +func TestStartegy(t *testing.T) { + fn := func(userName string) (string, auth.Info, error) { + if userName == "error" { + return "", nil, fmt.Errorf("Error #L 51") + } + return "", nil, nil } - str := s.Challenge("test2") - - assert.Contains(t, str, `qop="auth"`) - assert.Contains(t, str, "Digest") - assert.Contains(t, str, "algorithm=md5") - assert.Contains(t, str, "opaque=") - assert.Contains(t, str, "nonce=") - assert.Contains(t, str, `realm="test2"`) -} - -func TestStartegy(t *testing.T) { table := []struct { name string authz string @@ -59,7 +31,7 @@ func TestStartegy(t *testing.T) { }, { name: "it authinticate user", - authz: `Digest username="a", realm="t", nonce="1", uri="/", cnonce="1=", nc=00000001, qop=auth, response="22cf307b29e6318dafba1fc1d564fc12", opaque="1"`, + authz: `Digest username="a", realm="t", nonce="1", uri="/", cnonce="1=", nc=00000001, qop=auth, response="22cf307b29e6318dafba1fc1d564fc12", opaque="1", algorithm="md5"`, }, { name: "it return error when username is invalid", @@ -75,51 +47,54 @@ func TestStartegy(t *testing.T) { for _, tt := range table { t.Run(tt.name, func(t *testing.T) { - s := Strategy{ - Algorithm: "md5", - Realm: "test", - FetchUser: func(userName string) (string, auth.Info, error) { - if userName == "error" { - return "", nil, fmt.Errorf("Error #L 51") - } - return "", nil, nil - }, - Hash: func(algo string) hash.Hash { - return md5.New() - }, - } - + strategy := testDigest(fn) r, _ := http.NewRequest("GEt", "/", nil) r.Header.Set("Authorization", tt.authz) - _, err := s.Authenticate(r.Context(), r) - assert.Equal(t, tt.expectedErr, err != nil) + _, err := strategy.Authenticate(r.Context(), r) + assert.Equal(t, tt.expectedErr, err != nil, err) }) } } func BenchmarkStrategy(b *testing.B) { - authz := `Digest username="a", realm="t", nonce="1", uri="/", cnonce="1=", nc=00000001, qop=auth, response="22cf307b29e6318dafba1fc1d564fc12", opaque="1"` - s := Strategy{ - Algorithm: "md5", - Realm: "test", - FetchUser: func(userName string) (string, auth.Info, error) { - return "", nil, nil - }, - Hash: func(algo string) hash.Hash { - return md5.New() - }, + authz := `Digest username="a", realm="t", nonce="1", uri="/", cnonce="1=", nc=00000001, qop=auth, response="22cf307b29e6318dafba1fc1d564fc12", opaque="1", algorithm="md5"` + fn := func(userName string) (string, auth.Info, error) { + return "", nil, nil } + strategy := testDigest(fn) + r, _ := http.NewRequest("GEt", "/", nil) r.Header.Set("Authorization", authz) b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - _, err := s.Authenticate(r.Context(), r) + _, err := strategy.Authenticate(r.Context(), r) if err != nil { b.Error(err) } } }) } + +func testDigest(fn FetchUser) auth.Strategy { + opaque := SetOpaque("1") + realm := SetRealm("t") + cache := make(mockCache) + cache.Store("1", nil) + return New(fn, cache, opaque, realm) +} + +type mockCache map[interface{}]interface{} + +func (m mockCache) Load(key interface{}) (interface{}, bool) { + v, ok := m[key] + return v, ok +} + +func (m mockCache) Store(key, value interface{}) { + m[key] = value +} + +func (m mockCache) Delete(key interface{}) {} diff --git a/auth/strategies/digest/header.go b/auth/strategies/digest/header.go index 191580d..d2aebfd 100644 --- a/auth/strategies/digest/header.go +++ b/auth/strategies/digest/header.go @@ -3,14 +3,13 @@ package digest import ( "crypto/rand" "encoding/hex" + "errors" "fmt" "strings" - - "github.com/shaj13/go-guardian/errors" ) // ErrInavlidHeader is returned by Header parse when authz header is not digest. -var ErrInavlidHeader = errors.New("Digest: Invalid Authorization Header") +var ErrInavlidHeader = errors.New("strategies/digest: Invalid Authorization Header") const ( username = "username" @@ -154,25 +153,13 @@ func (h Header) Parse(authorization string) error { } // WWWAuthenticate return string represents HTTP WWW-Authenticate header field with Digest scheme. -func (h Header) WWWAuthenticate() (string, error) { +func (h Header) WWWAuthenticate() string { if len(h.Nonce()) == 0 { - nonce, err := secretKey() - - if err != nil { - return "", err - } - - h.SetNonce(nonce) + h.SetNonce(secretKey()) } if len(h.Opaque()) == 0 { - opaque, err := secretKey() - - if err != nil { - return "", err - } - - h.SetOpaque(opaque) + h.SetOpaque(secretKey()) } h.SetQOP("auth") @@ -186,7 +173,7 @@ func (h Header) WWWAuthenticate() (string, error) { h.QOP(), ) - return s, nil + return s } // Compare server header vs client header returns error if any diff found. @@ -194,7 +181,7 @@ func (h Header) Compare(ch Header) error { for k, v := range h { cv := ch[k] if cv != v { - return fmt.Errorf("Digest: %s Does not match value in provided header", k) + return fmt.Errorf("strategies/digest: %s Does not match value in provided header", k) } } return nil @@ -211,11 +198,24 @@ func (h Header) String() string { return str[:len(str)-2] } -func secretKey() (string, error) { +// Clone returns a copy of h or nil if h is nil. +func (h Header) Clone() Header { + if h == nil { + return nil + } + h2 := make(Header, len(h)) + for k, v := range h { + h2[k] = v + } + + return h2 +} + +func secretKey() string { secret := make([]byte, 16) _, err := rand.Read(secret) if err != nil { - return "", err + panic(err) } - return hex.EncodeToString(secret), nil + return hex.EncodeToString(secret) } diff --git a/auth/strategies/digest/header_test.go b/auth/strategies/digest/header_test.go index f778c11..1923d8e 100644 --- a/auth/strategies/digest/header_test.go +++ b/auth/strategies/digest/header_test.go @@ -208,7 +208,7 @@ func TestWWWAuthenticateForHeader(t *testing.T) { h := make(Header) // Round #1 - str, _ := h.WWWAuthenticate() + str := h.WWWAuthenticate() assert.Contains(t, str, `qop="auth"`) assert.Contains(t, str, "Digest") assert.Contains(t, str, "algorithm=") @@ -219,7 +219,7 @@ func TestWWWAuthenticateForHeader(t *testing.T) { // Round #2 h.SetNonce("nounce") h.SetOpaque("opaque") - str, _ = h.WWWAuthenticate() + str = h.WWWAuthenticate() assert.Equal(t, str, `Digest realm="", nonce="nounce", opaque="opaque", algorithm=, qop="auth"`) } diff --git a/auth/strategies/digest/options.go b/auth/strategies/digest/options.go new file mode 100644 index 0000000..fda658f --- /dev/null +++ b/auth/strategies/digest/options.go @@ -0,0 +1,41 @@ +package digest + +import ( + "crypto" + + "github.com/shaj13/go-guardian/auth" +) + +// SetHash set the hashing algorithm to hash the user password. +// Default md5 +// +// SetHash(crypto.SHA1, "sha1") +// +func SetHash(h crypto.Hash, algorithm string) auth.Option { + return auth.OptionFunc(func(v interface{}) { + if d, ok := v.(*Digest); ok { + d.chash = h + d.h.SetAlgorithm(algorithm) + } + }) +} + +// SetRealm set digest authentication realm. +// Default "Users". +func SetRealm(realm string) auth.Option { + return auth.OptionFunc(func(v interface{}) { + if d, ok := v.(*Digest); ok { + d.h.SetRealm(realm) + } + }) +} + +// SetOpaque set digest opaque. +// Default random generated key. +func SetOpaque(opaque string) auth.Option { + return auth.OptionFunc(func(v interface{}) { + if d, ok := v.(*Digest); ok { + d.h.SetOpaque(opaque) + } + }) +} diff --git a/auth/strategies/kubernetes/example_test.go b/auth/strategies/kubernetes/example_test.go index ee2c031..852aa4e 100644 --- a/auth/strategies/kubernetes/example_test.go +++ b/auth/strategies/kubernetes/example_test.go @@ -5,11 +5,12 @@ import ( "net/http" "github.com/shaj13/go-guardian/auth/strategies/token" - "github.com/shaj13/go-guardian/store" + "github.com/shaj13/go-guardian/cache" + _ "github.com/shaj13/go-guardian/cache/container/idle" ) func ExampleNew() { - cache := store.New(2) + cache := cache.IDLE.NewUnsafe() kube := New(cache) r, _ := http.NewRequest("", "/", nil) _, err := kube.Authenticate(r.Context(), r) @@ -19,7 +20,7 @@ func ExampleNew() { } func ExampleGetAuthenticateFunc() { - cache := store.New(2) + cache := cache.IDLE.NewUnsafe() fn := GetAuthenticateFunc() kube := token.New(fn, cache) r, _ := http.NewRequest("", "/", nil) @@ -31,7 +32,7 @@ func ExampleGetAuthenticateFunc() { func Example() { st := SetServiceAccountToken("Service Account Token") - cache := store.New(2) + cache := cache.IDLE.NewUnsafe() kube := New(cache, st) r, _ := http.NewRequest("", "/", nil) _, err := kube.Authenticate(r.Context(), r) diff --git a/auth/strategies/kubernetes/kubernetes.go b/auth/strategies/kubernetes/kubernetes.go index f7468c1..c9aaf3f 100644 --- a/auth/strategies/kubernetes/kubernetes.go +++ b/auth/strategies/kubernetes/kubernetes.go @@ -18,7 +18,6 @@ import ( "github.com/shaj13/go-guardian/auth" "github.com/shaj13/go-guardian/auth/strategies/token" - "github.com/shaj13/go-guardian/store" ) type kubeReview struct { @@ -110,7 +109,7 @@ func GetAuthenticateFunc(opts ...auth.Option) token.AuthenticateFunc { // New return strategy authenticate request using kubernetes token review. // New is similar to token.New(). -func New(c store.Cache, opts ...auth.Option) auth.Strategy { +func New(c auth.Cache, opts ...auth.Option) auth.Strategy { fn := GetAuthenticateFunc(opts...) return token.New(fn, c, opts...) } diff --git a/auth/strategies/ldap/ldap.go b/auth/strategies/ldap/ldap.go index c402cca..b830c10 100644 --- a/auth/strategies/ldap/ldap.go +++ b/auth/strategies/ldap/ldap.go @@ -11,18 +11,13 @@ import ( "github.com/shaj13/go-guardian/auth" "github.com/shaj13/go-guardian/auth/strategies/basic" - "github.com/shaj13/go-guardian/store" "gopkg.in/ldap.v3" ) -// StrategyKey export identifier for the LDAP strategy, -// commonly used when enable/add strategy to go-guardian authenticator. -const StrategyKey = auth.StrategyKey("LDAP.Strategy") - // ErrEntries is returned by ldap authenticate function, // When search result return user DN does not exist or too many entries returned. -var ErrEntries = errors.New("LDAP: Serach user DN does not exist or too many entries returned") +var ErrEntries = errors.New("strategies/ldap: Serach user DN does not exist or too many entries returned") type conn interface { Bind(username, password string) error @@ -67,7 +62,6 @@ func dial(cfg *Config) (conn, error) { } type client struct { - auth.Strategy dial func(cfg *Config) (conn, error) cfg *Config } @@ -138,25 +132,26 @@ func (c client) authenticate(ctx context.Context, r *http.Request, userName, pas return auth.NewUserInfo(userName, id, nil, ext), nil } -func (c client) Challenge(realm string) string { - return fmt.Sprintf(`LDAP realm="%s", title="LDAP Based Authentication"`, realm) -} - -// New return new auth.Strategy. -func New(cfg *Config) auth.Strategy { +// GetAuthenticateFunc return function to authenticate request using LDAP. +// The returned function typically used with the basic strategy. +func GetAuthenticateFunc(cfg *Config, opts ...auth.Option) basic.AuthenticateFunc { cl := new(client) cl.dial = dial cl.cfg = cfg - cl.Strategy = basic.AuthenticateFunc(cl.authenticate) - return cl + return cl.authenticate +} + +// New return strategy authenticate request using LDAP. +// New is similar to Basic.New(). +func New(cfg *Config, opts ...auth.Option) auth.Strategy { + fn := GetAuthenticateFunc(cfg, opts...) + return basic.New(fn, opts...) } -// NewCached return new auth.Strategy. +// NewCached return strategy authenticate request using LDAP. // The returned strategy, caches the authentication decision. -func NewCached(cfg *Config, c store.Cache) auth.Strategy { - cl := new(client) - cl.dial = dial - cl.cfg = cfg - cl.Strategy = basic.New(cl.authenticate, c) - return cl +// New is similar to Basic.NewCached(). +func NewCached(cfg *Config, c auth.Cache, opts ...auth.Option) auth.Strategy { + fn := GetAuthenticateFunc(cfg, opts...) + return basic.NewCached(fn, c, opts...) } diff --git a/auth/strategies/ldap/ldap_test.go b/auth/strategies/ldap/ldap_test.go index ff0382c..cc86a70 100644 --- a/auth/strategies/ldap/ldap_test.go +++ b/auth/strategies/ldap/ldap_test.go @@ -147,22 +147,14 @@ func TestLdap(t *testing.T) { assert.Equal(t, tt.expectedErr, err != nil) if !tt.expectedErr { - assert.Equal(t, tt.id, info.ID()) - assert.Equal(t, tt.user, info.UserName()) + assert.Equal(t, tt.id, info.GetID()) + assert.Equal(t, tt.user, info.GetUserName()) } }) } } -func TestChallenge(t *testing.T) { - strategy := new(client) - got := strategy.Challenge("Test Realm") - expected := `LDAP realm="Test Realm", title="LDAP Based Authentication"` - - assert.Equal(t, expected, got) -} - type mockConn struct { mock.Mock } diff --git a/auth/strategies/token/cached.go b/auth/strategies/token/cached.go index 6afcecc..c0aa21e 100644 --- a/auth/strategies/token/cached.go +++ b/auth/strategies/token/cached.go @@ -5,14 +5,8 @@ import ( "net/http" "github.com/shaj13/go-guardian/auth" - "github.com/shaj13/go-guardian/errors" - "github.com/shaj13/go-guardian/store" ) -// CachedStrategyKey export identifier for the cached bearer strategy, -// commonly used when enable/add strategy to go-guardian authenticator. -const CachedStrategyKey = auth.StrategyKey("Token.Cached.Strategy") - // AuthenticateFunc declare custom function to authenticate request using token. // The authenticate function invoked by Authenticate Strategy method when // The token does not exist in the cahce and the invocation result will be cached, unless an error returned. @@ -22,13 +16,13 @@ type AuthenticateFunc func(ctx context.Context, r *http.Request, token string) ( // New return new auth.Strategy. // The returned strategy, caches the invocation result of authenticate function, See AuthenticateFunc. // Use NoOpAuthenticate to refresh/mangae token directly using cache or Append function, See NoOpAuthenticate. -func New(auth AuthenticateFunc, c store.Cache, opts ...auth.Option) auth.Strategy { +func New(auth AuthenticateFunc, c auth.Cache, opts ...auth.Option) auth.Strategy { if auth == nil { - panic("Authenticate Function required and can't be nil") + panic("strategies/token: Authenticate Function required and can't be nil") } if c == nil { - panic("Cache object required and can't be nil") + panic("strategies/token: Cache object required and can't be nil") } cached := &cachedToken{ @@ -48,7 +42,7 @@ func New(auth AuthenticateFunc, c store.Cache, opts ...auth.Option) auth.Strateg type cachedToken struct { parser Parser typ Type - cache store.Cache + cache auth.Cache authFunc AuthenticateFunc } @@ -58,45 +52,37 @@ func (c *cachedToken) Authenticate(ctx context.Context, r *http.Request) (auth.I return nil, err } - info, ok, err := c.cache.Load(token, r) - - if err != nil { - return nil, err - } + info, ok := c.cache.Load(token) // if token not found invoke user authenticate function if !ok { info, err = c.authFunc(ctx, r, token) - if err == nil { - // cache result - err = c.cache.Store(token, info, r) + if err != nil { + return nil, err } - } - - if err != nil { - return nil, err + c.cache.Store(token, info) } if _, ok := info.(auth.Info); !ok { - return nil, errors.NewInvalidType((*auth.Info)(nil), info) + return nil, auth.NewTypeError("strategies/token:", (*auth.Info)(nil), info) } return info.(auth.Info), nil } -func (c *cachedToken) Append(token string, info auth.Info, r *http.Request) error { - return c.cache.Store(token, info, r) +func (c *cachedToken) Append(token interface{}, info auth.Info) error { + c.cache.Store(token, info) + return nil } -func (c *cachedToken) Revoke(token string, r *http.Request) error { - return c.cache.Delete(token, r) +func (c *cachedToken) Revoke(token interface{}) error { + c.cache.Delete(token) + return nil } -func (c *cachedToken) Challenge(realm string) string { return challenge(realm, c.typ) } - -// NoOpAuthenticate implements Authenticate function, it return nil, auth.ErrNOOP, +// NoOpAuthenticate implements Authenticate function, it return nil, ErrNOOP, // commonly used when token refreshed/mangaed directly using cache or Append function, // and there is no need to parse token and authenticate request. func NoOpAuthenticate(ctx context.Context, r *http.Request, token string) (auth.Info, error) { - return nil, auth.ErrNOOP + return nil, ErrNOOP } diff --git a/auth/strategies/token/cached_test.go b/auth/strategies/token/cached_test.go index 5f2c20e..1247bcc 100644 --- a/auth/strategies/token/cached_test.go +++ b/auth/strategies/token/cached_test.go @@ -2,14 +2,12 @@ package token import ( "context" - "fmt" "net/http" "testing" "github.com/stretchr/testify/assert" "github.com/shaj13/go-guardian/auth" - "github.com/shaj13/go-guardian/store" ) func TestNewCahced(t *testing.T) { @@ -17,7 +15,7 @@ func TestNewCahced(t *testing.T) { name string panic bool expectedErr bool - cache store.Cache + cache auth.Cache authFunc AuthenticateFunc info interface{} token string @@ -93,7 +91,7 @@ func TestNewCahced(t *testing.T) { strategy := New(tt.authFunc, tt.cache) r, _ := http.NewRequest("GET", "/", nil) r.Header.Set("Authorization", "Bearer "+tt.token) - _ = tt.cache.Store(tt.token, tt.info, r) + tt.cache.Store(tt.token, tt.info) info, err := strategy.Authenticate(r.Context(), r) if tt.expectedErr { assert.Error(t, err) @@ -108,29 +106,18 @@ func TestCahcedTokenAppend(t *testing.T) { cache := make(mockCache) strategy := &cachedToken{cache: cache} info := auth.NewDefaultUser("1", "2", nil, nil) - strategy.Append("test-append", info, nil) - cachedInfo, ok, _ := cache.Load("test-append", nil) + strategy.Append("test-append", info) + cachedInfo, ok := cache.Load("test-append") assert.True(t, ok) assert.Equal(t, info, cachedInfo) } -func TestCahcedTokenChallenge(t *testing.T) { - strategy := &cachedToken{ - typ: Bearer, - } - - got := strategy.Challenge("Test Realm") - expected := `Bearer realm="Test Realm", title="Bearer Token Based Authentication Scheme"` - - assert.Equal(t, expected, got) -} - func BenchmarkCachedToken(b *testing.B) { r, _ := http.NewRequest("GET", "/", nil) r.Header.Set("Authorization", "Bearer token") cache := make(mockCache) - cache.Store("token", auth.NewDefaultUser("benchmark", "1", nil, nil), r) + cache.Store("token", auth.NewDefaultUser("benchmark", "1", nil, nil)) strategy := New(NoOpAuthenticate, cache) @@ -145,25 +132,15 @@ func BenchmarkCachedToken(b *testing.B) { }) } -type mockCache map[string]interface{} +type mockCache map[interface{}]interface{} -func (m mockCache) Load(key string, _ *http.Request) (interface{}, bool, error) { - if key == "error" { - return nil, false, fmt.Errorf("Load Error") - } +func (m mockCache) Load(key interface{}) (interface{}, bool) { v, ok := m[key] - return v, ok, nil + return v, ok } -func (m mockCache) Store(key string, value interface{}, _ *http.Request) error { - if key == "store-error" { - return fmt.Errorf("Store Error") - } +func (m mockCache) Store(key, value interface{}) { m[key] = value - return nil } -func (m mockCache) Keys() []string { return nil } -func (m mockCache) Delete(key string, _ *http.Request) error { - return nil -} +func (m mockCache) Delete(key interface{}) {} diff --git a/auth/strategies/token/example_test.go b/auth/strategies/token/example_test.go index 23cc862..fa2e3df 100644 --- a/auth/strategies/token/example_test.go +++ b/auth/strategies/token/example_test.go @@ -4,10 +4,10 @@ import ( "context" "fmt" "net/http" - "time" "github.com/shaj13/go-guardian/auth" - "github.com/shaj13/go-guardian/store" + "github.com/shaj13/go-guardian/cache" + _ "github.com/shaj13/go-guardian/cache/container/lru" ) func ExampleNewStaticFromFile() { @@ -15,7 +15,7 @@ func ExampleNewStaticFromFile() { r, _ := http.NewRequest("GET", "/", nil) r.Header.Set("Authorization", "Bearer testUserToken") info, err := strategy.Authenticate(r.Context(), r) - fmt.Println(err, info.ID()) + fmt.Println(err, info.GetID()) // Output: // 1 } @@ -27,15 +27,12 @@ func ExampleNewStatic() { r, _ := http.NewRequest("GET", "/", nil) r.Header.Set("Authorization", "Bearer 90d64460d14870c08c81352a05dedd3465940a7") info, err := strategy.Authenticate(r.Context(), r) - fmt.Println(err, info.ID()) + fmt.Println(err, info.GetID()) // Output: // 1 } func ExampleNew() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - authFunc := AuthenticateFunc(func(ctx context.Context, r *http.Request, token string) (auth.Info, error) { fmt.Print("authFunc called ") if token == "90d64460d14870c08c81352a05dedd3465940a7" { @@ -44,7 +41,7 @@ func ExampleNew() { return nil, fmt.Errorf("Invalid user token") }) - cache := store.NewFIFO(ctx, time.Minute*5) + cache := cache.LRU.New() strategy := New(authFunc, cache) r, _ := http.NewRequest("GET", "/", nil) @@ -52,41 +49,30 @@ func ExampleNew() { // first request when authentication decision not cached info, err := strategy.Authenticate(r.Context(), r) - fmt.Println(err, info.ID()) + fmt.Println(err, info.GetID()) // second request where authentication decision cached and authFunc will not be called info, err = strategy.Authenticate(r.Context(), r) - fmt.Println(err, info.ID()) + fmt.Println(err, info.GetID()) // Output: // authFunc called 1 // 1 } func ExampleNoOpAuthenticate() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - cache := store.NewFIFO(ctx, time.Microsecond*500) + cache := cache.LRU.New() strategy := New(NoOpAuthenticate, cache) - // demonstrate a user attempt to login - r, _ := http.NewRequest("GET", "/login", nil) // user verified and add the user token to strategy using append or cache - cache.Store("token", auth.NewDefaultUser("example", "1", nil, nil), r) + cache.Store("token", auth.NewDefaultUser("example", "1", nil, nil)) // first request where authentication decision added to cached - r, _ = http.NewRequest("GET", "/login", nil) + r, _ := http.NewRequest("GET", "/", nil) r.Header.Set("Authorization", "Bearer token") info, err := strategy.Authenticate(r.Context(), r) - fmt.Println(err, info.ID()) - - // second request where authentication decision expired and user must login again - time.Sleep(time.Second) - info, err = strategy.Authenticate(r.Context(), r) - fmt.Println(err, info) + fmt.Println(err, info.GetID()) // Output: // 1 - // NOOP } func ExampleAuthorizationParser() { @@ -140,7 +126,7 @@ func ExampleNewStatic_apikey() { } strategy := NewStatic(tokens, opt) info, err := strategy.Authenticate(r.Context(), r) - fmt.Println(info.UserName(), err) + fmt.Println(info.GetUserName(), err) // Output: // example @@ -151,9 +137,6 @@ func ExampleNew_apikey() { parser := QueryParser("api_key") opt := SetParser(parser) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - authFunc := AuthenticateFunc(func(ctx context.Context, r *http.Request, token string) (auth.Info, error) { if token == "token" { return auth.NewDefaultUser("example", "1", nil, nil), nil @@ -161,11 +144,11 @@ func ExampleNew_apikey() { return nil, fmt.Errorf("Invalid user token") }) - cache := store.NewFIFO(ctx, time.Minute*5) + cache := cache.LRU.New() strategy := New(authFunc, cache, opt) info, err := strategy.Authenticate(r.Context(), r) - fmt.Println(info.UserName(), err) + fmt.Println(info.GetUserName(), err) // Output: // example } diff --git a/auth/strategies/token/parser.go b/auth/strategies/token/parser.go index a1e0677..3f6bff5 100644 --- a/auth/strategies/token/parser.go +++ b/auth/strategies/token/parser.go @@ -3,7 +3,7 @@ package token import ( "net/http" - "github.com/shaj13/go-guardian/internal" + "github.com/shaj13/go-guardian/auth/internal" ) // Parser parse and extract token from incoming HTTP request. @@ -52,3 +52,12 @@ func CookieParser(key string) Parser { return tokenFn(fn) } + +// JSONBodyParser return a token parser, where token extracted extracted form request body. +func JSONBodyParser(key string) Parser { + fn := func(r *http.Request) (string, error) { + return internal.ParseJSONBody(key, r, ErrInvalidToken) + } + + return tokenFn(fn) +} diff --git a/auth/strategies/token/static.go b/auth/strategies/token/static.go index 788bc2a..68127c4 100644 --- a/auth/strategies/token/static.go +++ b/auth/strategies/token/static.go @@ -13,30 +13,26 @@ import ( "github.com/shaj13/go-guardian/auth" ) -// StatitcStrategyKey export identifier for the static bearer strategy, -// commonly used when enable/add strategy to go-guardian authenticator. -const StatitcStrategyKey = auth.StrategyKey("Token.Static.Strategy") - // Static implements auth.Strategy and define a synchronized map honor all predefined bearer tokens. -type Static struct { - MU *sync.Mutex - Tokens map[string]auth.Info - Type Type - Parser Parser +type static struct { + mu *sync.Mutex + tokens map[string]auth.Info + ttype Type + parser Parser } // Authenticate user request against predefined tokens by verifying request token existence in the static Map. // Once token found auth.Info returned with a nil error, // Otherwise, a nil auth.Info and ErrTokenNotFound returned. -func (s *Static) Authenticate(ctx context.Context, r *http.Request) (auth.Info, error) { - token, err := s.Parser.Token(r) +func (s *static) Authenticate(ctx context.Context, r *http.Request) (auth.Info, error) { + token, err := s.parser.Token(r) if err != nil { return nil, err } - s.MU.Lock() - defer s.MU.Unlock() - info, ok := s.Tokens[token] + s.mu.Lock() + defer s.mu.Unlock() + info, ok := s.tokens[token] if !ok { return nil, ErrTokenNotFound @@ -46,25 +42,21 @@ func (s *Static) Authenticate(ctx context.Context, r *http.Request) (auth.Info, } // Append add new token to static store. -func (s *Static) Append(token string, info auth.Info, _ *http.Request) error { - s.MU.Lock() - defer s.MU.Unlock() - s.Tokens[token] = info +func (s *static) Append(token interface{}, info auth.Info) error { + s.mu.Lock() + defer s.mu.Unlock() + s.tokens[token.(string)] = info return nil } // Revoke delete token from static store. -func (s *Static) Revoke(token string, _ *http.Request) error { - s.MU.Lock() - defer s.MU.Unlock() - delete(s.Tokens, token) +func (s *static) Revoke(token interface{}) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.tokens, token.(string)) return nil } -// Challenge returns string indicates the authentication scheme. -// Typically used to adds a HTTP WWW-Authenticate header. -func (s *Static) Challenge(realm string) string { return challenge(realm, s.Type) } - // NewStaticFromFile returns static auth.Strategy, populated from a CSV file. // The CSV file must contain records in one of following formats // basic record: `token,username,userid` @@ -94,20 +86,20 @@ func NewStaticFromFile(path string, opts ...auth.Option) (auth.Strategy, error) if len(record) < 3 { return nil, fmt.Errorf( - "static: record must have at least 3 columns (token, username, id), Record: %v", + "strategies/token: record must have at least 3 columns (token, username, id), Record: %v", record, ) } if record[0] == "" { - return nil, fmt.Errorf("static: a non empty token is required, Record: %v", record) + return nil, fmt.Errorf("strategies/token: a non empty token is required, Record: %v", record) } // if token Contains Bearer remove it record[0] = strings.TrimPrefix(record[0], "Bearer ") if _, ok := tokens[record[0]]; ok { - return nil, fmt.Errorf("static: token already exists, Record: %v", record) + return nil, fmt.Errorf("strategies/token: token already exists, Record: %v", record) } info := auth.NewUserInfo(record[1], record[2], nil, nil) @@ -137,11 +129,11 @@ func NewStaticFromFile(path string, opts ...auth.Option) (auth.Strategy, error) // NewStatic returns static auth.Strategy, populated from a map. func NewStatic(tokens map[string]auth.Info, opts ...auth.Option) auth.Strategy { - static := &Static{ - Tokens: tokens, - MU: new(sync.Mutex), - Type: Bearer, - Parser: AuthorizationParser(string(Bearer)), + static := &static{ + tokens: tokens, + mu: new(sync.Mutex), + ttype: Bearer, + parser: AuthorizationParser(string(Bearer)), } for _, opt := range opts { diff --git a/auth/strategies/token/static_test.go b/auth/strategies/token/static_test.go index 5769733..7d0cb61 100644 --- a/auth/strategies/token/static_test.go +++ b/auth/strategies/token/static_test.go @@ -76,10 +76,10 @@ func TestNewStaticFromFile(t *testing.T) { r.Header.Set("Authorization", "Bearer "+k) info, err := strategy.Authenticate(r.Context(), r) assert.NoError(t, err) - assert.EqualValues(t, v.ID(), info.ID()) - assert.EqualValues(t, v.UserName(), info.UserName()) - assert.EqualValues(t, v.Groups(), info.Groups()) - assert.EqualValues(t, v.Extensions(), info.Extensions()) + assert.EqualValues(t, v.GetID(), info.GetID()) + assert.EqualValues(t, v.GetUserName(), info.GetUserName()) + assert.EqualValues(t, v.GetGroups(), info.GetGroups()) + assert.EqualValues(t, v.GetExtensions(), info.GetExtensions()) } }) } @@ -108,24 +108,13 @@ func TestNewStatic(t *testing.T) { info, err := strategy.Authenticate(context.Background(), r) assert.NoError(t, err) - assert.EqualValues(t, v.ID(), info.ID()) - assert.EqualValues(t, v.UserName(), info.UserName()) - assert.EqualValues(t, v.Groups(), info.Groups()) - assert.EqualValues(t, v.Extensions(), info.Extensions()) + assert.EqualValues(t, v.GetID(), info.GetID()) + assert.EqualValues(t, v.GetUserName(), info.GetUserName()) + assert.EqualValues(t, v.GetGroups(), info.GetGroups()) + assert.EqualValues(t, v.GetExtensions(), info.GetExtensions()) } } -func TestStaticChallenge(t *testing.T) { - strategy := &Static{ - Type: Bearer, - } - - got := strategy.Challenge("Test Realm") - expected := `Bearer realm="Test Realm", title="Bearer Token Based Authentication Scheme"` - - assert.Equal(t, expected, got) -} - func BenchmarkStaticToken(b *testing.B) { r, _ := http.NewRequest("GET", "/", nil) r.Header.Set("Authorization", "Bearer token") diff --git a/auth/strategies/token/token.go b/auth/strategies/token/token.go index 3f01d28..add7567 100644 --- a/auth/strategies/token/token.go +++ b/auth/strategies/token/token.go @@ -1,8 +1,9 @@ +// Package token provides authentication strategy, +// to authenticate HTTP requests based on token. package token import ( "errors" - "fmt" "github.com/shaj13/go-guardian/auth" ) @@ -11,9 +12,15 @@ var ( // ErrInvalidToken indicate a hit of an invalid token format. // And it's returned by Token Parser. ErrInvalidToken = errors.New("strategies/token: Invalid token") + // ErrTokenNotFound is returned by authenticating functions for token strategies, // when token not found in their store. ErrTokenNotFound = errors.New("strategies/token: Token does not exists") + + // ErrNOOP is a soft error similar to EOF, + // returned by NoOpAuthenticate function to indicate there no op, + // and signal the caller to unauthenticate the request. + ErrNOOP = errors.New("strategies/token: NOOP") ) // Type is Authentication token type or scheme. A common type is Bearer. @@ -31,8 +38,8 @@ const ( func SetType(t Type) auth.Option { return auth.OptionFunc(func(v interface{}) { switch v := v.(type) { - case *Static: - v.Type = t + case *static: + v.ttype = t case *cachedToken: v.typ = t } @@ -43,14 +50,10 @@ func SetType(t Type) auth.Option { func SetParser(p Parser) auth.Option { return auth.OptionFunc(func(v interface{}) { switch v := v.(type) { - case *Static: - v.Parser = p + case *static: + v.parser = p case *cachedToken: v.parser = p } }) } - -func challenge(realm string, t Type) string { - return fmt.Sprintf(`%s realm="%s", title="%s Token Based Authentication Scheme"`, t, realm, t) -} diff --git a/auth/strategies/token/token_test.go b/auth/strategies/token/token_test.go index 3dd3308..e090a78 100644 --- a/auth/strategies/token/token_test.go +++ b/auth/strategies/token/token_test.go @@ -8,7 +8,7 @@ import ( func TestSetType(t *testing.T) { cached := new(cachedToken) - static := new(Static) + static := new(static) typ := Type("test") opt := SetType(typ) @@ -17,12 +17,12 @@ func TestSetType(t *testing.T) { opt.Apply(static) assert.Equal(t, cached.typ, typ) - assert.Equal(t, static.Type, typ) + assert.Equal(t, static.ttype, typ) } func TestSetParser(t *testing.T) { cached := new(cachedToken) - static := new(Static) + static := new(static) p := XHeaderParser("") opt := SetParser(p) @@ -31,5 +31,5 @@ func TestSetParser(t *testing.T) { opt.Apply(static) assert.True(t, cached.parser != nil) - assert.True(t, static.Parser != nil) + assert.True(t, static.parser != nil) } diff --git a/auth/strategies/twofactor/example_test.go b/auth/strategies/twofactor/example_test.go index bd8abd5..932a01b 100644 --- a/auth/strategies/twofactor/example_test.go +++ b/auth/strategies/twofactor/example_test.go @@ -15,24 +15,24 @@ type OTPManager struct{} func (OTPManager) Enabled(_ auth.Info) bool { return true } -func (OTPManager) Load(_ auth.Info) (twofactor.OTP, error) { +func (OTPManager) Load(_ auth.Info) (twofactor.Verifier, error) { // user otp configuration must be loaded from persistent storage key := otp.NewKey(otp.HOTP, "LABEL", "GXNRHI2MFRFWXQGJHWZJFOSYI6E7MEVA") ver := otp.New(key) return ver, nil } -func (OTPManager) Store(_ auth.Info, o twofactor.OTP) error { +func (OTPManager) Store(_ auth.Info, o twofactor.Verifier) error { // persist user otp after verification fmt.Println("Failures: ", o.(*otp.Verifier).Failures) return nil } func Example() { - strategy := twofactor.Strategy{ + strategy := twofactor.TwoFactor{ Parser: twofactor.XHeaderParser("X-Example-OTP"), Manager: OTPManager{}, - Primary: basic.AuthenticateFunc( + Primary: basic.New( func(ctx context.Context, r *http.Request, userName, password string) (auth.Info, error) { return auth.NewDefaultUser("example", "1", nil, nil), nil }, @@ -44,7 +44,7 @@ func Example() { r.Header.Set("X-Example-OTP", "345515") info, err := strategy.Authenticate(r.Context(), r) - fmt.Println(info.UserName(), err) + fmt.Println(info.GetUserName(), err) // Output: // Failures: 0 diff --git a/auth/strategies/twofactor/parser.go b/auth/strategies/twofactor/parser.go index 9e16c1f..1e17326 100644 --- a/auth/strategies/twofactor/parser.go +++ b/auth/strategies/twofactor/parser.go @@ -4,56 +4,56 @@ import ( "errors" "net/http" - "github.com/shaj13/go-guardian/internal" + "github.com/shaj13/go-guardian/auth/internal" ) -// ErrMissingPin is returned by Parser, +// ErrMissingOTP is returned by Parser, // When one-time password missing or empty in HTTP request. -var ErrMissingPin = errors.New("strategies/twofactor: One-time password missing or empty") +var ErrMissingOTP = errors.New("strategies/twofactor: One-time password missing or empty") // Parser parse and extract one-time password from incoming HTTP request. type Parser interface { - PinCode(r *http.Request) (string, error) + GetOTP(r *http.Request) (string, error) } -type pinFn func(r *http.Request) (string, error) +type otpFn func(r *http.Request) (string, error) -func (fn pinFn) PinCode(r *http.Request) (string, error) { +func (fn otpFn) GetOTP(r *http.Request) (string, error) { return fn(r) } -// XHeaderParser return a one-time password parser, where pin extracted form "X-" header. +// XHeaderParser return a one-time password parser, where otp extracted form "X-" header. func XHeaderParser(header string) Parser { fn := func(r *http.Request) (string, error) { - return internal.ParseHeader(header, r, ErrMissingPin) + return internal.ParseHeader(header, r, ErrMissingOTP) } - return pinFn(fn) + return otpFn(fn) } -// JSONBodyParser return a one-time password parser, where pin extracted form request body. +// JSONBodyParser return a one-time password parser, where otp extracted form request body. func JSONBodyParser(key string) Parser { fn := func(r *http.Request) (string, error) { - return internal.ParseJSONBody(key, r, ErrMissingPin) + return internal.ParseJSONBody(key, r, ErrMissingOTP) } - return pinFn(fn) + return otpFn(fn) } -// QueryParser return a one-time password parser, where pin extracted form HTTP query string. +// QueryParser return a one-time password parser, where otp extracted form HTTP query string. func QueryParser(key string) Parser { fn := func(r *http.Request) (string, error) { - return internal.ParseQuery(key, r, ErrMissingPin) + return internal.ParseQuery(key, r, ErrMissingOTP) } - return pinFn(fn) + return otpFn(fn) } -// CookieParser return a one-time password parser, where pin extracted form HTTP Cookie. +// CookieParser return a one-time password parser, where otp extracted form HTTP Cookie. func CookieParser(key string) Parser { fn := func(r *http.Request) (string, error) { - return internal.ParseCookie(key, r, ErrMissingPin) + return internal.ParseCookie(key, r, ErrMissingOTP) } - return pinFn(fn) + return otpFn(fn) } diff --git a/auth/strategies/twofactor/parser_test.go b/auth/strategies/twofactor/parser_test.go index 66e4468..62e9afb 100644 --- a/auth/strategies/twofactor/parser_test.go +++ b/auth/strategies/twofactor/parser_test.go @@ -22,7 +22,7 @@ func TestParser(t *testing.T) { parser := XHeaderParser("X-OTP") return parser, req }, - err: ErrMissingPin, + err: ErrMissingOTP, pin: "", }, { @@ -43,7 +43,7 @@ func TestParser(t *testing.T) { parser := QueryParser("otp") return parser, req }, - err: ErrMissingPin, + err: ErrMissingOTP, pin: "", }, { @@ -75,7 +75,7 @@ func TestParser(t *testing.T) { parser := CookieParser("otp") return parser, req }, - err: ErrMissingPin, + err: ErrMissingOTP, pin: "", }, { @@ -108,7 +108,7 @@ func TestParser(t *testing.T) { parser := JSONBodyParser("otp") return parser, req }, - err: ErrMissingPin, + err: ErrMissingOTP, pin: "", }, } @@ -116,7 +116,7 @@ func TestParser(t *testing.T) { for _, tt := range table { t.Run(tt.name, func(t *testing.T) { p, r := tt.prepare() - pin, err := p.PinCode(r) + pin, err := p.GetOTP(r) assert.Equal(t, tt.pin, pin) assert.Equal(t, tt.err, err) diff --git a/auth/strategies/twofactor/twofactor.go b/auth/strategies/twofactor/twofactor.go index da23f58..35f30be 100644 --- a/auth/strategies/twofactor/twofactor.go +++ b/auth/strategies/twofactor/twofactor.go @@ -1,3 +1,5 @@ +// Package twofactor provides authentication strategy, +// to authenticate HTTP requests based on one time password(otp). package twofactor import ( @@ -8,62 +10,58 @@ import ( "github.com/shaj13/go-guardian/auth" ) -// StrategyKey export identifier for the two factor strategy, -// commonly used when enable/add strategy to go-guardian authenticator. -const StrategyKey = auth.StrategyKey("2FA.Strategy") - -// ErrInvalidPin is returned by strategy, +// ErrInvalidOTP is returned by twofactor strategy, // When the user-supplied an invalid one time password and verification process failed. -var ErrInvalidPin = errors.New("strategies/twofactor: Invalid one time password") +var ErrInvalidOTP = errors.New("strategies/twofactor: Invalid one time password") -// OTP represents one-time password verification. -type OTP interface { +// Verifier represents one-time password verification. +type Verifier interface { // Verify user one-time password. Verify(pin string) (bool, error) } -// OTPManager load and store user OTP. -type OTPManager interface { +// Manager load and store user OTP Verifier. +type Manager interface { // Enabled check if two factor for user enabled. Enabled(user auth.Info) bool - // Load return user OTP or error. - Load(user auth.Info) (OTP, error) - // Store user OTP. - Store(user auth.Info, otp OTP) error + // Load return user OTP Verifier or error. + Load(user auth.Info) (Verifier, error) + // Store user OTP Verifier. + Store(user auth.Info, v Verifier) error } -// Strategy represents two factor authentication strategy. -type Strategy struct { +// TwoFactor represents two factor authentication strategy. +type TwoFactor struct { // Primary strategy that authenticates the user before verifying the one time password. // The primary strategy Typically of type basic or LDAP. Primary auth.Strategy Parser Parser - Manager OTPManager + Manager Manager } // Authenticate returns user info or error by authenticating request using primary strategy, // and then verifying one-time password. -func (s Strategy) Authenticate(ctx context.Context, r *http.Request) (auth.Info, error) { - info, err := s.Primary.Authenticate(ctx, r) +func (t TwoFactor) Authenticate(ctx context.Context, r *http.Request) (auth.Info, error) { + info, err := t.Primary.Authenticate(ctx, r) if err != nil { return nil, err } - if !s.Manager.Enabled(info) { + if !t.Manager.Enabled(info) { return info, nil } - pin, err := s.Parser.PinCode(r) + pin, err := t.Parser.GetOTP(r) if err != nil { return nil, err } - otp, err := s.Manager.Load(info) + otp, err := t.Manager.Load(info) if err != nil { return nil, err } - defer s.Manager.Store(info, otp) + defer t.Manager.Store(info, otp) ok, err := otp.Verify(pin) if err != nil { @@ -71,7 +69,7 @@ func (s Strategy) Authenticate(ctx context.Context, r *http.Request) (auth.Info, } if !ok { - return nil, ErrInvalidPin + return nil, ErrInvalidOTP } return info, nil diff --git a/auth/strategies/twofactor/twofactor_test.go b/auth/strategies/twofactor/twofactor_test.go index 7047f4d..3eb6cf8 100644 --- a/auth/strategies/twofactor/twofactor_test.go +++ b/auth/strategies/twofactor/twofactor_test.go @@ -18,51 +18,51 @@ func TestStrategy(t *testing.T) { err error pin string expectedInfo bool - prepare func(t *testing.T) Strategy + prepare func(t *testing.T) TwoFactor }{ { name: "it return error when primary strategy return error", err: fmt.Errorf("primary strategy error"), - prepare: func(t *testing.T) Strategy { + prepare: func(t *testing.T) TwoFactor { m := &mockStrategy{mock.Mock{}} m. On("Authenticate"). Return(nil, fmt.Errorf("primary strategy error")) - return Strategy{Primary: m} + return TwoFactor{Primary: m} }, }, { name: "it return info when user does not enabled tfa", err: nil, expectedInfo: true, - prepare: func(t *testing.T) Strategy { + prepare: func(t *testing.T) TwoFactor { m := &mockStrategy{mock.Mock{}} m.On("Authenticate").Return(nil, nil) mng := &mockManager{mock.Mock{}} mng.On("Enabled").Return(false) - return Strategy{Primary: m, Manager: mng} + return TwoFactor{Primary: m, Manager: mng} }, }, { name: "it return error when pin missing", - err: ErrMissingPin, - prepare: func(t *testing.T) Strategy { + err: ErrMissingOTP, + prepare: func(t *testing.T) TwoFactor { m := &mockStrategy{mock.Mock{}} m.On("Authenticate").Return(nil, nil) mng := &mockManager{mock.Mock{}} mng.On("Enabled").Return(true) - return Strategy{Primary: m, Manager: mng} + return TwoFactor{Primary: m, Manager: mng} }, }, { name: "it return error when manager load return error", err: fmt.Errorf("manager load error"), pin: "123456", - prepare: func(t *testing.T) Strategy { + prepare: func(t *testing.T) TwoFactor { m := &mockStrategy{mock.Mock{}} m.On("Authenticate").Return(nil, nil) @@ -70,14 +70,14 @@ func TestStrategy(t *testing.T) { mng.On("Enabled").Return(true) mng.On("Load").Return(new(mockOTP), fmt.Errorf("manager load error")) - return Strategy{Primary: m, Manager: mng} + return TwoFactor{Primary: m, Manager: mng} }, }, { name: "it return error when otp Verify return error", err: fmt.Errorf("OTP Error"), pin: "123456", - prepare: func(t *testing.T) Strategy { + prepare: func(t *testing.T) TwoFactor { m := &mockStrategy{mock.Mock{}} m.On("Authenticate").Return(nil, nil) @@ -89,14 +89,14 @@ func TestStrategy(t *testing.T) { mng.On("Load").Return(otp, nil) mng.On("Store").Return(nil) - return Strategy{Primary: m, Manager: mng} + return TwoFactor{Primary: m, Manager: mng} }, }, { name: "it return error when otp Verify return false", - err: ErrInvalidPin, + err: ErrInvalidOTP, pin: "123456", - prepare: func(t *testing.T) Strategy { + prepare: func(t *testing.T) TwoFactor { m := &mockStrategy{mock.Mock{}} m.On("Authenticate").Return(nil, nil) @@ -108,7 +108,7 @@ func TestStrategy(t *testing.T) { mng.On("Load").Return(otp, nil) mng.On("Store").Return(nil) - return Strategy{Primary: m, Manager: mng} + return TwoFactor{Primary: m, Manager: mng} }, }, { @@ -116,7 +116,7 @@ func TestStrategy(t *testing.T) { err: nil, pin: "123456", expectedInfo: true, - prepare: func(t *testing.T) Strategy { + prepare: func(t *testing.T) TwoFactor { m := &mockStrategy{mock.Mock{}} m.On("Authenticate").Return(nil, nil) @@ -128,7 +128,7 @@ func TestStrategy(t *testing.T) { mng.On("Load").Return(otp, nil) mng.On("Store").Return(nil) - return Strategy{Primary: m, Manager: mng} + return TwoFactor{Primary: m, Manager: mng} }, }, } @@ -168,12 +168,12 @@ func (m *mockManager) Enabled(user auth.Info) bool { return args.Bool(0) } -func (m *mockManager) Load(user auth.Info) (OTP, error) { +func (m *mockManager) Load(user auth.Info) (Verifier, error) { args := m.Called() - return args.Get(0).(OTP), args.Error(1) + return args.Get(0).(Verifier), args.Error(1) } -func (m *mockManager) Store(user auth.Info, otp OTP) error { +func (m *mockManager) Store(user auth.Info, otp Verifier) error { args := m.Called() return args.Error(0) } diff --git a/auth/strategies/union/union.go b/auth/strategies/union/union.go new file mode 100644 index 0000000..c7c56eb --- /dev/null +++ b/auth/strategies/union/union.go @@ -0,0 +1,63 @@ +package union + +import ( + "context" + "net/http" + + "github.com/shaj13/go-guardian/auth" +) + +// MultiError represent multiple errors that occur when attempting to authenticate a request. +type MultiError []error + +func (errs MultiError) Error() string { + if len(errs) == 0 { + return "" + } + + str := "" + for _, err := range errs { + str += err.Error() + ", " + } + + return "strategies/union: [" + str[:len(str)-2] + "]" +} + +// Union implements authentication strategy, +// and consolidate a chain of strategies. +type Union interface { + auth.Strategy + // AuthenticateRequest authenticates the request using a chain of strategies. + // AuthenticateRequest returns user info alongside the successful strategy. + AuthenticateRequest(r *http.Request) (auth.Strategy, auth.Info, error) + // Chain returns chain of strategies + Chain() []auth.Strategy +} + +type union []auth.Strategy + +func (u union) Authenticate(ctx context.Context, r *http.Request) (auth.Info, error) { + _, info, err := u.AuthenticateRequest(r) + return info, err +} + +func (u union) AuthenticateRequest(r *http.Request) (auth.Strategy, auth.Info, error) { + errs := MultiError{} + for _, s := range u { + info, err := s.Authenticate(r.Context(), r) + if err == nil { + return s, info, nil + } + errs = append(errs, err) + } + return nil, nil, errs +} + +func (u union) Chain() []auth.Strategy { + return u +} + +// New returns new union strategy. +func New(strategies ...auth.Strategy) Union { + return union(strategies) +} diff --git a/auth/strategies/x509/example_test.go b/auth/strategies/x509/example_test.go index 0a7882d..0ab3948 100644 --- a/auth/strategies/x509/example_test.go +++ b/auth/strategies/x509/example_test.go @@ -8,7 +8,6 @@ import ( "testing" "github.com/shaj13/go-guardian/auth" - "github.com/shaj13/go-guardian/errors" ) func Example() { @@ -20,21 +19,19 @@ func Example() { // create strategy and authenticator strategy := New(opts) - authenticator := auth.New() - authenticator.EnableStrategy(StrategyKey, strategy) // user request req, _ := http.NewRequest("GET", "/", nil) req.TLS = &tls.ConnectionState{PeerCertificates: readCertificates("client_valid")} // validate request - info, err := authenticator.Authenticate(req) - fmt.Println(info.UserName(), err) + info, err := strategy.Authenticate(req.Context(), req) + fmt.Println(info.GetUserName(), err) // validate expired client certificate req.TLS = &tls.ConnectionState{PeerCertificates: readCertificates("client_expired")} - info, err = authenticator.Authenticate(req) - fmt.Println(info, err.(errors.MultiError)[1]) + info, err = strategy.Authenticate(req.Context(), req) + fmt.Println(info, err) // Output: // host.test.com @@ -48,21 +45,20 @@ func ExampleInfoBuilder() { // Read Root Ca Certificate opts.Roots.AddCert(readCertificates("ca")[0]) - // create strategy and authenticator - strategy := New(opts) - Builder = InfoBuilder(func(chain [][]*x509.Certificate) (auth.Info, error) { + builder := SetInfoBuilder(func(chain [][]*x509.Certificate) (auth.Info, error) { return auth.NewDefaultUser("user-info-builder", "10", nil, nil), nil }) - authenticator := auth.New() - authenticator.EnableStrategy(StrategyKey, strategy) + + // create strategy and authenticator + strategy := New(opts, builder) // user request req, _ := http.NewRequest("GET", "/", nil) req.TLS = &tls.ConnectionState{PeerCertificates: readCertificates("client_valid")} // validate request - info, err := authenticator.Authenticate(req) - fmt.Println(info.UserName(), err) + info, err := strategy.Authenticate(req.Context(), req) + fmt.Println(info.GetUserName(), err) // Output: // user-info-builder diff --git a/auth/strategies/x509/x509.go b/auth/strategies/x509/x509.go index 8ac1201..820fd33 100644 --- a/auth/strategies/x509/x509.go +++ b/auth/strategies/x509/x509.go @@ -6,31 +6,26 @@ import ( "context" "crypto/x509" "errors" - "fmt" "net/http" "github.com/shaj13/go-guardian/auth" ) -// StrategyKey export identifier for the x509 strategy, -// commonly used when enable/add strategy to go-guardian authenticator. -const StrategyKey = auth.StrategyKey("x509.Strategy") - var ( // ErrMissingCN is returned by DefaultBuilder when Certificate CommonName missing. - ErrMissingCN = errors.New("x509.strategy: Certificate subject CN missing") + ErrMissingCN = errors.New("strategies/x509: Certificate subject CN missing") // ErrInvalidRequest is returned by x509 strategy when a non TLS request received. - ErrInvalidRequest = errors.New("x509.strategy: Invalid request, missing TLS parameters") + ErrInvalidRequest = errors.New("strategy/x509: Invalid request, missing TLS parameters") ) // InfoBuilder declare a function signature for building Info from certificate chain. type InfoBuilder func(chain [][]*x509.Certificate) (auth.Info, error) -// Builder define default InfoBuilder by building Info from certificate chain subject. +// builder define default InfoBuilder by building Info from certificate chain subject. // where the subject values mapped in the following format, // CommonName to UserName, SerialNumber to ID, Organization to groups // and country, postalCode, streetAddress, locality, province mapped to Extensions. -var Builder = InfoBuilder(func(chain [][]*x509.Certificate) (auth.Info, error) { +func builder(chain [][]*x509.Certificate) (auth.Info, error) { subject := chain[0][0].Subject if len(subject.CommonName) == 0 { @@ -51,18 +46,21 @@ var Builder = InfoBuilder(func(chain [][]*x509.Certificate) (auth.Info, error) { subject.Organization, exts, ), nil -}) +} -type authenticateFunc func() x509.VerifyOptions +type strategy struct { + fn func() x509.VerifyOptions + builder InfoBuilder +} -func (f authenticateFunc) Authenticate(ctx context.Context, r *http.Request) (auth.Info, error) { +func (s strategy) Authenticate(ctx context.Context, r *http.Request) (auth.Info, error) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { return nil, ErrInvalidRequest } // get verify options shallow copy - opts := f() + opts := s.fn() // copy intermediates certificates to verify options from request if needed. // ignore r.TLS.PeerCertificates[0] it refer to client certificates. @@ -79,14 +77,25 @@ func (f authenticateFunc) Authenticate(ctx context.Context, r *http.Request) (au return nil, err } - return Builder(chain) + return s.builder(chain) } -func (f authenticateFunc) Challenge(realm string) string { - return fmt.Sprintf(`X.509 realm="%s", title="Certificate Based Authentication"`, realm) +// New returns auth.Strategy authenticate request from client certificates +func New(vopt x509.VerifyOptions, opts ...auth.Option) auth.Strategy { + s := new(strategy) + s.fn = func() x509.VerifyOptions { return vopt } + s.builder = builder + for _, opt := range opts { + opt.Apply(s) + } + return s } -// New returns auth.Strategy authenticate request from client certificates -func New(opts x509.VerifyOptions) auth.Strategy { - return authenticateFunc(func() x509.VerifyOptions { return opts }) +// SetInfoBuilder sets x509 info builder. +func SetInfoBuilder(builder InfoBuilder) auth.Option { + return auth.OptionFunc(func(v interface{}) { + if s, ok := v.(*strategy); ok { + s.builder = builder + } + }) } diff --git a/auth/strategies/x509/x509_test.go b/auth/strategies/x509/x509_test.go index 0446f1d..7100072 100644 --- a/auth/strategies/x509/x509_test.go +++ b/auth/strategies/x509/x509_test.go @@ -8,8 +8,6 @@ import ( "net/http" "strings" "testing" - - "github.com/stretchr/testify/assert" ) func Test(t *testing.T) { @@ -99,17 +97,6 @@ func Test(t *testing.T) { } } -func TestChallenge(t *testing.T) { - strategy := authenticateFunc(func() x509.VerifyOptions { - return x509.VerifyOptions{} - }) - - got := strategy.Challenge("Test Realm") - expected := `X.509 realm="Test Realm", title="Certificate Based Authentication"` - - assert.Equal(t, expected, got) -} - func BenchmarkX509(b *testing.B) { opts := testVerifyOptions(b) strategy := New(opts) diff --git a/auth/strategy.go b/auth/strategy.go index 9c6acf4..03a786b 100644 --- a/auth/strategy.go +++ b/auth/strategy.go @@ -8,10 +8,7 @@ import ( // ErrInvalidStrategy is returned by Append/Revoke functions, // when passed strategy does not implement Append/Revoke. -var ErrInvalidStrategy = errors.New("Invalid strategy") - -// StrategyKey define a custom type to expose a strategy identifier. -type StrategyKey string +var ErrInvalidStrategy = errors.New("auth: Invalid strategy") // Strategy represents an authentication mechanism or method to authenticate users requests. type Strategy interface { @@ -40,13 +37,13 @@ func (fn OptionFunc) Apply(v interface{}) { // Otherwise, nil. // // WARNING: Append function does not guarantee safe concurrency, It's natively depends on strategy store. -func Append(s Strategy, key string, info Info, r *http.Request) error { +func Append(s Strategy, key interface{}, info Info) error { u, ok := s.(interface { - Append(key string, info Info, r *http.Request) error + Append(interface{}, Info) error }) if ok { - return u.Append(key, info, r) + return u.Append(key, info) } return ErrInvalidStrategy @@ -57,45 +54,14 @@ func Append(s Strategy, key string, info Info, r *http.Request) error { // Otherwise, nil. // // WARNING: Revoke function does not guarantee safe concurrency, It's natively depends on strategy store. -func Revoke(s Strategy, key string, r *http.Request) error { +func Revoke(s Strategy, key interface{}) error { u, ok := s.(interface { - Revoke(key string, r *http.Request) error + Revoke(key interface{}) error }) if ok { - return u.Revoke(key, r) + return u.Revoke(key) } return ErrInvalidStrategy } - -// SetWWWAuthenticate adds a HTTP WWW-Authenticate header to the provided ResponseWriter's headers. -// by consolidating the result of calling Challenge methods on provided strategies. -// if strategy contains an Challenge method call it. -// Otherwise, strategy ignored. -func SetWWWAuthenticate(w http.ResponseWriter, realm string, strategies ...Strategy) { - str := "" - - if len(strategies) == 0 { - return - } - - for _, s := range strategies { - u, ok := s.(interface { - Challenge(string) string - }) - - if ok { - str = str + u.Challenge(realm) + ", " - } - } - - if len(str) == 0 { - return - } - - // remove ", " - str = str[0 : len(str)-2] - - w.Header().Set("WWW-Authenticate", str) -} diff --git a/auth/strategy_test.go b/auth/strategy_test.go index e530a86..a590027 100644 --- a/auth/strategy_test.go +++ b/auth/strategy_test.go @@ -3,7 +3,6 @@ package auth import ( "context" "net/http" - "net/http/httptest" "testing" "github.com/stretchr/testify/assert" @@ -49,9 +48,9 @@ func TestAppendRevoke(t *testing.T) { var err error switch tt.funcName { case "append": - err = Append(strategy, "", nil, nil) + err = Append(strategy, "", nil) case "revoke": - err = Revoke(strategy, "", nil) + err = Revoke(strategy, "") default: t.Errorf("Unsupported function %s", tt.funcName) return @@ -66,51 +65,6 @@ func TestAppendRevoke(t *testing.T) { } } -func TestSetWWWAuthenticate(t *testing.T) { - var ( - basic = &mockStrategy{challenge: `Basic realm="test"`} - bearer = &mockStrategy{challenge: `Bearer realm="test"`} - invalid = new(mockInvalidStrategy) - ) - - table := []struct { - name string - strategies []Strategy - expected string - }{ - { - name: "it does not set heder when no provided strategies", - expected: "", - }, - { - name: "it does not set heder when no provided strategies not implements Challenge method", - strategies: []Strategy{invalid}, - expected: "", - }, - { - name: "it ignore strategy if not implements Challenge method", - strategies: []Strategy{basic, invalid}, - expected: `Basic realm="test"`, - }, - { - name: "it consolidate strategies challenges into header", - strategies: []Strategy{basic, bearer}, - expected: `Basic realm="test", Bearer realm="test"`, - }, - } - - for _, tt := range table { - t.Run(tt.name, func(t *testing.T) { - w := httptest.NewRecorder() - SetWWWAuthenticate(w, "test", tt.strategies...) - - got := w.Header().Get("WWW-Authenticate") - - assert.Equal(t, tt.expected, got) - }) - } -} - type mockStrategy struct { called bool challenge string @@ -119,12 +73,12 @@ type mockStrategy struct { func (m *mockStrategy) Authenticate(ctx context.Context, r *http.Request) (Info, error) { return nil, nil } -func (m *mockStrategy) Append(token string, info Info, r *http.Request) error { +func (m *mockStrategy) Append(interface{}, Info) error { m.called = true return nil } -func (m *mockStrategy) Revoke(token string, r *http.Request) error { +func (m *mockStrategy) Revoke(interface{}) error { m.called = true return nil } diff --git a/cache/cache.go b/cache/cache.go new file mode 100644 index 0000000..2f49737 --- /dev/null +++ b/cache/cache.go @@ -0,0 +1,120 @@ +// Package cache provides in-memory caches based on different caches replacement algorithms. +package cache + +import ( + "sync" +) + +// Cache stores data so that future requests for that data can be served faster. +type Cache interface { + // Load returns key's value. + Load(key interface{}) (interface{}, bool) + // Peek returns key's value without updating the underlying "rank". + Peek(key interface{}) (interface{}, bool) + // Update the key value without updating the underlying "rank". + Update(key interface{}, value interface{}) + // Store sets the value for a key. + Store(key interface{}, value interface{}) + // Delete deletes the key value. + Delete(key interface{}) + // Keys return cache records keys. + Keys() []interface{} + // Contains Checks if a key exists in cache. + Contains(key interface{}) bool + // Purge Clears all cache entries. + Purge() + // Resize cache, returning number evicted + Resize(int) int + // Len Returns the number of items in the cache. + Len() int + // Cap Returns the cache capacity. + Cap() int +} + +// OnEvicted define a function signature to be +// executed in its own goroutine when an entry is purged from the cache. +type OnEvicted func(key interface{}, value interface{}) + +// OnExpired define a function signature to be +// executed in its own goroutine when an entry TTL elapsed. +// invocation of OnExpired, does not mean the entry is purged from the cache, +// if need be, it must coordinate with the cache explicitly. +// +// var cache cache.Cache +// onExpired := func(key interface{}) { +// _, _, _ = cache.Peek(key) +// } +// +// This should not be done unless the cache thread-safe. +type OnExpired func(key interface{}) + +type cache struct { + mu sync.RWMutex + container Cache +} + +func (c *cache) Load(key interface{}) (interface{}, bool) { + c.mu.Lock() + defer c.mu.Unlock() + return c.container.Load(key) +} + +func (c *cache) Peek(key interface{}) (interface{}, bool) { + c.mu.Lock() + defer c.mu.Unlock() + return c.container.Peek(key) +} + +func (c *cache) Update(key interface{}, value interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + c.container.Update(key, value) +} + +func (c *cache) Store(key interface{}, value interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + c.container.Store(key, value) +} + +func (c *cache) Delete(key interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + c.container.Delete(key) +} + +func (c *cache) Keys() []interface{} { + c.mu.RLock() + defer c.mu.RUnlock() + return c.container.Keys() +} + +func (c *cache) Contains(key interface{}) bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.container.Contains(key) +} + +func (c *cache) Purge() { + c.mu.Lock() + defer c.mu.Unlock() + c.container.Purge() +} + +func (c *cache) Resize(s int) int { + c.mu.Lock() + defer c.mu.Unlock() + return c.container.Resize(s) +} + +func (c *cache) Len() int { + c.mu.RLock() + defer c.mu.RUnlock() + return c.container.Len() +} + +func (c *cache) Cap() int { + c.mu.RLock() + defer c.mu.RUnlock() + return c.container.Cap() +} diff --git a/cache/container.go b/cache/container.go new file mode 100644 index 0000000..55778c4 --- /dev/null +++ b/cache/container.go @@ -0,0 +1,56 @@ +package cache + +import ( + "strconv" + "sync" +) + +// Container identifies a cache container that implemented in another package. +type Container uint + +const ( + // IDLE cache container. + IDLE Container = iota + 1 + // LRU cache container. + LRU + // FIFO cache container. + FIFO + max +) + +var containers = make([]func(opts ...Option) Cache, max) + +// Register registers a function that returns a new instance, +// of the given cache container function. +// This is intended to be called from the init function in packages that implement container functions. +func (c Container) Register(function func(opts ...Option) Cache) { + if c <= 0 && c >= max { //nolint:staticcheck + panic("cache: Register of unknown cache container function") + } + + containers[c] = function +} + +// Available reports whether the given cache container is linked into the binary. +func (c Container) Available() bool { + return c > 0 && c < max && containers[c] != nil +} + +// New returns a new thread safe cache. +// New panics if the cache container function is not linked into the binary. +func (c Container) New(opts ...Option) Cache { + cache := new(cache) + cache.mu = sync.RWMutex{} + cache.container = c.NewUnsafe(opts...) + return cache +} + +// NewUnsafe returns a new thread unsafe cache. +// NewUnsafe panics if the cache container function is not linked into the binary. +func (c Container) NewUnsafe(opts ...Option) Cache { + if !c.Available() { + panic("cache: Requested cache container function #" + strconv.Itoa(int(c)) + " is unavailable") + } + + return containers[c](opts...) +} diff --git a/cache/container/fifo/fifo.go b/cache/container/fifo/fifo.go new file mode 100644 index 0000000..eb9ab46 --- /dev/null +++ b/cache/container/fifo/fifo.go @@ -0,0 +1,103 @@ +// Package fifo implements an FIFO cache. +package fifo + +import ( + "container/list" + + "github.com/shaj13/go-guardian/cache" + "github.com/shaj13/go-guardian/cache/internal" +) + +func init() { + cache.FIFO.Register(New) +} + +// New returns new thread unsafe cache container. +func New(opts ...cache.Option) cache.Cache { + col := &collection{list.New()} + fifo := new(fifo) + fifo.c = internal.New(col) + for _, opt := range opts { + opt.Apply(fifo) + } + return fifo +} + +type fifo struct { + c *internal.Container +} + +func (f *fifo) Load(key interface{}) (interface{}, bool) { + return f.c.Load(key) +} + +func (f *fifo) Peek(key interface{}) (interface{}, bool) { + return f.c.Peek(key) +} + +func (f *fifo) Store(key, value interface{}) { + f.c.Store(key, value) +} + +func (f *fifo) Update(key, value interface{}) { + f.c.Update(key, value) +} + +func (f *fifo) Delete(key interface{}) { + f.c.Delete(key) +} + +func (f *fifo) Contains(key interface{}) bool { + return f.c.Contains(key) +} + +func (f *fifo) Resize(size int) int { + return f.c.Resize(size) +} + +func (f *fifo) Purge() { + f.c.Purge() +} + +func (f *fifo) Keys() []interface{} { + return f.c.Keys() +} + +func (f *fifo) Len() int { + return f.c.Len() +} + +func (f *fifo) Cap() int { + return f.c.Capacity +} + +type collection struct { + ll *list.List +} + +func (c *collection) Move(e *internal.Entry) {} + +func (c *collection) Push(e *internal.Entry) { + le := c.ll.PushBack(e) + e.Element = le +} + +func (c *collection) Remove(e *internal.Entry) { + le := e.Element.(*list.Element) + c.ll.Remove(le) +} + +func (c *collection) GetOldest() (e *internal.Entry) { + if le := c.ll.Front(); le != nil { + e = le.Value.(*internal.Entry) + } + return +} + +func (c *collection) Len() int { + return c.ll.Len() +} + +func (c *collection) Init() { + c.ll.Init() +} diff --git a/cache/container/fifo/fifo_test.go b/cache/container/fifo/fifo_test.go new file mode 100644 index 0000000..b381eb4 --- /dev/null +++ b/cache/container/fifo/fifo_test.go @@ -0,0 +1,223 @@ +package fifo + +import ( + "fmt" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/shaj13/go-guardian/cache" +) + +func TestStore(t *testing.T) { + fifo := New() + fifo.Store(1, 1) + ok := fifo.Contains(1) + assert.True(t, ok) +} + +func TestLoad(t *testing.T) { + fifo := New() + fifo.Store("1", 1) + v, ok := fifo.Load("1") + assert.True(t, ok) + assert.Equal(t, 1, v) +} + +func TestDelete(t *testing.T) { + fifo := New() + fifo.Store(1, 1) + fifo.Delete(1) + ok := fifo.Contains(1) + assert.False(t, ok) +} + +func TestPeek(t *testing.T) { + fifo := New().(*fifo) + fifo.c.Capacity = 3 + + fifo.Store(1, 0) + fifo.Store(2, 0) + fifo.Store(3, 0) + v, ok := fifo.Peek(1) + fifo.Store(4, 0) + found := fifo.Contains(1) + + assert.Equal(t, 0, v) + assert.True(t, ok) + assert.False(t, found, "Peek should not move element") +} + +func TestContains(t *testing.T) { + fifo := New().(*fifo) + fifo.c.Capacity = 3 + + fifo.Store(1, 0) + fifo.Store(2, 0) + fifo.Store(3, 0) + found := fifo.Contains(1) + fifo.Store(4, 0) + _, ok := fifo.Load(1) + + assert.True(t, found) + assert.False(t, ok, "Contains should not move element") +} + +func TestUpdate(t *testing.T) { + fifo := New().(*fifo) + fifo.c.Capacity = 3 + + fifo.Store(1, 0) + fifo.Store(2, 0) + fifo.Store(3, 0) + fifo.Update(1, 1) + v, ok := fifo.Peek(1) + fifo.Store(4, 0) + found := fifo.Contains(1) + + assert.Equal(t, 1, v) + assert.True(t, ok) + assert.False(t, found, "Update should not move element") +} + +func TestPurge(t *testing.T) { + fifo := New().(*fifo) + fifo.c.Capacity = 3 + + fifo.Store(1, 0) + fifo.Store(2, 0) + fifo.Store(3, 0) + fifo.Purge() + + assert.Equal(t, 0, fifo.Len()) +} + +func TestResize(t *testing.T) { + fifo := New().(*fifo) + fifo.c.Capacity = 3 + + fifo.Store(1, 0) + fifo.Store(2, 0) + fifo.Store(3, 0) + fifo.Resize(2) + + assert.Equal(t, 2, fifo.Len()) + assert.True(t, fifo.Contains(2)) + assert.True(t, fifo.Contains(3)) + assert.False(t, fifo.Contains(1)) +} + +func TestKeys(t *testing.T) { + fifo := New() + + fifo.Store(1, 0) + fifo.Store(2, 0) + fifo.Store(3, 0) + + assert.ElementsMatch(t, []interface{}{1, 2, 3}, fifo.Keys()) +} + +func TestCap(t *testing.T) { + fifo := New().(*fifo) + fifo.c.Capacity = 3 + assert.Equal(t, 3, fifo.Cap()) +} + +func TestOnEvicted(t *testing.T) { + send := make(chan interface{}) + done := make(chan bool) + + evictedKeys := make([]interface{}, 0, 2) + + onEvictedFun := func(key, value interface{}) { + send <- key + } + + fifo := New().(*fifo) + fifo.c.Capacity = 20 + fifo.c.OnEvicted = onEvictedFun + + go func() { + for { + key := <-send + evictedKeys = append(evictedKeys, key) + if len(evictedKeys) >= 2 { + done <- true + return + } + } + }() + + for i := 0; i < 22; i++ { + fifo.Store(fmt.Sprintf("myKey%d", i), 1234) + } + + select { + case <-done: + case <-time.After(time.Second * 2): + t.Fatal("TestOnEvicted timeout exceeded, expected to receive evicted keys") + } + + assert.ElementsMatch(t, []interface{}{"myKey0", "myKey1"}, evictedKeys) +} + +func TestOnExpired(t *testing.T) { + send := make(chan interface{}) + done := make(chan bool) + + expiredKeys := make([]interface{}, 0, 2) + + onExpiredFun := func(key interface{}) { + send <- key + } + + fifo := New().(*fifo) + fifo.c.OnExpired = onExpiredFun + fifo.c.TTL = time.Millisecond + + go func() { + for { + key := <-send + expiredKeys = append(expiredKeys, key) + if len(expiredKeys) >= 2 { + done <- true + return + } + } + }() + + fifo.Store(1, 1234) + fifo.Store(2, 1234) + + select { + case <-done: + case <-time.After(time.Second * 2): + t.Fatal("TestOnExpired timeout exceeded, expected to receive expired keys") + } + + assert.ElementsMatch(t, []interface{}{1, 2}, expiredKeys) +} + +func BenchmarkFIFO(b *testing.B) { + keys := []interface{}{} + fifo := cache.FIFO.New() + + for i := 0; i < 100; i++ { + keys = append(keys, i) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + key := keys[rand.Intn(100)] + _, ok := fifo.Load(key) + if ok { + fifo.Delete(key) + } else { + fifo.Store(key, struct{}{}) + } + } + }) +} diff --git a/cache/container/fifo/options.go b/cache/container/fifo/options.go new file mode 100644 index 0000000..d01ae7e --- /dev/null +++ b/cache/container/fifo/options.go @@ -0,0 +1,43 @@ +package fifo + +import ( + "time" + + "github.com/shaj13/go-guardian/cache" +) + +// TTL set cache container entries TTL. +func TTL(ttl time.Duration) cache.Option { + return cache.OptionFunc(func(c cache.Cache) { + if l, ok := c.(*fifo); ok { + l.c.TTL = ttl + } + }) +} + +// Capacity set cache container capacity. +func Capacity(capacity int) cache.Option { + return cache.OptionFunc(func(c cache.Cache) { + if l, ok := c.(*fifo); ok { + l.c.Capacity = capacity + } + }) +} + +// RegisterOnEvicted register OnEvicted callback. +func RegisterOnEvicted(cb cache.OnEvicted) cache.Option { + return cache.OptionFunc(func(c cache.Cache) { + if l, ok := c.(*fifo); ok { + l.c.OnEvicted = cb + } + }) +} + +// RegisterOnExpired register OnExpired callback. +func RegisterOnExpired(cb cache.OnExpired) cache.Option { + return cache.OptionFunc(func(c cache.Cache) { + if l, ok := c.(*fifo); ok { + l.c.OnExpired = cb + } + }) +} diff --git a/cache/container/fifo/options_test.go b/cache/container/fifo/options_test.go new file mode 100644 index 0000000..32cd088 --- /dev/null +++ b/cache/container/fifo/options_test.go @@ -0,0 +1,36 @@ +package fifo + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestCapacity(t *testing.T) { + opt := Capacity(100) + fifo := New(opt).(*fifo) + + assert.Equal(t, fifo.c.Capacity, 100) +} + +func TestRegisterOnEvicted(t *testing.T) { + opt := RegisterOnEvicted(func(key, value interface{}) {}) + fifo := New(opt).(*fifo) + + assert.NotNil(t, fifo.c.OnEvicted) +} + +func TestRegisterOnExpired(t *testing.T) { + opt := RegisterOnExpired(func(key interface{}) {}) + fifo := New(opt).(*fifo) + + assert.NotNil(t, fifo.c.OnExpired) +} + +func TestTTL(t *testing.T) { + opt := TTL(time.Hour) + fifo := New(opt).(*fifo) + + assert.Equal(t, fifo.c.TTL, time.Hour) +} diff --git a/cache/container/idle/idle.go b/cache/container/idle/idle.go new file mode 100644 index 0000000..6db585b --- /dev/null +++ b/cache/container/idle/idle.go @@ -0,0 +1,27 @@ +// Package idle implements an IDLE cache, that never finds/stores a key's value. +package idle + +import "github.com/shaj13/go-guardian/cache" + +func init() { + cache.IDLE.Register(New) +} + +// New return idle cache container that never finds/stores a key's value. +func New(opts ...cache.Option) cache.Cache { + return idle{} +} + +type idle struct{} + +func (idle) Load(interface{}) (v interface{}, ok bool) { return } +func (idle) Peek(interface{}) (v interface{}, ok bool) { return } +func (idle) Keys() (keys []interface{}) { return } +func (idle) Contains(interface{}) (ok bool) { return } +func (idle) Resize(int) (i int) { return } +func (idle) Len() (len int) { return } +func (idle) Cap() (cap int) { return } +func (idle) Update(interface{}, interface{}) {} +func (idle) Store(interface{}, interface{}) {} +func (idle) Delete(interface{}) {} +func (idle) Purge() {} diff --git a/cache/container/lru/lru.go b/cache/container/lru/lru.go new file mode 100644 index 0000000..981ec2f --- /dev/null +++ b/cache/container/lru/lru.go @@ -0,0 +1,106 @@ +// Package lru implements an LRU cache. +package lru + +import ( + "container/list" + + "github.com/shaj13/go-guardian/cache" + "github.com/shaj13/go-guardian/cache/internal" +) + +func init() { + cache.LRU.Register(New) +} + +// New returns new thread unsafe cache container. +func New(opts ...cache.Option) cache.Cache { + col := &collection{list.New()} + lru := new(lru) + lru.c = internal.New(col) + for _, opt := range opts { + opt.Apply(lru) + } + return lru +} + +type lru struct { + c *internal.Container +} + +func (l *lru) Load(key interface{}) (interface{}, bool) { + return l.c.Load(key) +} + +func (l *lru) Peek(key interface{}) (interface{}, bool) { + return l.c.Peek(key) +} + +func (l *lru) Store(key, value interface{}) { + l.c.Store(key, value) +} + +func (l *lru) Update(key, value interface{}) { + l.c.Update(key, value) +} + +func (l *lru) Delete(key interface{}) { + l.c.Delete(key) +} + +func (l *lru) Contains(key interface{}) bool { + return l.c.Contains(key) +} + +func (l *lru) Resize(size int) int { + return l.c.Resize(size) +} + +func (l *lru) Purge() { + l.c.Purge() +} + +func (l *lru) Keys() []interface{} { + return l.c.Keys() +} + +func (l *lru) Len() int { + return l.c.Len() +} + +func (l *lru) Cap() int { + return l.c.Capacity +} + +type collection struct { + ll *list.List +} + +func (c *collection) Move(e *internal.Entry) { + le := e.Element.(*list.Element) + c.ll.MoveToFront(le) +} + +func (c *collection) Push(e *internal.Entry) { + le := c.ll.PushFront(e) + e.Element = le +} + +func (c *collection) Remove(e *internal.Entry) { + le := e.Element.(*list.Element) + c.ll.Remove(le) +} + +func (c *collection) GetOldest() (e *internal.Entry) { + if le := c.ll.Back(); le != nil { + e = le.Value.(*internal.Entry) + } + return +} + +func (c *collection) Len() int { + return c.ll.Len() +} + +func (c *collection) Init() { + c.ll.Init() +} diff --git a/cache/container/lru/lru_test.go b/cache/container/lru/lru_test.go new file mode 100644 index 0000000..3f9bef4 --- /dev/null +++ b/cache/container/lru/lru_test.go @@ -0,0 +1,223 @@ +package lru + +import ( + "fmt" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/shaj13/go-guardian/cache" +) + +func TestStore(t *testing.T) { + lru := New() + lru.Store(1, 1) + ok := lru.Contains(1) + assert.True(t, ok) +} + +func TestLoad(t *testing.T) { + lru := New() + lru.Store("1", 1) + v, ok := lru.Load("1") + assert.True(t, ok) + assert.Equal(t, 1, v) +} + +func TestDelete(t *testing.T) { + lru := New() + lru.Store(1, 1) + lru.Delete(1) + ok := lru.Contains(1) + assert.False(t, ok) +} + +func TestPeek(t *testing.T) { + lru := New().(*lru) + lru.c.Capacity = 3 + + lru.Store(1, 0) + lru.Store(2, 0) + lru.Store(3, 0) + v, ok := lru.Peek(1) + lru.Store(4, 0) + found := lru.Contains(1) + + assert.Equal(t, 0, v) + assert.True(t, ok) + assert.False(t, found, "Peek should not move element") +} + +func TestContains(t *testing.T) { + lru := New().(*lru) + lru.c.Capacity = 3 + + lru.Store(1, 0) + lru.Store(2, 0) + lru.Store(3, 0) + found := lru.Contains(1) + lru.Store(4, 0) + _, ok := lru.Load(1) + + assert.True(t, found) + assert.False(t, ok, "Contains should not move element") +} + +func TestUpdate(t *testing.T) { + lru := New().(*lru) + lru.c.Capacity = 3 + + lru.Store(1, 0) + lru.Store(2, 0) + lru.Store(3, 0) + lru.Update(1, 1) + v, ok := lru.Peek(1) + lru.Store(4, 0) + found := lru.Contains(1) + + assert.Equal(t, 1, v) + assert.True(t, ok) + assert.False(t, found, "Update should not move element") +} + +func TestPurge(t *testing.T) { + lru := New().(*lru) + lru.c.Capacity = 3 + + lru.Store(1, 0) + lru.Store(2, 0) + lru.Store(3, 0) + lru.Purge() + + assert.Equal(t, 0, lru.Len()) +} + +func TestResize(t *testing.T) { + lru := New().(*lru) + lru.c.Capacity = 3 + + lru.Store(1, 0) + lru.Store(2, 0) + lru.Store(3, 0) + lru.Resize(2) + + assert.Equal(t, 2, lru.Len()) + assert.True(t, lru.Contains(2)) + assert.True(t, lru.Contains(3)) + assert.False(t, lru.Contains(1)) +} + +func TestKeys(t *testing.T) { + lru := New() + + lru.Store(1, 0) + lru.Store(2, 0) + lru.Store(3, 0) + + assert.ElementsMatch(t, []interface{}{1, 2, 3}, lru.Keys()) +} + +func TestCap(t *testing.T) { + lru := New().(*lru) + lru.c.Capacity = 3 + assert.Equal(t, 3, lru.Cap()) +} + +func TestOnEvicted(t *testing.T) { + send := make(chan interface{}) + done := make(chan bool) + + evictedKeys := make([]interface{}, 0, 2) + + onEvictedFun := func(key, value interface{}) { + send <- key + } + + lru := New().(*lru) + lru.c.Capacity = 20 + lru.c.OnEvicted = onEvictedFun + + go func() { + for { + key := <-send + evictedKeys = append(evictedKeys, key) + if len(evictedKeys) >= 2 { + done <- true + return + } + } + }() + + for i := 0; i < 22; i++ { + lru.Store(fmt.Sprintf("myKey%d", i), 1234) + } + + select { + case <-done: + case <-time.After(time.Second * 2): + t.Fatal("TestOnEvicted timeout exceeded, expected to receive evicted keys") + } + + assert.ElementsMatch(t, []interface{}{"myKey0", "myKey1"}, evictedKeys) +} + +func TestOnExpired(t *testing.T) { + send := make(chan interface{}) + done := make(chan bool) + + expiredKeys := make([]interface{}, 0, 2) + + onExpiredFun := func(key interface{}) { + send <- key + } + + lru := New().(*lru) + lru.c.OnExpired = onExpiredFun + lru.c.TTL = time.Millisecond + + go func() { + for { + key := <-send + expiredKeys = append(expiredKeys, key) + if len(expiredKeys) >= 2 { + done <- true + return + } + } + }() + + lru.Store(1, 1234) + lru.Store(2, 1234) + + select { + case <-done: + case <-time.After(time.Second * 2): + t.Fatal("TestOnExpired timeout exceeded, expected to receive expired keys") + } + + assert.ElementsMatch(t, []interface{}{1, 2}, expiredKeys) +} + +func BenchmarkLRU(b *testing.B) { + keys := []interface{}{} + lru := cache.LRU.New() + + for i := 0; i < 100; i++ { + keys = append(keys, i) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + key := keys[rand.Intn(100)] + _, ok := lru.Load(key) + if ok { + lru.Delete(key) + } else { + lru.Store(key, struct{}{}) + } + } + }) +} diff --git a/cache/container/lru/options.go b/cache/container/lru/options.go new file mode 100644 index 0000000..1357ce0 --- /dev/null +++ b/cache/container/lru/options.go @@ -0,0 +1,43 @@ +package lru + +import ( + "time" + + "github.com/shaj13/go-guardian/cache" +) + +// TTL set cache container entries TTL. +func TTL(ttl time.Duration) cache.Option { + return cache.OptionFunc(func(c cache.Cache) { + if l, ok := c.(*lru); ok { + l.c.TTL = ttl + } + }) +} + +// Capacity set cache container capacity. +func Capacity(capacity int) cache.Option { + return cache.OptionFunc(func(c cache.Cache) { + if l, ok := c.(*lru); ok { + l.c.Capacity = capacity + } + }) +} + +// RegisterOnEvicted register OnEvicted callback. +func RegisterOnEvicted(cb cache.OnEvicted) cache.Option { + return cache.OptionFunc(func(c cache.Cache) { + if l, ok := c.(*lru); ok { + l.c.OnEvicted = cb + } + }) +} + +// RegisterOnExpired register OnExpired callback. +func RegisterOnExpired(cb cache.OnExpired) cache.Option { + return cache.OptionFunc(func(c cache.Cache) { + if l, ok := c.(*lru); ok { + l.c.OnExpired = cb + } + }) +} diff --git a/cache/container/lru/options_test.go b/cache/container/lru/options_test.go new file mode 100644 index 0000000..16cf180 --- /dev/null +++ b/cache/container/lru/options_test.go @@ -0,0 +1,36 @@ +package lru + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestCapacity(t *testing.T) { + opt := Capacity(100) + lru := New(opt).(*lru) + + assert.Equal(t, lru.c.Capacity, 100) +} + +func TestRegisterOnEvicted(t *testing.T) { + opt := RegisterOnEvicted(func(key, value interface{}) {}) + lru := New(opt).(*lru) + + assert.NotNil(t, lru.c.OnEvicted) +} + +func TestRegisterOnExpired(t *testing.T) { + opt := RegisterOnExpired(func(key interface{}) {}) + lru := New(opt).(*lru) + + assert.NotNil(t, lru.c.OnExpired) +} + +func TestTTL(t *testing.T) { + opt := TTL(time.Hour) + lru := New(opt).(*lru) + + assert.Equal(t, lru.c.TTL, time.Hour) +} diff --git a/cache/example_test.go b/cache/example_test.go new file mode 100644 index 0000000..d456937 --- /dev/null +++ b/cache/example_test.go @@ -0,0 +1,64 @@ +package cache_test + +import ( + "fmt" + "time" + + "github.com/shaj13/go-guardian/cache" + "github.com/shaj13/go-guardian/cache/container/fifo" + _ "github.com/shaj13/go-guardian/cache/container/idle" + "github.com/shaj13/go-guardian/cache/container/lru" +) + +func Example_idle() { + // it can be unsafe, no any race conditions + c := cache.IDLE.NewUnsafe() + c.Store(1, 0) + fmt.Println(c.Contains(1)) + // Output: + // false +} + +func Example_fifo() { + cap := fifo.Capacity(2) + c := cache.FIFO.New(cap) + c.Store(1, 0) + c.Store(2, 0) + c.Store(3, 0) + fmt.Println(c.Contains(1)) + // Output: + // false +} + +func Example_lru() { + cap := lru.Capacity(2) + c := cache.LRU.New(cap) + c.Store(1, 0) + c.Store(2, 0) + c.Store(3, 0) + fmt.Println(c.Contains(1)) + // Output: + // false +} + +func Example_onexpired() { + // c must be thread safe + var c cache.Cache + + ttl := lru.TTL(time.Millisecond) + exp := lru.RegisterOnExpired(func(key interface{}) { + fmt.Println("") + // use Peek/Load over delete, perhaps a new entry added with the same key during expiration, + // or entry refreshed from other thread. + c.Peek(key) + }) + + c = cache.LRU.New(ttl, exp) + + c.Store(1, 0) + + time.Sleep(time.Millisecond * 5) + fmt.Println(c.Contains(1)) + // Output: + // false +} diff --git a/cache/internal/container.go b/cache/internal/container.go new file mode 100644 index 0000000..484db0f --- /dev/null +++ b/cache/internal/container.go @@ -0,0 +1,186 @@ +package internal + +import ( + "time" +) + +// Collection represents the container underlying data structure, +// and defines the functions or operations that can be applied to the data elements. +type Collection interface { + Move(*Entry) + Push(*Entry) + Remove(*Entry) + GetOldest() *Entry + Len() int + Init() +} + +// Entry is used to hold a value in the cache. +type Entry struct { + Key interface{} + Value interface{} + Element interface{} + Exp time.Time + Timer *time.Timer +} + +// Container represent core cache container. +type Container struct { + Collection Collection + Entries map[interface{}]*Entry + OnEvicted func(key interface{}, value interface{}) + OnExpired func(key interface{}) + TTL time.Duration + Capacity int +} + +// Load returns key's value. +func (c *Container) Load(key interface{}) (interface{}, bool) { + return c.get(key, false) +} + +// Peek returns key's value without updating the underlying "rank". +func (c *Container) Peek(key interface{}) (interface{}, bool) { + return c.get(key, true) +} + +func (c *Container) get(key interface{}, peek bool) (v interface{}, found bool) { + e, ok := c.Entries[key] + if !ok { + return + } + + if c.TTL > 0 && time.Now().UTC().After(e.Exp) { + c.Evict(e) + return + } + + if !peek { + c.Collection.Move(e) + } + + return e.Value, ok +} + +// Store sets the value for a key. +func (c *Container) Store(key, value interface{}) { + if e, ok := c.Entries[key]; ok { + c.RemoveEntry(e) + } + + e := c.withTTL(key, value) + c.Entries[key] = e + c.Collection.Push(e) + c.RemoveOldest() +} + +// Update the key value without updating the underlying "rank". +func (c *Container) Update(key interface{}, value interface{}) { + if e, ok := c.Entries[key]; ok { + e.Value = value + } +} + +// Purge Clears all cache entries. +func (c *Container) Purge() { + defer c.Collection.Init() + + if c.OnEvicted == nil { + c.Entries = make(map[interface{}]*Entry) + return + } + + for _, e := range c.Entries { + c.Evict(e) + } +} + +// Resize cache, returning number evicted +func (c *Container) Resize(size int) int { + c.Capacity = size + diff := c.Len() - size + + if diff < 0 { + diff = 0 + } + + for i := 0; i < diff; i++ { + c.RemoveOldest() + } + + return diff +} + +// Delete deletes the key value. +func (c *Container) Delete(key interface{}) { + if e, ok := c.Entries[key]; ok { + c.Evict(e) + } +} + +// Contains Checks if a key exists in cache. +func (c *Container) Contains(key interface{}) (ok bool) { + _, ok = c.Peek(key) + return +} + +// Keys return cache records keys. +func (c *Container) Keys() (keys []interface{}) { + for k := range c.Entries { + keys = append(keys, k) + } + return +} + +// Len Returns the number of items in the cache. +func (c *Container) Len() int { + return c.Collection.Len() +} + +// RemoveOldest Removes the oldest entry from cache. +func (c *Container) RemoveOldest() { + e := c.Collection.GetOldest() + if c.Capacity != 0 && c.Len() > c.Capacity && e != nil { + c.Evict(e) + } +} + +// RemoveEntry remove entry silently. +func (c *Container) RemoveEntry(e *Entry) { + c.Collection.Remove(e) + if e.Timer != nil { + e.Timer.Stop() + } + delete(c.Entries, e.Key) +} + +// Evict remove entry and fire on evicted callback. +func (c *Container) Evict(e *Entry) { + c.RemoveEntry(e) + if c.OnEvicted != nil { + go c.OnEvicted(e.Key, e.Value) + } +} + +func (c *Container) withTTL(key, value interface{}) *Entry { + e := &Entry{Key: key, Value: value} + + if c.TTL > 0 { + if c.OnExpired != nil { + e.Timer = time.AfterFunc(c.TTL, func() { + c.OnExpired(e.Key) + }) + } + e.Exp = time.Now().UTC().Add(c.TTL) + } + + return e +} + +// New return new container. +func New(c Collection) *Container { + return &Container{ + Collection: c, + Entries: make(map[interface{}]*Entry), + } +} diff --git a/cache/option.go b/cache/option.go new file mode 100644 index 0000000..744dc67 --- /dev/null +++ b/cache/option.go @@ -0,0 +1,18 @@ +package cache + +// Option configures cache instance using the functional options paradigm +// popularized by Rob Pike and Dave Cheney. +// If you're unfamiliar with this style, +// see https://commandcenter.blogspot.com/2014/01/self-referential-functions-and-design.html and +// https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis. +type Option interface { + Apply(c Cache) +} + +// OptionFunc implements Option interface. +type OptionFunc func(c Cache) + +// Apply the configuration to the provided cache. +func (fn OptionFunc) Apply(c Cache) { + fn(c) +} diff --git a/doc.go b/doc.go index ded3449..1158335 100644 --- a/doc.go +++ b/doc.go @@ -21,75 +21,6 @@ Here are a few bullet point reasons you might like to try it out: * provides a package to caches the authentication decisions, based on different mechanisms and algorithms. * provides two-factor authentication and one-time password as defined in [RFC-4226](https://tools.ietf.org/html/rfc4226) and [RFC-6238](https://tools.ietf.org/html/rfc6238) * provides a mechanism to customize strategies, even enables writing a custom strategy - - -Example: - package main - - import ( - "crypto/x509" - "encoding/pem" - "io/ioutil" - "log" - "net/http" - - "github.com/gorilla/mux" - "github.com/shaj13/go-guardian/auth" - x509Strategy "github.com/shaj13/go-guardian/auth/strategies/x509" - ) - - var authenticator auth.Authenticator - - func middleware(next http.Handler) http.HandlerFunc { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Println("Executing Auth Middleware") - user, err := authenticator.Authenticate(r) - if err != nil { - code := http.StatusUnauthorized - http.Error(w, http.StatusText(code), code) - return - } - log.Printf("User %s Authenticated\n", user.UserName()) - next.ServeHTTP(w, r) - }) - } - - func Handler(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("Handler!!\n")) - } - - func main() { - opts := x509.VerifyOptions{} - opts.KeyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth} - opts.Roots = x509.NewCertPool() - // Read Root Ca Certificate - opts.Roots.AddCert(readCertificate("//")) - - // create strategy and bind it to authenticator. - strategy := x509Strategy.New(opts) - authenticator = auth.New() - authenticator.EnableStrategy(x509Strategy.StrategyKey, strategy) - - r := mux.NewRouter() - r.HandleFunc("/", middleware(http.HandlerFunc(Handler))) - log.Fatal(http.ListenAndServeTLS(":8080", "", "", r)) - } - - func readCertificate(file string) *x509.Certificate { - data, err := ioutil.ReadFile(file) - - if err != nil { - log.Fatalf("error reading %s: %v", file, err) - } - - p, _ := pem.Decode(data) - cert, err := x509.ParseCertificate(p.Bytes) - if err != nil { - log.Fatalf("error parseing certificate %s: %v", file, err) - } - - return cert - } */ //nolint:lll package guardian diff --git a/errors/errors.go b/errors/errors.go deleted file mode 100644 index 05ce390..0000000 --- a/errors/errors.go +++ /dev/null @@ -1,60 +0,0 @@ -package errors - -import ( - "errors" - "fmt" - "reflect" -) - -// InvalidType represent invalid type assertion error. -type InvalidType struct { - Want string - Got string -} - -// Error describe error as a string -func (i InvalidType) Error() string { - return "Invalid type assertion: Want " + i.Want + " Got " + i.Got -} - -// NewInvalidType returns InvalidType error -func NewInvalidType(want, got interface{}) error { - f := func(v interface{}) string { - if v == nil { - return "" - } - return reflect.TypeOf(v).String() - } - - return InvalidType{ - Want: f(want), - Got: f(got), - } -} - -// MultiError stores multiple errors. -type MultiError []error - -func (errs MultiError) Error() string { - if len(errs) == 0 { - return "" - } - - if len(errs) == 1 { - return errs[0].Error() - } - - str := "" - - for _, err := range errs[1:] { - str += err.Error() + ", " - } - - return fmt.Sprintf("%v: [%s]", errs[0], str) -} - -// New returns an error that formats as the given text. -// Each call to New returns a distinct error value even if the text is identical. -func New(str string) error { - return errors.New(str) -} diff --git a/go.sum b/go.sum index c7c50df..ea28345 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,7 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.1.0 h1:Hsa8mG0dQ46ij8Sl2AYJDUv1oA9/d6Vk+3LG99Oe02g= github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gnostic v0.0.0-20170729233727-0c5108395e2d/go.mod h1:sJBsCZ4ayReDTBIg8b9dl28c5xFWyhBTVRp3pOg5EKY= github.com/googleapis/gnostic v0.1.0/go.mod h1:sJBsCZ4ayReDTBIg8b9dl28c5xFWyhBTVRp3pOg5EKY= diff --git a/store/cache.go b/store/cache.go deleted file mode 100644 index 125516d..0000000 --- a/store/cache.go +++ /dev/null @@ -1,48 +0,0 @@ -// Package store provides different cache mechanisms and algorithms, -// To caches the authentication decisions. -package store - -import ( - "errors" - "net/http" - "time" -) - -// ErrCachedExp returned by cache when cached record have expired, -// and no longer living in cache (deleted) -var ErrCachedExp = errors.New("cache: Cached record have expired") - -// Cache stores data so that future requests for that data can be served faster. -type Cache interface { - // Load returns the value stored in the cache for a key, or nil if no value is present. - // The ok result indicates whether value was found in the Cache. - // The error reserved for moderate cache and returned if an error occurs, Otherwise nil. - Load(key string, r *http.Request) (interface{}, bool, error) - // Store sets the value for a key. - // The error reserved for moderate cache and returned if an error occurs, Otherwise nil. - Store(key string, value interface{}, r *http.Request) error - // Delete deletes the value for a key. - // The error reserved for moderate cache and returned if an error occurs, Otherwise nil. - Delete(key string, r *http.Request) error - - // Keys return cache records keys. - Keys() []string -} - -// OnEvicted define a function signature to be -// executed when an entry is purged from the cache. -type OnEvicted func(key string, value interface{}) - -// NoCache is an implementation of Cache interface that never finds/stores a record. -type NoCache struct{} - -func (NoCache) Store(_ string, _ interface{}, _ *http.Request) (err error) { return } // nolint:golint -func (NoCache) Load(_ string, _ *http.Request) (v interface{}, ok bool, err error) { return } // nolint:golint -func (NoCache) Delete(_ string, _ *http.Request) (err error) { return } // nolint:golint -func (NoCache) Keys() (keys []string) { return } // nolint:golint - -type record struct { - Exp time.Time - Key string - Value interface{} -} diff --git a/store/example_test.go b/store/example_test.go deleted file mode 100644 index 42e6fa4..0000000 --- a/store/example_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package store - -import ( - "context" - "fmt" - "net/http" - "time" -) - -func ExampleNewFIFO() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - r, _ := http.NewRequest("GET", "/", nil) - cache := NewFIFO(ctx, time.Minute*5) - - cache.Store("key", "value", r) - - v, ok, _ := cache.Load("key", r) - fmt.Println(v, ok) - - cache.Delete("key", r) - v, ok, _ = cache.Load("key", r) - fmt.Println(v, ok) - - // Output: - // value true - // false -} - -func ExampleLRU() { - r, _ := http.NewRequest("GET", "/", nil) - - cache := New(2) - - cache.Store("key", "value", r) - - v, ok, _ := cache.Load("key", r) - fmt.Println(v, ok) - - cache.Delete("key", r) - v, ok, _ = cache.Load("key", r) - fmt.Println(v, ok) - - // Output: - // value true - // false -} - -func ExampleReplicator() { - lru := New(20) - fileSystem := NewFileSystem(context.TODO(), 0, "/tmp") - replicator := Replicator{ - Persistent: fileSystem, - InMemory: lru, - } - - replicator.Store("key", "value", nil) - - v, _, _ := lru.Load("key", nil) - fmt.Println(v) - - v, _, _ = fileSystem.Load("key", nil) - fmt.Println(v) - - // Output: - // value - // value -} diff --git a/store/fifo.go b/store/fifo.go deleted file mode 100644 index 235b61e..0000000 --- a/store/fifo.go +++ /dev/null @@ -1,208 +0,0 @@ -package store - -import ( - "context" - "net/http" - "sync" - "time" -) - -// NewFIFO return a simple FIFO Cache instance safe for concurrent usage, -// And spawning a garbage collector goroutine to collect expired record. -// The cache send record to garbage collector through a queue when it stored a new one. -// Once the garbage collector received the record it checks if record not expired to wait until expiration, -// Otherwise, wait for the next record. -// When the all expired record collected the garbage collector will be blocked, -// until new record stored to repeat the process. -// The context will be Passed to garbage collector -func NewFIFO(ctx context.Context, ttl time.Duration) *FIFO { - queue := &queue{ - notify: make(chan struct{}, 1), - mu: &sync.Mutex{}, - } - - f := &FIFO{ - queue: queue, - TTL: ttl, - records: make(map[string]*record), - MU: &sync.Mutex{}, - } - - go gc(ctx, queue, f) - - return f -} - -// FIFO Cache instance safe for concurrent usage. -type FIFO struct { - // TTL To expire a value in cache. - // TTL Must be greater than 0. - TTL time.Duration - - // OnEvicted optionally specifies a callback function to be - // executed when an entry is purged from the cache. - OnEvicted OnEvicted - - MU *sync.Mutex - records map[string]*record - queue *queue -} - -// Load returns the value stored in the Cache for a key, or nil if no value is present. -// The ok result indicates whether value was found in the Cache. -func (f *FIFO) Load(key string, _ *http.Request) (interface{}, bool, error) { - f.MU.Lock() - defer f.MU.Unlock() - - record, ok := f.records[key] - - if !ok { - return nil, ok, nil - } - - if time.Now().UTC().After(record.Exp) { - f.delete(key) - return nil, ok, ErrCachedExp - } - - return record.Value, ok, nil -} - -// Store sets the value for a key. -func (f *FIFO) Store(key string, value interface{}, _ *http.Request) error { - f.MU.Lock() - defer f.MU.Unlock() - - exp := time.Now().UTC().Add(f.TTL) - record := &record{ - Key: key, - Exp: exp, - Value: value, - } - - f.records[key] = record - f.queue.push(record) - - return nil -} - -// Update the value for a key without updating TTL. -func (f *FIFO) Update(key string, value interface{}, _ *http.Request) error { - f.MU.Lock() - defer f.MU.Unlock() - - if r, ok := f.records[key]; ok { - r.Value = value - } - - return nil -} - -// Delete the value for a key. -func (f *FIFO) Delete(key string, _ *http.Request) error { - f.MU.Lock() - defer f.MU.Unlock() - f.delete(key) - return nil -} - -func (f *FIFO) delete(key string) { - r, ok := f.records[key] - - if !ok { - return - } - - delete(f.records, key) - - if f.OnEvicted != nil { - f.OnEvicted(key, r.Value) - } -} - -// Keys return cache records keys. -func (f *FIFO) Keys() []string { - f.MU.Lock() - defer f.MU.Unlock() - keys := make([]string, 0) - - for k := range f.records { - keys = append(keys, k) - } - - return keys -} - -type node struct { - record *record - next *node -} - -type queue struct { - mu *sync.Mutex - head *node - tail *node - notify chan struct{} -} - -func (q *queue) next() *record { - q.mu.Lock() - defer q.mu.Unlock() - if q.head != nil { - current := q.head - q.head = current.next - return current.record - } - return nil -} - -func (q *queue) push(r *record) { - q.mu.Lock() - defer q.mu.Unlock() - node := &node{ - record: r, - next: nil, - } - if q.head == nil { - q.head = node - q.tail = q.head - select { - case q.notify <- struct{}{}: - default: - } - return - } - q.tail.next = node - q.tail = q.tail.next -} - -func gc(ctx context.Context, queue *queue, cache Cache) { - for { - record := queue.next() - - if record == nil { - select { - case <-queue.notify: - continue - case <-ctx.Done(): - return - } - } - - _, ok, _ := cache.Load(record.Key, nil) - - // check if the record exist then wait until it expired - if ok { - d := record.Exp.Sub(time.Now().UTC()) - select { - case <-time.After(d): - case <-ctx.Done(): - return - } - } - - // invoke Load to expire the record, the load method used over delete - // cause the record may be renewed, and the cache queued the record again in another node. - _, _, _ = cache.Load(record.Key, nil) - } -} diff --git a/store/fifo_test.go b/store/fifo_test.go deleted file mode 100644 index 4177cb5..0000000 --- a/store/fifo_test.go +++ /dev/null @@ -1,254 +0,0 @@ -package store - -import ( - "context" - "fmt" - "math/rand" - "strconv" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestFIFOStore(t *testing.T) { - const key = "key" - - queue := &queue{ - notify: make(chan struct{}, 1), - mu: new(sync.Mutex), - } - - cache := &FIFO{ - MU: new(sync.Mutex), - records: make(map[string]*record, 1), - queue: queue, - } - - err := cache.Store(key, "value", nil) - _, ok := cache.records[key] - r := queue.next() - - assert.NoError(t, err) - assert.True(t, ok) - assert.Equal(t, key, r.Key) -} - -func TestFIFOUpdate(t *testing.T) { - const ( - key = "key" - value = "value" - ) - - cache := &FIFO{ - MU: new(sync.Mutex), - records: make(map[string]*record, 1), - } - cache.records[key] = new(record) - - err := cache.Update(key, value, nil) - r := cache.records[key] - - assert.NoError(t, err) - assert.Equal(t, value, r.Value) -} - -func TestFIFODelete(t *testing.T) { - table := []struct { - name string - key string - value interface{} - }{ - { - name: "it's not crash when trying to delete a non exist record", - key: "key", - }, - { - name: "it delete a exist record", - key: "key", - value: "value", - }, - } - - for _, tt := range table { - t.Run(tt.name, func(t *testing.T) { - cache := &FIFO{ - MU: new(sync.Mutex), - records: make(map[string]*record, 1), - } - - if tt.value != nil { - cache.records[tt.key] = new(record) - } - - err := cache.Delete(tt.key, nil) - _, ok := cache.records[tt.key] - - assert.NoError(t, err) - assert.False(t, ok) - }) - } -} - -func TestFifoLoad(t *testing.T) { - table := []struct { - name string - key string - value interface{} - add bool - found bool - exp time.Time - err error - }{ - { - name: "it return err when key expired", - key: "expired", - add: true, - found: true, - exp: time.Now().Add(-time.Hour), - err: ErrCachedExp, - }, - { - name: "it return false when key does not exist", - key: "key", - }, - { - name: "it return true and value when exist", - exp: time.Now().Add(time.Hour), - found: true, - add: true, - key: "test", - value: "test", - }, - } - - for _, tt := range table { - t.Run(tt.name, func(t *testing.T) { - cache := &FIFO{ - MU: new(sync.Mutex), - records: make(map[string]*record, 1), - } - - if tt.add { - r := &record{ - Exp: tt.exp, - Key: tt.key, - Value: tt.value, - } - - cache.records[tt.key] = r - } - - v, ok, err := cache.Load(tt.key, nil) - - assert.Equal(t, tt.value, v) - assert.Equal(t, tt.err, err) - assert.Equal(t, tt.found, ok) - }) - } -} - -func TestFIFOEvict(t *testing.T) { - evictedKeys := make([]string, 0) - onEvictedFun := func(key string, value interface{}) { - evictedKeys = append(evictedKeys, key) - } - - fifo := NewFIFO(context.Background(), 1) - fifo.OnEvicted = onEvictedFun - - for i := 0; i < 10; i++ { - key := fmt.Sprintf("myKey%d", i) - fifo.Store(key, 1234, nil) - } - - for i := 0; i < 10; i++ { - key := fmt.Sprintf("myKey%d", i) - fifo.Delete(key, nil) - } - - assert.Equal(t, 10, len(evictedKeys)) - - for i := 0; i < 10; i++ { - key := fmt.Sprintf("myKey%d", i) - assert.Equal(t, key, evictedKeys[i]) - } -} - -func TestQueue(t *testing.T) { - queue := &queue{ - notify: make(chan struct{}, 10), - mu: &sync.Mutex{}, - } - - for i := 0; i < 5; i++ { - queue.push( - &record{ - Key: "any", - Value: i, - }) - } - - for i := 0; i < 5; i++ { - r := queue.next() - assert.Equal(t, i, r.Value) - } -} - -func TestFifoKeys(t *testing.T) { - ctx, cacnel := context.WithCancel(context.Background()) - defer cacnel() - - f := NewFIFO(ctx, time.Minute) - - f.Store("1", "", nil) - f.Store("2", "", nil) - f.Store("3", "", nil) - - assert.ElementsMatch(t, []string{"1", "2", "3"}, f.Keys()) -} - -func TestGC(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - cache := NewFIFO(ctx, time.Nanosecond*50) - - cache.Store("1", 1, nil) - cache.Store("2", 2, nil) - - time.Sleep(time.Millisecond) - _, ok, _ := cache.Load("1", nil) - assert.False(t, ok) - - _, ok, _ = cache.Load("2", nil) - assert.False(t, ok) -} - -func BenchmarkFIFIO(b *testing.B) { - cache := NewFIFO(context.Background(), time.Minute) - benchmarkCache(b, cache) -} - -func benchmarkCache(b *testing.B, cache Cache) { - keys := []string{} - - for i := 0; i < 100; i++ { - key := strconv.Itoa(i) - keys = append(keys, key) - } - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - key := keys[rand.Intn(100)] - _, ok, _ := cache.Load(key, nil) - if ok { - cache.Delete(key, nil) - } else { - cache.Store(key, struct{}{}, nil) - } - } - }) -} diff --git a/store/filesystem.go b/store/filesystem.go deleted file mode 100644 index 8ce6986..0000000 --- a/store/filesystem.go +++ /dev/null @@ -1,172 +0,0 @@ -package store - -import ( - "bytes" - "context" - "encoding/gob" - "io/ioutil" - "net/http" - "os" - "path/filepath" - "sync" - "time" -) - -func init() { - gob.Register(&record{}) -} - -const fileExt = ".cache.gob" - -// FileSystem stores cache record in the filesystem. -// FileSystem encode/decode cache record using encoding/gob. -type FileSystem struct { - // TTL To expire a value in cache. - // 0 TTL means no expiry policy specified. - TTL time.Duration - - MU *sync.RWMutex - - path string - queue *queue -} - -// Load returns the value stored in the Cache for a key, or nil if no value is present. -// The ok result indicates whether value was found in the Cache. -func (f *FileSystem) Load(key string, _ *http.Request) (interface{}, bool, error) { - filename := f.fileName(key) - - f.MU.RLock() - defer f.MU.RUnlock() - - data, err := ioutil.ReadFile(filename) - if err != nil { - if os.IsNotExist(err) { - return nil, false, nil - } - return nil, false, err - } - - r := new(record) - err = f.decode(r, data) - if err != nil { - return nil, true, err - } - - if f.TTL > 0 { - if time.Now().UTC().After(r.Exp) { - _ = f.delete(key) - return nil, false, ErrCachedExp - } - } - - return r.Value, true, nil -} - -// Store sets the value for a key. -func (f *FileSystem) Store(key string, value interface{}, _ *http.Request) error { - filename := f.fileName(key) - - r := &record{ - Key: key, - Value: value, - Exp: time.Now().UTC(), - } - - f.MU.Lock() - defer f.MU.Unlock() - - b, err := f.encode(r) - if err != nil { - return err - } - - if f.TTL > 0 { - r.Value = nil - f.queue.push(r) - } - - return ioutil.WriteFile(filename, b, 0600) -} - -// Delete the value for a key. -func (f *FileSystem) Delete(key string, _ *http.Request) error { - f.MU.RLock() - defer f.MU.RUnlock() - - return f.delete(key) -} - -func (f *FileSystem) delete(key string) error { - filename := f.fileName(key) - return os.Remove(filename) -} - -// Keys return cache records keys. -func (f *FileSystem) Keys() []string { - p := f.fileName("*") - - f.MU.RLock() - defer f.MU.RUnlock() - - files, err := filepath.Glob(p) - if err != nil { - panic(err) - } - - for i, f := range files { - f = filepath.Base(f) - l := len(f) - len(fileExt) - files[i] = f[:l] - } - - return files -} - -func (f *FileSystem) fileName(key string) string { - return filepath.Join(f.path, key+fileExt) -} - -func (f *FileSystem) decode(r *record, data []byte) error { - b := new(bytes.Buffer) - _, err := b.Write(data) - if err != nil { - return err - } - - return gob.NewDecoder(b).Decode(r) -} - -func (f *FileSystem) encode(r *record) ([]byte, error) { - b := new(bytes.Buffer) - err := gob.NewEncoder(b).Encode(r) - return b.Bytes(), err -} - -// NewFileSystem return FileSystem Cache instance safe for concurrent usage, -// And spawning a garbage collector goroutine to collect expired record. -// The cache send record to garbage collector through a queue when it stored a new one. -// Once the garbage collector received the record it checks if record not expired to wait until expiration, -// Otherwise, wait for the next record. -// When the all expired record collected the garbage collector will be blocked, -// until new record stored to repeat the process. -// The context will be Passed to garbage collector -func NewFileSystem(ctx context.Context, ttl time.Duration, path string) *FileSystem { - queue := &queue{ - notify: make(chan struct{}, 1), - mu: &sync.Mutex{}, - } - - f := &FileSystem{ - path: path, - MU: &sync.RWMutex{}, - queue: queue, - TTL: ttl, - } - - if ttl > 0 { - go gc(ctx, queue, f) - } - - return f -} diff --git a/store/filesystem_test.go b/store/filesystem_test.go deleted file mode 100644 index 2d8abca..0000000 --- a/store/filesystem_test.go +++ /dev/null @@ -1,169 +0,0 @@ -package store - -import ( - "context" - "encoding/gob" - "fmt" - "os" - "sync" - "time" - - // "net/http" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestFileSystemStore(t *testing.T) { - table := []struct { - name string - expectedErr bool - key string - value interface{} - ttl time.Duration - }{ - { - name: "it return error when failed to encode value", - expectedErr: true, - key: "error", - value: &struct{ int }{1}, - }, - { - name: "it return nil error and store data in filesystem", - expectedErr: false, - key: "key", - value: "value", - }, - { - name: "it push record to queue when ttl > 0", - expectedErr: false, - key: "queue", - value: "value", - ttl: time.Second, - }, - } - - for i, tt := range table { - t.Run(tt.name, func(t *testing.T) { - path := fmt.Sprintf("/tmp/test/%v", i) - os.MkdirAll(path, os.ModePerm) - defer os.RemoveAll(path) - - f := &FileSystem{ - path: path, - MU: &sync.RWMutex{}, - queue: &queue{ - notify: make(chan struct{}, 1), - mu: &sync.Mutex{}, - }, - TTL: tt.ttl, - } - - err := f.Store(tt.key, tt.value, nil) - - if tt.expectedErr { - assert.True(t, err != nil) - return - } - - if tt.ttl > 0 { - r := f.queue.next() - assert.Equal(t, nil, r.Value) - assert.Equal(t, tt.key, r.Key) - } - - _, err = os.Stat(f.fileName(tt.key)) - assert.NoError(t, err) - - }) - } -} - -func TestFileSystemLoad(t *testing.T) { - table := []struct { - name string - expectedErr bool - key string - value interface{} - found bool - }{ - { - name: "it return error when failed to decode", - expectedErr: true, - key: "error", - value: nil, - found: true, - }, - { - name: "it return false and nil error when record does not exist", - expectedErr: false, - key: "notfound", - value: nil, - found: false, - }, - { - name: "it return value", - expectedErr: false, - key: "key", - value: "value", - found: true, - }, - } - - for _, tt := range table { - t.Run(tt.name, func(t *testing.T) { - ctx, cacnel := context.WithCancel(context.Background()) - defer cacnel() - - f := NewFileSystem(ctx, 0, "./testdata") - v, ok, err := f.Load(tt.key, nil) - - assert.Equal(t, tt.expectedErr, err != nil) - assert.Equal(t, tt.value, v) - assert.Equal(t, tt.found, ok) - - }) - } -} - -func TestFileSystemKeys(t *testing.T) { - ctx, cacnel := context.WithCancel(context.Background()) - defer cacnel() - - f := NewFileSystem(ctx, 0, "./testdata") - keys := f.Keys() - expected := []string{"error", "key", "queue"} - - assert.ElementsMatch(t, keys, expected) -} - -func TestFileSystemName(t *testing.T) { - f := new(FileSystem) - f.path = "/test" - - got := f.fileName("key") - expected := "/test/key.cache.gob" - - assert.Equal(t, expected, got) -} - -func TestFileSystemDelete(t *testing.T) { - f := new(FileSystem) - f.path = "/tmp" - f.MU = new(sync.RWMutex) - - filename := f.fileName("deletekey") - os.OpenFile(filename, os.O_RDONLY|os.O_CREATE, 0666) - - err := f.Delete("deletekey", nil) - assert.NoError(t, err) - - _, err = os.Stat(filename) - assert.Error(t, err) -} - -func BenchmarkFileSystem(b *testing.B) { - gob.Register(struct{}{}) - cache := NewFileSystem(context.Background(), 0, "/tmp") - benchmarkCache(b, cache) -} diff --git a/store/lru.go b/store/lru.go deleted file mode 100644 index b8c26bc..0000000 --- a/store/lru.go +++ /dev/null @@ -1,225 +0,0 @@ -package store - -import ( - "container/list" - "net/http" - "sync" - "time" -) - -// LRU implements a fixed-size thread safe LRU cache. -// It is based on the LRU cache in Groupcache. -type LRU struct { - // MaxEntries is the maximum number of cache entries before - // an item is evicted. Zero means no limit. - MaxEntries int - - // OnEvicted optionally specifies a callback function to be - // executed when an entry is purged from the cache. - OnEvicted OnEvicted - - // TTL To expire a value in cache. - // 0 TTL means no expiry policy specified. - TTL time.Duration - - MU *sync.Mutex - - ll *list.List - cache map[string]*list.Element -} - -// New creates a new LRU Cache. -// If maxEntries is zero, the cache has no limit and it's assumed -// that eviction is done by the caller. -func New(maxEntries int) *LRU { - return &LRU{ - MaxEntries: maxEntries, - ll: list.New(), - cache: make(map[string]*list.Element), - MU: new(sync.Mutex), - } -} - -// Store sets the value for a key. -func (l *LRU) Store(key string, value interface{}, _ *http.Request) error { - l.MU.Lock() - defer l.MU.Unlock() - - if l.cache == nil { - l.cache = make(map[string]*list.Element) - l.ll = list.New() - } - - if ee, ok := l.cache[key]; ok { - l.ll.MoveToFront(ee) - r := ee.Value.(*record) - r.Value = value - l.withTTL(r) - return nil - } - - r := &record{ - Key: key, - Value: value, - } - - l.withTTL(r) - - ele := l.ll.PushFront(r) - l.cache[key] = ele - - if l.MaxEntries != 0 && l.ll.Len() > l.MaxEntries { - l.removeOldest() - } - - return nil -} - -// Update the value for a key without updating the "recently used". -func (l *LRU) Update(key string, value interface{}, _ *http.Request) error { - l.MU.Lock() - defer l.MU.Unlock() - - if e, ok := l.cache[key]; ok { - r := e.Value.(*record) - r.Value = value - } - - return nil -} - -func (l *LRU) withTTL(r *record) { - if l.TTL > 0 { - r.Exp = time.Now().UTC().Add(l.TTL) - } -} - -// Load returns the value stored in the Cache for a key, or nil if no value is present. -// The ok result indicates whether value was found in the Cache. -func (l *LRU) Load(key string, _ *http.Request) (interface{}, bool, error) { - l.MU.Lock() - defer l.MU.Unlock() - - if l.cache == nil { - return nil, false, nil - } - - if ele, ok := l.cache[key]; ok { - r := ele.Value.(*record) - - if l.TTL > 0 { - if time.Now().UTC().After(r.Exp) { - l.removeElement(ele) - return nil, ok, ErrCachedExp - } - } - - l.ll.MoveToFront(ele) - return r.Value, ok, nil - } - - return nil, false, nil -} - -// Peek returns the value stored in the Cache for a key -// without updating the "recently used", or nil if no value is present. -// The ok result indicates whether value was found in the Cache. -func (l *LRU) Peek(key string, _ *http.Request) (interface{}, bool, error) { - if ele, ok := l.cache[key]; ok { - r := ele.Value.(*record) - - if l.TTL > 0 { - if time.Now().UTC().After(r.Exp) { - l.removeElement(ele) - return nil, ok, ErrCachedExp - } - } - - return r.Value, ok, nil - } - - return nil, false, nil -} - -// Delete the value for a key. -func (l *LRU) Delete(key string, _ *http.Request) error { - l.MU.Lock() - defer l.MU.Unlock() - - if l.cache == nil { - return nil - } - - if ele, hit := l.cache[key]; hit { - l.removeElement(ele) - } - - return nil -} - -// RemoveOldest removes the oldest item from the cache. -func (l *LRU) RemoveOldest() { - l.MU.Lock() - defer l.MU.Unlock() - l.removeOldest() -} - -func (l *LRU) removeOldest() { - if l.cache == nil { - return - } - - if ele := l.ll.Back(); ele != nil { - l.removeElement(ele) - } -} - -func (l *LRU) removeElement(e *list.Element) { - l.ll.Remove(e) - kv := e.Value.(*record) - delete(l.cache, kv.Key) - if l.OnEvicted != nil { - l.OnEvicted(kv.Key, kv.Value) - } -} - -// Len returns the number of items in the cache. -func (l *LRU) Len() int { - l.MU.Lock() - defer l.MU.Unlock() - - if l.cache == nil { - return 0 - } - - return l.ll.Len() -} - -// Clear purges all stored items from the cache. -func (l *LRU) Clear() { - l.MU.Lock() - defer l.MU.Unlock() - - if l.OnEvicted != nil { - for _, e := range l.cache { - kv := e.Value.(*record) - l.OnEvicted(kv.Key, kv.Value) - } - } - - l.ll = nil - l.cache = nil -} - -// Keys return cache records keys. -func (l *LRU) Keys() []string { - l.MU.Lock() - defer l.MU.Unlock() - keys := make([]string, 0) - - for k := range l.cache { - keys = append(keys, k) - } - - return keys -} diff --git a/store/lru_test.go b/store/lru_test.go deleted file mode 100644 index 2d7f61d..0000000 --- a/store/lru_test.go +++ /dev/null @@ -1,205 +0,0 @@ -package store - -import ( - "fmt" - "net/http" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestLRU(t *testing.T) { - table := []struct { - name string - key string - value interface{} - op string - expectedErr bool - found bool - }{ - { - name: "it return false when key does not exist", - op: "load", - key: "key", - found: false, - }, - { - name: "it return true and value when exist", - op: "load", - key: "test", - value: "test", - found: true, - }, - { - name: "it overwrite exist key and value when store", - op: "store", - key: "test", - value: "test2", - found: true, - }, - { - name: "it create new record when store", - op: "store", - key: "key", - value: "value", - found: true, - }, - { - name: "it's not crash when trying to delete a non exist record", - key: "key", - found: false, - }, - { - name: "it delete a exist record", - op: "delete", - key: "test", - found: false, - }, - } - - for _, tt := range table { - t.Run(tt.name, func(t *testing.T) { - cache := New(2) - - cache.Store("test", "test", nil) - - r, _ := http.NewRequest("GET", "/", nil) - var err error - - switch tt.op { - case "load": - v, ok, err := cache.Load(tt.key, r) - assert.Equal(t, tt.value, v) - assert.Equal(t, tt.found, ok) - assert.NoError(t, err) - return - case "store": - err = cache.Store(tt.key, tt.value, r) - case "delete": - err = cache.Delete(tt.key, r) - } - - v, ok, _ := cache.Load(tt.key, nil) - assert.NoError(t, err) - assert.Equal(t, tt.found, ok) - assert.Equal(t, tt.value, v) - - }) - } - -} - -func TestLRUUpdate(t *testing.T) { - cache := New(2) - cache.Store("1", 1, nil) - cache.Store("2", 2, nil) - - err := cache.Update("1", 0, nil) - r := cache.ll.Back().Value.(*record) - - assert.NoError(t, err) - assert.Equal(t, "1", r.Key) - assert.Equal(t, 0, r.Value) -} - -func TestLRUPeek(t *testing.T) { - cache := New(2) - cache.Store("1", 1, nil) - cache.Store("2", 2, nil) - - v, ok, err := cache.Peek("1", nil) - r := cache.ll.Back().Value.(*record) - - assert.NoError(t, err) - assert.True(t, ok) - assert.Equal(t, 1, v) - assert.Equal(t, "1", r.Key) -} - -func TestTTLLRU(t *testing.T) { - cache := New(2) - cache.TTL = time.Nanosecond * 100 - - _ = cache.Store("key", "value", nil) - - time.Sleep(time.Nanosecond * 110) - - v, ok, err := cache.Load("key", nil) - - assert.Equal(t, ErrCachedExp, err) - assert.True(t, ok) - assert.Nil(t, v) -} - -func TestLRUEvict(t *testing.T) { - evictedKeys := make([]string, 0) - onEvictedFun := func(key string, value interface{}) { - evictedKeys = append(evictedKeys, key) - } - - lru := New(20) - lru.OnEvicted = onEvictedFun - - for i := 0; i < 22; i++ { - lru.Store(fmt.Sprintf("myKey%d", i), 1234, nil) - } - - assert.Equal(t, 2, len(evictedKeys)) - assert.Equal(t, 2, len(evictedKeys)) - assert.Equal(t, "myKey0", evictedKeys[0]) - assert.Equal(t, "myKey1", evictedKeys[1]) -} - -func TestLRULen(t *testing.T) { - lru := New(1) - lru.Store("1", 1, nil) - - assert.Equal(t, lru.Len(), 1) - - lru.Clear() - assert.Equal(t, lru.Len(), 0) -} - -func TestLRUClear(t *testing.T) { - i := 0 - - onEvictedFun := func(key string, value interface{}) { - i++ - } - - lru := New(20) - lru.OnEvicted = onEvictedFun - - for i := 0; i < 20; i++ { - lru.Store(fmt.Sprintf("myKey%d", i), 1234, nil) - } - - lru.Clear() - assert.Equal(t, 20, i) -} - -func TestRemoveOldest(t *testing.T) { - lru := New(1) - lru.Store("1", 1, nil) - - assert.Equal(t, lru.Len(), 1) - - lru.RemoveOldest() - assert.Equal(t, lru.Len(), 0) -} - -func TestLRUKeys(t *testing.T) { - l := New(3) - - l.Store("1", "", nil) - l.Store("2", "", nil) - l.Store("3", "", nil) - - assert.ElementsMatch(t, []string{"1", "2", "3"}, l.Keys()) -} - -func BenchmarkLRU(b *testing.B) { - cache := New(2) - benchmarkCache(b, cache) -} diff --git a/store/replicator.go b/store/replicator.go deleted file mode 100644 index 9735487..0000000 --- a/store/replicator.go +++ /dev/null @@ -1,103 +0,0 @@ -package store - -import ( - "net/http" - "sort" -) - -// Replicator holding two caches instances to replicate data between them on load and store data, -// Replicator is typically used to replicate data between in-memory and a persistent caches, -// to obtain data consistency and persistency between runs of a program or scaling purposes. -// -// NOTICE: Replicator cache ignore errors from in-memory cache since all data first stored in Persistent. -type Replicator struct { - InMemory Cache - Persistent Cache -} - -// Load returns the value stored in the Cache for a key, or nil if no value is present. -// The ok result indicates whether value was found in the Cache. -// When any error occurs fallback to persistent cache. -func (r *Replicator) Load(key string, req *http.Request) (interface{}, bool, error) { - v, ok, err := r.InMemory.Load(key, req) - - if err != nil { - ok = false - } - - if !ok { - - v, ok, err = r.Persistent.Load(key, req) - - if ok { - _ = r.InMemory.Store(key, v, req) - } - } - - return v, ok, err -} - -// Store sets the value for a key. -func (r *Replicator) Store(key string, value interface{}, req *http.Request) error { - err := r.Persistent.Store(key, value, req) - - if err != nil { - return err - } - - _ = r.InMemory.Store(key, value, req) - return nil -} - -// Delete the value for a key. -func (r *Replicator) Delete(key string, req *http.Request) error { - if err := r.Persistent.Delete(key, req); err != nil { - return err - } - - _ = r.InMemory.Delete(key, req) - return nil -} - -// Keys return cache records keys. -func (r *Replicator) Keys() []string { - return r.Persistent.Keys() -} - -// IsSynced return true/false if two cached keys are equal. -func (r *Replicator) IsSynced() bool { - mk := r.InMemory.Keys() - pk := r.Persistent.Keys() - - if len(mk) != len(pk) { - return false - } - - sort.Strings(mk) - sort.Strings(pk) - - for i, v := range mk { - if v != pk[i] { - return false - } - } - - return true -} - -// Sync two caches by reload all persistent cache records into in-memory cache. -func (r *Replicator) Sync() error { - - if r.IsSynced() { - return nil - } - - for _, k := range r.Persistent.Keys() { - _, _, err := r.Load(k, nil) - if err != nil { - return err - } - } - - return nil -} diff --git a/store/replicator_test.go b/store/replicator_test.go deleted file mode 100644 index 1bb7553..0000000 --- a/store/replicator_test.go +++ /dev/null @@ -1,355 +0,0 @@ -package store - -import ( - "fmt" - "net/http" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestReplicatorLoad(t *testing.T) { - const ( - key = "key" - value = "value" - funcName = "Load" - ) - - table := []struct { - name string - value interface{} - found bool - expectedErr bool - prepare func() (Cache, Cache) - }{ - { - name: "it return false when key does not exist in both caches", - value: nil, - found: false, - expectedErr: false, - prepare: func() (Cache, Cache) { - p := &mockCache{ - Mock: mock.Mock{}, - } - - m := &mockCache{ - Mock: mock.Mock{}, - } - - p.On(funcName).Return(nil, false, nil) - m.On(funcName).Return(nil, false, nil) - - return p, m - }, - }, - { - name: "it return value from Persistent cache when in memory return error", - value: value, - found: true, - expectedErr: false, - prepare: func() (Cache, Cache) { - p := &mockCache{ - Mock: mock.Mock{}, - } - - m := &mockCache{ - Mock: mock.Mock{}, - } - - p.On(funcName).Return(value, true, nil) - m.On(funcName).Return(nil, false, fmt.Errorf("Replicator test #L56")) - m.On("Store").Return(nil) - return p, m - }, - }, - { - name: "it return error from Persistent cache", - value: nil, - found: false, - expectedErr: true, - prepare: func() (Cache, Cache) { - p := &mockCache{ - Mock: mock.Mock{}, - } - - m := &mockCache{ - Mock: mock.Mock{}, - } - - p.On(funcName).Return(nil, false, fmt.Errorf("Replicator test #78")) - m.On(funcName).Return(nil, false, nil) - return p, m - }, - }, - { - name: "it return value from memory cache", - value: value, - found: true, - expectedErr: false, - prepare: func() (Cache, Cache) { - p := &mockCache{ - Mock: mock.Mock{}, - } - - m := &mockCache{ - Mock: mock.Mock{}, - } - - m.On(funcName).Return(value, true, nil) - return p, m - }, - }, - } - - for _, tt := range table { - t.Run(tt.name, func(t *testing.T) { - p, m := tt.prepare() - - r := &Replicator{ - Persistent: p, - InMemory: m, - } - - v, ok, err := r.Load(key, nil) - - assert.Equal(t, tt.value, v) - assert.Equal(t, tt.found, ok) - assert.Equal(t, tt.expectedErr, err != nil) - }) - } -} - -func TestReplicatorStore(t *testing.T) { - const funcName = "Store" - - table := []struct { - name string - expectedErr bool - prepare func() (Cache, Cache) - }{ - { - name: "it return error when Persistent cache return error ", - expectedErr: true, - prepare: func() (Cache, Cache) { - p := &mockCache{ - Mock: mock.Mock{}, - } - - m := &mockCache{ - Mock: mock.Mock{}, - } - - p.On(funcName).Return(fmt.Errorf("Replicator test #L140")) - return p, m - }, - }, - { - name: "it return nil error when memory cache return error ", - expectedErr: false, - prepare: func() (Cache, Cache) { - p := &mockCache{ - Mock: mock.Mock{}, - } - - m := &mockCache{ - Mock: mock.Mock{}, - } - - p.On(funcName).Return(nil) - m.On(funcName).Return(fmt.Errorf("Replicator test #L140")) - return p, m - }, - }, - } - - for _, tt := range table { - t.Run(tt.name, func(t *testing.T) { - p, m := tt.prepare() - - r := &Replicator{ - Persistent: p, - InMemory: m, - } - - err := r.Store("key", "value", nil) - assert.Equal(t, tt.expectedErr, err != nil) - }) - } -} - -func TestReplicatorDelete(t *testing.T) { - const funcName = "Delete" - - table := []struct { - name string - expectedErr bool - prepare func() (Cache, Cache) - }{ - { - name: "it return error when Persistent cache return error ", - expectedErr: true, - prepare: func() (Cache, Cache) { - p := &mockCache{ - Mock: mock.Mock{}, - } - - m := &mockCache{ - Mock: mock.Mock{}, - } - - p.On(funcName).Return(fmt.Errorf("Replicator test #L201")) - return p, m - }, - }, - { - name: "it return nil error when memory cache return error ", - expectedErr: false, - prepare: func() (Cache, Cache) { - p := &mockCache{ - Mock: mock.Mock{}, - } - - m := &mockCache{ - Mock: mock.Mock{}, - } - - p.On(funcName).Return(nil) - m.On(funcName).Return(fmt.Errorf("Replicator test #L140")) - return p, m - }, - }, - } - - for _, tt := range table { - t.Run(tt.name, func(t *testing.T) { - p, m := tt.prepare() - - r := &Replicator{ - Persistent: p, - InMemory: m, - } - - err := r.Delete("key", nil) - assert.Equal(t, tt.expectedErr, err != nil) - }) - } -} - -func TestReplicatorKeys(t *testing.T) { - p := &mockCache{ - Mock: mock.Mock{}, - } - - r := &Replicator{ - Persistent: p, - } - - keys := []string{"1", "2", "3"} - p.On("Keys").Return(keys) - got := r.Keys() - - assert.Equal(t, keys, got) -} - -func TestReplicatorIsSynced(t *testing.T) { - const funcName = "Keys" - - table := []struct { - name string - expected bool - prepare func() (Cache, Cache) - }{ - { - name: "it return true when Persistent and memory synced", - expected: true, - prepare: func() (Cache, Cache) { - p := &mockCache{ - Mock: mock.Mock{}, - } - - m := &mockCache{ - Mock: mock.Mock{}, - } - - keys := []string{"1"} - - p.On(funcName).Return(keys) - m.On(funcName).Return(keys) - return p, m - }, - }, - { - name: "it return false when Persistent and memory keys are not equal", - expected: false, - prepare: func() (Cache, Cache) { - p := &mockCache{ - Mock: mock.Mock{}, - } - - m := &mockCache{ - Mock: mock.Mock{}, - } - - p.On(funcName).Return([]string{"1", "2"}) - m.On(funcName).Return([]string{"1"}) - return p, m - }, - }, - { - name: "it return false when Persistent and memory keys are not deep equal", - expected: false, - prepare: func() (Cache, Cache) { - p := &mockCache{ - Mock: mock.Mock{}, - } - - m := &mockCache{ - Mock: mock.Mock{}, - } - - p.On(funcName).Return([]string{"1", "2"}) - m.On(funcName).Return([]string{"1", "3"}) - return p, m - }, - }, - } - - for _, tt := range table { - t.Run(tt.name, func(t *testing.T) { - p, m := tt.prepare() - - r := &Replicator{ - Persistent: p, - InMemory: m, - } - - got := r.IsSynced() - assert.Equal(t, tt.expected, got) - }) - } -} - -type mockCache struct { - mock.Mock -} - -func (m *mockCache) Load(key string, req *http.Request) (interface{}, bool, error) { - args := m.Called() - return args.Get(0), args.Bool(1), args.Error(2) -} - -func (m *mockCache) Store(key string, value interface{}, _ *http.Request) error { - args := m.Called() - return args.Error(0) -} - -func (m *mockCache) Delete(key string, _ *http.Request) error { - args := m.Called() - return args.Error(0) -} - -func (m *mockCache) Keys() []string { - args := m.Called() - return args.Get(0).([]string) -} diff --git a/store/testdata/error.cache.gob b/store/testdata/error.cache.gob deleted file mode 100644 index d30b310..0000000 Binary files a/store/testdata/error.cache.gob and /dev/null differ diff --git a/store/testdata/key.cache.gob b/store/testdata/key.cache.gob deleted file mode 100644 index a26d9ec..0000000 Binary files a/store/testdata/key.cache.gob and /dev/null differ diff --git a/store/testdata/queue.cache.gob b/store/testdata/queue.cache.gob deleted file mode 100644 index f0ab5c4..0000000 Binary files a/store/testdata/queue.cache.gob and /dev/null differ