diff --git a/application/auth/aksk/factory.go b/application/auth/aksk/factory.go index 334a4b39..a0934288 100644 --- a/application/auth/aksk/factory.go +++ b/application/auth/aksk/factory.go @@ -14,7 +14,7 @@ var _ auth.IAuthFactory = (*factory)(nil) var driverName = "aksk" -//Register 注册auth驱动工厂 +// Register 注册auth驱动工厂 func Register() { auth.FactoryRegister(driverName, NewFactory()) } @@ -44,6 +44,10 @@ func (f *factory) Alias() []string { } } +func (f *factory) PreRouters() []*auth.PreRouter { + return nil +} + func (f *factory) Create(tokenName string, position string, rule interface{}) (application.IAuth, error) { a := &aksk{ id: toId(tokenName, position), @@ -54,7 +58,7 @@ func (f *factory) Create(tokenName string, position string, rule interface{}) (a return a, nil } -//NewFactory 生成一个 auth_apiKey工厂 +// NewFactory 生成一个 auth_apiKey工厂 func NewFactory() auth.IAuthFactory { typ := reflect.TypeOf((*Config)(nil)) render, _ := schema.Generate(typ, nil) diff --git a/application/auth/apikey/factory.go b/application/auth/apikey/factory.go index 496d37ab..ee124852 100644 --- a/application/auth/apikey/factory.go +++ b/application/auth/apikey/factory.go @@ -43,6 +43,10 @@ func (f *factory) Alias() []string { } } +func (f *factory) PreRouters() []*auth.PreRouter { + return nil +} + func (f *factory) Create(tokenName string, position string, rule interface{}) (application.IAuth, error) { a := &apikey{ id: toId(tokenName, position), diff --git a/application/auth/basic/factory.go b/application/auth/basic/factory.go index 04857553..e11dfd29 100644 --- a/application/auth/basic/factory.go +++ b/application/auth/basic/factory.go @@ -44,6 +44,10 @@ func (f *factory) Alias() []string { } } +func (f *factory) PreRouters() []*auth.PreRouter { + return nil +} + func (f *factory) Create(tokenName string, position string, rule interface{}) (application.IAuth, error) { a := &basic{ id: toId(tokenName, position), diff --git a/application/auth/factory.go b/application/auth/factory.go index 0dc7bbd7..b34f3b10 100644 --- a/application/auth/factory.go +++ b/application/auth/factory.go @@ -6,6 +6,8 @@ import ( "reflect" "strings" + "github.com/eolinker/apinto/router" + "github.com/eolinker/apinto/application" "github.com/eolinker/eosc/log" @@ -18,16 +20,24 @@ var ( _ eosc.ISetting = defaultAuthFactoryRegister ) -//IAuthFactory 鉴权工厂方法 +type PreRouter struct { + ID string + PreHandler router.IRouterPreHandler + Path string + Method []string +} + +// IAuthFactory 鉴权工厂方法 type IAuthFactory interface { Create(tokenName string, position string, rule interface{}) (application.IAuth, error) Alias() []string Render() interface{} ConfigType() reflect.Type UserType() reflect.Type + PreRouters() []*PreRouter } -//IAuthFactoryRegister 实现了鉴权工厂管理器 +// IAuthFactoryRegister 实现了鉴权工厂管理器 type IAuthFactoryRegister interface { RegisterFactoryByKey(key string, factory IAuthFactory) GetFactoryByKey(key string) (IAuthFactory, bool) @@ -35,7 +45,7 @@ type IAuthFactoryRegister interface { Alias() map[string]string } -//driverRegister 驱动注册器 +// driverRegister 驱动注册器 type driverRegister struct { register eosc.IRegister[IAuthFactory] keys []string @@ -80,7 +90,7 @@ func (dm *driverRegister) ReadOnly() bool { return true } -//newAuthFactoryManager 创建auth工厂管理器 +// newAuthFactoryManager 创建auth工厂管理器 func newAuthFactoryManager() *driverRegister { return &driverRegister{ register: eosc.NewRegister[IAuthFactory](), @@ -90,12 +100,12 @@ func newAuthFactoryManager() *driverRegister { } } -//GetFactoryByKey 获取指定auth工厂 +// GetFactoryByKey 获取指定auth工厂 func (dm *driverRegister) GetFactoryByKey(key string) (IAuthFactory, bool) { return dm.register.Get(key) } -//RegisterFactoryByKey 注册auth工厂 +// RegisterFactoryByKey 注册auth工厂 func (dm *driverRegister) RegisterFactoryByKey(key string, factory IAuthFactory) { err := dm.register.Register(key, factory, true) if err != nil { @@ -109,7 +119,7 @@ func (dm *driverRegister) RegisterFactoryByKey(key string, factory IAuthFactory) } } -//Keys 返回所有已注册的key +// Keys 返回所有已注册的key func (dm *driverRegister) Keys() []string { return dm.keys } @@ -118,18 +128,18 @@ func (dm *driverRegister) Alias() map[string]string { return dm.driverAlias } -//FactoryRegister 注册auth工厂到默认auth工厂注册器 +// FactoryRegister 注册auth工厂到默认auth工厂注册器 func FactoryRegister(key string, factory IAuthFactory) { defaultAuthFactoryRegister.RegisterFactoryByKey(key, factory) } -//Get 从默认auth工厂注册器中获取auth工厂 +// Get 从默认auth工厂注册器中获取auth工厂 func Get(key string) (IAuthFactory, bool) { return defaultAuthFactoryRegister.GetFactoryByKey(key) } -//Keys 返回默认的auth工厂注册器中所有已注册的key +// Keys 返回默认的auth工厂注册器中所有已注册的key func Keys() []string { return defaultAuthFactoryRegister.Keys() } @@ -138,7 +148,7 @@ func Alias() map[string]string { return defaultAuthFactoryRegister.Alias() } -//GetFactory 获取指定auth工厂,若指定的不存在则返回一个已注册的工厂 +// GetFactory 获取指定auth工厂,若指定的不存在则返回一个已注册的工厂 func GetFactory(name string) (IAuthFactory, error) { factory, ok := Get(name) if !ok { diff --git a/application/auth/jwt/factory.go b/application/auth/jwt/factory.go index 716903fe..79d89be5 100644 --- a/application/auth/jwt/factory.go +++ b/application/auth/jwt/factory.go @@ -14,7 +14,7 @@ var _ auth.IAuthFactory = (*factory)(nil) var driverName = "jwt" -//Register 注册auth驱动工厂 +// Register 注册auth驱动工厂 func Register() { auth.FactoryRegister(driverName, NewFactory()) } @@ -43,6 +43,10 @@ func (f *factory) Alias() []string { } } +func (f *factory) PreRouters() []*auth.PreRouter { + return nil +} + func (f *factory) Create(tokenName string, position string, rule interface{}) (application.IAuth, error) { baseConfig, ok := rule.(*application.BaseConfig) if !ok { @@ -66,7 +70,7 @@ func (f *factory) Create(tokenName string, position string, rule interface{}) (a return a, nil } -//NewFactory 生成一个 auth_apiKey工厂 +// NewFactory 生成一个 auth_apiKey工厂 func NewFactory() auth.IAuthFactory { typ := reflect.TypeOf((*Config)(nil)) render, _ := schema.Generate(typ, nil) diff --git a/application/auth/oauth2/authorize.go b/application/auth/oauth2/authorize.go new file mode 100644 index 00000000..cd8eb53d --- /dev/null +++ b/application/auth/oauth2/authorize.go @@ -0,0 +1,144 @@ +package oauth2 + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "os" + "strconv" + "sync" + "time" + + scope_manager "github.com/eolinker/apinto/scope-manager" + + "github.com/eolinker/apinto/resources" + http_context "github.com/eolinker/eosc/eocontext/http-context" +) + +const ( + ResponseTypeCode = "code" + ResponseTypeToken = "token" +) + +func NewAuthorizeHandler() *AuthorizeHandler { + return &AuthorizeHandler{} +} + +type AuthorizeHandler struct { + cache scope_manager.IProxyOutput[resources.ICache] + once sync.Once +} + +func (a *AuthorizeHandler) Handle(ctx http_context.IHttpContext, client *Client, params url.Values) { + responseType := params.Get("response_type") + if responseType == "" || !((responseType == ResponseTypeCode && client.EnableAuthorizationCode) || (responseType == ResponseTypeToken && client.EnableImplicitGrant)) { + ctx.Response().SetBody([]byte(fmt.Sprintf("unsupported response type: %s,client id is %s", responseType, client.ClientId))) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + + scope := params.Get("scope") + if scope == "" && client.MandatoryScope { + ctx.Response().SetBody([]byte("scope is required, client id is " + client.ClientId)) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + matchScope := false + for _, s := range client.Scopes { + if s == scope { + matchScope = true + break + } + } + if !matchScope { + ctx.Response().SetBody([]byte("invalid scope, client id is " + client.ClientId)) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + + redirectURI := params.Get("redirect_uri") + if redirectURI == "" { + ctx.Response().SetBody([]byte("redirect uri is required, client id is " + client.ClientId)) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + + matchRedirectUri := false + for _, uri := range client.RedirectUrls { + if uri == redirectURI { + matchRedirectUri = true + break + } + } + if !matchRedirectUri { + ctx.Response().SetBody([]byte("invalid redirect uri, client id is " + client.ClientId)) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + uri, err := url.Parse(redirectURI) + if err != nil { + ctx.Response().SetBody([]byte("invalid redirect uri, client id is " + client.ClientId)) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + a.once.Do(func() { + a.cache = scope_manager.Auto[resources.ICache]("", "redis") + }) + list := a.cache.List() + if len(list) < 1 { + ctx.Response().SetBody([]byte("redis cache is not available")) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + cache := list[0] + query := url.Values{} + switch responseType { + case ResponseTypeCode: + { + // 授权码模式 + provisionKey := params.Get("provision_key") + if provisionKey != client.ProvisionKey { + ctx.Response().SetBody([]byte("invalid provision key, client id is " + client.ClientId)) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + code := generateRandomString() + redisKey := fmt.Sprintf("apinto:oauth2_codes:%s:%s", os.Getenv("cluster_id"), code) + field := map[string]interface{}{ + "code": code, + "scope": scope, + } + _, err = cache.HMSetN(ctx.Context(), redisKey, field, 6*time.Minute).Result() + if err != nil { + ctx.Response().SetBody([]byte(fmt.Sprintf("(%s)redis HMSet %s error: %s", client.ClientId, redisKey, err.Error()))) + ctx.Response().SetStatus(http.StatusInternalServerError, "server error") + return + } + query.Set("code", code) + } + case ResponseTypeToken: + { + token, err := generateToken(ctx.Context(), cache, client.ClientId, client.TokenExpiration, client.RefreshTokenTTL, scope, false) + if err != nil { + ctx.Response().SetBody([]byte(fmt.Sprintf("(%s)generate token error: %s", client.ClientId, err.Error()))) + ctx.Response().SetStatus(http.StatusInternalServerError, "server error") + return + } + query.Set("access_token", token.AccessToken) + query.Set("token_type", "bearer") + query.Set("expires_in", strconv.Itoa(token.ExpiresIn)) + } + } + + state := params.Get("state") + if state != "" { + query.Set("state", state) + } + data, _ := json.Marshal(map[string]interface{}{ + "redirect_uri": fmt.Sprintf("%s?%s", uri.String(), query.Encode()), + }) + ctx.Response().SetBody(data) + ctx.Response().SetStatus(http.StatusOK, "OK") + return +} diff --git a/application/auth/oauth2/config.go b/application/auth/oauth2/config.go new file mode 100644 index 00000000..12d8a3da --- /dev/null +++ b/application/auth/oauth2/config.go @@ -0,0 +1,42 @@ +package oauth2 + +import "github.com/eolinker/apinto/application" + +const ( + GrantAuthorizationCode = "authorization_code" + GrantClientCredentials = "client_credentials" + GrantRefreshToken = "refresh_token" +) + +type Config struct { + application.Auth + Users []*User `json:"users" label:"用户列表"` +} + +type User struct { + Pattern Pattern `json:"pattern" label:"用户信息"` + application.User +} + +type Pattern struct { + ClientId string `json:"client_id"` + ClientSecret string `json:"client_secret"` + ClientType string `json:"client_type"` + HashSecret bool `json:"hash_secret"` + RedirectUrls []string `json:"redirect_urls" label:"重定向URL"` + Scopes []string `json:"scopes" label:"授权范围"` + MandatoryScope bool `json:"mandatory_scope" label:"强制授权"` + ProvisionKey string `json:"provision_key" label:"Provision Key"` + TokenExpiration int `json:"token_expiration" label:"令牌过期时间"` + RefreshTokenTTL int `json:"refresh_token_ttl" label:"刷新令牌TTL"` + EnableAuthorizationCode bool `json:"enable_authorization_code" label:"启用授权码模式"` + EnableImplicitGrant bool `json:"enable_implicit_grant" label:"启用隐式授权模式"` + EnableClientCredentials bool `json:"enable_client_credentials" label:"启用客户端凭证模式"` + AcceptHttpIfAlreadyTerminated bool `json:"accept_http_if_already_terminated" label:"如果已终止,则接受HTTP"` + ReuseRefreshToken bool `json:"reuse_refresh_token" label:"重用刷新令牌"` + PersistentRefreshToken bool `json:"persistent_refresh_token" label:"持久刷新令牌"` +} + +func (u *User) Username() string { + return u.Pattern.ClientId +} diff --git a/application/auth/oauth2/factory.go b/application/auth/oauth2/factory.go new file mode 100644 index 00000000..700134d0 --- /dev/null +++ b/application/auth/oauth2/factory.go @@ -0,0 +1,88 @@ +package oauth2 + +import ( + "fmt" + "net/http" + "reflect" + + "github.com/eolinker/eosc/utils/schema" + + "github.com/eolinker/apinto/application" + "github.com/eolinker/apinto/application/auth" +) + +var _ auth.IAuthFactory = (*factory)(nil) + +var driverName = "oauth2" + +// Register 注册auth驱动工厂 +func Register() { + auth.FactoryRegister(driverName, NewFactory()) +} + +type factory struct { + configType reflect.Type + render *schema.Schema + userType reflect.Type +} + +func (f *factory) Render() interface{} { + return f.render +} + +func (f *factory) ConfigType() reflect.Type { + return f.configType +} + +func (f *factory) UserType() reflect.Type { + return f.userType +} + +func (f *factory) Alias() []string { + return []string{ + "oauth2", + "oauth2_auth", + } +} + +func (f *factory) PreRouters() []*auth.PreRouter { + return []*auth.PreRouter{ + { + ID: "/oauth2/token", + PreHandler: NewHandler(NewTokenHandler()), + Path: "/oauth2/token", + Method: []string{http.MethodPost}, + }, + { + ID: "/oauth2/authorize", + PreHandler: NewHandler(NewAuthorizeHandler()), + Path: "/oauth2/authorize", + Method: []string{http.MethodPost}, + }, + } +} + +func (f *factory) Create(tokenName string, position string, rule interface{}) (application.IAuth, error) { + a := &oauth2{ + id: toId(tokenName, position), + tokenName: tokenName, + position: position, + users: application.NewUserManager(), + } + return a, nil +} + +// NewFactory 生成一个 auth_apiKey工厂 +func NewFactory() auth.IAuthFactory { + typ := reflect.TypeOf((*Config)(nil)) + render, _ := schema.Generate(typ, nil) + return &factory{ + configType: typ, + render: render, + userType: reflect.TypeOf((*User)(nil)), + } +} + +func toId(tokenName, position string) string { + return fmt.Sprintf("%s@%s@%s", tokenName, position, driverName) +} diff --git a/application/auth/oauth2/handler.go b/application/auth/oauth2/handler.go new file mode 100644 index 00000000..52a5344f --- /dev/null +++ b/application/auth/oauth2/handler.go @@ -0,0 +1,61 @@ +package oauth2 + +import ( + "net/http" + "net/url" + "strings" + "time" + + "github.com/eolinker/eosc/log" + + eoscContext "github.com/eolinker/eosc/eocontext" + + http_context "github.com/eolinker/eosc/eocontext/http-context" +) + +type IHandler interface { + Handle(ctx http_context.IHttpContext, client *Client, params url.Values) +} + +type Handler struct { + handler IHandler +} + +func NewHandler(handler IHandler) *Handler { + return &Handler{handler: handler} +} + +func (h *Handler) Server(eoContext eoscContext.EoContext) (isContinue bool) { + // 简化模式/授权码模式执行该流程 + ctx, err := http_context.Assert(eoContext) + if err != nil { + log.Errorf("assert http context error: %s", err) + return true + } + params := retrieveParameters(ctx) + clientId := params.Get("client_id") + if clientId == "" { + // 当空时视为正常请求,不做拦截 + return true + } + client, has := getClient(clientId) + if !has { + ctx.Response().SetBody([]byte("invalid client id")) + ctx.Response().SetStatus(http.StatusNotFound, "not found") + return false + } + + if strings.ToUpper(ctx.Request().URI().Scheme()) != "HTTPS" && !client.AcceptHttpIfAlreadyTerminated { + return false + } + if client.Expire > 0 && client.Expire < time.Now().Unix() { + ctx.Response().SetBody([]byte("client id is expired")) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return false + } + if h.handler != nil { + h.handler.Handle(ctx, client, params) + } + + return false +} diff --git a/application/auth/oauth2/hash.go b/application/auth/oauth2/hash.go new file mode 100644 index 00000000..c4b0ca13 --- /dev/null +++ b/application/auth/oauth2/hash.go @@ -0,0 +1,85 @@ +package oauth2 + +import ( + "fmt" + "strconv" + "strings" +) + +type hashRule struct { + algorithm string + iterations int + length int + salt string + value string +} + +func extractHashRule(hash string) (*hashRule, error) { + parts := strings.Split(hash, "$") + if len(parts) != 5 { + return nil, fmt.Errorf("invalid hashed password format") + } + subParts := strings.Split(parts[2], ",") + if len(subParts) != 2 { + return nil, fmt.Errorf("invalid hashed sub part format") + } + iterationsIndex := strings.Index(subParts[0], "=") + if iterationsIndex == -1 { + return nil, fmt.Errorf("iterations not found") + } + iterations, err := strconv.Atoi(subParts[0][iterationsIndex+1:]) + if err != nil { + return nil, fmt.Errorf("invalid iterations format") + } + lengthIndex := strings.Index(subParts[1], "=") + if lengthIndex == -1 { + return nil, fmt.Errorf("length not found") + } + length, err := strconv.Atoi(subParts[1][lengthIndex+1:]) + if err != nil { + return nil, fmt.Errorf("invalid length format") + } + return &hashRule{ + algorithm: parts[0], + iterations: iterations, + length: length, + salt: parts[3], + value: parts[4], + }, nil +} + +// +//func hashSecret(secret []byte, saltLen int, iterations int, keyLength int) (string, error) { +// if saltLen < 1 { +// saltLen = 16 +// } +// salt, err := generateRandomSalt(saltLen) +// if err != nil { +// return "", err +// } +// // 迭代次数和密钥长度 +// if iterations < 1 { +// iterations = 10000 +// } +// if keyLength < 1 { +// keyLength = 32 +// } +// +// // 使用 PBKDF2 密钥派生函数 +// key := pbkdf2.Key(secret, salt, iterations, keyLength, sha512.New) +// return fmt.Sprintf("$pbkdf2-sha512$i=%d,l=%d$%s$%s", iterations, keyLength, base64.RawStdEncoding.EncodeToString(salt), base64.RawStdEncoding.EncodeToString(key)), nil +//} + +//func generateRandomSalt(length int) ([]byte, error) { +// // Create a byte slice with the specified length +// salt := make([]byte, length) +// +// // Use crypto/rand to fill the slice with random bytes +// _, err := rand.Read(salt) +// if err != nil { +// return nil, err +// } +// +// // Return the salt as a hexadecimal string +// return salt, nil +//} diff --git a/application/auth/oauth2/hash_test.go b/application/auth/oauth2/hash_test.go new file mode 100644 index 00000000..39567a67 --- /dev/null +++ b/application/auth/oauth2/hash_test.go @@ -0,0 +1,12 @@ +package oauth2 + +import "testing" + +func TestHash(t *testing.T) { + data := "$pbkdf2-sha512$i=10000,l=32$7BGLyS03BLF+F+M01p7MBg$OTAR1PTJpXzCVBfRq3VcGXYlSeRD2IUEzk/RsRQwfwI" + + _, err := extractHashRule(data) + if err != nil { + t.Fatal(err) + } +} diff --git a/application/auth/oauth2/manager.go b/application/auth/oauth2/manager.go new file mode 100644 index 00000000..9e51f79e --- /dev/null +++ b/application/auth/oauth2/manager.go @@ -0,0 +1,33 @@ +package oauth2 + +import "github.com/eolinker/eosc" + +func registerClient(clientId string, client *Client) { + manager.clients.Set(clientId, client) +} + +func removeClient(clientId string) { + manager.clients.Del(clientId) +} + +func getClient(clientId string) (*Client, bool) { + return manager.clients.Get(clientId) +} + +var manager = NewManager() + +// Manager 管理oauth2配置 +type Manager struct { + clients eosc.Untyped[string, *Client] +} + +func NewManager() *Manager { + return &Manager{clients: eosc.BuildUntyped[string, *Client]()} +} + +type Client struct { + *Pattern + // Expire 过期时间 + Expire int64 + hashRule *hashRule +} diff --git a/application/auth/oauth2/oauth2.go b/application/auth/oauth2/oauth2.go new file mode 100644 index 00000000..8ea736fe --- /dev/null +++ b/application/auth/oauth2/oauth2.go @@ -0,0 +1,108 @@ +package oauth2 + +import ( + "fmt" + "sync" + + "github.com/eolinker/apinto/resources" + scope_manager "github.com/eolinker/apinto/scope-manager" + + "github.com/eolinker/eosc/log" + + "github.com/eolinker/apinto/application" + + http_service "github.com/eolinker/eosc/eocontext/http-context" +) + +var _ application.IAuth = (*oauth2)(nil) + +type oauth2 struct { + id string + tokenName string + position string + users application.IUserManager + cache scope_manager.IProxyOutput[resources.ICache] + once sync.Once +} + +func (o *oauth2) GetUser(ctx http_service.IHttpContext) (*application.UserInfo, bool) { + token, has := application.GetToken(ctx, o.tokenName, o.position) + if !has || token == "" { + return nil, false + } + o.once.Do(func() { + o.cache = scope_manager.Auto[resources.ICache]("", "redis") + }) + list := o.cache.List() + if len(list) < 1 { + return nil, false + } + clientID, err := validToken(ctx.Context(), list[0], token) + if err != nil { + log.Error("valid token error:", err, "token:", token) + return nil, false + } + + return o.users.Get(clientID) +} + +func (o *oauth2) ID() string { + return o.id +} + +func (o *oauth2) Driver() string { + return driverName +} + +func (o *oauth2) Check(appID string, users []application.ITransformConfig) error { + us := make([]application.IUser, 0, len(users)) + for _, u := range users { + v, ok := u.Config().(*User) + if !ok { + return fmt.Errorf("%s check error: invalid config type", driverName) + } + us = append(us, v) + } + return o.users.Check(appID, driverName, us) +} + +func (o *oauth2) Set(app application.IApp, users []application.ITransformConfig) { + infos := make([]*application.UserInfo, 0, len(users)) + for _, user := range users { + v, _ := user.Config().(*User) + client := &Client{ + Pattern: &v.Pattern, + Expire: v.Expire, + } + if v.Pattern.HashSecret { + hr, err := extractHashRule(v.Pattern.ClientSecret) + if err != nil { + log.Error("extract hash error:", err, "client secret:", v.Pattern.ClientSecret) + continue + } + log.Debug("hash rule: ", *hr) + client.hashRule = hr + } + registerClient(v.Pattern.ClientId, client) + + infos = append(infos, &application.UserInfo{ + Name: v.Username(), + Value: v.Pattern.ClientSecret, + Expire: v.Expire, + Labels: v.Labels, + HideCredential: v.HideCredential, + TokenName: o.tokenName, + Position: o.position, + App: app, + }) + } + o.users.Set(app.Id(), infos) +} + +func (o *oauth2) Del(appID string) { + o.users.DelByAppID(appID) +} + +func (o *oauth2) UserCount() int { + return o.users.Count() +} diff --git a/application/auth/oauth2/scan.go b/application/auth/oauth2/scan.go new file mode 100644 index 00000000..550e3885 --- /dev/null +++ b/application/auth/oauth2/scan.go @@ -0,0 +1,281 @@ +// Copyright 2012 Gary Burd +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package oauth2 + +import ( + "errors" + "fmt" + "reflect" + "strconv" +) + +type Error string + +func (err Error) Error() string { return string(err) } + +type Scanner interface { + // RedisScan assigns a value from a Redis value. The argument src is one of + // the reply types listed in the section `Executing Commands`. + // + // An error should be returned if the value cannot be stored without + // loss of information. + RedisScan(src interface{}) error +} + +func ensureLen(d reflect.Value, n int) { + if n > d.Cap() { + d.Set(reflect.MakeSlice(d.Type(), n, n)) + } else { + d.SetLen(n) + } +} + +func cannotConvert(d reflect.Value, s interface{}) error { + var name string + switch s.(type) { + case string: + name = "Redis simple string" + case Error: + name = "Redis error" + case int64: + name = "Redis integer" + case []byte: + name = "Redis bulk string" + case []interface{}: + name = "Redis array" + default: + name = reflect.TypeOf(s).String() + } + return fmt.Errorf("cannot convert from %s to %s", name, d.Type()) +} + +func convertAssignBulkString(d reflect.Value, s []byte) (err error) { + switch d.Type().Kind() { + case reflect.Float32, reflect.Float64: + var x float64 + x, err = strconv.ParseFloat(string(s), d.Type().Bits()) + d.SetFloat(x) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + var x int64 + x, err = strconv.ParseInt(string(s), 10, d.Type().Bits()) + d.SetInt(x) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + var x uint64 + x, err = strconv.ParseUint(string(s), 10, d.Type().Bits()) + d.SetUint(x) + case reflect.Bool: + var x bool + x, err = strconv.ParseBool(string(s)) + d.SetBool(x) + case reflect.String: + d.SetString(string(s)) + case reflect.Slice: + if d.Type().Elem().Kind() != reflect.Uint8 { + err = cannotConvert(d, s) + } else { + d.SetBytes(s) + } + default: + err = cannotConvert(d, s) + } + return +} + +func convertAssignInt(d reflect.Value, s int64) (err error) { + switch d.Type().Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + d.SetInt(s) + if d.Int() != s { + err = strconv.ErrRange + d.SetInt(0) + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if s < 0 { + err = strconv.ErrRange + } else { + x := uint64(s) + d.SetUint(x) + if d.Uint() != x { + err = strconv.ErrRange + d.SetUint(0) + } + } + case reflect.Bool: + d.SetBool(s != 0) + default: + err = cannotConvert(d, s) + } + return +} + +func convertAssignValue(d reflect.Value, s interface{}) (err error) { + if d.Kind() != reflect.Ptr { + if d.CanAddr() { + d2 := d.Addr() + if d2.CanInterface() { + if scanner, ok := d2.Interface().(Scanner); ok { + return scanner.RedisScan(s) + } + } + } + } else if d.CanInterface() { + // Already a reflect.Ptr + if d.IsNil() { + d.Set(reflect.New(d.Type().Elem())) + } + if scanner, ok := d.Interface().(Scanner); ok { + return scanner.RedisScan(s) + } + } + + switch s := s.(type) { + case []byte: + err = convertAssignBulkString(d, s) + case int64: + err = convertAssignInt(d, s) + default: + err = cannotConvert(d, s) + } + return err +} + +func convertAssignArray(d reflect.Value, s []interface{}) error { + if d.Type().Kind() != reflect.Slice { + return cannotConvert(d, s) + } + ensureLen(d, len(s)) + for i := 0; i < len(s); i++ { + if err := convertAssignValue(d.Index(i), s[i]); err != nil { + return err + } + } + return nil +} + +func convertAssign(d interface{}, s interface{}) (err error) { + if scanner, ok := d.(Scanner); ok { + return scanner.RedisScan(s) + } + + // Handle the most common destination types using type switches and + // fall back to reflection for all other types. + switch s := s.(type) { + case nil: + // ignore + case []byte: + switch d := d.(type) { + case *string: + *d = string(s) + case *int: + *d, err = strconv.Atoi(string(s)) + case *bool: + *d, err = strconv.ParseBool(string(s)) + case *[]byte: + *d = s + case *interface{}: + *d = s + case nil: + // skip value + default: + if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr { + err = cannotConvert(d, s) + } else { + err = convertAssignBulkString(d.Elem(), s) + } + } + case int64: + switch d := d.(type) { + case *int: + x := int(s) + if int64(x) != s { + err = strconv.ErrRange + x = 0 + } + *d = x + case *bool: + *d = s != 0 + case *interface{}: + *d = s + case nil: + // skip value + default: + if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr { + err = cannotConvert(d, s) + } else { + err = convertAssignInt(d.Elem(), s) + } + } + case string: + switch d := d.(type) { + case *string: + *d = s + case *interface{}: + *d = s + case nil: + // skip value + default: + err = cannotConvert(reflect.ValueOf(d), s) + } + case []interface{}: + switch d := d.(type) { + case *[]interface{}: + *d = s + case *interface{}: + *d = s + case nil: + // skip value + default: + if d := reflect.ValueOf(d); d.Type().Kind() != reflect.Ptr { + err = cannotConvert(d, s) + } else { + err = convertAssignArray(d.Elem(), s) + } + } + case Error: + err = s + default: + err = cannotConvert(reflect.ValueOf(d), s) + } + return +} + +// Scan copies from src to the values pointed at by dest. +// +// Scan uses RedisScan if available otherwise: +// +// The values pointed at by dest must be an integer, float, boolean, string, +// []byte, interface{} or slices of these types. Scan uses the standard strconv +// package to convert bulk strings to numeric and boolean types. +// +// If a dest value is nil, then the corresponding src value is skipped. +// +// If a src element is nil, then the corresponding dest value is not modified. +// +// To enable easy use of Scan in a loop, Scan returns the slice of src +// following the copied values. +func Scan(src []interface{}, dest ...interface{}) ([]interface{}, error) { + if len(src) < len(dest) { + return nil, errors.New("redigo.Scan: array short") + } + var err error + for i, d := range dest { + err = convertAssign(d, src[i]) + if err != nil { + err = fmt.Errorf("redigo.Scan: cannot assign to dest %d: %v", i, err) + break + } + } + return src[len(dest):], err +} diff --git a/application/auth/oauth2/token.go b/application/auth/oauth2/token.go new file mode 100644 index 00000000..a132786b --- /dev/null +++ b/application/auth/oauth2/token.go @@ -0,0 +1,444 @@ +package oauth2 + +import ( + "context" + "crypto/md5" + "crypto/rand" + "crypto/sha512" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/eolinker/eosc/router" + + "github.com/eolinker/eosc/log" + + scope_manager "github.com/eolinker/apinto/scope-manager" + + http_service "github.com/eolinker/eosc/eocontext/http-context" + "golang.org/x/crypto/pbkdf2" + + "github.com/eolinker/apinto/resources" +) + +type TokenResponse struct { + Total int `json:"total"` + Data []*TokenData `json:"data"` +} + +type TokenData struct { + AuthenticatedUserid interface{} `json:"authenticated_userid"` + Credential struct { + Id string `json:"id"` + } `json:"credential"` + AccessToken string `json:"access_token"` + Service interface{} `json:"service"` + CreatedAt int64 `json:"created_at"` + RefreshToken interface{} `json:"refresh_token"` + Scope interface{} `json:"scope"` + Ttl int `json:"ttl"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + ClientID string `json:"client_id"` +} + +func NewTokenHandler() *TokenHandler { + h := &TokenHandler{} + router.SetPath("aaaa", "/oauth_tokens/", h) + return h +} + +type TokenHandler struct { + cache scope_manager.IProxyOutput[resources.ICache] + once sync.Once +} + +func (t *TokenHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + t.once.Do(func() { + t.cache = scope_manager.Auto[resources.ICache]("", "redis") + }) + list := t.cache.List() + if len(list) < 1 { + writer.WriteHeader(http.StatusOK) + writer.Write(newError(http.StatusForbidden, "redis cache not found")) + return + } + cache := list[0] + switch request.Method { + case http.MethodPost: + // 创建token + body, err := io.ReadAll(request.Body) + if err != nil { + writer.WriteHeader(http.StatusOK) + writer.Write(newError(http.StatusForbidden, err.Error())) + return + } + var resp TokenResponse + err = json.Unmarshal(body, &resp) + if err != nil { + writer.WriteHeader(http.StatusOK) + writer.Write(newError(-1, err.Error())) + return + } + for _, token := range resp.Data { + createAt := time.UnixMilli(token.CreatedAt) + if createAt.Add(time.Duration(token.ExpiresIn) * time.Second).Before(time.Now()) { + // 过期 + continue + } + redisKey := fmt.Sprintf("apinto:oauth2_access_tokens:%s:%s", os.Getenv("cluster_id"), token.AccessToken) + // 保存token + cache.HMSetN(context.Background(), redisKey, map[string]interface{}{ + "access_token": token.AccessToken, + "scope": token.Scope, + "expires_in": token.ExpiresIn, + "create_at": token.CreatedAt, + "refresh_token": token.RefreshToken, + "client_id": token.Credential.Id, + }, time.Duration(token.ExpiresIn)*time.Second) + } + byteBody, _ := json.Marshal(map[string]interface{}{ + "code": 0, + }) + writer.WriteHeader(http.StatusOK) + writer.Write(byteBody) + return + case http.MethodGet: + // 获取tokens + tokenKeys, err := cache.Keys(context.Background(), fmt.Sprintf("apinto:oauth2_access_tokens:%s:*", os.Getenv("cluster_id"))).Result() + if err != nil { + writer.WriteHeader(http.StatusOK) + writer.Write(newError(-1, err.Error())) + return + } + var tokens []*TokenData + for _, key := range tokenKeys { + token, err := getTokenByRedis(cache, key) + if err != nil { + log.Errorf("get token error: %s", err.Error()) + continue + } + tokens = append(tokens, token) + } + data, err := json.Marshal(TokenResponse{ + Total: len(tokens), + Data: tokens, + }) + if err != nil { + writer.WriteHeader(http.StatusOK) + writer.Write(newError(-1, err.Error())) + return + } + writer.WriteHeader(http.StatusOK) + writer.Write(data) + return + } +} + +func getTokenByRedis(cache resources.ICache, redisKey string) (*TokenData, error) { + var accessToken, scope, refreshToken, clientId, createdAt, expiresIn string + result, err := cache.HMGet(context.Background(), redisKey, "access_token", "scope", "expires_in", "create_at", "refresh_token", "client_id").Result() + if err != nil { + return nil, err + } + _, err = Scan(result, &accessToken, &scope, &expiresIn, &createdAt, &refreshToken, &clientId) + if err != nil { + return nil, err + } + expiresInInt, err := strconv.Atoi(expiresIn) + if err != nil { + return nil, err + } + createdAtInt, err := strconv.ParseInt(createdAt, 10, 64) + if err != nil { + return nil, err + } + return &TokenData{ + AccessToken: accessToken, + RefreshToken: refreshToken, + Scope: scope, + ExpiresIn: expiresInInt, + CreatedAt: createdAtInt, + ClientID: clientId, + }, nil +} + +func newError(code int, msg string) []byte { + body, _ := json.Marshal(map[string]interface{}{ + "code": code, + "err": msg, + }) + return body +} + +func (t *TokenHandler) Handle(ctx http_service.IHttpContext, client *Client, params url.Values) { + + grantType := params.Get("grant_type") + clientSecret := params.Get("client_secret") + state := params.Get("state") + if grantType == "" || !((grantType == GrantAuthorizationCode && client.EnableAuthorizationCode) || (grantType == GrantClientCredentials && client.EnableClientCredentials) || grantType == GrantRefreshToken) { + ctx.Response().SetBody([]byte(fmt.Sprintf("unsupported grant type: %s,client id is %s", grantType, client.ClientId))) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + + if client.HashSecret { + // 密钥经过加密 + salt, _ := base64.RawStdEncoding.DecodeString(client.hashRule.salt) + secret := pbkdf2.Key([]byte(clientSecret), salt, client.hashRule.iterations, client.hashRule.length, sha512.New) + clientSecret = base64.RawStdEncoding.EncodeToString(secret) + } + + if clientSecret != client.hashRule.value { + ctx.Response().SetBody([]byte(fmt.Sprintf("fail to match secret,now: %s,hope: %s,client id is %s", clientSecret, client.hashRule.value, client.ClientId))) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + type Response struct { + *Token + State string `json:"state,omitempty"` + } + t.once.Do(func() { + t.cache = scope_manager.Auto[resources.ICache]("", "redis") + }) + list := t.cache.List() + if len(list) < 1 { + ctx.Response().SetBody([]byte("redis cache is not found")) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + + cache := list[0] + switch grantType { + case GrantRefreshToken: + refreshToken := params.Get("refresh_token") + if refreshToken == "" { + ctx.Response().SetBody([]byte("refresh token is required, client id is " + client.ClientId)) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + + redisKey := fmt.Sprintf("apinto:oauth2_refresh_tokens:%s:%s", os.Getenv("cluster_id"), refreshToken) + + result, err := cache.HMGet(ctx.Context(), redisKey, "refresh_token", "access_token").Result() + if err != nil { + ctx.Response().SetBody([]byte("fail to get refresh token, client id is " + client.ClientId)) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + var refreshTokenStr, accessTokenStr string + _, err = Scan(result, &refreshTokenStr, &accessTokenStr) + if err != nil { + ctx.Response().SetBody([]byte("invalid refresh token, client id is " + client.ClientId)) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + if refreshTokenStr != refreshToken { + ctx.Response().SetBody([]byte("invalid refresh token, client id is " + client.ClientId)) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + token, err := generateToken(ctx.Context(), cache, client.ClientId, client.TokenExpiration, client.RefreshTokenTTL, "", !client.ReuseRefreshToken) + if !client.PersistentRefreshToken { + // 不持久化refresh token + accessTokenRedisKey := fmt.Sprintf("apinto:oauth2_access_tokens:%s:%s", os.Getenv("cluster_id"), accessTokenStr) + cache.Del(ctx.Context(), accessTokenRedisKey) + } + if client.ReuseRefreshToken { + // 重用refresh token + token.AccessToken = accessTokenStr + cache.HMSetN(ctx.Context(), redisKey, map[string]interface{}{ + "access_token": token.AccessToken, + }, 0) + } else { + cache.Del(ctx.Context(), redisKey) + } + response := &Response{ + Token: token, + State: state, + } + data, _ := json.Marshal(response) + ctx.Response().SetBody(data) + ctx.Response().SetStatus(http.StatusOK, "ok") + return + case GrantAuthorizationCode: + code := params.Get("code") + if code == "" { + ctx.Response().SetBody([]byte("code is required, client id is " + client.ClientId)) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + redisKey := fmt.Sprintf("apinto:oauth2_codes:%s:%s", os.Getenv("cluster_id"), code) + result, err := cache.HMGet(ctx.Context(), redisKey, "code", "scope").Result() + if err != nil { + ctx.Response().SetBody([]byte("fail to get code, client id is " + client.ClientId)) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + // 删除旧授权码 + cache.Del(ctx.Context(), redisKey) + var codeStr, scope string + _, err = Scan(result, &codeStr, &scope) + if err != nil { + ctx.Response().SetBody([]byte("invalid code")) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + if codeStr != code { + ctx.Response().SetBody([]byte("invalid code")) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + + token, err := generateToken(ctx.Context(), cache, client.ClientId, client.TokenExpiration, client.RefreshTokenTTL, scope, true) + if err != nil { + ctx.Response().SetBody([]byte(fmt.Sprintf("(%s)generate token error: %s", client.ClientId, err.Error()))) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + response := &Response{ + Token: token, + State: state, + } + data, _ := json.Marshal(response) + ctx.Response().SetBody(data) + ctx.Response().SetStatus(http.StatusOK, "ok") + return + case GrantClientCredentials: + // 生成token + token, err := generateToken(ctx.Context(), cache, client.ClientId, client.TokenExpiration, client.RefreshTokenTTL, "", false) + if err != nil { + ctx.Response().SetBody([]byte(fmt.Sprintf("(%s)generate token error: %s", client.ClientId, err.Error()))) + ctx.Response().SetStatus(http.StatusForbidden, "forbidden") + return + } + response := &Response{ + Token: token, + State: state, + } + data, _ := json.Marshal(response) + ctx.Response().SetBody(data) + ctx.Response().SetStatus(http.StatusOK, "ok") + return + } +} + +func generateRandomString() string { + b := make([]byte, 40) + if _, err := io.ReadFull(rand.Reader, b); err != nil { + return "" + } + baseRes := base64.StdEncoding.EncodeToString(b) + h := md5.New() + h.Write([]byte(baseRes)) + res := hex.EncodeToString(h.Sum(nil)) + return res +} + +func retrieveParameters(ctx http_service.IHttpContext) url.Values { + params := url.Values{} + queries, _ := url.ParseQuery(ctx.Request().URI().RawQuery()) + for k, v := range queries { + params.Set(k, v[0]) + } + if strings.Contains(ctx.Request().ContentType(), "application/x-www-form-urlencoded") { + body, _ := ctx.Request().Body().BodyForm() + for k, v := range body { + params.Set(k, v[0]) + } + } else if strings.Contains(ctx.Request().ContentType(), "application/json") { + var body map[string]string + rawBody, _ := ctx.Request().Body().RawBody() + json.Unmarshal(rawBody, &body) + for k, v := range body { + params.Set(k, v) + } + } + return params +} + +func generateToken(ctx context.Context, cache resources.ICache, clientID string, tokenExpired int, refreshTokenTTL int, scope string, isRefresh bool) (*Token, error) { + // 简化模式 + accessToken := generateRandomString() + if tokenExpired <= 0 { + tokenExpired = 7200 + } + if refreshTokenTTL <= 0 { + refreshTokenTTL = 1209600 + } + refreshToken := "" + if isRefresh { + refreshToken = generateRandomString() + } + + redisKey := fmt.Sprintf("apinto:oauth2_access_tokens:%s:%s", os.Getenv("cluster_id"), accessToken) + now := time.Now() + fields := map[string]interface{}{ + "client_id": clientID, + "expires_in": tokenExpired, + "access_token": accessToken, + "refresh_token": refreshToken, + "create_at": now.UnixMilli(), + "scope": scope, + } + _, err := cache.HMSetN(ctx, redisKey, fields, time.Duration(tokenExpired)*time.Second).Result() + if err != nil { + return nil, fmt.Errorf("(%s)redis HMSet %s error: %s", clientID, redisKey, err.Error()) + } + if isRefresh { + redisKey = fmt.Sprintf("apinto:oauth2_refresh_tokens:%s:%s", os.Getenv("cluster_id"), refreshToken) + + _, err = cache.HMSetN(ctx, redisKey, fields, time.Duration(refreshTokenTTL)*time.Second).Result() + if err != nil { + return nil, fmt.Errorf("(%s)redis HMSet %s error: %s", clientID, redisKey, err.Error()) + } + } + return &Token{ + TokenType: "bearer", + ExpiresIn: tokenExpired, + AccessToken: accessToken, + RefreshToken: refreshToken, + Scope: scope, + }, nil +} + +func validToken(ctx context.Context, cache resources.ICache, token string) (string, error) { + redisKey := fmt.Sprintf("apinto:oauth2_access_tokens:%s:%s", os.Getenv("cluster_id"), token) + result, err := cache.HMGet(ctx, redisKey, "client_id", "access_token", "create_at", "expires_in").Result() + if err != nil { + return "", fmt.Errorf("redis HMGet %s error: %s", redisKey, err.Error()) + } + var clientID, accessToken, createAt, expiresInStr string + _, err = Scan(result, &clientID, &accessToken, &createAt, &expiresInStr) + if err != nil { + return "", fmt.Errorf("scan redis result error: %s", err.Error()) + } + createAtTime, _ := strconv.ParseInt(createAt, 10, 64) + expiresIn, _ := strconv.ParseInt(expiresInStr, 10, 64) + createTime := time.UnixMilli(createAtTime) + if time.Now().After(createTime.Add(time.Duration(expiresIn) * time.Second)) { + // token过期 + return "", fmt.Errorf("token expired") + } + if accessToken != token { + return "", fmt.Errorf("invalid token") + } + return clientID, nil +} + +type Token struct { + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + AccessToken string `json:"access_token,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` +} diff --git a/drivers/app/factory.go b/drivers/app/factory.go index 74289526..22497e92 100644 --- a/drivers/app/factory.go +++ b/drivers/app/factory.go @@ -1,15 +1,17 @@ package app import ( + "sync" + "github.com/eolinker/apinto/application/auth" "github.com/eolinker/apinto/application/auth/aksk" "github.com/eolinker/apinto/application/auth/apikey" "github.com/eolinker/apinto/application/auth/basic" "github.com/eolinker/apinto/application/auth/jwt" + "github.com/eolinker/apinto/application/auth/oauth2" "github.com/eolinker/apinto/drivers" "github.com/eolinker/apinto/drivers/app/manager" "github.com/eolinker/eosc/common/bean" - "sync" "github.com/eolinker/eosc" ) @@ -21,21 +23,21 @@ var ( ones sync.Once ) -//Register 注册service_http驱动工厂 +// Register 注册service_http驱动工厂 func Register(register eosc.IExtenderDriverRegister) { register.RegisterExtenderDriver(name, NewFactory()) } -//NewFactory 创建service_http驱动工厂 +// NewFactory 创建service_http驱动工厂 func NewFactory() eosc.IExtenderDriverFactory { ones.Do(func() { apikey.Register() basic.Register() aksk.Register() jwt.Register() + oauth2.Register() appManager = manager.NewManager(auth.Alias(), auth.Keys()) bean.Injection(&appManager) }) return drivers.NewFactory[Config](Create) } - diff --git a/drivers/app/manager/app.go b/drivers/app/manager/app.go index 8746e8bb..175a283f 100644 --- a/drivers/app/manager/app.go +++ b/drivers/app/manager/app.go @@ -1,6 +1,12 @@ package manager -import "sync" +import ( + "sync" + + "github.com/eolinker/apinto/drivers/router/http-router/manager" + + "github.com/eolinker/apinto/application/auth" +) var _ IAppManager = (*AppManager)(nil) @@ -52,9 +58,14 @@ func (a *AppManager) Set(appID string, driver string, ids []string) { app.Set(appID, ids) return } + app = NewAppData() app.Set(appID, ids) a.apps[driver] = app + fac, _ := auth.GetFactory(driver) + for _, r := range fac.PreRouters() { + manager.AddPreRouter(r.ID, r.Method, r.Path, r.PreHandler) + } } func (a *AppManager) DelByDriver(driver string) { @@ -66,8 +77,15 @@ func (a *AppManager) DelByDriver(driver string) { func (a *AppManager) DelByAppID(appID string) { a.locker.RLock() defer a.locker.RUnlock() - for _, app := range a.apps { + for driver, app := range a.apps { app.Del(appID) + ids := app.All() + if len(ids) == 0 { + fac, _ := auth.GetFactory(driver) + for _, r := range fac.PreRouters() { + manager.DeletePreRouter(r.ID) + } + } } } diff --git a/drivers/resources/redis/cmdable.go b/drivers/resources/redis/cmdable.go index 5fbabca2..ad2283b9 100644 --- a/drivers/resources/redis/cmdable.go +++ b/drivers/resources/redis/cmdable.go @@ -9,12 +9,14 @@ import ( ) var ( - ErrorNotInitRedis = errors.New("redis not init") - intError = resources.NewIntResult(0, ErrorNotInitRedis) - boolError = resources.NewBoolResult(false, ErrorNotInitRedis) - stringError = resources.NewStringResult("", ErrorNotInitRedis) - statusError = resources.NewStatusResult(ErrorNotInitRedis) - interfaceError = resources.NewInterfaceResult(nil, ErrorNotInitRedis) + ErrorNotInitRedis = errors.New("redis not init") + intError = resources.NewIntResult(0, ErrorNotInitRedis) + boolError = resources.NewBoolResult(false, ErrorNotInitRedis) + stringError = resources.NewStringResult("", ErrorNotInitRedis) + statusError = resources.NewStatusResult(ErrorNotInitRedis) + interfaceError = resources.NewInterfaceResult(nil, ErrorNotInitRedis) + stringSliceError = resources.NewStringSliceResult(nil, ErrorNotInitRedis) + arrayInterfaceError = resources.NewArrayInterfaceResult(nil, ErrorNotInitRedis) ) type Empty struct { @@ -45,6 +47,10 @@ func (e *Empty) IncrBy(ctx context.Context, key string, decrement int64, expirat return intError } +func (e *Empty) Keys(ctx context.Context, key string) resources.StringSliceResult { + return stringSliceError +} + func (e *Empty) Get(ctx context.Context, key string) resources.StringResult { return stringError } @@ -57,6 +63,14 @@ func (e *Empty) Del(ctx context.Context, keys ...string) resources.IntResult { return intError } +func (e *Empty) HMSetN(ctx context.Context, key string, fields map[string]interface{}, expiration time.Duration) resources.BoolResult { + return boolError +} + +func (e *Empty) HMGet(ctx context.Context, key string, fields ...string) resources.ArrayInterfaceResult { + return arrayInterfaceError +} + func (e *Empty) Run(ctx context.Context, script interface{}, keys []string, args ...interface{}) resources.InterfaceResult { return interfaceError } diff --git a/drivers/resources/redis/redis.go b/drivers/resources/redis/redis.go index 4ce7cfe0..d3d95c63 100644 --- a/drivers/resources/redis/redis.go +++ b/drivers/resources/redis/redis.go @@ -5,6 +5,8 @@ import ( "fmt" "time" + "github.com/eolinker/eosc/log" + "github.com/eolinker/apinto/resources" "github.com/go-redis/redis/v8" ) @@ -89,6 +91,10 @@ func (r *Cmdable) IncrBy(ctx context.Context, key string, decrement int64, expir return result } +func (r *Cmdable) Keys(ctx context.Context, key string) resources.StringSliceResult { + return r.cmdable.Keys(ctx, key) +} + func (r *Cmdable) Get(ctx context.Context, key string) resources.StringResult { return r.cmdable.Get(ctx, key) @@ -99,6 +105,24 @@ func (r *Cmdable) GetDel(ctx context.Context, key string) resources.StringResult } +func (r *Cmdable) HMSetN(ctx context.Context, key string, fields map[string]interface{}, expiration time.Duration) resources.BoolResult { + pipeline := r.cmdable.Pipeline() + result := pipeline.HMSet(ctx, key, fields) + if expiration > 0 { + pipeline.Expire(ctx, key, expiration) + } + _, err := pipeline.Exec(ctx) + if err != nil { + log.Errorf("HMSetN error:%s", err.Error()) + return nil + } + return result +} + +func (r *Cmdable) HMGet(ctx context.Context, key string, fields ...string) resources.ArrayInterfaceResult { + return r.cmdable.HMGet(ctx, key, fields...) +} + func (r *Cmdable) Del(ctx context.Context, keys ...string) resources.IntResult { return r.cmdable.Del(ctx, keys...) } diff --git a/drivers/router/dubbo2-router/handler.go b/drivers/router/dubbo2-router/handler.go index 97a8a77d..59e20e9f 100644 --- a/drivers/router/dubbo2-router/handler.go +++ b/drivers/router/dubbo2-router/handler.go @@ -31,7 +31,7 @@ type dubboHandler struct { var completeCaller = manager.NewCompleteCaller() -func (d *dubboHandler) ServeHTTP(ctx eocontext.EoContext) { +func (d *dubboHandler) Serve(ctx eocontext.EoContext) { dubboCtx, err := dubbo2_context.Assert(ctx) if err != nil { diff --git a/drivers/router/dubbo2-router/manager/manager.go b/drivers/router/dubbo2-router/manager/manager.go index 4a37271d..33d84faa 100644 --- a/drivers/router/dubbo2-router/manager/manager.go +++ b/drivers/router/dubbo2-router/manager/manager.go @@ -93,7 +93,7 @@ func (d *dubboManger) Handler(port int, req *invocation.RPCInvocation) protocol. } else { log.Debug("match has:", port) - match.ServeHTTP(ctx) + match.Serve(ctx) } finish := ctx.GetFinish() diff --git a/drivers/router/grpc-router/handler.go b/drivers/router/grpc-router/handler.go index 6b2f3394..390777e4 100644 --- a/drivers/router/grpc-router/handler.go +++ b/drivers/router/grpc-router/handler.go @@ -31,7 +31,7 @@ type grpcRouter struct { timeout time.Duration } -func (h *grpcRouter) ServeHTTP(ctx eocontext.EoContext) { +func (h *grpcRouter) Serve(ctx eocontext.EoContext) { grpcContext, err := grpc_context.Assert(ctx) if err != nil { return diff --git a/drivers/router/grpc-router/manager/manager.go b/drivers/router/grpc-router/manager/manager.go index 8f524d6c..b1214356 100644 --- a/drivers/router/grpc-router/manager/manager.go +++ b/drivers/router/grpc-router/manager/manager.go @@ -85,7 +85,7 @@ func (m *Manager) FastHandler(port int, srv interface{}, stream grpc.ServerStrea } } else { log.Debug("match has:", port) - r.ServeHTTP(ctx) + r.Serve(ctx) } finishHandler := ctx.GetFinish() diff --git a/drivers/router/http-router/http-handler.go b/drivers/router/http-router/http-handler.go index 9f50eeec..5243f7e7 100644 --- a/drivers/router/http-router/http-handler.go +++ b/drivers/router/http-router/http-handler.go @@ -34,7 +34,7 @@ type httpHandler struct { timeout time.Duration } -func (h *httpHandler) ServeHTTP(ctx eocontext.EoContext) { +func (h *httpHandler) Serve(ctx eocontext.EoContext) { httpContext, err := http_context.Assert(ctx) if err != nil { return diff --git a/drivers/router/http-router/manager/export.go b/drivers/router/http-router/manager/export.go new file mode 100644 index 00000000..0819db61 --- /dev/null +++ b/drivers/router/http-router/manager/export.go @@ -0,0 +1,16 @@ +package manager + +import "github.com/eolinker/apinto/router" + +func Set(id string, port int, hosts []string, method []string, path string, append []AppendRule, router router.IRouterHandler) error { + return routerManager.Set(id, port, hosts, method, path, append, router) +} +func Delete(id string) { + routerManager.Delete(id) +} +func AddPreRouter(id string, method []string, path string, handler router.IRouterPreHandler) { + routerManager.AddPreRouter(id, method, path, handler) +} +func DeletePreRouter(id string) { + routerManager.DeletePreRouter(id) +} diff --git a/drivers/router/http-router/manager/init.go b/drivers/router/http-router/manager/init.go index 6f9875a3..e543a827 100644 --- a/drivers/router/http-router/manager/init.go +++ b/drivers/router/http-router/manager/init.go @@ -12,12 +12,12 @@ import ( ) var ( - chainProxy eocontext.IChainPro + chainProxy eocontext.IChainPro + routerManager = NewManager() ) func init() { - var routerManager = NewManager() serverHandler := func(port int, ln net.Listener) { server := fasthttp.Server{ StreamRequestBody: true, diff --git a/drivers/router/http-router/manager/manager.go b/drivers/router/http-router/manager/manager.go index c59a972e..92d1b237 100644 --- a/drivers/router/http-router/manager/manager.go +++ b/drivers/router/http-router/manager/manager.go @@ -20,8 +20,12 @@ var completeCaller = http_complete.NewHttpCompleteCaller() type IManger interface { Set(id string, port int, hosts []string, method []string, path string, append []AppendRule, router router.IRouterHandler) error Delete(id string) + AddPreRouter(id string, method []string, path string, handler router.IRouterPreHandler) + DeletePreRouter(id string) } + type Manager struct { + IPreRouterData lock sync.RWMutex matcher router.IMatcher @@ -35,7 +39,8 @@ func (m *Manager) SetGlobalFilters(globalFilters *eoscContext.IChainPro) { // NewManager 创建路由管理器 func NewManager() *Manager { - return &Manager{routersData: new(RouterData)} + return &Manager{routersData: new(RouterData), + IPreRouterData: newImlPreRouterData()} } func (m *Manager) Set(id string, port int, hosts []string, method []string, path string, append []AppendRule, router router.IRouterHandler) error { @@ -68,6 +73,9 @@ func (m *Manager) Delete(id string) { func (m *Manager) FastHandler(port int, ctx *fasthttp.RequestCtx) { httpContext := http_context.NewContext(ctx, port) + if !m.IPreRouterData.Server(httpContext) { + return + } if m.matcher == nil { httpContext.SetFinish(notFound) httpContext.SetCompleteHandler(notFound) @@ -88,7 +96,7 @@ func (m *Manager) FastHandler(port int, ctx *fasthttp.RequestCtx) { } } else { log.Debug("match has:", port) - r.ServeHTTP(httpContext) + r.Serve(httpContext) } finishHandler := httpContext.GetFinish() if finishHandler != nil { diff --git a/drivers/router/http-router/manager/pre.go b/drivers/router/http-router/manager/pre.go new file mode 100644 index 00000000..87f1fdcb --- /dev/null +++ b/drivers/router/http-router/manager/pre.go @@ -0,0 +1,123 @@ +package manager + +import ( + "sync" + + "github.com/eolinker/apinto/router" + "github.com/eolinker/eosc/eocontext" + http_context "github.com/eolinker/eosc/eocontext/http-context" + "github.com/eolinker/eosc/log" +) + +type IPreRouterData interface { + router.IRouterPreHandler + AddPreRouter(id string, method []string, path string, handler router.IRouterPreHandler) + DeletePreRouter(id string) +} + +type preRouterItem struct { + id string + method []string + path string + handler router.IRouterPreHandler +} + +var ( + _ router.IRouterPreHandler = (*iPreRouterHandler)(nil) +) + +type iPreRouterHandler struct { + routers map[string]map[string][]router.IRouterPreHandler +} + +func (i *iPreRouterHandler) Server(ctx eocontext.EoContext) (isContinue bool) { + if i == nil { + return true + } + httpCtx, err := http_context.Assert(ctx) + if err != nil { + return true + } + method := httpCtx.Request().Method() + path := httpCtx.Request().URI().Path() + + ms, has := i.routers[path] + + if !has { + return true + } + + handlers, has := ms[method] + if !has { + handlers, has = ms["*"] + if !has { + return true + } + } + for _, handler := range handlers { + if !handler.Server(ctx) { + return false + } + } + return true + +} + +type imlPreRouterData struct { + lock sync.RWMutex + + handler router.IRouterPreHandler + items map[string]*preRouterItem +} + +func (p *imlPreRouterData) Server(ctx eocontext.EoContext) (isContinue bool) { + if p == nil || p.handler == nil { + return true + } + log.Debug("pre router hander:", p.handler) + return p.handler.Server(ctx) +} + +func newImlPreRouterData() IPreRouterData { + return &imlPreRouterData{ + items: make(map[string]*preRouterItem), + } +} + +func (p *imlPreRouterData) AddPreRouter(id string, method []string, path string, handler router.IRouterPreHandler) { + p.lock.Lock() + defer p.lock.Unlock() + p.items[id] = &preRouterItem{ + id: id, + method: method, + path: path, + handler: handler, + } + log.Debug("add pre router:", p.items) + p.handler = p.parse() +} + +func (p *imlPreRouterData) DeletePreRouter(id string) { + p.lock.Lock() + defer p.lock.Unlock() + delete(p.items, id) + p.handler = p.parse() +} +func (p *imlPreRouterData) parse() router.IRouterPreHandler { + if len(p.items) == 0 { + return nil + } + routers := make(map[string]map[string][]router.IRouterPreHandler) + for _, v := range p.items { + if _, has := routers[v.path]; !has { + routers[v.path] = make(map[string][]router.IRouterPreHandler) + } + if len(v.method) == 0 { + v.method = []string{"*"} + } + for _, method := range v.method { + routers[v.path][method] = append(routers[v.path][method], v.handler) + } + } + return &iPreRouterHandler{routers: routers} +} diff --git a/go.mod b/go.mod index 85afeed7..16851248 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/clbanning/mxj v1.8.4 github.com/coocood/freecache v1.2.2 github.com/dubbogo/gost v1.13.1 - github.com/eolinker/eosc v0.15.2 + github.com/eolinker/eosc v0.16.1 github.com/fasthttp/websocket v1.5.0 github.com/fullstorydev/grpcurl v1.8.7 github.com/go-redis/redis/v8 v8.11.5 @@ -167,4 +167,4 @@ require ( gopkg.in/yaml.v2 v2.4.0 // indirect ) -replace github.com/eolinker/eosc => ../eosc +//replace github.com/eolinker/eosc => ../eosc diff --git a/resources/cache-local.go b/resources/cache-local.go index d5f87f51..dd0cb5da 100644 --- a/resources/cache-local.go +++ b/resources/cache-local.go @@ -131,9 +131,13 @@ func ToInt(b []byte) int64 { return v } func ToBytes(v int64) []byte { - return []byte(strconv.FormatInt(v, 10)) } + +func (n *cacheLocal) Keys(ctx context.Context, pattern string) StringSliceResult { + return NewStringSliceResult(nil, errors.New("not support")) +} + func (n *cacheLocal) Get(ctx context.Context, key string) StringResult { data, err := n.client.Get([]byte(key)) if err != nil { @@ -152,6 +156,14 @@ func (n *cacheLocal) GetDel(ctx context.Context, key string) StringResult { return NewStringResultBytes(bytes, nil) } +func (n *cacheLocal) HMSetN(ctx context.Context, key string, fields map[string]interface{}, expiration time.Duration) BoolResult { + return NewBoolResult(false, errors.New("not support")) +} + +func (n *cacheLocal) HMGet(ctx context.Context, key string, fields ...string) ArrayInterfaceResult { + return NewArrayInterfaceResult(nil, errors.New("not support")) +} + func (n *cacheLocal) Del(ctx context.Context, keys ...string) IntResult { var count int64 = 0 for _, key := range keys { diff --git a/resources/cache.go b/resources/cache.go index 44e79ab2..f4b8f403 100644 --- a/resources/cache.go +++ b/resources/cache.go @@ -13,8 +13,11 @@ type ICache interface { SetNX(ctx context.Context, key string, value []byte, expiration time.Duration) BoolResult DecrBy(ctx context.Context, key string, decrement int64, expiration time.Duration) IntResult IncrBy(ctx context.Context, key string, decrement int64, expiration time.Duration) IntResult + Keys(ctx context.Context, key string) StringSliceResult Get(ctx context.Context, key string) StringResult GetDel(ctx context.Context, key string) StringResult + HMSetN(ctx context.Context, key string, fields map[string]interface{}, expiration time.Duration) BoolResult + HMGet(ctx context.Context, key string, fields ...string) ArrayInterfaceResult Del(ctx context.Context, keys ...string) IntResult Run(ctx context.Context, script interface{}, keys []string, args ...interface{}) InterfaceResult Tx() TX @@ -25,6 +28,10 @@ type TX interface { Exec(ctx context.Context) error } +type ArrayInterfaceResult interface { + Result() ([]interface{}, error) +} + type InterfaceResult interface { Result() (interface{}, error) } @@ -39,6 +46,10 @@ type StringResult interface { Result() (string, error) Bytes() ([]byte, error) } + +type StringSliceResult interface { + Result() ([]string, error) +} type StatusResult interface { Result() error } @@ -105,6 +116,19 @@ func (b *intResult) Result() (int64, error) { return b.val, b.err } +type stringSliceResult struct { + val []string + err error +} + +func NewStringSliceResult(val []string, err error) *stringSliceResult { + return &stringSliceResult{val: val, err: err} +} + +func (b *stringSliceResult) Result() ([]string, error) { + return b.val, b.err +} + type interfaceResult struct { val interface{} err error @@ -117,3 +141,16 @@ func NewInterfaceResult(val interface{}, err error) *interfaceResult { func (b *interfaceResult) Result() (interface{}, error) { return b.val, b.err } + +type arrayInterfaceResult struct { + val []interface{} + err error +} + +func NewArrayInterfaceResult(val []interface{}, err error) *arrayInterfaceResult { + return &arrayInterfaceResult{val: val, err: err} +} + +func (b *arrayInterfaceResult) Result() ([]interface{}, error) { + return b.val, b.err +} diff --git a/router/match.go b/router/match.go index ad3a0570..5b82a0c6 100644 --- a/router/match.go +++ b/router/match.go @@ -9,5 +9,9 @@ type IMatcher interface { } type IRouterHandler interface { - ServeHTTP(ctx eoscContext.EoContext) + Serve(ctx eoscContext.EoContext) +} + +type IRouterPreHandler interface { + Server(ctx eoscContext.EoContext) (isContinue bool) }