Skip to content

Commit

Permalink
feat: support postgres custom channel (#21)
Browse files Browse the repository at this point in the history
* feat: add test for schema listen failure

* fix: schema support

* fix: update documentation for postgres schema
  • Loading branch information
wesnick authored Mar 18, 2024
1 parent 3a27cda commit f7de1c0
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 4 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,18 @@ database := gosumer.PgDatabase{
}
```

If you are using a custom schema, you can specify it with backticks:
```go
database := gosumer.PgDatabase{
Host: "localhost",
Port: 5432,
User: "app",
Password: "!ChangeMe!",
Database: "app",
TableName: `"myschema"."messenger_messages"`,
}
```

For RabbitMQ:
```go
database := gosumer.RabbitMQ{
Expand Down
16 changes: 13 additions & 3 deletions pgsql_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log"
"strings"
"time"

"github.com/jackc/pgx/v5/pgxpool"
Expand Down Expand Up @@ -32,6 +33,15 @@ func (database PgDatabase) connect() error {
return nil
}

func (database PgDatabase) GetChannelName() string {
// Symfony uses the format "schema.table" for channel name
if strings.Contains(database.TableName, ".") {
return fmt.Sprintf(`"%s"`, strings.Replace(database.TableName, `"`, "", -1))
}

return database.TableName
}

func (database PgDatabase) listen(fn process, message any, sec int) error {
err := database.connect()

Expand All @@ -45,12 +55,12 @@ func (database PgDatabase) listen(fn process, message any, sec int) error {

log.Printf("Successfully connected to the database!")

_, err = pool.Exec(context.Background(), fmt.Sprintf("LISTEN %s", database.TableName))
_, err = pool.Exec(context.Background(), fmt.Sprintf("LISTEN %s", database.GetChannelName()))
if err != nil {
return err
}

defer pool.Exec(context.Background(), fmt.Sprintf("UNLISTEN %s", database.TableName))
defer pool.Exec(context.Background(), fmt.Sprintf("UNLISTEN %s", database.GetChannelName()))

conn, err := pool.Acquire(context.Background())
if err != nil {
Expand All @@ -77,7 +87,7 @@ func (database PgDatabase) listenEvery(seconds int, fn process, message any) {

go func() error {
for {
<- time.After(delay)
<-time.After(delay)
_ = database.processMessage(fn, message)
}
}()
Expand Down
38 changes: 37 additions & 1 deletion pgsql_transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gosumer
import (
"context"
"fmt"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -36,12 +37,35 @@ func setupDatabase(t *testing.T) (*pgxpool.Pool, PgDatabase) {
TableName: "table_name",
}

return initDatabase(t, database), database
}

func setupDatabaseWithSchema(t *testing.T) (*pgxpool.Pool, PgDatabase) {
database := PgDatabase{
Host: "localhost",
Port: 5432,
User: "postgres",
Password: "postgres",
Database: "postgres",
TableName: `"table_name"."schema_name"`,
}

return initDatabase(t, database), database
}

func initDatabase(t *testing.T, database PgDatabase) (pool *pgxpool.Pool) {
pool, err := pgxpool.New(context.Background(), fmt.Sprintf("postgres://%s:%s@%s:%d/%s", database.User, database.Password, database.Host, database.Port, database.Database))
if err != nil {
t.Errorf("Expected no error, got %v", err)
}

pool.Exec(context.Background(), fmt.Sprintf("DROP TABLE IF EXISTS %s", database.TableName))

if strings.Contains(database.TableName, ".") {
schema := strings.Split(database.TableName, ".")[0]
pool.Exec(context.Background(), fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema))
}

pool.Exec(context.Background(), fmt.Sprintf("CREATE TABLE %s (id BIGSERIAL NOT NULL, body TEXT NOT NULL, headers TEXT NOT NULL, queue_name VARCHAR(190) NOT NULL, created_at TIMESTAMP(0) WITHOUT TIME ZONE NOT NULL, available_at TIMESTAMP(0) WITHOUT TIME ZONE NOT NULL, delivered_at TIMESTAMP(0) WITHOUT TIME ZONE DEFAULT NULL, PRIMARY KEY(id))", database.TableName))

_, err = pool.Exec(context.Background(), fmt.Sprintf("INSERT INTO %s (body, headers, queue_name, created_at, available_at, delivered_at) VALUES ('{\"id\": 1}', '{}', 'go', NOW(), NOW(), NULL)", database.TableName))
Expand All @@ -54,7 +78,7 @@ func setupDatabase(t *testing.T) (*pgxpool.Pool, PgDatabase) {
t.Errorf("Expected no error, got %v", err)
}

return pool, database
return pool
}

func TestPgDelete(t *testing.T) {
Expand Down Expand Up @@ -84,6 +108,18 @@ func TestPgListen(t *testing.T) {
pool, database := setupDatabase(t)
defer pool.Close()

doTestPgListen(t, database)
}

func TestPgListenWithSchema(t *testing.T) {
pool, database := setupDatabaseWithSchema(t)
defer pool.Close()

doTestPgListen(t, database)
}

func doTestPgListen(t *testing.T, database PgDatabase) {

go func() {
err := database.listen(processMessage, Message{}, 5)
if err != nil {
Expand Down

0 comments on commit f7de1c0

Please sign in to comment.