From 1adc975ad6bee54da30a69742d93eca56383358d Mon Sep 17 00:00:00 2001 From: Sergiu Ghitea <28300158+sergiught@users.noreply.github.com> Date: Thu, 19 Oct 2023 16:54:07 +0200 Subject: [PATCH] Send branding data on websocket message request --- internal/cli/universal_login_customize.go | 92 +++++++++++------------ 1 file changed, 45 insertions(+), 47 deletions(-) diff --git a/internal/cli/universal_login_customize.go b/internal/cli/universal_login_customize.go index 8faf75823..044fdc1f9 100644 --- a/internal/cli/universal_login_customize.go +++ b/internal/cli/universal_login_customize.go @@ -15,18 +15,17 @@ import ( "github.com/spf13/cobra" "golang.org/x/sync/errgroup" - "github.com/auth0/auth0-cli/internal/ansi" "github.com/auth0/auth0-cli/internal/auth0" "github.com/auth0/auth0-cli/internal/display" ) const ( - webAppURL = "http://localhost:5173" - loadBrandingMessageType = "LOAD_BRANDING" - fetchPromptMessageType = "FETCH_PROMPT" - saveBrandingMessageType = "SAVE_BRANDING" - errorMessageType = "ERROR" - successMessageType = "SUCCESS" + webAppURL = "http://localhost:5173" + fetchBrandingMessageType = "FETCH_BRANDING" + fetchPromptMessageType = "FETCH_PROMPT" + saveBrandingMessageType = "SAVE_BRANDING" + errorMessageType = "ERROR" + successMessageType = "SUCCESS" ) type ( @@ -59,10 +58,10 @@ type ( } webSocketHandler struct { - shutdown context.CancelFunc - display *display.Renderer - api *auth0.API - brandingData *universalLoginBrandingData + shutdown context.CancelFunc + display *display.Renderer + api *auth0.API + tenant string } webSocketMessage struct { @@ -110,7 +109,7 @@ func (m *webSocketMessage) UnmarshalJSON(b []byte) error { var payload interface{} switch m.Type { - case loadBrandingMessageType, saveBrandingMessageType: + case fetchBrandingMessageType, saveBrandingMessageType: payload = &universalLoginBrandingData{} case fetchPromptMessageType: payload = &promptData{} @@ -151,16 +150,7 @@ func customizeUniversalLoginCmd(cli *cli) *cobra.Command { return err } - var universalLoginBrandingData *universalLoginBrandingData - - if err := ansi.Spinner("Fetching Universal Login branding data", func() (err error) { - universalLoginBrandingData, err = fetchUniversalLoginBrandingData(ctx, cli.api, cli.tenant) - return err - }); err != nil { - return err - } - - return startWebSocketServer(ctx, cli.api, cli.renderer, universalLoginBrandingData) + return startWebSocketServer(ctx, cli.api, cli.renderer, cli.tenant) }, } @@ -353,7 +343,7 @@ func startWebSocketServer( ctx context.Context, api *auth0.API, display *display.Renderer, - brandingData *universalLoginBrandingData, + tenantDomain string, ) error { ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -365,10 +355,10 @@ func startWebSocketServer( defer listener.Close() handler := &webSocketHandler{ - display: display, - api: api, - shutdown: cancel, - brandingData: brandingData, + display: display, + api: api, + shutdown: cancel, + tenant: tenantDomain, } server := &http.Server{ @@ -421,32 +411,40 @@ func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { connection.SetReadLimit(1e+6) // 1 MB. - loadBrandingMsg := webSocketMessage{ - Type: loadBrandingMessageType, - Payload: h.brandingData, - } - - if err = connection.WriteJSON(&loadBrandingMsg); err != nil { - h.display.Errorf("Failed to send branding data message: %v", err) - h.display.Warnf("Try restarting the command.") - h.shutdown() - return - } - for { var message webSocketMessage if err := connection.ReadJSON(&message); err != nil { - if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) || - websocket.IsUnexpectedCloseError(err, websocket.CloseAbnormalClosure) { - // The connection was closed. - break - } - - h.display.Errorf("Failed to read WebSocket message: %v", err) - continue + break } switch message.Type { + case fetchBrandingMessageType: + brandingData, err := fetchUniversalLoginBrandingData(r.Context(), h.api, h.tenant) + if err != nil { + h.display.Errorf("Failed to fetch Universal Login branding data: %v", err) + + errorMsg := webSocketMessage{ + Type: errorMessageType, + Payload: &errorData{ + Error: err.Error(), + }, + } + + if err := connection.WriteJSON(&errorMsg); err != nil { + h.display.Errorf("Failed to send error message: %v", err) + } + + continue + } + + loadBrandingMsg := webSocketMessage{ + Type: fetchBrandingMessageType, + Payload: brandingData, + } + + if err = connection.WriteJSON(&loadBrandingMsg); err != nil { + h.display.Errorf("Failed to send branding data message: %v", err) + } case fetchPromptMessageType: promptToFetch, ok := message.Payload.(*promptData) if !ok {