From 29bad43a49b348175d6aca749a15669b88d9e4c9 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 6 Feb 2025 14:59:03 -0500 Subject: [PATCH 01/86] oauth providers to access oidc and related functions Signed-off-by: Aaron Choo --- internal/controller/oauth/base_provider.go | 50 +++ .../oauth/client_credentials_provider.go | 99 ++++++ internal/controller/oauth/oidc_provider.go | 212 ++++++++++++ internal/controller/oauth/types.go | 32 ++ internal/controller/rotators/aws_common.go | 296 ++++++++++++++++ .../controller/rotators/aws_oidc_rotator.go | 319 ++++++++++++++++++ internal/controller/rotators/common.go | 1 + 7 files changed, 1009 insertions(+) create mode 100644 internal/controller/oauth/base_provider.go create mode 100644 internal/controller/oauth/client_credentials_provider.go create mode 100644 internal/controller/oauth/oidc_provider.go create mode 100644 internal/controller/oauth/types.go create mode 100644 internal/controller/rotators/aws_common.go create mode 100644 internal/controller/rotators/aws_oidc_rotator.go create mode 100644 internal/controller/rotators/common.go diff --git a/internal/controller/oauth/base_provider.go b/internal/controller/oauth/base_provider.go new file mode 100644 index 000000000..4f5e8b520 --- /dev/null +++ b/internal/controller/oauth/base_provider.go @@ -0,0 +1,50 @@ +package oauth + +import ( + "context" + "fmt" + "net/http" + "time" + + "github.com/go-logr/logr" + corev1 "k8s.io/api/core/v1" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +// BaseProvider implements common OAuth functionality +type BaseProvider struct { + client client.Client + logger logr.Logger + http *http.Client +} + +// NewBaseProvider creates a new base provider +func NewBaseProvider(client client.Client, logger logr.Logger, httpClient *http.Client) *BaseProvider { + if httpClient == nil { + httpClient = &http.Client{Timeout: 30 * time.Second} + } + + return &BaseProvider{ + client: client, + logger: logger, + http: httpClient, + } +} + +// getClientSecret retrieves the client secret from a Kubernetes secret +func (p *BaseProvider) getClientSecret(ctx context.Context, secretRef *corev1.SecretReference) (string, error) { + secret := &corev1.Secret{} + if err := p.client.Get(ctx, client.ObjectKey{ + Namespace: secretRef.Namespace, + Name: secretRef.Name, + }, secret); err != nil { + return "", fmt.Errorf("failed to get client secret: %w", err) + } + + clientSecret, ok := secret.Data["client-secret"] + if !ok { + return "", fmt.Errorf("client-secret key not found in secret") + } + + return string(clientSecret), nil +} diff --git a/internal/controller/oauth/client_credentials_provider.go b/internal/controller/oauth/client_credentials_provider.go new file mode 100644 index 000000000..a97f4c9ac --- /dev/null +++ b/internal/controller/oauth/client_credentials_provider.go @@ -0,0 +1,99 @@ +package oauth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" + corev1 "k8s.io/api/core/v1" +) + +// ClientCredentialsProvider implements the standard OAuth2 client credentials flow +type ClientCredentialsProvider struct { + *BaseProvider +} + +// NewClientCredentialsProvider creates a new client credentials provider +func NewClientCredentialsProvider(base *BaseProvider) *ClientCredentialsProvider { + return &ClientCredentialsProvider{ + BaseProvider: base, + } +} + +func (p *ClientCredentialsProvider) FetchToken(ctx context.Context, oidc *egv1a1.OIDC) (*TokenResponse, error) { + clientSecret, err := p.getClientSecret(ctx, &corev1.SecretReference{ + Name: string(oidc.ClientSecret.Name), + Namespace: string(*oidc.ClientSecret.Namespace), + }) + if err != nil { + return nil, err + } + + // Prepare token request + form := url.Values{} + form.Set("grant_type", "client_credentials") + form.Set("client_id", oidc.ClientID) + form.Set("client_secret", clientSecret) + if len(oidc.Scopes) > 0 { + form.Set("scope", strings.Join(oidc.Scopes, " ")) + } + + // Make request + req, err := http.NewRequestWithContext(ctx, "POST", *oidc.Provider.TokenEndpoint, + strings.NewReader(form.Encode())) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := p.http.Do(req) + if err != nil { + return nil, fmt.Errorf("token request failed: %w", err) + } + defer resp.Body.Close() + + // Parse response + var raw map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&raw); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // Convert to TokenResponse + token := &TokenResponse{ + Raw: raw, + } + + // Extract standard fields + if v, ok := raw["access_token"].(string); ok { + token.AccessToken = v + } + if v, ok := raw["token_type"].(string); ok { + token.TokenType = v + } + if v, ok := raw["scope"].(string); ok { + token.Scope = v + } + + // Handle expiration + if v, ok := raw["expires_in"].(float64); ok { + token.ExpiresAt = time.Now().Add(time.Duration(v) * time.Second) + } + + return token, nil +} + +func (p *ClientCredentialsProvider) SupportsFlow(flowType FlowType) bool { + return flowType == FlowClientCredentials +} + +func (p *ClientCredentialsProvider) ValidateToken(ctx context.Context, token string) error { + // Implement token validation logic + // This might involve introspection endpoint if available + return nil +} diff --git a/internal/controller/oauth/oidc_provider.go b/internal/controller/oauth/oidc_provider.go new file mode 100644 index 000000000..8f0725333 --- /dev/null +++ b/internal/controller/oauth/oidc_provider.go @@ -0,0 +1,212 @@ +package oauth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" + "github.com/golang-jwt/jwt/v5" +) + +// OIDCProvider extends ClientCredentialsProvider with OIDC support +type OIDCProvider struct { + *ClientCredentialsProvider + httpClient *http.Client + oidcCredential *egv1a1.OIDC +} + +// OIDCMetadata represents the OpenID Connect provider metadata +type OIDCMetadata struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + JWKSURI string `json:"jwks_uri"` + SupportedScopes []string `json:"scopes_supported"` +} + +// NewOIDCProvider creates a new OIDC-aware provider +func NewOIDCProvider(base *BaseProvider, oidcCredentials *egv1a1.OIDC) *OIDCProvider { + return &OIDCProvider{ + ClientCredentialsProvider: NewClientCredentialsProvider(base), + httpClient: &http.Client{Timeout: 30 * time.Second}, + oidcCredential: oidcCredentials, + } +} + +// getOIDCMetadata retrieves or creates OIDC metadata for the given issuer URL +func (p *OIDCProvider) getOIDCMetadata(ctx context.Context, issuerURL string) (*OIDCMetadata, error) { + // Check context before proceeding + if err := ctx.Err(); err != nil { + return nil, fmt.Errorf("context error before discovery: %w", err) + } + + // Fetch OIDC configuration + wellKnown := strings.TrimSuffix(issuerURL, "/") + "/.well-known/openid-configuration" + req, err := http.NewRequestWithContext(ctx, "GET", wellKnown, nil) + if err != nil { + return nil, fmt.Errorf("failed to create discovery request: %w", err) + } + + resp, err := p.httpClient.Do(req) + + if err != nil { + return nil, fmt.Errorf("failed to fetch OIDC metadata: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code from discovery endpoint: %d", resp.StatusCode) + } + + var metadata OIDCMetadata + if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil { + return nil, fmt.Errorf("failed to decode OIDC metadata: %w", err) + } + + // Validate required fields + if metadata.Issuer == "" { + return nil, fmt.Errorf("issuer is required in OIDC metadata") + } + if metadata.TokenEndpoint == "" { + return nil, fmt.Errorf("token_endpoint is required in OIDC metadata") + } + + return &metadata, nil +} + +// validateIDToken validates the ID token according to the OIDC spec +func (p *OIDCProvider) validateIDToken(ctx context.Context, rawIDToken, issuerURL, clientID string) (map[string]interface{}, error) { + if err := ctx.Err(); err != nil { + return nil, fmt.Errorf("context error before validation: %w", err) + } + + token, err := jwt.Parse(rawIDToken, func(token *jwt.Token) (interface{}, error) { + // For now, we skip signature validation as we don't have the key + // TODO: Implement JWKS validation + return jwt.UnsafeAllowNoneSignatureType, nil + }) + if err != nil { + return nil, fmt.Errorf("failed to parse ID token: %w", err) + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, fmt.Errorf("invalid claims format in token") + } + + now := time.Now() + + // Validate issuer + if iss, err := claims.GetIssuer(); err != nil || iss != issuerURL { + return nil, fmt.Errorf("invalid issuer claim") + } + + // Validate audience + if aud, err := claims.GetAudience(); err != nil || !contains(aud, clientID) { + return nil, fmt.Errorf("invalid audience claim") + } + + // Validate expiration + if exp, err := claims.GetExpirationTime(); err != nil || exp.Before(now) { + return nil, fmt.Errorf("token is expired") + } + + // Validate issued at + if iat, err := claims.GetIssuedAt(); err != nil || iat.After(now) { + return nil, fmt.Errorf("token used before issued") + } + + return claims, nil +} + +// contains checks if a string slice contains a value +func contains(slice []string, val string) bool { + for _, item := range slice { + if item == val { + return true + } + } + return false +} + +// FetchToken retrieves and validates tokens using the client credentials flow with OIDC support +func (p *OIDCProvider) FetchToken(ctx context.Context) (*TokenResponse, error) { + // If issuer URL is provided, fetch OIDC metadata + if issuerURL := p.oidcCredential.Provider.Issuer; issuerURL != "" { + metadata, err := p.getOIDCMetadata(ctx, issuerURL) + if err != nil { + return nil, fmt.Errorf("failed to get OIDC metadata: %w", err) + } + + // Use discovered token endpoint if not explicitly provided + if p.oidcCredential.Provider.TokenEndpoint == nil { + p.oidcCredential.Provider.TokenEndpoint = &metadata.TokenEndpoint + } + + // Add discovered scopes if available + if len(metadata.SupportedScopes) > 0 { + requestedScopes := make(map[string]bool) + for _, scope := range p.oidcCredential.Scopes { + requestedScopes[scope] = true + } + + // Add supported scopes that aren't already requested + for _, scope := range metadata.SupportedScopes { + if !requestedScopes[scope] { + p.oidcCredential.Scopes = append(p.oidcCredential.Scopes, scope) + } + } + } + } + + // Ensure openid scope is present + hasOpenID := false + for _, scope := range p.oidcCredential.Scopes { + if scope == "openid" { + hasOpenID = true + break + } + } + if !hasOpenID { + p.oidcCredential.Scopes = append(p.oidcCredential.Scopes, "openid") + } + + // Get base token response + token, err := p.ClientCredentialsProvider.FetchToken(ctx, p.oidcCredential) + if err != nil { + return nil, fmt.Errorf("failed to get token: %w", err) + } + + // Extract ID token if present + if rawIDToken, ok := token.Raw["id_token"].(string); ok { + token.IDToken = rawIDToken + + // Validate ID token if issuer URL is provided + if issuerURL := p.oidcCredential.Provider.Issuer; issuerURL != "" { + claims, err := p.validateIDToken(ctx, rawIDToken, issuerURL, p.oidcCredential.ClientID) + if err != nil { + return nil, fmt.Errorf("failed to validate ID token: %w", err) + } + + // Store claims in raw map for access by consumers + token.Raw["id_token_claims"] = claims + } + } + + return token, nil +} + +func (p *OIDCProvider) SupportsFlow(flowType FlowType) bool { + return flowType == FlowClientCredentialsWithIDToken +} + +// ValidateToken implements token validation for both access tokens and ID tokens +func (p *OIDCProvider) ValidateToken(ctx context.Context, token string) error { + // For ID tokens, we expect them to have been validated during GetToken + // For access tokens, we could implement introspection here if needed + return nil +} diff --git a/internal/controller/oauth/types.go b/internal/controller/oauth/types.go new file mode 100644 index 000000000..d4e753798 --- /dev/null +++ b/internal/controller/oauth/types.go @@ -0,0 +1,32 @@ +package oauth + +import ( + "context" + "time" +) + +// FlowType represents different OAuth/OIDC flow types +type FlowType string + +const ( + FlowClientCredentials FlowType = "client_credentials" + FlowClientCredentialsWithIDToken FlowType = "client_credentials_with_id_token" +) + +// TokenResponse represents the common token response structure +type TokenResponse struct { + AccessToken string + TokenType string + ExpiresAt time.Time + Scope string + IDToken string // Optional OIDC field + RefreshToken string // Optional refresh token + Raw map[string]interface{} +} + +// Provider defines the interface for OAuth token providers +type Provider interface { + FetchToken(ctx context.Context) (*TokenResponse, error) + ValidateToken(ctx context.Context, token string) error + SupportsFlow(flowType FlowType) bool +} diff --git a/internal/controller/rotators/aws_common.go b/internal/controller/rotators/aws_common.go new file mode 100644 index 000000000..675c1789c --- /dev/null +++ b/internal/controller/rotators/aws_common.go @@ -0,0 +1,296 @@ +/* +Package backendauthrotators provides credential rotation implementations. +This file contains common AWS functionality shared between different AWS credential +rotators. It provides: +1. AWS Client Interfaces and Implementations: +- IAMOperations for AWS IAM API operations +- STSOperations for AWS STS API operations +- Concrete implementations with proper AWS SDK integration +2. Credential File Management: +- Parsing and formatting of AWS credentials files +- Support for multiple credential profiles +- Handling of temporary credentials and session tokens +3. Common Configuration: +- Default AWS configuration with adaptive retry +- Standard timeouts and delays +- Session name formatting +*/ +package backendauthrotators + +import ( + "context" + "fmt" + "sort" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/iam" + "github.com/aws/aws-sdk-go-v2/service/sts" + corev1 "k8s.io/api/core/v1" +) + +// Common constants for AWS operations +const ( + // defaultKeyDeletionDelay is the time to wait before deleting old access keys + defaultKeyDeletionDelay = 60 * time.Second + // defaultMinPropagationDelay is the minimum time to wait for credential propagation + defaultMinPropagationDelay = 30 * time.Second + // credentialsKey is the key used to store AWS credentials in Kubernetes secrets + credentialsKey = "credentials" + // awsSessionNameFormat is the format string for AWS session names + awsSessionNameFormat = "ai-gateway-%s" +) + +// profileFromMetadata determines which AWS credentials profile to use. +// If a profile is specified in the metadata, that profile is used. +// Otherwise, if there is only one profile in the credentials, that profile is used. +// If there are multiple profiles and none specified, an error is returned. +func profileFromMetadata(metadata map[string]string, creds *awsCredentialsFile) (string, error) { + // If profile is specified in metadata, use that + if profile, ok := metadata["profile"]; ok { + if _, exists := creds.profiles[profile]; !exists { + return "", fmt.Errorf("specified profile %q not found in credentials", profile) + } + return profile, nil + } + + // If only one profile exists, use that + if len(creds.profiles) == 1 { + for profile := range creds.profiles { + return profile, nil + } + } + + // Multiple profiles exist but none specified + return "", fmt.Errorf("multiple AWS credential profiles found but none specified in metadata") +} + +// defaultAWSConfig returns an AWS config with adaptive retry mode enabled. +// This ensures better handling of transient API failures and rate limiting. +func defaultAWSConfig(ctx context.Context) (aws.Config, error) { + return config.LoadDefaultConfig(ctx, + config.WithRetryMode(aws.RetryModeAdaptive), + ) +} + +// awsConfigFromCredentials creates an AWS config using the provided credentials. +// This is used when we want to explicitly use credentials from a secret rather than +// relying on the default credential chain. +func awsConfigFromCredentials(ctx context.Context, creds *awsCredentials) (aws.Config, error) { + return config.LoadDefaultConfig(ctx, + config.WithRetryMode(aws.RetryModeAdaptive), + config.WithCredentialsProvider(aws.CredentialsProviderFunc( + func(ctx context.Context) (aws.Credentials, error) { + return aws.Credentials{ + AccessKeyID: creds.accessKeyID, + SecretAccessKey: creds.secretAccessKey, + SessionToken: creds.sessionToken, + }, nil + }, + )), + config.WithRegion(creds.region), + ) +} + +// IAMOperations defines the interface for AWS IAM operations required by the rotators. +// This interface allows for easier testing through mocks and provides a clear +// contract for required IAM functionality. +type IAMOperations interface { + // CreateAccessKey creates a new IAM access key + CreateAccessKey(ctx context.Context, params *iam.CreateAccessKeyInput, optFns ...func(*iam.Options)) (*iam.CreateAccessKeyOutput, error) + // DeleteAccessKey deletes an existing IAM access key + DeleteAccessKey(ctx context.Context, params *iam.DeleteAccessKeyInput, optFns ...func(*iam.Options)) (*iam.DeleteAccessKeyOutput, error) +} + +// STSOperations defines the interface for AWS STS operations required by the rotators. +// This interface encapsulates the STS API operations needed for OIDC token exchange +// and role assumption. +type STSOperations interface { + // AssumeRoleWithWebIdentity exchanges a web identity token for temporary AWS credentials + AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) +} + +// IAMClient implements the IAMOperations interface using the AWS SDK v2. +// It provides a concrete implementation for IAM operations using the official AWS SDK. +type IAMClient struct { + client *iam.Client +} + +// NewIAMClient creates a new IAMClient with the given AWS config. +// The client is configured with the provided AWS configuration, which should +// include appropriate credentials and region settings. +func NewIAMClient(cfg aws.Config) *IAMClient { + return &IAMClient{ + client: iam.NewFromConfig(cfg), + } +} + +// CreateAccessKey implements the IAMOperations interface by creating a new IAM access key. +func (c *IAMClient) CreateAccessKey(ctx context.Context, params *iam.CreateAccessKeyInput, optFns ...func(*iam.Options)) (*iam.CreateAccessKeyOutput, error) { + return c.client.CreateAccessKey(ctx, params, optFns...) +} + +// DeleteAccessKey implements the IAMOperations interface by deleting an IAM access key. +func (c *IAMClient) DeleteAccessKey(ctx context.Context, params *iam.DeleteAccessKeyInput, optFns ...func(*iam.Options)) (*iam.DeleteAccessKeyOutput, error) { + return c.client.DeleteAccessKey(ctx, params, optFns...) +} + +// STSClient implements the STSOperations interface using the AWS SDK v2. +// It provides a concrete implementation for STS operations using the official AWS SDK. +type STSClient struct { + client *sts.Client +} + +// NewSTSClient creates a new STSClient with the given AWS config. +// The client is configured with the provided AWS configuration, which should +// include appropriate credentials and region settings. +func NewSTSClient(cfg aws.Config) *STSClient { + return &STSClient{ + client: sts.NewFromConfig(cfg), + } +} + +// AssumeRoleWithWebIdentity implements the STSOperations interface by exchanging +// a web identity token for temporary AWS credentials. +func (c *STSClient) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { + return c.client.AssumeRoleWithWebIdentity(ctx, params, optFns...) +} + +// awsCredentials represents a single set of AWS credentials, including optional +// session token and region configuration. It maps to a single profile in an +// AWS credentials file. +type awsCredentials struct { + // profile is the name of the credentials profile + profile string + // accessKeyID is the AWS access key ID + accessKeyID string + // secretAccessKey is the AWS secret access key + secretAccessKey string + // sessionToken is the optional AWS session token for temporary credentials + sessionToken string + // region is the optional AWS region for the profile + region string +} + +// awsCredentialsFile represents a complete AWS credentials file containing +// multiple credential profiles. It provides a structured way to manage +// multiple sets of AWS credentials. +type awsCredentialsFile struct { + // profiles maps profile names to their respective credentials + profiles map[string]*awsCredentials +} + +// parseAWSCredentialsFile parses an AWS credentials file with multiple profiles. +// The file format follows the standard AWS credentials file format: +// +// [profile-name] +// aws_access_key_id = AKIAXXXXXXXXXXXXXXXX +// aws_secret_access_key = xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx +// aws_session_token = xxxxxxxx (optional) +// region = xx-xxxx-x (optional) +// +// Returns a structured representation of the credentials file. +func parseAWSCredentialsFile(data string) *awsCredentialsFile { + file := &awsCredentialsFile{ + profiles: make(map[string]*awsCredentials), + } + + var currentCreds *awsCredentials + + for _, line := range strings.Split(data, "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { + profileName := strings.TrimPrefix(strings.TrimSuffix(line, "]"), "[") + currentCreds = &awsCredentials{profile: profileName} + file.profiles[profileName] = currentCreds + continue + } + + if currentCreds == nil { + continue + } + + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + continue + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + switch key { + case "aws_access_key_id": + currentCreds.accessKeyID = value + case "aws_secret_access_key": + currentCreds.secretAccessKey = value + case "aws_session_token": + currentCreds.sessionToken = value + case "region": + currentCreds.region = value + } + } + + return file +} + +// formatAWSCredentialsFile formats multiple AWS credential profiles into a credentials file. +// The output follows the standard AWS credentials file format and ensures: +// - Consistent ordering of profiles through sorting +// - Proper formatting of all credential components +// - Optional inclusion of session tokens and regions +// - Profile isolation with proper section markers +func formatAWSCredentialsFile(file *awsCredentialsFile) string { + var builder strings.Builder + + // Sort profiles to ensure consistent output + profileNames := make([]string, 0, len(file.profiles)) + for profileName := range file.profiles { + profileNames = append(profileNames, profileName) + } + sort.Strings(profileNames) + + for i, profileName := range profileNames { + if i > 0 { + builder.WriteString("\n") + } + creds := file.profiles[profileName] + builder.WriteString(fmt.Sprintf("[%s]\n", profileName)) + builder.WriteString(fmt.Sprintf("aws_access_key_id = %s\n", creds.accessKeyID)) + builder.WriteString(fmt.Sprintf("aws_secret_access_key = %s\n", creds.secretAccessKey)) + if creds.sessionToken != "" { + builder.WriteString(fmt.Sprintf("aws_session_token = %s\n", creds.sessionToken)) + } + if creds.region != "" { + builder.WriteString(fmt.Sprintf("region = %s\n", creds.region)) + } + } + return builder.String() +} + +// validateAWSSecret validates that a secret contains valid AWS credentials +func validateAWSSecret(secret *corev1.Secret) (*awsCredentialsFile, error) { + if secret.Data == nil || len(secret.Data[credentialsKey]) == 0 { + return nil, fmt.Errorf("secret contains no AWS credentials") + } + + creds := parseAWSCredentialsFile(string(secret.Data[credentialsKey])) + if creds == nil || len(creds.profiles) == 0 { + return nil, fmt.Errorf("no valid AWS credentials found in secret") + } + + return creds, nil +} + +// updateAWSCredentialsInSecret updates AWS credentials in a secret +func updateAWSCredentialsInSecret(secret *corev1.Secret, creds *awsCredentialsFile) { + if secret.Data == nil { + secret.Data = make(map[string][]byte) + } + secret.Data[credentialsKey] = []byte(formatAWSCredentialsFile(creds)) +} diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go new file mode 100644 index 000000000..2150ba864 --- /dev/null +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -0,0 +1,319 @@ +package backendauthrotators + +import ( + "context" + "errors" + "fmt" + "net/http" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/go-logr/logr" + "k8s.io/client-go/kubernetes" + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/envoyproxy/ai-gateway/internal/controller/oauth" +) + +// ----------------------------------------------------------------------------- +// Types and Constants +// ----------------------------------------------------------------------------- + +// AWSOIDCRotator implements the Rotator interface for AWS OIDC token exchange. +// It manages the lifecycle of temporary AWS credentials obtained through OIDC token +// exchange with AWS STS. The rotator automatically schedules credential refresh +// before expiration to ensure continuous access. +// +// Key features: +// - Automatic credential refresh before expiration +// - Support for role assumption with web identity +// - Integration with Kubernetes secrets for credential storage +// - Channel-based rotation scheduling +type AWSOIDCRotator struct { + // client is used for Kubernetes API operations + client client.Client + // kube provides additional Kubernetes API capabilities + kube kubernetes.Interface + // logger is used for structured logging + logger logr.Logger + // stsOps provides AWS STS operations interface + stsOps STSOperations + // oidcProvider provides OIDC token provider + oidcProvider oauth.Provider +} + +// ----------------------------------------------------------------------------- +// Constructor and Interface Implementation +// ----------------------------------------------------------------------------- + +// NewAWSOIDCRotator creates a new AWS OIDC rotator with the specified configuration. +// It initializes the AWS STS client and sets up the rotation channels. +func NewAWSOIDCRotator( + client client.Client, + kube kubernetes.Interface, + logger logr.Logger, +) (*AWSOIDCRotator, error) { + cfg, err := defaultAWSConfig(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to load AWS config: %w", err) + } + + stsClient := NewSTSClient(cfg) + + // Create OIDC provider + baseProvider := oauth.NewBaseProvider(client, logger, &http.Client{Timeout: 30 * time.Second}) + oidcProvider := oauth.NewOIDCProvider(baseProvider) + + return &AWSOIDCRotator{ + client: client, + kube: kube, + logger: logger, + stsOps: stsClient, + oidcProvider: oidcProvider, + }, nil +} + +// Type returns the type of rotation this rotator handles +func (r *AWSOIDCRotator) Type() RotationType { + return RotationTypeAWSOIDC +} + +// SetSTSOperations sets the STS operations implementation - primarily used for testing +func (r *AWSOIDCRotator) SetSTSOperations(ops STSOperations) { + r.stsOps = ops +} + +// ----------------------------------------------------------------------------- +// Event Processing +// ----------------------------------------------------------------------------- + +// Start begins processing rotation events from the rotation channel. +// It runs until the context is cancelled, processing only events +// that match this rotator's type. +func (r *AWSOIDCRotator) Start(ctx context.Context) error { + for { + select { + case event := <-r.rotationChan: + // Only process events for this rotator type + if event.Type != RotationTypeAWSOIDC { + continue + } + + if err := r.Rotate(ctx, event); err != nil { + if !errors.Is(err, context.Canceled) { + r.logger.Error(err, "failed to rotate credentials", + "namespace", event.Namespace, + "name", event.Name) + } + } + case <-ctx.Done(): + return ctx.Err() + } + } +} + +// ----------------------------------------------------------------------------- +// Main Interface Methods - Initialize and Rotate +// ----------------------------------------------------------------------------- + +// Initialize implements the initial token retrieval for AWS OIDC tokens +func (r *AWSOIDCRotator) Initialize(ctx context.Context, event RotationEvent) error { + r.logger.Info("initializing AWS OIDC token", + "namespace", event.Namespace, + "name", event.Name) + + // Get OIDC configuration from metadata + config, err := r.getOIDCConfig(ctx, event) + if err != nil { + return fmt.Errorf("failed to get OIDC config: %w", err) + } + + // Fetch and validate OIDC token + token, err := r.oidcProvider.FetchToken(ctx, config) + if err != nil { + return fmt.Errorf("failed to fetch OIDC token: %w", err) + } + + // Exchange token for AWS credentials + result, err := r.assumeRoleWithToken(ctx, event, token.IDToken) + if err != nil { + return err + } + + // Create new secret struct + secret := newSecret(event.Namespace, event.Name) + + // Get profile from metadata, defaulting to "default" if not specified + profile := event.Metadata["profile"] + if profile == "" { + profile = "default" + } + + // Create credentials file with the specified profile + credsFile := &awsCredentialsFile{ + profiles: map[string]*awsCredentials{ + profile: { + profile: profile, + accessKeyID: aws.ToString(result.Credentials.AccessKeyId), + secretAccessKey: aws.ToString(result.Credentials.SecretAccessKey), + sessionToken: aws.ToString(result.Credentials.SessionToken), + region: event.Metadata["region"], + }, + }, + } + + // Update secret with credentials + updateAWSCredentialsInSecret(secret, credsFile) + return updateSecret(ctx, r.client, secret) +} + +// Rotate implements the Rotator interface for AWS OIDC credentials +func (r *AWSOIDCRotator) Rotate(ctx context.Context, event RotationEvent) error { + if err := validateRotationEvent(event); err != nil { + return err + } + + // Get OIDC configuration from metadata + config, err := r.getOIDCConfig(ctx, event) + if err != nil { + return fmt.Errorf("failed to get OIDC config: %w", err) + } + + // Fetch and validate OIDC token + token, err := r.oidcProvider.FetchToken(ctx, config) + if err != nil { + return fmt.Errorf("failed to fetch OIDC token: %w", err) + } + + // Get existing secret + secret, err := lookupSecret(ctx, r.client, event.Namespace, event.Name) + if err != nil { + return err + } + + existingCreds, err := validateAWSSecret(secret) + if err != nil { + return err + } + + // Determine which profile to use + profile, err := profileFromMetadata(event.Metadata, existingCreds) + if err != nil { + return fmt.Errorf("failed to determine AWS profile: %w", err) + } + + // Exchange token for AWS credentials + resp, err := r.assumeRoleWithToken(ctx, event, token.IDToken) + if err != nil { + return err + } + + // Update only the specified profile's credentials + existingCreds.profiles[profile].accessKeyID = aws.ToString(resp.Credentials.AccessKeyId) + existingCreds.profiles[profile].secretAccessKey = aws.ToString(resp.Credentials.SecretAccessKey) + existingCreds.profiles[profile].sessionToken = aws.ToString(resp.Credentials.SessionToken) + existingCreds.profiles[profile].region = event.Metadata["region"] + + // Update secret with credentials + updateAWSCredentialsInSecret(secret, existingCreds) + if err := updateSecret(ctx, r.client, secret); err != nil { + return err + } + + // Schedule next rotation if needed + return r.scheduleNextRotation(event, resp) +} + +// ----------------------------------------------------------------------------- +// Helper Methods - Token and Credential Management +// ----------------------------------------------------------------------------- + +// getOIDCConfig creates an OAuth config from the rotation event metadata +func (r *AWSOIDCRotator) getOIDCConfig(ctx context.Context, event RotationEvent) (oauth.Config, error) { + // Convert metadata to expected format + params := make(map[string]string) + + // Required fields + params["token_url"] = event.Metadata["token_url"] + params["client_id"] = event.Metadata["client_id"] + params["client-secret-name"] = event.Metadata["client_secret_name"] + + // Optional fields + if issuerURL := event.Metadata["issuer_url"]; issuerURL != "" { + params["issuer_url"] = issuerURL + } + if scopes := event.Metadata["scopes"]; scopes != "" { + params["scopes"] = scopes + } + + return oauth.NewOIDCConfig(ctx, r.client, event.Namespace, params) +} + +// assumeRoleWithToken exchanges an OIDC token for AWS credentials +func (r *AWSOIDCRotator) assumeRoleWithToken(ctx context.Context, event RotationEvent, token string) (*sts.AssumeRoleWithWebIdentityOutput, error) { + roleARN := event.Metadata["role_arn"] + if roleARN == "" { + roleARN = event.Metadata["role-arn"] // support both formats + } + if roleARN == "" { + return nil, fmt.Errorf("role ARN is required in metadata") + } + + return r.stsOps.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityInput{ + RoleArn: aws.String(roleARN), + WebIdentityToken: aws.String(token), + RoleSessionName: aws.String(fmt.Sprintf(awsSessionNameFormat, event.Name)), + }) +} + +// ----------------------------------------------------------------------------- +// Helper Methods - Validation and Scheduling +// ----------------------------------------------------------------------------- + +// scheduleNextRotation schedules the next rotation before credentials expire +func (r *AWSOIDCRotator) scheduleNextRotation(event RotationEvent, resp *sts.AssumeRoleWithWebIdentityOutput) error { + if resp.Credentials.Expiration == nil { + return nil + } + + // Calculate when we should rotate - 5 minutes before expiry + rotateAt := resp.Credentials.Expiration.Add(-5 * time.Minute) + + // If we're not too close to expiry, schedule the next rotation + if time.Until(rotateAt) > time.Second { + // Create a new event for the next rotation, preserving all metadata + nextEvent := RotationEvent{ + Namespace: event.Namespace, + Name: event.Name, + Type: RotationTypeAWSOIDC, + Metadata: make(map[string]string), + } + + // Copy all metadata from the original event + for k, v := range event.Metadata { + nextEvent.Metadata[k] = v + } + + // Update the rotation time + nextEvent.Metadata["rotate_at"] = rotateAt.Format(time.RFC3339) + + // Send the event through the schedule channel + select { + case r.scheduleChan <- nextEvent: + r.logger.Info("scheduled next rotation", + "namespace", event.Namespace, + "name", event.Name, + "profile", event.Metadata["profile"], + "rotateAt", rotateAt) + default: + r.logger.Error(fmt.Errorf("schedule channel is full"), "failed to schedule next rotation", + "namespace", event.Namespace, + "name", event.Name, + "profile", event.Metadata["profile"], + "rotateAt", rotateAt) + } + } + + return nil +} diff --git a/internal/controller/rotators/common.go b/internal/controller/rotators/common.go new file mode 100644 index 000000000..cf0164297 --- /dev/null +++ b/internal/controller/rotators/common.go @@ -0,0 +1 @@ +package backendauthrotators From 06f53190279c5a83eee390bcf0825e1a8b158725 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 6 Feb 2025 15:00:04 -0500 Subject: [PATCH 02/86] add aws oidc sts fetch Signed-off-by: Aaron Choo --- internal/controller/rotators/aws_common.go | 156 --------- .../controller/rotators/aws_oidc_rotator.go | 301 +++++------------- internal/controller/rotators/common.go | 80 +++++ 3 files changed, 152 insertions(+), 385 deletions(-) diff --git a/internal/controller/rotators/aws_common.go b/internal/controller/rotators/aws_common.go index 675c1789c..8a2a6fd40 100644 --- a/internal/controller/rotators/aws_common.go +++ b/internal/controller/rotators/aws_common.go @@ -3,7 +3,6 @@ Package backendauthrotators provides credential rotation implementations. This file contains common AWS functionality shared between different AWS credential rotators. It provides: 1. AWS Client Interfaces and Implementations: -- IAMOperations for AWS IAM API operations - STSOperations for AWS STS API operations - Concrete implementations with proper AWS SDK integration 2. Credential File Management: @@ -22,51 +21,21 @@ import ( "fmt" "sort" "strings" - "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/service/iam" "github.com/aws/aws-sdk-go-v2/service/sts" corev1 "k8s.io/api/core/v1" ) // Common constants for AWS operations const ( - // defaultKeyDeletionDelay is the time to wait before deleting old access keys - defaultKeyDeletionDelay = 60 * time.Second - // defaultMinPropagationDelay is the minimum time to wait for credential propagation - defaultMinPropagationDelay = 30 * time.Second // credentialsKey is the key used to store AWS credentials in Kubernetes secrets credentialsKey = "credentials" // awsSessionNameFormat is the format string for AWS session names awsSessionNameFormat = "ai-gateway-%s" ) -// profileFromMetadata determines which AWS credentials profile to use. -// If a profile is specified in the metadata, that profile is used. -// Otherwise, if there is only one profile in the credentials, that profile is used. -// If there are multiple profiles and none specified, an error is returned. -func profileFromMetadata(metadata map[string]string, creds *awsCredentialsFile) (string, error) { - // If profile is specified in metadata, use that - if profile, ok := metadata["profile"]; ok { - if _, exists := creds.profiles[profile]; !exists { - return "", fmt.Errorf("specified profile %q not found in credentials", profile) - } - return profile, nil - } - - // If only one profile exists, use that - if len(creds.profiles) == 1 { - for profile := range creds.profiles { - return profile, nil - } - } - - // Multiple profiles exist but none specified - return "", fmt.Errorf("multiple AWS credential profiles found but none specified in metadata") -} - // defaultAWSConfig returns an AWS config with adaptive retry mode enabled. // This ensures better handling of transient API failures and rate limiting. func defaultAWSConfig(ctx context.Context) (aws.Config, error) { @@ -75,35 +44,6 @@ func defaultAWSConfig(ctx context.Context) (aws.Config, error) { ) } -// awsConfigFromCredentials creates an AWS config using the provided credentials. -// This is used when we want to explicitly use credentials from a secret rather than -// relying on the default credential chain. -func awsConfigFromCredentials(ctx context.Context, creds *awsCredentials) (aws.Config, error) { - return config.LoadDefaultConfig(ctx, - config.WithRetryMode(aws.RetryModeAdaptive), - config.WithCredentialsProvider(aws.CredentialsProviderFunc( - func(ctx context.Context) (aws.Credentials, error) { - return aws.Credentials{ - AccessKeyID: creds.accessKeyID, - SecretAccessKey: creds.secretAccessKey, - SessionToken: creds.sessionToken, - }, nil - }, - )), - config.WithRegion(creds.region), - ) -} - -// IAMOperations defines the interface for AWS IAM operations required by the rotators. -// This interface allows for easier testing through mocks and provides a clear -// contract for required IAM functionality. -type IAMOperations interface { - // CreateAccessKey creates a new IAM access key - CreateAccessKey(ctx context.Context, params *iam.CreateAccessKeyInput, optFns ...func(*iam.Options)) (*iam.CreateAccessKeyOutput, error) - // DeleteAccessKey deletes an existing IAM access key - DeleteAccessKey(ctx context.Context, params *iam.DeleteAccessKeyInput, optFns ...func(*iam.Options)) (*iam.DeleteAccessKeyOutput, error) -} - // STSOperations defines the interface for AWS STS operations required by the rotators. // This interface encapsulates the STS API operations needed for OIDC token exchange // and role assumption. @@ -112,31 +52,6 @@ type STSOperations interface { AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) } -// IAMClient implements the IAMOperations interface using the AWS SDK v2. -// It provides a concrete implementation for IAM operations using the official AWS SDK. -type IAMClient struct { - client *iam.Client -} - -// NewIAMClient creates a new IAMClient with the given AWS config. -// The client is configured with the provided AWS configuration, which should -// include appropriate credentials and region settings. -func NewIAMClient(cfg aws.Config) *IAMClient { - return &IAMClient{ - client: iam.NewFromConfig(cfg), - } -} - -// CreateAccessKey implements the IAMOperations interface by creating a new IAM access key. -func (c *IAMClient) CreateAccessKey(ctx context.Context, params *iam.CreateAccessKeyInput, optFns ...func(*iam.Options)) (*iam.CreateAccessKeyOutput, error) { - return c.client.CreateAccessKey(ctx, params, optFns...) -} - -// DeleteAccessKey implements the IAMOperations interface by deleting an IAM access key. -func (c *IAMClient) DeleteAccessKey(ctx context.Context, params *iam.DeleteAccessKeyInput, optFns ...func(*iam.Options)) (*iam.DeleteAccessKeyOutput, error) { - return c.client.DeleteAccessKey(ctx, params, optFns...) -} - // STSClient implements the STSOperations interface using the AWS SDK v2. // It provides a concrete implementation for STS operations using the official AWS SDK. type STSClient struct { @@ -182,63 +97,6 @@ type awsCredentialsFile struct { profiles map[string]*awsCredentials } -// parseAWSCredentialsFile parses an AWS credentials file with multiple profiles. -// The file format follows the standard AWS credentials file format: -// -// [profile-name] -// aws_access_key_id = AKIAXXXXXXXXXXXXXXXX -// aws_secret_access_key = xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx -// aws_session_token = xxxxxxxx (optional) -// region = xx-xxxx-x (optional) -// -// Returns a structured representation of the credentials file. -func parseAWSCredentialsFile(data string) *awsCredentialsFile { - file := &awsCredentialsFile{ - profiles: make(map[string]*awsCredentials), - } - - var currentCreds *awsCredentials - - for _, line := range strings.Split(data, "\n") { - line = strings.TrimSpace(line) - if line == "" { - continue - } - - if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { - profileName := strings.TrimPrefix(strings.TrimSuffix(line, "]"), "[") - currentCreds = &awsCredentials{profile: profileName} - file.profiles[profileName] = currentCreds - continue - } - - if currentCreds == nil { - continue - } - - parts := strings.SplitN(line, "=", 2) - if len(parts) != 2 { - continue - } - - key := strings.TrimSpace(parts[0]) - value := strings.TrimSpace(parts[1]) - - switch key { - case "aws_access_key_id": - currentCreds.accessKeyID = value - case "aws_secret_access_key": - currentCreds.secretAccessKey = value - case "aws_session_token": - currentCreds.sessionToken = value - case "region": - currentCreds.region = value - } - } - - return file -} - // formatAWSCredentialsFile formats multiple AWS credential profiles into a credentials file. // The output follows the standard AWS credentials file format and ensures: // - Consistent ordering of profiles through sorting @@ -273,20 +131,6 @@ func formatAWSCredentialsFile(file *awsCredentialsFile) string { return builder.String() } -// validateAWSSecret validates that a secret contains valid AWS credentials -func validateAWSSecret(secret *corev1.Secret) (*awsCredentialsFile, error) { - if secret.Data == nil || len(secret.Data[credentialsKey]) == 0 { - return nil, fmt.Errorf("secret contains no AWS credentials") - } - - creds := parseAWSCredentialsFile(string(secret.Data[credentialsKey])) - if creds == nil || len(creds.profiles) == 0 { - return nil, fmt.Errorf("no valid AWS credentials found in secret") - } - - return creds, nil -} - // updateAWSCredentialsInSecret updates AWS credentials in a secret func updateAWSCredentialsInSecret(secret *corev1.Secret, creds *awsCredentialsFile) { if secret.Data == nil { diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index 2150ba864..dc3ac8456 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -2,34 +2,23 @@ package backendauthrotators import ( "context" - "errors" "fmt" "net/http" + "net/url" + "os" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/go-logr/logr" + "k8s.io/apimachinery/pkg/api/errors" "k8s.io/client-go/kubernetes" "sigs.k8s.io/controller-runtime/pkg/client" - - "github.com/envoyproxy/ai-gateway/internal/controller/oauth" ) -// ----------------------------------------------------------------------------- -// Types and Constants -// ----------------------------------------------------------------------------- - // AWSOIDCRotator implements the Rotator interface for AWS OIDC token exchange. // It manages the lifecycle of temporary AWS credentials obtained through OIDC token -// exchange with AWS STS. The rotator automatically schedules credential refresh -// before expiration to ensure continuous access. -// -// Key features: -// - Automatic credential refresh before expiration -// - Support for role assumption with web identity -// - Integration with Kubernetes secrets for credential storage -// - Channel-based rotation scheduling +// exchange with AWS STS. type AWSOIDCRotator struct { // client is used for Kubernetes API operations client client.Client @@ -39,118 +28,110 @@ type AWSOIDCRotator struct { logger logr.Logger // stsOps provides AWS STS operations interface stsOps STSOperations - // oidcProvider provides OIDC token provider - oidcProvider oauth.Provider + // backendSecurityPolicyName provides name of backend security policy + backendSecurityPolicyName string + // backendSecurityPolicyNamespace provides namespace of backend security policy + backendSecurityPolicyNamespace string + // preRotationWindow specifies how long before expiry to rotate + preRotationWindow time.Duration } -// ----------------------------------------------------------------------------- -// Constructor and Interface Implementation -// ----------------------------------------------------------------------------- - // NewAWSOIDCRotator creates a new AWS OIDC rotator with the specified configuration. // It initializes the AWS STS client and sets up the rotation channels. func NewAWSOIDCRotator( client client.Client, kube kubernetes.Interface, logger logr.Logger, + backendSecurityPolicyNamespace string, + backendSecurityPolicyName string, + preRotationWindow time.Duration, + region string, ) (*AWSOIDCRotator, error) { cfg, err := defaultAWSConfig(context.Background()) if err != nil { return nil, fmt.Errorf("failed to load AWS config: %w", err) } - stsClient := NewSTSClient(cfg) + if region != "" { + cfg.Region = region + } - // Create OIDC provider - baseProvider := oauth.NewBaseProvider(client, logger, &http.Client{Timeout: 30 * time.Second}) - oidcProvider := oauth.NewOIDCProvider(baseProvider) + if proxyURL := os.Getenv("AI_GATEWY_STS_PROXY_URL"); proxyURL != "" { + cfg.HTTPClient = &http.Client{ + Transport: &http.Transport{ + Proxy: func(*http.Request) (*url.URL, error) { + return url.Parse(proxyURL) + }, + }, + } + } + + stsClient := NewSTSClient(cfg) return &AWSOIDCRotator{ - client: client, - kube: kube, - logger: logger, - stsOps: stsClient, - oidcProvider: oidcProvider, + client: client, + kube: kube, + logger: logger, + stsOps: stsClient, + backendSecurityPolicyNamespace: backendSecurityPolicyNamespace, + backendSecurityPolicyName: backendSecurityPolicyName, + preRotationWindow: preRotationWindow, }, nil } -// Type returns the type of rotation this rotator handles -func (r *AWSOIDCRotator) Type() RotationType { - return RotationTypeAWSOIDC -} - // SetSTSOperations sets the STS operations implementation - primarily used for testing func (r *AWSOIDCRotator) SetSTSOperations(ops STSOperations) { r.stsOps = ops } -// ----------------------------------------------------------------------------- -// Event Processing -// ----------------------------------------------------------------------------- - -// Start begins processing rotation events from the rotation channel. -// It runs until the context is cancelled, processing only events -// that match this rotator's type. -func (r *AWSOIDCRotator) Start(ctx context.Context) error { - for { - select { - case event := <-r.rotationChan: - // Only process events for this rotator type - if event.Type != RotationTypeAWSOIDC { - continue - } - - if err := r.Rotate(ctx, event); err != nil { - if !errors.Is(err, context.Canceled) { - r.logger.Error(err, "failed to rotate credentials", - "namespace", event.Namespace, - "name", event.Name) - } - } - case <-ctx.Done(): - return ctx.Err() - } +func (r *AWSOIDCRotator) IsExpired() (bool, error) { + preRotationExpirationTime := r.GetPreRotationTime() + if preRotationExpirationTime == nil { + return true, nil } + return IsExpired(0, *preRotationExpirationTime), nil } -// ----------------------------------------------------------------------------- -// Main Interface Methods - Initialize and Rotate -// ----------------------------------------------------------------------------- - -// Initialize implements the initial token retrieval for AWS OIDC tokens -func (r *AWSOIDCRotator) Initialize(ctx context.Context, event RotationEvent) error { - r.logger.Info("initializing AWS OIDC token", - "namespace", event.Namespace, - "name", event.Name) - - // Get OIDC configuration from metadata - config, err := r.getOIDCConfig(ctx, event) +func (r *AWSOIDCRotator) GetPreRotationTime() *time.Time { + secret, err := LookupSecret(context.Background(), r.client, r.backendSecurityPolicyNamespace, r.backendSecurityPolicyName) if err != nil { - return fmt.Errorf("failed to get OIDC config: %w", err) + if !errors.IsNotFound(err) { + return nil + } + return nil } - - // Fetch and validate OIDC token - token, err := r.oidcProvider.FetchToken(ctx, config) + expirationTime, err := GetExpirationSecretAnnotation(secret) if err != nil { - return fmt.Errorf("failed to fetch OIDC token: %w", err) + return nil } + preRotationTime := expirationTime.Add(-r.preRotationWindow) + return &preRotationTime +} + +// Rotate implements the retrieval and storage of AWS sts credentials +func (r *AWSOIDCRotator) Rotate(ctx context.Context, region, roleARN, token string) error { + r.logger.Info("rotating AWS sts temporary credentials", + "namespace", r.backendSecurityPolicyNamespace, + "name", r.backendSecurityPolicyName) - // Exchange token for AWS credentials - result, err := r.assumeRoleWithToken(ctx, event, token.IDToken) + result, err := r.assumeRoleWithToken(ctx, roleARN, token) if err != nil { + r.logger.Error(err, "failed to assume role", "role", roleARN, "ID", token) return err } - // Create new secret struct - secret := newSecret(event.Namespace, event.Name) - - // Get profile from metadata, defaulting to "default" if not specified - profile := event.Metadata["profile"] - if profile == "" { - profile = "default" + secret, err := LookupSecret(ctx, r.client, r.backendSecurityPolicyNamespace, r.backendSecurityPolicyName) + if err != nil { + if !errors.IsNotFound(err) { + return err + } + secret = newSecret(r.backendSecurityPolicyNamespace, r.backendSecurityPolicyName) } - // Create credentials file with the specified profile + updateExpirationSecretAnnotation(secret, *result.Credentials.Expiration) + + // For now have profile as default + profile := "default" credsFile := &awsCredentialsFile{ profiles: map[string]*awsCredentials{ profile: { @@ -158,104 +139,17 @@ func (r *AWSOIDCRotator) Initialize(ctx context.Context, event RotationEvent) er accessKeyID: aws.ToString(result.Credentials.AccessKeyId), secretAccessKey: aws.ToString(result.Credentials.SecretAccessKey), sessionToken: aws.ToString(result.Credentials.SessionToken), - region: event.Metadata["region"], + region: region, }, }, } - // Update secret with credentials updateAWSCredentialsInSecret(secret, credsFile) return updateSecret(ctx, r.client, secret) } -// Rotate implements the Rotator interface for AWS OIDC credentials -func (r *AWSOIDCRotator) Rotate(ctx context.Context, event RotationEvent) error { - if err := validateRotationEvent(event); err != nil { - return err - } - - // Get OIDC configuration from metadata - config, err := r.getOIDCConfig(ctx, event) - if err != nil { - return fmt.Errorf("failed to get OIDC config: %w", err) - } - - // Fetch and validate OIDC token - token, err := r.oidcProvider.FetchToken(ctx, config) - if err != nil { - return fmt.Errorf("failed to fetch OIDC token: %w", err) - } - - // Get existing secret - secret, err := lookupSecret(ctx, r.client, event.Namespace, event.Name) - if err != nil { - return err - } - - existingCreds, err := validateAWSSecret(secret) - if err != nil { - return err - } - - // Determine which profile to use - profile, err := profileFromMetadata(event.Metadata, existingCreds) - if err != nil { - return fmt.Errorf("failed to determine AWS profile: %w", err) - } - - // Exchange token for AWS credentials - resp, err := r.assumeRoleWithToken(ctx, event, token.IDToken) - if err != nil { - return err - } - - // Update only the specified profile's credentials - existingCreds.profiles[profile].accessKeyID = aws.ToString(resp.Credentials.AccessKeyId) - existingCreds.profiles[profile].secretAccessKey = aws.ToString(resp.Credentials.SecretAccessKey) - existingCreds.profiles[profile].sessionToken = aws.ToString(resp.Credentials.SessionToken) - existingCreds.profiles[profile].region = event.Metadata["region"] - - // Update secret with credentials - updateAWSCredentialsInSecret(secret, existingCreds) - if err := updateSecret(ctx, r.client, secret); err != nil { - return err - } - - // Schedule next rotation if needed - return r.scheduleNextRotation(event, resp) -} - -// ----------------------------------------------------------------------------- -// Helper Methods - Token and Credential Management -// ----------------------------------------------------------------------------- - -// getOIDCConfig creates an OAuth config from the rotation event metadata -func (r *AWSOIDCRotator) getOIDCConfig(ctx context.Context, event RotationEvent) (oauth.Config, error) { - // Convert metadata to expected format - params := make(map[string]string) - - // Required fields - params["token_url"] = event.Metadata["token_url"] - params["client_id"] = event.Metadata["client_id"] - params["client-secret-name"] = event.Metadata["client_secret_name"] - - // Optional fields - if issuerURL := event.Metadata["issuer_url"]; issuerURL != "" { - params["issuer_url"] = issuerURL - } - if scopes := event.Metadata["scopes"]; scopes != "" { - params["scopes"] = scopes - } - - return oauth.NewOIDCConfig(ctx, r.client, event.Namespace, params) -} - // assumeRoleWithToken exchanges an OIDC token for AWS credentials -func (r *AWSOIDCRotator) assumeRoleWithToken(ctx context.Context, event RotationEvent, token string) (*sts.AssumeRoleWithWebIdentityOutput, error) { - roleARN := event.Metadata["role_arn"] - if roleARN == "" { - roleARN = event.Metadata["role-arn"] // support both formats - } +func (r *AWSOIDCRotator) assumeRoleWithToken(ctx context.Context, roleARN, token string) (*sts.AssumeRoleWithWebIdentityOutput, error) { if roleARN == "" { return nil, fmt.Errorf("role ARN is required in metadata") } @@ -263,57 +157,6 @@ func (r *AWSOIDCRotator) assumeRoleWithToken(ctx context.Context, event Rotation return r.stsOps.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityInput{ RoleArn: aws.String(roleARN), WebIdentityToken: aws.String(token), - RoleSessionName: aws.String(fmt.Sprintf(awsSessionNameFormat, event.Name)), + RoleSessionName: aws.String(fmt.Sprintf(awsSessionNameFormat, r.backendSecurityPolicyName)), }) } - -// ----------------------------------------------------------------------------- -// Helper Methods - Validation and Scheduling -// ----------------------------------------------------------------------------- - -// scheduleNextRotation schedules the next rotation before credentials expire -func (r *AWSOIDCRotator) scheduleNextRotation(event RotationEvent, resp *sts.AssumeRoleWithWebIdentityOutput) error { - if resp.Credentials.Expiration == nil { - return nil - } - - // Calculate when we should rotate - 5 minutes before expiry - rotateAt := resp.Credentials.Expiration.Add(-5 * time.Minute) - - // If we're not too close to expiry, schedule the next rotation - if time.Until(rotateAt) > time.Second { - // Create a new event for the next rotation, preserving all metadata - nextEvent := RotationEvent{ - Namespace: event.Namespace, - Name: event.Name, - Type: RotationTypeAWSOIDC, - Metadata: make(map[string]string), - } - - // Copy all metadata from the original event - for k, v := range event.Metadata { - nextEvent.Metadata[k] = v - } - - // Update the rotation time - nextEvent.Metadata["rotate_at"] = rotateAt.Format(time.RFC3339) - - // Send the event through the schedule channel - select { - case r.scheduleChan <- nextEvent: - r.logger.Info("scheduled next rotation", - "namespace", event.Namespace, - "name", event.Name, - "profile", event.Metadata["profile"], - "rotateAt", rotateAt) - default: - r.logger.Error(fmt.Errorf("schedule channel is full"), "failed to schedule next rotation", - "namespace", event.Namespace, - "name", event.Name, - "profile", event.Metadata["profile"], - "rotateAt", rotateAt) - } - } - - return nil -} diff --git a/internal/controller/rotators/common.go b/internal/controller/rotators/common.go index cf0164297..f1573f328 100644 --- a/internal/controller/rotators/common.go +++ b/internal/controller/rotators/common.go @@ -1 +1,81 @@ package backendauthrotators + +import ( + "context" + "fmt" + "time" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +const ExpirationTimeAnnotationKey = "rotators/expiration-time" + +// newSecret creates a new secret struct (does not persist to k8s) +func newSecret(namespace, name string) *corev1.Secret { + return &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + Type: corev1.SecretTypeOpaque, + Data: make(map[string][]byte), + } +} + +// updateSecret updates an existing secret or creates a new one +func updateSecret(ctx context.Context, k8sClient client.Client, secret *corev1.Secret) error { + if secret.ResourceVersion == "" { + if err := k8sClient.Create(ctx, secret); err != nil { + return fmt.Errorf("failed to create secret: %w", err) + } + } else { + if err := k8sClient.Update(ctx, secret); err != nil { + return fmt.Errorf("failed to update secret: %w", err) + } + } + return nil +} + +// LookupSecret retrieves an existing secret +func LookupSecret(ctx context.Context, k8sClient client.Client, namespace, name string) (*corev1.Secret, error) { + secret := &corev1.Secret{} + if err := k8sClient.Get(ctx, client.ObjectKey{ + Namespace: namespace, + Name: name, + }, secret); err != nil { + if errors.IsNotFound(err) { + return nil, err + } + return nil, fmt.Errorf("failed to get secret: %w", err) + } + return secret, nil +} + +// updateExpirationSecretAnnotation will set the expiration time of credentials set in secret annotation +func updateExpirationSecretAnnotation(secret *corev1.Secret, time time.Time) { + if secret.Annotations == nil { + secret.Annotations = make(map[string]string) + } + secret.Annotations[ExpirationTimeAnnotationKey] = time.String() +} + +// GetExpirationSecretAnnotation will get the expiration time of credentials set in secret annotation +func GetExpirationSecretAnnotation(secret *corev1.Secret) (*time.Time, error) { + expirationTimeAnnotationKey, ok := secret.Annotations[ExpirationTimeAnnotationKey] + if !ok { + return nil, fmt.Errorf("secret %s/%s missing expiration time annotation", secret.Namespace, secret.Name) + } + + expirationTime, err := time.Parse(time.RFC3339, expirationTimeAnnotationKey) + if err != nil { + return nil, fmt.Errorf("failed to parse expiration time annotation: %w", err) + } + return &expirationTime, nil +} + +func IsExpired(preRotationInterval time.Duration, expirationTime time.Time) bool { + return expirationTime.Add(-preRotationInterval).Before(time.Now()) +} From 9ed1c0258b39ed307069a61ab14b296d05dbadc5 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 6 Feb 2025 15:04:54 -0500 Subject: [PATCH 03/86] linting Signed-off-by: Aaron Choo --- .../controller/backend_security_policy.go | 41 ++++++++++++- internal/controller/oauth/oidc_provider.go | 1 - internal/controller/sink.go | 57 ++++++++++++++++++- 3 files changed, 92 insertions(+), 7 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index cc252a76a..78ffb1c02 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -7,7 +7,9 @@ package controller import ( "context" + "time" + egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/go-logr/logr" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/client-go/kubernetes" @@ -15,6 +17,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" aigv1a1 "github.com/envoyproxy/ai-gateway/api/v1alpha1" + backendauthrotators "github.com/envoyproxy/ai-gateway/internal/controller/rotators" ) // backendSecurityPolicyController implements [reconcile.TypedReconciler] for [aigv1a1.BackendSecurityPolicy]. @@ -37,9 +40,9 @@ func newBackendSecurityPolicyController(client client.Client, kube kubernetes.In } // Reconcile implements the [reconcile.TypedReconciler] for [aigv1a1.BackendSecurityPolicy]. -func (b backendSecurityPolicyController) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { +func (b backendSecurityPolicyController) Reconcile(ctx context.Context, req ctrl.Request) (res ctrl.Result, err error) { var backendSecurityPolicy aigv1a1.BackendSecurityPolicy - if err := b.client.Get(ctx, req.NamespacedName, &backendSecurityPolicy); err != nil { + if err = b.client.Get(ctx, req.NamespacedName, &backendSecurityPolicy); err != nil { if errors.IsNotFound(err) { ctrl.Log.Info("Deleting Backend Security Policy", "namespace", req.Namespace, "name", req.Name) @@ -48,7 +51,39 @@ func (b backendSecurityPolicyController) Reconcile(ctx context.Context, req ctrl return ctrl.Result{}, err } + if isBackendSecurityPolicyAuthOIDC(backendSecurityPolicy.Spec) { + var requeue time.Duration + requeue = time.Minute + region := backendSecurityPolicy.Spec.AWSCredentials.Region + rotator, err := backendauthrotators.NewAWSOIDCRotator(b.client, b.kube, b.logger, backendSecurityPolicy.Namespace, backendSecurityPolicy.Name, preRotationWindow, region) + if err != nil { + b.logger.Error(err, "failed to create AWS OIDC rotator") + } else if expired, err := rotator.IsExpired(); err != nil && !expired { + requeue = time.Until(*rotator.GetPreRotationTime()) + if requeue.Seconds() == 0 { + requeue = time.Minute + } + } + res = ctrl.Result{RequeueAfter: requeue, Requeue: true} + } // Send the backend security policy to the config sink so that it can modify the configuration together with the state of other resources. b.eventChan <- backendSecurityPolicy.DeepCopy() - return ctrl.Result{}, nil + return +} + +// Move to rotator file +func isBackendSecurityPolicyAuthOIDC(spec aigv1a1.BackendSecurityPolicySpec) bool { + if spec.AWSCredentials != nil { + return spec.AWSCredentials.OIDCExchangeToken != nil + } + return false +} + +func getBackendSecurityPolicyAuthOIDC(spec aigv1a1.BackendSecurityPolicySpec) *egv1a1.OIDC { + if isBackendSecurityPolicyAuthOIDC(spec) { + if spec.AWSCredentials.OIDCExchangeToken != nil { + return &spec.AWSCredentials.OIDCExchangeToken.OIDC + } + } + return nil } diff --git a/internal/controller/oauth/oidc_provider.go b/internal/controller/oauth/oidc_provider.go index 8f0725333..3550b4b39 100644 --- a/internal/controller/oauth/oidc_provider.go +++ b/internal/controller/oauth/oidc_provider.go @@ -52,7 +52,6 @@ func (p *OIDCProvider) getOIDCMetadata(ctx context.Context, issuerURL string) (* } resp, err := p.httpClient.Do(req) - if err != nil { return nil, fmt.Errorf("failed to fetch OIDC metadata: %w", err) } diff --git a/internal/controller/sink.go b/internal/controller/sink.go index c28630f90..00cebd780 100644 --- a/internal/controller/sink.go +++ b/internal/controller/sink.go @@ -8,7 +8,9 @@ package controller import ( "context" "fmt" + "net/http" "path" + "time" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/go-logr/logr" @@ -27,6 +29,8 @@ import ( aigv1a1 "github.com/envoyproxy/ai-gateway/api/v1alpha1" "github.com/envoyproxy/ai-gateway/filterapi" + "github.com/envoyproxy/ai-gateway/internal/controller/oauth" + backendauthrotators "github.com/envoyproxy/ai-gateway/internal/controller/rotators" "github.com/envoyproxy/ai-gateway/internal/llmcostcel" ) @@ -41,6 +45,10 @@ const ( // secret with backendSecurityPolicy auth instead of mounting new secret files to the external proc. const mountedExtProcSecretPath = "/etc/backend_security_policy" // #nosec G101 +// preRotationWindow specifies how long before expiry to rotate credentials +// temporarily a fixed duration +const preRotationWindow = 5 * time.Minute + // ConfigSinkEvent is the interface for the events that the configSink can handle. // It can be either an AIServiceBackend, an AIGatewayRoute, or a deletion event. // @@ -65,6 +73,7 @@ type configSink struct { extProcImagePullPolicy corev1.PullPolicy extProcLogLevel string eventChan chan ConfigSinkEvent + oidcTokenCache map[string]*oauth.TokenResponse } func newConfigSink( @@ -83,6 +92,7 @@ func newConfigSink( extProcImagePullPolicy: corev1.PullIfNotPresent, extProcLogLevel: extProcLogLevel, eventChan: eventChan, + oidcTokenCache: make(map[string]*oauth.TokenResponse), } return c } @@ -256,6 +266,48 @@ func (c *configSink) syncBackendSecurityPolicy(ctx context.Context, bsp *aigv1a1 aiBackend := &aiServiceBackends.Items[i] c.syncAIServiceBackend(ctx, aiBackend) } + + if isBackendSecurityPolicyAuthOIDC(bsp.Spec) { + tokenResponse, ok := c.oidcTokenCache[key] + if !ok || backendauthrotators.IsExpired(preRotationWindow, tokenResponse.ExpiresAt) { + baseProvider := oauth.NewBaseProvider(c.client, c.logger, &http.Client{Timeout: 30 * time.Second}) + oidcProvider := oauth.NewOIDCProvider(baseProvider, getBackendSecurityPolicyAuthOIDC(bsp.Spec)) + + tokenRes, err := oidcProvider.FetchToken(context.TODO()) + if err != nil { + c.logger.Error(err, "failed to fetch OIDC provider token") + return + } + c.oidcTokenCache[key] = tokenRes + tokenResponse = tokenRes + } + + awsCredentials := bsp.Spec.AWSCredentials + rotator, err := backendauthrotators.NewAWSOIDCRotator(c.client, c.kube, c.logger, bsp.Namespace, bsp.Name, preRotationWindow, awsCredentials.Region) + if err != nil { + c.logger.Error(err, "failed to create AWS OIDC rotator") + return + } + + expired, err := rotator.IsExpired() + if err != nil { + c.logger.Error(err, "failed to check if AWS OIDC rotator is expired") + return + } + + if expired { + token := tokenResponse.IDToken + if token == "" { + token = tokenResponse.AccessToken + } + + err = rotator.Rotate(context.Background(), awsCredentials.Region, awsCredentials.OIDCExchangeToken.AwsRoleArn, token) + if err != nil { + c.logger.Error(err, "failed to rotate AWS OIDC exchange token") + return + } + } + } } // updateExtProcConfigMap updates the external process configmap with the new AIGatewayRoute. @@ -309,7 +361,7 @@ func (c *configSink) updateExtProcConfigMap(ctx context.Context, aiGatewayRoute if backendSecurityPolicy.Spec.AWSCredentials == nil { return fmt.Errorf("AWSCredentials type selected but not defined %s", backendSecurityPolicy.Name) } - if backendSecurityPolicy.Spec.AWSCredentials.CredentialsFile != nil { + if awsCred := backendSecurityPolicy.Spec.AWSCredentials; awsCred.CredentialsFile != nil || awsCred.OIDCExchangeToken != nil { ec.Rules[i].Backends[j].Auth = &filterapi.BackendAuth{ AWSAuth: &filterapi.AWSAuth{ CredentialFileName: path.Join(backendSecurityMountPath(volumeName), "/credentials"), @@ -607,8 +659,7 @@ func (c *configSink) mountBackendSecurityPolicySecrets(ctx context.Context, spec if backendSecurityPolicy.Spec.AWSCredentials.CredentialsFile != nil { secretName = string(backendSecurityPolicy.Spec.AWSCredentials.CredentialsFile.SecretRef.Name) } else { - // Will introduce OIDC in a following PR - continue + secretName = backendSecurityPolicy.Name } default: return nil, fmt.Errorf("backend security policy %s is not supported", backendSecurityPolicy.Spec.Type) From ed575c4d28b7aae1be5e54cf3cd8e8cc672e635d Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 6 Feb 2025 15:05:19 -0500 Subject: [PATCH 04/86] enable proxy for aws and mounted cm Signed-off-by: Aaron Choo --- .../ai-gateway-helm/templates/deployment.yaml | 26 +++++++++++++++++++ manifests/charts/ai-gateway-helm/values.yaml | 12 +++++++++ 2 files changed, 38 insertions(+) diff --git a/manifests/charts/ai-gateway-helm/templates/deployment.yaml b/manifests/charts/ai-gateway-helm/templates/deployment.yaml index ea8a13414..c7ea8f829 100644 --- a/manifests/charts/ai-gateway-helm/templates/deployment.yaml +++ b/manifests/charts/ai-gateway-helm/templates/deployment.yaml @@ -47,6 +47,13 @@ spec: port: 1063 initialDelaySeconds: 5 periodSeconds: 2 + {{- if .Values.controller.podEnv }} + env: + {{- range $key, $val := .Values.controller.podEnv }} + - name: {{ $key }} + value: {{ $val }} + {{- end }} + {{- end }} readinessProbe: grpc: port: 1063 @@ -54,6 +61,25 @@ spec: periodSeconds: 2 resources: {{- toYaml .Values.controller.resources | nindent 12 }} + {{- if .Values.controller.volumes }} + volumeMounts: + {{- range $volume := .Values.controller.volumes }} + - mountPath: {{ $volume.mountPath }} + name: {{ $volume.name }} + {{- if $volume.subPath }} + subPath: {{ $volume.subPath }} + {{- end }} + {{- end}} + {{- end }} + {{- if .Values.controller.volumes }} + volumes: + {{- range $volume := .Values.controller.volumes }} + - name: {{ $volume.name }} + configMap: + defaultMode: {{ $volume.configmap.defaultMode }} + name: {{ $volume.configmap.name }} + {{- end }} + {{- end }} {{- with .Values.controller.nodeSelector }} nodeSelector: {{- toYaml . | nindent 8 }} diff --git a/manifests/charts/ai-gateway-helm/values.yaml b/manifests/charts/ai-gateway-helm/values.yaml index f8257eb89..b6f3d6a36 100644 --- a/manifests/charts/ai-gateway-helm/values.yaml +++ b/manifests/charts/ai-gateway-helm/values.yaml @@ -38,6 +38,18 @@ controller: podAnnotations: {} podSecurityContext: {} securityContext: {} + # Example of a podEnv + # - key: AI_GATEWY_STS_PROXY_URL + # value: some-proxy-placeholder + podEnv: {} + # Example of volumes + # - mountPath: /placeholder/path + # name: volume-name + # subPath: placeholder-sub-path + # configmap: + # defaultMode: placeholder + # name: configmap-name + volumes: {} service: type: ClusterIP ports: From 98fa80263f947c7f098bdb62cb4a4ddbea6d6671 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 6 Feb 2025 15:05:33 -0500 Subject: [PATCH 05/86] update go packages Signed-off-by: Aaron Choo --- go.mod | 3 ++- go.sum | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index c54cc3d00..3c016dcd4 100644 --- a/go.mod +++ b/go.mod @@ -6,9 +6,11 @@ require ( github.com/aws/aws-sdk-go-v2 v1.36.1 github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.8 github.com/aws/aws-sdk-go-v2/config v1.29.6 + github.com/aws/aws-sdk-go-v2/service/sts v1.33.14 github.com/envoyproxy/gateway v1.3.0 github.com/envoyproxy/go-control-plane/envoy v1.32.4 github.com/go-logr/logr v1.4.2 + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/cel-go v0.23.2 github.com/google/go-cmp v0.6.0 github.com/openai/openai-go v0.1.0-alpha.56 @@ -79,7 +81,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.13 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.24.15 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.14 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.33.14 // indirect github.com/aws/smithy-go v1.22.2 // indirect github.com/baulk/chardet v0.1.0 // indirect github.com/beorn7/perks v1.0.1 // indirect diff --git a/go.sum b/go.sum index 154bd492e..a47a2aaa0 100644 --- a/go.sum +++ b/go.sum @@ -359,6 +359,8 @@ github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUv github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golangci/dupl v0.0.0-20180902072040-3e9179ac440a h1:w8hkcTqaFpzKqonE9uMCefW1WDie15eSP/4MssdenaM= From 134191182216f806dbc8c84de3074a52f4657e8a Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 6 Feb 2025 15:26:41 -0500 Subject: [PATCH 06/86] precommit test Signed-off-by: Aaron Choo --- internal/controller/oauth/client_credentials_provider.go | 2 +- internal/controller/oauth/oidc_provider.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/controller/oauth/client_credentials_provider.go b/internal/controller/oauth/client_credentials_provider.go index a97f4c9ac..7687d9535 100644 --- a/internal/controller/oauth/client_credentials_provider.go +++ b/internal/controller/oauth/client_credentials_provider.go @@ -92,7 +92,7 @@ func (p *ClientCredentialsProvider) SupportsFlow(flowType FlowType) bool { return flowType == FlowClientCredentials } -func (p *ClientCredentialsProvider) ValidateToken(ctx context.Context, token string) error { +func (p *ClientCredentialsProvider) ValidateToken(_ context.Context, _ string) error { // Implement token validation logic // This might involve introspection endpoint if available return nil diff --git a/internal/controller/oauth/oidc_provider.go b/internal/controller/oauth/oidc_provider.go index 3550b4b39..d65e2b51d 100644 --- a/internal/controller/oauth/oidc_provider.go +++ b/internal/controller/oauth/oidc_provider.go @@ -83,7 +83,7 @@ func (p *OIDCProvider) validateIDToken(ctx context.Context, rawIDToken, issuerUR return nil, fmt.Errorf("context error before validation: %w", err) } - token, err := jwt.Parse(rawIDToken, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.Parse(rawIDToken, func(_ *jwt.Token) (interface{}, error) { // For now, we skip signature validation as we don't have the key // TODO: Implement JWKS validation return jwt.UnsafeAllowNoneSignatureType, nil @@ -204,7 +204,7 @@ func (p *OIDCProvider) SupportsFlow(flowType FlowType) bool { } // ValidateToken implements token validation for both access tokens and ID tokens -func (p *OIDCProvider) ValidateToken(ctx context.Context, token string) error { +func (p *OIDCProvider) ValidateToken(_ context.Context, _ string) error { // For ID tokens, we expect them to have been validated during GetToken // For access tokens, we could implement introspection here if needed return nil From 88f53fbd905e9321106532fb499fe542c44e83f1 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 6 Feb 2025 15:39:38 -0500 Subject: [PATCH 07/86] remove comment Signed-off-by: Aaron Choo --- internal/controller/backend_security_policy.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 78ffb1c02..dbf0b408b 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -71,7 +71,6 @@ func (b backendSecurityPolicyController) Reconcile(ctx context.Context, req ctrl return } -// Move to rotator file func isBackendSecurityPolicyAuthOIDC(spec aigv1a1.BackendSecurityPolicySpec) bool { if spec.AWSCredentials != nil { return spec.AWSCredentials.OIDCExchangeToken != nil From 40beefc31de92e80ee66c9258f0fe1bd498a0dc8 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 6 Feb 2025 17:39:50 -0500 Subject: [PATCH 08/86] set up index for aws oidc Signed-off-by: Aaron Choo --- internal/controller/controller.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/controller/controller.go b/internal/controller/controller.go index ec728d45f..fb3756ad7 100644 --- a/internal/controller/controller.go +++ b/internal/controller/controller.go @@ -186,8 +186,9 @@ func backendSecurityPolicyIndexFunc(o client.Object) []string { awsCreds := backendSecurityPolicy.Spec.AWSCredentials if awsCreds.CredentialsFile != nil { key = getSecretNameAndNamespace(awsCreds.CredentialsFile.SecretRef, backendSecurityPolicy.Namespace) + } else if awsCreds.OIDCExchangeToken != nil { + key = fmt.Sprintf("%s.%s", backendSecurityPolicy.Name, backendSecurityPolicy.Namespace) } - // TODO: OIDC. } return []string{key} } From 88f5fc7277bd62d6148f330ec1b4f1850ade4ea3 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Fri, 7 Feb 2025 11:20:42 -0500 Subject: [PATCH 09/86] trigger refresh on controller start up Signed-off-by: Aaron Choo --- .../controller/backend_security_policy.go | 39 ++++++++++++++----- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index dbf0b408b..696693db0 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -24,23 +24,44 @@ import ( // // This handles the BackendSecurityPolicy resource and sends it to the config sink so that it can modify configuration. type backendSecurityPolicyController struct { - client client.Client - kube kubernetes.Interface - logger logr.Logger - eventChan chan ConfigSinkEvent + client client.Client + kube kubernetes.Interface + logger logr.Logger + eventChan chan ConfigSinkEvent + reconcileAll bool } func newBackendSecurityPolicyController(client client.Client, kube kubernetes.Interface, logger logr.Logger, ch chan ConfigSinkEvent) *backendSecurityPolicyController { return &backendSecurityPolicyController{ - client: client, - kube: kube, - logger: logger, - eventChan: ch, + client: client, + kube: kube, + logger: logger, + eventChan: ch, + reconcileAll: true, } } // Reconcile implements the [reconcile.TypedReconciler] for [aigv1a1.BackendSecurityPolicy]. -func (b backendSecurityPolicyController) Reconcile(ctx context.Context, req ctrl.Request) (res ctrl.Result, err error) { +func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctrl.Request) (res ctrl.Result, err error) { + if b.reconcileAll { + var backendSecPolicyList aigv1a1.BackendSecurityPolicyList + err = b.client.List(ctx, &backendSecPolicyList) + if err != nil { + b.logger.Error(err, "failed to trigger refresh for existing backendSecPolicy resources") + } else { + refreshTime := time.Now().String() + for _, backendSecurityPolicy := range backendSecPolicyList.Items { + if isBackendSecurityPolicyAuthOIDC(backendSecurityPolicy.Spec) { + backendSecurityPolicy.Annotations["refresh"] = refreshTime + } + err = b.client.Update(ctx, &backendSecurityPolicy) + if err != nil { + b.logger.Error(err, "failed to trigger refresh for existing backendSecPolicy resource", "name", backendSecurityPolicy.Name) + } + } + b.reconcileAll = false + } + } var backendSecurityPolicy aigv1a1.BackendSecurityPolicy if err = b.client.Get(ctx, req.NamespacedName, &backendSecurityPolicy); err != nil { if errors.IsNotFound(err) { From d30315c49c4cd48383c703f3a17efff58baa0571 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Fri, 7 Feb 2025 11:49:52 -0500 Subject: [PATCH 10/86] backend security policy tests Signed-off-by: Aaron Choo --- .../controller/backend_security_policy.go | 3 +++ .../backend_security_policy_test.go | 24 ++++++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 696693db0..341661b0d 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -52,6 +52,9 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr refreshTime := time.Now().String() for _, backendSecurityPolicy := range backendSecPolicyList.Items { if isBackendSecurityPolicyAuthOIDC(backendSecurityPolicy.Spec) { + if len(backendSecurityPolicy.Annotations) == 0 { + backendSecurityPolicy.Annotations = make(map[string]string) + } backendSecurityPolicy.Annotations["refresh"] = refreshTime } err = b.client.Update(ctx, &backendSecurityPolicy) diff --git a/internal/controller/backend_security_policy_test.go b/internal/controller/backend_security_policy_test.go index d1b1ea7e3..e02e4fced 100644 --- a/internal/controller/backend_security_policy_test.go +++ b/internal/controller/backend_security_policy_test.go @@ -6,6 +6,8 @@ package controller import ( + "context" + "fmt" "testing" "github.com/stretchr/testify/require" @@ -28,7 +30,18 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { err := cl.Create(t.Context(), &aigv1a1.BackendSecurityPolicy{ObjectMeta: metav1.ObjectMeta{Name: backendSecurityPolicyName, Namespace: namespace}}) require.NoError(t, err) - _, err = c.Reconcile(t.Context(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: backendSecurityPolicyName}}) + err = cl.Create(context.Background(), &aigv1a1.BackendSecurityPolicy{ + ObjectMeta: metav1.ObjectMeta{Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName), Namespace: namespace}, + Spec: aigv1a1.BackendSecurityPolicySpec{ + Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, + AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ + Region: "us-east-1", + OIDCExchangeToken: &aigv1a1.AWSOIDCExchangeToken{}, + }, + }, + }) + require.NoError(t, err) + _, err = c.Reconcile(context.Background(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: backendSecurityPolicyName}}) require.NoError(t, err) item, ok := <-ch require.True(t, ok) @@ -36,6 +49,15 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { require.Equal(t, backendSecurityPolicyName, item.(*aigv1a1.BackendSecurityPolicy).Name) require.Equal(t, namespace, item.(*aigv1a1.BackendSecurityPolicy).Namespace) + // Test backendSecurityPolicy with OIDC credentials have the annotation added + oidcBackendSecurityPolicy := &aigv1a1.BackendSecurityPolicy{} + err = cl.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName)}, oidcBackendSecurityPolicy) + require.NoError(t, err) + require.Len(t, oidcBackendSecurityPolicy.Annotations, 1) + time, ok := oidcBackendSecurityPolicy.Annotations["refresh"] + require.True(t, ok) + require.NotEmpty(t, time) + // Test the case where the BackendSecurityPolicy is being deleted. err = cl.Delete(t.Context(), &aigv1a1.BackendSecurityPolicy{ObjectMeta: metav1.ObjectMeta{Name: backendSecurityPolicyName, Namespace: namespace}}) require.NoError(t, err) From 10f842ab671756b65579af143820ef08dcce00c0 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Fri, 7 Feb 2025 16:15:54 -0500 Subject: [PATCH 11/86] patch reconcile annotation Signed-off-by: Aaron Choo --- .../controller/backend_security_policy.go | 23 ++++--- .../backend_security_policy_test.go | 66 ++++++++++++++++++- .../controller/oauth/base_provider_test.go | 0 .../oauth/client_credentials_provider_test.go | 0 .../controller/oauth/oidc_provider_test.go | 0 5 files changed, 80 insertions(+), 9 deletions(-) create mode 100644 internal/controller/oauth/base_provider_test.go create mode 100644 internal/controller/oauth/client_credentials_provider_test.go create mode 100644 internal/controller/oauth/oidc_provider_test.go diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 341661b0d..ed780ace5 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -7,11 +7,13 @@ package controller import ( "context" + "fmt" "time" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/go-logr/logr" "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/kubernetes" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" @@ -41,6 +43,8 @@ func newBackendSecurityPolicyController(client client.Client, kube kubernetes.In } } +type patchBackendSecurityPolicy struct{} + // Reconcile implements the [reconcile.TypedReconciler] for [aigv1a1.BackendSecurityPolicy]. func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctrl.Request) (res ctrl.Result, err error) { if b.reconcileAll { @@ -49,17 +53,12 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr if err != nil { b.logger.Error(err, "failed to trigger refresh for existing backendSecPolicy resources") } else { - refreshTime := time.Now().String() for _, backendSecurityPolicy := range backendSecPolicyList.Items { if isBackendSecurityPolicyAuthOIDC(backendSecurityPolicy.Spec) { - if len(backendSecurityPolicy.Annotations) == 0 { - backendSecurityPolicy.Annotations = make(map[string]string) + err = b.client.Patch(ctx, &backendSecurityPolicy, patchBackendSecurityPolicy{}) + if err != nil { + b.logger.Error(err, "failed to trigger refresh for existing backendSecPolicy resource", "name", backendSecurityPolicy.Name) } - backendSecurityPolicy.Annotations["refresh"] = refreshTime - } - err = b.client.Update(ctx, &backendSecurityPolicy) - if err != nil { - b.logger.Error(err, "failed to trigger refresh for existing backendSecPolicy resource", "name", backendSecurityPolicy.Name) } } b.reconcileAll = false @@ -110,3 +109,11 @@ func getBackendSecurityPolicyAuthOIDC(spec aigv1a1.BackendSecurityPolicySpec) *e } return nil } + +func (p patchBackendSecurityPolicy) Type() types.PatchType { + return types.MergePatchType +} + +func (p patchBackendSecurityPolicy) Data(_ client.Object) ([]byte, error) { + return []byte(fmt.Sprintf(`{"metadata":{"annotations":{"%s":"%s"}}}`, "reconcile", time.Now().String())), nil +} diff --git a/internal/controller/backend_security_policy_test.go b/internal/controller/backend_security_policy_test.go index e02e4fced..64a070dcc 100644 --- a/internal/controller/backend_security_policy_test.go +++ b/internal/controller/backend_security_policy_test.go @@ -10,6 +10,7 @@ import ( "fmt" "testing" + egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/stretchr/testify/require" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" @@ -17,6 +18,7 @@ import ( ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/controller-runtime/pkg/reconcile" + gwapiv1 "sigs.k8s.io/gateway-api/apis/v1" aigv1a1 "github.com/envoyproxy/ai-gateway/api/v1alpha1" ) @@ -54,7 +56,7 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { err = cl.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName)}, oidcBackendSecurityPolicy) require.NoError(t, err) require.Len(t, oidcBackendSecurityPolicy.Annotations, 1) - time, ok := oidcBackendSecurityPolicy.Annotations["refresh"] + time, ok := oidcBackendSecurityPolicy.Annotations["reconcile"] require.True(t, ok) require.NotEmpty(t, time) @@ -64,3 +66,65 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { _, err = c.Reconcile(t.Context(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: backendSecurityPolicyName}}) require.NoError(t, err) } + +func TestBackendSecurityController_IsBackendSecurityPolicyAuthOIDC(t *testing.T) { + require.False(t, isBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{ + Type: aigv1a1.BackendSecurityPolicyTypeAPIKey, + APIKey: &aigv1a1.BackendSecurityPolicyAPIKey{}, + })) + + require.False(t, isBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{ + Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, + AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ + Region: "us-east-1", + CredentialsFile: &aigv1a1.AWSCredentialsFile{}, + }, + })) + + require.True(t, isBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{ + Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, + AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ + Region: "us-east-1", + OIDCExchangeToken: &aigv1a1.AWSOIDCExchangeToken{}, + }, + })) +} + +func TestBackendSecurityController_GetBackendSecurityPolicyAuthOIDC(t *testing.T) { + require.Nil(t, getBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{ + Type: aigv1a1.BackendSecurityPolicyTypeAPIKey, + APIKey: &aigv1a1.BackendSecurityPolicyAPIKey{}, + })) + + require.Nil(t, getBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{ + Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, + AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ + Region: "us-east-1", + CredentialsFile: &aigv1a1.AWSCredentialsFile{}, + }, + })) + + oidc := egv1a1.OIDC{ + Provider: egv1a1.OIDCProvider{ + Issuer: "https://oidc.example.com", + }, + ClientID: "client-id", + ClientSecret: gwapiv1.SecretObjectReference{ + Name: "client-secret", + }, + } + + actualOIDC := getBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{ + Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, + AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ + Region: "us-east-1", + OIDCExchangeToken: &aigv1a1.AWSOIDCExchangeToken{ + OIDC: oidc, + }, + }, + }) + require.NotNil(t, actualOIDC) + require.Equal(t, oidc.ClientID, actualOIDC.ClientID) + require.Equal(t, oidc.Provider.Issuer, actualOIDC.Provider.Issuer) + require.Equal(t, oidc.ClientSecret.Name, actualOIDC.ClientSecret.Name) +} diff --git a/internal/controller/oauth/base_provider_test.go b/internal/controller/oauth/base_provider_test.go new file mode 100644 index 000000000..e69de29bb diff --git a/internal/controller/oauth/client_credentials_provider_test.go b/internal/controller/oauth/client_credentials_provider_test.go new file mode 100644 index 000000000..e69de29bb diff --git a/internal/controller/oauth/oidc_provider_test.go b/internal/controller/oauth/oidc_provider_test.go new file mode 100644 index 000000000..e69de29bb From 01fbe61f5c401c4b946f6285ae7873f2818c1a32 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Fri, 7 Feb 2025 17:50:47 -0500 Subject: [PATCH 12/86] testing Signed-off-by: Aaron Choo --- .../controller/oauth/base_provider_test.go | 50 ++++++++++++ .../oauth/client_credentials_provider_test.go | 81 +++++++++++++++++++ .../controller/oauth/oidc_provider_test.go | 56 +++++++++++++ 3 files changed, 187 insertions(+) diff --git a/internal/controller/oauth/base_provider_test.go b/internal/controller/oauth/base_provider_test.go index e69de29bb..518e93ea0 100644 --- a/internal/controller/oauth/base_provider_test.go +++ b/internal/controller/oauth/base_provider_test.go @@ -0,0 +1,50 @@ +package oauth + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client/fake" +) + +func TestNewBaseProvider(t *testing.T) { + scheme := runtime.NewScheme() + cl := fake.NewClientBuilder().WithScheme(scheme).Build() + require.NotNil(t, NewBaseProvider(cl, ctrl.Log, nil)) +} + +func TestBaseProvider_GetClientSecret(t *testing.T) { + scheme := runtime.NewScheme() + scheme.AddKnownTypes(corev1.SchemeGroupVersion, + &corev1.Secret{}, + ) + cl := fake.NewClientBuilder().WithScheme(scheme).Build() + baseProvider := NewBaseProvider(cl, ctrl.Log, nil) + + secretName, secretNamespace := "secret", "secret-ns" + err := cl.Create(context.Background(), &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: secretName, + Namespace: secretNamespace, + }, + Immutable: nil, + Data: map[string][]byte{ + "client-secret": []byte("client-secret"), + }, + StringData: nil, + Type: "", + }) + require.NoError(t, err) + + secret, err := baseProvider.getClientSecret(context.Background(), &corev1.SecretReference{ + Name: secretName, + Namespace: secretNamespace, + }) + require.NoError(t, err) + require.Equal(t, "client-secret", secret) +} diff --git a/internal/controller/oauth/client_credentials_provider_test.go b/internal/controller/oauth/client_credentials_provider_test.go index e69de29bb..d5cf39775 100644 --- a/internal/controller/oauth/client_credentials_provider_test.go +++ b/internal/controller/oauth/client_credentials_provider_test.go @@ -0,0 +1,81 @@ +package oauth + +import ( + "context" + egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "net/http" + "net/http/httptest" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + gwapiv1 "sigs.k8s.io/gateway-api/apis/v1" + "testing" + "time" +) + +func TestNewClientCredentialsProvider(t *testing.T) { + require.NotNil(t, NewClientCredentialsProvider(nil)) +} + +func TestClientCredentialsProvider_FetchToken(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(`{"access_token": "token", "token_type": "Bearer", "expires_in": 3600}`)) + })) + defer ts.Close() + + scheme := runtime.NewScheme() + scheme.AddKnownTypes(corev1.SchemeGroupVersion, + &corev1.Secret{}, + ) + cl := fake.NewClientBuilder().WithScheme(scheme).Build() + baseProvider := NewBaseProvider(cl, ctrl.Log, nil) + require.NotNil(t, baseProvider) + + secretName, secretNamespace := "secret", "secret-ns" + err := cl.Create(context.Background(), &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: secretName, + Namespace: secretNamespace, + }, + Immutable: nil, + Data: map[string][]byte{ + "client-secret": []byte("client-secret"), + }, + StringData: nil, + Type: "", + }) + require.NoError(t, err) + + clientProvider := NewClientCredentialsProvider(baseProvider) + require.NotNil(t, clientProvider) + + namespaceRef := gwapiv1.Namespace(secretNamespace) + token, err := clientProvider.FetchToken(context.Background(), &egv1a1.OIDC{ + Provider: egv1a1.OIDCProvider{ + Issuer: ts.URL, + TokenEndpoint: &ts.URL, + }, + ClientID: "some-client-id", + ClientSecret: gwapiv1.SecretObjectReference{ + Name: gwapiv1.ObjectName(secretName), + Namespace: &namespaceRef, + }, + }) + require.NoError(t, err) + require.Equal(t, "token", token.AccessToken) + require.WithinRangef(t, token.ExpiresAt, time.Now().Add(3590*time.Second), time.Now().Add(3600*time.Second), "token expires at") +} + +func TestClientCredentialsProvider_SupportsFlow(t *testing.T) { + provider := NewClientCredentialsProvider(nil) + require.True(t, provider.SupportsFlow(FlowClientCredentials)) + require.False(t, provider.SupportsFlow(FlowClientCredentialsWithIDToken)) +} + +func TestClientCredentialsProvider_ValidateToken(t *testing.T) { + provider := NewClientCredentialsProvider(nil) + require.Nil(t, provider.ValidateToken(context.Background(), "")) +} diff --git a/internal/controller/oauth/oidc_provider_test.go b/internal/controller/oauth/oidc_provider_test.go index e69de29bb..31523495c 100644 --- a/internal/controller/oauth/oidc_provider_test.go +++ b/internal/controller/oauth/oidc_provider_test.go @@ -0,0 +1,56 @@ +package oauth + +import ( + "context" + egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/runtime" + "net/http" + "net/http/httptest" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + "testing" +) + +func TestNewOIDCProvider(t *testing.T) { + require.NotNil(t, NewOIDCProvider(nil, &egv1a1.OIDC{})) +} + +func TestOIDCProvider_GetOIDCMetadata(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(`{"issuer": "issuer", "token_endpoint": "token_endpoint", "authorization_endpoint": "authorization_endpoint", "jwks_uri": "jwks_uri", "scopes_supported": []}`)) + })) + defer ts.Close() + + scheme := runtime.NewScheme() + scheme.AddKnownTypes(corev1.SchemeGroupVersion, + &corev1.Secret{}, + ) + cl := fake.NewClientBuilder().WithScheme(scheme).Build() + baseProvider := NewBaseProvider(cl, ctrl.Log, nil) + require.NotNil(t, baseProvider) + + oidc := &egv1a1.OIDC{ + Provider: egv1a1.OIDCProvider{ + Issuer: ts.URL, + TokenEndpoint: &ts.URL, + }, + ClientID: "some-client-id", + } + + oidcProvider := NewOIDCProvider(baseProvider, oidc) + metadata, err := oidcProvider.getOIDCMetadata(context.Background(), ts.URL) + require.NoError(t, err) + require.Equal(t, "token_endpoint", metadata.TokenEndpoint) + require.Equal(t, "issuer", metadata.Issuer) +} +func TestOIDCProvider_validateIDToken(t *testing.T) {} + +func TestOIDCProvider_FetchToken(t *testing.T) { + +} + +func TestOIDCProvider_SupportsFlow(t *testing.T) {} + +func TestOIDCProvider_ValidateToken(t *testing.T) {} From 74cada37e346656ec745ba502ebdcaf6ff182504 Mon Sep 17 00:00:00 2001 From: Dan Sun Date: Sun, 9 Feb 2025 13:59:29 -0500 Subject: [PATCH 13/86] Refactor oauth2 code Signed-off-by: Dan Sun --- go.mod | 5 +- go.sum | 6 +- internal/controller/oauth/base_provider.go | 10 +- .../controller/oauth/base_provider_test.go | 4 +- .../oauth/client_credentials_provider.go | 80 +++---------- .../oauth/client_credentials_provider_test.go | 60 ++++++---- internal/controller/oauth/oidc_provider.go | 110 ++---------------- .../controller/oauth/oidc_provider_test.go | 67 +++++++++-- internal/controller/oauth/types.go | 29 +---- internal/controller/sink.go | 14 +-- 10 files changed, 140 insertions(+), 245 deletions(-) diff --git a/go.mod b/go.mod index 3c016dcd4..048957dfa 100644 --- a/go.mod +++ b/go.mod @@ -10,14 +10,14 @@ require ( github.com/envoyproxy/gateway v1.3.0 github.com/envoyproxy/go-control-plane/envoy v1.32.4 github.com/go-logr/logr v1.4.2 - github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/cel-go v0.23.2 github.com/google/go-cmp v0.6.0 github.com/openai/openai-go v0.1.0-alpha.56 github.com/stretchr/testify v1.10.0 go.uber.org/goleak v1.3.0 go.uber.org/zap v1.27.0 - golang.org/x/exp v0.0.0-20250207012021-f9890c6ad9f3 + golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c + golang.org/x/oauth2 v0.26.0 google.golang.org/grpc v1.70.0 google.golang.org/protobuf v1.36.5 k8s.io/api v0.32.1 @@ -352,7 +352,6 @@ require ( golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.23.0 // indirect golang.org/x/net v0.35.0 // indirect - golang.org/x/oauth2 v0.26.0 // indirect golang.org/x/sync v0.11.0 // indirect golang.org/x/sys v0.30.0 // indirect golang.org/x/term v0.29.0 // indirect diff --git a/go.sum b/go.sum index a47a2aaa0..1b7834776 100644 --- a/go.sum +++ b/go.sum @@ -359,8 +359,6 @@ github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUv github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= -github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golangci/dupl v0.0.0-20180902072040-3e9179ac440a h1:w8hkcTqaFpzKqonE9uMCefW1WDie15eSP/4MssdenaM= @@ -919,8 +917,8 @@ golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliY golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= -golang.org/x/exp v0.0.0-20250207012021-f9890c6ad9f3 h1:qNgPs5exUA+G0C96DrPwNrvLSj7GT/9D+3WMWUcUg34= -golang.org/x/exp v0.0.0-20250207012021-f9890c6ad9f3/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= +golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c h1:KL/ZBHXgKGVmuZBZ01Lt57yE5ws8ZPSkkihmEyq7FXc= +golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= golang.org/x/exp/typeparams v0.0.0-20220428152302-39d4317da171/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/exp/typeparams v0.0.0-20230203172020-98cc5a0785f9/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/exp/typeparams v0.0.0-20241108190413-2d47ceb2692f h1:WTyX8eCCyfdqiPYkRGm0MqElSfYFH3yR1+rl/mct9sA= diff --git a/internal/controller/oauth/base_provider.go b/internal/controller/oauth/base_provider.go index 4f5e8b520..177779df6 100644 --- a/internal/controller/oauth/base_provider.go +++ b/internal/controller/oauth/base_provider.go @@ -3,8 +3,6 @@ package oauth import ( "context" "fmt" - "net/http" - "time" "github.com/go-logr/logr" corev1 "k8s.io/api/core/v1" @@ -15,19 +13,13 @@ import ( type BaseProvider struct { client client.Client logger logr.Logger - http *http.Client } // NewBaseProvider creates a new base provider -func NewBaseProvider(client client.Client, logger logr.Logger, httpClient *http.Client) *BaseProvider { - if httpClient == nil { - httpClient = &http.Client{Timeout: 30 * time.Second} - } - +func NewBaseProvider(client client.Client, logger logr.Logger) *BaseProvider { return &BaseProvider{ client: client, logger: logger, - http: httpClient, } } diff --git a/internal/controller/oauth/base_provider_test.go b/internal/controller/oauth/base_provider_test.go index 518e93ea0..e7039cd20 100644 --- a/internal/controller/oauth/base_provider_test.go +++ b/internal/controller/oauth/base_provider_test.go @@ -15,7 +15,7 @@ import ( func TestNewBaseProvider(t *testing.T) { scheme := runtime.NewScheme() cl := fake.NewClientBuilder().WithScheme(scheme).Build() - require.NotNil(t, NewBaseProvider(cl, ctrl.Log, nil)) + require.NotNil(t, NewBaseProvider(cl, ctrl.Log)) } func TestBaseProvider_GetClientSecret(t *testing.T) { @@ -24,7 +24,7 @@ func TestBaseProvider_GetClientSecret(t *testing.T) { &corev1.Secret{}, ) cl := fake.NewClientBuilder().WithScheme(scheme).Build() - baseProvider := NewBaseProvider(cl, ctrl.Log, nil) + baseProvider := NewBaseProvider(cl, ctrl.Log) secretName, secretNamespace := "secret", "secret-ns" err := cl.Create(context.Background(), &corev1.Secret{ diff --git a/internal/controller/oauth/client_credentials_provider.go b/internal/controller/oauth/client_credentials_provider.go index 7687d9535..c07ce9bb3 100644 --- a/internal/controller/oauth/client_credentials_provider.go +++ b/internal/controller/oauth/client_credentials_provider.go @@ -2,14 +2,12 @@ package oauth import ( "context" - "encoding/json" "fmt" - "net/http" - "net/url" - "strings" "time" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" + "golang.org/x/oauth2" + "golang.org/x/oauth2/clientcredentials" corev1 "k8s.io/api/core/v1" ) @@ -19,13 +17,14 @@ type ClientCredentialsProvider struct { } // NewClientCredentialsProvider creates a new client credentials provider -func NewClientCredentialsProvider(base *BaseProvider) *ClientCredentialsProvider { +func NewClientCredentialsProvider(base *BaseProvider) TokenProvider { return &ClientCredentialsProvider{ BaseProvider: base, } } -func (p *ClientCredentialsProvider) FetchToken(ctx context.Context, oidc *egv1a1.OIDC) (*TokenResponse, error) { +// FetchToken gets the client secret from the secret reference and fetches the token from the provider token URL. +func (p *ClientCredentialsProvider) FetchToken(ctx context.Context, oidc *egv1a1.OIDC) (*oauth2.Token, error) { clientSecret, err := p.getClientSecret(ctx, &corev1.SecretReference{ Name: string(oidc.ClientSecret.Name), Namespace: string(*oidc.ClientSecret.Namespace), @@ -33,67 +32,26 @@ func (p *ClientCredentialsProvider) FetchToken(ctx context.Context, oidc *egv1a1 if err != nil { return nil, err } + return p.getTokenWithClientCredentialConfig(ctx, oidc, clientSecret) +} - // Prepare token request - form := url.Values{} - form.Set("grant_type", "client_credentials") - form.Set("client_id", oidc.ClientID) - form.Set("client_secret", clientSecret) - if len(oidc.Scopes) > 0 { - form.Set("scope", strings.Join(oidc.Scopes, " ")) +// getTokenWithClientCredentialFlow fetches the oauth2 token with client credential config +func (p *ClientCredentialsProvider) getTokenWithClientCredentialConfig(ctx context.Context, oidc *egv1a1.OIDC, clientSecret string) (*oauth2.Token, error) { + oauth2Config := clientcredentials.Config{ + ClientID: oidc.ClientID, + ClientSecret: clientSecret, + // Discovery returns the OAuth2 endpoints. + TokenURL: *oidc.Provider.TokenEndpoint, + Scopes: oidc.Scopes, } - - // Make request - req, err := http.NewRequestWithContext(ctx, "POST", *oidc.Provider.TokenEndpoint, - strings.NewReader(form.Encode())) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, err := p.http.Do(req) + token, err := oauth2Config.Token(ctx) if err != nil { - return nil, fmt.Errorf("token request failed: %w", err) - } - defer resp.Body.Close() - - // Parse response - var raw map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&raw); err != nil { - return nil, fmt.Errorf("failed to parse response: %w", err) - } - - // Convert to TokenResponse - token := &TokenResponse{ - Raw: raw, - } - - // Extract standard fields - if v, ok := raw["access_token"].(string); ok { - token.AccessToken = v - } - if v, ok := raw["token_type"].(string); ok { - token.TokenType = v - } - if v, ok := raw["scope"].(string); ok { - token.Scope = v + return nil, fmt.Errorf("fail to get oauth2 token %w", err) } // Handle expiration - if v, ok := raw["expires_in"].(float64); ok { - token.ExpiresAt = time.Now().Add(time.Duration(v) * time.Second) + if token.ExpiresIn > 0 { + token.Expiry = time.Now().Add(time.Duration(token.ExpiresIn) * time.Second) } - return token, nil } - -func (p *ClientCredentialsProvider) SupportsFlow(flowType FlowType) bool { - return flowType == FlowClientCredentials -} - -func (p *ClientCredentialsProvider) ValidateToken(_ context.Context, _ string) error { - // Implement token validation logic - // This might involve introspection endpoint if available - return nil -} diff --git a/internal/controller/oauth/client_credentials_provider_test.go b/internal/controller/oauth/client_credentials_provider_test.go index d5cf39775..29b54f741 100644 --- a/internal/controller/oauth/client_credentials_provider_test.go +++ b/internal/controller/oauth/client_credentials_provider_test.go @@ -2,27 +2,54 @@ package oauth import ( "context" + "net/http" + "net/http/httptest" + "testing" + "time" + egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/stretchr/testify/require" + "golang.org/x/oauth2" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" - "net/http" - "net/http/httptest" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client/fake" gwapiv1 "sigs.k8s.io/gateway-api/apis/v1" - "testing" - "time" ) -func TestNewClientCredentialsProvider(t *testing.T) { - require.NotNil(t, NewClientCredentialsProvider(nil)) +// MockClientCredentialsProvider implements the standard OAuth2 client credentials flow +type MockClientCredentialsProvider struct { + *BaseProvider +} + +// NewMockClientCredentialsProvider creates a new client credentials provider +func NewMockClientCredentialsProvider(base *BaseProvider) TokenProvider { + return &MockClientCredentialsProvider{ + BaseProvider: base, + } +} + +// FetchToken gets the client secret from the secret reference and fetches the token from provider token URL. +func (m *MockClientCredentialsProvider) FetchToken(ctx context.Context, oidc *egv1a1.OIDC) (*oauth2.Token, error) { + _, err := m.getClientSecret(ctx, &corev1.SecretReference{ + Name: string(oidc.ClientSecret.Name), + Namespace: string(*oidc.ClientSecret.Namespace), + }) + if err != nil { + return nil, err + } + return &oauth2.Token{ + AccessToken: "token", + ExpiresIn: 3600, + Expiry: time.Now().Add(time.Duration(3600) * time.Second), + }, nil } func TestClientCredentialsProvider_FetchToken(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(`{"access_token": "token", "token_type": "Bearer", "expires_in": 3600}`)) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, err := w.Write([]byte(`{"access_token": "token", "token_type": "Bearer", "expires_in": 3600}`)) + require.NoError(t, err) })) defer ts.Close() @@ -31,7 +58,7 @@ func TestClientCredentialsProvider_FetchToken(t *testing.T) { &corev1.Secret{}, ) cl := fake.NewClientBuilder().WithScheme(scheme).Build() - baseProvider := NewBaseProvider(cl, ctrl.Log, nil) + baseProvider := NewBaseProvider(cl, ctrl.Log) require.NotNil(t, baseProvider) secretName, secretNamespace := "secret", "secret-ns" @@ -49,7 +76,7 @@ func TestClientCredentialsProvider_FetchToken(t *testing.T) { }) require.NoError(t, err) - clientProvider := NewClientCredentialsProvider(baseProvider) + clientProvider := NewMockClientCredentialsProvider(baseProvider) require.NotNil(t, clientProvider) namespaceRef := gwapiv1.Namespace(secretNamespace) @@ -66,16 +93,5 @@ func TestClientCredentialsProvider_FetchToken(t *testing.T) { }) require.NoError(t, err) require.Equal(t, "token", token.AccessToken) - require.WithinRangef(t, token.ExpiresAt, time.Now().Add(3590*time.Second), time.Now().Add(3600*time.Second), "token expires at") -} - -func TestClientCredentialsProvider_SupportsFlow(t *testing.T) { - provider := NewClientCredentialsProvider(nil) - require.True(t, provider.SupportsFlow(FlowClientCredentials)) - require.False(t, provider.SupportsFlow(FlowClientCredentialsWithIDToken)) -} - -func TestClientCredentialsProvider_ValidateToken(t *testing.T) { - provider := NewClientCredentialsProvider(nil) - require.Nil(t, provider.ValidateToken(context.Background(), "")) + require.WithinRangef(t, token.Expiry, time.Now().Add(3590*time.Second), time.Now().Add(3600*time.Second), "token expires at") } diff --git a/internal/controller/oauth/oidc_provider.go b/internal/controller/oauth/oidc_provider.go index d65e2b51d..61347c0f3 100644 --- a/internal/controller/oauth/oidc_provider.go +++ b/internal/controller/oauth/oidc_provider.go @@ -9,12 +9,12 @@ import ( "time" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" - "github.com/golang-jwt/jwt/v5" + "golang.org/x/oauth2" ) // OIDCProvider extends ClientCredentialsProvider with OIDC support type OIDCProvider struct { - *ClientCredentialsProvider + tokenProvider TokenProvider httpClient *http.Client oidcCredential *egv1a1.OIDC } @@ -29,11 +29,11 @@ type OIDCMetadata struct { } // NewOIDCProvider creates a new OIDC-aware provider -func NewOIDCProvider(base *BaseProvider, oidcCredentials *egv1a1.OIDC) *OIDCProvider { +func NewOIDCProvider(tokenProvider TokenProvider, oidcCredentials *egv1a1.OIDC) *OIDCProvider { return &OIDCProvider{ - ClientCredentialsProvider: NewClientCredentialsProvider(base), - httpClient: &http.Client{Timeout: 30 * time.Second}, - oidcCredential: oidcCredentials, + tokenProvider: tokenProvider, + httpClient: &http.Client{Timeout: 30 * time.Second}, + oidcCredential: oidcCredentials, } } @@ -77,63 +77,8 @@ func (p *OIDCProvider) getOIDCMetadata(ctx context.Context, issuerURL string) (* return &metadata, nil } -// validateIDToken validates the ID token according to the OIDC spec -func (p *OIDCProvider) validateIDToken(ctx context.Context, rawIDToken, issuerURL, clientID string) (map[string]interface{}, error) { - if err := ctx.Err(); err != nil { - return nil, fmt.Errorf("context error before validation: %w", err) - } - - token, err := jwt.Parse(rawIDToken, func(_ *jwt.Token) (interface{}, error) { - // For now, we skip signature validation as we don't have the key - // TODO: Implement JWKS validation - return jwt.UnsafeAllowNoneSignatureType, nil - }) - if err != nil { - return nil, fmt.Errorf("failed to parse ID token: %w", err) - } - - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - return nil, fmt.Errorf("invalid claims format in token") - } - - now := time.Now() - - // Validate issuer - if iss, err := claims.GetIssuer(); err != nil || iss != issuerURL { - return nil, fmt.Errorf("invalid issuer claim") - } - - // Validate audience - if aud, err := claims.GetAudience(); err != nil || !contains(aud, clientID) { - return nil, fmt.Errorf("invalid audience claim") - } - - // Validate expiration - if exp, err := claims.GetExpirationTime(); err != nil || exp.Before(now) { - return nil, fmt.Errorf("token is expired") - } - - // Validate issued at - if iat, err := claims.GetIssuedAt(); err != nil || iat.After(now) { - return nil, fmt.Errorf("token used before issued") - } - - return claims, nil -} - -// contains checks if a string slice contains a value -func contains(slice []string, val string) bool { - for _, item := range slice { - if item == val { - return true - } - } - return false -} - // FetchToken retrieves and validates tokens using the client credentials flow with OIDC support -func (p *OIDCProvider) FetchToken(ctx context.Context) (*TokenResponse, error) { +func (p *OIDCProvider) FetchToken(ctx context.Context) (*oauth2.Token, error) { // If issuer URL is provided, fetch OIDC metadata if issuerURL := p.oidcCredential.Provider.Issuer; issuerURL != "" { metadata, err := p.getOIDCMetadata(ctx, issuerURL) @@ -162,50 +107,11 @@ func (p *OIDCProvider) FetchToken(ctx context.Context) (*TokenResponse, error) { } } - // Ensure openid scope is present - hasOpenID := false - for _, scope := range p.oidcCredential.Scopes { - if scope == "openid" { - hasOpenID = true - break - } - } - if !hasOpenID { - p.oidcCredential.Scopes = append(p.oidcCredential.Scopes, "openid") - } - // Get base token response - token, err := p.ClientCredentialsProvider.FetchToken(ctx, p.oidcCredential) + token, err := p.tokenProvider.FetchToken(ctx, p.oidcCredential) if err != nil { return nil, fmt.Errorf("failed to get token: %w", err) } - // Extract ID token if present - if rawIDToken, ok := token.Raw["id_token"].(string); ok { - token.IDToken = rawIDToken - - // Validate ID token if issuer URL is provided - if issuerURL := p.oidcCredential.Provider.Issuer; issuerURL != "" { - claims, err := p.validateIDToken(ctx, rawIDToken, issuerURL, p.oidcCredential.ClientID) - if err != nil { - return nil, fmt.Errorf("failed to validate ID token: %w", err) - } - - // Store claims in raw map for access by consumers - token.Raw["id_token_claims"] = claims - } - } - return token, nil } - -func (p *OIDCProvider) SupportsFlow(flowType FlowType) bool { - return flowType == FlowClientCredentialsWithIDToken -} - -// ValidateToken implements token validation for both access tokens and ID tokens -func (p *OIDCProvider) ValidateToken(_ context.Context, _ string) error { - // For ID tokens, we expect them to have been validated during GetToken - // For access tokens, we could implement introspection here if needed - return nil -} diff --git a/internal/controller/oauth/oidc_provider_test.go b/internal/controller/oauth/oidc_provider_test.go index 31523495c..16b35ae08 100644 --- a/internal/controller/oauth/oidc_provider_test.go +++ b/internal/controller/oauth/oidc_provider_test.go @@ -2,15 +2,18 @@ package oauth import ( "context" + "net/http" + "net/http/httptest" + "testing" + egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/stretchr/testify/require" corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" - "net/http" - "net/http/httptest" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client/fake" - "testing" + gwapiv1 "sigs.k8s.io/gateway-api/apis/v1" ) func TestNewOIDCProvider(t *testing.T) { @@ -18,8 +21,9 @@ func TestNewOIDCProvider(t *testing.T) { } func TestOIDCProvider_GetOIDCMetadata(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(`{"issuer": "issuer", "token_endpoint": "token_endpoint", "authorization_endpoint": "authorization_endpoint", "jwks_uri": "jwks_uri", "scopes_supported": []}`)) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, err := w.Write([]byte(`{"issuer": "issuer", "token_endpoint": "token_endpoint", "authorization_endpoint": "authorization_endpoint", "jwks_uri": "jwks_uri", "scopes_supported": []}`)) + require.NoError(t, err) })) defer ts.Close() @@ -28,7 +32,7 @@ func TestOIDCProvider_GetOIDCMetadata(t *testing.T) { &corev1.Secret{}, ) cl := fake.NewClientBuilder().WithScheme(scheme).Build() - baseProvider := NewBaseProvider(cl, ctrl.Log, nil) + baseProvider := NewBaseProvider(cl, ctrl.Log) require.NotNil(t, baseProvider) oidc := &egv1a1.OIDC{ @@ -39,18 +43,59 @@ func TestOIDCProvider_GetOIDCMetadata(t *testing.T) { ClientID: "some-client-id", } - oidcProvider := NewOIDCProvider(baseProvider, oidc) + oidcProvider := NewOIDCProvider(NewMockClientCredentialsProvider(baseProvider), oidc) metadata, err := oidcProvider.getOIDCMetadata(context.Background(), ts.URL) require.NoError(t, err) require.Equal(t, "token_endpoint", metadata.TokenEndpoint) require.Equal(t, "issuer", metadata.Issuer) } -func TestOIDCProvider_validateIDToken(t *testing.T) {} func TestOIDCProvider_FetchToken(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, err := w.Write([]byte(`{"issuer": "issuer", "token_endpoint": "token_endpoint", "authorization_endpoint": "authorization_endpoint", "jwks_uri": "jwks_uri", "scopes_supported": []}`)) + require.NoError(t, err) + })) + defer ts.Close() -} + scheme := runtime.NewScheme() + scheme.AddKnownTypes(corev1.SchemeGroupVersion, + &corev1.Secret{}, + ) + cl := fake.NewClientBuilder().WithScheme(scheme).Build() + baseProvider := NewBaseProvider(cl, ctrl.Log) + require.NotNil(t, baseProvider) -func TestOIDCProvider_SupportsFlow(t *testing.T) {} + secretName, secretNamespace := "secret", "secret-ns" + err := cl.Create(context.Background(), &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: secretName, + Namespace: secretNamespace, + }, + Immutable: nil, + Data: map[string][]byte{ + "client-secret": []byte("client-secret"), + }, + StringData: nil, + Type: "", + }) + require.NoError(t, err) + namespaceRef := gwapiv1.Namespace(secretNamespace) + oidc := &egv1a1.OIDC{ + Provider: egv1a1.OIDCProvider{ + Issuer: ts.URL, + TokenEndpoint: &ts.URL, + }, + ClientID: "some-client-id", + ClientSecret: gwapiv1.SecretObjectReference{ + Name: gwapiv1.ObjectName(secretName), + Namespace: &namespaceRef, + }, + } -func TestOIDCProvider_ValidateToken(t *testing.T) {} + oidcProvider := NewOIDCProvider(NewMockClientCredentialsProvider(baseProvider), oidc) + token, err := oidcProvider.FetchToken(context.Background()) + require.NoError(t, err) + require.Equal(t, "token", token.AccessToken) + require.Equal(t, "Bearer", token.Type()) + require.Equal(t, int64(3600), token.ExpiresIn) +} diff --git a/internal/controller/oauth/types.go b/internal/controller/oauth/types.go index d4e753798..fa3c3795c 100644 --- a/internal/controller/oauth/types.go +++ b/internal/controller/oauth/types.go @@ -2,31 +2,12 @@ package oauth import ( "context" - "time" -) - -// FlowType represents different OAuth/OIDC flow types -type FlowType string -const ( - FlowClientCredentials FlowType = "client_credentials" - FlowClientCredentialsWithIDToken FlowType = "client_credentials_with_id_token" + egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" + "golang.org/x/oauth2" ) -// TokenResponse represents the common token response structure -type TokenResponse struct { - AccessToken string - TokenType string - ExpiresAt time.Time - Scope string - IDToken string // Optional OIDC field - RefreshToken string // Optional refresh token - Raw map[string]interface{} -} - -// Provider defines the interface for OAuth token providers -type Provider interface { - FetchToken(ctx context.Context) (*TokenResponse, error) - ValidateToken(ctx context.Context, token string) error - SupportsFlow(flowType FlowType) bool +// TokenProvider defines the interface for OAuth token providers +type TokenProvider interface { + FetchToken(ctx context.Context, oidc *egv1a1.OIDC) (*oauth2.Token, error) } diff --git a/internal/controller/sink.go b/internal/controller/sink.go index 00cebd780..83cdd64c3 100644 --- a/internal/controller/sink.go +++ b/internal/controller/sink.go @@ -8,12 +8,12 @@ package controller import ( "context" "fmt" - "net/http" "path" "time" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/go-logr/logr" + "golang.org/x/oauth2" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -73,7 +73,7 @@ type configSink struct { extProcImagePullPolicy corev1.PullPolicy extProcLogLevel string eventChan chan ConfigSinkEvent - oidcTokenCache map[string]*oauth.TokenResponse + oidcTokenCache map[string]*oauth2.Token } func newConfigSink( @@ -92,7 +92,7 @@ func newConfigSink( extProcImagePullPolicy: corev1.PullIfNotPresent, extProcLogLevel: extProcLogLevel, eventChan: eventChan, - oidcTokenCache: make(map[string]*oauth.TokenResponse), + oidcTokenCache: make(map[string]*oauth2.Token), } return c } @@ -269,9 +269,9 @@ func (c *configSink) syncBackendSecurityPolicy(ctx context.Context, bsp *aigv1a1 if isBackendSecurityPolicyAuthOIDC(bsp.Spec) { tokenResponse, ok := c.oidcTokenCache[key] - if !ok || backendauthrotators.IsExpired(preRotationWindow, tokenResponse.ExpiresAt) { - baseProvider := oauth.NewBaseProvider(c.client, c.logger, &http.Client{Timeout: 30 * time.Second}) - oidcProvider := oauth.NewOIDCProvider(baseProvider, getBackendSecurityPolicyAuthOIDC(bsp.Spec)) + if !ok || backendauthrotators.IsExpired(preRotationWindow, tokenResponse.Expiry) { + baseProvider := oauth.NewBaseProvider(c.client, c.logger) + oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(baseProvider), getBackendSecurityPolicyAuthOIDC(bsp.Spec)) tokenRes, err := oidcProvider.FetchToken(context.TODO()) if err != nil { @@ -296,7 +296,7 @@ func (c *configSink) syncBackendSecurityPolicy(ctx context.Context, bsp *aigv1a1 } if expired { - token := tokenResponse.IDToken + token := tokenResponse.AccessToken if token == "" { token = tokenResponse.AccessToken } From 44c776d26c2eba79b7909f06a3f6956f55f68823 Mon Sep 17 00:00:00 2001 From: Dan Sun Date: Sun, 9 Feb 2025 16:20:57 -0500 Subject: [PATCH 14/86] Add AWS rotator test Signed-off-by: Dan Sun --- ...o => client_credentials_token_provider.go} | 10 +- ...client_credentials_token_provider_test.go} | 0 internal/controller/oauth/oidc_provider.go | 2 +- internal/controller/rotators/aws_common.go | 57 +++++++ .../rotators/aws_oidc_rotator_test.go | 152 ++++++++++++++++++ 5 files changed, 215 insertions(+), 6 deletions(-) rename internal/controller/oauth/{client_credentials_provider.go => client_credentials_token_provider.go} (75%) rename internal/controller/oauth/{client_credentials_provider_test.go => client_credentials_token_provider_test.go} (100%) create mode 100644 internal/controller/rotators/aws_oidc_rotator_test.go diff --git a/internal/controller/oauth/client_credentials_provider.go b/internal/controller/oauth/client_credentials_token_provider.go similarity index 75% rename from internal/controller/oauth/client_credentials_provider.go rename to internal/controller/oauth/client_credentials_token_provider.go index c07ce9bb3..8c665b6f2 100644 --- a/internal/controller/oauth/client_credentials_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -11,20 +11,20 @@ import ( corev1 "k8s.io/api/core/v1" ) -// ClientCredentialsProvider implements the standard OAuth2 client credentials flow -type ClientCredentialsProvider struct { +// ClientCredentialsTokenProvider implements the standard OAuth2 client credentials flow +type ClientCredentialsTokenProvider struct { *BaseProvider } // NewClientCredentialsProvider creates a new client credentials provider func NewClientCredentialsProvider(base *BaseProvider) TokenProvider { - return &ClientCredentialsProvider{ + return &ClientCredentialsTokenProvider{ BaseProvider: base, } } // FetchToken gets the client secret from the secret reference and fetches the token from the provider token URL. -func (p *ClientCredentialsProvider) FetchToken(ctx context.Context, oidc *egv1a1.OIDC) (*oauth2.Token, error) { +func (p *ClientCredentialsTokenProvider) FetchToken(ctx context.Context, oidc *egv1a1.OIDC) (*oauth2.Token, error) { clientSecret, err := p.getClientSecret(ctx, &corev1.SecretReference{ Name: string(oidc.ClientSecret.Name), Namespace: string(*oidc.ClientSecret.Namespace), @@ -36,7 +36,7 @@ func (p *ClientCredentialsProvider) FetchToken(ctx context.Context, oidc *egv1a1 } // getTokenWithClientCredentialFlow fetches the oauth2 token with client credential config -func (p *ClientCredentialsProvider) getTokenWithClientCredentialConfig(ctx context.Context, oidc *egv1a1.OIDC, clientSecret string) (*oauth2.Token, error) { +func (p *ClientCredentialsTokenProvider) getTokenWithClientCredentialConfig(ctx context.Context, oidc *egv1a1.OIDC, clientSecret string) (*oauth2.Token, error) { oauth2Config := clientcredentials.Config{ ClientID: oidc.ClientID, ClientSecret: clientSecret, diff --git a/internal/controller/oauth/client_credentials_provider_test.go b/internal/controller/oauth/client_credentials_token_provider_test.go similarity index 100% rename from internal/controller/oauth/client_credentials_provider_test.go rename to internal/controller/oauth/client_credentials_token_provider_test.go diff --git a/internal/controller/oauth/oidc_provider.go b/internal/controller/oauth/oidc_provider.go index 61347c0f3..defa6b4bd 100644 --- a/internal/controller/oauth/oidc_provider.go +++ b/internal/controller/oauth/oidc_provider.go @@ -12,7 +12,7 @@ import ( "golang.org/x/oauth2" ) -// OIDCProvider extends ClientCredentialsProvider with OIDC support +// OIDCProvider extends ClientCredentialsTokenProvider with OIDC support type OIDCProvider struct { tokenProvider TokenProvider httpClient *http.Client diff --git a/internal/controller/rotators/aws_common.go b/internal/controller/rotators/aws_common.go index 8a2a6fd40..5a726199b 100644 --- a/internal/controller/rotators/aws_common.go +++ b/internal/controller/rotators/aws_common.go @@ -97,6 +97,63 @@ type awsCredentialsFile struct { profiles map[string]*awsCredentials } +// parseAWSCredentialsFile parses an AWS credentials file with multiple profiles. +// The file format follows the standard AWS credentials file format: +// +// [profile-name] +// aws_access_key_id = AKIAXXXXXXXXXXXXXXXX +// aws_secret_access_key = xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx +// aws_session_token = xxxxxxxx (optional) +// region = xx-xxxx-x (optional) +// +// Returns a structured representation of the credentials file. +func parseAWSCredentialsFile(data string) *awsCredentialsFile { + file := &awsCredentialsFile{ + profiles: make(map[string]*awsCredentials), + } + + var currentCreds *awsCredentials + + for _, line := range strings.Split(data, "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { + profileName := strings.TrimPrefix(strings.TrimSuffix(line, "]"), "[") + currentCreds = &awsCredentials{profile: profileName} + file.profiles[profileName] = currentCreds + continue + } + + if currentCreds == nil { + continue + } + + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + continue + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + switch key { + case "aws_access_key_id": + currentCreds.accessKeyID = value + case "aws_secret_access_key": + currentCreds.secretAccessKey = value + case "aws_session_token": + currentCreds.sessionToken = value + case "region": + currentCreds.region = value + } + } + + return file +} + // formatAWSCredentialsFile formats multiple AWS credential profiles into a credentials file. // The output follows the standard AWS credentials file format and ensures: // - Consistent ordering of profiles through sorting diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go new file mode 100644 index 000000000..22882dcc4 --- /dev/null +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -0,0 +1,152 @@ +package backendauthrotators + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/aws/aws-sdk-go-v2/service/sts/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" +) + +// ----------------------------------------------------------------------------- +// Test Helper Methods +// ----------------------------------------------------------------------------- + +// createTestAWSSecret creates a test secret with given credentials +func createTestAWSSecret(t *testing.T, client client.Client, name string, accessKey, secretKey, sessionToken string, profile string) { + if profile == "" { + profile = "default" + } + data := map[string][]byte{ + credentialsKey: []byte(fmt.Sprintf("[%s]\naws_access_key_id = %s\naws_secret_access_key = %s\naws_session_token = %s\nregion = us-west-2", + profile, accessKey, secretKey, sessionToken)), + } + err := client.Create(context.Background(), &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: "default", + }, + Data: data, + }) + require.NoError(t, err) +} + +// verifyAWSSecretCredentials verifies the credentials in a secret +func verifyAWSSecretCredentials(t *testing.T, client client.Client, namespace, secretName, expectedKeyID, expectedSecret, expectedToken string, profile string) { + if profile == "" { + profile = "default" + } + secret, err := LookupSecret(context.Background(), client, namespace, secretName) + require.NoError(t, err) + creds := parseAWSCredentialsFile(string(secret.Data[credentialsKey])) + require.NotNil(t, creds) + require.Contains(t, creds.profiles, profile) + assert.Equal(t, expectedKeyID, creds.profiles[profile].accessKeyID) + assert.Equal(t, expectedSecret, creds.profiles[profile].secretAccessKey) + assert.Equal(t, expectedToken, creds.profiles[profile].sessionToken) +} + +// createClientSecret creates the OIDC client secret +func createClientSecret(t *testing.T, name string) { + data := map[string][]byte{ + "client-secret": []byte("test-client-secret"), + } + scheme := runtime.NewScheme() + scheme.AddKnownTypes(corev1.SchemeGroupVersion, + &corev1.Secret{}, + ) + cl := fake.NewClientBuilder().WithScheme(scheme).Build() + err := cl.Create(context.Background(), &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: "default", + }, + Data: data, + }) + require.NoError(t, err) +} + +// MockSTSOperations implements the STSOperations interface for testing +type MockSTSOperations struct { + assumeRoleWithWebIdentityFunc func(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) +} + +func (m *MockSTSOperations) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { + if m.assumeRoleWithWebIdentityFunc != nil { + return m.assumeRoleWithWebIdentityFunc(ctx, params, optFns...) + } + return nil, fmt.Errorf("mock not implemented") +} + +// ----------------------------------------------------------------------------- +// Test Cases +// ----------------------------------------------------------------------------- + +func TestAWS_OIDCRotator(t *testing.T) { + t.Run("basic rotation", func(t *testing.T) { + var mockSTS STSOperations = &MockSTSOperations{ + assumeRoleWithWebIdentityFunc: func(_ context.Context, _ *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { + return &sts.AssumeRoleWithWebIdentityOutput{ + Credentials: &types.Credentials{ + AccessKeyId: aws.String("NEWKEY"), + SecretAccessKey: aws.String("NEWSECRET"), + SessionToken: aws.String("NEWTOKEN"), + Expiration: aws.Time(time.Now().Add(1 * time.Hour)), + }, + }, nil + }, + } + scheme := runtime.NewScheme() + scheme.AddKnownTypes(corev1.SchemeGroupVersion, + &corev1.Secret{}, + ) + cl := fake.NewClientBuilder().WithScheme(scheme).Build() + // Setup initial credentials and client secret + createTestAWSSecret(t, cl, "test-secret", "OLDKEY", "OLDSECRET", "OLDTOKEN", "default") + createClientSecret(t, "test-client-secret") + + awsOidcRotator := AWSOIDCRotator{ + client: cl, + stsOps: mockSTS, + backendSecurityPolicyNamespace: "default", + backendSecurityPolicyName: "test-secret", + } + + require.NoError(t, awsOidcRotator.Rotate(context.Background(), "us-east1", "test", "NEW-OIDC-TOKEN")) + verifyAWSSecretCredentials(t, cl, "default", "test-secret", "NEWKEY", "NEWSECRET", "NEWTOKEN", "default") + }) + + t.Run("error handling - STS assume role failure", func(t *testing.T) { + scheme := runtime.NewScheme() + scheme.AddKnownTypes(corev1.SchemeGroupVersion, + &corev1.Secret{}, + ) + cl := fake.NewClientBuilder().WithScheme(scheme).Build() + createTestAWSSecret(t, cl, "test-secret", "OLDKEY", "OLDSECRET", "OLDTOKEN", "default") + createClientSecret(t, "test-client-secret") + var mockSTS STSOperations = &MockSTSOperations{ + assumeRoleWithWebIdentityFunc: func(_ context.Context, _ *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { + return nil, fmt.Errorf("failed to assume role") + }, + } + awsOidcRotator := AWSOIDCRotator{ + client: cl, + stsOps: mockSTS, + backendSecurityPolicyNamespace: "default", + backendSecurityPolicyName: "test-secret", + } + err := awsOidcRotator.Rotate(context.Background(), "us-east1", "test", "NEW-OIDC-TOKEN") + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to assume role") + }) +} From 09b8b00cf48f00e6b1fa894abad314e4af5c1b3a Mon Sep 17 00:00:00 2001 From: Dan Sun Date: Sun, 9 Feb 2025 16:49:33 -0500 Subject: [PATCH 15/86] Mock token source Signed-off-by: Dan Sun --- .../client_credentials_token_provider.go | 20 +++++++------ .../client_credentials_token_provider_test.go | 28 +++++-------------- internal/controller/oauth/oidc_provider.go | 4 +-- .../controller/oauth/oidc_provider_test.go | 8 ++++-- 4 files changed, 26 insertions(+), 34 deletions(-) diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index 8c665b6f2..3aa43c6f8 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -14,10 +14,11 @@ import ( // ClientCredentialsTokenProvider implements the standard OAuth2 client credentials flow type ClientCredentialsTokenProvider struct { *BaseProvider + TokenSource oauth2.TokenSource } // NewClientCredentialsProvider creates a new client credentials provider -func NewClientCredentialsProvider(base *BaseProvider) TokenProvider { +func NewClientCredentialsProvider(base *BaseProvider) *ClientCredentialsTokenProvider { return &ClientCredentialsTokenProvider{ BaseProvider: base, } @@ -37,14 +38,17 @@ func (p *ClientCredentialsTokenProvider) FetchToken(ctx context.Context, oidc *e // getTokenWithClientCredentialFlow fetches the oauth2 token with client credential config func (p *ClientCredentialsTokenProvider) getTokenWithClientCredentialConfig(ctx context.Context, oidc *egv1a1.OIDC, clientSecret string) (*oauth2.Token, error) { - oauth2Config := clientcredentials.Config{ - ClientID: oidc.ClientID, - ClientSecret: clientSecret, - // Discovery returns the OAuth2 endpoints. - TokenURL: *oidc.Provider.TokenEndpoint, - Scopes: oidc.Scopes, + if p.TokenSource == nil { + oauth2Config := clientcredentials.Config{ + ClientID: oidc.ClientID, + ClientSecret: clientSecret, + // Discovery returns the OAuth2 endpoints. + TokenURL: *oidc.Provider.TokenEndpoint, + Scopes: oidc.Scopes, + } + p.TokenSource = oauth2Config.TokenSource(ctx) } - token, err := oauth2Config.Token(ctx) + token, err := p.TokenSource.Token() if err != nil { return nil, fmt.Errorf("fail to get oauth2 token %w", err) } diff --git a/internal/controller/oauth/client_credentials_token_provider_test.go b/internal/controller/oauth/client_credentials_token_provider_test.go index 29b54f741..f4fe28986 100644 --- a/internal/controller/oauth/client_credentials_token_provider_test.go +++ b/internal/controller/oauth/client_credentials_token_provider_test.go @@ -18,31 +18,16 @@ import ( gwapiv1 "sigs.k8s.io/gateway-api/apis/v1" ) -// MockClientCredentialsProvider implements the standard OAuth2 client credentials flow -type MockClientCredentialsProvider struct { +// MockClientCredentialsTokenSource implements the standard OAuth2 client credentials flow +type MockClientCredentialsTokenSource struct { *BaseProvider } -// NewMockClientCredentialsProvider creates a new client credentials provider -func NewMockClientCredentialsProvider(base *BaseProvider) TokenProvider { - return &MockClientCredentialsProvider{ - BaseProvider: base, - } -} - // FetchToken gets the client secret from the secret reference and fetches the token from provider token URL. -func (m *MockClientCredentialsProvider) FetchToken(ctx context.Context, oidc *egv1a1.OIDC) (*oauth2.Token, error) { - _, err := m.getClientSecret(ctx, &corev1.SecretReference{ - Name: string(oidc.ClientSecret.Name), - Namespace: string(*oidc.ClientSecret.Namespace), - }) - if err != nil { - return nil, err - } +func (m *MockClientCredentialsTokenSource) Token() (*oauth2.Token, error) { return &oauth2.Token{ AccessToken: "token", ExpiresIn: 3600, - Expiry: time.Now().Add(time.Duration(3600) * time.Second), }, nil } @@ -76,11 +61,12 @@ func TestClientCredentialsProvider_FetchToken(t *testing.T) { }) require.NoError(t, err) - clientProvider := NewMockClientCredentialsProvider(baseProvider) - require.NotNil(t, clientProvider) + clientCredentialProvider := NewClientCredentialsProvider(baseProvider) + clientCredentialProvider.TokenSource = &MockClientCredentialsTokenSource{BaseProvider: baseProvider} + require.NotNil(t, clientCredentialProvider) namespaceRef := gwapiv1.Namespace(secretNamespace) - token, err := clientProvider.FetchToken(context.Background(), &egv1a1.OIDC{ + token, err := clientCredentialProvider.FetchToken(context.Background(), &egv1a1.OIDC{ Provider: egv1a1.OIDCProvider{ Issuer: ts.URL, TokenEndpoint: &ts.URL, diff --git a/internal/controller/oauth/oidc_provider.go b/internal/controller/oauth/oidc_provider.go index defa6b4bd..522ba268d 100644 --- a/internal/controller/oauth/oidc_provider.go +++ b/internal/controller/oauth/oidc_provider.go @@ -14,7 +14,7 @@ import ( // OIDCProvider extends ClientCredentialsTokenProvider with OIDC support type OIDCProvider struct { - tokenProvider TokenProvider + tokenProvider *ClientCredentialsTokenProvider httpClient *http.Client oidcCredential *egv1a1.OIDC } @@ -29,7 +29,7 @@ type OIDCMetadata struct { } // NewOIDCProvider creates a new OIDC-aware provider -func NewOIDCProvider(tokenProvider TokenProvider, oidcCredentials *egv1a1.OIDC) *OIDCProvider { +func NewOIDCProvider(tokenProvider *ClientCredentialsTokenProvider, oidcCredentials *egv1a1.OIDC) *OIDCProvider { return &OIDCProvider{ tokenProvider: tokenProvider, httpClient: &http.Client{Timeout: 30 * time.Second}, diff --git a/internal/controller/oauth/oidc_provider_test.go b/internal/controller/oauth/oidc_provider_test.go index 16b35ae08..3aa254726 100644 --- a/internal/controller/oauth/oidc_provider_test.go +++ b/internal/controller/oauth/oidc_provider_test.go @@ -43,7 +43,7 @@ func TestOIDCProvider_GetOIDCMetadata(t *testing.T) { ClientID: "some-client-id", } - oidcProvider := NewOIDCProvider(NewMockClientCredentialsProvider(baseProvider), oidc) + oidcProvider := NewOIDCProvider(NewClientCredentialsProvider(baseProvider), oidc) metadata, err := oidcProvider.getOIDCMetadata(context.Background(), ts.URL) require.NoError(t, err) require.Equal(t, "token_endpoint", metadata.TokenEndpoint) @@ -91,8 +91,10 @@ func TestOIDCProvider_FetchToken(t *testing.T) { Namespace: &namespaceRef, }, } - - oidcProvider := NewOIDCProvider(NewMockClientCredentialsProvider(baseProvider), oidc) + clientCredentialProvider := NewClientCredentialsProvider(baseProvider) + clientCredentialProvider.TokenSource = &MockClientCredentialsTokenSource{BaseProvider: baseProvider} + require.NotNil(t, clientCredentialProvider) + oidcProvider := NewOIDCProvider(clientCredentialProvider, oidc) token, err := oidcProvider.FetchToken(context.Background()) require.NoError(t, err) require.Equal(t, "token", token.AccessToken) From 785ea99f869d192f89a47d9398230b3dea8d340c Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Mon, 10 Feb 2025 11:58:52 -0500 Subject: [PATCH 16/86] remove patch and fix duplicate code Signed-off-by: Aaron Choo --- .../controller/backend_security_policy.go | 28 ++----------------- .../backend_security_policy_test.go | 23 --------------- internal/controller/sink.go | 8 ++---- 3 files changed, 4 insertions(+), 55 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index ed780ace5..779234ce3 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -47,23 +47,6 @@ type patchBackendSecurityPolicy struct{} // Reconcile implements the [reconcile.TypedReconciler] for [aigv1a1.BackendSecurityPolicy]. func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctrl.Request) (res ctrl.Result, err error) { - if b.reconcileAll { - var backendSecPolicyList aigv1a1.BackendSecurityPolicyList - err = b.client.List(ctx, &backendSecPolicyList) - if err != nil { - b.logger.Error(err, "failed to trigger refresh for existing backendSecPolicy resources") - } else { - for _, backendSecurityPolicy := range backendSecPolicyList.Items { - if isBackendSecurityPolicyAuthOIDC(backendSecurityPolicy.Spec) { - err = b.client.Patch(ctx, &backendSecurityPolicy, patchBackendSecurityPolicy{}) - if err != nil { - b.logger.Error(err, "failed to trigger refresh for existing backendSecPolicy resource", "name", backendSecurityPolicy.Name) - } - } - } - b.reconcileAll = false - } - } var backendSecurityPolicy aigv1a1.BackendSecurityPolicy if err = b.client.Get(ctx, req.NamespacedName, &backendSecurityPolicy); err != nil { if errors.IsNotFound(err) { @@ -74,7 +57,7 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr return ctrl.Result{}, err } - if isBackendSecurityPolicyAuthOIDC(backendSecurityPolicy.Spec) { + if getBackendSecurityPolicyAuthOIDC(backendSecurityPolicy.Spec) != nil { var requeue time.Duration requeue = time.Minute region := backendSecurityPolicy.Spec.AWSCredentials.Region @@ -94,15 +77,8 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr return } -func isBackendSecurityPolicyAuthOIDC(spec aigv1a1.BackendSecurityPolicySpec) bool { - if spec.AWSCredentials != nil { - return spec.AWSCredentials.OIDCExchangeToken != nil - } - return false -} - func getBackendSecurityPolicyAuthOIDC(spec aigv1a1.BackendSecurityPolicySpec) *egv1a1.OIDC { - if isBackendSecurityPolicyAuthOIDC(spec) { + if spec.AWSCredentials != nil && spec.AWSCredentials.OIDCExchangeToken != nil { if spec.AWSCredentials.OIDCExchangeToken != nil { return &spec.AWSCredentials.OIDCExchangeToken.OIDC } diff --git a/internal/controller/backend_security_policy_test.go b/internal/controller/backend_security_policy_test.go index 64a070dcc..97e7bae73 100644 --- a/internal/controller/backend_security_policy_test.go +++ b/internal/controller/backend_security_policy_test.go @@ -67,29 +67,6 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { require.NoError(t, err) } -func TestBackendSecurityController_IsBackendSecurityPolicyAuthOIDC(t *testing.T) { - require.False(t, isBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{ - Type: aigv1a1.BackendSecurityPolicyTypeAPIKey, - APIKey: &aigv1a1.BackendSecurityPolicyAPIKey{}, - })) - - require.False(t, isBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{ - Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, - AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ - Region: "us-east-1", - CredentialsFile: &aigv1a1.AWSCredentialsFile{}, - }, - })) - - require.True(t, isBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{ - Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, - AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ - Region: "us-east-1", - OIDCExchangeToken: &aigv1a1.AWSOIDCExchangeToken{}, - }, - })) -} - func TestBackendSecurityController_GetBackendSecurityPolicyAuthOIDC(t *testing.T) { require.Nil(t, getBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{ Type: aigv1a1.BackendSecurityPolicyTypeAPIKey, diff --git a/internal/controller/sink.go b/internal/controller/sink.go index 83cdd64c3..8e1b5d6a7 100644 --- a/internal/controller/sink.go +++ b/internal/controller/sink.go @@ -267,11 +267,11 @@ func (c *configSink) syncBackendSecurityPolicy(ctx context.Context, bsp *aigv1a1 c.syncAIServiceBackend(ctx, aiBackend) } - if isBackendSecurityPolicyAuthOIDC(bsp.Spec) { + if oidc := getBackendSecurityPolicyAuthOIDC(bsp.Spec); oidc != nil { tokenResponse, ok := c.oidcTokenCache[key] if !ok || backendauthrotators.IsExpired(preRotationWindow, tokenResponse.Expiry) { baseProvider := oauth.NewBaseProvider(c.client, c.logger) - oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(baseProvider), getBackendSecurityPolicyAuthOIDC(bsp.Spec)) + oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(baseProvider), oidc) tokenRes, err := oidcProvider.FetchToken(context.TODO()) if err != nil { @@ -297,10 +297,6 @@ func (c *configSink) syncBackendSecurityPolicy(ctx context.Context, bsp *aigv1a1 if expired { token := tokenResponse.AccessToken - if token == "" { - token = tokenResponse.AccessToken - } - err = rotator.Rotate(context.Background(), awsCredentials.Region, awsCredentials.OIDCExchangeToken.AwsRoleArn, token) if err != nil { c.logger.Error(err, "failed to rotate AWS OIDC exchange token") From 81ab6d03ed6533d8fd25eff3e4f62bed4c7713cf Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Mon, 10 Feb 2025 14:02:23 -0500 Subject: [PATCH 17/86] add test and remove unused functions Signed-off-by: Aaron Choo --- .../controller/backend_security_policy.go | 12 ----- .../backend_security_policy_test.go | 5 -- internal/controller/oauth/oidc_provider.go | 53 +++++++++---------- .../controller/oauth/oidc_provider_test.go | 14 +++-- site/crd-ref-docs/templates/type_members.tpl | 4 +- site/docs/getting-started/basic-usage.md | 2 +- 6 files changed, 36 insertions(+), 54 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 779234ce3..af3476756 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -7,13 +7,11 @@ package controller import ( "context" - "fmt" "time" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/go-logr/logr" "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/kubernetes" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" @@ -43,8 +41,6 @@ func newBackendSecurityPolicyController(client client.Client, kube kubernetes.In } } -type patchBackendSecurityPolicy struct{} - // Reconcile implements the [reconcile.TypedReconciler] for [aigv1a1.BackendSecurityPolicy]. func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctrl.Request) (res ctrl.Result, err error) { var backendSecurityPolicy aigv1a1.BackendSecurityPolicy @@ -85,11 +81,3 @@ func getBackendSecurityPolicyAuthOIDC(spec aigv1a1.BackendSecurityPolicySpec) *e } return nil } - -func (p patchBackendSecurityPolicy) Type() types.PatchType { - return types.MergePatchType -} - -func (p patchBackendSecurityPolicy) Data(_ client.Object) ([]byte, error) { - return []byte(fmt.Sprintf(`{"metadata":{"annotations":{"%s":"%s"}}}`, "reconcile", time.Now().String())), nil -} diff --git a/internal/controller/backend_security_policy_test.go b/internal/controller/backend_security_policy_test.go index 97e7bae73..baec0300e 100644 --- a/internal/controller/backend_security_policy_test.go +++ b/internal/controller/backend_security_policy_test.go @@ -55,11 +55,6 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { oidcBackendSecurityPolicy := &aigv1a1.BackendSecurityPolicy{} err = cl.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName)}, oidcBackendSecurityPolicy) require.NoError(t, err) - require.Len(t, oidcBackendSecurityPolicy.Annotations, 1) - time, ok := oidcBackendSecurityPolicy.Annotations["reconcile"] - require.True(t, ok) - require.NotEmpty(t, time) - // Test the case where the BackendSecurityPolicy is being deleted. err = cl.Delete(t.Context(), &aigv1a1.BackendSecurityPolicy{ObjectMeta: metav1.ObjectMeta{Name: backendSecurityPolicyName, Namespace: namespace}}) require.NoError(t, err) diff --git a/internal/controller/oauth/oidc_provider.go b/internal/controller/oauth/oidc_provider.go index 522ba268d..72211711b 100644 --- a/internal/controller/oauth/oidc_provider.go +++ b/internal/controller/oauth/oidc_provider.go @@ -2,12 +2,11 @@ package oauth import ( "context" - "encoding/json" "fmt" "net/http" - "strings" "time" + "github.com/coreos/go-oidc/v3/oidc" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "golang.org/x/oauth2" ) @@ -37,69 +36,65 @@ func NewOIDCProvider(tokenProvider *ClientCredentialsTokenProvider, oidcCredenti } } -// getOIDCMetadata retrieves or creates OIDC metadata for the given issuer URL -func (p *OIDCProvider) getOIDCMetadata(ctx context.Context, issuerURL string) (*OIDCMetadata, error) { +// getOIDCProviderConfig retrieves or creates OIDC config for the given issuer URL +func (p *OIDCProvider) getOIDCProviderConfig(ctx context.Context, issuerURL string) (*oidc.ProviderConfig, *[]string, error) { // Check context before proceeding if err := ctx.Err(); err != nil { - return nil, fmt.Errorf("context error before discovery: %w", err) + return nil, nil, fmt.Errorf("context error before discovery: %w", err) } - // Fetch OIDC configuration - wellKnown := strings.TrimSuffix(issuerURL, "/") + "/.well-known/openid-configuration" - req, err := http.NewRequestWithContext(ctx, "GET", wellKnown, nil) + provider, err := oidc.NewProvider(ctx, issuerURL) if err != nil { - return nil, fmt.Errorf("failed to create discovery request: %w", err) + return nil, nil, fmt.Errorf("failed to create go-oidc provider %q: %w", issuerURL, err) } - resp, err := p.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to fetch OIDC metadata: %w", err) + var config oidc.ProviderConfig + if err = provider.Claims(&config); err != nil { + return nil, nil, fmt.Errorf("failed to decode provider config claims %q: %w", issuerURL, err) } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code from discovery endpoint: %d", resp.StatusCode) + // Unmarshall supported scopes + var claims struct { + SupportedScopes []string `json:"scopes_supported"` } - - var metadata OIDCMetadata - if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil { - return nil, fmt.Errorf("failed to decode OIDC metadata: %w", err) + if err = provider.Claims(&claims); err != nil { + return nil, nil, fmt.Errorf("failed to decode provider scope supported claims: %w", err) } // Validate required fields - if metadata.Issuer == "" { - return nil, fmt.Errorf("issuer is required in OIDC metadata") + if config.IssuerURL == "" { + return nil, nil, fmt.Errorf("issuer is required in OIDC provider config") } - if metadata.TokenEndpoint == "" { - return nil, fmt.Errorf("token_endpoint is required in OIDC metadata") + if config.TokenURL == "" { + return nil, nil, fmt.Errorf("token_endpoint is required in OIDC provider config") } - return &metadata, nil + return &config, &claims.SupportedScopes, nil } // FetchToken retrieves and validates tokens using the client credentials flow with OIDC support func (p *OIDCProvider) FetchToken(ctx context.Context) (*oauth2.Token, error) { // If issuer URL is provided, fetch OIDC metadata if issuerURL := p.oidcCredential.Provider.Issuer; issuerURL != "" { - metadata, err := p.getOIDCMetadata(ctx, issuerURL) + config, supportedScopes, err := p.getOIDCProviderConfig(ctx, issuerURL) if err != nil { - return nil, fmt.Errorf("failed to get OIDC metadata: %w", err) + return nil, fmt.Errorf("failed to get OIDC config: %w", err) } // Use discovered token endpoint if not explicitly provided if p.oidcCredential.Provider.TokenEndpoint == nil { - p.oidcCredential.Provider.TokenEndpoint = &metadata.TokenEndpoint + p.oidcCredential.Provider.TokenEndpoint = &config.TokenURL } // Add discovered scopes if available - if len(metadata.SupportedScopes) > 0 { + if supportedScopes != nil && len(*supportedScopes) > 0 { requestedScopes := make(map[string]bool) for _, scope := range p.oidcCredential.Scopes { requestedScopes[scope] = true } // Add supported scopes that aren't already requested - for _, scope := range metadata.SupportedScopes { + for _, scope := range *supportedScopes { if !requestedScopes[scope] { p.oidcCredential.Scopes = append(p.oidcCredential.Scopes, scope) } diff --git a/internal/controller/oauth/oidc_provider_test.go b/internal/controller/oauth/oidc_provider_test.go index 3aa254726..f8d714561 100644 --- a/internal/controller/oauth/oidc_provider_test.go +++ b/internal/controller/oauth/oidc_provider_test.go @@ -6,6 +6,7 @@ import ( "net/http/httptest" "testing" + oidcv3 "github.com/coreos/go-oidc/v3/oidc" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/stretchr/testify/require" corev1 "k8s.io/api/core/v1" @@ -20,7 +21,7 @@ func TestNewOIDCProvider(t *testing.T) { require.NotNil(t, NewOIDCProvider(nil, &egv1a1.OIDC{})) } -func TestOIDCProvider_GetOIDCMetadata(t *testing.T) { +func TestOIDCProvider_GetOIDCProviderConfig(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, err := w.Write([]byte(`{"issuer": "issuer", "token_endpoint": "token_endpoint", "authorization_endpoint": "authorization_endpoint", "jwks_uri": "jwks_uri", "scopes_supported": []}`)) require.NoError(t, err) @@ -43,11 +44,13 @@ func TestOIDCProvider_GetOIDCMetadata(t *testing.T) { ClientID: "some-client-id", } + ctx := oidcv3.InsecureIssuerURLContext(context.Background(), ts.URL) oidcProvider := NewOIDCProvider(NewClientCredentialsProvider(baseProvider), oidc) - metadata, err := oidcProvider.getOIDCMetadata(context.Background(), ts.URL) + config, supportedScope, err := oidcProvider.getOIDCProviderConfig(ctx, ts.URL) require.NoError(t, err) - require.Equal(t, "token_endpoint", metadata.TokenEndpoint) - require.Equal(t, "issuer", metadata.Issuer) + require.Equal(t, "token_endpoint", config.TokenURL) + require.Equal(t, "issuer", config.IssuerURL) + require.Empty(t, supportedScope) } func TestOIDCProvider_FetchToken(t *testing.T) { @@ -94,8 +97,9 @@ func TestOIDCProvider_FetchToken(t *testing.T) { clientCredentialProvider := NewClientCredentialsProvider(baseProvider) clientCredentialProvider.TokenSource = &MockClientCredentialsTokenSource{BaseProvider: baseProvider} require.NotNil(t, clientCredentialProvider) + ctx := oidcv3.InsecureIssuerURLContext(context.Background(), ts.URL) oidcProvider := NewOIDCProvider(clientCredentialProvider, oidc) - token, err := oidcProvider.FetchToken(context.Background()) + token, err := oidcProvider.FetchToken(ctx) require.NoError(t, err) require.Equal(t, "token", token.AccessToken) require.Equal(t, "Bearer", token.Type()) diff --git a/site/crd-ref-docs/templates/type_members.tpl b/site/crd-ref-docs/templates/type_members.tpl index e372c3f06..0858fac79 100644 --- a/site/crd-ref-docs/templates/type_members.tpl +++ b/site/crd-ref-docs/templates/type_members.tpl @@ -1,7 +1,7 @@ {{- define "type_members" -}} {{- $field := . -}} -{{- if eq $field.Name "metadata" -}} -Refer to Kubernetes API documentation for fields of `metadata`. +{{- if eq $field.Name "config" -}} +Refer to Kubernetes API documentation for fields of `config`. {{- else -}} {{ markdownRenderFieldDoc $field.Doc | replace "\"" "`" }} {{- end -}} diff --git a/site/docs/getting-started/basic-usage.md b/site/docs/getting-started/basic-usage.md index 5dd9f707a..ae7cd8c8f 100644 --- a/site/docs/getting-started/basic-usage.md +++ b/site/docs/getting-started/basic-usage.md @@ -75,7 +75,7 @@ Then set up port forwarding (this will block the terminal): ```shell export ENVOY_SERVICE=$(kubectl get svc -n envoy-gateway-system \ --selector=gateway.envoyproxy.io/owning-gateway-namespace=default,gateway.envoyproxy.io/owning-gateway-name=envoy-ai-gateway-basic \ - -o jsonpath='{.items[0].metadata.name}') + -o jsonpath='{.items[0].config.name}') kubectl port-forward -n envoy-gateway-system svc/$ENVOY_SERVICE 8080:80 ``` From fc8229da699feb63572b6a7d361589c69db52a50 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Mon, 10 Feb 2025 14:31:47 -0500 Subject: [PATCH 18/86] updates Signed-off-by: Aaron Choo --- .../001-ai-gateway-proposal/proposal.md | 40 +++++++++---------- go.mod | 2 + go.sum | 4 ++ internal/controller/oauth/oidc_provider.go | 9 ----- 4 files changed, 26 insertions(+), 29 deletions(-) diff --git a/docs/proposals/001-ai-gateway-proposal/proposal.md b/docs/proposals/001-ai-gateway-proposal/proposal.md index 7efcf48e6..ec8937e50 100644 --- a/docs/proposals/001-ai-gateway-proposal/proposal.md +++ b/docs/proposals/001-ai-gateway-proposal/proposal.md @@ -130,7 +130,7 @@ FilterConfig *AIGatewayFilterConfig `json:"filterConfig,omitempty"` // LLMRequestCosts specifies how to capture the cost of the LLM-related request, notably the token usage. // The AI Gateway filter will capture each specified number and store it in the Envoy's dynamic -// metadata per HTTP request. The namespaced key is "io.envoy.ai_gateway", +// config per HTTP request. The namespaced key is "io.envoy.ai_gateway", // // For example, let's say we have the following LLMRequestCosts configuration: // @@ -172,7 +172,7 @@ BackendRefs []AIGatewayRouteRuleBackendRef `json:"backendRefs,omitempty"` Matches []AIGatewayRouteRuleMatch `json:"matches,omitempty"` } -// LLMRequestCost specifies "where" the request cost is stored in the filter metadata as well as +// LLMRequestCost specifies "where" the request cost is stored in the filter config as well as // "how" the cost is calculated. By default, the cost is retrieved from "output token" in the response body. // // This can be used to subtract the usage token from the usage quota in the rate limit filter when @@ -180,7 +180,7 @@ Matches []AIGatewayRouteRuleMatch `json:"matches,omitempty"` // the rate limit configuration https://www.envoyproxy.io/docs/envoy/latest/api-v3/config/route/v3/route_components.proto#config-route-v3-ratelimit // which is introduced in Envoy 1.33 (to be released soon as of writing). type LLMRequestCost struct { -// MetadataKey is the key of the metadata storing the request cost. +// MetadataKey is the key of the config storing the request cost. MetadataKey string `json:"metadataKey"` // Type is the kind of the request cost calculation. Type LLMRequestCostType `json:"type"` @@ -311,7 +311,7 @@ type RateLimitCost struct { } // RateLimitCostSpecifier specifies where the Envoy retrieves the number to reduce the rate limit counters. // -// +kubebuilder:validation:XValidation:rule="!(has(self.number) && has(self.metadata))",message="only one of number or metadata can be specified" +// +kubebuilder:validation:XValidation:rule="!(has(self.number) && has(self.config))",message="only one of number or config can be specified" type RateLimitCostSpecifier struct { // From specifies where to get the rate limit cost. Currently, only "Number" and "Metadata" are supported. // @@ -323,19 +323,19 @@ From RateLimitCostFrom `json:"from"` // +optional // +notImplementedHide Number *uint64 `json:"number,omitempty"` -// Metadata specifies the per-request metadata to retrieve the usage number from. +// Metadata specifies the per-request config to retrieve the usage number from. // // +optional // +notImplementedHide -Metadata *RateLimitCostMetadata `json:"metadata,omitempty"` +Metadata *RateLimitCostMetadata `json:"config,omitempty"` } -// RateLimitCostMetadata specifies the filter metadata to retrieve the usage number from. +// RateLimitCostMetadata specifies the filter config to retrieve the usage number from. type RateLimitCostMetadata struct { -// Namespace is the namespace of the dynamic metadata. +// Namespace is the namespace of the dynamic config. // // +kubebuilder:validation:Required Namespace string `json:"namespace"` -// Key is the key to retrieve the usage number from the filter metadata. +// Key is the key to retrieve the usage number from the filter config. // // +kubebuilder:validation:Required Key string `json:"key"` @@ -387,12 +387,12 @@ The routing calculation in done in the `ExtProc` by analyzing the match rules on because it happens at the very end of the filter chain. The `AIServiceBackend` rules are specified on the `AIGatewayRoute` based on model header matching, in this example `anthropic.claude-3-5-sonnet` is routed to the AWS Bedrock and `llama-3.3-70b-instruction` is routed to the KServe backend for the self-hosted llama model. -`LLMRequestCost` is specified with the metadata key `llm_total_token` to store the cost of the LLM request. +`LLMRequestCost` is specified with the config key `llm_total_token` to store the cost of the LLM request. ```yaml apiVersion: aigateway.envoyproxy.io/v1alpha1 kind: AIGatewayRoute -metadata: +config: name: llmroute namespace: ai-gateway spec: @@ -417,7 +417,7 @@ spec: value: llama-3.3-70b-instruction backendRefs: - name: kserve-llama-backend - # The following metadata keys are used to store the costs from the LLM request. + # The following config keys are used to store the costs from the LLM request. llmRequestCosts: - metadataKey: llm_total_token type: TotalToken @@ -432,7 +432,7 @@ In this example API key is used to authenticate with OpenAI service and AWS cred ```yaml apiVersion: aigateway.envoyproxy.io/v1alpha1 kind: BackendSecurityPolicy -metadata: +config: name: aws-bedrock-credential namespace: default spec: @@ -446,7 +446,7 @@ spec: --- apiVersion: aigateway.envoyproxy.io/v1alpha1 kind: BackendSecurityPolicy -metadata: +config: name: openai-ai-key namespace: default spec: @@ -461,7 +461,7 @@ Based on the gateway routes, we define the AWS Bedrock and KServe `AIServiceBack ```yaml apiVersion: aigateway.envoyproxy.io/v1alpha1 kind: AIServiceBackend -metadata: +config: name: awsbedrock-backend namespace: ai-gateway spec: @@ -478,7 +478,7 @@ spec: --- apiVersion: aigateway.envoyproxy.io/v1alpha1 kind: AIServiceBackend -metadata: +config: name: kserve-llama-backend namespace: ai-gateway spec: @@ -495,7 +495,7 @@ spec: --- apiVersion: gateway.envoyproxy.io/v1alpha1 kind: Backend -metadata: +config: name: kserve-llama-backend namespace: ai-gateway spec: @@ -506,7 +506,7 @@ spec: --- apiVersion: gateway.envoyproxy.io/v1alpha1 kind: Backend -metadata: +config: name: llm-bedrock-backend namespace: ai-gateway spec: @@ -522,7 +522,7 @@ spec: ```yaml apiVersion: gateway.envoyproxy.io/v1alpha1 kind: BackendTrafficPolicy -metadata: +config: name: llama-ratelimit spec: # Applies the rate limit policy to the gateway. @@ -547,7 +547,7 @@ spec: cost: response: from: Metadata - metadata: + config: namespace: "io.envoy.ai_gateway" key: "llm_total_token" ``` diff --git a/go.mod b/go.mod index 048957dfa..99bcdff32 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.8 github.com/aws/aws-sdk-go-v2/config v1.29.6 github.com/aws/aws-sdk-go-v2/service/sts v1.33.14 + github.com/coreos/go-oidc/v3 v3.12.0 github.com/envoyproxy/gateway v1.3.0 github.com/envoyproxy/go-control-plane/envoy v1.32.4 github.com/go-logr/logr v1.4.2 @@ -144,6 +145,7 @@ require ( github.com/go-git/go-billy/v5 v5.6.0 // indirect github.com/go-git/go-git/v5 v5.13.0 // indirect github.com/go-gorp/gorp/v3 v3.1.0 // indirect + github.com/go-jose/go-jose/v4 v4.0.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/zapr v1.3.0 // indirect github.com/go-openapi/jsonpointer v0.21.0 // indirect diff --git a/go.sum b/go.sum index 1b7834776..af556ca81 100644 --- a/go.sum +++ b/go.sum @@ -182,6 +182,8 @@ github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A= github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw= +github.com/coreos/go-oidc/v3 v3.12.0 h1:sJk+8G2qq94rDI6ehZ71Bol3oUHy63qNYmkiSjrc/Jo= +github.com/coreos/go-oidc/v3 v3.12.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= @@ -295,6 +297,8 @@ github.com/go-git/go-git/v5 v5.13.0 h1:vLn5wlGIh/X78El6r3Jr+30W16Blk0CTcxTYcYPWi github.com/go-git/go-git/v5 v5.13.0/go.mod h1:Wjo7/JyVKtQgUNdXYXIepzWfJQkUEIGvkvVkiXRR/zw= github.com/go-gorp/gorp/v3 v3.1.0 h1:ItKF/Vbuj31dmV4jxA1qblpSwkl9g1typ24xoe70IGs= github.com/go-gorp/gorp/v3 v3.1.0/go.mod h1:dLEjIyyRNiXvNZ8PSmzpt1GsWAUK8kjVhEpjH8TixEw= +github.com/go-jose/go-jose/v4 v4.0.2 h1:R3l3kkBds16bO7ZFAEEcofK0MkrAJt3jlJznWZG0nvk= +github.com/go-jose/go-jose/v4 v4.0.2/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= diff --git a/internal/controller/oauth/oidc_provider.go b/internal/controller/oauth/oidc_provider.go index 72211711b..8140865ef 100644 --- a/internal/controller/oauth/oidc_provider.go +++ b/internal/controller/oauth/oidc_provider.go @@ -18,15 +18,6 @@ type OIDCProvider struct { oidcCredential *egv1a1.OIDC } -// OIDCMetadata represents the OpenID Connect provider metadata -type OIDCMetadata struct { - Issuer string `json:"issuer"` - AuthorizationEndpoint string `json:"authorization_endpoint"` - TokenEndpoint string `json:"token_endpoint"` - JWKSURI string `json:"jwks_uri"` - SupportedScopes []string `json:"scopes_supported"` -} - // NewOIDCProvider creates a new OIDC-aware provider func NewOIDCProvider(tokenProvider *ClientCredentialsTokenProvider, oidcCredentials *egv1a1.OIDC) *OIDCProvider { return &OIDCProvider{ From aaf266c81257ae04775d94152df7356bdc31f65f Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Mon, 10 Feb 2025 14:52:59 -0500 Subject: [PATCH 19/86] fix accidental refactor Signed-off-by: Aaron Choo --- .../001-ai-gateway-proposal/proposal.md | 40 +++++++++---------- site/crd-ref-docs/templates/type_members.tpl | 4 +- site/docs/getting-started/basic-usage.md | 2 +- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/docs/proposals/001-ai-gateway-proposal/proposal.md b/docs/proposals/001-ai-gateway-proposal/proposal.md index ec8937e50..7efcf48e6 100644 --- a/docs/proposals/001-ai-gateway-proposal/proposal.md +++ b/docs/proposals/001-ai-gateway-proposal/proposal.md @@ -130,7 +130,7 @@ FilterConfig *AIGatewayFilterConfig `json:"filterConfig,omitempty"` // LLMRequestCosts specifies how to capture the cost of the LLM-related request, notably the token usage. // The AI Gateway filter will capture each specified number and store it in the Envoy's dynamic -// config per HTTP request. The namespaced key is "io.envoy.ai_gateway", +// metadata per HTTP request. The namespaced key is "io.envoy.ai_gateway", // // For example, let's say we have the following LLMRequestCosts configuration: // @@ -172,7 +172,7 @@ BackendRefs []AIGatewayRouteRuleBackendRef `json:"backendRefs,omitempty"` Matches []AIGatewayRouteRuleMatch `json:"matches,omitempty"` } -// LLMRequestCost specifies "where" the request cost is stored in the filter config as well as +// LLMRequestCost specifies "where" the request cost is stored in the filter metadata as well as // "how" the cost is calculated. By default, the cost is retrieved from "output token" in the response body. // // This can be used to subtract the usage token from the usage quota in the rate limit filter when @@ -180,7 +180,7 @@ Matches []AIGatewayRouteRuleMatch `json:"matches,omitempty"` // the rate limit configuration https://www.envoyproxy.io/docs/envoy/latest/api-v3/config/route/v3/route_components.proto#config-route-v3-ratelimit // which is introduced in Envoy 1.33 (to be released soon as of writing). type LLMRequestCost struct { -// MetadataKey is the key of the config storing the request cost. +// MetadataKey is the key of the metadata storing the request cost. MetadataKey string `json:"metadataKey"` // Type is the kind of the request cost calculation. Type LLMRequestCostType `json:"type"` @@ -311,7 +311,7 @@ type RateLimitCost struct { } // RateLimitCostSpecifier specifies where the Envoy retrieves the number to reduce the rate limit counters. // -// +kubebuilder:validation:XValidation:rule="!(has(self.number) && has(self.config))",message="only one of number or config can be specified" +// +kubebuilder:validation:XValidation:rule="!(has(self.number) && has(self.metadata))",message="only one of number or metadata can be specified" type RateLimitCostSpecifier struct { // From specifies where to get the rate limit cost. Currently, only "Number" and "Metadata" are supported. // @@ -323,19 +323,19 @@ From RateLimitCostFrom `json:"from"` // +optional // +notImplementedHide Number *uint64 `json:"number,omitempty"` -// Metadata specifies the per-request config to retrieve the usage number from. +// Metadata specifies the per-request metadata to retrieve the usage number from. // // +optional // +notImplementedHide -Metadata *RateLimitCostMetadata `json:"config,omitempty"` +Metadata *RateLimitCostMetadata `json:"metadata,omitempty"` } -// RateLimitCostMetadata specifies the filter config to retrieve the usage number from. +// RateLimitCostMetadata specifies the filter metadata to retrieve the usage number from. type RateLimitCostMetadata struct { -// Namespace is the namespace of the dynamic config. +// Namespace is the namespace of the dynamic metadata. // // +kubebuilder:validation:Required Namespace string `json:"namespace"` -// Key is the key to retrieve the usage number from the filter config. +// Key is the key to retrieve the usage number from the filter metadata. // // +kubebuilder:validation:Required Key string `json:"key"` @@ -387,12 +387,12 @@ The routing calculation in done in the `ExtProc` by analyzing the match rules on because it happens at the very end of the filter chain. The `AIServiceBackend` rules are specified on the `AIGatewayRoute` based on model header matching, in this example `anthropic.claude-3-5-sonnet` is routed to the AWS Bedrock and `llama-3.3-70b-instruction` is routed to the KServe backend for the self-hosted llama model. -`LLMRequestCost` is specified with the config key `llm_total_token` to store the cost of the LLM request. +`LLMRequestCost` is specified with the metadata key `llm_total_token` to store the cost of the LLM request. ```yaml apiVersion: aigateway.envoyproxy.io/v1alpha1 kind: AIGatewayRoute -config: +metadata: name: llmroute namespace: ai-gateway spec: @@ -417,7 +417,7 @@ spec: value: llama-3.3-70b-instruction backendRefs: - name: kserve-llama-backend - # The following config keys are used to store the costs from the LLM request. + # The following metadata keys are used to store the costs from the LLM request. llmRequestCosts: - metadataKey: llm_total_token type: TotalToken @@ -432,7 +432,7 @@ In this example API key is used to authenticate with OpenAI service and AWS cred ```yaml apiVersion: aigateway.envoyproxy.io/v1alpha1 kind: BackendSecurityPolicy -config: +metadata: name: aws-bedrock-credential namespace: default spec: @@ -446,7 +446,7 @@ spec: --- apiVersion: aigateway.envoyproxy.io/v1alpha1 kind: BackendSecurityPolicy -config: +metadata: name: openai-ai-key namespace: default spec: @@ -461,7 +461,7 @@ Based on the gateway routes, we define the AWS Bedrock and KServe `AIServiceBack ```yaml apiVersion: aigateway.envoyproxy.io/v1alpha1 kind: AIServiceBackend -config: +metadata: name: awsbedrock-backend namespace: ai-gateway spec: @@ -478,7 +478,7 @@ spec: --- apiVersion: aigateway.envoyproxy.io/v1alpha1 kind: AIServiceBackend -config: +metadata: name: kserve-llama-backend namespace: ai-gateway spec: @@ -495,7 +495,7 @@ spec: --- apiVersion: gateway.envoyproxy.io/v1alpha1 kind: Backend -config: +metadata: name: kserve-llama-backend namespace: ai-gateway spec: @@ -506,7 +506,7 @@ spec: --- apiVersion: gateway.envoyproxy.io/v1alpha1 kind: Backend -config: +metadata: name: llm-bedrock-backend namespace: ai-gateway spec: @@ -522,7 +522,7 @@ spec: ```yaml apiVersion: gateway.envoyproxy.io/v1alpha1 kind: BackendTrafficPolicy -config: +metadata: name: llama-ratelimit spec: # Applies the rate limit policy to the gateway. @@ -547,7 +547,7 @@ spec: cost: response: from: Metadata - config: + metadata: namespace: "io.envoy.ai_gateway" key: "llm_total_token" ``` diff --git a/site/crd-ref-docs/templates/type_members.tpl b/site/crd-ref-docs/templates/type_members.tpl index 0858fac79..e372c3f06 100644 --- a/site/crd-ref-docs/templates/type_members.tpl +++ b/site/crd-ref-docs/templates/type_members.tpl @@ -1,7 +1,7 @@ {{- define "type_members" -}} {{- $field := . -}} -{{- if eq $field.Name "config" -}} -Refer to Kubernetes API documentation for fields of `config`. +{{- if eq $field.Name "metadata" -}} +Refer to Kubernetes API documentation for fields of `metadata`. {{- else -}} {{ markdownRenderFieldDoc $field.Doc | replace "\"" "`" }} {{- end -}} diff --git a/site/docs/getting-started/basic-usage.md b/site/docs/getting-started/basic-usage.md index ae7cd8c8f..5dd9f707a 100644 --- a/site/docs/getting-started/basic-usage.md +++ b/site/docs/getting-started/basic-usage.md @@ -75,7 +75,7 @@ Then set up port forwarding (this will block the terminal): ```shell export ENVOY_SERVICE=$(kubectl get svc -n envoy-gateway-system \ --selector=gateway.envoyproxy.io/owning-gateway-namespace=default,gateway.envoyproxy.io/owning-gateway-name=envoy-ai-gateway-basic \ - -o jsonpath='{.items[0].config.name}') + -o jsonpath='{.items[0].metadata.name}') kubectl port-forward -n envoy-gateway-system svc/$ENVOY_SERVICE 8080:80 ``` From d4bf4f0aa1432877b0296379cf09430ae48b1f56 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Mon, 10 Feb 2025 16:25:16 -0500 Subject: [PATCH 20/86] testing Signed-off-by: Aaron Choo --- .../controller/rotators/aws_common_test.go | 76 +++++++++ internal/controller/rotators/common.go | 4 +- internal/controller/rotators/common_test.go | 156 ++++++++++++++++++ 3 files changed, 234 insertions(+), 2 deletions(-) create mode 100644 internal/controller/rotators/aws_common_test.go create mode 100644 internal/controller/rotators/common_test.go diff --git a/internal/controller/rotators/aws_common_test.go b/internal/controller/rotators/aws_common_test.go new file mode 100644 index 000000000..9abc0eb66 --- /dev/null +++ b/internal/controller/rotators/aws_common_test.go @@ -0,0 +1,76 @@ +package backendauthrotators + +import ( + "fmt" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" +) + +func TestNewSTSClient(t *testing.T) { + stsClient := NewSTSClient(aws.Config{Region: "us-west-2"}) + require.NotNil(t, stsClient) + require.NotNil(t, stsClient.client) +} + +func TestParseAWSCredentialsFile(t *testing.T) { + profile := "default" + accessKey := "AKIAXXXXXXXXXXXXXXXX" + secretKey := "XXXXXXXXXXXXXXXXXXXX" + sessionToken := "XXXXXXXXXXXXXXXXXXXX" + region := "us-west-2" + awsCred := parseAWSCredentialsFile(fmt.Sprintf("[%s]\naws_access_key_id=%s\naws_secret_access_key=%s\naws_session_token=%s\nregion=%s", profile, accessKey, + secretKey, sessionToken, region)) + require.NotNil(t, awsCred) + defaultProfile, ok := awsCred.profiles[profile] + require.True(t, ok) + require.NotNil(t, defaultProfile) + require.Equal(t, accessKey, defaultProfile.accessKeyID) + require.Equal(t, secretKey, defaultProfile.secretAccessKey) + require.Equal(t, sessionToken, defaultProfile.sessionToken) + require.Equal(t, region, defaultProfile.region) +} + +func TestFormatAWSCredentialsFile(t *testing.T) { + emptyCredentialsFile := awsCredentialsFile{map[string]*awsCredentials{}} + require.Empty(t, formatAWSCredentialsFile(&emptyCredentialsFile)) + + profile := "default" + accessKey := "AKIAXXXXXXXXXXXXXXXX" + secretKey := "XXXXXXXXXXXXXXXXXXXX" + sessionToken := "XXXXXXXXXXXXXXXXXXXX" + region := "us-west-2" + credentials := awsCredentials{ + profile: profile, + accessKeyID: accessKey, + secretAccessKey: secretKey, + sessionToken: sessionToken, + region: region, + } + + awsCred := fmt.Sprintf("[%s]\naws_access_key_id = %s\naws_secret_access_key = %s\naws_session_token = %s\nregion = %s\n", profile, accessKey, + secretKey, sessionToken, region) + + require.Equal(t, awsCred, formatAWSCredentialsFile(&awsCredentialsFile{profiles: map[string]*awsCredentials{"default": &credentials}})) +} + +func TestUpdateAWSCredentialsInSecret(t *testing.T) { + secret := &corev1.Secret{} + + credentials := awsCredentials{ + profile: "default", + accessKeyID: "accessKey", + secretAccessKey: "secretKey", + sessionToken: "sessionToken", + region: "region", + } + + updateAWSCredentialsInSecret(secret, &awsCredentialsFile{profiles: map[string]*awsCredentials{"default": &credentials}}) + require.Len(t, secret.Data, 1) + + val, ok := secret.Data[credentialsKey] + require.True(t, ok) + require.NotEmpty(t, val) +} diff --git a/internal/controller/rotators/common.go b/internal/controller/rotators/common.go index f1573f328..2ba77c95a 100644 --- a/internal/controller/rotators/common.go +++ b/internal/controller/rotators/common.go @@ -55,11 +55,11 @@ func LookupSecret(ctx context.Context, k8sClient client.Client, namespace, name } // updateExpirationSecretAnnotation will set the expiration time of credentials set in secret annotation -func updateExpirationSecretAnnotation(secret *corev1.Secret, time time.Time) { +func updateExpirationSecretAnnotation(secret *corev1.Secret, updateTime time.Time) { if secret.Annotations == nil { secret.Annotations = make(map[string]string) } - secret.Annotations[ExpirationTimeAnnotationKey] = time.String() + secret.Annotations[ExpirationTimeAnnotationKey] = updateTime.Format(time.RFC3339) } // GetExpirationSecretAnnotation will get the expiration time of credentials set in secret annotation diff --git a/internal/controller/rotators/common_test.go b/internal/controller/rotators/common_test.go new file mode 100644 index 000000000..c6498965c --- /dev/null +++ b/internal/controller/rotators/common_test.go @@ -0,0 +1,156 @@ +package backendauthrotators + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" +) + +func TestNewSecret(t *testing.T) { + name := "test" + namespace := "test-namespace" + secret := newSecret(namespace, name) + + require.NotNil(t, secret) + require.Equal(t, name, secret.Name) + require.Equal(t, namespace, secret.Namespace) + require.NotNil(t, secret.Data) +} + +func TestUpdateSecret(t *testing.T) { + scheme := runtime.NewScheme() + scheme.AddKnownTypes(corev1.SchemeGroupVersion, + &corev1.Secret{}, + ) + cl := fake.NewClientBuilder().WithScheme(scheme).Build() + + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "test-namespace", + }, + Data: map[string][]byte{ + "key": []byte("value"), + }, + } + + err := cl.Get(context.Background(), client.ObjectKeyFromObject(secret), secret) + require.NoError(t, client.IgnoreNotFound(err)) + require.NoError(t, updateSecret(context.Background(), cl, secret)) + + var secretPlaceholder corev1.Secret + require.NoError(t, cl.Get(context.Background(), client.ObjectKey{ + Namespace: "test-namespace", + Name: "test", + }, &secretPlaceholder)) + require.Equal(t, secret.Name, secretPlaceholder.Name) + require.Equal(t, secret.Namespace, secretPlaceholder.Namespace) + require.Equal(t, []byte("value"), secretPlaceholder.Data["key"]) + + secret.Data["key"] = []byte("another value") + require.NoError(t, updateSecret(context.Background(), cl, secret)) + + require.NoError(t, cl.Get(context.Background(), client.ObjectKey{ + Namespace: "test-namespace", + Name: "test", + }, &secretPlaceholder)) + require.Equal(t, []byte("another value"), secretPlaceholder.Data["key"]) +} + +func TestLookupSecret(t *testing.T) { + scheme := runtime.NewScheme() + scheme.AddKnownTypes(corev1.SchemeGroupVersion, + &corev1.Secret{}, + ) + cl := fake.NewClientBuilder().WithScheme(scheme).Build() + + secretName := "test" + secretNamespace := "test-namespace" + secret, err := LookupSecret(context.Background(), cl, secretNamespace, secretName) + require.Error(t, err) + require.Nil(t, secret) + + require.NoError(t, cl.Create(context.Background(), &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: secretName, + Namespace: secretNamespace, + }, + })) + + secret, err = LookupSecret(context.Background(), cl, secretNamespace, secretName) + require.NoError(t, err) + require.NotNil(t, secret) + require.Equal(t, secretName, secret.Name) + require.Equal(t, secretNamespace, secret.Namespace) +} + +func TestUpdateExpirationSecretAnnotation(t *testing.T) { + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "test-namespace", + }, + } + timeNow := time.Now() + updateExpirationSecretAnnotation(secret, timeNow) + require.NotNil(t, secret.Annotations) + timeValue, ok := secret.Annotations[ExpirationTimeAnnotationKey] + require.True(t, ok) + require.Equal(t, timeNow.Format(time.RFC3339), timeValue) +} + +func TestGetExpirationSecretAnnotation(t *testing.T) { + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "test-namespace", + }, + } + + expirationTime, err := GetExpirationSecretAnnotation(secret) + require.Error(t, err) + require.Contains(t, err.Error(), "missing expiration time annotation") + require.Nil(t, expirationTime) + + secret.Annotations = map[string]string{ + ExpirationTimeAnnotationKey: "invalid", + } + expirationTime, err = GetExpirationSecretAnnotation(secret) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to parse") + require.Nil(t, expirationTime) + + timeNow := time.Now() + secret.Annotations = map[string]string{ + ExpirationTimeAnnotationKey: timeNow.Format(time.RFC3339), + } + expirationTime, err = GetExpirationSecretAnnotation(secret) + require.NoError(t, err) + require.Equal(t, timeNow.Format(time.RFC3339), expirationTime.Format(time.RFC3339)) +} + +func TestUpdateAndGetExpirationSecretAnnotation(t *testing.T) { + secret := &corev1.Secret{} + expirationTime, err := GetExpirationSecretAnnotation(secret) + require.Error(t, err) + require.Contains(t, err.Error(), "missing expiration time annotation") + require.Nil(t, expirationTime) + + timeNow := time.Now() + updateExpirationSecretAnnotation(secret, timeNow) + expirationTime, err = GetExpirationSecretAnnotation(secret) + require.NoError(t, err) + require.Equal(t, timeNow.Format(time.RFC3339), expirationTime.Format(time.RFC3339)) +} + +func TestIsExpired(t *testing.T) { + require.True(t, IsExpired(1*time.Minute, time.Now())) + require.False(t, IsExpired(1*time.Minute, time.Now().Add(10*time.Minute))) +} From 99184bcc42c7d19171d56ecab264dfceb2023f55 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Mon, 10 Feb 2025 16:27:23 -0500 Subject: [PATCH 21/86] remove reconcile all Signed-off-by: Aaron Choo --- internal/controller/backend_security_policy.go | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index af3476756..c93c71133 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -24,20 +24,18 @@ import ( // // This handles the BackendSecurityPolicy resource and sends it to the config sink so that it can modify configuration. type backendSecurityPolicyController struct { - client client.Client - kube kubernetes.Interface - logger logr.Logger - eventChan chan ConfigSinkEvent - reconcileAll bool + client client.Client + kube kubernetes.Interface + logger logr.Logger + eventChan chan ConfigSinkEvent } func newBackendSecurityPolicyController(client client.Client, kube kubernetes.Interface, logger logr.Logger, ch chan ConfigSinkEvent) *backendSecurityPolicyController { return &backendSecurityPolicyController{ - client: client, - kube: kube, - logger: logger, - eventChan: ch, - reconcileAll: true, + client: client, + kube: kube, + logger: logger, + eventChan: ch, } } From 89be2f64db80bd18ce5eb3c1257b90b0250a20e2 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Mon, 10 Feb 2025 16:43:43 -0500 Subject: [PATCH 22/86] check result of reconcile Signed-off-by: Aaron Choo --- .../controller/backend_security_policy.go | 2 +- .../backend_security_policy_test.go | 26 ++++++++++++------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index c93c71133..92d1ac8b2 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -72,7 +72,7 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr } func getBackendSecurityPolicyAuthOIDC(spec aigv1a1.BackendSecurityPolicySpec) *egv1a1.OIDC { - if spec.AWSCredentials != nil && spec.AWSCredentials.OIDCExchangeToken != nil { + if spec.AWSCredentials != nil { if spec.AWSCredentials.OIDCExchangeToken != nil { return &spec.AWSCredentials.OIDCExchangeToken.OIDC } diff --git a/internal/controller/backend_security_policy_test.go b/internal/controller/backend_security_policy_test.go index baec0300e..03cf25615 100644 --- a/internal/controller/backend_security_policy_test.go +++ b/internal/controller/backend_security_policy_test.go @@ -8,8 +8,6 @@ package controller import ( "context" "fmt" - "testing" - egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/stretchr/testify/require" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -19,6 +17,8 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/controller-runtime/pkg/reconcile" gwapiv1 "sigs.k8s.io/gateway-api/apis/v1" + "testing" + "time" aigv1a1 "github.com/envoyproxy/ai-gateway/api/v1alpha1" ) @@ -37,29 +37,35 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { Spec: aigv1a1.BackendSecurityPolicySpec{ Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ - Region: "us-east-1", - OIDCExchangeToken: &aigv1a1.AWSOIDCExchangeToken{}, + Region: "us-east-1", + OIDCExchangeToken: &aigv1a1.AWSOIDCExchangeToken{ + OIDC: egv1a1.OIDC{}, + }, }, }, }) require.NoError(t, err) - _, err = c.Reconcile(context.Background(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: backendSecurityPolicyName}}) + res, err := c.Reconcile(context.Background(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: backendSecurityPolicyName}}) require.NoError(t, err) + require.False(t, res.Requeue) item, ok := <-ch require.True(t, ok) require.IsType(t, &aigv1a1.BackendSecurityPolicy{}, item) require.Equal(t, backendSecurityPolicyName, item.(*aigv1a1.BackendSecurityPolicy).Name) require.Equal(t, namespace, item.(*aigv1a1.BackendSecurityPolicy).Namespace) - // Test backendSecurityPolicy with OIDC credentials have the annotation added - oidcBackendSecurityPolicy := &aigv1a1.BackendSecurityPolicy{} - err = cl.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName)}, oidcBackendSecurityPolicy) + res, err = c.Reconcile(context.Background(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName)}}) require.NoError(t, err) + require.True(t, res.Requeue) + require.Equal(t, res.RequeueAfter, time.Minute) + // Test the case where the BackendSecurityPolicy is being deleted. - err = cl.Delete(t.Context(), &aigv1a1.BackendSecurityPolicy{ObjectMeta: metav1.ObjectMeta{Name: backendSecurityPolicyName, Namespace: namespace}}) + err = cl.Delete(context.Background(), &aigv1a1.BackendSecurityPolicy{ObjectMeta: metav1.ObjectMeta{Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName), Namespace: namespace}}) require.NoError(t, err) - _, err = c.Reconcile(t.Context(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: backendSecurityPolicyName}}) + + res, err = c.Reconcile(context.Background(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName)}}) require.NoError(t, err) + require.False(t, res.Requeue) } func TestBackendSecurityController_GetBackendSecurityPolicyAuthOIDC(t *testing.T) { From 898f312fc429586313a8fb84b375b19f37297067 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Tue, 11 Feb 2025 13:36:44 -0500 Subject: [PATCH 23/86] sync bsp Signed-off-by: Aaron Choo --- internal/controller/sink.go | 4 +- internal/controller/sink_test.go | 183 ++++++++++++++++++++++++++++++- 2 files changed, 179 insertions(+), 8 deletions(-) diff --git a/internal/controller/sink.go b/internal/controller/sink.go index 8e1b5d6a7..f3a3caa22 100644 --- a/internal/controller/sink.go +++ b/internal/controller/sink.go @@ -273,7 +273,7 @@ func (c *configSink) syncBackendSecurityPolicy(ctx context.Context, bsp *aigv1a1 baseProvider := oauth.NewBaseProvider(c.client, c.logger) oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(baseProvider), oidc) - tokenRes, err := oidcProvider.FetchToken(context.TODO()) + tokenRes, err := oidcProvider.FetchToken(ctx) if err != nil { c.logger.Error(err, "failed to fetch OIDC provider token") return @@ -297,7 +297,7 @@ func (c *configSink) syncBackendSecurityPolicy(ctx context.Context, bsp *aigv1a1 if expired { token := tokenResponse.AccessToken - err = rotator.Rotate(context.Background(), awsCredentials.Region, awsCredentials.OIDCExchangeToken.AwsRoleArn, token) + err = rotator.Rotate(ctx, awsCredentials.Region, awsCredentials.OIDCExchangeToken.AwsRoleArn, token) if err != nil { c.logger.Error(err, "failed to rotate AWS OIDC exchange token") return diff --git a/internal/controller/sink_test.go b/internal/controller/sink_test.go index 41416e3b6..c4ddd8a4b 100644 --- a/internal/controller/sink_test.go +++ b/internal/controller/sink_test.go @@ -7,8 +7,12 @@ package controller import ( "context" + "encoding/json" "fmt" + oidcv3 "github.com/coreos/go-oidc/v3/oidc" "log/slog" + "net/http" + "net/http/httptest" "os" "strconv" "testing" @@ -176,6 +180,106 @@ func TestConfigSink_syncBackendSecurityPolicy(t *testing.T) { }) } +func TestConfigSink_syncBackendSecurityPolicyOIDC(t *testing.T) { + fakeClient := requireNewFakeClientWithIndexes(t) + eventChan := make(chan ConfigSinkEvent) + s := newConfigSink(fakeClient, nil, logr.Discard(), eventChan, "defaultExtProcImage", "debug") + + require.Empty(t, s.oidcTokenCache) + + // Test with OIDC backend + backend := aigv1a1.AIServiceBackend{ + ObjectMeta: metav1.ObjectMeta{Name: "potato", Namespace: "ns"}, + Spec: aigv1a1.AIServiceBackendSpec{ + BackendRef: gwapiv1.BackendObjectReference{Name: "some-backend2", Namespace: ptr.To[gwapiv1.Namespace]("ns")}, + BackendSecurityPolicyRef: &gwapiv1.LocalObjectReference{Name: "orange"}, + }, + } + require.NoError(t, fakeClient.Create(context.Background(), &backend, &client.CreateOptions{})) + + clientSecret := "secretName" + secretNamespace := "ns" + secret := corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: clientSecret, + Namespace: secretNamespace, + }, + Data: map[string][]byte{ + "client-secret": []byte("client-secret"), + }, + } + require.NoError(t, fakeClient.Create(context.Background(), &secret, &client.CreateOptions{})) + + secret = corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "orange", + Namespace: secretNamespace, + Annotations: map[string]string{ + "rotators/expiration-time": "3025-01-01T01:01:00.000-00:00", + }, + }, + Data: map[string][]byte{ + "credentials": []byte("credentials"), + }, + } + + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + println("123") + + w.Header().Add("Content-Type", "application/json") + type tokenJSON struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn string `json:"expires_in"` + } + b, err := json.Marshal(tokenJSON{AccessToken: "some-access-token", TokenType: "Bearer", ExpiresIn: "60"}) + require.NoError(t, err) + _, err = w.Write(b) + require.NoError(t, err) + + })) + defer tokenServer.Close() + + discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, err := w.Write([]byte(`{"issuer": "issuer", "token_endpoint": "token_endpoint", "authorization_endpoint": "authorization_endpoint", "jwks_uri": "jwks_uri", "scopes_supported": []}`)) + require.NoError(t, err) + })) + defer discoveryServer.Close() + + ctx := oidcv3.InsecureIssuerURLContext(context.Background(), discoveryServer.URL) + namespaceRef := gwapiv1.Namespace(secretNamespace) + + s.syncBackendSecurityPolicy(ctx, &aigv1a1.BackendSecurityPolicy{ + ObjectMeta: metav1.ObjectMeta{Name: "orange", Namespace: "ns"}, + Spec: aigv1a1.BackendSecurityPolicySpec{ + Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, + AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ + Region: "us-east-1", + OIDCExchangeToken: &aigv1a1.AWSOIDCExchangeToken{ + OIDC: egv1a1.OIDC{ + Provider: egv1a1.OIDCProvider{ + Issuer: discoveryServer.URL, + TokenEndpoint: &tokenServer.URL, + }, + ClientID: "some-client-id", + ClientSecret: gwapiv1.SecretObjectReference{ + Name: gwapiv1.ObjectName(clientSecret), + Namespace: &namespaceRef, + }, + }, + GrantType: "placeholder", + Aud: "placeholder", + AwsRoleArn: "placeholder", + }, + }, + }, + }) + require.Len(t, s.oidcTokenCache, 1) + token, ok := s.oidcTokenCache["orange.ns"] + require.True(t, ok) + require.Equal(t, "some-access-token", token.AccessToken) +} + func Test_newHTTPRoute(t *testing.T) { eventChan := make(chan ConfigSinkEvent) fakeClient := requireNewFakeClientWithIndexes(t) @@ -322,6 +426,16 @@ func Test_updateExtProcConfigMap(t *testing.T) { }, }, }, + { + ObjectMeta: metav1.ObjectMeta{Name: "some-backend-security-policy-3", Namespace: "ns"}, + Spec: aigv1a1.BackendSecurityPolicySpec{ + Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, + AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ + Region: "us-east-1", + OIDCExchangeToken: &aigv1a1.AWSOIDCExchangeToken{}, + }, + }, + }, } { err := fakeClient.Create(t.Context(), bsp, &client.CreateOptions{}) require.NoError(t, err) @@ -358,6 +472,13 @@ func Test_updateExtProcConfigMap(t *testing.T) { BackendSecurityPolicyRef: &gwapiv1.LocalObjectReference{Name: "some-backend-security-policy-2"}, }, }, + { + ObjectMeta: metav1.ObjectMeta{Name: "dog", Namespace: "ns"}, + Spec: aigv1a1.AIServiceBackendSpec{ + BackendRef: gwapiv1.BackendObjectReference{Name: "some-backend5", Namespace: ptr.To[gwapiv1.Namespace]("ns")}, + BackendSecurityPolicyRef: &gwapiv1.LocalObjectReference{Name: "some-backend-security-policy-3"}, + }, + }, } { err := fakeClient.Create(t.Context(), b, &client.CreateOptions{}) require.NoError(t, err) @@ -399,6 +520,14 @@ func Test_updateExtProcConfigMap(t *testing.T) { {Headers: []gwapiv1.HTTPHeaderMatch{{Name: aigv1a1.AIModelHeaderKey, Value: "another-ai-2"}}}, }, }, + { + BackendRefs: []aigv1a1.AIGatewayRouteRuleBackendRef{ + {Name: "dog", Weight: 1}, + }, + Matches: []aigv1a1.AIGatewayRouteRuleMatch{ + {Headers: []gwapiv1.HTTPHeaderMatch{{Name: aigv1a1.AIModelHeaderKey, Value: "another-ai-3"}}}, + }, + }, }, LLMRequestCosts: []aigv1a1.LLMRequestCost{ { @@ -455,6 +584,15 @@ func Test_updateExtProcConfigMap(t *testing.T) { }}}, Headers: []filterapi.HeaderMatch{{Name: aigv1a1.AIModelHeaderKey, Value: "another-ai-2"}}, }, + { + Backends: []filterapi.Backend{{Name: "dog.ns", Weight: 1, Auth: &filterapi.BackendAuth{ + AWSAuth: &filterapi.AWSAuth{ + CredentialFileName: "/etc/backend_security_policy/rule3-backref0-some-backend-security-policy-3/credentials", + Region: "us-east-1", + }, + }}}, + Headers: []filterapi.HeaderMatch{{Name: aigv1a1.AIModelHeaderKey, Value: "another-ai-3"}}, + }, }, LLMRequestCosts: []filterapi.LLMRequestCost{ {Type: filterapi.LLMRequestCostTypeOutputToken, MetadataKey: "output-token"}, @@ -668,6 +806,7 @@ func TestConfigSink_MountBackendSecurityPolicySecrets(t *testing.T) { {ObjectMeta: metav1.ObjectMeta{Name: "some-secret-policy-1"}}, {ObjectMeta: metav1.ObjectMeta{Name: "some-secret-policy-2"}}, {ObjectMeta: metav1.ObjectMeta{Name: "some-secret-policy-3"}}, + {ObjectMeta: metav1.ObjectMeta{Name: "aws-oidc-name"}}, } { require.NoError(t, fakeClient.Create(t.Context(), secret, &client.CreateOptions{})) } @@ -691,6 +830,16 @@ func TestConfigSink_MountBackendSecurityPolicySecrets(t *testing.T) { }, }, }, + { + ObjectMeta: metav1.ObjectMeta{Name: "aws-oidc-name", Namespace: "ns"}, + Spec: aigv1a1.BackendSecurityPolicySpec{ + Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, + AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ + OIDCExchangeToken: &aigv1a1.AWSOIDCExchangeToken{}, + Region: "us-east-1", + }, + }, + }, { ObjectMeta: metav1.ObjectMeta{Name: "some-other-backend-security-policy-aws", Namespace: "ns"}, Spec: aigv1a1.BackendSecurityPolicySpec{ @@ -719,7 +868,6 @@ func TestConfigSink_MountBackendSecurityPolicySecrets(t *testing.T) { BackendSecurityPolicyRef: &gwapiv1.LocalObjectReference{Name: "some-other-backend-security-policy-1"}, }, }, - { ObjectMeta: metav1.ObjectMeta{Name: "pineapple", Namespace: "ns"}, Spec: aigv1a1.AIServiceBackendSpec{ @@ -730,6 +878,16 @@ func TestConfigSink_MountBackendSecurityPolicySecrets(t *testing.T) { BackendSecurityPolicyRef: &gwapiv1.LocalObjectReference{Name: "some-other-backend-security-policy-aws"}, }, }, + { + ObjectMeta: metav1.ObjectMeta{Name: "dog", Namespace: "ns"}, + Spec: aigv1a1.AIServiceBackendSpec{ + APISchema: aigv1a1.VersionedAPISchema{ + Name: aigv1a1.APISchemaAWSBedrock, + }, + BackendRef: gwapiv1.BackendObjectReference{Name: "some-backend4", Namespace: ptr.To[gwapiv1.Namespace]("ns")}, + BackendSecurityPolicyRef: &gwapiv1.LocalObjectReference{Name: "aws-oidc-name"}, + }, + }, } { require.NoError(t, fakeClient.Create(t.Context(), backend, &client.CreateOptions{})) require.NotNil(t, s) @@ -755,6 +913,14 @@ func TestConfigSink_MountBackendSecurityPolicySecrets(t *testing.T) { {Headers: []gwapiv1.HTTPHeaderMatch{{Name: aigv1a1.AIModelHeaderKey, Value: "some-ai-2"}}}, }, }, + { + BackendRefs: []aigv1a1.AIGatewayRouteRuleBackendRef{ + {Name: "dog", Weight: 1}, + }, + Matches: []aigv1a1.AIGatewayRouteRuleMatch{ + {Headers: []gwapiv1.HTTPHeaderMatch{{Name: aigv1a1.AIModelHeaderKey, Value: "some-ai-3"}}}, + }, + }, }, }, } @@ -782,18 +948,23 @@ func TestConfigSink_MountBackendSecurityPolicySecrets(t *testing.T) { updatedSpec, err := s.mountBackendSecurityPolicySecrets(t.Context(), &spec, &aiGateway) require.NoError(t, err) - require.Len(t, updatedSpec.Volumes, 3) - require.Len(t, updatedSpec.Containers[0].VolumeMounts, 3) + require.Len(t, updatedSpec.Volumes, 4) + require.Len(t, updatedSpec.Containers[0].VolumeMounts, 4) // API Key. require.Equal(t, "some-secret-policy-1", updatedSpec.Volumes[1].VolumeSource.Secret.SecretName) require.Equal(t, "rule0-backref0-some-other-backend-security-policy-1", updatedSpec.Volumes[1].Name) require.Equal(t, "rule0-backref0-some-other-backend-security-policy-1", updatedSpec.Containers[0].VolumeMounts[1].Name) require.Equal(t, "/etc/backend_security_policy/rule0-backref0-some-other-backend-security-policy-1", updatedSpec.Containers[0].VolumeMounts[1].MountPath) - // AWS. + // AWS CredentialFile. require.Equal(t, "some-secret-policy-3", updatedSpec.Volumes[2].VolumeSource.Secret.SecretName) require.Equal(t, "rule1-backref0-some-other-backend-security-policy-aws", updatedSpec.Volumes[2].Name) require.Equal(t, "rule1-backref0-some-other-backend-security-policy-aws", updatedSpec.Containers[0].VolumeMounts[2].Name) require.Equal(t, "/etc/backend_security_policy/rule1-backref0-some-other-backend-security-policy-aws", updatedSpec.Containers[0].VolumeMounts[2].MountPath) + // AWS OIDC. + require.Equal(t, "aws-oidc-name", updatedSpec.Volumes[3].VolumeSource.Secret.SecretName) + require.Equal(t, "rule2-backref0-aws-oidc-name", updatedSpec.Volumes[3].Name) + require.Equal(t, "rule2-backref0-aws-oidc-name", updatedSpec.Containers[0].VolumeMounts[3].Name) + require.Equal(t, "/etc/backend_security_policy/rule2-backref0-aws-oidc-name", updatedSpec.Containers[0].VolumeMounts[3].MountPath) require.NoError(t, fakeClient.Delete(t.Context(), &aigv1a1.AIServiceBackend{ObjectMeta: metav1.ObjectMeta{Name: "apple", Namespace: "ns"}}, &client.DeleteOptions{})) @@ -815,8 +986,8 @@ func TestConfigSink_MountBackendSecurityPolicySecrets(t *testing.T) { updatedSpec, err = s.mountBackendSecurityPolicySecrets(t.Context(), &spec, &aiGateway) require.NoError(t, err) - require.Len(t, updatedSpec.Volumes, 3) - require.Len(t, updatedSpec.Containers[0].VolumeMounts, 3) + require.Len(t, updatedSpec.Volumes, 4) + require.Len(t, updatedSpec.Containers[0].VolumeMounts, 4) require.Equal(t, "some-secret-policy-2", updatedSpec.Volumes[1].VolumeSource.Secret.SecretName) require.Equal(t, "rule0-backref0-some-other-backend-security-policy-2", updatedSpec.Volumes[1].Name) require.Equal(t, "rule0-backref0-some-other-backend-security-policy-2", updatedSpec.Containers[0].VolumeMounts[1].Name) From 30efc02e5e6fef5b574056c401fdb289befd3e83 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Wed, 12 Feb 2025 21:11:45 -0500 Subject: [PATCH 24/86] add tests Signed-off-by: Aaron Choo --- .../controller/backend_security_policy.go | 2 +- .../controller/rotators/aws_oidc_rotator.go | 6 +- .../rotators/aws_oidc_rotator_test.go | 57 +++++++++++++++++++ internal/controller/sink.go | 14 ++--- internal/controller/sink_test.go | 40 ++++++++++--- 5 files changed, 100 insertions(+), 19 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 92d1ac8b2..db3e0ea35 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -58,7 +58,7 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr rotator, err := backendauthrotators.NewAWSOIDCRotator(b.client, b.kube, b.logger, backendSecurityPolicy.Namespace, backendSecurityPolicy.Name, preRotationWindow, region) if err != nil { b.logger.Error(err, "failed to create AWS OIDC rotator") - } else if expired, err := rotator.IsExpired(); err != nil && !expired { + } else if !rotator.IsExpired() { requeue = time.Until(*rotator.GetPreRotationTime()) if requeue.Seconds() == 0 { requeue = time.Minute diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index dc3ac8456..0bfaa9c10 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -84,12 +84,12 @@ func (r *AWSOIDCRotator) SetSTSOperations(ops STSOperations) { r.stsOps = ops } -func (r *AWSOIDCRotator) IsExpired() (bool, error) { +func (r *AWSOIDCRotator) IsExpired() bool { preRotationExpirationTime := r.GetPreRotationTime() if preRotationExpirationTime == nil { - return true, nil + return true } - return IsExpired(0, *preRotationExpirationTime), nil + return IsExpired(0, *preRotationExpirationTime) } func (r *AWSOIDCRotator) GetPreRotationTime() *time.Time { diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index 22882dcc4..c80dcadee 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -150,3 +150,60 @@ func TestAWS_OIDCRotator(t *testing.T) { assert.Contains(t, err.Error(), "failed to assume role") }) } + +func TestAWS_GetPreRotationTime(t *testing.T) { + scheme := runtime.NewScheme() + scheme.AddKnownTypes(corev1.SchemeGroupVersion, + &corev1.Secret{}, + ) + cl := fake.NewClientBuilder().WithScheme(scheme).Build() + awsOidcRotator := AWSOIDCRotator{ + client: cl, + backendSecurityPolicyNamespace: "default", + backendSecurityPolicyName: "test-secret", + } + + require.Nil(t, awsOidcRotator.GetPreRotationTime()) + + createTestAWSSecret(t, cl, "test-secret", "OLDKEY", "OLDSECRET", "OLDTOKEN", "default") + require.Nil(t, awsOidcRotator.GetPreRotationTime()) + + secret, err := LookupSecret(context.Background(), cl, "default", "test-secret") + require.NoError(t, err) + + expiredTime := time.Now().Add(-1 * time.Hour) + updateExpirationSecretAnnotation(secret, expiredTime) + require.NoError(t, cl.Update(context.Background(), secret)) + require.Equal(t, expiredTime.Format(time.RFC3339), awsOidcRotator.GetPreRotationTime().Format(time.RFC3339)) +} + +func TestAWS_IsExpired(t *testing.T) { + scheme := runtime.NewScheme() + scheme.AddKnownTypes(corev1.SchemeGroupVersion, + &corev1.Secret{}, + ) + cl := fake.NewClientBuilder().WithScheme(scheme).Build() + awsOidcRotator := AWSOIDCRotator{ + client: cl, + backendSecurityPolicyNamespace: "default", + backendSecurityPolicyName: "test-secret", + } + + require.True(t, awsOidcRotator.IsExpired()) + + createTestAWSSecret(t, cl, "test-secret", "OLDKEY", "OLDSECRET", "OLDTOKEN", "default") + require.Nil(t, awsOidcRotator.GetPreRotationTime()) + + secret, err := LookupSecret(context.Background(), cl, "default", "test-secret") + require.NoError(t, err) + + expiredTime := time.Now().Add(-1 * time.Hour) + updateExpirationSecretAnnotation(secret, expiredTime) + require.NoError(t, cl.Update(context.Background(), secret)) + require.True(t, awsOidcRotator.IsExpired()) + + hourFromNowTime := time.Now().Add(1 * time.Hour) + updateExpirationSecretAnnotation(secret, hourFromNowTime) + require.NoError(t, cl.Update(context.Background(), secret)) + require.False(t, awsOidcRotator.IsExpired()) +} diff --git a/internal/controller/sink.go b/internal/controller/sink.go index f3a3caa22..c014f1c2b 100644 --- a/internal/controller/sink.go +++ b/internal/controller/sink.go @@ -73,6 +73,7 @@ type configSink struct { extProcImagePullPolicy corev1.PullPolicy extProcLogLevel string eventChan chan ConfigSinkEvent + StsOP backendauthrotators.STSOperations oidcTokenCache map[string]*oauth2.Token } @@ -92,6 +93,7 @@ func newConfigSink( extProcImagePullPolicy: corev1.PullIfNotPresent, extProcLogLevel: extProcLogLevel, eventChan: eventChan, + StsOP: nil, oidcTokenCache: make(map[string]*oauth2.Token), } return c @@ -289,13 +291,11 @@ func (c *configSink) syncBackendSecurityPolicy(ctx context.Context, bsp *aigv1a1 return } - expired, err := rotator.IsExpired() - if err != nil { - c.logger.Error(err, "failed to check if AWS OIDC rotator is expired") - return - } - - if expired { + if rotator.IsExpired() { + // This is to abstract the real STS behavior for testing purpose. + if c.StsOP != nil { + rotator.SetSTSOperations(c.StsOP) + } token := tokenResponse.AccessToken err = rotator.Rotate(ctx, awsCredentials.Region, awsCredentials.OIDCExchangeToken.AwsRoleArn, token) if err != nil { diff --git a/internal/controller/sink_test.go b/internal/controller/sink_test.go index c4ddd8a4b..fe083b5e7 100644 --- a/internal/controller/sink_test.go +++ b/internal/controller/sink_test.go @@ -9,7 +9,11 @@ import ( "context" "encoding/json" "fmt" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/aws/aws-sdk-go-v2/service/sts/types" oidcv3 "github.com/coreos/go-oidc/v3/oidc" + backendauthrotators "github.com/envoyproxy/ai-gateway/internal/controller/rotators" "log/slog" "net/http" "net/http/httptest" @@ -180,6 +184,20 @@ func TestConfigSink_syncBackendSecurityPolicy(t *testing.T) { }) } +// MockSTSOperations implements the STSOperations interface for testing +type MockSTSOperations struct{} + +func (m *MockSTSOperations) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { + return &sts.AssumeRoleWithWebIdentityOutput{ + Credentials: &types.Credentials{ + AccessKeyId: aws.String("NEWKEY"), + SecretAccessKey: aws.String("NEWSECRET"), + SessionToken: aws.String("NEWTOKEN"), + Expiration: aws.Time(time.Now().Add(1 * time.Hour)), + }, + }, nil +} + func TestConfigSink_syncBackendSecurityPolicyOIDC(t *testing.T) { fakeClient := requireNewFakeClientWithIndexes(t) eventChan := make(chan ConfigSinkEvent) @@ -198,11 +216,11 @@ func TestConfigSink_syncBackendSecurityPolicyOIDC(t *testing.T) { require.NoError(t, fakeClient.Create(context.Background(), &backend, &client.CreateOptions{})) clientSecret := "secretName" - secretNamespace := "ns" + sharedNamespace := "ns" secret := corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: clientSecret, - Namespace: secretNamespace, + Namespace: sharedNamespace, }, Data: map[string][]byte{ "client-secret": []byte("client-secret"), @@ -213,19 +231,18 @@ func TestConfigSink_syncBackendSecurityPolicyOIDC(t *testing.T) { secret = corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: "orange", - Namespace: secretNamespace, + Namespace: sharedNamespace, Annotations: map[string]string{ - "rotators/expiration-time": "3025-01-01T01:01:00.000-00:00", + backendauthrotators.ExpirationTimeAnnotationKey: "2024-01-01T01:01:00.000-00:00", }, }, Data: map[string][]byte{ "credentials": []byte("credentials"), }, } + require.NoError(t, fakeClient.Create(context.Background(), &secret, &client.CreateOptions{})) tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - println("123") - w.Header().Add("Content-Type", "application/json") type tokenJSON struct { AccessToken string `json:"access_token"` @@ -247,10 +264,12 @@ func TestConfigSink_syncBackendSecurityPolicyOIDC(t *testing.T) { defer discoveryServer.Close() ctx := oidcv3.InsecureIssuerURLContext(context.Background(), discoveryServer.URL) - namespaceRef := gwapiv1.Namespace(secretNamespace) + namespaceRef := gwapiv1.Namespace(sharedNamespace) + + s.StsOP = &MockSTSOperations{} s.syncBackendSecurityPolicy(ctx, &aigv1a1.BackendSecurityPolicy{ - ObjectMeta: metav1.ObjectMeta{Name: "orange", Namespace: "ns"}, + ObjectMeta: metav1.ObjectMeta{Name: "orange", Namespace: sharedNamespace}, Spec: aigv1a1.BackendSecurityPolicySpec{ Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ @@ -278,6 +297,11 @@ func TestConfigSink_syncBackendSecurityPolicyOIDC(t *testing.T) { token, ok := s.oidcTokenCache["orange.ns"] require.True(t, ok) require.Equal(t, "some-access-token", token.AccessToken) + + updatedSecret, err := backendauthrotators.LookupSecret(context.Background(), fakeClient, sharedNamespace, "orange") + require.NoError(t, err) + require.NotEqualf(t, secret.Annotations[backendauthrotators.ExpirationTimeAnnotationKey], updatedSecret.Annotations[backendauthrotators.ExpirationTimeAnnotationKey], "expected updated expiration time annotation") + } func Test_newHTTPRoute(t *testing.T) { From d575d7b235ada97c8b08c93fdfa539adb1370672 Mon Sep 17 00:00:00 2001 From: Dan Sun Date: Wed, 12 Feb 2025 21:41:36 -0500 Subject: [PATCH 25/86] Fix format Signed-off-by: Dan Sun --- internal/controller/backend_security_policy_test.go | 5 +++-- internal/controller/sink_test.go | 12 +++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/internal/controller/backend_security_policy_test.go b/internal/controller/backend_security_policy_test.go index 03cf25615..e750903b9 100644 --- a/internal/controller/backend_security_policy_test.go +++ b/internal/controller/backend_security_policy_test.go @@ -8,6 +8,9 @@ package controller import ( "context" "fmt" + "testing" + "time" + egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/stretchr/testify/require" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -17,8 +20,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/controller-runtime/pkg/reconcile" gwapiv1 "sigs.k8s.io/gateway-api/apis/v1" - "testing" - "time" aigv1a1 "github.com/envoyproxy/ai-gateway/api/v1alpha1" ) diff --git a/internal/controller/sink_test.go b/internal/controller/sink_test.go index fe083b5e7..fba1a4672 100644 --- a/internal/controller/sink_test.go +++ b/internal/controller/sink_test.go @@ -9,11 +9,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/sts" - "github.com/aws/aws-sdk-go-v2/service/sts/types" - oidcv3 "github.com/coreos/go-oidc/v3/oidc" - backendauthrotators "github.com/envoyproxy/ai-gateway/internal/controller/rotators" "log/slog" "net/http" "net/http/httptest" @@ -22,6 +17,10 @@ import ( "testing" "time" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/aws/aws-sdk-go-v2/service/sts/types" + oidcv3 "github.com/coreos/go-oidc/v3/oidc" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/go-logr/logr" "github.com/stretchr/testify/require" @@ -39,6 +38,7 @@ import ( aigv1a1 "github.com/envoyproxy/ai-gateway/api/v1alpha1" "github.com/envoyproxy/ai-gateway/filterapi" + backendauthrotators "github.com/envoyproxy/ai-gateway/internal/controller/rotators" ) func requireNewFakeClientWithIndexes(t *testing.T) client.Client { @@ -253,7 +253,6 @@ func TestConfigSink_syncBackendSecurityPolicyOIDC(t *testing.T) { require.NoError(t, err) _, err = w.Write(b) require.NoError(t, err) - })) defer tokenServer.Close() @@ -301,7 +300,6 @@ func TestConfigSink_syncBackendSecurityPolicyOIDC(t *testing.T) { updatedSecret, err := backendauthrotators.LookupSecret(context.Background(), fakeClient, sharedNamespace, "orange") require.NoError(t, err) require.NotEqualf(t, secret.Annotations[backendauthrotators.ExpirationTimeAnnotationKey], updatedSecret.Annotations[backendauthrotators.ExpirationTimeAnnotationKey], "expected updated expiration time annotation") - } func Test_newHTTPRoute(t *testing.T) { From 5e966824b55fd2f7f1955f97e49896ee7f82b708 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 10:56:23 -0500 Subject: [PATCH 26/86] fix tests Signed-off-by: Aaron Choo --- internal/controller/backend_security_policy_test.go | 2 +- internal/controller/sink_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/controller/backend_security_policy_test.go b/internal/controller/backend_security_policy_test.go index e750903b9..800641ad4 100644 --- a/internal/controller/backend_security_policy_test.go +++ b/internal/controller/backend_security_policy_test.go @@ -58,7 +58,7 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { res, err = c.Reconcile(context.Background(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName)}}) require.NoError(t, err) require.True(t, res.Requeue) - require.Equal(t, res.RequeueAfter, time.Minute) + require.Equal(t, time.Minute, res.RequeueAfter) // Test the case where the BackendSecurityPolicy is being deleted. err = cl.Delete(context.Background(), &aigv1a1.BackendSecurityPolicy{ObjectMeta: metav1.ObjectMeta{Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName), Namespace: namespace}}) diff --git a/internal/controller/sink_test.go b/internal/controller/sink_test.go index fba1a4672..244d2c81d 100644 --- a/internal/controller/sink_test.go +++ b/internal/controller/sink_test.go @@ -187,7 +187,7 @@ func TestConfigSink_syncBackendSecurityPolicy(t *testing.T) { // MockSTSOperations implements the STSOperations interface for testing type MockSTSOperations struct{} -func (m *MockSTSOperations) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { +func (m *MockSTSOperations) AssumeRoleWithWebIdentity(_ context.Context, _ *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { return &sts.AssumeRoleWithWebIdentityOutput{ Credentials: &types.Credentials{ AccessKeyId: aws.String("NEWKEY"), From 0bb34d9a217623d69920152dc2453d2d48b70ead Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 13:51:56 -0500 Subject: [PATCH 27/86] add test for error case Signed-off-by: Aaron Choo --- .../controller/oauth/oidc_provider_test.go | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/internal/controller/oauth/oidc_provider_test.go b/internal/controller/oauth/oidc_provider_test.go index f8d714561..42d7518b8 100644 --- a/internal/controller/oauth/oidc_provider_test.go +++ b/internal/controller/oauth/oidc_provider_test.go @@ -21,6 +21,84 @@ func TestNewOIDCProvider(t *testing.T) { require.NotNil(t, NewOIDCProvider(nil, &egv1a1.OIDC{})) } +func TestOIDCProvider_GetOIDCProviderConfigErrors(t *testing.T) { + scheme := runtime.NewScheme() + scheme.AddKnownTypes(corev1.SchemeGroupVersion, + &corev1.Secret{}, + ) + cl := fake.NewClientBuilder().WithScheme(scheme).Build() + baseProvider := NewBaseProvider(cl, ctrl.Log) + require.NotNil(t, baseProvider) + + oidc := &egv1a1.OIDC{ + Provider: egv1a1.OIDCProvider{}, + ClientID: "some-client-id", + } + + var err error + missingIssuerTestServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, err = w.Write([]byte(`{"token_endpoint": "token_endpoint", "authorization_endpoint": "authorization_endpoint", "jwks_uri": "jwks_uri"}`)) + require.NoError(t, err) + })) + defer missingIssuerTestServer.Close() + + missingTokenURLTestServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, err = w.Write([]byte(`{"issuer": "issuer", "authorization_endpoint": "authorization_endpoint", "jwks_uri": "jwks_uri"}`)) + require.NoError(t, err) + })) + defer missingTokenURLTestServer.Close() + + oidcProvider := NewOIDCProvider(NewClientCredentialsProvider(baseProvider), oidc) + cancelledContext, cancel := context.WithCancel(context.Background()) + cancel() + + for _, testcase := range []struct { + name string + provider *OIDCProvider + url string + ctx context.Context + contains string + }{ + { + name: "context error", + provider: oidcProvider, + ctx: cancelledContext, + url: "", + contains: "context error before discovery", + }, + { + name: "failed to create go oidc", + provider: oidcProvider, + url: "", + ctx: context.Background(), + contains: "failed to create go-oidc provider", + }, + { + name: "config missing token url", + provider: oidcProvider, + url: missingTokenURLTestServer.URL, + ctx: oidcv3.InsecureIssuerURLContext(context.Background(), missingTokenURLTestServer.URL), + contains: "token_endpoint is required in OIDC provider config", + }, + { + name: "config missing issuer", + provider: oidcProvider, + url: missingIssuerTestServer.URL, + ctx: oidcv3.InsecureIssuerURLContext(context.Background(), missingIssuerTestServer.URL), + contains: "issuer is required in OIDC provider config", + }, + } { + t.Run(testcase.name, func(t *testing.T) { + oidcProvider := testcase.provider + config, supportedScope, err := oidcProvider.getOIDCProviderConfig(testcase.ctx, testcase.url) + require.Error(t, err) + require.Contains(t, err.Error(), testcase.contains) + require.Nil(t, config) + require.Nil(t, supportedScope) + }) + } +} + func TestOIDCProvider_GetOIDCProviderConfig(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, err := w.Write([]byte(`{"issuer": "issuer", "token_endpoint": "token_endpoint", "authorization_endpoint": "authorization_endpoint", "jwks_uri": "jwks_uri", "scopes_supported": []}`)) From 63360b54ee10c6e030a3e65ea19b56b126f82011 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 17:22:12 -0500 Subject: [PATCH 28/86] Update internal/controller/oauth/client_credentials_token_provider.go Co-authored-by: Takeshi Yoneda Signed-off-by: Aaron Choo --- internal/controller/oauth/client_credentials_token_provider.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index 3aa43c6f8..465294ecc 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -11,7 +11,7 @@ import ( corev1 "k8s.io/api/core/v1" ) -// ClientCredentialsTokenProvider implements the standard OAuth2 client credentials flow +// ClientCredentialsTokenProvider implements the standard OAuth2 client credentials flow. type ClientCredentialsTokenProvider struct { *BaseProvider TokenSource oauth2.TokenSource From 126cd26c7ea0475071e99e559561b7ddea5b1dea Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 18:35:33 -0500 Subject: [PATCH 29/86] Update internal/controller/oauth/client_credentials_token_provider.go Co-authored-by: Takeshi Yoneda Signed-off-by: Aaron Choo --- internal/controller/oauth/client_credentials_token_provider.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index 465294ecc..766a33d60 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -36,7 +36,7 @@ func (p *ClientCredentialsTokenProvider) FetchToken(ctx context.Context, oidc *e return p.getTokenWithClientCredentialConfig(ctx, oidc, clientSecret) } -// getTokenWithClientCredentialFlow fetches the oauth2 token with client credential config +// getTokenWithClientCredentialFlow fetches the oauth2 token with client credential config. func (p *ClientCredentialsTokenProvider) getTokenWithClientCredentialConfig(ctx context.Context, oidc *egv1a1.OIDC, clientSecret string) (*oauth2.Token, error) { if p.TokenSource == nil { oauth2Config := clientcredentials.Config{ From 1abe7124ec3ca1bbc0743edea3512651bc38e93a Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 18:37:02 -0500 Subject: [PATCH 30/86] Update internal/controller/oauth/client_credentials_token_provider.go Co-authored-by: Takeshi Yoneda Signed-off-by: Aaron Choo --- internal/controller/oauth/client_credentials_token_provider.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index 766a33d60..f4b757aff 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -53,7 +53,7 @@ func (p *ClientCredentialsTokenProvider) getTokenWithClientCredentialConfig(ctx return nil, fmt.Errorf("fail to get oauth2 token %w", err) } - // Handle expiration + // Handle expiration. if token.ExpiresIn > 0 { token.Expiry = time.Now().Add(time.Duration(token.ExpiresIn) * time.Second) } From de463dd3e607e29cec0ad94e957c0f8187000c6e Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 17:34:13 -0500 Subject: [PATCH 31/86] remove base config Signed-off-by: Aaron Choo --- .../controller/backend_security_policy.go | 6 ++---- .../client_credentials_token_provider.go | 9 +++++---- .../client_credentials_token_provider_test.go | 11 +++-------- .../controller/oauth/oidc_provider_test.go | 15 ++++----------- .../oauth/{base_provider.go => util.go} | 19 ++----------------- .../{base_provider_test.go => util_test.go} | 12 ++---------- internal/controller/sink.go | 3 +-- 7 files changed, 19 insertions(+), 56 deletions(-) rename internal/controller/oauth/{base_provider.go => util.go} (56%) rename internal/controller/oauth/{base_provider_test.go => util_test.go} (68%) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index db3e0ea35..4b98f801f 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -72,10 +72,8 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr } func getBackendSecurityPolicyAuthOIDC(spec aigv1a1.BackendSecurityPolicySpec) *egv1a1.OIDC { - if spec.AWSCredentials != nil { - if spec.AWSCredentials.OIDCExchangeToken != nil { - return &spec.AWSCredentials.OIDCExchangeToken.OIDC - } + if spec.AWSCredentials != nil && spec.AWSCredentials.OIDCExchangeToken != nil { + return &spec.AWSCredentials.OIDCExchangeToken.OIDC } return nil } diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index f4b757aff..946998480 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -3,6 +3,7 @@ package oauth import ( "context" "fmt" + "sigs.k8s.io/controller-runtime/pkg/client" "time" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" @@ -13,20 +14,20 @@ import ( // ClientCredentialsTokenProvider implements the standard OAuth2 client credentials flow. type ClientCredentialsTokenProvider struct { - *BaseProvider TokenSource oauth2.TokenSource + client client.Client } // NewClientCredentialsProvider creates a new client credentials provider -func NewClientCredentialsProvider(base *BaseProvider) *ClientCredentialsTokenProvider { +func NewClientCredentialsProvider(cl client.Client) *ClientCredentialsTokenProvider { return &ClientCredentialsTokenProvider{ - BaseProvider: base, + client: cl, } } // FetchToken gets the client secret from the secret reference and fetches the token from the provider token URL. func (p *ClientCredentialsTokenProvider) FetchToken(ctx context.Context, oidc *egv1a1.OIDC) (*oauth2.Token, error) { - clientSecret, err := p.getClientSecret(ctx, &corev1.SecretReference{ + clientSecret, err := getClientSecret(ctx, p.client, &corev1.SecretReference{ Name: string(oidc.ClientSecret.Name), Namespace: string(*oidc.ClientSecret.Namespace), }) diff --git a/internal/controller/oauth/client_credentials_token_provider_test.go b/internal/controller/oauth/client_credentials_token_provider_test.go index f4fe28986..37d8c4795 100644 --- a/internal/controller/oauth/client_credentials_token_provider_test.go +++ b/internal/controller/oauth/client_credentials_token_provider_test.go @@ -13,15 +13,12 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" - ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client/fake" gwapiv1 "sigs.k8s.io/gateway-api/apis/v1" ) // MockClientCredentialsTokenSource implements the standard OAuth2 client credentials flow -type MockClientCredentialsTokenSource struct { - *BaseProvider -} +type MockClientCredentialsTokenSource struct{} // FetchToken gets the client secret from the secret reference and fetches the token from provider token URL. func (m *MockClientCredentialsTokenSource) Token() (*oauth2.Token, error) { @@ -43,8 +40,6 @@ func TestClientCredentialsProvider_FetchToken(t *testing.T) { &corev1.Secret{}, ) cl := fake.NewClientBuilder().WithScheme(scheme).Build() - baseProvider := NewBaseProvider(cl, ctrl.Log) - require.NotNil(t, baseProvider) secretName, secretNamespace := "secret", "secret-ns" err := cl.Create(context.Background(), &corev1.Secret{ @@ -61,8 +56,8 @@ func TestClientCredentialsProvider_FetchToken(t *testing.T) { }) require.NoError(t, err) - clientCredentialProvider := NewClientCredentialsProvider(baseProvider) - clientCredentialProvider.TokenSource = &MockClientCredentialsTokenSource{BaseProvider: baseProvider} + clientCredentialProvider := NewClientCredentialsProvider(cl) + clientCredentialProvider.TokenSource = &MockClientCredentialsTokenSource{} require.NotNil(t, clientCredentialProvider) namespaceRef := gwapiv1.Namespace(secretNamespace) diff --git a/internal/controller/oauth/oidc_provider_test.go b/internal/controller/oauth/oidc_provider_test.go index 42d7518b8..5f5d5c8fc 100644 --- a/internal/controller/oauth/oidc_provider_test.go +++ b/internal/controller/oauth/oidc_provider_test.go @@ -12,7 +12,6 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" - ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client/fake" gwapiv1 "sigs.k8s.io/gateway-api/apis/v1" ) @@ -27,8 +26,6 @@ func TestOIDCProvider_GetOIDCProviderConfigErrors(t *testing.T) { &corev1.Secret{}, ) cl := fake.NewClientBuilder().WithScheme(scheme).Build() - baseProvider := NewBaseProvider(cl, ctrl.Log) - require.NotNil(t, baseProvider) oidc := &egv1a1.OIDC{ Provider: egv1a1.OIDCProvider{}, @@ -48,7 +45,7 @@ func TestOIDCProvider_GetOIDCProviderConfigErrors(t *testing.T) { })) defer missingTokenURLTestServer.Close() - oidcProvider := NewOIDCProvider(NewClientCredentialsProvider(baseProvider), oidc) + oidcProvider := NewOIDCProvider(NewClientCredentialsProvider(cl), oidc) cancelledContext, cancel := context.WithCancel(context.Background()) cancel() @@ -111,8 +108,6 @@ func TestOIDCProvider_GetOIDCProviderConfig(t *testing.T) { &corev1.Secret{}, ) cl := fake.NewClientBuilder().WithScheme(scheme).Build() - baseProvider := NewBaseProvider(cl, ctrl.Log) - require.NotNil(t, baseProvider) oidc := &egv1a1.OIDC{ Provider: egv1a1.OIDCProvider{ @@ -123,7 +118,7 @@ func TestOIDCProvider_GetOIDCProviderConfig(t *testing.T) { } ctx := oidcv3.InsecureIssuerURLContext(context.Background(), ts.URL) - oidcProvider := NewOIDCProvider(NewClientCredentialsProvider(baseProvider), oidc) + oidcProvider := NewOIDCProvider(NewClientCredentialsProvider(cl), oidc) config, supportedScope, err := oidcProvider.getOIDCProviderConfig(ctx, ts.URL) require.NoError(t, err) require.Equal(t, "token_endpoint", config.TokenURL) @@ -143,8 +138,6 @@ func TestOIDCProvider_FetchToken(t *testing.T) { &corev1.Secret{}, ) cl := fake.NewClientBuilder().WithScheme(scheme).Build() - baseProvider := NewBaseProvider(cl, ctrl.Log) - require.NotNil(t, baseProvider) secretName, secretNamespace := "secret", "secret-ns" err := cl.Create(context.Background(), &corev1.Secret{ @@ -172,8 +165,8 @@ func TestOIDCProvider_FetchToken(t *testing.T) { Namespace: &namespaceRef, }, } - clientCredentialProvider := NewClientCredentialsProvider(baseProvider) - clientCredentialProvider.TokenSource = &MockClientCredentialsTokenSource{BaseProvider: baseProvider} + clientCredentialProvider := NewClientCredentialsProvider(cl) + clientCredentialProvider.TokenSource = &MockClientCredentialsTokenSource{} require.NotNil(t, clientCredentialProvider) ctx := oidcv3.InsecureIssuerURLContext(context.Background(), ts.URL) oidcProvider := NewOIDCProvider(clientCredentialProvider, oidc) diff --git a/internal/controller/oauth/base_provider.go b/internal/controller/oauth/util.go similarity index 56% rename from internal/controller/oauth/base_provider.go rename to internal/controller/oauth/util.go index 177779df6..515b63d77 100644 --- a/internal/controller/oauth/base_provider.go +++ b/internal/controller/oauth/util.go @@ -4,29 +4,14 @@ import ( "context" "fmt" - "github.com/go-logr/logr" corev1 "k8s.io/api/core/v1" "sigs.k8s.io/controller-runtime/pkg/client" ) -// BaseProvider implements common OAuth functionality -type BaseProvider struct { - client client.Client - logger logr.Logger -} - -// NewBaseProvider creates a new base provider -func NewBaseProvider(client client.Client, logger logr.Logger) *BaseProvider { - return &BaseProvider{ - client: client, - logger: logger, - } -} - // getClientSecret retrieves the client secret from a Kubernetes secret -func (p *BaseProvider) getClientSecret(ctx context.Context, secretRef *corev1.SecretReference) (string, error) { +func getClientSecret(ctx context.Context, cl client.Client, secretRef *corev1.SecretReference) (string, error) { secret := &corev1.Secret{} - if err := p.client.Get(ctx, client.ObjectKey{ + if err := cl.Get(ctx, client.ObjectKey{ Namespace: secretRef.Namespace, Name: secretRef.Name, }, secret); err != nil { diff --git a/internal/controller/oauth/base_provider_test.go b/internal/controller/oauth/util_test.go similarity index 68% rename from internal/controller/oauth/base_provider_test.go rename to internal/controller/oauth/util_test.go index e7039cd20..3d69ae99a 100644 --- a/internal/controller/oauth/base_provider_test.go +++ b/internal/controller/oauth/util_test.go @@ -8,23 +8,15 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" - ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client/fake" ) -func TestNewBaseProvider(t *testing.T) { - scheme := runtime.NewScheme() - cl := fake.NewClientBuilder().WithScheme(scheme).Build() - require.NotNil(t, NewBaseProvider(cl, ctrl.Log)) -} - -func TestBaseProvider_GetClientSecret(t *testing.T) { +func TestGetClientSecret(t *testing.T) { scheme := runtime.NewScheme() scheme.AddKnownTypes(corev1.SchemeGroupVersion, &corev1.Secret{}, ) cl := fake.NewClientBuilder().WithScheme(scheme).Build() - baseProvider := NewBaseProvider(cl, ctrl.Log) secretName, secretNamespace := "secret", "secret-ns" err := cl.Create(context.Background(), &corev1.Secret{ @@ -41,7 +33,7 @@ func TestBaseProvider_GetClientSecret(t *testing.T) { }) require.NoError(t, err) - secret, err := baseProvider.getClientSecret(context.Background(), &corev1.SecretReference{ + secret, err := getClientSecret(context.Background(), cl, &corev1.SecretReference{ Name: secretName, Namespace: secretNamespace, }) diff --git a/internal/controller/sink.go b/internal/controller/sink.go index c014f1c2b..060ec9020 100644 --- a/internal/controller/sink.go +++ b/internal/controller/sink.go @@ -272,8 +272,7 @@ func (c *configSink) syncBackendSecurityPolicy(ctx context.Context, bsp *aigv1a1 if oidc := getBackendSecurityPolicyAuthOIDC(bsp.Spec); oidc != nil { tokenResponse, ok := c.oidcTokenCache[key] if !ok || backendauthrotators.IsExpired(preRotationWindow, tokenResponse.Expiry) { - baseProvider := oauth.NewBaseProvider(c.client, c.logger) - oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(baseProvider), oidc) + oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(c.client), oidc) tokenRes, err := oidcProvider.FetchToken(ctx) if err != nil { From 7adc0fee60844ac9ce7ba680814c6ce9fa04f67a Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 17:40:51 -0500 Subject: [PATCH 32/86] check if pointer is nil Signed-off-by: Aaron Choo --- .../oauth/client_credentials_token_provider.go | 12 ++++++++---- .../oauth/client_credentials_token_provider_test.go | 6 +++++- internal/controller/oauth/oidc_provider_test.go | 2 +- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index 946998480..a427665fd 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -14,7 +14,7 @@ import ( // ClientCredentialsTokenProvider implements the standard OAuth2 client credentials flow. type ClientCredentialsTokenProvider struct { - TokenSource oauth2.TokenSource + tokenSource oauth2.TokenSource client client.Client } @@ -27,6 +27,10 @@ func NewClientCredentialsProvider(cl client.Client) *ClientCredentialsTokenProvi // FetchToken gets the client secret from the secret reference and fetches the token from the provider token URL. func (p *ClientCredentialsTokenProvider) FetchToken(ctx context.Context, oidc *egv1a1.OIDC) (*oauth2.Token, error) { + if oidc == nil || oidc.ClientSecret.Namespace == nil { + return nil, fmt.Errorf("oidc or oidc-client-secret is nil") + } + clientSecret, err := getClientSecret(ctx, p.client, &corev1.SecretReference{ Name: string(oidc.ClientSecret.Name), Namespace: string(*oidc.ClientSecret.Namespace), @@ -39,7 +43,7 @@ func (p *ClientCredentialsTokenProvider) FetchToken(ctx context.Context, oidc *e // getTokenWithClientCredentialFlow fetches the oauth2 token with client credential config. func (p *ClientCredentialsTokenProvider) getTokenWithClientCredentialConfig(ctx context.Context, oidc *egv1a1.OIDC, clientSecret string) (*oauth2.Token, error) { - if p.TokenSource == nil { + if p.tokenSource == nil { oauth2Config := clientcredentials.Config{ ClientID: oidc.ClientID, ClientSecret: clientSecret, @@ -47,9 +51,9 @@ func (p *ClientCredentialsTokenProvider) getTokenWithClientCredentialConfig(ctx TokenURL: *oidc.Provider.TokenEndpoint, Scopes: oidc.Scopes, } - p.TokenSource = oauth2Config.TokenSource(ctx) + p.tokenSource = oauth2Config.TokenSource(ctx) } - token, err := p.TokenSource.Token() + token, err := p.tokenSource.Token() if err != nil { return nil, fmt.Errorf("fail to get oauth2 token %w", err) } diff --git a/internal/controller/oauth/client_credentials_token_provider_test.go b/internal/controller/oauth/client_credentials_token_provider_test.go index 37d8c4795..db760dfda 100644 --- a/internal/controller/oauth/client_credentials_token_provider_test.go +++ b/internal/controller/oauth/client_credentials_token_provider_test.go @@ -57,9 +57,13 @@ func TestClientCredentialsProvider_FetchToken(t *testing.T) { require.NoError(t, err) clientCredentialProvider := NewClientCredentialsProvider(cl) - clientCredentialProvider.TokenSource = &MockClientCredentialsTokenSource{} + clientCredentialProvider.tokenSource = &MockClientCredentialsTokenSource{} require.NotNil(t, clientCredentialProvider) + _, err = clientCredentialProvider.FetchToken(context.Background(), nil) + require.Error(t, err) + require.Contains(t, err.Error(), "oidc or oidc-client-secret is nil") + namespaceRef := gwapiv1.Namespace(secretNamespace) token, err := clientCredentialProvider.FetchToken(context.Background(), &egv1a1.OIDC{ Provider: egv1a1.OIDCProvider{ diff --git a/internal/controller/oauth/oidc_provider_test.go b/internal/controller/oauth/oidc_provider_test.go index 5f5d5c8fc..98d2a1cd2 100644 --- a/internal/controller/oauth/oidc_provider_test.go +++ b/internal/controller/oauth/oidc_provider_test.go @@ -166,7 +166,7 @@ func TestOIDCProvider_FetchToken(t *testing.T) { }, } clientCredentialProvider := NewClientCredentialsProvider(cl) - clientCredentialProvider.TokenSource = &MockClientCredentialsTokenSource{} + clientCredentialProvider.tokenSource = &MockClientCredentialsTokenSource{} require.NotNil(t, clientCredentialProvider) ctx := oidcv3.InsecureIssuerURLContext(context.Background(), ts.URL) oidcProvider := NewOIDCProvider(clientCredentialProvider, oidc) From 1fbe7bfc4d37df89f44ba7d8688b3ad9b790498d Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 17:52:45 -0500 Subject: [PATCH 33/86] add descriptions to all functions Signed-off-by: Aaron Choo --- internal/controller/backend_security_policy.go | 3 ++- internal/controller/rotators/aws_oidc_rotator.go | 5 ++++- internal/controller/rotators/common.go | 5 +++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 4b98f801f..ac72c3813 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -55,7 +55,7 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr var requeue time.Duration requeue = time.Minute region := backendSecurityPolicy.Spec.AWSCredentials.Region - rotator, err := backendauthrotators.NewAWSOIDCRotator(b.client, b.kube, b.logger, backendSecurityPolicy.Namespace, backendSecurityPolicy.Name, preRotationWindow, region) + rotator, err := backendauthrotators.NewAWSOIDCRotator(ctx, b.client, b.kube, b.logger, backendSecurityPolicy.Namespace, backendSecurityPolicy.Name, preRotationWindow, region) if err != nil { b.logger.Error(err, "failed to create AWS OIDC rotator") } else if !rotator.IsExpired() { @@ -71,6 +71,7 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr return } +// getBackendSecurityPolicyAuthOIDC returns the backendSecurityPolicy's OIDC pointer or nil func getBackendSecurityPolicyAuthOIDC(spec aigv1a1.BackendSecurityPolicySpec) *egv1a1.OIDC { if spec.AWSCredentials != nil && spec.AWSCredentials.OIDCExchangeToken != nil { return &spec.AWSCredentials.OIDCExchangeToken.OIDC diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index 0bfaa9c10..0864a1dc0 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -39,6 +39,7 @@ type AWSOIDCRotator struct { // NewAWSOIDCRotator creates a new AWS OIDC rotator with the specified configuration. // It initializes the AWS STS client and sets up the rotation channels. func NewAWSOIDCRotator( + ctx context.Context, client client.Client, kube kubernetes.Interface, logger logr.Logger, @@ -47,7 +48,7 @@ func NewAWSOIDCRotator( preRotationWindow time.Duration, region string, ) (*AWSOIDCRotator, error) { - cfg, err := defaultAWSConfig(context.Background()) + cfg, err := defaultAWSConfig(ctx) if err != nil { return nil, fmt.Errorf("failed to load AWS config: %w", err) } @@ -84,6 +85,7 @@ func (r *AWSOIDCRotator) SetSTSOperations(ops STSOperations) { r.stsOps = ops } +// IsExpired checks if the preRotation time is before the current time. func (r *AWSOIDCRotator) IsExpired() bool { preRotationExpirationTime := r.GetPreRotationTime() if preRotationExpirationTime == nil { @@ -92,6 +94,7 @@ func (r *AWSOIDCRotator) IsExpired() bool { return IsExpired(0, *preRotationExpirationTime) } +// GetPreRotationTime gets the expiration time minus the preRotation interval. func (r *AWSOIDCRotator) GetPreRotationTime() *time.Time { secret, err := LookupSecret(context.Background(), r.client, r.backendSecurityPolicyNamespace, r.backendSecurityPolicyName) if err != nil { diff --git a/internal/controller/rotators/common.go b/internal/controller/rotators/common.go index 2ba77c95a..e5c866c21 100644 --- a/internal/controller/rotators/common.go +++ b/internal/controller/rotators/common.go @@ -76,6 +76,7 @@ func GetExpirationSecretAnnotation(secret *corev1.Secret) (*time.Time, error) { return &expirationTime, nil } -func IsExpired(preRotationInterval time.Duration, expirationTime time.Time) bool { - return expirationTime.Add(-preRotationInterval).Before(time.Now()) +// IsExpired checks if the expired time minus duration buffer is before the current time. +func IsExpired(buffer time.Duration, expirationTime time.Time) bool { + return expirationTime.Add(-buffer).Before(time.Now()) } From fd2224005e34cbbcdfbed87c53bd139cb0920fd6 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 17:59:58 -0500 Subject: [PATCH 34/86] require arn Signed-off-by: Aaron Choo --- api/v1alpha1/api.go | 3 +++ .../controller/oauth/client_credentials_token_provider.go | 2 +- internal/controller/rotators/aws_oidc_rotator.go | 4 ---- internal/controller/sink.go | 2 +- .../crds/aigateway.envoyproxy.io_backendsecuritypolicies.yaml | 1 + 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/api/v1alpha1/api.go b/api/v1alpha1/api.go index 371ba4bd9..2cedc783b 100644 --- a/api/v1alpha1/api.go +++ b/api/v1alpha1/api.go @@ -489,6 +489,9 @@ type AWSOIDCExchangeToken struct { // AwsRoleArn is the AWS IAM Role with the permission to use specific resources in AWS account // which maps to the temporary AWS security credentials exchanged using the authentication token issued by OIDC provider. + // + // +kubebuilder:validation:Required + // +kubebuilder:validation:MinLength=1 AwsRoleArn string `json:"awsRoleArn"` } diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index a427665fd..27e65382b 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -3,13 +3,13 @@ package oauth import ( "context" "fmt" - "sigs.k8s.io/controller-runtime/pkg/client" "time" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" corev1 "k8s.io/api/core/v1" + "sigs.k8s.io/controller-runtime/pkg/client" ) // ClientCredentialsTokenProvider implements the standard OAuth2 client credentials flow. diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index 0864a1dc0..8d4dc359c 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -153,10 +153,6 @@ func (r *AWSOIDCRotator) Rotate(ctx context.Context, region, roleARN, token stri // assumeRoleWithToken exchanges an OIDC token for AWS credentials func (r *AWSOIDCRotator) assumeRoleWithToken(ctx context.Context, roleARN, token string) (*sts.AssumeRoleWithWebIdentityOutput, error) { - if roleARN == "" { - return nil, fmt.Errorf("role ARN is required in metadata") - } - return r.stsOps.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityInput{ RoleArn: aws.String(roleARN), WebIdentityToken: aws.String(token), diff --git a/internal/controller/sink.go b/internal/controller/sink.go index 060ec9020..555080305 100644 --- a/internal/controller/sink.go +++ b/internal/controller/sink.go @@ -284,7 +284,7 @@ func (c *configSink) syncBackendSecurityPolicy(ctx context.Context, bsp *aigv1a1 } awsCredentials := bsp.Spec.AWSCredentials - rotator, err := backendauthrotators.NewAWSOIDCRotator(c.client, c.kube, c.logger, bsp.Namespace, bsp.Name, preRotationWindow, awsCredentials.Region) + rotator, err := backendauthrotators.NewAWSOIDCRotator(ctx, c.client, c.kube, c.logger, bsp.Namespace, bsp.Name, preRotationWindow, awsCredentials.Region) if err != nil { c.logger.Error(err, "failed to create AWS OIDC rotator") return diff --git a/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_backendsecuritypolicies.yaml b/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_backendsecuritypolicies.yaml index fcbbb52cc..451b15e74 100644 --- a/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_backendsecuritypolicies.yaml +++ b/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_backendsecuritypolicies.yaml @@ -176,6 +176,7 @@ spec: description: |- AwsRoleArn is the AWS IAM Role with the permission to use specific resources in AWS account which maps to the temporary AWS security credentials exchanged using the authentication token issued by OIDC provider. + minLength: 1 type: string grantType: description: GrantType is the method application gets access From 068bf3e1ca71c1f080ea53e31228d4b668b6309c Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 18:19:00 -0500 Subject: [PATCH 35/86] pass ctx Signed-off-by: Aaron Choo --- internal/controller/rotators/aws_common.go | 2 +- internal/controller/rotators/aws_oidc_rotator.go | 16 +++++++++------- .../controller/rotators/aws_oidc_rotator_test.go | 7 +++++-- internal/controller/sink.go | 2 +- 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/internal/controller/rotators/aws_common.go b/internal/controller/rotators/aws_common.go index 5a726199b..09010070c 100644 --- a/internal/controller/rotators/aws_common.go +++ b/internal/controller/rotators/aws_common.go @@ -114,7 +114,7 @@ func parseAWSCredentialsFile(data string) *awsCredentialsFile { var currentCreds *awsCredentials - for _, line := range strings.Split(data, "\n") { + for line := range strings.Lines(data) { line = strings.TrimSpace(line) if line == "" { continue diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index 8d4dc359c..deb0ee07f 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -20,6 +20,8 @@ import ( // It manages the lifecycle of temporary AWS credentials obtained through OIDC token // exchange with AWS STS. type AWSOIDCRotator struct { + // ctx provides a user specified context + ctx context.Context // client is used for Kubernetes API operations client client.Client // kube provides additional Kubernetes API capabilities @@ -96,7 +98,7 @@ func (r *AWSOIDCRotator) IsExpired() bool { // GetPreRotationTime gets the expiration time minus the preRotation interval. func (r *AWSOIDCRotator) GetPreRotationTime() *time.Time { - secret, err := LookupSecret(context.Background(), r.client, r.backendSecurityPolicyNamespace, r.backendSecurityPolicyName) + secret, err := LookupSecret(r.ctx, r.client, r.backendSecurityPolicyNamespace, r.backendSecurityPolicyName) if err != nil { if !errors.IsNotFound(err) { return nil @@ -112,18 +114,18 @@ func (r *AWSOIDCRotator) GetPreRotationTime() *time.Time { } // Rotate implements the retrieval and storage of AWS sts credentials -func (r *AWSOIDCRotator) Rotate(ctx context.Context, region, roleARN, token string) error { +func (r *AWSOIDCRotator) Rotate(region, roleARN, token string) error { r.logger.Info("rotating AWS sts temporary credentials", "namespace", r.backendSecurityPolicyNamespace, "name", r.backendSecurityPolicyName) - result, err := r.assumeRoleWithToken(ctx, roleARN, token) + result, err := r.assumeRoleWithToken(roleARN, token) if err != nil { r.logger.Error(err, "failed to assume role", "role", roleARN, "ID", token) return err } - secret, err := LookupSecret(ctx, r.client, r.backendSecurityPolicyNamespace, r.backendSecurityPolicyName) + secret, err := LookupSecret(r.ctx, r.client, r.backendSecurityPolicyNamespace, r.backendSecurityPolicyName) if err != nil { if !errors.IsNotFound(err) { return err @@ -148,12 +150,12 @@ func (r *AWSOIDCRotator) Rotate(ctx context.Context, region, roleARN, token stri } updateAWSCredentialsInSecret(secret, credsFile) - return updateSecret(ctx, r.client, secret) + return updateSecret(r.ctx, r.client, secret) } // assumeRoleWithToken exchanges an OIDC token for AWS credentials -func (r *AWSOIDCRotator) assumeRoleWithToken(ctx context.Context, roleARN, token string) (*sts.AssumeRoleWithWebIdentityOutput, error) { - return r.stsOps.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityInput{ +func (r *AWSOIDCRotator) assumeRoleWithToken(roleARN, token string) (*sts.AssumeRoleWithWebIdentityOutput, error) { + return r.stsOps.AssumeRoleWithWebIdentity(r.ctx, &sts.AssumeRoleWithWebIdentityInput{ RoleArn: aws.String(roleARN), WebIdentityToken: aws.String(token), RoleSessionName: aws.String(fmt.Sprintf(awsSessionNameFormat, r.backendSecurityPolicyName)), diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index c80dcadee..614f51a90 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -116,13 +116,14 @@ func TestAWS_OIDCRotator(t *testing.T) { createClientSecret(t, "test-client-secret") awsOidcRotator := AWSOIDCRotator{ + ctx: context.Background(), client: cl, stsOps: mockSTS, backendSecurityPolicyNamespace: "default", backendSecurityPolicyName: "test-secret", } - require.NoError(t, awsOidcRotator.Rotate(context.Background(), "us-east1", "test", "NEW-OIDC-TOKEN")) + require.NoError(t, awsOidcRotator.Rotate("us-east1", "test", "NEW-OIDC-TOKEN")) verifyAWSSecretCredentials(t, cl, "default", "test-secret", "NEWKEY", "NEWSECRET", "NEWTOKEN", "default") }) @@ -140,12 +141,13 @@ func TestAWS_OIDCRotator(t *testing.T) { }, } awsOidcRotator := AWSOIDCRotator{ + ctx: context.Background(), client: cl, stsOps: mockSTS, backendSecurityPolicyNamespace: "default", backendSecurityPolicyName: "test-secret", } - err := awsOidcRotator.Rotate(context.Background(), "us-east1", "test", "NEW-OIDC-TOKEN") + err := awsOidcRotator.Rotate("us-east1", "test", "NEW-OIDC-TOKEN") require.Error(t, err) assert.Contains(t, err.Error(), "failed to assume role") }) @@ -158,6 +160,7 @@ func TestAWS_GetPreRotationTime(t *testing.T) { ) cl := fake.NewClientBuilder().WithScheme(scheme).Build() awsOidcRotator := AWSOIDCRotator{ + ctx: context.Background(), client: cl, backendSecurityPolicyNamespace: "default", backendSecurityPolicyName: "test-secret", diff --git a/internal/controller/sink.go b/internal/controller/sink.go index 555080305..dde3c94b5 100644 --- a/internal/controller/sink.go +++ b/internal/controller/sink.go @@ -296,7 +296,7 @@ func (c *configSink) syncBackendSecurityPolicy(ctx context.Context, bsp *aigv1a1 rotator.SetSTSOperations(c.StsOP) } token := tokenResponse.AccessToken - err = rotator.Rotate(ctx, awsCredentials.Region, awsCredentials.OIDCExchangeToken.AwsRoleArn, token) + err = rotator.Rotate(awsCredentials.Region, awsCredentials.OIDCExchangeToken.AwsRoleArn, token) if err != nil { c.logger.Error(err, "failed to rotate AWS OIDC exchange token") return From d812d6c00af61ffcad6d0b27643c729ee0137c8b Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 18:30:47 -0500 Subject: [PATCH 36/86] rename backendauthrotators to rotators Signed-off-by: Aaron Choo --- internal/controller/backend_security_policy.go | 4 ++-- internal/controller/rotators/aws_common.go | 4 ++-- internal/controller/rotators/aws_common_test.go | 2 +- internal/controller/rotators/aws_oidc_rotator.go | 2 +- internal/controller/rotators/aws_oidc_rotator_test.go | 2 +- internal/controller/rotators/common.go | 2 +- internal/controller/rotators/common_test.go | 2 +- internal/controller/sink.go | 8 ++++---- internal/controller/sink_test.go | 8 ++++---- 9 files changed, 17 insertions(+), 17 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index ac72c3813..41d71ba22 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -17,7 +17,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" aigv1a1 "github.com/envoyproxy/ai-gateway/api/v1alpha1" - backendauthrotators "github.com/envoyproxy/ai-gateway/internal/controller/rotators" + "github.com/envoyproxy/ai-gateway/internal/controller/rotators" ) // backendSecurityPolicyController implements [reconcile.TypedReconciler] for [aigv1a1.BackendSecurityPolicy]. @@ -55,7 +55,7 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr var requeue time.Duration requeue = time.Minute region := backendSecurityPolicy.Spec.AWSCredentials.Region - rotator, err := backendauthrotators.NewAWSOIDCRotator(ctx, b.client, b.kube, b.logger, backendSecurityPolicy.Namespace, backendSecurityPolicy.Name, preRotationWindow, region) + rotator, err := rotators.NewAWSOIDCRotator(ctx, b.client, b.kube, b.logger, backendSecurityPolicy.Namespace, backendSecurityPolicy.Name, preRotationWindow, region) if err != nil { b.logger.Error(err, "failed to create AWS OIDC rotator") } else if !rotator.IsExpired() { diff --git a/internal/controller/rotators/aws_common.go b/internal/controller/rotators/aws_common.go index 09010070c..afd4c029d 100644 --- a/internal/controller/rotators/aws_common.go +++ b/internal/controller/rotators/aws_common.go @@ -1,5 +1,5 @@ /* -Package backendauthrotators provides credential rotation implementations. +Package rotators provides credential rotation implementations. This file contains common AWS functionality shared between different AWS credential rotators. It provides: 1. AWS Client Interfaces and Implementations: @@ -14,7 +14,7 @@ rotators. It provides: - Standard timeouts and delays - Session name formatting */ -package backendauthrotators +package rotators import ( "context" diff --git a/internal/controller/rotators/aws_common_test.go b/internal/controller/rotators/aws_common_test.go index 9abc0eb66..dd156c055 100644 --- a/internal/controller/rotators/aws_common_test.go +++ b/internal/controller/rotators/aws_common_test.go @@ -1,4 +1,4 @@ -package backendauthrotators +package rotators import ( "fmt" diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index deb0ee07f..7cd588715 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -1,4 +1,4 @@ -package backendauthrotators +package rotators import ( "context" diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index 614f51a90..eb696116c 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -1,4 +1,4 @@ -package backendauthrotators +package rotators import ( "context" diff --git a/internal/controller/rotators/common.go b/internal/controller/rotators/common.go index e5c866c21..f880201f1 100644 --- a/internal/controller/rotators/common.go +++ b/internal/controller/rotators/common.go @@ -1,4 +1,4 @@ -package backendauthrotators +package rotators import ( "context" diff --git a/internal/controller/rotators/common_test.go b/internal/controller/rotators/common_test.go index c6498965c..cf3ceea99 100644 --- a/internal/controller/rotators/common_test.go +++ b/internal/controller/rotators/common_test.go @@ -1,4 +1,4 @@ -package backendauthrotators +package rotators import ( "context" diff --git a/internal/controller/sink.go b/internal/controller/sink.go index dde3c94b5..e731cfafc 100644 --- a/internal/controller/sink.go +++ b/internal/controller/sink.go @@ -30,7 +30,7 @@ import ( aigv1a1 "github.com/envoyproxy/ai-gateway/api/v1alpha1" "github.com/envoyproxy/ai-gateway/filterapi" "github.com/envoyproxy/ai-gateway/internal/controller/oauth" - backendauthrotators "github.com/envoyproxy/ai-gateway/internal/controller/rotators" + "github.com/envoyproxy/ai-gateway/internal/controller/rotators" "github.com/envoyproxy/ai-gateway/internal/llmcostcel" ) @@ -73,7 +73,7 @@ type configSink struct { extProcImagePullPolicy corev1.PullPolicy extProcLogLevel string eventChan chan ConfigSinkEvent - StsOP backendauthrotators.STSOperations + StsOP rotators.STSOperations oidcTokenCache map[string]*oauth2.Token } @@ -271,7 +271,7 @@ func (c *configSink) syncBackendSecurityPolicy(ctx context.Context, bsp *aigv1a1 if oidc := getBackendSecurityPolicyAuthOIDC(bsp.Spec); oidc != nil { tokenResponse, ok := c.oidcTokenCache[key] - if !ok || backendauthrotators.IsExpired(preRotationWindow, tokenResponse.Expiry) { + if !ok || rotators.IsExpired(preRotationWindow, tokenResponse.Expiry) { oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(c.client), oidc) tokenRes, err := oidcProvider.FetchToken(ctx) @@ -284,7 +284,7 @@ func (c *configSink) syncBackendSecurityPolicy(ctx context.Context, bsp *aigv1a1 } awsCredentials := bsp.Spec.AWSCredentials - rotator, err := backendauthrotators.NewAWSOIDCRotator(ctx, c.client, c.kube, c.logger, bsp.Namespace, bsp.Name, preRotationWindow, awsCredentials.Region) + rotator, err := rotators.NewAWSOIDCRotator(ctx, c.client, c.kube, c.logger, bsp.Namespace, bsp.Name, preRotationWindow, awsCredentials.Region) if err != nil { c.logger.Error(err, "failed to create AWS OIDC rotator") return diff --git a/internal/controller/sink_test.go b/internal/controller/sink_test.go index 244d2c81d..4870cf251 100644 --- a/internal/controller/sink_test.go +++ b/internal/controller/sink_test.go @@ -38,7 +38,7 @@ import ( aigv1a1 "github.com/envoyproxy/ai-gateway/api/v1alpha1" "github.com/envoyproxy/ai-gateway/filterapi" - backendauthrotators "github.com/envoyproxy/ai-gateway/internal/controller/rotators" + "github.com/envoyproxy/ai-gateway/internal/controller/rotators" ) func requireNewFakeClientWithIndexes(t *testing.T) client.Client { @@ -233,7 +233,7 @@ func TestConfigSink_syncBackendSecurityPolicyOIDC(t *testing.T) { Name: "orange", Namespace: sharedNamespace, Annotations: map[string]string{ - backendauthrotators.ExpirationTimeAnnotationKey: "2024-01-01T01:01:00.000-00:00", + rotators.ExpirationTimeAnnotationKey: "2024-01-01T01:01:00.000-00:00", }, }, Data: map[string][]byte{ @@ -297,9 +297,9 @@ func TestConfigSink_syncBackendSecurityPolicyOIDC(t *testing.T) { require.True(t, ok) require.Equal(t, "some-access-token", token.AccessToken) - updatedSecret, err := backendauthrotators.LookupSecret(context.Background(), fakeClient, sharedNamespace, "orange") + updatedSecret, err := rotators.LookupSecret(context.Background(), fakeClient, sharedNamespace, "orange") require.NoError(t, err) - require.NotEqualf(t, secret.Annotations[backendauthrotators.ExpirationTimeAnnotationKey], updatedSecret.Annotations[backendauthrotators.ExpirationTimeAnnotationKey], "expected updated expiration time annotation") + require.NotEqualf(t, secret.Annotations[rotators.ExpirationTimeAnnotationKey], updatedSecret.Annotations[rotators.ExpirationTimeAnnotationKey], "expected updated expiration time annotation") } func Test_newHTTPRoute(t *testing.T) { From 724408e295cf263cfc3dd8f5c5b67516b62cb64d Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 18:35:24 -0500 Subject: [PATCH 37/86] update constant name Signed-off-by: Aaron Choo --- internal/controller/rotators/aws_common.go | 6 +++--- internal/controller/rotators/aws_common_test.go | 2 +- internal/controller/rotators/aws_oidc_rotator.go | 2 +- internal/controller/rotators/aws_oidc_rotator_test.go | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/internal/controller/rotators/aws_common.go b/internal/controller/rotators/aws_common.go index afd4c029d..cfffac9c3 100644 --- a/internal/controller/rotators/aws_common.go +++ b/internal/controller/rotators/aws_common.go @@ -30,8 +30,8 @@ import ( // Common constants for AWS operations const ( - // credentialsKey is the key used to store AWS credentials in Kubernetes secrets - credentialsKey = "credentials" + // awsCredentialsKey is the key used to store AWS credentials in Kubernetes secrets + awsCredentialsKey = "credentials" // awsSessionNameFormat is the format string for AWS session names awsSessionNameFormat = "ai-gateway-%s" ) @@ -193,5 +193,5 @@ func updateAWSCredentialsInSecret(secret *corev1.Secret, creds *awsCredentialsFi if secret.Data == nil { secret.Data = make(map[string][]byte) } - secret.Data[credentialsKey] = []byte(formatAWSCredentialsFile(creds)) + secret.Data[awsCredentialsKey] = []byte(formatAWSCredentialsFile(creds)) } diff --git a/internal/controller/rotators/aws_common_test.go b/internal/controller/rotators/aws_common_test.go index dd156c055..602e57076 100644 --- a/internal/controller/rotators/aws_common_test.go +++ b/internal/controller/rotators/aws_common_test.go @@ -70,7 +70,7 @@ func TestUpdateAWSCredentialsInSecret(t *testing.T) { updateAWSCredentialsInSecret(secret, &awsCredentialsFile{profiles: map[string]*awsCredentials{"default": &credentials}}) require.Len(t, secret.Data, 1) - val, ok := secret.Data[credentialsKey] + val, ok := secret.Data[awsCredentialsKey] require.True(t, ok) require.NotEmpty(t, val) } diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index 7cd588715..d119dbcaa 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -137,7 +137,7 @@ func (r *AWSOIDCRotator) Rotate(region, roleARN, token string) error { // For now have profile as default profile := "default" - credsFile := &awsCredentialsFile{ + credsFile := awsCredentialsFile{ profiles: map[string]*awsCredentials{ profile: { profile: profile, diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index eb696116c..03bc0371c 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -28,7 +28,7 @@ func createTestAWSSecret(t *testing.T, client client.Client, name string, access profile = "default" } data := map[string][]byte{ - credentialsKey: []byte(fmt.Sprintf("[%s]\naws_access_key_id = %s\naws_secret_access_key = %s\naws_session_token = %s\nregion = us-west-2", + awsCredentialsKey: []byte(fmt.Sprintf("[%s]\naws_access_key_id = %s\naws_secret_access_key = %s\naws_session_token = %s\nregion = us-west-2", profile, accessKey, secretKey, sessionToken)), } err := client.Create(context.Background(), &corev1.Secret{ @@ -48,7 +48,7 @@ func verifyAWSSecretCredentials(t *testing.T, client client.Client, namespace, s } secret, err := LookupSecret(context.Background(), client, namespace, secretName) require.NoError(t, err) - creds := parseAWSCredentialsFile(string(secret.Data[credentialsKey])) + creds := parseAWSCredentialsFile(string(secret.Data[awsCredentialsKey])) require.NotNil(t, creds) require.Contains(t, creds.profiles, profile) assert.Equal(t, expectedKeyID, creds.profiles[profile].accessKeyID) From 16987b58a8243ca70edd7f93e651a5db12127256 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 18:44:09 -0500 Subject: [PATCH 38/86] update comments Signed-off-by: Aaron Choo --- .../controller/backend_security_policy.go | 2 +- .../client_credentials_token_provider.go | 2 +- internal/controller/oauth/oidc_provider.go | 32 +++++++++---------- internal/controller/oauth/types.go | 2 +- internal/controller/oauth/util.go | 2 +- internal/controller/rotators/aws_common.go | 24 +++++++------- .../controller/rotators/aws_oidc_rotator.go | 26 +++++++-------- internal/controller/rotators/common.go | 10 +++--- 8 files changed, 50 insertions(+), 50 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 41d71ba22..7e78dbc0b 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -71,7 +71,7 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr return } -// getBackendSecurityPolicyAuthOIDC returns the backendSecurityPolicy's OIDC pointer or nil +// getBackendSecurityPolicyAuthOIDC returns the backendSecurityPolicy's OIDC pointer or nil. func getBackendSecurityPolicyAuthOIDC(spec aigv1a1.BackendSecurityPolicySpec) *egv1a1.OIDC { if spec.AWSCredentials != nil && spec.AWSCredentials.OIDCExchangeToken != nil { return &spec.AWSCredentials.OIDCExchangeToken.OIDC diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index 27e65382b..cf6dd29f5 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -18,7 +18,7 @@ type ClientCredentialsTokenProvider struct { client client.Client } -// NewClientCredentialsProvider creates a new client credentials provider +// NewClientCredentialsProvider creates a new client credentials provider. func NewClientCredentialsProvider(cl client.Client) *ClientCredentialsTokenProvider { return &ClientCredentialsTokenProvider{ client: cl, diff --git a/internal/controller/oauth/oidc_provider.go b/internal/controller/oauth/oidc_provider.go index 8140865ef..25e690245 100644 --- a/internal/controller/oauth/oidc_provider.go +++ b/internal/controller/oauth/oidc_provider.go @@ -11,14 +11,14 @@ import ( "golang.org/x/oauth2" ) -// OIDCProvider extends ClientCredentialsTokenProvider with OIDC support +// OIDCProvider extends ClientCredentialsTokenProvider with OIDC support. type OIDCProvider struct { tokenProvider *ClientCredentialsTokenProvider httpClient *http.Client oidcCredential *egv1a1.OIDC } -// NewOIDCProvider creates a new OIDC-aware provider +// NewOIDCProvider creates a new OIDC-aware provider. func NewOIDCProvider(tokenProvider *ClientCredentialsTokenProvider, oidcCredentials *egv1a1.OIDC) *OIDCProvider { return &OIDCProvider{ tokenProvider: tokenProvider, @@ -27,9 +27,9 @@ func NewOIDCProvider(tokenProvider *ClientCredentialsTokenProvider, oidcCredenti } } -// getOIDCProviderConfig retrieves or creates OIDC config for the given issuer URL -func (p *OIDCProvider) getOIDCProviderConfig(ctx context.Context, issuerURL string) (*oidc.ProviderConfig, *[]string, error) { - // Check context before proceeding +// getOIDCProviderConfig retrieves or creates OIDC config for the given issuer URL. +func (p *OIDCProvider) getOIDCProviderConfig(ctx context.Context, issuerURL string) (*oidc.ProviderConfig, []string, error) { + // Check context before proceeding in case context is cancelled because of timeout. if err := ctx.Err(); err != nil { return nil, nil, fmt.Errorf("context error before discovery: %w", err) } @@ -44,7 +44,7 @@ func (p *OIDCProvider) getOIDCProviderConfig(ctx context.Context, issuerURL stri return nil, nil, fmt.Errorf("failed to decode provider config claims %q: %w", issuerURL, err) } - // Unmarshall supported scopes + // Unmarshall supported scopes. var claims struct { SupportedScopes []string `json:"scopes_supported"` } @@ -52,7 +52,7 @@ func (p *OIDCProvider) getOIDCProviderConfig(ctx context.Context, issuerURL stri return nil, nil, fmt.Errorf("failed to decode provider scope supported claims: %w", err) } - // Validate required fields + // Validate required fields. if config.IssuerURL == "" { return nil, nil, fmt.Errorf("issuer is required in OIDC provider config") } @@ -60,32 +60,32 @@ func (p *OIDCProvider) getOIDCProviderConfig(ctx context.Context, issuerURL stri return nil, nil, fmt.Errorf("token_endpoint is required in OIDC provider config") } - return &config, &claims.SupportedScopes, nil + return &config, claims.SupportedScopes, nil } -// FetchToken retrieves and validates tokens using the client credentials flow with OIDC support +// FetchToken retrieves and validates tokens using the client credentials flow with OIDC support. func (p *OIDCProvider) FetchToken(ctx context.Context) (*oauth2.Token, error) { - // If issuer URL is provided, fetch OIDC metadata + // If issuer URL is provided, fetch OIDC metadata. if issuerURL := p.oidcCredential.Provider.Issuer; issuerURL != "" { config, supportedScopes, err := p.getOIDCProviderConfig(ctx, issuerURL) if err != nil { return nil, fmt.Errorf("failed to get OIDC config: %w", err) } - // Use discovered token endpoint if not explicitly provided + // Use discovered token endpoint if not explicitly provided. if p.oidcCredential.Provider.TokenEndpoint == nil { p.oidcCredential.Provider.TokenEndpoint = &config.TokenURL } - // Add discovered scopes if available - if supportedScopes != nil && len(*supportedScopes) > 0 { + // Add discovered scopes if available. + if supportedScopes != nil && len(supportedScopes) > 0 { requestedScopes := make(map[string]bool) for _, scope := range p.oidcCredential.Scopes { requestedScopes[scope] = true } - // Add supported scopes that aren't already requested - for _, scope := range *supportedScopes { + // Add supported scopes that aren't already requested. + for _, scope := range supportedScopes { if !requestedScopes[scope] { p.oidcCredential.Scopes = append(p.oidcCredential.Scopes, scope) } @@ -93,7 +93,7 @@ func (p *OIDCProvider) FetchToken(ctx context.Context) (*oauth2.Token, error) { } } - // Get base token response + // Get base token response. token, err := p.tokenProvider.FetchToken(ctx, p.oidcCredential) if err != nil { return nil, fmt.Errorf("failed to get token: %w", err) diff --git a/internal/controller/oauth/types.go b/internal/controller/oauth/types.go index fa3c3795c..9775c1eb3 100644 --- a/internal/controller/oauth/types.go +++ b/internal/controller/oauth/types.go @@ -7,7 +7,7 @@ import ( "golang.org/x/oauth2" ) -// TokenProvider defines the interface for OAuth token providers +// TokenProvider defines the interface for OAuth token providers. type TokenProvider interface { FetchToken(ctx context.Context, oidc *egv1a1.OIDC) (*oauth2.Token, error) } diff --git a/internal/controller/oauth/util.go b/internal/controller/oauth/util.go index 515b63d77..c38dcf691 100644 --- a/internal/controller/oauth/util.go +++ b/internal/controller/oauth/util.go @@ -8,7 +8,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" ) -// getClientSecret retrieves the client secret from a Kubernetes secret +// getClientSecret retrieves the client secret from a Kubernetes secret. func getClientSecret(ctx context.Context, cl client.Client, secretRef *corev1.SecretReference) (string, error) { secret := &corev1.Secret{} if err := cl.Get(ctx, client.ObjectKey{ diff --git a/internal/controller/rotators/aws_common.go b/internal/controller/rotators/aws_common.go index cfffac9c3..bcaf59e23 100644 --- a/internal/controller/rotators/aws_common.go +++ b/internal/controller/rotators/aws_common.go @@ -28,11 +28,11 @@ import ( corev1 "k8s.io/api/core/v1" ) -// Common constants for AWS operations +// Common constants for AWS operations. const ( - // awsCredentialsKey is the key used to store AWS credentials in Kubernetes secrets + // awsCredentialsKey is the key used to store AWS credentials in Kubernetes secrets. awsCredentialsKey = "credentials" - // awsSessionNameFormat is the format string for AWS session names + // awsSessionNameFormat is the format string for AWS session names. awsSessionNameFormat = "ai-gateway-%s" ) @@ -48,7 +48,7 @@ func defaultAWSConfig(ctx context.Context) (aws.Config, error) { // This interface encapsulates the STS API operations needed for OIDC token exchange // and role assumption. type STSOperations interface { - // AssumeRoleWithWebIdentity exchanges a web identity token for temporary AWS credentials + // AssumeRoleWithWebIdentity exchanges a web identity token for temporary AWS credentials. AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) } @@ -77,15 +77,15 @@ func (c *STSClient) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.A // session token and region configuration. It maps to a single profile in an // AWS credentials file. type awsCredentials struct { - // profile is the name of the credentials profile + // profile is the name of the credentials profile. profile string - // accessKeyID is the AWS access key ID + // accessKeyID is the AWS access key ID. accessKeyID string - // secretAccessKey is the AWS secret access key + // secretAccessKey is the AWS secret access key. secretAccessKey string - // sessionToken is the optional AWS session token for temporary credentials + // sessionToken is the optional AWS session token for temporary credentials. sessionToken string - // region is the optional AWS region for the profile + // region is the optional AWS region for the profile. region string } @@ -93,7 +93,7 @@ type awsCredentials struct { // multiple credential profiles. It provides a structured way to manage // multiple sets of AWS credentials. type awsCredentialsFile struct { - // profiles maps profile names to their respective credentials + // profiles maps profile names to their respective credentials. profiles map[string]*awsCredentials } @@ -163,7 +163,7 @@ func parseAWSCredentialsFile(data string) *awsCredentialsFile { func formatAWSCredentialsFile(file *awsCredentialsFile) string { var builder strings.Builder - // Sort profiles to ensure consistent output + // Sort profiles to ensure consistent output. profileNames := make([]string, 0, len(file.profiles)) for profileName := range file.profiles { profileNames = append(profileNames, profileName) @@ -188,7 +188,7 @@ func formatAWSCredentialsFile(file *awsCredentialsFile) string { return builder.String() } -// updateAWSCredentialsInSecret updates AWS credentials in a secret +// updateAWSCredentialsInSecret updates AWS credentials in a secret. func updateAWSCredentialsInSecret(secret *corev1.Secret, creds *awsCredentialsFile) { if secret.Data == nil { secret.Data = make(map[string][]byte) diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index d119dbcaa..01c906e2f 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -20,21 +20,21 @@ import ( // It manages the lifecycle of temporary AWS credentials obtained through OIDC token // exchange with AWS STS. type AWSOIDCRotator struct { - // ctx provides a user specified context + // ctx provides a user specified context. ctx context.Context - // client is used for Kubernetes API operations + // client is used for Kubernetes API operations. client client.Client - // kube provides additional Kubernetes API capabilities + // kube provides additional Kubernetes API capabilities. kube kubernetes.Interface - // logger is used for structured logging + // logger is used for structured logging. logger logr.Logger - // stsOps provides AWS STS operations interface + // stsOps provides AWS STS operations interface. stsOps STSOperations - // backendSecurityPolicyName provides name of backend security policy + // backendSecurityPolicyName provides name of backend security policy. backendSecurityPolicyName string - // backendSecurityPolicyNamespace provides namespace of backend security policy + // backendSecurityPolicyNamespace provides namespace of backend security policy. backendSecurityPolicyNamespace string - // preRotationWindow specifies how long before expiry to rotate + // preRotationWindow specifies how long before expiry to rotate. preRotationWindow time.Duration } @@ -82,7 +82,7 @@ func NewAWSOIDCRotator( }, nil } -// SetSTSOperations sets the STS operations implementation - primarily used for testing +// SetSTSOperations sets the STS operations implementation - primarily used for testing. func (r *AWSOIDCRotator) SetSTSOperations(ops STSOperations) { r.stsOps = ops } @@ -113,7 +113,7 @@ func (r *AWSOIDCRotator) GetPreRotationTime() *time.Time { return &preRotationTime } -// Rotate implements the retrieval and storage of AWS sts credentials +// Rotate implements the retrieval and storage of AWS sts credentials. func (r *AWSOIDCRotator) Rotate(region, roleARN, token string) error { r.logger.Info("rotating AWS sts temporary credentials", "namespace", r.backendSecurityPolicyNamespace, @@ -135,7 +135,7 @@ func (r *AWSOIDCRotator) Rotate(region, roleARN, token string) error { updateExpirationSecretAnnotation(secret, *result.Credentials.Expiration) - // For now have profile as default + // For now have profile as default. profile := "default" credsFile := awsCredentialsFile{ profiles: map[string]*awsCredentials{ @@ -149,11 +149,11 @@ func (r *AWSOIDCRotator) Rotate(region, roleARN, token string) error { }, } - updateAWSCredentialsInSecret(secret, credsFile) + updateAWSCredentialsInSecret(secret, &credsFile) return updateSecret(r.ctx, r.client, secret) } -// assumeRoleWithToken exchanges an OIDC token for AWS credentials +// assumeRoleWithToken exchanges an OIDC token for AWS credentials. func (r *AWSOIDCRotator) assumeRoleWithToken(roleARN, token string) (*sts.AssumeRoleWithWebIdentityOutput, error) { return r.stsOps.AssumeRoleWithWebIdentity(r.ctx, &sts.AssumeRoleWithWebIdentityInput{ RoleArn: aws.String(roleARN), diff --git a/internal/controller/rotators/common.go b/internal/controller/rotators/common.go index f880201f1..b5963014a 100644 --- a/internal/controller/rotators/common.go +++ b/internal/controller/rotators/common.go @@ -13,7 +13,7 @@ import ( const ExpirationTimeAnnotationKey = "rotators/expiration-time" -// newSecret creates a new secret struct (does not persist to k8s) +// newSecret creates a new secret struct (does not persist to k8s). func newSecret(namespace, name string) *corev1.Secret { return &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ @@ -25,7 +25,7 @@ func newSecret(namespace, name string) *corev1.Secret { } } -// updateSecret updates an existing secret or creates a new one +// updateSecret updates an existing secret or creates a new one. func updateSecret(ctx context.Context, k8sClient client.Client, secret *corev1.Secret) error { if secret.ResourceVersion == "" { if err := k8sClient.Create(ctx, secret); err != nil { @@ -39,7 +39,7 @@ func updateSecret(ctx context.Context, k8sClient client.Client, secret *corev1.S return nil } -// LookupSecret retrieves an existing secret +// LookupSecret retrieves an existing secret. func LookupSecret(ctx context.Context, k8sClient client.Client, namespace, name string) (*corev1.Secret, error) { secret := &corev1.Secret{} if err := k8sClient.Get(ctx, client.ObjectKey{ @@ -54,7 +54,7 @@ func LookupSecret(ctx context.Context, k8sClient client.Client, namespace, name return secret, nil } -// updateExpirationSecretAnnotation will set the expiration time of credentials set in secret annotation +// updateExpirationSecretAnnotation will set the expiration time of credentials set in secret annotation. func updateExpirationSecretAnnotation(secret *corev1.Secret, updateTime time.Time) { if secret.Annotations == nil { secret.Annotations = make(map[string]string) @@ -62,7 +62,7 @@ func updateExpirationSecretAnnotation(secret *corev1.Secret, updateTime time.Tim secret.Annotations[ExpirationTimeAnnotationKey] = updateTime.Format(time.RFC3339) } -// GetExpirationSecretAnnotation will get the expiration time of credentials set in secret annotation +// GetExpirationSecretAnnotation will get the expiration time of credentials set in secret annotation. func GetExpirationSecretAnnotation(secret *corev1.Secret) (*time.Time, error) { expirationTimeAnnotationKey, ok := secret.Annotations[ExpirationTimeAnnotationKey] if !ok { From be80c3969c2027196f0971961e16c8e28f8a8103 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 19:49:34 -0500 Subject: [PATCH 39/86] test updating scopes Signed-off-by: Aaron Choo --- internal/controller/oauth/oidc_provider.go | 2 +- internal/controller/oauth/oidc_provider_test.go | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/internal/controller/oauth/oidc_provider.go b/internal/controller/oauth/oidc_provider.go index 25e690245..2d9b002e6 100644 --- a/internal/controller/oauth/oidc_provider.go +++ b/internal/controller/oauth/oidc_provider.go @@ -78,7 +78,7 @@ func (p *OIDCProvider) FetchToken(ctx context.Context) (*oauth2.Token, error) { } // Add discovered scopes if available. - if supportedScopes != nil && len(supportedScopes) > 0 { + if len(supportedScopes) > 0 { requestedScopes := make(map[string]bool) for _, scope := range p.oidcCredential.Scopes { requestedScopes[scope] = true diff --git a/internal/controller/oauth/oidc_provider_test.go b/internal/controller/oauth/oidc_provider_test.go index 98d2a1cd2..36809a717 100644 --- a/internal/controller/oauth/oidc_provider_test.go +++ b/internal/controller/oauth/oidc_provider_test.go @@ -98,7 +98,7 @@ func TestOIDCProvider_GetOIDCProviderConfigErrors(t *testing.T) { func TestOIDCProvider_GetOIDCProviderConfig(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - _, err := w.Write([]byte(`{"issuer": "issuer", "token_endpoint": "token_endpoint", "authorization_endpoint": "authorization_endpoint", "jwks_uri": "jwks_uri", "scopes_supported": []}`)) + _, err := w.Write([]byte(`{"issuer": "issuer", "token_endpoint": "token_endpoint", "authorization_endpoint": "authorization_endpoint", "jwks_uri": "jwks_uri", "scopes_supported": ["one", "openid"]}`)) require.NoError(t, err) })) defer ts.Close() @@ -114,6 +114,7 @@ func TestOIDCProvider_GetOIDCProviderConfig(t *testing.T) { Issuer: ts.URL, TokenEndpoint: &ts.URL, }, + Scopes: []string{"two", "openid"}, ClientID: "some-client-id", } @@ -123,12 +124,12 @@ func TestOIDCProvider_GetOIDCProviderConfig(t *testing.T) { require.NoError(t, err) require.Equal(t, "token_endpoint", config.TokenURL) require.Equal(t, "issuer", config.IssuerURL) - require.Empty(t, supportedScope) + require.Len(t, supportedScope, 2) } func TestOIDCProvider_FetchToken(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - _, err := w.Write([]byte(`{"issuer": "issuer", "token_endpoint": "token_endpoint", "authorization_endpoint": "authorization_endpoint", "jwks_uri": "jwks_uri", "scopes_supported": []}`)) + _, err := w.Write([]byte(`{"issuer": "issuer", "token_endpoint": "token_endpoint", "authorization_endpoint": "authorization_endpoint", "jwks_uri": "jwks_uri", "scopes_supported": ["one", "openid"]}`)) require.NoError(t, err) })) defer ts.Close() @@ -164,15 +165,19 @@ func TestOIDCProvider_FetchToken(t *testing.T) { Name: gwapiv1.ObjectName(secretName), Namespace: &namespaceRef, }, + Scopes: []string{"two", "openid"}, } clientCredentialProvider := NewClientCredentialsProvider(cl) clientCredentialProvider.tokenSource = &MockClientCredentialsTokenSource{} require.NotNil(t, clientCredentialProvider) ctx := oidcv3.InsecureIssuerURLContext(context.Background(), ts.URL) oidcProvider := NewOIDCProvider(clientCredentialProvider, oidc) + require.Len(t, oidcProvider.oidcCredential.Scopes, 2) + token, err := oidcProvider.FetchToken(ctx) require.NoError(t, err) require.Equal(t, "token", token.AccessToken) require.Equal(t, "Bearer", token.Type()) require.Equal(t, int64(3600), token.ExpiresIn) + require.Len(t, oidcProvider.oidcCredential.Scopes, 3) } From 7aa0b14067f310c1aab3ec2eafb3fe5eb3915b8b Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 19:51:45 -0500 Subject: [PATCH 40/86] region is required Signed-off-by: Aaron Choo --- internal/controller/rotators/aws_oidc_rotator.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index 01c906e2f..106789ba8 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -55,9 +55,7 @@ func NewAWSOIDCRotator( return nil, fmt.Errorf("failed to load AWS config: %w", err) } - if region != "" { - cfg.Region = region - } + cfg.Region = region if proxyURL := os.Getenv("AI_GATEWY_STS_PROXY_URL"); proxyURL != "" { cfg.HTTPClient = &http.Client{ From 567a6c06dc748079bc2085e49581a7b6d495fc9a Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 20:07:41 -0500 Subject: [PATCH 41/86] STSClient -> stsClient and STSOperations -> STSClient Signed-off-by: Aaron Choo --- internal/controller/rotators/aws_common.go | 16 +++++++++------- internal/controller/rotators/aws_common_test.go | 1 - internal/controller/rotators/aws_oidc_rotator.go | 4 ++-- .../controller/rotators/aws_oidc_rotator_test.go | 4 ++-- internal/controller/sink.go | 2 +- 5 files changed, 14 insertions(+), 13 deletions(-) diff --git a/internal/controller/rotators/aws_common.go b/internal/controller/rotators/aws_common.go index bcaf59e23..c2196efa0 100644 --- a/internal/controller/rotators/aws_common.go +++ b/internal/controller/rotators/aws_common.go @@ -44,32 +44,34 @@ func defaultAWSConfig(ctx context.Context) (aws.Config, error) { ) } -// STSOperations defines the interface for AWS STS operations required by the rotators. +// STSClient defines the interface for AWS STS operations required by the rotators. // This interface encapsulates the STS API operations needed for OIDC token exchange // and role assumption. -type STSOperations interface { +type STSClient interface { // AssumeRoleWithWebIdentity exchanges a web identity token for temporary AWS credentials. AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) } -// STSClient implements the STSOperations interface using the AWS SDK v2. +// stsClient implements the STSOperations interface using the AWS SDK v2. // It provides a concrete implementation for STS operations using the official AWS SDK. -type STSClient struct { +type stsClient struct { client *sts.Client } // NewSTSClient creates a new STSClient with the given AWS config. // The client is configured with the provided AWS configuration, which should // include appropriate credentials and region settings. -func NewSTSClient(cfg aws.Config) *STSClient { - return &STSClient{ +func NewSTSClient(cfg aws.Config) STSClient { + return &stsClient{ client: sts.NewFromConfig(cfg), } } // AssumeRoleWithWebIdentity implements the STSOperations interface by exchanging // a web identity token for temporary AWS credentials. -func (c *STSClient) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { +// +// This implements [STSClient.AssumeRoleWithWebIdentity]. +func (c *stsClient) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { return c.client.AssumeRoleWithWebIdentity(ctx, params, optFns...) } diff --git a/internal/controller/rotators/aws_common_test.go b/internal/controller/rotators/aws_common_test.go index 602e57076..3bb420ec0 100644 --- a/internal/controller/rotators/aws_common_test.go +++ b/internal/controller/rotators/aws_common_test.go @@ -12,7 +12,6 @@ import ( func TestNewSTSClient(t *testing.T) { stsClient := NewSTSClient(aws.Config{Region: "us-west-2"}) require.NotNil(t, stsClient) - require.NotNil(t, stsClient.client) } func TestParseAWSCredentialsFile(t *testing.T) { diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index 106789ba8..b6a71d328 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -29,7 +29,7 @@ type AWSOIDCRotator struct { // logger is used for structured logging. logger logr.Logger // stsOps provides AWS STS operations interface. - stsOps STSOperations + stsOps STSClient // backendSecurityPolicyName provides name of backend security policy. backendSecurityPolicyName string // backendSecurityPolicyNamespace provides namespace of backend security policy. @@ -81,7 +81,7 @@ func NewAWSOIDCRotator( } // SetSTSOperations sets the STS operations implementation - primarily used for testing. -func (r *AWSOIDCRotator) SetSTSOperations(ops STSOperations) { +func (r *AWSOIDCRotator) SetSTSOperations(ops STSClient) { r.stsOps = ops } diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index 03bc0371c..0e987c3bf 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -94,7 +94,7 @@ func (m *MockSTSOperations) AssumeRoleWithWebIdentity(ctx context.Context, param func TestAWS_OIDCRotator(t *testing.T) { t.Run("basic rotation", func(t *testing.T) { - var mockSTS STSOperations = &MockSTSOperations{ + var mockSTS STSClient = &MockSTSOperations{ assumeRoleWithWebIdentityFunc: func(_ context.Context, _ *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { return &sts.AssumeRoleWithWebIdentityOutput{ Credentials: &types.Credentials{ @@ -135,7 +135,7 @@ func TestAWS_OIDCRotator(t *testing.T) { cl := fake.NewClientBuilder().WithScheme(scheme).Build() createTestAWSSecret(t, cl, "test-secret", "OLDKEY", "OLDSECRET", "OLDTOKEN", "default") createClientSecret(t, "test-client-secret") - var mockSTS STSOperations = &MockSTSOperations{ + var mockSTS STSClient = &MockSTSOperations{ assumeRoleWithWebIdentityFunc: func(_ context.Context, _ *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { return nil, fmt.Errorf("failed to assume role") }, diff --git a/internal/controller/sink.go b/internal/controller/sink.go index e731cfafc..34ba1250f 100644 --- a/internal/controller/sink.go +++ b/internal/controller/sink.go @@ -73,7 +73,7 @@ type configSink struct { extProcImagePullPolicy corev1.PullPolicy extProcLogLevel string eventChan chan ConfigSinkEvent - StsOP rotators.STSOperations + StsOP rotators.STSClient oidcTokenCache map[string]*oauth2.Token } From 8bc85268082855749f75dbf7b44fb306f7270ac7 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 20:07:57 -0500 Subject: [PATCH 42/86] fix indentation Signed-off-by: Aaron Choo --- manifests/charts/ai-gateway-helm/values.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/manifests/charts/ai-gateway-helm/values.yaml b/manifests/charts/ai-gateway-helm/values.yaml index b6f3d6a36..26c7d35c2 100644 --- a/manifests/charts/ai-gateway-helm/values.yaml +++ b/manifests/charts/ai-gateway-helm/values.yaml @@ -44,11 +44,11 @@ controller: podEnv: {} # Example of volumes # - mountPath: /placeholder/path - # name: volume-name - # subPath: placeholder-sub-path - # configmap: - # defaultMode: placeholder - # name: configmap-name + # name: volume-name + # subPath: placeholder-sub-path + # configmap: + # defaultMode: placeholder + # name: configmap-name volumes: {} service: type: ClusterIP From 9f8142419374e869e5f67fa8b1c04d55bd31737e Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 20:44:03 -0500 Subject: [PATCH 43/86] createSecert -> createBSPSecret and fixed tests Signed-off-by: Aaron Choo --- internal/controller/rotators/aws_oidc_rotator.go | 6 +++--- .../controller/rotators/aws_oidc_rotator_test.go | 10 +++++----- internal/controller/rotators/common.go | 12 +++++++++--- internal/controller/rotators/common_test.go | 6 +++--- internal/controller/sink.go | 2 +- internal/controller/sink_test.go | 4 ++-- 6 files changed, 23 insertions(+), 17 deletions(-) diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index b6a71d328..7dc4ef101 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -96,7 +96,7 @@ func (r *AWSOIDCRotator) IsExpired() bool { // GetPreRotationTime gets the expiration time minus the preRotation interval. func (r *AWSOIDCRotator) GetPreRotationTime() *time.Time { - secret, err := LookupSecret(r.ctx, r.client, r.backendSecurityPolicyNamespace, r.backendSecurityPolicyName) + secret, err := LookupSecret(r.ctx, r.client, r.backendSecurityPolicyNamespace, GetBSPSecretName(r.backendSecurityPolicyName)) if err != nil { if !errors.IsNotFound(err) { return nil @@ -123,12 +123,12 @@ func (r *AWSOIDCRotator) Rotate(region, roleARN, token string) error { return err } - secret, err := LookupSecret(r.ctx, r.client, r.backendSecurityPolicyNamespace, r.backendSecurityPolicyName) + secret, err := LookupSecret(r.ctx, r.client, r.backendSecurityPolicyNamespace, GetBSPSecretName(r.backendSecurityPolicyName)) if err != nil { if !errors.IsNotFound(err) { return err } - secret = newSecret(r.backendSecurityPolicyNamespace, r.backendSecurityPolicyName) + secret = newBSPSecret(r.backendSecurityPolicyNamespace, r.backendSecurityPolicyName) } updateExpirationSecretAnnotation(secret, *result.Credentials.Expiration) diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index 0e987c3bf..5909db307 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -23,7 +23,7 @@ import ( // ----------------------------------------------------------------------------- // createTestAWSSecret creates a test secret with given credentials -func createTestAWSSecret(t *testing.T, client client.Client, name string, accessKey, secretKey, sessionToken string, profile string) { +func createTestAWSSecret(t *testing.T, client client.Client, bspName string, accessKey, secretKey, sessionToken string, profile string) { if profile == "" { profile = "default" } @@ -33,7 +33,7 @@ func createTestAWSSecret(t *testing.T, client client.Client, name string, access } err := client.Create(context.Background(), &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ - Name: name, + Name: GetBSPSecretName(bspName), Namespace: "default", }, Data: data, @@ -46,7 +46,7 @@ func verifyAWSSecretCredentials(t *testing.T, client client.Client, namespace, s if profile == "" { profile = "default" } - secret, err := LookupSecret(context.Background(), client, namespace, secretName) + secret, err := LookupSecret(context.Background(), client, namespace, GetBSPSecretName(secretName)) require.NoError(t, err) creds := parseAWSCredentialsFile(string(secret.Data[awsCredentialsKey])) require.NotNil(t, creds) @@ -171,7 +171,7 @@ func TestAWS_GetPreRotationTime(t *testing.T) { createTestAWSSecret(t, cl, "test-secret", "OLDKEY", "OLDSECRET", "OLDTOKEN", "default") require.Nil(t, awsOidcRotator.GetPreRotationTime()) - secret, err := LookupSecret(context.Background(), cl, "default", "test-secret") + secret, err := LookupSecret(context.Background(), cl, "default", GetBSPSecretName("test-secret")) require.NoError(t, err) expiredTime := time.Now().Add(-1 * time.Hour) @@ -197,7 +197,7 @@ func TestAWS_IsExpired(t *testing.T) { createTestAWSSecret(t, cl, "test-secret", "OLDKEY", "OLDSECRET", "OLDTOKEN", "default") require.Nil(t, awsOidcRotator.GetPreRotationTime()) - secret, err := LookupSecret(context.Background(), cl, "default", "test-secret") + secret, err := LookupSecret(context.Background(), cl, "default", GetBSPSecretName("test-secret")) require.NoError(t, err) expiredTime := time.Now().Add(-1 * time.Hour) diff --git a/internal/controller/rotators/common.go b/internal/controller/rotators/common.go index b5963014a..a207b6ff1 100644 --- a/internal/controller/rotators/common.go +++ b/internal/controller/rotators/common.go @@ -12,12 +12,13 @@ import ( ) const ExpirationTimeAnnotationKey = "rotators/expiration-time" +const RotatorSecretNamePrefix = "ai-eg-bsp" -// newSecret creates a new secret struct (does not persist to k8s). -func newSecret(namespace, name string) *corev1.Secret { +// newBSPSecret creates a new secret struct (does not persist to k8s). +func newBSPSecret(namespace, bspName string) *corev1.Secret { return &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ - Name: name, + Name: GetBSPSecretName(bspName), Namespace: namespace, }, Type: corev1.SecretTypeOpaque, @@ -80,3 +81,8 @@ func GetExpirationSecretAnnotation(secret *corev1.Secret) (*time.Time, error) { func IsExpired(buffer time.Duration, expirationTime time.Time) bool { return expirationTime.Add(-buffer).Before(time.Now()) } + +// GetBSPSecretName will return the bspName with rotator prefix. +func GetBSPSecretName(bspName string) string { + return fmt.Sprintf("%s-%s", RotatorSecretNamePrefix, bspName) +} diff --git a/internal/controller/rotators/common_test.go b/internal/controller/rotators/common_test.go index cf3ceea99..9053acb94 100644 --- a/internal/controller/rotators/common_test.go +++ b/internal/controller/rotators/common_test.go @@ -13,13 +13,13 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client/fake" ) -func TestNewSecret(t *testing.T) { +func TestNewBSPSecret(t *testing.T) { name := "test" namespace := "test-namespace" - secret := newSecret(namespace, name) + secret := newBSPSecret(namespace, name) require.NotNil(t, secret) - require.Equal(t, name, secret.Name) + require.Equal(t, GetBSPSecretName(name), secret.Name) require.Equal(t, namespace, secret.Namespace) require.NotNil(t, secret.Data) } diff --git a/internal/controller/sink.go b/internal/controller/sink.go index 34ba1250f..f44d131d5 100644 --- a/internal/controller/sink.go +++ b/internal/controller/sink.go @@ -654,7 +654,7 @@ func (c *configSink) mountBackendSecurityPolicySecrets(ctx context.Context, spec if backendSecurityPolicy.Spec.AWSCredentials.CredentialsFile != nil { secretName = string(backendSecurityPolicy.Spec.AWSCredentials.CredentialsFile.SecretRef.Name) } else { - secretName = backendSecurityPolicy.Name + secretName = rotators.GetBSPSecretName(backendSecurityPolicy.Name) } default: return nil, fmt.Errorf("backend security policy %s is not supported", backendSecurityPolicy.Spec.Type) diff --git a/internal/controller/sink_test.go b/internal/controller/sink_test.go index 4870cf251..a223d808c 100644 --- a/internal/controller/sink_test.go +++ b/internal/controller/sink_test.go @@ -297,7 +297,7 @@ func TestConfigSink_syncBackendSecurityPolicyOIDC(t *testing.T) { require.True(t, ok) require.Equal(t, "some-access-token", token.AccessToken) - updatedSecret, err := rotators.LookupSecret(context.Background(), fakeClient, sharedNamespace, "orange") + updatedSecret, err := rotators.LookupSecret(context.Background(), fakeClient, sharedNamespace, rotators.GetBSPSecretName("orange")) require.NoError(t, err) require.NotEqualf(t, secret.Annotations[rotators.ExpirationTimeAnnotationKey], updatedSecret.Annotations[rotators.ExpirationTimeAnnotationKey], "expected updated expiration time annotation") } @@ -983,7 +983,7 @@ func TestConfigSink_MountBackendSecurityPolicySecrets(t *testing.T) { require.Equal(t, "rule1-backref0-some-other-backend-security-policy-aws", updatedSpec.Containers[0].VolumeMounts[2].Name) require.Equal(t, "/etc/backend_security_policy/rule1-backref0-some-other-backend-security-policy-aws", updatedSpec.Containers[0].VolumeMounts[2].MountPath) // AWS OIDC. - require.Equal(t, "aws-oidc-name", updatedSpec.Volumes[3].VolumeSource.Secret.SecretName) + require.Equal(t, rotators.GetBSPSecretName("aws-oidc-name"), updatedSpec.Volumes[3].VolumeSource.Secret.SecretName) require.Equal(t, "rule2-backref0-aws-oidc-name", updatedSpec.Volumes[3].Name) require.Equal(t, "rule2-backref0-aws-oidc-name", updatedSpec.Containers[0].VolumeMounts[3].Name) require.Equal(t, "/etc/backend_security_policy/rule2-backref0-aws-oidc-name", updatedSpec.Containers[0].VolumeMounts[3].MountPath) From 4f6161d11f1e9aa349ee26779c9bde15a43e8de9 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 21:04:55 -0500 Subject: [PATCH 44/86] pass time.Time by value instead of pointers Signed-off-by: Aaron Choo --- internal/controller/backend_security_policy.go | 2 +- .../controller/rotators/aws_oidc_rotator.go | 17 +++++++---------- .../rotators/aws_oidc_rotator_test.go | 6 +++--- internal/controller/rotators/common.go | 8 ++++---- internal/controller/rotators/common_test.go | 3 --- 5 files changed, 15 insertions(+), 21 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 7e78dbc0b..afdf760c1 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -59,7 +59,7 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr if err != nil { b.logger.Error(err, "failed to create AWS OIDC rotator") } else if !rotator.IsExpired() { - requeue = time.Until(*rotator.GetPreRotationTime()) + requeue = time.Until(rotator.GetPreRotationTime()) if requeue.Seconds() == 0 { requeue = time.Minute } diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index 7dc4ef101..2fab4ae36 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -88,27 +88,24 @@ func (r *AWSOIDCRotator) SetSTSOperations(ops STSClient) { // IsExpired checks if the preRotation time is before the current time. func (r *AWSOIDCRotator) IsExpired() bool { preRotationExpirationTime := r.GetPreRotationTime() - if preRotationExpirationTime == nil { - return true - } - return IsExpired(0, *preRotationExpirationTime) + return IsExpired(0, preRotationExpirationTime) } -// GetPreRotationTime gets the expiration time minus the preRotation interval. -func (r *AWSOIDCRotator) GetPreRotationTime() *time.Time { +// GetPreRotationTime gets the expiration time minus the preRotation interval or return zero value for time. +func (r *AWSOIDCRotator) GetPreRotationTime() time.Time { secret, err := LookupSecret(r.ctx, r.client, r.backendSecurityPolicyNamespace, GetBSPSecretName(r.backendSecurityPolicyName)) if err != nil { if !errors.IsNotFound(err) { - return nil + return time.Time{} } - return nil + return time.Time{} } expirationTime, err := GetExpirationSecretAnnotation(secret) if err != nil { - return nil + return time.Time{} } preRotationTime := expirationTime.Add(-r.preRotationWindow) - return &preRotationTime + return preRotationTime } // Rotate implements the retrieval and storage of AWS sts credentials. diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index 5909db307..bd29e232a 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -166,10 +166,10 @@ func TestAWS_GetPreRotationTime(t *testing.T) { backendSecurityPolicyName: "test-secret", } - require.Nil(t, awsOidcRotator.GetPreRotationTime()) + require.Equal(t, 0, awsOidcRotator.GetPreRotationTime().Minute()) createTestAWSSecret(t, cl, "test-secret", "OLDKEY", "OLDSECRET", "OLDTOKEN", "default") - require.Nil(t, awsOidcRotator.GetPreRotationTime()) + require.Equal(t, 0, awsOidcRotator.GetPreRotationTime().Minute()) secret, err := LookupSecret(context.Background(), cl, "default", GetBSPSecretName("test-secret")) require.NoError(t, err) @@ -195,7 +195,7 @@ func TestAWS_IsExpired(t *testing.T) { require.True(t, awsOidcRotator.IsExpired()) createTestAWSSecret(t, cl, "test-secret", "OLDKEY", "OLDSECRET", "OLDTOKEN", "default") - require.Nil(t, awsOidcRotator.GetPreRotationTime()) + require.Equal(t, 0, awsOidcRotator.GetPreRotationTime().Minute()) secret, err := LookupSecret(context.Background(), cl, "default", GetBSPSecretName("test-secret")) require.NoError(t, err) diff --git a/internal/controller/rotators/common.go b/internal/controller/rotators/common.go index a207b6ff1..51867981e 100644 --- a/internal/controller/rotators/common.go +++ b/internal/controller/rotators/common.go @@ -64,17 +64,17 @@ func updateExpirationSecretAnnotation(secret *corev1.Secret, updateTime time.Tim } // GetExpirationSecretAnnotation will get the expiration time of credentials set in secret annotation. -func GetExpirationSecretAnnotation(secret *corev1.Secret) (*time.Time, error) { +func GetExpirationSecretAnnotation(secret *corev1.Secret) (time.Time, error) { expirationTimeAnnotationKey, ok := secret.Annotations[ExpirationTimeAnnotationKey] if !ok { - return nil, fmt.Errorf("secret %s/%s missing expiration time annotation", secret.Namespace, secret.Name) + return time.Time{}, fmt.Errorf("secret %s/%s missing expiration time annotation", secret.Namespace, secret.Name) } expirationTime, err := time.Parse(time.RFC3339, expirationTimeAnnotationKey) if err != nil { - return nil, fmt.Errorf("failed to parse expiration time annotation: %w", err) + return time.Time{}, fmt.Errorf("failed to parse expiration time annotation: %w", err) } - return &expirationTime, nil + return expirationTime, nil } // IsExpired checks if the expired time minus duration buffer is before the current time. diff --git a/internal/controller/rotators/common_test.go b/internal/controller/rotators/common_test.go index 9053acb94..c462fca49 100644 --- a/internal/controller/rotators/common_test.go +++ b/internal/controller/rotators/common_test.go @@ -117,7 +117,6 @@ func TestGetExpirationSecretAnnotation(t *testing.T) { expirationTime, err := GetExpirationSecretAnnotation(secret) require.Error(t, err) require.Contains(t, err.Error(), "missing expiration time annotation") - require.Nil(t, expirationTime) secret.Annotations = map[string]string{ ExpirationTimeAnnotationKey: "invalid", @@ -125,7 +124,6 @@ func TestGetExpirationSecretAnnotation(t *testing.T) { expirationTime, err = GetExpirationSecretAnnotation(secret) require.Error(t, err) require.Contains(t, err.Error(), "failed to parse") - require.Nil(t, expirationTime) timeNow := time.Now() secret.Annotations = map[string]string{ @@ -141,7 +139,6 @@ func TestUpdateAndGetExpirationSecretAnnotation(t *testing.T) { expirationTime, err := GetExpirationSecretAnnotation(secret) require.Error(t, err) require.Contains(t, err.Error(), "missing expiration time annotation") - require.Nil(t, expirationTime) timeNow := time.Now() updateExpirationSecretAnnotation(secret, timeNow) From 03aad2b6a252a79168f42cf2b3abb5a2e9807dd0 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 22:01:35 -0500 Subject: [PATCH 45/86] remove few secret related functions and fixed typos Signed-off-by: Aaron Choo --- .../controller/rotators/aws_oidc_rotator.go | 24 ++++++- internal/controller/rotators/common.go | 31 +--------- internal/controller/rotators/common_test.go | 62 ++----------------- manifests/charts/ai-gateway-helm/values.yaml | 2 +- 4 files changed, 30 insertions(+), 89 deletions(-) diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index 2fab4ae36..371a8ceea 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -11,7 +11,9 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/go-logr/logr" + corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -57,7 +59,7 @@ func NewAWSOIDCRotator( cfg.Region = region - if proxyURL := os.Getenv("AI_GATEWY_STS_PROXY_URL"); proxyURL != "" { + if proxyURL := os.Getenv("AI_GATEWAY_STS_PROXY_URL"); proxyURL != "" { cfg.HTTPClient = &http.Client{ Transport: &http.Transport{ Proxy: func(*http.Request) (*url.URL, error) { @@ -125,7 +127,14 @@ func (r *AWSOIDCRotator) Rotate(region, roleARN, token string) error { if !errors.IsNotFound(err) { return err } - secret = newBSPSecret(r.backendSecurityPolicyNamespace, r.backendSecurityPolicyName) + secret = &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: GetBSPSecretName(r.backendSecurityPolicyName), + Namespace: r.backendSecurityPolicyNamespace, + }, + Type: corev1.SecretTypeOpaque, + Data: make(map[string][]byte), + } } updateExpirationSecretAnnotation(secret, *result.Credentials.Expiration) @@ -145,7 +154,16 @@ func (r *AWSOIDCRotator) Rotate(region, roleARN, token string) error { } updateAWSCredentialsInSecret(secret, &credsFile) - return updateSecret(r.ctx, r.client, secret) + + err = r.client.Create(r.ctx, secret) + if err != nil { + if !errors.IsAlreadyExists(err) { + return r.client.Update(r.ctx, secret) + } + return fmt.Errorf("failed to create secret: %w", err) + } + + return nil } // assumeRoleWithToken exchanges an OIDC token for AWS credentials. diff --git a/internal/controller/rotators/common.go b/internal/controller/rotators/common.go index 51867981e..d29f923f5 100644 --- a/internal/controller/rotators/common.go +++ b/internal/controller/rotators/common.go @@ -7,38 +7,13 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client" ) +// ExpirationTimeAnnotationKey is exported for testing purposes within the controller. const ExpirationTimeAnnotationKey = "rotators/expiration-time" -const RotatorSecretNamePrefix = "ai-eg-bsp" -// newBSPSecret creates a new secret struct (does not persist to k8s). -func newBSPSecret(namespace, bspName string) *corev1.Secret { - return &corev1.Secret{ - ObjectMeta: metav1.ObjectMeta{ - Name: GetBSPSecretName(bspName), - Namespace: namespace, - }, - Type: corev1.SecretTypeOpaque, - Data: make(map[string][]byte), - } -} - -// updateSecret updates an existing secret or creates a new one. -func updateSecret(ctx context.Context, k8sClient client.Client, secret *corev1.Secret) error { - if secret.ResourceVersion == "" { - if err := k8sClient.Create(ctx, secret); err != nil { - return fmt.Errorf("failed to create secret: %w", err) - } - } else { - if err := k8sClient.Update(ctx, secret); err != nil { - return fmt.Errorf("failed to update secret: %w", err) - } - } - return nil -} +const rotatorSecretNamePrefix = "ai-eg-bsp" // #nosec G101 // LookupSecret retrieves an existing secret. func LookupSecret(ctx context.Context, k8sClient client.Client, namespace, name string) (*corev1.Secret, error) { @@ -84,5 +59,5 @@ func IsExpired(buffer time.Duration, expirationTime time.Time) bool { // GetBSPSecretName will return the bspName with rotator prefix. func GetBSPSecretName(bspName string) string { - return fmt.Sprintf("%s-%s", RotatorSecretNamePrefix, bspName) + return fmt.Sprintf("%s-%s", rotatorSecretNamePrefix, bspName) } diff --git a/internal/controller/rotators/common_test.go b/internal/controller/rotators/common_test.go index c462fca49..e94557f95 100644 --- a/internal/controller/rotators/common_test.go +++ b/internal/controller/rotators/common_test.go @@ -9,61 +9,9 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" - "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" ) -func TestNewBSPSecret(t *testing.T) { - name := "test" - namespace := "test-namespace" - secret := newBSPSecret(namespace, name) - - require.NotNil(t, secret) - require.Equal(t, GetBSPSecretName(name), secret.Name) - require.Equal(t, namespace, secret.Namespace) - require.NotNil(t, secret.Data) -} - -func TestUpdateSecret(t *testing.T) { - scheme := runtime.NewScheme() - scheme.AddKnownTypes(corev1.SchemeGroupVersion, - &corev1.Secret{}, - ) - cl := fake.NewClientBuilder().WithScheme(scheme).Build() - - secret := &corev1.Secret{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - Namespace: "test-namespace", - }, - Data: map[string][]byte{ - "key": []byte("value"), - }, - } - - err := cl.Get(context.Background(), client.ObjectKeyFromObject(secret), secret) - require.NoError(t, client.IgnoreNotFound(err)) - require.NoError(t, updateSecret(context.Background(), cl, secret)) - - var secretPlaceholder corev1.Secret - require.NoError(t, cl.Get(context.Background(), client.ObjectKey{ - Namespace: "test-namespace", - Name: "test", - }, &secretPlaceholder)) - require.Equal(t, secret.Name, secretPlaceholder.Name) - require.Equal(t, secret.Namespace, secretPlaceholder.Namespace) - require.Equal(t, []byte("value"), secretPlaceholder.Data["key"]) - - secret.Data["key"] = []byte("another value") - require.NoError(t, updateSecret(context.Background(), cl, secret)) - - require.NoError(t, cl.Get(context.Background(), client.ObjectKey{ - Namespace: "test-namespace", - Name: "test", - }, &secretPlaceholder)) - require.Equal(t, []byte("another value"), secretPlaceholder.Data["key"]) -} - func TestLookupSecret(t *testing.T) { scheme := runtime.NewScheme() scheme.AddKnownTypes(corev1.SchemeGroupVersion, @@ -114,14 +62,14 @@ func TestGetExpirationSecretAnnotation(t *testing.T) { }, } - expirationTime, err := GetExpirationSecretAnnotation(secret) + _, err := GetExpirationSecretAnnotation(secret) require.Error(t, err) require.Contains(t, err.Error(), "missing expiration time annotation") secret.Annotations = map[string]string{ ExpirationTimeAnnotationKey: "invalid", } - expirationTime, err = GetExpirationSecretAnnotation(secret) + _, err = GetExpirationSecretAnnotation(secret) require.Error(t, err) require.Contains(t, err.Error(), "failed to parse") @@ -129,20 +77,20 @@ func TestGetExpirationSecretAnnotation(t *testing.T) { secret.Annotations = map[string]string{ ExpirationTimeAnnotationKey: timeNow.Format(time.RFC3339), } - expirationTime, err = GetExpirationSecretAnnotation(secret) + expirationTime, err := GetExpirationSecretAnnotation(secret) require.NoError(t, err) require.Equal(t, timeNow.Format(time.RFC3339), expirationTime.Format(time.RFC3339)) } func TestUpdateAndGetExpirationSecretAnnotation(t *testing.T) { secret := &corev1.Secret{} - expirationTime, err := GetExpirationSecretAnnotation(secret) + _, err := GetExpirationSecretAnnotation(secret) require.Error(t, err) require.Contains(t, err.Error(), "missing expiration time annotation") timeNow := time.Now() updateExpirationSecretAnnotation(secret, timeNow) - expirationTime, err = GetExpirationSecretAnnotation(secret) + expirationTime, err := GetExpirationSecretAnnotation(secret) require.NoError(t, err) require.Equal(t, timeNow.Format(time.RFC3339), expirationTime.Format(time.RFC3339)) } diff --git a/manifests/charts/ai-gateway-helm/values.yaml b/manifests/charts/ai-gateway-helm/values.yaml index 26c7d35c2..860edf5b0 100644 --- a/manifests/charts/ai-gateway-helm/values.yaml +++ b/manifests/charts/ai-gateway-helm/values.yaml @@ -39,7 +39,7 @@ controller: podSecurityContext: {} securityContext: {} # Example of a podEnv - # - key: AI_GATEWY_STS_PROXY_URL + # - key: AI_GATEWAY_STS_PROXY_URL # value: some-proxy-placeholder podEnv: {} # Example of volumes From 5adc33dfdf811d8e476722975b0153e53c529094 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 22:04:47 -0500 Subject: [PATCH 46/86] check if oidc is nil before accessing Signed-off-by: Aaron Choo --- .../controller/oauth/client_credentials_token_provider.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index cf6dd29f5..d7a0d711e 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -45,11 +45,13 @@ func (p *ClientCredentialsTokenProvider) FetchToken(ctx context.Context, oidc *e func (p *ClientCredentialsTokenProvider) getTokenWithClientCredentialConfig(ctx context.Context, oidc *egv1a1.OIDC, clientSecret string) (*oauth2.Token, error) { if p.tokenSource == nil { oauth2Config := clientcredentials.Config{ - ClientID: oidc.ClientID, ClientSecret: clientSecret, + } + if oidc != nil { + oauth2Config.ClientID = oidc.ClientID + oauth2Config.Scopes = oidc.Scopes // Discovery returns the OAuth2 endpoints. - TokenURL: *oidc.Provider.TokenEndpoint, - Scopes: oidc.Scopes, + oauth2Config.TokenURL = *oidc.Provider.TokenEndpoint } p.tokenSource = oauth2Config.TokenSource(ctx) } From 38753f005c51e96ea938c5928f060e885025fd9c Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 22:06:58 -0500 Subject: [PATCH 47/86] remove requeue: true Signed-off-by: Aaron Choo --- internal/controller/backend_security_policy.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index afdf760c1..2d56092fd 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -64,7 +64,8 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr requeue = time.Minute } } - res = ctrl.Result{RequeueAfter: requeue, Requeue: true} + // TODO: Investigate how to stop stale events from re-queuing. + res = ctrl.Result{RequeueAfter: requeue} } // Send the backend security policy to the config sink so that it can modify the configuration together with the state of other resources. b.eventChan <- backendSecurityPolicy.DeepCopy() From 8ec76eb99a44fad98ad5b5a090b041eddeee77b2 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 23:17:28 -0500 Subject: [PATCH 48/86] move rotator into the controller Signed-off-by: Aaron Choo --- .../controller/backend_security_policy.go | 85 ++++++++++-- .../backend_security_policy_test.go | 103 ++++++++++++++- .../controller/rotators/aws_oidc_rotator.go | 1 + internal/controller/sink.go | 46 ------- internal/controller/sink_test.go | 125 ------------------ 5 files changed, 172 insertions(+), 188 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 2d56092fd..cfb6eb82d 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -7,35 +7,46 @@ package controller import ( "context" + "fmt" "time" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/go-logr/logr" + "golang.org/x/oauth2" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/client-go/kubernetes" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" aigv1a1 "github.com/envoyproxy/ai-gateway/api/v1alpha1" + "github.com/envoyproxy/ai-gateway/internal/controller/oauth" "github.com/envoyproxy/ai-gateway/internal/controller/rotators" ) +// preRotationWindow specifies how long before expiry to rotate credentials. +// Temporarily a fixed duration. +const preRotationWindow = 5 * time.Minute + // backendSecurityPolicyController implements [reconcile.TypedReconciler] for [aigv1a1.BackendSecurityPolicy]. // // This handles the BackendSecurityPolicy resource and sends it to the config sink so that it can modify configuration. type backendSecurityPolicyController struct { - client client.Client - kube kubernetes.Interface - logger logr.Logger - eventChan chan ConfigSinkEvent + client client.Client + kube kubernetes.Interface + logger logr.Logger + eventChan chan ConfigSinkEvent + StsOP rotators.STSClient + oidcTokenCache map[string]*oauth2.Token } func newBackendSecurityPolicyController(client client.Client, kube kubernetes.Interface, logger logr.Logger, ch chan ConfigSinkEvent) *backendSecurityPolicyController { return &backendSecurityPolicyController{ - client: client, - kube: kube, - logger: logger, - eventChan: ch, + client: client, + kube: kube, + logger: logger, + eventChan: ch, + StsOP: nil, + oidcTokenCache: make(map[string]*oauth2.Token), } } @@ -51,19 +62,65 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr return ctrl.Result{}, err } - if getBackendSecurityPolicyAuthOIDC(backendSecurityPolicy.Spec) != nil { + println("zero") + if oidc := getBackendSecurityPolicyAuthOIDC(backendSecurityPolicy.Spec); oidc != nil { + println("oidc is not nil") var requeue time.Duration requeue = time.Minute region := backendSecurityPolicy.Spec.AWSCredentials.Region + rotator, err := rotators.NewAWSOIDCRotator(ctx, b.client, b.kube, b.logger, backendSecurityPolicy.Namespace, backendSecurityPolicy.Name, preRotationWindow, region) if err != nil { + println("new aws oidc rotator failed to get") b.logger.Error(err, "failed to create AWS OIDC rotator") - } else if !rotator.IsExpired() { - requeue = time.Until(rotator.GetPreRotationTime()) - if requeue.Seconds() == 0 { - requeue = time.Minute + } else if rotator.IsExpired() { + bspKey := fmt.Sprintf("%s.%s", backendSecurityPolicy.Name, backendSecurityPolicy.Namespace) + + println("one") + var validToken *oauth2.Token + if tokenResponse, ok := b.oidcTokenCache[bspKey]; !ok || rotators.IsExpired(preRotationWindow, tokenResponse.Expiry) { + println("two") + oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(b.client), oidc) + // Valid Token will be nil if fetch token errors. + validToken, err = oidcProvider.FetchToken(ctx) + if err != nil { + println("three") + b.logger.Error(err, "failed to fetch OIDC provider token") + } else { + b.oidcTokenCache[bspKey] = validToken + } + } else { + println("four") + validToken = tokenResponse + } + + println("five") + if validToken != nil { + println("six") + b.oidcTokenCache[bspKey] = validToken + awsCredentials := backendSecurityPolicy.Spec.AWSCredentials + + println("seven") + // This is to abstract the real STS behavior for testing purpose. + if b.StsOP != nil { + println("eight") + rotator.SetSTSOperations(b.StsOP) + } + println("nine") + token := validToken.AccessToken + err = rotator.Rotate(awsCredentials.Region, awsCredentials.OIDCExchangeToken.AwsRoleArn, token) + if err != nil { + println("ten") + b.logger.Error(err, "failed to rotate AWS OIDC exchange token") + requeue = time.Minute + } else { + println("eleven") + requeue = time.Until(rotator.GetPreRotationTime()) + } + } } + println("twelve") // TODO: Investigate how to stop stale events from re-queuing. res = ctrl.Result{RequeueAfter: requeue} } @@ -74,7 +131,9 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr // getBackendSecurityPolicyAuthOIDC returns the backendSecurityPolicy's OIDC pointer or nil. func getBackendSecurityPolicyAuthOIDC(spec aigv1a1.BackendSecurityPolicySpec) *egv1a1.OIDC { + println("point five") if spec.AWSCredentials != nil && spec.AWSCredentials.OIDCExchangeToken != nil { + println("point 8") return &spec.AWSCredentials.OIDCExchangeToken.OIDC } return nil diff --git a/internal/controller/backend_security_policy_test.go b/internal/controller/backend_security_policy_test.go index 800641ad4..4a32d7bc1 100644 --- a/internal/controller/backend_security_policy_test.go +++ b/internal/controller/backend_security_policy_test.go @@ -7,21 +7,31 @@ package controller import ( "context" + "encoding/json" "fmt" + "net/http" + "net/http/httptest" "testing" "time" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sts" + stsTypes "github.com/aws/aws-sdk-go-v2/service/sts/types" + oidcv3 "github.com/coreos/go-oidc/v3/oidc" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" fake2 "k8s.io/client-go/kubernetes/fake" ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/controller-runtime/pkg/reconcile" gwapiv1 "sigs.k8s.io/gateway-api/apis/v1" aigv1a1 "github.com/envoyproxy/ai-gateway/api/v1alpha1" + "github.com/envoyproxy/ai-gateway/internal/controller/rotators" ) func TestBackendSecurityController_Reconcile(t *testing.T) { @@ -31,6 +41,51 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { backendSecurityPolicyName := "mybackendSecurityPolicy" namespace := "default" + secret := corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "clientSecret", + Namespace: namespace, + }, + Data: map[string][]byte{ + "client-secret": []byte("client-secret"), + }, + } + require.NoError(t, cl.Create(context.Background(), &secret, &client.CreateOptions{})) + + secret = corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: rotators.GetBSPSecretName(fmt.Sprintf("%s-OIDC", backendSecurityPolicyName)), + Namespace: namespace, + Annotations: map[string]string{ + rotators.ExpirationTimeAnnotationKey: "2024-01-01T01:01:00.000-00:00", + }, + }, + Data: map[string][]byte{ + "credentials": []byte("credentials"), + }, + } + require.NoError(t, cl.Create(context.Background(), &secret, &client.CreateOptions{})) + + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Add("Content-Type", "application/json") + type tokenJSON struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn string `json:"expires_in"` + } + b, err := json.Marshal(tokenJSON{AccessToken: "some-access-token", TokenType: "Bearer", ExpiresIn: "60"}) + require.NoError(t, err) + _, err = w.Write(b) + require.NoError(t, err) + })) + defer tokenServer.Close() + + discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, err := w.Write([]byte(`{"issuer": "issuer", "token_endpoint": "token_endpoint", "authorization_endpoint": "authorization_endpoint", "jwks_uri": "jwks_uri", "scopes_supported": []}`)) + require.NoError(t, err) + })) + defer discoveryServer.Close() + err := cl.Create(t.Context(), &aigv1a1.BackendSecurityPolicy{ObjectMeta: metav1.ObjectMeta{Name: backendSecurityPolicyName, Namespace: namespace}}) require.NoError(t, err) err = cl.Create(context.Background(), &aigv1a1.BackendSecurityPolicy{ @@ -40,7 +95,20 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ Region: "us-east-1", OIDCExchangeToken: &aigv1a1.AWSOIDCExchangeToken{ - OIDC: egv1a1.OIDC{}, + OIDC: egv1a1.OIDC{ + Provider: egv1a1.OIDCProvider{ + Issuer: discoveryServer.URL, + TokenEndpoint: &tokenServer.URL, + }, + ClientID: "some-client-id", + ClientSecret: gwapiv1.SecretObjectReference{ + Name: "clientSecret", + Namespace: (*gwapiv1.Namespace)(&namespace), + }, + }, + GrantType: "placeholder", + Aud: "placeholder", + AwsRoleArn: "placeholder", }, }, }, @@ -55,10 +123,20 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { require.Equal(t, backendSecurityPolicyName, item.(*aigv1a1.BackendSecurityPolicy).Name) require.Equal(t, namespace, item.(*aigv1a1.BackendSecurityPolicy).Namespace) - res, err = c.Reconcile(context.Background(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName)}}) + c.StsOP = &mockSTSOperations{} + ctx := oidcv3.InsecureIssuerURLContext(context.Background(), discoveryServer.URL) + res, err = c.Reconcile(ctx, reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName)}}) require.NoError(t, err) - require.True(t, res.Requeue) - require.Equal(t, time.Minute, res.RequeueAfter) + require.WithinRange(t, time.Now().Add(res.RequeueAfter), time.Now().Add(50*time.Minute), time.Now().Add(time.Hour)) + + require.Len(t, c.oidcTokenCache, 1) + token, ok := c.oidcTokenCache[fmt.Sprintf("%s-OIDC.%s", backendSecurityPolicyName, namespace)] + require.True(t, ok) + require.Equal(t, "some-access-token", token.AccessToken) + + updatedSecret, err := rotators.LookupSecret(context.Background(), cl, namespace, rotators.GetBSPSecretName(fmt.Sprintf("%s-OIDC", backendSecurityPolicyName))) + require.NoError(t, err) + require.NotEqualf(t, secret.Annotations[rotators.ExpirationTimeAnnotationKey], updatedSecret.Annotations[rotators.ExpirationTimeAnnotationKey], "expected updated expiration time annotation") // Test the case where the BackendSecurityPolicy is being deleted. err = cl.Delete(context.Background(), &aigv1a1.BackendSecurityPolicy{ObjectMeta: metav1.ObjectMeta{Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName), Namespace: namespace}}) @@ -69,6 +147,23 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { require.False(t, res.Requeue) } +// mockSTSOperations implements the STSOperations interface for testing +type mockSTSOperations struct{} + +// AssumeRoleWithWebIdentity will return placeholder of type aws credentials. +// +// This implements [STSClient.AssumeRoleWithWebIdentity]. +func (m *mockSTSOperations) AssumeRoleWithWebIdentity(_ context.Context, _ *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { + return &sts.AssumeRoleWithWebIdentityOutput{ + Credentials: &stsTypes.Credentials{ + AccessKeyId: aws.String("NEWKEY"), + SecretAccessKey: aws.String("NEWSECRET"), + SessionToken: aws.String("NEWTOKEN"), + Expiration: aws.Time(time.Now().Add(1 * time.Hour)), + }, + }, nil +} + func TestBackendSecurityController_GetBackendSecurityPolicyAuthOIDC(t *testing.T) { require.Nil(t, getBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{ Type: aigv1a1.BackendSecurityPolicyTypeAPIKey, diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index 371a8ceea..35047f7cc 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -90,6 +90,7 @@ func (r *AWSOIDCRotator) SetSTSOperations(ops STSClient) { // IsExpired checks if the preRotation time is before the current time. func (r *AWSOIDCRotator) IsExpired() bool { preRotationExpirationTime := r.GetPreRotationTime() + println(preRotationExpirationTime.String()) return IsExpired(0, preRotationExpirationTime) } diff --git a/internal/controller/sink.go b/internal/controller/sink.go index f44d131d5..2e8b9bd59 100644 --- a/internal/controller/sink.go +++ b/internal/controller/sink.go @@ -9,11 +9,9 @@ import ( "context" "fmt" "path" - "time" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/go-logr/logr" - "golang.org/x/oauth2" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -29,7 +27,6 @@ import ( aigv1a1 "github.com/envoyproxy/ai-gateway/api/v1alpha1" "github.com/envoyproxy/ai-gateway/filterapi" - "github.com/envoyproxy/ai-gateway/internal/controller/oauth" "github.com/envoyproxy/ai-gateway/internal/controller/rotators" "github.com/envoyproxy/ai-gateway/internal/llmcostcel" ) @@ -45,10 +42,6 @@ const ( // secret with backendSecurityPolicy auth instead of mounting new secret files to the external proc. const mountedExtProcSecretPath = "/etc/backend_security_policy" // #nosec G101 -// preRotationWindow specifies how long before expiry to rotate credentials -// temporarily a fixed duration -const preRotationWindow = 5 * time.Minute - // ConfigSinkEvent is the interface for the events that the configSink can handle. // It can be either an AIServiceBackend, an AIGatewayRoute, or a deletion event. // @@ -73,8 +66,6 @@ type configSink struct { extProcImagePullPolicy corev1.PullPolicy extProcLogLevel string eventChan chan ConfigSinkEvent - StsOP rotators.STSClient - oidcTokenCache map[string]*oauth2.Token } func newConfigSink( @@ -93,8 +84,6 @@ func newConfigSink( extProcImagePullPolicy: corev1.PullIfNotPresent, extProcLogLevel: extProcLogLevel, eventChan: eventChan, - StsOP: nil, - oidcTokenCache: make(map[string]*oauth2.Token), } return c } @@ -268,41 +257,6 @@ func (c *configSink) syncBackendSecurityPolicy(ctx context.Context, bsp *aigv1a1 aiBackend := &aiServiceBackends.Items[i] c.syncAIServiceBackend(ctx, aiBackend) } - - if oidc := getBackendSecurityPolicyAuthOIDC(bsp.Spec); oidc != nil { - tokenResponse, ok := c.oidcTokenCache[key] - if !ok || rotators.IsExpired(preRotationWindow, tokenResponse.Expiry) { - oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(c.client), oidc) - - tokenRes, err := oidcProvider.FetchToken(ctx) - if err != nil { - c.logger.Error(err, "failed to fetch OIDC provider token") - return - } - c.oidcTokenCache[key] = tokenRes - tokenResponse = tokenRes - } - - awsCredentials := bsp.Spec.AWSCredentials - rotator, err := rotators.NewAWSOIDCRotator(ctx, c.client, c.kube, c.logger, bsp.Namespace, bsp.Name, preRotationWindow, awsCredentials.Region) - if err != nil { - c.logger.Error(err, "failed to create AWS OIDC rotator") - return - } - - if rotator.IsExpired() { - // This is to abstract the real STS behavior for testing purpose. - if c.StsOP != nil { - rotator.SetSTSOperations(c.StsOP) - } - token := tokenResponse.AccessToken - err = rotator.Rotate(awsCredentials.Region, awsCredentials.OIDCExchangeToken.AwsRoleArn, token) - if err != nil { - c.logger.Error(err, "failed to rotate AWS OIDC exchange token") - return - } - } - } } // updateExtProcConfigMap updates the external process configmap with the new AIGatewayRoute. diff --git a/internal/controller/sink_test.go b/internal/controller/sink_test.go index a223d808c..d223b904f 100644 --- a/internal/controller/sink_test.go +++ b/internal/controller/sink_test.go @@ -7,20 +7,13 @@ package controller import ( "context" - "encoding/json" "fmt" "log/slog" - "net/http" - "net/http/httptest" "os" "strconv" "testing" "time" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/sts" - "github.com/aws/aws-sdk-go-v2/service/sts/types" - oidcv3 "github.com/coreos/go-oidc/v3/oidc" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/go-logr/logr" "github.com/stretchr/testify/require" @@ -184,124 +177,6 @@ func TestConfigSink_syncBackendSecurityPolicy(t *testing.T) { }) } -// MockSTSOperations implements the STSOperations interface for testing -type MockSTSOperations struct{} - -func (m *MockSTSOperations) AssumeRoleWithWebIdentity(_ context.Context, _ *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { - return &sts.AssumeRoleWithWebIdentityOutput{ - Credentials: &types.Credentials{ - AccessKeyId: aws.String("NEWKEY"), - SecretAccessKey: aws.String("NEWSECRET"), - SessionToken: aws.String("NEWTOKEN"), - Expiration: aws.Time(time.Now().Add(1 * time.Hour)), - }, - }, nil -} - -func TestConfigSink_syncBackendSecurityPolicyOIDC(t *testing.T) { - fakeClient := requireNewFakeClientWithIndexes(t) - eventChan := make(chan ConfigSinkEvent) - s := newConfigSink(fakeClient, nil, logr.Discard(), eventChan, "defaultExtProcImage", "debug") - - require.Empty(t, s.oidcTokenCache) - - // Test with OIDC backend - backend := aigv1a1.AIServiceBackend{ - ObjectMeta: metav1.ObjectMeta{Name: "potato", Namespace: "ns"}, - Spec: aigv1a1.AIServiceBackendSpec{ - BackendRef: gwapiv1.BackendObjectReference{Name: "some-backend2", Namespace: ptr.To[gwapiv1.Namespace]("ns")}, - BackendSecurityPolicyRef: &gwapiv1.LocalObjectReference{Name: "orange"}, - }, - } - require.NoError(t, fakeClient.Create(context.Background(), &backend, &client.CreateOptions{})) - - clientSecret := "secretName" - sharedNamespace := "ns" - secret := corev1.Secret{ - ObjectMeta: metav1.ObjectMeta{ - Name: clientSecret, - Namespace: sharedNamespace, - }, - Data: map[string][]byte{ - "client-secret": []byte("client-secret"), - }, - } - require.NoError(t, fakeClient.Create(context.Background(), &secret, &client.CreateOptions{})) - - secret = corev1.Secret{ - ObjectMeta: metav1.ObjectMeta{ - Name: "orange", - Namespace: sharedNamespace, - Annotations: map[string]string{ - rotators.ExpirationTimeAnnotationKey: "2024-01-01T01:01:00.000-00:00", - }, - }, - Data: map[string][]byte{ - "credentials": []byte("credentials"), - }, - } - require.NoError(t, fakeClient.Create(context.Background(), &secret, &client.CreateOptions{})) - - tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Add("Content-Type", "application/json") - type tokenJSON struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresIn string `json:"expires_in"` - } - b, err := json.Marshal(tokenJSON{AccessToken: "some-access-token", TokenType: "Bearer", ExpiresIn: "60"}) - require.NoError(t, err) - _, err = w.Write(b) - require.NoError(t, err) - })) - defer tokenServer.Close() - - discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - _, err := w.Write([]byte(`{"issuer": "issuer", "token_endpoint": "token_endpoint", "authorization_endpoint": "authorization_endpoint", "jwks_uri": "jwks_uri", "scopes_supported": []}`)) - require.NoError(t, err) - })) - defer discoveryServer.Close() - - ctx := oidcv3.InsecureIssuerURLContext(context.Background(), discoveryServer.URL) - namespaceRef := gwapiv1.Namespace(sharedNamespace) - - s.StsOP = &MockSTSOperations{} - - s.syncBackendSecurityPolicy(ctx, &aigv1a1.BackendSecurityPolicy{ - ObjectMeta: metav1.ObjectMeta{Name: "orange", Namespace: sharedNamespace}, - Spec: aigv1a1.BackendSecurityPolicySpec{ - Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, - AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ - Region: "us-east-1", - OIDCExchangeToken: &aigv1a1.AWSOIDCExchangeToken{ - OIDC: egv1a1.OIDC{ - Provider: egv1a1.OIDCProvider{ - Issuer: discoveryServer.URL, - TokenEndpoint: &tokenServer.URL, - }, - ClientID: "some-client-id", - ClientSecret: gwapiv1.SecretObjectReference{ - Name: gwapiv1.ObjectName(clientSecret), - Namespace: &namespaceRef, - }, - }, - GrantType: "placeholder", - Aud: "placeholder", - AwsRoleArn: "placeholder", - }, - }, - }, - }) - require.Len(t, s.oidcTokenCache, 1) - token, ok := s.oidcTokenCache["orange.ns"] - require.True(t, ok) - require.Equal(t, "some-access-token", token.AccessToken) - - updatedSecret, err := rotators.LookupSecret(context.Background(), fakeClient, sharedNamespace, rotators.GetBSPSecretName("orange")) - require.NoError(t, err) - require.NotEqualf(t, secret.Annotations[rotators.ExpirationTimeAnnotationKey], updatedSecret.Annotations[rotators.ExpirationTimeAnnotationKey], "expected updated expiration time annotation") -} - func Test_newHTTPRoute(t *testing.T) { eventChan := make(chan ConfigSinkEvent) fakeClient := requireNewFakeClientWithIndexes(t) From 841e6702d07ade5bec59fdbbd388b2c05b51aea4 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 23:52:33 -0500 Subject: [PATCH 49/86] add timeout for external calls Signed-off-by: Aaron Choo --- .../controller/backend_security_policy.go | 30 ++++++++----------- .../client_credentials_token_provider.go | 18 +++++++++-- .../client_credentials_token_provider_test.go | 16 ++++++++++ .../controller/rotators/aws_oidc_rotator.go | 27 +++++++++++++---- .../rotators/aws_oidc_rotator_test.go | 8 ++++- 5 files changed, 72 insertions(+), 27 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index cfb6eb82d..939f44a7a 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -27,6 +27,9 @@ import ( // Temporarily a fixed duration. const preRotationWindow = 5 * time.Minute +// outgoingTimeOut will be used to prevent outgoing request from blocking. +const outGoingTimeOut = time.Minute + // backendSecurityPolicyController implements [reconcile.TypedReconciler] for [aigv1a1.BackendSecurityPolicy]. // // This handles the BackendSecurityPolicy resource and sends it to the config sink so that it can modify configuration. @@ -62,65 +65,58 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr return ctrl.Result{}, err } - println("zero") if oidc := getBackendSecurityPolicyAuthOIDC(backendSecurityPolicy.Spec); oidc != nil { - println("oidc is not nil") var requeue time.Duration requeue = time.Minute region := backendSecurityPolicy.Spec.AWSCredentials.Region rotator, err := rotators.NewAWSOIDCRotator(ctx, b.client, b.kube, b.logger, backendSecurityPolicy.Namespace, backendSecurityPolicy.Name, preRotationWindow, region) if err != nil { - println("new aws oidc rotator failed to get") b.logger.Error(err, "failed to create AWS OIDC rotator") } else if rotator.IsExpired() { bspKey := fmt.Sprintf("%s.%s", backendSecurityPolicy.Name, backendSecurityPolicy.Namespace) - println("one") var validToken *oauth2.Token if tokenResponse, ok := b.oidcTokenCache[bspKey]; !ok || rotators.IsExpired(preRotationWindow, tokenResponse.Expiry) { - println("two") oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(b.client), oidc) // Valid Token will be nil if fetch token errors. - validToken, err = oidcProvider.FetchToken(ctx) + + timeOutCtx, cancelFunc := context.WithTimeout(ctx, outGoingTimeOut) + defer cancelFunc() + validToken, err = oidcProvider.FetchToken(timeOutCtx) if err != nil { - println("three") b.logger.Error(err, "failed to fetch OIDC provider token") } else { b.oidcTokenCache[bspKey] = validToken } } else { - println("four") validToken = tokenResponse } - println("five") if validToken != nil { - println("six") b.oidcTokenCache[bspKey] = validToken awsCredentials := backendSecurityPolicy.Spec.AWSCredentials - println("seven") // This is to abstract the real STS behavior for testing purpose. if b.StsOP != nil { - println("eight") rotator.SetSTSOperations(b.StsOP) } - println("nine") + + // Set a timeout for rotate. + timeOutCtx, cancelFunc2 := context.WithTimeout(ctx, outGoingTimeOut) + defer cancelFunc2() + rotator.UpdateCtx(timeOutCtx) token := validToken.AccessToken err = rotator.Rotate(awsCredentials.Region, awsCredentials.OIDCExchangeToken.AwsRoleArn, token) if err != nil { - println("ten") b.logger.Error(err, "failed to rotate AWS OIDC exchange token") requeue = time.Minute } else { - println("eleven") requeue = time.Until(rotator.GetPreRotationTime()) } } } - println("twelve") // TODO: Investigate how to stop stale events from re-queuing. res = ctrl.Result{RequeueAfter: requeue} } @@ -131,9 +127,7 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr // getBackendSecurityPolicyAuthOIDC returns the backendSecurityPolicy's OIDC pointer or nil. func getBackendSecurityPolicyAuthOIDC(spec aigv1a1.BackendSecurityPolicySpec) *egv1a1.OIDC { - println("point five") if spec.AWSCredentials != nil && spec.AWSCredentials.OIDCExchangeToken != nil { - println("point 8") return &spec.AWSCredentials.OIDCExchangeToken.OIDC } return nil diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index d7a0d711e..f5393fb3a 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -55,9 +55,21 @@ func (p *ClientCredentialsTokenProvider) getTokenWithClientCredentialConfig(ctx } p.tokenSource = oauth2Config.TokenSource(ctx) } - token, err := p.tokenSource.Token() - if err != nil { - return nil, fmt.Errorf("fail to get oauth2 token %w", err) + + var token *oauth2.Token + var err error + // This adds timeout via ctx from the caller. + for token == nil { + timer := time.NewTimer(time.Second) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + token, err = p.tokenSource.Token() + if err != nil { + return nil, fmt.Errorf("fail to get oauth2 token %w", err) + } + } } // Handle expiration. diff --git a/internal/controller/oauth/client_credentials_token_provider_test.go b/internal/controller/oauth/client_credentials_token_provider_test.go index db760dfda..a018240be 100644 --- a/internal/controller/oauth/client_credentials_token_provider_test.go +++ b/internal/controller/oauth/client_credentials_token_provider_test.go @@ -65,6 +65,22 @@ func TestClientCredentialsProvider_FetchToken(t *testing.T) { require.Contains(t, err.Error(), "oidc or oidc-client-secret is nil") namespaceRef := gwapiv1.Namespace(secretNamespace) + timeOutCtx, cancelFunc := context.WithTimeout(context.Background(), time.Second) + defer cancelFunc() + time.Sleep(time.Second) + _, err = clientCredentialProvider.FetchToken(timeOutCtx, &egv1a1.OIDC{ + Provider: egv1a1.OIDCProvider{ + Issuer: ts.URL, + TokenEndpoint: &ts.URL, + }, + ClientID: "some-client-id", + ClientSecret: gwapiv1.SecretObjectReference{ + Name: gwapiv1.ObjectName(secretName), + Namespace: &namespaceRef, + }, + }) + require.Error(t, err) + token, err := clientCredentialProvider.FetchToken(context.Background(), &egv1a1.OIDC{ Provider: egv1a1.OIDCProvider{ Issuer: ts.URL, diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index 35047f7cc..fec3894b3 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -90,10 +90,15 @@ func (r *AWSOIDCRotator) SetSTSOperations(ops STSClient) { // IsExpired checks if the preRotation time is before the current time. func (r *AWSOIDCRotator) IsExpired() bool { preRotationExpirationTime := r.GetPreRotationTime() - println(preRotationExpirationTime.String()) return IsExpired(0, preRotationExpirationTime) } +// UpdateCtx is used to update the context used in AWSOIDCRotator functions. +// This can be used to set timeouts for outgoing calls to assume role. +func (r *AWSOIDCRotator) UpdateCtx(ctx context.Context) { + r.ctx = ctx +} + // GetPreRotationTime gets the expiration time minus the preRotation interval or return zero value for time. func (r *AWSOIDCRotator) GetPreRotationTime() time.Time { secret, err := LookupSecret(r.ctx, r.client, r.backendSecurityPolicyNamespace, GetBSPSecretName(r.backendSecurityPolicyName)) @@ -117,10 +122,22 @@ func (r *AWSOIDCRotator) Rotate(region, roleARN, token string) error { "namespace", r.backendSecurityPolicyNamespace, "name", r.backendSecurityPolicyName) - result, err := r.assumeRoleWithToken(roleARN, token) - if err != nil { - r.logger.Error(err, "failed to assume role", "role", roleARN, "ID", token) - return err + var result *sts.AssumeRoleWithWebIdentityOutput + var err error + + // This adds timeout via ctx from the caller. + for result == nil { + timer := time.NewTimer(time.Second) + select { + case <-r.ctx.Done(): + return r.ctx.Err() + case <-timer.C: + result, err = r.assumeRoleWithToken(roleARN, token) + if err != nil { + r.logger.Error(err, "failed to assume role", "role", roleARN, "ID", token) + return err + } + } } secret, err := LookupSecret(r.ctx, r.client, r.backendSecurityPolicyNamespace, GetBSPSecretName(r.backendSecurityPolicyName)) diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index bd29e232a..2478519b9 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -116,13 +116,19 @@ func TestAWS_OIDCRotator(t *testing.T) { createClientSecret(t, "test-client-secret") awsOidcRotator := AWSOIDCRotator{ - ctx: context.Background(), client: cl, stsOps: mockSTS, backendSecurityPolicyNamespace: "default", backendSecurityPolicyName: "test-secret", } + timeOutCtx, cancelFunc := context.WithTimeout(context.Background(), time.Second) + defer cancelFunc() + time.Sleep(time.Second) + awsOidcRotator.UpdateCtx(timeOutCtx) + require.Error(t, awsOidcRotator.Rotate("us-east1", "test", "NEW-OIDC-TOKEN")) + + awsOidcRotator.UpdateCtx(context.Background()) require.NoError(t, awsOidcRotator.Rotate("us-east1", "test", "NEW-OIDC-TOKEN")) verifyAWSSecretCredentials(t, cl, "default", "test-secret", "NEWKEY", "NEWSECRET", "NEWTOKEN", "default") }) From 856030b6b553f964f8867ff6c9a6e5f2d433a382 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 23:55:59 -0500 Subject: [PATCH 50/86] clean up comments Signed-off-by: Aaron Choo --- .../controller/oauth/client_credentials_token_provider.go | 2 ++ internal/controller/oauth/oidc_provider.go | 7 ++----- internal/controller/oauth/types.go | 1 + internal/controller/rotators/aws_oidc_rotator.go | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index f5393fb3a..bae5f8a83 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -26,6 +26,8 @@ func NewClientCredentialsProvider(cl client.Client) *ClientCredentialsTokenProvi } // FetchToken gets the client secret from the secret reference and fetches the token from the provider token URL. +// +// This implements [TokenProvider.FetchToken]. func (p *ClientCredentialsTokenProvider) FetchToken(ctx context.Context, oidc *egv1a1.OIDC) (*oauth2.Token, error) { if oidc == nil || oidc.ClientSecret.Namespace == nil { return nil, fmt.Errorf("oidc or oidc-client-secret is nil") diff --git a/internal/controller/oauth/oidc_provider.go b/internal/controller/oauth/oidc_provider.go index 2d9b002e6..e343e5678 100644 --- a/internal/controller/oauth/oidc_provider.go +++ b/internal/controller/oauth/oidc_provider.go @@ -3,9 +3,6 @@ package oauth import ( "context" "fmt" - "net/http" - "time" - "github.com/coreos/go-oidc/v3/oidc" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "golang.org/x/oauth2" @@ -14,7 +11,6 @@ import ( // OIDCProvider extends ClientCredentialsTokenProvider with OIDC support. type OIDCProvider struct { tokenProvider *ClientCredentialsTokenProvider - httpClient *http.Client oidcCredential *egv1a1.OIDC } @@ -22,7 +18,6 @@ type OIDCProvider struct { func NewOIDCProvider(tokenProvider *ClientCredentialsTokenProvider, oidcCredentials *egv1a1.OIDC) *OIDCProvider { return &OIDCProvider{ tokenProvider: tokenProvider, - httpClient: &http.Client{Timeout: 30 * time.Second}, oidcCredential: oidcCredentials, } } @@ -64,6 +59,8 @@ func (p *OIDCProvider) getOIDCProviderConfig(ctx context.Context, issuerURL stri } // FetchToken retrieves and validates tokens using the client credentials flow with OIDC support. +// +// This implements [TokenProvider.FetchToken]. func (p *OIDCProvider) FetchToken(ctx context.Context) (*oauth2.Token, error) { // If issuer URL is provided, fetch OIDC metadata. if issuerURL := p.oidcCredential.Provider.Issuer; issuerURL != "" { diff --git a/internal/controller/oauth/types.go b/internal/controller/oauth/types.go index 9775c1eb3..335f175b4 100644 --- a/internal/controller/oauth/types.go +++ b/internal/controller/oauth/types.go @@ -9,5 +9,6 @@ import ( // TokenProvider defines the interface for OAuth token providers. type TokenProvider interface { + // FetchToken will obtain oauth token using oidc credentials. FetchToken(ctx context.Context, oidc *egv1a1.OIDC) (*oauth2.Token, error) } diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index fec3894b3..2873c8890 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -74,7 +74,7 @@ func NewAWSOIDCRotator( return &AWSOIDCRotator{ client: client, kube: kube, - logger: logger, + logger: logger.WithName("aws-oidc-rotator"), stsOps: stsClient, backendSecurityPolicyNamespace: backendSecurityPolicyNamespace, backendSecurityPolicyName: backendSecurityPolicyName, From ba376782b0915064c81311730cfb756730bf8f49 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Thu, 13 Feb 2025 23:58:03 -0500 Subject: [PATCH 51/86] minor comments Signed-off-by: Aaron Choo --- internal/controller/backend_security_policy_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/controller/backend_security_policy_test.go b/internal/controller/backend_security_policy_test.go index 4a32d7bc1..38c62e49a 100644 --- a/internal/controller/backend_security_policy_test.go +++ b/internal/controller/backend_security_policy_test.go @@ -165,11 +165,13 @@ func (m *mockSTSOperations) AssumeRoleWithWebIdentity(_ context.Context, _ *sts. } func TestBackendSecurityController_GetBackendSecurityPolicyAuthOIDC(t *testing.T) { + // APIKey is not OIDC. require.Nil(t, getBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{ Type: aigv1a1.BackendSecurityPolicyTypeAPIKey, APIKey: &aigv1a1.BackendSecurityPolicyAPIKey{}, })) + // AWSCredentials contains OIDC but this backendSecurityPolicy does not specify OIDC. require.Nil(t, getBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{ Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ From a9764b8920a1028c57efb51e96e0e992ec282da3 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Fri, 14 Feb 2025 00:19:16 -0500 Subject: [PATCH 52/86] replace context.Background() with t.Context() in tests Signed-off-by: Aaron Choo --- .../backend_security_policy_test.go | 19 +++++++-------- .../client_credentials_token_provider_test.go | 8 +++---- .../controller/oauth/oidc_provider_test.go | 14 +++++------ internal/controller/oauth/util_test.go | 4 ++-- .../rotators/aws_oidc_rotator_test.go | 24 +++++++++---------- internal/controller/rotators/common_test.go | 6 ++--- 6 files changed, 36 insertions(+), 39 deletions(-) diff --git a/internal/controller/backend_security_policy_test.go b/internal/controller/backend_security_policy_test.go index 38c62e49a..5f031b862 100644 --- a/internal/controller/backend_security_policy_test.go +++ b/internal/controller/backend_security_policy_test.go @@ -50,7 +50,7 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { "client-secret": []byte("client-secret"), }, } - require.NoError(t, cl.Create(context.Background(), &secret, &client.CreateOptions{})) + require.NoError(t, cl.Create(t.Context(), &secret, &client.CreateOptions{})) secret = corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ @@ -64,7 +64,7 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { "credentials": []byte("credentials"), }, } - require.NoError(t, cl.Create(context.Background(), &secret, &client.CreateOptions{})) + require.NoError(t, cl.Create(t.Context(), &secret, &client.CreateOptions{})) tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Add("Content-Type", "application/json") @@ -88,7 +88,7 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { err := cl.Create(t.Context(), &aigv1a1.BackendSecurityPolicy{ObjectMeta: metav1.ObjectMeta{Name: backendSecurityPolicyName, Namespace: namespace}}) require.NoError(t, err) - err = cl.Create(context.Background(), &aigv1a1.BackendSecurityPolicy{ + err = cl.Create(t.Context(), &aigv1a1.BackendSecurityPolicy{ ObjectMeta: metav1.ObjectMeta{Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName), Namespace: namespace}, Spec: aigv1a1.BackendSecurityPolicySpec{ Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, @@ -114,7 +114,7 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { }, }) require.NoError(t, err) - res, err := c.Reconcile(context.Background(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: backendSecurityPolicyName}}) + res, err := c.Reconcile(t.Context(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: backendSecurityPolicyName}}) require.NoError(t, err) require.False(t, res.Requeue) item, ok := <-ch @@ -124,7 +124,7 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { require.Equal(t, namespace, item.(*aigv1a1.BackendSecurityPolicy).Namespace) c.StsOP = &mockSTSOperations{} - ctx := oidcv3.InsecureIssuerURLContext(context.Background(), discoveryServer.URL) + ctx := oidcv3.InsecureIssuerURLContext(t.Context(), discoveryServer.URL) res, err = c.Reconcile(ctx, reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName)}}) require.NoError(t, err) require.WithinRange(t, time.Now().Add(res.RequeueAfter), time.Now().Add(50*time.Minute), time.Now().Add(time.Hour)) @@ -134,17 +134,14 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { require.True(t, ok) require.Equal(t, "some-access-token", token.AccessToken) - updatedSecret, err := rotators.LookupSecret(context.Background(), cl, namespace, rotators.GetBSPSecretName(fmt.Sprintf("%s-OIDC", backendSecurityPolicyName))) + updatedSecret, err := rotators.LookupSecret(t.Context(), cl, namespace, rotators.GetBSPSecretName(fmt.Sprintf("%s-OIDC", backendSecurityPolicyName))) require.NoError(t, err) require.NotEqualf(t, secret.Annotations[rotators.ExpirationTimeAnnotationKey], updatedSecret.Annotations[rotators.ExpirationTimeAnnotationKey], "expected updated expiration time annotation") // Test the case where the BackendSecurityPolicy is being deleted. - err = cl.Delete(context.Background(), &aigv1a1.BackendSecurityPolicy{ObjectMeta: metav1.ObjectMeta{Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName), Namespace: namespace}}) + err = cl.Delete(t.Context(), &aigv1a1.BackendSecurityPolicy{ObjectMeta: metav1.ObjectMeta{Name: backendSecurityPolicyName, Namespace: namespace}}) require.NoError(t, err) - - res, err = c.Reconcile(context.Background(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName)}}) - require.NoError(t, err) - require.False(t, res.Requeue) + _, err = c.Reconcile(t.Context(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: backendSecurityPolicyName}}) } // mockSTSOperations implements the STSOperations interface for testing diff --git a/internal/controller/oauth/client_credentials_token_provider_test.go b/internal/controller/oauth/client_credentials_token_provider_test.go index a018240be..935e9b023 100644 --- a/internal/controller/oauth/client_credentials_token_provider_test.go +++ b/internal/controller/oauth/client_credentials_token_provider_test.go @@ -42,7 +42,7 @@ func TestClientCredentialsProvider_FetchToken(t *testing.T) { cl := fake.NewClientBuilder().WithScheme(scheme).Build() secretName, secretNamespace := "secret", "secret-ns" - err := cl.Create(context.Background(), &corev1.Secret{ + err := cl.Create(t.Context(), &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: secretName, Namespace: secretNamespace, @@ -60,12 +60,12 @@ func TestClientCredentialsProvider_FetchToken(t *testing.T) { clientCredentialProvider.tokenSource = &MockClientCredentialsTokenSource{} require.NotNil(t, clientCredentialProvider) - _, err = clientCredentialProvider.FetchToken(context.Background(), nil) + _, err = clientCredentialProvider.FetchToken(t.Context(), nil) require.Error(t, err) require.Contains(t, err.Error(), "oidc or oidc-client-secret is nil") namespaceRef := gwapiv1.Namespace(secretNamespace) - timeOutCtx, cancelFunc := context.WithTimeout(context.Background(), time.Second) + timeOutCtx, cancelFunc := context.WithTimeout(t.Context(), time.Second) defer cancelFunc() time.Sleep(time.Second) _, err = clientCredentialProvider.FetchToken(timeOutCtx, &egv1a1.OIDC{ @@ -81,7 +81,7 @@ func TestClientCredentialsProvider_FetchToken(t *testing.T) { }) require.Error(t, err) - token, err := clientCredentialProvider.FetchToken(context.Background(), &egv1a1.OIDC{ + token, err := clientCredentialProvider.FetchToken(t.Context(), &egv1a1.OIDC{ Provider: egv1a1.OIDCProvider{ Issuer: ts.URL, TokenEndpoint: &ts.URL, diff --git a/internal/controller/oauth/oidc_provider_test.go b/internal/controller/oauth/oidc_provider_test.go index 36809a717..c9cf9c4e4 100644 --- a/internal/controller/oauth/oidc_provider_test.go +++ b/internal/controller/oauth/oidc_provider_test.go @@ -46,7 +46,7 @@ func TestOIDCProvider_GetOIDCProviderConfigErrors(t *testing.T) { defer missingTokenURLTestServer.Close() oidcProvider := NewOIDCProvider(NewClientCredentialsProvider(cl), oidc) - cancelledContext, cancel := context.WithCancel(context.Background()) + cancelledContext, cancel := context.WithCancel(t.Context()) cancel() for _, testcase := range []struct { @@ -67,21 +67,21 @@ func TestOIDCProvider_GetOIDCProviderConfigErrors(t *testing.T) { name: "failed to create go oidc", provider: oidcProvider, url: "", - ctx: context.Background(), + ctx: t.Context(), contains: "failed to create go-oidc provider", }, { name: "config missing token url", provider: oidcProvider, url: missingTokenURLTestServer.URL, - ctx: oidcv3.InsecureIssuerURLContext(context.Background(), missingTokenURLTestServer.URL), + ctx: oidcv3.InsecureIssuerURLContext(t.Context(), missingTokenURLTestServer.URL), contains: "token_endpoint is required in OIDC provider config", }, { name: "config missing issuer", provider: oidcProvider, url: missingIssuerTestServer.URL, - ctx: oidcv3.InsecureIssuerURLContext(context.Background(), missingIssuerTestServer.URL), + ctx: oidcv3.InsecureIssuerURLContext(t.Context(), missingIssuerTestServer.URL), contains: "issuer is required in OIDC provider config", }, } { @@ -118,7 +118,7 @@ func TestOIDCProvider_GetOIDCProviderConfig(t *testing.T) { ClientID: "some-client-id", } - ctx := oidcv3.InsecureIssuerURLContext(context.Background(), ts.URL) + ctx := oidcv3.InsecureIssuerURLContext(t.Context(), ts.URL) oidcProvider := NewOIDCProvider(NewClientCredentialsProvider(cl), oidc) config, supportedScope, err := oidcProvider.getOIDCProviderConfig(ctx, ts.URL) require.NoError(t, err) @@ -141,7 +141,7 @@ func TestOIDCProvider_FetchToken(t *testing.T) { cl := fake.NewClientBuilder().WithScheme(scheme).Build() secretName, secretNamespace := "secret", "secret-ns" - err := cl.Create(context.Background(), &corev1.Secret{ + err := cl.Create(t.Context(), &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: secretName, Namespace: secretNamespace, @@ -170,7 +170,7 @@ func TestOIDCProvider_FetchToken(t *testing.T) { clientCredentialProvider := NewClientCredentialsProvider(cl) clientCredentialProvider.tokenSource = &MockClientCredentialsTokenSource{} require.NotNil(t, clientCredentialProvider) - ctx := oidcv3.InsecureIssuerURLContext(context.Background(), ts.URL) + ctx := oidcv3.InsecureIssuerURLContext(t.Context(), ts.URL) oidcProvider := NewOIDCProvider(clientCredentialProvider, oidc) require.Len(t, oidcProvider.oidcCredential.Scopes, 2) diff --git a/internal/controller/oauth/util_test.go b/internal/controller/oauth/util_test.go index 3d69ae99a..1e848b6e0 100644 --- a/internal/controller/oauth/util_test.go +++ b/internal/controller/oauth/util_test.go @@ -19,7 +19,7 @@ func TestGetClientSecret(t *testing.T) { cl := fake.NewClientBuilder().WithScheme(scheme).Build() secretName, secretNamespace := "secret", "secret-ns" - err := cl.Create(context.Background(), &corev1.Secret{ + err := cl.Create(t.Context(), &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: secretName, Namespace: secretNamespace, @@ -33,7 +33,7 @@ func TestGetClientSecret(t *testing.T) { }) require.NoError(t, err) - secret, err := getClientSecret(context.Background(), cl, &corev1.SecretReference{ + secret, err := getClientSecret(t.Context(), cl, &corev1.SecretReference{ Name: secretName, Namespace: secretNamespace, }) diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index 2478519b9..7729f9c73 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -31,7 +31,7 @@ func createTestAWSSecret(t *testing.T, client client.Client, bspName string, acc awsCredentialsKey: []byte(fmt.Sprintf("[%s]\naws_access_key_id = %s\naws_secret_access_key = %s\naws_session_token = %s\nregion = us-west-2", profile, accessKey, secretKey, sessionToken)), } - err := client.Create(context.Background(), &corev1.Secret{ + err := client.Create(t.Context(), &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: GetBSPSecretName(bspName), Namespace: "default", @@ -46,7 +46,7 @@ func verifyAWSSecretCredentials(t *testing.T, client client.Client, namespace, s if profile == "" { profile = "default" } - secret, err := LookupSecret(context.Background(), client, namespace, GetBSPSecretName(secretName)) + secret, err := LookupSecret(t.Context(), client, namespace, GetBSPSecretName(secretName)) require.NoError(t, err) creds := parseAWSCredentialsFile(string(secret.Data[awsCredentialsKey])) require.NotNil(t, creds) @@ -66,7 +66,7 @@ func createClientSecret(t *testing.T, name string) { &corev1.Secret{}, ) cl := fake.NewClientBuilder().WithScheme(scheme).Build() - err := cl.Create(context.Background(), &corev1.Secret{ + err := cl.Create(t.Context(), &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: name, Namespace: "default", @@ -122,13 +122,13 @@ func TestAWS_OIDCRotator(t *testing.T) { backendSecurityPolicyName: "test-secret", } - timeOutCtx, cancelFunc := context.WithTimeout(context.Background(), time.Second) + timeOutCtx, cancelFunc := context.WithTimeout(t.Context(), time.Second) defer cancelFunc() time.Sleep(time.Second) awsOidcRotator.UpdateCtx(timeOutCtx) require.Error(t, awsOidcRotator.Rotate("us-east1", "test", "NEW-OIDC-TOKEN")) - awsOidcRotator.UpdateCtx(context.Background()) + awsOidcRotator.UpdateCtx(t.Context()) require.NoError(t, awsOidcRotator.Rotate("us-east1", "test", "NEW-OIDC-TOKEN")) verifyAWSSecretCredentials(t, cl, "default", "test-secret", "NEWKEY", "NEWSECRET", "NEWTOKEN", "default") }) @@ -147,7 +147,7 @@ func TestAWS_OIDCRotator(t *testing.T) { }, } awsOidcRotator := AWSOIDCRotator{ - ctx: context.Background(), + ctx: t.Context(), client: cl, stsOps: mockSTS, backendSecurityPolicyNamespace: "default", @@ -166,7 +166,7 @@ func TestAWS_GetPreRotationTime(t *testing.T) { ) cl := fake.NewClientBuilder().WithScheme(scheme).Build() awsOidcRotator := AWSOIDCRotator{ - ctx: context.Background(), + ctx: t.Context(), client: cl, backendSecurityPolicyNamespace: "default", backendSecurityPolicyName: "test-secret", @@ -177,12 +177,12 @@ func TestAWS_GetPreRotationTime(t *testing.T) { createTestAWSSecret(t, cl, "test-secret", "OLDKEY", "OLDSECRET", "OLDTOKEN", "default") require.Equal(t, 0, awsOidcRotator.GetPreRotationTime().Minute()) - secret, err := LookupSecret(context.Background(), cl, "default", GetBSPSecretName("test-secret")) + secret, err := LookupSecret(t.Context(), cl, "default", GetBSPSecretName("test-secret")) require.NoError(t, err) expiredTime := time.Now().Add(-1 * time.Hour) updateExpirationSecretAnnotation(secret, expiredTime) - require.NoError(t, cl.Update(context.Background(), secret)) + require.NoError(t, cl.Update(t.Context(), secret)) require.Equal(t, expiredTime.Format(time.RFC3339), awsOidcRotator.GetPreRotationTime().Format(time.RFC3339)) } @@ -203,16 +203,16 @@ func TestAWS_IsExpired(t *testing.T) { createTestAWSSecret(t, cl, "test-secret", "OLDKEY", "OLDSECRET", "OLDTOKEN", "default") require.Equal(t, 0, awsOidcRotator.GetPreRotationTime().Minute()) - secret, err := LookupSecret(context.Background(), cl, "default", GetBSPSecretName("test-secret")) + secret, err := LookupSecret(t.Context(), cl, "default", GetBSPSecretName("test-secret")) require.NoError(t, err) expiredTime := time.Now().Add(-1 * time.Hour) updateExpirationSecretAnnotation(secret, expiredTime) - require.NoError(t, cl.Update(context.Background(), secret)) + require.NoError(t, cl.Update(t.Context(), secret)) require.True(t, awsOidcRotator.IsExpired()) hourFromNowTime := time.Now().Add(1 * time.Hour) updateExpirationSecretAnnotation(secret, hourFromNowTime) - require.NoError(t, cl.Update(context.Background(), secret)) + require.NoError(t, cl.Update(t.Context(), secret)) require.False(t, awsOidcRotator.IsExpired()) } diff --git a/internal/controller/rotators/common_test.go b/internal/controller/rotators/common_test.go index e94557f95..249f1180a 100644 --- a/internal/controller/rotators/common_test.go +++ b/internal/controller/rotators/common_test.go @@ -21,18 +21,18 @@ func TestLookupSecret(t *testing.T) { secretName := "test" secretNamespace := "test-namespace" - secret, err := LookupSecret(context.Background(), cl, secretNamespace, secretName) + secret, err := LookupSecret(t.Context(), cl, secretNamespace, secretName) require.Error(t, err) require.Nil(t, secret) - require.NoError(t, cl.Create(context.Background(), &corev1.Secret{ + require.NoError(t, cl.Create(t.Context(), &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: secretName, Namespace: secretNamespace, }, })) - secret, err = LookupSecret(context.Background(), cl, secretNamespace, secretName) + secret, err = LookupSecret(t.Context(), cl, secretNamespace, secretName) require.NoError(t, err) require.NotNil(t, secret) require.Equal(t, secretName, secret.Name) From 61c93bcc746efa66c958f1b2d5183286b73e80cf Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Fri, 14 Feb 2025 00:23:12 -0500 Subject: [PATCH 53/86] update IsExpired to IsBufferedTimeExpired Signed-off-by: Aaron Choo --- internal/controller/backend_security_policy.go | 2 +- internal/controller/rotators/aws_oidc_rotator.go | 2 +- internal/controller/rotators/common.go | 4 ++-- internal/controller/rotators/common_test.go | 7 +++---- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 939f44a7a..0df030dc9 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -77,7 +77,7 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr bspKey := fmt.Sprintf("%s.%s", backendSecurityPolicy.Name, backendSecurityPolicy.Namespace) var validToken *oauth2.Token - if tokenResponse, ok := b.oidcTokenCache[bspKey]; !ok || rotators.IsExpired(preRotationWindow, tokenResponse.Expiry) { + if tokenResponse, ok := b.oidcTokenCache[bspKey]; !ok || rotators.IsBufferedTimeExpired(preRotationWindow, tokenResponse.Expiry) { oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(b.client), oidc) // Valid Token will be nil if fetch token errors. diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index 2873c8890..d1d9c603f 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -90,7 +90,7 @@ func (r *AWSOIDCRotator) SetSTSOperations(ops STSClient) { // IsExpired checks if the preRotation time is before the current time. func (r *AWSOIDCRotator) IsExpired() bool { preRotationExpirationTime := r.GetPreRotationTime() - return IsExpired(0, preRotationExpirationTime) + return IsBufferedTimeExpired(0, preRotationExpirationTime) } // UpdateCtx is used to update the context used in AWSOIDCRotator functions. diff --git a/internal/controller/rotators/common.go b/internal/controller/rotators/common.go index d29f923f5..d049af5cb 100644 --- a/internal/controller/rotators/common.go +++ b/internal/controller/rotators/common.go @@ -52,8 +52,8 @@ func GetExpirationSecretAnnotation(secret *corev1.Secret) (time.Time, error) { return expirationTime, nil } -// IsExpired checks if the expired time minus duration buffer is before the current time. -func IsExpired(buffer time.Duration, expirationTime time.Time) bool { +// IsBufferedTimeExpired checks if the expired time minus duration buffer is before the current time. +func IsBufferedTimeExpired(buffer time.Duration, expirationTime time.Time) bool { return expirationTime.Add(-buffer).Before(time.Now()) } diff --git a/internal/controller/rotators/common_test.go b/internal/controller/rotators/common_test.go index 249f1180a..f890f023b 100644 --- a/internal/controller/rotators/common_test.go +++ b/internal/controller/rotators/common_test.go @@ -1,7 +1,6 @@ package rotators import ( - "context" "testing" "time" @@ -95,7 +94,7 @@ func TestUpdateAndGetExpirationSecretAnnotation(t *testing.T) { require.Equal(t, timeNow.Format(time.RFC3339), expirationTime.Format(time.RFC3339)) } -func TestIsExpired(t *testing.T) { - require.True(t, IsExpired(1*time.Minute, time.Now())) - require.False(t, IsExpired(1*time.Minute, time.Now().Add(10*time.Minute))) +func TestIsBufferedTimeExpired(t *testing.T) { + require.True(t, IsBufferedTimeExpired(1*time.Minute, time.Now())) + require.False(t, IsBufferedTimeExpired(1*time.Minute, time.Now().Add(10*time.Minute))) } From 80eb8e851e6742fe46b7a2a437995ebd07920e4d Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Fri, 14 Feb 2025 00:24:59 -0500 Subject: [PATCH 54/86] fix tests Signed-off-by: Aaron Choo --- internal/controller/backend_security_policy_test.go | 1 + internal/controller/oauth/oidc_provider.go | 1 + internal/controller/oauth/util_test.go | 1 - 3 files changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/controller/backend_security_policy_test.go b/internal/controller/backend_security_policy_test.go index 5f031b862..25ea81cfd 100644 --- a/internal/controller/backend_security_policy_test.go +++ b/internal/controller/backend_security_policy_test.go @@ -142,6 +142,7 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { err = cl.Delete(t.Context(), &aigv1a1.BackendSecurityPolicy{ObjectMeta: metav1.ObjectMeta{Name: backendSecurityPolicyName, Namespace: namespace}}) require.NoError(t, err) _, err = c.Reconcile(t.Context(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: backendSecurityPolicyName}}) + require.NoError(t, err) } // mockSTSOperations implements the STSOperations interface for testing diff --git a/internal/controller/oauth/oidc_provider.go b/internal/controller/oauth/oidc_provider.go index e343e5678..751d53cb0 100644 --- a/internal/controller/oauth/oidc_provider.go +++ b/internal/controller/oauth/oidc_provider.go @@ -3,6 +3,7 @@ package oauth import ( "context" "fmt" + "github.com/coreos/go-oidc/v3/oidc" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "golang.org/x/oauth2" diff --git a/internal/controller/oauth/util_test.go b/internal/controller/oauth/util_test.go index 1e848b6e0..33854f9cd 100644 --- a/internal/controller/oauth/util_test.go +++ b/internal/controller/oauth/util_test.go @@ -1,7 +1,6 @@ package oauth import ( - "context" "testing" "github.com/stretchr/testify/require" From 986361411c4dec36826c276302e76d7f7d76ba8a Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Fri, 14 Feb 2025 10:44:35 -0500 Subject: [PATCH 55/86] linting Signed-off-by: Aaron Choo --- .../controller/oauth/client_credentials_token_provider.go | 5 +++++ .../oauth/client_credentials_token_provider_test.go | 5 +++++ internal/controller/oauth/oidc_provider.go | 5 +++++ internal/controller/oauth/oidc_provider_test.go | 5 +++++ internal/controller/oauth/types.go | 5 +++++ internal/controller/oauth/util.go | 5 +++++ internal/controller/oauth/util_test.go | 5 +++++ internal/controller/rotators/aws_common.go | 5 +++++ internal/controller/rotators/aws_common_test.go | 5 +++++ internal/controller/rotators/aws_oidc_rotator.go | 5 +++++ internal/controller/rotators/aws_oidc_rotator_test.go | 5 +++++ internal/controller/rotators/common.go | 5 +++++ internal/controller/rotators/common_test.go | 5 +++++ 13 files changed, 65 insertions(+) diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index bae5f8a83..421efb735 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -1,3 +1,8 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + package oauth import ( diff --git a/internal/controller/oauth/client_credentials_token_provider_test.go b/internal/controller/oauth/client_credentials_token_provider_test.go index 935e9b023..c255ff0b9 100644 --- a/internal/controller/oauth/client_credentials_token_provider_test.go +++ b/internal/controller/oauth/client_credentials_token_provider_test.go @@ -1,3 +1,8 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + package oauth import ( diff --git a/internal/controller/oauth/oidc_provider.go b/internal/controller/oauth/oidc_provider.go index 751d53cb0..2ff9ae92f 100644 --- a/internal/controller/oauth/oidc_provider.go +++ b/internal/controller/oauth/oidc_provider.go @@ -1,3 +1,8 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + package oauth import ( diff --git a/internal/controller/oauth/oidc_provider_test.go b/internal/controller/oauth/oidc_provider_test.go index c9cf9c4e4..24a4b5bc2 100644 --- a/internal/controller/oauth/oidc_provider_test.go +++ b/internal/controller/oauth/oidc_provider_test.go @@ -1,3 +1,8 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + package oauth import ( diff --git a/internal/controller/oauth/types.go b/internal/controller/oauth/types.go index 335f175b4..2de87b35b 100644 --- a/internal/controller/oauth/types.go +++ b/internal/controller/oauth/types.go @@ -1,3 +1,8 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + package oauth import ( diff --git a/internal/controller/oauth/util.go b/internal/controller/oauth/util.go index c38dcf691..8aee48e70 100644 --- a/internal/controller/oauth/util.go +++ b/internal/controller/oauth/util.go @@ -1,3 +1,8 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + package oauth import ( diff --git a/internal/controller/oauth/util_test.go b/internal/controller/oauth/util_test.go index 33854f9cd..c7f109c02 100644 --- a/internal/controller/oauth/util_test.go +++ b/internal/controller/oauth/util_test.go @@ -1,3 +1,8 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + package oauth import ( diff --git a/internal/controller/rotators/aws_common.go b/internal/controller/rotators/aws_common.go index c2196efa0..0d75b486d 100644 --- a/internal/controller/rotators/aws_common.go +++ b/internal/controller/rotators/aws_common.go @@ -1,3 +1,8 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + /* Package rotators provides credential rotation implementations. This file contains common AWS functionality shared between different AWS credential diff --git a/internal/controller/rotators/aws_common_test.go b/internal/controller/rotators/aws_common_test.go index 3bb420ec0..d775a74d4 100644 --- a/internal/controller/rotators/aws_common_test.go +++ b/internal/controller/rotators/aws_common_test.go @@ -1,3 +1,8 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + package rotators import ( diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index d1d9c603f..1cb4d1283 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -1,3 +1,8 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + package rotators import ( diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index 7729f9c73..7d8f67fd1 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -1,3 +1,8 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + package rotators import ( diff --git a/internal/controller/rotators/common.go b/internal/controller/rotators/common.go index d049af5cb..b4878912a 100644 --- a/internal/controller/rotators/common.go +++ b/internal/controller/rotators/common.go @@ -1,3 +1,8 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + package rotators import ( diff --git a/internal/controller/rotators/common_test.go b/internal/controller/rotators/common_test.go index f890f023b..df8f6768e 100644 --- a/internal/controller/rotators/common_test.go +++ b/internal/controller/rotators/common_test.go @@ -1,3 +1,8 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + package rotators import ( From f4ad277e8cb90f55a4e16969ce7293e175a81485 Mon Sep 17 00:00:00 2001 From: Dan Sun Date: Sat, 15 Feb 2025 12:05:49 -0500 Subject: [PATCH 56/86] Remove TokenSource Signed-off-by: Dan Sun --- .../backend_security_policy_test.go | 8 ++--- .../client_credentials_token_provider.go | 24 ++++++-------- .../client_credentials_token_provider_test.go | 32 +++++++------------ .../controller/oauth/oidc_provider_test.go | 24 ++++++++++---- 4 files changed, 41 insertions(+), 47 deletions(-) diff --git a/internal/controller/backend_security_policy_test.go b/internal/controller/backend_security_policy_test.go index 25ea81cfd..b0102b43a 100644 --- a/internal/controller/backend_security_policy_test.go +++ b/internal/controller/backend_security_policy_test.go @@ -9,6 +9,7 @@ import ( "context" "encoding/json" "fmt" + "golang.org/x/oauth2" "net/http" "net/http/httptest" "testing" @@ -68,12 +69,7 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Add("Content-Type", "application/json") - type tokenJSON struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresIn string `json:"expires_in"` - } - b, err := json.Marshal(tokenJSON{AccessToken: "some-access-token", TokenType: "Bearer", ExpiresIn: "60"}) + b, err := json.Marshal(oauth2.Token{AccessToken: "some-access-token", TokenType: "Bearer", ExpiresIn: 60}) require.NoError(t, err) _, err = w.Write(b) require.NoError(t, err) diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index 421efb735..98432d61d 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -19,8 +19,7 @@ import ( // ClientCredentialsTokenProvider implements the standard OAuth2 client credentials flow. type ClientCredentialsTokenProvider struct { - tokenSource oauth2.TokenSource - client client.Client + client client.Client } // NewClientCredentialsProvider creates a new client credentials provider. @@ -50,17 +49,14 @@ func (p *ClientCredentialsTokenProvider) FetchToken(ctx context.Context, oidc *e // getTokenWithClientCredentialFlow fetches the oauth2 token with client credential config. func (p *ClientCredentialsTokenProvider) getTokenWithClientCredentialConfig(ctx context.Context, oidc *egv1a1.OIDC, clientSecret string) (*oauth2.Token, error) { - if p.tokenSource == nil { - oauth2Config := clientcredentials.Config{ - ClientSecret: clientSecret, - } - if oidc != nil { - oauth2Config.ClientID = oidc.ClientID - oauth2Config.Scopes = oidc.Scopes - // Discovery returns the OAuth2 endpoints. - oauth2Config.TokenURL = *oidc.Provider.TokenEndpoint - } - p.tokenSource = oauth2Config.TokenSource(ctx) + oauth2Config := clientcredentials.Config{ + ClientSecret: clientSecret, + } + if oidc != nil { + oauth2Config.ClientID = oidc.ClientID + oauth2Config.Scopes = oidc.Scopes + // Discovery returns the OAuth2 endpoints. + oauth2Config.TokenURL = *oidc.Provider.TokenEndpoint } var token *oauth2.Token @@ -72,7 +68,7 @@ func (p *ClientCredentialsTokenProvider) getTokenWithClientCredentialConfig(ctx case <-ctx.Done(): return nil, ctx.Err() case <-timer.C: - token, err = p.tokenSource.Token() + token, err = oauth2Config.Token(ctx) if err != nil { return nil, fmt.Errorf("fail to get oauth2 token %w", err) } diff --git a/internal/controller/oauth/client_credentials_token_provider_test.go b/internal/controller/oauth/client_credentials_token_provider_test.go index c255ff0b9..3d79a99d2 100644 --- a/internal/controller/oauth/client_credentials_token_provider_test.go +++ b/internal/controller/oauth/client_credentials_token_provider_test.go @@ -7,6 +7,8 @@ package oauth import ( "context" + "encoding/json" + "golang.org/x/oauth2" "net/http" "net/http/httptest" "testing" @@ -14,7 +16,6 @@ import ( egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/stretchr/testify/require" - "golang.org/x/oauth2" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -22,23 +23,15 @@ import ( gwapiv1 "sigs.k8s.io/gateway-api/apis/v1" ) -// MockClientCredentialsTokenSource implements the standard OAuth2 client credentials flow -type MockClientCredentialsTokenSource struct{} - -// FetchToken gets the client secret from the secret reference and fetches the token from provider token URL. -func (m *MockClientCredentialsTokenSource) Token() (*oauth2.Token, error) { - return &oauth2.Token{ - AccessToken: "token", - ExpiresIn: 3600, - }, nil -} - func TestClientCredentialsProvider_FetchToken(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - _, err := w.Write([]byte(`{"access_token": "token", "token_type": "Bearer", "expires_in": 3600}`)) + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Add("Content-Type", "application/json") + b, err := json.Marshal(oauth2.Token{AccessToken: "token", TokenType: "Bearer", ExpiresIn: 3600}) + require.NoError(t, err) + _, err = w.Write(b) require.NoError(t, err) })) - defer ts.Close() + defer tokenServer.Close() scheme := runtime.NewScheme() scheme.AddKnownTypes(corev1.SchemeGroupVersion, @@ -62,7 +55,6 @@ func TestClientCredentialsProvider_FetchToken(t *testing.T) { require.NoError(t, err) clientCredentialProvider := NewClientCredentialsProvider(cl) - clientCredentialProvider.tokenSource = &MockClientCredentialsTokenSource{} require.NotNil(t, clientCredentialProvider) _, err = clientCredentialProvider.FetchToken(t.Context(), nil) @@ -75,8 +67,8 @@ func TestClientCredentialsProvider_FetchToken(t *testing.T) { time.Sleep(time.Second) _, err = clientCredentialProvider.FetchToken(timeOutCtx, &egv1a1.OIDC{ Provider: egv1a1.OIDCProvider{ - Issuer: ts.URL, - TokenEndpoint: &ts.URL, + Issuer: tokenServer.URL, + TokenEndpoint: &tokenServer.URL, }, ClientID: "some-client-id", ClientSecret: gwapiv1.SecretObjectReference{ @@ -88,8 +80,8 @@ func TestClientCredentialsProvider_FetchToken(t *testing.T) { token, err := clientCredentialProvider.FetchToken(t.Context(), &egv1a1.OIDC{ Provider: egv1a1.OIDCProvider{ - Issuer: ts.URL, - TokenEndpoint: &ts.URL, + Issuer: tokenServer.URL, + TokenEndpoint: &tokenServer.URL, }, ClientID: "some-client-id", ClientSecret: gwapiv1.SecretObjectReference{ diff --git a/internal/controller/oauth/oidc_provider_test.go b/internal/controller/oauth/oidc_provider_test.go index 24a4b5bc2..f4ca3a849 100644 --- a/internal/controller/oauth/oidc_provider_test.go +++ b/internal/controller/oauth/oidc_provider_test.go @@ -7,9 +7,12 @@ package oauth import ( "context" + "encoding/json" + "golang.org/x/oauth2" "net/http" "net/http/httptest" "testing" + "time" oidcv3 "github.com/coreos/go-oidc/v3/oidc" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" @@ -133,11 +136,19 @@ func TestOIDCProvider_GetOIDCProviderConfig(t *testing.T) { } func TestOIDCProvider_FetchToken(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + oidcServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, err := w.Write([]byte(`{"issuer": "issuer", "token_endpoint": "token_endpoint", "authorization_endpoint": "authorization_endpoint", "jwks_uri": "jwks_uri", "scopes_supported": ["one", "openid"]}`)) require.NoError(t, err) })) - defer ts.Close() + defer oidcServer.Close() + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Add("Content-Type", "application/json") + b, err := json.Marshal(oauth2.Token{AccessToken: "token", TokenType: "Bearer", ExpiresIn: int64(3600)}) + require.NoError(t, err) + _, err = w.Write(b) + require.NoError(t, err) + })) + defer tokenServer.Close() scheme := runtime.NewScheme() scheme.AddKnownTypes(corev1.SchemeGroupVersion, @@ -162,8 +173,8 @@ func TestOIDCProvider_FetchToken(t *testing.T) { namespaceRef := gwapiv1.Namespace(secretNamespace) oidc := &egv1a1.OIDC{ Provider: egv1a1.OIDCProvider{ - Issuer: ts.URL, - TokenEndpoint: &ts.URL, + Issuer: oidcServer.URL, + TokenEndpoint: &tokenServer.URL, }, ClientID: "some-client-id", ClientSecret: gwapiv1.SecretObjectReference{ @@ -173,9 +184,8 @@ func TestOIDCProvider_FetchToken(t *testing.T) { Scopes: []string{"two", "openid"}, } clientCredentialProvider := NewClientCredentialsProvider(cl) - clientCredentialProvider.tokenSource = &MockClientCredentialsTokenSource{} require.NotNil(t, clientCredentialProvider) - ctx := oidcv3.InsecureIssuerURLContext(t.Context(), ts.URL) + ctx := oidcv3.InsecureIssuerURLContext(t.Context(), oidcServer.URL) oidcProvider := NewOIDCProvider(clientCredentialProvider, oidc) require.Len(t, oidcProvider.oidcCredential.Scopes, 2) @@ -183,6 +193,6 @@ func TestOIDCProvider_FetchToken(t *testing.T) { require.NoError(t, err) require.Equal(t, "token", token.AccessToken) require.Equal(t, "Bearer", token.Type()) - require.Equal(t, int64(3600), token.ExpiresIn) + require.WithinRangef(t, token.Expiry, time.Now().Add(3590*time.Second), time.Now().Add(3600*time.Second), "token expires at") require.Len(t, oidcProvider.oidcCredential.Scopes, 3) } From 82c1dbcd949e443e3ca6ce4df1f051b10974283d Mon Sep 17 00:00:00 2001 From: Dan Sun Date: Sat, 15 Feb 2025 12:10:36 -0500 Subject: [PATCH 57/86] Fix awsRoleArn Signed-off-by: Dan Sun --- tests/crdcel/testdata/backendsecuritypolicies/aws_oidc.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/crdcel/testdata/backendsecuritypolicies/aws_oidc.yaml b/tests/crdcel/testdata/backendsecuritypolicies/aws_oidc.yaml index 63fa85968..d684bb978 100644 --- a/tests/crdcel/testdata/backendsecuritypolicies/aws_oidc.yaml +++ b/tests/crdcel/testdata/backendsecuritypolicies/aws_oidc.yaml @@ -13,6 +13,7 @@ spec: awsCredentials: region: us-east-1 oidcExchangeToken: + awsRoleArn: "arn" oidc: provider: issuer: placeholder From b33c481683dc12d8a185a724b465e572f2665874 Mon Sep 17 00:00:00 2001 From: Dan Sun Date: Sat, 15 Feb 2025 12:18:39 -0500 Subject: [PATCH 58/86] Fix context timeout Signed-off-by: Dan Sun --- .../backend_security_policy_test.go | 2 +- .../client_credentials_token_provider.go | 19 +++---------------- .../client_credentials_token_provider_test.go | 2 +- .../controller/oauth/oidc_provider_test.go | 2 +- 4 files changed, 6 insertions(+), 19 deletions(-) diff --git a/internal/controller/backend_security_policy_test.go b/internal/controller/backend_security_policy_test.go index b0102b43a..89b7972a3 100644 --- a/internal/controller/backend_security_policy_test.go +++ b/internal/controller/backend_security_policy_test.go @@ -9,7 +9,6 @@ import ( "context" "encoding/json" "fmt" - "golang.org/x/oauth2" "net/http" "net/http/httptest" "testing" @@ -21,6 +20,7 @@ import ( oidcv3 "github.com/coreos/go-oidc/v3/oidc" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/stretchr/testify/require" + "golang.org/x/oauth2" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index 98432d61d..b19a1d26a 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -58,23 +58,10 @@ func (p *ClientCredentialsTokenProvider) getTokenWithClientCredentialConfig(ctx // Discovery returns the OAuth2 endpoints. oauth2Config.TokenURL = *oidc.Provider.TokenEndpoint } - - var token *oauth2.Token - var err error - // This adds timeout via ctx from the caller. - for token == nil { - timer := time.NewTimer(time.Second) - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-timer.C: - token, err = oauth2Config.Token(ctx) - if err != nil { - return nil, fmt.Errorf("fail to get oauth2 token %w", err) - } - } + token, err := oauth2Config.Token(ctx) + if err != nil { + return nil, fmt.Errorf("fail to get oauth2 token %w", err) } - // Handle expiration. if token.ExpiresIn > 0 { token.Expiry = time.Now().Add(time.Duration(token.ExpiresIn) * time.Second) diff --git a/internal/controller/oauth/client_credentials_token_provider_test.go b/internal/controller/oauth/client_credentials_token_provider_test.go index 3d79a99d2..9ff7cf3ac 100644 --- a/internal/controller/oauth/client_credentials_token_provider_test.go +++ b/internal/controller/oauth/client_credentials_token_provider_test.go @@ -8,7 +8,6 @@ package oauth import ( "context" "encoding/json" - "golang.org/x/oauth2" "net/http" "net/http/httptest" "testing" @@ -16,6 +15,7 @@ import ( egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/stretchr/testify/require" + "golang.org/x/oauth2" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" diff --git a/internal/controller/oauth/oidc_provider_test.go b/internal/controller/oauth/oidc_provider_test.go index f4ca3a849..9644be72d 100644 --- a/internal/controller/oauth/oidc_provider_test.go +++ b/internal/controller/oauth/oidc_provider_test.go @@ -8,7 +8,6 @@ package oauth import ( "context" "encoding/json" - "golang.org/x/oauth2" "net/http" "net/http/httptest" "testing" @@ -17,6 +16,7 @@ import ( oidcv3 "github.com/coreos/go-oidc/v3/oidc" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/stretchr/testify/require" + "golang.org/x/oauth2" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" From d1ba562898a67a375dff6e1e9e69679ef80dc926 Mon Sep 17 00:00:00 2001 From: Dan Sun Date: Sat, 15 Feb 2025 13:38:38 -0500 Subject: [PATCH 59/86] Fix context Signed-off-by: Dan Sun --- .../controller/backend_security_policy.go | 3 +- .../controller/rotators/aws_oidc_rotator.go | 43 +++++-------------- .../rotators/aws_oidc_rotator_test.go | 11 +---- 3 files changed, 14 insertions(+), 43 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 0df030dc9..97d5acb5a 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -105,9 +105,8 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr // Set a timeout for rotate. timeOutCtx, cancelFunc2 := context.WithTimeout(ctx, outGoingTimeOut) defer cancelFunc2() - rotator.UpdateCtx(timeOutCtx) token := validToken.AccessToken - err = rotator.Rotate(awsCredentials.Region, awsCredentials.OIDCExchangeToken.AwsRoleArn, token) + err = rotator.Rotate(timeOutCtx, awsCredentials.Region, awsCredentials.OIDCExchangeToken.AwsRoleArn, token) if err != nil { b.logger.Error(err, "failed to rotate AWS OIDC exchange token") requeue = time.Minute diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index 1cb4d1283..d79ffd72c 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -27,8 +27,6 @@ import ( // It manages the lifecycle of temporary AWS credentials obtained through OIDC token // exchange with AWS STS. type AWSOIDCRotator struct { - // ctx provides a user specified context. - ctx context.Context // client is used for Kubernetes API operations. client client.Client // kube provides additional Kubernetes API capabilities. @@ -98,15 +96,9 @@ func (r *AWSOIDCRotator) IsExpired() bool { return IsBufferedTimeExpired(0, preRotationExpirationTime) } -// UpdateCtx is used to update the context used in AWSOIDCRotator functions. -// This can be used to set timeouts for outgoing calls to assume role. -func (r *AWSOIDCRotator) UpdateCtx(ctx context.Context) { - r.ctx = ctx -} - // GetPreRotationTime gets the expiration time minus the preRotation interval or return zero value for time. func (r *AWSOIDCRotator) GetPreRotationTime() time.Time { - secret, err := LookupSecret(r.ctx, r.client, r.backendSecurityPolicyNamespace, GetBSPSecretName(r.backendSecurityPolicyName)) + secret, err := LookupSecret(context.Background(), r.client, r.backendSecurityPolicyNamespace, GetBSPSecretName(r.backendSecurityPolicyName)) if err != nil { if !errors.IsNotFound(err) { return time.Time{} @@ -122,30 +114,17 @@ func (r *AWSOIDCRotator) GetPreRotationTime() time.Time { } // Rotate implements the retrieval and storage of AWS sts credentials. -func (r *AWSOIDCRotator) Rotate(region, roleARN, token string) error { +func (r *AWSOIDCRotator) Rotate(ctx context.Context, region, roleARN, token string) error { r.logger.Info("rotating AWS sts temporary credentials", "namespace", r.backendSecurityPolicyNamespace, "name", r.backendSecurityPolicyName) - var result *sts.AssumeRoleWithWebIdentityOutput - var err error - - // This adds timeout via ctx from the caller. - for result == nil { - timer := time.NewTimer(time.Second) - select { - case <-r.ctx.Done(): - return r.ctx.Err() - case <-timer.C: - result, err = r.assumeRoleWithToken(roleARN, token) - if err != nil { - r.logger.Error(err, "failed to assume role", "role", roleARN, "ID", token) - return err - } - } + result, err := r.assumeRoleWithToken(ctx, roleARN, token) + if err != nil { + r.logger.Error(err, "failed to assume role", "role", roleARN, "access token", token) + return err } - - secret, err := LookupSecret(r.ctx, r.client, r.backendSecurityPolicyNamespace, GetBSPSecretName(r.backendSecurityPolicyName)) + secret, err := LookupSecret(ctx, r.client, r.backendSecurityPolicyNamespace, GetBSPSecretName(r.backendSecurityPolicyName)) if err != nil { if !errors.IsNotFound(err) { return err @@ -178,10 +157,10 @@ func (r *AWSOIDCRotator) Rotate(region, roleARN, token string) error { updateAWSCredentialsInSecret(secret, &credsFile) - err = r.client.Create(r.ctx, secret) + err = r.client.Create(ctx, secret) if err != nil { if !errors.IsAlreadyExists(err) { - return r.client.Update(r.ctx, secret) + return r.client.Update(ctx, secret) } return fmt.Errorf("failed to create secret: %w", err) } @@ -190,8 +169,8 @@ func (r *AWSOIDCRotator) Rotate(region, roleARN, token string) error { } // assumeRoleWithToken exchanges an OIDC token for AWS credentials. -func (r *AWSOIDCRotator) assumeRoleWithToken(roleARN, token string) (*sts.AssumeRoleWithWebIdentityOutput, error) { - return r.stsOps.AssumeRoleWithWebIdentity(r.ctx, &sts.AssumeRoleWithWebIdentityInput{ +func (r *AWSOIDCRotator) assumeRoleWithToken(ctx context.Context, roleARN, token string) (*sts.AssumeRoleWithWebIdentityOutput, error) { + return r.stsOps.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityInput{ RoleArn: aws.String(roleARN), WebIdentityToken: aws.String(token), RoleSessionName: aws.String(fmt.Sprintf(awsSessionNameFormat, r.backendSecurityPolicyName)), diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index 7d8f67fd1..f0e592dcc 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -129,12 +129,7 @@ func TestAWS_OIDCRotator(t *testing.T) { timeOutCtx, cancelFunc := context.WithTimeout(t.Context(), time.Second) defer cancelFunc() - time.Sleep(time.Second) - awsOidcRotator.UpdateCtx(timeOutCtx) - require.Error(t, awsOidcRotator.Rotate("us-east1", "test", "NEW-OIDC-TOKEN")) - - awsOidcRotator.UpdateCtx(t.Context()) - require.NoError(t, awsOidcRotator.Rotate("us-east1", "test", "NEW-OIDC-TOKEN")) + require.NoError(t, awsOidcRotator.Rotate(timeOutCtx, "us-east1", "test", "NEW-OIDC-TOKEN")) verifyAWSSecretCredentials(t, cl, "default", "test-secret", "NEWKEY", "NEWSECRET", "NEWTOKEN", "default") }) @@ -152,13 +147,12 @@ func TestAWS_OIDCRotator(t *testing.T) { }, } awsOidcRotator := AWSOIDCRotator{ - ctx: t.Context(), client: cl, stsOps: mockSTS, backendSecurityPolicyNamespace: "default", backendSecurityPolicyName: "test-secret", } - err := awsOidcRotator.Rotate("us-east1", "test", "NEW-OIDC-TOKEN") + err := awsOidcRotator.Rotate(t.Context(), "us-east1", "test", "NEW-OIDC-TOKEN") require.Error(t, err) assert.Contains(t, err.Error(), "failed to assume role") }) @@ -171,7 +165,6 @@ func TestAWS_GetPreRotationTime(t *testing.T) { ) cl := fake.NewClientBuilder().WithScheme(scheme).Build() awsOidcRotator := AWSOIDCRotator{ - ctx: t.Context(), client: cl, backendSecurityPolicyNamespace: "default", backendSecurityPolicyName: "test-secret", From 0030f322d7f7de0795dce4aab53f5e60e3e45d61 Mon Sep 17 00:00:00 2001 From: Dan Sun Date: Sat, 15 Feb 2025 14:35:03 -0500 Subject: [PATCH 60/86] Fix oidc config Signed-off-by: Dan Sun --- .../controller/backend_security_policy.go | 13 +----- .../backend_security_policy_test.go | 41 ------------------- .../client_credentials_token_provider.go | 17 ++++---- .../client_credentials_token_provider_test.go | 17 ++++++-- internal/controller/oauth/oidc_provider.go | 4 +- .../controller/oauth/oidc_provider_test.go | 10 ++--- internal/controller/rotators/aws_common.go | 4 +- 7 files changed, 28 insertions(+), 78 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 97d5acb5a..56640c6b1 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -10,7 +10,6 @@ import ( "fmt" "time" - egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/go-logr/logr" "golang.org/x/oauth2" "k8s.io/apimachinery/pkg/api/errors" @@ -65,7 +64,7 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr return ctrl.Result{}, err } - if oidc := getBackendSecurityPolicyAuthOIDC(backendSecurityPolicy.Spec); oidc != nil { + if backendSecurityPolicy.Spec.AWSCredentials != nil && backendSecurityPolicy.Spec.AWSCredentials.OIDCExchangeToken != nil { var requeue time.Duration requeue = time.Minute region := backendSecurityPolicy.Spec.AWSCredentials.Region @@ -78,7 +77,7 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr var validToken *oauth2.Token if tokenResponse, ok := b.oidcTokenCache[bspKey]; !ok || rotators.IsBufferedTimeExpired(preRotationWindow, tokenResponse.Expiry) { - oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(b.client), oidc) + oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(b.client), backendSecurityPolicy.Spec.AWSCredentials.OIDCExchangeToken.OIDC) // Valid Token will be nil if fetch token errors. timeOutCtx, cancelFunc := context.WithTimeout(ctx, outGoingTimeOut) @@ -123,11 +122,3 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr b.eventChan <- backendSecurityPolicy.DeepCopy() return } - -// getBackendSecurityPolicyAuthOIDC returns the backendSecurityPolicy's OIDC pointer or nil. -func getBackendSecurityPolicyAuthOIDC(spec aigv1a1.BackendSecurityPolicySpec) *egv1a1.OIDC { - if spec.AWSCredentials != nil && spec.AWSCredentials.OIDCExchangeToken != nil { - return &spec.AWSCredentials.OIDCExchangeToken.OIDC - } - return nil -} diff --git a/internal/controller/backend_security_policy_test.go b/internal/controller/backend_security_policy_test.go index 89b7972a3..6f35b58b5 100644 --- a/internal/controller/backend_security_policy_test.go +++ b/internal/controller/backend_security_policy_test.go @@ -157,44 +157,3 @@ func (m *mockSTSOperations) AssumeRoleWithWebIdentity(_ context.Context, _ *sts. }, }, nil } - -func TestBackendSecurityController_GetBackendSecurityPolicyAuthOIDC(t *testing.T) { - // APIKey is not OIDC. - require.Nil(t, getBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{ - Type: aigv1a1.BackendSecurityPolicyTypeAPIKey, - APIKey: &aigv1a1.BackendSecurityPolicyAPIKey{}, - })) - - // AWSCredentials contains OIDC but this backendSecurityPolicy does not specify OIDC. - require.Nil(t, getBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{ - Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, - AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ - Region: "us-east-1", - CredentialsFile: &aigv1a1.AWSCredentialsFile{}, - }, - })) - - oidc := egv1a1.OIDC{ - Provider: egv1a1.OIDCProvider{ - Issuer: "https://oidc.example.com", - }, - ClientID: "client-id", - ClientSecret: gwapiv1.SecretObjectReference{ - Name: "client-secret", - }, - } - - actualOIDC := getBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{ - Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, - AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ - Region: "us-east-1", - OIDCExchangeToken: &aigv1a1.AWSOIDCExchangeToken{ - OIDC: oidc, - }, - }, - }) - require.NotNil(t, actualOIDC) - require.Equal(t, oidc.ClientID, actualOIDC.ClientID) - require.Equal(t, oidc.Provider.Issuer, actualOIDC.Provider.Issuer) - require.Equal(t, oidc.ClientSecret.Name, actualOIDC.ClientSecret.Name) -} diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index b19a1d26a..c3fd9d1dd 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -32,9 +32,9 @@ func NewClientCredentialsProvider(cl client.Client) *ClientCredentialsTokenProvi // FetchToken gets the client secret from the secret reference and fetches the token from the provider token URL. // // This implements [TokenProvider.FetchToken]. -func (p *ClientCredentialsTokenProvider) FetchToken(ctx context.Context, oidc *egv1a1.OIDC) (*oauth2.Token, error) { - if oidc == nil || oidc.ClientSecret.Namespace == nil { - return nil, fmt.Errorf("oidc or oidc-client-secret is nil") +func (p *ClientCredentialsTokenProvider) FetchToken(ctx context.Context, oidc egv1a1.OIDC) (*oauth2.Token, error) { + if oidc.ClientSecret.Namespace == nil { + return nil, fmt.Errorf("oidc-client-secret namespace is nil") } clientSecret, err := getClientSecret(ctx, p.client, &corev1.SecretReference{ @@ -48,15 +48,12 @@ func (p *ClientCredentialsTokenProvider) FetchToken(ctx context.Context, oidc *e } // getTokenWithClientCredentialFlow fetches the oauth2 token with client credential config. -func (p *ClientCredentialsTokenProvider) getTokenWithClientCredentialConfig(ctx context.Context, oidc *egv1a1.OIDC, clientSecret string) (*oauth2.Token, error) { +func (p *ClientCredentialsTokenProvider) getTokenWithClientCredentialConfig(ctx context.Context, oidc egv1a1.OIDC, clientSecret string) (*oauth2.Token, error) { oauth2Config := clientcredentials.Config{ ClientSecret: clientSecret, - } - if oidc != nil { - oauth2Config.ClientID = oidc.ClientID - oauth2Config.Scopes = oidc.Scopes - // Discovery returns the OAuth2 endpoints. - oauth2Config.TokenURL = *oidc.Provider.TokenEndpoint + ClientID: oidc.ClientID, + Scopes: oidc.Scopes, + TokenURL: *oidc.Provider.TokenEndpoint, } token, err := oauth2Config.Token(ctx) if err != nil { diff --git a/internal/controller/oauth/client_credentials_token_provider_test.go b/internal/controller/oauth/client_credentials_token_provider_test.go index 9ff7cf3ac..dd98929d3 100644 --- a/internal/controller/oauth/client_credentials_token_provider_test.go +++ b/internal/controller/oauth/client_credentials_token_provider_test.go @@ -57,15 +57,24 @@ func TestClientCredentialsProvider_FetchToken(t *testing.T) { clientCredentialProvider := NewClientCredentialsProvider(cl) require.NotNil(t, clientCredentialProvider) - _, err = clientCredentialProvider.FetchToken(t.Context(), nil) + _, err = clientCredentialProvider.FetchToken(t.Context(), egv1a1.OIDC{ + Provider: egv1a1.OIDCProvider{ + Issuer: tokenServer.URL, + TokenEndpoint: &tokenServer.URL, + }, + ClientID: "some-client-id", + ClientSecret: gwapiv1.SecretObjectReference{ + Name: gwapiv1.ObjectName(secretName), + }, + }) require.Error(t, err) - require.Contains(t, err.Error(), "oidc or oidc-client-secret is nil") + require.Contains(t, err.Error(), "oidc-client-secret namespace is nil") namespaceRef := gwapiv1.Namespace(secretNamespace) timeOutCtx, cancelFunc := context.WithTimeout(t.Context(), time.Second) defer cancelFunc() time.Sleep(time.Second) - _, err = clientCredentialProvider.FetchToken(timeOutCtx, &egv1a1.OIDC{ + _, err = clientCredentialProvider.FetchToken(timeOutCtx, egv1a1.OIDC{ Provider: egv1a1.OIDCProvider{ Issuer: tokenServer.URL, TokenEndpoint: &tokenServer.URL, @@ -78,7 +87,7 @@ func TestClientCredentialsProvider_FetchToken(t *testing.T) { }) require.Error(t, err) - token, err := clientCredentialProvider.FetchToken(t.Context(), &egv1a1.OIDC{ + token, err := clientCredentialProvider.FetchToken(t.Context(), egv1a1.OIDC{ Provider: egv1a1.OIDCProvider{ Issuer: tokenServer.URL, TokenEndpoint: &tokenServer.URL, diff --git a/internal/controller/oauth/oidc_provider.go b/internal/controller/oauth/oidc_provider.go index 2ff9ae92f..4a8102c66 100644 --- a/internal/controller/oauth/oidc_provider.go +++ b/internal/controller/oauth/oidc_provider.go @@ -17,11 +17,11 @@ import ( // OIDCProvider extends ClientCredentialsTokenProvider with OIDC support. type OIDCProvider struct { tokenProvider *ClientCredentialsTokenProvider - oidcCredential *egv1a1.OIDC + oidcCredential egv1a1.OIDC } // NewOIDCProvider creates a new OIDC-aware provider. -func NewOIDCProvider(tokenProvider *ClientCredentialsTokenProvider, oidcCredentials *egv1a1.OIDC) *OIDCProvider { +func NewOIDCProvider(tokenProvider *ClientCredentialsTokenProvider, oidcCredentials egv1a1.OIDC) *OIDCProvider { return &OIDCProvider{ tokenProvider: tokenProvider, oidcCredential: oidcCredentials, diff --git a/internal/controller/oauth/oidc_provider_test.go b/internal/controller/oauth/oidc_provider_test.go index 9644be72d..9babc4db9 100644 --- a/internal/controller/oauth/oidc_provider_test.go +++ b/internal/controller/oauth/oidc_provider_test.go @@ -24,10 +24,6 @@ import ( gwapiv1 "sigs.k8s.io/gateway-api/apis/v1" ) -func TestNewOIDCProvider(t *testing.T) { - require.NotNil(t, NewOIDCProvider(nil, &egv1a1.OIDC{})) -} - func TestOIDCProvider_GetOIDCProviderConfigErrors(t *testing.T) { scheme := runtime.NewScheme() scheme.AddKnownTypes(corev1.SchemeGroupVersion, @@ -35,7 +31,7 @@ func TestOIDCProvider_GetOIDCProviderConfigErrors(t *testing.T) { ) cl := fake.NewClientBuilder().WithScheme(scheme).Build() - oidc := &egv1a1.OIDC{ + oidc := egv1a1.OIDC{ Provider: egv1a1.OIDCProvider{}, ClientID: "some-client-id", } @@ -117,7 +113,7 @@ func TestOIDCProvider_GetOIDCProviderConfig(t *testing.T) { ) cl := fake.NewClientBuilder().WithScheme(scheme).Build() - oidc := &egv1a1.OIDC{ + oidc := egv1a1.OIDC{ Provider: egv1a1.OIDCProvider{ Issuer: ts.URL, TokenEndpoint: &ts.URL, @@ -171,7 +167,7 @@ func TestOIDCProvider_FetchToken(t *testing.T) { }) require.NoError(t, err) namespaceRef := gwapiv1.Namespace(secretNamespace) - oidc := &egv1a1.OIDC{ + oidc := egv1a1.OIDC{ Provider: egv1a1.OIDCProvider{ Issuer: oidcServer.URL, TokenEndpoint: &tokenServer.URL, diff --git a/internal/controller/rotators/aws_common.go b/internal/controller/rotators/aws_common.go index 0d75b486d..a77cc2969 100644 --- a/internal/controller/rotators/aws_common.go +++ b/internal/controller/rotators/aws_common.go @@ -188,9 +188,7 @@ func formatAWSCredentialsFile(file *awsCredentialsFile) string { if creds.sessionToken != "" { builder.WriteString(fmt.Sprintf("aws_session_token = %s\n", creds.sessionToken)) } - if creds.region != "" { - builder.WriteString(fmt.Sprintf("region = %s\n", creds.region)) - } + builder.WriteString(fmt.Sprintf("region = %s\n", creds.region)) } return builder.String() } From 1912b513a159b5487f15399041e5398d437c0083 Mon Sep 17 00:00:00 2001 From: Dan Sun Date: Sat, 15 Feb 2025 14:48:16 -0500 Subject: [PATCH 61/86] Address comments Signed-off-by: Dan Sun --- internal/controller/backend_security_policy.go | 3 +-- .../controller/oauth/client_credentials_token_provider.go | 1 + internal/controller/oauth/oidc_provider.go | 5 ----- internal/controller/oauth/util.go | 4 ++-- 4 files changed, 4 insertions(+), 9 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 56640c6b1..bf85deba1 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -107,7 +107,7 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr token := validToken.AccessToken err = rotator.Rotate(timeOutCtx, awsCredentials.Region, awsCredentials.OIDCExchangeToken.AwsRoleArn, token) if err != nil { - b.logger.Error(err, "failed to rotate AWS OIDC exchange token") + b.logger.Error(err, "failed to rotate AWS OIDC exchange token, retry in one minute") requeue = time.Minute } else { requeue = time.Until(rotator.GetPreRotationTime()) @@ -115,7 +115,6 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr } } - // TODO: Investigate how to stop stale events from re-queuing. res = ctrl.Result{RequeueAfter: requeue} } // Send the backend security policy to the config sink so that it can modify the configuration together with the state of other resources. diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index c3fd9d1dd..d78073bae 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -33,6 +33,7 @@ func NewClientCredentialsProvider(cl client.Client) *ClientCredentialsTokenProvi // // This implements [TokenProvider.FetchToken]. func (p *ClientCredentialsTokenProvider) FetchToken(ctx context.Context, oidc egv1a1.OIDC) (*oauth2.Token, error) { + // client secret namespace is optional on egv1a1.OIDC, but it is required for AI Gateway for now. if oidc.ClientSecret.Namespace == nil { return nil, fmt.Errorf("oidc-client-secret namespace is nil") } diff --git a/internal/controller/oauth/oidc_provider.go b/internal/controller/oauth/oidc_provider.go index 4a8102c66..101fd05cf 100644 --- a/internal/controller/oauth/oidc_provider.go +++ b/internal/controller/oauth/oidc_provider.go @@ -30,11 +30,6 @@ func NewOIDCProvider(tokenProvider *ClientCredentialsTokenProvider, oidcCredenti // getOIDCProviderConfig retrieves or creates OIDC config for the given issuer URL. func (p *OIDCProvider) getOIDCProviderConfig(ctx context.Context, issuerURL string) (*oidc.ProviderConfig, []string, error) { - // Check context before proceeding in case context is cancelled because of timeout. - if err := ctx.Err(); err != nil { - return nil, nil, fmt.Errorf("context error before discovery: %w", err) - } - provider, err := oidc.NewProvider(ctx, issuerURL) if err != nil { return nil, nil, fmt.Errorf("failed to create go-oidc provider %q: %w", issuerURL, err) diff --git a/internal/controller/oauth/util.go b/internal/controller/oauth/util.go index 8aee48e70..1273204b9 100644 --- a/internal/controller/oauth/util.go +++ b/internal/controller/oauth/util.go @@ -7,6 +7,7 @@ package oauth import ( "context" + "errors" "fmt" corev1 "k8s.io/api/core/v1" @@ -25,8 +26,7 @@ func getClientSecret(ctx context.Context, cl client.Client, secretRef *corev1.Se clientSecret, ok := secret.Data["client-secret"] if !ok { - return "", fmt.Errorf("client-secret key not found in secret") + return "", errors.New("client-secret key not found in secret") } - return string(clientSecret), nil } From 56de6f43573906daa3f6d508a01abbda055127c8 Mon Sep 17 00:00:00 2001 From: Dan Sun Date: Sat, 15 Feb 2025 16:04:06 -0500 Subject: [PATCH 62/86] Extract rotate function Signed-off-by: Dan Sun --- .../controller/backend_security_policy.go | 99 ++++++++++--------- .../backend_security_policy_test.go | 3 +- internal/controller/controller.go | 2 +- .../controller/oauth/oidc_provider_test.go | 9 -- .../controller/rotators/aws_oidc_rotator.go | 18 ++-- .../rotators/aws_oidc_rotator_test.go | 4 +- 6 files changed, 61 insertions(+), 74 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index bf85deba1..7ad1e628c 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -37,17 +37,17 @@ type backendSecurityPolicyController struct { kube kubernetes.Interface logger logr.Logger eventChan chan ConfigSinkEvent - StsOP rotators.STSClient + StsClient rotators.STSClient oidcTokenCache map[string]*oauth2.Token } -func newBackendSecurityPolicyController(client client.Client, kube kubernetes.Interface, logger logr.Logger, ch chan ConfigSinkEvent) *backendSecurityPolicyController { +func newBackendSecurityPolicyController(client client.Client, stsClient rotators.STSClient, kube kubernetes.Interface, logger logr.Logger, ch chan ConfigSinkEvent) *backendSecurityPolicyController { return &backendSecurityPolicyController{ client: client, kube: kube, logger: logger, eventChan: ch, - StsOP: nil, + StsClient: stsClient, oidcTokenCache: make(map[string]*oauth2.Token), } } @@ -55,7 +55,7 @@ func newBackendSecurityPolicyController(client client.Client, kube kubernetes.In // Reconcile implements the [reconcile.TypedReconciler] for [aigv1a1.BackendSecurityPolicy]. func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctrl.Request) (res ctrl.Result, err error) { var backendSecurityPolicy aigv1a1.BackendSecurityPolicy - if err = b.client.Get(ctx, req.NamespacedName, &backendSecurityPolicy); err != nil { + if err := b.client.Get(ctx, req.NamespacedName, &backendSecurityPolicy); err != nil { if errors.IsNotFound(err) { ctrl.Log.Info("Deleting Backend Security Policy", "namespace", req.Namespace, "name", req.Name) @@ -65,54 +65,21 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr } if backendSecurityPolicy.Spec.AWSCredentials != nil && backendSecurityPolicy.Spec.AWSCredentials.OIDCExchangeToken != nil { - var requeue time.Duration - requeue = time.Minute - region := backendSecurityPolicy.Spec.AWSCredentials.Region - - rotator, err := rotators.NewAWSOIDCRotator(ctx, b.client, b.kube, b.logger, backendSecurityPolicy.Namespace, backendSecurityPolicy.Name, preRotationWindow, region) + rotator, err := rotators.NewAWSOIDCRotator(ctx, b.client, b.StsClient, b.kube, b.logger, backendSecurityPolicy.Namespace, + backendSecurityPolicy.Name, preRotationWindow, backendSecurityPolicy.Spec.AWSCredentials.Region) if err != nil { b.logger.Error(err, "failed to create AWS OIDC rotator") - } else if rotator.IsExpired() { - bspKey := fmt.Sprintf("%s.%s", backendSecurityPolicy.Name, backendSecurityPolicy.Namespace) - - var validToken *oauth2.Token - if tokenResponse, ok := b.oidcTokenCache[bspKey]; !ok || rotators.IsBufferedTimeExpired(preRotationWindow, tokenResponse.Expiry) { - oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(b.client), backendSecurityPolicy.Spec.AWSCredentials.OIDCExchangeToken.OIDC) - // Valid Token will be nil if fetch token errors. - - timeOutCtx, cancelFunc := context.WithTimeout(ctx, outGoingTimeOut) - defer cancelFunc() - validToken, err = oidcProvider.FetchToken(timeOutCtx) - if err != nil { - b.logger.Error(err, "failed to fetch OIDC provider token") - } else { - b.oidcTokenCache[bspKey] = validToken - } + return ctrl.Result{}, err + } + var requeue time.Duration + requeue = time.Minute + if rotator.IsExpired() { + err := b.rotateCredential(ctx, rotator, backendSecurityPolicy) + if err != nil { + b.logger.Error(err, "failed to rotate OIDC exchange token, retry in one minute") + requeue = time.Minute } else { - validToken = tokenResponse - } - - if validToken != nil { - b.oidcTokenCache[bspKey] = validToken - awsCredentials := backendSecurityPolicy.Spec.AWSCredentials - - // This is to abstract the real STS behavior for testing purpose. - if b.StsOP != nil { - rotator.SetSTSOperations(b.StsOP) - } - - // Set a timeout for rotate. - timeOutCtx, cancelFunc2 := context.WithTimeout(ctx, outGoingTimeOut) - defer cancelFunc2() - token := validToken.AccessToken - err = rotator.Rotate(timeOutCtx, awsCredentials.Region, awsCredentials.OIDCExchangeToken.AwsRoleArn, token) - if err != nil { - b.logger.Error(err, "failed to rotate AWS OIDC exchange token, retry in one minute") - requeue = time.Minute - } else { - requeue = time.Until(rotator.GetPreRotationTime()) - } - + requeue = time.Until(rotator.GetPreRotationTime()) } } res = ctrl.Result{RequeueAfter: requeue} @@ -121,3 +88,37 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr b.eventChan <- backendSecurityPolicy.DeepCopy() return } + +func (b *backendSecurityPolicyController) rotateCredential(ctx context.Context, rotator *rotators.AWSOIDCRotator, policy aigv1a1.BackendSecurityPolicy) error { + bspKey := fmt.Sprintf("%s.%s", policy.Name, policy.Namespace) + var validToken *oauth2.Token + var err error + if tokenResponse, ok := b.oidcTokenCache[bspKey]; !ok || rotators.IsBufferedTimeExpired(preRotationWindow, tokenResponse.Expiry) { + oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(b.client), policy.Spec.AWSCredentials.OIDCExchangeToken.OIDC) + // Valid Token will be nil if fetch token errors. + + timeOutCtx, cancelFunc := context.WithTimeout(ctx, outGoingTimeOut) + defer cancelFunc() + validToken, err = oidcProvider.FetchToken(timeOutCtx) + if err != nil { + b.logger.Error(err, "failed to fetch OIDC provider token") + return err + } else { + b.oidcTokenCache[bspKey] = validToken + } + } else { + validToken = tokenResponse + } + + if validToken != nil { + b.oidcTokenCache[bspKey] = validToken + awsCredentials := policy.Spec.AWSCredentials + + // Set a timeout for rotate. + timeOutCtx, cancelRotateFunc := context.WithTimeout(ctx, outGoingTimeOut) + defer cancelRotateFunc() + token := validToken.AccessToken + return rotator.Rotate(timeOutCtx, awsCredentials.Region, awsCredentials.OIDCExchangeToken.AwsRoleArn, token) + } + return nil +} diff --git a/internal/controller/backend_security_policy_test.go b/internal/controller/backend_security_policy_test.go index 6f35b58b5..9a464dcb1 100644 --- a/internal/controller/backend_security_policy_test.go +++ b/internal/controller/backend_security_policy_test.go @@ -38,7 +38,7 @@ import ( func TestBackendSecurityController_Reconcile(t *testing.T) { ch := make(chan ConfigSinkEvent, 100) cl := fake.NewClientBuilder().WithScheme(scheme).Build() - c := newBackendSecurityPolicyController(cl, fake2.NewClientset(), ctrl.Log, ch) + c := newBackendSecurityPolicyController(cl, &mockSTSOperations{}, fake2.NewClientset(), ctrl.Log, ch) backendSecurityPolicyName := "mybackendSecurityPolicy" namespace := "default" @@ -119,7 +119,6 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { require.Equal(t, backendSecurityPolicyName, item.(*aigv1a1.BackendSecurityPolicy).Name) require.Equal(t, namespace, item.(*aigv1a1.BackendSecurityPolicy).Namespace) - c.StsOP = &mockSTSOperations{} ctx := oidcv3.InsecureIssuerURLContext(t.Context(), discoveryServer.URL) res, err = c.Reconcile(ctx, reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName)}}) require.NoError(t, err) diff --git a/internal/controller/controller.go b/internal/controller/controller.go index fb3756ad7..acdb36e0b 100644 --- a/internal/controller/controller.go +++ b/internal/controller/controller.go @@ -93,7 +93,7 @@ func StartControllers(ctx context.Context, config *rest.Config, logger logr.Logg return fmt.Errorf("failed to create controller for AIServiceBackend: %w", err) } - backendSecurityPolicyC := newBackendSecurityPolicyController(c, kubernetes.NewForConfigOrDie(config), logger. + backendSecurityPolicyC := newBackendSecurityPolicyController(c, nil, kubernetes.NewForConfigOrDie(config), logger. WithName("backend-security-policy"), sinkChan) if err = ctrl.NewControllerManagedBy(mgr). For(&aigv1a1.BackendSecurityPolicy{}). diff --git a/internal/controller/oauth/oidc_provider_test.go b/internal/controller/oauth/oidc_provider_test.go index 9babc4db9..39bb69075 100644 --- a/internal/controller/oauth/oidc_provider_test.go +++ b/internal/controller/oauth/oidc_provider_test.go @@ -50,8 +50,6 @@ func TestOIDCProvider_GetOIDCProviderConfigErrors(t *testing.T) { defer missingTokenURLTestServer.Close() oidcProvider := NewOIDCProvider(NewClientCredentialsProvider(cl), oidc) - cancelledContext, cancel := context.WithCancel(t.Context()) - cancel() for _, testcase := range []struct { name string @@ -60,13 +58,6 @@ func TestOIDCProvider_GetOIDCProviderConfigErrors(t *testing.T) { ctx context.Context contains string }{ - { - name: "context error", - provider: oidcProvider, - ctx: cancelledContext, - url: "", - contains: "context error before discovery", - }, { name: "failed to create go oidc", provider: oidcProvider, diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index d79ffd72c..26419af3d 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -34,7 +34,7 @@ type AWSOIDCRotator struct { // logger is used for structured logging. logger logr.Logger // stsOps provides AWS STS operations interface. - stsOps STSClient + stsClient STSClient // backendSecurityPolicyName provides name of backend security policy. backendSecurityPolicyName string // backendSecurityPolicyNamespace provides namespace of backend security policy. @@ -48,6 +48,7 @@ type AWSOIDCRotator struct { func NewAWSOIDCRotator( ctx context.Context, client client.Client, + stsClient STSClient, kube kubernetes.Interface, logger logr.Logger, backendSecurityPolicyNamespace string, @@ -71,25 +72,20 @@ func NewAWSOIDCRotator( }, } } - - stsClient := NewSTSClient(cfg) - + if stsClient == nil { + stsClient = NewSTSClient(cfg) + } return &AWSOIDCRotator{ client: client, kube: kube, logger: logger.WithName("aws-oidc-rotator"), - stsOps: stsClient, + stsClient: stsClient, backendSecurityPolicyNamespace: backendSecurityPolicyNamespace, backendSecurityPolicyName: backendSecurityPolicyName, preRotationWindow: preRotationWindow, }, nil } -// SetSTSOperations sets the STS operations implementation - primarily used for testing. -func (r *AWSOIDCRotator) SetSTSOperations(ops STSClient) { - r.stsOps = ops -} - // IsExpired checks if the preRotation time is before the current time. func (r *AWSOIDCRotator) IsExpired() bool { preRotationExpirationTime := r.GetPreRotationTime() @@ -170,7 +166,7 @@ func (r *AWSOIDCRotator) Rotate(ctx context.Context, region, roleARN, token stri // assumeRoleWithToken exchanges an OIDC token for AWS credentials. func (r *AWSOIDCRotator) assumeRoleWithToken(ctx context.Context, roleARN, token string) (*sts.AssumeRoleWithWebIdentityOutput, error) { - return r.stsOps.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityInput{ + return r.stsClient.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityInput{ RoleArn: aws.String(roleARN), WebIdentityToken: aws.String(token), RoleSessionName: aws.String(fmt.Sprintf(awsSessionNameFormat, r.backendSecurityPolicyName)), diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index f0e592dcc..fe20808a0 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -122,7 +122,7 @@ func TestAWS_OIDCRotator(t *testing.T) { awsOidcRotator := AWSOIDCRotator{ client: cl, - stsOps: mockSTS, + stsClient: mockSTS, backendSecurityPolicyNamespace: "default", backendSecurityPolicyName: "test-secret", } @@ -148,7 +148,7 @@ func TestAWS_OIDCRotator(t *testing.T) { } awsOidcRotator := AWSOIDCRotator{ client: cl, - stsOps: mockSTS, + stsClient: mockSTS, backendSecurityPolicyNamespace: "default", backendSecurityPolicyName: "test-secret", } From 64d6b87702932c807487fe44aac67eec8b83b897 Mon Sep 17 00:00:00 2001 From: Dan Sun Date: Sat, 15 Feb 2025 16:09:10 -0500 Subject: [PATCH 63/86] Fix lint Signed-off-by: Dan Sun --- internal/controller/backend_security_policy.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 7ad1e628c..3709a1bcd 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -103,9 +103,8 @@ func (b *backendSecurityPolicyController) rotateCredential(ctx context.Context, if err != nil { b.logger.Error(err, "failed to fetch OIDC provider token") return err - } else { - b.oidcTokenCache[bspKey] = validToken } + b.oidcTokenCache[bspKey] = validToken } else { validToken = tokenResponse } From 3bdc3db7bed37b37bd367be8d3af0e2056eb7195 Mon Sep 17 00:00:00 2001 From: Dan Sun Date: Sat, 15 Feb 2025 16:13:44 -0500 Subject: [PATCH 64/86] Fix profile Signed-off-by: Dan Sun --- internal/controller/rotators/aws_common.go | 5 +---- internal/controller/rotators/aws_oidc_rotator.go | 6 +++--- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/internal/controller/rotators/aws_common.go b/internal/controller/rotators/aws_common.go index a77cc2969..2d5868bd0 100644 --- a/internal/controller/rotators/aws_common.go +++ b/internal/controller/rotators/aws_common.go @@ -177,10 +177,7 @@ func formatAWSCredentialsFile(file *awsCredentialsFile) string { } sort.Strings(profileNames) - for i, profileName := range profileNames { - if i > 0 { - builder.WriteString("\n") - } + for _, profileName := range profileNames { creds := file.profiles[profileName] builder.WriteString(fmt.Sprintf("[%s]\n", profileName)) builder.WriteString(fmt.Sprintf("aws_access_key_id = %s\n", creds.accessKeyID)) diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index 26419af3d..b5e220ae8 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -138,11 +138,11 @@ func (r *AWSOIDCRotator) Rotate(ctx context.Context, region, roleARN, token stri updateExpirationSecretAnnotation(secret, *result.Credentials.Expiration) // For now have profile as default. - profile := "default" + const defaultProfile = "default" credsFile := awsCredentialsFile{ profiles: map[string]*awsCredentials{ - profile: { - profile: profile, + defaultProfile: { + profile: defaultProfile, accessKeyID: aws.ToString(result.Credentials.AccessKeyId), secretAccessKey: aws.ToString(result.Credentials.SecretAccessKey), sessionToken: aws.ToString(result.Credentials.SessionToken), From 380d5be47faebb1d0cd7a8063134d4ed27447639 Mon Sep 17 00:00:00 2001 From: Dan Sun Date: Sat, 15 Feb 2025 20:17:02 -0500 Subject: [PATCH 65/86] Fix aws rotator and credential files Signed-off-by: Dan Sun --- .../controller/backend_security_policy.go | 18 ++++++--- internal/controller/rotators/aws_common.go | 29 +++++---------- .../controller/rotators/aws_common_test.go | 9 ++--- .../controller/rotators/aws_oidc_rotator.go | 33 +++++++++-------- .../rotators/aws_oidc_rotator_test.go | 37 +++++++++++-------- 5 files changed, 66 insertions(+), 60 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 3709a1bcd..32d6c990d 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -66,20 +66,28 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr if backendSecurityPolicy.Spec.AWSCredentials != nil && backendSecurityPolicy.Spec.AWSCredentials.OIDCExchangeToken != nil { rotator, err := rotators.NewAWSOIDCRotator(ctx, b.client, b.StsClient, b.kube, b.logger, backendSecurityPolicy.Namespace, - backendSecurityPolicy.Name, preRotationWindow, backendSecurityPolicy.Spec.AWSCredentials.Region) + backendSecurityPolicy.Name, preRotationWindow, backendSecurityPolicy.Spec.AWSCredentials.Region, backendSecurityPolicy.Spec.AWSCredentials.OIDCExchangeToken.AwsRoleArn) if err != nil { b.logger.Error(err, "failed to create AWS OIDC rotator") return ctrl.Result{}, err } var requeue time.Duration requeue = time.Minute - if rotator.IsExpired() { + preRotationExpirationTime, err := rotator.GetPreRotationTime() + if err != nil { + return ctrl.Result{}, err + } + if rotator.IsExpired(preRotationExpirationTime) { err := b.rotateCredential(ctx, rotator, backendSecurityPolicy) if err != nil { b.logger.Error(err, "failed to rotate OIDC exchange token, retry in one minute") requeue = time.Minute } else { - requeue = time.Until(rotator.GetPreRotationTime()) + preRotationExpirationTime, err = rotator.GetPreRotationTime() + if err != nil { + return ctrl.Result{}, err + } + requeue = time.Until(preRotationExpirationTime) } } res = ctrl.Result{RequeueAfter: requeue} @@ -111,13 +119,11 @@ func (b *backendSecurityPolicyController) rotateCredential(ctx context.Context, if validToken != nil { b.oidcTokenCache[bspKey] = validToken - awsCredentials := policy.Spec.AWSCredentials - // Set a timeout for rotate. timeOutCtx, cancelRotateFunc := context.WithTimeout(ctx, outGoingTimeOut) defer cancelRotateFunc() token := validToken.AccessToken - return rotator.Rotate(timeOutCtx, awsCredentials.Region, awsCredentials.OIDCExchangeToken.AwsRoleArn, token) + return rotator.Rotate(timeOutCtx, token) } return nil } diff --git a/internal/controller/rotators/aws_common.go b/internal/controller/rotators/aws_common.go index 2d5868bd0..e4494301f 100644 --- a/internal/controller/rotators/aws_common.go +++ b/internal/controller/rotators/aws_common.go @@ -24,7 +24,6 @@ package rotators import ( "context" "fmt" - "sort" "strings" "github.com/aws/aws-sdk-go-v2/aws" @@ -101,7 +100,7 @@ type awsCredentials struct { // multiple sets of AWS credentials. type awsCredentialsFile struct { // profiles maps profile names to their respective credentials. - profiles map[string]*awsCredentials + profiles []*awsCredentials } // parseAWSCredentialsFile parses an AWS credentials file with multiple profiles. @@ -116,7 +115,7 @@ type awsCredentialsFile struct { // Returns a structured representation of the credentials file. func parseAWSCredentialsFile(data string) *awsCredentialsFile { file := &awsCredentialsFile{ - profiles: make(map[string]*awsCredentials), + profiles: make([]*awsCredentials, 0), } var currentCreds *awsCredentials @@ -130,7 +129,7 @@ func parseAWSCredentialsFile(data string) *awsCredentialsFile { if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { profileName := strings.TrimPrefix(strings.TrimSuffix(line, "]"), "[") currentCreds = &awsCredentials{profile: profileName} - file.profiles[profileName] = currentCreds + file.profiles = append(file.profiles, currentCreds) continue } @@ -170,22 +169,14 @@ func parseAWSCredentialsFile(data string) *awsCredentialsFile { func formatAWSCredentialsFile(file *awsCredentialsFile) string { var builder strings.Builder - // Sort profiles to ensure consistent output. - profileNames := make([]string, 0, len(file.profiles)) - for profileName := range file.profiles { - profileNames = append(profileNames, profileName) - } - sort.Strings(profileNames) - - for _, profileName := range profileNames { - creds := file.profiles[profileName] - builder.WriteString(fmt.Sprintf("[%s]\n", profileName)) - builder.WriteString(fmt.Sprintf("aws_access_key_id = %s\n", creds.accessKeyID)) - builder.WriteString(fmt.Sprintf("aws_secret_access_key = %s\n", creds.secretAccessKey)) - if creds.sessionToken != "" { - builder.WriteString(fmt.Sprintf("aws_session_token = %s\n", creds.sessionToken)) + for _, profile := range file.profiles { + builder.WriteString(fmt.Sprintf("[%s]\n", profile.profile)) + builder.WriteString(fmt.Sprintf("aws_access_key_id = %s\n", profile.accessKeyID)) + builder.WriteString(fmt.Sprintf("aws_secret_access_key = %s\n", profile.secretAccessKey)) + if profile.sessionToken != "" { + builder.WriteString(fmt.Sprintf("aws_session_token = %s\n", profile.sessionToken)) } - builder.WriteString(fmt.Sprintf("region = %s\n", creds.region)) + builder.WriteString(fmt.Sprintf("region = %s\n", profile.region)) } return builder.String() } diff --git a/internal/controller/rotators/aws_common_test.go b/internal/controller/rotators/aws_common_test.go index d775a74d4..bcfc2e5cc 100644 --- a/internal/controller/rotators/aws_common_test.go +++ b/internal/controller/rotators/aws_common_test.go @@ -28,8 +28,7 @@ func TestParseAWSCredentialsFile(t *testing.T) { awsCred := parseAWSCredentialsFile(fmt.Sprintf("[%s]\naws_access_key_id=%s\naws_secret_access_key=%s\naws_session_token=%s\nregion=%s", profile, accessKey, secretKey, sessionToken, region)) require.NotNil(t, awsCred) - defaultProfile, ok := awsCred.profiles[profile] - require.True(t, ok) + defaultProfile := awsCred.profiles[0] require.NotNil(t, defaultProfile) require.Equal(t, accessKey, defaultProfile.accessKeyID) require.Equal(t, secretKey, defaultProfile.secretAccessKey) @@ -38,7 +37,7 @@ func TestParseAWSCredentialsFile(t *testing.T) { } func TestFormatAWSCredentialsFile(t *testing.T) { - emptyCredentialsFile := awsCredentialsFile{map[string]*awsCredentials{}} + emptyCredentialsFile := awsCredentialsFile{[]*awsCredentials{}} require.Empty(t, formatAWSCredentialsFile(&emptyCredentialsFile)) profile := "default" @@ -57,7 +56,7 @@ func TestFormatAWSCredentialsFile(t *testing.T) { awsCred := fmt.Sprintf("[%s]\naws_access_key_id = %s\naws_secret_access_key = %s\naws_session_token = %s\nregion = %s\n", profile, accessKey, secretKey, sessionToken, region) - require.Equal(t, awsCred, formatAWSCredentialsFile(&awsCredentialsFile{profiles: map[string]*awsCredentials{"default": &credentials}})) + require.Equal(t, awsCred, formatAWSCredentialsFile(&awsCredentialsFile{profiles: []*awsCredentials{&credentials}})) } func TestUpdateAWSCredentialsInSecret(t *testing.T) { @@ -71,7 +70,7 @@ func TestUpdateAWSCredentialsInSecret(t *testing.T) { region: "region", } - updateAWSCredentialsInSecret(secret, &awsCredentialsFile{profiles: map[string]*awsCredentials{"default": &credentials}}) + updateAWSCredentialsInSecret(secret, &awsCredentialsFile{profiles: []*awsCredentials{&credentials}}) require.Len(t, secret.Data, 1) val, ok := secret.Data[awsCredentialsKey] diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index b5e220ae8..b796d3020 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -39,6 +39,10 @@ type AWSOIDCRotator struct { backendSecurityPolicyName string // backendSecurityPolicyNamespace provides namespace of backend security policy. backendSecurityPolicyNamespace string + // aws region + region string + // aws IAM role ARN + roleARN string // preRotationWindow specifies how long before expiry to rotate. preRotationWindow time.Duration } @@ -55,6 +59,7 @@ func NewAWSOIDCRotator( backendSecurityPolicyName string, preRotationWindow time.Duration, region string, + roleARN string, ) (*AWSOIDCRotator, error) { cfg, err := defaultAWSConfig(ctx) if err != nil { @@ -83,41 +88,39 @@ func NewAWSOIDCRotator( backendSecurityPolicyNamespace: backendSecurityPolicyNamespace, backendSecurityPolicyName: backendSecurityPolicyName, preRotationWindow: preRotationWindow, + roleARN: roleARN, + region: region, }, nil } // IsExpired checks if the preRotation time is before the current time. -func (r *AWSOIDCRotator) IsExpired() bool { - preRotationExpirationTime := r.GetPreRotationTime() +func (r *AWSOIDCRotator) IsExpired(preRotationExpirationTime time.Time) bool { return IsBufferedTimeExpired(0, preRotationExpirationTime) } // GetPreRotationTime gets the expiration time minus the preRotation interval or return zero value for time. -func (r *AWSOIDCRotator) GetPreRotationTime() time.Time { +func (r *AWSOIDCRotator) GetPreRotationTime() (time.Time, error) { secret, err := LookupSecret(context.Background(), r.client, r.backendSecurityPolicyNamespace, GetBSPSecretName(r.backendSecurityPolicyName)) if err != nil { - if !errors.IsNotFound(err) { - return time.Time{} - } - return time.Time{} + return time.Time{}, err } expirationTime, err := GetExpirationSecretAnnotation(secret) if err != nil { - return time.Time{} + return time.Time{}, err } preRotationTime := expirationTime.Add(-r.preRotationWindow) - return preRotationTime + return preRotationTime, nil } // Rotate implements the retrieval and storage of AWS sts credentials. -func (r *AWSOIDCRotator) Rotate(ctx context.Context, region, roleARN, token string) error { +func (r *AWSOIDCRotator) Rotate(ctx context.Context, token string) error { r.logger.Info("rotating AWS sts temporary credentials", "namespace", r.backendSecurityPolicyNamespace, "name", r.backendSecurityPolicyName) - result, err := r.assumeRoleWithToken(ctx, roleARN, token) + result, err := r.assumeRoleWithToken(ctx, r.roleARN, token) if err != nil { - r.logger.Error(err, "failed to assume role", "role", roleARN, "access token", token) + r.logger.Error(err, "failed to assume role", "role", r.roleARN, "access token", token) return err } secret, err := LookupSecret(ctx, r.client, r.backendSecurityPolicyNamespace, GetBSPSecretName(r.backendSecurityPolicyName)) @@ -140,13 +143,13 @@ func (r *AWSOIDCRotator) Rotate(ctx context.Context, region, roleARN, token stri // For now have profile as default. const defaultProfile = "default" credsFile := awsCredentialsFile{ - profiles: map[string]*awsCredentials{ - defaultProfile: { + profiles: []*awsCredentials{ + { profile: defaultProfile, accessKeyID: aws.ToString(result.Credentials.AccessKeyId), secretAccessKey: aws.ToString(result.Credentials.SecretAccessKey), sessionToken: aws.ToString(result.Credentials.SessionToken), - region: region, + region: r.region, }, }, } diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index fe20808a0..fd22becaf 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -55,10 +55,10 @@ func verifyAWSSecretCredentials(t *testing.T, client client.Client, namespace, s require.NoError(t, err) creds := parseAWSCredentialsFile(string(secret.Data[awsCredentialsKey])) require.NotNil(t, creds) - require.Contains(t, creds.profiles, profile) - assert.Equal(t, expectedKeyID, creds.profiles[profile].accessKeyID) - assert.Equal(t, expectedSecret, creds.profiles[profile].secretAccessKey) - assert.Equal(t, expectedToken, creds.profiles[profile].sessionToken) + assert.Equal(t, profile, creds.profiles[0].profile) + assert.Equal(t, expectedKeyID, creds.profiles[0].accessKeyID) + assert.Equal(t, expectedSecret, creds.profiles[0].secretAccessKey) + assert.Equal(t, expectedToken, creds.profiles[0].sessionToken) } // createClientSecret creates the OIDC client secret @@ -125,11 +125,13 @@ func TestAWS_OIDCRotator(t *testing.T) { stsClient: mockSTS, backendSecurityPolicyNamespace: "default", backendSecurityPolicyName: "test-secret", + region: "us-east-1", + roleARN: "test-role", } timeOutCtx, cancelFunc := context.WithTimeout(t.Context(), time.Second) defer cancelFunc() - require.NoError(t, awsOidcRotator.Rotate(timeOutCtx, "us-east1", "test", "NEW-OIDC-TOKEN")) + require.NoError(t, awsOidcRotator.Rotate(timeOutCtx, "NEW-OIDC-TOKEN")) verifyAWSSecretCredentials(t, cl, "default", "test-secret", "NEWKEY", "NEWSECRET", "NEWTOKEN", "default") }) @@ -151,8 +153,10 @@ func TestAWS_OIDCRotator(t *testing.T) { stsClient: mockSTS, backendSecurityPolicyNamespace: "default", backendSecurityPolicyName: "test-secret", + region: "us-east-1", + roleARN: "test-role", } - err := awsOidcRotator.Rotate(t.Context(), "us-east1", "test", "NEW-OIDC-TOKEN") + err := awsOidcRotator.Rotate(t.Context(), "NEW-OIDC-TOKEN") require.Error(t, err) assert.Contains(t, err.Error(), "failed to assume role") }) @@ -169,11 +173,11 @@ func TestAWS_GetPreRotationTime(t *testing.T) { backendSecurityPolicyNamespace: "default", backendSecurityPolicyName: "test-secret", } - - require.Equal(t, 0, awsOidcRotator.GetPreRotationTime().Minute()) + preRotateTime, _ := awsOidcRotator.GetPreRotationTime() + require.Equal(t, 0, preRotateTime.Minute()) createTestAWSSecret(t, cl, "test-secret", "OLDKEY", "OLDSECRET", "OLDTOKEN", "default") - require.Equal(t, 0, awsOidcRotator.GetPreRotationTime().Minute()) + require.Equal(t, 0, preRotateTime.Minute()) secret, err := LookupSecret(t.Context(), cl, "default", GetBSPSecretName("test-secret")) require.NoError(t, err) @@ -181,7 +185,8 @@ func TestAWS_GetPreRotationTime(t *testing.T) { expiredTime := time.Now().Add(-1 * time.Hour) updateExpirationSecretAnnotation(secret, expiredTime) require.NoError(t, cl.Update(t.Context(), secret)) - require.Equal(t, expiredTime.Format(time.RFC3339), awsOidcRotator.GetPreRotationTime().Format(time.RFC3339)) + preRotateTime, _ = awsOidcRotator.GetPreRotationTime() + require.Equal(t, expiredTime.Format(time.RFC3339), preRotateTime.Format(time.RFC3339)) } func TestAWS_IsExpired(t *testing.T) { @@ -195,11 +200,11 @@ func TestAWS_IsExpired(t *testing.T) { backendSecurityPolicyNamespace: "default", backendSecurityPolicyName: "test-secret", } - - require.True(t, awsOidcRotator.IsExpired()) + preRotateTime, _ := awsOidcRotator.GetPreRotationTime() + require.True(t, awsOidcRotator.IsExpired(preRotateTime)) createTestAWSSecret(t, cl, "test-secret", "OLDKEY", "OLDSECRET", "OLDTOKEN", "default") - require.Equal(t, 0, awsOidcRotator.GetPreRotationTime().Minute()) + require.Equal(t, 0, preRotateTime.Minute()) secret, err := LookupSecret(t.Context(), cl, "default", GetBSPSecretName("test-secret")) require.NoError(t, err) @@ -207,10 +212,12 @@ func TestAWS_IsExpired(t *testing.T) { expiredTime := time.Now().Add(-1 * time.Hour) updateExpirationSecretAnnotation(secret, expiredTime) require.NoError(t, cl.Update(t.Context(), secret)) - require.True(t, awsOidcRotator.IsExpired()) + preRotateTime, _ = awsOidcRotator.GetPreRotationTime() + require.True(t, awsOidcRotator.IsExpired(preRotateTime)) hourFromNowTime := time.Now().Add(1 * time.Hour) updateExpirationSecretAnnotation(secret, hourFromNowTime) require.NoError(t, cl.Update(t.Context(), secret)) - require.False(t, awsOidcRotator.IsExpired()) + preRotateTime, _ = awsOidcRotator.GetPreRotationTime() + require.False(t, awsOidcRotator.IsExpired(preRotateTime)) } From aa2f2754b771b898cf481db0b02fd913207ae73d Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Fri, 14 Feb 2025 11:17:02 -0500 Subject: [PATCH 66/86] cel validation Signed-off-by: Aaron Choo --- api/v1alpha1/zz_generated.deepcopy.go | 5 ----- .../crds/aigateway.envoyproxy.io_aigatewayroutes.yaml | 5 ----- .../crds/aigateway.envoyproxy.io_aiservicebackends.yaml | 5 ----- .../aigateway.envoyproxy.io_backendsecuritypolicies.yaml | 5 ----- 4 files changed, 20 deletions(-) diff --git a/api/v1alpha1/zz_generated.deepcopy.go b/api/v1alpha1/zz_generated.deepcopy.go index 5d6c85884..22857dfef 100644 --- a/api/v1alpha1/zz_generated.deepcopy.go +++ b/api/v1alpha1/zz_generated.deepcopy.go @@ -1,8 +1,3 @@ -// Copyright Envoy AI Gateway Authors -// SPDX-License-Identifier: Apache-2.0 -// The full text of the Apache license is available in the LICENSE file at -// the root of the repo. - //go:build !ignore_autogenerated // Code generated by controller-gen. DO NOT EDIT. diff --git a/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_aigatewayroutes.yaml b/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_aigatewayroutes.yaml index bb40b3318..91e32922d 100644 --- a/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_aigatewayroutes.yaml +++ b/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_aigatewayroutes.yaml @@ -1,8 +1,3 @@ -# Copyright Envoy AI Gateway Authors -# SPDX-License-Identifier: Apache-2.0 -# The full text of the Apache license is available in the LICENSE file at -# the root of the repo. - --- apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition diff --git a/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_aiservicebackends.yaml b/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_aiservicebackends.yaml index 98b865245..7029e4806 100644 --- a/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_aiservicebackends.yaml +++ b/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_aiservicebackends.yaml @@ -1,8 +1,3 @@ -# Copyright Envoy AI Gateway Authors -# SPDX-License-Identifier: Apache-2.0 -# The full text of the Apache license is available in the LICENSE file at -# the root of the repo. - --- apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition diff --git a/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_backendsecuritypolicies.yaml b/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_backendsecuritypolicies.yaml index 451b15e74..6ef366f3d 100644 --- a/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_backendsecuritypolicies.yaml +++ b/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_backendsecuritypolicies.yaml @@ -1,8 +1,3 @@ -# Copyright Envoy AI Gateway Authors -# SPDX-License-Identifier: Apache-2.0 -# The full text of the Apache license is available in the LICENSE file at -# the root of the repo. - --- apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition From a55c5b1bc3f9bd2cb0976a77e1d376fddc579d33 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Fri, 14 Feb 2025 16:08:55 -0500 Subject: [PATCH 67/86] revamp newClientCredentialsProvider + update API requirements Signed-off-by: Aaron Choo --- api/v1alpha1/api.go | 2 ++ api/v1alpha1/zz_generated.deepcopy.go | 5 +++ .../controller/backend_security_policy.go | 2 +- .../client_credentials_token_provider.go | 32 ++++++++++------- .../client_credentials_token_provider_test.go | 35 ++++--------------- internal/controller/oauth/oidc_provider.go | 6 ++-- .../controller/oauth/oidc_provider_test.go | 8 ++--- internal/controller/oauth/types.go | 3 +- internal/controller/sink.go | 6 +++- ...gateway.envoyproxy.io_aigatewayroutes.yaml | 5 +++ ...teway.envoyproxy.io_aiservicebackends.yaml | 5 +++ ...envoyproxy.io_backendsecuritypolicies.yaml | 5 +++ .../backendsecuritypolicies/aws_oidc.yaml | 5 +-- 13 files changed, 62 insertions(+), 57 deletions(-) diff --git a/api/v1alpha1/api.go b/api/v1alpha1/api.go index 2cedc783b..f5e98d28c 100644 --- a/api/v1alpha1/api.go +++ b/api/v1alpha1/api.go @@ -475,6 +475,8 @@ type AWSCredentialsFile struct { // and store them in a temporary credentials file. type AWSOIDCExchangeToken struct { // OIDC is used to obtain oidc tokens via an SSO server which will be used to exchange for temporary AWS credentials. + // + // +kubebuilder:validation:Required OIDC egv1a1.OIDC `json:"oidc"` // GrantType is the method application gets access token. diff --git a/api/v1alpha1/zz_generated.deepcopy.go b/api/v1alpha1/zz_generated.deepcopy.go index 22857dfef..5d6c85884 100644 --- a/api/v1alpha1/zz_generated.deepcopy.go +++ b/api/v1alpha1/zz_generated.deepcopy.go @@ -1,3 +1,8 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + //go:build !ignore_autogenerated // Code generated by controller-gen. DO NOT EDIT. diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 32d6c990d..593715ab4 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -102,7 +102,7 @@ func (b *backendSecurityPolicyController) rotateCredential(ctx context.Context, var validToken *oauth2.Token var err error if tokenResponse, ok := b.oidcTokenCache[bspKey]; !ok || rotators.IsBufferedTimeExpired(preRotationWindow, tokenResponse.Expiry) { - oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(b.client), policy.Spec.AWSCredentials.OIDCExchangeToken.OIDC) + oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(b.client, policy.Spec.AWSCredentials.OIDCExchangeToken.OIDC), policy.Spec.AWSCredentials.OIDCExchangeToken.OIDC) // Valid Token will be nil if fetch token errors. timeOutCtx, cancelFunc := context.WithTimeout(ctx, outGoingTimeOut) diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index d78073bae..1bf5d8aaf 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -8,6 +8,7 @@ package oauth import ( "context" "fmt" + "net/http" "time" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" @@ -17,45 +18,52 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" ) +// tokenTimeoutDuration specifies duration of token retrieval query. +const tokenTimeoutDuration = time.Minute + // ClientCredentialsTokenProvider implements the standard OAuth2 client credentials flow. type ClientCredentialsTokenProvider struct { - client client.Client + client client.Client + oidcCredential egv1a1.OIDC } // NewClientCredentialsProvider creates a new client credentials provider. -func NewClientCredentialsProvider(cl client.Client) *ClientCredentialsTokenProvider { +func NewClientCredentialsProvider(cl client.Client, oidcCredential egv1a1.OIDC) *ClientCredentialsTokenProvider { return &ClientCredentialsTokenProvider{ - client: cl, + client: cl, + oidcCredential: oidcCredential, } } // FetchToken gets the client secret from the secret reference and fetches the token from the provider token URL. // // This implements [TokenProvider.FetchToken]. -func (p *ClientCredentialsTokenProvider) FetchToken(ctx context.Context, oidc egv1a1.OIDC) (*oauth2.Token, error) { +func (p *ClientCredentialsTokenProvider) FetchToken(ctx context.Context) (*oauth2.Token, error) { // client secret namespace is optional on egv1a1.OIDC, but it is required for AI Gateway for now. - if oidc.ClientSecret.Namespace == nil { + if p.oidcCredential.ClientSecret.Namespace == nil { return nil, fmt.Errorf("oidc-client-secret namespace is nil") } clientSecret, err := getClientSecret(ctx, p.client, &corev1.SecretReference{ - Name: string(oidc.ClientSecret.Name), - Namespace: string(*oidc.ClientSecret.Namespace), + Name: string(p.oidcCredential.ClientSecret.Name), + Namespace: string(*p.oidcCredential.ClientSecret.Namespace), }) if err != nil { return nil, err } - return p.getTokenWithClientCredentialConfig(ctx, oidc, clientSecret) + return p.getTokenWithClientCredentialConfig(ctx, clientSecret) } // getTokenWithClientCredentialFlow fetches the oauth2 token with client credential config. -func (p *ClientCredentialsTokenProvider) getTokenWithClientCredentialConfig(ctx context.Context, oidc egv1a1.OIDC, clientSecret string) (*oauth2.Token, error) { +func (p *ClientCredentialsTokenProvider) getTokenWithClientCredentialConfig(ctx context.Context, clientSecret string) (*oauth2.Token, error) { oauth2Config := clientcredentials.Config{ ClientSecret: clientSecret, - ClientID: oidc.ClientID, - Scopes: oidc.Scopes, - TokenURL: *oidc.Provider.TokenEndpoint, + ClientID: p.oidcCredential.ClientID, + Scopes: p.oidcCredential.Scopes, + TokenURL: *p.oidcCredential.Provider.TokenEndpoint, } + // Underlying token call will apply http client timeout. + ctx = context.WithValue(ctx, oauth2.HTTPClient, &http.Client{Timeout: tokenTimeoutDuration}) token, err := oauth2Config.Token(ctx) if err != nil { return nil, fmt.Errorf("fail to get oauth2 token %w", err) diff --git a/internal/controller/oauth/client_credentials_token_provider_test.go b/internal/controller/oauth/client_credentials_token_provider_test.go index dd98929d3..c0121321f 100644 --- a/internal/controller/oauth/client_credentials_token_provider_test.go +++ b/internal/controller/oauth/client_credentials_token_provider_test.go @@ -6,7 +6,6 @@ package oauth import ( - "context" "encoding/json" "net/http" "net/http/httptest" @@ -54,27 +53,15 @@ func TestClientCredentialsProvider_FetchToken(t *testing.T) { }) require.NoError(t, err) - clientCredentialProvider := NewClientCredentialsProvider(cl) + namespaceRef := gwapiv1.Namespace(secretNamespace) + clientCredentialProvider := NewClientCredentialsProvider(cl, egv1a1.OIDC{}) require.NotNil(t, clientCredentialProvider) - _, err = clientCredentialProvider.FetchToken(t.Context(), egv1a1.OIDC{ - Provider: egv1a1.OIDCProvider{ - Issuer: tokenServer.URL, - TokenEndpoint: &tokenServer.URL, - }, - ClientID: "some-client-id", - ClientSecret: gwapiv1.SecretObjectReference{ - Name: gwapiv1.ObjectName(secretName), - }, - }) + _, err = clientCredentialProvider.FetchToken(t.Context()) require.Error(t, err) require.Contains(t, err.Error(), "oidc-client-secret namespace is nil") - namespaceRef := gwapiv1.Namespace(secretNamespace) - timeOutCtx, cancelFunc := context.WithTimeout(t.Context(), time.Second) - defer cancelFunc() - time.Sleep(time.Second) - _, err = clientCredentialProvider.FetchToken(timeOutCtx, egv1a1.OIDC{ + clientCredentialProvider = NewClientCredentialsProvider(cl, egv1a1.OIDC{ Provider: egv1a1.OIDCProvider{ Issuer: tokenServer.URL, TokenEndpoint: &tokenServer.URL, @@ -85,19 +72,9 @@ func TestClientCredentialsProvider_FetchToken(t *testing.T) { Namespace: &namespaceRef, }, }) - require.Error(t, err) + require.NotNil(t, clientCredentialProvider) + token, err := clientCredentialProvider.FetchToken(t.Context()) - token, err := clientCredentialProvider.FetchToken(t.Context(), egv1a1.OIDC{ - Provider: egv1a1.OIDCProvider{ - Issuer: tokenServer.URL, - TokenEndpoint: &tokenServer.URL, - }, - ClientID: "some-client-id", - ClientSecret: gwapiv1.SecretObjectReference{ - Name: gwapiv1.ObjectName(secretName), - Namespace: &namespaceRef, - }, - }) require.NoError(t, err) require.Equal(t, "token", token.AccessToken) require.WithinRangef(t, token.Expiry, time.Now().Add(3590*time.Second), time.Now().Add(3600*time.Second), "token expires at") diff --git a/internal/controller/oauth/oidc_provider.go b/internal/controller/oauth/oidc_provider.go index 101fd05cf..a5768a98f 100644 --- a/internal/controller/oauth/oidc_provider.go +++ b/internal/controller/oauth/oidc_provider.go @@ -16,12 +16,12 @@ import ( // OIDCProvider extends ClientCredentialsTokenProvider with OIDC support. type OIDCProvider struct { - tokenProvider *ClientCredentialsTokenProvider + tokenProvider TokenProvider oidcCredential egv1a1.OIDC } // NewOIDCProvider creates a new OIDC-aware provider. -func NewOIDCProvider(tokenProvider *ClientCredentialsTokenProvider, oidcCredentials egv1a1.OIDC) *OIDCProvider { +func NewOIDCProvider(tokenProvider TokenProvider, oidcCredentials egv1a1.OIDC) *OIDCProvider { return &OIDCProvider{ tokenProvider: tokenProvider, oidcCredential: oidcCredentials, @@ -92,7 +92,7 @@ func (p *OIDCProvider) FetchToken(ctx context.Context) (*oauth2.Token, error) { } // Get base token response. - token, err := p.tokenProvider.FetchToken(ctx, p.oidcCredential) + token, err := p.tokenProvider.FetchToken(ctx) if err != nil { return nil, fmt.Errorf("failed to get token: %w", err) } diff --git a/internal/controller/oauth/oidc_provider_test.go b/internal/controller/oauth/oidc_provider_test.go index 39bb69075..a22ce1b0a 100644 --- a/internal/controller/oauth/oidc_provider_test.go +++ b/internal/controller/oauth/oidc_provider_test.go @@ -49,7 +49,7 @@ func TestOIDCProvider_GetOIDCProviderConfigErrors(t *testing.T) { })) defer missingTokenURLTestServer.Close() - oidcProvider := NewOIDCProvider(NewClientCredentialsProvider(cl), oidc) + oidcProvider := NewOIDCProvider(NewClientCredentialsProvider(cl, oidc), oidc) for _, testcase := range []struct { name string @@ -114,7 +114,7 @@ func TestOIDCProvider_GetOIDCProviderConfig(t *testing.T) { } ctx := oidcv3.InsecureIssuerURLContext(t.Context(), ts.URL) - oidcProvider := NewOIDCProvider(NewClientCredentialsProvider(cl), oidc) + oidcProvider := NewOIDCProvider(NewClientCredentialsProvider(cl, oidc), oidc) config, supportedScope, err := oidcProvider.getOIDCProviderConfig(ctx, ts.URL) require.NoError(t, err) require.Equal(t, "token_endpoint", config.TokenURL) @@ -170,10 +170,8 @@ func TestOIDCProvider_FetchToken(t *testing.T) { }, Scopes: []string{"two", "openid"}, } - clientCredentialProvider := NewClientCredentialsProvider(cl) - require.NotNil(t, clientCredentialProvider) ctx := oidcv3.InsecureIssuerURLContext(t.Context(), oidcServer.URL) - oidcProvider := NewOIDCProvider(clientCredentialProvider, oidc) + oidcProvider := NewOIDCProvider(NewClientCredentialsProvider(cl, oidc), oidc) require.Len(t, oidcProvider.oidcCredential.Scopes, 2) token, err := oidcProvider.FetchToken(ctx) diff --git a/internal/controller/oauth/types.go b/internal/controller/oauth/types.go index 2de87b35b..8aaad49d5 100644 --- a/internal/controller/oauth/types.go +++ b/internal/controller/oauth/types.go @@ -8,12 +8,11 @@ package oauth import ( "context" - egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "golang.org/x/oauth2" ) // TokenProvider defines the interface for OAuth token providers. type TokenProvider interface { // FetchToken will obtain oauth token using oidc credentials. - FetchToken(ctx context.Context, oidc *egv1a1.OIDC) (*oauth2.Token, error) + FetchToken(ctx context.Context) (*oauth2.Token, error) } diff --git a/internal/controller/sink.go b/internal/controller/sink.go index 2e8b9bd59..4ad87ce89 100644 --- a/internal/controller/sink.go +++ b/internal/controller/sink.go @@ -246,7 +246,7 @@ func (c *configSink) syncAIServiceBackend(ctx context.Context, aiBackend *aigv1a } func (c *configSink) syncBackendSecurityPolicy(ctx context.Context, bsp *aigv1a1.BackendSecurityPolicy) { - key := fmt.Sprintf("%s.%s", bsp.Name, bsp.Namespace) + key := backendSecurityPolicyKey(bsp.Namespace, bsp.Name) var aiServiceBackends aigv1a1.AIServiceBackendList err := c.client.List(ctx, &aiServiceBackends, client.MatchingFields{k8sClientIndexBackendSecurityPolicyToReferencingAIServiceBackend: key}) if err != nil { @@ -659,3 +659,7 @@ func backendSecurityPolicyVolumeName(ruleIndex, backendRefIndex int, name string func backendSecurityMountPath(backendSecurityPolicyKey string) string { return fmt.Sprintf("%s/%s", mountedExtProcSecretPath, backendSecurityPolicyKey) } + +func backendSecurityPolicyKey(namespace, name string) string { + return fmt.Sprintf("%s.%s", name, namespace) +} diff --git a/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_aigatewayroutes.yaml b/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_aigatewayroutes.yaml index 91e32922d..bb40b3318 100644 --- a/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_aigatewayroutes.yaml +++ b/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_aigatewayroutes.yaml @@ -1,3 +1,8 @@ +# Copyright Envoy AI Gateway Authors +# SPDX-License-Identifier: Apache-2.0 +# The full text of the Apache license is available in the LICENSE file at +# the root of the repo. + --- apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition diff --git a/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_aiservicebackends.yaml b/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_aiservicebackends.yaml index 7029e4806..98b865245 100644 --- a/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_aiservicebackends.yaml +++ b/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_aiservicebackends.yaml @@ -1,3 +1,8 @@ +# Copyright Envoy AI Gateway Authors +# SPDX-License-Identifier: Apache-2.0 +# The full text of the Apache license is available in the LICENSE file at +# the root of the repo. + --- apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition diff --git a/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_backendsecuritypolicies.yaml b/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_backendsecuritypolicies.yaml index 6ef366f3d..451b15e74 100644 --- a/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_backendsecuritypolicies.yaml +++ b/manifests/charts/ai-gateway-helm/crds/aigateway.envoyproxy.io_backendsecuritypolicies.yaml @@ -1,3 +1,8 @@ +# Copyright Envoy AI Gateway Authors +# SPDX-License-Identifier: Apache-2.0 +# The full text of the Apache license is available in the LICENSE file at +# the root of the repo. + --- apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition diff --git a/tests/crdcel/testdata/backendsecuritypolicies/aws_oidc.yaml b/tests/crdcel/testdata/backendsecuritypolicies/aws_oidc.yaml index d684bb978..0b3a0e7d9 100644 --- a/tests/crdcel/testdata/backendsecuritypolicies/aws_oidc.yaml +++ b/tests/crdcel/testdata/backendsecuritypolicies/aws_oidc.yaml @@ -20,7 +20,4 @@ spec: clientID: placeholder clientSecret: name: placeholder - credentialsFile: - secretRef: - name: placeholder - profile: default + awsRoleARN: placeholder From 6720aee4b5c54c325c734c04059fce7233a519b7 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Fri, 14 Feb 2025 16:22:06 -0500 Subject: [PATCH 68/86] getClientSeret more descriptive error Signed-off-by: Aaron Choo --- internal/controller/oauth/util.go | 6 +++--- internal/controller/oauth/util_test.go | 12 ++++++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/internal/controller/oauth/util.go b/internal/controller/oauth/util.go index 1273204b9..96b3db995 100644 --- a/internal/controller/oauth/util.go +++ b/internal/controller/oauth/util.go @@ -7,7 +7,6 @@ package oauth import ( "context" - "errors" "fmt" corev1 "k8s.io/api/core/v1" @@ -24,9 +23,10 @@ func getClientSecret(ctx context.Context, cl client.Client, secretRef *corev1.Se return "", fmt.Errorf("failed to get client secret: %w", err) } - clientSecret, ok := secret.Data["client-secret"] + secretDataKey := "client-secret" + clientSecret, ok := secret.Data[secretDataKey] if !ok { - return "", errors.New("client-secret key not found in secret") + return "", fmt.Errorf("failed to get client secret: no secret data found using key '%s' in secret name '%s' and namespace '%s", secretDataKey, secretRef.Name, secretRef.Namespace) } return string(clientSecret), nil } diff --git a/internal/controller/oauth/util_test.go b/internal/controller/oauth/util_test.go index c7f109c02..76ee75e19 100644 --- a/internal/controller/oauth/util_test.go +++ b/internal/controller/oauth/util_test.go @@ -23,7 +23,15 @@ func TestGetClientSecret(t *testing.T) { cl := fake.NewClientBuilder().WithScheme(scheme).Build() secretName, secretNamespace := "secret", "secret-ns" - err := cl.Create(t.Context(), &corev1.Secret{ + + secret, err := getClientSecret(t.Context(), cl, &corev1.SecretReference{ + Name: secretName, + Namespace: secretNamespace, + }) + require.Error(t, err) + require.Empty(t, secret) + + err = cl.Create(t.Context(), &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: secretName, Namespace: secretNamespace, @@ -37,7 +45,7 @@ func TestGetClientSecret(t *testing.T) { }) require.NoError(t, err) - secret, err := getClientSecret(t.Context(), cl, &corev1.SecretReference{ + secret, err = getClientSecret(t.Context(), cl, &corev1.SecretReference{ Name: secretName, Namespace: secretNamespace, }) From df1272656e884c07ba3b4445322ed14a3570d571 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Fri, 14 Feb 2025 17:09:12 -0500 Subject: [PATCH 69/86] support only one credential profile per file Signed-off-by: Aaron Choo --- internal/controller/rotators/aws_common.go | 94 ++++--------------- .../controller/rotators/aws_common_test.go | 24 +---- .../controller/rotators/aws_oidc_rotator.go | 18 ++-- .../rotators/aws_oidc_rotator_test.go | 18 ++-- 4 files changed, 31 insertions(+), 123 deletions(-) diff --git a/internal/controller/rotators/aws_common.go b/internal/controller/rotators/aws_common.go index e4494301f..23dbf12cb 100644 --- a/internal/controller/rotators/aws_common.go +++ b/internal/controller/rotators/aws_common.go @@ -8,11 +8,10 @@ Package rotators provides credential rotation implementations. This file contains common AWS functionality shared between different AWS credential rotators. It provides: 1. AWS Client Interfaces and Implementations: -- STSOperations for AWS STS API operations +- STSClient for AWS STS API operations - Concrete implementations with proper AWS SDK integration 2. Credential File Management: -- Parsing and formatting of AWS credentials files -- Support for multiple credential profiles +- Parsing and formatting of AWS credentials file - Handling of temporary credentials and session tokens 3. Common Configuration: - Default AWS configuration with adaptive retry @@ -95,89 +94,28 @@ type awsCredentials struct { region string } -// awsCredentialsFile represents a complete AWS credentials file containing -// multiple credential profiles. It provides a structured way to manage -// multiple sets of AWS credentials. +// awsCredentialsFile represents a complete AWS credentials file containing a credential profile. type awsCredentialsFile struct { - // profiles maps profile names to their respective credentials. - profiles []*awsCredentials + // creds stores the aws credentials. + creds awsCredentials } -// parseAWSCredentialsFile parses an AWS credentials file with multiple profiles. -// The file format follows the standard AWS credentials file format: -// -// [profile-name] -// aws_access_key_id = AKIAXXXXXXXXXXXXXXXX -// aws_secret_access_key = xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx -// aws_session_token = xxxxxxxx (optional) -// region = xx-xxxx-x (optional) -// -// Returns a structured representation of the credentials file. -func parseAWSCredentialsFile(data string) *awsCredentialsFile { - file := &awsCredentialsFile{ - profiles: make([]*awsCredentials, 0), - } - - var currentCreds *awsCredentials - - for line := range strings.Lines(data) { - line = strings.TrimSpace(line) - if line == "" { - continue - } - - if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { - profileName := strings.TrimPrefix(strings.TrimSuffix(line, "]"), "[") - currentCreds = &awsCredentials{profile: profileName} - file.profiles = append(file.profiles, currentCreds) - continue - } - - if currentCreds == nil { - continue - } - - parts := strings.SplitN(line, "=", 2) - if len(parts) != 2 { - continue - } - - key := strings.TrimSpace(parts[0]) - value := strings.TrimSpace(parts[1]) - - switch key { - case "aws_access_key_id": - currentCreds.accessKeyID = value - case "aws_secret_access_key": - currentCreds.secretAccessKey = value - case "aws_session_token": - currentCreds.sessionToken = value - case "region": - currentCreds.region = value - } - } - - return file -} - -// formatAWSCredentialsFile formats multiple AWS credential profiles into a credentials file. +// formatAWSCredentialsFile formats an AWS credential profile into a credentials file. // The output follows the standard AWS credentials file format and ensures: -// - Consistent ordering of profiles through sorting // - Proper formatting of all credential components -// - Optional inclusion of session tokens and regions -// - Profile isolation with proper section markers +// - Optional inclusion of session token and region func formatAWSCredentialsFile(file *awsCredentialsFile) string { var builder strings.Builder - - for _, profile := range file.profiles { - builder.WriteString(fmt.Sprintf("[%s]\n", profile.profile)) - builder.WriteString(fmt.Sprintf("aws_access_key_id = %s\n", profile.accessKeyID)) - builder.WriteString(fmt.Sprintf("aws_secret_access_key = %s\n", profile.secretAccessKey)) - if profile.sessionToken != "" { - builder.WriteString(fmt.Sprintf("aws_session_token = %s\n", profile.sessionToken)) - } - builder.WriteString(fmt.Sprintf("region = %s\n", profile.region)) + builder.WriteString(fmt.Sprintf("[%s]\n", file.creds.profile)) + builder.WriteString(fmt.Sprintf("aws_access_key_id = %s\n", file.creds.accessKeyID)) + builder.WriteString(fmt.Sprintf("aws_secret_access_key = %s\n", file.creds.secretAccessKey)) + if file.creds.sessionToken != "" { + builder.WriteString(fmt.Sprintf("aws_session_token = %s\n", file.creds.sessionToken)) } + if file.creds.region != "" { + builder.WriteString(fmt.Sprintf("region = %s\n", file.creds.region)) + } + return builder.String() } diff --git a/internal/controller/rotators/aws_common_test.go b/internal/controller/rotators/aws_common_test.go index bcfc2e5cc..7843f0e4e 100644 --- a/internal/controller/rotators/aws_common_test.go +++ b/internal/controller/rotators/aws_common_test.go @@ -19,27 +19,7 @@ func TestNewSTSClient(t *testing.T) { require.NotNil(t, stsClient) } -func TestParseAWSCredentialsFile(t *testing.T) { - profile := "default" - accessKey := "AKIAXXXXXXXXXXXXXXXX" - secretKey := "XXXXXXXXXXXXXXXXXXXX" - sessionToken := "XXXXXXXXXXXXXXXXXXXX" - region := "us-west-2" - awsCred := parseAWSCredentialsFile(fmt.Sprintf("[%s]\naws_access_key_id=%s\naws_secret_access_key=%s\naws_session_token=%s\nregion=%s", profile, accessKey, - secretKey, sessionToken, region)) - require.NotNil(t, awsCred) - defaultProfile := awsCred.profiles[0] - require.NotNil(t, defaultProfile) - require.Equal(t, accessKey, defaultProfile.accessKeyID) - require.Equal(t, secretKey, defaultProfile.secretAccessKey) - require.Equal(t, sessionToken, defaultProfile.sessionToken) - require.Equal(t, region, defaultProfile.region) -} - func TestFormatAWSCredentialsFile(t *testing.T) { - emptyCredentialsFile := awsCredentialsFile{[]*awsCredentials{}} - require.Empty(t, formatAWSCredentialsFile(&emptyCredentialsFile)) - profile := "default" accessKey := "AKIAXXXXXXXXXXXXXXXX" secretKey := "XXXXXXXXXXXXXXXXXXXX" @@ -56,7 +36,7 @@ func TestFormatAWSCredentialsFile(t *testing.T) { awsCred := fmt.Sprintf("[%s]\naws_access_key_id = %s\naws_secret_access_key = %s\naws_session_token = %s\nregion = %s\n", profile, accessKey, secretKey, sessionToken, region) - require.Equal(t, awsCred, formatAWSCredentialsFile(&awsCredentialsFile{profiles: []*awsCredentials{&credentials}})) + require.Equal(t, awsCred, formatAWSCredentialsFile(&awsCredentialsFile{credentials})) } func TestUpdateAWSCredentialsInSecret(t *testing.T) { @@ -70,7 +50,7 @@ func TestUpdateAWSCredentialsInSecret(t *testing.T) { region: "region", } - updateAWSCredentialsInSecret(secret, &awsCredentialsFile{profiles: []*awsCredentials{&credentials}}) + updateAWSCredentialsInSecret(secret, &awsCredentialsFile{credentials}) require.Len(t, secret.Data, 1) val, ok := secret.Data[awsCredentialsKey] diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index b796d3020..d05d8a8cf 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -142,17 +142,13 @@ func (r *AWSOIDCRotator) Rotate(ctx context.Context, token string) error { // For now have profile as default. const defaultProfile = "default" - credsFile := awsCredentialsFile{ - profiles: []*awsCredentials{ - { - profile: defaultProfile, - accessKeyID: aws.ToString(result.Credentials.AccessKeyId), - secretAccessKey: aws.ToString(result.Credentials.SecretAccessKey), - sessionToken: aws.ToString(result.Credentials.SessionToken), - region: r.region, - }, - }, - } + credsFile := awsCredentialsFile{awsCredentials{ + profile: defaultProfile, + accessKeyID: aws.ToString(result.Credentials.AccessKeyId), + secretAccessKey: aws.ToString(result.Credentials.SecretAccessKey), + sessionToken: aws.ToString(result.Credentials.SessionToken), + region: r.region, + }} updateAWSCredentialsInSecret(secret, &credsFile) diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index fd22becaf..fbc689517 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -47,18 +47,14 @@ func createTestAWSSecret(t *testing.T, client client.Client, bspName string, acc } // verifyAWSSecretCredentials verifies the credentials in a secret -func verifyAWSSecretCredentials(t *testing.T, client client.Client, namespace, secretName, expectedKeyID, expectedSecret, expectedToken string, profile string) { +func verifyAWSSecretCredentials(t *testing.T, client client.Client, namespace, secretName, expectedKeyID, expectedSecret, expectedToken, profile, region string) { if profile == "" { profile = "default" } secret, err := LookupSecret(t.Context(), client, namespace, GetBSPSecretName(secretName)) require.NoError(t, err) - creds := parseAWSCredentialsFile(string(secret.Data[awsCredentialsKey])) - require.NotNil(t, creds) - assert.Equal(t, profile, creds.profiles[0].profile) - assert.Equal(t, expectedKeyID, creds.profiles[0].accessKeyID) - assert.Equal(t, expectedSecret, creds.profiles[0].secretAccessKey) - assert.Equal(t, expectedToken, creds.profiles[0].sessionToken) + expectedSecretData := fmt.Sprintf("[%s]\naws_access_key_id = %s\naws_secret_access_key = %s\naws_session_token = %s\nregion = %s\n", profile, expectedKeyID, expectedSecret, expectedToken, region) + require.Equal(t, expectedSecretData, string(secret.Data[awsCredentialsKey])) } // createClientSecret creates the OIDC client secret @@ -125,14 +121,12 @@ func TestAWS_OIDCRotator(t *testing.T) { stsClient: mockSTS, backendSecurityPolicyNamespace: "default", backendSecurityPolicyName: "test-secret", - region: "us-east-1", + region: "us-east1", roleARN: "test-role", } - timeOutCtx, cancelFunc := context.WithTimeout(t.Context(), time.Second) - defer cancelFunc() - require.NoError(t, awsOidcRotator.Rotate(timeOutCtx, "NEW-OIDC-TOKEN")) - verifyAWSSecretCredentials(t, cl, "default", "test-secret", "NEWKEY", "NEWSECRET", "NEWTOKEN", "default") + require.NoError(t, awsOidcRotator.Rotate(t.Context(), "NEW-OIDC-TOKEN")) + verifyAWSSecretCredentials(t, cl, "default", "test-secret", "NEWKEY", "NEWSECRET", "NEWTOKEN", "default", "us-east1") }) t.Run("error handling - STS assume role failure", func(t *testing.T) { From db334d4fe8acc422c191ff22948344600faf634d Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Fri, 14 Feb 2025 17:12:24 -0500 Subject: [PATCH 70/86] region always exists Signed-off-by: Aaron Choo --- internal/controller/rotators/aws_common.go | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/internal/controller/rotators/aws_common.go b/internal/controller/rotators/aws_common.go index 23dbf12cb..0d571f240 100644 --- a/internal/controller/rotators/aws_common.go +++ b/internal/controller/rotators/aws_common.go @@ -78,8 +78,8 @@ func (c *stsClient) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.A return c.client.AssumeRoleWithWebIdentity(ctx, params, optFns...) } -// awsCredentials represents a single set of AWS credentials, including optional -// session token and region configuration. It maps to a single profile in an +// awsCredentials represents an AWS credential including optional +// session token configuration. It maps to a single profile in an // AWS credentials file. type awsCredentials struct { // profile is the name of the credentials profile. @@ -103,7 +103,7 @@ type awsCredentialsFile struct { // formatAWSCredentialsFile formats an AWS credential profile into a credentials file. // The output follows the standard AWS credentials file format and ensures: // - Proper formatting of all credential components -// - Optional inclusion of session token and region +// - Optional inclusion of session token func formatAWSCredentialsFile(file *awsCredentialsFile) string { var builder strings.Builder builder.WriteString(fmt.Sprintf("[%s]\n", file.creds.profile)) @@ -112,9 +112,7 @@ func formatAWSCredentialsFile(file *awsCredentialsFile) string { if file.creds.sessionToken != "" { builder.WriteString(fmt.Sprintf("aws_session_token = %s\n", file.creds.sessionToken)) } - if file.creds.region != "" { - builder.WriteString(fmt.Sprintf("region = %s\n", file.creds.region)) - } + builder.WriteString(fmt.Sprintf("region = %s\n", file.creds.region)) return builder.String() } From 98f60016edef4247aac1d618ba1bc9dbb3583941 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Sat, 15 Feb 2025 15:49:19 -0500 Subject: [PATCH 71/86] use backendSecurityPolicyKey func to reduce replication Signed-off-by: Aaron Choo --- internal/controller/backend_security_policy.go | 5 +++++ internal/controller/controller.go | 2 +- internal/controller/sink.go | 6 +----- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 593715ab4..13ea9f4c9 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -127,3 +127,8 @@ func (b *backendSecurityPolicyController) rotateCredential(ctx context.Context, } return nil } + +// backendSecurityPolicyKey returns the key used for indexing and caching the backendSecurityPolicy +func backendSecurityPolicyKey(namespace, name string) string { + return fmt.Sprintf("%s.%s", name, namespace) +} diff --git a/internal/controller/controller.go b/internal/controller/controller.go index acdb36e0b..8808b11bf 100644 --- a/internal/controller/controller.go +++ b/internal/controller/controller.go @@ -187,7 +187,7 @@ func backendSecurityPolicyIndexFunc(o client.Object) []string { if awsCreds.CredentialsFile != nil { key = getSecretNameAndNamespace(awsCreds.CredentialsFile.SecretRef, backendSecurityPolicy.Namespace) } else if awsCreds.OIDCExchangeToken != nil { - key = fmt.Sprintf("%s.%s", backendSecurityPolicy.Name, backendSecurityPolicy.Namespace) + key = backendSecurityPolicyKey(backendSecurityPolicy.Namespace, backendSecurityPolicy.Name) } } return []string{key} diff --git a/internal/controller/sink.go b/internal/controller/sink.go index 4ad87ce89..c583d39c6 100644 --- a/internal/controller/sink.go +++ b/internal/controller/sink.go @@ -638,7 +638,7 @@ func (c *configSink) syncSecret(ctx context.Context, namespace, name string) { var backendSecurityPolicies aigv1a1.BackendSecurityPolicyList err := c.client.List(ctx, &backendSecurityPolicies, client.MatchingFields{ - k8sClientIndexSecretToReferencingBackendSecurityPolicy: fmt.Sprintf("%s.%s", name, namespace), + k8sClientIndexSecretToReferencingBackendSecurityPolicy: backendSecurityPolicyKey(namespace, name), }, ) if err != nil { @@ -659,7 +659,3 @@ func backendSecurityPolicyVolumeName(ruleIndex, backendRefIndex int, name string func backendSecurityMountPath(backendSecurityPolicyKey string) string { return fmt.Sprintf("%s/%s", mountedExtProcSecretPath, backendSecurityPolicyKey) } - -func backendSecurityPolicyKey(namespace, name string) string { - return fmt.Sprintf("%s.%s", name, namespace) -} From 35f37031025bd39066c9a98980a4d4b998a90c70 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Sat, 15 Feb 2025 15:51:17 -0500 Subject: [PATCH 72/86] stsOp -> stsClient Signed-off-by: Aaron Choo --- internal/controller/rotators/aws_oidc_rotator.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index d05d8a8cf..fb61e5c9e 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -33,7 +33,7 @@ type AWSOIDCRotator struct { kube kubernetes.Interface // logger is used for structured logging. logger logr.Logger - // stsOps provides AWS STS operations interface. + // stsClient provides AWS STS operations interface. stsClient STSClient // backendSecurityPolicyName provides name of backend security policy. backendSecurityPolicyName string @@ -118,7 +118,7 @@ func (r *AWSOIDCRotator) Rotate(ctx context.Context, token string) error { "namespace", r.backendSecurityPolicyNamespace, "name", r.backendSecurityPolicyName) - result, err := r.assumeRoleWithToken(ctx, r.roleARN, token) + result, err := r.assumeRoleWithToken(ctx, token) if err != nil { r.logger.Error(err, "failed to assume role", "role", r.roleARN, "access token", token) return err @@ -164,9 +164,9 @@ func (r *AWSOIDCRotator) Rotate(ctx context.Context, token string) error { } // assumeRoleWithToken exchanges an OIDC token for AWS credentials. -func (r *AWSOIDCRotator) assumeRoleWithToken(ctx context.Context, roleARN, token string) (*sts.AssumeRoleWithWebIdentityOutput, error) { +func (r *AWSOIDCRotator) assumeRoleWithToken(ctx context.Context, token string) (*sts.AssumeRoleWithWebIdentityOutput, error) { return r.stsClient.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityInput{ - RoleArn: aws.String(roleARN), + RoleArn: aws.String(r.roleARN), WebIdentityToken: aws.String(token), RoleSessionName: aws.String(fmt.Sprintf(awsSessionNameFormat, r.backendSecurityPolicyName)), }) From 0a2dcd0d4b4c377c444184a39bbfe95c17341de7 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Sat, 15 Feb 2025 23:36:35 -0500 Subject: [PATCH 73/86] stop storing ctx Signed-off-by: Aaron Choo --- internal/controller/backend_security_policy.go | 6 +++--- internal/controller/rotators/aws_oidc_rotator.go | 5 +++-- internal/controller/rotators/aws_oidc_rotator_test.go | 11 ++++++----- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 13ea9f4c9..1d4fe446a 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -73,7 +73,7 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr } var requeue time.Duration requeue = time.Minute - preRotationExpirationTime, err := rotator.GetPreRotationTime() + preRotationExpirationTime, err := rotator.GetPreRotationTime(ctx) if err != nil { return ctrl.Result{}, err } @@ -83,7 +83,7 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr b.logger.Error(err, "failed to rotate OIDC exchange token, retry in one minute") requeue = time.Minute } else { - preRotationExpirationTime, err = rotator.GetPreRotationTime() + preRotationExpirationTime, err = rotator.GetPreRotationTime(ctx) if err != nil { return ctrl.Result{}, err } @@ -98,7 +98,7 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr } func (b *backendSecurityPolicyController) rotateCredential(ctx context.Context, rotator *rotators.AWSOIDCRotator, policy aigv1a1.BackendSecurityPolicy) error { - bspKey := fmt.Sprintf("%s.%s", policy.Name, policy.Namespace) + bspKey := backendSecurityPolicyKey(policy.Namespace, policy.Name) var validToken *oauth2.Token var err error if tokenResponse, ok := b.oidcTokenCache[bspKey]; !ok || rotators.IsBufferedTimeExpired(preRotationWindow, tokenResponse.Expiry) { diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index fb61e5c9e..344009621 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -99,8 +99,8 @@ func (r *AWSOIDCRotator) IsExpired(preRotationExpirationTime time.Time) bool { } // GetPreRotationTime gets the expiration time minus the preRotation interval or return zero value for time. -func (r *AWSOIDCRotator) GetPreRotationTime() (time.Time, error) { - secret, err := LookupSecret(context.Background(), r.client, r.backendSecurityPolicyNamespace, GetBSPSecretName(r.backendSecurityPolicyName)) +func (r *AWSOIDCRotator) GetPreRotationTime(ctx context.Context) (time.Time, error) { + secret, err := LookupSecret(ctx, r.client, r.backendSecurityPolicyNamespace, GetBSPSecretName(r.backendSecurityPolicyName)) if err != nil { return time.Time{}, err } @@ -123,6 +123,7 @@ func (r *AWSOIDCRotator) Rotate(ctx context.Context, token string) error { r.logger.Error(err, "failed to assume role", "role", r.roleARN, "access token", token) return err } + secret, err := LookupSecret(ctx, r.client, r.backendSecurityPolicyNamespace, GetBSPSecretName(r.backendSecurityPolicyName)) if err != nil { if !errors.IsNotFound(err) { diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index fbc689517..fc39c9d94 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -167,7 +167,8 @@ func TestAWS_GetPreRotationTime(t *testing.T) { backendSecurityPolicyNamespace: "default", backendSecurityPolicyName: "test-secret", } - preRotateTime, _ := awsOidcRotator.GetPreRotationTime() + + preRotateTime, _ := awsOidcRotator.GetPreRotationTime(t.Context()) require.Equal(t, 0, preRotateTime.Minute()) createTestAWSSecret(t, cl, "test-secret", "OLDKEY", "OLDSECRET", "OLDTOKEN", "default") @@ -179,7 +180,7 @@ func TestAWS_GetPreRotationTime(t *testing.T) { expiredTime := time.Now().Add(-1 * time.Hour) updateExpirationSecretAnnotation(secret, expiredTime) require.NoError(t, cl.Update(t.Context(), secret)) - preRotateTime, _ = awsOidcRotator.GetPreRotationTime() + preRotateTime, _ = awsOidcRotator.GetPreRotationTime(t.Context()) require.Equal(t, expiredTime.Format(time.RFC3339), preRotateTime.Format(time.RFC3339)) } @@ -194,7 +195,7 @@ func TestAWS_IsExpired(t *testing.T) { backendSecurityPolicyNamespace: "default", backendSecurityPolicyName: "test-secret", } - preRotateTime, _ := awsOidcRotator.GetPreRotationTime() + preRotateTime, _ := awsOidcRotator.GetPreRotationTime(t.Context()) require.True(t, awsOidcRotator.IsExpired(preRotateTime)) createTestAWSSecret(t, cl, "test-secret", "OLDKEY", "OLDSECRET", "OLDTOKEN", "default") @@ -206,12 +207,12 @@ func TestAWS_IsExpired(t *testing.T) { expiredTime := time.Now().Add(-1 * time.Hour) updateExpirationSecretAnnotation(secret, expiredTime) require.NoError(t, cl.Update(t.Context(), secret)) - preRotateTime, _ = awsOidcRotator.GetPreRotationTime() + preRotateTime, _ = awsOidcRotator.GetPreRotationTime(t.Context()) require.True(t, awsOidcRotator.IsExpired(preRotateTime)) hourFromNowTime := time.Now().Add(1 * time.Hour) updateExpirationSecretAnnotation(secret, hourFromNowTime) require.NoError(t, cl.Update(t.Context(), secret)) - preRotateTime, _ = awsOidcRotator.GetPreRotationTime() + preRotateTime, _ = awsOidcRotator.GetPreRotationTime(t.Context()) require.False(t, awsOidcRotator.IsExpired(preRotateTime)) } From 987b35521638ab60b6222f94712d3769ac45f994 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Sun, 16 Feb 2025 01:07:56 -0500 Subject: [PATCH 74/86] split credential renewal from reconcile Signed-off-by: Aaron Choo --- .../controller/backend_security_policy.go | 106 +++++++++------- .../backend_security_policy_test.go | 116 ++++++++++-------- internal/controller/controller.go | 2 +- .../controller/rotators/aws_oidc_rotator.go | 22 ++-- .../rotators/aws_oidc_rotator_test.go | 6 +- internal/controller/rotators/common.go | 10 ++ 6 files changed, 154 insertions(+), 108 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 1d4fe446a..df373b0b1 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -10,6 +10,7 @@ import ( "fmt" "time" + egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "github.com/go-logr/logr" "golang.org/x/oauth2" "k8s.io/apimachinery/pkg/api/errors" @@ -26,9 +27,6 @@ import ( // Temporarily a fixed duration. const preRotationWindow = 5 * time.Minute -// outgoingTimeOut will be used to prevent outgoing request from blocking. -const outGoingTimeOut = time.Minute - // backendSecurityPolicyController implements [reconcile.TypedReconciler] for [aigv1a1.BackendSecurityPolicy]. // // This handles the BackendSecurityPolicy resource and sends it to the config sink so that it can modify configuration. @@ -37,17 +35,15 @@ type backendSecurityPolicyController struct { kube kubernetes.Interface logger logr.Logger eventChan chan ConfigSinkEvent - StsClient rotators.STSClient oidcTokenCache map[string]*oauth2.Token } -func newBackendSecurityPolicyController(client client.Client, stsClient rotators.STSClient, kube kubernetes.Interface, logger logr.Logger, ch chan ConfigSinkEvent) *backendSecurityPolicyController { +func newBackendSecurityPolicyController(client client.Client, kube kubernetes.Interface, logger logr.Logger, ch chan ConfigSinkEvent) *backendSecurityPolicyController { return &backendSecurityPolicyController{ client: client, kube: kube, logger: logger, eventChan: ch, - StsClient: stsClient, oidcTokenCache: make(map[string]*oauth2.Token), } } @@ -55,7 +51,7 @@ func newBackendSecurityPolicyController(client client.Client, stsClient rotators // Reconcile implements the [reconcile.TypedReconciler] for [aigv1a1.BackendSecurityPolicy]. func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctrl.Request) (res ctrl.Result, err error) { var backendSecurityPolicy aigv1a1.BackendSecurityPolicy - if err := b.client.Get(ctx, req.NamespacedName, &backendSecurityPolicy); err != nil { + if err = b.client.Get(ctx, req.NamespacedName, &backendSecurityPolicy); err != nil { if errors.IsNotFound(err) { ctrl.Log.Info("Deleting Backend Security Policy", "namespace", req.Namespace, "name", req.Name) @@ -64,71 +60,95 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr return ctrl.Result{}, err } - if backendSecurityPolicy.Spec.AWSCredentials != nil && backendSecurityPolicy.Spec.AWSCredentials.OIDCExchangeToken != nil { - rotator, err := rotators.NewAWSOIDCRotator(ctx, b.client, b.StsClient, b.kube, b.logger, backendSecurityPolicy.Namespace, - backendSecurityPolicy.Name, preRotationWindow, backendSecurityPolicy.Spec.AWSCredentials.Region, backendSecurityPolicy.Spec.AWSCredentials.OIDCExchangeToken.AwsRoleArn) - if err != nil { - b.logger.Error(err, "failed to create AWS OIDC rotator") - return ctrl.Result{}, err - } - var requeue time.Duration - requeue = time.Minute - preRotationExpirationTime, err := rotator.GetPreRotationTime(ctx) - if err != nil { - return ctrl.Result{}, err + if oidc := getBackendSecurityPolicyAuthOIDC(backendSecurityPolicy.Spec); oidc != nil { + var rotator rotators.Rotator + skipOIDC := false + + switch backendSecurityPolicy.Spec.Type { + case aigv1a1.BackendSecurityPolicyTypeAWSCredentials: + region := backendSecurityPolicy.Spec.AWSCredentials.Region + roleArn := backendSecurityPolicy.Spec.AWSCredentials.OIDCExchangeToken.AwsRoleArn + rotator, err = rotators.NewAWSOIDCRotator(ctx, b.client, nil, b.kube, b.logger, backendSecurityPolicy.Namespace, backendSecurityPolicy.Name, preRotationWindow, roleArn, region) + if err != nil { + return ctrl.Result{RequeueAfter: time.Minute}, err + } + default: + ctrl.Log.Error(fmt.Errorf("unsupported OIDC type %s", backendSecurityPolicy.Spec.Type), "namespace", backendSecurityPolicy.Namespace, "name", backendSecurityPolicy.Name) + skipOIDC = true } - if rotator.IsExpired(preRotationExpirationTime) { - err := b.rotateCredential(ctx, rotator, backendSecurityPolicy) + + if !skipOIDC { + var requeue time.Duration + rotationTime, err := rotator.GetPreRotationTime(ctx) if err != nil { - b.logger.Error(err, "failed to rotate OIDC exchange token, retry in one minute") requeue = time.Minute + b.logger.Error(err, "failed to rotate OIDC exchange token, retry in one minute") } else { - preRotationExpirationTime, err = rotator.GetPreRotationTime(ctx) - if err != nil { - return ctrl.Result{}, err + if rotator.IsExpired(rotationTime) { + requeue, err = b.rotateCredential(ctx, &backendSecurityPolicy, *oidc, rotator) + if err != nil { + requeue = time.Minute + b.logger.Error(err, "failed to rotate OIDC exchange token, retry in one minute") + } + } else { + requeue = time.Until(rotationTime) } - requeue = time.Until(preRotationExpirationTime) } + // TODO: Investigate how to stop stale events from re-queuing. + res = ctrl.Result{RequeueAfter: requeue} } - res = ctrl.Result{RequeueAfter: requeue} } // Send the backend security policy to the config sink so that it can modify the configuration together with the state of other resources. b.eventChan <- backendSecurityPolicy.DeepCopy() return } -func (b *backendSecurityPolicyController) rotateCredential(ctx context.Context, rotator *rotators.AWSOIDCRotator, policy aigv1a1.BackendSecurityPolicy) error { +// renewCredentials will take the backendSecurityPolicy and rotator to renew credentials and return the requeue time. +func (b *backendSecurityPolicyController) rotateCredential(ctx context.Context, policy *aigv1a1.BackendSecurityPolicy, oidcCreds egv1a1.OIDC, rotator rotators.Rotator) (time.Duration, error) { bspKey := backendSecurityPolicyKey(policy.Namespace, policy.Name) var validToken *oauth2.Token var err error - if tokenResponse, ok := b.oidcTokenCache[bspKey]; !ok || rotators.IsBufferedTimeExpired(preRotationWindow, tokenResponse.Expiry) { - oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(b.client, policy.Spec.AWSCredentials.OIDCExchangeToken.OIDC), policy.Spec.AWSCredentials.OIDCExchangeToken.OIDC) - // Valid Token will be nil if fetch token errors. - timeOutCtx, cancelFunc := context.WithTimeout(ctx, outGoingTimeOut) - defer cancelFunc() - validToken, err = oidcProvider.FetchToken(timeOutCtx) + if tokenResponse, ok := b.oidcTokenCache[bspKey]; !ok || tokenResponse == nil || rotators.IsBufferedTimeExpired(preRotationWindow, tokenResponse.Expiry) { + oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(b.client, oidcCreds), oidcCreds) + validToken, err = oidcProvider.FetchToken(ctx) if err != nil { b.logger.Error(err, "failed to fetch OIDC provider token") - return err + return time.Minute, err } b.oidcTokenCache[bspKey] = validToken } else { validToken = tokenResponse } - if validToken != nil { - b.oidcTokenCache[bspKey] = validToken - // Set a timeout for rotate. - timeOutCtx, cancelRotateFunc := context.WithTimeout(ctx, outGoingTimeOut) - defer cancelRotateFunc() - token := validToken.AccessToken - return rotator.Rotate(timeOutCtx, token) + b.oidcTokenCache[bspKey] = validToken + if validToken == nil { + return time.Minute, fmt.Errorf("found a nil token for backend security policy '%s' in '%s'", policy.Name, policy.Namespace) + } + + token := validToken.AccessToken + err = rotator.Rotate(ctx, token) + if err != nil { + b.logger.Error(err, fmt.Sprintf("failed to rotate credentials for backend security policy '%s' in '%s'", policy.Name, policy.Namespace)) + return time.Minute, err + } + rotationTime, err := rotator.GetPreRotationTime(ctx) + if err != nil { + return time.Minute, err + } + return time.Until(rotationTime), nil +} + +// getBackendSecurityPolicyAuthOIDC returns the backendSecurityPolicy's OIDC pointer or nil. +func getBackendSecurityPolicyAuthOIDC(spec aigv1a1.BackendSecurityPolicySpec) *egv1a1.OIDC { + // Currently only supports AWS. + if spec.AWSCredentials != nil && spec.AWSCredentials.OIDCExchangeToken != nil { + return &spec.AWSCredentials.OIDCExchangeToken.OIDC } return nil } -// backendSecurityPolicyKey returns the key used for indexing and caching the backendSecurityPolicy +// backendSecurityPolicyKey returns the key used for indexing and caching the backendSecurityPolicy. func backendSecurityPolicyKey(namespace, name string) string { return fmt.Sprintf("%s.%s", name, namespace) } diff --git a/internal/controller/backend_security_policy_test.go b/internal/controller/backend_security_policy_test.go index 9a464dcb1..2fc1cc504 100644 --- a/internal/controller/backend_security_policy_test.go +++ b/internal/controller/backend_security_policy_test.go @@ -38,7 +38,49 @@ import ( func TestBackendSecurityController_Reconcile(t *testing.T) { ch := make(chan ConfigSinkEvent, 100) cl := fake.NewClientBuilder().WithScheme(scheme).Build() - c := newBackendSecurityPolicyController(cl, &mockSTSOperations{}, fake2.NewClientset(), ctrl.Log, ch) + c := newBackendSecurityPolicyController(cl, fake2.NewClientset(), ctrl.Log, ch) + backendSecurityPolicyName := "mybackendSecurityPolicy" + namespace := "default" + + err := cl.Create(t.Context(), &aigv1a1.BackendSecurityPolicy{ObjectMeta: metav1.ObjectMeta{Name: backendSecurityPolicyName, Namespace: namespace}}) + require.NoError(t, err) + res, err := c.Reconcile(t.Context(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: backendSecurityPolicyName}}) + require.NoError(t, err) + require.False(t, res.Requeue) + item, ok := <-ch + require.True(t, ok) + require.IsType(t, &aigv1a1.BackendSecurityPolicy{}, item) + require.Equal(t, backendSecurityPolicyName, item.(*aigv1a1.BackendSecurityPolicy).Name) + require.Equal(t, namespace, item.(*aigv1a1.BackendSecurityPolicy).Namespace) + + // Test the case where the BackendSecurityPolicy is being deleted. + err = cl.Delete(t.Context(), &aigv1a1.BackendSecurityPolicy{ObjectMeta: metav1.ObjectMeta{Name: backendSecurityPolicyName, Namespace: namespace}}) + require.NoError(t, err) + _, err = c.Reconcile(t.Context(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: backendSecurityPolicyName}}) + require.NoError(t, err) +} + +// mockSTSOperations implements the STSOperations interface for testing +type mockSTSOperations struct{} + +// AssumeRoleWithWebIdentity will return placeholder of type aws credentials. +// +// This implements [STSClient.AssumeRoleWithWebIdentity]. +func (m *mockSTSOperations) AssumeRoleWithWebIdentity(_ context.Context, _ *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { + return &sts.AssumeRoleWithWebIdentityOutput{ + Credentials: &stsTypes.Credentials{ + AccessKeyId: aws.String("NEWKEY"), + SecretAccessKey: aws.String("NEWSECRET"), + SessionToken: aws.String("NEWTOKEN"), + Expiration: aws.Time(time.Now().Add(1 * time.Hour)), + }, + }, nil +} + +func TestBackendSecurityController_RenewCredentials(t *testing.T) { + ch := make(chan ConfigSinkEvent, 100) + cl := fake.NewClientBuilder().WithScheme(scheme).Build() + c := newBackendSecurityPolicyController(cl, fake2.NewClientset(), ctrl.Log, ch) backendSecurityPolicyName := "mybackendSecurityPolicy" namespace := "default" @@ -82,47 +124,38 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { })) defer discoveryServer.Close() - err := cl.Create(t.Context(), &aigv1a1.BackendSecurityPolicy{ObjectMeta: metav1.ObjectMeta{Name: backendSecurityPolicyName, Namespace: namespace}}) - require.NoError(t, err) - err = cl.Create(t.Context(), &aigv1a1.BackendSecurityPolicy{ + oidc := egv1a1.OIDC{ + Provider: egv1a1.OIDCProvider{ + Issuer: discoveryServer.URL, + TokenEndpoint: &tokenServer.URL, + }, + ClientID: "some-client-id", + ClientSecret: gwapiv1.SecretObjectReference{ + Name: "clientSecret", + Namespace: (*gwapiv1.Namespace)(&namespace), + }, + } + bsp := &aigv1a1.BackendSecurityPolicy{ ObjectMeta: metav1.ObjectMeta{Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName), Namespace: namespace}, Spec: aigv1a1.BackendSecurityPolicySpec{ Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ - Region: "us-east-1", OIDCExchangeToken: &aigv1a1.AWSOIDCExchangeToken{ - OIDC: egv1a1.OIDC{ - Provider: egv1a1.OIDCProvider{ - Issuer: discoveryServer.URL, - TokenEndpoint: &tokenServer.URL, - }, - ClientID: "some-client-id", - ClientSecret: gwapiv1.SecretObjectReference{ - Name: "clientSecret", - Namespace: (*gwapiv1.Namespace)(&namespace), - }, - }, - GrantType: "placeholder", - Aud: "placeholder", - AwsRoleArn: "placeholder", + OIDC: oidc, }, }, }, - }) - require.NoError(t, err) - res, err := c.Reconcile(t.Context(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: backendSecurityPolicyName}}) + } + err := cl.Create(t.Context(), bsp) require.NoError(t, err) - require.False(t, res.Requeue) - item, ok := <-ch - require.True(t, ok) - require.IsType(t, &aigv1a1.BackendSecurityPolicy{}, item) - require.Equal(t, backendSecurityPolicyName, item.(*aigv1a1.BackendSecurityPolicy).Name) - require.Equal(t, namespace, item.(*aigv1a1.BackendSecurityPolicy).Namespace) ctx := oidcv3.InsecureIssuerURLContext(t.Context(), discoveryServer.URL) - res, err = c.Reconcile(ctx, reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName)}}) + rotator, err := rotators.NewAWSOIDCRotator(ctx, cl, &mockSTSOperations{}, fake2.NewClientset(), ctrl.Log, namespace, bsp.Name, preRotationWindow, "placeholder", "us-east-1") require.NoError(t, err) - require.WithinRange(t, time.Now().Add(res.RequeueAfter), time.Now().Add(50*time.Minute), time.Now().Add(time.Hour)) + + res, err := c.rotateCredential(ctx, bsp, oidc, rotator) + require.NoError(t, err) + require.WithinRange(t, time.Now().Add(res), time.Now().Add(50*time.Minute), time.Now().Add(time.Hour)) require.Len(t, c.oidcTokenCache, 1) token, ok := c.oidcTokenCache[fmt.Sprintf("%s-OIDC.%s", backendSecurityPolicyName, namespace)] @@ -132,27 +165,4 @@ func TestBackendSecurityController_Reconcile(t *testing.T) { updatedSecret, err := rotators.LookupSecret(t.Context(), cl, namespace, rotators.GetBSPSecretName(fmt.Sprintf("%s-OIDC", backendSecurityPolicyName))) require.NoError(t, err) require.NotEqualf(t, secret.Annotations[rotators.ExpirationTimeAnnotationKey], updatedSecret.Annotations[rotators.ExpirationTimeAnnotationKey], "expected updated expiration time annotation") - - // Test the case where the BackendSecurityPolicy is being deleted. - err = cl.Delete(t.Context(), &aigv1a1.BackendSecurityPolicy{ObjectMeta: metav1.ObjectMeta{Name: backendSecurityPolicyName, Namespace: namespace}}) - require.NoError(t, err) - _, err = c.Reconcile(t.Context(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: backendSecurityPolicyName}}) - require.NoError(t, err) -} - -// mockSTSOperations implements the STSOperations interface for testing -type mockSTSOperations struct{} - -// AssumeRoleWithWebIdentity will return placeholder of type aws credentials. -// -// This implements [STSClient.AssumeRoleWithWebIdentity]. -func (m *mockSTSOperations) AssumeRoleWithWebIdentity(_ context.Context, _ *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { - return &sts.AssumeRoleWithWebIdentityOutput{ - Credentials: &stsTypes.Credentials{ - AccessKeyId: aws.String("NEWKEY"), - SecretAccessKey: aws.String("NEWSECRET"), - SessionToken: aws.String("NEWTOKEN"), - Expiration: aws.Time(time.Now().Add(1 * time.Hour)), - }, - }, nil } diff --git a/internal/controller/controller.go b/internal/controller/controller.go index 8808b11bf..c8d189a87 100644 --- a/internal/controller/controller.go +++ b/internal/controller/controller.go @@ -93,7 +93,7 @@ func StartControllers(ctx context.Context, config *rest.Config, logger logr.Logg return fmt.Errorf("failed to create controller for AIServiceBackend: %w", err) } - backendSecurityPolicyC := newBackendSecurityPolicyController(c, nil, kubernetes.NewForConfigOrDie(config), logger. + backendSecurityPolicyC := newBackendSecurityPolicyController(c, kubernetes.NewForConfigOrDie(config), logger. WithName("backend-security-policy"), sinkChan) if err = ctrl.NewControllerManagedBy(mgr). For(&aigv1a1.BackendSecurityPolicy{}). diff --git a/internal/controller/rotators/aws_oidc_rotator.go b/internal/controller/rotators/aws_oidc_rotator.go index 344009621..9db3b65ce 100644 --- a/internal/controller/rotators/aws_oidc_rotator.go +++ b/internal/controller/rotators/aws_oidc_rotator.go @@ -39,12 +39,12 @@ type AWSOIDCRotator struct { backendSecurityPolicyName string // backendSecurityPolicyNamespace provides namespace of backend security policy. backendSecurityPolicyNamespace string - // aws region - region string - // aws IAM role ARN - roleARN string // preRotationWindow specifies how long before expiry to rotate. preRotationWindow time.Duration + // roleArn is the role ARN used to obtain credentials. + roleArn string + // region is the AWS region for the credentials. + region string } // NewAWSOIDCRotator creates a new AWS OIDC rotator with the specified configuration. @@ -58,8 +58,8 @@ func NewAWSOIDCRotator( backendSecurityPolicyNamespace string, backendSecurityPolicyName string, preRotationWindow time.Duration, + roleArn string, region string, - roleARN string, ) (*AWSOIDCRotator, error) { cfg, err := defaultAWSConfig(ctx) if err != nil { @@ -88,7 +88,7 @@ func NewAWSOIDCRotator( backendSecurityPolicyNamespace: backendSecurityPolicyNamespace, backendSecurityPolicyName: backendSecurityPolicyName, preRotationWindow: preRotationWindow, - roleARN: roleARN, + roleArn: roleArn, region: region, }, nil } @@ -102,6 +102,10 @@ func (r *AWSOIDCRotator) IsExpired(preRotationExpirationTime time.Time) bool { func (r *AWSOIDCRotator) GetPreRotationTime(ctx context.Context) (time.Time, error) { secret, err := LookupSecret(ctx, r.client, r.backendSecurityPolicyNamespace, GetBSPSecretName(r.backendSecurityPolicyName)) if err != nil { + // return zero value for time if secret has not been created. + if errors.IsNotFound(err) { + return time.Time{}, nil + } return time.Time{}, err } expirationTime, err := GetExpirationSecretAnnotation(secret) @@ -113,6 +117,8 @@ func (r *AWSOIDCRotator) GetPreRotationTime(ctx context.Context) (time.Time, err } // Rotate implements the retrieval and storage of AWS sts credentials. +// +// This implements [Rotator.Rotate]. func (r *AWSOIDCRotator) Rotate(ctx context.Context, token string) error { r.logger.Info("rotating AWS sts temporary credentials", "namespace", r.backendSecurityPolicyNamespace, @@ -120,7 +126,7 @@ func (r *AWSOIDCRotator) Rotate(ctx context.Context, token string) error { result, err := r.assumeRoleWithToken(ctx, token) if err != nil { - r.logger.Error(err, "failed to assume role", "role", r.roleARN, "access token", token) + r.logger.Error(err, "failed to assume role", "role", r.roleArn, "access token", token) return err } @@ -167,7 +173,7 @@ func (r *AWSOIDCRotator) Rotate(ctx context.Context, token string) error { // assumeRoleWithToken exchanges an OIDC token for AWS credentials. func (r *AWSOIDCRotator) assumeRoleWithToken(ctx context.Context, token string) (*sts.AssumeRoleWithWebIdentityOutput, error) { return r.stsClient.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityInput{ - RoleArn: aws.String(r.roleARN), + RoleArn: aws.String(r.roleArn), WebIdentityToken: aws.String(token), RoleSessionName: aws.String(fmt.Sprintf(awsSessionNameFormat, r.backendSecurityPolicyName)), }) diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index fc39c9d94..4a6fee25c 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -112,7 +112,7 @@ func TestAWS_OIDCRotator(t *testing.T) { &corev1.Secret{}, ) cl := fake.NewClientBuilder().WithScheme(scheme).Build() - // Setup initial credentials and client secret + // Setup initial credentials and client secret. createTestAWSSecret(t, cl, "test-secret", "OLDKEY", "OLDSECRET", "OLDTOKEN", "default") createClientSecret(t, "test-client-secret") @@ -122,7 +122,7 @@ func TestAWS_OIDCRotator(t *testing.T) { backendSecurityPolicyNamespace: "default", backendSecurityPolicyName: "test-secret", region: "us-east1", - roleARN: "test-role", + roleArn: "test-role", } require.NoError(t, awsOidcRotator.Rotate(t.Context(), "NEW-OIDC-TOKEN")) @@ -148,7 +148,7 @@ func TestAWS_OIDCRotator(t *testing.T) { backendSecurityPolicyNamespace: "default", backendSecurityPolicyName: "test-secret", region: "us-east-1", - roleARN: "test-role", + roleArn: "test-role", } err := awsOidcRotator.Rotate(t.Context(), "NEW-OIDC-TOKEN") require.Error(t, err) diff --git a/internal/controller/rotators/common.go b/internal/controller/rotators/common.go index b4878912a..4376bdbc8 100644 --- a/internal/controller/rotators/common.go +++ b/internal/controller/rotators/common.go @@ -20,6 +20,16 @@ const ExpirationTimeAnnotationKey = "rotators/expiration-time" const rotatorSecretNamePrefix = "ai-eg-bsp" // #nosec G101 +// Rotator defines the interface for rotating provider credential. +type Rotator interface { + // IsExpired checks if the provider credentials needs to be renewed. + IsExpired(preRotationExpirationTime time.Time) bool + // GetPreRotationTime gets the time when the credentials need to be renewed. + GetPreRotationTime(ctx context.Context) (time.Time, error) + // Rotate will update the credential secret file with new credentials. + Rotate(ctx context.Context, token string) error +} + // LookupSecret retrieves an existing secret. func LookupSecret(ctx context.Context, k8sClient client.Client, namespace, name string) (*corev1.Secret, error) { secret := &corev1.Secret{} From 17e4123b54af9f725da06db4173248ad010d5747 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Sun, 16 Feb 2025 04:25:55 -0500 Subject: [PATCH 75/86] fix tests Signed-off-by: Aaron Choo --- .../controller/backend_security_policy.go | 30 +++++----- .../backend_security_policy_test.go | 56 ++++++++++++++++++- 2 files changed, 69 insertions(+), 17 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index df373b0b1..d62f16f53 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -78,16 +78,16 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr } if !skipOIDC { - var requeue time.Duration - rotationTime, err := rotator.GetPreRotationTime(ctx) + requeue := time.Minute + var rotationTime time.Time + rotationTime, err = rotator.GetPreRotationTime(ctx) if err != nil { - requeue = time.Minute b.logger.Error(err, "failed to rotate OIDC exchange token, retry in one minute") } else { if rotator.IsExpired(rotationTime) { requeue, err = b.rotateCredential(ctx, &backendSecurityPolicy, *oidc, rotator) if err != nil { - requeue = time.Minute + println(err.Error()) b.logger.Error(err, "failed to rotate OIDC exchange token, retry in one minute") } } else { @@ -106,10 +106,10 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr // renewCredentials will take the backendSecurityPolicy and rotator to renew credentials and return the requeue time. func (b *backendSecurityPolicyController) rotateCredential(ctx context.Context, policy *aigv1a1.BackendSecurityPolicy, oidcCreds egv1a1.OIDC, rotator rotators.Rotator) (time.Duration, error) { bspKey := backendSecurityPolicyKey(policy.Namespace, policy.Name) - var validToken *oauth2.Token - var err error - if tokenResponse, ok := b.oidcTokenCache[bspKey]; !ok || tokenResponse == nil || rotators.IsBufferedTimeExpired(preRotationWindow, tokenResponse.Expiry) { + var err error + validToken, ok := b.oidcTokenCache[bspKey] + if !ok || validToken == nil || rotators.IsBufferedTimeExpired(preRotationWindow, validToken.Expiry) { oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(b.client, oidcCreds), oidcCreds) validToken, err = oidcProvider.FetchToken(ctx) if err != nil { @@ -117,13 +117,6 @@ func (b *backendSecurityPolicyController) rotateCredential(ctx context.Context, return time.Minute, err } b.oidcTokenCache[bspKey] = validToken - } else { - validToken = tokenResponse - } - - b.oidcTokenCache[bspKey] = validToken - if validToken == nil { - return time.Minute, fmt.Errorf("found a nil token for backend security policy '%s' in '%s'", policy.Name, policy.Namespace) } token := validToken.AccessToken @@ -142,8 +135,13 @@ func (b *backendSecurityPolicyController) rotateCredential(ctx context.Context, // getBackendSecurityPolicyAuthOIDC returns the backendSecurityPolicy's OIDC pointer or nil. func getBackendSecurityPolicyAuthOIDC(spec aigv1a1.BackendSecurityPolicySpec) *egv1a1.OIDC { // Currently only supports AWS. - if spec.AWSCredentials != nil && spec.AWSCredentials.OIDCExchangeToken != nil { - return &spec.AWSCredentials.OIDCExchangeToken.OIDC + switch spec.Type { + case aigv1a1.BackendSecurityPolicyTypeAWSCredentials: + if spec.AWSCredentials != nil && spec.AWSCredentials.OIDCExchangeToken != nil { + return &spec.AWSCredentials.OIDCExchangeToken.OIDC + } + default: + return nil } return nil } diff --git a/internal/controller/backend_security_policy_test.go b/internal/controller/backend_security_policy_test.go index 2fc1cc504..d5549624c 100644 --- a/internal/controller/backend_security_policy_test.go +++ b/internal/controller/backend_security_policy_test.go @@ -77,7 +77,34 @@ func (m *mockSTSOperations) AssumeRoleWithWebIdentity(_ context.Context, _ *sts. }, nil } -func TestBackendSecurityController_RenewCredentials(t *testing.T) { +func TestBackendSecurityPolicyController_ReconcileOIDC(t *testing.T) { + ch := make(chan ConfigSinkEvent, 100) + cl := fake.NewClientBuilder().WithScheme(scheme).Build() + c := newBackendSecurityPolicyController(cl, fake2.NewClientset(), ctrl.Log, ch) + backendSecurityPolicyName := "mybackendSecurityPolicy" + namespace := "default" + + bsp := &aigv1a1.BackendSecurityPolicy{ + ObjectMeta: metav1.ObjectMeta{Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName), Namespace: namespace}, + Spec: aigv1a1.BackendSecurityPolicySpec{ + Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, + AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ + OIDCExchangeToken: &aigv1a1.AWSOIDCExchangeToken{ + OIDC: egv1a1.OIDC{}, + }, + }, + }, + } + err := cl.Create(t.Context(), bsp) + require.NoError(t, err) + + // Expects rotate credentials to fail due to missing OIDC details. + res, err := c.Reconcile(t.Context(), reconcile.Request{NamespacedName: types.NamespacedName{Namespace: namespace, Name: fmt.Sprintf("%s-OIDC", backendSecurityPolicyName)}}) + require.Error(t, err) + require.Equal(t, time.Minute, res.RequeueAfter) +} + +func TestBackendSecurityController_RotateCredentials(t *testing.T) { ch := make(chan ConfigSinkEvent, 100) cl := fake.NewClientBuilder().WithScheme(scheme).Build() c := newBackendSecurityPolicyController(cl, fake2.NewClientset(), ctrl.Log, ch) @@ -166,3 +193,30 @@ func TestBackendSecurityController_RenewCredentials(t *testing.T) { require.NoError(t, err) require.NotEqualf(t, secret.Annotations[rotators.ExpirationTimeAnnotationKey], updatedSecret.Annotations[rotators.ExpirationTimeAnnotationKey], "expected updated expiration time annotation") } + +func TestBackendSecurityController_GetBackendSecurityPolicyAuthOIDC(t *testing.T) { + // API Key type does not support OIDC. + require.Nil(t, getBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{Type: aigv1a1.BackendSecurityPolicyTypeAPIKey})) + + // AWS type supports OIDC type but OIDC needs to be defined. + require.Nil(t, getBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{ + Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, + AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ + CredentialsFile: &aigv1a1.AWSCredentialsFile{}, + }, + })) + + // AWS type with OIDC defined. + oidc := getBackendSecurityPolicyAuthOIDC(aigv1a1.BackendSecurityPolicySpec{ + Type: aigv1a1.BackendSecurityPolicyTypeAWSCredentials, + AWSCredentials: &aigv1a1.BackendSecurityPolicyAWSCredentials{ + OIDCExchangeToken: &aigv1a1.AWSOIDCExchangeToken{ + OIDC: egv1a1.OIDC{ + ClientID: "some-client-id", + }, + }, + }, + }) + require.NotNil(t, oidc) + require.Equal(t, "some-client-id", oidc.ClientID) +} From ab98e14f123e57b1bae40c6ff8e423c845d89ade Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Sun, 16 Feb 2025 05:07:50 -0500 Subject: [PATCH 76/86] sync oidc between providers Signed-off-by: Aaron Choo --- internal/controller/backend_security_policy.go | 1 - .../oauth/client_credentials_token_provider.go | 13 ++++++++++++- internal/controller/oauth/oidc_provider.go | 10 ++++++++++ internal/controller/oauth/types.go | 3 +++ 4 files changed, 25 insertions(+), 2 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index d62f16f53..6b8a97662 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -87,7 +87,6 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr if rotator.IsExpired(rotationTime) { requeue, err = b.rotateCredential(ctx, &backendSecurityPolicy, *oidc, rotator) if err != nil { - println(err.Error()) b.logger.Error(err, "failed to rotate OIDC exchange token, retry in one minute") } } else { diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index 1bf5d8aaf..5dac35fee 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -54,14 +54,25 @@ func (p *ClientCredentialsTokenProvider) FetchToken(ctx context.Context) (*oauth return p.getTokenWithClientCredentialConfig(ctx, clientSecret) } +// SetOIDC will update the OIDC field in ClientCredentialsTokenProvider. +// +// This implements [TokenProvider.SetOIDC]. +func (p *ClientCredentialsTokenProvider) SetOIDC(oidc egv1a1.OIDC) { + p.oidcCredential = oidc +} + // getTokenWithClientCredentialFlow fetches the oauth2 token with client credential config. func (p *ClientCredentialsTokenProvider) getTokenWithClientCredentialConfig(ctx context.Context, clientSecret string) (*oauth2.Token, error) { oauth2Config := clientcredentials.Config{ ClientSecret: clientSecret, ClientID: p.oidcCredential.ClientID, Scopes: p.oidcCredential.Scopes, - TokenURL: *p.oidcCredential.Provider.TokenEndpoint, } + + if p.oidcCredential.Provider.TokenEndpoint != nil { + oauth2Config.TokenURL = *p.oidcCredential.Provider.TokenEndpoint + } + // Underlying token call will apply http client timeout. ctx = context.WithValue(ctx, oauth2.HTTPClient, &http.Client{Timeout: tokenTimeoutDuration}) token, err := oauth2Config.Token(ctx) diff --git a/internal/controller/oauth/oidc_provider.go b/internal/controller/oauth/oidc_provider.go index a5768a98f..b9450d508 100644 --- a/internal/controller/oauth/oidc_provider.go +++ b/internal/controller/oauth/oidc_provider.go @@ -91,6 +91,9 @@ func (p *OIDCProvider) FetchToken(ctx context.Context) (*oauth2.Token, error) { } } + // Sync OIDC with TokenProvider. + p.tokenProvider.SetOIDC(p.oidcCredential) + // Get base token response. token, err := p.tokenProvider.FetchToken(ctx) if err != nil { @@ -99,3 +102,10 @@ func (p *OIDCProvider) FetchToken(ctx context.Context) (*oauth2.Token, error) { return token, nil } + +// SetOIDC will update the OIDC field in ClientCredentialsTokenProvider. +// +// This implements [TokenProvider.SetOIDC]. +func (p *OIDCProvider) SetOIDC(oidc egv1a1.OIDC) { + p.oidcCredential = oidc +} diff --git a/internal/controller/oauth/types.go b/internal/controller/oauth/types.go index 8aaad49d5..fb231e96e 100644 --- a/internal/controller/oauth/types.go +++ b/internal/controller/oauth/types.go @@ -8,6 +8,7 @@ package oauth import ( "context" + egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "golang.org/x/oauth2" ) @@ -15,4 +16,6 @@ import ( type TokenProvider interface { // FetchToken will obtain oauth token using oidc credentials. FetchToken(ctx context.Context) (*oauth2.Token, error) + // SetOIDC will update the locally stored OIDC credentials. + SetOIDC(oidc egv1a1.OIDC) } From f592bc2fca0f29a9ebfa32f2dc9e1a70f308be6f Mon Sep 17 00:00:00 2001 From: Dan Sun Date: Sun, 16 Feb 2025 15:25:55 -0500 Subject: [PATCH 77/86] Remove the skipOIDC flag Signed-off-by: Dan Sun --- .../controller/backend_security_policy.go | 42 ++++++++----------- .../backendsecuritypolicies/aws_oidc.yaml | 3 +- 2 files changed, 19 insertions(+), 26 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 6b8a97662..53b04db1a 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -62,47 +62,43 @@ func (b *backendSecurityPolicyController) Reconcile(ctx context.Context, req ctr if oidc := getBackendSecurityPolicyAuthOIDC(backendSecurityPolicy.Spec); oidc != nil { var rotator rotators.Rotator - skipOIDC := false - switch backendSecurityPolicy.Spec.Type { case aigv1a1.BackendSecurityPolicyTypeAWSCredentials: region := backendSecurityPolicy.Spec.AWSCredentials.Region roleArn := backendSecurityPolicy.Spec.AWSCredentials.OIDCExchangeToken.AwsRoleArn rotator, err = rotators.NewAWSOIDCRotator(ctx, b.client, nil, b.kube, b.logger, backendSecurityPolicy.Namespace, backendSecurityPolicy.Name, preRotationWindow, roleArn, region) if err != nil { - return ctrl.Result{RequeueAfter: time.Minute}, err + return ctrl.Result{}, err } default: - ctrl.Log.Error(fmt.Errorf("unsupported OIDC type %s", backendSecurityPolicy.Spec.Type), "namespace", backendSecurityPolicy.Namespace, "name", backendSecurityPolicy.Name) - skipOIDC = true + err = fmt.Errorf("backend security type %s does not support OIDC token exchange", backendSecurityPolicy.Spec.Type) + ctrl.Log.Error(err, "namespace", backendSecurityPolicy.Namespace, "name", backendSecurityPolicy.Name) + return ctrl.Result{}, err } - if !skipOIDC { - requeue := time.Minute - var rotationTime time.Time - rotationTime, err = rotator.GetPreRotationTime(ctx) - if err != nil { - b.logger.Error(err, "failed to rotate OIDC exchange token, retry in one minute") - } else { - if rotator.IsExpired(rotationTime) { - requeue, err = b.rotateCredential(ctx, &backendSecurityPolicy, *oidc, rotator) - if err != nil { - b.logger.Error(err, "failed to rotate OIDC exchange token, retry in one minute") - } - } else { - requeue = time.Until(rotationTime) + requeue := time.Minute + var rotationTime time.Time + rotationTime, err = rotator.GetPreRotationTime(ctx) + if err != nil { + b.logger.Error(err, "failed to get rotation time, retry in one minute") + } else { + if rotator.IsExpired(rotationTime) { + requeue, err = b.rotateCredential(ctx, &backendSecurityPolicy, *oidc, rotator) + if err != nil { + b.logger.Error(err, "failed to rotate OIDC exchange token, retry in one minute") } + } else { + requeue = time.Until(rotationTime) } - // TODO: Investigate how to stop stale events from re-queuing. - res = ctrl.Result{RequeueAfter: requeue} } + res = ctrl.Result{RequeueAfter: requeue} } // Send the backend security policy to the config sink so that it can modify the configuration together with the state of other resources. b.eventChan <- backendSecurityPolicy.DeepCopy() return } -// renewCredentials will take the backendSecurityPolicy and rotator to renew credentials and return the requeue time. +// rotateCredential rotates the credentials using the access token from OIDC provider and return the requeue time for next rotation. func (b *backendSecurityPolicyController) rotateCredential(ctx context.Context, policy *aigv1a1.BackendSecurityPolicy, oidcCreds egv1a1.OIDC, rotator rotators.Rotator) (time.Duration, error) { bspKey := backendSecurityPolicyKey(policy.Namespace, policy.Name) @@ -112,7 +108,6 @@ func (b *backendSecurityPolicyController) rotateCredential(ctx context.Context, oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(b.client, oidcCreds), oidcCreds) validToken, err = oidcProvider.FetchToken(ctx) if err != nil { - b.logger.Error(err, "failed to fetch OIDC provider token") return time.Minute, err } b.oidcTokenCache[bspKey] = validToken @@ -121,7 +116,6 @@ func (b *backendSecurityPolicyController) rotateCredential(ctx context.Context, token := validToken.AccessToken err = rotator.Rotate(ctx, token) if err != nil { - b.logger.Error(err, fmt.Sprintf("failed to rotate credentials for backend security policy '%s' in '%s'", policy.Name, policy.Namespace)) return time.Minute, err } rotationTime, err := rotator.GetPreRotationTime(ctx) diff --git a/tests/crdcel/testdata/backendsecuritypolicies/aws_oidc.yaml b/tests/crdcel/testdata/backendsecuritypolicies/aws_oidc.yaml index 0b3a0e7d9..d6fa2cccc 100644 --- a/tests/crdcel/testdata/backendsecuritypolicies/aws_oidc.yaml +++ b/tests/crdcel/testdata/backendsecuritypolicies/aws_oidc.yaml @@ -13,11 +13,10 @@ spec: awsCredentials: region: us-east-1 oidcExchangeToken: - awsRoleArn: "arn" + awsRoleArn: placeholder oidc: provider: issuer: placeholder clientID: placeholder clientSecret: name: placeholder - awsRoleARN: placeholder From 682293fa1664723cf489b0144f53c0b50fd0204d Mon Sep 17 00:00:00 2001 From: Dan Sun Date: Sun, 16 Feb 2025 15:52:16 -0500 Subject: [PATCH 78/86] Don't export NewClientCredentialsProvider Signed-off-by: Dan Sun --- .../controller/backend_security_policy.go | 2 +- .../client_credentials_token_provider.go | 11 ++----- .../client_credentials_token_provider_test.go | 4 +-- internal/controller/oauth/oidc_provider.go | 33 +++++++------------ .../controller/oauth/oidc_provider_test.go | 10 +++--- internal/controller/oauth/types.go | 3 -- 6 files changed, 22 insertions(+), 41 deletions(-) diff --git a/internal/controller/backend_security_policy.go b/internal/controller/backend_security_policy.go index 53b04db1a..c6652c169 100644 --- a/internal/controller/backend_security_policy.go +++ b/internal/controller/backend_security_policy.go @@ -105,7 +105,7 @@ func (b *backendSecurityPolicyController) rotateCredential(ctx context.Context, var err error validToken, ok := b.oidcTokenCache[bspKey] if !ok || validToken == nil || rotators.IsBufferedTimeExpired(preRotationWindow, validToken.Expiry) { - oidcProvider := oauth.NewOIDCProvider(oauth.NewClientCredentialsProvider(b.client, oidcCreds), oidcCreds) + oidcProvider := oauth.NewOIDCProvider(b.client, oidcCreds) validToken, err = oidcProvider.FetchToken(ctx) if err != nil { return time.Minute, err diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index 5dac35fee..e469db019 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -27,8 +27,8 @@ type ClientCredentialsTokenProvider struct { oidcCredential egv1a1.OIDC } -// NewClientCredentialsProvider creates a new client credentials provider. -func NewClientCredentialsProvider(cl client.Client, oidcCredential egv1a1.OIDC) *ClientCredentialsTokenProvider { +// newClientCredentialsProvider creates a new client credentials provider. +func newClientCredentialsProvider(cl client.Client, oidcCredential egv1a1.OIDC) *ClientCredentialsTokenProvider { return &ClientCredentialsTokenProvider{ client: cl, oidcCredential: oidcCredential, @@ -54,13 +54,6 @@ func (p *ClientCredentialsTokenProvider) FetchToken(ctx context.Context) (*oauth return p.getTokenWithClientCredentialConfig(ctx, clientSecret) } -// SetOIDC will update the OIDC field in ClientCredentialsTokenProvider. -// -// This implements [TokenProvider.SetOIDC]. -func (p *ClientCredentialsTokenProvider) SetOIDC(oidc egv1a1.OIDC) { - p.oidcCredential = oidc -} - // getTokenWithClientCredentialFlow fetches the oauth2 token with client credential config. func (p *ClientCredentialsTokenProvider) getTokenWithClientCredentialConfig(ctx context.Context, clientSecret string) (*oauth2.Token, error) { oauth2Config := clientcredentials.Config{ diff --git a/internal/controller/oauth/client_credentials_token_provider_test.go b/internal/controller/oauth/client_credentials_token_provider_test.go index c0121321f..35a8eb9e2 100644 --- a/internal/controller/oauth/client_credentials_token_provider_test.go +++ b/internal/controller/oauth/client_credentials_token_provider_test.go @@ -54,14 +54,14 @@ func TestClientCredentialsProvider_FetchToken(t *testing.T) { require.NoError(t, err) namespaceRef := gwapiv1.Namespace(secretNamespace) - clientCredentialProvider := NewClientCredentialsProvider(cl, egv1a1.OIDC{}) + clientCredentialProvider := newClientCredentialsProvider(cl, egv1a1.OIDC{}) require.NotNil(t, clientCredentialProvider) _, err = clientCredentialProvider.FetchToken(t.Context()) require.Error(t, err) require.Contains(t, err.Error(), "oidc-client-secret namespace is nil") - clientCredentialProvider = NewClientCredentialsProvider(cl, egv1a1.OIDC{ + clientCredentialProvider = newClientCredentialsProvider(cl, egv1a1.OIDC{ Provider: egv1a1.OIDCProvider{ Issuer: tokenServer.URL, TokenEndpoint: &tokenServer.URL, diff --git a/internal/controller/oauth/oidc_provider.go b/internal/controller/oauth/oidc_provider.go index b9450d508..d563a74aa 100644 --- a/internal/controller/oauth/oidc_provider.go +++ b/internal/controller/oauth/oidc_provider.go @@ -12,19 +12,20 @@ import ( "github.com/coreos/go-oidc/v3/oidc" egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "golang.org/x/oauth2" + "sigs.k8s.io/controller-runtime/pkg/client" ) // OIDCProvider extends ClientCredentialsTokenProvider with OIDC support. type OIDCProvider struct { - tokenProvider TokenProvider - oidcCredential egv1a1.OIDC + tokenProvider TokenProvider + oidcConfig egv1a1.OIDC } // NewOIDCProvider creates a new OIDC-aware provider. -func NewOIDCProvider(tokenProvider TokenProvider, oidcCredentials egv1a1.OIDC) *OIDCProvider { +func NewOIDCProvider(client client.Client, oidcConfig egv1a1.OIDC) *OIDCProvider { return &OIDCProvider{ - tokenProvider: tokenProvider, - oidcCredential: oidcCredentials, + tokenProvider: newClientCredentialsProvider(client, oidcConfig), + oidcConfig: oidcConfig, } } @@ -64,37 +65,34 @@ func (p *OIDCProvider) getOIDCProviderConfig(ctx context.Context, issuerURL stri // This implements [TokenProvider.FetchToken]. func (p *OIDCProvider) FetchToken(ctx context.Context) (*oauth2.Token, error) { // If issuer URL is provided, fetch OIDC metadata. - if issuerURL := p.oidcCredential.Provider.Issuer; issuerURL != "" { + if issuerURL := p.oidcConfig.Provider.Issuer; issuerURL != "" { config, supportedScopes, err := p.getOIDCProviderConfig(ctx, issuerURL) if err != nil { return nil, fmt.Errorf("failed to get OIDC config: %w", err) } // Use discovered token endpoint if not explicitly provided. - if p.oidcCredential.Provider.TokenEndpoint == nil { - p.oidcCredential.Provider.TokenEndpoint = &config.TokenURL + if p.oidcConfig.Provider.TokenEndpoint == nil { + p.oidcConfig.Provider.TokenEndpoint = &config.TokenURL } // Add discovered scopes if available. if len(supportedScopes) > 0 { requestedScopes := make(map[string]bool) - for _, scope := range p.oidcCredential.Scopes { + for _, scope := range p.oidcConfig.Scopes { requestedScopes[scope] = true } // Add supported scopes that aren't already requested. for _, scope := range supportedScopes { if !requestedScopes[scope] { - p.oidcCredential.Scopes = append(p.oidcCredential.Scopes, scope) + p.oidcConfig.Scopes = append(p.oidcConfig.Scopes, scope) } } } } - // Sync OIDC with TokenProvider. - p.tokenProvider.SetOIDC(p.oidcCredential) - - // Get base token response. + // Get token response from the provider. token, err := p.tokenProvider.FetchToken(ctx) if err != nil { return nil, fmt.Errorf("failed to get token: %w", err) @@ -102,10 +100,3 @@ func (p *OIDCProvider) FetchToken(ctx context.Context) (*oauth2.Token, error) { return token, nil } - -// SetOIDC will update the OIDC field in ClientCredentialsTokenProvider. -// -// This implements [TokenProvider.SetOIDC]. -func (p *OIDCProvider) SetOIDC(oidc egv1a1.OIDC) { - p.oidcCredential = oidc -} diff --git a/internal/controller/oauth/oidc_provider_test.go b/internal/controller/oauth/oidc_provider_test.go index a22ce1b0a..0367e8b1f 100644 --- a/internal/controller/oauth/oidc_provider_test.go +++ b/internal/controller/oauth/oidc_provider_test.go @@ -49,7 +49,7 @@ func TestOIDCProvider_GetOIDCProviderConfigErrors(t *testing.T) { })) defer missingTokenURLTestServer.Close() - oidcProvider := NewOIDCProvider(NewClientCredentialsProvider(cl, oidc), oidc) + oidcProvider := NewOIDCProvider(cl, oidc) for _, testcase := range []struct { name string @@ -114,7 +114,7 @@ func TestOIDCProvider_GetOIDCProviderConfig(t *testing.T) { } ctx := oidcv3.InsecureIssuerURLContext(t.Context(), ts.URL) - oidcProvider := NewOIDCProvider(NewClientCredentialsProvider(cl, oidc), oidc) + oidcProvider := NewOIDCProvider(cl, oidc) config, supportedScope, err := oidcProvider.getOIDCProviderConfig(ctx, ts.URL) require.NoError(t, err) require.Equal(t, "token_endpoint", config.TokenURL) @@ -171,13 +171,13 @@ func TestOIDCProvider_FetchToken(t *testing.T) { Scopes: []string{"two", "openid"}, } ctx := oidcv3.InsecureIssuerURLContext(t.Context(), oidcServer.URL) - oidcProvider := NewOIDCProvider(NewClientCredentialsProvider(cl, oidc), oidc) - require.Len(t, oidcProvider.oidcCredential.Scopes, 2) + oidcProvider := NewOIDCProvider(cl, oidc) + require.Len(t, oidcProvider.oidcConfig.Scopes, 2) token, err := oidcProvider.FetchToken(ctx) require.NoError(t, err) require.Equal(t, "token", token.AccessToken) require.Equal(t, "Bearer", token.Type()) require.WithinRangef(t, token.Expiry, time.Now().Add(3590*time.Second), time.Now().Add(3600*time.Second), "token expires at") - require.Len(t, oidcProvider.oidcCredential.Scopes, 3) + require.Len(t, oidcProvider.oidcConfig.Scopes, 3) } diff --git a/internal/controller/oauth/types.go b/internal/controller/oauth/types.go index fb231e96e..8aaad49d5 100644 --- a/internal/controller/oauth/types.go +++ b/internal/controller/oauth/types.go @@ -8,7 +8,6 @@ package oauth import ( "context" - egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" "golang.org/x/oauth2" ) @@ -16,6 +15,4 @@ import ( type TokenProvider interface { // FetchToken will obtain oauth token using oidc credentials. FetchToken(ctx context.Context) (*oauth2.Token, error) - // SetOIDC will update the locally stored OIDC credentials. - SetOIDC(oidc egv1a1.OIDC) } From dc5f31c17b79669572bc9d3519ff47c4eaf90b26 Mon Sep 17 00:00:00 2001 From: Dan Sun Date: Sun, 16 Feb 2025 15:57:12 -0500 Subject: [PATCH 79/86] Don't export NewClientCredentialsProvider Signed-off-by: Dan Sun --- .../client_credentials_token_provider.go | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/internal/controller/oauth/client_credentials_token_provider.go b/internal/controller/oauth/client_credentials_token_provider.go index e469db019..0043f2820 100644 --- a/internal/controller/oauth/client_credentials_token_provider.go +++ b/internal/controller/oauth/client_credentials_token_provider.go @@ -23,15 +23,15 @@ const tokenTimeoutDuration = time.Minute // ClientCredentialsTokenProvider implements the standard OAuth2 client credentials flow. type ClientCredentialsTokenProvider struct { - client client.Client - oidcCredential egv1a1.OIDC + client client.Client + oidcConfig egv1a1.OIDC } // newClientCredentialsProvider creates a new client credentials provider. -func newClientCredentialsProvider(cl client.Client, oidcCredential egv1a1.OIDC) *ClientCredentialsTokenProvider { +func newClientCredentialsProvider(cl client.Client, oidcConfig egv1a1.OIDC) *ClientCredentialsTokenProvider { return &ClientCredentialsTokenProvider{ - client: cl, - oidcCredential: oidcCredential, + client: cl, + oidcConfig: oidcConfig, } } @@ -40,13 +40,13 @@ func newClientCredentialsProvider(cl client.Client, oidcCredential egv1a1.OIDC) // This implements [TokenProvider.FetchToken]. func (p *ClientCredentialsTokenProvider) FetchToken(ctx context.Context) (*oauth2.Token, error) { // client secret namespace is optional on egv1a1.OIDC, but it is required for AI Gateway for now. - if p.oidcCredential.ClientSecret.Namespace == nil { + if p.oidcConfig.ClientSecret.Namespace == nil { return nil, fmt.Errorf("oidc-client-secret namespace is nil") } clientSecret, err := getClientSecret(ctx, p.client, &corev1.SecretReference{ - Name: string(p.oidcCredential.ClientSecret.Name), - Namespace: string(*p.oidcCredential.ClientSecret.Namespace), + Name: string(p.oidcConfig.ClientSecret.Name), + Namespace: string(*p.oidcConfig.ClientSecret.Namespace), }) if err != nil { return nil, err @@ -58,12 +58,12 @@ func (p *ClientCredentialsTokenProvider) FetchToken(ctx context.Context) (*oauth func (p *ClientCredentialsTokenProvider) getTokenWithClientCredentialConfig(ctx context.Context, clientSecret string) (*oauth2.Token, error) { oauth2Config := clientcredentials.Config{ ClientSecret: clientSecret, - ClientID: p.oidcCredential.ClientID, - Scopes: p.oidcCredential.Scopes, + ClientID: p.oidcConfig.ClientID, + Scopes: p.oidcConfig.Scopes, } - if p.oidcCredential.Provider.TokenEndpoint != nil { - oauth2Config.TokenURL = *p.oidcCredential.Provider.TokenEndpoint + if p.oidcConfig.Provider.TokenEndpoint != nil { + oauth2Config.TokenURL = *p.oidcConfig.Provider.TokenEndpoint } // Underlying token call will apply http client timeout. From b7c2d69c9c4d94c4c4918cf951b1198572aee776 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Tue, 18 Feb 2025 10:51:15 -0500 Subject: [PATCH 80/86] Update internal/controller/rotators/aws_oidc_rotator_test.go Co-authored-by: Takeshi Yoneda Signed-off-by: Aaron Choo --- internal/controller/rotators/aws_oidc_rotator_test.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index 4a6fee25c..0a319ce58 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -89,10 +89,6 @@ func (m *MockSTSOperations) AssumeRoleWithWebIdentity(ctx context.Context, param return nil, fmt.Errorf("mock not implemented") } -// ----------------------------------------------------------------------------- -// Test Cases -// ----------------------------------------------------------------------------- - func TestAWS_OIDCRotator(t *testing.T) { t.Run("basic rotation", func(t *testing.T) { var mockSTS STSClient = &MockSTSOperations{ From 76fdfcf3f1d8aa94a5619346fae63de4256844b9 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Tue, 18 Feb 2025 10:51:28 -0500 Subject: [PATCH 81/86] Update internal/controller/rotators/aws_oidc_rotator_test.go Co-authored-by: Takeshi Yoneda Signed-off-by: Aaron Choo --- internal/controller/rotators/aws_oidc_rotator_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index 0a319ce58..8aa129ba4 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -77,7 +77,7 @@ func createClientSecret(t *testing.T, name string) { require.NoError(t, err) } -// MockSTSOperations implements the STSOperations interface for testing +// MockSTSOperations implements the STSClient interface for testing type MockSTSOperations struct { assumeRoleWithWebIdentityFunc func(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) } From c121d27f8aa41c12d71c298b941f25022c0bdbfe Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Tue, 18 Feb 2025 10:51:43 -0500 Subject: [PATCH 82/86] Update internal/controller/rotators/aws_oidc_rotator_test.go Co-authored-by: Takeshi Yoneda Signed-off-by: Aaron Choo --- internal/controller/rotators/aws_oidc_rotator_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index 8aa129ba4..75fe95868 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -78,7 +78,7 @@ func createClientSecret(t *testing.T, name string) { } // MockSTSOperations implements the STSClient interface for testing -type MockSTSOperations struct { +type mockSTSOperations struct { assumeRoleWithWebIdentityFunc func(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) } From 807f8b435e873f844527ef15e5062ad5d35b5223 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Tue, 18 Feb 2025 11:04:05 -0500 Subject: [PATCH 83/86] add explanation for AI_GATEWAY_STS_PROXY_URL Signed-off-by: Aaron Choo --- manifests/charts/ai-gateway-helm/values.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/manifests/charts/ai-gateway-helm/values.yaml b/manifests/charts/ai-gateway-helm/values.yaml index 860edf5b0..f54118e08 100644 --- a/manifests/charts/ai-gateway-helm/values.yaml +++ b/manifests/charts/ai-gateway-helm/values.yaml @@ -41,6 +41,7 @@ controller: # Example of a podEnv # - key: AI_GATEWAY_STS_PROXY_URL # value: some-proxy-placeholder + # AWS STS request when rotating OIDC credentials will be configured to use AI_GATEWAY_STS_PROXY_URL proxy if set. podEnv: {} # Example of volumes # - mountPath: /placeholder/path From ce36ffb128ccf9cfbecc8692976b5b541645773b Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Tue, 18 Feb 2025 11:10:22 -0500 Subject: [PATCH 84/86] preallocate map Signed-off-by: Aaron Choo --- internal/controller/oauth/oidc_provider.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/controller/oauth/oidc_provider.go b/internal/controller/oauth/oidc_provider.go index d563a74aa..c36ad9495 100644 --- a/internal/controller/oauth/oidc_provider.go +++ b/internal/controller/oauth/oidc_provider.go @@ -78,7 +78,7 @@ func (p *OIDCProvider) FetchToken(ctx context.Context) (*oauth2.Token, error) { // Add discovered scopes if available. if len(supportedScopes) > 0 { - requestedScopes := make(map[string]bool) + requestedScopes := make(map[string]bool, len(p.oidcConfig.Scopes)) for _, scope := range p.oidcConfig.Scopes { requestedScopes[scope] = true } From 0727c2bc8dfd5182ed826a212d4b7807f1638a34 Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Tue, 18 Feb 2025 11:12:15 -0500 Subject: [PATCH 85/86] update mockSTSOperations to lower Signed-off-by: Aaron Choo --- internal/controller/rotators/aws_oidc_rotator_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/controller/rotators/aws_oidc_rotator_test.go b/internal/controller/rotators/aws_oidc_rotator_test.go index 75fe95868..db77f2141 100644 --- a/internal/controller/rotators/aws_oidc_rotator_test.go +++ b/internal/controller/rotators/aws_oidc_rotator_test.go @@ -82,7 +82,7 @@ type mockSTSOperations struct { assumeRoleWithWebIdentityFunc func(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) } -func (m *MockSTSOperations) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { +func (m *mockSTSOperations) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { if m.assumeRoleWithWebIdentityFunc != nil { return m.assumeRoleWithWebIdentityFunc(ctx, params, optFns...) } @@ -91,7 +91,7 @@ func (m *MockSTSOperations) AssumeRoleWithWebIdentity(ctx context.Context, param func TestAWS_OIDCRotator(t *testing.T) { t.Run("basic rotation", func(t *testing.T) { - var mockSTS STSClient = &MockSTSOperations{ + var mockSTS STSClient = &mockSTSOperations{ assumeRoleWithWebIdentityFunc: func(_ context.Context, _ *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { return &sts.AssumeRoleWithWebIdentityOutput{ Credentials: &types.Credentials{ @@ -133,7 +133,7 @@ func TestAWS_OIDCRotator(t *testing.T) { cl := fake.NewClientBuilder().WithScheme(scheme).Build() createTestAWSSecret(t, cl, "test-secret", "OLDKEY", "OLDSECRET", "OLDTOKEN", "default") createClientSecret(t, "test-client-secret") - var mockSTS STSClient = &MockSTSOperations{ + var mockSTS STSClient = &mockSTSOperations{ assumeRoleWithWebIdentityFunc: func(_ context.Context, _ *sts.AssumeRoleWithWebIdentityInput, _ ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { return nil, fmt.Errorf("failed to assume role") }, From 55c243c90f066803fadb23d8a826656f15adc39a Mon Sep 17 00:00:00 2001 From: Aaron Choo Date: Tue, 18 Feb 2025 11:15:02 -0500 Subject: [PATCH 86/86] update comment Signed-off-by: Aaron Choo --- manifests/charts/ai-gateway-helm/values.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manifests/charts/ai-gateway-helm/values.yaml b/manifests/charts/ai-gateway-helm/values.yaml index f54118e08..981884b30 100644 --- a/manifests/charts/ai-gateway-helm/values.yaml +++ b/manifests/charts/ai-gateway-helm/values.yaml @@ -39,9 +39,9 @@ controller: podSecurityContext: {} securityContext: {} # Example of a podEnv + # AWS STS request when rotating OIDC credentials will be configured to use AI_GATEWAY_STS_PROXY_URL proxy if set. # - key: AI_GATEWAY_STS_PROXY_URL # value: some-proxy-placeholder - # AWS STS request when rotating OIDC credentials will be configured to use AI_GATEWAY_STS_PROXY_URL proxy if set. podEnv: {} # Example of volumes # - mountPath: /placeholder/path