Skip to content

Commit

Permalink
sdk: add identity token helpers to consistently apply fields in plugi…
Browse files Browse the repository at this point in the history
…ns (#24925)
  • Loading branch information
vinay-gopalan authored Jan 17, 2024
1 parent fd92f2c commit 5f3ff6b
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 0 deletions.
3 changes: 3 additions & 0 deletions changelog/24925.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
sdk: Add identity token helpers to consistently apply new plugin WIF fields across integrations.
```
70 changes: 70 additions & 0 deletions sdk/helper/pluginidentityutil/fields.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package pluginidentityutil

import (
"errors"
"fmt"
"time"

"github.com/hashicorp/vault/sdk/framework"
)

// PluginIdentityTokenParams contains a set of common parameters that plugins
// can use for setting plugin identity token behavior.
type PluginIdentityTokenParams struct {
// IdentityTokenTTL is the duration that tokens will be valid for
IdentityTokenTTL time.Duration `json:"identity_token_ttl"`
// IdentityTokenAudience identifies the recipient of the token
IdentityTokenAudience string `json:"identity_token_audience"`
}

// ParsePluginIdentityTokenFields provides common field parsing to embedding structs.
func (p *PluginIdentityTokenParams) ParsePluginIdentityTokenFields(d *framework.FieldData) error {
if tokenTTLRaw, ok := d.GetOk("identity_token_ttl"); ok {
p.IdentityTokenTTL = time.Duration(tokenTTLRaw.(int)) * time.Second
}
if p.IdentityTokenTTL == 0 {
p.IdentityTokenTTL = time.Hour
}

if tokenAudienceRaw, ok := d.GetOk("identity_token_audience"); ok {
p.IdentityTokenAudience = tokenAudienceRaw.(string)
}
if p.IdentityTokenAudience == "" {
return errors.New("missing required identity_token_audience")
}

return nil
}

// PopulatePluginIdentityTokenData adds PluginIdentityTokenParams info into the given map.
func (p *PluginIdentityTokenParams) PopulatePluginIdentityTokenData(m map[string]interface{}) {
m["identity_token_ttl"] = int64(p.IdentityTokenTTL.Seconds())
m["identity_token_audience"] = p.IdentityTokenAudience
}

// AddPluginIdentityTokenFields adds plugin identity token fields to the given
// field schema map.
func AddPluginIdentityTokenFields(m map[string]*framework.FieldSchema) {
fields := map[string]*framework.FieldSchema{
"identity_token_audience": {
Type: framework.TypeString,
Description: "Audience of plugin identity tokens",
Default: "",
},
"identity_token_ttl": {
Type: framework.TypeDurationSecond,
Description: "Time-to-live of plugin identity tokens",
Default: 3600,
},
}

for name, schema := range fields {
if _, ok := m[name]; ok {
panic(fmt.Sprintf("adding field %q would overwrite existing field", name))
}
m[name] = schema
}
}
172 changes: 172 additions & 0 deletions sdk/helper/pluginidentityutil/fields_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package pluginidentityutil

import (
"testing"
"time"

"github.com/stretchr/testify/assert"

"github.com/hashicorp/vault/sdk/framework"
)

const (
fieldIDTokenTTL = "identity_token_ttl"
fieldIDTokenAudience = "identity_token_audience"
)

func identityTokenFieldData(raw map[string]interface{}) *framework.FieldData {
return &framework.FieldData{
Raw: raw,
Schema: map[string]*framework.FieldSchema{
fieldIDTokenTTL: {
Type: framework.TypeDurationSecond,
},
fieldIDTokenAudience: {
Type: framework.TypeString,
},
},
}
}

func TestParsePluginIdentityTokenFields(t *testing.T) {
testcases := []struct {
name string
d *framework.FieldData
wantErr bool
want map[string]interface{}
}{
{
name: "basic",
d: identityTokenFieldData(map[string]interface{}{
fieldIDTokenTTL: 10,
fieldIDTokenAudience: "test-aud",
}),
want: map[string]interface{}{
fieldIDTokenTTL: time.Duration(10) * time.Second,
fieldIDTokenAudience: "test-aud",
},
},
{
name: "empty-ttl",
d: identityTokenFieldData(map[string]interface{}{
fieldIDTokenAudience: "test-aud",
}),
want: map[string]interface{}{
fieldIDTokenTTL: time.Hour,
fieldIDTokenAudience: "test-aud",
},
},
{
name: "empty-audience",
d: identityTokenFieldData(map[string]interface{}{}),
wantErr: true,
},
}

for _, tt := range testcases {
t.Run(tt.name, func(t *testing.T) {
p := new(PluginIdentityTokenParams)
err := p.ParsePluginIdentityTokenFields(tt.d)
if tt.wantErr {
assert.Error(t, err)
return
}
got := map[string]interface{}{
fieldIDTokenTTL: p.IdentityTokenTTL,
fieldIDTokenAudience: p.IdentityTokenAudience,
}
assert.Equal(t, tt.want, got)
})
}
}

func TestPopulatePluginIdentityTokenData(t *testing.T) {
testcases := []struct {
name string
p *PluginIdentityTokenParams
want map[string]interface{}
}{
{
name: "basic",
p: &PluginIdentityTokenParams{
IdentityTokenAudience: "test-aud",
IdentityTokenTTL: time.Duration(10) * time.Second,
},
want: map[string]interface{}{
fieldIDTokenTTL: int64(10),
fieldIDTokenAudience: "test-aud",
},
},
}

for _, tt := range testcases {
t.Run(tt.name, func(t *testing.T) {
got := make(map[string]interface{})
tt.p.PopulatePluginIdentityTokenData(got)
assert.Equal(t, tt.want, got)
})
}
}

func TestAddPluginIdentityTokenFields(t *testing.T) {
testcases := []struct {
name string
input map[string]*framework.FieldSchema
want map[string]*framework.FieldSchema
}{
{
name: "basic",
input: map[string]*framework.FieldSchema{},
want: map[string]*framework.FieldSchema{
fieldIDTokenAudience: {
Type: framework.TypeString,
Description: "Audience of plugin identity tokens",
Default: "",
},
fieldIDTokenTTL: {
Type: framework.TypeDurationSecond,
Description: "Time-to-live of plugin identity tokens",
Default: 3600,
},
},
},
{
name: "additional-fields",
input: map[string]*framework.FieldSchema{
"test": {
Type: framework.TypeString,
Description: "Test description",
Default: "default",
},
},
want: map[string]*framework.FieldSchema{
fieldIDTokenAudience: {
Type: framework.TypeString,
Description: "Audience of plugin identity tokens",
Default: "",
},
fieldIDTokenTTL: {
Type: framework.TypeDurationSecond,
Description: "Time-to-live of plugin identity tokens",
Default: 3600,
},
"test": {
Type: framework.TypeString,
Description: "Test description",
Default: "default",
},
},
},
}

for _, tt := range testcases {
t.Run(tt.name, func(t *testing.T) {
got := tt.input
AddPluginIdentityTokenFields(got)
assert.Equal(t, tt.want, got)
})
}
}

0 comments on commit 5f3ff6b

Please sign in to comment.