diff --git a/go.mod b/go.mod index a687d139b..ad2a1b2f1 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.4.0 + github.com/gorilla/websocket v1.5.0 github.com/hashicorp/go-version v1.6.0 github.com/hashicorp/hc-install v0.6.1 github.com/hashicorp/terraform-exec v0.19.0 diff --git a/go.sum b/go.sum index 3ed7ebfd7..01b66a5ea 100644 --- a/go.sum +++ b/go.sum @@ -86,6 +86,8 @@ github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek= diff --git a/internal/cli/universal_login_customize.go b/internal/cli/universal_login_customize.go index 9b1f7b627..4291e6832 100644 --- a/internal/cli/universal_login_customize.go +++ b/internal/cli/universal_login_customize.go @@ -2,15 +2,25 @@ package cli import ( "context" + "fmt" + "net" + "net/http" + "net/url" + "time" "github.com/auth0/go-auth0/management" + "github.com/gorilla/websocket" + "github.com/pkg/browser" "github.com/spf13/cobra" "golang.org/x/sync/errgroup" "github.com/auth0/auth0-cli/internal/ansi" "github.com/auth0/auth0-cli/internal/auth0" + "github.com/auth0/auth0-cli/internal/display" ) +const webAppURL = "http://localhost:5173" + type ( universalLoginBrandingData struct { AuthenticationProfile *management.Prompt `json:"auth_profile"` @@ -32,6 +42,13 @@ type ( Prompt string `json:"prompt"` CustomText map[string]map[string]interface{} `json:"custom_text"` } + + webSocketHandler struct { + shutdown context.CancelFunc + display *display.Renderer + api *auth0.API + brandingData *universalLoginBrandingData + } ) func customizeUniversalLoginCmd(cli *cli) *cobra.Command { @@ -61,9 +78,7 @@ func customizeUniversalLoginCmd(cli *cli) *cobra.Command { return err } - cli.renderer.JSONResult(universalLoginBrandingData) - - return nil + return startWebSocketServer(ctx, cli.api, cli.renderer, universalLoginBrandingData) }, } @@ -237,3 +252,92 @@ func fetchPromptCustomTextWithDefaults( CustomText: brandingTextTranslations, }, nil } + +func startWebSocketServer( + ctx context.Context, + api *auth0.API, + display *display.Renderer, + brandingData *universalLoginBrandingData, +) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return err + } + defer listener.Close() + + handler := &webSocketHandler{ + display: display, + api: api, + shutdown: cancel, + brandingData: brandingData, + } + + server := &http.Server{ + Handler: handler, + ReadTimeout: time.Minute * 10, + WriteTimeout: time.Minute * 10, + } + + errChan := make(chan error, 1) + go func() { + errChan <- server.Serve(listener) + }() + + openWebAppInBrowser(display, listener.Addr()) + + select { + case err := <-errChan: + return err + case <-ctx.Done(): + return server.Close() + } +} + +func openWebAppInBrowser(display *display.Renderer, addr net.Addr) { + port := addr.(*net.TCPAddr).Port + webAppURLWithPort := fmt.Sprintf("%s?ws_port=%d", webAppURL, port) + + display.Infof("Perform your changes within the editor: %q", webAppURLWithPort) + + if err := browser.OpenURL(webAppURLWithPort); err != nil { + display.Warnf("Failed to open the browser. Visit the URL manually.") + } +} + +func (h *webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: checkOriginFunc, + } + + connection, err := upgrader.Upgrade(w, r, nil) + if err != nil { + h.display.Errorf("error accepting WebSocket connection: %v", err) + h.shutdown() + return + } + + connection.SetReadLimit(1e+6) // 1 MB. + + if err = connection.WriteJSON(h.brandingData); err != nil { + h.display.Errorf("failed to write json message: %v", err) + h.shutdown() + return + } +} + +func checkOriginFunc(r *http.Request) bool { + origin := r.Header["Origin"] + if len(origin) == 0 { + return false + } + + originURL, err := url.Parse(origin[0]) + if err != nil { + return false + } + + return originURL.String() == webAppURL +} diff --git a/internal/cli/universal_login_customize_test.go b/internal/cli/universal_login_customize_test.go index 25837a5b0..bc4f0c230 100644 --- a/internal/cli/universal_login_customize_test.go +++ b/internal/cli/universal_login_customize_test.go @@ -3,6 +3,7 @@ package cli import ( "context" "fmt" + "net/http" "testing" "github.com/auth0/go-auth0/management" @@ -165,6 +166,7 @@ func TestFetchUniversalLoginBrandingData(t *testing.T) { "signupActionText": "${footerText}", "title": "Welcome friend, glad to have you!", "user-blocked": "Your account has been blocked after multiple consecutive login attempts.", + "usernameOnlyPlaceholder": "Username", "usernamePlaceholder": "Username or email address", "wrong-credentials": "Wrong username or password", "wrong-email-credentials": "Wrong email or password", @@ -304,6 +306,7 @@ func TestFetchUniversalLoginBrandingData(t *testing.T) { "signupActionText": "${footerText}", "title": "Welcome friend, glad to have you!", "user-blocked": "Your account has been blocked after multiple consecutive login attempts.", + "usernameOnlyPlaceholder": "Username", "usernamePlaceholder": "Username or email address", "wrong-credentials": "Wrong username or password", "wrong-email-credentials": "Wrong email or password", @@ -445,6 +448,7 @@ func TestFetchUniversalLoginBrandingData(t *testing.T) { "signupActionText": "${footerText}", "title": "Welcome friend, glad to have you!", "user-blocked": "Your account has been blocked after multiple consecutive login attempts.", + "usernameOnlyPlaceholder": "Username", "usernamePlaceholder": "Username or email address", "wrong-credentials": "Wrong username or password", "wrong-email-credentials": "Wrong email or password", @@ -666,6 +670,7 @@ func TestFetchUniversalLoginBrandingData(t *testing.T) { "signupActionText": "${footerText}", "title": "Welcome friend, glad to have you!", "user-blocked": "Your account has been blocked after multiple consecutive login attempts.", + "usernameOnlyPlaceholder": "Username", "usernamePlaceholder": "Username or email address", "wrong-credentials": "Wrong username or password", "wrong-email-credentials": "Wrong email or password", @@ -829,3 +834,53 @@ func TestFetchUniversalLoginBrandingData(t *testing.T) { }) } } + +func TestCheckOriginFunc(t *testing.T) { + var testCases = []struct { + testName string + request *http.Request + expected bool + }{ + { + testName: "No Origin Header", + request: &http.Request{ + Header: http.Header{}, + }, + expected: false, + }, + { + testName: "Valid Origin", + request: &http.Request{ + Header: http.Header{ + "Origin": []string{webAppURL}, + }, + }, + expected: true, + }, + { + testName: "Invalid Origin", + request: &http.Request{ + Header: http.Header{ + "Origin": []string{"https://invalid.com"}, + }, + }, + expected: false, + }, + { + testName: "Malformed Origin", + request: &http.Request{ + Header: http.Header{ + "Origin": []string{"malformed-url"}, + }, + }, + expected: false, + }, + } + + for _, test := range testCases { + t.Run(test.testName, func(t *testing.T) { + actual := checkOriginFunc(test.request) + assert.Equal(t, test.expected, actual) + }) + } +}