diff --git a/storage/tenant.go b/storage/tenant.go index 48e88d69434..c00dfdd257d 100644 --- a/storage/tenant.go +++ b/storage/tenant.go @@ -26,7 +26,7 @@ const ( // WithTenant creates a Context with a tenant association func WithTenant(ctx context.Context, tenant string) context.Context { - return context.WithValue(context.Background(), tenantKey, tenant) + return context.WithValue(ctx, tenantKey, tenant) } // GetTenant retrieves a tenant associated with a Context diff --git a/storage/tenant_test.go b/storage/tenant_test.go index 5b44ff8657c..563a511c6e3 100644 --- a/storage/tenant_test.go +++ b/storage/tenant_test.go @@ -21,11 +21,22 @@ import ( "github.com/stretchr/testify/assert" ) +type testContextKey string + func TestContextTenantHandling(t *testing.T) { ctxWithTenant := WithTenant(context.Background(), "tenant1") assert.Equal(t, "tenant1", GetTenant(ctxWithTenant)) } +func TestContextPreserved(t *testing.T) { + key := testContextKey("expected-key") + val := "expected-value" + ctxWithValue := context.WithValue(context.Background(), key, val) + ctxWithTenant := WithTenant(ctxWithValue, "tenant1") + assert.Equal(t, "tenant1", GetTenant(ctxWithTenant)) + assert.Equal(t, val, ctxWithTenant.Value(key)) +} + func TestNoTenant(t *testing.T) { // If no tenant in context, GetTenant should return the empty string assert.Equal(t, "", GetTenant(context.Background()))