diff --git a/clair/api.go b/clair/api.go index e8bcf0b..0aa77ae 100644 --- a/clair/api.go +++ b/clair/api.go @@ -22,6 +22,7 @@ type apiV1 struct { } type apiV3 struct { + url string client clairpb.AncestryServiceClient } @@ -52,11 +53,16 @@ func newAPIV3(url string) (*apiV3, error) { runes := []rune(url) url = string(runes[i+3:]) } + if strings.Index(url, ":") == -1 { + url = fmt.Sprintf("%s:6060", url) + } conn, err := grpc.Dial(url, grpc.WithInsecure()) if err != nil { return nil, fmt.Errorf("did not connect to %s: %v", url, err) } - return &apiV3{clairpb.NewAncestryServiceClient(conn)}, nil + return &apiV3{ + url: url, + client: clairpb.NewAncestryServiceClient(conn)}, nil } func (a *apiV1) Push(image *docker.Image) error { diff --git a/clair/api_test.go b/clair/api_test.go new file mode 100644 index 0000000..191d788 --- /dev/null +++ b/clair/api_test.go @@ -0,0 +1,79 @@ +package clair + +import "testing" + +func TestNewAPIV1(t *testing.T) { + cases := []struct { + url string + expected string + }{ + { + url: "http://localhost:6060", + expected: "http://localhost:6060", + }, + { + url: "http://localhost", + expected: "http://localhost:6060", + }, + { + url: "localhost", + expected: "http://localhost:6060", + }, + { + url: "https://localhost:6060", + expected: "https://localhost:6060", + }, + { + url: "https://localhost", + expected: "https://localhost:6060", + }, + } + for _, tc := range cases { + api := newAPIV1(tc.url) + if api.url != tc.expected { + t.Errorf("expected %s got %s", api.url, tc.expected) + } + } +} + +func TestNewAPIV3(t *testing.T) { + cases := []struct { + url string + expected string + }{ + { + url: "http://localhost:6060", + expected: "localhost:6060", + }, + { + url: "http://localhost", + expected: "localhost:6060", + }, + { + url: "localhost", + expected: "localhost:6060", + }, + { + url: "https://localhost:6060", + expected: "localhost:6060", + }, + { + url: "https://localhost", + expected: "localhost:6060", + }, + { + url: "https://localhost:9090", + expected: "localhost:9090", + }, + } + for _, tc := range cases { + api, err := newAPIV3(tc.url) + if err != nil { + t.Errorf("failed to initialize api v3: %s", err) + continue + } + if api.url != tc.expected { + t.Errorf("expected %s got %s", api.url, tc.expected) + } + } +}