diff --git a/internal/cli/universal_login_customize.go b/internal/cli/universal_login_customize.go index 4291e6832..522b5170c 100644 --- a/internal/cli/universal_login_customize.go +++ b/internal/cli/universal_login_customize.go @@ -2,6 +2,7 @@ package cli import ( "context" + "encoding/json" "fmt" "net" "net/http" @@ -19,7 +20,12 @@ import ( "github.com/auth0/auth0-cli/internal/display" ) -const webAppURL = "http://localhost:5173" +const ( + webAppURL = "http://localhost:5173" + loadBrandingMessageType = "LOAD_BRANDING" + fetchPromptMessageType = "FETCH_PROMPT" + saveBrandingMessageType = "SAVE_BRANDING" +) type ( universalLoginBrandingData struct { @@ -40,7 +46,7 @@ type ( promptData struct { Language string `json:"language"` Prompt string `json:"prompt"` - CustomText map[string]map[string]interface{} `json:"custom_text"` + CustomText map[string]map[string]interface{} `json:"custom_text,omitempty"` } webSocketHandler struct { @@ -49,6 +55,11 @@ type ( api *auth0.API brandingData *universalLoginBrandingData } + + webSocketMessage struct { + Type string `json:"type"` + Payload json.RawMessage `json:"payload"` + } ) func customizeUniversalLoginCmd(cli *cli) *cobra.Command { @@ -314,18 +325,77 @@ func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { connection, err := upgrader.Upgrade(w, r, nil) if err != nil { - h.display.Errorf("error accepting WebSocket connection: %v", err) + h.display.Errorf("failed to upgrade the connection to the WebSocket protocol: %v", err) h.shutdown() return } connection.SetReadLimit(1e+6) // 1 MB. - if err = connection.WriteJSON(h.brandingData); err != nil { - h.display.Errorf("failed to write json message: %v", err) + payload, err := json.Marshal(&h.brandingData) + if err != nil { + h.display.Errorf("failed to encode the branding data to json: %v", err) + h.shutdown() + return + } + + loadBrandingMsg := webSocketMessage{ + Type: loadBrandingMessageType, + Payload: payload, + } + + if err = connection.WriteJSON(loadBrandingMsg); err != nil { + h.display.Errorf("failed to send branding data message: %v", err) h.shutdown() return } + + for { + var message webSocketMessage + if err := connection.ReadJSON(&message); err != nil { + h.display.Errorf("failed to read WebSocket message: %v", err) + continue + } + + switch message.Type { + case fetchPromptMessageType: + var promptToFetch promptData + if err := json.Unmarshal(message.Payload, &promptToFetch); err != nil { + h.display.Errorf("failed to unmarshal fetch prompt payload: %v", err) + continue + } + + promptToSend, err := fetchPromptCustomTextWithDefaults( + r.Context(), + h.api, + promptToFetch.Prompt, + promptToFetch.Language, + ) + if err != nil { + h.display.Errorf("failed to fetch custom text for prompt: %v", err) + continue + } + + payload, err := json.Marshal(promptToSend) + if err != nil { + h.display.Errorf("failed to encode the branding data to json: %v", err) + h.shutdown() + return + } + + fetchPromptMsg := webSocketMessage{ + Type: fetchPromptMessageType, + Payload: payload, + } + + if err = connection.WriteJSON(fetchPromptMsg); err != nil { + h.display.Errorf("failed to send prompt data message: %v", err) + continue + } + case saveBrandingMessageType: + h.display.Warnf("not yet implemented") + } + } } func checkOriginFunc(r *http.Request) bool {