diff --git a/pkg/registry/common/setregistrationtime/nse_server.go b/pkg/registry/common/setregistrationtime/nse_server.go index 77ac610e8..32df2e69c 100644 --- a/pkg/registry/common/setregistrationtime/nse_server.go +++ b/pkg/registry/common/setregistrationtime/nse_server.go @@ -1,5 +1,7 @@ // Copyright (c) 2021 Doc.ai and/or its affiliates. // +// Copyright (c) 2023 Cisco and/or its affiliates. +// // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,15 +21,19 @@ package setregistrationtime import ( "context" + "github.com/edwarnicke/genericsync" "github.com/golang/protobuf/ptypes/empty" "github.com/networkservicemesh/api/pkg/api/registry" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/tools/clock" ) -type setregtimeNSEServer struct{} +type setregtimeNSEServer struct { + genericsync.Map[string, *timestamppb.Timestamp] +} // NewNetworkServiceEndpointRegistryServer creates a new NetworkServiceServer chain element that sets initial // registration time. @@ -36,8 +42,13 @@ func NewNetworkServiceEndpointRegistryServer() registry.NetworkServiceEndpointRe } func (r *setregtimeNSEServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { - if nse.InitialRegistrationTime == nil { - nse.InitialRegistrationTime = timestamppb.New(clock.FromContext(ctx).Now()) + if v, ok := r.Load(nse.GetName()); ok { + nse.InitialRegistrationTime = v + } else { + if nse.InitialRegistrationTime == nil { + nse.InitialRegistrationTime = timestamppb.New(clock.FromContext(ctx).Now()) + } + r.Store(nse.GetName(), proto.Clone(nse.InitialRegistrationTime).(*timestamppb.Timestamp)) } return next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse) @@ -48,5 +59,6 @@ func (r *setregtimeNSEServer) Find(q *registry.NetworkServiceEndpointQuery, s re } func (r *setregtimeNSEServer) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*empty.Empty, error) { + r.Delete(nse.GetName()) return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse) } diff --git a/pkg/registry/common/setregistrationtime/nse_server_test.go b/pkg/registry/common/setregistrationtime/nse_server_test.go index 6c6c5f1dc..2f08619d9 100644 --- a/pkg/registry/common/setregistrationtime/nse_server_test.go +++ b/pkg/registry/common/setregistrationtime/nse_server_test.go @@ -1,5 +1,7 @@ // Copyright (c) 2021 Doc.ai and/or its affiliates. // +// Copyright (c) 2023 Cisco and/or its affiliates. +// // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -65,12 +67,21 @@ func TestRegTimeServer_Register(t *testing.T) { require.Len(t, nses, 1) require.True(t, proto.Equal(nses[0].InitialRegistrationTime, registeredNse.InitialRegistrationTime)) - // 3. Refresh + // 3.1 Refresh reg, err = s.Register(ctx, reg.Clone()) require.NoError(t, err) require.NotNil(t, reg.InitialRegistrationTime) require.True(t, proto.Equal(reg.InitialRegistrationTime, registeredNse.InitialRegistrationTime)) + // 3.2 Refresh with empty field + regClone := reg.Clone() + regClone.InitialRegistrationTime = nil + clockMock.Add(time.Second) + reg, err = s.Register(ctx, regClone) + require.NoError(t, err) + require.NotNil(t, reg.InitialRegistrationTime) + require.True(t, proto.Equal(reg.InitialRegistrationTime, registeredNse.InitialRegistrationTime)) + // 4. Unregister _, err = s.Unregister(ctx, reg.Clone()) require.NoError(t, err)