diff --git a/database/pgsql/ancestry.go b/database/pgsql/ancestry.go new file mode 100644 index 0000000000..f49b9d9f0b --- /dev/null +++ b/database/pgsql/ancestry.go @@ -0,0 +1,271 @@ +package pgsql + +import ( + "database/sql" + "errors" + "fmt" + + "strings" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/pkg/commonerr" + "github.com/lib/pq" + log "github.com/sirupsen/logrus" +) + +func (tx *pgSession) UpsertAncestry(ancestry database.Ancestry, features []database.NamespacedFeature, processedBy database.Processors) error { + if ancestry.Name == "" { + log.Warning("Empty ancestry name is not allowed") + return commonerr.NewBadRequestError("could not insert an ancestry with empty name") + } + + if len(ancestry.Layers) == 0 { + log.Warning("Empty ancestry is not allowed") + return commonerr.NewBadRequestError("could not insert an ancestry with 0 layers") + } + + err := tx.deleteAncestry(ancestry.Name) + if err != nil { + return err + } + + var ancestryID int + err = tx.QueryRow(insertAncestry, ancestry.Name).Scan(&ancestryID) + if err != nil { + return handleError("insertAncestry", err) + } + + err = tx.insertAncestryLayers(ancestryID, ancestry.Layers) + if err != nil { + return err + } + + err = tx.insertAncestryFeatures(ancestryID, features) + if err != nil { + return err + } + + return tx.persistProcessors(persistAncestryLister, + "persistAncestryLister", + persistAncestryDetector, + "persistAncestryDetector", + ancestryID, processedBy) +} + +func (tx *pgSession) FindAncestry(name string) (database.Ancestry, database.Processors, bool, error) { + ancestry := database.Ancestry{Name: name} + processed := database.Processors{} + + var ancestryID int + err := tx.QueryRow(searchAncestry, name).Scan(&ancestryID) + if err != nil { + if err == sql.ErrNoRows { + return ancestry, processed, false, nil + } + return ancestry, processed, false, handleError("searchAncestry", err) + } + + ancestry.Layers, err = tx.findAncestryLayers(ancestryID) + if err != nil { + return ancestry, processed, false, err + } + + processed.Detectors, err = tx.findProcessors(searchAncestryDetectors, "searchAncestryDetectors", "detector", ancestryID) + if err != nil { + return ancestry, processed, false, err + } + + processed.Listers, err = tx.findProcessors(searchAncestryListers, "searchAncestryListers", "lister", ancestryID) + if err != nil { + return ancestry, processed, false, err + } + + return ancestry, processed, true, nil +} + +func (tx *pgSession) FindAncestryFeatures(name string) (database.AncestryWithFeatures, bool, error) { + var ( + awf database.AncestryWithFeatures + ok bool + err error + ) + awf.Ancestry, awf.ProcessedBy, ok, err = tx.FindAncestry(name) + if err != nil { + return awf, false, err + } + + if !ok { + return awf, false, nil + } + + rows, err := tx.Query(searchAncestryFeatures, name) + if err != nil { + return awf, false, handleError("searchAncestryFeatures", err) + } + + for rows.Next() { + nf := database.NamespacedFeature{} + err := rows.Scan(&nf.Namespace.Name, &nf.Namespace.VersionFormat, &nf.Feature.Name, &nf.Feature.Version) + if err != nil { + return awf, false, handleError("searchAncestryFeatures", err) + } + nf.Feature.VersionFormat = nf.Namespace.VersionFormat + awf.Features = append(awf.Features, nf) + } + + return awf, true, nil +} + +func (tx *pgSession) deleteAncestry(name string) error { + result, err := tx.Exec(removeAncestry, name) + if err != nil { + return handleError("removeAncestry", err) + } + + _, err = result.RowsAffected() + if err != nil { + return handleError("removeAncestry", err) + } + + return nil +} + +func (tx *pgSession) findProcessors(query, queryName, processorType string, id int) ([]string, error) { + rows, err := tx.Query(query, id) + if err != nil { + if err == sql.ErrNoRows { + log.Warning("No " + processorType + " are used") + return nil, nil + } + return nil, handleError(queryName, err) + } + + var ( + processors []string + processor string + ) + + for rows.Next() { + err := rows.Scan(&processor) + if err != nil { + return nil, handleError(queryName, err) + } + processors = append(processors, processor) + } + + return processors, nil +} + +func (tx *pgSession) findAncestryLayers(ancestryID int) ([]database.Layer, error) { + rows, err := tx.Query(searchAncestryLayer, ancestryID) + if err != nil { + return nil, handleError("searchAncestryLayer", err) + } + layers := []database.Layer{} + for rows.Next() { + var layer database.Layer + err := rows.Scan(&layer.Hash) + if err != nil { + return nil, handleError("searchAncestryLayer", err) + } + layers = append(layers, layer) + } + return layers, nil +} + +func (tx *pgSession) insertAncestryLayers(ancestryID int, layers []database.Layer) error { + layerIDs := map[string]sql.NullInt64{} + for _, l := range layers { + layerIDs[l.Hash] = sql.NullInt64{} + } + + layerHashes := []string{} + for hash := range layerIDs { + layerHashes = append(layerHashes, hash) + } + + rows, err := tx.Query(searchLayerIDs, pq.Array(layerHashes)) + if err != nil { + return handleError("searchLayerIDs", err) + } + + for rows.Next() { + var ( + layerID sql.NullInt64 + layerName string + ) + err := rows.Scan(&layerID, &layerName) + if err != nil { + return handleError("searchLayerIDs", err) + } + layerIDs[layerName] = layerID + } + + notFound := []string{} + for hash, id := range layerIDs { + if !id.Valid { + notFound = append(notFound, hash) + } + } + + if len(notFound) > 0 { + return handleError("searchLayerIDs", fmt.Errorf("Layer %s is not found in database", strings.Join(notFound, ","))) + } + + stmt, err := tx.Prepare(copyinAncestryLayer) + if err != nil { + return handleError("copyinAncestryLayer", err) + } + + for index, layer := range layers { + _, err := stmt.Exec(ancestryID, index, layerIDs[layer.Hash].Int64) + if err != nil { + return handleError("copyinAncestryLayer", commonerr.CombineErrors(err, stmt.Close())) + } + } + + if _, err := stmt.Exec(); err != nil { + return handleError("copyinAncestryLayer", commonerr.CombineErrors(err, stmt.Close())) + } + + if err := stmt.Close(); err != nil { + return handleError("copyinAncestryLayer", err) + } + return nil +} + +func (tx *pgSession) insertAncestryFeatures(ancestryID int, features []database.NamespacedFeature) error { + featureIDs, err := tx.findNamespacedFeatureIDs(features) + if err != nil { + return err + } + + // bulk insert ancestry features + stmtFeatures, err := tx.Prepare(copyinAncestryFeatures) + if err != nil { + return handleError("copyinAncestryFeatures", err) + } + + for _, id := range featureIDs { + if !id.Valid { + stmtFeatures.Close() + return errors.New("requested namespaced feature is not in database") + } + + _, err := stmtFeatures.Exec(ancestryID, id) + if err != nil { + stmtFeatures.Close() + return handleError("copyinAncestryFeatures", err) + } + } + + if _, err := stmtFeatures.Exec(); err != nil { + stmtFeatures.Close() + return handleError("copyinAncestryFeatures", err) + } + + if err := stmtFeatures.Close(); err != nil { + return handleError("copyinAncestryFeatures", err) + } + return nil +} diff --git a/database/pgsql/ancestry_test.go b/database/pgsql/ancestry_test.go new file mode 100644 index 0000000000..73207c7411 --- /dev/null +++ b/database/pgsql/ancestry_test.go @@ -0,0 +1,197 @@ +package pgsql + +import ( + "testing" + + "sort" + + "github.com/coreos/clair/database" + "github.com/stretchr/testify/assert" +) + +var ( + store *pgSQL +) + +func TestUpsertAncestry(t *testing.T) { + store, tx := openSessionForTest(t, "UpsertAncestry", true) + defer closeTest(t, store, tx) + a1 := database.Ancestry{ + Name: "a1", + Layers: []database.Layer{ + {Hash: "layer-N"}, + }, + } + + a2 := database.Ancestry{} + + a3 := database.Ancestry{ + Name: "a", + Layers: []database.Layer{ + {Hash: "layer-0"}, + }, + } + + a4 := database.Ancestry{ + Name: "a", + Layers: []database.Layer{ + {Hash: "layer-1"}, + }, + } + + f1 := database.Feature{ + Name: "wechat", + Version: "0.5", + VersionFormat: "dpkg", + } + + // not in database + f2 := database.Feature{ + Name: "wechat", + Version: "0.6", + VersionFormat: "dpkg", + } + + n1 := database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + } + + p := database.Processors{ + Listers: []string{"dpkg", "non-existing"}, + Detectors: []string{"os-release", "non-existing"}, + } + + nsf1 := database.NamespacedFeature{ + Namespace: n1, + Feature: f1, + } + + // not in database + nsf2 := database.NamespacedFeature{ + Namespace: n1, + Feature: f2, + } + + // invalid case + assert.NotNil(t, tx.UpsertAncestry(a1, nil, database.Processors{})) + assert.NotNil(t, tx.UpsertAncestry(a2, nil, database.Processors{})) + // valid case + assert.Nil(t, tx.UpsertAncestry(a3, nil, database.Processors{})) + // replace invalid case + assert.NotNil(t, tx.UpsertAncestry(a4, []database.NamespacedFeature{nsf1, nsf2}, p)) + // replace valid case + assert.Nil(t, tx.UpsertAncestry(a4, []database.NamespacedFeature{nsf1}, p)) + // validate + ancestry, ok, err := tx.FindAncestryFeatures("a") + assert.Nil(t, err) + assert.True(t, ok) + assert.Equal(t, a4, ancestry.Ancestry) +} + +func assertProcessorsEqual(t *testing.T, expected database.Processors, actual database.Processors) bool { + sort.Strings(expected.Detectors) + sort.Strings(actual.Detectors) + sort.Strings(expected.Listers) + sort.Strings(actual.Listers) + return assert.Equal(t, expected.Detectors, actual.Detectors) && assert.Equal(t, expected.Listers, actual.Listers) +} + +func TestFindAncestry(t *testing.T) { + store, tx := openSessionForTest(t, "UpsertAncestry", true) + defer closeTest(t, store, tx) + + // not found + _, _, ok, err := tx.FindAncestry("ancestry-non") + assert.Nil(t, err) + assert.False(t, ok) + + expected := database.Ancestry{ + Name: "ancestry-1", + Layers: []database.Layer{ + {Hash: "layer-0"}, + {Hash: "layer-1"}, + {Hash: "layer-2"}, + {Hash: "layer-3a"}, + }, + } + + expectedProcessors := database.Processors{ + Detectors: []string{"os-release"}, + Listers: []string{"dpkg"}, + } + + // found + a, p, ok2, err := tx.FindAncestry("ancestry-1") + if assert.Nil(t, err) && assert.True(t, ok2) { + assertAncestryEqual(t, expected, a) + assertProcessorsEqual(t, expectedProcessors, p) + } +} + +func assertAncestryWithFeatureEqual(t *testing.T, expected database.AncestryWithFeatures, actual database.AncestryWithFeatures) bool { + return assertAncestryEqual(t, expected.Ancestry, actual.Ancestry) && + assertNamespacedFeatureEqual(t, expected.Features, actual.Features) && + assertProcessorsEqual(t, expected.ProcessedBy, actual.ProcessedBy) +} +func assertAncestryEqual(t *testing.T, expected database.Ancestry, actual database.Ancestry) bool { + return assert.Equal(t, expected.Name, actual.Name) && assert.Equal(t, expected.Layers, actual.Layers) +} + +func TestFindAncestryFeatures(t *testing.T) { + store, tx := openSessionForTest(t, "UpsertAncestry", true) + defer closeTest(t, store, tx) + + // invalid + _, ok, err := tx.FindAncestryFeatures("ancestry-non") + if assert.Nil(t, err) { + assert.False(t, ok) + } + + expected := database.AncestryWithFeatures{ + Ancestry: database.Ancestry{ + Name: "ancestry-2", + Layers: []database.Layer{ + {Hash: "layer-0"}, + {Hash: "layer-1"}, + {Hash: "layer-2"}, + {Hash: "layer-3b"}, + }, + }, + ProcessedBy: database.Processors{ + Detectors: []string{"os-release"}, + Listers: []string{"dpkg"}, + }, + Features: []database.NamespacedFeature{ + { + Namespace: database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + }, + Feature: database.Feature{ + Name: "wechat", + Version: "0.5", + VersionFormat: "dpkg", + }, + }, + { + Namespace: database.Namespace{ + Name: "debian:8", + VersionFormat: "dpkg", + }, + Feature: database.Feature{ + Name: "openssl", + Version: "1.0", + VersionFormat: "dpkg", + }, + }, + }, + } + // valid + ancestry, ok, err := tx.FindAncestryFeatures("ancestry-2") + if assert.Nil(t, err) && assert.True(t, ok) { + assertAncestryEqual(t, expected.Ancestry, ancestry.Ancestry) + assertNamespacedFeatureEqual(t, expected.Features, ancestry.Features) + assertProcessorsEqual(t, expected.ProcessedBy, ancestry.ProcessedBy) + } +} diff --git a/database/pgsql/complex_test.go b/database/pgsql/complex_test.go index ed038b4e0f..fb326b8766 100644 --- a/database/pgsql/complex_test.go +++ b/database/pgsql/complex_test.go @@ -17,145 +17,170 @@ package pgsql import ( "fmt" "math/rand" - "runtime" "strconv" - "sync" "testing" - "time" - "github.com/pborman/uuid" - "github.com/stretchr/testify/assert" + "sync" "github.com/coreos/clair/database" + "github.com/coreos/clair/ext/versionfmt" "github.com/coreos/clair/ext/versionfmt/dpkg" + "github.com/coreos/clair/pkg/strutil" + "github.com/pborman/uuid" + "github.com/stretchr/testify/assert" ) const ( numVulnerabilities = 100 - numFeatureVersions = 100 + numFeatures = 100 ) -func TestRaceAffects(t *testing.T) { - datastore, err := openDatabaseForTest("RaceAffects", false) - if err != nil { - t.Error(err) - return +func testGenRandomVulnerabilityAndNamespacedFeature(t *testing.T, store database.Datastore) ([]database.NamespacedFeature, []database.VulnerabilityWithAffected) { + tx, err := store.Begin() + if !assert.Nil(t, err) { + t.FailNow() } - defer datastore.Close() - - // Insert the Feature on which we'll work. - feature := database.Feature{ - Namespace: database.Namespace{ - Name: "TestRaceAffectsFeatureNamespace1", - VersionFormat: dpkg.ParserName, - }, - Name: "TestRaceAffecturesFeature1", + + featureName := "TestFeature" + featureVersionFormat := dpkg.ParserName + // Insert the namespace on which we'll work. + namespace := database.Namespace{ + Name: "TestRaceAffectsFeatureNamespace1", + VersionFormat: dpkg.ParserName, } - _, err = datastore.insertFeature(feature) - if err != nil { - t.Error(err) - return + + if !assert.Nil(t, tx.PersistNamespaces([]database.Namespace{namespace})) { + t.FailNow() } // Initialize random generator and enforce max procs. - rand.Seed(time.Now().UnixNano()) - runtime.GOMAXPROCS(runtime.NumCPU()) - - // Generate FeatureVersions. - featureVersions := make([]database.FeatureVersion, numFeatureVersions) - for i := 0; i < numFeatureVersions; i++ { - version := rand.Intn(numFeatureVersions) + rand.Seed(1) + + // Generate Distinct random features + features := make([]database.Feature, numFeatures) + nsFeatures := make([]database.NamespacedFeature, numFeatures) + for i := 0; i < numFeatures; i++ { + version := rand.Intn(numFeatures) + + features[i] = database.Feature{ + Name: featureName, + VersionFormat: featureVersionFormat, + Version: strconv.Itoa(version), + } - featureVersions[i] = database.FeatureVersion{ - Feature: feature, - Version: strconv.Itoa(version), + nsFeatures[i] = database.NamespacedFeature{ + Namespace: namespace, + Feature: features[i], } } + // insert features + if !assert.Nil(t, tx.PersistFeatures(features)) { + t.FailNow() + } + // Generate vulnerabilities. - // They are mapped by fixed version, which will make verification really easy afterwards. - vulnerabilities := make(map[int][]database.Vulnerability) + vulnerabilities := []database.VulnerabilityWithAffected{} for i := 0; i < numVulnerabilities; i++ { - version := rand.Intn(numFeatureVersions) + 1 - - // if _, ok := vulnerabilities[version]; !ok { - // vulnerabilities[version] = make([]database.Vulnerability) - // } - - vulnerability := database.Vulnerability{ - Name: uuid.New(), - Namespace: feature.Namespace, - FixedIn: []database.FeatureVersion{ + // any version less than this is vulnerable + version := rand.Intn(numFeatures) + 1 + + vulnerability := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: uuid.New(), + Namespace: namespace, + Severity: database.UnknownSeverity, + }, + Affected: []database.AffectedFeature{ { - Feature: feature, - Version: strconv.Itoa(version), + Namespace: namespace, + FeatureName: featureName, + AffectedVersion: strconv.Itoa(version), + FixedInVersion: strconv.Itoa(version), }, }, - Severity: database.UnknownSeverity, } - vulnerabilities[version] = append(vulnerabilities[version], vulnerability) + vulnerabilities = append(vulnerabilities, vulnerability) + } + tx.Commit() + + return nsFeatures, vulnerabilities +} + +func TestCaching(t *testing.T) { + store, err := openDatabaseForTest("caching_test", false) + if !assert.Nil(t, err) { + t.FailNow() } + defer store.Close() + + nsFeatures, vulnerabilities := testGenRandomVulnerabilityAndNamespacedFeature(t, store) + + fmt.Printf("%d features, %d vulnerabilities are generated", len(nsFeatures), len(vulnerabilities)) - // Insert featureversions and vulnerabilities in parallel. var wg sync.WaitGroup wg.Add(2) - go func() { defer wg.Done() - for _, vulnerabilitiesM := range vulnerabilities { - for _, vulnerability := range vulnerabilitiesM { - err = datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability}, true) - assert.Nil(t, err) - } + tx, err := store.Begin() + if !assert.Nil(t, err) { + t.FailNow() } - fmt.Println("finished to insert vulnerabilities") + + assert.Nil(t, tx.PersistNamespacedFeatures(nsFeatures)) + fmt.Println("finished to insert namespaced features") + + tx.Commit() }() go func() { defer wg.Done() - for i := 0; i < len(featureVersions); i++ { - featureVersions[i].ID, err = datastore.insertFeatureVersion(featureVersions[i]) - assert.Nil(t, err) + tx, err := store.Begin() + if !assert.Nil(t, err) { + t.FailNow() } - fmt.Println("finished to insert featureVersions") + + assert.Nil(t, tx.InsertVulnerabilities(vulnerabilities)) + fmt.Println("finished to insert vulnerabilities") + tx.Commit() + }() wg.Wait() + tx, err := store.Begin() + if !assert.Nil(t, err) { + t.FailNow() + } + defer tx.Rollback() + // Verify consistency now. - var actualAffectedNames []string - var expectedAffectedNames []string - - for _, featureVersion := range featureVersions { - featureVersionVersion, _ := strconv.Atoi(featureVersion.Version) - - // Get actual affects. - rows, err := datastore.Query(searchComplexTestFeatureVersionAffects, - featureVersion.ID) - assert.Nil(t, err) - defer rows.Close() - - var vulnName string - for rows.Next() { - err = rows.Scan(&vulnName) - if !assert.Nil(t, err) { - continue - } - actualAffectedNames = append(actualAffectedNames, vulnName) - } - if assert.Nil(t, rows.Err()) { - rows.Close() + affected, err := tx.FindAffectedNamespacedFeatures(nsFeatures) + if !assert.Nil(t, err) { + t.FailNow() + } + + for _, ansf := range affected { + if !assert.True(t, ansf.Valid) { + t.FailNow() } - // Get expected affects. - for i := numVulnerabilities; i > featureVersionVersion; i-- { - for _, vulnerability := range vulnerabilities[i] { - expectedAffectedNames = append(expectedAffectedNames, vulnerability.Name) + expectedAffectedNames := []string{} + for _, vuln := range vulnerabilities { + if ok, err := versionfmt.InRange(dpkg.ParserName, ansf.Version, vuln.Affected[0].AffectedVersion); err == nil { + if ok { + expectedAffectedNames = append(expectedAffectedNames, vuln.Name) + } } } - assert.Len(t, compareStringLists(expectedAffectedNames, actualAffectedNames), 0) - assert.Len(t, compareStringLists(actualAffectedNames, expectedAffectedNames), 0) + actualAffectedNames := []string{} + for _, s := range ansf.AffectedBy { + actualAffectedNames = append(actualAffectedNames, s.Name) + } + + assert.Len(t, strutil.CompareStringLists(expectedAffectedNames, actualAffectedNames), 0) + assert.Len(t, strutil.CompareStringLists(actualAffectedNames, expectedAffectedNames), 0) } } diff --git a/database/pgsql/feature.go b/database/pgsql/feature.go index c39bd5b7a3..1ed0e4d30b 100644 --- a/database/pgsql/feature.go +++ b/database/pgsql/feature.go @@ -16,230 +16,269 @@ package pgsql import ( "database/sql" - "strings" - "time" + "errors" "github.com/coreos/clair/database" "github.com/coreos/clair/ext/versionfmt" "github.com/coreos/clair/pkg/commonerr" + "github.com/lib/pq" + "github.com/sirupsen/logrus" ) -func (pgSQL *pgSQL) insertFeature(feature database.Feature) (int, error) { - if feature.Name == "" { - return 0, commonerr.NewBadRequestError("could not find/insert invalid Feature") - } +var ( + errFeatureNotFound = errors.New("Feature not found") +) + +type vulnerabilityAffecting struct { + vulnerabilityID int + addedByID int +} - // Do cache lookup. - if pgSQL.cache != nil { - promCacheQueriesTotal.WithLabelValues("feature").Inc() - id, found := pgSQL.cache.Get("feature:" + feature.Namespace.Name + ":" + feature.Name) - if found { - promCacheHitsTotal.WithLabelValues("feature").Inc() - return id.(int), nil +func (tx *pgSession) PersistFeatures(features []database.Feature) error { + for _, f := range features { + if f.Name == "" || f.Version == "" || f.VersionFormat == "" { + return commonerr.NewBadRequestError("Empty feature name, version or version format is not allowed") + } + var id int + err := tx.QueryRow(soiFeature, f.Name, f.Version, f.VersionFormat).Scan(&id) + if err != nil { + return handleError("soiFeature", err) } } + return nil +} - // We do `defer observeQueryTime` here because we don't want to observe cached features. - defer observeQueryTime("insertFeature", "all", time.Now()) +type namespacedFeatureWithID struct { + database.NamespacedFeature - // Find or create Namespace. - namespaceID, err := pgSQL.insertNamespace(feature.Namespace) - if err != nil { - return 0, err - } + ID int +} + +func (tx *pgSession) cacheAffectedNamespacedFeatures(features map[int]database.NamespacedFeature) error { + // compute the potential vulnerabilities from vulnerability_affected_feature + // compute if the requested feature is actually affected + // lock vulnerabiilty_affected_namespaced_feature and update the table + // use vulnerability_affected_feature table to compute the vulnerability_affected_namespaced_feature table - // Find or create Feature. - var id int - err = pgSQL.QueryRow(soiFeature, feature.Name, namespaceID).Scan(&id) + ids := []int{} + for id := range features { + ids = append(ids, id) + } + rows, err := tx.Query(searchPotentialNamespacedFeatureVulnerabilities, pq.Array(ids)) if err != nil { - return 0, handleError("soiFeature", err) + return handleError("searchPotentialNamespacedFeatureVulnerabilities", err) } - if pgSQL.cache != nil { - pgSQL.cache.Add("feature:"+feature.Namespace.Name+":"+feature.Name, id) + cached := map[int][]vulnerabilityAffecting{} + for rows.Next() { + var ( + fid int + vid int + affected string + addedBy int + ) + err := rows.Scan(&fid, &vid, &affected, &addedBy) + if err != nil { + rows.Close() + return err + } + if ok, err := versionfmt.InRange(features[fid].VersionFormat, features[fid].Version, affected); err != nil { + rows.Close() + return err + } else if ok { + cached[fid] = append(cached[fid], vulnerabilityAffecting{vulnerabilityID: vid, addedByID: addedBy}) + } } - return id, nil -} - -func (pgSQL *pgSQL) insertFeatureVersion(fv database.FeatureVersion) (id int, err error) { - err = versionfmt.Valid(fv.Feature.Namespace.VersionFormat, fv.Version) - if err != nil { - return 0, commonerr.NewBadRequestError("could not find/insert invalid FeatureVersion") - } + rows.Close() - // Do cache lookup. - cacheIndex := strings.Join([]string{"featureversion", fv.Feature.Namespace.Name, fv.Feature.Name, fv.Version}, ":") - if pgSQL.cache != nil { - promCacheQueriesTotal.WithLabelValues("featureversion").Inc() - id, found := pgSQL.cache.Get(cacheIndex) - if found { - promCacheHitsTotal.WithLabelValues("featureversion").Inc() - return id.(int), nil + numCachedFeature := 0 + for featureID, vulnIDs := range cached { + for _, vuln := range vulnIDs { + affected, err := tx.Exec(persistVulnerabilityAffectedNamespacedFeature, vuln.vulnerabilityID, featureID, vuln.addedByID) + if err != nil { + return handleError("persistVulnerabilityAffectedNamespacedFeature", err) + } + if num, err := affected.RowsAffected(); err == nil { + numCachedFeature += int(num) + } } } - // We do `defer observeQueryTime` here because we don't want to observe cached featureversions. - defer observeQueryTime("insertFeatureVersion", "all", time.Now()) + logrus.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", numCachedFeature) + return nil +} - // Find or create Feature first. - t := time.Now() - featureID, err := pgSQL.insertFeature(fv.Feature) - observeQueryTime("insertFeatureVersion", "insertFeature", t) +func (tx *pgSession) PersistNamespacedFeatures(features []database.NamespacedFeature) error { + nsfMap := map[int]database.NamespacedFeature{} + // NOTE(keyboardnerd): lock is relased when the transaction ends + _, err := tx.Exec(lockVulnerabilityAffects) if err != nil { - return 0, err + return handleError("lockVulnerabilityAffects", err) } - fv.Feature.ID = featureID - - // Try to find the FeatureVersion. - // - // In a populated database, the likelihood of the FeatureVersion already being there is high. - // If we can find it here, we then avoid using a transaction and locking the database. - err = pgSQL.QueryRow(searchFeatureVersion, featureID, fv.Version).Scan(&fv.ID) - if err != nil && err != sql.ErrNoRows { - return 0, handleError("searchFeatureVersion", err) + nsIDs := map[database.Namespace]sql.NullInt64{} + fIDs := map[database.Feature]sql.NullInt64{} + for _, f := range features { + nsIDs[f.Namespace] = sql.NullInt64{} + fIDs[f.Feature] = sql.NullInt64{} } - if err == nil { - if pgSQL.cache != nil { - pgSQL.cache.Add(cacheIndex, fv.ID) - } - return fv.ID, nil + fToFind := []database.Feature{} + for f := range fIDs { + fToFind = append(fToFind, f) } - // Begin transaction. - tx, err := pgSQL.Begin() - if err != nil { - tx.Rollback() - return 0, handleError("insertFeatureVersion.Begin()", err) + if ids, err := tx.findFeatureIDs(fToFind); err == nil { + for i, id := range ids { + if !id.Valid { + return errFeatureNotFound + } + fIDs[fToFind[i]] = id + } + } else { + return err } - // Lock Vulnerability_Affects_FeatureVersion exclusively. - // We want to prevent InsertVulnerability to modify it. - promConcurrentLockVAFV.Inc() - defer promConcurrentLockVAFV.Dec() - t = time.Now() - _, err = tx.Exec(lockVulnerabilityAffects) - observeQueryTime("insertFeatureVersion", "lock", t) + nsToFind := []database.Namespace{} + for ns := range nsIDs { + nsToFind = append(nsToFind, ns) + } - if err != nil { - tx.Rollback() - return 0, handleError("insertFeatureVersion.lockVulnerabilityAffects", err) + if ids, err := tx.findNamespaceIDs(nsToFind); err == nil { + for i, id := range ids { + if !id.Valid { + return errNamespaceNotFound + } + nsIDs[nsToFind[i]] = id + } + } else { + return err } - // Find or create FeatureVersion. - var created bool + for _, f := range features { + var id int + err := tx.QueryRow(soiNamespacedFeature, fIDs[f.Feature], nsIDs[f.Namespace]).Scan(&id) + if err != nil { + return handleError("soiNamespacedFeature", err) + } + nsfMap[id] = f + } - t = time.Now() - err = tx.QueryRow(soiFeatureVersion, featureID, fv.Version).Scan(&created, &fv.ID) - observeQueryTime("insertFeatureVersion", "soiFeatureVersion", t) + return tx.cacheAffectedNamespacedFeatures(nsfMap) +} - if err != nil { - tx.Rollback() - return 0, handleError("soiFeatureVersion", err) +// FindAffectedNamespacedFeatures looks up cache table and retrieve all +// vulnerabilities associated with the feature. +func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.NamespacedFeature) ([]database.NullableAffectedNamespacedFeature, error) { + if len(features) == 0 { + return nil, nil } - if !created { - // The featureVersion already existed, no need to link it to - // vulnerabilities. - tx.Commit() + returnFeatures := make([]database.NullableAffectedNamespacedFeature, len(features)) - if pgSQL.cache != nil { - pgSQL.cache.Add(cacheIndex, fv.ID) + // featureMap is used to keep track of duplicated features. + featureMap := map[database.NamespacedFeature][]*database.NullableAffectedNamespacedFeature{} + // initialize return value and generate unique feature request queries + for i, f := range features { + returnFeatures[i] = database.NullableAffectedNamespacedFeature{ + AffectedNamespacedFeature: database.AffectedNamespacedFeature{ + NamespacedFeature: f, + }, } - return fv.ID, nil + featureMap[f] = append(featureMap[f], &returnFeatures[i]) } - // Link the new FeatureVersion with every vulnerabilities that affect it, by inserting in - // Vulnerability_Affects_FeatureVersion. - t = time.Now() - err = linkFeatureVersionToVulnerabilities(tx, fv) - observeQueryTime("insertFeatureVersion", "linkFeatureVersionToVulnerabilities", t) - - if err != nil { - tx.Rollback() - return 0, err + // query unique namespaced features + distinctFeatures := []database.NamespacedFeature{} + for f := range featureMap { + distinctFeatures = append(distinctFeatures, f) } - // Commit transaction. - err = tx.Commit() + nsFeatureIDs, err := tx.findNamespacedFeatureIDs(distinctFeatures) if err != nil { - return 0, handleError("insertFeatureVersion.Commit()", err) - } - - if pgSQL.cache != nil { - pgSQL.cache.Add(cacheIndex, fv.ID) + return nil, err } - return fv.ID, nil -} - -// TODO(Quentin-M): Batch me -func (pgSQL *pgSQL) insertFeatureVersions(featureVersions []database.FeatureVersion) ([]int, error) { - IDs := make([]int, 0, len(featureVersions)) - - for i := 0; i < len(featureVersions); i++ { - id, err := pgSQL.insertFeatureVersion(featureVersions[i]) - if err != nil { - return IDs, err + toQuery := []int64{} + featureIDMap := map[int64][]*database.NullableAffectedNamespacedFeature{} + for i, id := range nsFeatureIDs { + if id.Valid { + toQuery = append(toQuery, id.Int64) + for _, f := range featureMap[distinctFeatures[i]] { + f.Valid = id.Valid + featureIDMap[id.Int64] = append(featureIDMap[id.Int64], f) + } } - IDs = append(IDs, id) } - return IDs, nil -} - -type vulnerabilityAffectsFeatureVersion struct { - vulnerabilityID int - fixedInID int - fixedInVersion string -} - -func linkFeatureVersionToVulnerabilities(tx *sql.Tx, featureVersion database.FeatureVersion) error { - // Select every vulnerability and the fixed version that affect this Feature. - // TODO(Quentin-M): LIMIT - rows, err := tx.Query(searchVulnerabilityFixedInFeature, featureVersion.Feature.ID) + rows, err := tx.Query(searchNamespacedFeaturesVulnerabilities, pq.Array(toQuery)) if err != nil { - return handleError("searchVulnerabilityFixedInFeature", err) + return nil, handleError("searchNamespacedFeaturesVulnerabilities", err) } defer rows.Close() - var affects []vulnerabilityAffectsFeatureVersion for rows.Next() { - var affect vulnerabilityAffectsFeatureVersion - - err := rows.Scan(&affect.fixedInID, &affect.vulnerabilityID, &affect.fixedInVersion) + var ( + featureID int64 + vuln database.VulnerabilityWithFixedIn + ) + err := rows.Scan(&featureID, + &vuln.Name, + &vuln.Description, + &vuln.Link, + &vuln.Severity, + &vuln.Metadata, + &vuln.FixedInVersion, + &vuln.Namespace.Name, + &vuln.Namespace.VersionFormat, + ) if err != nil { - return handleError("searchVulnerabilityFixedInFeature.Scan()", err) + return nil, handleError("searchNamespacedFeaturesVulnerabilities", err) } - cmp, err := versionfmt.Compare(featureVersion.Feature.Namespace.VersionFormat, featureVersion.Version, affect.fixedInVersion) - if err != nil { - return err - } - if cmp < 0 { - // The version of the FeatureVersion we are inserting is lower than the fixed version on this - // Vulnerability, thus, this FeatureVersion is affected by it. - affects = append(affects, affect) + for _, f := range featureIDMap[featureID] { + f.AffectedBy = append(f.AffectedBy, vuln) } } - if err = rows.Err(); err != nil { - return handleError("searchVulnerabilityFixedInFeature.Rows()", err) + + return returnFeatures, nil +} + +// findNamespacedFeatureIDs find ids for all namespaced features in input. +func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature) ([]sql.NullInt64, error) { + ids := make([]sql.NullInt64, len(nfs)) + stmt, err := tx.Prepare(searchNamespacedFeature) + if err != nil { + return nil, err } - rows.Close() - // Insert into Vulnerability_Affects_FeatureVersion. - for _, affect := range affects { - // TODO(Quentin-M): Batch me. - _, err := tx.Exec(insertVulnerabilityAffectsFeatureVersion, affect.vulnerabilityID, - featureVersion.ID, affect.fixedInID) - if err != nil { - return handleError("insertVulnerabilityAffectsFeatureVersion", err) + defer stmt.Close() + for i, nf := range nfs { + var id sql.NullInt64 + err := stmt.QueryRow(nf.Name, nf.Version, nf.VersionFormat, nf.Namespace.Name).Scan(&id) + if err != nil && err != sql.ErrNoRows { + return nil, handleError("searchNamespacedFeature", err) + } else if err == sql.ErrNoRows { + id.Valid = false } + ids[i] = id } - return nil + return ids, nil +} + +func (tx *pgSession) findFeatureIDs(fs []database.Feature) ([]sql.NullInt64, error) { + fid := []sql.NullInt64{} + for _, f := range fs { + var id sql.NullInt64 + err := tx.QueryRow(searchFeature, f.Name, f.Version, f.VersionFormat).Scan(&id) + if err != nil && err != sql.ErrNoRows { + return nil, handleError("searchFeature", err) + } + fid = append(fid, id) + } + return fid, nil } diff --git a/database/pgsql/feature_test.go b/database/pgsql/feature_test.go index 5b7f807858..fb7097d30e 100644 --- a/database/pgsql/feature_test.go +++ b/database/pgsql/feature_test.go @@ -17,99 +17,237 @@ package pgsql import ( "testing" - "github.com/stretchr/testify/assert" - "github.com/coreos/clair/database" - "github.com/coreos/clair/ext/versionfmt/dpkg" + _ "github.com/coreos/clair/ext/featurefmt/dpkg" + "github.com/stretchr/testify/assert" ) -func TestInsertFeature(t *testing.T) { - datastore, err := openDatabaseForTest("InsertFeature", false) - if err != nil { - t.Error(err) - return +func TestPersistFeatures(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistFeatures", false) + defer closeTest(t, datastore, tx) + + f1 := database.Feature{} + f2 := database.Feature{Name: "n", Version: "v", VersionFormat: "vf"} + + // empty + assert.Nil(t, tx.PersistFeatures([]database.Feature{})) + // invalid + assert.NotNil(t, tx.PersistFeatures([]database.Feature{f1})) + // duplicated + assert.Nil(t, tx.PersistFeatures([]database.Feature{f2, f2})) + // existing + assert.Nil(t, tx.PersistFeatures([]database.Feature{f2})) + + fs := listFeatures(t, tx) + assert.Len(t, fs, 1) + assert.Equal(t, f2, fs[0]) +} + +func TestPersistNamespacedFeatures(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistNamespacedFeatures", true) + defer closeTest(t, datastore, tx) + + // existing features + f1 := database.Feature{ + Name: "wechat", + Version: "0.5", + VersionFormat: "dpkg", } - defer datastore.Close() - // Invalid Feature. - id0, err := datastore.insertFeature(database.Feature{}) - assert.NotNil(t, err) - assert.Zero(t, id0) + // non-existing features + f2 := database.Feature{ + Name: "fake!", + } - id0, err = datastore.insertFeature(database.Feature{ - Namespace: database.Namespace{}, - Name: "TestInsertFeature0", - }) - assert.NotNil(t, err) - assert.Zero(t, id0) + f3 := database.Feature{ + Name: "openssl", + Version: "2.0", + VersionFormat: "dpkg", + } - // Insert Feature and ensure we can find it. - feature := database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertFeatureNamespace1", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertFeature1", - } - id1, err := datastore.insertFeature(feature) - assert.Nil(t, err) - id2, err := datastore.insertFeature(feature) - assert.Nil(t, err) - assert.Equal(t, id1, id2) - - // Insert invalid FeatureVersion. - for _, invalidFeatureVersion := range []database.FeatureVersion{ - { - Feature: database.Feature{}, - Version: "1.0", - }, - { - Feature: database.Feature{ - Namespace: database.Namespace{}, - Name: "TestInsertFeature2", - }, - Version: "1.0", - }, - { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertFeatureNamespace2", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertFeature2", - }, - Version: "", - }, - { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertFeatureNamespace2", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertFeature2", - }, - Version: "bad version", - }, - } { - id3, err := datastore.insertFeatureVersion(invalidFeatureVersion) - assert.Error(t, err) - assert.Zero(t, id3) + // exising namespace + n1 := database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + } + + n3 := database.Namespace{ + Name: "debian:8", + VersionFormat: "dpkg", + } + + // non-existing namespace + n2 := database.Namespace{ + Name: "debian:non", + VersionFormat: "dpkg", + } + + // existing namespaced feature + nf1 := database.NamespacedFeature{ + Namespace: n1, + Feature: f1, + } + + // invalid namespaced feature + nf2 := database.NamespacedFeature{ + Namespace: n2, + Feature: f2, + } + + // new namespaced feature affected by vulnerability + nf3 := database.NamespacedFeature{ + Namespace: n3, + Feature: f3, + } + + // namespaced features with namespaces or features not in the database will + // generate error. + assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{})) + + assert.NotNil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1, nf2})) + // valid case: insert nf3 + assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1, nf3})) + + all := listNamespacedFeatures(t, tx) + assert.Contains(t, all, nf1) + assert.Contains(t, all, nf3) +} + +func TestVulnerableFeature(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistNamespacedFeatures", true) + defer closeTest(t, datastore, tx) + + f1 := database.Feature{ + Name: "openssl", + Version: "1.3", + VersionFormat: "dpkg", } - // Insert FeatureVersion and ensure we can find it. - featureVersion := database.FeatureVersion{ + n1 := database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + } + + nf1 := database.NamespacedFeature{ + Namespace: n1, + Feature: f1, + } + assert.Nil(t, tx.PersistFeatures([]database.Feature{f1})) + assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1})) + + // ensure the namespaced feature is affected correctly + anf, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{nf1}) + if assert.Nil(t, err) && + assert.Len(t, anf, 1) && + assert.True(t, anf[0].Valid) && + assert.Len(t, anf[0].AffectedBy, 1) { + assert.Equal(t, "CVE-OPENSSL-1-DEB7", anf[0].AffectedBy[0].Name) + } +} + +func TestFindAffectedNamespacedFeatures(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistNamespacedFeatures", true) + defer closeTest(t, datastore, tx) + ns := database.NamespacedFeature{ Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertFeatureNamespace1", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertFeature1", + Name: "openssl", + Version: "1.0", + VersionFormat: "dpkg", + }, + Namespace: database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", }, - Version: "2:3.0-imba", } - id4, err := datastore.insertFeatureVersion(featureVersion) - assert.Nil(t, err) - id5, err := datastore.insertFeatureVersion(featureVersion) - assert.Nil(t, err) - assert.Equal(t, id4, id5) + + ans, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{ns}) + if assert.Nil(t, err) && + assert.Len(t, ans, 1) && + assert.True(t, ans[0].Valid) && + assert.Len(t, ans[0].AffectedBy, 1) { + assert.Equal(t, "CVE-OPENSSL-1-DEB7", ans[0].AffectedBy[0].Name) + } +} + +func listNamespacedFeatures(t *testing.T, tx *pgSession) []database.NamespacedFeature { + rows, err := tx.Query(`SELECT f.name, f.version, f.version_format, n.name, n.version_format + FROM feature AS f, namespace AS n, namespaced_feature AS nf + WHERE nf.feature_id = f.id AND nf.namespace_id = n.id`) + if err != nil { + t.Error(err) + t.FailNow() + } + + nf := []database.NamespacedFeature{} + for rows.Next() { + f := database.NamespacedFeature{} + err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat, &f.Namespace.Name, &f.Namespace.VersionFormat) + if err != nil { + t.Error(err) + t.FailNow() + } + nf = append(nf, f) + } + + return nf +} + +func listFeatures(t *testing.T, tx *pgSession) []database.Feature { + rows, err := tx.Query("SELECT name, version, version_format FROM feature") + if err != nil { + t.FailNow() + } + + fs := []database.Feature{} + for rows.Next() { + f := database.Feature{} + err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat) + if err != nil { + t.FailNow() + } + fs = append(fs, f) + } + return fs +} + +func assertFeaturesEqual(t *testing.T, expected []database.Feature, actual []database.Feature) bool { + if assert.Len(t, actual, len(expected)) { + has := map[database.Feature]bool{} + for _, nf := range expected { + has[nf] = false + } + + for _, nf := range actual { + has[nf] = true + } + + for nf, visited := range has { + if !assert.True(t, visited, nf.Name+" is expected") { + return false + } + return true + } + } + return false +} + +func assertNamespacedFeatureEqual(t *testing.T, expected []database.NamespacedFeature, actual []database.NamespacedFeature) bool { + if assert.Len(t, actual, len(expected)) { + has := map[database.NamespacedFeature]bool{} + for _, nf := range expected { + has[nf] = false + } + + for _, nf := range actual { + has[nf] = true + } + + for nf, visited := range has { + if !assert.True(t, visited, nf.Namespace.Name+":"+nf.Name+" is expected") { + return false + } + } + return true + } + return false } diff --git a/database/pgsql/keyvalue.go b/database/pgsql/keyvalue.go index ab5995881e..7a3af43c74 100644 --- a/database/pgsql/keyvalue.go +++ b/database/pgsql/keyvalue.go @@ -23,63 +23,38 @@ import ( "github.com/coreos/clair/pkg/commonerr" ) -// InsertKeyValue stores (or updates) a single key / value tuple. -func (pgSQL *pgSQL) InsertKeyValue(key, value string) (err error) { +// PersistKeyValue stores (or updates) a single key / value tuple. +func (tx *pgSession) UpdateKeyValue(key, value string) (err error) { if key == "" || value == "" { log.Warning("could not insert a flag which has an empty name or value") return commonerr.NewBadRequestError("could not insert a flag which has an empty name or value") } - defer observeQueryTime("InsertKeyValue", "all", time.Now()) + defer observeQueryTime("PersistKeyValue", "all", time.Now()) - // Upsert. - // - // Note: UPSERT works only on >= PostgreSQL 9.5 which is not yet supported by AWS RDS. - // The best solution is currently the use of http://dba.stackexchange.com/a/13477 - // but the key/value storage doesn't need to be super-efficient and super-safe at the - // moment so we can just use a client-side solution with transactions, based on - // http://postgresql.org/docs/current/static/plpgsql-control-structures.html. - // TODO(Quentin-M): Enable Upsert as soon as 9.5 is stable. - - for { - // First, try to update. - r, err := pgSQL.Exec(updateKeyValue, value, key) - if err != nil { - return handleError("updateKeyValue", err) - } - if n, _ := r.RowsAffected(); n > 0 { - // Updated successfully. - return nil - } - - // Try to insert the key. - // If someone else inserts the same key concurrently, we could get a unique-key violation error. - _, err = pgSQL.Exec(insertKeyValue, key, value) - if err != nil { - if isErrUniqueViolation(err) { - // Got unique constraint violation, retry. - continue - } - return handleError("insertKeyValue", err) - } - - return nil + _, err = tx.Exec(upsertKeyValue, key, value) + if err != nil { + return handleError("insertKeyValue", err) } + + return nil } -// GetValue reads a single key / value tuple and returns an empty string if the key doesn't exist. -func (pgSQL *pgSQL) GetKeyValue(key string) (string, error) { - defer observeQueryTime("GetKeyValue", "all", time.Now()) +// FindKeyValue finds value based on `key` in KeyValue table, returns value +// and true if found, and returns empty string and false if not found. +func (tx *pgSession) FindKeyValue(key string) (string, bool, error) { + defer observeQueryTime("FindKeyValue", "all", time.Now()) var value string - err := pgSQL.QueryRow(searchKeyValue, key).Scan(&value) + err := tx.QueryRow(searchKeyValue, key).Scan(&value) if err == sql.ErrNoRows { - return "", nil + return "", false, nil } + if err != nil { - return "", handleError("searchKeyValue", err) + return "", false, handleError("searchKeyValue", err) } - return value, nil + return value, true, nil } diff --git a/database/pgsql/keyvalue_test.go b/database/pgsql/keyvalue_test.go index 4a8b6593f2..d4ea9e1764 100644 --- a/database/pgsql/keyvalue_test.go +++ b/database/pgsql/keyvalue_test.go @@ -21,32 +21,30 @@ import ( ) func TestKeyValue(t *testing.T) { - datastore, err := openDatabaseForTest("KeyValue", false) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() + datastore, tx := openSessionForTest(t, "Lock", true) + defer closeTest(t, datastore, tx) // Get non-existing key/value - f, err := datastore.GetKeyValue("test") + f, ok, err := tx.FindKeyValue("test") assert.Nil(t, err) - assert.Empty(t, "", f) + assert.False(t, ok) // Try to insert invalid key/value. - assert.Error(t, datastore.InsertKeyValue("test", "")) - assert.Error(t, datastore.InsertKeyValue("", "test")) - assert.Error(t, datastore.InsertKeyValue("", "")) + assert.Error(t, tx.UpdateKeyValue("test", "")) + assert.Error(t, tx.UpdateKeyValue("", "test")) + assert.Error(t, tx.UpdateKeyValue("", "")) // Insert and verify. - assert.Nil(t, datastore.InsertKeyValue("test", "test1")) - f, err = datastore.GetKeyValue("test") + assert.Nil(t, tx.UpdateKeyValue("test", "test1")) + f, ok, err = tx.FindKeyValue("test") assert.Nil(t, err) + assert.True(t, ok) assert.Equal(t, "test1", f) // Update and verify. - assert.Nil(t, datastore.InsertKeyValue("test", "test2")) - f, err = datastore.GetKeyValue("test") + assert.Nil(t, tx.UpdateKeyValue("test", "test2")) + f, ok, err = tx.FindKeyValue("test") assert.Nil(t, err) + assert.True(t, ok) assert.Equal(t, "test2", f) } diff --git a/database/pgsql/layer.go b/database/pgsql/layer.go index 64e9a47599..82a85732f5 100644 --- a/database/pgsql/layer.go +++ b/database/pgsql/layer.go @@ -16,464 +16,196 @@ package pgsql import ( "database/sql" - "strings" - "time" - - "github.com/guregu/null/zero" - log "github.com/sirupsen/logrus" "github.com/coreos/clair/database" "github.com/coreos/clair/pkg/commonerr" ) -func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities bool) (database.Layer, error) { - subquery := "all" - if withFeatures { - subquery += "/features" - } else if withVulnerabilities { - subquery += "/features+vulnerabilities" - } - defer observeQueryTime("FindLayer", subquery, time.Now()) +func (tx *pgSession) FindLayer(hash string) (database.Layer, database.Processors, bool, error) { + l, p, _, ok, err := tx.findLayer(hash) + return l, p, ok, err +} - // Find the layer +func (tx *pgSession) FindLayerWithContent(hash string) (database.LayerWithContent, bool, error) { var ( - layer database.Layer - parentID zero.Int - parentName zero.String - nsID zero.Int - nsName sql.NullString - nsVersionFormat sql.NullString - ) - - t := time.Now() - err := pgSQL.QueryRow(searchLayer, name).Scan( - &layer.ID, - &layer.Name, - &layer.EngineVersion, - &parentID, - &parentName, + layer database.LayerWithContent + layerID int + ok bool + err error ) - observeQueryTime("FindLayer", "searchLayer", t) + layer.Layer, layer.ProcessedBy, layerID, ok, err = tx.findLayer(hash) if err != nil { - return layer, handleError("searchLayer", err) + return layer, false, err } - if !parentID.IsZero() { - layer.Parent = &database.Layer{ - Model: database.Model{ID: int(parentID.Int64)}, - Name: parentName.String, - } + if !ok { + return layer, false, nil } - rows, err := pgSQL.Query(searchLayerNamespace, layer.ID) - defer rows.Close() - if err != nil { - return layer, handleError("searchLayerNamespace", err) - } - for rows.Next() { - err = rows.Scan(&nsID, &nsName, &nsVersionFormat) - if err != nil { - return layer, handleError("searchLayerNamespace", err) - } - if !nsID.IsZero() { - layer.Namespaces = append(layer.Namespaces, database.Namespace{ - Model: database.Model{ID: int(nsID.Int64)}, - Name: nsName.String, - VersionFormat: nsVersionFormat.String, - }) - } - } - - // Find its features - if withFeatures || withVulnerabilities { - // Create a transaction to disable hash/merge joins as our experiments have shown that - // PostgreSQL 9.4 makes bad planning decisions about: - // - joining the layer tree to feature versions and feature - // - joining the feature versions to affected/fixed feature version and vulnerabilities - // It would for instance do a merge join between affected feature versions (300 rows, estimated - // 3000 rows) and fixed in feature version (100k rows). In this case, it is much more - // preferred to use a nested loop. - tx, err := pgSQL.Begin() - if err != nil { - return layer, handleError("FindLayer.Begin()", err) - } - defer tx.Commit() - - _, err = tx.Exec(disableHashJoin) - if err != nil { - log.WithError(err).Warningf("FindLayer: could not disable hash join") - } - _, err = tx.Exec(disableMergeJoin) - if err != nil { - log.WithError(err).Warningf("FindLayer: could not disable merge join") - } - - t = time.Now() - featureVersions, err := getLayerFeatureVersions(tx, layer.ID) - observeQueryTime("FindLayer", "getLayerFeatureVersions", t) - - if err != nil { - return layer, err - } - - layer.Features = featureVersions - - if withVulnerabilities { - // Load the vulnerabilities that affect the FeatureVersions. - t = time.Now() - err := loadAffectedBy(tx, layer.Features) - observeQueryTime("FindLayer", "loadAffectedBy", t) - - if err != nil { - return layer, err - } - } - } - - return layer, nil + layer.Features, err = tx.findLayerFeatures(layerID) + layer.Namespaces, err = tx.findLayerNamespaces(layerID) + return layer, true, nil } -// getLayerFeatureVersions returns list of database.FeatureVersion that a database.Layer has. -func getLayerFeatureVersions(tx *sql.Tx, layerID int) ([]database.FeatureVersion, error) { - var featureVersions []database.FeatureVersion - - // Query. - rows, err := tx.Query(searchLayerFeatureVersion, layerID) - if err != nil { - return featureVersions, handleError("searchLayerFeatureVersion", err) - } - defer rows.Close() - - // Scan query. - var modification string - mapFeatureVersions := make(map[int]database.FeatureVersion) - for rows.Next() { - var fv database.FeatureVersion - err = rows.Scan( - &fv.ID, - &modification, - &fv.Feature.Namespace.ID, - &fv.Feature.Namespace.Name, - &fv.Feature.Namespace.VersionFormat, - &fv.Feature.ID, - &fv.Feature.Name, - &fv.ID, - &fv.Version, - &fv.AddedBy.ID, - &fv.AddedBy.Name, - ) - if err != nil { - return featureVersions, handleError("searchLayerFeatureVersion.Scan()", err) - } - - // Do transitive closure. - switch modification { - case "add": - mapFeatureVersions[fv.ID] = fv - case "del": - delete(mapFeatureVersions, fv.ID) - default: - log.WithField("modification", modification).Warning("unknown Layer_diff_FeatureVersion's modification") - return featureVersions, database.ErrInconsistent - } - } - if err = rows.Err(); err != nil { - return featureVersions, handleError("searchLayerFeatureVersion.Rows()", err) - } - - // Build result by converting our map to a slice. - for _, featureVersion := range mapFeatureVersions { - featureVersions = append(featureVersions, featureVersion) +// PersistLayer soi a layer. +func (tx *pgSession) PersistLayer(layer database.Layer) error { + if layer.Hash == "" { + return commonerr.NewBadRequestError("Empty Layer Hash is not allowed") } - return featureVersions, nil + _, err := tx.Exec(soiLayer, layer.Hash) + return err } -// loadAffectedBy returns the list of database.Vulnerability that affect the given -// FeatureVersion. -func loadAffectedBy(tx *sql.Tx, featureVersions []database.FeatureVersion) error { - if len(featureVersions) == 0 { - return nil +// PersistLayerContent relates layer identified by hash with namespaces, +// features and processors provided. If the layer, namespaces, features are not +// in database, the function returns an error. +func (tx *pgSession) PersistLayerContent(hash string, namespaces []database.Namespace, features []database.Feature, processedBy database.Processors) error { + if hash == "" { + return commonerr.NewBadRequestError("Empty layer hash is not allowed") } - // Construct list of FeatureVersion IDs, we will do a single query - featureVersionIDs := make([]int, 0, len(featureVersions)) - for i := 0; i < len(featureVersions); i++ { - featureVersionIDs = append(featureVersionIDs, featureVersions[i].ID) - } - - rows, err := tx.Query(searchFeatureVersionVulnerability, - buildInputArray(featureVersionIDs)) - if err != nil && err != sql.ErrNoRows { - return handleError("searchFeatureVersionVulnerability", err) - } - defer rows.Close() - - vulnerabilities := make(map[int][]database.Vulnerability, len(featureVersions)) - var featureversionID int - for rows.Next() { - var vulnerability database.Vulnerability - err := rows.Scan( - &featureversionID, - &vulnerability.ID, - &vulnerability.Name, - &vulnerability.Description, - &vulnerability.Link, - &vulnerability.Severity, - &vulnerability.Metadata, - &vulnerability.Namespace.Name, - &vulnerability.Namespace.VersionFormat, - &vulnerability.FixedBy, - ) - if err != nil { - return handleError("searchFeatureVersionVulnerability.Scan()", err) - } - vulnerabilities[featureversionID] = append(vulnerabilities[featureversionID], vulnerability) - } - if err = rows.Err(); err != nil { - return handleError("searchFeatureVersionVulnerability.Rows()", err) - } - - // Assign vulnerabilities to every FeatureVersions - for i := 0; i < len(featureVersions); i++ { - featureVersions[i].AffectedBy = vulnerabilities[featureVersions[i].ID] + var layerID int + err := tx.QueryRow(searchLayer, hash).Scan(&layerID) + if err != nil { + return err } - return nil -} - -// Internally, only Feature additions/removals are stored for each layer. If a layer has a parent, -// the Feature list will be compared to the parent's Feature list and the difference will be stored. -// Note that when the Namespace of a layer differs from its parent, it is expected that several -// Feature that were already included a parent will have their Namespace updated as well -// (happens when Feature detectors relies on the detected layer Namespace). However, if the listed -// Feature has the same Name/Version as its parent, InsertLayer considers that the Feature hasn't -// been modified. -func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error { - tf := time.Now() - - // Verify parameters - if layer.Name == "" { - log.Warning("could not insert a layer which has an empty Name") - return commonerr.NewBadRequestError("could not insert a layer which has an empty Name") + nsIDs, err := tx.findNamespaceIDs(namespaces) + if err != nil { + return err } - // Get a potentially existing layer. - existingLayer, err := pgSQL.FindLayer(layer.Name, true, false) - if err != nil && err != commonerr.ErrNotFound { + fIDs, err := tx.findFeatureIDs(features) + if err != nil { return err - } else if err == nil { - if existingLayer.EngineVersion >= layer.EngineVersion { - // The layer exists and has an equal or higher engine version, do nothing. - return nil - } - - layer.ID = existingLayer.ID } - // We do `defer observeQueryTime` here because we don't want to observe existing layers. - defer observeQueryTime("InsertLayer", "all", tf) - - // Get parent ID. - var parentID zero.Int - if layer.Parent != nil { - if layer.Parent.ID == 0 { - log.Warning("Parent is expected to be retrieved from database when inserting a layer.") - return commonerr.NewBadRequestError("Parent is expected to be retrieved from database when inserting a layer.") + for _, id := range nsIDs { + if !id.Valid { + return errNamespaceNotFound } - - parentID = zero.IntFrom(int64(layer.Parent.ID)) - } - - // namespaceIDs will contain inherited and new namespaces - namespaceIDs := make(map[int]struct{}) - - // try to insert the new namespaces - for _, ns := range layer.Namespaces { - n, err := pgSQL.insertNamespace(ns) + _, err := tx.Exec(persistLayerNamespace, layerID, id) if err != nil { - return handleError("pgSQL.insertNamespace", err) + return handleError("persistLayerNamespace", err) } - namespaceIDs[n] = struct{}{} } - // inherit namespaces from parent layer - if layer.Parent != nil { - for _, ns := range layer.Parent.Namespaces { - namespaceIDs[ns.ID] = struct{}{} + for _, id := range fIDs { + _, err := tx.Exec(persistLayerFeature, layerID, id) + if err != nil { + return handleError("persistLayerFeature", err) } } - // Begin transaction. - tx, err := pgSQL.Begin() + return tx.persistProcessors(persistLayerListers, "persistlayerLister", persistLayerDetectors, "persistLayerDetectors", layerID, processedBy) +} + +func (tx *pgSession) persistProcessors(listerQuery, listerQueryName, detectorQuery, detectorQueryName string, id int, processors database.Processors) error { + stmt, err := tx.Prepare(listerQuery) if err != nil { - tx.Rollback() - return handleError("InsertLayer.Begin()", err) + return handleError(listerQueryName, err) } - if layer.ID == 0 { - // Insert a new layer. - err = tx.QueryRow(insertLayer, layer.Name, layer.EngineVersion, parentID). - Scan(&layer.ID) + for _, l := range processors.Listers { + _, err := stmt.Exec(id, l) if err != nil { - tx.Rollback() - - if isErrUniqueViolation(err) { - // Ignore this error, another process collided. - log.Debug("Attempted to insert duplicate layer.") - return nil - } - return handleError("insertLayer", err) - } - } else { - // Update an existing layer. - _, err = tx.Exec(updateLayer, layer.ID, layer.EngineVersion) - if err != nil { - tx.Rollback() - return handleError("updateLayer", err) - } - - // replace the old namespace in the database - _, err := tx.Exec(removeLayerNamespace, layer.ID) - if err != nil { - tx.Rollback() - return handleError("removeLayerNamespace", err) - } - // Remove all existing Layer_diff_FeatureVersion. - _, err = tx.Exec(removeLayerDiffFeatureVersion, layer.ID) - if err != nil { - tx.Rollback() - return handleError("removeLayerDiffFeatureVersion", err) + stmt.Close() + return handleError(listerQueryName, err) } } - // insert the layer's namespaces - stmt, err := tx.Prepare(insertLayerNamespace) + if err := stmt.Close(); err != nil { + return handleError(listerQueryName, err) + } + stmt, err = tx.Prepare(detectorQuery) if err != nil { - tx.Rollback() - return handleError("failed to prepare statement", err) + return handleError(detectorQueryName, err) } - defer func() { - err = stmt.Close() + for _, d := range processors.Detectors { + _, err := stmt.Exec(id, d) if err != nil { - tx.Rollback() - log.WithError(err).Error("failed to close prepared statement") - } - }() - - for nsid := range namespaceIDs { - _, err := stmt.Exec(layer.ID, nsid) - if err != nil { - tx.Rollback() - return handleError("insertLayerNamespace", err) + stmt.Close() + return handleError(detectorQueryName, err) } } - // Update Layer_diff_FeatureVersion now. - err = pgSQL.updateDiffFeatureVersions(tx, &layer, &existingLayer) - if err != nil { - tx.Rollback() - return err - } - - // Commit transaction. - err = tx.Commit() - if err != nil { - tx.Rollback() - return handleError("InsertLayer.Commit()", err) + if err := stmt.Close(); err != nil { + return handleError(detectorQueryName, err) } return nil } -func (pgSQL *pgSQL) updateDiffFeatureVersions(tx *sql.Tx, layer, existingLayer *database.Layer) error { - // add and del are the FeatureVersion diff we should insert. - var add []database.FeatureVersion - var del []database.FeatureVersion - - if layer.Parent == nil { - // There is no parent, every Features are added. - add = append(add, layer.Features...) - } else if layer.Parent != nil { - // There is a parent, we need to diff the Features with it. +func (tx *pgSession) findLayerNamespaces(layerID int) ([]database.Namespace, error) { + var namespaces []database.Namespace - // Build name:version structures. - layerFeaturesMapNV, layerFeaturesNV := createNV(layer.Features) - parentLayerFeaturesMapNV, parentLayerFeaturesNV := createNV(layer.Parent.Features) - - // Calculate the added and deleted FeatureVersions name:version. - addNV := compareStringLists(layerFeaturesNV, parentLayerFeaturesNV) - delNV := compareStringLists(parentLayerFeaturesNV, layerFeaturesNV) + rows, err := tx.Query(searchLayerNamespaces, layerID) + if err != nil { + return nil, handleError("searchLayerFeatures", err) + } - // Fill the structures containing the added and deleted FeatureVersions. - for _, nv := range addNV { - add = append(add, *layerFeaturesMapNV[nv]) - } - for _, nv := range delNV { - del = append(del, *parentLayerFeaturesMapNV[nv]) + for rows.Next() { + ns := database.Namespace{} + err := rows.Scan(&ns.Name, &ns.VersionFormat) + if err != nil { + return nil, err } + namespaces = append(namespaces, ns) } + return namespaces, nil +} - // Insert FeatureVersions in the database. - addIDs, err := pgSQL.insertFeatureVersions(add) - if err != nil { - return err - } - delIDs, err := pgSQL.insertFeatureVersions(del) +func (tx *pgSession) findLayerFeatures(layerID int) ([]database.Feature, error) { + var features []database.Feature + + rows, err := tx.Query(searchLayerFeatures, layerID) if err != nil { - return err + return nil, handleError("searchLayerFeatures", err) } - // Insert diff in the database. - if len(addIDs) > 0 { - _, err = tx.Exec(insertLayerDiffFeatureVersion, layer.ID, "add", buildInputArray(addIDs)) - if err != nil { - return handleError("insertLayerDiffFeatureVersion.Add", err) - } - } - if len(delIDs) > 0 { - _, err = tx.Exec(insertLayerDiffFeatureVersion, layer.ID, "del", buildInputArray(delIDs)) + for rows.Next() { + f := database.Feature{} + err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat) if err != nil { - return handleError("insertLayerDiffFeatureVersion.Del", err) + return nil, err } + features = append(features, f) } - - return nil + return features, nil } -func createNV(features []database.FeatureVersion) (map[string]*database.FeatureVersion, []string) { - mapNV := make(map[string]*database.FeatureVersion, 0) - sliceNV := make([]string, 0, len(features)) +func (tx *pgSession) findLayer(hash string) (database.Layer, database.Processors, int, bool, error) { + var ( + layerID int + layer = database.Layer{Hash: hash} + processors database.Processors + ) - for i := 0; i < len(features); i++ { - fv := &features[i] - nv := strings.Join([]string{fv.Feature.Namespace.Name, fv.Feature.Name, fv.Version}, ":") - mapNV[nv] = fv - sliceNV = append(sliceNV, nv) + if hash == "" { + return layer, processors, layerID, false, commonerr.NewBadRequestError("Empty Layer Hash is not allowed") } - return mapNV, sliceNV -} - -func (pgSQL *pgSQL) DeleteLayer(name string) error { - defer observeQueryTime("DeleteLayer", "all", time.Now()) - - result, err := pgSQL.Exec(removeLayer, name) + err := tx.QueryRow(searchLayer, hash).Scan(&layerID) if err != nil { - return handleError("removeLayer", err) + if err == sql.ErrNoRows { + return layer, processors, layerID, false, nil + } + return layer, processors, layerID, false, err } - affected, err := result.RowsAffected() + processors.Detectors, err = tx.findProcessors(searchLayerDetectors, "searchLayerDetectors", "detector", layerID) if err != nil { - return handleError("removeLayer.RowsAffected()", err) + return layer, processors, layerID, false, err } - if affected <= 0 { - return commonerr.ErrNotFound + processors.Listers, err = tx.findProcessors(searchLayerListers, "searchLayerListers", "lister", layerID) + if err != nil { + return layer, processors, layerID, false, err } - return nil + return layer, processors, layerID, true, nil } diff --git a/database/pgsql/layer_test.go b/database/pgsql/layer_test.go index 6f35bbde85..16e82558fb 100644 --- a/database/pgsql/layer_test.go +++ b/database/pgsql/layer_test.go @@ -1,437 +1,103 @@ -// Copyright 2017 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// // Copyright 2017 clair authors +// // +// // Licensed under the Apache License, Version 2.0 (the "License"); +// // you may not use this file except in compliance with the License. +// // You may obtain a copy of the License at +// // +// // http://www.apache.org/licenses/LICENSE-2.0 +// // +// // Unless required by applicable law or agreed to in writing, software +// // distributed under the License is distributed on an "AS IS" BASIS, +// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// // See the License for the specific language governing permissions and +// // limitations under the License. package pgsql import ( - "fmt" "testing" - "github.com/stretchr/testify/assert" - "github.com/coreos/clair/database" - "github.com/coreos/clair/ext/versionfmt/dpkg" - "github.com/coreos/clair/pkg/commonerr" + "github.com/stretchr/testify/assert" ) -func TestFindLayer(t *testing.T) { - datastore, err := openDatabaseForTest("FindLayer", true) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() - - // Layer-0: no parent, no namespace, no feature, no vulnerability - layer, err := datastore.FindLayer("layer-0", false, false) - if assert.Nil(t, err) && assert.NotNil(t, layer) { - assert.Equal(t, "layer-0", layer.Name) - assert.Len(t, layer.Namespaces, 0) - assert.Nil(t, layer.Parent) - assert.Equal(t, 1, layer.EngineVersion) - assert.Len(t, layer.Features, 0) - } - - layer, err = datastore.FindLayer("layer-0", true, false) - if assert.Nil(t, err) && assert.NotNil(t, layer) { - assert.Len(t, layer.Features, 0) - } - - // Layer-1: one parent, adds two features, one vulnerability - layer, err = datastore.FindLayer("layer-1", false, false) - if assert.Nil(t, err) && assert.NotNil(t, layer) { - assert.Equal(t, layer.Name, "layer-1") - assertExpectedNamespaceName(t, &layer, []string{"debian:7"}) - if assert.NotNil(t, layer.Parent) { - assert.Equal(t, "layer-0", layer.Parent.Name) - } - assert.Equal(t, 1, layer.EngineVersion) - assert.Len(t, layer.Features, 0) - } +func TestPersistLayer(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistLayer", false) + defer closeTest(t, datastore, tx) - layer, err = datastore.FindLayer("layer-1", true, false) - if assert.Nil(t, err) && assert.NotNil(t, layer) && assert.Len(t, layer.Features, 2) { - for _, featureVersion := range layer.Features { - assert.Equal(t, "debian:7", featureVersion.Feature.Namespace.Name) + l1 := database.Layer{} + l2 := database.Layer{Hash: "HESOYAM"} - switch featureVersion.Feature.Name { - case "wechat": - assert.Equal(t, "0.5", featureVersion.Version) - case "openssl": - assert.Equal(t, "1.0", featureVersion.Version) - default: - t.Errorf("unexpected package %s for layer-1", featureVersion.Feature.Name) - } - } - } - - layer, err = datastore.FindLayer("layer-1", true, true) - if assert.Nil(t, err) && assert.NotNil(t, layer) && assert.Len(t, layer.Features, 2) { - for _, featureVersion := range layer.Features { - assert.Equal(t, "debian:7", featureVersion.Feature.Namespace.Name) + // invalid + assert.NotNil(t, tx.PersistLayer(l1)) + // valid + assert.Nil(t, tx.PersistLayer(l2)) + // duplicated + assert.Nil(t, tx.PersistLayer(l2)) +} - switch featureVersion.Feature.Name { - case "wechat": - assert.Equal(t, "0.5", featureVersion.Version) - case "openssl": - assert.Equal(t, "1.0", featureVersion.Version) +func TestFindLayer(t *testing.T) { + datastore, tx := openSessionForTest(t, "TestFindLayer", true) + defer closeTest(t, datastore, tx) - if assert.Len(t, featureVersion.AffectedBy, 1) { - assert.Equal(t, "debian:7", featureVersion.AffectedBy[0].Namespace.Name) - assert.Equal(t, "CVE-OPENSSL-1-DEB7", featureVersion.AffectedBy[0].Name) - assert.Equal(t, database.HighSeverity, featureVersion.AffectedBy[0].Severity) - assert.Equal(t, "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", featureVersion.AffectedBy[0].Description) - assert.Equal(t, "http://google.com/#q=CVE-OPENSSL-1-DEB7", featureVersion.AffectedBy[0].Link) - assert.Equal(t, "2.0", featureVersion.AffectedBy[0].FixedBy) - } - default: - t.Errorf("unexpected package %s for layer-1", featureVersion.Feature.Name) - } - } + expected := database.Layer{Hash: "layer-4"} + expectedProcessors := database.Processors{ + Detectors: []string{"os-release", "apt-sources"}, + Listers: []string{"dpkg", "rpm"}, } - // Testing Multiple namespaces layer-3b has debian:7 and debian:8 namespaces - layer, err = datastore.FindLayer("layer-3b", true, true) - - if assert.Nil(t, err) && assert.NotNil(t, layer) && assert.Len(t, layer.Features, 2) { - assert.Equal(t, "layer-3b", layer.Name) - // validate the namespace - assertExpectedNamespaceName(t, &layer, []string{"debian:7", "debian:8"}) - for _, featureVersion := range layer.Features { - switch featureVersion.Feature.Namespace.Name { - case "debian:7": - assert.Equal(t, "wechat", featureVersion.Feature.Name) - assert.Equal(t, "0.5", featureVersion.Version) - case "debian:8": - assert.Equal(t, "openssl", featureVersion.Feature.Name) - assert.Equal(t, "1.0", featureVersion.Version) - default: - t.Errorf("unexpected package %s for layer-3b", featureVersion.Feature.Name) - } - } - } -} + // invalid + _, _, _, err := tx.FindLayer("") + assert.NotNil(t, err) + _, _, ok, err := tx.FindLayer("layer-non") + assert.Nil(t, err) + assert.False(t, ok) -func TestInsertLayer(t *testing.T) { - datastore, err := openDatabaseForTest("InsertLayer", false) - if err != nil { - t.Error(err) - return + // valid + layer, processors, ok2, err := tx.FindLayer("layer-4") + if assert.Nil(t, err) && assert.True(t, ok2) { + assert.Equal(t, expected, layer) + assertProcessorsEqual(t, expectedProcessors, processors) } - defer datastore.Close() - - // Insert invalid layer. - testInsertLayerInvalid(t, datastore) - - // Insert a layer tree. - testInsertLayerTree(t, datastore) - - // Update layer. - testInsertLayerUpdate(t, datastore) - - // Delete layer. - testInsertLayerDelete(t, datastore) } -func testInsertLayerInvalid(t *testing.T, datastore database.Datastore) { - invalidLayers := []database.Layer{ - {}, - {Name: "layer0", Parent: &database.Layer{}}, - {Name: "layer0", Parent: &database.Layer{Name: "UnknownLayer"}}, - } - - for _, invalidLayer := range invalidLayers { - err := datastore.InsertLayer(invalidLayer) - assert.Error(t, err) - } -} +func TestFindLayerWithContent(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistLayer", true) + defer closeTest(t, datastore, tx) -func testInsertLayerTree(t *testing.T, datastore database.Datastore) { - f1 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace2", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature1", - }, - Version: "1.0", - } - f2 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace2", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature2", - }, - Version: "0.34", - } - f3 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace2", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature3", - }, - Version: "0.56", - } - f4 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace3", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature2", - }, - Version: "0.34", - } - f5 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace3", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature3", - }, - Version: "0.56", - } - f6 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace3", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature4", - }, - Version: "0.666", - } + _, _, err := tx.FindLayerWithContent("") + assert.NotNil(t, err) + _, ok, err := tx.FindLayerWithContent("layer-non") + assert.Nil(t, err) + assert.False(t, ok) - layers := []database.Layer{ - { - Name: "TestInsertLayer1", - }, - { - Name: "TestInsertLayer2", - Parent: &database.Layer{Name: "TestInsertLayer1"}, - Namespaces: []database.Namespace{database.Namespace{ - Name: "TestInsertLayerNamespace1", - VersionFormat: dpkg.ParserName, - }}, - }, - // This layer changes the namespace and adds Features. - { - Name: "TestInsertLayer3", - Parent: &database.Layer{Name: "TestInsertLayer2"}, - Namespaces: []database.Namespace{database.Namespace{ - Name: "TestInsertLayerNamespace2", - VersionFormat: dpkg.ParserName, - }}, - Features: []database.FeatureVersion{f1, f2, f3}, + expectedL := database.LayerWithContent{ + Layer: database.Layer{ + Hash: "layer-4", }, - // This layer covers the case where the last layer doesn't provide any new Feature. - { - Name: "TestInsertLayer4a", - Parent: &database.Layer{Name: "TestInsertLayer3"}, - Features: []database.FeatureVersion{f1, f2, f3}, + Features: []database.Feature{ + {Name: "fake", Version: "2.0", VersionFormat: "rpm"}, + {Name: "openssl", Version: "2.0", VersionFormat: "dpkg"}, }, - // This layer covers the case where the last layer provides Features. - // It also modifies the Namespace ("upgrade") but keeps some Features not upgraded, their - // Namespaces should then remain unchanged. - { - Name: "TestInsertLayer4b", - Parent: &database.Layer{Name: "TestInsertLayer3"}, - Namespaces: []database.Namespace{database.Namespace{ - Name: "TestInsertLayerNamespace3", - VersionFormat: dpkg.ParserName, - }}, - Features: []database.FeatureVersion{ - // Deletes TestInsertLayerFeature1. - // Keep TestInsertLayerFeature2 (old Namespace should be kept): - f4, - // Upgrades TestInsertLayerFeature3 (with new Namespace): - f5, - // Adds TestInsertLayerFeature4: - f6, - }, + Namespaces: []database.Namespace{ + {Name: "debian:7", VersionFormat: "dpkg"}, + {Name: "fake:1.0", VersionFormat: "rpm"}, }, - } - - var err error - retrievedLayers := make(map[string]database.Layer) - for _, layer := range layers { - if layer.Parent != nil { - // Retrieve from database its parent and assign. - parent := retrievedLayers[layer.Parent.Name] - layer.Parent = &parent - } - - err = datastore.InsertLayer(layer) - assert.Nil(t, err) - - retrievedLayers[layer.Name], err = datastore.FindLayer(layer.Name, true, false) - assert.Nil(t, err) - } - - // layer inherits all namespaces from its ancestries - l4a := retrievedLayers["TestInsertLayer4a"] - assertExpectedNamespaceName(t, &l4a, []string{"TestInsertLayerNamespace2", "TestInsertLayerNamespace1"}) - assert.Len(t, l4a.Features, 3) - for _, featureVersion := range l4a.Features { - if cmpFV(featureVersion, f1) && cmpFV(featureVersion, f2) && cmpFV(featureVersion, f3) { - assert.Error(t, fmt.Errorf("TestInsertLayer4a contains an unexpected package: %#v. Should contain %#v and %#v and %#v.", featureVersion, f1, f2, f3)) - } - } - - l4b := retrievedLayers["TestInsertLayer4b"] - assertExpectedNamespaceName(t, &l4b, []string{"TestInsertLayerNamespace1", "TestInsertLayerNamespace2", "TestInsertLayerNamespace3"}) - assert.Len(t, l4b.Features, 3) - for _, featureVersion := range l4b.Features { - if cmpFV(featureVersion, f2) && cmpFV(featureVersion, f5) && cmpFV(featureVersion, f6) { - assert.Error(t, fmt.Errorf("TestInsertLayer4a contains an unexpected package: %#v. Should contain %#v and %#v and %#v.", featureVersion, f2, f4, f6)) - } - } -} - -func testInsertLayerUpdate(t *testing.T, datastore database.Datastore) { - f7 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace3", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature7", + ProcessedBy: database.Processors{ + Detectors: []string{"os-release", "apt-sources"}, + Listers: []string{"dpkg", "rpm"}, }, - Version: "0.01", - } - - l3, _ := datastore.FindLayer("TestInsertLayer3", true, false) - l3u := database.Layer{ - Name: l3.Name, - Parent: l3.Parent, - Namespaces: []database.Namespace{database.Namespace{ - Name: "TestInsertLayerNamespaceUpdated1", - VersionFormat: dpkg.ParserName, - }}, - Features: []database.FeatureVersion{f7}, - } - - l4u := database.Layer{ - Name: "TestInsertLayer4", - Parent: &database.Layer{Name: "TestInsertLayer3"}, - Features: []database.FeatureVersion{f7}, - EngineVersion: 2, - } - - // Try to re-insert without increasing the EngineVersion. - err := datastore.InsertLayer(l3u) - assert.Nil(t, err) - - l3uf, err := datastore.FindLayer(l3u.Name, true, false) - if assert.Nil(t, err) { - assertSameNamespaceName(t, &l3, &l3uf) - assert.Equal(t, l3.EngineVersion, l3uf.EngineVersion) - assert.Len(t, l3uf.Features, len(l3.Features)) - } - - // Update layer l3. - // Verify that the Namespace, EngineVersion and FeatureVersions got updated. - l3u.EngineVersion = 2 - err = datastore.InsertLayer(l3u) - assert.Nil(t, err) - - l3uf, err = datastore.FindLayer(l3u.Name, true, false) - if assert.Nil(t, err) { - assertSameNamespaceName(t, &l3u, &l3uf) - assert.Equal(t, l3u.EngineVersion, l3uf.EngineVersion) - if assert.Len(t, l3uf.Features, 1) { - assert.True(t, cmpFV(l3uf.Features[0], f7), "Updated layer should have %#v but actually have %#v", f7, l3uf.Features[0]) - } - } - - // Update layer l4. - // Verify that the Namespace got updated from its new Parent's, and also verify the - // EnginVersion and FeatureVersions. - l4u.Parent = &l3uf - err = datastore.InsertLayer(l4u) - assert.Nil(t, err) - - l4uf, err := datastore.FindLayer(l3u.Name, true, false) - if assert.Nil(t, err) { - assertSameNamespaceName(t, &l3u, &l4uf) - assert.Equal(t, l4u.EngineVersion, l4uf.EngineVersion) - if assert.Len(t, l4uf.Features, 1) { - assert.True(t, cmpFV(l3uf.Features[0], f7), "Updated layer should have %#v but actually have %#v", f7, l4uf.Features[0]) - } } -} - -func assertSameNamespaceName(t *testing.T, layer1 *database.Layer, layer2 *database.Layer) { - assert.Len(t, compareStringLists(extractNamespaceName(layer1), extractNamespaceName(layer2)), 0) -} - -func assertExpectedNamespaceName(t *testing.T, layer *database.Layer, expectedNames []string) { - assert.Len(t, compareStringLists(extractNamespaceName(layer), expectedNames), 0) -} -func extractNamespaceName(layer *database.Layer) []string { - slist := make([]string, 0, len(layer.Namespaces)) - for _, ns := range layer.Namespaces { - slist = append(slist, ns.Name) - } - return slist -} - -func testInsertLayerDelete(t *testing.T, datastore database.Datastore) { - err := datastore.DeleteLayer("TestInsertLayerX") - assert.Equal(t, commonerr.ErrNotFound, err) - - // ensure layer_namespace table is cleaned up once a layer is removed - layer3, err := datastore.FindLayer("TestInsertLayer3", false, false) - layer4a, err := datastore.FindLayer("TestInsertLayer4a", false, false) - layer4b, err := datastore.FindLayer("TestInsertLayer4b", false, false) - - err = datastore.DeleteLayer("TestInsertLayer3") - assert.Nil(t, err) - - _, err = datastore.FindLayer("TestInsertLayer3", false, false) - assert.Equal(t, commonerr.ErrNotFound, err) - assertNotInLayerNamespace(t, layer3.ID, datastore) - _, err = datastore.FindLayer("TestInsertLayer4a", false, false) - assert.Equal(t, commonerr.ErrNotFound, err) - assertNotInLayerNamespace(t, layer4a.ID, datastore) - _, err = datastore.FindLayer("TestInsertLayer4b", true, false) - assert.Equal(t, commonerr.ErrNotFound, err) - assertNotInLayerNamespace(t, layer4b.ID, datastore) -} - -func assertNotInLayerNamespace(t *testing.T, layerID int, datastore database.Datastore) { - pg, ok := datastore.(*pgSQL) - if !assert.True(t, ok) { - return - } - tx, err := pg.Begin() - if !assert.Nil(t, err) { - return + layer, ok2, err := tx.FindLayerWithContent("layer-4") + if assert.Nil(t, err) && assert.True(t, ok2) { + assertLayerWithContentEqual(t, expectedL, layer) } - rows, err := tx.Query(searchLayerNamespace, layerID) - assert.False(t, rows.Next()) } -func cmpFV(a, b database.FeatureVersion) bool { - return a.Feature.Name == b.Feature.Name && - a.Feature.Namespace.Name == b.Feature.Namespace.Name && - a.Version == b.Version +func assertLayerWithContentEqual(t *testing.T, expected database.LayerWithContent, actual database.LayerWithContent) bool { + return assert.Equal(t, expected.Layer, actual.Layer) && + assertFeaturesEqual(t, expected.Features, actual.Features) && + assertProcessorsEqual(t, expected.ProcessedBy, actual.ProcessedBy) && + assertNamespacesEqual(t, expected.Namespaces, actual.Namespaces) } diff --git a/database/pgsql/lock.go b/database/pgsql/lock.go index d3521b752a..66b35dbca5 100644 --- a/database/pgsql/lock.go +++ b/database/pgsql/lock.go @@ -15,6 +15,7 @@ package pgsql import ( + "errors" "time" log "github.com/sirupsen/logrus" @@ -22,86 +23,87 @@ import ( "github.com/coreos/clair/pkg/commonerr" ) +var ( + errLockNotFound = errors.New("lock is not in database") +) + // Lock tries to set a temporary lock in the database. // // Lock does not block, instead, it returns true and its expiration time // is the lock has been successfully acquired or false otherwise -func (pgSQL *pgSQL) Lock(name string, owner string, duration time.Duration, renew bool) (bool, time.Time) { +func (tx *pgSession) Lock(name string, owner string, duration time.Duration, renew bool) (bool, time.Time, error) { if name == "" || owner == "" || duration == 0 { log.Warning("could not create an invalid lock") - return false, time.Time{} + return false, time.Time{}, commonerr.NewBadRequestError("Invalid Lock Parameters") } - defer observeQueryTime("Lock", "all", time.Now()) - // Compute expiration. until := time.Now().Add(duration) if renew { // Renew lock. - r, err := pgSQL.Exec(updateLock, name, owner, until) + r, err := tx.Exec(updateLock, name, owner, until) if err != nil { - handleError("updateLock", err) - return false, until + return false, until, handleError("updateLock", err) } - if n, _ := r.RowsAffected(); n > 0 { + if n, err := r.RowsAffected(); err != nil { + return false, until, handleError("updateLock", err) + } else if n > 0 { // Updated successfully. - return true, until + return true, until, nil + } else { + return false, until, handleError("updateLock", errLockNotFound) } - } else { - // Prune locks. - pgSQL.pruneLocks() + } else if err := tx.pruneLocks(); err != nil { + return false, until, err } // Lock. - _, err := pgSQL.Exec(insertLock, name, owner, until) + _, err := tx.Exec(insertLock, name, owner, until) if err != nil { - if !isErrUniqueViolation(err) { - handleError("insertLock", err) - } - return false, until + return false, until, handleError("insertLock", err) } - return true, until + return true, until, nil } // Unlock unlocks a lock specified by its name if I own it -func (pgSQL *pgSQL) Unlock(name, owner string) { +func (tx *pgSession) Unlock(name, owner string) error { if name == "" || owner == "" { - log.Warning("could not delete an invalid lock") - return + return commonerr.NewBadRequestError("Invalid Lock Parameters") } defer observeQueryTime("Unlock", "all", time.Now()) - pgSQL.Exec(removeLock, name, owner) + _, err := tx.Exec(removeLock, name, owner) + return err } // FindLock returns the owner of a lock specified by its name and its // expiration time. -func (pgSQL *pgSQL) FindLock(name string) (string, time.Time, error) { +func (tx *pgSession) FindLock(name string) (string, time.Time, bool, error) { if name == "" { - log.Warning("could not find an invalid lock") - return "", time.Time{}, commonerr.NewBadRequestError("could not find an invalid lock") + return "", time.Time{}, false, commonerr.NewBadRequestError("could not find an invalid lock") } defer observeQueryTime("FindLock", "all", time.Now()) var owner string var until time.Time - err := pgSQL.QueryRow(searchLock, name).Scan(&owner, &until) + err := tx.QueryRow(searchLock, name).Scan(&owner, &until) if err != nil { - return owner, until, handleError("searchLock", err) + return owner, until, false, handleError("searchLock", err) } - return owner, until, nil + return owner, until, true, nil } // pruneLocks removes every expired locks from the database -func (pgSQL *pgSQL) pruneLocks() { +func (tx *pgSession) pruneLocks() error { defer observeQueryTime("pruneLocks", "all", time.Now()) - if _, err := pgSQL.Exec(removeLockExpired); err != nil { - handleError("removeLockExpired", err) + if _, err := tx.Exec(removeLockExpired); err != nil { + return handleError("removeLockExpired", err) } + return nil } diff --git a/database/pgsql/lock_test.go b/database/pgsql/lock_test.go index cbd2d9998f..82e24d80f9 100644 --- a/database/pgsql/lock_test.go +++ b/database/pgsql/lock_test.go @@ -1,16 +1,16 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// // Copyright 2016 clair authors +// // +// // Licensed under the Apache License, Version 2.0 (the "License"); +// // you may not use this file except in compliance with the License. +// // You may obtain a copy of the License at +// // +// // http://www.apache.org/licenses/LICENSE-2.0 +// // +// // Unless required by applicable law or agreed to in writing, software +// // distributed under the License is distributed on an "AS IS" BASIS, +// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// // See the License for the specific language governing permissions and +// // limitations under the License. package pgsql @@ -21,49 +21,77 @@ import ( "github.com/stretchr/testify/assert" ) -func TestLock(t *testing.T) { - datastore, err := openDatabaseForTest("InsertNamespace", false) - if err != nil { - t.Error(err) - return +func restartSession(t *testing.T, datastore *pgSQL, tx *pgSession, commit bool) *pgSession { + var err error + if !commit { + err = tx.Rollback() + } else { + err = tx.Commit() + } + + if assert.Nil(t, err) { + session, err := datastore.Begin() + if assert.Nil(t, err) { + return session.(*pgSession) + } } + t.FailNow() + return nil +} + +func TestLock(t *testing.T) { + datastore, tx := openSessionForTest(t, "Lock", true) defer datastore.Close() var l bool var et time.Time // Create a first lock. - l, _ = datastore.Lock("test1", "owner1", time.Minute, false) + l, _, err := tx.Lock("test1", "owner1", time.Minute, false) + assert.Nil(t, err) assert.True(t, l) + tx = restartSession(t, datastore, tx, true) // Try to lock the same lock with another owner. - l, _ = datastore.Lock("test1", "owner2", time.Minute, true) - assert.False(t, l) + l, _, err = tx.Lock("test1", "owner2", time.Minute, true) + assert.NotNil(t, err) + tx = restartSession(t, datastore, tx, false) - l, _ = datastore.Lock("test1", "owner2", time.Minute, false) - assert.False(t, l) + l, _, err = tx.Lock("test1", "owner2", time.Minute, false) + assert.NotNil(t, err) + tx = restartSession(t, datastore, tx, false) // Renew the lock. - l, _ = datastore.Lock("test1", "owner1", 2*time.Minute, true) + l, _, err = tx.Lock("test1", "owner1", 2*time.Minute, true) + assert.Nil(t, err) assert.True(t, l) // Unlock and then relock by someone else. - datastore.Unlock("test1", "owner1") + err = tx.Unlock("test1", "owner1") + assert.Nil(t, err) - l, et = datastore.Lock("test1", "owner2", time.Minute, false) + l, et, err = tx.Lock("test1", "owner2", time.Minute, false) + assert.Nil(t, err) assert.True(t, l) // LockInfo - o, et2, err := datastore.FindLock("test1") + o, et2, ok, err := tx.FindLock("test1") + assert.True(t, ok) assert.Nil(t, err) assert.Equal(t, "owner2", o) assert.Equal(t, et.Second(), et2.Second()) // Create a second lock which is actually already expired ... - l, _ = datastore.Lock("test2", "owner1", -time.Minute, false) + l, _, err = tx.Lock("test2", "owner1", -time.Minute, false) + assert.Nil(t, err) assert.True(t, l) // Take over the lock - l, _ = datastore.Lock("test2", "owner2", time.Minute, false) + l, _, err = tx.Lock("test2", "owner2", time.Minute, false) + assert.Nil(t, err) assert.True(t, l) + + if !assert.Nil(t, tx.Rollback()) { + t.FailNow() + } } diff --git a/database/pgsql/migrations/00001_change_migrator.go b/database/pgsql/migrations/00001_change_migrator.go deleted file mode 100644 index 8fef9ea0cf..0000000000 --- a/database/pgsql/migrations/00001_change_migrator.go +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import ( - "database/sql" - - "github.com/remind101/migrate" -) - -func init() { - // This migration removes the data maintained by the previous migration tool - // (liamstask/goose), and if it was present, mark the 00002_initial_schema - // migration as done. - RegisterMigration(migrate.Migration{ - ID: 1, - Up: func(tx *sql.Tx) error { - // Verify that goose was in use before, otherwise skip this migration. - var e bool - err := tx.QueryRow("SELECT true FROM pg_class WHERE relname = $1", "goose_db_version").Scan(&e) - if err == sql.ErrNoRows { - return nil - } - if err != nil { - return err - } - - // Delete goose's data. - _, err = tx.Exec("DROP TABLE goose_db_version CASCADE") - if err != nil { - return err - } - - // Mark the '00002_initial_schema' as done. - _, err = tx.Exec("INSERT INTO schema_migrations (version) VALUES (2)") - - return err - }, - Down: migrate.Queries([]string{}), - }) -} diff --git a/database/pgsql/migrations/00001_initial_schema.go b/database/pgsql/migrations/00001_initial_schema.go new file mode 100644 index 0000000000..6f01c3d0b0 --- /dev/null +++ b/database/pgsql/migrations/00001_initial_schema.go @@ -0,0 +1,192 @@ +// Copyright 2016 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package migrations + +import "github.com/remind101/migrate" + +func init() { + RegisterMigration(migrate.Migration{ + ID: 1, + Up: migrate.Queries([]string{ + // namespaces + `CREATE TABLE IF NOT EXISTS namespace ( + id SERIAL PRIMARY KEY, + name TEXT NULL, + version_format TEXT);`, + `CREATE UNIQUE INDEX ON namespace(name, version_format);`, + + // features + `CREATE TABLE IF NOT EXISTS feature ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + version TEXT NOT NULL, + version_format TEXT NOT NULL);`, + `CREATE UNIQUE INDEX ON feature(name, version, version_format);`, + + `CREATE TABLE IF NOT EXISTS namespaced_feature ( + id SERIAL PRIMARY KEY, + namespace_id INT REFERENCES namespace, + feature_id INT REFERENCES feature);`, + `CREATE UNIQUE INDEX ON namespaced_feature(namespace_id, feature_id);`, + + // layers + `CREATE TABLE IF NOT EXISTS layer( + id SERIAL PRIMARY KEY, + hash TEXT NOT NULL);`, + `CREATE UNIQUE INDEX ON layer(hash);`, + + `CREATE TABLE IF NOT EXISTS layer_feature ( + id SERIAL PRIMARY KEY, + layer_id INT REFERENCES layer ON DELETE CASCADE, + feature_id INT REFERENCES feature ON DELETE CASCADE);`, + `CREATE UNIQUE INDEX ON layer_feature(layer_id, feature_id);`, + + `CREATE TABLE IF NOT EXISTS layer_lister ( + id SERIAL PRIMARY KEY, + layer_id INT REFERENCES layer ON DELETE CASCADE, + lister TEXT NOT NULL);`, + `CREATE UNIQUE INDEX ON layer_lister(layer_id, lister);`, + `CREATE INDEX ON layer_lister(layer_id);`, + + `CREATE TABLE IF NOT EXISTS layer_detector ( + id SERIAL PRIMARY KEY, + layer_id INT REFERENCES layer ON DELETE CASCADE, + detector TEXT);`, + `CREATE UNIQUE INDEX ON layer_detector(layer_id, detector);`, + `CREATE INDEX ON layer_detector(layer_id);`, + + `CREATE TABLE IF NOT EXISTS layer_namespace ( + id SERIAL PRIMARY KEY, + layer_id INT REFERENCES layer ON DELETE CASCADE, + namespace_id INT REFERENCES namespace ON DELETE CASCADE);`, + `CREATE UNIQUE INDEX ON layer_namespace(layer_id, namespace_id);`, + `CREATE INDEX ON layer_namespace(layer_id);`, + + // ancestry + `CREATE TABLE IF NOT EXISTS ancestry ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL);`, + `CREATE UNIQUE INDEX ON ancestry(name);`, + + `CREATE TABLE IF NOT EXISTS ancestry_layer ( + id SERIAL PRIMARY KEY, + ancestry_id INT REFERENCES ancestry ON DELETE CASCADE, + ancestry_index INT NOT NULL, + layer_id INT REFERENCES layer);`, + `CREATE UNIQUE INDEX ON ancestry_layer(ancestry_id, ancestry_index);`, + `CREATE INDEX ON ancestry_layer(ancestry_id);`, + + `CREATE TABLE IF NOT EXISTS ancestry_feature ( + id SERIAL PRIMARY KEY, + ancestry_id INT REFERENCES ancestry ON DELETE CASCADE, + namespaced_feature_id INT REFERENCES namespaced_feature ON DELETE CASCADE);`, + `CREATE UNIQUE INDEX ON ancestry_feature(ancestry_id, namespaced_feature_id);`, + `CREATE INDEX ON ancestry_feature(ancestry_id);`, + `CREATE INDEX ON ancestry_feature(namespaced_feature_id);`, + + `CREATE TABLE IF NOT EXISTS ancestry_lister ( + id SERIAL PRIMARY KEY, + ancestry_id INT REFERENCES ancestry ON DELETE CASCADE, + lister TEXT);`, + `CREATE UNIQUE INDEX ON ancestry_lister(ancestry_id, lister);`, + `CREATE INDEX ON ancestry_lister(ancestry_id);`, + + `CREATE TABLE IF NOT EXISTS ancestry_detector ( + id SERIAL PRIMARY KEY, + ancestry_id INT REFERENCES ancestry ON DELETE CASCADE, + detector TEXT);`, + `CREATE UNIQUE INDEX ON ancestry_detector(ancestry_id, detector);`, + `CREATE INDEX ON ancestry_detector(ancestry_id);`, + + `CREATE TYPE severity AS ENUM ('Unknown', 'Negligible', 'Low', 'Medium', 'High', 'Critical', 'Defcon1');`, + + // vulnerability + `CREATE TABLE IF NOT EXISTS vulnerability ( + id SERIAL PRIMARY KEY, + namespace_id INT NOT NULL REFERENCES Namespace, + name TEXT NOT NULL, + description TEXT NULL, + link TEXT NULL, + severity severity NOT NULL, + metadata TEXT NULL, + created_at TIMESTAMP WITH TIME ZONE, + deleted_at TIMESTAMP WITH TIME ZONE NULL);`, + + `CREATE INDEX ON vulnerability(namespace_id, name);`, + + `CREATE TABLE IF NOT EXISTS vulnerability_affected_feature ( + id SERIAL PRIMARY KEY, + vulnerability_id INT NOT NULL REFERENCES vulnerability ON DELETE CASCADE, + feature_name TEXT NOT NULL, + affected_version TEXT, + fixedin TEXT);`, + `CREATE UNIQUE INDEX ON vulnerability_affected_feature(vulnerability_id, feature_name);`, + + `CREATE TABLE IF NOT EXISTS vulnerability_affected_namespaced_feature( + id SERIAL PRIMARY KEY, + vulnerability_id INT NOT NULL REFERENCES vulnerability ON DELETE CASCADE, + namespaced_feature_id INT NOT NULL REFERENCES namespaced_feature ON DELETE CASCADE, + added_by INT NOT NULL REFERENCES vulnerability_affected_feature ON DELETE CASCADE);`, + `CREATE UNIQUE INDEX ON vulnerability_affected_namespaced_feature(vulnerability_id, namespaced_feature_id);`, + + `CREATE TABLE IF NOT EXISTS KeyValue ( + id SERIAL PRIMARY KEY, + key TEXT NOT NULL UNIQUE, + value TEXT);`, + + `CREATE TABLE IF NOT EXISTS Lock ( + id SERIAL PRIMARY KEY, + name VARCHAR(64) NOT NULL UNIQUE, + owner VARCHAR(64) NOT NULL, + until TIMESTAMP WITH TIME ZONE);`, + `CREATE INDEX ON Lock (owner);`, + + // Notification + `CREATE TABLE IF NOT EXISTS Vulnerability_Notification ( + id SERIAL PRIMARY KEY, + name VARCHAR(64) NOT NULL UNIQUE, + created_at TIMESTAMP WITH TIME ZONE, + notified_at TIMESTAMP WITH TIME ZONE NULL, + deleted_at TIMESTAMP WITH TIME ZONE NULL, + old_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE, + new_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE);`, + `CREATE INDEX ON Vulnerability_Notification (notified_at);`, + }), + Down: migrate.Queries([]string{ + `DROP TABLE IF EXISTS + ancestry, + ancestry_layer, + ancestry_feature, + ancestry_detector, + ancestry_lister, + feature, + namespaced_feature, + keyvalue, + layer, + layer_detector, + layer_feature, + layer_lister, + layer_namespace, + lock, + namespace, + vulnerability, + vulnerability_affected_feature, + vulnerability_affected_namespaced_feature, + vulnerability_notification + CASCADE;`, + `DROP TYPE IF EXISTS severity;`, + }), + }) +} diff --git a/database/pgsql/migrations/00002_initial_schema.go b/database/pgsql/migrations/00002_initial_schema.go deleted file mode 100644 index f7cc17e68b..0000000000 --- a/database/pgsql/migrations/00002_initial_schema.go +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - // This migration creates the initial Clair's schema. - RegisterMigration(migrate.Migration{ - ID: 2, - Up: migrate.Queries([]string{ - `CREATE TABLE IF NOT EXISTS Namespace ( - id SERIAL PRIMARY KEY, - name VARCHAR(128) NULL);`, - - `CREATE TABLE IF NOT EXISTS Layer ( - id SERIAL PRIMARY KEY, - name VARCHAR(128) NOT NULL UNIQUE, - engineversion SMALLINT NOT NULL, - parent_id INT NULL REFERENCES Layer ON DELETE CASCADE, - namespace_id INT NULL REFERENCES Namespace, - created_at TIMESTAMP WITH TIME ZONE);`, - `CREATE INDEX ON Layer (parent_id);`, - `CREATE INDEX ON Layer (namespace_id);`, - - `CREATE TABLE IF NOT EXISTS Feature ( - id SERIAL PRIMARY KEY, - namespace_id INT NOT NULL REFERENCES Namespace, - name VARCHAR(128) NOT NULL, - UNIQUE (namespace_id, name));`, - - `CREATE TABLE IF NOT EXISTS FeatureVersion ( - id SERIAL PRIMARY KEY, - feature_id INT NOT NULL REFERENCES Feature, - version VARCHAR(128) NOT NULL);`, - `CREATE INDEX ON FeatureVersion (feature_id);`, - - `CREATE TYPE modification AS ENUM ('add', 'del');`, - `CREATE TABLE IF NOT EXISTS Layer_diff_FeatureVersion ( - id SERIAL PRIMARY KEY, - layer_id INT NOT NULL REFERENCES Layer ON DELETE CASCADE, - featureversion_id INT NOT NULL REFERENCES FeatureVersion, - modification modification NOT NULL, - UNIQUE (layer_id, featureversion_id));`, - `CREATE INDEX ON Layer_diff_FeatureVersion (layer_id);`, - `CREATE INDEX ON Layer_diff_FeatureVersion (featureversion_id);`, - `CREATE INDEX ON Layer_diff_FeatureVersion (featureversion_id, layer_id);`, - - `CREATE TYPE severity AS ENUM ('Unknown', 'Negligible', 'Low', 'Medium', 'High', 'Critical', 'Defcon1');`, - `CREATE TABLE IF NOT EXISTS Vulnerability ( - id SERIAL PRIMARY KEY, - namespace_id INT NOT NULL REFERENCES Namespace, - name VARCHAR(128) NOT NULL, - description TEXT NULL, - link VARCHAR(128) NULL, - severity severity NOT NULL, - metadata TEXT NULL, - created_at TIMESTAMP WITH TIME ZONE, - deleted_at TIMESTAMP WITH TIME ZONE NULL);`, - - `CREATE TABLE IF NOT EXISTS Vulnerability_FixedIn_Feature ( - id SERIAL PRIMARY KEY, - vulnerability_id INT NOT NULL REFERENCES Vulnerability ON DELETE CASCADE, - feature_id INT NOT NULL REFERENCES Feature, - version VARCHAR(128) NOT NULL, - UNIQUE (vulnerability_id, feature_id));`, - `CREATE INDEX ON Vulnerability_FixedIn_Feature (feature_id, vulnerability_id);`, - - `CREATE TABLE IF NOT EXISTS Vulnerability_Affects_FeatureVersion ( - id SERIAL PRIMARY KEY, - vulnerability_id INT NOT NULL REFERENCES Vulnerability ON DELETE CASCADE, - featureversion_id INT NOT NULL REFERENCES FeatureVersion, - fixedin_id INT NOT NULL REFERENCES Vulnerability_FixedIn_Feature ON DELETE CASCADE, - UNIQUE (vulnerability_id, featureversion_id));`, - `CREATE INDEX ON Vulnerability_Affects_FeatureVersion (fixedin_id);`, - `CREATE INDEX ON Vulnerability_Affects_FeatureVersion (featureversion_id, vulnerability_id);`, - - `CREATE TABLE IF NOT EXISTS KeyValue ( - id SERIAL PRIMARY KEY, - key VARCHAR(128) NOT NULL UNIQUE, - value TEXT);`, - - `CREATE TABLE IF NOT EXISTS Lock ( - id SERIAL PRIMARY KEY, - name VARCHAR(64) NOT NULL UNIQUE, - owner VARCHAR(64) NOT NULL, - until TIMESTAMP WITH TIME ZONE);`, - `CREATE INDEX ON Lock (owner);`, - - `CREATE TABLE IF NOT EXISTS Vulnerability_Notification ( - id SERIAL PRIMARY KEY, - name VARCHAR(64) NOT NULL UNIQUE, - created_at TIMESTAMP WITH TIME ZONE, - notified_at TIMESTAMP WITH TIME ZONE NULL, - deleted_at TIMESTAMP WITH TIME ZONE NULL, - old_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE, - new_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE);`, - `CREATE INDEX ON Vulnerability_Notification (notified_at);`, - }), - Down: migrate.Queries([]string{ - `DROP TABLE IF EXISTS - Namespace, - Layer, - Feature, - FeatureVersion, - Layer_diff_FeatureVersion, - Vulnerability, - Vulnerability_FixedIn_Feature, - Vulnerability_Affects_FeatureVersion, - Vulnerability_Notification, - KeyValue, - Lock - CASCADE;`, - }), - }) -} diff --git a/database/pgsql/migrations/00003_add_indexes.go b/database/pgsql/migrations/00003_add_indexes.go deleted file mode 100644 index 78ccaba2dd..0000000000 --- a/database/pgsql/migrations/00003_add_indexes.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - RegisterMigration(migrate.Migration{ - ID: 3, - Up: migrate.Queries([]string{ - `CREATE UNIQUE INDEX namespace_name_key ON Namespace (name);`, - `CREATE INDEX vulnerability_name_idx ON Vulnerability (name);`, - `CREATE INDEX vulnerability_namespace_id_name_idx ON Vulnerability (namespace_id, name);`, - `CREATE UNIQUE INDEX featureversion_feature_id_version_key ON FeatureVersion (feature_id, version);`, - }), - Down: migrate.Queries([]string{ - `DROP INDEX namespace_name_key;`, - `DROP INDEX vulnerability_name_idx;`, - `DROP INDEX vulnerability_namespace_id_name_idx;`, - `DROP INDEX featureversion_feature_id_version_key;`, - }), - }) -} diff --git a/database/pgsql/migrations/00004_add_index_notification_deleted_at.go b/database/pgsql/migrations/00004_add_index_notification_deleted_at.go deleted file mode 100644 index 12f38ab281..0000000000 --- a/database/pgsql/migrations/00004_add_index_notification_deleted_at.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - RegisterMigration(migrate.Migration{ - ID: 4, - Up: migrate.Queries([]string{ - `CREATE INDEX vulnerability_notification_deleted_at_idx ON Vulnerability_Notification (deleted_at);`, - }), - Down: migrate.Queries([]string{ - `DROP INDEX vulnerability_notification_deleted_at_idx;`, - }), - }) -} diff --git a/database/pgsql/migrations/00005_ldfv_index.go b/database/pgsql/migrations/00005_ldfv_index.go deleted file mode 100644 index ec8e713713..0000000000 --- a/database/pgsql/migrations/00005_ldfv_index.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - RegisterMigration(migrate.Migration{ - ID: 5, - Up: migrate.Queries([]string{ - `CREATE INDEX layer_diff_featureversion_layer_id_modification_idx ON Layer_diff_FeatureVersion (layer_id, modification);`, - }), - Down: migrate.Queries([]string{ - `DROP INDEX layer_diff_featureversion_layer_id_modification_idx;`, - }), - }) -} diff --git a/database/pgsql/migrations/00006_add_version_format.go b/database/pgsql/migrations/00006_add_version_format.go deleted file mode 100644 index 3a08f6f059..0000000000 --- a/database/pgsql/migrations/00006_add_version_format.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - RegisterMigration(migrate.Migration{ - ID: 6, - Up: migrate.Queries([]string{ - `ALTER TABLE Namespace ADD COLUMN version_format varchar(128);`, - `UPDATE Namespace SET version_format = 'rpm' WHERE name LIKE 'rhel%' OR name LIKE 'centos%' OR name LIKE 'fedora%' OR name LIKE 'amzn%' OR name LIKE 'scientific%' OR name LIKE 'ol%' OR name LIKE 'oracle%';`, - `UPDATE Namespace SET version_format = 'dpkg' WHERE version_format is NULL;`, - }), - Down: migrate.Queries([]string{ - `ALTER TABLE Namespace DROP COLUMN version_format;`, - }), - }) -} diff --git a/database/pgsql/migrations/00007_expand_column_width.go b/database/pgsql/migrations/00007_expand_column_width.go deleted file mode 100644 index 8bfdaaab7a..0000000000 --- a/database/pgsql/migrations/00007_expand_column_width.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2017 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - RegisterMigration(migrate.Migration{ - ID: 7, - Up: migrate.Queries([]string{ - `ALTER TABLE Namespace ALTER COLUMN version_format SET DATA TYPE varchar(256);`, - `ALTER TABLE Layer ALTER COLUMN name SET DATA TYPE varchar(256);`, - }), - Down: migrate.Queries([]string{ - `ALTER TABLE Namespace ALTER COLUMN version_format SET DATA TYPE varchar(128);`, - `ALTER TABLE Layer ALTER COLUMN name SET DATA TYPE varchar(128);`, - }), - }) -} diff --git a/database/pgsql/migrations/00008_add_multiplens.go b/database/pgsql/migrations/00008_add_multiplens.go deleted file mode 100644 index ecfb476222..0000000000 --- a/database/pgsql/migrations/00008_add_multiplens.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - RegisterMigration(migrate.Migration{ - ID: 8, - Up: migrate.Queries([]string{ - // set on deletion, remove the corresponding rows in database - `CREATE TABLE IF NOT EXISTS Layer_Namespace( - id SERIAL PRIMARY KEY, - layer_id INT REFERENCES Layer(id) ON DELETE CASCADE, - namespace_id INT REFERENCES Namespace(id) ON DELETE CASCADE, - unique(layer_id, namespace_id) - );`, - `CREATE INDEX ON Layer_Namespace (namespace_id);`, - `CREATE INDEX ON Layer_Namespace (layer_id);`, - // move the namespace_id to the table - `INSERT INTO Layer_Namespace (layer_id, namespace_id) SELECT id, namespace_id FROM Layer;`, - // alter the Layer table to remove the column - `ALTER TABLE IF EXISTS Layer DROP namespace_id;`, - }), - Down: migrate.Queries([]string{ - `ALTER TABLE IF EXISTS Layer ADD namespace_id INT NULL REFERENCES Namespace;`, - `CREATE INDEX ON Layer (namespace_id);`, - `UPDATE IF EXISTS Layer SET namespace_id = (SELECT lns.namespace_id FROM Layer_Namespace lns WHERE Layer.id = lns.layer_id LIMIT 1);`, - `DROP TABLE IF EXISTS Layer_Namespace;`, - }), - }) -} diff --git a/database/pgsql/namespace.go b/database/pgsql/namespace.go index 8d4b304bae..b0f1c2bbb7 100644 --- a/database/pgsql/namespace.go +++ b/database/pgsql/namespace.go @@ -15,61 +15,40 @@ package pgsql import ( - "time" + "database/sql" + "errors" "github.com/coreos/clair/database" "github.com/coreos/clair/pkg/commonerr" ) -func (pgSQL *pgSQL) insertNamespace(namespace database.Namespace) (int, error) { - if namespace.Name == "" { - return 0, commonerr.NewBadRequestError("could not find/insert invalid Namespace") - } +var ( + errNamespaceNotFound = errors.New("Requested Namespace is not in database") +) - if pgSQL.cache != nil { - promCacheQueriesTotal.WithLabelValues("namespace").Inc() - if id, found := pgSQL.cache.Get("namespace:" + namespace.Name); found { - promCacheHitsTotal.WithLabelValues("namespace").Inc() - return id.(int), nil +// PersistNamespaces soi namespaces into database. +func (tx *pgSession) PersistNamespaces(namespaces []database.Namespace) error { + for _, ns := range namespaces { + if ns.Name == "" || ns.VersionFormat == "" { + return commonerr.NewBadRequestError("Empty namespace name or version format is not allowed") + } + _, err := tx.Exec(persistNamespace, ns.Name, ns.VersionFormat) + if err != nil { + return handleError("persistNamespace", err) } } - - // We do `defer observeQueryTime` here because we don't want to observe cached namespaces. - defer observeQueryTime("insertNamespace", "all", time.Now()) - - var id int - err := pgSQL.QueryRow(soiNamespace, namespace.Name, namespace.VersionFormat).Scan(&id) - if err != nil { - return 0, handleError("soiNamespace", err) - } - - if pgSQL.cache != nil { - pgSQL.cache.Add("namespace:"+namespace.Name, id) - } - - return id, nil + return nil } -func (pgSQL *pgSQL) ListNamespaces() (namespaces []database.Namespace, err error) { - rows, err := pgSQL.Query(listNamespace) - if err != nil { - return namespaces, handleError("listNamespace", err) - } - defer rows.Close() - - for rows.Next() { - var ns database.Namespace - - err = rows.Scan(&ns.ID, &ns.Name, &ns.VersionFormat) - if err != nil { - return namespaces, handleError("listNamespace.Scan()", err) +func (tx *pgSession) findNamespaceIDs(namespaces []database.Namespace) ([]sql.NullInt64, error) { + ids := []sql.NullInt64{} + for _, ns := range namespaces { + var id sql.NullInt64 + err := tx.QueryRow(searchNamespaceID, ns.Name, ns.VersionFormat).Scan(&id) + if err != nil && err != sql.ErrNoRows { + return nil, handleError("searchNamespace", err) } - - namespaces = append(namespaces, ns) + ids = append(ids, id) } - if err = rows.Err(); err != nil { - return namespaces, handleError("listNamespace.Rows()", err) - } - - return namespaces, err + return ids, nil } diff --git a/database/pgsql/namespace_test.go b/database/pgsql/namespace_test.go index 0990b6f4c2..4c1d00236a 100644 --- a/database/pgsql/namespace_test.go +++ b/database/pgsql/namespace_test.go @@ -1,74 +1,82 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// // Copyright 2016 clair authors +// // +// // Licensed under the Apache License, Version 2.0 (the "License"); +// // you may not use this file except in compliance with the License. +// // You may obtain a copy of the License at +// // +// // http://www.apache.org/licenses/LICENSE-2.0 +// // +// // Unless required by applicable law or agreed to in writing, software +// // distributed under the License is distributed on an "AS IS" BASIS, +// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// // See the License for the specific language governing permissions and +// // limitations under the License. package pgsql import ( - "fmt" "testing" - "github.com/stretchr/testify/assert" - "github.com/coreos/clair/database" - "github.com/coreos/clair/ext/versionfmt/dpkg" + "github.com/stretchr/testify/assert" ) -func TestInsertNamespace(t *testing.T) { - datastore, err := openDatabaseForTest("InsertNamespace", false) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() +func TestPersistNamespaces(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistNamespaces", false) + defer closeTest(t, datastore, tx) + + ns1 := database.Namespace{} + ns2 := database.Namespace{Name: "t", VersionFormat: "b"} - // Invalid Namespace. - id0, err := datastore.insertNamespace(database.Namespace{}) - assert.NotNil(t, err) - assert.Zero(t, id0) + // Empty Case + assert.Nil(t, tx.PersistNamespaces([]database.Namespace{})) + // Invalid Case + assert.NotNil(t, tx.PersistNamespaces([]database.Namespace{ns1})) + // Duplicated Case + assert.Nil(t, tx.PersistNamespaces([]database.Namespace{ns2, ns2})) + // Existing Case + assert.Nil(t, tx.PersistNamespaces([]database.Namespace{ns2})) - // Insert Namespace and ensure we can find it. - id1, err := datastore.insertNamespace(database.Namespace{ - Name: "TestInsertNamespace1", - VersionFormat: dpkg.ParserName, - }) - assert.Nil(t, err) - id2, err := datastore.insertNamespace(database.Namespace{ - Name: "TestInsertNamespace1", - VersionFormat: dpkg.ParserName, - }) - assert.Nil(t, err) - assert.Equal(t, id1, id2) + nsList := listNamespaces(t, tx) + assert.Len(t, nsList, 1) + assert.Equal(t, ns2, nsList[0]) } -func TestListNamespace(t *testing.T) { - datastore, err := openDatabaseForTest("ListNamespaces", true) +func assertNamespacesEqual(t *testing.T, expected []database.Namespace, actual []database.Namespace) bool { + if assert.Len(t, actual, len(expected)) { + has := map[database.Namespace]bool{} + for _, i := range expected { + has[i] = false + } + for _, i := range actual { + has[i] = true + } + for key, v := range has { + if !assert.True(t, v, key.Name+"is expected") { + return false + } + } + return true + } + return false +} + +func listNamespaces(t *testing.T, tx *pgSession) []database.Namespace { + rows, err := tx.Query("SELECT name, version_format FROM namespace") if err != nil { - t.Error(err) - return + t.FailNow() } - defer datastore.Close() + defer rows.Close() - namespaces, err := datastore.ListNamespaces() - assert.Nil(t, err) - if assert.Len(t, namespaces, 2) { - for _, namespace := range namespaces { - switch namespace.Name { - case "debian:7", "debian:8": - continue - default: - assert.Error(t, fmt.Errorf("ListNamespaces should not have returned '%s'", namespace.Name)) - } + namespaces := []database.Namespace{} + for rows.Next() { + var ns database.Namespace + err := rows.Scan(&ns.Name, &ns.VersionFormat) + if err != nil { + t.FailNow() } + namespaces = append(namespaces, ns) } + + return namespaces } diff --git a/database/pgsql/notification.go b/database/pgsql/notification.go index f8c6960d28..8443d3f3e3 100644 --- a/database/pgsql/notification.go +++ b/database/pgsql/notification.go @@ -16,235 +16,228 @@ package pgsql import ( "database/sql" + "errors" "time" "github.com/guregu/null/zero" - "github.com/pborman/uuid" - log "github.com/sirupsen/logrus" "github.com/coreos/clair/database" "github.com/coreos/clair/pkg/commonerr" ) -// do it in tx so we won't insert/update a vuln without notification and vice-versa. -// name and created doesn't matter. -func createNotification(tx *sql.Tx, oldVulnerabilityID, newVulnerabilityID int) error { - defer observeQueryTime("createNotification", "all", time.Now()) +var ( + errNotificationNotFound = errors.New("requested notification is not in database") +) - // Insert Notification. - oldVulnerabilityNullableID := sql.NullInt64{Int64: int64(oldVulnerabilityID), Valid: oldVulnerabilityID != 0} - newVulnerabilityNullableID := sql.NullInt64{Int64: int64(newVulnerabilityID), Valid: newVulnerabilityID != 0} - _, err := tx.Exec(insertNotification, uuid.New(), oldVulnerabilityNullableID, newVulnerabilityNullableID) - if err != nil { - tx.Rollback() - return handleError("insertNotification", err) +func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.VulnerabilityNotification) error { + // ensure uniqueness of notifications + notiNameMap := map[string]struct{}{} + for _, noti := range notifications { + if noti.Name == "" { + return commonerr.NewBadRequestError("Empty notification name is not allowed") + } + if _, ok := notiNameMap[noti.Name]; ok { + return commonerr.NewBadRequestError("Duplicated notifications") + } + notiNameMap[noti.Name] = struct{}{} + } + + // retrieve all vulnerability IDs + newIDs := map[database.VulnerabilityID]sql.NullInt64{} + toQueryNew := make([]database.VulnerabilityID, 0, len(newIDs)) + oldIDs := map[database.VulnerabilityID]sql.NullInt64{} + toQueryOld := make([]database.VulnerabilityID, 0, len(oldIDs)) + for _, noti := range notifications { + if noti.New != nil { + key := database.VulnerabilityID{Name: noti.New.Name, Namespace: noti.New.Namespace.Name} + if _, ok := newIDs[key]; !ok { + newIDs[key] = sql.NullInt64{} + toQueryNew = append(toQueryNew, key) + } + } + if noti.Old != nil { + key := database.VulnerabilityID{Name: noti.Old.Name, Namespace: noti.Old.Namespace.Name} + if _, ok := oldIDs[key]; !ok { + oldIDs[key] = sql.NullInt64{} + toQueryOld = append(toQueryOld, key) + } + } } - return nil -} - -// Get one available notification name (!locked && !deleted && (!notified || notified_but_timed-out)). -// Does not fill new/old vuln. -func (pgSQL *pgSQL) GetAvailableNotification(renotifyInterval time.Duration) (database.VulnerabilityNotification, error) { - defer observeQueryTime("GetAvailableNotification", "all", time.Now()) - - before := time.Now().Add(-renotifyInterval) - row := pgSQL.QueryRow(searchNotificationAvailable, before) - notification, err := pgSQL.scanNotification(row, false) - - return notification, handleError("searchNotificationAvailable", err) -} - -func (pgSQL *pgSQL) GetNotification(name string, limit int, page database.VulnerabilityNotificationPageNumber) (database.VulnerabilityNotification, database.VulnerabilityNotificationPageNumber, error) { - defer observeQueryTime("GetNotification", "all", time.Now()) - - // Get Notification. - notification, err := pgSQL.scanNotification(pgSQL.QueryRow(searchNotification, name), true) + ids, err := tx.findNotDeletedVulnerabilityIDs(toQueryNew) if err != nil { - return notification, page, handleError("searchNotification", err) + return err } - // Load vulnerabilities' LayersIntroducingVulnerability. - page.OldVulnerability, err = pgSQL.loadLayerIntroducingVulnerability( - notification.OldVulnerability, - limit, - page.OldVulnerability, - ) + for i, id := range ids { + // ensure every vulnerability is in database + if !id.Valid { + return errVulnerabilityNotFound + } - if err != nil { - return notification, page, err + newIDs[toQueryNew[i]] = id } - page.NewVulnerability, err = pgSQL.loadLayerIntroducingVulnerability( - notification.NewVulnerability, - limit, - page.NewVulnerability, - ) - + ids, err = tx.findLatestDeletedVulnerabilityIDs(toQueryOld) if err != nil { - return notification, page, err + return err } - return notification, page, nil -} - -func (pgSQL *pgSQL) scanNotification(row *sql.Row, hasVulns bool) (database.VulnerabilityNotification, error) { - var notification database.VulnerabilityNotification - var created zero.Time - var notified zero.Time - var deleted zero.Time - var oldVulnerabilityNullableID sql.NullInt64 - var newVulnerabilityNullableID sql.NullInt64 - - // Scan notification. - if hasVulns { - err := row.Scan( - ¬ification.ID, - ¬ification.Name, - &created, - ¬ified, - &deleted, - &oldVulnerabilityNullableID, - &newVulnerabilityNullableID, - ) - - if err != nil { - return notification, err + for i, id := range ids { + if !id.Valid { + return errVulnerabilityNotFound } - } else { - err := row.Scan(¬ification.ID, ¬ification.Name, &created, ¬ified, &deleted) - if err != nil { - return notification, err - } + oldIDs[toQueryOld[i]] = id } - notification.Created = created.Time - notification.Notified = notified.Time - notification.Deleted = deleted.Time - - if hasVulns { - if oldVulnerabilityNullableID.Valid { - vulnerability, err := pgSQL.findVulnerabilityByIDWithDeleted(int(oldVulnerabilityNullableID.Int64)) - if err != nil { - return notification, err - } + for _, noti := range notifications { + var ( + newVulnID sql.NullInt64 + oldVulnID sql.NullInt64 + ) + if noti.New != nil { + newVulnID = newIDs[database.VulnerabilityID{Name: noti.New.Name, Namespace: noti.New.Namespace.Name}] + } - notification.OldVulnerability = &vulnerability + if noti.Old != nil { + oldVulnID = oldIDs[database.VulnerabilityID{Name: noti.Old.Name, Namespace: noti.Old.Namespace.Name}] } - if newVulnerabilityNullableID.Valid { - vulnerability, err := pgSQL.findVulnerabilityByIDWithDeleted(int(newVulnerabilityNullableID.Int64)) - if err != nil { - return notification, err - } + _, err = tx.Exec(insertNotification, noti.Name, noti.Created, + oldVulnID, + newVulnID, + ) - notification.NewVulnerability = &vulnerability + if err != nil { + return handleError("insertNotification", err) } } - return notification, nil + return nil } -// Fills Vulnerability.LayersIntroducingVulnerability. -// limit -1: won't do anything -// limit 0: will just get the startID of the second page -func (pgSQL *pgSQL) loadLayerIntroducingVulnerability(vulnerability *database.Vulnerability, limit, startID int) (int, error) { - tf := time.Now() +func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.Notification, bool, error) { + var ( + notification database.Notification + created zero.Time + notified zero.Time + deleted zero.Time + ) - if vulnerability == nil { - return -1, nil + err := tx.QueryRow(searchNotificationAvailable, notifiedBefore).Scan(¬ification.Name, &created, ¬ified, &deleted) + if err != nil { + if err == sql.ErrNoRows { + return notification, false, nil + } + return notification, false, handleError("searchNotificationAvailable", err) } - // A startID equals to -1 means that we reached the end already. - if startID == -1 || limit == -1 { - return -1, nil - } + notification.Created = created.Time + notification.Notified = notified.Time + notification.Deleted = deleted.Time - // Create a transaction to disable hash joins as our experience shows that - // PostgreSQL plans in certain cases a sequential scan and a hash on - // Layer_diff_FeatureVersion for the condition `ldfv.layer_id >= $2 AND - // ldfv.modification = 'add'` before realizing a hash inner join with - // Vulnerability_Affects_FeatureVersion. By disabling explictly hash joins, - // we force PostgreSQL to perform a bitmap index scan with - // `ldfv.featureversion_id = fv.id` on Layer_diff_FeatureVersion, followed by - // a bitmap heap scan on `ldfv.layer_id >= $2 AND ldfv.modification = 'add'`, - // thus avoiding a sequential scan on the biggest database table and - // allowing a small nested loop join instead. - tx, err := pgSQL.Begin() - if err != nil { - return -1, handleError("searchNotificationLayerIntroducingVulnerability.Begin()", err) - } - defer tx.Commit() + return notification, true, nil +} - _, err = tx.Exec(disableHashJoin) - if err != nil { - log.WithError(err).Warning("searchNotificationLayerIntroducingVulnerability: could not disable hash join") +func (tx *pgSession) FindVulnerabilityNotification(name string) (database.VulnerabilityNotification, bool, error) { + var ( + notification database.VulnerabilityNotification + created zero.Time + notified zero.Time + deleted zero.Time + oldVulnID sql.NullInt64 + newVulnID sql.NullInt64 + ) + + if name == "" { + return notification, false, commonerr.NewBadRequestError("Empty notification name is not allowed") } - // We do `defer observeQueryTime` here because we don't want to observe invalid calls. - defer observeQueryTime("loadLayerIntroducingVulnerability", "all", tf) + err := tx.QueryRow(searchNotification, name).Scan(¬ification.Name, + &created, ¬ified, &deleted, &oldVulnID, &newVulnID) - // Query with limit + 1, the last item will be used to know the next starting ID. - rows, err := tx.Query(searchNotificationLayerIntroducingVulnerability, - vulnerability.ID, startID, limit+1) if err != nil { - return 0, handleError("searchNotificationLayerIntroducingVulnerability", err) + if err == sql.ErrNoRows { + return notification, false, nil + } + return notification, false, handleError("searchNotification", err) + } + + if oldVulnID.Valid { + vuln := database.Vulnerability{} + err := tx.QueryRow(searchVulnerabilityByID, oldVulnID).Scan(&vuln.Name, + &vuln.Description, + &vuln.Link, + &vuln.Severity, + &vuln.Metadata, + &vuln.Namespace.Name, + &vuln.Namespace.VersionFormat) + if err != nil { + return notification, false, handleError("searchVulnerabilityByID", err) + } + notification.Old = &vuln + } + + if newVulnID.Valid { + vuln := database.Vulnerability{} + err := tx.QueryRow(searchVulnerabilityByID, newVulnID).Scan(&vuln.Name, + &vuln.Description, + &vuln.Link, + &vuln.Severity, + &vuln.Metadata, + &vuln.Namespace.Name, + &vuln.Namespace.VersionFormat) + if err != nil { + return notification, false, handleError("searchVulnerabilityByID", err) + } + notification.New = &vuln } - defer rows.Close() - var layers []database.Layer - for rows.Next() { - var layer database.Layer - - if err := rows.Scan(&layer.ID, &layer.Name); err != nil { - return -1, handleError("searchNotificationLayerIntroducingVulnerability.Scan()", err) - } + notification.Created = created.Time + notification.Notified = notified.Time + notification.Deleted = deleted.Time + return notification, true, nil +} - layers = append(layers, layer) - } - if err = rows.Err(); err != nil { - return -1, handleError("searchNotificationLayerIntroducingVulnerability.Rows()", err) +func (tx *pgSession) MarkNotificationNotified(name string) error { + if name == "" { + return commonerr.NewBadRequestError("Empty notification name is not allowed") } - size := limit - if len(layers) < limit { - size = len(layers) + r, err := tx.Exec(updatedNotificationNotified, name) + if err != nil { + return handleError("updatedNotificationNotified", err) } - vulnerability.LayersIntroducingVulnerability = layers[:size] - nextID := -1 - if len(layers) > limit { - nextID = layers[limit].ID + affected, err := r.RowsAffected() + if err != nil { + return handleError("updatedNotificationNotified", err) } - return nextID, nil -} - -func (pgSQL *pgSQL) SetNotificationNotified(name string) error { - defer observeQueryTime("SetNotificationNotified", "all", time.Now()) - - if _, err := pgSQL.Exec(updatedNotificationNotified, name); err != nil { - return handleError("updatedNotificationNotified", err) + if affected <= 0 { + return handleError("updatedNotificationNotified", errNotificationNotFound) } return nil } -func (pgSQL *pgSQL) DeleteNotification(name string) error { - defer observeQueryTime("DeleteNotification", "all", time.Now()) +func (tx *pgSession) DeleteNotification(name string) error { + if name == "" { + return commonerr.NewBadRequestError("Empty notification name is not allowed") + } - result, err := pgSQL.Exec(removeNotification, name) + result, err := tx.Exec(removeNotification, name) if err != nil { return handleError("removeNotification", err) } affected, err := result.RowsAffected() if err != nil { - return handleError("removeNotification.RowsAffected()", err) + return handleError("removeNotification", err) } if affected <= 0 { - return commonerr.ErrNotFound + return handleError("removeNotification", errNotificationNotFound) } return nil diff --git a/database/pgsql/notification_test.go b/database/pgsql/notification_test.go index 24e7924644..01ef1490b0 100644 --- a/database/pgsql/notification_test.go +++ b/database/pgsql/notification_test.go @@ -1,16 +1,16 @@ -// Copyright 2017 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// // Copyright 2017 clair authors +// // +// // Licensed under the Apache License, Version 2.0 (the "License"); +// // you may not use this file except in compliance with the License. +// // You may obtain a copy of the License at +// // +// // http://www.apache.org/licenses/LICENSE-2.0 +// // +// // Unless required by applicable law or agreed to in writing, software +// // distributed under the License is distributed on an "AS IS" BASIS, +// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// // See the License for the specific language governing permissions and +// // limitations under the License. package pgsql @@ -18,214 +18,160 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" + "database/sql" "github.com/coreos/clair/database" - "github.com/coreos/clair/ext/versionfmt" - "github.com/coreos/clair/ext/versionfmt/dpkg" - "github.com/coreos/clair/pkg/commonerr" + "github.com/stretchr/testify/assert" ) -func TestNotification(t *testing.T) { - datastore, err := openDatabaseForTest("Notification", false) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() - - // Try to get a notification when there is none. - _, err = datastore.GetAvailableNotification(time.Second) - assert.Equal(t, commonerr.ErrNotFound, err) - - // Create some data. - f1 := database.Feature{ - Name: "TestNotificationFeature1", - Namespace: database.Namespace{ - Name: "TestNotificationNamespace1", - VersionFormat: dpkg.ParserName, - }, - } +func TestInsertVulnerabilityNotifications(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistNamespacedFeatures", true) + defer closeTest(t, datastore, tx) - f2 := database.Feature{ - Name: "TestNotificationFeature2", - Namespace: database.Namespace{ - Name: "TestNotificationNamespace1", - VersionFormat: dpkg.ParserName, + n1 := database.VulnerabilityNotification{} + n3 := database.VulnerabilityNotification{ + Notification: database.Notification{ + Name: "random name", + Created: time.Now(), }, + Old: nil, + New: &database.Vulnerability{}, } - - l1 := database.Layer{ - Name: "TestNotificationLayer1", - Features: []database.FeatureVersion{ - { - Feature: f1, - Version: "0.1", + n4 := database.VulnerabilityNotification{ + Notification: database.Notification{ + Name: "random name", + Created: time.Now(), + }, + Old: nil, + New: &database.Vulnerability{ + Name: "CVE-OPENSSL-1-DEB7", + Namespace: database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", }, }, } - l2 := database.Layer{ - Name: "TestNotificationLayer2", - Features: []database.FeatureVersion{ - { - Feature: f1, - Version: "0.2", - }, - }, + // invalid case + err := tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n1}) + assert.NotNil(t, err) + + // invalid case: unknown vulnerability + err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n3}) + assert.NotNil(t, err) + + // invalid case: duplicated input notification + err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4, n4}) + assert.NotNil(t, err) + + name := "" + err = tx.QueryRow("SELECT name FROM vulnerability_notification WHERE name = $1", n4.Name).Scan(&name) + if !assert.Equal(t, sql.ErrNoRows, err) { + panic("fuck") } - l3 := database.Layer{ - Name: "TestNotificationLayer3", - Features: []database.FeatureVersion{ - { - Feature: f1, - Version: "0.3", - }, - }, + // valid case + err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4}) + assert.Nil(t, err) + + // invalid case: notification is already in database + err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4}) + assert.NotNil(t, err) +} + +func TestFindNewNotification(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistNamespacedFeatures", true) + defer closeTest(t, datastore, tx) + + noti, ok, err := tx.FindNewNotification(time.Now()) + if assert.Nil(t, err) && assert.True(t, ok) { + assert.Equal(t, "test", noti.Name) + assert.Equal(t, time.Time{}, noti.Notified) + assert.Equal(t, time.Time{}, noti.Created) + assert.Equal(t, time.Time{}, noti.Deleted) } - l4 := database.Layer{ - Name: "TestNotificationLayer4", - Features: []database.FeatureVersion{ - { - Feature: f2, - Version: "0.1", - }, - }, + // can't find the notified + assert.Nil(t, tx.MarkNotificationNotified("test")) + // if the notified time is before + noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(10 * time.Second))) + assert.Nil(t, err) + assert.False(t, ok) + // can find the notified after a period of time + noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(1000))) + if assert.Nil(t, err) && assert.True(t, ok) { + assert.Equal(t, "test", noti.Name) + assert.NotEqual(t, time.Time{}, noti.Notified) + assert.Equal(t, time.Time{}, noti.Created) + assert.Equal(t, time.Time{}, noti.Deleted) } - if !assert.Nil(t, datastore.InsertLayer(l1)) || - !assert.Nil(t, datastore.InsertLayer(l2)) || - !assert.Nil(t, datastore.InsertLayer(l3)) || - !assert.Nil(t, datastore.InsertLayer(l4)) { - return + assert.Nil(t, tx.DeleteNotification("test")) + // can't find in any time + noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(1000))) + assert.Nil(t, err) + assert.False(t, ok) + + noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(1000))) + assert.Nil(t, err) + assert.False(t, ok) +} + +func TestFindVulnerabilityNotification(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistNamespacedFeatures", true) + defer closeTest(t, datastore, tx) + + n1 := database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", } - // Insert a new vulnerability that is introduced by three layers. v1 := database.Vulnerability{ - Name: "TestNotificationVulnerability1", - Namespace: f1.Namespace, - Description: "TestNotificationDescription1", - Link: "TestNotificationLink1", - Severity: "Unknown", - FixedIn: []database.FeatureVersion{ - { - Feature: f1, - Version: "1.0", - }, - }, - } - assert.Nil(t, datastore.insertVulnerability(v1, false, true)) - - // Get the notification associated to the previously inserted vulnerability. - notification, err := datastore.GetAvailableNotification(time.Second) - - if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) { - // Verify the renotify behaviour. - if assert.Nil(t, datastore.SetNotificationNotified(notification.Name)) { - _, err := datastore.GetAvailableNotification(time.Second) - assert.Equal(t, commonerr.ErrNotFound, err) - - time.Sleep(50 * time.Millisecond) - notificationB, err := datastore.GetAvailableNotification(20 * time.Millisecond) - assert.Nil(t, err) - assert.Equal(t, notification.Name, notificationB.Name) - - datastore.SetNotificationNotified(notification.Name) - } - - // Get notification. - filledNotification, nextPage, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage) - if assert.Nil(t, err) { - assert.NotEqual(t, database.NoVulnerabilityNotificationPage, nextPage) - assert.Nil(t, filledNotification.OldVulnerability) - - if assert.NotNil(t, filledNotification.NewVulnerability) { - assert.Equal(t, v1.Name, filledNotification.NewVulnerability.Name) - assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 2) - } - } - - // Get second page. - filledNotification, nextPage, err = datastore.GetNotification(notification.Name, 2, nextPage) - if assert.Nil(t, err) { - assert.Equal(t, database.NoVulnerabilityNotificationPage, nextPage) - assert.Nil(t, filledNotification.OldVulnerability) - - if assert.NotNil(t, filledNotification.NewVulnerability) { - assert.Equal(t, v1.Name, filledNotification.NewVulnerability.Name) - assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 1) - } - } - - // Delete notification. - assert.Nil(t, datastore.DeleteNotification(notification.Name)) - - _, err = datastore.GetAvailableNotification(time.Millisecond) - assert.Equal(t, commonerr.ErrNotFound, err) + Namespace: n1, + Name: "CVE-OPENSSL-1-DEB7", + Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", + Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7", + Severity: database.HighSeverity, } - // Update a vulnerability and ensure that the old/new vulnerabilities are correct. - v1b := v1 - v1b.Severity = database.HighSeverity - v1b.FixedIn = []database.FeatureVersion{ - { - Feature: f1, - Version: versionfmt.MinVersion, - }, - { - Feature: f2, - Version: versionfmt.MaxVersion, - }, + v2 := database.Vulnerability{ + Namespace: n1, + Name: "CVE-NOPE", + Description: "A vulnerability affecting nothing", + Severity: database.UnknownSeverity, } - if assert.Nil(t, datastore.insertVulnerability(v1b, false, true)) { - notification, err = datastore.GetAvailableNotification(time.Second) - assert.Nil(t, err) - assert.NotEmpty(t, notification.Name) - - if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) { - filledNotification, nextPage, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage) - if assert.Nil(t, err) { - if assert.NotNil(t, filledNotification.OldVulnerability) { - assert.Equal(t, v1.Name, filledNotification.OldVulnerability.Name) - assert.Equal(t, v1.Severity, filledNotification.OldVulnerability.Severity) - assert.Len(t, filledNotification.OldVulnerability.LayersIntroducingVulnerability, 2) - } - - if assert.NotNil(t, filledNotification.NewVulnerability) { - assert.Equal(t, v1b.Name, filledNotification.NewVulnerability.Name) - assert.Equal(t, v1b.Severity, filledNotification.NewVulnerability.Severity) - assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 1) - } - - assert.Equal(t, -1, nextPage.NewVulnerability) - } - - assert.Nil(t, datastore.DeleteNotification(notification.Name)) - } + noti, ok, err := tx.FindVulnerabilityNotification("test") + if assert.Nil(t, err) && assert.True(t, ok) { + assert.Equal(t, "test", noti.Name) + assert.Equal(t, time.Time{}, noti.Notified) + assert.Equal(t, time.Time{}, noti.Created) + assert.Equal(t, time.Time{}, noti.Deleted) + assert.Equal(t, v1, *noti.New) + assert.Equal(t, v2, *noti.Old) } +} - // Delete a vulnerability and verify the notification. - if assert.Nil(t, datastore.DeleteVulnerability(v1b.Namespace.Name, v1b.Name)) { - notification, err = datastore.GetAvailableNotification(time.Second) - assert.Nil(t, err) - assert.NotEmpty(t, notification.Name) - - if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) { - filledNotification, _, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage) - if assert.Nil(t, err) { - assert.Nil(t, filledNotification.NewVulnerability) - - if assert.NotNil(t, filledNotification.OldVulnerability) { - assert.Equal(t, v1b.Name, filledNotification.OldVulnerability.Name) - assert.Equal(t, v1b.Severity, filledNotification.OldVulnerability.Severity) - assert.Len(t, filledNotification.OldVulnerability.LayersIntroducingVulnerability, 1) - } - } - - assert.Nil(t, datastore.DeleteNotification(notification.Name)) - } - } +func TestMarkNotificationNotified(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistNamespacedFeatures", true) + defer closeTest(t, datastore, tx) + + // invalid case: notification doesn't exist + assert.NotNil(t, tx.MarkNotificationNotified("non-existing")) + // valid case + assert.Nil(t, tx.MarkNotificationNotified("test")) + // valid case + assert.Nil(t, tx.MarkNotificationNotified("test")) +} + +func TestDeleteNotification(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistNamespacedFeatures", true) + defer closeTest(t, datastore, tx) + + // invalid case: notification doesn't exist + assert.NotNil(t, tx.DeleteNotification("non-existing")) + // valid case + assert.Nil(t, tx.DeleteNotification("test")) + // invalid case: notification is already deleted + assert.NotNil(t, tx.DeleteNotification("test")) } diff --git a/database/pgsql/pgsql.go b/database/pgsql/pgsql.go index 34504a9a1c..3a7441df85 100644 --- a/database/pgsql/pgsql.go +++ b/database/pgsql/pgsql.go @@ -59,7 +59,7 @@ var ( promConcurrentLockVAFV = prometheus.NewGauge(prometheus.GaugeOpts{ Name: "clair_pgsql_concurrent_lock_vafv_total", - Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_FeatureVersion lock.", + Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_Feature lock.", }) ) @@ -80,10 +80,61 @@ type Queryer interface { type pgSQL struct { *sql.DB + cache *lru.ARCCache config Config } +type pgSession struct { + *sql.Tx + + // cache is the pgSQL cache, this should only be modified by Commit() + // function. All changes to cache goes to `cachedIDs`. + cache *lru.ARCCache + cachedIDs map[string]int +} + +func (pgSQL *pgSQL) Begin() (database.Session, error) { + tx, err := pgSQL.DB.Begin() + if err != nil { + return nil, err + } + return &pgSession{ + Tx: tx, + cache: pgSQL.cache, + cachedIDs: make(map[string]int), + }, nil +} + +func (tx *pgSession) lookupID(key string) (int, bool) { + if v, ok := tx.cachedIDs[key]; ok { + return v, true + } + + if v, ok := tx.cache.Get(key); ok { + if vint, convok := v.(int); convok { + return vint, true + } + return 0, false + } + + return 0, false +} + +func (tx *pgSession) Commit() error { + err := tx.Tx.Commit() + if err != nil { + return err + } + + if tx.cache != nil { + for key, id := range tx.cachedIDs { + tx.cache.Add(key, id) + } + } + return nil +} + // Close closes the database and destroys if ManageDatabaseLifecycle has been specified in // the configuration. func (pgSQL *pgSQL) Close() { @@ -270,6 +321,7 @@ func dropDatabase(source, dbName string) error { return nil } +// should be used for every database query error // handleError logs an error with an extra description and masks the error if it's an SQL one. // This ensures we never return plain SQL errors and leak anything. func handleError(desc string, err error) error { diff --git a/database/pgsql/pgsql_test.go b/database/pgsql/pgsql_test.go index 93f5314456..466c716cca 100644 --- a/database/pgsql/pgsql_test.go +++ b/database/pgsql/pgsql_test.go @@ -15,27 +15,191 @@ package pgsql import ( + "database/sql" "fmt" + "io/ioutil" "os" "path/filepath" "runtime" "strings" + "testing" "github.com/pborman/uuid" + log "github.com/sirupsen/logrus" + yaml "gopkg.in/yaml.v2" "github.com/coreos/clair/database" ) +var ( + withFixtureName, withoutFixtureName string +) + +func genTemplateDatabase(name string, loadFixture bool) (sourceURL string, dbName string) { + config := generateTestConfig(name, loadFixture, false) + source := config.Options["source"].(string) + name, url, err := parseConnectionString(source) + if err != nil { + panic(err) + } + + fixturePath := config.Options["fixturepath"].(string) + + if err := createDatabase(url, name); err != nil { + panic(err) + } + + // migration and fixture + db, err := sql.Open("postgres", source) + if err != nil { + panic(err) + } + + // Verify database state. + if err := db.Ping(); err != nil { + panic(err) + } + + // Run migrations. + if err := migrateDatabase(db); err != nil { + panic(err) + } + + if loadFixture { + log.Info("pgsql: loading fixtures") + + d, err := ioutil.ReadFile(fixturePath) + if err != nil { + panic(err) + } + + _, err = db.Exec(string(d)) + if err != nil { + panic(err) + } + } + + db.Exec("UPDATE pg_database SET datistemplate=True WHERE datname=$1", name) + db.Close() + + log.Info("Generated Template database ", name) + return url, name +} + +func dropTemplateDatabase(url string, name string) { + db, err := sql.Open("postgres", url) + if err != nil { + panic(err) + } + + if _, err := db.Exec("UPDATE pg_database SET datistemplate=False WHERE datname=$1", name); err != nil { + panic(err) + } + + if err := db.Close(); err != nil { + panic(err) + } + + if err := dropDatabase(url, name); err != nil { + panic(err) + } + +} +func TestMain(m *testing.M) { + fURL, fName := genTemplateDatabase("fixture", true) + nfURL, nfName := genTemplateDatabase("nonfixture", false) + + withFixtureName = fName + withoutFixtureName = nfName + + m.Run() + + dropTemplateDatabase(fURL, fName) + dropTemplateDatabase(nfURL, nfName) +} + +func openCopiedDatabase(testConfig database.RegistrableComponentConfig, fixture bool) (database.Datastore, error) { + var fixtureName string + if fixture { + fixtureName = withFixtureName + } else { + fixtureName = withoutFixtureName + } + + // copy the database into new database + var pg pgSQL + // Parse configuration. + pg.config = Config{ + CacheSize: 16384, + } + + bytes, err := yaml.Marshal(testConfig.Options) + if err != nil { + return nil, fmt.Errorf("pgsql: could not load configuration: %v", err) + } + err = yaml.Unmarshal(bytes, &pg.config) + if err != nil { + return nil, fmt.Errorf("pgsql: could not load configuration: %v", err) + } + + dbName, pgSourceURL, err := parseConnectionString(pg.config.Source) + if err != nil { + return nil, err + } + + // Create database. + if pg.config.ManageDatabaseLifecycle { + if err = copyDatabase(pgSourceURL, dbName, fixtureName); err != nil { + return nil, err + } + } + + // Open database. + pg.DB, err = sql.Open("postgres", pg.config.Source) + fmt.Println("database", pg.config.Source) + if err != nil { + pg.Close() + return nil, fmt.Errorf("pgsql: could not open database: %v", err) + } + + return &pg, nil +} + +// copyDatabase creates a new database with +func copyDatabase(url, name string, templateName string) error { + // Open database. + db, err := sql.Open("postgres", url) + if err != nil { + return fmt.Errorf("pgsql: could not open 'postgres' database for creation: %v", err) + } + defer db.Close() + + // Create database with copy + _, err = db.Exec("CREATE DATABASE " + name + " WITH TEMPLATE " + templateName) + if err != nil { + return fmt.Errorf("pgsql: could not create database: %v", err) + } + + return nil +} + func openDatabaseForTest(testName string, loadFixture bool) (*pgSQL, error) { - ds, err := openDatabase(generateTestConfig(testName, loadFixture)) + var ( + db database.Datastore + err error + testConfig = generateTestConfig(testName, loadFixture, true) + ) + + db, err = openCopiedDatabase(testConfig, loadFixture) + if err != nil { return nil, err } - datastore := ds.(*pgSQL) + datastore := db.(*pgSQL) return datastore, nil } -func generateTestConfig(testName string, loadFixture bool) database.RegistrableComponentConfig { +func generateTestConfig(testName string, loadFixture bool, manageLife bool) database.RegistrableComponentConfig { dbName := "test_" + strings.ToLower(testName) + "_" + strings.Replace(uuid.New(), "-", "_", -1) var fixturePath string @@ -53,8 +217,32 @@ func generateTestConfig(testName string, loadFixture bool) database.RegistrableC Options: map[string]interface{}{ "source": source, "cachesize": 0, - "managedatabaselifecycle": true, + "managedatabaselifecycle": manageLife, "fixturepath": fixturePath, }, } } + +func closeTest(t *testing.T, store database.Datastore, session database.Session) { + err := session.Rollback() + if err != nil { + t.Error(err) + t.FailNow() + } + + store.Close() +} + +func openSessionForTest(t *testing.T, name string, loadFixture bool) (*pgSQL, *pgSession) { + store, err := openDatabaseForTest(name, loadFixture) + if err != nil { + t.Error(err) + t.FailNow() + } + tx, err := store.Begin() + if err != nil { + t.Error(err) + t.FailNow() + } + return store, tx.(*pgSession) +} diff --git a/database/pgsql/queries.go b/database/pgsql/queries.go index 3fedf8d0ad..183fb86ef0 100644 --- a/database/pgsql/queries.go +++ b/database/pgsql/queries.go @@ -15,129 +15,122 @@ package pgsql import "strconv" +import "strings" +import "github.com/lib/pq" const ( - lockVulnerabilityAffects = `LOCK Vulnerability_Affects_FeatureVersion IN SHARE ROW EXCLUSIVE MODE` - disableHashJoin = `SET LOCAL enable_hashjoin = off` - disableMergeJoin = `SET LOCAL enable_mergejoin = off` + lockVulnerabilityAffects = `LOCK vulnerability_affected_namespaced_feature IN SHARE ROW EXCLUSIVE MODE` // keyvalue.go - updateKeyValue = `UPDATE KeyValue SET value = $1 WHERE key = $2` - insertKeyValue = `INSERT INTO KeyValue(key, value) VALUES($1, $2)` searchKeyValue = `SELECT value FROM KeyValue WHERE key = $1` - + upsertKeyValue = `INSERT INTO KeyValue(key, value) + VALUES ($1, $2) +ON CONFLICT ON CONSTRAINT keyvalue_key_key +DO UPDATE SET key=$1, value=$2` + // namespace.go - soiNamespace = ` - WITH new_namespace AS ( - INSERT INTO Namespace(name, version_format) - SELECT CAST($1 AS VARCHAR), CAST($2 AS VARCHAR) - WHERE NOT EXISTS (SELECT name FROM Namespace WHERE name = $1) - RETURNING id - ) - SELECT id FROM Namespace WHERE name = $1 - UNION - SELECT id FROM new_namespace` - - searchNamespace = `SELECT id FROM Namespace WHERE name = $1` - listNamespace = `SELECT id, name, version_format FROM Namespace` + persistNamespace = `INSERT INTO namespace (name, version_format) VALUES ($1, $2) ON CONFLICT DO NOTHING` + searchNamespaceID = `SELECT id FROM Namespace WHERE name = $1 AND version_format = $2` // feature.go soiFeature = ` WITH new_feature AS ( - INSERT INTO Feature(name, namespace_id) - SELECT CAST($1 AS VARCHAR), CAST($2 AS INTEGER) - WHERE NOT EXISTS (SELECT id FROM Feature WHERE name = $1 AND namespace_id = $2) + INSERT INTO feature(name, version, version_format) + SELECT CAST ($1 AS TEXT), CAST ($2 AS TEXT), CAST ($3 AS TEXT) + WHERE NOT EXISTS ( SELECT id FROM feature WHERE feature.name = $1 AND feature.version = $2 AND feature.version_format = $3) RETURNING id ) - SELECT id FROM Feature WHERE name = $1 AND namespace_id = $2 + SELECT id FROM feature WHERE feature.name = $1 AND feature.version = $2 AND feature.version_format = $3 UNION SELECT id FROM new_feature` - searchFeatureVersion = ` - SELECT id FROM FeatureVersion WHERE feature_id = $1 AND version = $2` + searchFeature = `SELECT id FROM Feature WHERE name = $1 AND version = $2 AND version_format=$3` - soiFeatureVersion = ` - WITH new_featureversion AS ( - INSERT INTO FeatureVersion(feature_id, version) - SELECT CAST($1 AS INTEGER), CAST($2 AS VARCHAR) - WHERE NOT EXISTS (SELECT id FROM FeatureVersion WHERE feature_id = $1 AND version = $2) + soiNamespacedFeature = ` + WITH new_feature_ns AS ( + INSERT INTO namespaced_feature(feature_id, namespace_id) + SELECT CAST ($1 AS INTEGER), CAST ($2 AS INTEGER) + WHERE NOT EXISTS ( SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2) RETURNING id ) - SELECT false, id FROM FeatureVersion WHERE feature_id = $1 AND version = $2 + SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2 UNION - SELECT true, id FROM new_featureversion` - - searchVulnerabilityFixedInFeature = ` - SELECT id, vulnerability_id, version FROM Vulnerability_FixedIn_Feature - WHERE feature_id = $1` - - insertVulnerabilityAffectsFeatureVersion = ` - INSERT INTO Vulnerability_Affects_FeatureVersion(vulnerability_id, featureversion_id, fixedin_id) - VALUES($1, $2, $3)` + SELECT id FROM new_feature_ns` + + searchNamespacedFeature = ` + SELECT nf.id + FROM namespaced_feature AS nf, feature AS f, namespace AS n + WHERE nf.feature_id = f.id + AND nf.namespace_id = n.id + AND f.name = $1 + AND f.version = $2 + AND f.version_format = $3 + AND n.version_format = f.version_format + AND n.name = $4` + + searchPotentialNamespacedFeatureVulnerabilities = ` + SELECT nf.id, v.id, vaf.affected_version, vaf.id + FROM vulnerability_affected_feature AS vaf, vulnerability AS v, + namespaced_feature AS nf, feature AS f + WHERE nf.id = ANY($1) + AND nf.feature_id = f.id + AND nf.namespace_id = v.namespace_id + AND vaf.feature_name = f.name + AND vaf.vulnerability_id = v.id + AND v.deleted_at IS NULL` + + persistVulnerabilityAffectedNamespacedFeature = ` + INSERT INTO vulnerability_affected_namespaced_feature(vulnerability_id, namespaced_feature_id, added_by) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING + ` + + searchNamespacedFeaturesVulnerabilities = ` + SELECT vanf.namespaced_feature_id, v.name, v.description, v.link, + v.severity, v.metadata, vaf.fixedin, n.name, n.version_format + FROM vulnerability_affected_namespaced_feature AS vanf, + Vulnerability AS v, + vulnerability_affected_feature AS vaf, + namespace AS n + WHERE vanf.namespaced_feature_id = ANY($1) + AND vaf.id = vanf.added_by + AND v.id = vanf.vulnerability_id + AND n.id = v.namespace_id + AND v.deleted_at IS NULL` // layer.go - searchLayer = ` - SELECT l.id, l.name, l.engineversion, p.id, p.name - FROM Layer l - LEFT JOIN Layer p ON l.parent_id = p.id - WHERE l.name = $1;` - - searchLayerNamespace = ` - SELECT n.id, n.name, n.version_format - FROM Namespace n - JOIN Layer_Namespace lns ON lns.namespace_id = n.id - WHERE lns.layer_id = $1` - - searchLayerFeatureVersion = ` - WITH RECURSIVE layer_tree(id, name, parent_id, depth, path, cycle) AS( - SELECT l.id, l.name, l.parent_id, 1, ARRAY[l.id], false - FROM Layer l - WHERE l.id = $1 - UNION ALL - SELECT l.id, l.name, l.parent_id, lt.depth + 1, path || l.id, l.id = ANY(path) - FROM Layer l, layer_tree lt - WHERE l.id = lt.parent_id + searchLayerIDs = `SELECT id, hash FROM layer WHERE hash = ANY($1);` + soiLayer = `WITH new_layer AS ( + INSERT INTO layer (hash) + SELECT CAST ($1 AS TEXT) + WHERE NOT EXISTS ( SELECT id FROM layer WHERE hash = $1 ) + RETURNING id ) - SELECT ldf.featureversion_id, ldf.modification, fn.id, fn.name, fn.version_format, f.id, f.name, fv.id, fv.version, ltree.id, ltree.name - FROM Layer_diff_FeatureVersion ldf - JOIN ( - SELECT row_number() over (ORDER BY depth DESC), id, name FROM layer_tree - ) AS ltree (ordering, id, name) ON ldf.layer_id = ltree.id, FeatureVersion fv, Feature f, Namespace fn - WHERE ldf.featureversion_id = fv.id AND fv.feature_id = f.id AND f.namespace_id = fn.id - ORDER BY ltree.ordering` - - searchFeatureVersionVulnerability = ` - SELECT vafv.featureversion_id, v.id, v.name, v.description, v.link, v.severity, v.metadata, - vn.name, vn.version_format, vfif.version - FROM Vulnerability_Affects_FeatureVersion vafv, Vulnerability v, - Namespace vn, Vulnerability_FixedIn_Feature vfif - WHERE vafv.featureversion_id = ANY($1::integer[]) - AND vfif.vulnerability_id = v.id - AND vafv.fixedin_id = vfif.id - AND v.namespace_id = vn.id - AND v.deleted_at IS NULL` - - insertLayer = ` - INSERT INTO Layer(name, engineversion, parent_id, created_at) - VALUES($1, $2, $3, CURRENT_TIMESTAMP) - RETURNING id` - - insertLayerNamespace = `INSERT INTO Layer_Namespace(layer_id, namespace_id) VALUES($1, $2)` - removeLayerNamespace = `DELETE FROM Layer_Namespace WHERE layer_id = $1` - - updateLayer = `UPDATE LAYER SET engineversion = $2 WHERE id = $1` - - removeLayerDiffFeatureVersion = ` - DELETE FROM Layer_diff_FeatureVersion - WHERE layer_id = $1` - - insertLayerDiffFeatureVersion = ` - INSERT INTO Layer_diff_FeatureVersion(layer_id, featureversion_id, modification) - SELECT $1, fv.id, $2 - FROM FeatureVersion fv - WHERE fv.id = ANY($3::integer[])` - - removeLayer = `DELETE FROM Layer WHERE name = $1` + SELECT id FROM layer WHERE hash = $1 + UNION + SELECT id FROM new_layer;` + + persistLayerListers = `INSERT INTO layer_lister(layer_id, lister) VALUES ($1, $2) ON CONFLICT DO NOTHING` + + persistLayerDetectors = `INSERT INTO layer_detector(layer_id, detector) VALUES ($1, $2) ON CONFLICT DO NOTHING` + + persistLayerNamespace = `INSERT INTO layer_namespace (layer_id, namespace_id) VALUES ($1, $2) ON CONFLICT DO NOTHING` + + persistLayerFeature = `INSERT INTO layer_feature (layer_id, feature_id) VALUES ($1, $2) ON CONFLICT DO NOTHING` + + searchLayerFeatures = ` + SELECT feature.Name, feature.Version, feature.version_format + FROM feature, layer_feature + WHERE layer_feature.layer_id = $1 + AND layer_feature.feature_id = feature.id` + + searchLayerNamespaces = ` + SELECT namespace.Name, namespace.version_format + FROM namespace, layer_namespace + WHERE layer_namespace.layer_id = $1 + AND layer_namespace.namespace_id = namespace.id` + + searchLayer = `SELECT id FROM layer WHERE hash = $1` + searchLayerDetectors = `SELECT detector FROM layer_detector WHERE layer_id = $1` + searchLayerListers = `SELECT lister FROM layer_lister WHERE layer_id = $1` // lock.go insertLock = `INSERT INTO Lock(name, owner, until) VALUES($1, $2, $3)` @@ -147,52 +140,89 @@ const ( removeLockExpired = `DELETE FROM LOCK WHERE until < CURRENT_TIMESTAMP` // vulnerability.go - searchVulnerabilityBase = ` - SELECT v.id, v.name, n.id, n.name, n.version_format, v.description, v.link, v.severity, v.metadata - FROM Vulnerability v JOIN Namespace n ON v.namespace_id = n.id` - searchVulnerabilityForUpdate = ` FOR UPDATE OF v` - searchVulnerabilityByNamespaceAndName = ` WHERE n.name = $1 AND v.name = $2 AND v.deleted_at IS NULL` - searchVulnerabilityByID = ` WHERE v.id = $1` - searchVulnerabilityByNamespace = ` WHERE n.name = $1 AND v.deleted_at IS NULL - AND v.id >= $2 - ORDER BY v.id - LIMIT $3` - - searchVulnerabilityFixedIn = ` - SELECT vfif.version, f.id, f.Name - FROM Vulnerability_FixedIn_Feature vfif JOIN Feature f ON vfif.feature_id = f.id - WHERE vfif.vulnerability_id = $1` + searchVulnerability = ` + SELECT v.id, v.description, v.link, v.severity, v.metadata, n.version_format + FROM vulnerability AS v, namespace AS n + WHERE v.namespace_id = n.id + AND v.name = $1 + AND n.name = $2 + AND v.deleted_at IS NULL + ` + + insertVulnerabilityAffected = ` + INSERT INTO vulnerability_affected_feature(vulnerability_id, feature_name, affected_version, fixedin) + VALUES ($1, $2, $3, $4) + RETURNING ID + ` + + searchLatestDeletedVulnerabilityID = ` + SELECT v.id FROM vulnerability AS v, namespace AS n + WHERE v.name = $1 + AND n.name = $2 + AND v.namespace_id = n.id + AND v.deleted_at IS NOT NULL + ORDER BY v.deleted_at DESC + LIMIT 1` + + searchNotDeletedVulnerabilityID = ` + SELECT v.id FROM vulnerability AS v, namespace AS n + WHERE v.name = $1 + AND n.name = $2 + AND v.namespace_id = n.id + AND v.deleted_at IS NULL + LIMIT 1` + + searchVulnerabilityAffected = ` + SELECT vulnerability_id, feature_name, affected_version, fixedin + FROM vulnerability_affected_feature + WHERE vulnerability_id = ANY($1) + ` + + searchVulnerabilityByID = ` + SELECT v.name, v.description, v.link, v.severity, v.metadata, n.name, n.version_format + FROM vulnerability AS v, namespace AS n + WHERE v.namespace_id = n.id + AND v.id = $1` + + searchVulnerabilityPotentialAffected = ` + WITH req AS ( + SELECT vaf.id AS vaf_id, n.id AS n_id, vaf.feature_name AS name, v.id AS vulnerability_id + FROM vulnerability_affected_feature AS vaf, + vulnerability AS v, + namespace AS n + WHERE vaf.vulnerability_id = ANY($1) + AND v.id = vaf.vulnerability_id + AND n.id = v.namespace_id + ) + SELECT req.vulnerability_id, nf.id, f.version, req.vaf_id AS added_by + FROM feature AS f, namespaced_feature AS nf, req + WHERE f.name = req.name + AND nf.namespace_id = req.n_id + AND nf.feature_id = f.id` + + insertVulnerabilityAffectedNamespacedFeature = `INSERT INTO vulnerability_affected_namespaced_feature(vulnerability_id, namespaced_feature_id, added_by) + VALUES ($1, $2, $3)` insertVulnerability = ` + WITH ns AS ( + SELECT id FROM namespace WHERE name = $6 AND version_format = $7 + ) INSERT INTO Vulnerability(namespace_id, name, description, link, severity, metadata, created_at) - VALUES($1, $2, $3, $4, $5, $6, CURRENT_TIMESTAMP) + VALUES((SELECT id FROM ns), $1, $2, $3, $4, $5, CURRENT_TIMESTAMP) RETURNING id` - soiVulnerabilityFixedInFeature = ` - WITH new_fixedinfeature AS ( - INSERT INTO Vulnerability_FixedIn_Feature(vulnerability_id, feature_id, version) - SELECT CAST($1 AS INTEGER), CAST($2 AS INTEGER), CAST($3 AS VARCHAR) - WHERE NOT EXISTS (SELECT id FROM Vulnerability_FixedIn_Feature WHERE vulnerability_id = $1 AND feature_id = $2) - RETURNING id - ) - SELECT false, id FROM Vulnerability_FixedIn_Feature WHERE vulnerability_id = $1 AND feature_id = $2 - UNION - SELECT true, id FROM new_fixedinfeature` - - searchFeatureVersionByFeature = `SELECT id, version FROM FeatureVersion WHERE feature_id = $1` - removeVulnerability = ` UPDATE Vulnerability - SET deleted_at = CURRENT_TIMESTAMP - WHERE namespace_id = (SELECT id FROM Namespace WHERE name = $1) - AND name = $2 - AND deleted_at IS NULL - RETURNING id` + SET deleted_at = CURRENT_TIMESTAMP + WHERE namespace_id = (SELECT id FROM Namespace WHERE name = $1) + AND name = $2 + AND deleted_at IS NULL + RETURNING id` // notification.go insertNotification = ` INSERT INTO Vulnerability_Notification(name, created_at, old_vulnerability_id, new_vulnerability_id) - VALUES($1, CURRENT_TIMESTAMP, $2, $3)` + VALUES ($1, $2, $3, $4)` updatedNotificationNotified = ` UPDATE Vulnerability_Notification @@ -202,10 +232,10 @@ const ( removeNotification = ` UPDATE Vulnerability_Notification SET deleted_at = CURRENT_TIMESTAMP - WHERE name = $1` + WHERE name = $1 AND deleted_at IS NULL` searchNotificationAvailable = ` - SELECT id, name, created_at, notified_at, deleted_at + SELECT name, created_at, notified_at, deleted_at FROM Vulnerability_Notification WHERE (notified_at IS NULL OR notified_at < $1) AND deleted_at IS NULL @@ -214,35 +244,64 @@ const ( LIMIT 1` searchNotification = ` - SELECT id, name, created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id + SELECT name, created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id FROM Vulnerability_Notification WHERE name = $1` - searchNotificationLayerIntroducingVulnerability = ` - WITH LDFV AS ( - SELECT DISTINCT ldfv.layer_id - FROM Vulnerability_Affects_FeatureVersion vafv, FeatureVersion fv, Layer_diff_FeatureVersion ldfv - WHERE ldfv.layer_id >= $2 - AND vafv.vulnerability_id = $1 - AND vafv.featureversion_id = fv.id - AND ldfv.featureversion_id = fv.id - AND ldfv.modification = 'add' - ORDER BY ldfv.layer_id - ) - SELECT l.id, l.name - FROM LDFV, Layer l - WHERE LDFV.layer_id = l.id - LIMIT $3` - - // complex_test.go - searchComplexTestFeatureVersionAffects = ` - SELECT v.name - FROM FeatureVersion fv - LEFT JOIN Vulnerability_Affects_FeatureVersion vaf ON fv.id = vaf.featureversion_id - JOIN Vulnerability v ON vaf.vulnerability_id = v.id - WHERE featureversion_id = $1` + searchNotificationVulnerableAncestry = ` + SELECT DISTINCT ON (a.id) + a.id, a.name + FROM vulnerability_affected_namespaced_feature AS vanf, + ancestry AS a, ancestry_feature AS af + WHERE vanf.vulnerability_id = $1 + AND a.id >= $2 + AND a.id = af.ancestry_id + AND af.namespaced_feature_id = vanf.namespaced_feature_id + ORDER BY a.id ASC + LIMIT $3;` + + // ancestry.go + persistAncestryLister = `INSERT INTO ancestry_lister (ancestry_id, lister) + VALUES ($1, $2) ON CONFLICT DO NOTHING` + + persistAncestryDetector = `INSERT INTO ancestry_detector (ancestry_id, detector) + VALUES ($1, $2) ON CONFLICT DO NOTHING` + + insertAncestry = ` + INSERT INTO ancestry (name) VALUES ($1) + RETURNING id` + + searchAncestryLayer = ` + SELECT layer.hash + FROM layer, ancestry_layer + WHERE ancestry_layer.ancestry_id = $1 + AND ancestry_layer.layer_id = layer.id + ORDER BY ancestry_layer.ancestry_index ASC` + + searchAncestryFeatures = ` + SELECT namespace.name, namespace.version_format, feature.name, feature.version + FROM namespace, feature, ancestry, namespaced_feature, ancestry_feature + WHERE ancestry.name = $1 + AND ancestry.id = ancestry_feature.ancestry_id + AND ancestry_feature.namespaced_feature_id = namespaced_feature.id + AND namespaced_feature.feature_id = feature.id + AND namespaced_feature.namespace_id = namespace.id` + + searchAncestry = `SELECT id FROM ancestry WHERE name = $1` + searchAncestryDetectors = `SELECT detector FROM ancestry_detector WHERE ancestry_id = $1` + searchAncestryListers = `SELECT lister FROM ancestry_lister WHERE ancestry_id = $1` + removeAncestry = `DELETE FROM ancestry WHERE name = $1` ) +var ( + copyinAncestryLayer = pq.CopyIn("ancestry_layer", "ancestry_id", "ancestry_index", "layer_id") + copyinAncestryFeatures = pq.CopyIn("ancestry_feature", "ancestry_id", "namespaced_feature_id") +) + +func buildInputStringArray(strs []string) string { + return "{" + strings.Join(strs, ",") + "}" +} + // buildInputArray constructs a PostgreSQL input array from the specified integers. // Useful to use the `= ANY($1::integer[])` syntax that let us use a IN clause while using // a single placeholder. diff --git a/database/pgsql/testdata/data.sql b/database/pgsql/testdata/data.sql index b01e170e0b..5d9403b4c9 100644 --- a/database/pgsql/testdata/data.sql +++ b/database/pgsql/testdata/data.sql @@ -1,73 +1,114 @@ --- Copyright 2015 clair authors --- --- Licensed under the Apache License, Version 2.0 (the "License"); --- you may not use this file except in compliance with the License. --- You may obtain a copy of the License at --- --- http://www.apache.org/licenses/LICENSE-2.0 --- --- Unless required by applicable law or agreed to in writing, software --- distributed under the License is distributed on an "AS IS" BASIS, --- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --- See the License for the specific language governing permissions and --- limitations under the License. - INSERT INTO namespace (id, name, version_format) VALUES - (1, 'debian:7', 'dpkg'), - (2, 'debian:8', 'dpkg'); - -INSERT INTO feature (id, namespace_id, name) VALUES - (1, 1, 'wechat'), - (2, 1, 'openssl'), - (4, 1, 'libssl'), - (3, 2, 'openssl'); - -INSERT INTO featureversion (id, feature_id, version) VALUES - (1, 1, '0.5'), - (2, 2, '1.0'), - (3, 2, '2.0'), - (4, 3, '1.0'); - -INSERT INTO layer (id, name, engineversion, parent_id) VALUES - (1, 'layer-0', 1, NULL), - (2, 'layer-1', 1, 1), - (3, 'layer-2', 1, 2), - (4, 'layer-3a', 1, 3), - (5, 'layer-3b', 1, 3); - -INSERT INTO layer_namespace (id, layer_id, namespace_id) VALUES +(1, 'debian:7', 'dpkg'), +(2, 'debian:8', 'dpkg'), +(3, 'fake:1.0', 'rpm'); + +INSERT INTO feature (id, name, version, version_format) VALUES +(1, 'wechat', '0.5', 'dpkg'), +(2, 'openssl', '1.0', 'dpkg'), +(3, 'openssl', '2.0', 'dpkg'), +(4, 'fake', '2.0', 'rpm'); + +INSERT INTO layer (id, hash) VALUES + (1, 'layer-0'), -- blank + (2, 'layer-1'), -- debian:7; wechat 0.5, openssl 1.0 + (3, 'layer-2'), -- debian:7; wechat 0.5, openssl 2.0 + (4, 'layer-3a'),-- debian:7; + (5, 'layer-3b'),-- debian:8; wechat 0.5, openssl 1.0 + (6, 'layer-4'); -- debian:7, fake:1.0; openssl 2.0 (debian), fake 2.0 (fake) + +INSERT INTO layer_namespace(id, layer_id, namespace_id) VALUES (1, 2, 1), (2, 3, 1), (3, 4, 1), (4, 5, 2), - (5, 5, 1); + (5, 6, 1), + (6, 6, 3); + +INSERT INTO layer_feature(id, layer_id, feature_id) VALUES + (1, 2, 1), + (2, 2, 2), + (3, 3, 1), + (4, 3, 3), + (5, 5, 1), + (6, 5, 2), + (7, 6, 4), + (8, 6, 3); + +INSERT INTO layer_lister(id, layer_id, lister) VALUES + (1, 1, 'dpkg'), + (2, 2, 'dpkg'), + (3, 3, 'dpkg'), + (4, 4, 'dpkg'), + (5, 5, 'dpkg'), + (6, 6, 'dpkg'), + (7, 6, 'rpm'); + +INSERT INTO layer_detector(id, layer_id, detector) VALUES + (1, 1, 'os-release'), + (2, 2, 'os-release'), + (3, 3, 'os-release'), + (4, 4, 'os-release'), + (5, 5, 'os-release'), + (6, 6, 'os-release'), + (7, 6, 'apt-sources'); -INSERT INTO layer_diff_featureversion (id, layer_id, featureversion_id, modification) VALUES - (1, 2, 1, 'add'), - (2, 2, 2, 'add'), - (3, 3, 2, 'del'), -- layer-2: Update Debian:7 OpenSSL 1.0 -> 2.0 - (4, 3, 3, 'add'), -- ^ - (5, 5, 3, 'del'), -- layer-3b: Delete Debian:7 OpenSSL 2.0 - (6, 5, 4, 'add'); -- layer-3b: Add Debian:8 OpenSSL 1.0 +INSERT INTO ancestry (id, name) VALUES + (1, 'ancestry-1'), -- layer-0, layer-1, layer-2, layer-3a + (2, 'ancestry-2'), -- layer-0, layer-1, layer-2, layer-3b + (3, 'ancestry-3'), -- empty; just for testing the vulnerable ancestry + (4, 'ancestry-4'); -- empty; just for testing the vulnerable ancestry + +INSERT INTO ancestry_lister (id, ancestry_id, lister) VALUES + (1, 1, 'dpkg'), + (2, 2, 'dpkg'); + +INSERT INTO ancestry_detector (id, ancestry_id, detector) VALUES + (1, 1, 'os-release'), + (2, 2, 'os-release'); + +INSERT INTO ancestry_layer (id, ancestry_id, layer_id, ancestry_index) VALUES + (1, 1, 1, 0),(2, 1, 2, 1),(3, 1, 3, 2),(4, 1, 4, 3), + (5, 2, 1, 0),(6, 2, 2, 1),(7, 2, 3, 2),(8, 2, 5, 3); + +INSERT INTO namespaced_feature(id, feature_id, namespace_id) VALUES + (1, 1, 1), -- wechat 0.5, debian:7 + (2, 2, 1), -- openssl 1.0, debian:7 + (3, 2, 2), -- openssl 1.0, debian:8 + (4, 3, 1); -- openssl 2.0, debian:7 + +INSERT INTO ancestry_feature (id, ancestry_id, namespaced_feature_id) VALUES + (1, 1, 1), (2, 1, 4), + (3, 2, 1), (4, 2, 3), + (5, 3, 2), (6, 4, 2); -- assume that ancestry-3 and ancestry-4 are vulnerable. INSERT INTO vulnerability (id, namespace_id, name, description, link, severity) VALUES (1, 1, 'CVE-OPENSSL-1-DEB7', 'A vulnerability affecting OpenSSL < 2.0 on Debian 7.0', 'http://google.com/#q=CVE-OPENSSL-1-DEB7', 'High'), (2, 1, 'CVE-NOPE', 'A vulnerability affecting nothing', '', 'Unknown'); -INSERT INTO vulnerability_fixedin_feature (id, vulnerability_id, feature_id, version) VALUES - (1, 1, 2, '2.0'), - (2, 1, 4, '1.9-abc'); +INSERT INTO vulnerability_affected_feature(id, vulnerability_id, feature_name, affected_version, fixedin) VALUES +(1, 1, 'openssl', '2.0', '2.0'), +(2, 1, 'libssl', '1.9-abc', '1.9-abc'); + +INSERT INTO vulnerability_affected_namespaced_feature(id, vulnerability_id, namespaced_feature_id, added_by) VALUES + (1, 1, 2, 1); -INSERT INTO vulnerability_affects_featureversion (id, vulnerability_id, featureversion_id, fixedin_id) VALUES - (1, 1, 2, 1); -- CVE-OPENSSL-1-DEB7 affects Debian:7 OpenSSL 1.0 +INSERT INTO vulnerability_notification(id, name, created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id) VALUES + (1, 'test', NULL, NULL, NULL, 2, 1); SELECT pg_catalog.setval(pg_get_serial_sequence('namespace', 'id'), (SELECT MAX(id) FROM namespace)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry', 'id'), (SELECT MAX(id) FROM ancestry)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_layer', 'id'), (SELECT MAX(id) FROM ancestry_layer)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_feature', 'id'), (SELECT MAX(id) FROM ancestry_feature)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_detector', 'id'), (SELECT MAX(id) FROM ancestry_detector)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_lister', 'id'), (SELECT MAX(id) FROM ancestry_lister)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('feature', 'id'), (SELECT MAX(id) FROM feature)+1); -SELECT pg_catalog.setval(pg_get_serial_sequence('featureversion', 'id'), (SELECT MAX(id) FROM featureversion)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('namespaced_feature', 'id'), (SELECT MAX(id) FROM namespaced_feature)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('layer', 'id'), (SELECT MAX(id) FROM layer)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('layer_namespace', 'id'), (SELECT MAX(id) FROM layer_namespace)+1); -SELECT pg_catalog.setval(pg_get_serial_sequence('layer_diff_featureversion', 'id'), (SELECT MAX(id) FROM layer_diff_featureversion)+1); -SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability', 'id'), (SELECT MAX(id) FROM vulnerability)+1); -SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_fixedin_feature', 'id'), (SELECT MAX(id) FROM vulnerability_fixedin_feature)+1); -SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affects_featureversion', 'id'), (SELECT MAX(id) FROM vulnerability_affects_featureversion)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('layer_detector', 'id'), (SELECT MAX(id) FROM layer_detector)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('layer_lister', 'id'), (SELECT MAX(id) FROM layer_lister)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability', 'id'), (SELECT MAX(id) FROM vulnerability)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affected_feature', 'id'), (SELECT MAX(id) FROM vulnerability_affected_feature)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affected_namespaced_feature', 'id'), (SELECT MAX(id) FROM vulnerability_affected_namespaced_feature)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_notification', 'id'), (SELECT MAX(id) FROM vulnerability_notification)+1); \ No newline at end of file diff --git a/database/pgsql/vulnerability.go b/database/pgsql/vulnerability.go index efb5739271..fd08276059 100644 --- a/database/pgsql/vulnerability.go +++ b/database/pgsql/vulnerability.go @@ -15,354 +15,185 @@ package pgsql import ( - "database/sql" "encoding/json" - "reflect" + "errors" "time" - "github.com/guregu/null/zero" - log "github.com/sirupsen/logrus" + "database/sql" "github.com/coreos/clair/database" "github.com/coreos/clair/ext/versionfmt" "github.com/coreos/clair/pkg/commonerr" + "github.com/lib/pq" + "github.com/sirupsen/logrus" ) -// compareStringLists returns the strings that are present in X but not in Y. -func compareStringLists(X, Y []string) []string { - m := make(map[string]bool) - - for _, y := range Y { - m[y] = true - } - - diff := []string{} - for _, x := range X { - if m[x] { - continue - } - - diff = append(diff, x) - m[x] = true - } +var ( + errVulnerabilityNotFound = errors.New("vulnerability is not in database") +) - return diff +type affectRelation struct { + vulnerabilityID int + namespacedFeatureID int + addedBy int } -func compareStringListsInBoth(X, Y []string) []string { - m := make(map[string]struct{}) - - for _, y := range Y { - m[y] = struct{}{} - } - - diff := []string{} - for _, x := range X { - if _, e := m[x]; e { - diff = append(diff, x) - delete(m, x) - } - } - - return diff +type affectedFeatureRows struct { + rows map[int]database.AffectedFeature } -func (pgSQL *pgSQL) ListVulnerabilities(namespaceName string, limit int, startID int) ([]database.Vulnerability, int, error) { - defer observeQueryTime("listVulnerabilities", "all", time.Now()) - - // Query Namespace. - var id int - err := pgSQL.QueryRow(searchNamespace, namespaceName).Scan(&id) - if err != nil { - return nil, -1, handleError("searchNamespace", err) - } else if id == 0 { - return nil, -1, commonerr.ErrNotFound - } - - // Query. - query := searchVulnerabilityBase + searchVulnerabilityByNamespace - rows, err := pgSQL.Query(query, namespaceName, startID, limit+1) - if err != nil { - return nil, -1, handleError("searchVulnerabilityByNamespace", err) - } - defer rows.Close() +func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.VulnerabilityID) ([]database.NullableVulnerability, error) { + resultVuln := make([]database.NullableVulnerability, len(vulnerabilities)) + vulnIDMap := map[int64][]*database.NullableVulnerability{} + + // load vulnerabilities + for i, key := range vulnerabilities { + var ( + id sql.NullInt64 + vuln = database.NullableVulnerability{ + VulnerabilityWithAffected: database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: key.Name, + Namespace: database.Namespace{ + Name: key.Namespace, + }, + }, + }, + } + ) - var vulns []database.Vulnerability - nextID := -1 - size := 0 - // Scan query. - for rows.Next() { - var vulnerability database.Vulnerability - - err := rows.Scan( - &vulnerability.ID, - &vulnerability.Name, - &vulnerability.Namespace.ID, - &vulnerability.Namespace.Name, - &vulnerability.Namespace.VersionFormat, - &vulnerability.Description, - &vulnerability.Link, - &vulnerability.Severity, - &vulnerability.Metadata, + err := tx.QueryRow(searchVulnerability, key.Name, key.Namespace).Scan( + &id, + &vuln.Description, + &vuln.Link, + &vuln.Severity, + &vuln.Metadata, + &vuln.Namespace.VersionFormat, ) - if err != nil { - return nil, -1, handleError("searchVulnerabilityByNamespace.Scan()", err) + + if err != nil && err != sql.ErrNoRows { + return nil, handleError("searchVulnerability", err) } - size++ - if size > limit { - nextID = vulnerability.ID - } else { - vulns = append(vulns, vulnerability) + vuln.Valid = id.Valid + resultVuln[i] = vuln + if id.Valid { + vulnIDMap[id.Int64] = append(vulnIDMap[id.Int64], &resultVuln[i]) } } - if err := rows.Err(); err != nil { - return nil, -1, handleError("searchVulnerabilityByNamespace.Rows()", err) - } - - return vulns, nextID, nil -} - -func (pgSQL *pgSQL) FindVulnerability(namespaceName, name string) (database.Vulnerability, error) { - return findVulnerability(pgSQL, namespaceName, name, false) -} - -func findVulnerability(queryer Queryer, namespaceName, name string, forUpdate bool) (database.Vulnerability, error) { - defer observeQueryTime("findVulnerability", "all", time.Now()) - - queryName := "searchVulnerabilityBase+searchVulnerabilityByNamespaceAndName" - query := searchVulnerabilityBase + searchVulnerabilityByNamespaceAndName - if forUpdate { - queryName = queryName + "+searchVulnerabilityForUpdate" - query = query + searchVulnerabilityForUpdate - } - - return scanVulnerability(queryer, queryName, queryer.QueryRow(query, namespaceName, name)) -} - -func (pgSQL *pgSQL) findVulnerabilityByIDWithDeleted(id int) (database.Vulnerability, error) { - defer observeQueryTime("findVulnerabilityByIDWithDeleted", "all", time.Now()) - - queryName := "searchVulnerabilityBase+searchVulnerabilityByID" - query := searchVulnerabilityBase + searchVulnerabilityByID - - return scanVulnerability(pgSQL, queryName, pgSQL.QueryRow(query, id)) -} - -func scanVulnerability(queryer Queryer, queryName string, vulnerabilityRow *sql.Row) (database.Vulnerability, error) { - var vulnerability database.Vulnerability - - err := vulnerabilityRow.Scan( - &vulnerability.ID, - &vulnerability.Name, - &vulnerability.Namespace.ID, - &vulnerability.Namespace.Name, - &vulnerability.Namespace.VersionFormat, - &vulnerability.Description, - &vulnerability.Link, - &vulnerability.Severity, - &vulnerability.Metadata, - ) - - if err != nil { - return vulnerability, handleError(queryName+".Scan()", err) - } - - if vulnerability.ID == 0 { - return vulnerability, commonerr.ErrNotFound + toQuery := make([]int64, 0, len(vulnIDMap)) + for id := range vulnIDMap { + toQuery = append(toQuery, id) } - // Query the FixedIn FeatureVersion now. - rows, err := queryer.Query(searchVulnerabilityFixedIn, vulnerability.ID) + // load vulnerability affected features + rows, err := tx.Query(searchVulnerabilityAffected, pq.Array(toQuery)) if err != nil { - return vulnerability, handleError("searchVulnerabilityFixedIn.Scan()", err) + return nil, handleError("searchVulnerabilityAffected", err) } - defer rows.Close() for rows.Next() { - var featureVersionID zero.Int - var featureVersionVersion zero.String - var featureVersionFeatureName zero.String - - err := rows.Scan( - &featureVersionVersion, - &featureVersionID, - &featureVersionFeatureName, + var ( + id int64 + f database.AffectedFeature ) + err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FixedInVersion) if err != nil { - return vulnerability, handleError("searchVulnerabilityFixedIn.Scan()", err) + return nil, handleError("searchVulnerabilityAffected", err) } - if !featureVersionID.IsZero() { - // Note that the ID we fill in featureVersion is actually a Feature ID, and not - // a FeatureVersion ID. - featureVersion := database.FeatureVersion{ - Model: database.Model{ID: int(featureVersionID.Int64)}, - Feature: database.Feature{ - Model: database.Model{ID: int(featureVersionID.Int64)}, - Namespace: vulnerability.Namespace, - Name: featureVersionFeatureName.String, - }, - Version: featureVersionVersion.String, - } - vulnerability.FixedIn = append(vulnerability.FixedIn, featureVersion) + for _, vuln := range vulnIDMap[id] { + f.Namespace = vuln.Namespace + vuln.Affected = append(vuln.Affected, f) } } - if err := rows.Err(); err != nil { - return vulnerability, handleError("searchVulnerabilityFixedIn.Rows()", err) - } - - return vulnerability, nil -} - -// FixedIn.Namespace are not necessary, they are overwritten by the vuln. -// By setting the fixed version to minVersion, we can say that the vuln does'nt affect anymore. -func (pgSQL *pgSQL) InsertVulnerabilities(vulnerabilities []database.Vulnerability, generateNotifications bool) error { - for _, vulnerability := range vulnerabilities { - err := pgSQL.insertVulnerability(vulnerability, false, generateNotifications) - if err != nil { - return err - } - } - return nil + return resultVuln, nil } -func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability, onlyFixedIn, generateNotification bool) error { - tf := time.Now() - - // Verify parameters - if vulnerability.Name == "" || vulnerability.Namespace.Name == "" { - return commonerr.NewBadRequestError("insertVulnerability needs at least the Name and the Namespace") - } +func (tx *pgSession) InsertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) error { + var ( + // vCache is mapping from vulnerability ID to the affected features of that + // vulnerability. + vCache = map[int]affectedFeatureRows{} + vulnIDMap = map[database.VulnerabilityID]struct{}{} + vulnIDs = []database.VulnerabilityID{} + ) - for i := 0; i < len(vulnerability.FixedIn); i++ { - fifv := &vulnerability.FixedIn[i] - - if fifv.Feature.Namespace.Name == "" { - // As there is no Namespace on that FixedIn FeatureVersion, set it to the Vulnerability's - // Namespace. - fifv.Feature.Namespace = vulnerability.Namespace - } else if fifv.Feature.Namespace.Name != vulnerability.Namespace.Name { - msg := "could not insert an invalid vulnerability that contains FixedIn FeatureVersion that are not in the same namespace as the Vulnerability" - log.Warning(msg) - return commonerr.NewBadRequestError(msg) + // validating vulnerabilities, they should be NOT in the database and have + // valid fields. + for _, vuln := range vulnerabilities { + if vuln.Severity == "" || vuln.Name == "" || vuln.Namespace.Name == "" { + return commonerr.NewBadRequestError("Invalid vulnerability") } - } - // We do `defer observeQueryTime` here because we don't want to observe invalid vulnerabilities. - defer observeQueryTime("insertVulnerability", "all", tf) + key := database.VulnerabilityID{ + Name: vuln.Name, + Namespace: vuln.Namespace.Name, + } + if _, ok := vulnIDMap[key]; ok { + return commonerr.NewBadRequestError("duplicated vulnerability is not allowed") + } - // Begin transaction. - tx, err := pgSQL.Begin() - if err != nil { - tx.Rollback() - return handleError("insertVulnerability.Begin()", err) + vulnIDMap[key] = struct{}{} + vulnIDs = append(vulnIDs, key) } - // Find existing vulnerability and its Vulnerability_FixedIn_Features (for update). - existingVulnerability, err := findVulnerability(tx, vulnerability.Namespace.Name, vulnerability.Name, true) - if err != nil && err != commonerr.ErrNotFound { - tx.Rollback() + // query to ensure all the vulnerabilities don't exist + vulns, err := tx.FindVulnerabilities(vulnIDs) + if err != nil { return err } - if onlyFixedIn { - // Because this call tries to update FixedIn FeatureVersion, import all other data from the - // existing one. - if existingVulnerability.ID == 0 { - return commonerr.ErrNotFound + for _, vuln := range vulns { + if vuln.Valid { + return commonerr.NewBadRequestError("Vulnerability is already in database") } - - fixedIn := vulnerability.FixedIn - vulnerability = existingVulnerability - vulnerability.FixedIn = fixedIn } - if existingVulnerability.ID != 0 { - updateMetadata := vulnerability.Description != existingVulnerability.Description || - vulnerability.Link != existingVulnerability.Link || - vulnerability.Severity != existingVulnerability.Severity || - !reflect.DeepEqual(castMetadata(vulnerability.Metadata), existingVulnerability.Metadata) - - // Construct the entire list of FixedIn FeatureVersion, by using the - // the FixedIn list of the old vulnerability. - // - // TODO(Quentin-M): We could use !updateFixedIn to just copy FixedIn/Affects rows from the - // existing vulnerability in order to make metadata updates much faster. - var updateFixedIn bool - vulnerability.FixedIn, updateFixedIn = applyFixedInDiff(existingVulnerability.FixedIn, vulnerability.FixedIn) - - if !updateMetadata && !updateFixedIn { - tx.Commit() - return nil + // bulk insert vulnerabilities + for _, vuln := range vulnerabilities { + var vulnID int + err := tx.QueryRow(insertVulnerability, vuln.Name, vuln.Description, + vuln.Link, &vuln.Severity, &vuln.Metadata, + vuln.Namespace.Name, vuln.Namespace.VersionFormat).Scan(&vulnID) + if err != nil { + return handleError("insertVulnerability", err) } - // Mark the old vulnerability as non latest. - _, err = tx.Exec(removeVulnerability, vulnerability.Namespace.Name, vulnerability.Name) + features := affectedFeatureRows{rows: make(map[int]database.AffectedFeature)} + stmt, err := tx.Prepare(insertVulnerabilityAffected) if err != nil { - tx.Rollback() - return handleError("removeVulnerability", err) - } - } else { - // The vulnerability is new, we don't want to have any - // versionfmt.MinVersion as they are only used for diffing existing - // vulnerabilities. - var fixedIn []database.FeatureVersion - for _, fv := range vulnerability.FixedIn { - if fv.Version != versionfmt.MinVersion { - fixedIn = append(fixedIn, fv) - } + return handleError("insertVulnerabilityAffected", err) } - vulnerability.FixedIn = fixedIn - } - - // Find or insert Vulnerability's Namespace. - namespaceID, err := pgSQL.insertNamespace(vulnerability.Namespace) - if err != nil { - return err - } - // Insert vulnerability. - err = tx.QueryRow( - insertVulnerability, - namespaceID, - vulnerability.Name, - vulnerability.Description, - vulnerability.Link, - &vulnerability.Severity, - &vulnerability.Metadata, - ).Scan(&vulnerability.ID) + for _, f := range vuln.Affected { + // ensure vulnerability's affected features have the same namespace + // as the vulnerability. + if f.Namespace != vuln.Namespace { + stmt.Close() + return errors.New("Affected feature doesn't have the same namespace as the owner vulnerability") + } - if err != nil { - tx.Rollback() - return handleError("insertVulnerability", err) - } + var id int + err := stmt.QueryRow(vulnID, f.FeatureName, f.AffectedVersion, f.FixedInVersion).Scan(&id) + if err != nil { + stmt.Close() + return handleError("insertVulnerabilityAffected", err) + } - // Update Vulnerability_FixedIn_Feature and Vulnerability_Affects_FeatureVersion now. - err = pgSQL.insertVulnerabilityFixedInFeatureVersions(tx, vulnerability.ID, vulnerability.FixedIn) - if err != nil { - tx.Rollback() - return err - } + features.rows[id] = f + } - // Create a notification. - if generateNotification { - err = createNotification(tx, existingVulnerability.ID, vulnerability.ID) - if err != nil { - return err + if err := stmt.Close(); err != nil { + return handleError("insertVulnerabilityAffected", err) } - } - // Commit transaction. - err = tx.Commit() - if err != nil { - tx.Rollback() - return handleError("insertVulnerability.Commit()", err) + vCache[vulnID] = features } - return nil + return tx.cacheVulnerabiltyAffectedNamespacedFeature(vCache) } // castMetadata marshals the given database.MetadataMap and unmarshals it again to make sure that @@ -376,241 +207,210 @@ func castMetadata(m database.MetadataMap) database.MetadataMap { return c } -// applyFixedInDiff applies a FeatureVersion diff on a FeatureVersion list and returns the result. -func applyFixedInDiff(currentList, diff []database.FeatureVersion) ([]database.FeatureVersion, bool) { - currentMap, currentNames := createFeatureVersionNameMap(currentList) - diffMap, diffNames := createFeatureVersionNameMap(diff) - - addedNames := compareStringLists(diffNames, currentNames) - inBothNames := compareStringListsInBoth(diffNames, currentNames) - - different := false +// cacheVulnerabiltyAffectedNamespacedFeature takes in a map of +func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int]affectedFeatureRows) error { + var err error - for _, name := range addedNames { - if diffMap[name].Version == versionfmt.MinVersion { - // MinVersion only makes sense when a Feature is already fixed in some version, - // in which case we would be in the "inBothNames". - continue - } + // Lock Vulnerability_Affects_Feature exclusively. + // We want to prevent InsertFeature to modify it. + _, err = tx.Exec(lockVulnerabilityAffects) - currentMap[name] = diffMap[name] - different = true + if err != nil { + tx.Rollback() + return handleError("lockVulnerabilityAffects", err) } - for _, name := range inBothNames { - fv := diffMap[name] - - if fv.Version == versionfmt.MinVersion { - // MinVersion means that the Feature doesn't affect the Vulnerability anymore. - delete(currentMap, name) - different = true - } else if fv.Version != currentMap[name].Version { - // The version got updated. - currentMap[name] = diffMap[name] - different = true - } + vulnIDs := []int{} + for id := range affected { + vulnIDs = append(vulnIDs, id) } - // Convert currentMap to a slice and return it. - var newList []database.FeatureVersion - for _, fv := range currentMap { - newList = append(newList, fv) + // query for potentially affected namespaced features. The ones actually + // being affected will be computed by versionfmt.InRange using the + // vulnerability's affected + rows, err := tx.Query(searchVulnerabilityPotentialAffected, pq.Array(vulnIDs)) + if err != nil { + return handleError("searchVulnerabilityPotentialAffected", err) } - return newList, different -} + defer rows.Close() -func createFeatureVersionNameMap(features []database.FeatureVersion) (map[string]database.FeatureVersion, []string) { - m := make(map[string]database.FeatureVersion, 0) - s := make([]string, 0, len(features)) + relation := []affectRelation{} + for rows.Next() { + var ( + vulnID int + nsfID int + fVersion string + addedBy int + ) - for i := 0; i < len(features); i++ { - featureVersion := features[i] - m[featureVersion.Feature.Name] = featureVersion - s = append(s, featureVersion.Feature.Name) - } + err := rows.Scan(&vulnID, &nsfID, &fVersion, &addedBy) + if err != nil { + return handleError("searchVulnerabilityPotentialAffected", err) + } - return m, s -} + candidate, ok := affected[vulnID].rows[addedBy] -// insertVulnerabilityFixedInFeatureVersions populates Vulnerability_FixedIn_Feature for the given -// vulnerability with the specified database.FeatureVersion list and uses -// linkVulnerabilityToFeatureVersions to propagate the changes on Vulnerability_FixedIn_Feature to -// Vulnerability_Affects_FeatureVersion. -func (pgSQL *pgSQL) insertVulnerabilityFixedInFeatureVersions(tx *sql.Tx, vulnerabilityID int, fixedIn []database.FeatureVersion) error { - defer observeQueryTime("insertVulnerabilityFixedInFeatureVersions", "all", time.Now()) + if !ok { + return errors.New("vulnerability affected feature not found") + } - // Insert or find the Features. - // TODO(Quentin-M): Batch me. - var err error - var features []*database.Feature - for i := 0; i < len(fixedIn); i++ { - features = append(features, &fixedIn[i].Feature) - } - for _, feature := range features { - if feature.ID == 0 { - if feature.ID, err = pgSQL.insertFeature(*feature); err != nil { - return err + if in, err := versionfmt.InRange(candidate.Namespace.VersionFormat, + fVersion, + candidate.AffectedVersion); err == nil { + if in { + relation = append(relation, + affectRelation{ + vulnerabilityID: vulnID, + namespacedFeatureID: nsfID, + addedBy: addedBy, + }) } + } else { + return err } } - // Lock Vulnerability_Affects_FeatureVersion exclusively. - // We want to prevent InsertFeatureVersion to modify it. - promConcurrentLockVAFV.Inc() - defer promConcurrentLockVAFV.Dec() - t := time.Now() - _, err = tx.Exec(lockVulnerabilityAffects) - observeQueryTime("insertVulnerability", "lock", t) - - if err != nil { - tx.Rollback() - return handleError("insertVulnerability.lockVulnerabilityAffects", err) - } - - for _, fv := range fixedIn { - var fixedInID int - var created bool - - // Find or create entry in Vulnerability_FixedIn_Feature. - err = tx.QueryRow( - soiVulnerabilityFixedInFeature, - vulnerabilityID, fv.Feature.ID, - &fv.Version, - ).Scan(&created, &fixedInID) - + for _, r := range relation { + result, err := tx.Exec(insertVulnerabilityAffectedNamespacedFeature, r.vulnerabilityID, r.namespacedFeatureID, r.addedBy) if err != nil { - return handleError("insertVulnerabilityFixedInFeature", err) - } - - if !created { - // The relationship between the feature and the vulnerability already - // existed, no need to update Vulnerability_Affects_FeatureVersion. - continue + return handleError("insertVulnerabilityAffectedNamespacedFeature", err) } - // Insert Vulnerability_Affects_FeatureVersion. - err = linkVulnerabilityToFeatureVersions(tx, fixedInID, vulnerabilityID, fv.Feature.ID, fv.Feature.Namespace.VersionFormat, fv.Version) - if err != nil { + if num, err := result.RowsAffected(); err == nil { + if num <= 0 { + return errors.New("Nothing cached in database") + } + } else { return err } } + logrus.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", len(relation)) return nil } -func linkVulnerabilityToFeatureVersions(tx *sql.Tx, fixedInID, vulnerabilityID, featureID int, versionFormat, fixedInVersion string) error { - // Find every FeatureVersions of the Feature that the vulnerability affects. - // TODO(Quentin-M): LIMIT - rows, err := tx.Query(searchFeatureVersionByFeature, featureID) - if err != nil { - return handleError("searchFeatureVersionByFeature", err) - } - defer rows.Close() - - var affecteds []database.FeatureVersion - for rows.Next() { - var affected database.FeatureVersion - - err := rows.Scan(&affected.ID, &affected.Version) - if err != nil { - return handleError("searchFeatureVersionByFeature.Scan()", err) - } +func (tx *pgSession) DeleteVulnerabilities(vulnerabilities []database.VulnerabilityID) error { + defer observeQueryTime("DeleteVulnerability", "all", time.Now()) - cmp, err := versionfmt.Compare(versionFormat, affected.Version, fixedInVersion) + for _, vuln := range vulnerabilities { + r, err := tx.Exec(removeVulnerability, vuln.Namespace, vuln.Name) if err != nil { - return err - } - if cmp < 0 { - // The version of the FeatureVersion is lower than the fixed version of this vulnerability, - // thus, this FeatureVersion is affected by it. - affecteds = append(affecteds, affected) + return handleError("removeVulnerability", err) } - } - if err = rows.Err(); err != nil { - return handleError("searchFeatureVersionByFeature.Rows()", err) - } - rows.Close() - - // Insert into Vulnerability_Affects_FeatureVersion. - for _, affected := range affecteds { - // TODO(Quentin-M): Batch me. - _, err := tx.Exec(insertVulnerabilityAffectsFeatureVersion, vulnerabilityID, affected.ID, fixedInID) - if err != nil { - return handleError("insertVulnerabilityAffectsFeatureVersion", err) + if num, err := r.RowsAffected(); err != nil { + return handleError("removeVulnerability", err) + } else if num <= 0 { + return handleError("removeVulnerability", errors.New("Vulnerability to be removed is not in database")) } } return nil } -func (pgSQL *pgSQL) InsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName string, fixes []database.FeatureVersion) error { - defer observeQueryTime("InsertVulnerabilityFixes", "all", time.Now()) +// FindVulnerableAncestries joins vulnerability, +func (tx *pgSession) FindVulnerableAncestries(vulnID database.VulnerabilityID, limit int, page database.PageNumber) (database.PagedVulnerableAncestries, error) { + vulnAncestry := database.PagedVulnerableAncestries{ - v := database.Vulnerability{ - Name: vulnerabilityName, - Namespace: database.Namespace{ - Name: vulnerabilityNamespace, - }, - FixedIn: fixes, + Affected: make(map[int]string), + Limit: limit, + Current: page, } - return pgSQL.insertVulnerability(v, true, true) -} + if limit <= 0 { + return vulnAncestry, commonerr.NewBadRequestError("Page Limit should be greater than 0") + } -func (pgSQL *pgSQL) DeleteVulnerabilityFix(vulnerabilityNamespace, vulnerabilityName, featureName string) error { - defer observeQueryTime("DeleteVulnerabilityFix", "all", time.Now()) + if page == database.NoVulnerabilityNotificationPage { + return vulnAncestry, nil + } - v := database.Vulnerability{ - Name: vulnerabilityName, - Namespace: database.Namespace{ - Name: vulnerabilityNamespace, - }, - FixedIn: []database.FeatureVersion{ - { - Feature: database.Feature{ - Name: featureName, - Namespace: database.Namespace{ - Name: vulnerabilityNamespace, - }, - }, - Version: versionfmt.MinVersion, + vuln := database.VulnerabilityWithFixedIn{ + Vulnerability: database.Vulnerability{ + Name: vulnID.Name, + Namespace: database.Namespace{ + Name: vulnID.Namespace, }, }, } - return pgSQL.insertVulnerability(v, true, true) -} + var id int + err := tx.QueryRow(searchVulnerability, vulnID.Name, vulnID.Namespace).Scan(&id, &vuln.Description, &vuln.Link, &vuln.Severity, &vuln.Metadata, &vuln.Namespace.VersionFormat) + if err != nil { + return vulnAncestry, handleError("searchVulnerability", err) + } -func (pgSQL *pgSQL) DeleteVulnerability(namespaceName, name string) error { - defer observeQueryTime("DeleteVulnerability", "all", time.Now()) + vulnAncestry.VulnerabilityWithFixedIn = vuln - // Begin transaction. - tx, err := pgSQL.Begin() + // the last result is used for the next page's startID + rows, err := tx.Query(searchNotificationVulnerableAncestry, id, page.StartID, limit+1) if err != nil { - tx.Rollback() - return handleError("DeleteVulnerability.Begin()", err) + return vulnAncestry, handleError("searchNotificationVulnerableAncestry", err) } + defer rows.Close() - var vulnerabilityID int - err = tx.QueryRow(removeVulnerability, namespaceName, name).Scan(&vulnerabilityID) - if err != nil { - tx.Rollback() - return handleError("removeVulnerability", err) + type affectedAncestry struct { + name string + id int } - // Create a notification. - err = createNotification(tx, vulnerabilityID, 0) - if err != nil { - return err + ancestries := []affectedAncestry{} + for rows.Next() { + var ancestry affectedAncestry + err := rows.Scan(&ancestry.id, &ancestry.name) + if err != nil { + return vulnAncestry, handleError("searchNotificationVulnerableAncestry", err) + } + ancestries = append(ancestries, ancestry) } - // Commit transaction. - err = tx.Commit() - if err != nil { - tx.Rollback() - return handleError("DeleteVulnerability.Commit()", err) + // last ancestry's index in ancestries array + lastIndex := 0 + if len(ancestries)-1 < limit { + lastIndex = len(ancestries) + vulnAncestry.Next = database.NoVulnerabilityNotificationPage + } else { + lastIndex = len(ancestries) - 1 + vulnAncestry.Next = database.PageNumber{StartID: ancestries[len(ancestries)-1].id} } - return nil + for _, ancestry := range ancestries[0:lastIndex] { + vulnAncestry.Affected[ancestry.id] = ancestry.name + } + + return vulnAncestry, nil +} + +func (tx *pgSession) findLatestDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { + ids := []sql.NullInt64{} + + for _, vulnID := range vulnIDs { + var id sql.NullInt64 + err := tx.QueryRow(searchLatestDeletedVulnerabilityID, vulnID.Name, vulnID.Namespace).Scan(&id) + if err != nil && err != sql.ErrNoRows { + return nil, handleError("searchLatestDeletedVulnerabilityID", err) + } else if err == sql.ErrNoRows { + id.Valid = false + } + ids = append(ids, id) + } + + return ids, nil +} + +func (tx *pgSession) findNotDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { + ids := []sql.NullInt64{} + + for _, vulnID := range vulnIDs { + var id sql.NullInt64 + err := tx.QueryRow(searchNotDeletedVulnerabilityID, vulnID.Name, vulnID.Namespace).Scan(&id) + if err != nil && err != sql.ErrNoRows { + return nil, handleError("searchNotDeletedVulnerabilityID", err) + } else if err == sql.ErrNoRows { + id.Valid = false + } + ids = append(ids, id) + } + + return ids, nil } diff --git a/database/pgsql/vulnerability_test.go b/database/pgsql/vulnerability_test.go index 61d835bbb5..eaf9f51991 100644 --- a/database/pgsql/vulnerability_test.go +++ b/database/pgsql/vulnerability_test.go @@ -15,282 +15,340 @@ package pgsql import ( - "reflect" + "fmt" "testing" - "github.com/stretchr/testify/assert" - "github.com/coreos/clair/database" - "github.com/coreos/clair/ext/versionfmt" "github.com/coreos/clair/ext/versionfmt/dpkg" - "github.com/coreos/clair/pkg/commonerr" + "github.com/stretchr/testify/assert" ) -func TestFindVulnerability(t *testing.T) { - datastore, err := openDatabaseForTest("FindVulnerability", true) - if err != nil { - t.Error(err) - return +func TestInsertVulnerabilities(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistNamespacedFeatures", true) + defer closeTest(t, datastore, tx) + + ns1 := database.Namespace{ + Name: "name", + VersionFormat: "random stuff", } - defer datastore.Close() - // Find a vulnerability that does not exist. - _, err = datastore.FindVulnerability("", "") - assert.Equal(t, commonerr.ErrNotFound, err) + ns2 := database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + } - // Find a normal vulnerability. + // invalid vulnerability v1 := database.Vulnerability{ - Name: "CVE-OPENSSL-1-DEB7", - Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", - Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7", - Severity: database.HighSeverity, - Namespace: database.Namespace{ - Name: "debian:7", - VersionFormat: dpkg.ParserName, - }, - FixedIn: []database.FeatureVersion{ - { - Feature: database.Feature{Name: "openssl"}, - Version: "2.0", - }, - { - Feature: database.Feature{Name: "libssl"}, - Version: "1.9-abc", - }, - }, + Name: "invalid", + Namespace: ns1, } - v1f, err := datastore.FindVulnerability("debian:7", "CVE-OPENSSL-1-DEB7") - if assert.Nil(t, err) { - equalsVuln(t, &v1, &v1f) + vwa1 := database.VulnerabilityWithAffected{ + Vulnerability: v1, } - - // Find a vulnerability that has no link, no severity and no FixedIn. + // valid vulnerability v2 := database.Vulnerability{ - Name: "CVE-NOPE", - Description: "A vulnerability affecting nothing", - Namespace: database.Namespace{ - Name: "debian:7", - VersionFormat: dpkg.ParserName, - }, - Severity: database.UnknownSeverity, + Name: "valid", + Namespace: ns2, + Severity: database.UnknownSeverity, } - v2f, err := datastore.FindVulnerability("debian:7", "CVE-NOPE") - if assert.Nil(t, err) { - equalsVuln(t, &v2, &v2f) + vwa2 := database.VulnerabilityWithAffected{ + Vulnerability: v2, } -} -func TestDeleteVulnerability(t *testing.T) { - datastore, err := openDatabaseForTest("InsertVulnerability", true) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() + // empty + err := tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{}) + assert.Nil(t, err) + // invalid content: vwa1 is invalid + err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa1, vwa2}) + assert.NotNil(t, err) - // Delete non-existing Vulnerability. - err = datastore.DeleteVulnerability("TestDeleteVulnerabilityNamespace1", "CVE-OPENSSL-1-DEB7") - assert.Equal(t, commonerr.ErrNotFound, err) - err = datastore.DeleteVulnerability("debian:7", "TestDeleteVulnerabilityVulnerability1") - assert.Equal(t, commonerr.ErrNotFound, err) + // invalid content: duplicated input + err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2, vwa2}) + assert.NotNil(t, err) - // Delete Vulnerability. - err = datastore.DeleteVulnerability("debian:7", "CVE-OPENSSL-1-DEB7") - if assert.Nil(t, err) { - _, err := datastore.FindVulnerability("debian:7", "CVE-OPENSSL-1-DEB7") - assert.Equal(t, commonerr.ErrNotFound, err) + // valid content + err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2}) + assert.Nil(t, err) + + // ensure the content is in database + vulns, err := tx.FindVulnerabilities([]database.VulnerabilityID{{Name: "valid", Namespace: "debian:7"}}) + if assert.Nil(t, err) && assert.Len(t, vulns, 1) { + assert.True(t, vulns[0].Valid) } + + // invalid content: vwa2 is already in database + err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2}) + fmt.Println(err) + assert.NotNil(t, err) + + // valid content: vwa2 removed and inserted + err = tx.DeleteVulnerabilities([]database.VulnerabilityID{{Name: vwa2.Name, Namespace: vwa2.Namespace.Name}}) + assert.Nil(t, err) + + err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2}) + assert.Nil(t, err) } -func TestInsertVulnerability(t *testing.T) { - datastore, err := openDatabaseForTest("InsertVulnerability", false) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() +func TestCachingVulnerable(t *testing.T) { + datastore, tx := openSessionForTest(t, "CachingVulnerable", true) + defer closeTest(t, datastore, tx) - // Create some data. - n1 := database.Namespace{ - Name: "TestInsertVulnerabilityNamespace1", - VersionFormat: dpkg.ParserName, - } - n2 := database.Namespace{ - Name: "TestInsertVulnerabilityNamespace2", + ns := database.Namespace{ + Name: "debian:8", VersionFormat: dpkg.ParserName, } - f1 := database.FeatureVersion{ + f := database.NamespacedFeature{ Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion1", - Namespace: n1, + Name: "openssl", + Version: "1.0", + VersionFormat: dpkg.ParserName, }, - Version: "1.0", + Namespace: ns, } - f2 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion1", - Namespace: n2, + + vuln := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: "CVE-YAY", + Namespace: ns, + Severity: database.HighSeverity, }, - Version: "1.0", - } - f3 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion2", + Affected: []database.AffectedFeature{ + { + Namespace: ns, + FeatureName: "openssl", + AffectedVersion: "2.0", + FixedInVersion: "2.1", + }, }, - Version: versionfmt.MaxVersion, } - f4 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion2", + + vuln2 := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: "CVE-YAY2", + Namespace: ns, + Severity: database.HighSeverity, }, - Version: "1.4", - } - f5 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion3", + Affected: []database.AffectedFeature{ + { + Namespace: ns, + FeatureName: "openssl", + AffectedVersion: "2.1", + FixedInVersion: "2.2", + }, }, - Version: "1.5", } - f6 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion4", + + vulnFixed1 := database.VulnerabilityWithFixedIn{ + Vulnerability: database.Vulnerability{ + Name: "CVE-YAY", + Namespace: ns, + Severity: database.HighSeverity, }, - Version: "0.1", + FixedInVersion: "2.1", } - f7 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion5", + + vulnFixed2 := database.VulnerabilityWithFixedIn{ + Vulnerability: database.Vulnerability{ + Name: "CVE-YAY2", + Namespace: ns, + Severity: database.HighSeverity, }, - Version: versionfmt.MaxVersion, + FixedInVersion: "2.2", } - f8 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion5", - }, - Version: versionfmt.MinVersion, + + if !assert.Nil(t, tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vuln, vuln2})) { + t.FailNow() } - // Insert invalid vulnerabilities. - for _, vulnerability := range []database.Vulnerability{ - { - Name: "", - Namespace: n1, - FixedIn: []database.FeatureVersion{f1}, - Severity: database.UnknownSeverity, - }, - { - Name: "TestInsertVulnerability0", - Namespace: database.Namespace{}, - FixedIn: []database.FeatureVersion{f1}, - Severity: database.UnknownSeverity, - }, + r, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{f}) + assert.Nil(t, err) + assert.Len(t, r, 1) + for _, anf := range r { + if assert.True(t, anf.Valid) && assert.Len(t, anf.AffectedBy, 2) { + for _, a := range anf.AffectedBy { + if a.Name == "CVE-YAY" { + assert.Equal(t, vulnFixed1, a) + } else if a.Name == "CVE-YAY2" { + assert.Equal(t, vulnFixed2, a) + } else { + t.FailNow() + } + } + } + } +} + +func TestFindVulnerabilities(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistNamespacedFeatures", true) + defer closeTest(t, datastore, tx) + + vuln, err := tx.FindVulnerabilities([]database.VulnerabilityID{ + {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, + {Name: "CVE-NOPE", Namespace: "debian:7"}, + {Name: "CVE-NOT HERE"}, + }) + + ns := database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + } + + expectedExisting := []database.VulnerabilityWithAffected{ { - Name: "TestInsertVulnerability0-", - Namespace: database.Namespace{}, - FixedIn: []database.FeatureVersion{f1}, + Vulnerability: database.Vulnerability{ + Namespace: ns, + Name: "CVE-OPENSSL-1-DEB7", + Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", + Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7", + Severity: database.HighSeverity, + }, + Affected: []database.AffectedFeature{ + { + FeatureName: "openssl", + AffectedVersion: "2.0", + FixedInVersion: "2.0", + Namespace: ns, + }, + { + FeatureName: "libssl", + AffectedVersion: "1.9-abc", + FixedInVersion: "1.9-abc", + Namespace: ns, + }, + }, }, { - Name: "TestInsertVulnerability0", - Namespace: n1, - FixedIn: []database.FeatureVersion{f2}, - Severity: database.UnknownSeverity, + Vulnerability: database.Vulnerability{ + Namespace: ns, + Name: "CVE-NOPE", + Description: "A vulnerability affecting nothing", + Severity: database.UnknownSeverity, + }, }, - } { - err := datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability}, true) - assert.Error(t, err) } - // Insert a simple vulnerability and find it. - v1meta := make(map[string]interface{}) - v1meta["TestInsertVulnerabilityMetadata1"] = "TestInsertVulnerabilityMetadataValue1" - v1meta["TestInsertVulnerabilityMetadata2"] = struct { - Test string - }{ - Test: "TestInsertVulnerabilityMetadataValue1", + expectedExistingMap := map[database.VulnerabilityID]database.VulnerabilityWithAffected{} + for _, v := range expectedExisting { + expectedExistingMap[database.VulnerabilityID{Name: v.Name, Namespace: v.Namespace.Name}] = v } - v1 := database.Vulnerability{ - Name: "TestInsertVulnerability1", - Namespace: n1, - FixedIn: []database.FeatureVersion{f1, f3, f6, f7}, - Severity: database.LowSeverity, - Description: "TestInsertVulnerabilityDescription1", - Link: "TestInsertVulnerabilityLink1", - Metadata: v1meta, + nonexisting := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{Name: "CVE-NOT HERE"}, } - err = datastore.InsertVulnerabilities([]database.Vulnerability{v1}, true) + if assert.Nil(t, err) { - v1f, err := datastore.FindVulnerability(n1.Name, v1.Name) - if assert.Nil(t, err) { - equalsVuln(t, &v1, &v1f) + for _, v := range vuln { + if v.Valid { + key := database.VulnerabilityID{ + Name: v.Name, + Namespace: v.Namespace.Name, + } + + expected, ok := expectedExistingMap[key] + if assert.True(t, ok, "vulnerability not found: "+key.Name+":"+key.Namespace) { + assertVulnerabilityWithAffectedEqual(t, expected, v.VulnerabilityWithAffected) + } + } else if !assert.Equal(t, nonexisting, v.VulnerabilityWithAffected) { + t.FailNow() + } } } - // Update vulnerability. - v1.Description = "TestInsertVulnerabilityLink2" - v1.Link = "TestInsertVulnerabilityLink2" - v1.Severity = database.HighSeverity - // Update f3 in f4, add fixed in f5, add fixed in f6 which already exists, - // removes fixed in f7 by adding f8 which is f7 but with MinVersion, and - // add fixed by f5 a second time (duplicated). - v1.FixedIn = []database.FeatureVersion{f4, f5, f6, f8, f5} + // same vulnerability + r, err := tx.FindVulnerabilities([]database.VulnerabilityID{ + {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, + {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, + }) - err = datastore.InsertVulnerabilities([]database.Vulnerability{v1}, true) if assert.Nil(t, err) { - v1f, err := datastore.FindVulnerability(n1.Name, v1.Name) - if assert.Nil(t, err) { - // Remove f8 from the struct for comparison as it was just here to cancel f7. - // Remove one of the f5 too as it was twice in the struct but the database - // implementation should have dedup'd it. - v1.FixedIn = v1.FixedIn[:len(v1.FixedIn)-2] - - // We already had f1 before the update. - // Add it to the struct for comparison. - v1.FixedIn = append(v1.FixedIn, f1) - - equalsVuln(t, &v1, &v1f) + for _, vuln := range r { + if assert.True(t, vuln.Valid) { + expected, _ := expectedExistingMap[database.VulnerabilityID{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}] + assertVulnerabilityWithAffectedEqual(t, expected, vuln.VulnerabilityWithAffected) + } } } } -func equalsVuln(t *testing.T, expected, actual *database.Vulnerability) { - assert.Equal(t, expected.Name, actual.Name) - assert.Equal(t, expected.Namespace.Name, actual.Namespace.Name) - assert.Equal(t, expected.Description, actual.Description) - assert.Equal(t, expected.Link, actual.Link) - assert.Equal(t, expected.Severity, actual.Severity) - assert.True(t, reflect.DeepEqual(castMetadata(expected.Metadata), actual.Metadata), "Got metadata %#v, expected %#v", actual.Metadata, castMetadata(expected.Metadata)) - - if assert.Len(t, actual.FixedIn, len(expected.FixedIn)) { - for _, actualFeatureVersion := range actual.FixedIn { - found := false - for _, expectedFeatureVersion := range expected.FixedIn { - if expectedFeatureVersion.Feature.Name == actualFeatureVersion.Feature.Name { - found = true - - assert.Equal(t, expected.Namespace.Name, actualFeatureVersion.Feature.Namespace.Name) - assert.Equal(t, expectedFeatureVersion.Version, actualFeatureVersion.Version) - } - } - if !found { - t.Errorf("unexpected package %s in %s", actualFeatureVersion.Feature.Name, expected.Name) - } +func TestDeleteVulnerabilities(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistNamespacedFeatures", true) + defer closeTest(t, datastore, tx) + + remove := []database.VulnerabilityID{} + // empty case + assert.Nil(t, tx.DeleteVulnerabilities(remove)) + // invalid case + remove = append(remove, database.VulnerabilityID{}) + assert.NotNil(t, tx.DeleteVulnerabilities(remove)) + + // valid case + validRemove := []database.VulnerabilityID{ + {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, + {Name: "CVE-NOPE", Namespace: "debian:7"}, + } + + assert.Nil(t, tx.DeleteVulnerabilities(validRemove)) + vuln, err := tx.FindVulnerabilities(validRemove) + if assert.Nil(t, err) { + for _, v := range vuln { + assert.False(t, v.Valid) + } + } +} + +func TestFindVulnerableAncestries(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistNamespacedFeatures", true) + defer closeTest(t, datastore, tx) + + // empty vulnerable + vulnID := database.VulnerabilityID{ + Name: "CVE-OPENSSL-1-DEB7", + Namespace: "debian:7", + } + + vulnerablePage, err := tx.FindVulnerableAncestries(vulnID, 1, database.VulnerabilityNotificationFirstPage) + if assert.Nil(t, err) { + assert.Len(t, vulnerablePage.Affected, 1) + a, ok := vulnerablePage.Affected[3] + if assert.True(t, ok) { + assert.Equal(t, "ancestry-3", a) + } + assert.Equal(t, 4, vulnerablePage.Next.StartID) + } + + vulnerablePage, err = tx.FindVulnerableAncestries(vulnID, 1, vulnerablePage.Next) + if assert.Nil(t, err) { + assert.Len(t, vulnerablePage.Affected, 1) + a, ok := vulnerablePage.Affected[4] + if assert.True(t, ok) { + assert.Equal(t, "ancestry-4", a) } + assert.Equal(t, database.NoVulnerabilityNotificationPage, vulnerablePage.Next) } } -func TestStringComparison(t *testing.T) { - cmp := compareStringLists([]string{"a", "b", "b", "a"}, []string{"a", "c"}) - assert.Len(t, cmp, 1) - assert.NotContains(t, cmp, "a") - assert.Contains(t, cmp, "b") - - cmp = compareStringListsInBoth([]string{"a", "a", "b", "c"}, []string{"a", "c", "c"}) - assert.Len(t, cmp, 2) - assert.NotContains(t, cmp, "b") - assert.Contains(t, cmp, "a") - assert.Contains(t, cmp, "c") +func assertVulnerabilityWithAffectedEqual(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool { + return assert.Equal(t, expected.Vulnerability, actual.Vulnerability) && assertAffectedFeaturesEqual(t, expected.Affected, actual.Affected) +} + +func assertAffectedFeaturesEqual(t *testing.T, expected []database.AffectedFeature, actual []database.AffectedFeature) bool { + if assert.Len(t, actual, len(expected)) { + has := map[database.AffectedFeature]bool{} + for _, i := range expected { + has[i] = false + } + for _, i := range actual { + if visited, ok := has[i]; !ok { + return false + } else if visited { + return false + } + has[i] = true + } + return true + } + return false }