Skip to content

Commit

Permalink
syncing up to 6958704e221de3f6d410f1491cff571ab4c287fc
Browse files Browse the repository at this point in the history
Co-authored-by: Bruce Yu <[email protected]>
  • Loading branch information
superblocksadmin and bruce-y committed Dec 28, 2024
1 parent 1063129 commit 935252e
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 4 deletions.
35 changes: 35 additions & 0 deletions internal/auth/add_token_if_needed.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/url"
"strings"

"github.com/golang-jwt/jwt/v4"
"github.com/superblocksteam/agent/internal/auth/oauth"
authtypes "github.com/superblocksteam/agent/internal/auth/types"
"github.com/superblocksteam/agent/pkg/jsonutils"
Expand Down Expand Up @@ -160,6 +161,17 @@ func (t *tokenManager) AddTokenIfNeeded(
}
authConfig.Fields["authToken"] = structpb.NewStringValue(tokenPayload.Token)
authConfig.Fields["idToken"] = structpb.NewStringValue(tokenPayload.IdToken)

{
decodedToken, err := t.decodeJwt(ctx, tokenPayload.IdToken)
if err == nil {
tokenPayload.TokenDecoded = decodedToken
// We might also want to call this tokenClaims instead of tokenDecoded
authConfig.Fields["tokenDecoded"] = structpb.NewStructValue(decodedToken)
} else {
log.Warn("error decoding id token", zap.Error(err))
}
}
}

tokenPayload.BindingName = "oauth"
Expand Down Expand Up @@ -441,6 +453,29 @@ func (t *tokenManager) addParam(ctx context.Context, datasourceConfig *structpb.
})
}

func (t *tokenManager) decodeJwt(ctx context.Context, token string) (*structpb.Struct, error) {
if token == "" {
return nil, fmt.Errorf("empty token")
}

parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
if err != nil {
return nil, fmt.Errorf("failed to parse JWT: %w", err)
}

claims, ok := parsedToken.Claims.(jwt.MapClaims)
if !ok {
return nil, fmt.Errorf("failed to get JWT claims")
}

result, err := structpb.NewStruct(map[string]interface{}(claims))
if err != nil {
return nil, fmt.Errorf("failed to convert claims to Struct: %w", err)
}

return result, nil
}

func (t *tokenManager) marshalInputToJSONString(ctx context.Context, input interface{}) string {
log := observability.ZapLogger(ctx, t.logger)

Expand Down
50 changes: 50 additions & 0 deletions internal/auth/add_token_if_needed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1262,3 +1262,53 @@ func DatasourceConfig(authType string, authConfig map[string]interface{}) *struc
}
return s
}

func TestDecodeJwt(t *testing.T) {
tm := &tokenManager{logger: zaptest.NewLogger(t)}

tests := []struct {
name string
token string
wantErr bool
wantClaims map[string]interface{}
errContains string
}{
{
name: "empty token",
token: "",
wantErr: true,
errContains: "empty token",
},
{
name: "invalid token",
token: "invalid.token",
wantErr: true,
errContains: "failed to parse JWT",
},
{
name: "valid token",
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiZW1haWwiOiJmb29Ac3VwZXJibG9ja3MuY29tIiwiaWF0IjoxNTE2MjM5MDIyfQ.acxuPTE4HrmSFMY9v73QY5qgQWrXsRrbWdLo5Ss7fgU",
wantClaims: map[string]interface{}{
"sub": "1234567890",
"name": "John Doe",
"email": "[email protected]",
"iat": float64(1516239022),
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
claims, err := tm.decodeJwt(context.Background(), tt.token)

if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errContains)
return
}

require.NoError(t, err)
assert.Equal(t, tt.wantClaims, claims.AsMap())
})
}
}
9 changes: 5 additions & 4 deletions internal/auth/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ func transform(value any) (int, error) {
}

type TokenPayload struct {
Token string
IdToken string
UserId string
BindingName string
Token string
IdToken string
UserId string
TokenDecoded *structpb.Struct
BindingName string
}
9 changes: 9 additions & 0 deletions pkg/executor/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ func EvaluateDatasource(
authToken := tokenPayload.Token
authIdToken := tokenPayload.IdToken
authUserId := tokenPayload.UserId
tokenDecoded := tokenPayload.TokenDecoded
bindingName := tokenPayload.BindingName

newVars := []*apiv1.Variables_Config{}
Expand All @@ -656,6 +657,14 @@ func EvaluateDatasource(
objPairs = append(objPairs, fmt.Sprintf("idToken: '%s'", authIdToken))
redactedObjPairs = append(redactedObjPairs, fmt.Sprintf("idToken: '%s'", auth.RedactedSecret))
}
if tokenDecoded != nil {
tokenDecodedJson, err := protojson.Marshal(tokenDecoded)
if err != nil {
return nil, nil, err
}
objPairs = append(objPairs, fmt.Sprintf("tokenDecoded: %s", string(tokenDecodedJson)))
redactedObjPairs = append(redactedObjPairs, fmt.Sprintf("tokenDecoded: %s", "{}"))
}
value := fmt.Sprintf("{{ { %s } }}", strings.Join(objPairs, ", "))
valueRedacted := fmt.Sprintf("{{ { %s } }}", strings.Join(redactedObjPairs, ", "))

Expand Down

0 comments on commit 935252e

Please sign in to comment.