From 835af272fad49342f51adb4633ff639de3cc14a1 Mon Sep 17 00:00:00 2001 From: Hank Donnay Date: Thu, 19 Nov 2020 12:59:22 -0600 Subject: [PATCH] clairctl: fix and codify import arguments This fixes a case where `./file` was getting interpreted as a URI, and adds a test to codify the behavior. Signed-off-by: Hank Donnay --- cmd/clairctl/import.go | 59 +++++++++++++++------------- cmd/clairctl/import_test.go | 76 +++++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 27 deletions(-) create mode 100644 cmd/clairctl/import_test.go diff --git a/cmd/clairctl/import.go b/cmd/clairctl/import.go index 162c16fc4f..7757c40d9a 100644 --- a/cmd/clairctl/import.go +++ b/cmd/clairctl/import.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "fmt" "io" @@ -46,34 +47,11 @@ func importAction(c *cli.Context) error { } inName := args.First() - var in io.Reader - u, uerr := url.Parse(inName) - f, ferr := os.Open(inName) - if f != nil { - defer f.Close() - } - switch { - case uerr == nil: - req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) - if err != nil { - return err - } - res, err := cl.Do(req) - if res != nil { - defer res.Body.Close() - } - if err != nil { - return err - } - if res.StatusCode != http.StatusOK { - return fmt.Errorf("unexpected response: %d %s", res.StatusCode, res.Status) - } - in = res.Body - case ferr == nil: - in = f - default: - return fmt.Errorf("unable to make sense of argument %q", inName) + in, err := openInput(ctx, cl, inName) + if err != nil { + return err } + defer in.Close() pool, err := pgxpool.Connect(ctx, cfg.Matcher.ConnString) if err != nil { @@ -86,3 +64,30 @@ func importAction(c *cli.Context) error { } return nil } + +func openInput(ctx context.Context, c *http.Client, n string) (io.ReadCloser, error) { + f, ferr := os.Open(n) + if ferr == nil { + return f, nil + } + u, uerr := url.Parse(n) + if uerr == nil { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + if err != nil { + return nil, err + } + res, err := c.Do(req) + if err != nil { + if res != nil { + res.Body.Close() + } + return nil, err + } + if res.StatusCode != http.StatusOK { + res.Body.Close() + return nil, fmt.Errorf("unexpected response: %d %s", res.StatusCode, res.Status) + } + return res.Body, nil + } + return nil, fmt.Errorf("error opening input:\n\t%v\n\t%v", ferr, uerr) +} diff --git a/cmd/clairctl/import_test.go b/cmd/clairctl/import_test.go new file mode 100644 index 0000000000..7af4a3fdee --- /dev/null +++ b/cmd/clairctl/import_test.go @@ -0,0 +1,76 @@ +package main + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "reflect" + "testing" +) + +// TestImportArg checks that the import subcommand's magic URL handling is +// correct. +func TestImportArg(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "http") + })) + defer srv.Close() + + tt := []struct { + In string + Err bool + Ok func(*testing.T, io.ReadCloser) + }{ + { + In: "import_test.go", + Ok: func(t *testing.T, in io.ReadCloser) { + if got, want := reflect.TypeOf(in), reflect.TypeOf(&os.File{}); got != want { + t.Errorf("got: %T, want: %T", got, want) + } + }, + }, + { + In: "./import_test.go", + Ok: func(t *testing.T, in io.ReadCloser) { + if got, want := reflect.TypeOf(in), reflect.TypeOf(&os.File{}); got != want { + t.Errorf("got: %T, want: %T", got, want) + } + }, + }, + { + In: srv.URL, + Ok: func(t *testing.T, rc io.ReadCloser) { + b := bytes.Buffer{} + if _, err := b.ReadFrom(rc); err != nil { + t.Errorf("read error: %v", err) + } + if got, want := b.String(), "http"; got != want { + t.Errorf("got: %q, want: %q", got, want) + } + }, + }, + { + In: "invalid", + Err: true, + Ok: func(*testing.T, io.ReadCloser) {}, + }, + } + + ctx := context.Background() + for _, tc := range tt { + rc, err := openInput(ctx, srv.Client(), tc.In) + t.Logf("%q: %T; %v", tc.In, rc, err) + if (err != nil) && !tc.Err { + t.Error() + } + tc.Ok(t, rc) + if rc != nil { + rc.Close() + } + } +}