From 8c3c25e8fc3b9a76a0c6ba1bf30573973b3d3f8e Mon Sep 17 00:00:00 2001 From: Ewan Harris Date: Wed, 13 Dec 2023 17:48:23 +0000 Subject: [PATCH] Export an Authentication Error to allow type assertions (#330) Co-authored-by: Rita Zerrizuela --- authentication/authentication_error.go | 16 +++++++++------- authentication/authentication_error_test.go | 10 +++++----- authentication/authentication_test.go | 2 +- authentication/doc.go | 18 ++++++++++++++++++ 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/authentication/authentication_error.go b/authentication/authentication_error.go index 0df60783..7d47a1a8 100644 --- a/authentication/authentication_error.go +++ b/authentication/authentication_error.go @@ -6,17 +6,19 @@ import ( "net/http" ) -type authenticationError struct { +// Error represents errors returned from the Authentication API. The `Err` property can +// be used to check the error code returned from the API. +type Error struct { StatusCode int `json:"statusCode"` Err string `json:"error"` Message string `json:"error_description"` } func newError(response *http.Response) error { - apiError := &authenticationError{} + apiError := &Error{} if err := json.NewDecoder(response.Body).Decode(apiError); err != nil { - return &authenticationError{ + return &Error{ StatusCode: response.StatusCode, Err: http.StatusText(response.StatusCode), Message: fmt.Errorf("failed to decode json error response payload: %w", err).Error(), @@ -34,20 +36,20 @@ func newError(response *http.Response) error { } // Error formats the error into a string representation. -func (a *authenticationError) Error() string { +func (a *Error) Error() string { return fmt.Sprintf("%d %s: %s", a.StatusCode, a.Err, a.Message) } // Status returns the status code of the error. -func (a *authenticationError) Status() int { +func (a *Error) Status() int { return a.StatusCode } // UnmarshalJSON implements the json.Unmarshaler interface. // // It is required to handle the differences between error responses between the APIs. -func (a *authenticationError) UnmarshalJSON(b []byte) error { - type authError authenticationError +func (a *Error) UnmarshalJSON(b []byte) error { + type authError Error type authErrorWrapper struct { *authError Code string `json:"code"` diff --git a/authentication/authentication_error_test.go b/authentication/authentication_error_test.go index d86dce2e..7b657f50 100644 --- a/authentication/authentication_error_test.go +++ b/authentication/authentication_error_test.go @@ -13,7 +13,7 @@ func Test_newError(t *testing.T) { var testCases = []struct { name string givenResponse http.Response - expectedError authenticationError + expectedError Error }{ { name: "it fails to decode if body is not json", @@ -21,7 +21,7 @@ func Test_newError(t *testing.T) { StatusCode: http.StatusForbidden, Body: io.NopCloser(strings.NewReader("Hello, I'm not JSON.")), }, - expectedError: authenticationError{ + expectedError: Error{ StatusCode: 403, Err: "Forbidden", Message: "failed to decode json error response payload: invalid character 'H' looking for beginning of value", @@ -33,7 +33,7 @@ func Test_newError(t *testing.T) { StatusCode: http.StatusBadRequest, Body: io.NopCloser(strings.NewReader(`{"statusCode":400,"error":"invalid_scope","error_description":"Scope must be an array or a string"}`)), }, - expectedError: authenticationError{ + expectedError: Error{ StatusCode: 400, Err: "invalid_scope", Message: "Scope must be an array or a string", @@ -45,7 +45,7 @@ func Test_newError(t *testing.T) { StatusCode: http.StatusInternalServerError, Body: io.NopCloser(strings.NewReader(`{"errorMessage":"wrongStruct"}`)), }, - expectedError: authenticationError{ + expectedError: Error{ StatusCode: 500, Err: "Internal Server Error", Message: "", @@ -57,7 +57,7 @@ func Test_newError(t *testing.T) { StatusCode: http.StatusBadRequest, Body: io.NopCloser(strings.NewReader(`{"name":"BadRequestError","code":"invalid_signup","description":"Invalid sign up","statusCode":400}`)), }, - expectedError: authenticationError{ + expectedError: Error{ StatusCode: 400, Err: "invalid_signup", Message: "Invalid sign up", diff --git a/authentication/authentication_test.go b/authentication/authentication_test.go index dcbdfdda..c0d35dca 100644 --- a/authentication/authentication_test.go +++ b/authentication/authentication_test.go @@ -411,7 +411,7 @@ func TestRetries(t *testing.T) { assert.NoError(t, err) _, err = a.UserInfo(context.Background(), "123") - assert.Equal(t, http.StatusBadGateway, err.(*authenticationError).StatusCode) + assert.Equal(t, http.StatusBadGateway, err.(*Error).StatusCode) assert.Equal(t, 1, i) }) diff --git a/authentication/doc.go b/authentication/doc.go index 12bcb8b2..ccdceaa5 100644 --- a/authentication/doc.go +++ b/authentication/doc.go @@ -51,4 +51,22 @@ // authentication.WithClientSecret(secret), // Optional depending on the grants used // authentication.WithClockTolerance(10 * time.Second), // ) +// +// # Handling Errors +// +// This package exports an [authentication.Error] type that can be used to check errors +// returned from the Authentication API and handle them as necessary, for example +// +// tokens, err := auth.OAuth.LoginWithPassword(context.Background(), oauth.LoginWithPasswordRequest{ +// Username: "test@example.com", +// Password: "hunter2", +// }, oauth.IDTokenValidationOptions{}) +// +// if err != nil { +// if aerr, ok := err.(*authentication.Error); ok { +// if aerr.Err == "mfa_required" { +// // Handle prompting for MFA usage +// } +// } +// } package authentication