Skip to content

Commit

Permalink
Send branding data on websocket message request
Browse files Browse the repository at this point in the history
  • Loading branch information
sergiught committed Oct 19, 2023
1 parent 0493b3f commit 1adc975
Showing 1 changed file with 45 additions and 47 deletions.
92 changes: 45 additions & 47 deletions internal/cli/universal_login_customize.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,17 @@ import (
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"

"github.com/auth0/auth0-cli/internal/ansi"
"github.com/auth0/auth0-cli/internal/auth0"
"github.com/auth0/auth0-cli/internal/display"
)

const (
webAppURL = "http://localhost:5173"
loadBrandingMessageType = "LOAD_BRANDING"
fetchPromptMessageType = "FETCH_PROMPT"
saveBrandingMessageType = "SAVE_BRANDING"
errorMessageType = "ERROR"
successMessageType = "SUCCESS"
webAppURL = "http://localhost:5173"
fetchBrandingMessageType = "FETCH_BRANDING"
fetchPromptMessageType = "FETCH_PROMPT"
saveBrandingMessageType = "SAVE_BRANDING"
errorMessageType = "ERROR"
successMessageType = "SUCCESS"
)

type (
Expand Down Expand Up @@ -59,10 +58,10 @@ type (
}

webSocketHandler struct {
shutdown context.CancelFunc
display *display.Renderer
api *auth0.API
brandingData *universalLoginBrandingData
shutdown context.CancelFunc
display *display.Renderer
api *auth0.API
tenant string
}

webSocketMessage struct {
Expand Down Expand Up @@ -110,7 +109,7 @@ func (m *webSocketMessage) UnmarshalJSON(b []byte) error {
var payload interface{}

switch m.Type {
case loadBrandingMessageType, saveBrandingMessageType:
case fetchBrandingMessageType, saveBrandingMessageType:
payload = &universalLoginBrandingData{}
case fetchPromptMessageType:
payload = &promptData{}
Expand Down Expand Up @@ -151,16 +150,7 @@ func customizeUniversalLoginCmd(cli *cli) *cobra.Command {
return err
}

var universalLoginBrandingData *universalLoginBrandingData

if err := ansi.Spinner("Fetching Universal Login branding data", func() (err error) {
universalLoginBrandingData, err = fetchUniversalLoginBrandingData(ctx, cli.api, cli.tenant)
return err
}); err != nil {
return err
}

return startWebSocketServer(ctx, cli.api, cli.renderer, universalLoginBrandingData)
return startWebSocketServer(ctx, cli.api, cli.renderer, cli.tenant)
},
}

Expand Down Expand Up @@ -353,7 +343,7 @@ func startWebSocketServer(
ctx context.Context,
api *auth0.API,
display *display.Renderer,
brandingData *universalLoginBrandingData,
tenantDomain string,
) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
Expand All @@ -365,10 +355,10 @@ func startWebSocketServer(
defer listener.Close()

handler := &webSocketHandler{
display: display,
api: api,
shutdown: cancel,
brandingData: brandingData,
display: display,
api: api,
shutdown: cancel,
tenant: tenantDomain,
}

server := &http.Server{
Expand Down Expand Up @@ -421,32 +411,40 @@ func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

connection.SetReadLimit(1e+6) // 1 MB.

loadBrandingMsg := webSocketMessage{
Type: loadBrandingMessageType,
Payload: h.brandingData,
}

if err = connection.WriteJSON(&loadBrandingMsg); err != nil {
h.display.Errorf("Failed to send branding data message: %v", err)
h.display.Warnf("Try restarting the command.")
h.shutdown()
return
}

for {
var message webSocketMessage
if err := connection.ReadJSON(&message); err != nil {
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) ||
websocket.IsUnexpectedCloseError(err, websocket.CloseAbnormalClosure) {
// The connection was closed.
break
}

h.display.Errorf("Failed to read WebSocket message: %v", err)
continue
break
}

switch message.Type {
case fetchBrandingMessageType:
brandingData, err := fetchUniversalLoginBrandingData(r.Context(), h.api, h.tenant)
if err != nil {
h.display.Errorf("Failed to fetch Universal Login branding data: %v", err)

errorMsg := webSocketMessage{
Type: errorMessageType,
Payload: &errorData{
Error: err.Error(),
},
}

if err := connection.WriteJSON(&errorMsg); err != nil {
h.display.Errorf("Failed to send error message: %v", err)
}

continue
}

loadBrandingMsg := webSocketMessage{
Type: fetchBrandingMessageType,
Payload: brandingData,
}

if err = connection.WriteJSON(&loadBrandingMsg); err != nil {
h.display.Errorf("Failed to send branding data message: %v", err)
}
case fetchPromptMessageType:
promptToFetch, ok := message.Payload.(*promptData)
if !ok {
Expand Down

0 comments on commit 1adc975

Please sign in to comment.