diff --git a/README.md b/README.md index f4ee042..cfc6cde 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,18 @@ database := gosumer.PgDatabase{ } ``` +If you are using a custom schema, you can specify it with backticks: +```go +database := gosumer.PgDatabase{ +Host: "localhost", +Port: 5432, +User: "app", +Password: "!ChangeMe!", +Database: "app", +TableName: `"myschema"."messenger_messages"`, +} +``` + For RabbitMQ: ```go database := gosumer.RabbitMQ{ diff --git a/pgsql_transport.go b/pgsql_transport.go index 2db0d73..fe2d885 100644 --- a/pgsql_transport.go +++ b/pgsql_transport.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "strings" "time" "github.com/jackc/pgx/v5/pgxpool" @@ -32,6 +33,15 @@ func (database PgDatabase) connect() error { return nil } +func (database PgDatabase) GetChannelName() string { + // Symfony uses the format "schema.table" for channel name + if strings.Contains(database.TableName, ".") { + return fmt.Sprintf(`"%s"`, strings.Replace(database.TableName, `"`, "", -1)) + } + + return database.TableName +} + func (database PgDatabase) listen(fn process, message any, sec int) error { err := database.connect() @@ -45,12 +55,12 @@ func (database PgDatabase) listen(fn process, message any, sec int) error { log.Printf("Successfully connected to the database!") - _, err = pool.Exec(context.Background(), fmt.Sprintf("LISTEN %s", database.TableName)) + _, err = pool.Exec(context.Background(), fmt.Sprintf("LISTEN %s", database.GetChannelName())) if err != nil { return err } - defer pool.Exec(context.Background(), fmt.Sprintf("UNLISTEN %s", database.TableName)) + defer pool.Exec(context.Background(), fmt.Sprintf("UNLISTEN %s", database.GetChannelName())) conn, err := pool.Acquire(context.Background()) if err != nil { @@ -77,7 +87,7 @@ func (database PgDatabase) listenEvery(seconds int, fn process, message any) { go func() error { for { - <- time.After(delay) + <-time.After(delay) _ = database.processMessage(fn, message) } }() diff --git a/pgsql_transport_test.go b/pgsql_transport_test.go index 337f40e..c475925 100644 --- a/pgsql_transport_test.go +++ b/pgsql_transport_test.go @@ -3,6 +3,7 @@ package gosumer import ( "context" "fmt" + "strings" "testing" "time" @@ -36,12 +37,35 @@ func setupDatabase(t *testing.T) (*pgxpool.Pool, PgDatabase) { TableName: "table_name", } + return initDatabase(t, database), database +} + +func setupDatabaseWithSchema(t *testing.T) (*pgxpool.Pool, PgDatabase) { + database := PgDatabase{ + Host: "localhost", + Port: 5432, + User: "postgres", + Password: "postgres", + Database: "postgres", + TableName: `"table_name"."schema_name"`, + } + + return initDatabase(t, database), database +} + +func initDatabase(t *testing.T, database PgDatabase) (pool *pgxpool.Pool) { pool, err := pgxpool.New(context.Background(), fmt.Sprintf("postgres://%s:%s@%s:%d/%s", database.User, database.Password, database.Host, database.Port, database.Database)) if err != nil { t.Errorf("Expected no error, got %v", err) } pool.Exec(context.Background(), fmt.Sprintf("DROP TABLE IF EXISTS %s", database.TableName)) + + if strings.Contains(database.TableName, ".") { + schema := strings.Split(database.TableName, ".")[0] + pool.Exec(context.Background(), fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema)) + } + pool.Exec(context.Background(), fmt.Sprintf("CREATE TABLE %s (id BIGSERIAL NOT NULL, body TEXT NOT NULL, headers TEXT NOT NULL, queue_name VARCHAR(190) NOT NULL, created_at TIMESTAMP(0) WITHOUT TIME ZONE NOT NULL, available_at TIMESTAMP(0) WITHOUT TIME ZONE NOT NULL, delivered_at TIMESTAMP(0) WITHOUT TIME ZONE DEFAULT NULL, PRIMARY KEY(id))", database.TableName)) _, err = pool.Exec(context.Background(), fmt.Sprintf("INSERT INTO %s (body, headers, queue_name, created_at, available_at, delivered_at) VALUES ('{\"id\": 1}', '{}', 'go', NOW(), NOW(), NULL)", database.TableName)) @@ -54,7 +78,7 @@ func setupDatabase(t *testing.T) (*pgxpool.Pool, PgDatabase) { t.Errorf("Expected no error, got %v", err) } - return pool, database + return pool } func TestPgDelete(t *testing.T) { @@ -84,6 +108,18 @@ func TestPgListen(t *testing.T) { pool, database := setupDatabase(t) defer pool.Close() + doTestPgListen(t, database) +} + +func TestPgListenWithSchema(t *testing.T) { + pool, database := setupDatabaseWithSchema(t) + defer pool.Close() + + doTestPgListen(t, database) +} + +func doTestPgListen(t *testing.T, database PgDatabase) { + go func() { err := database.listen(processMessage, Message{}, 5) if err != nil {