diff --git a/physical/azure/azure.go b/physical/azure/azure.go index 99ce0a0d5336..08ace98f9e2b 100644 --- a/physical/azure/azure.go +++ b/physical/azure/azure.go @@ -12,17 +12,20 @@ import ( "time" storage "github.com/Azure/azure-sdk-for-go/storage" - log "github.com/hashicorp/go-hclog" - "github.com/armon/go-metrics" "github.com/hashicorp/errwrap" cleanhttp "github.com/hashicorp/go-cleanhttp" + log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/physical" ) -// MaxBlobSize at this time -var MaxBlobSize = 1024 * 1024 * 4 +const ( + // MaxBlobSize at this time + MaxBlobSize = 1024 * 1024 * 4 + // MaxListResults is the current default value, setting explicitly + MaxListResults = 5000 +) // AzureBackend is a physical backend that stores data // within an Azure blob container. @@ -180,22 +183,35 @@ func (a *AzureBackend) List(ctx context.Context, prefix string) ([]string, error defer metrics.MeasureSince([]string{"azure", "list"}, time.Now()) a.permitPool.Acquire() - list, err := a.container.ListBlobs(storage.ListBlobsParameters{Prefix: prefix}) - if err != nil { - // Break early. - a.permitPool.Release() - return nil, err - } - a.permitPool.Release() + defer a.permitPool.Release() + var marker string keys := []string{} - for _, blob := range list.Blobs { - key := strings.TrimPrefix(blob.Name, prefix) - if i := strings.Index(key, "/"); i == -1 { - keys = append(keys, key) - } else { - keys = strutil.AppendIfMissing(keys, key[:i+1]) + for { + list, err := a.container.ListBlobs(storage.ListBlobsParameters{ + Prefix: prefix, + Marker: marker, + MaxResults: MaxListResults, + }) + if err != nil { + return nil, err + } + + for _, blob := range list.Blobs { + key := strings.TrimPrefix(blob.Name, prefix) + if i := strings.Index(key, "/"); i == -1 { + // file + keys = append(keys, key) + } else { + // subdirectory + keys = strutil.AppendIfMissing(keys, key[:i+1]) + } + } + + if list.NextMarker == "" { + break } + marker = list.NextMarker } sort.Strings(keys) diff --git a/physical/azure/azure_test.go b/physical/azure/azure_test.go index dfcd8b941356..a2929b194f79 100644 --- a/physical/azure/azure_test.go +++ b/physical/azure/azure_test.go @@ -1,17 +1,18 @@ package azure import ( + "context" "fmt" "os" + "strconv" "testing" "time" + storage "github.com/Azure/azure-sdk-for-go/storage" cleanhttp "github.com/hashicorp/go-cleanhttp" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/helper/logging" "github.com/hashicorp/vault/physical" - - storage "github.com/Azure/azure-sdk-for-go/storage" ) func TestAzureBackend(t *testing.T) { @@ -50,3 +51,55 @@ func TestAzureBackend(t *testing.T) { physical.ExerciseBackend(t, backend) physical.ExerciseBackend_ListPrefix(t, backend) } + +func TestAzureBackend_ListPaging(t *testing.T) { + if os.Getenv("AZURE_ACCOUNT_NAME") == "" || + os.Getenv("AZURE_ACCOUNT_KEY") == "" { + t.SkipNow() + } + + accountName := os.Getenv("AZURE_ACCOUNT_NAME") + accountKey := os.Getenv("AZURE_ACCOUNT_KEY") + + ts := time.Now().UnixNano() + name := fmt.Sprintf("vault-test-%d", ts) + + cleanupClient, _ := storage.NewBasicClient(accountName, accountKey) + cleanupClient.HTTPClient = cleanhttp.DefaultPooledClient() + + logger := logging.NewVaultLogger(log.Debug) + + backend, err := NewAzureBackend(map[string]string{ + "container": name, + "accountName": accountName, + "accountKey": accountKey, + }, logger) + + defer func() { + blobService := cleanupClient.GetBlobService() + container := blobService.GetContainerReference(name) + container.DeleteIfExists(nil) + }() + + if err != nil { + t.Fatalf("err: %s", err) + } + + // by default, azure returns 5000 results in a page, load up more than that + for i := 0; i < MaxListResults+100; i++ { + if err := backend.Put(context.Background(), &physical.Entry{ + Key: strconv.Itoa(i), + Value: []byte(strconv.Itoa(i)), + }); err != nil { + t.Fatalf("err: %s", err) + } + } + + results, err := backend.List(context.Background(), "") + if err != nil { + t.Fatalf("err: %s", err) + } + if len(results) != MaxListResults+100 { + t.Fatalf("expected %d, got %d", MaxListResults+100, len(results)) + } +}