Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom marshal and unmarshal funcs to websocket message #880

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 85 additions & 34 deletions internal/cli/universal_login_customize.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ type (
CustomText map[string]interface{} `json:"custom_text,omitempty"`
}

errorData struct {
Error string `json:"error"`
}

webSocketHandler struct {
shutdown context.CancelFunc
display *display.Renderer
Expand All @@ -57,11 +61,69 @@ type (
}

webSocketMessage struct {
Type string `json:"type"`
Payload json.RawMessage `json:"payload"`
Type string `json:"type"`
Payload interface{} `json:"-"`
}
)

// MarshalJSON implements the json.Marshaler interface.
func (m *webSocketMessage) MarshalJSON() ([]byte, error) {
type message webSocketMessage
type messageWrapper struct {
*message
RawPayload json.RawMessage `json:"payload"`
}

w := &messageWrapper{(*message)(m), nil}

if m.Payload != nil {
b, err := json.Marshal(m.Payload)
if err != nil {
return nil, err
}

w.RawPayload = b
}

return json.Marshal(w)
}

// UnmarshalJSON implements the json.Unmarshaler interface.
func (m *webSocketMessage) UnmarshalJSON(b []byte) error {
type message webSocketMessage
type messageWrapper struct {
*message
RawPayload json.RawMessage `json:"payload"`
}

w := &messageWrapper{(*message)(m), nil}

if err := json.Unmarshal(b, w); err != nil {
return err
}

var payload interface{}

switch m.Type {
case loadBrandingMessageType, saveBrandingMessageType:
payload = &universalLoginBrandingData{}
case fetchPromptMessageType:
payload = &promptData{}
default:
payload = make(map[string]interface{})
}

if w.RawPayload != nil {
if err := json.Unmarshal(w.RawPayload, &payload); err != nil {
return err
}
}

m.Payload = payload

return nil
}

func customizeUniversalLoginCmd(cli *cli) *cobra.Command {
cmd := &cobra.Command{
Use: "customize",
Expand Down Expand Up @@ -354,20 +416,12 @@ func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

connection.SetReadLimit(1e+6) // 1 MB.

payload, err := json.Marshal(&h.brandingData)
if err != nil {
h.display.Errorf("Failed to encode the branding data to json: %v", err)
h.display.Warnf("Try restarting the command.")
h.shutdown()
return
}

loadBrandingMsg := webSocketMessage{
Type: loadBrandingMessageType,
Payload: payload,
Payload: h.brandingData,
}

if err = connection.WriteJSON(loadBrandingMsg); err != nil {
if err = connection.WriteJSON(&loadBrandingMsg); err != nil {
h.display.Errorf("Failed to send branding data message: %v", err)
h.display.Warnf("Try restarting the command.")
h.shutdown()
Expand All @@ -389,9 +443,9 @@ func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

switch message.Type {
case fetchPromptMessageType:
var promptToFetch promptData
if err := json.Unmarshal(message.Payload, &promptToFetch); err != nil {
h.display.Errorf("Failed to unmarshal fetch prompt payload: %v", err)
promptToFetch, ok := message.Payload.(*promptData)
if !ok {
h.display.Errorf("Invalid payload type: %T", message.Payload)
continue
}

Expand All @@ -405,49 +459,46 @@ func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.display.Errorf("Failed to fetch custom text for prompt: %v", err)

errorMsg := webSocketMessage{
Type: errorMessageType,
Payload: []byte(fmt.Sprintf(`{"error":%q}`, err.Error())),
Type: errorMessageType,
Payload: &errorData{
Error: err.Error(),
},
}

if err := connection.WriteJSON(errorMsg); err != nil {
if err := connection.WriteJSON(&errorMsg); err != nil {
h.display.Errorf("Failed to send error message: %v", err)
}

continue
}

payload, err := json.Marshal(promptToSend)
if err != nil {
h.display.Errorf("failed to encode the branding data to json: %v", err)
h.shutdown()
return
}

fetchPromptMsg := webSocketMessage{
Type: fetchPromptMessageType,
Payload: payload,
Payload: promptToSend,
}

if err = connection.WriteJSON(fetchPromptMsg); err != nil {
if err = connection.WriteJSON(&fetchPromptMsg); err != nil {
h.display.Errorf("Failed to send prompt data message: %v", err)
continue
}
case saveBrandingMessageType:
var saveBrandingMsg universalLoginBrandingData
if err := json.Unmarshal(message.Payload, &saveBrandingMsg); err != nil {
h.display.Errorf("Failed to unmarshal save branding data payload: %v", err)
saveBrandingMsg, ok := message.Payload.(*universalLoginBrandingData)
if !ok {
h.display.Errorf("Invalid payload type: %T", message.Payload)
continue
}

if err := saveUniversalLoginBrandingData(r.Context(), h.api, &saveBrandingMsg); err != nil {
if err := saveUniversalLoginBrandingData(r.Context(), h.api, saveBrandingMsg); err != nil {
h.display.Errorf("Failed to save branding data: %v", err)

errorMsg := webSocketMessage{
Type: errorMessageType,
Payload: []byte(fmt.Sprintf(`{"error":%q}`, err.Error())),
Type: errorMessageType,
Payload: &errorData{
Error: err.Error(),
},
}

if err := connection.WriteJSON(errorMsg); err != nil {
if err := connection.WriteJSON(&errorMsg); err != nil {
h.display.Errorf("Failed to send error message: %v", err)
}

Expand Down
Loading