Skip to content

Commit

Permalink
clairctl: fix and codify import arguments
Browse files Browse the repository at this point in the history
This fixes a case where `./file` was getting interpreted as a URI, and
adds a test to codify the behavior.

Signed-off-by: Hank Donnay <[email protected]>
  • Loading branch information
hdonnay committed Nov 20, 2020
1 parent b9ef107 commit 835af27
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 27 deletions.
59 changes: 32 additions & 27 deletions cmd/clairctl/import.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -46,34 +47,11 @@ func importAction(c *cli.Context) error {
}
inName := args.First()

var in io.Reader
u, uerr := url.Parse(inName)
f, ferr := os.Open(inName)
if f != nil {
defer f.Close()
}
switch {
case uerr == nil:
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return err
}
res, err := cl.Do(req)
if res != nil {
defer res.Body.Close()
}
if err != nil {
return err
}
if res.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected response: %d %s", res.StatusCode, res.Status)
}
in = res.Body
case ferr == nil:
in = f
default:
return fmt.Errorf("unable to make sense of argument %q", inName)
in, err := openInput(ctx, cl, inName)
if err != nil {
return err
}
defer in.Close()

pool, err := pgxpool.Connect(ctx, cfg.Matcher.ConnString)
if err != nil {
Expand All @@ -86,3 +64,30 @@ func importAction(c *cli.Context) error {
}
return nil
}

func openInput(ctx context.Context, c *http.Client, n string) (io.ReadCloser, error) {
f, ferr := os.Open(n)
if ferr == nil {
return f, nil
}
u, uerr := url.Parse(n)
if uerr == nil {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return nil, err
}
res, err := c.Do(req)
if err != nil {
if res != nil {
res.Body.Close()
}
return nil, err
}
if res.StatusCode != http.StatusOK {
res.Body.Close()
return nil, fmt.Errorf("unexpected response: %d %s", res.StatusCode, res.Status)
}
return res.Body, nil
}
return nil, fmt.Errorf("error opening input:\n\t%v\n\t%v", ferr, uerr)
}
76 changes: 76 additions & 0 deletions cmd/clairctl/import_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package main

import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"reflect"
"testing"
)

// TestImportArg checks that the import subcommand's magic URL handling is
// correct.
func TestImportArg(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, "http")
}))
defer srv.Close()

tt := []struct {
In string
Err bool
Ok func(*testing.T, io.ReadCloser)
}{
{
In: "import_test.go",
Ok: func(t *testing.T, in io.ReadCloser) {
if got, want := reflect.TypeOf(in), reflect.TypeOf(&os.File{}); got != want {
t.Errorf("got: %T, want: %T", got, want)
}
},
},
{
In: "./import_test.go",
Ok: func(t *testing.T, in io.ReadCloser) {
if got, want := reflect.TypeOf(in), reflect.TypeOf(&os.File{}); got != want {
t.Errorf("got: %T, want: %T", got, want)
}
},
},
{
In: srv.URL,
Ok: func(t *testing.T, rc io.ReadCloser) {
b := bytes.Buffer{}
if _, err := b.ReadFrom(rc); err != nil {
t.Errorf("read error: %v", err)
}
if got, want := b.String(), "http"; got != want {
t.Errorf("got: %q, want: %q", got, want)
}
},
},
{
In: "invalid",
Err: true,
Ok: func(*testing.T, io.ReadCloser) {},
},
}

ctx := context.Background()
for _, tc := range tt {
rc, err := openInput(ctx, srv.Client(), tc.In)
t.Logf("%q: %T; %v", tc.In, rc, err)
if (err != nil) && !tc.Err {
t.Error()
}
tc.Ok(t, rc)
if rc != nil {
rc.Close()
}
}
}

0 comments on commit 835af27

Please sign in to comment.