diff --git a/internal/cli/universal_login_customize.go b/internal/cli/universal_login_customize.go index 2b2a8c529..bf6c595d6 100644 --- a/internal/cli/universal_login_customize.go +++ b/internal/cli/universal_login_customize.go @@ -49,6 +49,10 @@ type ( CustomText map[string]interface{} `json:"custom_text,omitempty"` } + errorData struct { + Error string `json:"error"` + } + webSocketHandler struct { shutdown context.CancelFunc display *display.Renderer @@ -57,11 +61,69 @@ type ( } webSocketMessage struct { - Type string `json:"type"` - Payload json.RawMessage `json:"payload"` + Type string `json:"type"` + Payload interface{} `json:"-"` } ) +// MarshalJSON implements the json.Marshaler interface. +func (m *webSocketMessage) MarshalJSON() ([]byte, error) { + type message webSocketMessage + type messageWrapper struct { + *message + RawPayload json.RawMessage `json:"payload"` + } + + w := &messageWrapper{(*message)(m), nil} + + if m.Payload != nil { + b, err := json.Marshal(m.Payload) + if err != nil { + return nil, err + } + + w.RawPayload = b + } + + return json.Marshal(w) +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (m *webSocketMessage) UnmarshalJSON(b []byte) error { + type message webSocketMessage + type messageWrapper struct { + *message + RawPayload json.RawMessage `json:"payload"` + } + + w := &messageWrapper{(*message)(m), nil} + + if err := json.Unmarshal(b, w); err != nil { + return err + } + + var payload interface{} + + switch m.Type { + case loadBrandingMessageType, saveBrandingMessageType: + payload = &universalLoginBrandingData{} + case fetchPromptMessageType: + payload = &promptData{} + default: + payload = make(map[string]interface{}) + } + + if w.RawPayload != nil { + if err := json.Unmarshal(w.RawPayload, &payload); err != nil { + return err + } + } + + m.Payload = payload + + return nil +} + func customizeUniversalLoginCmd(cli *cli) *cobra.Command { cmd := &cobra.Command{ Use: "customize", @@ -354,20 +416,12 @@ func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { connection.SetReadLimit(1e+6) // 1 MB. - payload, err := json.Marshal(&h.brandingData) - if err != nil { - h.display.Errorf("Failed to encode the branding data to json: %v", err) - h.display.Warnf("Try restarting the command.") - h.shutdown() - return - } - loadBrandingMsg := webSocketMessage{ Type: loadBrandingMessageType, - Payload: payload, + Payload: h.brandingData, } - if err = connection.WriteJSON(loadBrandingMsg); err != nil { + if err = connection.WriteJSON(&loadBrandingMsg); err != nil { h.display.Errorf("Failed to send branding data message: %v", err) h.display.Warnf("Try restarting the command.") h.shutdown() @@ -389,9 +443,9 @@ func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch message.Type { case fetchPromptMessageType: - var promptToFetch promptData - if err := json.Unmarshal(message.Payload, &promptToFetch); err != nil { - h.display.Errorf("Failed to unmarshal fetch prompt payload: %v", err) + promptToFetch, ok := message.Payload.(*promptData) + if !ok { + h.display.Errorf("Invalid payload type: %T", message.Payload) continue } @@ -405,49 +459,46 @@ func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.display.Errorf("Failed to fetch custom text for prompt: %v", err) errorMsg := webSocketMessage{ - Type: errorMessageType, - Payload: []byte(fmt.Sprintf(`{"error":%q}`, err.Error())), + Type: errorMessageType, + Payload: &errorData{ + Error: err.Error(), + }, } - if err := connection.WriteJSON(errorMsg); err != nil { + if err := connection.WriteJSON(&errorMsg); err != nil { h.display.Errorf("Failed to send error message: %v", err) } continue } - payload, err := json.Marshal(promptToSend) - if err != nil { - h.display.Errorf("failed to encode the branding data to json: %v", err) - h.shutdown() - return - } - fetchPromptMsg := webSocketMessage{ Type: fetchPromptMessageType, - Payload: payload, + Payload: promptToSend, } - if err = connection.WriteJSON(fetchPromptMsg); err != nil { + if err = connection.WriteJSON(&fetchPromptMsg); err != nil { h.display.Errorf("Failed to send prompt data message: %v", err) continue } case saveBrandingMessageType: - var saveBrandingMsg universalLoginBrandingData - if err := json.Unmarshal(message.Payload, &saveBrandingMsg); err != nil { - h.display.Errorf("Failed to unmarshal save branding data payload: %v", err) + saveBrandingMsg, ok := message.Payload.(*universalLoginBrandingData) + if !ok { + h.display.Errorf("Invalid payload type: %T", message.Payload) continue } - if err := saveUniversalLoginBrandingData(r.Context(), h.api, &saveBrandingMsg); err != nil { + if err := saveUniversalLoginBrandingData(r.Context(), h.api, saveBrandingMsg); err != nil { h.display.Errorf("Failed to save branding data: %v", err) errorMsg := webSocketMessage{ - Type: errorMessageType, - Payload: []byte(fmt.Sprintf(`{"error":%q}`, err.Error())), + Type: errorMessageType, + Payload: &errorData{ + Error: err.Error(), + }, } - if err := connection.WriteJSON(errorMsg); err != nil { + if err := connection.WriteJSON(&errorMsg); err != nil { h.display.Errorf("Failed to send error message: %v", err) }