diff --git a/.gitignore b/.gitignore index 8b77e3b..9f12641 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ lib/ *.swp coverage.txt report.json +*.test # for VSCode .vscode/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 2137e9e..77c6dd5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). ## [Unreleased] +- Refactor configuration to preserve case in claims + ### Changed - Update mentions of the default branch from 'master' to 'main'. [#58](https://github.com/xmidt-org/themis/pull/58) diff --git a/devMode.go b/devMode.go index 4fac593..8230221 100644 --- a/devMode.go +++ b/devMode.go @@ -65,35 +65,41 @@ token: notBeforeDelta: -15s duration: 24h claims: - mac: + - key: mac header: X-Midt-Mac-Address parameter: mac - serial: + - key: serial header: X-Midt-Serial-Number parameter: serial - uuid: + - key: uuid header: X-Midt-Uuid parameter: uuid - iss: + - key: iss value: "development" - trust: + - key: trust value: 1000 - sub: + - key: sub value: "client-supplied" - aud: + - key: aud value: "XMiDT" - capabilities: + - key: capabilities value: - x1:issuer:test:.*:all + - key: allowedResources + json: '{ + "allowedPartners": ["comcast"], + "allowedServiceAccountIds": ["1234", "5678"] + }' + metadata: - mac: + - key: mac header: X-Midt-Mac-Address parameter: mac - serial: + - key: serial header: X-Midt-Serial-Number parameter: serial - uuid: + - key: uuid header: X-Midt-Uuid parameter: uuid partnerID: diff --git a/main.go b/main.go index 74fdd9e..48372df 100644 --- a/main.go +++ b/main.go @@ -90,7 +90,7 @@ func setupViper(in config.ViperIn, v *viper.Viper) (err error) { v.Set("log.level", "DEBUG") } - return nil + return } func main() { diff --git a/themis.yaml b/themis.yaml index 68654c1..8ed0d4b 100644 --- a/themis.yaml +++ b/themis.yaml @@ -56,34 +56,39 @@ token: notBeforeDelta: -15s duration: 24h claims: - mac: + - key: mac header: X-Midt-Mac-Address parameter: mac - serial: + - key: serial header: X-Midt-Serial-Number parameter: serial - uuid: + - key: uuid header: X-Midt-Uuid parameter: uuid - iss: + - key: iss value: "development" - trust: + - key: trust value: 1000 - sub: + - key: sub value: "client-supplied" - aud: + - key: aud value: "XMiDT" - capabilities: + - key: capabilities value: - x1:issuer:test:.*:all + - key: nestedClaims + json: '{ + "casePreservedScalar": "true", + "casePreservedArray": ["casePreserved1", "casePreserved2"] + }' metadata: - mac: + - key: mac header: X-Midt-Mac-Address parameter: mac - serial: + - key: serial header: X-Midt-Serial-Number parameter: serial - uuid: + - key: uuid header: X-Midt-Uuid parameter: uuid partnerID: diff --git a/token/claimBuilder.go b/token/claimBuilder.go index 022c29f..9ea2be3 100644 --- a/token/claimBuilder.go +++ b/token/claimBuilder.go @@ -17,6 +17,7 @@ import ( var ( ErrRemoteURLRequired = errors.New("A URL for the remote claimer is required") + ErrMissingKey = errors.New("A key is required for all claims and metadata values") ) // ClaimBuilder is a strategy for building token claims, given a token Request @@ -188,16 +189,25 @@ func NewClaimBuilders(n random.Noncer, client xhttpclient.Interface, o Options) if o.Remote != nil { // scan the metadata looking for static values that should be applied when invoking the remote server metadata := make(map[string]interface{}) - for name, value := range o.Metadata { - if len(value.Header) != 0 || len(value.Parameter) != 0 || len(value.Variable) != 0 { + for _, value := range o.Metadata { + switch { + case len(value.Key) == 0: + return nil, ErrMissingKey + + case value.IsFromHTTP(): continue - } - if value.Value == nil { - return nil, fmt.Errorf("A value is required for the static metadata: %s", name) - } + case !value.IsStatic(): + return nil, fmt.Errorf("A value is required for the static metadata: %s", value.Key) - metadata[name] = value.Value + default: + msg, err := value.RawMessage() + if err != nil { + return nil, err + } + + metadata[value.Key] = msg + } } remoteClaimBuilder, err := newRemoteClaimBuilder(client, metadata, o.Remote) @@ -208,17 +218,25 @@ func NewClaimBuilders(n random.Noncer, client xhttpclient.Interface, o Options) builders = append(builders, remoteClaimBuilder) } - for name, value := range o.Claims { - if len(value.Header) != 0 || len(value.Parameter) != 0 || len(value.Variable) != 0 { - // skip any claims derived from HTTP requests + for _, value := range o.Claims { + switch { + case len(value.Key) == 0: + return nil, ErrMissingKey + + case value.IsFromHTTP(): continue - } - if value.Value == nil { - return nil, fmt.Errorf("A value is required for the static claim: %s", name) - } + case !value.IsStatic(): + return nil, fmt.Errorf("A value is required for the static claim: %s", value.Key) - staticClaimBuilder[name] = value.Value + default: + msg, err := value.RawMessage() + if err != nil { + return nil, err + } + + staticClaimBuilder[value.Key] = msg + } } if len(staticClaimBuilder) > 0 { diff --git a/token/claimBuilder_test.go b/token/claimBuilder_test.go index b90a793..6d91f80 100644 --- a/token/claimBuilder_test.go +++ b/token/claimBuilder_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -14,12 +15,99 @@ import ( "github.com/xmidt-org/themis/random/randomtest" "github.com/xmidt-org/themis/xhttp/xhttpclient" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" ) -func TestRequestClaimBuilder(t *testing.T) { - testData := []struct { +type ClaimBuildersTestSuite struct { + suite.Suite + expectedCtx context.Context + expectedErr error +} + +var _ suite.SetupAllSuite = (*ClaimBuildersTestSuite)(nil) + +func (suite *ClaimBuildersTestSuite) SetupSuite() { + suite.expectedCtx = context.WithValue(context.Background(), "foo", "bar") + suite.expectedErr = errors.New("expected AddClaims error") +} + +func (suite *ClaimBuildersTestSuite) TestSuccess() { + for _, count := range []int{0, 1, 2, 5} { + suite.Run(fmt.Sprintf("count=%d", count), func() { + var ( + builder ClaimBuilders + expectedRequest = new(Request) + expected = make(map[string]interface{}) + actual = make(map[string]interface{}) + ) + + for i := 0; i < count; i++ { + i := i + expected[strconv.Itoa(i)] = "true" + builder = append(builder, + ClaimBuilderFunc(func(actualCtx context.Context, actualRequest *Request, target map[string]interface{}) error { + suite.Equal(suite.expectedCtx, actualCtx) + suite.True(expectedRequest == actualRequest) + target[strconv.Itoa(i)] = "true" + return nil + }), + ) + } + + suite.Require().NoError( + builder.AddClaims(suite.expectedCtx, expectedRequest, actual), + ) + + suite.Equal(expected, actual) + }) + } +} + +func (suite *ClaimBuildersTestSuite) TestError() { + var ( + expectedRequest = new(Request) + expected = map[string]interface{}{ + "first": "true", + } + + actual = make(map[string]interface{}) + + builder = ClaimBuilders{ + ClaimBuilderFunc(func(actualCtx context.Context, actualRequest *Request, target map[string]interface{}) error { + suite.Equal(suite.expectedCtx, actualCtx) + suite.True(expectedRequest == actualRequest) + target["first"] = "true" + return nil + }), + ClaimBuilderFunc(func(actualCtx context.Context, actualRequest *Request, target map[string]interface{}) error { + suite.Equal(suite.expectedCtx, actualCtx) + suite.True(expectedRequest == actualRequest) + return suite.expectedErr + }), + ClaimBuilderFunc(func(actualCtx context.Context, actualRequest *Request, target map[string]interface{}) error { + suite.Fail("This claim builder should not have been called") + return nil + }), + } + ) + + suite.Error( + builder.AddClaims(suite.expectedCtx, expectedRequest, actual), + ) + + suite.Equal(expected, actual) +} + +func TestClaimBuilders(t *testing.T) { + suite.Run(t, new(ClaimBuildersTestSuite)) +} + +type RequestClaimBuilderTestSuite struct { + suite.Suite +} + +func (suite *RequestClaimBuilderTestSuite) Test() { + cases := []struct { request *Request expected map[string]interface{} }{ @@ -35,24 +123,28 @@ func TestRequestClaimBuilder(t *testing.T) { }, } - for i, record := range testData { - t.Run(strconv.Itoa(i), func(t *testing.T) { - var ( - assert = assert.New(t) - actual = make(map[string]interface{}) - ) - - assert.NoError( - requestClaimBuilder{}.AddClaims(context.Background(), record.request, actual), + for i, testCase := range cases { + suite.Run(strconv.Itoa(i), func() { + actual := make(map[string]interface{}) + suite.NoError( + requestClaimBuilder{}.AddClaims(context.Background(), testCase.request, actual), ) - assert.Equal(record.expected, actual) + suite.Equal(testCase.expected, actual) }) } } -func TestStaticClaimBuilder(t *testing.T) { - testData := []struct { +func TestRequestClaimBuilder(t *testing.T) { + suite.Run(t, new(RequestClaimBuilderTestSuite)) +} + +type StaticClaimBuilderTestSuite struct { + suite.Suite +} + +func (suite *StaticClaimBuilderTestSuite) Test() { + cases := []struct { builder staticClaimBuilder expected map[string]interface{} }{ @@ -69,152 +161,210 @@ func TestStaticClaimBuilder(t *testing.T) { }, } - for i, record := range testData { - t.Run(strconv.Itoa(i), func(t *testing.T) { - var ( - assert = assert.New(t) - actual = make(map[string]interface{}) - ) - - assert.NoError( - record.builder.AddClaims(context.Background(), new(Request), actual), + for i, testCase := range cases { + suite.Run(strconv.Itoa(i), func() { + actual := make(map[string]interface{}) + suite.NoError( + testCase.builder.AddClaims(context.Background(), new(Request), actual), ) - assert.Equal(record.expected, actual) + suite.Equal(testCase.expected, actual) }) } } -func TestTimeClaimBuilder(t *testing.T) { - var ( - expectedNow = time.Now() - now = func() time.Time { return expectedNow } - testData = []struct { - builder timeClaimBuilder - expected map[string]interface{} - }{ - { - builder: timeClaimBuilder{ - now: now, - disableNotBefore: true, - }, - expected: map[string]interface{}{ - "iat": expectedNow.UTC().Unix(), - }, +func TestStaticClaimBuilder(t *testing.T) { + suite.Run(t, new(StaticClaimBuilderTestSuite)) +} + +type TimeClaimBuilderTestSuite struct { + suite.Suite + expectedNow time.Time +} + +var _ suite.SetupAllSuite = (*TimeClaimBuilderTestSuite)(nil) + +func (suite *TimeClaimBuilderTestSuite) SetupSuite() { + suite.expectedNow = time.Now() +} + +func (suite *TimeClaimBuilderTestSuite) now() time.Time { + return suite.expectedNow +} + +func (suite *TimeClaimBuilderTestSuite) TestX() { + cases := []struct { + builder timeClaimBuilder + expected map[string]interface{} + }{ + { + builder: timeClaimBuilder{ + now: suite.now, + disableNotBefore: true, }, - { - builder: timeClaimBuilder{ - now: now, - }, - expected: map[string]interface{}{ - "iat": expectedNow.UTC().Unix(), - "nbf": expectedNow.UTC().Unix(), - }, + expected: map[string]interface{}{ + "iat": suite.expectedNow.UTC().Unix(), }, - { - builder: timeClaimBuilder{ - now: now, - duration: 24 * time.Hour, - notBeforeDelta: 5 * time.Minute, - }, - expected: map[string]interface{}{ - "iat": expectedNow.UTC().Unix(), - "nbf": expectedNow.UTC().Add(5 * time.Minute).Unix(), - "exp": expectedNow.UTC().Add(24 * time.Hour).Unix(), - }, + }, + { + builder: timeClaimBuilder{ + now: suite.now, }, - { - builder: timeClaimBuilder{ - now: now, - duration: 30 * time.Minute, - disableNotBefore: true, - }, - expected: map[string]interface{}{ - "iat": expectedNow.UTC().Unix(), - "exp": expectedNow.UTC().Add(30 * time.Minute).Unix(), - }, + expected: map[string]interface{}{ + "iat": suite.expectedNow.UTC().Unix(), + "nbf": suite.expectedNow.UTC().Unix(), }, - } - ) - - for i, record := range testData { - t.Run(strconv.Itoa(i), func(t *testing.T) { - var ( - assert = assert.New(t) - actual = make(map[string]interface{}) - ) + }, + { + builder: timeClaimBuilder{ + now: suite.now, + duration: 24 * time.Hour, + notBeforeDelta: 5 * time.Minute, + }, + expected: map[string]interface{}{ + "iat": suite.expectedNow.UTC().Unix(), + "nbf": suite.expectedNow.UTC().Add(5 * time.Minute).Unix(), + "exp": suite.expectedNow.UTC().Add(24 * time.Hour).Unix(), + }, + }, + { + builder: timeClaimBuilder{ + now: suite.now, + duration: 30 * time.Minute, + disableNotBefore: true, + }, + expected: map[string]interface{}{ + "iat": suite.expectedNow.UTC().Unix(), + "exp": suite.expectedNow.UTC().Add(30 * time.Minute).Unix(), + }, + }, + } - assert.NoError( - record.builder.AddClaims(context.Background(), new(Request), actual), + for i, testCase := range cases { + suite.Run(strconv.Itoa(i), func() { + actual := make(map[string]interface{}) + suite.NoError( + testCase.builder.AddClaims(context.Background(), new(Request), actual), ) - assert.Equal(record.expected, actual) + suite.Equal(testCase.expected, actual) }) } } +func TestTimeClaimBuilder(t *testing.T) { + suite.Run(t, new(TimeClaimBuilderTestSuite)) +} + +type NonceClaimBuilderTestSuite struct { + suite.Suite + noncer *randomtest.Noncer + builder nonceClaimBuilder + expectedErr error +} + +var _ suite.SetupTestSuite = (*NonceClaimBuilderTestSuite)(nil) +var _ suite.TearDownTestSuite = (*NonceClaimBuilderTestSuite)(nil) + +func (suite *NonceClaimBuilderTestSuite) SetupTest() { + suite.noncer = new(randomtest.Noncer) + suite.builder = nonceClaimBuilder{n: suite.noncer} + suite.expectedErr = errors.New("expected") +} + +func (suite *NonceClaimBuilderTestSuite) TearDownTest() { + suite.noncer.AssertExpectations(suite.T()) +} + +func (suite *NonceClaimBuilderTestSuite) TestSuccess() { + actual := make(map[string]interface{}) + suite.noncer.ExpectNonce().Return("test", error(nil)).Once() + suite.NoError( + suite.builder.AddClaims(context.Background(), new(Request), actual), + ) + + suite.Equal( + map[string]interface{}{"jti": "test"}, + actual, + ) +} + +func (suite *NonceClaimBuilderTestSuite) TestError() { + actual := make(map[string]interface{}) + suite.noncer.ExpectNonce().Return("", suite.expectedErr).Once() + suite.Equal( + suite.expectedErr, + suite.builder.AddClaims(context.Background(), new(Request), actual), + ) + + suite.Empty(actual) +} + func TestNonceClaimBuilder(t *testing.T) { - t.Run("Success", func(t *testing.T) { - var ( - assert = assert.New(t) - noncer = new(randomtest.Noncer) - - actual = make(map[string]interface{}) - builder = nonceClaimBuilder{n: noncer} - ) - - noncer.ExpectNonce().Return("test", error(nil)).Once() - assert.NoError(builder.AddClaims(context.Background(), new(Request), actual)) - assert.Equal( - map[string]interface{}{"jti": "test"}, - actual, - ) - }) + suite.Run(t, new(NonceClaimBuilderTestSuite)) +} - t.Run("Error", func(t *testing.T) { - var ( - assert = assert.New(t) - noncer = new(randomtest.Noncer) - expectedErr = errors.New("expected") +type RemoteClaimBuilderTestSuite struct { + suite.Suite + server *httptest.Server + goodURL string + badURL string + expectedMethod string +} - actual = make(map[string]interface{}) - builder = nonceClaimBuilder{n: noncer} - ) +var _ suite.SetupAllSuite = (*RemoteClaimBuilderTestSuite)(nil) +var _ suite.TearDownAllSuite = (*RemoteClaimBuilderTestSuite)(nil) - noncer.ExpectNonce().Return("", expectedErr).Once() - assert.Equal( - expectedErr, - builder.AddClaims(context.Background(), new(Request), actual), - ) +func (suite *RemoteClaimBuilderTestSuite) SetupSuite() { + mux := http.NewServeMux() + mux.HandleFunc("/good", suite.goodHandler) + mux.HandleFunc("/bad", suite.badHandler) + suite.server = httptest.NewServer(mux) + suite.goodURL = suite.server.URL + "/good" + suite.badURL = suite.server.URL + "/bad" +} - assert.Empty(actual) - }) +func (suite *RemoteClaimBuilderTestSuite) TearDownSuite() { + suite.server.Close() } -func TestNewRemoteClaimBuilder(t *testing.T) { - t.Run("NoURL", func(t *testing.T) { - assert := assert.New(t) - cb, err := newRemoteClaimBuilder(new(http.Client), nil, new(RemoteClaims)) - assert.Nil(cb) - assert.Error(err) - }) +func (suite *RemoteClaimBuilderTestSuite) TearDownTest() { + suite.expectedMethod = "" +} - t.Run("BadURL", func(t *testing.T) { - var ( - assert = assert.New(t) - remoteClaims = &RemoteClaims{ - URL: "this is not valid (%$&@!()&*()*%", - } - ) +func (suite *RemoteClaimBuilderTestSuite) goodHandler(response http.ResponseWriter, request *http.Request) { + suite.Equal("application/json", request.Header.Get("Content-Type")) + expectedMethod := suite.expectedMethod + if len(expectedMethod) == 0 { + expectedMethod = http.MethodPost + } - cb, err := newRemoteClaimBuilder(new(http.Client), nil, remoteClaims) - assert.Nil(cb) - assert.Error(err) - }) + suite.Equal(expectedMethod, request.Method) + + b, err := ioutil.ReadAll(request.Body) + suite.NoError(err) + + var input map[string]interface{} + suite.NoError(json.Unmarshal(b, &input)) + if input == nil { + input = make(map[string]interface{}) + } + + input["custom"] = "value" + + response.Header().Set("Content-Type", "application/json") + b, err = json.Marshal(input) + suite.NoError(err) + suite.NotEmpty(b) + response.Write(b) } -func testRemoteClaimBuilderAddClaims(t *testing.T) { - testData := []struct { +func (suite *RemoteClaimBuilderTestSuite) badHandler(response http.ResponseWriter, request *http.Request) { + response.Write([]byte("this is not JSON")) +} + +func (suite *RemoteClaimBuilderTestSuite) TestAddClaims() { + cases := []struct { method string client xhttpclient.Interface metadata map[string]interface{} @@ -245,328 +395,426 @@ func testRemoteClaimBuilderAddClaims(t *testing.T) { }, } - for i, record := range testData { - t.Run(strconv.Itoa(i), func(t *testing.T) { + for i, testCase := range cases { + suite.Run(strconv.Itoa(i), func() { var ( - assert = assert.New(t) - require = require.New(t) - remoteClaims = &RemoteClaims{Method: record.method} - - handler = http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { - assert.Equal("application/json", request.Header.Get("Content-Type")) - expectedMethod := record.method - if len(expectedMethod) == 0 { - expectedMethod = http.MethodPost - } - - assert.Equal(expectedMethod, request.Method) - - b, err := ioutil.ReadAll(request.Body) - assert.NoError(err) - - var input map[string]interface{} - assert.NoError(json.Unmarshal(b, &input)) - if input == nil { - input = make(map[string]interface{}) - } - - input["custom"] = "value" - - response.Header().Set("Content-Type", "application/json") - b, err = json.Marshal(input) - assert.NoError(err) - assert.NotEmpty(b) - response.Write(b) - }) - ) + actual = make(map[string]interface{}) - server := httptest.NewServer(handler) - defer server.Close() - remoteClaims.URL = server.URL + remoteClaims = &RemoteClaims{ + URL: suite.goodURL, + Method: testCase.method, + } - builder, err := newRemoteClaimBuilder( - record.client, - record.metadata, - remoteClaims, + builder, err = newRemoteClaimBuilder( + testCase.client, + testCase.metadata, + remoteClaims, + ) ) - require.NoError(err) - require.NotNil(builder) + suite.Require().NoError(err) + suite.Require().NotNil(builder) + suite.expectedMethod = testCase.method - actual := make(map[string]interface{}) - require.NoError(builder.AddClaims(context.Background(), record.request, actual)) - assert.Equal(record.expected, actual) + suite.Require().NoError( + builder.AddClaims(context.Background(), testCase.request, actual), + ) + + suite.Equal(testCase.expected, actual) }) } } -func testRemoteClaimBuilderRemoteError(t *testing.T) { - var ( - assert = assert.New(t) - require = require.New(t) +func (suite *RemoteClaimBuilderTestSuite) TestError() { + builder, err := newRemoteClaimBuilder(nil, nil, &RemoteClaims{URL: suite.badURL}) + suite.Require().NoError(err) + suite.Require().NotNil(builder) - handler = http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { - response.Write([]byte("this is not JSON")) - }) + suite.Error( + builder.AddClaims(context.Background(), new(Request), make(map[string]interface{})), ) +} + +func (suite *RemoteClaimBuilderTestSuite) TestNoURL() { + builder, err := newRemoteClaimBuilder(new(http.Client), nil, new(RemoteClaims)) + suite.Nil(builder) + suite.Error(err) +} - server := httptest.NewServer(handler) - defer server.Close() +func (suite *RemoteClaimBuilderTestSuite) TestBadURL() { + var ( + remoteClaims = &RemoteClaims{ + URL: "this is not valid (%$&@!()&*()*%", + } - builder, err := newRemoteClaimBuilder(nil, nil, &RemoteClaims{URL: server.URL}) - require.NoError(err) - require.NotNil(builder) + builder, err = newRemoteClaimBuilder(new(http.Client), nil, remoteClaims) + ) - assert.Error(builder.AddClaims(context.Background(), new(Request), make(map[string]interface{}))) + suite.Nil(builder) + suite.Error(err) } func TestRemoteClaimBuilder(t *testing.T) { - t.Run("AddClaims", testRemoteClaimBuilderAddClaims) - t.Run("RemoteError", testRemoteClaimBuilderRemoteError) + suite.Run(t, new(RemoteClaimBuilderTestSuite)) } -func testNewClaimBuildersMinimum(t *testing.T) { - var ( - assert = assert.New(t) - require = require.New(t) +type NewClaimBuildersTestSuite struct { + suite.Suite + server *httptest.Server + noncer *randomtest.Noncer + expectedNow time.Time +} - noncer = new(randomtest.Noncer) +var _ suite.SetupTestSuite = (*NewClaimBuildersTestSuite)(nil) +var _ suite.SetupAllSuite = (*NewClaimBuildersTestSuite)(nil) +var _ suite.TearDownTestSuite = (*NewClaimBuildersTestSuite)(nil) +var _ suite.TearDownAllSuite = (*NewClaimBuildersTestSuite)(nil) + +func (suite *NewClaimBuildersTestSuite) SetupSuite() { + suite.server = httptest.NewServer( + http.HandlerFunc(suite.handleRemoteClaims), ) +} + +func (suite *NewClaimBuildersTestSuite) SetupTest() { + suite.noncer = new(randomtest.Noncer) + suite.expectedNow = time.Now() +} + +func (suite *NewClaimBuildersTestSuite) TearDownSuite() { + suite.server.Close() +} + +func (suite *NewClaimBuildersTestSuite) TearDownTest() { + suite.noncer.AssertExpectations(suite.T()) +} - builder, err := NewClaimBuilders(noncer, nil, Options{ +func (suite *NewClaimBuildersTestSuite) now() time.Time { + return suite.expectedNow +} + +func (suite *NewClaimBuildersTestSuite) replaceNow(cb ClaimBuilders) { + for _, b := range cb { + if tb, ok := b.(*timeClaimBuilder); ok { + tb.now = suite.now + } + } +} + +func (suite *NewClaimBuildersTestSuite) rawMessage(v interface{}) json.RawMessage { + raw, err := json.Marshal(v) + suite.Require().NoError(err) + return json.RawMessage(raw) +} + +func (suite *NewClaimBuildersTestSuite) handleRemoteClaims(response http.ResponseWriter, request *http.Request) { + body, err := ioutil.ReadAll(request.Body) + suite.Require().NoError(err) + + var metadata map[string]interface{} + suite.Require().NoError(json.Unmarshal(body, &metadata)) + suite.Equal(map[string]interface{}{"extra": "extra stuff"}, metadata) + + response.Header().Set("Content-Type", "application/json") + response.Write([]byte(`{"remote": "value"}`)) +} + +func (suite *NewClaimBuildersTestSuite) TestMinimum() { + builder, err := NewClaimBuilders(suite.noncer, nil, Options{ Nonce: false, DisableTime: true, }) - require.NoError(err) - require.NotEmpty(builder) + suite.Require().NoError(err) + suite.Require().NotEmpty(builder) actual := make(map[string]interface{}) - assert.NoError( + suite.NoError( builder.AddClaims(context.Background(), &Request{Claims: map[string]interface{}{"request": 123}}, actual), ) - assert.Equal( + suite.Equal( map[string]interface{}{"request": 123}, actual, ) +} - noncer.AssertExpectations(t) +func (suite *NewClaimBuildersTestSuite) TestMissingKey() { + suite.Run("Claims", suite.testClaimsMissingKey) + suite.Run("Metadata", suite.testMetadataMissingKey) } -func testNewClaimBuildersBadValue(t *testing.T) { - var ( - assert = assert.New(t) - noncer = new(randomtest.Noncer) - ) +func (suite *NewClaimBuildersTestSuite) testClaimsMissingKey() { + builder, err := NewClaimBuilders(suite.noncer, nil, Options{ + Nonce: false, + DisableTime: true, + Claims: []Value{ + {}, // the value should have something configured + }, + }) - builder, err := NewClaimBuilders(noncer, nil, Options{ + suite.Nil(builder) + suite.Error(err) +} + +func (suite *NewClaimBuildersTestSuite) testMetadataMissingKey() { + builder, err := NewClaimBuilders(suite.noncer, nil, Options{ Nonce: false, DisableTime: true, - Claims: map[string]Value{ - "bad": Value{}, // the value should have something configured + Metadata: []Value{ + {}, // the value should have something configured }, + Remote: &RemoteClaims{}, }) - assert.Nil(builder) - assert.Error(err) + suite.Nil(builder) + suite.Error(err) +} - noncer.AssertExpectations(t) +func (suite *NewClaimBuildersTestSuite) TestMissingValue() { + suite.Run("Claims", suite.testClaimsMissingValue) + suite.Run("Metadata", suite.testMetadataMissingValue) } -func testNewClaimBuildersStatic(t *testing.T) { - var ( - assert = assert.New(t) - require = require.New(t) +func (suite *NewClaimBuildersTestSuite) testClaimsMissingValue() { + builder, err := NewClaimBuilders(suite.noncer, nil, Options{ + Nonce: false, + DisableTime: true, + Claims: []Value{ + { + Key: "test", + // either JSON or Value should be set + }, + }, + }) - noncer = new(randomtest.Noncer) - ) + suite.Nil(builder) + suite.Error(err) +} + +func (suite *NewClaimBuildersTestSuite) testMetadataMissingValue() { + builder, err := NewClaimBuilders(suite.noncer, nil, Options{ + Nonce: false, + DisableTime: true, + Metadata: []Value{ + { + Key: "test", + // either JSON or Value should be set + }, + }, + Remote: &RemoteClaims{}, + }) + + suite.Nil(builder) + suite.Error(err) +} + +func (suite *NewClaimBuildersTestSuite) TestBadJSONValue() { + suite.Run("Claims", suite.testClaimsBadJSONValue) + suite.Run("Metadata", suite.testMetadataBadJSONValue) +} - builder, err := NewClaimBuilders(noncer, nil, Options{ +func (suite *NewClaimBuildersTestSuite) testClaimsBadJSONValue() { + builder, err := NewClaimBuilders(suite.noncer, nil, Options{ Nonce: false, DisableTime: true, - Claims: map[string]Value{ - "static1": Value{ + Claims: []Value{ + { + Key: "test", + JSON: `{"this isn't valid JSON`, + }, + }, + }) + + suite.Nil(builder) + suite.Error(err) +} + +func (suite *NewClaimBuildersTestSuite) testMetadataBadJSONValue() { + builder, err := NewClaimBuilders(suite.noncer, nil, Options{ + Nonce: false, + DisableTime: true, + Metadata: []Value{ + { + Key: "test", + JSON: `{"this isn't valid JSON`, + }, + }, + Remote: &RemoteClaims{}, + }) + + suite.Nil(builder) + suite.Error(err) +} + +func (suite *NewClaimBuildersTestSuite) TestStatic() { + builder, err := NewClaimBuilders(suite.noncer, nil, Options{ + Nonce: false, + DisableTime: true, + Claims: []Value{ + { + Key: "static1", Value: -72.5, }, - "static2": Value{ + { + Key: "static2", Value: []string{"a", "b"}, }, - "http1": Value{ + { + Key: "http1", Header: "X-Ignore-Me", }, }, }) - require.NoError(err) - require.NotEmpty(builder) + suite.Require().NoError(err) + suite.Require().NotEmpty(builder) actual := make(map[string]interface{}) - assert.NoError( + suite.NoError( builder.AddClaims(context.Background(), &Request{Claims: map[string]interface{}{"request": 123}}, actual), ) - assert.Equal( - map[string]interface{}{"static1": -72.5, "static2": []string{"a", "b"}, "request": 123}, + suite.Equal( + map[string]interface{}{ + "static1": suite.rawMessage(-72.5), + "static2": suite.rawMessage([]string{"a", "b"}), + "request": 123, + }, actual, ) - - noncer.AssertExpectations(t) } -func testNewClaimBuildersNoRemote(t *testing.T) { - var ( - assert = assert.New(t) - require = require.New(t) - - noncer = new(randomtest.Noncer) - - expectedNow = time.Now() - now = func() time.Time { return expectedNow } - ) - - builder, err := NewClaimBuilders(noncer, nil, Options{ +func (suite *NewClaimBuildersTestSuite) TestNoRemote() { + builder, err := NewClaimBuilders(suite.noncer, nil, Options{ Nonce: true, Duration: 24 * time.Hour, NotBeforeDelta: 15 * time.Second, - Claims: map[string]Value{ - "static1": Value{ + Claims: []Value{ + { + Key: "static1", Value: -72.5, }, - "static2": Value{ + { + Key: "static2", Value: []string{"a", "b"}, }, - "http1": Value{ + { + Key: "http1", Header: "X-Ignore-Me", }, }, }) - require.NoError(err) - require.NotEmpty(builder) - - for _, b := range builder { - if tb, ok := b.(*timeClaimBuilder); ok { - tb.now = now - break - } - } + suite.Require().NoError(err) + suite.Require().NotEmpty(builder) - noncer.ExpectNonce().Return("test", error(nil)).Once() + suite.replaceNow(builder) + suite.noncer.ExpectNonce().Return("test", error(nil)).Once() actual := make(map[string]interface{}) - assert.NoError( + suite.NoError( builder.AddClaims(context.Background(), &Request{Claims: map[string]interface{}{"request": 123}}, actual), ) - assert.Equal( + suite.Equal( map[string]interface{}{ - "static1": -72.5, - "static2": []string{"a", "b"}, + "static1": suite.rawMessage(-72.5), + "static2": suite.rawMessage([]string{"a", "b"}), "request": 123, "jti": "test", - "iat": expectedNow.UTC().Unix(), - "nbf": expectedNow.Add(15 * time.Second).UTC().Unix(), - "exp": expectedNow.Add(24 * time.Hour).UTC().Unix(), + "iat": suite.expectedNow.UTC().Unix(), + "nbf": suite.expectedNow.Add(15 * time.Second).UTC().Unix(), + "exp": suite.expectedNow.Add(24 * time.Hour).UTC().Unix(), }, actual, ) - - noncer.AssertExpectations(t) } -func testNewClaimBuildersFull(t *testing.T) { - var ( - assert = assert.New(t) - require = require.New(t) +func (suite *NewClaimBuildersTestSuite) TestBadRemote() { + _, err := NewClaimBuilders(nil, nil, Options{ + Nonce: true, + Duration: 24 * time.Hour, + NotBeforeDelta: 15 * time.Second, + Metadata: []Value{ + { + Key: "extra1", + Value: "extra stuff", + }, + { + Key: "http2", + Parameter: "foo", + }, + }, + Remote: &RemoteClaims{}, // invalid: missing a URL + }) - noncer = new(randomtest.Noncer) + suite.Error(err) +} - expectedNow = time.Now() - now = func() time.Time { return expectedNow } - options = Options{ +func (suite *NewClaimBuildersTestSuite) TestFull() { + var ( + options = Options{ Nonce: true, Duration: 24 * time.Hour, NotBeforeDelta: 15 * time.Second, - Claims: map[string]Value{ - "static1": Value{ + Claims: []Value{ + { + Key: "static1", Value: -72.5, }, - "static2": Value{ + { + Key: "static2", Value: []string{"a", "b"}, }, - "http1": Value{ + { + Key: "http1", Header: "X-Ignore-Me", }, }, - Metadata: map[string]Value{ - "extra1": Value{ + Metadata: []Value{ + { + Key: "extra", Value: "extra stuff", }, - "http2": Value{ + { + Key: "http2", Parameter: "foo", }, }, + Remote: &RemoteClaims{ + URL: suite.server.URL, + }, } - - handler = http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { - body, err := ioutil.ReadAll(request.Body) - require.NoError(err) - - var metadata map[string]interface{} - require.NoError(json.Unmarshal(body, &metadata)) - assert.Equal(map[string]interface{}{"extra1": "extra stuff"}, metadata) - - response.Header().Set("Content-Type", "application/json") - response.Write([]byte(`{"remote": "value"}`)) - }) ) - server := httptest.NewServer(handler) - defer server.Close() - options.Remote = &RemoteClaims{ - URL: server.URL, - } - - builder, err := NewClaimBuilders(noncer, nil, options) - require.NoError(err) - require.NotEmpty(builder) - - for _, b := range builder { - if tb, ok := b.(*timeClaimBuilder); ok { - tb.now = now - break - } - } + builder, err := NewClaimBuilders(suite.noncer, nil, options) + suite.Require().NoError(err) + suite.Require().NotEmpty(builder) - noncer.ExpectNonce().Return("test", error(nil)).Once() + suite.replaceNow(builder) + suite.noncer.ExpectNonce().Return("test", error(nil)).Once() actual := make(map[string]interface{}) - assert.NoError( + suite.NoError( builder.AddClaims(context.Background(), &Request{Claims: map[string]interface{}{"request": 123}}, actual), ) - assert.Equal( + suite.Equal( map[string]interface{}{ - "static1": -72.5, - "static2": []string{"a", "b"}, + "static1": suite.rawMessage(-72.5), + "static2": suite.rawMessage([]string{"a", "b"}), "request": 123, "remote": "value", "jti": "test", - "iat": expectedNow.UTC().Unix(), - "nbf": expectedNow.Add(15 * time.Second).UTC().Unix(), - "exp": expectedNow.Add(24 * time.Hour).UTC().Unix(), + "iat": suite.expectedNow.UTC().Unix(), + "nbf": suite.expectedNow.Add(15 * time.Second).UTC().Unix(), + "exp": suite.expectedNow.Add(24 * time.Hour).UTC().Unix(), }, actual, ) - - noncer.AssertExpectations(t) } func TestNewClaimBuilders(t *testing.T) { - t.Run("Minimal", testNewClaimBuildersMinimum) - t.Run("BadValue", testNewClaimBuildersBadValue) - t.Run("Static", testNewClaimBuildersStatic) - t.Run("NoRemote", testNewClaimBuildersNoRemote) - t.Run("Full", testNewClaimBuildersFull) + suite.Run(t, new(NewClaimBuildersTestSuite)) } diff --git a/token/options.go b/token/options.go index f00bf57..0922e06 100644 --- a/token/options.go +++ b/token/options.go @@ -1,6 +1,7 @@ package token import ( + "encoding/json" "time" "github.com/xmidt-org/themis/key" @@ -17,8 +18,11 @@ type RemoteClaims struct { URL string } -// Value represents information pulled from either the HTTP request or statically, via config. +// Value describes how to extract a key/value pair from either an HTTP request or from configuration. type Value struct { + // Key is the key to use for this value. Typically, this is the name of a claim. + Key string + // Header is an HTTP header from which the value is pulled Header string @@ -28,10 +32,45 @@ type Value struct { // Variable is a URL gorilla/mux variable from with the value is pulled Variable string + // JSON is the value embedded as a JSON snippet. If this field is set, Value is ignored. + // Using this field is convenient to avoid viper's lowercasing of keys. It's also handy + // to embed arbitrary structures in claims. + JSON string + // Value is the statically assigned value from configuration Value interface{} } +// IsFromHTTP tests if this value is extracted from an HTTP request +func (v Value) IsFromHTTP() bool { + return len(v.Header) > 0 || len(v.Parameter) > 0 || len(v.Variable) > 0 +} + +// IsStatic tests if this value is statically configured and does not +// come from an HTTP request. +func (v Value) IsStatic() bool { + return len(v.JSON) > 0 || v.Value != nil +} + +// RawMessage precomputes the JSON for this value. If the JSON field is set, +// it is verified by unmarshaling. Otherwise, the Value field is marshaled. +func (v Value) RawMessage() (json.RawMessage, error) { + switch { + case len(v.JSON) > 0: + raw := []byte(v.JSON) + var m map[string]interface{} + err := json.Unmarshal(raw, &m) + return json.RawMessage(raw), err + + case v.Value != nil: + raw, err := json.Marshal(v.Value) + return json.RawMessage(raw), err + + default: + return json.RawMessage(nil), nil + } +} + // PartnerID describes how to extract the partner id from an HTTP request. Partner IDs // require some special processing. type PartnerID struct { @@ -65,10 +104,10 @@ type Options struct { // // None of these claims receive any special processing. They are copied as is from the HTTP request // or statically from configuration. For special processing around the partner id, set the PartnerID field. - Claims map[string]Value + Claims []Value // Metadata describes non-claim data, which can be statically configured or supplied via a request - Metadata map[string]Value + Metadata []Value // PartnerID is the optional partner id configuration. If unset, no partner id processing is // performed, though a partner id may still be configured as part of the claims. diff --git a/token/transport.go b/token/transport.go index 82ff5c7..61bfaff 100644 --- a/token/transport.go +++ b/token/transport.go @@ -209,24 +209,29 @@ func (prb partnerIDRequestBuilder) Build(original *http.Request, tr *Request) er // assigned values are handled by ClaimBuilder objects and are part of the Factory configuration. func NewRequestBuilders(o Options) (RequestBuilders, error) { var rb RequestBuilders - for name, value := range o.Claims { - if len(value.Header) > 0 || len(value.Parameter) > 0 { + for _, value := range o.Claims { + switch { + case len(value.Key) == 0: + return nil, ErrMissingKey + + case len(value.Header) > 0 || len(value.Parameter) > 0: if len(value.Variable) > 0 { return nil, ErrVariableNotAllowed } rb = append(rb, headerParameterRequestBuilder{ - key: name, + key: value.Key, header: http.CanonicalHeaderKey(value.Header), parameter: value.Parameter, setter: claimsSetter, }, ) - } else if len(value.Variable) > 0 { + + case len(value.Variable) > 0: rb = append(rb, variableRequestBuilder{ - key: name, + key: value.Key, variable: value.Variable, setter: claimsSetter, }, @@ -234,24 +239,29 @@ func NewRequestBuilders(o Options) (RequestBuilders, error) { } } - for name, value := range o.Metadata { - if len(value.Header) > 0 || len(value.Parameter) > 0 { + for _, value := range o.Metadata { + switch { + case len(value.Key) == 0: + return nil, ErrMissingKey + + case len(value.Header) > 0 || len(value.Parameter) > 0: if len(value.Variable) > 0 { return nil, ErrVariableNotAllowed } rb = append(rb, headerParameterRequestBuilder{ - key: name, + key: value.Key, header: http.CanonicalHeaderKey(value.Header), parameter: value.Parameter, setter: metadataSetter, }, ) - } else if len(value.Variable) > 0 { + + case len(value.Variable) > 0: rb = append(rb, variableRequestBuilder{ - key: name, + key: value.Key, variable: value.Variable, setter: metadataSetter, }, diff --git a/token/transport_test.go b/token/transport_test.go index 1e18a68..9de9953 100644 --- a/token/transport_test.go +++ b/token/transport_test.go @@ -21,8 +21,9 @@ import ( func testNewRequestBuildersInvalidClaim(t *testing.T) { assert := assert.New(t) rb, err := NewRequestBuilders(Options{ - Claims: map[string]Value{ - "bad": Value{ + Claims: []Value{ + { + Key: "bad", Header: "xxx", Parameter: "yyy", Variable: "zzz", @@ -37,8 +38,9 @@ func testNewRequestBuildersInvalidClaim(t *testing.T) { func testNewRequestBuildersInvalidMetadata(t *testing.T) { assert := assert.New(t) rb, err := NewRequestBuilders(Options{ - Metadata: map[string]Value{ - "bad": Value{ + Metadata: []Value{ + { + Key: "bad", Header: "xxx", Parameter: "yyy", Variable: "zzz", @@ -64,19 +66,23 @@ func testNewRequestBuildersSuccess(t *testing.T) { }, { options: Options{ - Claims: map[string]Value{ - "fromHeader": Value{ + Claims: []Value{ + { + Key: "fromHeader", Header: "X-Claim", }, - "missing": Value{ + { + Key: "missing", Header: "X-Missing", }, }, - Metadata: map[string]Value{ - "fromHeader": Value{ + Metadata: []Value{ + { + Key: "fromHeader", Header: "X-Metadata", }, - "missing": Value{ + { + Key: "missing", Header: "X-Missing", }, }, @@ -105,19 +111,23 @@ func testNewRequestBuildersSuccess(t *testing.T) { }, { options: Options{ - Claims: map[string]Value{ - "fromParameter": Value{ + Claims: []Value{ + { + Key: "fromParameter", Parameter: "claim", }, - "missing": Value{ + { + Key: "missing", Parameter: "missing", }, }, - Metadata: map[string]Value{ - "fromParameter": Value{ + Metadata: []Value{ + { + Key: "fromParameter", Parameter: "metadata", }, - "missing": Value{ + { + Key: "missing", Parameter: "missing", }, }, @@ -141,13 +151,15 @@ func testNewRequestBuildersSuccess(t *testing.T) { }, { options: Options{ - Claims: map[string]Value{ - "fromVariable": Value{ + Claims: []Value{ + { + Key: "fromVariable", Variable: "claim", }, }, - Metadata: map[string]Value{ - "fromVariable": Value{ + Metadata: []Value{ + { + Key: "fromVariable", Variable: "metadata", }, }, @@ -176,13 +188,15 @@ func testNewRequestBuildersSuccess(t *testing.T) { }, { options: Options{ - Claims: map[string]Value{ - "fromVariable": Value{ + Claims: []Value{ + { + Key: "fromVariable", Variable: "claim", }, }, - Metadata: map[string]Value{ - "fromVariable": Value{ + Metadata: []Value{ + { + Key: "fromVariable", Variable: "metadata", }, }, @@ -237,8 +251,9 @@ func testNewRequestBuildersMissingVariable(t *testing.T) { require = require.New(t) options = Options{ - Claims: map[string]Value{ - "missing": Value{ + Claims: []Value{ + { + Key: "missing", Variable: "missing", }, }, diff --git a/token/unmarshal_test.go b/token/unmarshal_test.go index 6936457..b3abbbc 100644 --- a/token/unmarshal_test.go +++ b/token/unmarshal_test.go @@ -52,10 +52,11 @@ func testUnmarshalClaimBuilderError(t *testing.T) { config.Json(` { "token": { - "metadata": { - "bad": { + "metadata": [ + { + "key": "bad" } - }, + ], "remote": { "url": "http://foobar.com" } @@ -114,13 +115,14 @@ func testUnmarshalRequestBuilderError(t *testing.T) { config.Json(` { "token": { - "claims": { - "bad": { + "claims": [ + { + "key": "bad", "header": "X-Bad", "parameter": "bad", "variable": "bad" } - } + ] } } `), @@ -147,11 +149,12 @@ func testUnmarshalSuccess(t *testing.T) { config.Json(` { "token": { - "claims": { - "static": { + "claims": [ + { + "key": "static", "value": "foo" } - } + ] } } `),