diff --git a/courier/test/persistence.go b/courier/test/persistence.go index 254b17ac5d0b..b63af9640a7c 100644 --- a/courier/test/persistence.go +++ b/courier/test/persistence.go @@ -24,11 +24,11 @@ type PersisterWrapper interface { courier.Persister } -type NetworkWrapper func() (uuid.UUID, PersisterWrapper) +type NetworkWrapper func(t *testing.T, ctx context.Context) (uuid.UUID, PersisterWrapper) func TestPersister(ctx context.Context, newNetworkUnlessExisting NetworkWrapper, newNetwork NetworkWrapper) func(t *testing.T) { return func(t *testing.T) { - nid, p := newNetworkUnlessExisting() + nid, p := newNetworkUnlessExisting(t, ctx) t.Run("case=no messages in queue", func(t *testing.T) { m, err := p.NextMessages(ctx, 10) @@ -116,7 +116,7 @@ func TestPersister(ctx context.Context, newNetworkUnlessExisting NetworkWrapper, }) t.Run("can not get on another network", func(t *testing.T) { - _, p := newNetwork() + _, p := newNetwork(t, ctx) _, err := p.LatestQueuedMessage(ctx) require.ErrorIs(t, err, courier.ErrQueueEmpty) @@ -126,7 +126,7 @@ func TestPersister(ctx context.Context, newNetworkUnlessExisting NetworkWrapper, }) t.Run("can not update on another network", func(t *testing.T) { - _, p := newNetwork() + _, p := newNetwork(t, ctx) err := p.SetMessageStatus(ctx, id, courier.MessageStatusProcessing) require.ErrorIs(t, err, sqlcon.ErrNoRows) }) diff --git a/persistence/sql/persister_test.go b/persistence/sql/persister_test.go index c16754356f8a..890abc8bf3b6 100644 --- a/persistence/sql/persister_test.go +++ b/persistence/sql/persister_test.go @@ -180,7 +180,7 @@ func TestPersister(t *testing.T) { }) t.Run("contract=courier.TestPersister", func(t *testing.T) { pop.SetLogger(pl(t)) - upsert, insert := sqltesthelpers.DefaultNetworkWrapper(t, ctx, p) + upsert, insert := sqltesthelpers.DefaultNetworkWrapper(p) courier.TestPersister(ctx, upsert, insert)(t) }) t.Run("contract=verification.TestPersister", func(t *testing.T) { diff --git a/persistence/sql/testhelpers/network.go b/persistence/sql/testhelpers/network.go index 70fa2419657a..eb530f1d4c8d 100644 --- a/persistence/sql/testhelpers/network.go +++ b/persistence/sql/testhelpers/network.go @@ -11,10 +11,10 @@ import ( "github.com/ory/kratos/persistence" ) -func DefaultNetworkWrapper(t *testing.T, ctx context.Context, p persistence.Persister) (courier.NetworkWrapper, courier.NetworkWrapper) { - return func() (db.UUID, courier.PersisterWrapper) { +func DefaultNetworkWrapper(p persistence.Persister) (courier.NetworkWrapper, courier.NetworkWrapper) { + return func(t *testing.T, ctx context.Context) (db.UUID, courier.PersisterWrapper) { return testhelpers.NewNetworkUnlessExisting(t, ctx, p) - }, func() (db.UUID, courier.PersisterWrapper) { + }, func(t *testing.T, ctx context.Context) (db.UUID, courier.PersisterWrapper) { return testhelpers.NewNetwork(t, ctx, p) } }