Skip to content

Commit

Permalink
add apigateway authorizer adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasdembelli committed Jan 13, 2025
1 parent a5b9e53 commit ca3862d
Show file tree
Hide file tree
Showing 7 changed files with 315 additions and 0 deletions.
128 changes: 128 additions & 0 deletions adapters/apigateway-authorizer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package adapters

import (
"context"
"fmt"
"strings"

"github.com/aws/aws-sdk-go-v2/service/apigateway"
"github.com/aws/aws-sdk-go-v2/service/apigateway/types"
"github.com/overmindtech/aws-source/adapterhelpers"
"github.com/overmindtech/sdp-go"
)

// convertGetAuthorizerOutputToAuthorizer converts a GetAuthorizerOutput to an Authorizer
func convertGetAuthorizerOutputToAuthorizer(output *apigateway.GetAuthorizerOutput) *types.Authorizer {
return &types.Authorizer{
Id: output.Id,
Name: output.Name,
Type: output.Type,
ProviderARNs: output.ProviderARNs,
AuthType: output.AuthType,
AuthorizerUri: output.AuthorizerUri,
AuthorizerCredentials: output.AuthorizerCredentials,
IdentitySource: output.IdentitySource,
IdentityValidationExpression: output.IdentityValidationExpression,
AuthorizerResultTtlInSeconds: output.AuthorizerResultTtlInSeconds,
}
}

func authorizerOutputMapper(scope string, awsItem *types.Authorizer) (*sdp.Item, error) {
attributes, err := adapterhelpers.ToAttributesWithExclude(awsItem, "tags")
if err != nil {
return nil, err
}

item := sdp.Item{
Type: "apigateway-authorizer",
UniqueAttribute: "Id",
Attributes: attributes,
Scope: scope,
}

return &item, nil
}

func NewAPIGatewayAuthorizerAdapter(client *apigateway.Client, accountID string, region string) *adapterhelpers.GetListAdapter[*types.Authorizer, *apigateway.Client, *apigateway.Options] {
return &adapterhelpers.GetListAdapter[*types.Authorizer, *apigateway.Client, *apigateway.Options]{
ItemType: "apigateway-authorizer",
Client: client,
AccountID: accountID,
Region: region,
AdapterMetadata: authorizerAdapterMetadata,
GetFunc: func(ctx context.Context, client *apigateway.Client, scope, query string) (*types.Authorizer, error) {
f := strings.Split(query, "/")
if len(f) != 2 {
return nil, &sdp.QueryError{
ErrorType: sdp.QueryError_NOTFOUND,
ErrorString: fmt.Sprintf("query must be in the format of: the rest-api-id/authorizer-id, but found: %s", query),
}
}
out, err := client.GetAuthorizer(ctx, &apigateway.GetAuthorizerInput{
RestApiId: &f[0],
AuthorizerId: &f[1],
})
if err != nil {
return nil, err
}
return convertGetAuthorizerOutputToAuthorizer(out), nil
},
DisableList: true,
SearchFunc: func(ctx context.Context, client *apigateway.Client, scope string, query string) ([]*types.Authorizer, error) {
f := strings.Split(query, "/")
var restAPIID string
var name string

switch len(f) {
case 1:
restAPIID = f[0]
case 2:
restAPIID = f[0]
name = f[1]
default:
return nil, &sdp.QueryError{
ErrorType: sdp.QueryError_NOTFOUND,
ErrorString: fmt.Sprintf(
"query must be in the format of: the rest-api-id/authorizer-id or rest-api-id, but found: %s",
query,
),
}
}

out, err := client.GetAuthorizers(ctx, &apigateway.GetAuthorizersInput{
RestApiId: &restAPIID,
})
if err != nil {
return nil, err
}

var items []*types.Authorizer
for _, authorizer := range out.Items {
if name != "" && strings.Contains(*authorizer.Name, name) {
items = append(items, &authorizer)
} else {
items = append(items, &authorizer)
}
}

return items, nil
},
ItemMapper: func(_, scope string, awsItem *types.Authorizer) (*sdp.Item, error) {
return authorizerOutputMapper(scope, awsItem)
},
}
}

