Skip to content

Commit

Permalink
Add custom marshal and unmarshal funcs to websocket message
Browse files Browse the repository at this point in the history
  • Loading branch information
sergiught committed Oct 17, 2023
1 parent 64c08a4 commit 8425a56
Showing 1 changed file with 85 additions and 34 deletions.
119 changes: 85 additions & 34 deletions internal/cli/universal_login_customize.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ type (
CustomText map[string]interface{} `json:"custom_text,omitempty"`
}

errorData struct {
Error string `json:"error"`
}

webSocketHandler struct {
shutdown context.CancelFunc
display *display.Renderer
Expand All @@ -57,11 +61,69 @@ type (
}

webSocketMessage struct {
Type string `json:"type"`
Payload json.RawMessage `json:"payload"`
Type string `json:"type"`
Payload interface{} `json:"-"`
}
)

// MarshalJSON implements the json.Marshaler interface.
func (m *webSocketMessage) MarshalJSON() ([]byte, error) {
type message webSocketMessage
type messageWrapper struct {
*message
RawPayload json.RawMessage `json:"payload"`
}

w := &messageWrapper{(*message)(m), nil}

if m.Payload != nil {
b, err := json.Marshal(m.Payload)
if err != nil {
return nil, err
}

w.RawPayload = b
}

return json.Marshal(w)
}

// UnmarshalJSON implements the json.Unmarshaler interface.
func (m *webSocketMessage) UnmarshalJSON(b []byte) error {
type message webSocketMessage
type messageWrapper struct {
*message
RawPayload json.RawMessage `json:"payload"`
}

w := &messageWrapper{(*message)(m), nil}

if err := json.Unmarshal(b, w); err != nil {
return err
}

var payload interface{}

switch m.Type {
case loadBrandingMessageType, saveBrandingMessageType:
payload = &universalLoginBrandingData{}
case fetchPromptMessageType:
payload = &promptData{}
default:
payload = make(map[string]interface{})
}

if w.RawPayload != nil {
if err := json.Unmarshal(w.RawPayload, &payload); err != nil {
return err
}
}

m.Payload = payload

return nil
}

func customizeUniversalLoginCmd(cli *cli) *cobra.Command {
cmd := &cobra.Command{
Use: "customize",
Expand Down Expand Up @@ -354,20 +416,12 @@ func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

connection.SetReadLimit(1e+6) // 1 MB.

payload, err := json.Marshal(&h.brandingData)
if err != nil {
h.display.Errorf("Failed to encode the branding data to json: %v", err)
h.display.Warnf("Try restarting the command.")
h.shutdown()
return
}

loadBrandingMsg := webSocketMessage{
Type: loadBrandingMessageType,
Payload: payload,
Payload: h.brandingData,
}

if err = connection.WriteJSON(loadBrandingMsg); err != nil {
if err = connection.WriteJSON(&loadBrandingMsg); err != nil {
h.display.Errorf("Failed to send branding data message: %v", err)
h.display.Warnf("Try restarting the command.")
h.shutdown()
Expand All @@ -389,9 +443,9 @@ func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

switch message.Type {
case fetchPromptMessageType:
var promptToFetch promptData
if err := json.Unmarshal(message.Payload, &promptToFetch); err != nil {
h.display.Errorf("Failed to unmarshal fetch prompt payload: %v", err)
promptToFetch, ok := message.Payload.(*promptData)
if !ok {
h.display.Errorf("Invalid payload type: %T", message.Payload)
continue
}

Expand All @@ -405,49 +459,46 @@ func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.display.Errorf("Failed to fetch custom text for prompt: %v", err)

errorMsg := webSocketMessage{
Type: errorMessageType,
Payload: []byte(fmt.Sprintf(`{"error":%q}`, err.Error())),
Type: errorMessageType,
Payload: &errorData{
Error: err.Error(),
},
}

if err := connection.WriteJSON(errorMsg); err != nil {
if err := connection.WriteJSON(&errorMsg); err != nil {
h.display.Errorf("Failed to send error message: %v", err)
}

continue
}

payload, err := json.Marshal(promptToSend)
if err != nil {
h.display.Errorf("failed to encode the branding data to json: %v", err)
h.shutdown()
return
}

fetchPromptMsg := webSocketMessage{
Type: fetchPromptMessageType,
Payload: payload,
Payload: promptToSend,
}

if err = connection.WriteJSON(fetchPromptMsg); err != nil {
if err = connection.WriteJSON(&fetchPromptMsg); err != nil {
h.display.Errorf("Failed to send prompt data message: %v", err)
continue
}
case saveBrandingMessageType:
var saveBrandingMsg universalLoginBrandingData
if err := json.Unmarshal(message.Payload, &saveBrandingMsg); err != nil {
h.display.Errorf("Failed to unmarshal save branding data payload: %v", err)
saveBrandingMsg, ok := message.Payload.(*universalLoginBrandingData)
if !ok {
h.display.Errorf("Invalid payload type: %T", message.Payload)
continue
}

if err := saveUniversalLoginBrandingData(r.Context(), h.api, &saveBrandingMsg); err != nil {
if err := saveUniversalLoginBrandingData(r.Context(), h.api, saveBrandingMsg); err != nil {
h.display.Errorf("Failed to save branding data: %v", err)

errorMsg := webSocketMessage{
Type: errorMessageType,
Payload: []byte(fmt.Sprintf(`{"error":%q}`, err.Error())),
Type: errorMessageType,
Payload: &errorData{
Error: err.Error(),
},
}

if err := connection.WriteJSON(errorMsg); err != nil {
if err := connection.WriteJSON(&errorMsg); err != nil {
h.display.Errorf("Failed to send error message: %v", err)
}

Expand Down

0 comments on commit 8425a56

Please sign in to comment.