Skip to content

Commit

Permalink
Fix the non-default dbname error
Browse files Browse the repository at this point in the history
The linked issue described in great detail an issue where we assumed everyone would use the default database user, whose home DB defaults to the postgres database. When that was not the case, the snapshots would fail silently as the user would not connect to the right database to take the commands.

This PR fixes the issue by adding the dbname by default in the command, and adds a test to validate this works as intended. In addition, it also adds some logic to handle any error that does not cause the exec command to fail, such as database access failures.

Run the added test to test this works as intended.

Closes testcontainers#2474
  • Loading branch information
Minivera committed Apr 18, 2024
1 parent 9f1d656 commit c82de90
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 8 deletions.
34 changes: 30 additions & 4 deletions modules/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package postgres
import (
"context"
"fmt"
"io"
"net"
"path/filepath"
"strings"
Expand All @@ -26,10 +27,9 @@ type PostgresContainer struct {
snapshotName string
}


// MustConnectionString panics if the address cannot be determined.
func (c *PostgresContainer) MustConnectionString(ctx context.Context, args ...string) string {
addr, err := c.ConnectionString(ctx,args...)
addr, err := c.ConnectionString(ctx, args...)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -185,6 +185,10 @@ func (c *PostgresContainer) Snapshot(ctx context.Context, opts ...SnapshotOption
snapshotName = config.snapshotName
}

if c.dbName == "postgres" {
return fmt.Errorf("cannot snapshot the postgres system database as it cannot be dropped to be restored")
}

// execute the commands to create the snapshot, in order
cmds := []string{
// Drop the snapshot database if it already exists
Expand All @@ -196,10 +200,19 @@ func (c *PostgresContainer) Snapshot(ctx context.Context, opts ...SnapshotOption
}

for _, cmd := range cmds {
_, _, err := c.Exec(ctx, []string{"psql", "-U", c.user, "-c", cmd})
exitCode, reader, err := c.Exec(ctx, []string{"psql", "-U", c.user, "-d", c.dbName, "-c", cmd})
if err != nil {
return err
}
if exitCode != 0 {
buf := new(strings.Builder)
_, err := io.Copy(buf, reader)
if err != nil {
return fmt.Errorf("non-zero exit code for snapshot command, could not read command output: %w", err)
}

return fmt.Errorf("non-zero exit code for snapshot command: %s", buf.String())
}
}

c.snapshotName = snapshotName
Expand All @@ -220,6 +233,10 @@ func (c *PostgresContainer) Restore(ctx context.Context, opts ...SnapshotOption)
snapshotName = config.snapshotName
}

if c.dbName == "postgres" {
return fmt.Errorf("cannot restore the postgres system database as it cannot be dropped to be restored")
}

// execute the commands to restore the snapshot, in order
cmds := []string{
// Drop the entire database by connecting to the postgres global database
Expand All @@ -229,10 +246,19 @@ func (c *PostgresContainer) Restore(ctx context.Context, opts ...SnapshotOption)
}

for _, cmd := range cmds {
_, _, err := c.Exec(ctx, []string{"psql", "-U", c.user, "-d", "postgres", "-c", cmd})
exitCode, reader, err := c.Exec(ctx, []string{"psql", "-v", "ON_ERROR_STOP=1", "-U", c.user, "-d", "postgres", "-c", cmd})
if err != nil {
return err
}
if exitCode != 0 {
buf := new(strings.Builder)
_, err := io.Copy(buf, reader)
if err != nil {
return fmt.Errorf("non-zero exit code for restore command, could not read command output: %w", err)
}

return fmt.Errorf("non-zero exit code for restore command: %s", buf.String())
}
}

return nil
Expand Down
81 changes: 77 additions & 4 deletions modules/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ func TestPostgres(t *testing.T) {
connStr, err := container.ConnectionString(ctx, "sslmode=disable", "application_name=test")
// }
require.NoError(t, err)
mustConnStr := container.MustConnectionString(ctx,"sslmode=disable", "application_name=test")
if mustConnStr!=connStr{

mustConnStr := container.MustConnectionString(ctx, "sslmode=disable", "application_name=test")
if mustConnStr != connStr {
t.Errorf("ConnectionString was not equal to MustConnectionString")
}

// Ensure connection string is using generic format
id, err := container.MappedPort(ctx, "5432/tcp")
require.NoError(t, err)
Expand Down Expand Up @@ -327,3 +327,76 @@ func TestSnapshot(t *testing.T) {
})
// }
}

func TestSnapshotWithOverrides(t *testing.T) {
ctx := context.Background()

dbname := "other-db"
user := "other-user"
password := "other-password"

container, err := postgres.RunContainer(
ctx,
testcontainers.WithImage("docker.io/postgres:16-alpine"),
postgres.WithDatabase(dbname),
postgres.WithUsername(user),
postgres.WithPassword(password),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(5*time.Second)),
)
if err != nil {
t.Fatal(err)
}

_, _, err = container.Exec(ctx, []string{"psql", "-U", user, "-d", dbname, "-c", "CREATE TABLE users (id SERIAL, name TEXT NOT NULL, age INT NOT NULL)"})
if err != nil {
t.Fatal(err)
}

err = container.Snapshot(ctx, postgres.WithSnapshotName("other-snapshot"))
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := container.Terminate(ctx); err != nil {
t.Fatalf("failed to terminate container: %s", err)
}
})

dbURL, err := container.ConnectionString(ctx)
if err != nil {
t.Fatal(err)
}

t.Run("Test that the restore works when not using defaults", func(t *testing.T) {
_, _, err = container.Exec(ctx, []string{"psql", "-U", user, "-d", dbname, "-c", "INSERT INTO users(name, age) VALUES ('test', 42)"})
if err != nil {
t.Fatal(err)
}

// Doing the restore before we connect since this resets the pgx connection
err = container.Restore(ctx)
if err != nil {
t.Fatal(err)
}

conn, err := pgx.Connect(context.Background(), dbURL)
if err != nil {
t.Fatal(err)
}
defer conn.Close(context.Background())

var count int64
err = conn.QueryRow(context.Background(), "SELECT COUNT(1) FROM users").Scan(&count)
if err != nil {
t.Fatal(err)
}

if count != 0 {
t.Fatalf("Expected %d to equal `0`", count)
}
})
}

0 comments on commit c82de90

Please sign in to comment.