From a1ba34b360b877783dc4b45033656ef8d20a68ff Mon Sep 17 00:00:00 2001 From: Cyril David Date: Sat, 20 Mar 2021 15:47:11 -0700 Subject: [PATCH] Add rules (#171) * step 1: restore from our earlier iteration * Very rough first pass on create * Better name * Wrap up editor prompt / rules create flow * Rules: CRUD is done * Cleanup * Lint * Finalize rules Pick 5 templates to start. * Tweak flags so we don't repeat ourselves too much * Add docs for flags.Required() * Switch to using new ansi.Waiting for most cases * Ask name after fetching rule * Subtle fix: update && not required should not prompt * Cleanup * Add a note for now that we're going to do something for next time * Remove `Required()` use * Amend to have AlwaysPrompt logic --- internal/auth/auth.go | 1 + internal/auth/auth_test.go | 1 + internal/auth0/auth0.go | 2 + internal/auth0/rule.go | 25 ++ internal/cli/apps.go | 11 +- internal/cli/cli.go | 8 + ...rule-template-add-email-to-access-token.js | 8 + ...rule-template-check-last-password-reset.js | 12 + internal/cli/data/rule-template-empty-rule.js | 3 + .../rule-template-ip-address-allow-list.js | 12 + .../rule-template-ip-address-deny-list.js | 14 + internal/cli/flags.go | 31 +- internal/cli/root.go | 1 + internal/cli/rules.go | 380 ++++++++++++++++++ internal/cli/rules_embed.go | 22 + internal/display/rules.go | 101 +++++ internal/prompt/editor.go | 106 +++++ 17 files changed, 720 insertions(+), 18 deletions(-) create mode 100644 internal/auth0/rule.go create mode 100644 internal/cli/data/rule-template-add-email-to-access-token.js create mode 100644 internal/cli/data/rule-template-check-last-password-reset.js create mode 100644 internal/cli/data/rule-template-empty-rule.js create mode 100644 internal/cli/data/rule-template-ip-address-allow-list.js create mode 100644 internal/cli/data/rule-template-ip-address-deny-list.js create mode 100644 internal/cli/rules.go create mode 100644 internal/cli/rules_embed.go create mode 100644 internal/display/rules.go create mode 100644 internal/prompt/editor.go diff --git a/internal/auth/auth.go b/internal/auth/auth.go index a7e5543e5..a85c4426e 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -27,6 +27,7 @@ var requiredScopes = []string{ "offline_access", // <-- to get a refresh token. "create:clients", "delete:clients", "read:clients", "update:clients", "create:resource_servers", "delete:resource_servers", "read:resource_servers", "update:resource_servers", + "create:rules", "delete:rules", "read:rules", "update:rules", "read:client_keys", "read:logs", } diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 7e50244da..b05e91310 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -7,6 +7,7 @@ func TestRequiredScopes(t *testing.T) { crudResources := []string{ "clients", "resource_servers", + "rules", } crudPrefixes := []string{"create:", "delete:", "read:", "update:"} diff --git a/internal/auth0/auth0.go b/internal/auth0/auth0.go index 4c69d64d3..5ba7c7201 100644 --- a/internal/auth0/auth0.go +++ b/internal/auth0/auth0.go @@ -15,6 +15,7 @@ type API struct { Client ClientAPI Log LogAPI ResourceServer ResourceServerAPI + Rule RuleAPI } func NewAPI(m *management.Management) *API { @@ -26,6 +27,7 @@ func NewAPI(m *management.Management) *API { Client: m.Client, Log: m.Log, ResourceServer: m.ResourceServer, + Rule: m.Rule, } } diff --git a/internal/auth0/rule.go b/internal/auth0/rule.go new file mode 100644 index 000000000..de0b48621 --- /dev/null +++ b/internal/auth0/rule.go @@ -0,0 +1,25 @@ +//go:generate mockgen -source=rule.go -destination=rule_mock.go -package=auth0 + +package auth0 + +import "gopkg.in/auth0.v5/management" + +type RuleAPI interface { + // Create a new rule. + // + // Note: Changing a rule's stage of execution from the default `login_success` + // can change the rule's function signature to have user omitted. + Create(r *management.Rule, opts ...management.RequestOption) error + + // Retrieve rule details. Accepts a list of fields to include or exclude in the result. + Read(id string, opts ...management.RequestOption) (r *management.Rule, err error) + + // Update an existing rule. + Update(id string, r *management.Rule, opts ...management.RequestOption) error + + // Delete a rule. + Delete(id string, opts ...management.RequestOption) error + + // List all rules. + List(opts ...management.RequestOption) (r *management.RuleList, err error) +} diff --git a/internal/cli/apps.go b/internal/cli/apps.go index 358f843cb..403ae5e4f 100644 --- a/internal/cli/apps.go +++ b/internal/cli/apps.go @@ -49,11 +49,12 @@ var ( IsRequired: false, } appCallbacks = Flag{ - Name: "Callback URLs", - LongForm: "callbacks", - ShortForm: "c", - Help: "After the user authenticates we will only call back to any of these URLs. You can specify multiple valid URLs by comma-separating them (typically to handle different environments like QA or testing). Make sure to specify the protocol (https://) otherwise the callback may fail in some cases. With the exception of custom URI schemes for native apps, all callbacks should use protocol https://.", - IsRequired: false, + Name: "Callback URLs", + LongForm: "callbacks", + ShortForm: "c", + Help: "After the user authenticates we will only call back to any of these URLs. You can specify multiple valid URLs by comma-separating them (typically to handle different environments like QA or testing). Make sure to specify the protocol (https://) otherwise the callback may fail in some cases. With the exception of custom URI schemes for native apps, all callbacks should use protocol https://.", + IsRequired: false, + AlwaysPrompt: true, } appOrigins = Flag{ Name: "Allowed Origin URLs", diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 9cf7221ea..8d768626e 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -381,3 +381,11 @@ func prepareInteractivity(cmd *cobra.Command) { }) } } + +func flagOptionsFromMapping(mapping map[string]string) []string { + result := make([]string, 0, len(mapping)) + for k := range mapping { + result = append(result, k) + } + return result +} diff --git a/internal/cli/data/rule-template-add-email-to-access-token.js b/internal/cli/data/rule-template-add-email-to-access-token.js new file mode 100644 index 000000000..615e4964a --- /dev/null +++ b/internal/cli/data/rule-template-add-email-to-access-token.js @@ -0,0 +1,8 @@ +function addEmailToAccessToken(user, context, callback) { + // This rule adds the authenticated user's email address to the access token. + + var namespace = 'https://example.com/'; + + context.accessToken[namespace + 'email'] = user.email; + return callback(null, user, context); +} diff --git a/internal/cli/data/rule-template-check-last-password-reset.js b/internal/cli/data/rule-template-check-last-password-reset.js new file mode 100644 index 000000000..b7fb87220 --- /dev/null +++ b/internal/cli/data/rule-template-check-last-password-reset.js @@ -0,0 +1,12 @@ +function checkLastPasswordReset(user, context, callback) { + function daydiff(first, second) { + return (second - first) / (1000 * 60 * 60 * 24); + } + + const last_password_change = user.last_password_reset || user.created_at; + + if (daydiff(new Date(last_password_change), new Date()) > 30) { + return callback(new UnauthorizedError('please change your password')); + } + callback(null, user, context); +} diff --git a/internal/cli/data/rule-template-empty-rule.js b/internal/cli/data/rule-template-empty-rule.js new file mode 100644 index 000000000..abcc2116a --- /dev/null +++ b/internal/cli/data/rule-template-empty-rule.js @@ -0,0 +1,3 @@ +function(user, context, cb) { + cb(null, user, context); +} diff --git a/internal/cli/data/rule-template-ip-address-allow-list.js b/internal/cli/data/rule-template-ip-address-allow-list.js new file mode 100644 index 000000000..94fd03980 --- /dev/null +++ b/internal/cli/data/rule-template-ip-address-allow-list.js @@ -0,0 +1,12 @@ +function ipAddressAllowList(user, context, callback) { + const allowlist = ['1.2.3.4', '2.3.4.5']; // authorized IPs + const userHasAccess = allowlist.some(function (ip) { + return context.request.ip === ip; + }); + + if (!userHasAccess) { + return callback(new Error('Access denied from this IP address.')); + } + + return callback(null, user, context); +} diff --git a/internal/cli/data/rule-template-ip-address-deny-list.js b/internal/cli/data/rule-template-ip-address-deny-list.js new file mode 100644 index 000000000..a47a8403f --- /dev/null +++ b/internal/cli/data/rule-template-ip-address-deny-list.js @@ -0,0 +1,14 @@ +function ipAddressDenylist(user, context, callback) { + const denylist = ['1.2.3.4', '2.3.4.5']; // unauthorized IPs + const notAuthorized = denylist.some(function (ip) { + return context.request.ip === ip; + }); + + if (notAuthorized) { + return callback( + new UnauthorizedError('Access denied from this IP address.') + ); + } + + return callback(null, user, context); +} diff --git a/internal/cli/flags.go b/internal/cli/flags.go index 162219f9f..5ae7fbb77 100644 --- a/internal/cli/flags.go +++ b/internal/cli/flags.go @@ -7,11 +7,12 @@ import ( ) type Flag struct { - Name string - LongForm string - ShortForm string - Help string - IsRequired bool + Name string + LongForm string + ShortForm string + Help string + IsRequired bool + AlwaysPrompt bool } func (f Flag) GetName() string { @@ -94,14 +95,6 @@ func selectFlag(cmd *cobra.Command, f *Flag, value interface{}, options []string return nil } -func shouldAsk(cmd *cobra.Command, f *Flag, isUpdate bool) bool { - if isUpdate { - return shouldPromptWhenFlagless(cmd, f.LongForm) - } - - return shouldPrompt(cmd, f.LongForm) -} - func registerString(cmd *cobra.Command, f *Flag, value *string, defaultValue string, isUpdate bool) { cmd.Flags().StringVarP(value, f.LongForm, f.ShortForm, defaultValue, f.Help) @@ -134,6 +127,18 @@ func registerBool(cmd *cobra.Command, f *Flag, value *bool, defaultValue bool, i } } +func shouldAsk(cmd *cobra.Command, f *Flag, isUpdate bool) bool { + if isUpdate { + if !f.AlwaysPrompt { + return false + } + + return shouldPromptWhenFlagless(cmd, f.LongForm) + } + + return shouldPrompt(cmd, f.LongForm) +} + func markFlagRequired(cmd *cobra.Command, f *Flag, isUpdate bool) error { if f.IsRequired && !isUpdate { return cmd.MarkFlagRequired(f.LongForm) diff --git a/internal/cli/root.go b/internal/cli/root.go index 30910be74..a21b8df22 100644 --- a/internal/cli/root.go +++ b/internal/cli/root.go @@ -80,6 +80,7 @@ func Execute() { rootCmd.AddCommand(loginCmd(cli)) rootCmd.AddCommand(tenantsCmd(cli)) rootCmd.AddCommand(appsCmd(cli)) + rootCmd.AddCommand(rulesCmd(cli)) rootCmd.AddCommand(quickstartsCmd(cli)) rootCmd.AddCommand(apisCmd(cli)) rootCmd.AddCommand(testCmd(cli)) diff --git a/internal/cli/rules.go b/internal/cli/rules.go new file mode 100644 index 000000000..6eb0029c6 --- /dev/null +++ b/internal/cli/rules.go @@ -0,0 +1,380 @@ +package cli + +import ( + "fmt" + + "github.com/auth0/auth0-cli/internal/ansi" + "github.com/auth0/auth0-cli/internal/auth0" + "github.com/auth0/auth0-cli/internal/prompt" + "github.com/spf13/cobra" + "gopkg.in/auth0.v5/management" +) + +var ( + ruleName = Flag{ + Name: "Name", + LongForm: "name", + ShortForm: "n", + Help: "Name of the rule.", + IsRequired: true, + } + + ruleTemplate = Flag{ + Name: "Template", + LongForm: "template", + ShortForm: "t", + Help: "Template to use for the rule.", + } + + ruleTemplateOptions = flagOptionsFromMapping(ruleTemplateMappings) + + ruleEnabled = Flag{ + Name: "Enabled", + LongForm: "enabled", + ShortForm: "e", + Help: "Enable (or disable) a rule.", + } + + ruleTemplateMappings = map[string]string{ + "Empty rule": ruleTemplateEmptyRule, + "Add email to access token": ruleTemplateAddEmailToAccessToken, + "Check last password reset": ruleTemplateCheckLastPasswordReset, + "IP address allow list": ruleTemplateIPAddressAllowList, + "IP address deny list": ruleTemplateIPAddressDenyList, + } +) + +func rulesCmd(cli *cli) *cobra.Command { + cmd := &cobra.Command{ + Use: "rules", + Short: "Manage rules for clients", + } + + cmd.SetUsageTemplate(resourceUsageTemplate()) + cmd.AddCommand(listRulesCmd(cli)) + cmd.AddCommand(createRuleCmd(cli)) + cmd.AddCommand(showRuleCmd(cli)) + cmd.AddCommand(updateRuleCmd(cli)) + cmd.AddCommand(deleteRuleCmd(cli)) + + return cmd +} + +func listRulesCmd(cli *cli) *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "List your rules", + Long: `List the rules in your current tenant.`, + RunE: func(cmd *cobra.Command, args []string) error { + var rules []*management.Rule + err := ansi.Waiting(func() error { + ruleList, err := cli.api.Rule.List() + if err != nil { + return err + } + rules = ruleList.Rules + return nil + }) + + if err != nil { + return err + } + + cli.renderer.RulesList(rules) + return nil + }, + } + + return cmd +} + +func createRuleCmd(cli *cli) *cobra.Command { + var inputs struct { + Name string + Template string + Enabled bool + } + + cmd := &cobra.Command{ + Use: "create", + Short: "Create a new rule", + Long: `Create a new rule: + +auth0 rules create --name "My Rule" --template [empty-rule]" + `, + PreRun: func(cmd *cobra.Command, args []string) { + prepareInteractivity(cmd) + }, + RunE: func(cmd *cobra.Command, args []string) error { + if err := ruleName.Ask(cmd, &inputs.Name); err != nil { + return err + } + + if err := ruleTemplate.Select(cmd, &inputs.Template, ruleTemplateOptions); err != nil { + return err + } + + // TODO(cyx): we can re-think this once we have + // `--stdin` based commands. For now we don't have + // those yet, so keeping this simple. + script, err := prompt.CaptureInputViaEditor( + ruleTemplateMappings[inputs.Template], + inputs.Name+".*.js", + ) + if err != nil { + return fmt.Errorf("Failed to capture input from the editor: %w", err) + } + + rule := &management.Rule{ + Name: &inputs.Name, + Script: auth0.String(script), + Enabled: &inputs.Enabled, + } + + err = ansi.Waiting(func() error { + return cli.api.Rule.Create(rule) + }) + + if err != nil { + return fmt.Errorf("Unable to create rule: %w", err) + } + + cli.renderer.RuleCreate(rule) + return nil + }, + } + + ruleName.RegisterString(cmd, &inputs.Name, "") + ruleTemplate.RegisterString(cmd, &inputs.Template, "") + ruleEnabled.RegisterBool(cmd, &inputs.Enabled, true) + + return cmd +} + +func showRuleCmd(cli *cli) *cobra.Command { + var inputs struct { + ID string + } + + cmd := &cobra.Command{ + Use: "show", + Args: cobra.MaximumNArgs(1), + Short: "Show a rule", + Long: `Show a rule: + +auth0 rules show +`, + PreRun: func(cmd *cobra.Command, args []string) { + prepareInteractivity(cmd) + }, + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + inputs.ID = args[0] + } else { + // TODO(cyx): Consider making a primitive for + // Argument to ask using a provided func. + var err error + inputs.ID, err = promptForRuleViaDropdown(cli, cmd) + if err != nil { + return err + } + + if inputs.ID == "" { + cli.renderer.Infof("There are currently no rules.") + return nil + } + } + + var rule *management.Rule + + err := ansi.Waiting(func() error { + var err error + rule, err = cli.api.Rule.Read(inputs.ID) + return err + }) + + if err != nil { + return fmt.Errorf("Unable to load rule. The ID %v specified doesn't exist", inputs.ID) + } + + cli.renderer.RuleShow(rule) + return nil + }, + } + + return cmd +} + +func deleteRuleCmd(cli *cli) *cobra.Command { + var inputs struct { + ID string + } + + cmd := &cobra.Command{ + Use: "delete", + Short: "Delete a rule", + Long: `Delete a rule: + +auth0 rules delete rul_d2VSaGlyaW5n`, + PreRun: func(cmd *cobra.Command, args []string) { + prepareInteractivity(cmd) + }, + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + inputs.ID = args[0] + } else { + // TODO(cyx): Consider making a primitive for + // Argument to ask using a provided func. + var err error + inputs.ID, err = promptForRuleViaDropdown(cli, cmd) + if err != nil { + return err + } + + if inputs.ID == "" { + cli.renderer.Infof("There are currently no rules.") + return nil + } + } + + err := ansi.Spinner("Deleting rule", func() error { + return cli.api.Rule.Delete(inputs.ID) + }) + + if err != nil { + return err + } + + return nil + }, + } + + return cmd +} + +func updateRuleCmd(cli *cli) *cobra.Command { + var inputs struct { + ID string + Name string + Enabled bool + } + + cmd := &cobra.Command{ + Use: "update", + Short: "update a rule", + Long: `Update a rule: + +auth0 rules update --id rul_d2VSaGlyaW5n --name "My Updated Rule" --enabled=false + `, + PreRun: func(cmd *cobra.Command, args []string) { + prepareInteractivity(cmd) + }, + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + inputs.ID = args[0] + } else { + // TODO(cyx): Consider making a primitive for + // Argument to ask using a provided func. + var err error + inputs.ID, err = promptForRuleViaDropdown(cli, cmd) + if err != nil { + return err + } + + if inputs.ID == "" { + cli.renderer.Infof("There are currently no rules.") + return nil + } + } + + var rule *management.Rule + err := ansi.Waiting(func() error { + var err error + rule, err = cli.api.Rule.Read(inputs.ID) + return err + }) + if err != nil { + return fmt.Errorf("Failed to fetch rule with ID: %s %v", inputs.ID, err) + } + + if err := ruleName.AskU(cmd, &inputs.Name); err != nil { + return err + } + + // TODO(cyx): we can re-think this once we have + // `--stdin` based commands. For now we don't have + // those yet, so keeping this simple. + script, err := prompt.CaptureInputViaEditor( + rule.GetScript(), + rule.GetName()+".*.js", + ) + if err != nil { + return fmt.Errorf("Failed to capture input from the editor: %w", err) + } + + // Since name is optional, no need to specify what they chose. + if inputs.Name == "" { + inputs.Name = rule.GetName() + } + + // Prepare rule payload for update. This will also be + // re-hydrated by the SDK, which we'll use below during + // display. + rule = &management.Rule{ + Name: &inputs.Name, + Script: &script, + Enabled: &inputs.Enabled, + } + + err = ansi.Waiting(func() error { + return cli.api.Rule.Update(inputs.ID, rule) + }) + + if err != nil { + return err + } + + cli.renderer.RuleUpdate(rule) + return nil + }, + } + + ruleName.RegisterStringU(cmd, &inputs.Name, "") + ruleEnabled.RegisterBool(cmd, &inputs.Enabled, true) + + return cmd +} + +func promptForRuleViaDropdown(cli *cli, cmd *cobra.Command) (id string, err error) { + dropdown := Flag{Name: "Rule"} + + var rules []*management.Rule + + // == Start experimental dropdown for names => id. + // TODO(cyx): Consider extracting this + // pattern once we've done more of it. + err = ansi.Waiting(func() error { + list, err := cli.api.Rule.List() + if err != nil { + return err + } + rules = list.Rules + return nil + }) + + if err != nil || len(rules) == 0 { + return "", err + } + + mapping := map[string]string{} + for _, r := range rules { + mapping[r.GetName()] = r.GetID() + } + + var name string + if err := dropdown.Select(cmd, &name, flagOptionsFromMapping(mapping)); err != nil { + return "", err + } + + return mapping[name], nil +} diff --git a/internal/cli/rules_embed.go b/internal/cli/rules_embed.go new file mode 100644 index 000000000..14fc553b0 --- /dev/null +++ b/internal/cli/rules_embed.go @@ -0,0 +1,22 @@ +package cli + +import ( + _ "embed" +) + +var ( + //go:embed data/rule-template-empty-rule.js + ruleTemplateEmptyRule string + + //go:embed data/rule-template-add-email-to-access-token.js + ruleTemplateAddEmailToAccessToken string + + //go:embed data/rule-template-check-last-password-reset.js + ruleTemplateCheckLastPasswordReset string + + //go:embed data/rule-template-ip-address-allow-list.js + ruleTemplateIPAddressAllowList string + + //go:embed data/rule-template-ip-address-deny-list.js + ruleTemplateIPAddressDenyList string +) diff --git a/internal/display/rules.go b/internal/display/rules.go new file mode 100644 index 000000000..f254e2036 --- /dev/null +++ b/internal/display/rules.go @@ -0,0 +1,101 @@ +package display + +import ( + "fmt" + "sort" + "strconv" + + "github.com/auth0/auth0-cli/internal/ansi" + "gopkg.in/auth0.v5/management" +) + +type ruleView struct { + Name string + Enabled bool + ID string + Order int + Script string + + raw interface{} +} + +func (v *ruleView) AsTableHeader() []string { + return []string{"Id", "Name", "Enabled", "Order"} +} + +func (v *ruleView) AsTableRow() []string { + return []string{v.ID, v.Name, strconv.FormatBool(v.Enabled), fmt.Sprintf("%d", v.Order)} +} + +func (v *ruleView) KeyValues() [][]string { + return [][]string{ + []string{"NAME", v.Name}, + []string{"ID", v.ID}, + []string{"ENABLED", strconv.FormatBool(v.Enabled)}, + []string{"SCRIPT", v.Script}, + } +} + +func (v *ruleView) Object() interface{} { + return v.raw +} + +func (r *Renderer) RulesList(rules []*management.Rule) { + r.Heading(ansi.Bold(r.Tenant), "rules\n") + var res []View + + //@TODO Provide sort options via flags + sort.Slice(rules, func(i, j int) bool { + return rules[i].GetOrder() < rules[j].GetOrder() + }) + + for _, rule := range rules { + res = append(res, &ruleView{ + Name: *rule.Name, + ID: *rule.ID, + Enabled: *rule.Enabled, + Order: *rule.Order, + }) + } + + r.Results(res) + +} + +func (r *Renderer) RuleCreate(rule *management.Rule) { + r.Heading(ansi.Bold(r.Tenant), "rule created\n") + r.Result(makeRuleView(rule)) + r.Newline() + + // TODO(cyx): possibly guard this with a --no-hint flag. + r.Infof("%s: To edit this rule, do `auth0 rules update %s`", + ansi.Faint("Hint"), + rule.GetID(), + ) + + r.Infof("%s: You might wanna try `auth0 test login", + ansi.Faint("Hint"), + ) +} + +func (r *Renderer) RuleUpdate(rule *management.Rule) { + r.Heading(ansi.Bold(r.Tenant), "rule updated\n") + r.Result(makeRuleView(rule)) +} + +func (r *Renderer) RuleShow(rule *management.Rule) { + r.Heading(ansi.Bold(r.Tenant), "rule\n") + r.Result(makeRuleView(rule)) +} + +func makeRuleView(rule *management.Rule) *ruleView { + return &ruleView{ + Name: rule.GetName(), + ID: rule.GetID(), + Enabled: rule.GetEnabled(), + Order: rule.GetOrder(), + Script: rule.GetScript(), + + raw: rule, + } +} diff --git a/internal/prompt/editor.go b/internal/prompt/editor.go new file mode 100644 index 000000000..a063c0969 --- /dev/null +++ b/internal/prompt/editor.go @@ -0,0 +1,106 @@ +package prompt + +import ( + "fmt" + "os" + "os/exec" + "strings" +) + +const defaultEditor = "vim" + +var defaultEditorPrompt = &editorPrompt{defaultEditor: defaultEditor} + +// CaptureInputViaEditor is the high level function to use in this package in +// order to capture input from an editor. +// +// The arguments have been tailored for our use of strings mostly in the rest +// of the CLI even though internally we're using []byte. +func CaptureInputViaEditor(contents, pattern string) (result string, err error) { + v, err := defaultEditorPrompt.captureInput([]byte(contents), pattern) + return string(v), err +} + +type editorPrompt struct { + defaultEditor string +} + +// GetPreferredEditorFromEnvironment returns the user's editor as defined by the +// `$EDITOR` environment variable, or the `defaultEditor` if it is not set. +func (p *editorPrompt) getPreferredEditor() string { + editor := os.Getenv("EDITOR") + + if editor == "" { + return p.defaultEditor + } + + return editor +} + +func (p *editorPrompt) resolveEditorArguments(executable string, filename string) []string { + args := []string{filename} + + if strings.Contains(executable, "Visual Studio Code.app") { + args = append([]string{"--wait"}, args...) + } + + // TODO(cyx): add other common editors + + return args +} + +// openFile opens filename in the preferred text editor, resolving the +// arguments with editor specific logic. +func (p *editorPrompt) openFile(filename string) error { + // Get the full executable path for the editor. + executable, err := exec.LookPath(p.getPreferredEditor()) + if err != nil { + return err + } + + cmd := exec.Command(executable, p.resolveEditorArguments(executable, filename)...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + return cmd.Run() +} + +// captureInput opens a temporary file in a text editor and returns +// the written bytes on success or an error on failure. It handles deletion +// of the temporary file behind the scenes. +// +// If given default contents, it will write that to the file before popping +// open the editor. +func (p *editorPrompt) captureInput(contents []byte, pattern string) ([]byte, error) { + file, err := os.CreateTemp(os.TempDir(), pattern) + if err != nil { + return []byte{}, err + } + + filename := file.Name() + + if len(contents) > 0 { + if err := os.WriteFile(filename, contents, 0644); err != nil { + return nil, fmt.Errorf("Failed to write to file: %w", err) + } + } + + // Defer removal of the temporary file in case any of the next steps fail. + defer os.Remove(filename) + + if err = file.Close(); err != nil { + return nil, err + } + + if err = p.openFile(filename); err != nil { + return nil, err + } + + bytes, err := os.ReadFile(filename) + if err != nil { + return []byte{}, err + } + + return bytes, nil +}