diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go index 7fe7958ed53..af78eea641a 100644 --- a/client/vaultclient/vaultclient.go +++ b/client/vaultclient/vaultclient.go @@ -428,6 +428,7 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error { fatal := false if renewalErr != nil && (strings.Contains(renewalErr.Error(), "lease not found or lease is not renewable") || + strings.Contains(renewalErr.Error(), "lease is not renewable") || strings.Contains(renewalErr.Error(), "token not found") || strings.Contains(renewalErr.Error(), "permission denied")) { fatal = true diff --git a/client/vaultclient/vaultclient_test.go b/client/vaultclient/vaultclient_test.go index dc0b63aa85a..7322d33511b 100644 --- a/client/vaultclient/vaultclient_test.go +++ b/client/vaultclient/vaultclient_test.go @@ -3,6 +3,7 @@ package vaultclient import ( "log" "os" + "strings" "testing" "time" @@ -197,3 +198,85 @@ func TestVaultClient_Heap(t *testing.T) { } } + +func TestVaultClient_RenewNonRenewableLease(t *testing.T) { + t.Parallel() + v := testutil.NewTestVault(t) + defer v.Stop() + + logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) + v.Config.ConnectionRetryIntv = 100 * time.Millisecond + v.Config.TaskTokenTTL = "4s" + c, err := NewVaultClient(v.Config, logger, nil) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + + c.Start() + defer c.Stop() + + // Sleep a little while to ensure that the renewal loop is active + time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second) + + tcr := &vaultapi.TokenCreateRequest{ + Policies: []string{"foo", "bar"}, + TTL: "2s", + DisplayName: "derived-for-task", + Renewable: new(bool), + } + + c.client.SetToken(v.Config.Token) + + if err := c.client.SetAddress(v.Config.Addr); err != nil { + t.Fatal(err) + } + + secret, err := c.client.Auth().Token().Create(tcr) + if err != nil { + t.Fatalf("failed to create vault token: %v", err) + } + + if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { + t.Fatal("failed to derive a wrapped vault token") + } + + _, err = c.RenewToken(secret.Auth.ClientToken, secret.Auth.LeaseDuration) + if err == nil { + t.Fatalf("expected error, got nil") + } else if !strings.Contains(err.Error(), "lease is not renewable") { + t.Fatalf("expected \"%s\" in error message, got \"%v\"", "lease is not renewable", err) + } +} + +func TestVaultClient_RenewNonExistentLease(t *testing.T) { + t.Parallel() + v := testutil.NewTestVault(t) + defer v.Stop() + + logger := log.New(os.Stderr, "TEST: ", log.Lshortfile|log.LstdFlags) + v.Config.ConnectionRetryIntv = 100 * time.Millisecond + v.Config.TaskTokenTTL = "4s" + c, err := NewVaultClient(v.Config, logger, nil) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + + c.Start() + defer c.Stop() + + // Sleep a little while to ensure that the renewal loop is active + time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second) + + c.client.SetToken(v.Config.Token) + + if err := c.client.SetAddress(v.Config.Addr); err != nil { + t.Fatal(err) + } + + _, err = c.RenewToken(c.client.Token(), 10) + if err == nil { + t.Fatalf("expected error, got nil") + } else if !strings.Contains(err.Error(), "lease not found") { + t.Fatalf("expected \"%s\" in error message, got \"%v\"", "lease not found", err) + } +}