diff --git a/testscript/script_test.go b/testscript/script_test.go index 3c544b037..1372c8258 100644 --- a/testscript/script_test.go +++ b/testscript/script_test.go @@ -3,6 +3,7 @@ package testscript import ( "bytes" "context" + "database/sql" "flag" "fmt" "io" @@ -30,7 +31,6 @@ import ( "github.com/rogpeppe/go-internal/testscript" "github.com/spf13/cobra" "golang.org/x/crypto/ssh" - _ "modernc.org/sqlite" // sqlite Driver ) var update = flag.Bool("update", false, "update script files") @@ -101,6 +101,26 @@ func TestScript(t *testing.T) { cfg.LFS.Enabled = true cfg.LFS.SSHEnabled = true + dbDriver := os.Getenv("DB_DRIVER") + if dbDriver != "" { + cfg.DB.Driver = dbDriver + } + + dbDsn := os.Getenv("DB_DATA_SOURCE") + if dbDsn != "" { + cfg.DB.DataSource = dbDsn + } + + if cfg.DB.Driver == "postgres" { + err, cleanup := setupPostgres(e.T(), cfg) + if err != nil { + return err + } + if cleanup != nil { + e.Defer(cleanup) + } + } + if err := cfg.Validate(); err != nil { return err } @@ -117,7 +137,6 @@ func TestScript(t *testing.T) { defer f.Close() // nolint: errcheck } - // TODO: test postgres dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource) if err != nil { return fmt.Errorf("open database: %w", err) @@ -385,3 +404,68 @@ func cmdCurl(ts *testscript.TestScript, neg bool, args []string) { check(ts, cmd.Execute(), neg) } + +func setupPostgres(t testscript.T, cfg *config.Config) (error, func()) { + // Indicates postgres + // Create a disposable database + dbName := fmt.Sprintf("softserve_test_%d", time.Now().UnixNano()) + dbDsn := os.Getenv("DB_DATA_SOURCE") + if dbDsn == "" { + cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable" + } + + dbUrl, err := url.Parse(cfg.DB.DataSource) + if err != nil { + return err, nil + } + + connInfo := fmt.Sprintf("host=%s sslmode=disable", dbUrl.Hostname()) + username := dbUrl.User.Username() + if username != "" { + connInfo += fmt.Sprintf(" user=%s", username) + password, ok := dbUrl.User.Password() + if ok { + username = fmt.Sprintf("%s:%s", username, password) + connInfo += fmt.Sprintf(" password=%s", password) + } + username = fmt.Sprintf("%s@", username) + } else { + connInfo += " user=postgres" + } + + port := dbUrl.Port() + if port != "" { + connInfo += fmt.Sprintf(" port=%s", port) + port = fmt.Sprintf(":%s", port) + } + + cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable", + dbUrl.Scheme, + username, + dbUrl.Hostname(), + port, + dbName, + ) + + // Create the database + db, err := sql.Open(cfg.DB.Driver, connInfo) + if err != nil { + return err, nil + } + + if _, err := db.Exec("CREATE DATABASE " + dbName); err != nil { + return err, nil + } + + return nil, func() { + db, err := sql.Open(cfg.DB.Driver, connInfo) + if err != nil { + t.Log("failed to open database", dbName, err) + return + } + + if _, err := db.Exec("DROP DATABASE " + dbName); err != nil { + t.Log("failed to drop database", dbName, err) + } + } +}