From ef0bd93a77339ef6c749f0eee2d33e5cb2fd114f Mon Sep 17 00:00:00 2001 From: Olivier Leduc Date: Mon, 27 May 2024 10:12:23 -0400 Subject: [PATCH] Make --scm-base-url more fool-proof (#95) --- cmd/root.go | 6 ++-- providers/scm/scm_domain.go | 29 +++++++++++++++++ providers/scm/scm_domain_test.go | 56 ++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 3 deletions(-) create mode 100644 providers/scm/scm_domain.go create mode 100644 providers/scm/scm_domain_test.go diff --git a/cmd/root.go b/cmd/root.go index 8d51fc6..4101c25 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -26,7 +26,7 @@ import ( var Format string var Verbose bool var ScmProvider string -var ScmBaseURL string +var ScmBaseURL scm.ScmBaseDomain var ( Version string Commit string @@ -110,7 +110,7 @@ func init() { rootCmd.PersistentFlags().StringVarP(&Format, "format", "f", "pretty", "Output format (pretty, json, sarif)") rootCmd.PersistentFlags().BoolVarP(&Verbose, "verbose", "v", false, "Enable verbose logging") rootCmd.PersistentFlags().StringVarP(&ScmProvider, "scm", "s", "github", "SCM platform (github, gitlab)") - rootCmd.PersistentFlags().StringVarP(&ScmBaseURL, "scm-base-url", "b", "", "Base URI of the self-hosted SCM instance (optional)") + rootCmd.PersistentFlags().VarP(&ScmBaseURL, "scm-base-url", "b", "Base URI of the self-hosted SCM instance (optional)") } func initConfig() { @@ -166,7 +166,7 @@ func GetFormatter() analyze.Formatter { } func GetAnalyzer(ctx context.Context, command string) (*analyze.Analyzer, error) { - scmClient, err := scm.NewScmClient(ctx, ScmProvider, ScmBaseURL, token, command) + scmClient, err := scm.NewScmClient(ctx, ScmProvider, ScmBaseURL.String(), token, command) if err != nil { return nil, fmt.Errorf("failed to create SCM client: %w", err) } diff --git a/providers/scm/scm_domain.go b/providers/scm/scm_domain.go new file mode 100644 index 0000000..c4f1a42 --- /dev/null +++ b/providers/scm/scm_domain.go @@ -0,0 +1,29 @@ +package scm + +import "strings" + +// ScmBaseDomain represent the base domain for a SCM provider. +type ScmBaseDomain string + +var schemePrefixes = []string{"https://", "http://"} + +func (d *ScmBaseDomain) Set(value string) error { + for _, prefix := range schemePrefixes { + value = strings.TrimPrefix(value, prefix) + } + value = strings.TrimRight(value, "/") + + *d = ScmBaseDomain(value) + return nil +} + +func (d *ScmBaseDomain) String() string { + if d == nil { + return "" + } + return string(*d) +} + +func (d *ScmBaseDomain) Type() string { + return "string" +} diff --git a/providers/scm/scm_domain_test.go b/providers/scm/scm_domain_test.go new file mode 100644 index 0000000..c672385 --- /dev/null +++ b/providers/scm/scm_domain_test.go @@ -0,0 +1,56 @@ +package scm + +import "testing" + +var tests = map[string]struct { + input string + expected string +}{ + "strip https": { + input: "https://scm.com", + expected: "scm.com", + }, + "strip http": { + input: "http://example.scm.com", + expected: "example.scm.com", + }, + "ignore": { + input: "scm.com", + expected: "scm.com", + }, + "empty": { + input: "", + expected: "", + }, + "trailing slash": { + input: "https://scm.com/", + expected: "scm.com", + }, + "sub path": { + input: "https://scm.com/sub/domain", + expected: "scm.com/sub/domain", + }, +} + +func TestScmBaseDomain(t *testing.T) { + for name, test := range tests { + t.Run(name, func(t *testing.T) { + var d ScmBaseDomain + err := d.Set(test.input) + if err != nil { + t.Fatal(err) + } + s := d.String() + if s != test.expected { + t.Errorf("expected %s, got %s", test.expected, s) + } + }) + } +} + +func TestScmBaseDomainNil(t *testing.T) { + var d ScmBaseDomain + if d.String() != "" { + t.Error("expected default value of to be \"\"") + } +}