Skip to content

Commit

Permalink
feat: add user info to run environment variables
Browse files Browse the repository at this point in the history
Additionally, this change adds this information to the JWT token for
each run.

Signed-off-by: Donnie Adams <[email protected]>
  • Loading branch information
thedadams committed Dec 17, 2024
1 parent 4035938 commit 5d4f00a
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 25 deletions.
4 changes: 2 additions & 2 deletions pkg/api/authz/thread.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
)

func authorizeThread(req *http.Request, user user.Info) bool {
thread := types.FirstSet(user.GetExtra()["otto:threadID"]...)
agent := types.FirstSet(user.GetExtra()["otto:agentID"]...)
thread := types.FirstSet(user.GetExtra()["acorn:threadID"]...)
agent := types.FirstSet(user.GetExtra()["acorn:agentID"]...)
if thread == "" || agent == "" {
return false
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/gateway/client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,26 @@ import (
)

type UserDecorator struct {
Next authenticator.Request
Client *Client
next authenticator.Request
client *Client
}

func NewUserDecorator(next authenticator.Request, client *Client) *UserDecorator {
return &UserDecorator{
Next: next,
Client: client,
next: next,
client: client,
}
}

func (u UserDecorator) AuthenticateRequest(req *http.Request) (*authenticator.Response, bool, error) {
resp, ok, err := u.Next.AuthenticateRequest(req)
resp, ok, err := u.next.AuthenticateRequest(req)
if err != nil {
return nil, false, err
} else if !ok {
return nil, false, nil
}

gatewayUser, err := u.Client.EnsureIdentity(req.Context(), &types.Identity{
gatewayUser, err := u.client.EnsureIdentity(req.Context(), &types.Identity{
Email: firstValue(resp.User.GetExtra(), "email"),
AuthProviderID: uint(firstValueAsInt(resp.User.GetExtra(), "auth_provider_id")),
ProviderUsername: resp.User.GetName(),
Expand Down
5 changes: 5 additions & 0 deletions pkg/gateway/client/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ func (c *Client) User(ctx context.Context, username string) (*types.User, error)
return u, c.db.WithContext(ctx).Where("username = ?", username).First(u).Error
}

func (c *Client) UserByID(ctx context.Context, id string) (*types.User, error) {
u := new(types.User)
return u, c.db.WithContext(ctx).Where("id = ?", id).First(u).Error
}

func (c *Client) UpdateProfileIconIfNeeded(ctx context.Context, user *types.User, authProviderID uint) error {
if authProviderID == 0 {
return nil
Expand Down
43 changes: 30 additions & 13 deletions pkg/invoke/invoker.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/acorn-io/acorn/apiclient/types"
"github.com/acorn-io/acorn/logger"
"github.com/acorn-io/acorn/pkg/events"
"github.com/acorn-io/acorn/pkg/gateway/client"
"github.com/acorn-io/acorn/pkg/gz"
"github.com/acorn-io/acorn/pkg/hash"
"github.com/acorn-io/acorn/pkg/jwt"
Expand All @@ -34,22 +35,22 @@ import (
var log = logger.Package()

type Invoker struct {
gptClient *gptscript.GPTScript
uncached kclient.WithWatch
tokenService *jwt.TokenService
events *events.Emitter
threadWorkspaceProvider string
serverURL string
gptClient *gptscript.GPTScript
uncached kclient.WithWatch
gatewayClient *client.Client
tokenService *jwt.TokenService
events *events.Emitter
serverURL string
}

func NewInvoker(c kclient.WithWatch, gptClient *gptscript.GPTScript, serverURL, workspaceProviderType string, tokenService *jwt.TokenService, events *events.Emitter) *Invoker {
func NewInvoker(c kclient.WithWatch, gptClient *gptscript.GPTScript, gatewayClient *client.Client, serverURL string, tokenService *jwt.TokenService, events *events.Emitter) *Invoker {
return &Invoker{
uncached: c,
gptClient: gptClient,
tokenService: tokenService,
events: events,
threadWorkspaceProvider: workspaceProviderType,
serverURL: serverURL,
uncached: c,
gptClient: gptClient,
gatewayClient: gatewayClient,
tokenService: tokenService,
events: events,
serverURL: serverURL,
}
}

Expand Down Expand Up @@ -461,13 +462,26 @@ func (i *Invoker) Resume(ctx context.Context, c kclient.WithWatch, thread *v1.Th
return err
}

var userID, userName, userEmail string
if thread.Spec.UserUID != "" {
u, err := i.gatewayClient.UserByID(ctx, thread.Spec.UserUID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}

userID, userName, userEmail = thread.Spec.UserUID, u.Username, u.Email
}

token, err := i.tokenService.NewToken(jwt.TokenContext{
RunID: run.Name,
ThreadID: thread.Name,
AgentID: run.Spec.AgentName,
WorkflowID: run.Spec.WorkflowName,
WorkflowStepID: run.Spec.WorkflowStepID,
Scope: thread.Namespace,
UserID: userID,
UserName: userName,
UserEmail: userEmail,
})
if err != nil {
return err
Expand Down Expand Up @@ -496,6 +510,9 @@ func (i *Invoker) Resume(ctx context.Context, c kclient.WithWatch, thread *v1.Th
"ACORN_DEFAULT_TEXT_EMBEDDING_MODEL="+string(types.DefaultModelAliasTypeTextEmbedding),
"ACORN_DEFAULT_IMAGE_GENERATION_MODEL="+string(types.DefaultModelAliasTypeImageGeneration),
"ACORN_DEFAULT_VISION_MODEL="+string(types.DefaultModelAliasTypeVision),
"ACORN_USER_ID="+userID,
"ACORN_USER_NAME="+userName,
"ACORN_USER_EMAIL="+userEmail,
"GPTSCRIPT_HTTP_ENV=ACORN_TOKEN,ACORN_RUN_ID,ACORN_THREAD_ID,ACORN_WORKFLOW_ID,ACORN_WORKFLOW_STEP_ID,ACORN_AGENT_ID",
),
DefaultModel: run.Spec.DefaultModel,
Expand Down
18 changes: 15 additions & 3 deletions pkg/jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ type TokenContext struct {
WorkflowID string
WorkflowStepID string
Scope string
UserID string
UserName string
UserEmail string
}

type TokenService struct{}
Expand All @@ -37,9 +40,12 @@ func (t *TokenService) AuthenticateRequest(req *http.Request) (*authenticator.Re
authz.AuthenticatedGroup,
},
Extra: map[string][]string{
"otto:runID": {tokenContext.RunID},
"otto:threadID": {tokenContext.ThreadID},
"otto:agentID": {tokenContext.AgentID},
"acorn:runID": {tokenContext.RunID},
"acorn:threadID": {tokenContext.ThreadID},
"acorn:agentID": {tokenContext.AgentID},
"acorn:userID": {tokenContext.UserID},
"acorn:userName": {tokenContext.UserName},
"acorn:userEmail": {tokenContext.UserEmail},
},
},
}, true, nil
Expand All @@ -63,6 +69,9 @@ func (t *TokenService) DecodeToken(token string) (*TokenContext, error) {
Scope: claims["Scope"].(string),
WorkflowID: claims["WorkflowID"].(string),
WorkflowStepID: claims["WorkflowStepID"].(string),
UserID: claims["UserID"].(string),
UserName: claims["UserName"].(string),
UserEmail: claims["UserEmail"].(string),
}, nil
}

Expand All @@ -74,6 +83,9 @@ func (t *TokenService) NewToken(context TokenContext) (string, error) {
"Scope": context.Scope,
"WorkflowID": context.WorkflowID,
"WorkflowStepID": context.WorkflowStepID,
"UserID": context.UserID,
"UserName": context.UserName,
"UserEmail": context.UserEmail,
})
return token.SignedString([]byte(secret))
}
2 changes: 1 addition & 1 deletion pkg/services/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ func New(ctx context.Context, config Config) (*Services, error) {
tokenServer = &jwt.TokenService{}
events = events.NewEmitter(storageClient)
gatewayClient = client.New(gatewayDB, config.AuthAdminEmails)
invoker = invoke.NewInvoker(storageClient, c, config.Hostname, config.WorkspaceProviderType, tokenServer, events)
invoker = invoke.NewInvoker(storageClient, c, client.New(gatewayDB, config.AuthAdminEmails), config.Hostname, tokenServer, events)
modelProviderDispatcher = dispatcher.New(invoker, storageClient, c)

proxyServer *proxy.Proxy
Expand Down

0 comments on commit 5d4f00a

Please sign in to comment.