From e7b960c05b19d3bca0e1027071cacbbcf131c1a5 Mon Sep 17 00:00:00 2001 From: Quentin Machu Date: Mon, 2 May 2016 18:33:03 -0400 Subject: [PATCH] database: Allow specifying datastore driver by config Fixes #145 --- clair.go | 4 +- cmd/clair/main.go | 6 +- config.example.yaml | 15 +- config/config.go | 27 ++-- config/config_test.go | 81 ---------- database/database.go | 35 +++- database/pgsql/complex_test.go | 7 +- database/pgsql/feature_test.go | 5 +- database/pgsql/keyvalue_test.go | 2 +- database/pgsql/layer.go | 4 +- database/pgsql/layer_test.go | 7 +- database/pgsql/lock_test.go | 2 +- database/pgsql/namespace_test.go | 7 +- database/pgsql/notification_test.go | 5 +- database/pgsql/pgsql.go | 232 ++++++++++++++++----------- database/pgsql/pgsql_test.go | 45 ++++++ database/pgsql/vulnerability_test.go | 9 +- 17 files changed, 264 insertions(+), 229 deletions(-) delete mode 100644 config/config_test.go create mode 100644 database/pgsql/pgsql_test.go diff --git a/clair.go b/clair.go index cfb5c8500e..788ab617d4 100644 --- a/clair.go +++ b/clair.go @@ -26,7 +26,7 @@ import ( "github.com/coreos/clair/api" "github.com/coreos/clair/api/context" "github.com/coreos/clair/config" - "github.com/coreos/clair/database/pgsql" + "github.com/coreos/clair/database" "github.com/coreos/clair/notifier" "github.com/coreos/clair/updater" "github.com/coreos/clair/utils" @@ -42,7 +42,7 @@ func Boot(config *config.Config) { st := utils.NewStopper() // Open database - db, err := pgsql.Open(config.Database) + db, err := database.Open(config.Database) if err != nil { log.Fatal(err) } diff --git a/cmd/clair/main.go b/cmd/clair/main.go index e04ddbdd08..df3d0ba2be 100644 --- a/cmd/clair/main.go +++ b/cmd/clair/main.go @@ -20,11 +20,11 @@ import ( "runtime/pprof" "strings" + "github.com/coreos/pkg/capnslog" + "github.com/coreos/clair" "github.com/coreos/clair/config" - "github.com/coreos/pkg/capnslog" - // Register components _ "github.com/coreos/clair/notifier/notifiers" @@ -43,6 +43,8 @@ import ( _ "github.com/coreos/clair/worker/detectors/namespace/lsbrelease" _ "github.com/coreos/clair/worker/detectors/namespace/osrelease" _ "github.com/coreos/clair/worker/detectors/namespace/redhatrelease" + + _ "github.com/coreos/clair/database/pgsql" ) var log = capnslog.NewPackageLogger("github.com/coreos/clair/cmd/clair", "main") diff --git a/config.example.yaml b/config.example.yaml index 166141116a..0e65926476 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -15,13 +15,16 @@ # The values specified here are the default values that Clair uses if no configuration file is specified or if the keys are not defined. clair: database: - # PostgreSQL Connection string - # http://www.postgresql.org/docs/9.4/static/libpq-connect.html - source: + # Database driver + type: pgsql + options: + # PostgreSQL Connection string + # http://www.postgresql.org/docs/9.4/static/libpq-connect.html + source: - # Number of elements kept in the cache - # Values unlikely to change (e.g. namespaces) are cached in order to save prevent needless roundtrips to the database. - cacheSize: 16384 + # Number of elements kept in the cache + # Values unlikely to change (e.g. namespaces) are cached in order to save prevent needless roundtrips to the database. + cachesize: 16384 api: # API server port diff --git a/config/config.go b/config/config.go index 7e45c47abf..9996c6514e 100644 --- a/config/config.go +++ b/config/config.go @@ -27,6 +27,14 @@ import ( // ErrDatasourceNotLoaded is returned when the datasource variable in the configuration file is not loaded properly var ErrDatasourceNotLoaded = errors.New("could not load configuration: no database source specified") +// RegistrableComponentConfig is a configuration block that can be used to +// determine which registrable component should be initialized and pass +// custom configuration to it. +type RegistrableComponentConfig struct { + Type string + Options map[string]interface{} +} + // File represents a YAML configuration file that namespaces all Clair // configuration under the top-level "clair" key. type File struct { @@ -35,19 +43,12 @@ type File struct { // Config is the global configuration for an instance of Clair. type Config struct { - Database *DatabaseConfig + Database RegistrableComponentConfig Updater *UpdaterConfig Notifier *NotifierConfig API *APIConfig } -// DatabaseConfig is the configuration used to specify how Clair connects -// to a database. -type DatabaseConfig struct { - Source string - CacheSize int -} - // UpdaterConfig is the configuration for the Updater service. type UpdaterConfig struct { Interval time.Duration @@ -72,8 +73,8 @@ type APIConfig struct { // DefaultConfig is a configuration that can be used as a fallback value. func DefaultConfig() Config { return Config{ - Database: &DatabaseConfig{ - CacheSize: 16384, + Database: RegistrableComponentConfig{ + Type: "pgsql", }, Updater: &UpdaterConfig{ Interval: 1 * time.Hour, @@ -116,12 +117,8 @@ func Load(path string) (config *Config, err error) { } config = &cfgFile.Clair - if config.Database.Source == "" { - err = ErrDatasourceNotLoaded - return - } - // Generate a pagination key if none is provided. + // TODO(Quentin-M): Move to the API code. if config.API.PaginationKey == "" { var key fernet.Key if err = key.Generate(); err != nil { diff --git a/config/config_test.go b/config/config_test.go deleted file mode 100644 index dec03fd244..0000000000 --- a/config/config_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package config - -import ( - "io/ioutil" - "log" - "os" - "testing" - - "github.com/stretchr/testify/assert" -) - -const wrongConfig = ` -dummyKey: - wrong:true -` - -const goodConfig = ` -clair: - database: - source: postgresql://postgres:root@postgres:5432?sslmode=disable - cacheSize: 16384 - api: - port: 6060 - healthport: 6061 - timeout: 900s - paginationKey: - servername: - cafile: - keyfile: - certfile: - updater: - interval: 2h - notifier: - attempts: 3 - renotifyInterval: 2h - http: - endpoint: - servername: - cafile: - keyfile: - certfile: - proxy: -` - -func TestLoadWrongConfiguration(t *testing.T) { - tmpfile, err := ioutil.TempFile("", "clair-config") - if err != nil { - log.Fatal(err) - } - - defer os.Remove(tmpfile.Name()) // clean up - if _, err := tmpfile.Write([]byte(wrongConfig)); err != nil { - log.Fatal(err) - } - if err := tmpfile.Close(); err != nil { - log.Fatal(err) - } - - _, err = Load(tmpfile.Name()) - - assert.EqualError(t, err, ErrDatasourceNotLoaded.Error()) -} - -func TestLoad(t *testing.T) { - tmpfile, err := ioutil.TempFile("", "clair-config") - if err != nil { - log.Fatal(err) - } - - defer os.Remove(tmpfile.Name()) // clean up - - if _, err := tmpfile.Write([]byte(goodConfig)); err != nil { - log.Fatal(err) - } - if err := tmpfile.Close(); err != nil { - log.Fatal(err) - } - - _, err = Load(tmpfile.Name()) - assert.NoError(t, err) -} diff --git a/database/database.go b/database/database.go index 0c84f99eca..4ca13e42c8 100644 --- a/database/database.go +++ b/database/database.go @@ -17,7 +17,10 @@ package database import ( "errors" + "fmt" "time" + + "github.com/coreos/clair/config" ) var ( @@ -28,11 +31,37 @@ var ( // ErrInconsistent is an error that occurs when a database consistency check // fails (ie. when an entity which is supposed to be unique is detected twice) ErrInconsistent = errors.New("database: inconsistent database") - - // ErrCantOpen is an error that occurs when the database could not be opened - ErrCantOpen = errors.New("database: could not open database") ) +var drivers = make(map[string]Driver) + +// Driver is a function that opens a Datastore specified by its database driver type and specific +// configuration. +type Driver func(config.RegistrableComponentConfig) (Datastore, error) + +// Register makes a Constructor available by the provided name. +// +// If this function is called twice with the same name or if the Constructor is +// nil, it panics. +func Register(name string, driver Driver) { + if driver == nil { + panic("database: could not register nil Driver") + } + if _, dup := drivers[name]; dup { + panic("database: could not register duplicate Driver: " + name) + } + drivers[name] = driver +} + +// Open opens a Datastore specified by a configuration. +func Open(cfg config.RegistrableComponentConfig) (Datastore, error) { + driver, ok := drivers[cfg.Type] + if !ok { + return nil, fmt.Errorf("database: unknown Driver %q (forgotten configuration or import?)", cfg.Type) + } + return driver(cfg) +} + // Datastore is the interface that describes a database backend implementation. type Datastore interface { // # Namespace diff --git a/database/pgsql/complex_test.go b/database/pgsql/complex_test.go index 9660a5ecf9..46ba504ab4 100644 --- a/database/pgsql/complex_test.go +++ b/database/pgsql/complex_test.go @@ -23,11 +23,12 @@ import ( "testing" "time" + "github.com/pborman/uuid" + "github.com/stretchr/testify/assert" + "github.com/coreos/clair/database" "github.com/coreos/clair/utils" "github.com/coreos/clair/utils/types" - "github.com/pborman/uuid" - "github.com/stretchr/testify/assert" ) const ( @@ -36,7 +37,7 @@ const ( ) func TestRaceAffects(t *testing.T) { - datastore, err := OpenForTest("RaceAffects", false) + datastore, err := openDatabaseForTest("RaceAffects", false) if err != nil { t.Error(err) return diff --git a/database/pgsql/feature_test.go b/database/pgsql/feature_test.go index 251fca6c26..a857c30f6b 100644 --- a/database/pgsql/feature_test.go +++ b/database/pgsql/feature_test.go @@ -17,13 +17,14 @@ package pgsql import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/coreos/clair/database" "github.com/coreos/clair/utils/types" - "github.com/stretchr/testify/assert" ) func TestInsertFeature(t *testing.T) { - datastore, err := OpenForTest("InsertFeature", false) + datastore, err := openDatabaseForTest("InsertFeature", false) if err != nil { t.Error(err) return diff --git a/database/pgsql/keyvalue_test.go b/database/pgsql/keyvalue_test.go index 3897f4037f..5c7cc43f16 100644 --- a/database/pgsql/keyvalue_test.go +++ b/database/pgsql/keyvalue_test.go @@ -21,7 +21,7 @@ import ( ) func TestKeyValue(t *testing.T) { - datastore, err := OpenForTest("KeyValue", false) + datastore, err := openDatabaseForTest("KeyValue", false) if err != nil { t.Error(err) return diff --git a/database/pgsql/layer.go b/database/pgsql/layer.go index 71e5726a51..607d2b5a5f 100644 --- a/database/pgsql/layer.go +++ b/database/pgsql/layer.go @@ -41,9 +41,7 @@ func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities boo var namespaceName sql.NullString t := time.Now() - err := pgSQL.QueryRow(searchLayer, name). - Scan(&layer.ID, &layer.Name, &layer.EngineVersion, &parentID, &parentName, &namespaceID, - &namespaceName) + err := pgSQL.QueryRow(searchLayer, name).Scan(&layer.ID, &layer.Name, &layer.EngineVersion, &parentID, &parentName, &namespaceID, &namespaceName) observeQueryTime("FindLayer", "searchLayer", t) if err != nil { diff --git a/database/pgsql/layer_test.go b/database/pgsql/layer_test.go index dec0384ce1..d79f442a3f 100644 --- a/database/pgsql/layer_test.go +++ b/database/pgsql/layer_test.go @@ -18,14 +18,15 @@ import ( "fmt" "testing" + "github.com/stretchr/testify/assert" + "github.com/coreos/clair/database" cerrors "github.com/coreos/clair/utils/errors" "github.com/coreos/clair/utils/types" - "github.com/stretchr/testify/assert" ) func TestFindLayer(t *testing.T) { - datastore, err := OpenForTest("FindLayer", true) + datastore, err := openDatabaseForTest("FindLayer", true) if err != nil { t.Error(err) return @@ -102,7 +103,7 @@ func TestFindLayer(t *testing.T) { } func TestInsertLayer(t *testing.T) { - datastore, err := OpenForTest("InsertLayer", false) + datastore, err := openDatabaseForTest("InsertLayer", false) if err != nil { t.Error(err) return diff --git a/database/pgsql/lock_test.go b/database/pgsql/lock_test.go index e401e30a17..f2330f6f2f 100644 --- a/database/pgsql/lock_test.go +++ b/database/pgsql/lock_test.go @@ -22,7 +22,7 @@ import ( ) func TestLock(t *testing.T) { - datastore, err := OpenForTest("InsertNamespace", false) + datastore, err := openDatabaseForTest("InsertNamespace", false) if err != nil { t.Error(err) return diff --git a/database/pgsql/namespace_test.go b/database/pgsql/namespace_test.go index be8de68621..b9bf96fe33 100644 --- a/database/pgsql/namespace_test.go +++ b/database/pgsql/namespace_test.go @@ -18,12 +18,13 @@ import ( "fmt" "testing" - "github.com/coreos/clair/database" "github.com/stretchr/testify/assert" + + "github.com/coreos/clair/database" ) func TestInsertNamespace(t *testing.T) { - datastore, err := OpenForTest("InsertNamespace", false) + datastore, err := openDatabaseForTest("InsertNamespace", false) if err != nil { t.Error(err) return @@ -44,7 +45,7 @@ func TestInsertNamespace(t *testing.T) { } func TestListNamespace(t *testing.T) { - datastore, err := OpenForTest("ListNamespaces", true) + datastore, err := openDatabaseForTest("ListNamespaces", true) if err != nil { t.Error(err) return diff --git a/database/pgsql/notification_test.go b/database/pgsql/notification_test.go index 787cf567b4..6ab32d6d94 100644 --- a/database/pgsql/notification_test.go +++ b/database/pgsql/notification_test.go @@ -4,14 +4,15 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/coreos/clair/database" cerrors "github.com/coreos/clair/utils/errors" "github.com/coreos/clair/utils/types" - "github.com/stretchr/testify/assert" ) func TestNotification(t *testing.T) { - datastore, err := OpenForTest("Notification", false) + datastore, err := openDatabaseForTest("Notification", false) if err != nil { t.Error(err) return diff --git a/database/pgsql/pgsql.go b/database/pgsql/pgsql.go index 15789dd9da..c849e12cd2 100644 --- a/database/pgsql/pgsql.go +++ b/database/pgsql/pgsql.go @@ -19,22 +19,23 @@ import ( "database/sql" "fmt" "io/ioutil" - "os" + "net/url" "path" "runtime" "strings" "time" "bitbucket.org/liamstask/goose/lib/goose" - "github.com/coreos/clair/config" - "github.com/coreos/clair/database" - "github.com/coreos/clair/utils" - cerrors "github.com/coreos/clair/utils/errors" "github.com/coreos/pkg/capnslog" "github.com/hashicorp/golang-lru" "github.com/lib/pq" - "github.com/pborman/uuid" "github.com/prometheus/client_golang/prometheus" + "gopkg.in/yaml.v2" + + "github.com/coreos/clair/config" + "github.com/coreos/clair/database" + "github.com/coreos/clair/utils" + cerrors "github.com/coreos/clair/utils/errors" ) var ( @@ -72,6 +73,8 @@ func init() { prometheus.MustRegister(promCacheQueriesTotal) prometheus.MustRegister(promQueryDurationMilliseconds) prometheus.MustRegister(promConcurrentLockVAFV) + + database.Register("pgsql", openDatabase) } type Queryer interface { @@ -81,46 +84,137 @@ type Queryer interface { type pgSQL struct { *sql.DB - cache *lru.ARCCache + cache *lru.ARCCache + config Config } +// Close closes the database and destroys if ManageDatabaseLifecycle has been specified in +// the configuration. func (pgSQL *pgSQL) Close() { - pgSQL.DB.Close() + if pgSQL.DB != nil { + pgSQL.DB.Close() + } + + if pgSQL.config.ManageDatabaseLifecycle { + dbName, pgSourceURL, _ := parseConnectionString(pgSQL.config.Source) + dropDatabase(pgSourceURL, dbName) + } } +// Ping verifies that the database is accessible. func (pgSQL *pgSQL) Ping() bool { return pgSQL.DB.Ping() == nil } -// Open creates a Datastore backed by a PostgreSQL database. -// -// It will run immediately every necessary migration on the database. -func Open(config *config.DatabaseConfig) (database.Datastore, error) { - // Run migrations. - if err := migrate(config.Source); err != nil { - log.Error(err) - return nil, database.ErrCantOpen +// Config is the configuration that is used by openDatabase. +type Config struct { + Source string + CacheSize int + + ManageDatabaseLifecycle bool + FixturePath string +} + +// openDatabase opens a PostgresSQL-backed Datastore using the given configuration. +// It immediately every necessary migrations. If ManageDatabaseLifecycle is specified, +// the database will be created first. If FixturePath is specified, every SQL queries that are +// present insides will be executed. +func openDatabase(registrableComponentConfig config.RegistrableComponentConfig) (database.Datastore, error) { + var pg pgSQL + var err error + + // Parse configuration. + pg.config = Config{ + CacheSize: 16384, + } + bytes, err := yaml.Marshal(registrableComponentConfig.Options) + if err != nil { + return nil, fmt.Errorf("pgsql: could not load configuration: %v", err) + } + err = yaml.Unmarshal(bytes, &pg.config) + if err != nil { + return nil, fmt.Errorf("pgsql: could not load configuration: %v", err) + } + + dbName, pgSourceURL, err := parseConnectionString(pg.config.Source) + if err != nil { + return nil, err + } + + // Create database. + if pg.config.ManageDatabaseLifecycle { + log.Info("pgsql: creating database") + if err := createDatabase(pgSourceURL, dbName); err != nil { + return nil, err + } } // Open database. - db, err := sql.Open("postgres", config.Source) + pg.DB, err = sql.Open("postgres", pg.config.Source) if err != nil { - log.Error(err) - return nil, database.ErrCantOpen + pg.Close() + return nil, fmt.Errorf("pgsql: could not open database: %v", err) + } + + // Verify database state. + if err := pg.DB.Ping(); err != nil { + pg.Close() + return nil, fmt.Errorf("pgsql: could not open database: %v", err) + } + + // Run migrations. + if err := migrate(pg.config.Source); err != nil { + pg.Close() + return nil, err + } + + // Load fixture data. + if pg.config.FixturePath != "" { + log.Info("pgsql: loading fixtures") + + d, err := ioutil.ReadFile(pg.config.FixturePath) + if err != nil { + pg.Close() + return nil, fmt.Errorf("pgsql: could not open fixture file: %v", err) + } + + _, err = pg.DB.Exec(string(d)) + if err != nil { + pg.Close() + return nil, fmt.Errorf("pgsql: an error occured while importing fixtures: %v", err) + } } // Initialize cache. // TODO(Quentin-M): Benchmark with a simple LRU Cache. - var cache *lru.ARCCache - if config.CacheSize > 0 { - cache, _ = lru.NewARC(config.CacheSize) + if pg.config.CacheSize > 0 { + pg.cache, _ = lru.NewARC(pg.config.CacheSize) + } + + return &pg, nil +} + +func parseConnectionString(source string) (dbName string, pgSourceURL string, err error) { + if source == "" { + return "", "", cerrors.NewBadRequestError("pgsql: no database connection string specified") } - return &pgSQL{DB: db, cache: cache}, nil + sourceURL, err := url.Parse(source) + if err != nil { + return "", "", cerrors.NewBadRequestError("pgsql: database connection string is not a valid URL") + } + + dbName = strings.TrimPrefix(sourceURL.Path, "/") + + pgSource := *sourceURL + pgSource.Path = "/postgres" + pgSourceURL = pgSource.String() + + return } // migrate runs all available migrations on a pgSQL database. -func migrate(dataSource string) error { +func migrate(source string) error { log.Info("running database migrations") _, filename, _, _ := runtime.Caller(1) @@ -129,7 +223,7 @@ func migrate(dataSource string) error { MigrationsDir: migrationDir, Driver: goose.DBDriver{ Name: "postgres", - OpenStr: dataSource, + OpenStr: source, Import: "github.com/lib/pq", Dialect: &goose.PostgresDialect{}, }, @@ -138,13 +232,13 @@ func migrate(dataSource string) error { // Determine the most recent revision available from the migrations folder. target, err := goose.GetMostRecentDBVersion(conf.MigrationsDir) if err != nil { - return err + return fmt.Errorf("pgsql: could not get most recent migration: %v", err) } - // Run migrations + // Run migrations. err = goose.RunMigrations(conf, conf.MigrationsDir, target) if err != nil { - return err + return fmt.Errorf("pgsql: an error occured while running migrations: %v", err) } log.Info("database migration ran successfully") @@ -152,109 +246,51 @@ func migrate(dataSource string) error { } // createDatabase creates a new database. -// The dataSource parameter should not contain a dbname. -func createDatabase(dataSource, databaseName string) error { +// The source parameter should not contain a dbname. +func createDatabase(source, dbName string) error { // Open database. - db, err := sql.Open("postgres", dataSource) + db, err := sql.Open("postgres", source) if err != nil { - return fmt.Errorf("could not open database (CreateDatabase): %v", err) + return fmt.Errorf("pgsql: could not open 'postgres' database for creation: %v", err) } defer db.Close() // Create database. - _, err = db.Exec("CREATE DATABASE " + databaseName) + _, err = db.Exec("CREATE DATABASE " + dbName) if err != nil { - return fmt.Errorf("could not create database: %v", err) + return fmt.Errorf("pgsql: could not create database: %v", err) } return nil } // dropDatabase drops an existing database. -// The dataSource parameter should not contain a dbname. -func dropDatabase(dataSource, databaseName string) error { +// The source parameter should not contain a dbname. +func dropDatabase(source, dbName string) error { // Open database. - db, err := sql.Open("postgres", dataSource) + db, err := sql.Open("postgres", source) if err != nil { return fmt.Errorf("could not open database (DropDatabase): %v", err) } defer db.Close() // Kill any opened connection. - if _, err := db.Exec(` + if _, err = db.Exec(` SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE pg_stat_activity.datname = $1 - AND pid <> pg_backend_pid()`, databaseName); err != nil { + AND pid <> pg_backend_pid()`, dbName); err != nil { return fmt.Errorf("could not drop database: %v", err) } // Drop database. - if _, err = db.Exec("DROP DATABASE " + databaseName); err != nil { + if _, err = db.Exec("DROP DATABASE " + dbName); err != nil { return fmt.Errorf("could not drop database: %v", err) } return nil } -// pgSQLTest wraps pgSQL for testing purposes. -// Its Close() method drops the database. -type pgSQLTest struct { - *pgSQL - dataSourceDefaultDatabase string - dbName string -} - -// OpenForTest creates a test Datastore backed by a new PostgreSQL database. -// It creates a new unique and prefixed ("test_") database. -// Using Close() will drop the database. -func OpenForTest(name string, withTestData bool) (*pgSQLTest, error) { - // Define the PostgreSQL connection strings. - dataSource := "host=127.0.0.1 sslmode=disable user=postgres dbname=" - if dataSourceEnv := os.Getenv("CLAIR_TEST_PGSQL"); dataSourceEnv != "" { - dataSource = dataSourceEnv + " dbname=" - } - dbName := "test_" + strings.ToLower(name) + "_" + strings.Replace(uuid.New(), "-", "_", -1) - dataSourceDefaultDatabase := dataSource + "postgres" - dataSourceTestDatabase := dataSource + dbName - - // Create database. - if err := createDatabase(dataSourceDefaultDatabase, dbName); err != nil { - log.Error(err) - return nil, database.ErrCantOpen - } - - // Open database. - db, err := Open(&config.DatabaseConfig{Source: dataSourceTestDatabase, CacheSize: 0}) - if err != nil { - dropDatabase(dataSourceDefaultDatabase, dbName) - log.Error(err) - return nil, database.ErrCantOpen - } - - // Load test data if specified. - if withTestData { - _, filename, _, _ := runtime.Caller(0) - d, _ := ioutil.ReadFile(path.Join(path.Dir(filename)) + "/testdata/data.sql") - _, err = db.(*pgSQL).Exec(string(d)) - if err != nil { - dropDatabase(dataSourceDefaultDatabase, dbName) - log.Error(err) - return nil, database.ErrCantOpen - } - } - - return &pgSQLTest{ - pgSQL: db.(*pgSQL), - dataSourceDefaultDatabase: dataSourceDefaultDatabase, - dbName: dbName}, nil -} - -func (pgSQL *pgSQLTest) Close() { - pgSQL.DB.Close() - dropDatabase(pgSQL.dataSourceDefaultDatabase, pgSQL.dbName) -} - // handleError logs an error with an extra description and masks the error if it's an SQL one. // This ensures we never return plain SQL errors and leak anything. func handleError(desc string, err error) error { diff --git a/database/pgsql/pgsql_test.go b/database/pgsql/pgsql_test.go new file mode 100644 index 0000000000..7685700788 --- /dev/null +++ b/database/pgsql/pgsql_test.go @@ -0,0 +1,45 @@ +package pgsql + +import ( + "fmt" + "os" + "path" + "runtime" + "strings" + + "github.com/coreos/clair/config" + "github.com/pborman/uuid" +) + +func openDatabaseForTest(testName string, loadFixture bool) (*pgSQL, error) { + ds, err := openDatabase(generateTestConfig(testName, loadFixture)) + if err != nil { + return nil, err + } + datastore := ds.(*pgSQL) + return datastore, nil +} + +func generateTestConfig(testName string, loadFixture bool) config.RegistrableComponentConfig { + dbName := "test_" + strings.ToLower(testName) + "_" + strings.Replace(uuid.New(), "-", "_", -1) + + var fixturePath string + if loadFixture { + _, filename, _, _ := runtime.Caller(0) + fixturePath = path.Join(path.Dir(filename)) + "/testdata/data.sql" + } + + source := fmt.Sprintf("postgresql://postgres@127.0.0.1:5432/%s?sslmode=disable", dbName) + if sourceEnv := os.Getenv("CLAIR_TEST_PGSQL"); sourceEnv != "" { + source = fmt.Sprintf(sourceEnv, dbName) + } + + return config.RegistrableComponentConfig{ + Options: map[string]interface{}{ + "source": source, + "cachesize": 0, + "managedatabaselifecycle": true, + "fixturepath": fixturePath, + }, + } +} diff --git a/database/pgsql/vulnerability_test.go b/database/pgsql/vulnerability_test.go index 8177fdb057..d20c2e35f4 100644 --- a/database/pgsql/vulnerability_test.go +++ b/database/pgsql/vulnerability_test.go @@ -18,14 +18,15 @@ import ( "reflect" "testing" + "github.com/stretchr/testify/assert" + "github.com/coreos/clair/database" cerrors "github.com/coreos/clair/utils/errors" "github.com/coreos/clair/utils/types" - "github.com/stretchr/testify/assert" ) func TestFindVulnerability(t *testing.T) { - datastore, err := OpenForTest("FindVulnerability", true) + datastore, err := openDatabaseForTest("FindVulnerability", true) if err != nil { t.Error(err) return @@ -75,7 +76,7 @@ func TestFindVulnerability(t *testing.T) { } func TestDeleteVulnerability(t *testing.T) { - datastore, err := OpenForTest("InsertVulnerability", true) + datastore, err := openDatabaseForTest("InsertVulnerability", true) if err != nil { t.Error(err) return @@ -97,7 +98,7 @@ func TestDeleteVulnerability(t *testing.T) { } func TestInsertVulnerability(t *testing.T) { - datastore, err := OpenForTest("InsertVulnerability", false) + datastore, err := openDatabaseForTest("InsertVulnerability", false) if err != nil { t.Error(err) return