diff --git a/changelog/24925.txt b/changelog/24925.txt new file mode 100644 index 000000000000..7bce8d0bdebc --- /dev/null +++ b/changelog/24925.txt @@ -0,0 +1,3 @@ +```release-note:improvement +sdk: Add identity token helpers to consistently apply new plugin WIF fields across integrations. +``` \ No newline at end of file diff --git a/sdk/helper/pluginidentityutil/fields.go b/sdk/helper/pluginidentityutil/fields.go new file mode 100644 index 000000000000..27a692b10421 --- /dev/null +++ b/sdk/helper/pluginidentityutil/fields.go @@ -0,0 +1,70 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package pluginidentityutil + +import ( + "errors" + "fmt" + "time" + + "github.com/hashicorp/vault/sdk/framework" +) + +// PluginIdentityTokenParams contains a set of common parameters that plugins +// can use for setting plugin identity token behavior. +type PluginIdentityTokenParams struct { + // IdentityTokenTTL is the duration that tokens will be valid for + IdentityTokenTTL time.Duration `json:"identity_token_ttl"` + // IdentityTokenAudience identifies the recipient of the token + IdentityTokenAudience string `json:"identity_token_audience"` +} + +// ParsePluginIdentityTokenFields provides common field parsing to embedding structs. +func (p *PluginIdentityTokenParams) ParsePluginIdentityTokenFields(d *framework.FieldData) error { + if tokenTTLRaw, ok := d.GetOk("identity_token_ttl"); ok { + p.IdentityTokenTTL = time.Duration(tokenTTLRaw.(int)) * time.Second + } + if p.IdentityTokenTTL == 0 { + p.IdentityTokenTTL = time.Hour + } + + if tokenAudienceRaw, ok := d.GetOk("identity_token_audience"); ok { + p.IdentityTokenAudience = tokenAudienceRaw.(string) + } + if p.IdentityTokenAudience == "" { + return errors.New("missing required identity_token_audience") + } + + return nil +} + +// PopulatePluginIdentityTokenData adds PluginIdentityTokenParams info into the given map. +func (p *PluginIdentityTokenParams) PopulatePluginIdentityTokenData(m map[string]interface{}) { + m["identity_token_ttl"] = int64(p.IdentityTokenTTL.Seconds()) + m["identity_token_audience"] = p.IdentityTokenAudience +} + +// AddPluginIdentityTokenFields adds plugin identity token fields to the given +// field schema map. +func AddPluginIdentityTokenFields(m map[string]*framework.FieldSchema) { + fields := map[string]*framework.FieldSchema{ + "identity_token_audience": { + Type: framework.TypeString, + Description: "Audience of plugin identity tokens", + Default: "", + }, + "identity_token_ttl": { + Type: framework.TypeDurationSecond, + Description: "Time-to-live of plugin identity tokens", + Default: 3600, + }, + } + + for name, schema := range fields { + if _, ok := m[name]; ok { + panic(fmt.Sprintf("adding field %q would overwrite existing field", name)) + } + m[name] = schema + } +} diff --git a/sdk/helper/pluginidentityutil/fields_test.go b/sdk/helper/pluginidentityutil/fields_test.go new file mode 100644 index 000000000000..f66196b0d31f --- /dev/null +++ b/sdk/helper/pluginidentityutil/fields_test.go @@ -0,0 +1,172 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package pluginidentityutil + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/hashicorp/vault/sdk/framework" +) + +const ( + fieldIDTokenTTL = "identity_token_ttl" + fieldIDTokenAudience = "identity_token_audience" +) + +func identityTokenFieldData(raw map[string]interface{}) *framework.FieldData { + return &framework.FieldData{ + Raw: raw, + Schema: map[string]*framework.FieldSchema{ + fieldIDTokenTTL: { + Type: framework.TypeDurationSecond, + }, + fieldIDTokenAudience: { + Type: framework.TypeString, + }, + }, + } +} + +func TestParsePluginIdentityTokenFields(t *testing.T) { + testcases := []struct { + name string + d *framework.FieldData + wantErr bool + want map[string]interface{} + }{ + { + name: "basic", + d: identityTokenFieldData(map[string]interface{}{ + fieldIDTokenTTL: 10, + fieldIDTokenAudience: "test-aud", + }), + want: map[string]interface{}{ + fieldIDTokenTTL: time.Duration(10) * time.Second, + fieldIDTokenAudience: "test-aud", + }, + }, + { + name: "empty-ttl", + d: identityTokenFieldData(map[string]interface{}{ + fieldIDTokenAudience: "test-aud", + }), + want: map[string]interface{}{ + fieldIDTokenTTL: time.Hour, + fieldIDTokenAudience: "test-aud", + }, + }, + { + name: "empty-audience", + d: identityTokenFieldData(map[string]interface{}{}), + wantErr: true, + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + p := new(PluginIdentityTokenParams) + err := p.ParsePluginIdentityTokenFields(tt.d) + if tt.wantErr { + assert.Error(t, err) + return + } + got := map[string]interface{}{ + fieldIDTokenTTL: p.IdentityTokenTTL, + fieldIDTokenAudience: p.IdentityTokenAudience, + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestPopulatePluginIdentityTokenData(t *testing.T) { + testcases := []struct { + name string + p *PluginIdentityTokenParams + want map[string]interface{} + }{ + { + name: "basic", + p: &PluginIdentityTokenParams{ + IdentityTokenAudience: "test-aud", + IdentityTokenTTL: time.Duration(10) * time.Second, + }, + want: map[string]interface{}{ + fieldIDTokenTTL: int64(10), + fieldIDTokenAudience: "test-aud", + }, + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + got := make(map[string]interface{}) + tt.p.PopulatePluginIdentityTokenData(got) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestAddPluginIdentityTokenFields(t *testing.T) { + testcases := []struct { + name string + input map[string]*framework.FieldSchema + want map[string]*framework.FieldSchema + }{ + { + name: "basic", + input: map[string]*framework.FieldSchema{}, + want: map[string]*framework.FieldSchema{ + fieldIDTokenAudience: { + Type: framework.TypeString, + Description: "Audience of plugin identity tokens", + Default: "", + }, + fieldIDTokenTTL: { + Type: framework.TypeDurationSecond, + Description: "Time-to-live of plugin identity tokens", + Default: 3600, + }, + }, + }, + { + name: "additional-fields", + input: map[string]*framework.FieldSchema{ + "test": { + Type: framework.TypeString, + Description: "Test description", + Default: "default", + }, + }, + want: map[string]*framework.FieldSchema{ + fieldIDTokenAudience: { + Type: framework.TypeString, + Description: "Audience of plugin identity tokens", + Default: "", + }, + fieldIDTokenTTL: { + Type: framework.TypeDurationSecond, + Description: "Time-to-live of plugin identity tokens", + Default: 3600, + }, + "test": { + Type: framework.TypeString, + Description: "Test description", + Default: "default", + }, + }, + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + got := tt.input + AddPluginIdentityTokenFields(got) + assert.Equal(t, tt.want, got) + }) + } +}