Skip to content

Commit

Permalink
feat: improve error message of when a set of flags is required to use…
Browse files Browse the repository at this point in the history
… together (#1358)

Signed-off-by: Xiaoxuan Wang <[email protected]>
  • Loading branch information
wangxiaoxuan273 authored Apr 30, 2024
1 parent 639d4b7 commit 4c15005
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 12 deletions.
31 changes: 25 additions & 6 deletions cmd/oras/internal/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,34 @@ func NewErrEmptyTagOrDigest(ref string, cmd *cobra.Command, needsTag bool) error
// CheckMutuallyExclusiveFlags checks if any mutually exclusive flags are used
// at the same time, returns an error when detecting used exclusive flags.
func CheckMutuallyExclusiveFlags(fs *pflag.FlagSet, exclusiveFlagSet ...string) error {
var changedFlags []string
for _, flagName := range exclusiveFlagSet {
if fs.Changed(flagName) {
changedFlags = append(changedFlags, fmt.Sprintf("--%s", flagName))
}
}
changedFlags, _ := checkChangedFlags(fs, exclusiveFlagSet...)
if len(changedFlags) >= 2 {
flags := strings.Join(changedFlags, ", ")
return fmt.Errorf("%s cannot be used at the same time", flags)
}
return nil
}

// CheckRequiredTogetherFlags checks if any flags required together are all used,
// returns an error when detecting any flags not used while other flags have been used.
func CheckRequiredTogetherFlags(fs *pflag.FlagSet, requiredTogetherFlags ...string) error {
changed, unchanged := checkChangedFlags(fs, requiredTogetherFlags...)
unchangedCount := len(unchanged)
if unchangedCount != 0 && unchangedCount != len(requiredTogetherFlags) {
changed := strings.Join(changed, ", ")
unchanged := strings.Join(unchanged, ", ")
return fmt.Errorf("%s must be used in conjunction with %s", changed, unchanged)
}
return nil
}

func checkChangedFlags(fs *pflag.FlagSet, flagSet ...string) (changedFlags []string, unchangedFlags []string) {
for _, flagName := range flagSet {
if fs.Changed(flagName) {
changedFlags = append(changedFlags, fmt.Sprintf("--%s", flagName))
} else {
unchangedFlags = append(unchangedFlags, fmt.Sprintf("--%s", flagName))
}
}
return
}
39 changes: 39 additions & 0 deletions cmd/oras/internal/errors/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,42 @@ func TestCheckMutuallyExclusiveFlags(t *testing.T) {
})
}
}

func TestCheckRequiredTogetherFlags(t *testing.T) {
fs := &pflag.FlagSet{}
var foo, bar, hello, world bool
fs.BoolVar(&foo, "foo", false, "foo test")
fs.BoolVar(&bar, "bar", false, "bar test")
fs.BoolVar(&hello, "hello", false, "hello test")
fs.BoolVar(&world, "world", false, "world test")
fs.Lookup("foo").Changed = true
fs.Lookup("bar").Changed = true
tests := []struct {
name string
requiredTogetherFlags []string
wantErr bool
}{
{
"--foo and --bar are both used, no error is returned",
[]string{"foo", "bar"},
false,
},
{
"--foo and --hello are not both used, an error is returned",
[]string{"foo", "hello"},
true,
},
{
"none of --hello and --world is used, no error is returned",
[]string{"hello", "world"},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := CheckRequiredTogetherFlags(fs, tt.requiredTogetherFlags...); (err != nil) != tt.wantErr {
t.Errorf("CheckRequiredTogetherFlags() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
6 changes: 3 additions & 3 deletions cmd/oras/internal/option/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ func (opts *Remote) Parse(cmd *cobra.Command) error {
if err := opts.parseCustomHeaders(); err != nil {
return err
}

cmd.MarkFlagsRequiredTogether(certFileAndKeyFileFlags...)

if err := oerrors.CheckRequiredTogetherFlags(cmd.Flags(), certFileAndKeyFileFlags...); err != nil {
return err
}
return opts.readSecret(cmd)
}

Expand Down
6 changes: 3 additions & 3 deletions test/e2e/internal/utils/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ const (
orasBinary = "oras"

// customize your own basic auth file via `htpasswd -cBb <file_name> <user_name> <password>`
Username = "hello"
Password = "oras-test"
DefaultTimeout = 10 * time.Second
Username = "hello"
Password = "oras-test"
DefaultTimeout = 10 * time.Second
// If the command hasn't exited yet, ginkgo session ExitCode is -1
notResponding = -1
)
Expand Down
5 changes: 5 additions & 0 deletions test/e2e/suite/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ var _ = Describe("Common registry user", func() {
ORAS("login", ZOTHost, "--identity-token", Password).
MatchErrKeyWords("WARNING", "Using --identity-token via the CLI is insecure", "Use --identity-token-stdin").ExpectFailure().Exec()
})

It("should fail if --cert-file is not used with --key-file with correct error message", func() {
ORAS("login", ZOTHost, "--cert-file", "test").
MatchErrKeyWords("--cert-file", "in conjunction with", "--key-file").ExpectFailure().Exec()
})
})

When("using legacy config", func() {
Expand Down

0 comments on commit 4c15005

Please sign in to comment.