Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(aws): handle ECR repositories in different regions #6217

Merged
merged 1 commit into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions pkg/fanal/image/registry/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"golang.org/x/xerrors"

"github.com/aquasecurity/trivy/pkg/fanal/image/registry/intf"
"github.com/aquasecurity/trivy/pkg/fanal/types"
)

type Registry struct {
type RegistryClient struct {
domain string
scope string
cloud cloud.Configuration
}

type Registry struct {
}

const (
azureURL = ".azurecr.io"
chinaAzureURL = ".azurecr.cn"
Expand All @@ -31,23 +35,25 @@ const (
scheme = "https"
)

func (r *Registry) CheckOptions(domain string, _ types.RegistryOptions) error {
func (r *Registry) CheckOptions(domain string, _ types.RegistryOptions) (intf.RegistryClient, error) {
if strings.HasSuffix(domain, azureURL) {
r.domain = domain
r.scope = scope
r.cloud = cloud.AzurePublic
return nil
return &RegistryClient{
domain: domain,
scope: scope,
cloud: cloud.AzurePublic,
}, nil
} else if strings.HasSuffix(domain, chinaAzureURL) {
r.domain = domain
r.scope = chinaScope
r.cloud = cloud.AzureChina
return nil
return &RegistryClient{
domain: domain,
scope: scope,
cloud: cloud.AzureChina,
}, nil
}

return xerrors.Errorf("Azure registry: %w", types.InvalidURLPattern)
return nil, xerrors.Errorf("Azure registry: %w", types.InvalidURLPattern)
}

func (r *Registry) GetCredential(ctx context.Context) (string, string, error) {
func (r *RegistryClient) GetCredential(ctx context.Context) (string, string, error) {
opts := azcore.ClientOptions{Cloud: r.cloud}
cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{ClientOptions: opts})
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/fanal/image/registry/azure/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestRegistry_CheckOptions(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := azure.Registry{}
err := r.CheckOptions(tt.domain, types.RegistryOptions{})
_, err := r.CheckOptions(tt.domain, types.RegistryOptions{})
if tt.wantErr != "" {
assert.EqualError(t, err, tt.wantErr)
} else {
Expand Down
54 changes: 40 additions & 14 deletions pkg/fanal/image/registry/ecr/ecr.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ecr
import (
"context"
"encoding/base64"
"regexp"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
Expand All @@ -11,48 +12,73 @@ import (
"github.com/aws/aws-sdk-go-v2/service/ecr"
"golang.org/x/xerrors"

"github.com/aquasecurity/trivy/pkg/fanal/image/registry/intf"
"github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
)

const ecrURLSuffix = ".amazonaws.com"
const ecrURLPartial = ".dkr.ecr"

type ecrAPI interface {
GetAuthorizationToken(ctx context.Context, params *ecr.GetAuthorizationTokenInput, optFns ...func(*ecr.Options)) (*ecr.GetAuthorizationTokenOutput, error)
}

type ECR struct {
}

type ECRClient struct {
Client ecrAPI
}

func getSession(option types.RegistryOptions) (aws.Config, error) {
func getSession(domain, region string, option types.RegistryOptions) (aws.Config, error) {
// create custom credential information if option is valid
if option.AWSSecretKey != "" && option.AWSAccessKey != "" && option.AWSRegion != "" {
if region != option.AWSRegion {
log.Warnf("The region from AWS_REGION (%s) is being overridden. The region from domain (%s) was used.", option.AWSRegion, domain)
}
return config.LoadDefaultConfig(
context.TODO(),
config.WithRegion(option.AWSRegion),
config.WithRegion(region),
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(option.AWSAccessKey, option.AWSSecretKey, option.AWSSessionToken)),
)
}
return config.LoadDefaultConfig(context.TODO())
return config.LoadDefaultConfig(context.TODO(), config.WithRegion(region))
}

func (e *ECR) CheckOptions(domain string, option types.RegistryOptions) error {
if !strings.HasSuffix(domain, ecrURLSuffix) && !strings.Contains(domain, ecrURLPartial) {
return xerrors.Errorf("ECR : %w", types.InvalidURLPattern)
func (e *ECR) CheckOptions(domain string, option types.RegistryOptions) (intf.RegistryClient, error) {
region := determineRegion(domain)
if region == "" {
return nil, xerrors.Errorf("ECR : %w", types.InvalidURLPattern)
}

cfg, err := getSession(option)
cfg, err := getSession(domain, region, option)
if err != nil {
return err
return nil, err
}

svc := ecr.NewFromConfig(cfg)
e.Client = svc
return nil
return &ECRClient{Client: svc}, nil
}

// Endpoints take the form
// <registry-id>.dkr.ecr.<region>.amazonaws.com
// <registry-id>.dkr.ecr-fips.<region>.amazonaws.com
// <registry-id>.dkr.ecr.<region>.amazonaws.com.cn
// <registry-id>.dkr.ecr.<region>.sc2s.sgov.gov
// <registry-id>.dkr.ecr.<region>.c2s.ic.gov
// see
// - https://docs.aws.amazon.com/general/latest/gr/ecr.html
// - https://docs.amazonaws.cn/en_us/aws/latest/userguide/endpoints-arns.html
// - https://github.com/boto/botocore/blob/1.34.51/botocore/data/endpoints.json
var ecrEndpointMatch = regexp.MustCompile(`^[^.]+\.dkr\.ecr(?:-fips)?\.([^.]+)\.(?:amazonaws\.com(?:\.cn)?|sc2s\.sgov\.gov|c2s\.ic\.gov)$`)

func determineRegion(domain string) string {
matches := ecrEndpointMatch.FindStringSubmatch(domain)
if matches != nil {
return matches[1]
}
return ""
}

func (e *ECR) GetCredential(ctx context.Context) (username, password string, err error) {
func (e *ECRClient) GetCredential(ctx context.Context) (username, password string, err error) {
input := &ecr.GetAuthorizationTokenInput{}
result, err := e.Client.GetAuthorizationToken(ctx, input)
if err != nil {
Expand Down
68 changes: 63 additions & 5 deletions pkg/fanal/image/registry/ecr/ecr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,20 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ecr"
awstypes "github.com/aws/aws-sdk-go-v2/service/ecr/types"
"github.com/stretchr/testify/require"

"github.com/aquasecurity/trivy/pkg/fanal/types"
)

type testECRClient interface {
Options() ecr.Options
}

func TestCheckOptions(t *testing.T) {
var tests = map[string]struct {
domain string
wantErr error
domain string
expectedRegion string
wantErr error
}{
"InvalidURL": {
domain: "alpine:3.9",
Expand All @@ -30,19 +36,71 @@ func TestCheckOptions(t *testing.T) {
wantErr: types.InvalidURLPattern,
},
"NoOption": {
domain: "xxx.ecr.ap-northeast-1.amazonaws.com",
domain: "xxx.dkr.ecr.ap-northeast-1.amazonaws.com",
expectedRegion: "ap-northeast-1",
},
"region-1": {
domain: "xxx.dkr.ecr.region-1.amazonaws.com",
expectedRegion: "region-1",
},
"region-2": {
domain: "xxx.dkr.ecr.region-2.amazonaws.com",
expectedRegion: "region-2",
},
"fips-region-1": {
domain: "xxx.dkr.ecr-fips.fips-region.amazonaws.com",
expectedRegion: "fips-region",
},
"cn-region-1": {
domain: "xxx.dkr.ecr.region-1.amazonaws.com.cn",
expectedRegion: "region-1",
},
"cn-region-2": {
domain: "xxx.dkr.ecr.region-2.amazonaws.com.cn",
expectedRegion: "region-2",
},
"sc2s-region-1": {
domain: "xxx.dkr.ecr.sc2s-region.sc2s.sgov.gov",
expectedRegion: "sc2s-region",
},
"c2s-region-1": {
domain: "xxx.dkr.ecr.c2s-region.c2s.ic.gov",
expectedRegion: "c2s-region",
},
"invalid-ecr": {
domain: "xxx.dkrecr.region-1.amazonaws.com",
wantErr: types.InvalidURLPattern,
},
"invalid-fips": {
domain: "xxx.dkr.ecrfips.fips-region.amazonaws.com",
wantErr: types.InvalidURLPattern,
},
"invalid-cn": {
domain: "xxx.dkr.ecr.region-2.amazonaws.cn",
wantErr: types.InvalidURLPattern,
},
"invalid-sc2s": {
domain: "xxx.dkr.ecr.sc2s-region.sc2s.sgov",
wantErr: types.InvalidURLPattern,
},
"invalid-cs2": {
domain: "xxx.dkr.ecr.c2s-region.c2s.ic",
wantErr: types.InvalidURLPattern,
},
}

for testname, v := range tests {
a := &ECR{}
err := a.CheckOptions(v.domain, types.RegistryOptions{})
ecrClient, err := a.CheckOptions(v.domain, types.RegistryOptions{})
if err != nil {
if !errors.Is(err, v.wantErr) {
t.Errorf("[%s]\nexpected error based on %v\nactual : %v", testname, v.wantErr, err)
}
continue
}

client := (ecrClient.(*ECRClient)).Client.(testECRClient)
require.Equal(t, v.expectedRegion, client.Options().Region)
}
}

Expand Down Expand Up @@ -90,7 +148,7 @@ func TestECRGetCredential(t *testing.T) {
}

for i, c := range cases {
e := ECR{
e := ECRClient{
Client: mockedECR{Resp: c.Resp},
}
username, password, err := e.GetCredential(context.Background())
Expand Down
18 changes: 11 additions & 7 deletions pkg/fanal/image/registry/google/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,37 @@ import (
"github.com/GoogleCloudPlatform/docker-credential-gcr/store"
"golang.org/x/xerrors"

"github.com/aquasecurity/trivy/pkg/fanal/image/registry/intf"
"github.com/aquasecurity/trivy/pkg/fanal/types"
)

type Registry struct {
type GoogleRegistryClient struct {
Store store.GCRCredStore
domain string
}

type Registry struct {
}

// Google container registry
const gcrURLDomain = "gcr.io"
const gcrURLSuffix = ".gcr.io"

// Google artifact registry
const garURLSuffix = "-docker.pkg.dev"

func (g *Registry) CheckOptions(domain string, option types.RegistryOptions) error {
func (g *Registry) CheckOptions(domain string, option types.RegistryOptions) (intf.RegistryClient, error) {
if domain != gcrURLDomain && !strings.HasSuffix(domain, gcrURLSuffix) && !strings.HasSuffix(domain, garURLSuffix) {
return xerrors.Errorf("Google registry: %w", types.InvalidURLPattern)
return nil, xerrors.Errorf("Google registry: %w", types.InvalidURLPattern)
}
g.domain = domain
client := GoogleRegistryClient{domain: domain}
if option.GCPCredPath != "" {
g.Store = store.NewGCRCredStore(option.GCPCredPath)
client.Store = store.NewGCRCredStore(option.GCPCredPath)
}
return nil
return &client, nil
}

func (g *Registry) GetCredential(_ context.Context) (username, password string, err error) {
func (g *GoogleRegistryClient) GetCredential(_ context.Context) (username, password string, err error) {
var credStore store.GCRCredStore
if g.Store == nil {
credStore, err = store.DefaultGCRCredStore()
Expand Down
12 changes: 6 additions & 6 deletions pkg/fanal/image/registry/google/google_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func TestCheckOptions(t *testing.T) {
var tests = map[string]struct {
domain string
opt types.RegistryOptions
gcr *Registry
grc *GoogleRegistryClient
wantErr error
}{
"InvalidURL": {
Expand All @@ -27,12 +27,12 @@ func TestCheckOptions(t *testing.T) {
},
"NoOption": {
domain: "gcr.io",
gcr: &Registry{domain: "gcr.io"},
grc: &GoogleRegistryClient{domain: "gcr.io"},
},
"CredOption": {
domain: "gcr.io",
opt: types.RegistryOptions{GCPCredPath: "/path/to/file.json"},
gcr: &Registry{
DmitriyLewen marked this conversation as resolved.
Show resolved Hide resolved
grc: &GoogleRegistryClient{
domain: "gcr.io",
Store: store.NewGCRCredStore("/path/to/file.json"),
},
Expand All @@ -41,7 +41,7 @@ func TestCheckOptions(t *testing.T) {

for testname, v := range tests {
g := &Registry{}
err := g.CheckOptions(v.domain, v.opt)
grc, err := g.CheckOptions(v.domain, v.opt)
if v.wantErr != nil {
if err == nil {
t.Errorf("%s : expected error but no error", testname)
Expand All @@ -52,8 +52,8 @@ func TestCheckOptions(t *testing.T) {
}
continue
}
if !reflect.DeepEqual(v.gcr, g) {
t.Errorf("[%s]\nexpected : %v\nactual : %v", testname, v.gcr, g)
if !reflect.DeepEqual(v.grc, grc) {
t.Errorf("[%s]\nexpected : %v\nactual : %v", testname, v.grc, grc)
}
}
}
15 changes: 15 additions & 0 deletions pkg/fanal/image/registry/intf/registry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package intf

import (
"context"

"github.com/aquasecurity/trivy/pkg/fanal/types"
)

type RegistryClient interface {
GetCredential(ctx context.Context) (string, string, error)
}

type Registry interface {
CheckOptions(domain string, option types.RegistryOptions) (RegistryClient, error)
}
Loading