diff --git a/internal/cli/universal_login_customize.go b/internal/cli/universal_login_customize.go index 78b741287..2b2a8c529 100644 --- a/internal/cli/universal_login_customize.go +++ b/internal/cli/universal_login_customize.go @@ -25,6 +25,7 @@ const ( loadBrandingMessageType = "LOAD_BRANDING" fetchPromptMessageType = "FETCH_PROMPT" saveBrandingMessageType = "SAVE_BRANDING" + errorMessageType = "ERROR" ) type ( @@ -342,16 +343,21 @@ func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { connection, err := upgrader.Upgrade(w, r, nil) if err != nil { - h.display.Errorf("failed to upgrade the connection to the WebSocket protocol: %v", err) + h.display.Errorf("Failed to upgrade the connection to the WebSocket protocol: %v", err) + h.display.Warnf("Try restarting the command.") h.shutdown() return } + defer func() { + _ = connection.Close() + }() connection.SetReadLimit(1e+6) // 1 MB. payload, err := json.Marshal(&h.brandingData) if err != nil { - h.display.Errorf("failed to encode the branding data to json: %v", err) + h.display.Errorf("Failed to encode the branding data to json: %v", err) + h.display.Warnf("Try restarting the command.") h.shutdown() return } @@ -362,7 +368,8 @@ func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if err = connection.WriteJSON(loadBrandingMsg); err != nil { - h.display.Errorf("failed to send branding data message: %v", err) + h.display.Errorf("Failed to send branding data message: %v", err) + h.display.Warnf("Try restarting the command.") h.shutdown() return } @@ -370,7 +377,13 @@ func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { for { var message webSocketMessage if err := connection.ReadJSON(&message); err != nil { - h.display.Errorf("failed to read WebSocket message: %v", err) + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) || + websocket.IsUnexpectedCloseError(err, websocket.CloseAbnormalClosure) { + // The connection was closed. + break + } + + h.display.Errorf("Failed to read WebSocket message: %v", err) continue } @@ -378,7 +391,7 @@ func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { case fetchPromptMessageType: var promptToFetch promptData if err := json.Unmarshal(message.Payload, &promptToFetch); err != nil { - h.display.Errorf("failed to unmarshal fetch prompt payload: %v", err) + h.display.Errorf("Failed to unmarshal fetch prompt payload: %v", err) continue } @@ -389,7 +402,17 @@ func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { promptToFetch.Language, ) if err != nil { - h.display.Errorf("failed to fetch custom text for prompt: %v", err) + h.display.Errorf("Failed to fetch custom text for prompt: %v", err) + + errorMsg := webSocketMessage{ + Type: errorMessageType, + Payload: []byte(fmt.Sprintf(`{"error":%q}`, err.Error())), + } + + if err := connection.WriteJSON(errorMsg); err != nil { + h.display.Errorf("Failed to send error message: %v", err) + } + continue } @@ -406,18 +429,28 @@ func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if err = connection.WriteJSON(fetchPromptMsg); err != nil { - h.display.Errorf("failed to send prompt data message: %v", err) + h.display.Errorf("Failed to send prompt data message: %v", err) continue } case saveBrandingMessageType: var saveBrandingMsg universalLoginBrandingData if err := json.Unmarshal(message.Payload, &saveBrandingMsg); err != nil { - h.display.Errorf("failed to unmarshal save branding data payload: %v", err) + h.display.Errorf("Failed to unmarshal save branding data payload: %v", err) continue } if err := saveUniversalLoginBrandingData(r.Context(), h.api, &saveBrandingMsg); err != nil { - h.display.Errorf("failed to save branding data: %v", err) + h.display.Errorf("Failed to save branding data: %v", err) + + errorMsg := webSocketMessage{ + Type: errorMessageType, + Payload: []byte(fmt.Sprintf(`{"error":%q}`, err.Error())), + } + + if err := connection.WriteJSON(errorMsg); err != nil { + h.display.Errorf("Failed to send error message: %v", err) + } + continue } }