From 5d4f00ae4660e39baadf3dc8fe453f73c32a1395 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Mon, 16 Dec 2024 19:51:44 -0500 Subject: [PATCH] feat: add user info to run environment variables Additionally, this change adds this information to the JWT token for each run. Signed-off-by: Donnie Adams --- pkg/api/authz/thread.go | 4 ++-- pkg/gateway/client/auth.go | 12 +++++------ pkg/gateway/client/user.go | 5 +++++ pkg/invoke/invoker.go | 43 ++++++++++++++++++++++++++------------ pkg/jwt/jwt.go | 18 +++++++++++++--- pkg/services/config.go | 2 +- 6 files changed, 59 insertions(+), 25 deletions(-) diff --git a/pkg/api/authz/thread.go b/pkg/api/authz/thread.go index 7b0c6d056..67a66a746 100644 --- a/pkg/api/authz/thread.go +++ b/pkg/api/authz/thread.go @@ -9,8 +9,8 @@ import ( ) func authorizeThread(req *http.Request, user user.Info) bool { - thread := types.FirstSet(user.GetExtra()["otto:threadID"]...) - agent := types.FirstSet(user.GetExtra()["otto:agentID"]...) + thread := types.FirstSet(user.GetExtra()["acorn:threadID"]...) + agent := types.FirstSet(user.GetExtra()["acorn:agentID"]...) if thread == "" || agent == "" { return false } diff --git a/pkg/gateway/client/auth.go b/pkg/gateway/client/auth.go index 6b86447c9..46d76a328 100644 --- a/pkg/gateway/client/auth.go +++ b/pkg/gateway/client/auth.go @@ -13,26 +13,26 @@ import ( ) type UserDecorator struct { - Next authenticator.Request - Client *Client + next authenticator.Request + client *Client } func NewUserDecorator(next authenticator.Request, client *Client) *UserDecorator { return &UserDecorator{ - Next: next, - Client: client, + next: next, + client: client, } } func (u UserDecorator) AuthenticateRequest(req *http.Request) (*authenticator.Response, bool, error) { - resp, ok, err := u.Next.AuthenticateRequest(req) + resp, ok, err := u.next.AuthenticateRequest(req) if err != nil { return nil, false, err } else if !ok { return nil, false, nil } - gatewayUser, err := u.Client.EnsureIdentity(req.Context(), &types.Identity{ + gatewayUser, err := u.client.EnsureIdentity(req.Context(), &types.Identity{ Email: firstValue(resp.User.GetExtra(), "email"), AuthProviderID: uint(firstValueAsInt(resp.User.GetExtra(), "auth_provider_id")), ProviderUsername: resp.User.GetName(), diff --git a/pkg/gateway/client/user.go b/pkg/gateway/client/user.go index c3bfd93da..906305574 100644 --- a/pkg/gateway/client/user.go +++ b/pkg/gateway/client/user.go @@ -17,6 +17,11 @@ func (c *Client) User(ctx context.Context, username string) (*types.User, error) return u, c.db.WithContext(ctx).Where("username = ?", username).First(u).Error } +func (c *Client) UserByID(ctx context.Context, id string) (*types.User, error) { + u := new(types.User) + return u, c.db.WithContext(ctx).Where("id = ?", id).First(u).Error +} + func (c *Client) UpdateProfileIconIfNeeded(ctx context.Context, user *types.User, authProviderID uint) error { if authProviderID == 0 { return nil diff --git a/pkg/invoke/invoker.go b/pkg/invoke/invoker.go index be611a877..f1290f01c 100644 --- a/pkg/invoke/invoker.go +++ b/pkg/invoke/invoker.go @@ -15,6 +15,7 @@ import ( "github.com/acorn-io/acorn/apiclient/types" "github.com/acorn-io/acorn/logger" "github.com/acorn-io/acorn/pkg/events" + "github.com/acorn-io/acorn/pkg/gateway/client" "github.com/acorn-io/acorn/pkg/gz" "github.com/acorn-io/acorn/pkg/hash" "github.com/acorn-io/acorn/pkg/jwt" @@ -34,22 +35,22 @@ import ( var log = logger.Package() type Invoker struct { - gptClient *gptscript.GPTScript - uncached kclient.WithWatch - tokenService *jwt.TokenService - events *events.Emitter - threadWorkspaceProvider string - serverURL string + gptClient *gptscript.GPTScript + uncached kclient.WithWatch + gatewayClient *client.Client + tokenService *jwt.TokenService + events *events.Emitter + serverURL string } -func NewInvoker(c kclient.WithWatch, gptClient *gptscript.GPTScript, serverURL, workspaceProviderType string, tokenService *jwt.TokenService, events *events.Emitter) *Invoker { +func NewInvoker(c kclient.WithWatch, gptClient *gptscript.GPTScript, gatewayClient *client.Client, serverURL string, tokenService *jwt.TokenService, events *events.Emitter) *Invoker { return &Invoker{ - uncached: c, - gptClient: gptClient, - tokenService: tokenService, - events: events, - threadWorkspaceProvider: workspaceProviderType, - serverURL: serverURL, + uncached: c, + gptClient: gptClient, + gatewayClient: gatewayClient, + tokenService: tokenService, + events: events, + serverURL: serverURL, } } @@ -461,6 +462,16 @@ func (i *Invoker) Resume(ctx context.Context, c kclient.WithWatch, thread *v1.Th return err } + var userID, userName, userEmail string + if thread.Spec.UserUID != "" { + u, err := i.gatewayClient.UserByID(ctx, thread.Spec.UserUID) + if err != nil { + return fmt.Errorf("failed to get user: %w", err) + } + + userID, userName, userEmail = thread.Spec.UserUID, u.Username, u.Email + } + token, err := i.tokenService.NewToken(jwt.TokenContext{ RunID: run.Name, ThreadID: thread.Name, @@ -468,6 +479,9 @@ func (i *Invoker) Resume(ctx context.Context, c kclient.WithWatch, thread *v1.Th WorkflowID: run.Spec.WorkflowName, WorkflowStepID: run.Spec.WorkflowStepID, Scope: thread.Namespace, + UserID: userID, + UserName: userName, + UserEmail: userEmail, }) if err != nil { return err @@ -496,6 +510,9 @@ func (i *Invoker) Resume(ctx context.Context, c kclient.WithWatch, thread *v1.Th "ACORN_DEFAULT_TEXT_EMBEDDING_MODEL="+string(types.DefaultModelAliasTypeTextEmbedding), "ACORN_DEFAULT_IMAGE_GENERATION_MODEL="+string(types.DefaultModelAliasTypeImageGeneration), "ACORN_DEFAULT_VISION_MODEL="+string(types.DefaultModelAliasTypeVision), + "ACORN_USER_ID="+userID, + "ACORN_USER_NAME="+userName, + "ACORN_USER_EMAIL="+userEmail, "GPTSCRIPT_HTTP_ENV=ACORN_TOKEN,ACORN_RUN_ID,ACORN_THREAD_ID,ACORN_WORKFLOW_ID,ACORN_WORKFLOW_STEP_ID,ACORN_AGENT_ID", ), DefaultModel: run.Spec.DefaultModel, diff --git a/pkg/jwt/jwt.go b/pkg/jwt/jwt.go index fa76aee8b..5a14a2628 100644 --- a/pkg/jwt/jwt.go +++ b/pkg/jwt/jwt.go @@ -20,6 +20,9 @@ type TokenContext struct { WorkflowID string WorkflowStepID string Scope string + UserID string + UserName string + UserEmail string } type TokenService struct{} @@ -37,9 +40,12 @@ func (t *TokenService) AuthenticateRequest(req *http.Request) (*authenticator.Re authz.AuthenticatedGroup, }, Extra: map[string][]string{ - "otto:runID": {tokenContext.RunID}, - "otto:threadID": {tokenContext.ThreadID}, - "otto:agentID": {tokenContext.AgentID}, + "acorn:runID": {tokenContext.RunID}, + "acorn:threadID": {tokenContext.ThreadID}, + "acorn:agentID": {tokenContext.AgentID}, + "acorn:userID": {tokenContext.UserID}, + "acorn:userName": {tokenContext.UserName}, + "acorn:userEmail": {tokenContext.UserEmail}, }, }, }, true, nil @@ -63,6 +69,9 @@ func (t *TokenService) DecodeToken(token string) (*TokenContext, error) { Scope: claims["Scope"].(string), WorkflowID: claims["WorkflowID"].(string), WorkflowStepID: claims["WorkflowStepID"].(string), + UserID: claims["UserID"].(string), + UserName: claims["UserName"].(string), + UserEmail: claims["UserEmail"].(string), }, nil } @@ -74,6 +83,9 @@ func (t *TokenService) NewToken(context TokenContext) (string, error) { "Scope": context.Scope, "WorkflowID": context.WorkflowID, "WorkflowStepID": context.WorkflowStepID, + "UserID": context.UserID, + "UserName": context.UserName, + "UserEmail": context.UserEmail, }) return token.SignedString([]byte(secret)) } diff --git a/pkg/services/config.go b/pkg/services/config.go index 7c6bf6918..ad2385380 100644 --- a/pkg/services/config.go +++ b/pkg/services/config.go @@ -217,7 +217,7 @@ func New(ctx context.Context, config Config) (*Services, error) { tokenServer = &jwt.TokenService{} events = events.NewEmitter(storageClient) gatewayClient = client.New(gatewayDB, config.AuthAdminEmails) - invoker = invoke.NewInvoker(storageClient, c, config.Hostname, config.WorkspaceProviderType, tokenServer, events) + invoker = invoke.NewInvoker(storageClient, c, client.New(gatewayDB, config.AuthAdminEmails), config.Hostname, tokenServer, events) modelProviderDispatcher = dispatcher.New(invoker, storageClient, c) proxyServer *proxy.Proxy