var authorizerAdapterMetadata = Metadata.Register(&sdp.AdapterMetadata{
Type: "apigateway-authorizer",
DescriptiveName: "API Gateway Authorizer",
Category: sdp.AdapterCategory_ADAPTER_CATEGORY_SECURITY,
SupportedQueryMethods: &sdp.AdapterSupportedQueryMethods{
Get: true,
List: true,
Search: true,
GetDescription: "Get an API Gateway Authorizer by its rest API ID and ID: rest-api-id/authorizer-id",
ListDescription: "List all API Gateway Authorizers",
SearchDescription: "Search for API Gateway Authorizers by their rest API ID or with rest API ID and their name: rest-api-id/authorizer-name",
},
})
51 changes: 51 additions & 0 deletions adapters/apigateway-authorizer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package adapters

import (
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/apigateway"
"github.com/aws/aws-sdk-go-v2/service/apigateway/types"
"github.com/overmindtech/aws-source/adapterhelpers"
)

func TestAuthorizerOutputMapper(t *testing.T) {
awsItem := &types.Authorizer{
Id: aws.String("authorizer-id"),
Name: aws.String("authorizer-name"),
Type: types.AuthorizerTypeRequest,
ProviderARNs: []string{"arn:aws:iam::123456789012:role/service-role"},
AuthType: aws.String("custom"),
AuthorizerUri: aws.String("arn:aws:apigateway:us-east-1:lambda:path/2015-03-31/functions/arn:aws:lambda:us-east-1:123456789012:function:my-function/invocations"),
AuthorizerCredentials: aws.String("arn:aws:iam::123456789012:role/service-role"),
IdentitySource: aws.String("method.request.header.Authorization"),
IdentityValidationExpression: aws.String(".*"),
AuthorizerResultTtlInSeconds: aws.Int32(300),
}

item, err := authorizerOutputMapper("scope", awsItem)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if err := item.Validate(); err != nil {
t.Error(err)
}
}

func TestNewAPIGatewayAuthorizerAdapter(t *testing.T) {
config, account, region := adapterhelpers.GetAutoConfig(t)

client := apigateway.NewFromConfig(config)

adapter := NewAPIGatewayAuthorizerAdapter(client, account, region)

test := adapterhelpers.E2ETest{
Adapter: adapter,
Timeout: 10 * time.Second,
SkipList: true,
}

test.Run(t)
}
75 changes: 75 additions & 0 deletions adapters/integration/apigateway/apigateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ func APIGateway(t *testing.T) {
t.Fatalf("failed to validate APIGateway API key adapter: %v", err)
}

authorizerSource := adapters.NewAPIGatewayAuthorizerAdapter(testClient, accountID, testAWSConfig.Region)

err = authorizerSource.Validate()
if err != nil {
t.Fatalf("failed to validate APIGateway authorizer adapter: %v", err)
}

scope := adapterhelpers.FormatScope(accountID, testAWSConfig.Region)

// List restApis
Expand Down Expand Up @@ -301,5 +308,73 @@ func APIGateway(t *testing.T) {
t.Fatalf("expected API key ID %s, got %s", apiKeyID, apiKeyIDFromSearch)
}

// Search authorizers by restApiID
authorizers, err := authorizerSource.Search(ctx, scope, restApiID, true)
if err != nil {
t.Fatalf("failed to search APIGateway authorizers: %v", err)
}

authorizerUniqueAttribute := authorizers[0].GetUniqueAttribute()

authorizerTestName := integration.ResourceName(integration.APIGateway, authorizerSrc, integration.TestID())
authorizerID, err := integration.GetUniqueAttributeValueBySignificantAttribute(
authorizerUniqueAttribute,
"Name",
authorizerTestName,
authorizers,
true,
)
if err != nil {
t.Fatalf("failed to get authorizer ID: %v", err)
}

// Get authorizer
query := fmt.Sprintf("%s/%s", restApiID, authorizerID)
authorizer, err := authorizerSource.Get(ctx, scope, query, true)
if err != nil {
t.Fatalf("failed to get APIGateway authorizer: %v", err)
}

authorizerIDFromGet, err := integration.GetUniqueAttributeValueBySignificantAttribute(
authorizerUniqueAttribute,
"Name",
authorizerTestName,
[]*sdp.Item{authorizer},
true,
)
if err != nil {
t.Fatalf("failed to get authorizer ID from get: %v", err)
}

if authorizerID != authorizerIDFromGet {
t.Fatalf("expected authorizer ID %s, got %s", authorizerID, authorizerIDFromGet)
}

// Search authorizer by restApiID/name
query = fmt.Sprintf("%s/%s", restApiID, authorizerTestName)
authorizersFromSearch, err := authorizerSource.Search(ctx, scope, query, true)
if err != nil {
t.Fatalf("failed to search APIGateway authorizers: %v", err)
}

if len(authorizersFromSearch) == 0 {
t.Fatalf("no authorizers found")
}

