diff --git a/internal/api/api.go b/internal/api/api.go index 2591f5f..a322b2c 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -2,6 +2,7 @@ package api import ( "github.com/go-chi/chi/v5" + gh "github.com/moonwalker/moonbase/pkg/github" ) // @title Moonbase @@ -36,7 +37,7 @@ func Routes() chi.Router { // api routes which needs authenticated user token r.Group(func(r chi.Router) { - r.Use(withUser) + r.Use(gh.WithUser) // low level github apis r.Group(func(r chi.Router) { r.Get("/repos", getRepos) diff --git a/internal/api/auth.go b/internal/api/auth.go index a9d9e5f..ba1c324 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -1,7 +1,6 @@ package api import ( - "context" "encoding/base64" "fmt" "net/http" @@ -21,15 +20,12 @@ import ( // test the flow: // http://localhost:8080/login/github?return_url=/login/github/authenticate -type ctxKey int - const ( - retUrlCodePath = 0 - retUrlCodeQuery = 1 - oauthStateSep = "|" - codeTokenExpires = time.Minute - accessTokenExpires = time.Hour * 24 - ctxKeyAccessToken ctxKey = iota + retUrlCodePath = 0 + retUrlCodeQuery = 1 + oauthStateSep = "|" + codeTokenExpires = time.Minute + accessTokenExpires = time.Hour * 24 ) var ( @@ -188,41 +184,3 @@ func decryptExchangeCode(code string) (string, error) { return string(token.Claims.(*jwt.AuthClaims).Data), nil } - -func withUser(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var tokenString string - - // get token from authorization header - bearer := r.Header.Get("Authorization") - if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" { - tokenString = bearer[7:] - } - if len(tokenString) == 0 { - errAuthNoToken().Json(w) - return - } - - token, err := jwt.VerifyAndDecrypt(env.JweKey, env.JwtKey, tokenString) - if err != nil { - errAuthBadToken().Json(w) - return - } - - authClaims, ok := token.Claims.(*jwt.AuthClaims) - if !ok { - errAuthBadClaims().Json(w) - return - } - - // add auth claims to context - ctx := context.WithValue(r.Context(), ctxKeyAccessToken, string(authClaims.Data)) - - // authenticated, pass it through - next.ServeHTTP(w, r.WithContext(ctx)) - }) -} - -func accessTokenFromContext(ctx context.Context) string { - return ctx.Value(ctxKeyAccessToken).(string) -} diff --git a/internal/api/cms.go b/internal/api/cms.go index 25c7764..8f4f07b 100644 --- a/internal/api/cms.go +++ b/internal/api/cms.go @@ -133,7 +133,7 @@ func getSchema(ctx context.Context, accessToken string, owner string, repo strin // @Security bearerToken func getInfo(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") @@ -172,7 +172,7 @@ func getInfo(w http.ResponseWriter, r *http.Request) { // @Security bearerToken func getCollections(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") @@ -213,7 +213,7 @@ func getCollections(w http.ResponseWriter, r *http.Request) { // @Security bearerToken func postCollection(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") @@ -255,7 +255,7 @@ func postCollection(w http.ResponseWriter, r *http.Request) { // @Security bearerToken func delCollection(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") @@ -297,7 +297,7 @@ func delCollection(w http.ResponseWriter, r *http.Request) { // @Security bearerToken func getEntries(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") @@ -362,7 +362,7 @@ func putEntry(w http.ResponseWriter, r *http.Request) { func createOrUpdateEntry(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") @@ -486,7 +486,7 @@ func createOrUpdateEntry(w http.ResponseWriter, r *http.Request) { // @Security bearerToken func getEntry(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") @@ -555,7 +555,7 @@ func getEntry(w http.ResponseWriter, r *http.Request) { // @Security bearerToken func delEntry(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") @@ -589,7 +589,7 @@ func delEntry(w http.ResponseWriter, r *http.Request) { // @Security bearerToken func postImage(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") @@ -667,7 +667,7 @@ func postImage(w http.ResponseWriter, r *http.Request) { // @Security bearerToken func getSettings(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") @@ -705,7 +705,7 @@ func getSettings(w http.ResponseWriter, r *http.Request) { // @Security bearerToken func getSetting(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") @@ -745,7 +745,7 @@ func getSetting(w http.ResponseWriter, r *http.Request) { // @Security bearerToken func postSetting(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") @@ -782,7 +782,7 @@ func postSetting(w http.ResponseWriter, r *http.Request) { // @Security bearerToken func delSetting(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") @@ -815,7 +815,7 @@ func delSetting(w http.ResponseWriter, r *http.Request) { // @Security bearerToken func getReference(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") @@ -871,7 +871,7 @@ func getReference(w http.ResponseWriter, r *http.Request) { // @Security bearerToken func getCollectionGroups(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") @@ -912,7 +912,7 @@ func getCollectionGroups(w http.ResponseWriter, r *http.Request) { // @Security bearerToken func getCollectionGroup(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") diff --git a/internal/api/github.go b/internal/api/github.go index a70af06..aedd7dc 100644 --- a/internal/api/github.go +++ b/internal/api/github.go @@ -63,7 +63,7 @@ type commitPayload struct { // @Security bearerToken func getRepos(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) page, _ := strconv.Atoi(r.FormValue("page")) perPage, _ := strconv.Atoi(r.FormValue("per_page")) @@ -101,7 +101,7 @@ func getRepos(w http.ResponseWriter, r *http.Request) { // @Security bearerToken func getBranches(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") @@ -141,7 +141,7 @@ func getBranches(w http.ResponseWriter, r *http.Request) { // @Security bearerToken func getTree(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") @@ -181,7 +181,7 @@ func getTree(w http.ResponseWriter, r *http.Request) { // @Security bearerToken func getBlob(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") @@ -216,7 +216,7 @@ func getBlob(w http.ResponseWriter, r *http.Request) { // @Security bearerToken func postBlob(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") @@ -253,7 +253,7 @@ func postBlob(w http.ResponseWriter, r *http.Request) { // @Security bearerToken func delBlob(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - accessToken := accessTokenFromContext(ctx) + accessToken := gh.AccessTokenFromContext(ctx) owner := chi.URLParam(r, "owner") repo := chi.URLParam(r, "repo") diff --git a/pkg/github/auth.go b/pkg/github/auth.go new file mode 100644 index 0000000..0af7b9e --- /dev/null +++ b/pkg/github/auth.go @@ -0,0 +1,54 @@ +package github + +import ( + "context" + "net/http" + "strings" + + "github.com/moonwalker/moonbase/internal/env" + "github.com/moonwalker/moonbase/internal/jwt" +) + +type ctxKey int + +const ( + ctxKeyAccessToken ctxKey = iota +) + +func WithUser(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var tokenString string + + // get token from authorization header + bearer := r.Header.Get("Authorization") + if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" { + tokenString = bearer[7:] + } + if len(tokenString) == 0 { + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + + token, err := jwt.VerifyAndDecrypt(env.JweKey, env.JwtKey, tokenString) + if err != nil { + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + + authClaims, ok := token.Claims.(*jwt.AuthClaims) + if !ok { + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + + // add auth claims to context + ctx := context.WithValue(r.Context(), ctxKeyAccessToken, string(authClaims.Data)) + + // authenticated, pass it through + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func AccessTokenFromContext(ctx context.Context) string { + return ctx.Value(ctxKeyAccessToken).(string) +}