From 6487328ce5cfcf3f06707ac8e0063c3b6b171d82 Mon Sep 17 00:00:00 2001 From: Oliver Cvetkovski Date: Wed, 27 Sep 2023 12:29:56 +0200 Subject: [PATCH] ns: Add limit to deduplication of accumulated metadata --- pkg/networkserver/grpc_gsns.go | 17 ++- .../redis/uplink_deduplicator.go | 6 +- pkg/redis/redis.go | 23 +++- pkg/redis/redis_test.go | 105 ++++++++++++++---- 4 files changed, 117 insertions(+), 34 deletions(-) diff --git a/pkg/networkserver/grpc_gsns.go b/pkg/networkserver/grpc_gsns.go index e0f00e46e6..cce31eb588 100644 --- a/pkg/networkserver/grpc_gsns.go +++ b/pkg/networkserver/grpc_gsns.go @@ -57,19 +57,26 @@ const ( // This parameter is separated from the uplink collection period since the JoinRequest may have to be // served by a Join Server which is either geographically far away, or simply slow to respond. joinRequestCollectionWindow = 6 * time.Second + + // DeduplicationLimit is the number of metadata to deduplicate for a single transmission. + deduplicationLimit = 50 ) // UplinkDeduplicator represents an entity, that deduplicates uplinks and accumulates metadata. type UplinkDeduplicator interface { // DeduplicateUplink deduplicates an uplink message for specified time.Duration, in the provided round. // DeduplicateUplink returns true if the uplink is not a duplicate or false and error, if any, otherwise. - DeduplicateUplink(ctx context.Context, up *ttnpb.UplinkMessage, window time.Duration, round uint64) (first bool, err error) + DeduplicateUplink( + ctx context.Context, up *ttnpb.UplinkMessage, window time.Duration, limit int, round uint64, + ) (first bool, err error) // AccumulatedMetadata returns accumulated metadata for specified uplink message in the provided round and error, if any. AccumulatedMetadata(ctx context.Context, up *ttnpb.UplinkMessage, round uint64) (mds []*ttnpb.RxMetadata, err error) } -func (ns *NetworkServer) deduplicateUplink(ctx context.Context, up *ttnpb.UplinkMessage, window time.Duration, round uint64) (bool, error) { - ok, err := ns.uplinkDeduplicator.DeduplicateUplink(ctx, up, window, round) +func (ns *NetworkServer) deduplicateUplink( + ctx context.Context, up *ttnpb.UplinkMessage, window time.Duration, limit int, round uint64, +) (bool, error) { + ok, err := ns.uplinkDeduplicator.DeduplicateUplink(ctx, up, window, limit, round) if err != nil { log.FromContext(ctx).WithError(err).Error("Failed to deduplicate uplink") return false, err @@ -887,7 +894,7 @@ func (ns *NetworkServer) handleDataUplink(ctx context.Context, up *ttnpb.UplinkM "uplink_f_cnt", pld.FHdr.FCnt, )) - ok, err := ns.deduplicateUplink(ctx, up, ns.collectionWindow(ctx), initialDeduplicationRound) + ok, err := ns.deduplicateUplink(ctx, up, ns.collectionWindow(ctx), deduplicationLimit, initialDeduplicationRound) if err != nil { return err } @@ -1182,7 +1189,7 @@ func (ns *NetworkServer) handleJoinRequest(ctx context.Context, up *ttnpb.Uplink "join_eui", types.MustEUI64(pld.JoinEui).OrZero(), )) - ok, err := ns.deduplicateUplink(ctx, up, joinRequestCollectionWindow, initialDeduplicationRound) + ok, err := ns.deduplicateUplink(ctx, up, joinRequestCollectionWindow, deduplicationLimit, initialDeduplicationRound) if err != nil { return err } diff --git a/pkg/networkserver/redis/uplink_deduplicator.go b/pkg/networkserver/redis/uplink_deduplicator.go index 2f747cdca4..7af56d81a4 100644 --- a/pkg/networkserver/redis/uplink_deduplicator.go +++ b/pkg/networkserver/redis/uplink_deduplicator.go @@ -52,7 +52,9 @@ func uplinkHash(ctx context.Context, up *ttnpb.UplinkMessage, round uint64) (str } // DeduplicateUplink deduplicates up for window. Since highest precision allowed by Redis is milliseconds, window is truncated to milliseconds. -func (d *UplinkDeduplicator) DeduplicateUplink(ctx context.Context, up *ttnpb.UplinkMessage, window time.Duration, round uint64) (bool, error) { +func (d *UplinkDeduplicator) DeduplicateUplink( + ctx context.Context, up *ttnpb.UplinkMessage, window time.Duration, limit int, round uint64, +) (bool, error) { h, err := uplinkHash(ctx, up, round) if err != nil { return false, err @@ -61,7 +63,7 @@ func (d *UplinkDeduplicator) DeduplicateUplink(ctx context.Context, up *ttnpb.Up for _, md := range up.RxMetadata { msgs = append(msgs, md) } - return ttnredis.DeduplicateProtos(ctx, d.Redis, d.Redis.Key(h), window, msgs...) + return ttnredis.DeduplicateProtos(ctx, d.Redis, d.Redis.Key(h), window, limit, msgs...) } // AccumulatedMetadata returns accumulated metadata for up. diff --git a/pkg/redis/redis.go b/pkg/redis/redis.go index 25da7c615c..7c66bb1d06 100644 --- a/pkg/redis/redis.go +++ b/pkg/redis/redis.go @@ -696,13 +696,17 @@ func (q *TaskQueue) Pop(ctx context.Context, consumerID string, r redis.Cmdable, return popTask(ctx, r, q.Group, consumerID, f, q.Key, q.StreamBlockLimit) } -var deduplicateProtosScript = redis.NewScript(`local exp = ARGV[1] +var deduplicateProtosScript = redis.NewScript(`local exp = table.remove(ARGV, 1) +local limit = tonumber(table.remove(ARGV, 1)) local ok = redis.call('set', KEYS[1], '', 'px', exp, 'nx') -if #ARGV > 1 then - table.remove(ARGV, 1) + +if #ARGV > 0 then redis.call('rpush', KEYS[2], unpack(ARGV)) local ttl = redis.call('pttl', KEYS[1]) redis.call('pexpire', KEYS[2], ttl) + if limit > 0 then + redis.call('ltrim', KEYS[2], -limit, -1) + end end if ok then return 1 @@ -728,12 +732,19 @@ func milliseconds(d time.Duration) int64 { return ms } -// DeduplicateProtos deduplicates protos using key k. It stores a lock at LockKey(k) and the list of collected protos at ListKey(k). +// DeduplicateProtos deduplicates protos using key k. It stores a lock at LockKey(k) +// and the list of collected protos at ListKey(k). +// If the number of protos exceeds limit, the messages are trimmed from the start of the list. func DeduplicateProtos( - ctx context.Context, r redis.Scripter, k string, window time.Duration, msgs ...proto.Message, + ctx context.Context, r redis.Scripter, k string, window time.Duration, limit int, msgs ...proto.Message, ) (bool, error) { - args := make([]any, 0, 1+len(msgs)) + args := make([]any, 0, 2+len(msgs)) args = append(args, milliseconds(window)) + args = append(args, limit) + if n := len(msgs) - limit; n > 0 { + msgs = msgs[n:] + } + for _, msg := range msgs { s, err := MarshalProto(msg) if err != nil { diff --git a/pkg/redis/redis_test.go b/pkg/redis/redis_test.go index ed247ceca2..ce433375d6 100644 --- a/pkg/redis/redis_test.go +++ b/pkg/redis/redis_test.go @@ -487,57 +487,62 @@ func TestTaskQueue(t *testing.T) { } } +func makeProto(t *testing.T, s string) proto.Message { + t.Helper() + return &ttnpb.APIKey{Id: s} +} + +func makeProtoString(t *testing.T, s string) string { + t.Helper() + m := makeProto(t, s) + return test.Must(MarshalProto(m)) +} + func TestProtoDeduplicator(t *testing.T) { a, ctx := test.New(t) cl, flush := test.NewRedis(ctx, "redis_test") defer flush() defer cl.Close() - - makeProto := func(s string) proto.Message { - return &ttnpb.APIKey{Id: s} - } - makeProtoString := func(s string) string { - m := makeProto(s) - s, _ = MarshalProto(m) - return s - } + limit := 50 ttl := (1 << 12) * test.Delay key1 := cl.Key("test1") key2 := cl.Key("test2") - v, err := DeduplicateProtos(ctx, cl, key1, ttl) + v, err := DeduplicateProtos(ctx, cl, key1, ttl, limit) if !a.So(err, should.BeNil) { t.FailNow() } a.So(v, should.BeTrue) - v, err = DeduplicateProtos(ctx, cl, key1, ttl, makeProto("proto1")) + v, err = DeduplicateProtos(ctx, cl, key1, ttl, limit, makeProto(t, "proto1")) if !a.So(err, should.BeNil) { t.FailNow() } a.So(v, should.BeFalse) - v, err = DeduplicateProtos(ctx, cl, key2, ttl, makeProto("proto1")) + v, err = DeduplicateProtos(ctx, cl, key2, ttl, limit, makeProto(t, "proto1")) if !a.So(err, should.BeNil) { t.FailNow() } a.So(v, should.BeTrue) - v, err = DeduplicateProtos(ctx, cl, key1, ttl, makeProto("proto1")) + v, err = DeduplicateProtos(ctx, cl, key1, ttl, limit, makeProto(t, "proto1")) if !a.So(err, should.BeNil) { t.FailNow() } a.So(v, should.BeFalse) - v, err = DeduplicateProtos(ctx, cl, key1, ttl, makeProto("proto2"), makeProto("proto3")) + v, err = DeduplicateProtos( + ctx, cl, key1, ttl, limit, makeProto(t, "proto2"), makeProto(t, "proto3"), + ) if !a.So(err, should.BeNil) { t.FailNow() } a.So(v, should.BeFalse) - v, err = DeduplicateProtos(ctx, cl, key2, ttl, makeProto("proto2")) + v, err = DeduplicateProtos(ctx, cl, key2, ttl, limit, makeProto(t, "proto2")) if !a.So(err, should.BeNil) { t.FailNow() } @@ -556,10 +561,10 @@ func TestProtoDeduplicator(t *testing.T) { t.FailNow() } a.So(ss, should.Resemble, []string{ - makeProtoString("proto1"), - makeProtoString("proto1"), - makeProtoString("proto2"), - makeProtoString("proto3"), + makeProtoString(t, "proto1"), + makeProtoString(t, "proto1"), + makeProtoString(t, "proto2"), + makeProtoString(t, "proto3"), }) a.So(lockTTL, should.BeGreaterThan, 0) a.So(lockTTL, should.BeLessThanOrEqualTo, ttl) @@ -579,8 +584,8 @@ func TestProtoDeduplicator(t *testing.T) { t.FailNow() } a.So(ss, should.Resemble, []string{ - makeProtoString("proto1"), - makeProtoString("proto2"), + makeProtoString(t, "proto1"), + makeProtoString(t, "proto2"), }) a.So(lockTTL, should.BeGreaterThan, 0) a.So(lockTTL, should.BeLessThanOrEqualTo, ttl) @@ -588,6 +593,64 @@ func TestProtoDeduplicator(t *testing.T) { a.So(listTTL, should.BeLessThanOrEqualTo, ttl) } +func TestProtoDeduplicatorRespectsLimit(t *testing.T) { + t.Parallel() + a, ctx := test.New(t) + cl, flush := test.NewRedis(ctx, "redis_test") + defer flush() + defer cl.Close() + + ttl := (1 << 12) * test.Delay + key := cl.Key("test1") + limit := 30 + protoID := 0 + + for i := 0; i < limit+3; i++ { + s := fmt.Sprintf("proto%d", protoID) + _, err := DeduplicateProtos(ctx, cl, key, ttl, limit, makeProto(t, s)) + if !a.So(err, should.BeNil) { + t.FailNow() + } + protoID++ + } + + actual, err := cl.LRange(ctx, ListKey(key), 0, -1).Result() + if !a.So(err, should.BeNil) { + t.FailNow() + } + a.So(actual, should.HaveLength, limit) + expected := make([]string, limit) + for i := limit - 1; i >= 0; i-- { + s := fmt.Sprintf("proto%d", protoID-limit+i) + expected[i] = makeProtoString(t, s) + } + a.So(actual, should.Resemble, expected) + + bulkedProtosLen := limit + 5 + bulkedProtos := make([]proto.Message, bulkedProtosLen) + for i := 0; i < bulkedProtosLen; i++ { + s := fmt.Sprintf("proto%d", protoID) + bulkedProtos[i] = makeProto(t, s) + protoID++ + } + + if _, err := DeduplicateProtos(ctx, cl, key, ttl, limit, bulkedProtos...); !a.So(err, should.BeNil) { + t.FailNow() + } + + actual, err = cl.LRange(ctx, ListKey(key), 0, -1).Result() + if !a.So(err, should.BeNil) { + t.FailNow() + } + a.So(actual, should.HaveLength, limit) + expected = make([]string, limit) + for i := limit - 1; i >= 0; i-- { + s := fmt.Sprintf("proto%d", protoID-limit+i) + expected[i] = makeProtoString(t, s) + } + a.So(actual, should.Resemble, expected) +} + func TestMutex(t *testing.T) { a, ctx := test.New(t)