diff --git a/internal/cli/try_login.go b/internal/cli/try_login.go index c392e92c4..cf9eee731 100644 --- a/internal/cli/try_login.go +++ b/internal/cli/try_login.go @@ -24,6 +24,10 @@ const ( ) func tryLoginCmd(cli *cli) *cobra.Command { + var clientID string + var connectionName string + var reveal bool + cmd := &cobra.Command{ Use: "try-login", Short: "Try out your universal login box", @@ -40,7 +44,6 @@ Launch a browser to try out your universal login box for the given client. // "CLI Login Testing" client if none passed. This client is only // used for testing login from the CLI and will be created if it // does not exist. - clientID, _ := cmd.Flags().GetString("client-id") if clientID == "" { client, err := getOrCreateCLITesterClient(cli.api.Client) if err != nil { @@ -67,11 +70,15 @@ Launch a browser to try out your universal login box for the given client. } if client.GetInitiateLoginURI() != cliLoginTestingInitiateLoginURI { + if connectionName != "" { + cli.renderer.Warnf("Specific connections are not supported when using a non-default client, ignoring.") + cli.renderer.Warnf("You should ensure the connection you wish to test is enabled for the client you want to use in the Auth0 Dashboard.") + } return open.URL(client.GetInitiateLoginURI()) } // Build a login URL and initiate login in a browser window. - loginURL, err := buildInitiateLoginURL(tenant.Domain, client.GetClientID()) + loginURL, err := buildInitiateLoginURL(tenant.Domain, client.GetClientID(), connectionName) if err != nil { return err } @@ -107,15 +114,15 @@ Launch a browser to try out your universal login box for the given client. return err } - reveal, _ := cmd.Flags().GetBool("reveal") cli.renderer.TryLogin(userInfo, tokenResponse, reveal) return nil }, } cmd.SetUsageTemplate(resourceUsageTemplate()) - cmd.Flags().StringP("client-id", "c", "", "Client ID for which to test login.") - cmd.Flags().BoolP("reveal", "r", false, "⚠️ Reveal tokens after successful login.") + cmd.Flags().StringVarP(&clientID, "client-id", "c", "", "Client ID for which to test login.") + cmd.Flags().StringVarP(&connectionName, "connection", "", "", "Connection to test during login.") + cmd.Flags().BoolVarP(&reveal, "reveal", "r", false, "⚠️ Reveal tokens after successful login.") return cmd } @@ -146,7 +153,7 @@ func getOrCreateCLITesterClient(clientManager auth0.ClientAPI) (*management.Clie // buildInitiateLoginURL constructs a URL + query string that can be used to // initiate a login-flow from the CLI. -func buildInitiateLoginURL(domain, clientID string) (string, error) { +func buildInitiateLoginURL(domain, clientID, connectionName string) (string, error) { var path string = "/authorize" q := url.Values{} @@ -156,6 +163,10 @@ func buildInitiateLoginURL(domain, clientID string) (string, error) { q.Add("scope", cliLoginTestingScopes) q.Add("redirect_uri", cliLoginTestingCallbackURL) + if connectionName != "" { + q.Add("connection", connectionName) + } + u := &url.URL{ Scheme: "https", Host: domain, @@ -171,20 +182,32 @@ func buildInitiateLoginURL(domain, clientID string) (string, error) { // `code` is extracted from the query string (if any), and returns it to the // caller. func waitForBrowserCallback() (string, error) { - codeCh := make(chan string) + type callback struct { + code string + err string + errDescription string + } + + cbCh := make(chan *callback) errCh := make(chan error) m := http.NewServeMux() s := http.Server{Addr: cliLoginTestingCallbackAddr, Handler: m} m.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - authCode := r.URL.Query().Get("code") - if authCode == "" { + cb := &callback{ + code: r.URL.Query().Get("code"), + err: r.URL.Query().Get("error"), + errDescription: r.URL.Query().Get("error_description"), + } + + if cb.code == "" { _, _ = w.Write([]byte("

❌ Unable to extract code from request, please try authenticating again

")) } else { _, _ = w.Write([]byte("

👋 You can close the window and go back to the CLI to see the user info and tokens

")) } - codeCh <- authCode + + cbCh <- cb }) go func() { @@ -194,12 +217,16 @@ func waitForBrowserCallback() (string, error) { }() select { - case code := <-codeCh: + case cb := <-cbCh: ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() + defer func(c context.Context) { _ = s.Shutdown(ctx) }(ctx) - err := s.Shutdown(ctx) - return code, err + var err error + if cb.err != "" { + err = fmt.Errorf("%s: %s", cb.err, cb.errDescription) + } + return cb.code, err case err := <-errCh: return "", err }