diff --git a/pkg/networkservice/common/begin/client.go b/pkg/networkservice/common/begin/client.go index a0978408a..7bb8ebafe 100644 --- a/pkg/networkservice/common/begin/client.go +++ b/pkg/networkservice/common/begin/client.go @@ -65,6 +65,7 @@ func (b *beginClient) Request(ctx context.Context, request *networkservice.Netwo conn, err = b.Request(ctx, request, opts...) return } + eventFactoryClient.updateContext(ctx) ctx = withEventFactory(ctx, eventFactoryClient) request.Connection = mergeConnection(eventFactoryClient.returnedConnection, request.GetConnection(), eventFactoryClient.request.GetConnection()) diff --git a/pkg/networkservice/common/begin/event_factory.go b/pkg/networkservice/common/begin/event_factory.go index a1db68fcf..b1607b36b 100644 --- a/pkg/networkservice/common/begin/event_factory.go +++ b/pkg/networkservice/common/begin/event_factory.go @@ -60,11 +60,7 @@ func newEventFactoryClient(ctx context.Context, afterClose func(), opts ...grpc. client: next.Client(ctx), opts: opts, } - ctxFunc := postpone.ContextWithValues(ctx) - f.ctxFunc = func() (context.Context, context.CancelFunc) { - eventCtx, cancel := ctxFunc() - return withEventFactory(eventCtx, f), cancel - } + f.updateContext(ctx) f.afterCloseFunc = func() { f.state = closed @@ -75,6 +71,14 @@ func newEventFactoryClient(ctx context.Context, afterClose func(), opts ...grpc. return f } +func (f *eventFactoryClient) updateContext(ctx context.Context) { + ctxFunc := postpone.ContextWithValues(ctx) + f.ctxFunc = func() (context.Context, context.CancelFunc) { + eventCtx, cancel := ctxFunc() + return withEventFactory(eventCtx, f), cancel + } +} + func (f *eventFactoryClient) Request(opts ...Option) <-chan error { o := &option{ cancelCtx: context.Background(), @@ -155,11 +159,7 @@ func newEventFactoryServer(ctx context.Context, afterClose func()) *eventFactory f := &eventFactoryServer{ server: next.Server(ctx), } - ctxFunc := postpone.ContextWithValues(ctx) - f.ctxFunc = func() (context.Context, context.CancelFunc) { - eventCtx, cancel := ctxFunc() - return withEventFactory(eventCtx, f), cancel - } + f.updateContext(ctx) f.afterCloseFunc = func() { f.state = closed @@ -168,6 +168,14 @@ func newEventFactoryServer(ctx context.Context, afterClose func()) *eventFactory return f } +func (f *eventFactoryServer) updateContext(ctx context.Context) { + ctxFunc := postpone.ContextWithValues(ctx) + f.ctxFunc = func() (context.Context, context.CancelFunc) { + eventCtx, cancel := ctxFunc() + return withEventFactory(eventCtx, f), cancel + } +} + func (f *eventFactoryServer) Request(opts ...Option) <-chan error { o := &option{ cancelCtx: context.Background(), diff --git a/pkg/networkservice/common/begin/event_factory_client_test.go b/pkg/networkservice/common/begin/event_factory_client_test.go index 3b6484a17..d70bf1985 100644 --- a/pkg/networkservice/common/begin/event_factory_client_test.go +++ b/pkg/networkservice/common/begin/event_factory_client_test.go @@ -33,6 +33,48 @@ import ( "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" ) +// This test reproduces the situation when refresh changes the eventFactory context +// nolint:dupl +func TestRefresh_Client(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + syncChan := make(chan struct{}) + checkCtxCl := &checkContextClient{t: t} + eventFactoryCl := &eventFactoryClient{ch: syncChan} + client := chain.NewNetworkServiceClient( + begin.NewClient(), + checkCtxCl, + eventFactoryCl, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Set any value to context + ctx = context.WithValue(ctx, contextKey{}, "value_1") + checkCtxCl.setExpectedValue("value_1") + + // Do Request with this context + request := testRequest("1") + conn, err := client.Request(ctx, request.Clone()) + assert.NotNil(t, t, conn) + assert.NoError(t, err) + + // Change context value before refresh Request + ctx = context.WithValue(ctx, contextKey{}, "value_2") + checkCtxCl.setExpectedValue("value_2") + request.Connection = conn.Clone() + + // Call refresh + conn, err = client.Request(ctx, request.Clone()) + assert.NotNil(t, t, conn) + assert.NoError(t, err) + + // Call refresh from eventFactory. We are expecting updated value in the context + eventFactoryCl.callRefresh() + <-syncChan +} + // This test reproduces the situation when Close and Request were called at the same time // nolint:dupl func TestRefreshDuringClose_Client(t *testing.T) { diff --git a/pkg/networkservice/common/begin/event_factory_server_test.go b/pkg/networkservice/common/begin/event_factory_server_test.go index d4ad8e9e5..8013ba8de 100644 --- a/pkg/networkservice/common/begin/event_factory_server_test.go +++ b/pkg/networkservice/common/begin/event_factory_server_test.go @@ -32,6 +32,48 @@ import ( "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" ) +// This test reproduces the situation when refresh changes the eventFactory context +// nolint:dupl +func TestRefresh_Server(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + syncChan := make(chan struct{}) + checkCtxServ := &checkContextServer{t: t} + eventFactoryServ := &eventFactoryServer{ch: syncChan} + server := chain.NewNetworkServiceServer( + begin.NewServer(), + checkCtxServ, + eventFactoryServ, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Set any value to context + ctx = context.WithValue(ctx, contextKey{}, "value_1") + checkCtxServ.setExpectedValue("value_1") + + // Do Request with this context + request := testRequest("1") + conn, err := server.Request(ctx, request.Clone()) + assert.NotNil(t, t, conn) + assert.NoError(t, err) + + // Change context value before refresh Request + ctx = context.WithValue(ctx, contextKey{}, "value_2") + checkCtxServ.setExpectedValue("value_2") + request.Connection = conn.Clone() + + // Call refresh + conn, err = server.Request(ctx, request.Clone()) + assert.NotNil(t, t, conn) + assert.NoError(t, err) + + // Call refresh from eventFactory. We are expecting updated value in the context + eventFactoryServ.callRefresh() + <-syncChan +} + // This test reproduces the situation when Close and Request were called at the same time // nolint:dupl func TestRefreshDuringClose_Server(t *testing.T) { diff --git a/pkg/networkservice/common/begin/server.go b/pkg/networkservice/common/begin/server.go index 790916c1d..e8d1bdf03 100644 --- a/pkg/networkservice/common/begin/server.go +++ b/pkg/networkservice/common/begin/server.go @@ -61,6 +61,8 @@ func (b *beginServer) Request(ctx context.Context, request *networkservice.Netwo conn, err = b.Request(ctx, request) return } + eventFactoryServer.updateContext(ctx) + ctx = withEventFactory(ctx, eventFactoryServer) conn, err = next.Server(ctx).Request(ctx, request) if err != nil { diff --git a/pkg/registry/common/begin/ns_client.go b/pkg/registry/common/begin/ns_client.go index fd35ee9a5..52d41b592 100644 --- a/pkg/registry/common/begin/ns_client.go +++ b/pkg/registry/common/begin/ns_client.go @@ -61,6 +61,7 @@ func (b *beginNSClient) Register(ctx context.Context, in *registry.NetworkServic resp, err = b.Register(ctx, in, opts...) return } + eventFactoryClient.updateContext(ctx) ctx = withEventFactory(ctx, eventFactoryClient) resp, err = next.NetworkServiceRegistryClient(ctx).Register(ctx, in, opts...) diff --git a/pkg/registry/common/begin/ns_event_factory.go b/pkg/registry/common/begin/ns_event_factory.go index 7af3f8c71..ebecd053a 100644 --- a/pkg/registry/common/begin/ns_event_factory.go +++ b/pkg/registry/common/begin/ns_event_factory.go @@ -43,11 +43,7 @@ func newEventNSFactoryClient(ctx context.Context, afterClose func(), opts ...grp client: next.NetworkServiceRegistryClient(ctx), opts: opts, } - ctxFunc := postpone.ContextWithValues(ctx) - f.ctxFunc = func() (context.Context, context.CancelFunc) { - eventCtx, cancel := ctxFunc() - return withEventFactory(eventCtx, f), cancel - } + f.updateContext(ctx) f.afterCloseFunc = func() { f.state = closed @@ -58,6 +54,14 @@ func newEventNSFactoryClient(ctx context.Context, afterClose func(), opts ...grp return f } +func (f *eventNSFactoryClient) updateContext(ctx context.Context) { + ctxFunc := postpone.ContextWithValues(ctx) + f.ctxFunc = func() (context.Context, context.CancelFunc) { + eventCtx, cancel := ctxFunc() + return withEventFactory(eventCtx, f), cancel + } +} + func (f *eventNSFactoryClient) Register(opts ...Option) <-chan error { o := &option{ cancelCtx: context.Background(), @@ -129,11 +133,7 @@ func newNSEventFactoryServer(ctx context.Context, afterClose func()) *eventNSFac f := &eventNSFactoryServer{ server: next.NetworkServiceRegistryServer(ctx), } - ctxFunc := postpone.ContextWithValues(ctx) - f.ctxFunc = func() (context.Context, context.CancelFunc) { - eventCtx, cancel := ctxFunc() - return withEventFactory(eventCtx, f), cancel - } + f.updateContext(ctx) f.afterCloseFunc = func() { f.state = closed @@ -142,6 +142,14 @@ func newNSEventFactoryServer(ctx context.Context, afterClose func()) *eventNSFac return f } +func (f *eventNSFactoryServer) updateContext(ctx context.Context) { + ctxFunc := postpone.ContextWithValues(ctx) + f.ctxFunc = func() (context.Context, context.CancelFunc) { + eventCtx, cancel := ctxFunc() + return withEventFactory(eventCtx, f), cancel + } +} + func (f *eventNSFactoryServer) Register(opts ...Option) <-chan error { o := &option{ cancelCtx: context.Background(), diff --git a/pkg/registry/common/begin/ns_server.go b/pkg/registry/common/begin/ns_server.go index 55e35b26b..5ac9360b4 100644 --- a/pkg/registry/common/begin/ns_server.go +++ b/pkg/registry/common/begin/ns_server.go @@ -60,6 +60,8 @@ func (b *beginNSServer) Register(ctx context.Context, in *registry.NetworkServic resp, err = b.Register(ctx, in) return } + eventFactoryServer.updateContext(ctx) + ctx = withEventFactory(ctx, eventFactoryServer) resp, err = next.NetworkServiceRegistryServer(ctx).Register(ctx, in) if err != nil { diff --git a/pkg/registry/common/begin/nse_client.go b/pkg/registry/common/begin/nse_client.go index a2c53c344..1ec2ed9ee 100644 --- a/pkg/registry/common/begin/nse_client.go +++ b/pkg/registry/common/begin/nse_client.go @@ -61,6 +61,7 @@ func (b *beginNSEClient) Register(ctx context.Context, in *registry.NetworkServi resp, err = b.Register(ctx, in, opts...) return } + eventFactoryClient.updateContext(ctx) ctx = withEventFactory(ctx, eventFactoryClient) resp, err = next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...) diff --git a/pkg/registry/common/begin/nse_event_factory.go b/pkg/registry/common/begin/nse_event_factory.go index c6dddca00..92c7ca280 100644 --- a/pkg/registry/common/begin/nse_event_factory.go +++ b/pkg/registry/common/begin/nse_event_factory.go @@ -43,11 +43,7 @@ func newEventNSEFactoryClient(ctx context.Context, afterClose func(), opts ...gr client: next.NetworkServiceEndpointRegistryClient(ctx), opts: opts, } - ctxFunc := postpone.ContextWithValues(ctx) - f.ctxFunc = func() (context.Context, context.CancelFunc) { - eventCtx, cancel := ctxFunc() - return withEventFactory(eventCtx, f), cancel - } + f.updateContext(ctx) f.afterCloseFunc = func() { f.state = closed @@ -58,6 +54,14 @@ func newEventNSEFactoryClient(ctx context.Context, afterClose func(), opts ...gr return f } +func (f *eventNSEFactoryClient) updateContext(ctx context.Context) { + ctxFunc := postpone.ContextWithValues(ctx) + f.ctxFunc = func() (context.Context, context.CancelFunc) { + eventCtx, cancel := ctxFunc() + return withEventFactory(eventCtx, f), cancel + } +} + func (f *eventNSEFactoryClient) Register(opts ...Option) <-chan error { o := &option{ cancelCtx: context.Background(), @@ -129,11 +133,7 @@ func newNSEEventFactoryServer(ctx context.Context, afterClose func()) *eventNSEF f := &eventNSEFactoryServer{ server: next.NetworkServiceEndpointRegistryServer(ctx), } - ctxFunc := postpone.ContextWithValues(ctx) - f.ctxFunc = func() (context.Context, context.CancelFunc) { - eventCtx, cancel := ctxFunc() - return withEventFactory(eventCtx, f), cancel - } + f.updateContext(ctx) f.afterCloseFunc = func() { f.state = closed @@ -142,6 +142,14 @@ func newNSEEventFactoryServer(ctx context.Context, afterClose func()) *eventNSEF return f } +func (f *eventNSEFactoryServer) updateContext(ctx context.Context) { + ctxFunc := postpone.ContextWithValues(ctx) + f.ctxFunc = func() (context.Context, context.CancelFunc) { + eventCtx, cancel := ctxFunc() + return withEventFactory(eventCtx, f), cancel + } +} + func (f *eventNSEFactoryServer) Register(opts ...Option) <-chan error { o := &option{ cancelCtx: context.Background(), diff --git a/pkg/registry/common/begin/nse_event_factory_client_test.go b/pkg/registry/common/begin/nse_event_factory_client_test.go index 00e5be845..8ada5695d 100644 --- a/pkg/registry/common/begin/nse_event_factory_client_test.go +++ b/pkg/registry/common/begin/nse_event_factory_client_test.go @@ -33,6 +33,48 @@ import ( "google.golang.org/grpc" ) +// This test reproduces the situation when refresh changes the eventFactory context +func TestRefresh_Client(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + syncChan := make(chan struct{}) + checkCtxCl := &checkContextClient{t: t} + eventFactoryCl := &eventFactoryClient{ch: syncChan} + client := chain.NewNetworkServiceEndpointRegistryClient( + begin.NewNetworkServiceEndpointRegistryClient(), + checkCtxCl, + eventFactoryCl, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Set any value to context + ctx = context.WithValue(ctx, contextKey{}, "value_1") + checkCtxCl.setExpectedValue("value_1") + + // Do Register with this context + nse := ®istry.NetworkServiceEndpoint{ + Name: "1", + } + conn, err := client.Register(ctx, nse.Clone()) + assert.NotNil(t, t, conn) + assert.NoError(t, err) + + // Change context value before refresh + ctx = context.WithValue(ctx, contextKey{}, "value_2") + checkCtxCl.setExpectedValue("value_2") + + // Call refresh + conn, err = client.Register(ctx, nse.Clone()) + assert.NotNil(t, t, conn) + assert.NoError(t, err) + + // Call refresh from eventFactory. We are expecting updated value in the context + eventFactoryCl.callRefresh() + <-syncChan +} + // This test reproduces the situation when Unregister and Register were called at the same time func TestRefreshDuringUnregister_Client(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) diff --git a/pkg/registry/common/begin/nse_event_factory_server_test.go b/pkg/registry/common/begin/nse_event_factory_server_test.go index ea980bbfe..03b880d1b 100644 --- a/pkg/registry/common/begin/nse_event_factory_server_test.go +++ b/pkg/registry/common/begin/nse_event_factory_server_test.go @@ -32,6 +32,48 @@ import ( "go.uber.org/goleak" ) +// This test reproduces the situation when refresh changes the eventFactory context +func TestRefresh_Server(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + syncChan := make(chan struct{}) + checkCtxServ := &checkContextServer{t: t} + eventFactoryServ := &eventFactoryServer{ch: syncChan} + server := chain.NewNetworkServiceEndpointRegistryServer( + begin.NewNetworkServiceEndpointRegistryServer(), + checkCtxServ, + eventFactoryServ, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Set any value to context + ctx = context.WithValue(ctx, contextKey{}, "value_1") + checkCtxServ.setExpectedValue("value_1") + + // Do Register with this context + nse := ®istry.NetworkServiceEndpoint{ + Name: "1", + } + conn, err := server.Register(ctx, nse.Clone()) + assert.NotNil(t, t, conn) + assert.NoError(t, err) + + // Change context value before refresh + ctx = context.WithValue(ctx, contextKey{}, "value_2") + checkCtxServ.setExpectedValue("value_2") + + // Call refresh + conn, err = server.Register(ctx, nse.Clone()) + assert.NotNil(t, t, conn) + assert.NoError(t, err) + + // Call refresh from eventFactory. We are expecting updated value in the context + eventFactoryServ.callRefresh() + <-syncChan +} + // This test reproduces the situation when Unregister and Register were called at the same time func TestRefreshDuringUnregister_Server(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) diff --git a/pkg/registry/common/begin/nse_server.go b/pkg/registry/common/begin/nse_server.go index 28221d05a..2ee332267 100644 --- a/pkg/registry/common/begin/nse_server.go +++ b/pkg/registry/common/begin/nse_server.go @@ -60,6 +60,8 @@ func (b *beginNSEServer) Register(ctx context.Context, in *registry.NetworkServi resp, err = b.Register(ctx, in) return } + eventFactoryServer.updateContext(ctx) + ctx = withEventFactory(ctx, eventFactoryServer) resp, err = next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, in) if err != nil {