Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/pre router #149

Merged
merged 7 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions application/auth/aksk/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ var _ auth.IAuthFactory = (*factory)(nil)

var driverName = "aksk"

//Register 注册auth驱动工厂
// Register 注册auth驱动工厂
func Register() {
auth.FactoryRegister(driverName, NewFactory())
}
Expand Down Expand Up @@ -44,6 +44,10 @@ func (f *factory) Alias() []string {
}
}

func (f *factory) PreRouters() []*auth.PreRouter {
return nil
}

func (f *factory) Create(tokenName string, position string, rule interface{}) (application.IAuth, error) {
a := &aksk{
id: toId(tokenName, position),
Expand All @@ -54,7 +58,7 @@ func (f *factory) Create(tokenName string, position string, rule interface{}) (a
return a, nil
}

//NewFactory 生成一个 auth_apiKey工厂
// NewFactory 生成一个 auth_apiKey工厂
func NewFactory() auth.IAuthFactory {
typ := reflect.TypeOf((*Config)(nil))
render, _ := schema.Generate(typ, nil)
Expand Down
4 changes: 4 additions & 0 deletions application/auth/apikey/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ func (f *factory) Alias() []string {
}
}

func (f *factory) PreRouters() []*auth.PreRouter {
return nil
}

func (f *factory) Create(tokenName string, position string, rule interface{}) (application.IAuth, error) {
a := &apikey{
id: toId(tokenName, position),
Expand Down
4 changes: 4 additions & 0 deletions application/auth/basic/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ func (f *factory) Alias() []string {
}
}

func (f *factory) PreRouters() []*auth.PreRouter {
return nil
}

func (f *factory) Create(tokenName string, position string, rule interface{}) (application.IAuth, error) {
a := &basic{
id: toId(tokenName, position),
Expand Down
32 changes: 21 additions & 11 deletions application/auth/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"reflect"
"strings"

"github.com/eolinker/apinto/router"

"github.com/eolinker/apinto/application"
"github.com/eolinker/eosc/log"

Expand All @@ -18,24 +20,32 @@ var (
_ eosc.ISetting = defaultAuthFactoryRegister
)

//IAuthFactory 鉴权工厂方法
type PreRouter struct {
ID string
PreHandler router.IRouterPreHandler
Path string
Method []string
}

// IAuthFactory 鉴权工厂方法
type IAuthFactory interface {
Create(tokenName string, position string, rule interface{}) (application.IAuth, error)
Alias() []string
Render() interface{}
ConfigType() reflect.Type
UserType() reflect.Type
PreRouters() []*PreRouter
}

//IAuthFactoryRegister 实现了鉴权工厂管理器
// IAuthFactoryRegister 实现了鉴权工厂管理器
type IAuthFactoryRegister interface {
RegisterFactoryByKey(key string, factory IAuthFactory)
GetFactoryByKey(key string) (IAuthFactory, bool)
Keys() []string
Alias() map[string]string
}

//driverRegister 驱动注册器
// driverRegister 驱动注册器
type driverRegister struct {
register eosc.IRegister[IAuthFactory]
keys []string
Expand Down Expand Up @@ -80,7 +90,7 @@ func (dm *driverRegister) ReadOnly() bool {
return true
}

//newAuthFactoryManager 创建auth工厂管理器
// newAuthFactoryManager 创建auth工厂管理器
func newAuthFactoryManager() *driverRegister {
return &driverRegister{
register: eosc.NewRegister[IAuthFactory](),
Expand All @@ -90,12 +100,12 @@ func newAuthFactoryManager() *driverRegister {
}
}

//GetFactoryByKey 获取指定auth工厂
// GetFactoryByKey 获取指定auth工厂
func (dm *driverRegister) GetFactoryByKey(key string) (IAuthFactory, bool) {
return dm.register.Get(key)
}

//RegisterFactoryByKey 注册auth工厂
// RegisterFactoryByKey 注册auth工厂
func (dm *driverRegister) RegisterFactoryByKey(key string, factory IAuthFactory) {
err := dm.register.Register(key, factory, true)
if err != nil {
Expand All @@ -109,7 +119,7 @@ func (dm *driverRegister) RegisterFactoryByKey(key string, factory IAuthFactory)
}
}

//Keys 返回所有已注册的key
// Keys 返回所有已注册的key
func (dm *driverRegister) Keys() []string {
return dm.keys
}
Expand All @@ -118,18 +128,18 @@ func (dm *driverRegister) Alias() map[string]string {
return dm.driverAlias
}

//FactoryRegister 注册auth工厂到默认auth工厂注册器
// FactoryRegister 注册auth工厂到默认auth工厂注册器
func FactoryRegister(key string, factory IAuthFactory) {

defaultAuthFactoryRegister.RegisterFactoryByKey(key, factory)
}

//Get 从默认auth工厂注册器中获取auth工厂
// Get 从默认auth工厂注册器中获取auth工厂
func Get(key string) (IAuthFactory, bool) {
return defaultAuthFactoryRegister.GetFactoryByKey(key)
}

//Keys 返回默认的auth工厂注册器中所有已注册的key
// Keys 返回默认的auth工厂注册器中所有已注册的key
func Keys() []string {
return defaultAuthFactoryRegister.Keys()
}
Expand All @@ -138,7 +148,7 @@ func Alias() map[string]string {
return defaultAuthFactoryRegister.Alias()
}

//GetFactory 获取指定auth工厂,若指定的不存在则返回一个已注册的工厂
// GetFactory 获取指定auth工厂,若指定的不存在则返回一个已注册的工厂
func GetFactory(name string) (IAuthFactory, error) {
factory, ok := Get(name)
if !ok {
Expand Down
8 changes: 6 additions & 2 deletions application/auth/jwt/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ var _ auth.IAuthFactory = (*factory)(nil)

var driverName = "jwt"

//Register 注册auth驱动工厂
// Register 注册auth驱动工厂
func Register() {
auth.FactoryRegister(driverName, NewFactory())
}
Expand Down Expand Up @@ -43,6 +43,10 @@ func (f *factory) Alias() []string {
}
}

func (f *factory) PreRouters() []*auth.PreRouter {
return nil
}

func (f *factory) Create(tokenName string, position string, rule interface{}) (application.IAuth, error) {
baseConfig, ok := rule.(*application.BaseConfig)
if !ok {
Expand All @@ -66,7 +70,7 @@ func (f *factory) Create(tokenName string, position string, rule interface{}) (a
return a, nil
}

//NewFactory 生成一个 auth_apiKey工厂
// NewFactory 生成一个 auth_apiKey工厂
func NewFactory() auth.IAuthFactory {
typ := reflect.TypeOf((*Config)(nil))
render, _ := schema.Generate(typ, nil)
Expand Down
144 changes: 144 additions & 0 deletions application/auth/oauth2/authorize.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package oauth2

import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"strconv"
"sync"
"time"

scope_manager "github.com/eolinker/apinto/scope-manager"

"github.com/eolinker/apinto/resources"
http_context "github.com/eolinker/eosc/eocontext/http-context"
)

const (
ResponseTypeCode = "code"
ResponseTypeToken = "token"
)

func NewAuthorizeHandler() *AuthorizeHandler {
return &AuthorizeHandler{}
}

type AuthorizeHandler struct {
cache scope_manager.IProxyOutput[resources.ICache]
once sync.Once
}

func (a *AuthorizeHandler) Handle(ctx http_context.IHttpContext, client *Client, params url.Values) {
responseType := params.Get("response_type")
if responseType == "" || !((responseType == ResponseTypeCode && client.EnableAuthorizationCode) || (responseType == ResponseTypeToken && client.EnableImplicitGrant)) {
ctx.Response().SetBody([]byte(fmt.Sprintf("unsupported response type: %s,client id is %s", responseType, client.ClientId)))
ctx.Response().SetStatus(http.StatusForbidden, "forbidden")
return
}

scope := params.Get("scope")
if scope == "" && client.MandatoryScope {
ctx.Response().SetBody([]byte("scope is required, client id is " + client.ClientId))
ctx.Response().SetStatus(http.StatusForbidden, "forbidden")
return
}
matchScope := false
for _, s := range client.Scopes {
if s == scope {
matchScope = true
break
}
}
if !matchScope {
ctx.Response().SetBody([]byte("invalid scope, client id is " + client.ClientId))
ctx.Response().SetStatus(http.StatusForbidden, "forbidden")
return
}

redirectURI := params.Get("redirect_uri")
if redirectURI == "" {
ctx.Response().SetBody([]byte("redirect uri is required, client id is " + client.ClientId))
ctx.Response().SetStatus(http.StatusForbidden, "forbidden")
return
}

matchRedirectUri := false
for _, uri := range client.RedirectUrls {
if uri == redirectURI {
matchRedirectUri = true
break
}
}
if !matchRedirectUri {
ctx.Response().SetBody([]byte("invalid redirect uri, client id is " + client.ClientId))
ctx.Response().SetStatus(http.StatusForbidden, "forbidden")
return
}
uri, err := url.Parse(redirectURI)
if err != nil {
ctx.Response().SetBody([]byte("invalid redirect uri, client id is " + client.ClientId))
ctx.Response().SetStatus(http.StatusForbidden, "forbidden")
return
}
a.once.Do(func() {
a.cache = scope_manager.Auto[resources.ICache]("", "redis")
})
list := a.cache.List()
if len(list) < 1 {
ctx.Response().SetBody([]byte("redis cache is not available"))
ctx.Response().SetStatus(http.StatusForbidden, "forbidden")
return
}
cache := list[0]
query := url.Values{}
switch responseType {
case ResponseTypeCode:
{
// 授权码模式
provisionKey := params.Get("provision_key")
if provisionKey != client.ProvisionKey {
ctx.Response().SetBody([]byte("invalid provision key, client id is " + client.ClientId))
ctx.Response().SetStatus(http.StatusForbidden, "forbidden")
return
}
code := generateRandomString()
redisKey := fmt.Sprintf("apinto:oauth2_codes:%s:%s", os.Getenv("cluster_id"), code)
field := map[string]interface{}{
"code": code,
"scope": scope,
}
_, err = cache.HMSetN(ctx.Context(), redisKey, field, 6*time.Minute).Result()
if err != nil {
ctx.Response().SetBody([]byte(fmt.Sprintf("(%s)redis HMSet %s error: %s", client.ClientId, redisKey, err.Error())))
ctx.Response().SetStatus(http.StatusInternalServerError, "server error")
return
}
query.Set("code", code)
}
case ResponseTypeToken:
{
token, err := generateToken(ctx.Context(), cache, client.ClientId, client.TokenExpiration, client.RefreshTokenTTL, scope, false)
if err != nil {
ctx.Response().SetBody([]byte(fmt.Sprintf("(%s)generate token error: %s", client.ClientId, err.Error())))
ctx.Response().SetStatus(http.StatusInternalServerError, "server error")
return
}
query.Set("access_token", token.AccessToken)
query.Set("token_type", "bearer")
query.Set("expires_in", strconv.Itoa(token.ExpiresIn))
}
}

state := params.Get("state")
if state != "" {
query.Set("state", state)
}
data, _ := json.Marshal(map[string]interface{}{
"redirect_uri": fmt.Sprintf("%s?%s", uri.String(), query.Encode()),
})
ctx.Response().SetBody(data)
ctx.Response().SetStatus(http.StatusOK, "OK")
return
}
42 changes: 42 additions & 0 deletions application/auth/oauth2/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package oauth2

import "github.com/eolinker/apinto/application"

const (
GrantAuthorizationCode = "authorization_code"
GrantClientCredentials = "client_credentials"
GrantRefreshToken = "refresh_token"
)

type Config struct {
application.Auth
Users []*User `json:"users" label:"用户列表"`
}

type User struct {
Pattern Pattern `json:"pattern" label:"用户信息"`
application.User
}

type Pattern struct {
ClientId string `json:"client_id"`
ClientSecret string `json:"client_secret"`
ClientType string `json:"client_type"`
HashSecret bool `json:"hash_secret"`
RedirectUrls []string `json:"redirect_urls" label:"重定向URL"`
Scopes []string `json:"scopes" label:"授权范围"`
MandatoryScope bool `json:"mandatory_scope" label:"强制授权"`
ProvisionKey string `json:"provision_key" label:"Provision Key"`
TokenExpiration int `json:"token_expiration" label:"令牌过期时间"`
RefreshTokenTTL int `json:"refresh_token_ttl" label:"刷新令牌TTL"`
EnableAuthorizationCode bool `json:"enable_authorization_code" label:"启用授权码模式"`
EnableImplicitGrant bool `json:"enable_implicit_grant" label:"启用隐式授权模式"`
EnableClientCredentials bool `json:"enable_client_credentials" label:"启用客户端凭证模式"`
AcceptHttpIfAlreadyTerminated bool `json:"accept_http_if_already_terminated" label:"如果已终止,则接受HTTP"`
ReuseRefreshToken bool `json:"reuse_refresh_token" label:"重用刷新令牌"`
PersistentRefreshToken bool `json:"persistent_refresh_token" label:"持久刷新令牌"`
}

func (u *User) Username() string {
return u.Pattern.ClientId
}
Loading