Skip to content

Commit

Permalink
refactor: remove duplicate artifact downloading
Browse files Browse the repository at this point in the history
Signed-off-by: nikpivkin <[email protected]>
  • Loading branch information
nikpivkin committed Sep 30, 2024
1 parent d75a584 commit b08004a
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 14 deletions.
35 changes: 22 additions & 13 deletions pkg/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,8 @@ func (c *Client) updateDownloadedAt(ctx context.Context, dbDir string) error {
}

func (c *Client) initOCIArtifact(repository name.Reference, opt types.RegistryOptions) (*oci.Artifact, error) {
if c.artifact != nil {
return c.artifact, nil
}

art, err := oci.NewArtifact(repository.String(), c.quiet, opt)
// TODO: NewArtifact never returns an error
if err != nil {
var terr *transport.Error
if errors.As(err, &terr) {
Expand All @@ -221,28 +218,40 @@ func (c *Client) initOCIArtifact(repository name.Reference, opt types.RegistryOp
return art, nil
}

func (c *Client) downloadDB(ctx context.Context, opt types.RegistryOptions, dst string) error {
downloadOpt := oci.DownloadOption{MediaType: dbMediaType}
func (c *Client) initArtifacts(opt types.RegistryOptions) ([]*oci.Artifact, error) {
if c.artifact != nil {
return c.artifact.Download(ctx, dst, downloadOpt)
return []*oci.Artifact{c.artifact}, nil
}

for i, repo := range c.dbRepositories {
artifacts := make([]*oci.Artifact, 0, len(c.dbRepositories))

for _, repo := range c.dbRepositories {
a, err := c.initOCIArtifact(repo, opt)
if err != nil {
return err
return nil, err
}
artifacts = append(artifacts, a)
}
return artifacts, nil
}

func (c *Client) downloadDB(ctx context.Context, opt types.RegistryOptions, dst string) error {
arts, err := c.initArtifacts(opt)
if err != nil {
return err
}

log.Info("Downloading vulnerability DB...", log.String("repo", repo.String()))
if err := a.Download(ctx, dst, downloadOpt); err != nil {
log.Error("Failed to download DB", log.String("repo", repo.String()), log.Err(err))
for i, art := range arts {
log.Info("Downloading vulnerability DB...", log.String("repo", art.Repository()))
if err := art.Download(ctx, dst, oci.DownloadOption{MediaType: dbMediaType}); err != nil {
log.Error("Failed to download DB", log.String("repo", art.Repository()), log.Err(err))
if i < len(c.dbRepositories)-1 {
log.Info("Trying to download DB from other repository...")
}
continue
}

log.Info("DB successfully downloaded", log.String("repo", repo.String()))
log.Info("DB successfully downloaded", log.String("repo", art.Repository()))
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func TestClient_Download(t *testing.T) {
{
name: "invalid gzip",
input: "testdata/trivy.db",
wantErr: "unexpected EOF",
wantErr: "OCI artifact error: failed to download vulnerability DB from any source",
},
}

Expand Down
4 changes: 4 additions & 0 deletions pkg/oci/artifact.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,7 @@ func (a *Artifact) Digest(ctx context.Context) (string, error) {
}
return digest.String(), nil
}

func (a *Artifact) Repository() string {
return a.repository
}

0 comments on commit b08004a

Please sign in to comment.