authorizerIDFromSearch, err := integration.GetUniqueAttributeValueBySignificantAttribute(
authorizerUniqueAttribute,
"Name",
authorizerTestName,
authorizersFromSearch,
true,
)
if err != nil {
t.Fatalf("failed to get authorizer ID from search: %v", err)
}

if authorizerID != authorizerIDFromSearch {
t.Fatalf("expected authorizer ID %s, got %s", authorizerID, authorizerIDFromSearch)
}

t.Log("APIGateway integration test completed")
}
32 changes: 32 additions & 0 deletions adapters/integration/apigateway/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package apigateway
import (
"context"
"errors"
"github.com/aws/aws-sdk-go-v2/service/apigateway/types"
"log/slog"
"strings"

Expand Down Expand Up @@ -196,3 +197,34 @@ func createAPIKey(ctx context.Context, logger *slog.Logger, client *apigateway.C

return nil
}

func createAuthorizer(ctx context.Context, logger *slog.Logger, client *apigateway.Client, restAPIID, testID string) error {
// check if an authorizer with the same name already exists
id, err := findAuthorizerByName(ctx, client, restAPIID, integration.ResourceName(integration.APIGateway, authorizerSrc, testID))
if err != nil {
if errors.As(err, new(integration.NotFoundError)) {
logger.InfoContext(ctx, "Creating authorizer")
} else {
return err
}
}

if id != nil {
logger.InfoContext(ctx, "Authorizer already exists")
return nil
}

identitySource := "method.request.header.Authorization"
_, err = client.CreateAuthorizer(ctx, &apigateway.CreateAuthorizerInput{
RestApiId: &restAPIID,
Name: adapterhelpers.PtrString(integration.ResourceName(integration.APIGateway, authorizerSrc, testID)),
Type: types.AuthorizerTypeToken,
IdentitySource: &identitySource,
AuthorizerUri: adapterhelpers.PtrString("arn:aws:apigateway:us-east-1:lambda:path/2015-03-31/functions/arn:aws:lambda:us-east-1:123456789012:function:auth-function/invocations"),
})
if err != nil {
return err
}

return nil
}
21 changes: 21 additions & 0 deletions adapters/integration/apigateway/find.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,24 @@ func findAPIKeyByName(ctx context.Context, client *apigateway.Client, name strin

return nil, integration.NewNotFoundError(integration.ResourceName(integration.APIGateway, apiKeySrc, name))
}

func findAuthorizerByName(ctx context.Context, client *apigateway.Client, restAPIID, name string) (*string, error) {
result, err := client.GetAuthorizers(ctx, &apigateway.GetAuthorizersInput{
RestApiId: &restAPIID,
})
if err != nil {
return nil, err
}

if len(result.Items) == 0 {
return nil, integration.NewNotFoundError(integration.ResourceName(integration.APIGateway, authorizerSrc, name))
}

for _, authorizer := range result.Items {
if *authorizer.Name == name {
return authorizer.Id, nil
}
}

return nil, integration.NewNotFoundError(integration.ResourceName(integration.APIGateway, authorizerSrc, name))
}
7 changes: 7 additions & 0 deletions adapters/integration/apigateway/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const (
methodResponseSrc = "method-response"
integrationSrc = "integration"
apiKeySrc = "api-key"
authorizerSrc = "authorizer"
)

func setup(ctx context.Context, logger *slog.Logger, client *apigateway.Client) error {
Expand Down Expand Up @@ -62,5 +63,11 @@ func setup(ctx context.Context, logger *slog.Logger, client *apigateway.Client)
return err
}

// Create Authorizer
err = createAuthorizer(ctx, logger, client, *restApiID, testID)
if err != nil {
return err
}

return nil
}
1 change: 1 addition & 0 deletions proc/proc.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ func InitializeAwsSourceEngine(ctx context.Context, ec *discovery.EngineConfig,
adapters.NewAPIGatewayIntegrationAdapter(apigatewayClient, *callerID.Account, cfg.Region),
adapters.NewAPIGatewayVpcLinkAdapter(apigatewayClient, *callerID.Account, cfg.Region),
adapters.NewAPIGatewayApiKeyAdapter(apigatewayClient, *callerID.Account, cfg.Region),
adapters.NewAPIGatewayAuthorizerAdapter(apigatewayClient, *callerID.Account, cfg.Region),

// SSM
adapters.NewSSMParameterAdapter(ssmClient, *callerID.Account, cfg.Region),
Expand Down

0 comments on commit ca3862d

Please sign in to comment.