Skip to content

Commit

Permalink
fix(postgres): Fix the non-default dbname error (testcontainers#2489)
Browse files Browse the repository at this point in the history
* Fix the non-default dbname error

The linked issue described in great detail an issue where we assumed everyone would use the default database user, whose home DB defaults to the postgres database. When that was not the case, the snapshots would fail silently as the user would not connect to the right database to take the commands.

This PR fixes the issue by adding the dbname by default in the command, and adds a test to validate this works as intended. In addition, it also adds some logic to handle any error that does not cause the exec command to fail, such as database access failures.

Run the added test to test this works as intended.

Closes testcontainers#2474

* Document the postgres dbname issue in the docs
  • Loading branch information
Minivera authored and mdelapenya committed Apr 23, 2024
1 parent e855bee commit 9f03b53
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 6 deletions.
6 changes: 6 additions & 0 deletions docs/modules/postgres.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ This example shows the usage of the postgres module's Snapshot feature to give e
to recreate the database container on every test or run heavy scripts to clean your database. This makes the individual
tests very modular, since they always run on a brand-new database.

!!!tip
You should never pass the `"postgres"` system database as the container database name if you want to use snapshots.
The Snapshot logic requires dropping the connected database and using the system database to run commands, which will
not work if the database for the container is set to `"postgres"`.


<!--codeinclude-->
[Test with a reusable Postgres container](../../modules/postgres/postgres_test.go) inside_block:snapshotAndReset
<!--/codeinclude-->
31 changes: 29 additions & 2 deletions modules/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package postgres
import (
"context"
"fmt"
"io"
"net"
"path/filepath"
"strings"
Expand Down Expand Up @@ -184,6 +185,10 @@ func (c *PostgresContainer) Snapshot(ctx context.Context, opts ...SnapshotOption
snapshotName = config.snapshotName
}

if c.dbName == "postgres" {
return fmt.Errorf("cannot snapshot the postgres system database as it cannot be dropped to be restored")
}

// execute the commands to create the snapshot, in order
cmds := []string{
// Drop the snapshot database if it already exists
Expand All @@ -195,10 +200,19 @@ func (c *PostgresContainer) Snapshot(ctx context.Context, opts ...SnapshotOption
}

for _, cmd := range cmds {
_, _, err := c.Exec(ctx, []string{"psql", "-U", c.user, "-c", cmd})
exitCode, reader, err := c.Exec(ctx, []string{"psql", "-U", c.user, "-d", c.dbName, "-c", cmd})
if err != nil {
return err
}
if exitCode != 0 {
buf := new(strings.Builder)
_, err := io.Copy(buf, reader)
if err != nil {
return fmt.Errorf("non-zero exit code for snapshot command, could not read command output: %w", err)
}

return fmt.Errorf("non-zero exit code for snapshot command: %s", buf.String())
}
}

c.snapshotName = snapshotName
Expand All @@ -219,6 +233,10 @@ func (c *PostgresContainer) Restore(ctx context.Context, opts ...SnapshotOption)
snapshotName = config.snapshotName
}

if c.dbName == "postgres" {
return fmt.Errorf("cannot restore the postgres system database as it cannot be dropped to be restored")
}

// execute the commands to restore the snapshot, in order
cmds := []string{
// Drop the entire database by connecting to the postgres global database
Expand All @@ -228,10 +246,19 @@ func (c *PostgresContainer) Restore(ctx context.Context, opts ...SnapshotOption)
}

for _, cmd := range cmds {
_, _, err := c.Exec(ctx, []string{"psql", "-U", c.user, "-d", "postgres", "-c", cmd})
exitCode, reader, err := c.Exec(ctx, []string{"psql", "-v", "ON_ERROR_STOP=1", "-U", c.user, "-d", "postgres", "-c", cmd})
if err != nil {
return err
}
if exitCode != 0 {
buf := new(strings.Builder)
_, err := io.Copy(buf, reader)
if err != nil {
return fmt.Errorf("non-zero exit code for restore command, could not read command output: %w", err)
}

return fmt.Errorf("non-zero exit code for restore command: %s", buf.String())
}
}

return nil
Expand Down
81 changes: 77 additions & 4 deletions modules/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ func TestPostgres(t *testing.T) {
connStr, err := container.ConnectionString(ctx, "sslmode=disable", "application_name=test")
// }
require.NoError(t, err)
mustConnStr := container.MustConnectionString(ctx,"sslmode=disable", "application_name=test")
if mustConnStr!=connStr{

mustConnStr := container.MustConnectionString(ctx, "sslmode=disable", "application_name=test")
if mustConnStr != connStr {
t.Errorf("ConnectionString was not equal to MustConnectionString")
}

// Ensure connection string is using generic format
id, err := container.MappedPort(ctx, "5432/tcp")
require.NoError(t, err)
Expand Down Expand Up @@ -327,3 +327,76 @@ func TestSnapshot(t *testing.T) {
})
// }
}

func TestSnapshotWithOverrides(t *testing.T) {
ctx := context.Background()

dbname := "other-db"
user := "other-user"
password := "other-password"

container, err := postgres.RunContainer(
ctx,
testcontainers.WithImage("docker.io/postgres:16-alpine"),
postgres.WithDatabase(dbname),
postgres.WithUsername(user),
postgres.WithPassword(password),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(5*time.Second)),
)
if err != nil {
t.Fatal(err)
}

_, _, err = container.Exec(ctx, []string{"psql", "-U", user, "-d", dbname, "-c", "CREATE TABLE users (id SERIAL, name TEXT NOT NULL, age INT NOT NULL)"})
if err != nil {
t.Fatal(err)
}

err = container.Snapshot(ctx, postgres.WithSnapshotName("other-snapshot"))
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := container.Terminate(ctx); err != nil {
t.Fatalf("failed to terminate container: %s", err)
}
})

dbURL, err := container.ConnectionString(ctx)
if err != nil {
t.Fatal(err)
}

t.Run("Test that the restore works when not using defaults", func(t *testing.T) {
_, _, err = container.Exec(ctx, []string{"psql", "-U", user, "-d", dbname, "-c", "INSERT INTO users(name, age) VALUES ('test', 42)"})
if err != nil {
t.Fatal(err)
}

// Doing the restore before we connect since this resets the pgx connection
err = container.Restore(ctx)
if err != nil {
t.Fatal(err)
}

conn, err := pgx.Connect(context.Background(), dbURL)
if err != nil {
t.Fatal(err)
}
defer conn.Close(context.Background())

var count int64
err = conn.QueryRow(context.Background(), "SELECT COUNT(1) FROM users").Scan(&count)
if err != nil {
t.Fatal(err)
}

if count != 0 {
t.Fatalf("Expected %d to equal `0`", count)
}
})
}

0 comments on commit 9f03b53

Please sign in to comment.