From 66252ab70613f8e2c57d309334ca47b08dda673d Mon Sep 17 00:00:00 2001 From: Will Vedder Date: Wed, 20 Sep 2023 10:55:11 -0400 Subject: [PATCH] Adding custom JSON deserializer for UserInfo type (#851) * Adding custom JSON deserializer for UserInfo type * Adding error test case * Errorf instead of sprintf * Removing error stutter --------- Co-authored-by: Will Vedder --- internal/auth/authutil/user_info.go | 37 ++++++++++++++++++++++++ internal/auth/authutil/user_info_test.go | 27 ++++++++++++++++- 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/internal/auth/authutil/user_info.go b/internal/auth/authutil/user_info.go index b38b3bc44..eebcc5fb2 100644 --- a/internal/auth/authutil/user_info.go +++ b/internal/auth/authutil/user_info.go @@ -5,6 +5,8 @@ import ( "fmt" "net/http" "net/url" + "reflect" + "strconv" "time" ) @@ -31,6 +33,41 @@ type UserInfo struct { UpdatedAt *time.Time `json:"updated_at,omitempty"` } +// UnmarshalJSON is a custom deserializer for the UserInfo type. +// A custom solution is necessary due to possible inconsistencies in value types. +func (u *UserInfo) UnmarshalJSON(b []byte) error { + type userInfo UserInfo + type userAlias struct { + *userInfo + RawEmailVerified interface{} `json:"email_verified,omitempty"` + } + + alias := &userAlias{(*userInfo)(u), nil} + + err := json.Unmarshal(b, alias) + if err != nil { + return err + } + + if alias.RawEmailVerified != nil { + var emailVerified bool + switch rawEmailVerified := alias.RawEmailVerified.(type) { + case bool: + emailVerified = rawEmailVerified + case string: + emailVerified, err = strconv.ParseBool(rawEmailVerified) + if err != nil { + return err + } + default: + return fmt.Errorf("email_verified field expected to be bool or string, got: %s", reflect.TypeOf(rawEmailVerified)) + } + alias.EmailVerified = &emailVerified + } + + return nil +} + // FetchUserInfo fetches and parses user information with the provided access token. func FetchUserInfo(httpClient *http.Client, baseDomain, token string) (*UserInfo, error) { endpoint := url.URL{Scheme: "https", Host: baseDomain, Path: "/userinfo"} diff --git a/internal/auth/authutil/user_info_test.go b/internal/auth/authutil/user_info_test.go index 8622c7b54..d0f79d1ec 100644 --- a/internal/auth/authutil/user_info_test.go +++ b/internal/auth/authutil/user_info_test.go @@ -16,7 +16,7 @@ func TestUserInfo(t *testing.T) { assert.Equal(t, "Bearer token", r.Header.Get("authorization")) w.Header().Set("Content-Type", "application/json") - io.WriteString(w, `{"name": "Joe Bloggs"}`) + io.WriteString(w, `{"name": "Joe Bloggs","email_verified":true}`) })) defer ts.Close() @@ -27,6 +27,25 @@ func TestUserInfo(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "Joe Bloggs", *user.Name) + assert.Equal(t, true, *user.EmailVerified) + }) + + t.Run("Successfully call user info endpoint with string encoded email verified field", func(t *testing.T) { + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "Bearer token", r.Header.Get("authorization")) + + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, `{"email_verified":"true"}`) + })) + + defer ts.Close() + parsedURL, err := url.Parse(ts.URL) + assert.NoError(t, err) + + user, err := FetchUserInfo(ts.Client(), parsedURL.Host, "token") + + assert.NoError(t, err) + assert.Equal(t, true, *user.EmailVerified) }) testCases := []struct { @@ -46,6 +65,12 @@ func TestUserInfo(t *testing.T) { httpStatus: http.StatusOK, response: `{ "foo": "bar" `, }, + { + name: "Email verified field not string or bool", + expect: "cannot decode response: email_verified field expected to be bool or string, got: float64", + httpStatus: http.StatusOK, + response: `{ "email_verified": 0 }`, + }, } for _, testCase := range testCases {