diff --git a/README.md b/README.md index 9d74612..bc9a835 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ - auth/basic 基本的验证处理; - auth/jwt JSON Web Tokens 中间件; - auth/session session 管理; +- auth/temporary 临时令牌; - auth/token 传统方式的令牌管理; - empty 提供了一个不作任何操作的中间件; - skip 根据条件跳过路由的执行; diff --git a/middlewares/auth/temporary/temporary.go b/middlewares/auth/temporary/temporary.go new file mode 100644 index 0000000..2ea0635 --- /dev/null +++ b/middlewares/auth/temporary/temporary.go @@ -0,0 +1,122 @@ +// SPDX-FileCopyrightText: 2024 caixw +// +// SPDX-License-Identifier: MIT + +// Package temporary 用于创建一个一次性的令牌 +package temporary + +import ( + "errors" + "net/http" + "time" + + "github.com/issue9/cache" + "github.com/issue9/mux/v9/header" + "github.com/issue9/web" + "github.com/issue9/web/openapi" + + "github.com/issue9/webuse/v7/internal/mauth" + "github.com/issue9/webuse/v7/middlewares/auth" +) + +type tokenType int + +const tokenContext tokenType = 0 + +type Response struct { + XMLName struct{} `json:"-" cbor:"-" xml:"token" yaml:"-"` + Token string `json:"token" xml:"token" cbor:"token" comment:"access token"` // 访问令牌 + Expire int `json:"expire" xml:"expire,attr" cbor:"expire" comment:"access token expired"` // 访问令牌的有效时长,单位为秒 +} + +type Temporary[T any] struct { + cache web.Cache + ttl time.Duration + expire int + once bool + unauthProblemID string + invalidTokenProblemID string +} + +// New 创建 [Temporary] 对象 +// +// ttl 表示令牌的过期时间。 +// once 是否为一次性令牌,如果为 true,在验证成功之后,该令牌将自动失效; +// unauthProblemID 验证不通过时的错误代码; +// invalidTokenProblemID 令牌无效时返回的错误代码; +func New[T any](s web.Server, ttl time.Duration, once bool, unauthProblemID, invalidTokenProblemID string) *Temporary[T] { + return &Temporary[T]{ + cache: web.NewCache(s.UniqueID(), s.Cache()), + ttl: ttl, + expire: int(ttl.Seconds()), + once: once, + unauthProblemID: unauthProblemID, + invalidTokenProblemID: invalidTokenProblemID, + } +} + +// New 创建令牌 +// +// v 为令牌关联的数据,之后通过验证接口可以访问该数据; +func (t *Temporary[T]) New(ctx *web.Context, v T, status int) web.Responser { + token := ctx.Server().UniqueID() + if err := t.cache.Set(token, v, t.ttl); err != nil { + return ctx.Error(err, "") + } + + return web.Response(status, &Response{Token: token, Expire: t.expire}) +} + +func (t *Temporary[T]) Middleware(next web.HandlerFunc, method, _, _ string) web.HandlerFunc { + if method == http.MethodOptions { + return next + } + + return func(ctx *web.Context) web.Responser { + token := auth.GetBearerToken(ctx, header.Authorization) + if token == "" { + return ctx.Problem(t.unauthProblemID) + } + + var v T + err := t.cache.Get(token, &v) + switch { + case errors.Is(err, cache.ErrCacheMiss()): + return ctx.Problem(t.unauthProblemID) + case err != nil: + return ctx.Error(err, t.invalidTokenProblemID) + default: + mauth.Set(ctx, v) + ctx.SetVar(tokenContext, token) + + if t.once { + if err := t.cache.Delete(token); err != nil { + ctx.Server().Logs().ERROR().Error(err) // 只记录错误,不反馈给客户端。 + } + } + + return next(ctx) + } + } +} + +func (t *Temporary[T]) Logout(ctx *web.Context) error { + if key, found := ctx.GetVar(tokenContext); found { + return t.cache.Delete(key.(string)) + } + return nil +} + +func (t *Temporary[T]) GetInfo(ctx *web.Context) (T, bool) { + return mauth.Get[T](ctx) +} + +// SecurityScheme 声明支持 openapi 的 [openapi.SecurityScheme] 对象 +func SecurityScheme(id string, desc web.LocaleStringer) *openapi.SecurityScheme { + return &openapi.SecurityScheme{ + ID: id, + Type: openapi.SecuritySchemeTypeHTTP, + Description: desc, + Scheme: auth.Bearer, + } +} diff --git a/middlewares/auth/temporary/temporary_test.go b/middlewares/auth/temporary/temporary_test.go new file mode 100644 index 0000000..9274729 --- /dev/null +++ b/middlewares/auth/temporary/temporary_test.go @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: 2024 caixw +// +// SPDX-License-Identifier: MIT + +package temporary + +import ( + "encoding/json" + "net/http" + "testing" + "time" + + "github.com/issue9/assert/v4" + "github.com/issue9/mux/v9/header" + "github.com/issue9/web" + "github.com/issue9/web/server/servertest" + + "github.com/issue9/webuse/v7/internal/testserver" + "github.com/issue9/webuse/v7/middlewares/auth" +) + +var _ auth.Auth[string] = &Temporary[string]{} + +func TestTemporary(t *testing.T) { + a := assert.New(t, false) + s := testserver.New(a) + + temp := New[string](s, time.Second, true, web.ProblemForbidden, web.ProblemBadRequest) + a.NotNil(temp) + s.Routers() + + r := s.Routers().New("default", nil) + r.Post("/login", func(ctx *web.Context) web.Responser { + return temp.New(ctx, "5", http.StatusCreated) + }) + + r.Get("/info", func(ctx *web.Context) web.Responser { + if info, ok := temp.GetInfo(ctx); ok { + return web.OK(info) // info == /login 中传递的值 "5" + } + panic("永远不可能达到此处") + }, temp) + + defer servertest.Run(a, s)() + defer s.Close(0) + + // 未登录 + servertest.Get(a, "http://localhost:8080/info"). + Do(nil). + Status(http.StatusForbidden) + + servertest.Post(a, "http://localhost:8080/login", nil). + Do(nil). + Status(http.StatusCreated). + BodyFunc(func(a *assert.Assertion, body []byte) { + resp := &Response{} + a.NotError(json.Unmarshal(body, resp)). + NotEmpty(resp.Token). + Equal(1, resp.Expire) + + // 正常访问 + servertest.Get(a, "http://localhost:8080/info"). + Header(header.Authorization, auth.BearerToken(resp.Token)). + Do(nil). + Status(http.StatusOK). + StringBody(`"5"`) + + // 再次访问,令牌失效 + servertest.Get(a, "http://localhost:8080/info"). + Header(header.Authorization, auth.BearerToken(resp.Token)). + Do(nil). + Status(http.StatusForbidden) + }) +}