diff --git a/internal/cli/universal_login_customize.go b/internal/cli/universal_login_customize.go index 60d8af797..244934580 100644 --- a/internal/cli/universal_login_customize.go +++ b/internal/cli/universal_login_customize.go @@ -718,7 +718,7 @@ func fetchAllPartials(ctx context.Context, api *auth0.API) ([]*management.Prompt partial, err := api.Prompt.ReadPartials(ctx, prompt) if err != nil { - if strings.Contains(err.Error(), "To create or modify prompt templates") { + if strings.Contains(err.Error(), "feature is not available for your plan") { return []*management.PromptPartials{}, nil } return nil, err diff --git a/internal/cli/universal_login_customize_test.go b/internal/cli/universal_login_customize_test.go index abd8b02f0..8f379d018 100644 --- a/internal/cli/universal_login_customize_test.go +++ b/internal/cli/universal_login_customize_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http" + "reflect" "testing" "github.com/auth0/go-auth0/management" @@ -1623,3 +1624,112 @@ func TestSaveUniversalLoginBrandingData(t *testing.T) { }) } } + +func TestFetchAllPartials(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + var testCases = []struct { + name string + expectedData []*management.PromptPartials + expectedError string + mockedAPI func() *auth0.API + }{ + { + name: "it can fetch all partials", + expectedData: []*management.PromptPartials{ + { + Prompt: "signup", + FormContentStart: "
", + }, + { + Prompt: "signup-id", + FormContentStart: "", + }, + { + Prompt: "signup-password", + FormContentStart: "", + }, + { + Prompt: "login", + FormContentStart: "", + }, + { + Prompt: "login-id", + FormContentStart: "", + }, + { + Prompt: "login-password", + FormContentStart: "", + }, + }, + mockedAPI: func() *auth0.API { + mockPromptAPI := mock.NewMockPromptAPI(ctrl) + for _, promptType := range allowedPromptsWithPartials { + mockPromptAPI.EXPECT(). + ReadPartials(gomock.Any(), promptType). + Return(&management.PromptPartials{ + FormContentStart: "", + Prompt: promptType, + }, nil) + } + + mockAPI := &auth0.API{ + Prompt: mockPromptAPI, + } + + return mockAPI + }, + }, + { + name: "it fails to fetch partials if there's an error retrieving them", + expectedError: "failed to fetch partials", + mockedAPI: func() *auth0.API { + mockPromptAPI := mock.NewMockPromptAPI(ctrl) + mockPromptAPI.EXPECT(). + ReadPartials(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("failed to fetch partials")) + + mockAPI := &auth0.API{ + Prompt: mockPromptAPI, + } + + return mockAPI + }, + }, + { + name: "it doesn't fails if feature flag is disabled", + expectedData: []*management.PromptPartials{}, + mockedAPI: func() *auth0.API { + mockPromptAPI := mock.NewMockPromptAPI(ctrl) + mockPromptAPI.EXPECT(). + ReadPartials(gomock.Any(), gomock.Any()). + Return(nil, fmt.Errorf("feature is not available for your plan")) + mockAPI := &auth0.API{ + Prompt: mockPromptAPI, + } + + return mockAPI + }, + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + actualData, actualError := fetchAllPartials(context.Background(), test.mockedAPI()) + + if test.expectedError != "" { + if actualError == nil || actualError.Error() != test.expectedError { + t.Errorf("expected error %q, got %q", test.expectedError, actualError) + } + } else { + if actualError != nil { + t.Errorf("unexpected error: %v", actualError) + } + if !reflect.DeepEqual(actualData, test.expectedData) { + t.Errorf("expected data %v, got %v", test.expectedData, actualData) + } + } + }) + } +}