From a1e202c224afecd91e20f197aa9189f21e0094d5 Mon Sep 17 00:00:00 2001 From: Sergiu Ghitea <28300158+sergiught@users.noreply.github.com> Date: Tue, 31 Oct 2023 17:56:10 +0100 Subject: [PATCH] [1/3] Expand test coverage for ul customize (#892) Expand test coverage for ul customize --- internal/cli/universal_login_customize.go | 406 +++++++++--------- .../cli/universal_login_customize_test.go | 306 +++++++++++++ 2 files changed, 509 insertions(+), 203 deletions(-) diff --git a/internal/cli/universal_login_customize.go b/internal/cli/universal_login_customize.go index 2d1a87167..2377747fd 100644 --- a/internal/cli/universal_login_customize.go +++ b/internal/cli/universal_login_customize.go @@ -190,6 +190,209 @@ func ensureNewUniversalLoginExperienceIsActive(ctx context.Context, api *auth0.A ) } +func startWebSocketServer( + ctx context.Context, + api *auth0.API, + display *display.Renderer, + tenantDomain string, +) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + listener, err := net.Listen("tcp", webServerHost) + if err != nil { + return err + } + defer listener.Close() + + handler := &webSocketHandler{ + display: display, + api: api, + shutdown: cancel, + tenant: tenantDomain, + } + + assetsWithoutPrefix, err := fs.Sub(universalLoginPreviewAssets, "data/universal-login") + if err != nil { + return err + } + + router := http.NewServeMux() + router.Handle("/", http.FileServer(http.FS(assetsWithoutPrefix))) + router.Handle("/ws", handler) + + server := &http.Server{ + Handler: router, + ReadTimeout: time.Minute * 10, + WriteTimeout: time.Minute * 10, + } + + errChan := make(chan error, 1) + go func() { + errChan <- server.Serve(listener) + }() + + openWebAppInBrowser(display) + + select { + case err := <-errChan: + return err + case <-ctx.Done(): + return server.Close() + } +} + +func openWebAppInBrowser(display *display.Renderer) { + display.Infof("Perform your changes within the editor: %q", webServerURL) + + if err := browser.OpenURL(webServerURL); err != nil { + display.Warnf("Failed to open the browser. Visit the URL manually.") + } +} + +func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: checkOriginFunc, + } + + connection, err := upgrader.Upgrade(w, r, nil) + if err != nil { + h.display.Errorf("Failed to upgrade the connection to the WebSocket protocol: %v", err) + h.display.Warnf("Try restarting the command.") + h.shutdown() + return + } + defer func() { + _ = connection.Close() + }() + + connection.SetReadLimit(1e+6) // 1 MB. + + for { + var message webSocketMessage + if err := connection.ReadJSON(&message); err != nil { + break + } + + switch message.Type { + case fetchBrandingMessageType: + brandingData, err := fetchUniversalLoginBrandingData(r.Context(), h.api, h.tenant) + if err != nil { + h.display.Errorf("Failed to fetch Universal Login branding data: %v", err) + + errorMsg := webSocketMessage{ + Type: errorMessageType, + Payload: &errorData{ + Error: err.Error(), + }, + } + + if err := connection.WriteJSON(&errorMsg); err != nil { + h.display.Errorf("Failed to send error message: %v", err) + } + + continue + } + + loadBrandingMsg := webSocketMessage{ + Type: fetchBrandingMessageType, + Payload: brandingData, + } + + if err = connection.WriteJSON(&loadBrandingMsg); err != nil { + h.display.Errorf("Failed to send branding data message: %v", err) + } + case fetchPromptMessageType: + promptToFetch, ok := message.Payload.(*promptData) + if !ok { + h.display.Errorf("Invalid payload type: %T", message.Payload) + continue + } + + promptToSend, err := fetchPromptCustomTextWithDefaults( + r.Context(), + h.api, + promptToFetch.Prompt, + promptToFetch.Language, + ) + if err != nil { + h.display.Errorf("Failed to fetch custom text for prompt: %v", err) + + errorMsg := webSocketMessage{ + Type: errorMessageType, + Payload: &errorData{ + Error: err.Error(), + }, + } + + if err := connection.WriteJSON(&errorMsg); err != nil { + h.display.Errorf("Failed to send error message: %v", err) + } + + continue + } + + fetchPromptMsg := webSocketMessage{ + Type: fetchPromptMessageType, + Payload: promptToSend, + } + + if err = connection.WriteJSON(&fetchPromptMsg); err != nil { + h.display.Errorf("Failed to send prompt data message: %v", err) + continue + } + case saveBrandingMessageType: + saveBrandingMsg, ok := message.Payload.(*universalLoginBrandingData) + if !ok { + h.display.Errorf("Invalid payload type: %T", message.Payload) + continue + } + + if err := saveUniversalLoginBrandingData(r.Context(), h.api, saveBrandingMsg); err != nil { + h.display.Errorf("Failed to save branding data: %v", err) + + errorMsg := webSocketMessage{ + Type: errorMessageType, + Payload: &errorData{ + Error: err.Error(), + }, + } + + if err := connection.WriteJSON(&errorMsg); err != nil { + h.display.Errorf("Failed to send error message: %v", err) + } + + continue + } + + successMsg := webSocketMessage{ + Type: successMessageType, + Payload: &successData{ + Success: true, + }, + } + + if err := connection.WriteJSON(&successMsg); err != nil { + h.display.Errorf("Failed to send success message: %v", err) + } + } + } +} + +func checkOriginFunc(r *http.Request) bool { + origin := r.Header["Origin"] + if len(origin) == 0 { + return false + } + + originURL, err := url.Parse(origin[0]) + if err != nil { + return false + } + + return originURL.String() == webServerURL +} + func fetchUniversalLoginBrandingData( ctx context.Context, api *auth0.API, @@ -397,209 +600,6 @@ func fetchAllApplications(ctx context.Context, api *auth0.API) ([]*applicationDa return applications, nil } -func startWebSocketServer( - ctx context.Context, - api *auth0.API, - display *display.Renderer, - tenantDomain string, -) error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - listener, err := net.Listen("tcp", webServerHost) - if err != nil { - return err - } - defer listener.Close() - - handler := &webSocketHandler{ - display: display, - api: api, - shutdown: cancel, - tenant: tenantDomain, - } - - assetsWithoutPrefix, err := fs.Sub(universalLoginPreviewAssets, "data/universal-login") - if err != nil { - return err - } - - router := http.NewServeMux() - router.Handle("/", http.FileServer(http.FS(assetsWithoutPrefix))) - router.Handle("/ws", handler) - - server := &http.Server{ - Handler: router, - ReadTimeout: time.Minute * 10, - WriteTimeout: time.Minute * 10, - } - - errChan := make(chan error, 1) - go func() { - errChan <- server.Serve(listener) - }() - - openWebAppInBrowser(display) - - select { - case err := <-errChan: - return err - case <-ctx.Done(): - return server.Close() - } -} - -func openWebAppInBrowser(display *display.Renderer) { - display.Infof("Perform your changes within the editor: %q", webServerURL) - - if err := browser.OpenURL(webServerURL); err != nil { - display.Warnf("Failed to open the browser. Visit the URL manually.") - } -} - -func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - upgrader := websocket.Upgrader{ - CheckOrigin: checkOriginFunc, - } - - connection, err := upgrader.Upgrade(w, r, nil) - if err != nil { - h.display.Errorf("Failed to upgrade the connection to the WebSocket protocol: %v", err) - h.display.Warnf("Try restarting the command.") - h.shutdown() - return - } - defer func() { - _ = connection.Close() - }() - - connection.SetReadLimit(1e+6) // 1 MB. - - for { - var message webSocketMessage - if err := connection.ReadJSON(&message); err != nil { - break - } - - switch message.Type { - case fetchBrandingMessageType: - brandingData, err := fetchUniversalLoginBrandingData(r.Context(), h.api, h.tenant) - if err != nil { - h.display.Errorf("Failed to fetch Universal Login branding data: %v", err) - - errorMsg := webSocketMessage{ - Type: errorMessageType, - Payload: &errorData{ - Error: err.Error(), - }, - } - - if err := connection.WriteJSON(&errorMsg); err != nil { - h.display.Errorf("Failed to send error message: %v", err) - } - - continue - } - - loadBrandingMsg := webSocketMessage{ - Type: fetchBrandingMessageType, - Payload: brandingData, - } - - if err = connection.WriteJSON(&loadBrandingMsg); err != nil { - h.display.Errorf("Failed to send branding data message: %v", err) - } - case fetchPromptMessageType: - promptToFetch, ok := message.Payload.(*promptData) - if !ok { - h.display.Errorf("Invalid payload type: %T", message.Payload) - continue - } - - promptToSend, err := fetchPromptCustomTextWithDefaults( - r.Context(), - h.api, - promptToFetch.Prompt, - promptToFetch.Language, - ) - if err != nil { - h.display.Errorf("Failed to fetch custom text for prompt: %v", err) - - errorMsg := webSocketMessage{ - Type: errorMessageType, - Payload: &errorData{ - Error: err.Error(), - }, - } - - if err := connection.WriteJSON(&errorMsg); err != nil { - h.display.Errorf("Failed to send error message: %v", err) - } - - continue - } - - fetchPromptMsg := webSocketMessage{ - Type: fetchPromptMessageType, - Payload: promptToSend, - } - - if err = connection.WriteJSON(&fetchPromptMsg); err != nil { - h.display.Errorf("Failed to send prompt data message: %v", err) - continue - } - case saveBrandingMessageType: - saveBrandingMsg, ok := message.Payload.(*universalLoginBrandingData) - if !ok { - h.display.Errorf("Invalid payload type: %T", message.Payload) - continue - } - - if err := saveUniversalLoginBrandingData(r.Context(), h.api, saveBrandingMsg); err != nil { - h.display.Errorf("Failed to save branding data: %v", err) - - errorMsg := webSocketMessage{ - Type: errorMessageType, - Payload: &errorData{ - Error: err.Error(), - }, - } - - if err := connection.WriteJSON(&errorMsg); err != nil { - h.display.Errorf("Failed to send error message: %v", err) - } - - continue - } - - successMsg := webSocketMessage{ - Type: successMessageType, - Payload: &successData{ - Success: true, - }, - } - - if err := connection.WriteJSON(&successMsg); err != nil { - h.display.Errorf("Failed to send success message: %v", err) - } - } - } -} - -func checkOriginFunc(r *http.Request) bool { - origin := r.Header["Origin"] - if len(origin) == 0 { - return false - } - - originURL, err := url.Parse(origin[0]) - if err != nil { - return false - } - - return originURL.String() == webServerURL -} - func saveUniversalLoginBrandingData(ctx context.Context, api *auth0.API, data *universalLoginBrandingData) error { group, ctx := errgroup.WithContext(ctx) diff --git a/internal/cli/universal_login_customize_test.go b/internal/cli/universal_login_customize_test.go index 8133559e5..a94469c1f 100644 --- a/internal/cli/universal_login_customize_test.go +++ b/internal/cli/universal_login_customize_test.go @@ -2,6 +2,7 @@ package cli import ( "context" + "encoding/json" "fmt" "net/http" "testing" @@ -1008,3 +1009,308 @@ func TestCheckOriginFunc(t *testing.T) { }) } } + +func TestWebSocketMessage_MarshalJSON(t *testing.T) { + var testCases = []struct { + name string + input *webSocketMessage + expected string + }{ + { + name: "it can marshal a fetch prompt data message", + input: &webSocketMessage{ + Type: "FETCH_PROMPT", + Payload: &promptData{ + Language: "en", + Prompt: "login", + CustomText: map[string]interface{}{"key": "value"}, + }, + }, + expected: `{"type":"FETCH_PROMPT","payload":{"language":"en","prompt":"login","custom_text":{"key":"value"}}}`, + }, + { + name: "it can marshal a fetch branding data message", + input: &webSocketMessage{ + Type: "FETCH_BRANDING", + Payload: &universalLoginBrandingData{}, + }, + expected: `{"type":"FETCH_BRANDING","payload":{"applications":null,"prompts":null,"settings":null,"template":null,"theme":null,"tenant":null}}`, + }, + { + name: "it can marshal a message with an empty payload", + input: &webSocketMessage{ + Type: "FETCH_BRANDING", + }, + expected: `{"type":"FETCH_BRANDING","payload":null}`, + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + actual, err := json.Marshal(test.input) + assert.NoError(t, err) + assert.Equal(t, test.expected, string(actual)) + }) + } +} + +func TestWebSocketMessage_UnmarshalJSON(t *testing.T) { + var testCases = []struct { + name string + input []byte + expected *webSocketMessage + }{ + { + name: "it can unmarshal a fetch prompt data message", + input: []byte(`{"type":"FETCH_PROMPT","payload":{"language":"en","prompt":"login","custom_text":{"key":"value"}}}`), + expected: &webSocketMessage{ + Type: "FETCH_PROMPT", + Payload: &promptData{ + Language: "en", + Prompt: "login", + CustomText: map[string]interface{}{"key": "value"}, + }, + }, + }, + { + name: "it can unmarshal a fetch branding data message", + input: []byte(`{"type":"FETCH_BRANDING","payload":{"applications":null,"prompts":null,"settings":null,"template":null,"theme":null,"tenant":null}}`), + expected: &webSocketMessage{ + Type: "FETCH_BRANDING", + Payload: &universalLoginBrandingData{}, + }, + }, + { + name: "it can unmarshal a message with an empty payload", + input: []byte(`{"type":"FETCH_BRANDING","payload":null}`), + expected: &webSocketMessage{ + Type: "FETCH_BRANDING", + }, + }, + { + name: "it can unmarshal a message with an unknown payload", + input: []byte(`{"type":"UNKNOWN","payload":{"key":"value"}}`), + expected: &webSocketMessage{ + Type: "UNKNOWN", + Payload: map[string]interface{}{"key": "value"}, + }, + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + var actual webSocketMessage + err := json.Unmarshal(test.input, &actual) + assert.NoError(t, err) + assert.Equal(t, test.expected, &actual) + }) + } +} + +func TestSaveUniversalLoginBrandingData(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + var testCases = []struct { + name string + input *universalLoginBrandingData + expectedError string + mockedAPI func() *auth0.API + }{ + { + name: "it can correctly save all of the universal login branding data", + input: &universalLoginBrandingData{ + Settings: &management.Branding{ + Colors: &management.BrandingColors{ + Primary: auth0.String("#33ddff"), + PageBackground: auth0.String("#99aacc"), + }, + }, + Template: &management.BrandingUniversalLogin{ + Body: auth0.String(""), + }, + Theme: &management.BrandingTheme{}, + Prompts: []*promptData{ + { + Language: "en", + Prompt: "login", + CustomText: map[string]interface{}{"key": "value"}, + }, + }, + }, + mockedAPI: func() *auth0.API { + mockBrandingAPI := mock.NewMockBrandingAPI(ctrl) + mockBrandingAPI.EXPECT(). + Update(gomock.Any(), &management.Branding{ + Colors: &management.BrandingColors{ + Primary: auth0.String("#33ddff"), + PageBackground: auth0.String("#99aacc"), + }, + }). + Return(nil) + mockBrandingAPI.EXPECT(). + SetUniversalLogin(gomock.Any(), &management.BrandingUniversalLogin{ + Body: auth0.String(""), + }). + Return(nil) + + mockBrandingThemeAPI := mock.NewMockBrandingThemeAPI(ctrl) + mockBrandingThemeAPI.EXPECT(). + Default(gomock.Any()). + Return(&management.BrandingTheme{ + ID: auth0.String("111"), + }, nil) + mockBrandingThemeAPI.EXPECT(). + Update(gomock.Any(), "111", &management.BrandingTheme{}). + Return(nil) + + mockPromptAPI := mock.NewMockPromptAPI(ctrl) + mockPromptAPI.EXPECT(). + SetCustomText(gomock.Any(), "login", "en", map[string]interface{}{"key": "value"}). + Return(nil) + + mockAPI := &auth0.API{ + Branding: mockBrandingAPI, + BrandingTheme: mockBrandingThemeAPI, + Prompt: mockPromptAPI, + } + + return mockAPI + }, + }, + { + name: "it fails to save the universal login branding data if the branding api returns an error", + input: &universalLoginBrandingData{ + Settings: &management.Branding{ + Colors: &management.BrandingColors{ + Primary: auth0.String("#33ddff"), + PageBackground: auth0.String("#99aacc"), + }, + }, + Template: &management.BrandingUniversalLogin{ + Body: auth0.String(""), + }, + Theme: &management.BrandingTheme{}, + Prompts: []*promptData{ + { + Language: "en", + Prompt: "login", + CustomText: map[string]interface{}{"key": "value"}, + }, + }, + }, + expectedError: "branding api failure", + mockedAPI: func() *auth0.API { + mockBrandingAPI := mock.NewMockBrandingAPI(ctrl) + mockBrandingAPI.EXPECT(). + Update(gomock.Any(), &management.Branding{ + Colors: &management.BrandingColors{ + Primary: auth0.String("#33ddff"), + PageBackground: auth0.String("#99aacc"), + }, + }). + Return(fmt.Errorf("branding api failure")) + mockBrandingAPI.EXPECT(). + SetUniversalLogin(gomock.Any(), &management.BrandingUniversalLogin{ + Body: auth0.String(""), + }). + Return(nil) + + mockBrandingThemeAPI := mock.NewMockBrandingThemeAPI(ctrl) + mockBrandingThemeAPI.EXPECT(). + Default(gomock.Any()). + Return(&management.BrandingTheme{ + ID: auth0.String("111"), + }, nil) + mockBrandingThemeAPI.EXPECT(). + Update(gomock.Any(), "111", &management.BrandingTheme{}). + Return(nil) + + mockPromptAPI := mock.NewMockPromptAPI(ctrl) + mockPromptAPI.EXPECT(). + SetCustomText(gomock.Any(), "login", "en", map[string]interface{}{"key": "value"}). + Return(nil) + + mockAPI := &auth0.API{ + Branding: mockBrandingAPI, + BrandingTheme: mockBrandingThemeAPI, + Prompt: mockPromptAPI, + } + + return mockAPI + }, + }, + { + name: "it creates the theme if not found", + input: &universalLoginBrandingData{ + Settings: &management.Branding{ + Colors: &management.BrandingColors{ + Primary: auth0.String("#33ddff"), + PageBackground: auth0.String("#99aacc"), + }, + }, + Template: &management.BrandingUniversalLogin{ + Body: auth0.String(""), + }, + Theme: &management.BrandingTheme{}, + Prompts: []*promptData{ + { + Language: "en", + Prompt: "login", + CustomText: map[string]interface{}{"key": "value"}, + }, + }, + }, + mockedAPI: func() *auth0.API { + mockBrandingAPI := mock.NewMockBrandingAPI(ctrl) + mockBrandingAPI.EXPECT(). + Update(gomock.Any(), &management.Branding{ + Colors: &management.BrandingColors{ + Primary: auth0.String("#33ddff"), + PageBackground: auth0.String("#99aacc"), + }, + }). + Return(nil) + mockBrandingAPI.EXPECT(). + SetUniversalLogin(gomock.Any(), &management.BrandingUniversalLogin{ + Body: auth0.String(""), + }). + Return(nil) + + mockBrandingThemeAPI := mock.NewMockBrandingThemeAPI(ctrl) + mockBrandingThemeAPI.EXPECT(). + Default(gomock.Any()). + Return(&management.BrandingTheme{}, fmt.Errorf("failed to find theme")) + mockBrandingThemeAPI.EXPECT(). + Create(gomock.Any(), &management.BrandingTheme{}). + Return(nil) + + mockPromptAPI := mock.NewMockPromptAPI(ctrl) + mockPromptAPI.EXPECT(). + SetCustomText(gomock.Any(), "login", "en", map[string]interface{}{"key": "value"}). + Return(nil) + + mockAPI := &auth0.API{ + Branding: mockBrandingAPI, + BrandingTheme: mockBrandingThemeAPI, + Prompt: mockPromptAPI, + } + + return mockAPI + }, + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + err := saveUniversalLoginBrandingData(context.Background(), test.mockedAPI(), test.input) + + if test.expectedError != "" { + assert.EqualError(t, err, test.expectedError) + return + } + + assert.NoError(t, err) + }) + } +}