Skip to content

Commit

Permalink
Make --scm-base-url more fool-proof (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
ledo01 authored May 27, 2024
1 parent 8dfc8b6 commit ef0bd93
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 3 deletions.
6 changes: 3 additions & 3 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
var Format string
var Verbose bool
var ScmProvider string
var ScmBaseURL string
var ScmBaseURL scm.ScmBaseDomain
var (
Version string
Commit string
Expand Down Expand Up @@ -110,7 +110,7 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&Format, "format", "f", "pretty", "Output format (pretty, json, sarif)")
rootCmd.PersistentFlags().BoolVarP(&Verbose, "verbose", "v", false, "Enable verbose logging")
rootCmd.PersistentFlags().StringVarP(&ScmProvider, "scm", "s", "github", "SCM platform (github, gitlab)")
rootCmd.PersistentFlags().StringVarP(&ScmBaseURL, "scm-base-url", "b", "", "Base URI of the self-hosted SCM instance (optional)")
rootCmd.PersistentFlags().VarP(&ScmBaseURL, "scm-base-url", "b", "Base URI of the self-hosted SCM instance (optional)")
}

func initConfig() {
Expand Down Expand Up @@ -166,7 +166,7 @@ func GetFormatter() analyze.Formatter {
}

func GetAnalyzer(ctx context.Context, command string) (*analyze.Analyzer, error) {
scmClient, err := scm.NewScmClient(ctx, ScmProvider, ScmBaseURL, token, command)
scmClient, err := scm.NewScmClient(ctx, ScmProvider, ScmBaseURL.String(), token, command)
if err != nil {
return nil, fmt.Errorf("failed to create SCM client: %w", err)
}
Expand Down
29 changes: 29 additions & 0 deletions providers/scm/scm_domain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package scm

import "strings"

// ScmBaseDomain represent the base domain for a SCM provider.
type ScmBaseDomain string

var schemePrefixes = []string{"https://", "http://"}

func (d *ScmBaseDomain) Set(value string) error {
for _, prefix := range schemePrefixes {
value = strings.TrimPrefix(value, prefix)
}
value = strings.TrimRight(value, "/")

*d = ScmBaseDomain(value)
return nil
}

func (d *ScmBaseDomain) String() string {
if d == nil {
return ""
}
return string(*d)
}

func (d *ScmBaseDomain) Type() string {
return "string"
}
56 changes: 56 additions & 0 deletions providers/scm/scm_domain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package scm

import "testing"

var tests = map[string]struct {
input string
expected string
}{
"strip https": {
input: "https://scm.com",
expected: "scm.com",
},
"strip http": {
input: "http://example.scm.com",
expected: "example.scm.com",
},
"ignore": {
input: "scm.com",
expected: "scm.com",
},
"empty": {
input: "",
expected: "",
},
"trailing slash": {
input: "https://scm.com/",
expected: "scm.com",
},
"sub path": {
input: "https://scm.com/sub/domain",
expected: "scm.com/sub/domain",
},
}

func TestScmBaseDomain(t *testing.T) {
for name, test := range tests {
t.Run(name, func(t *testing.T) {
var d ScmBaseDomain
err := d.Set(test.input)
if err != nil {
t.Fatal(err)
}
s := d.String()
if s != test.expected {
t.Errorf("expected %s, got %s", test.expected, s)
}
})
}
}

func TestScmBaseDomainNil(t *testing.T) {
var d ScmBaseDomain
if d.String() != "" {
t.Error("expected default value of to be \"\"")
}
}

0 comments on commit ef0bd93

Please sign in to comment.