From 973442e3d329286cfa31ed43c5d51852d3b6cc91 Mon Sep 17 00:00:00 2001 From: OpenIM-Gordon <46924906+FGadvancer@users.noreply.github.com> Date: Mon, 3 Jun 2024 11:24:37 +0800 Subject: [PATCH] refactor: db cache batch refactor and batch consume message. (#2325) * refactor: cmd update. * refactor: msg transfer refactor. * refactor: msg transfer refactor. * refactor: msg transfer refactor. * fix: read prometheus port when flag set to enable and prevent failure during startup. * fix: notification has counted unread counts bug fix. * fix: merge opensource code into local. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * refactor: delete message and message batch use lua. * fix: add protective measures against memory overflow. --- config/redis.yml | 1 - go.mod | 5 +- go.sum | 42 +- internal/api/init.go | 15 +- internal/msggateway/hub_server.go | 4 +- internal/msggateway/init.go | 6 +- internal/msgtransfer/init.go | 58 +-- .../msgtransfer/online_history_msg_handler.go | 463 ++++++------------ .../online_msg_to_mongo_handler.go | 1 - internal/push/push_handler.go | 20 +- internal/rpc/group/group.go | 23 +- internal/rpc/group/notification.go | 4 - internal/rpc/msg/server.go | 3 +- pkg/common/config/config.go | 13 +- pkg/common/startrpc/start.go | 22 +- pkg/common/storage/cache/cachekey/msg.go | 4 - pkg/common/storage/cache/conversation.go | 3 - pkg/common/storage/cache/msg.go | 9 +- .../storage/cache/redis/batch_handler.go | 43 +- .../storage/cache/redis/conversation.go | 8 - pkg/common/storage/cache/redis/lua_script.go | 125 +++++ .../storage/cache/redis/lua_script_test.go | 75 +++ pkg/common/storage/cache/redis/meta_cache.go | 15 - pkg/common/storage/cache/redis/msg.go | 364 +++----------- pkg/common/storage/cache/redis/msg_test.go | 453 ++++------------- .../cache/redis/redis_shard_manager.go | 197 ++++++++ pkg/common/storage/cache/user.go | 6 +- pkg/common/storage/controller/msg.go | 115 +---- pkg/common/storage/database/mgo/msg.go | 25 +- pkg/tools/batcher/batcher.go | 272 ++++++++++ pkg/tools/batcher/batcher_test.go | 66 +++ 31 files changed, 1137 insertions(+), 1323 deletions(-) create mode 100644 pkg/common/storage/cache/redis/lua_script.go create mode 100644 pkg/common/storage/cache/redis/lua_script_test.go delete mode 100644 pkg/common/storage/cache/redis/meta_cache.go create mode 100644 pkg/common/storage/cache/redis/redis_shard_manager.go create mode 100644 pkg/tools/batcher/batcher.go create mode 100644 pkg/tools/batcher/batcher_test.go diff --git a/config/redis.yml b/config/redis.yml index 26becd887d..6fe0dd02d4 100644 --- a/config/redis.yml +++ b/config/redis.yml @@ -1,7 +1,6 @@ address: [ localhost:16379 ] username: '' password: openIM123 -enablePipeline: false clusterMode: false db: 0 maxRetry: 10 \ No newline at end of file diff --git a/go.mod b/go.mod index 54e8a8e0e5..e34e3e4bd8 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/dtm-labs/rockscache v0.1.1 github.com/gin-gonic/gin v1.9.1 github.com/go-playground/validator/v10 v10.18.0 - github.com/gogo/protobuf v1.3.2 + github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt/v4 v4.5.0 github.com/gorilla/websocket v1.5.1 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 @@ -31,6 +31,7 @@ require ( github.com/IBM/sarama v1.43.0 github.com/fatih/color v1.14.1 github.com/go-redis/redis v6.15.9+incompatible + github.com/go-redis/redismock/v9 v9.2.0 github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/kelindar/bitmap v1.5.2 github.com/likexian/gokit v0.25.13 @@ -112,8 +113,6 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect github.com/mozillazg/go-httpheader v0.4.0 // indirect - github.com/onsi/ginkgo v1.16.5 // indirect - github.com/onsi/gomega v1.18.1 // indirect github.com/pelletier/go-toml/v2 v2.1.0 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect diff --git a/go.sum b/go.sum index 5611a6ca65..b2fa7f318d 100644 --- a/go.sum +++ b/go.sum @@ -38,9 +38,6 @@ github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= -github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= -github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= -github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/clbanning/mxj v1.8.4 h1:HuhwZtbyvyOw+3Z1AowPkU87JkJUSv751ELWaiTpj8I= github.com/clbanning/mxj v1.8.4/go.mod h1:BVjHeAH+rl9rs6f+QIpeRl0tfu10SXn1pUSa5PVGJng= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= @@ -83,8 +80,6 @@ github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8 github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= @@ -110,7 +105,8 @@ github.com/go-playground/validator/v10 v10.18.0 h1:BvolUXjp4zuvkZ5YN5t7ebzbhlUtP github.com/go-playground/validator/v10 v10.18.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg= github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= -github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/go-redis/redismock/v9 v9.2.0 h1:ZrMYQeKPECZPjOj5u9eyOjg8Nnb0BS9lkVIZ6IpsKLw= +github.com/go-redis/redismock/v9 v9.2.0/go.mod h1:18KHfGDK4Y6c2R0H38EUGWAdc7ZQS9gfYxc94k7rWT0= github.com/go-zookeeper/zk v1.0.3 h1:7M2kwOsc//9VeeFiPtf+uSJlVpU66x9Ba5+8XK7/TDg= github.com/go-zookeeper/zk v1.0.3/go.mod h1:nOB03cncLtlp4t+UAkGSV+9beXP/akpekBwL+UX1Qcw= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= @@ -133,7 +129,6 @@ github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrU github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= @@ -157,7 +152,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian/v3 v3.3.2 h1:IqNFLAmvJOgVlpdEBiQbDc2EwKW77amAycfTuWKdfvw= github.com/google/martian/v3 v3.3.2/go.mod h1:oBOf6HBosgwRXnUGWUB05QECsc6uvmMiJ3+6W4l/CUk= -github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -186,8 +180,6 @@ github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= -github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= -github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= @@ -270,20 +262,12 @@ github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJ github.com/mozillazg/go-httpheader v0.2.1/go.mod h1:jJ8xECTlalr6ValeXYdOF8fFUISeBAdw6E61aqQma60= github.com/mozillazg/go-httpheader v0.4.0 h1:aBn6aRXtFzyDLZ4VIRLsZbbJloagQfMnCiYgOq6hK4w= github.com/mozillazg/go-httpheader v0.4.0/go.mod h1:PuT8h0pw6efvp8ZeUec1Rs7dwjK08bt6gKSReGMqtdA= -github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= -github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= -github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= -github.com/onsi/ginkgo/v2 v2.0.0/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c= -github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= -github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= -github.com/onsi/gomega v1.17.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= -github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= -github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs= +github.com/onsi/gomega v1.25.0 h1:Vw7br2PCDYijJHSfBOWhov+8cAnUf8MfMaIOV323l6Y= +github.com/onsi/gomega v1.25.0/go.mod h1:r+zV744Re+DiYCIPRlYOTxn0YkOLcAnW8k1xXdMPGhM= github.com/openimsdk/gomake v0.0.13 h1:xLDe/moqgWpRoptHzI4packAWzs4C16b+sVY+txNJp0= github.com/openimsdk/gomake v0.0.13/go.mod h1:PndCozNc2IsQIciyn9mvEblYWZwJmAI+06z94EY+csI= github.com/openimsdk/protocol v0.0.65 h1:SPT9qyUsFRTTKSKb/FjpS+xr6sxz/Kbnu+su1bxYagc= @@ -348,7 +332,6 @@ github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= @@ -438,18 +421,15 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= @@ -467,21 +447,12 @@ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -494,7 +465,6 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= @@ -509,7 +479,6 @@ golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3 golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -556,14 +525,11 @@ google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHh gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/api/init.go b/internal/api/init.go index b49a145696..23866c4a07 100644 --- a/internal/api/init.go +++ b/internal/api/init.go @@ -48,10 +48,6 @@ func Start(ctx context.Context, index int, config *Config) error { if err != nil { return err } - prometheusPort, err := datautil.GetElemByIndex(config.API.Prometheus.Ports, index) - if err != nil { - return err - } var client discovery.SvcDiscoveryRegistry @@ -62,13 +58,20 @@ func Start(ctx context.Context, index int, config *Config) error { } var ( - netDone = make(chan struct{}, 1) - netErr error + netDone = make(chan struct{}, 1) + netErr error + prometheusPort int ) router := newGinRouter(client, config) if config.API.Prometheus.Enable { go func() { + prometheusPort, err = datautil.GetElemByIndex(config.API.Prometheus.Ports, index) + if err != nil { + netErr = err + netDone <- struct{}{} + return + } p := ginprom.NewPrometheus("app", prommetrics.GetGinCusMetrics("Api")) p.SetListenAddress(fmt.Sprintf(":%d", prometheusPort)) if err = p.Use(router); err != nil && err != http.ErrServerClosed { diff --git a/internal/msggateway/hub_server.go b/internal/msggateway/hub_server.go index f9bb699ed9..8ff6d10018 100644 --- a/internal/msggateway/hub_server.go +++ b/internal/msggateway/hub_server.go @@ -47,7 +47,6 @@ func (s *Server) Start(ctx context.Context, index int, conf *Config) error { type Server struct { rpcPort int - prometheusPort int LongConnServer LongConnServer config *Config pushTerminal map[int]struct{} @@ -57,10 +56,9 @@ func (s *Server) SetLongConnServer(LongConnServer LongConnServer) { s.LongConnServer = LongConnServer } -func NewServer(rpcPort int, proPort int, longConnServer LongConnServer, conf *Config) *Server { +func NewServer(rpcPort int, longConnServer LongConnServer, conf *Config) *Server { s := &Server{ rpcPort: rpcPort, - prometheusPort: proPort, LongConnServer: longConnServer, pushTerminal: make(map[int]struct{}), config: conf, diff --git a/internal/msggateway/init.go b/internal/msggateway/init.go index ef24d1bf93..f4d8b0381a 100644 --- a/internal/msggateway/init.go +++ b/internal/msggateway/init.go @@ -38,10 +38,6 @@ func Start(ctx context.Context, index int, conf *Config) error { if err != nil { return err } - prometheusPort, err := datautil.GetElemByIndex(conf.MsgGateway.Prometheus.Ports, index) - if err != nil { - return err - } rpcPort, err := datautil.GetElemByIndex(conf.MsgGateway.RPC.Ports, index) if err != nil { return err @@ -57,7 +53,7 @@ func Start(ctx context.Context, index int, conf *Config) error { return err } - hubServer := NewServer(rpcPort, prometheusPort, longServer, conf) + hubServer := NewServer(rpcPort, longServer, conf) netDone := make(chan error) go func() { err = hubServer.Start(ctx, index, conf) diff --git a/internal/msgtransfer/init.go b/internal/msgtransfer/init.go index ba82abacfd..65d04f3810 100644 --- a/internal/msgtransfer/init.go +++ b/internal/msgtransfer/init.go @@ -44,15 +44,14 @@ import ( ) type MsgTransfer struct { - // This consumer aggregated messages, subscribed to the topic:ws2ms_chat, - // the modification notification is sent to msg_to_modify topic, the message is stored in redis, Incr Redis, - // and then the message is sent to ms2pschat topic for push, and the message is sent to msg_to_mongo topic for persistence - historyCH *OnlineHistoryRedisConsumerHandler + // This consumer aggregated messages, subscribed to the topic:toRedis, + // the message is stored in redis, Incr Redis, and then the message is sent to toPush topic for push, + // and the message is sent to toMongo topic for persistence + historyCH *OnlineHistoryRedisConsumerHandler + //This consumer handle message to mongo historyMongoCH *OnlineHistoryMongoConsumerHandler - // mongoDB batch insert, delete messages in redis after success, - // and handle the deletion notification message deleted subscriptions topic: msg_to_mongo - ctx context.Context - cancel context.CancelFunc + ctx context.Context + cancel context.CancelFunc } type Config struct { @@ -82,8 +81,7 @@ func Start(ctx context.Context, index int, config *Config) error { } client.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin"))) - //todo MsgCacheTimeout - msgModel := redis.NewMsgCache(rdb, config.RedisConfig.EnablePipeline) + msgModel := redis.NewMsgCache(rdb) seqModel := redis.NewSeqCache(rdb) msgDocModel, err := mgo.NewMsgMongo(mgocli.GetDB()) if err != nil { @@ -95,37 +93,23 @@ func Start(ctx context.Context, index int, config *Config) error { } conversationRpcClient := rpcclient.NewConversationRpcClient(client, config.Share.RpcRegisterName.Conversation) groupRpcClient := rpcclient.NewGroupRpcClient(client, config.Share.RpcRegisterName.Group) - msgTransfer, err := NewMsgTransfer(&config.KafkaConfig, msgDatabase, &conversationRpcClient, &groupRpcClient) + historyCH, err := NewOnlineHistoryRedisConsumerHandler(&config.KafkaConfig, msgDatabase, &conversationRpcClient, &groupRpcClient) if err != nil { return err } - return msgTransfer.Start(index, config) -} - -func NewMsgTransfer(kafkaConf *config.Kafka, msgDatabase controller.CommonMsgDatabase, - conversationRpcClient *rpcclient.ConversationRpcClient, groupRpcClient *rpcclient.GroupRpcClient) (*MsgTransfer, error) { - historyCH, err := NewOnlineHistoryRedisConsumerHandler(kafkaConf, msgDatabase, conversationRpcClient, groupRpcClient) - if err != nil { - return nil, err - } - historyMongoCH, err := NewOnlineHistoryMongoConsumerHandler(kafkaConf, msgDatabase) + historyMongoCH, err := NewOnlineHistoryMongoConsumerHandler(&config.KafkaConfig, msgDatabase) if err != nil { - return nil, err + return err } - - return &MsgTransfer{ + msgTransfer := &MsgTransfer{ historyCH: historyCH, historyMongoCH: historyMongoCH, - }, nil + } + return msgTransfer.Start(index, config) } func (m *MsgTransfer) Start(index int, config *Config) error { - prometheusPort, err := datautil.GetElemByIndex(config.MsgTransfer.Prometheus.Ports, index) - if err != nil { - return err - } m.ctx, m.cancel = context.WithCancel(context.Background()) - var ( netDone = make(chan struct{}, 1) netErr error @@ -133,16 +117,26 @@ func (m *MsgTransfer) Start(index int, config *Config) error { go m.historyCH.historyConsumerGroup.RegisterHandleAndConsumer(m.ctx, m.historyCH) go m.historyMongoCH.historyConsumerGroup.RegisterHandleAndConsumer(m.ctx, m.historyMongoCH) + err := m.historyCH.redisMessageBatches.Start() + if err != nil { + return err + } if config.MsgTransfer.Prometheus.Enable { go func() { + prometheusPort, err := datautil.GetElemByIndex(config.MsgTransfer.Prometheus.Ports, index) + if err != nil { + netErr = err + netDone <- struct{}{} + return + } proreg := prometheus.NewRegistry() proreg.MustRegister( collectors.NewGoCollector(), ) proreg.MustRegister(prommetrics.GetGrpcCusMetrics("Transfer", &config.Share)...) http.Handle("/metrics", promhttp.HandlerFor(proreg, promhttp.HandlerOpts{Registry: proreg})) - err := http.ListenAndServe(fmt.Sprintf(":%d", prometheusPort), nil) + err = http.ListenAndServe(fmt.Sprintf(":%d", prometheusPort), nil) if err != nil && err != http.ErrServerClosed { netErr = errs.WrapMsg(err, "prometheus start error", "prometheusPort", prometheusPort) netDone <- struct{}{} @@ -157,11 +151,13 @@ func (m *MsgTransfer) Start(index int, config *Config) error { program.SIGTERMExit() // graceful close kafka client. m.cancel() + m.historyCH.redisMessageBatches.Close() m.historyCH.historyConsumerGroup.Close() m.historyMongoCH.historyConsumerGroup.Close() return nil case <-netDone: m.cancel() + m.historyCH.redisMessageBatches.Close() m.historyCH.historyConsumerGroup.Close() m.historyMongoCH.historyConsumerGroup.Close() close(netDone) diff --git a/internal/msgtransfer/online_history_msg_handler.go b/internal/msgtransfer/online_history_msg_handler.go index 194b70187e..d671ec52a2 100644 --- a/internal/msgtransfer/online_history_msg_handler.go +++ b/internal/msgtransfer/online_history_msg_handler.go @@ -16,51 +16,34 @@ package msgtransfer import ( "context" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - "github.com/IBM/sarama" "github.com/go-redis/redis" "github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/controller" "github.com/openimsdk/open-im-server/v3/pkg/msgprocessor" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" + "github.com/openimsdk/open-im-server/v3/pkg/tools/batcher" "github.com/openimsdk/protocol/constant" "github.com/openimsdk/protocol/sdkws" "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/log" "github.com/openimsdk/tools/mcontext" "github.com/openimsdk/tools/mq/kafka" - "github.com/openimsdk/tools/utils/idutil" "github.com/openimsdk/tools/utils/stringutil" "google.golang.org/protobuf/proto" + "strconv" + "strings" + "time" ) const ( - ConsumerMsgs = 3 - SourceMessages = 4 - MongoMessages = 5 - ChannelNum = 100 + size = 500 + mainDataBuffer = 500 + subChanBuffer = 50 + worker = 50 + interval = 100 * time.Millisecond ) -type MsgChannelValue struct { - uniqueKey string - ctx context.Context - ctxMsgList []*ContextMsg -} - -type TriggerChannelValue struct { - ctx context.Context - cMsgList []*sarama.ConsumerMessage -} - -type Cmd2Value struct { - Cmd int - Value any -} type ContextMsg struct { message *sdkws.MsgData ctx context.Context @@ -68,13 +51,8 @@ type ContextMsg struct { type OnlineHistoryRedisConsumerHandler struct { historyConsumerGroup *kafka.MConsumerGroup - chArrays [ChannelNum]chan Cmd2Value - msgDistributionCh chan Cmd2Value - // singleMsgSuccessCount uint64 - // singleMsgFailedCount uint64 - // singleMsgSuccessCountMutex sync.Mutex - // singleMsgFailedCountMutex sync.Mutex + redisMessageBatches *batcher.Batcher[sarama.ConsumerMessage] msgDatabase controller.CommonMsgDatabase conversationRpcClient *rpcclient.ConversationRpcClient @@ -83,89 +61,82 @@ type OnlineHistoryRedisConsumerHandler struct { func NewOnlineHistoryRedisConsumerHandler(kafkaConf *config.Kafka, database controller.CommonMsgDatabase, conversationRpcClient *rpcclient.ConversationRpcClient, groupRpcClient *rpcclient.GroupRpcClient) (*OnlineHistoryRedisConsumerHandler, error) { - historyConsumerGroup, err := kafka.NewMConsumerGroup(kafkaConf.Build(), kafkaConf.ToRedisGroupID, []string{kafkaConf.ToRedisTopic}, true) + historyConsumerGroup, err := kafka.NewMConsumerGroup(kafkaConf.Build(), kafkaConf.ToRedisGroupID, []string{kafkaConf.ToRedisTopic}, false) if err != nil { return nil, err } var och OnlineHistoryRedisConsumerHandler och.msgDatabase = database - och.msgDistributionCh = make(chan Cmd2Value) // no buffer channel - go och.MessagesDistributionHandle() - for i := 0; i < ChannelNum; i++ { - och.chArrays[i] = make(chan Cmd2Value, 50) - go och.Run(i) + + b := batcher.New[sarama.ConsumerMessage]( + batcher.WithSize(size), + batcher.WithWorker(worker), + batcher.WithInterval(interval), + batcher.WithDataBuffer(mainDataBuffer), + batcher.WithSyncWait(true), + batcher.WithBuffer(subChanBuffer), + ) + b.Sharding = func(key string) int { + hashCode := stringutil.GetHashCode(key) + return int(hashCode) % och.redisMessageBatches.Worker() + } + b.Key = func(consumerMessage *sarama.ConsumerMessage) string { + return string(consumerMessage.Key) } + b.Do = och.do + och.redisMessageBatches = b och.conversationRpcClient = conversationRpcClient och.groupRpcClient = groupRpcClient och.historyConsumerGroup = historyConsumerGroup return &och, err } +func (och *OnlineHistoryRedisConsumerHandler) do(ctx context.Context, channelID int, val *batcher.Msg[sarama.ConsumerMessage]) { + ctx = mcontext.WithTriggerIDContext(ctx, val.TriggerID()) + ctxMessages := och.parseConsumerMessages(ctx, val.Val()) + ctx = withAggregationCtx(ctx, ctxMessages) + log.ZInfo(ctx, "msg arrived channel", "channel id", channelID, "msgList length", len(ctxMessages), + "key", val.Key()) + + storageMsgList, notStorageMsgList, storageNotificationList, notStorageNotificationList := + och.categorizeMessageLists(ctxMessages) + log.ZDebug(ctx, "number of categorized messages", "storageMsgList", len(storageMsgList), "notStorageMsgList", + len(notStorageMsgList), "storageNotificationList", len(storageNotificationList), "notStorageNotificationList", + len(notStorageNotificationList)) + + conversationIDMsg := msgprocessor.GetChatConversationIDByMsg(ctxMessages[0].message) + conversationIDNotification := msgprocessor.GetNotificationConversationIDByMsg(ctxMessages[0].message) + och.handleMsg(ctx, val.Key(), conversationIDMsg, storageMsgList, notStorageMsgList) + och.handleNotification(ctx, val.Key(), conversationIDNotification, storageNotificationList, notStorageNotificationList) +} -func (och *OnlineHistoryRedisConsumerHandler) Run(channelID int) { - for cmd := range och.chArrays[channelID] { - switch cmd.Cmd { - case SourceMessages: - msgChannelValue := cmd.Value.(MsgChannelValue) - ctxMsgList := msgChannelValue.ctxMsgList - ctx := msgChannelValue.ctx - log.ZDebug( - ctx, - "msg arrived channel", - "channel id", - channelID, - "msgList length", - len(ctxMsgList), - "uniqueKey", - msgChannelValue.uniqueKey, - ) - storageMsgList, notStorageMsgList, storageNotificationList, notStorageNotificationList, modifyMsgList := och.getPushStorageMsgList( - ctxMsgList, - ) - log.ZDebug( - ctx, - "msg lens", - "storageMsgList", - len(storageMsgList), - "notStorageMsgList", - len(notStorageMsgList), - "storageNotificationList", - len(storageNotificationList), - "notStorageNotificationList", - len(notStorageNotificationList), - "modifyMsgList", - len(modifyMsgList), - ) - conversationIDMsg := msgprocessor.GetChatConversationIDByMsg(ctxMsgList[0].message) - conversationIDNotification := msgprocessor.GetNotificationConversationIDByMsg(ctxMsgList[0].message) - och.handleMsg(ctx, msgChannelValue.uniqueKey, conversationIDMsg, storageMsgList, notStorageMsgList) - och.handleNotification( - ctx, - msgChannelValue.uniqueKey, - conversationIDNotification, - storageNotificationList, - notStorageNotificationList, - ) - if err := och.msgDatabase.MsgToModifyMQ(ctx, msgChannelValue.uniqueKey, conversationIDNotification, modifyMsgList); err != nil { - log.ZError(ctx, "msg to modify mq error", err, "uniqueKey", msgChannelValue.uniqueKey, "modifyMsgList", modifyMsgList) - } +func (och *OnlineHistoryRedisConsumerHandler) parseConsumerMessages(ctx context.Context, consumerMessages []*sarama.ConsumerMessage) []*ContextMsg { + var ctxMessages []*ContextMsg + for i := 0; i < len(consumerMessages); i++ { + ctxMsg := &ContextMsg{} + msgFromMQ := &sdkws.MsgData{} + err := proto.Unmarshal(consumerMessages[i].Value, msgFromMQ) + if err != nil { + log.ZWarn(ctx, "msg_transfer Unmarshal msg err", err, string(consumerMessages[i].Value)) + continue } + var arr []string + for i, header := range consumerMessages[i].Headers { + arr = append(arr, strconv.Itoa(i), string(header.Key), string(header.Value)) + } + log.ZDebug(ctx, "consumer.kafka.GetContextWithMQHeader", "len", len(consumerMessages[i].Headers), + "header", strings.Join(arr, ", ")) + ctxMsg.ctx = kafka.GetContextWithMQHeader(consumerMessages[i].Headers) + ctxMsg.message = msgFromMQ + log.ZDebug(ctx, "message parse finish", "message", msgFromMQ, "key", + string(consumerMessages[i].Key)) + ctxMessages = append(ctxMessages, ctxMsg) } + return ctxMessages } // Get messages/notifications stored message list, not stored and pushed message list. -func (och *OnlineHistoryRedisConsumerHandler) getPushStorageMsgList( - totalMsgs []*ContextMsg, -) (storageMsgList, notStorageMsgList, storageNotificatoinList, notStorageNotificationList, modifyMsgList []*sdkws.MsgData) { - isStorage := func(msg *sdkws.MsgData) bool { - options2 := msgprocessor.Options(msg.Options) - if options2.IsHistory() { - return true - } - // if !(!options2.IsSenderSync() && conversationID == msg.MsgData.SendID) { - // return false - // } - return false - } +func (och *OnlineHistoryRedisConsumerHandler) categorizeMessageLists(totalMsgs []*ContextMsg) (storageMsgList, + notStorageMsgList, storageNotificationList, notStorageNotificationList []*ContextMsg) { for _, v := range totalMsgs { options := msgprocessor.Options(v.message.Options) if !options.IsNotNotification() { @@ -185,176 +156,106 @@ func (och *OnlineHistoryRedisConsumerHandler) getPushStorageMsgList( msgprocessor.WithOfflinePush(false), msgprocessor.WithUnreadCount(false), ) - storageMsgList = append(storageMsgList, msg) + ctxMsg := &ContextMsg{ + message: msg, + ctx: v.ctx, + } + storageMsgList = append(storageMsgList, ctxMsg) } - if isStorage(v.message) { - storageNotificatoinList = append(storageNotificatoinList, v.message) + if options.IsHistory() { + storageNotificationList = append(storageNotificationList, v) } else { - notStorageNotificationList = append(notStorageNotificationList, v.message) + notStorageNotificationList = append(notStorageNotificationList, v) } } else { - if isStorage(v.message) { - storageMsgList = append(storageMsgList, v.message) + if options.IsHistory() { + storageMsgList = append(storageMsgList, v) } else { - notStorageMsgList = append(notStorageMsgList, v.message) + notStorageMsgList = append(notStorageMsgList, v) } } - if v.message.ContentType == constant.ReactionMessageModifier || - v.message.ContentType == constant.ReactionMessageDeleter { - modifyMsgList = append(modifyMsgList, v.message) - } } return } -func (och *OnlineHistoryRedisConsumerHandler) handleNotification( - ctx context.Context, - key, conversationID string, - storageList, notStorageList []*sdkws.MsgData, -) { +func (och *OnlineHistoryRedisConsumerHandler) handleMsg(ctx context.Context, key, conversationID string, storageList, notStorageList []*ContextMsg) { och.toPushTopic(ctx, key, conversationID, notStorageList) - if len(storageList) > 0 { - lastSeq, _, err := och.msgDatabase.BatchInsertChat2Cache(ctx, conversationID, storageList) - if err != nil { - log.ZError( - ctx, - "notification batch insert to redis error", - err, - "conversationID", - conversationID, - "storageList", - storageList, - ) - return - } - log.ZDebug(ctx, "success to next topic", "conversationID", conversationID) - err = och.msgDatabase.MsgToMongoMQ(ctx, key, conversationID, storageList, lastSeq) - if err != nil { - log.ZError(ctx, "MsgToMongoMQ error", err) - } - och.toPushTopic(ctx, key, conversationID, storageList) - } -} - -func (och *OnlineHistoryRedisConsumerHandler) toPushTopic(ctx context.Context, key, conversationID string, msgs []*sdkws.MsgData) { - for _, v := range msgs { - och.msgDatabase.MsgToPushMQ(ctx, key, conversationID, v) // nolint: errcheck + var storageMessageList []*sdkws.MsgData + for _, msg := range storageList { + storageMessageList = append(storageMessageList, msg.message) } -} - -func (och *OnlineHistoryRedisConsumerHandler) handleMsg(ctx context.Context, key, conversationID string, storageList, notStorageList []*sdkws.MsgData) { - och.toPushTopic(ctx, key, conversationID, notStorageList) - if len(storageList) > 0 { - lastSeq, isNewConversation, err := och.msgDatabase.BatchInsertChat2Cache(ctx, conversationID, storageList) + if len(storageMessageList) > 0 { + msg := storageMessageList[0] + lastSeq, isNewConversation, err := och.msgDatabase.BatchInsertChat2Cache(ctx, conversationID, storageMessageList) if err != nil && errs.Unwrap(err) != redis.Nil { - log.ZError(ctx, "batch data insert to redis err", err, "storageMsgList", storageList) + log.ZError(ctx, "batch data insert to redis err", err, "storageMsgList", storageMessageList) return } if isNewConversation { - switch storageList[0].SessionType { + switch msg.SessionType { case constant.ReadGroupChatType: log.ZInfo(ctx, "group chat first create conversation", "conversationID", conversationID) - userIDs, err := och.groupRpcClient.GetGroupMemberIDs(ctx, storageList[0].GroupID) + userIDs, err := och.groupRpcClient.GetGroupMemberIDs(ctx, msg.GroupID) if err != nil { log.ZWarn(ctx, "get group member ids error", err, "conversationID", conversationID) } else { if err := och.conversationRpcClient.GroupChatFirstCreateConversation(ctx, - storageList[0].GroupID, userIDs); err != nil { + msg.GroupID, userIDs); err != nil { log.ZWarn(ctx, "single chat first create conversation error", err, "conversationID", conversationID) } } case constant.SingleChatType, constant.NotificationChatType: - if err := och.conversationRpcClient.SingleChatFirstCreateConversation(ctx, storageList[0].RecvID, - storageList[0].SendID, conversationID, storageList[0].SessionType); err != nil { + if err := och.conversationRpcClient.SingleChatFirstCreateConversation(ctx, msg.RecvID, + msg.SendID, conversationID, msg.SessionType); err != nil { log.ZWarn(ctx, "single chat or notification first create conversation error", err, - "conversationID", conversationID, "sessionType", storageList[0].SessionType) + "conversationID", conversationID, "sessionType", msg.SessionType) } default: log.ZWarn(ctx, "unknown session type", nil, "sessionType", - storageList[0].SessionType) + msg.SessionType) } } log.ZDebug(ctx, "success incr to next topic") - err = och.msgDatabase.MsgToMongoMQ(ctx, key, conversationID, storageList, lastSeq) + err = och.msgDatabase.MsgToMongoMQ(ctx, key, conversationID, storageMessageList, lastSeq) if err != nil { - log.ZError(ctx, "MsgToMongoMQ error", err) + log.ZError(ctx, "Msg To MongoDB MQ error", err, "conversationID", + conversationID, "storageList", storageMessageList, "lastSeq", lastSeq) } och.toPushTopic(ctx, key, conversationID, storageList) } } -func (och *OnlineHistoryRedisConsumerHandler) MessagesDistributionHandle() { - for { - aggregationMsgs := make(map[string][]*ContextMsg, ChannelNum) - select { - case cmd := <-och.msgDistributionCh: - switch cmd.Cmd { - case ConsumerMsgs: - triggerChannelValue := cmd.Value.(TriggerChannelValue) - ctx := triggerChannelValue.ctx - consumerMessages := triggerChannelValue.cMsgList - // Aggregation map[userid]message list - log.ZDebug(ctx, "batch messages come to distribution center", "length", len(consumerMessages)) - for i := 0; i < len(consumerMessages); i++ { - ctxMsg := &ContextMsg{} - msgFromMQ := &sdkws.MsgData{} - err := proto.Unmarshal(consumerMessages[i].Value, msgFromMQ) - if err != nil { - log.ZError(ctx, "msg_transfer Unmarshal msg err", err, string(consumerMessages[i].Value)) - continue - } - var arr []string - for i, header := range consumerMessages[i].Headers { - arr = append(arr, strconv.Itoa(i), string(header.Key), string(header.Value)) - } - log.ZInfo(ctx, "consumer.kafka.GetContextWithMQHeader", "len", len(consumerMessages[i].Headers), - "header", strings.Join(arr, ", ")) - ctxMsg.ctx = kafka.GetContextWithMQHeader(consumerMessages[i].Headers) - ctxMsg.message = msgFromMQ - log.ZDebug( - ctx, - "single msg come to distribution center", - "message", - msgFromMQ, - "key", - string(consumerMessages[i].Key), - ) - // aggregationMsgs[string(consumerMessages[i].Key)] = - // append(aggregationMsgs[string(consumerMessages[i].Key)], ctxMsg) - if oldM, ok := aggregationMsgs[string(consumerMessages[i].Key)]; ok { - oldM = append(oldM, ctxMsg) - aggregationMsgs[string(consumerMessages[i].Key)] = oldM - } else { - m := make([]*ContextMsg, 0, 100) - m = append(m, ctxMsg) - aggregationMsgs[string(consumerMessages[i].Key)] = m - } - } - log.ZDebug(ctx, "generate map list users len", "length", len(aggregationMsgs)) - for uniqueKey, v := range aggregationMsgs { - if len(v) >= 0 { - hashCode := stringutil.GetHashCode(uniqueKey) - channelID := hashCode % ChannelNum - newCtx := withAggregationCtx(ctx, v) - log.ZDebug( - newCtx, - "generate channelID", - "hashCode", - hashCode, - "channelID", - channelID, - "uniqueKey", - uniqueKey, - ) - och.chArrays[channelID] <- Cmd2Value{Cmd: SourceMessages, Value: MsgChannelValue{uniqueKey: uniqueKey, ctxMsgList: v, ctx: newCtx}} - } - } - } +func (och *OnlineHistoryRedisConsumerHandler) handleNotification(ctx context.Context, key, conversationID string, + storageList, notStorageList []*ContextMsg) { + och.toPushTopic(ctx, key, conversationID, notStorageList) + var storageMessageList []*sdkws.MsgData + for _, msg := range storageList { + storageMessageList = append(storageMessageList, msg.message) + } + if len(storageMessageList) > 0 { + lastSeq, _, err := och.msgDatabase.BatchInsertChat2Cache(ctx, conversationID, storageMessageList) + if err != nil { + log.ZError(ctx, "notification batch insert to redis error", err, "conversationID", conversationID, + "storageList", storageMessageList) + return + } + log.ZDebug(ctx, "success to next topic", "conversationID", conversationID) + err = och.msgDatabase.MsgToMongoMQ(ctx, key, conversationID, storageMessageList, lastSeq) + if err != nil { + log.ZError(ctx, "Msg To MongoDB MQ error", err, "conversationID", + conversationID, "storageList", storageMessageList, "lastSeq", lastSeq) } + och.toPushTopic(ctx, key, conversationID, storageList) + } +} + +func (och *OnlineHistoryRedisConsumerHandler) toPushTopic(_ context.Context, key, conversationID string, msgs []*ContextMsg) { + for _, v := range msgs { + och.msgDatabase.MsgToPushMQ(v.ctx, key, conversationID, v.message) } } @@ -377,106 +278,30 @@ func (och *OnlineHistoryRedisConsumerHandler) Cleanup(_ sarama.ConsumerGroupSess return nil } -func (och *OnlineHistoryRedisConsumerHandler) ConsumeClaim( - sess sarama.ConsumerGroupSession, - claim sarama.ConsumerGroupClaim, -) error { // a instance in the consumer group - for { - if sess == nil { - log.ZWarn(context.Background(), "sess == nil, waiting", nil) - time.Sleep(100 * time.Millisecond) - } else { - break - } - } +func (och *OnlineHistoryRedisConsumerHandler) ConsumeClaim(session sarama.ConsumerGroupSession, + claim sarama.ConsumerGroupClaim) error { // a instance in the consumer group log.ZInfo(context.Background(), "online new session msg come", "highWaterMarkOffset", claim.HighWaterMarkOffset(), "topic", claim.Topic(), "partition", claim.Partition()) - - var ( - split = 1000 - rwLock = new(sync.RWMutex) - messages = make([]*sarama.ConsumerMessage, 0, 1000) - ticker = time.NewTicker(time.Millisecond * 100) - - wg = sync.WaitGroup{} - running = new(atomic.Bool) - ) - running.Store(true) - - wg.Add(1) - go func() { - defer wg.Done() - - for { - select { - case <-ticker.C: - // if the buffer is empty and running is false, return loop. - if len(messages) == 0 { - if !running.Load() { - return - } - - continue - } - - rwLock.Lock() - buffer := make([]*sarama.ConsumerMessage, 0, len(messages)) - buffer = append(buffer, messages...) - - // reuse slice, set cap to 0 - messages = messages[:0] - rwLock.Unlock() - - start := time.Now() - ctx := mcontext.WithTriggerIDContext(context.Background(), idutil.OperationIDGenerator()) - log.ZDebug(ctx, "timer trigger msg consumer start", "length", len(buffer)) - for i := 0; i < len(buffer)/split; i++ { - och.msgDistributionCh <- Cmd2Value{Cmd: ConsumerMsgs, Value: TriggerChannelValue{ - ctx: ctx, cMsgList: buffer[i*split : (i+1)*split], - }} - } - if (len(buffer) % split) > 0 { - och.msgDistributionCh <- Cmd2Value{Cmd: ConsumerMsgs, Value: TriggerChannelValue{ - ctx: ctx, cMsgList: buffer[split*(len(buffer)/split):], - }} - } - - log.ZDebug(ctx, "timer trigger msg consumer end", - "length", len(buffer), "time_cost", time.Since(start), - ) + och.redisMessageBatches.OnComplete = func(lastMessage *sarama.ConsumerMessage, totalCount int) { + session.MarkMessage(lastMessage, "") + session.Commit() + } + for { + select { + case msg, ok := <-claim.Messages(): + if !ok { + return nil } - } - }() - wg.Add(1) - go func() { - defer wg.Done() - - for running.Load() { - select { - case msg, ok := <-claim.Messages(): - if !ok { - running.Store(false) - return - } - - if len(msg.Value) == 0 { - continue - } - - rwLock.Lock() - messages = append(messages, msg) - rwLock.Unlock() - - sess.MarkMessage(msg, "") - - case <-sess.Context().Done(): - running.Store(false) - return + if len(msg.Value) == 0 { + continue + } + err := och.redisMessageBatches.Put(context.Background(), msg) + if err != nil { + log.ZWarn(context.Background(), "put msg to error", err, "msg", msg) } + case <-session.Context().Done(): + return nil } - }() - - wg.Wait() - return nil + } } diff --git a/internal/msgtransfer/online_msg_to_mongo_handler.go b/internal/msgtransfer/online_msg_to_mongo_handler.go index 0fa9fe0d12..e5651012c6 100644 --- a/internal/msgtransfer/online_msg_to_mongo_handler.go +++ b/internal/msgtransfer/online_msg_to_mongo_handler.go @@ -89,7 +89,6 @@ func (mc *OnlineHistoryMongoConsumerHandler) handleChatWs2Mongo(ctx context.Cont msgFromMQ.ConversationID, ) } - mc.msgDatabase.DelUserDeleteMsgsList(ctx, msgFromMQ.ConversationID, seqs) } func (OnlineHistoryMongoConsumerHandler) Setup(_ sarama.ConsumerGroupSession) error { return nil } diff --git a/internal/push/push_handler.go b/internal/push/push_handler.go index bf0ede375c..03c299b7ab 100644 --- a/internal/push/push_handler.go +++ b/internal/push/push_handler.go @@ -17,6 +17,7 @@ package push import ( "context" "encoding/json" + "github.com/IBM/sarama" "github.com/openimsdk/open-im-server/v3/internal/push/offlinepush" "github.com/openimsdk/open-im-server/v3/internal/push/offlinepush/options" "github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics" @@ -25,20 +26,18 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/rpccache" "github.com/openimsdk/open-im-server/v3/pkg/rpcclient" "github.com/openimsdk/open-im-server/v3/pkg/util/conversationutil" - "github.com/openimsdk/protocol/sdkws" - "github.com/openimsdk/tools/discovery" - "github.com/openimsdk/tools/mcontext" - "github.com/openimsdk/tools/utils/jsonutil" - "github.com/redis/go-redis/v9" - - "github.com/IBM/sarama" "github.com/openimsdk/protocol/constant" pbchat "github.com/openimsdk/protocol/msg" pbpush "github.com/openimsdk/protocol/push" + "github.com/openimsdk/protocol/sdkws" + "github.com/openimsdk/tools/discovery" "github.com/openimsdk/tools/log" + "github.com/openimsdk/tools/mcontext" "github.com/openimsdk/tools/mq/kafka" "github.com/openimsdk/tools/utils/datautil" + "github.com/openimsdk/tools/utils/jsonutil" "github.com/openimsdk/tools/utils/timeutil" + "github.com/redis/go-redis/v9" "google.golang.org/protobuf/proto" ) @@ -162,7 +161,8 @@ func (c *ConsumerHandler) Push2User(ctx context.Context, userIDs []string, msg * err = c.offlinePushMsg(ctx, msg, offlinePUshUserID) if err != nil { - return err + log.ZWarn(ctx, "offlinePushMsg failed", err, "offlinePUshUserID", offlinePUshUserID, "msg", msg) + return nil } return nil @@ -223,8 +223,8 @@ func (c *ConsumerHandler) Push2Group(ctx context.Context, groupID string, msg *s err = c.offlinePushMsg(ctx, msg, needOfflinePushUserIDs) if err != nil { - log.ZError(ctx, "offlinePushMsg failed", err, "groupID", groupID, "msg", msg) - return err + log.ZWarn(ctx, "offlinePushMsg failed", err, "groupID", groupID, "msg", msg) + return nil } } diff --git a/internal/rpc/group/group.go b/internal/rpc/group/group.go index 51fd2d7b66..a9cea4ff22 100644 --- a/internal/rpc/group/group.go +++ b/internal/rpc/group/group.go @@ -292,28 +292,7 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbgroup.CreateGroupR break } } - if req.GroupInfo.GroupType == constant.SuperGroup { - go func() { - for _, userID := range userIDs { - s.notification.SuperGroupNotification(ctx, userID, userID) - } - }() - } else { - tips := &sdkws.GroupCreatedTips{ - Group: resp.GroupInfo, - OperationTime: group.CreateTime.UnixMilli(), - GroupOwnerUser: s.groupMemberDB2PB(groupMembers[0], userMap[groupMembers[0].UserID].AppMangerLevel), - } - for _, member := range groupMembers { - member.Nickname = userMap[member.UserID].Nickname - tips.MemberList = append(tips.MemberList, s.groupMemberDB2PB(member, userMap[member.UserID].AppMangerLevel)) - if member.UserID == opUserID { - tips.OpUser = s.groupMemberDB2PB(member, userMap[member.UserID].AppMangerLevel) - break - } - } - s.notification.GroupCreatedNotification(ctx, tips) - } + s.notification.GroupCreatedNotification(ctx, tips) reqCallBackAfter := &pbgroup.CreateGroupReq{ MemberUserIDs: userIDs, diff --git a/internal/rpc/group/notification.go b/internal/rpc/group/notification.go index f0f054d0ab..cfa62c85db 100644 --- a/internal/rpc/group/notification.go +++ b/internal/rpc/group/notification.go @@ -715,7 +715,3 @@ func (g *GroupNotificationSender) GroupMemberSetToOrdinaryUserNotification(ctx c } g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.GroupMemberSetToOrdinaryUserNotification, tips) } - -func (g *GroupNotificationSender) SuperGroupNotification(ctx context.Context, sendID, recvID string) { - g.Notification(ctx, sendID, recvID, constant.SuperGroupUpdateNotification, nil) -} diff --git a/internal/rpc/msg/server.go b/internal/rpc/msg/server.go index 6ff45605e4..f1fb28ffff 100644 --- a/internal/rpc/msg/server.go +++ b/internal/rpc/msg/server.go @@ -85,8 +85,7 @@ func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryReg if err != nil { return err } - //todo MsgCacheTimeout - msgModel := redis.NewMsgCache(rdb, config.RedisConfig.EnablePipeline) + msgModel := redis.NewMsgCache(rdb) seqModel := redis.NewSeqCache(rdb) conversationClient := rpcclient.NewConversationRpcClient(client, config.Share.RpcRegisterName.Conversation) userRpcClient := rpcclient.NewUserRpcClient(client, config.Share.RpcRegisterName.User, config.Share.IMAdminUserID) diff --git a/pkg/common/config/config.go b/pkg/common/config/config.go index a75d45ebbd..5313c196ac 100644 --- a/pkg/common/config/config.go +++ b/pkg/common/config/config.go @@ -323,13 +323,12 @@ type User struct { } type Redis struct { - Address []string `mapstructure:"address"` - Username string `mapstructure:"username"` - Password string `mapstructure:"password"` - EnablePipeline bool `mapstructure:"enablePipeline"` - ClusterMode bool `mapstructure:"clusterMode"` - DB int `mapstructure:"storage"` - MaxRetry int `mapstructure:"MaxRetry"` + Address []string `mapstructure:"address"` + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + ClusterMode bool `mapstructure:"clusterMode"` + DB int `mapstructure:"storage"` + MaxRetry int `mapstructure:"MaxRetry"` } type BeforeConfig struct { diff --git a/pkg/common/startrpc/start.go b/pkg/common/startrpc/start.go index a36bcfe1c8..069c92012a 100644 --- a/pkg/common/startrpc/start.go +++ b/pkg/common/startrpc/start.go @@ -52,12 +52,9 @@ func Start[T any](ctx context.Context, discovery *config2.Discovery, prometheusC if err != nil { return err } - prometheusPort, err := datautil.GetElemByIndex(prometheusConfig.Ports, index) - if err != nil { - return err - } + log.CInfo(ctx, "RPC server is initializing", "rpcRegisterName", rpcRegisterName, "rpcPort", rpcPort, - "prometheusPort", prometheusPort) + "prometheusPorts", prometheusConfig.Ports) rpcTcpAddr := net.JoinHostPort(network.GetListenIP(listenIP), strconv.Itoa(rpcPort)) listener, err := net.Listen( "tcp", @@ -117,9 +114,14 @@ func Start[T any](ctx context.Context, discovery *config2.Discovery, prometheusC netErr error httpServer *http.Server ) - - go func() { - if prometheusConfig.Enable && prometheusPort != 0 { + if prometheusConfig.Enable { + go func() { + prometheusPort, err := datautil.GetElemByIndex(prometheusConfig.Ports, index) + if err != nil { + netErr = err + netDone <- struct{}{} + return + } metric.InitializeMetrics(srv) // Create a HTTP server for prometheus. httpServer = &http.Server{Handler: promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), Addr: fmt.Sprintf("0.0.0.0:%d", prometheusPort)} @@ -127,8 +129,8 @@ func Start[T any](ctx context.Context, discovery *config2.Discovery, prometheusC netErr = errs.WrapMsg(err, "prometheus start err", httpServer.Addr) netDone <- struct{}{} } - } - }() + }() + } go func() { err := srv.Serve(listener) diff --git a/pkg/common/storage/cache/cachekey/msg.go b/pkg/common/storage/cache/cachekey/msg.go index d1e8eeb7bc..8e05b64f1f 100644 --- a/pkg/common/storage/cache/cachekey/msg.go +++ b/pkg/common/storage/cache/cachekey/msg.go @@ -31,10 +31,6 @@ const ( reactionNotification = "EX_NOTIFICATION_" ) -func GetAllMessageCacheKey(conversationID string) string { - return messageCache + conversationID + "_*" -} - func GetMessageCacheKey(conversationID string, seq int64) string { return messageCache + conversationID + "_" + strconv.Itoa(int(seq)) } diff --git a/pkg/common/storage/cache/conversation.go b/pkg/common/storage/cache/conversation.go index bf85af0c5f..f34fd599f8 100644 --- a/pkg/common/storage/cache/conversation.go +++ b/pkg/common/storage/cache/conversation.go @@ -52,9 +52,6 @@ type ConversationCache interface { // GetUserAllHasReadSeqs(ctx context.Context, ownerUserID string) (map[string]int64, error) DelUserAllHasReadSeqs(ownerUserID string, conversationIDs ...string) ConversationCache - GetConversationsByConversationID(ctx context.Context, - conversationIDs []string) ([]*relationtb.Conversation, error) - DelConversationByConversationID(conversationIDs ...string) ConversationCache GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error) DelConversationNotReceiveMessageUserIDs(conversationIDs ...string) ConversationCache } diff --git a/pkg/common/storage/cache/msg.go b/pkg/common/storage/cache/msg.go index 0adbb35726..00eb28c02e 100644 --- a/pkg/common/storage/cache/msg.go +++ b/pkg/common/storage/cache/msg.go @@ -23,13 +23,8 @@ import ( type MsgCache interface { GetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsg []*sdkws.MsgData, failedSeqList []int64, err error) - SetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) - UserDeleteMsgs(ctx context.Context, conversationID string, seqs []int64, userID string) error - DelUserDeleteMsgsList(ctx context.Context, conversationID string, seqs []int64) - DeleteMessages(ctx context.Context, conversationID string, seqs []int64) error - GetUserDelList(ctx context.Context, userID, conversationID string) (seqs []int64, err error) - CleanUpOneConversationAllMsg(ctx context.Context, conversationID string) error - DelMsgFromCache(ctx context.Context, userID string, seqList []int64) error + SetMessagesToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) + DeleteMessagesFromCache(ctx context.Context, conversationID string, seqs []int64) error SetSendMsgStatus(ctx context.Context, id string, status int32) error GetSendMsgStatus(ctx context.Context, id string) (int32, error) JudgeMessageReactionExist(ctx context.Context, clientMsgID string, sessionType int32) (bool, error) diff --git a/pkg/common/storage/cache/redis/batch_handler.go b/pkg/common/storage/cache/redis/batch_handler.go index b68d8f86dc..95f6699047 100644 --- a/pkg/common/storage/cache/redis/batch_handler.go +++ b/pkg/common/storage/cache/redis/batch_handler.go @@ -62,17 +62,13 @@ func (c *BatchDeleterRedis) ChainExecDel(ctx context.Context) error { func (c *BatchDeleterRedis) execDel(ctx context.Context, keys []string) error { if len(keys) > 0 { log.ZDebug(ctx, "delete cache", "topic", c.redisPubTopics, "keys", keys) - slotMapKeys, err := groupKeysBySlot(ctx, c.redisClient, keys) + // Batch delete keys + err := ProcessKeysBySlot(ctx, c.redisClient, keys, func(ctx context.Context, slot int64, keys []string) error { + return c.rocksClient.TagAsDeletedBatch2(ctx, keys) + }) if err != nil { return err } - // Batch delete keys - for slot, singleSlotKeys := range slotMapKeys { - if err := c.rocksClient.TagAsDeletedBatch2(ctx, singleSlotKeys); err != nil { - log.ZWarn(ctx, "Batch delete cache failed", err, "slot", slot, "keys", singleSlotKeys) - continue - } - } // Publish the keys that have been deleted to Redis to update the local cache information of other nodes if len(c.redisPubTopics) > 0 && len(keys) > 0 { keysByTopic := localcache.GetPublishKeysByTopic(c.redisPubTopics, keys) @@ -117,37 +113,6 @@ func GetRocksCacheOptions() *rockscache.Options { return &opts } -// groupKeysBySlot groups keys by their Redis cluster hash slots. -func groupKeysBySlot(ctx context.Context, redisClient redis.UniversalClient, keys []string) (map[int64][]string, error) { - slots := make(map[int64][]string) - clusterClient, isCluster := redisClient.(*redis.ClusterClient) - if isCluster { - pipe := clusterClient.Pipeline() - cmds := make([]*redis.IntCmd, len(keys)) - for i, key := range keys { - cmds[i] = pipe.ClusterKeySlot(ctx, key) - } - _, err := pipe.Exec(ctx) - if err != nil { - return nil, errs.WrapMsg(err, "get slot err") - } - - for i, cmd := range cmds { - slot, err := cmd.Result() - if err != nil { - log.ZWarn(ctx, "some key get slot err", err, "key", keys[i]) - continue - } - slots[slot] = append(slots[slot], keys[i]) - } - } else { - // If not a cluster client, put all keys in the same slot (0) - slots[0] = keys - } - - return slots, nil -} - func getCache[T any](ctx context.Context, rcClient *rockscache.Client, key string, expire time.Duration, fn func(ctx context.Context) (T, error)) (T, error) { var t T var write bool diff --git a/pkg/common/storage/cache/redis/conversation.go b/pkg/common/storage/cache/redis/conversation.go index 5fac79a7e1..8c0393dd56 100644 --- a/pkg/common/storage/cache/redis/conversation.go +++ b/pkg/common/storage/cache/redis/conversation.go @@ -222,14 +222,6 @@ func (c *ConversationRedisCache) DelUserAllHasReadSeqs(ownerUserID string, conve return cache } -func (c *ConversationRedisCache) GetConversationsByConversationID(ctx context.Context, conversationIDs []string) ([]*model.Conversation, error) { - panic("implement me") -} - -func (c *ConversationRedisCache) DelConversationByConversationID(conversationIDs ...string) cache.ConversationCache { - panic("implement me") -} - func (c *ConversationRedisCache) GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error) { return getCache(ctx, c.rcClient, c.getConversationNotReceiveMessageUserIDsKey(conversationID), c.expireTime, func(ctx context.Context) ([]string, error) { return c.conversationDB.GetConversationNotReceiveMessageUserIDs(ctx, conversationID) diff --git a/pkg/common/storage/cache/redis/lua_script.go b/pkg/common/storage/cache/redis/lua_script.go new file mode 100644 index 0000000000..c7609cb443 --- /dev/null +++ b/pkg/common/storage/cache/redis/lua_script.go @@ -0,0 +1,125 @@ +package redis + +import ( + "context" + "fmt" + "github.com/openimsdk/open-im-server/v3/pkg/common/servererrs" + "github.com/openimsdk/tools/errs" + "github.com/openimsdk/tools/log" + "github.com/redis/go-redis/v9" +) + +var ( + setBatchWithCommonExpireScript = redis.NewScript(` +local expire = tonumber(ARGV[1]) +for i, key in ipairs(KEYS) do + redis.call('SET', key, ARGV[i + 1]) + redis.call('EXPIRE', key, expire) +end +return #KEYS +`) + + setBatchWithIndividualExpireScript = redis.NewScript(` +local n = #KEYS +for i = 1, n do + redis.call('SET', KEYS[i], ARGV[i]) + redis.call('EXPIRE', KEYS[i], ARGV[i + n]) +end +return n +`) + + deleteBatchScript = redis.NewScript(` +for i, key in ipairs(KEYS) do + redis.call('DEL', key) +end +return #KEYS +`) + + getBatchScript = redis.NewScript(` +local values = {} +for i, key in ipairs(KEYS) do + local value = redis.call('GET', key) + table.insert(values, value) +end +return values +`) +) + +func callLua(ctx context.Context, rdb redis.Scripter, script *redis.Script, keys []string, args []any) (any, error) { + log.ZDebug(ctx, "callLua args", "scriptHash", script.Hash(), "keys", keys, "args", args) + r := script.EvalSha(ctx, rdb, keys, args) + if redis.HasErrorPrefix(r.Err(), "NOSCRIPT") { + if err := script.Load(ctx, rdb).Err(); err != nil { + r = script.Eval(ctx, rdb, keys, args) + } else { + r = script.EvalSha(ctx, rdb, keys, args) + } + } + v, err := r.Result() + if err == redis.Nil { + err = nil + } + return v, errs.WrapMsg(err, "call lua err", "scriptHash", script.Hash(), "keys", keys, "args", args) +} + +func LuaSetBatchWithCommonExpire(ctx context.Context, rdb redis.Scripter, keys []string, values []string, expire int) error { + // Check if the lengths of keys and values match + if len(keys) != len(values) { + return errs.New("keys and values length mismatch").Wrap() + } + + // Ensure allocation size does not overflow + maxAllowedLen := (1 << 31) - 1 // 2GB limit (maximum address space for 32-bit systems) + + if len(values) > maxAllowedLen-1 { + return fmt.Errorf("values length is too large, causing overflow") + } + var vals = make([]any, 0, 1+len(values)) + vals = append(vals, expire) + for _, v := range values { + vals = append(vals, v) + } + _, err := callLua(ctx, rdb, setBatchWithCommonExpireScript, keys, vals) + return err +} + +func LuaSetBatchWithIndividualExpire(ctx context.Context, rdb redis.Scripter, keys []string, values []string, expires []int) error { + // Check if the lengths of keys, values, and expires match + if len(keys) != len(values) || len(keys) != len(expires) { + return errs.New("keys and values length mismatch").Wrap() + } + + // Ensure the allocation size does not overflow + maxAllowedLen := (1 << 31) - 1 // 2GB limit (maximum address space for 32-bit systems) + + if len(values) > maxAllowedLen-1 { + return errs.New(fmt.Sprintf("values length %d exceeds the maximum allowed length %d", len(values), maxAllowedLen-1)).Wrap() + } + var vals = make([]any, 0, len(values)+len(expires)) + for _, v := range values { + vals = append(vals, v) + } + for _, ex := range expires { + vals = append(vals, ex) + } + _, err := callLua(ctx, rdb, setBatchWithIndividualExpireScript, keys, vals) + return err +} + +func LuaDeleteBatch(ctx context.Context, rdb redis.Scripter, keys []string) error { + _, err := callLua(ctx, rdb, deleteBatchScript, keys, nil) + return err +} + +func LuaGetBatch(ctx context.Context, rdb redis.Scripter, keys []string) ([]any, error) { + v, err := callLua(ctx, rdb, getBatchScript, keys, nil) + if err != nil { + return nil, err + } + values, ok := v.([]any) + if !ok { + return nil, servererrs.ErrArgs.WrapMsg("invalid lua get batch result") + } + return values, nil + +} diff --git a/pkg/common/storage/cache/redis/lua_script_test.go b/pkg/common/storage/cache/redis/lua_script_test.go new file mode 100644 index 0000000000..1566b59a0b --- /dev/null +++ b/pkg/common/storage/cache/redis/lua_script_test.go @@ -0,0 +1,75 @@ +package redis + +import ( + "context" + "github.com/go-redis/redismock/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" +) + +func TestLuaSetBatchWithCommonExpire(t *testing.T) { + rdb, mock := redismock.NewClientMock() + ctx := context.Background() + + keys := []string{"key1", "key2"} + values := []string{"value1", "value2"} + expire := 10 + + mock.ExpectEvalSha(setBatchWithCommonExpireScript.Hash(), keys, []any{expire, "value1", "value2"}).SetVal(int64(len(keys))) + + err := LuaSetBatchWithCommonExpire(ctx, rdb, keys, values, expire) + require.NoError(t, err) + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestLuaSetBatchWithIndividualExpire(t *testing.T) { + rdb, mock := redismock.NewClientMock() + ctx := context.Background() + + keys := []string{"key1", "key2"} + values := []string{"value1", "value2"} + expires := []int{10, 20} + + args := make([]any, 0, len(values)+len(expires)) + for _, v := range values { + args = append(args, v) + } + for _, ex := range expires { + args = append(args, ex) + } + + mock.ExpectEvalSha(setBatchWithIndividualExpireScript.Hash(), keys, args).SetVal(int64(len(keys))) + + err := LuaSetBatchWithIndividualExpire(ctx, rdb, keys, values, expires) + require.NoError(t, err) + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestLuaDeleteBatch(t *testing.T) { + rdb, mock := redismock.NewClientMock() + ctx := context.Background() + + keys := []string{"key1", "key2"} + + mock.ExpectEvalSha(deleteBatchScript.Hash(), keys, []any{}).SetVal(int64(len(keys))) + + err := LuaDeleteBatch(ctx, rdb, keys) + require.NoError(t, err) + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestLuaGetBatch(t *testing.T) { + rdb, mock := redismock.NewClientMock() + ctx := context.Background() + + keys := []string{"key1", "key2"} + expectedValues := []any{"value1", "value2"} + + mock.ExpectEvalSha(getBatchScript.Hash(), keys, []any{}).SetVal(expectedValues) + + values, err := LuaGetBatch(ctx, rdb, keys) + require.NoError(t, err) + assert.NoError(t, mock.ExpectationsWereMet()) + assert.Equal(t, expectedValues, values) +} diff --git a/pkg/common/storage/cache/redis/meta_cache.go b/pkg/common/storage/cache/redis/meta_cache.go deleted file mode 100644 index 4c2fcacd13..0000000000 --- a/pkg/common/storage/cache/redis/meta_cache.go +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright © 2023 OpenIM. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package redis diff --git a/pkg/common/storage/cache/redis/msg.go b/pkg/common/storage/cache/redis/msg.go index df69bc6451..2d21cfe135 100644 --- a/pkg/common/storage/cache/redis/msg.go +++ b/pkg/common/storage/cache/redis/msg.go @@ -16,37 +16,25 @@ package redis import ( "context" - "errors" - "github.com/gogo/protobuf/jsonpb" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey" "github.com/openimsdk/open-im-server/v3/pkg/msgprocessor" - "github.com/openimsdk/protocol/constant" "github.com/openimsdk/protocol/sdkws" "github.com/openimsdk/tools/errs" - "github.com/openimsdk/tools/log" - "github.com/openimsdk/tools/utils/stringutil" + "github.com/openimsdk/tools/utils/datautil" "github.com/redis/go-redis/v9" - "golang.org/x/sync/errgroup" "time" -) +) // -const msgCacheTimeout = 86400 * time.Second +// msgCacheTimeout is expiration time of message cache, 86400 seconds +const msgCacheTimeout = 86400 -var concurrentLimit = 3 - -func NewMsgCache(client redis.UniversalClient, redisEnablePipeline bool) cache.MsgCache { - return &msgCache{rdb: client, msgCacheTimeout: msgCacheTimeout, redisEnablePipeline: redisEnablePipeline} +func NewMsgCache(client redis.UniversalClient) cache.MsgCache { + return &msgCache{rdb: client} } type msgCache struct { - rdb redis.UniversalClient - msgCacheTimeout time.Duration - redisEnablePipeline bool -} - -func (c *msgCache) getAllMessageCacheKey(conversationID string) string { - return cachekey.GetAllMessageCacheKey(conversationID) + rdb redis.UniversalClient } func (c *msgCache) getMessageCacheKey(conversationID string, seq int64) string { @@ -72,218 +60,41 @@ func (c *msgCache) getMessageReactionExPrefix(clientMsgID string, sessionType in return cachekey.GetMessageReactionExKey(clientMsgID, sessionType) } -func (c *msgCache) SetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) { - if c.redisEnablePipeline { - return c.PipeSetMessageToCache(ctx, conversationID, msgs) - } - return c.ParallelSetMessageToCache(ctx, conversationID, msgs) -} - -func (c *msgCache) PipeSetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) { - pipe := c.rdb.Pipeline() - for _, msg := range msgs { - s, err := msgprocessor.Pb2String(msg) - if err != nil { - return 0, err - } - - key := c.getMessageCacheKey(conversationID, msg.Seq) - _ = pipe.Set(ctx, key, s, c.msgCacheTimeout) - } - - results, err := pipe.Exec(ctx) - if err != nil { - return 0, errs.Wrap(err) - } - - for _, res := range results { - if res.Err() != nil { - return 0, errs.Wrap(err) - } - } - - return len(msgs), nil -} - -func (c *msgCache) ParallelSetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) { - wg := errgroup.Group{} - wg.SetLimit(concurrentLimit) - - for _, msg := range msgs { - msg := msg // closure safe var - wg.Go(func() error { - s, err := msgprocessor.Pb2String(msg) - if err != nil { - return errs.Wrap(err) - } - - key := c.getMessageCacheKey(conversationID, msg.Seq) - if err := c.rdb.Set(ctx, key, s, c.msgCacheTimeout).Err(); err != nil { - return errs.Wrap(err) - } - return nil - }) - } - - err := wg.Wait() - if err != nil { - return 0, errs.WrapMsg(err, "wg.Wait failed") - } - - return len(msgs), nil -} - -func (c *msgCache) UserDeleteMsgs(ctx context.Context, conversationID string, seqs []int64, userID string) error { - for _, seq := range seqs { - delUserListKey := c.getMessageDelUserListKey(conversationID, seq) - userDelListKey := c.getUserDelList(conversationID, userID) - err := c.rdb.SAdd(ctx, delUserListKey, userID).Err() - if err != nil { - return errs.Wrap(err) - } - err = c.rdb.SAdd(ctx, userDelListKey, seq).Err() - if err != nil { - return errs.Wrap(err) - } - if err := c.rdb.Expire(ctx, delUserListKey, c.msgCacheTimeout).Err(); err != nil { - return errs.Wrap(err) - } - if err := c.rdb.Expire(ctx, userDelListKey, c.msgCacheTimeout).Err(); err != nil { - return errs.Wrap(err) - } - } - - return nil -} - -func (c *msgCache) GetUserDelList(ctx context.Context, userID, conversationID string) (seqs []int64, err error) { - result, err := c.rdb.SMembers(ctx, c.getUserDelList(conversationID, userID)).Result() - if err != nil { - return nil, errs.Wrap(err) - } - seqs = make([]int64, len(result)) - for i, v := range result { - seqs[i] = stringutil.StringToInt64(v) - } - - return seqs, nil -} - -func (c *msgCache) DelUserDeleteMsgsList(ctx context.Context, conversationID string, seqs []int64) { - for _, seq := range seqs { - delUsers, err := c.rdb.SMembers(ctx, c.getMessageDelUserListKey(conversationID, seq)).Result() - if err != nil { - log.ZWarn(ctx, "DelUserDeleteMsgsList failed", err, "conversationID", conversationID, "seq", seq) - - continue - } - if len(delUsers) > 0 { - var failedFlag bool - for _, userID := range delUsers { - err = c.rdb.SRem(ctx, c.getUserDelList(conversationID, userID), seq).Err() +func (c *msgCache) SetMessagesToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) { + msgMap := datautil.SliceToMap(msgs, func(msg *sdkws.MsgData) string { + return c.getMessageCacheKey(conversationID, msg.Seq) + }) + keys := datautil.Slice(msgs, func(msg *sdkws.MsgData) string { + return c.getMessageCacheKey(conversationID, msg.Seq) + }) + err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error { + var values []string + for _, key := range keys { + if msg, ok := msgMap[key]; ok { + s, err := msgprocessor.Pb2String(msg) if err != nil { - failedFlag = true - log.ZWarn(ctx, "DelUserDeleteMsgsList failed", err, "conversationID", conversationID, "seq", seq, "userID", userID) - } - } - if !failedFlag { - if err := c.rdb.Del(ctx, c.getMessageDelUserListKey(conversationID, seq)).Err(); err != nil { - log.ZWarn(ctx, "DelUserDeleteMsgsList failed", err, "conversationID", conversationID, "seq", seq) + return err } + values = append(values, s) } } - } -} - -func (c *msgCache) DeleteMessages(ctx context.Context, conversationID string, seqs []int64) error { - if c.redisEnablePipeline { - return c.PipeDeleteMessages(ctx, conversationID, seqs) - } - - return c.ParallelDeleteMessages(ctx, conversationID, seqs) -} - -func (c *msgCache) ParallelDeleteMessages(ctx context.Context, conversationID string, seqs []int64) error { - wg := errgroup.Group{} - wg.SetLimit(concurrentLimit) - - for _, seq := range seqs { - seq := seq - wg.Go(func() error { - err := c.rdb.Del(ctx, c.getMessageCacheKey(conversationID, seq)).Err() - if err != nil { - return errs.Wrap(err) - } - return nil - }) - } - - return wg.Wait() -} - -func (c *msgCache) PipeDeleteMessages(ctx context.Context, conversationID string, seqs []int64) error { - pipe := c.rdb.Pipeline() - for _, seq := range seqs { - _ = pipe.Del(ctx, c.getMessageCacheKey(conversationID, seq)) - } - - results, err := pipe.Exec(ctx) - if err != nil { - return errs.WrapMsg(err, "pipe.del") - } - - for _, res := range results { - if res.Err() != nil { - return errs.Wrap(err) - } - } - - return nil -} - -func (c *msgCache) CleanUpOneConversationAllMsg(ctx context.Context, conversationID string) error { - vals, err := c.rdb.Keys(ctx, c.getAllMessageCacheKey(conversationID)).Result() - if errors.Is(err, redis.Nil) { - return nil - } + return LuaSetBatchWithCommonExpire(ctx, c.rdb, keys, values, msgCacheTimeout) + }) if err != nil { - return errs.Wrap(err) - } - for _, v := range vals { - if err := c.rdb.Del(ctx, v).Err(); err != nil { - return errs.Wrap(err) - } + return 0, err } - return nil + return len(msgs), nil } -func (c *msgCache) DelMsgFromCache(ctx context.Context, userID string, seqs []int64) error { +func (c *msgCache) DeleteMessagesFromCache(ctx context.Context, conversationID string, seqs []int64) error { + var keys []string for _, seq := range seqs { - key := c.getMessageCacheKey(userID, seq) - result, err := c.rdb.Get(ctx, key).Result() - if err != nil { - if errors.Is(err, redis.Nil) { - continue - } - - return errs.Wrap(err) - } - var msg sdkws.MsgData - err = jsonpb.UnmarshalString(result, &msg) - if err != nil { - return err - } - msg.Status = constant.MsgDeleted - s, err := msgprocessor.Pb2String(&msg) - if err != nil { - return errs.Wrap(err) - } - if err := c.rdb.Set(ctx, key, s, c.msgCacheTimeout).Err(); err != nil { - return errs.Wrap(err) - } + keys = append(keys, c.getMessageCacheKey(conversationID, seq)) } - return nil + return ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error { + return LuaDeleteBatch(ctx, c.rdb, keys) + }) } func (c *msgCache) SetSendMsgStatus(ctx context.Context, id string, status int32) error { @@ -338,102 +149,39 @@ func (c *msgCache) DeleteOneMessageKey(ctx context.Context, clientMsgID string, } func (c *msgCache) GetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsgs []*sdkws.MsgData, failedSeqs []int64, err error) { - if c.redisEnablePipeline { - return c.PipeGetMessagesBySeq(ctx, conversationID, seqs) - } - - return c.ParallelGetMessagesBySeq(ctx, conversationID, seqs) -} - -func (c *msgCache) PipeGetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsgs []*sdkws.MsgData, failedSeqs []int64, err error) { - pipe := c.rdb.Pipeline() - - results := []*redis.StringCmd{} + var keys []string + keySeqMap := make(map[string]int64, 10) for _, seq := range seqs { - results = append(results, pipe.Get(ctx, c.getMessageCacheKey(conversationID, seq))) + key := c.getMessageCacheKey(conversationID, seq) + keys = append(keys, key) + keySeqMap[key] = seq } - - _, err = pipe.Exec(ctx) - if err != nil && err != redis.Nil { - return seqMsgs, failedSeqs, errs.WrapMsg(err, "pipe.get") - } - - for idx, res := range results { - seq := seqs[idx] - if res.Err() != nil { - log.ZError(ctx, "GetMessagesBySeq failed", err, "conversationID", conversationID, "seq", seq, "err", res.Err()) - failedSeqs = append(failedSeqs, seq) - continue - } - - msg := sdkws.MsgData{} - if err = msgprocessor.String2Pb(res.Val(), &msg); err != nil { - log.ZError(ctx, "GetMessagesBySeq Unmarshal failed", err, "res", res, "conversationID", conversationID, "seq", seq) - failedSeqs = append(failedSeqs, seq) - continue - } - - if msg.Status == constant.MsgDeleted { - failedSeqs = append(failedSeqs, seq) - continue + err = ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error { + result, err := LuaGetBatch(ctx, c.rdb, keys) + if err != nil { + return err } - - seqMsgs = append(seqMsgs, &msg) - } - - return -} - -func (c *msgCache) ParallelGetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsgs []*sdkws.MsgData, failedSeqs []int64, err error) { - type entry struct { - err error - msg *sdkws.MsgData - } - - wg := errgroup.Group{} - wg.SetLimit(concurrentLimit) - - results := make([]entry, len(seqs)) // set slice len/cap to length of seqs. - for idx, seq := range seqs { - // closure safe var - idx := idx - seq := seq - - wg.Go(func() error { - res, err := c.rdb.Get(ctx, c.getMessageCacheKey(conversationID, seq)).Result() - if err != nil { - log.ZError(ctx, "GetMessagesBySeq failed", err, "conversationID", conversationID, "seq", seq) - results[idx] = entry{err: err} - return nil - } - - msg := sdkws.MsgData{} - if err = msgprocessor.String2Pb(res, &msg); err != nil { - log.ZError(ctx, "GetMessagesBySeq Unmarshal failed", err, "res", res, "conversationID", conversationID, "seq", seq) - results[idx] = entry{err: err} - return nil + for i, value := range result { + seq := keySeqMap[keys[i]] + if value == nil { + failedSeqs = append(failedSeqs, seq) + continue } - if msg.Status == constant.MsgDeleted { - results[idx] = entry{err: err} - return nil + msg := &sdkws.MsgData{} + msgString, ok := value.(string) + if !ok || msgprocessor.String2Pb(msgString, msg) != nil { + failedSeqs = append(failedSeqs, seq) + continue } + seqMsgs = append(seqMsgs, msg) - results[idx] = entry{msg: &msg} - return nil - }) - } - - _ = wg.Wait() - - for idx, res := range results { - if res.err != nil { - failedSeqs = append(failedSeqs, seqs[idx]) - continue } - - seqMsgs = append(seqMsgs, res.msg) + return nil + }) + if err != nil { + return nil, nil, err } + return seqMsgs, failedSeqs, nil - return } diff --git a/pkg/common/storage/cache/redis/msg_test.go b/pkg/common/storage/cache/redis/msg_test.go index d47fa18e18..10b9ce18b0 100644 --- a/pkg/common/storage/cache/redis/msg_test.go +++ b/pkg/common/storage/cache/redis/msg_test.go @@ -4,14 +4,13 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - package redis import ( @@ -20,381 +19,115 @@ import ( "github.com/openimsdk/protocol/sdkws" "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" - "math/rand" + "google.golang.org/protobuf/proto" "testing" ) -func TestParallelSetMessageToCache(t *testing.T) { - var ( - cid = fmt.Sprintf("cid-%v", rand.Int63()) - seqFirst = rand.Int63() - msgs = []*sdkws.MsgData{} - ) - - for i := 0; i < 100; i++ { - msgs = append(msgs, &sdkws.MsgData{ - Seq: seqFirst + int64(i), - }) - } - - testParallelSetMessageToCache(t, cid, msgs) -} - -func testParallelSetMessageToCache(t *testing.T, cid string, msgs []*sdkws.MsgData) { - rdb := redis.NewClient(&redis.Options{}) - defer rdb.Close() - - cacher := msgCache{rdb: rdb} - - ret, err := cacher.ParallelSetMessageToCache(context.Background(), cid, msgs) - assert.Nil(t, err) - assert.Equal(t, len(msgs), ret) - - // validate - for _, msg := range msgs { - key := cacher.getMessageCacheKey(cid, msg.Seq) - val, err := rdb.Exists(context.Background(), key).Result() - assert.Nil(t, err) - assert.EqualValues(t, 1, val) - } -} - -func TestPipeSetMessageToCache(t *testing.T) { - var ( - cid = fmt.Sprintf("cid-%v", rand.Int63()) - seqFirst = rand.Int63() - msgs = []*sdkws.MsgData{} - ) - - for i := 0; i < 100; i++ { - msgs = append(msgs, &sdkws.MsgData{ - Seq: seqFirst + int64(i), +func Test_msgCache_SetMessagesToCache(t *testing.T) { + type fields struct { + rdb redis.UniversalClient + } + type args struct { + ctx context.Context + conversationID string + msgs []*sdkws.MsgData + } + tests := []struct { + name string + fields fields + args args + want int + wantErr assert.ErrorAssertionFunc + }{ + {"test1", fields{rdb: redis.NewClient(&redis.Options{Addr: "localhost:16379", Username: "", Password: "openIM123", DB: 0})}, args{context.Background(), + "cid", []*sdkws.MsgData{{Seq: 1}, {Seq: 2}, {Seq: 3}}}, 3, assert.NoError}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &msgCache{ + rdb: tt.fields.rdb, + } + got, err := c.SetMessagesToCache(tt.args.ctx, tt.args.conversationID, tt.args.msgs) + if !tt.wantErr(t, err, fmt.Sprintf("SetMessagesToCache(%v, %v, %v)", tt.args.ctx, tt.args.conversationID, tt.args.msgs)) { + return + } + assert.Equalf(t, tt.want, got, "SetMessagesToCache(%v, %v, %v)", tt.args.ctx, tt.args.conversationID, tt.args.msgs) }) } - - testPipeSetMessageToCache(t, cid, msgs) -} - -func testPipeSetMessageToCache(t *testing.T, cid string, msgs []*sdkws.MsgData) { - rdb := redis.NewClient(&redis.Options{}) - defer rdb.Close() - - cacher := msgCache{rdb: rdb} - - ret, err := cacher.PipeSetMessageToCache(context.Background(), cid, msgs) - assert.Nil(t, err) - assert.Equal(t, len(msgs), ret) - - // validate - for _, msg := range msgs { - key := cacher.getMessageCacheKey(cid, msg.Seq) - val, err := rdb.Exists(context.Background(), key).Result() - assert.Nil(t, err) - assert.EqualValues(t, 1, val) - } -} - -func TestGetMessagesBySeq(t *testing.T) { - var ( - cid = fmt.Sprintf("cid-%v", rand.Int63()) - seqFirst = rand.Int63() - msgs = []*sdkws.MsgData{} - ) - - seqs := []int64{} - for i := 0; i < 100; i++ { - msgs = append(msgs, &sdkws.MsgData{ - Seq: seqFirst + int64(i), - SendID: fmt.Sprintf("fake-sendid-%v", i), - }) - seqs = append(seqs, seqFirst+int64(i)) - } - - // set data to cache - testPipeSetMessageToCache(t, cid, msgs) - - // get data from cache with parallet mode - testParallelGetMessagesBySeq(t, cid, seqs, msgs) - - // get data from cache with pipeline mode - testPipeGetMessagesBySeq(t, cid, seqs, msgs) -} - -func testParallelGetMessagesBySeq(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) { - rdb := redis.NewClient(&redis.Options{}) - defer rdb.Close() - - cacher := msgCache{rdb: rdb} - - respMsgs, failedSeqs, err := cacher.ParallelGetMessagesBySeq(context.Background(), cid, seqs) - assert.Nil(t, err) - assert.Equal(t, 0, len(failedSeqs)) - assert.Equal(t, len(respMsgs), len(seqs)) - - // validate - for idx, msg := range respMsgs { - assert.Equal(t, msg.Seq, inputMsgs[idx].Seq) - assert.Equal(t, msg.SendID, inputMsgs[idx].SendID) - } -} - -func testPipeGetMessagesBySeq(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) { - rdb := redis.NewClient(&redis.Options{}) - defer rdb.Close() - - cacher := msgCache{rdb: rdb} - - respMsgs, failedSeqs, err := cacher.PipeGetMessagesBySeq(context.Background(), cid, seqs) - assert.Nil(t, err) - assert.Equal(t, 0, len(failedSeqs)) - assert.Equal(t, len(respMsgs), len(seqs)) - - // validate - for idx, msg := range respMsgs { - assert.Equal(t, msg.Seq, inputMsgs[idx].Seq) - assert.Equal(t, msg.SendID, inputMsgs[idx].SendID) - } -} - -func TestGetMessagesBySeqWithEmptySeqs(t *testing.T) { - var ( - cid = fmt.Sprintf("cid-%v", rand.Int63()) - seqFirst int64 = 0 - msgs = []*sdkws.MsgData{} - ) - - seqs := []int64{} - for i := 0; i < 100; i++ { - msgs = append(msgs, &sdkws.MsgData{ - Seq: seqFirst + int64(i), - SendID: fmt.Sprintf("fake-sendid-%v", i), - }) - seqs = append(seqs, seqFirst+int64(i)) - } - - // don't set cache, only get data from cache. - - // get data from cache with parallet mode - testParallelGetMessagesBySeqWithEmptry(t, cid, seqs, msgs) - - // get data from cache with pipeline mode - testPipeGetMessagesBySeqWithEmptry(t, cid, seqs, msgs) -} - -func testParallelGetMessagesBySeqWithEmptry(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) { - rdb := redis.NewClient(&redis.Options{}) - defer rdb.Close() - - cacher := msgCache{rdb: rdb} - - respMsgs, failedSeqs, err := cacher.ParallelGetMessagesBySeq(context.Background(), cid, seqs) - assert.Nil(t, err) - assert.Equal(t, len(seqs), len(failedSeqs)) - assert.Equal(t, 0, len(respMsgs)) } -func testPipeGetMessagesBySeqWithEmptry(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) { - rdb := redis.NewClient(&redis.Options{}) - defer rdb.Close() - - cacher := msgCache{rdb: rdb} - - respMsgs, failedSeqs, err := cacher.PipeGetMessagesBySeq(context.Background(), cid, seqs) - assert.Equal(t, err, redis.Nil) - assert.Equal(t, len(seqs), len(failedSeqs)) - assert.Equal(t, 0, len(respMsgs)) -} - -func TestGetMessagesBySeqWithLostHalfSeqs(t *testing.T) { - var ( - cid = fmt.Sprintf("cid-%v", rand.Int63()) - seqFirst int64 = 0 - msgs = []*sdkws.MsgData{} - ) - - seqs := []int64{} - for i := 0; i < 100; i++ { - msgs = append(msgs, &sdkws.MsgData{ - Seq: seqFirst + int64(i), - SendID: fmt.Sprintf("fake-sendid-%v", i), +func Test_msgCache_GetMessagesBySeq(t *testing.T) { + type fields struct { + rdb redis.UniversalClient + } + type args struct { + ctx context.Context + conversationID string + seqs []int64 + } + var failedSeq []int64 + tests := []struct { + name string + fields fields + args args + wantSeqMsgs []*sdkws.MsgData + wantFailedSeqs []int64 + wantErr assert.ErrorAssertionFunc + }{ + {"test1", fields{rdb: redis.NewClient(&redis.Options{Addr: "localhost:16379", Password: "openIM123", DB: 0})}, + args{context.Background(), "cid", []int64{1, 2, 3}}, + []*sdkws.MsgData{{Seq: 1}, {Seq: 2}, {Seq: 3}}, failedSeq, assert.NoError}, + {"test2", fields{rdb: redis.NewClient(&redis.Options{Addr: "localhost:16379", Password: "openIM123", DB: 0})}, + args{context.Background(), "cid", []int64{4, 5, 6}}, + nil, []int64{4, 5, 6}, assert.NoError}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &msgCache{ + rdb: tt.fields.rdb, + } + gotSeqMsgs, gotFailedSeqs, err := c.GetMessagesBySeq(tt.args.ctx, tt.args.conversationID, tt.args.seqs) + if !tt.wantErr(t, err, fmt.Sprintf("GetMessagesBySeq(%v, %v, %v)", tt.args.ctx, tt.args.conversationID, tt.args.seqs)) { + return + } + equalMsgDataSlices(t, tt.wantSeqMsgs, gotSeqMsgs) + assert.Equalf(t, tt.wantFailedSeqs, gotFailedSeqs, "GetMessagesBySeq(%v, %v, %v)", tt.args.ctx, tt.args.conversationID, tt.args.seqs) }) - seqs = append(seqs, seqFirst+int64(i)) } - - // Only set half the number of messages. - testParallelSetMessageToCache(t, cid, msgs[:50]) - - // get data from cache with parallet mode - testParallelGetMessagesBySeqWithLostHalfSeqs(t, cid, seqs, msgs) - - // get data from cache with pipeline mode - testPipeGetMessagesBySeqWithLostHalfSeqs(t, cid, seqs, msgs) } -func testParallelGetMessagesBySeqWithLostHalfSeqs(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) { - rdb := redis.NewClient(&redis.Options{}) - defer rdb.Close() - - cacher := msgCache{rdb: rdb} - - respMsgs, failedSeqs, err := cacher.ParallelGetMessagesBySeq(context.Background(), cid, seqs) - assert.Nil(t, err) - assert.Equal(t, len(seqs)/2, len(failedSeqs)) - assert.Equal(t, len(seqs)/2, len(respMsgs)) - - for idx, msg := range respMsgs { - assert.Equal(t, msg.Seq, seqs[idx]) +func equalMsgDataSlices(t *testing.T, expected, actual []*sdkws.MsgData) { + assert.Equal(t, len(expected), len(actual), "Slices have different lengths") + for i := range expected { + assert.True(t, proto.Equal(expected[i], actual[i]), "Element %d not equal: expected %v, got %v", i, expected[i], actual[i]) } } -func testPipeGetMessagesBySeqWithLostHalfSeqs(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) { - rdb := redis.NewClient(&redis.Options{}) - defer rdb.Close() - - cacher := msgCache{rdb: rdb} - - respMsgs, failedSeqs, err := cacher.PipeGetMessagesBySeq(context.Background(), cid, seqs) - assert.Nil(t, err) - assert.Equal(t, len(seqs)/2, len(failedSeqs)) - assert.Equal(t, len(seqs)/2, len(respMsgs)) - - for idx, msg := range respMsgs { - assert.Equal(t, msg.Seq, seqs[idx]) +func Test_msgCache_DeleteMessagesFromCache(t *testing.T) { + type fields struct { + rdb redis.UniversalClient } -} - -func TestPipeDeleteMessages(t *testing.T) { - var ( - cid = fmt.Sprintf("cid-%v", rand.Int63()) - seqFirst = rand.Int63() - msgs = []*sdkws.MsgData{} - ) - - var seqs []int64 - for i := 0; i < 100; i++ { - msgs = append(msgs, &sdkws.MsgData{ - Seq: seqFirst + int64(i), - }) - seqs = append(seqs, msgs[i].Seq) + type args struct { + ctx context.Context + conversationID string + seqs []int64 } - - testPipeSetMessageToCache(t, cid, msgs) - testPipeDeleteMessagesOK(t, cid, seqs, msgs) - - // set again - testPipeSetMessageToCache(t, cid, msgs) - testPipeDeleteMessagesMix(t, cid, seqs[:90], msgs) -} - -func testPipeDeleteMessagesOK(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) { - rdb := redis.NewClient(&redis.Options{}) - defer rdb.Close() - - cacher := msgCache{rdb: rdb} - - err := cacher.PipeDeleteMessages(context.Background(), cid, seqs) - assert.Nil(t, err) - - // validate - for _, msg := range inputMsgs { - key := cacher.getMessageCacheKey(cid, msg.Seq) - val := rdb.Exists(context.Background(), key).Val() - assert.EqualValues(t, 0, val) + tests := []struct { + name string + fields fields + args args + wantErr assert.ErrorAssertionFunc + }{ + {"test1", fields{rdb: redis.NewClient(&redis.Options{Addr: "localhost:16379", Password: "openIM123"})}, + args{context.Background(), "cid", []int64{1, 2, 3}}, assert.NoError}, } -} - -func testPipeDeleteMessagesMix(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) { - rdb := redis.NewClient(&redis.Options{}) - defer rdb.Close() - - cacher := msgCache{rdb: rdb} - - err := cacher.PipeDeleteMessages(context.Background(), cid, seqs) - assert.Nil(t, err) - - // validate - for idx, msg := range inputMsgs { - key := cacher.getMessageCacheKey(cid, msg.Seq) - val, err := rdb.Exists(context.Background(), key).Result() - assert.Nil(t, err) - if idx < 90 { - assert.EqualValues(t, 0, val) // not exists - continue - } - - assert.EqualValues(t, 1, val) // exists - } -} - -func TestParallelDeleteMessages(t *testing.T) { - var ( - cid = fmt.Sprintf("cid-%v", rand.Int63()) - seqFirst = rand.Int63() - msgs = []*sdkws.MsgData{} - ) - - var seqs []int64 - for i := 0; i < 100; i++ { - msgs = append(msgs, &sdkws.MsgData{ - Seq: seqFirst + int64(i), + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &msgCache{ + rdb: tt.fields.rdb, + } + tt.wantErr(t, c.DeleteMessagesFromCache(tt.args.ctx, tt.args.conversationID, tt.args.seqs), + fmt.Sprintf("DeleteMessagesFromCache(%v, %v, %v)", tt.args.ctx, tt.args.conversationID, tt.args.seqs)) }) - seqs = append(seqs, msgs[i].Seq) - } - - randSeqs := []int64{} - for i := seqFirst + 100; i < seqFirst+200; i++ { - randSeqs = append(randSeqs, i) - } - - testParallelSetMessageToCache(t, cid, msgs) - testParallelDeleteMessagesOK(t, cid, seqs, msgs) - - // set again - testParallelSetMessageToCache(t, cid, msgs) - testParallelDeleteMessagesMix(t, cid, seqs[:90], msgs, 90) - testParallelDeleteMessagesOK(t, cid, seqs[90:], msgs[:90]) - - // set again - testParallelSetMessageToCache(t, cid, msgs) - testParallelDeleteMessagesMix(t, cid, randSeqs, msgs, 0) -} - -func testParallelDeleteMessagesOK(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) { - rdb := redis.NewClient(&redis.Options{}) - defer rdb.Close() - - cacher := msgCache{rdb: rdb} - - err := cacher.PipeDeleteMessages(context.Background(), cid, seqs) - assert.Nil(t, err) - - // validate - for _, msg := range inputMsgs { - key := cacher.getMessageCacheKey(cid, msg.Seq) - val := rdb.Exists(context.Background(), key).Val() - assert.EqualValues(t, 0, val) - } -} - -func testParallelDeleteMessagesMix(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData, lessValNonExists int) { - rdb := redis.NewClient(&redis.Options{}) - defer rdb.Close() - - cacher := msgCache{rdb: rdb} - - err := cacher.PipeDeleteMessages(context.Background(), cid, seqs) - assert.Nil(t, err) - - // validate - for idx, msg := range inputMsgs { - key := cacher.getMessageCacheKey(cid, msg.Seq) - val, err := rdb.Exists(context.Background(), key).Result() - assert.Nil(t, err) - if idx < lessValNonExists { - assert.EqualValues(t, 0, val) // not exists - continue - } - - assert.EqualValues(t, 1, val) // exists } } diff --git a/pkg/common/storage/cache/redis/redis_shard_manager.go b/pkg/common/storage/cache/redis/redis_shard_manager.go new file mode 100644 index 0000000000..98d70dabf9 --- /dev/null +++ b/pkg/common/storage/cache/redis/redis_shard_manager.go @@ -0,0 +1,197 @@ +package redis + +import ( + "context" + "github.com/openimsdk/tools/errs" + "github.com/openimsdk/tools/log" + "github.com/redis/go-redis/v9" + "golang.org/x/sync/errgroup" +) + +const ( + defaultBatchSize = 50 + defaultConcurrentLimit = 3 +) + +// RedisShardManager is a class for sharding and processing keys +type RedisShardManager struct { + redisClient redis.UniversalClient + config *Config +} +type Config struct { + batchSize int + continueOnError bool + concurrentLimit int +} + +// Option is a function type for configuring Config +type Option func(c *Config) + +// NewRedisShardManager creates a new RedisShardManager instance +func NewRedisShardManager(redisClient redis.UniversalClient, opts ...Option) *RedisShardManager { + config := &Config{ + batchSize: defaultBatchSize, // Default batch size is 50 keys + continueOnError: false, + concurrentLimit: defaultConcurrentLimit, // Default concurrent limit is 3 + } + for _, opt := range opts { + opt(config) + } + rsm := &RedisShardManager{ + redisClient: redisClient, + config: config, + } + return rsm +} + +// WithBatchSize sets the number of keys to process per batch +func WithBatchSize(size int) Option { + return func(c *Config) { + c.batchSize = size + } +} + +// WithContinueOnError sets whether to continue processing on error +func WithContinueOnError(continueOnError bool) Option { + return func(c *Config) { + c.continueOnError = continueOnError + } +} + +// WithConcurrentLimit sets the concurrency limit +func WithConcurrentLimit(limit int) Option { + return func(c *Config) { + c.concurrentLimit = limit + } +} + +// ProcessKeysBySlot groups keys by their Redis cluster hash slots and processes them using the provided function. +func (rsm *RedisShardManager) ProcessKeysBySlot( + ctx context.Context, + keys []string, + processFunc func(ctx context.Context, slot int64, keys []string) error, +) error { + + // Group keys by slot + slots, err := groupKeysBySlot(ctx, rsm.redisClient, keys) + if err != nil { + return err + } + + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(rsm.config.concurrentLimit) + + // Process keys in each slot using the provided function + for slot, singleSlotKeys := range slots { + batches := splitIntoBatches(singleSlotKeys, rsm.config.batchSize) + for _, batch := range batches { + slot, batch := slot, batch // Avoid closure capture issue + g.Go(func() error { + err := processFunc(ctx, slot, batch) + if err != nil { + log.ZWarn(ctx, "Batch processFunc failed", err, "slot", slot, "keys", batch) + if !rsm.config.continueOnError { + return err + } + } + return nil + }) + } + } + + if err := g.Wait(); err != nil { + return err + } + return nil +} + +// groupKeysBySlot groups keys by their Redis cluster hash slots. +func groupKeysBySlot(ctx context.Context, redisClient redis.UniversalClient, keys []string) (map[int64][]string, error) { + slots := make(map[int64][]string) + clusterClient, isCluster := redisClient.(*redis.ClusterClient) + if isCluster { + pipe := clusterClient.Pipeline() + cmds := make([]*redis.IntCmd, len(keys)) + for i, key := range keys { + cmds[i] = pipe.ClusterKeySlot(ctx, key) + } + _, err := pipe.Exec(ctx) + if err != nil { + return nil, errs.WrapMsg(err, "get slot err") + } + + for i, cmd := range cmds { + slot, err := cmd.Result() + if err != nil { + log.ZWarn(ctx, "some key get slot err", err, "key", keys[i]) + return nil, errs.WrapMsg(err, "get slot err", "key", keys[i]) + } + slots[slot] = append(slots[slot], keys[i]) + } + } else { + // If not a cluster client, put all keys in the same slot (0) + slots[0] = keys + } + + return slots, nil +} + +// splitIntoBatches splits keys into batches of the specified size +func splitIntoBatches(keys []string, batchSize int) [][]string { + var batches [][]string + for batchSize < len(keys) { + keys, batches = keys[batchSize:], append(batches, keys[0:batchSize:batchSize]) + } + return append(batches, keys) +} + +// ProcessKeysBySlot groups keys by their Redis cluster hash slots and processes them using the provided function. +func ProcessKeysBySlot( + ctx context.Context, + redisClient redis.UniversalClient, + keys []string, + processFunc func(ctx context.Context, slot int64, keys []string) error, + opts ...Option, +) error { + + config := &Config{ + batchSize: defaultBatchSize, + continueOnError: false, + concurrentLimit: defaultConcurrentLimit, + } + for _, opt := range opts { + opt(config) + } + + // Group keys by slot + slots, err := groupKeysBySlot(ctx, redisClient, keys) + if err != nil { + return err + } + + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(config.concurrentLimit) + + // Process keys in each slot using the provided function + for slot, singleSlotKeys := range slots { + batches := splitIntoBatches(singleSlotKeys, config.batchSize) + for _, batch := range batches { + slot, batch := slot, batch // Avoid closure capture issue + g.Go(func() error { + err := processFunc(ctx, slot, batch) + if err != nil { + log.ZWarn(ctx, "Batch processFunc failed", err, "slot", slot, "keys", batch) + if !config.continueOnError { + return err + } + } + return nil + }) + } + } + + if err := g.Wait(); err != nil { + return err + } + return nil +} diff --git a/pkg/common/storage/cache/user.go b/pkg/common/storage/cache/user.go index 4a129ddd18..5101c0b6ce 100644 --- a/pkg/common/storage/cache/user.go +++ b/pkg/common/storage/cache/user.go @@ -16,15 +16,15 @@ package cache import ( "context" - relationtb "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" "github.com/openimsdk/protocol/user" ) type UserCache interface { BatchDeleter CloneUserCache() UserCache - GetUserInfo(ctx context.Context, userID string) (userInfo *relationtb.User, err error) - GetUsersInfo(ctx context.Context, userIDs []string) ([]*relationtb.User, error) + GetUserInfo(ctx context.Context, userID string) (userInfo *model.User, err error) + GetUsersInfo(ctx context.Context, userIDs []string) ([]*model.User, error) DelUsersInfo(userIDs ...string) UserCache GetUserGlobalRecvMsgOpt(ctx context.Context, userID string) (opt int, err error) DelUsersGlobalRecvMsgOpt(userIDs ...string) UserCache diff --git a/pkg/common/storage/controller/msg.go b/pkg/common/storage/controller/msg.go index ce107e9237..8eb9e8e6fd 100644 --- a/pkg/common/storage/controller/msg.go +++ b/pkg/common/storage/controller/msg.go @@ -54,8 +54,6 @@ type CommonMsgDatabase interface { MarkSingleChatMsgsAsRead(ctx context.Context, userID string, conversationID string, seqs []int64) error // DeleteMessagesFromCache deletes message caches from Redis by sequence numbers. DeleteMessagesFromCache(ctx context.Context, conversationID string, seqs []int64) error - // DelUserDeleteMsgsList deletes user's message deletion list. - DelUserDeleteMsgsList(ctx context.Context, conversationID string, seqs []int64) // BatchInsertChat2Cache increments the sequence number and then batch inserts messages into the cache. BatchInsertChat2Cache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (seq int64, isNewConversation bool, err error) // GetMsgBySeqsRange retrieves messages from MongoDB by a range of sequence numbers. @@ -98,7 +96,6 @@ type CommonMsgDatabase interface { // to mq MsgToMQ(ctx context.Context, key string, msg2mq *sdkws.MsgData) error - MsgToModifyMQ(ctx context.Context, key, conversarionID string, msgs []*sdkws.MsgData) error MsgToPushMQ(ctx context.Context, key, conversarionID string, msg2mq *sdkws.MsgData) (int32, int64, error) MsgToMongoMQ(ctx context.Context, key, conversarionID string, msgs []*sdkws.MsgData, lastSeq int64) error @@ -150,14 +147,13 @@ func NewCommonMsgDatabase(msgDocModel database.Msg, msg cache.MsgCache, seq cach //} type commonMsgDatabase struct { - msgDocDatabase database.Msg - msgTable model.MsgDocModel - msg cache.MsgCache - seq cache.SeqCache - producer *kafka.Producer - producerToMongo *kafka.Producer - producerToModify *kafka.Producer - producerToPush *kafka.Producer + msgDocDatabase database.Msg + msgTable model.MsgDocModel + msg cache.MsgCache + seq cache.SeqCache + producer *kafka.Producer + producerToMongo *kafka.Producer + producerToPush *kafka.Producer } func (db *commonMsgDatabase) MsgToMQ(ctx context.Context, key string, msg2mq *sdkws.MsgData) error { @@ -165,14 +161,6 @@ func (db *commonMsgDatabase) MsgToMQ(ctx context.Context, key string, msg2mq *sd return err } -func (db *commonMsgDatabase) MsgToModifyMQ(ctx context.Context, key, conversationID string, messages []*sdkws.MsgData) error { - if len(messages) > 0 { - _, _, err := db.producerToModify.SendMessage(ctx, key, &pbmsg.MsgDataToModifyByMQ{ConversationID: conversationID, Messages: messages}) - return err - } - return nil -} - func (db *commonMsgDatabase) MsgToPushMQ(ctx context.Context, key, conversationID string, msg2mq *sdkws.MsgData) (int32, int64, error) { partition, offset, err := db.producerToPush.SendMessage(ctx, key, &pbmsg.PushMsgDataToMQ{MsgData: msg2mq, ConversationID: conversationID}) if err != nil { @@ -357,11 +345,7 @@ func (db *commonMsgDatabase) MarkSingleChatMsgsAsRead(ctx context.Context, userI } func (db *commonMsgDatabase) DeleteMessagesFromCache(ctx context.Context, conversationID string, seqs []int64) error { - return db.msg.DeleteMessages(ctx, conversationID, seqs) -} - -func (db *commonMsgDatabase) DelUserDeleteMsgsList(ctx context.Context, conversationID string, seqs []int64) { - db.msg.DelUserDeleteMsgsList(ctx, conversationID, seqs) + return db.msg.DeleteMessagesFromCache(ctx, conversationID, seqs) } func (db *commonMsgDatabase) BatchInsertChat2Cache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (seq int64, isNew bool, err error) { @@ -388,7 +372,7 @@ func (db *commonMsgDatabase) BatchInsertChat2Cache(ctx context.Context, conversa userSeqMap[m.SendID] = m.Seq } - failedNum, err := db.msg.SetMessageToCache(ctx, conversationID, msgs) + failedNum, err := db.msg.SetMessagesToCache(ctx, conversationID, msgs) if err != nil { prommetrics.MsgInsertRedisFailedCounter.Add(float64(failedNum)) log.ZError(ctx, "setMessageToCache error", err, "len", len(msgs), "conversationID", conversationID) @@ -584,6 +568,7 @@ func (db *commonMsgDatabase) GetMsgBySeqsRange(ctx context.Context, userID strin } newBegin := seqs[0] newEnd := seqs[len(seqs)-1] + var successMsgs []*sdkws.MsgData log.ZDebug(ctx, "GetMsgBySeqsRange", "first seqs", seqs, "newBegin", newBegin, "newEnd", newEnd) cachedMsgs, failedSeqs, err := db.msg.GetMessagesBySeq(ctx, conversationID, seqs) if err != nil { @@ -592,54 +577,12 @@ func (db *commonMsgDatabase) GetMsgBySeqsRange(ctx context.Context, userID strin log.ZError(ctx, "get message from redis exception", err, "conversationID", conversationID, "seqs", seqs) } } - var successMsgs []*sdkws.MsgData - if len(cachedMsgs) > 0 { - delSeqs, err := db.msg.GetUserDelList(ctx, userID, conversationID) - if err != nil && errs.Unwrap(err) != redis.Nil { - return 0, 0, nil, err - } - var cacheDelNum int - for _, msg := range cachedMsgs { - if !datautil.Contain(msg.Seq, delSeqs...) { - successMsgs = append(successMsgs, msg) - } else { - cacheDelNum += 1 - } - } - log.ZDebug(ctx, "get delSeqs from redis", "delSeqs", delSeqs, "userID", userID, "conversationID", conversationID, "cacheDelNum", cacheDelNum) - var reGetSeqsCache []int64 - for i := 1; i <= cacheDelNum; { - newSeq := newBegin - int64(i) - if newSeq >= begin { - if !datautil.Contain(newSeq, delSeqs...) { - log.ZDebug(ctx, "seq del in cache, a new seq in range append", "new seq", newSeq) - reGetSeqsCache = append(reGetSeqsCache, newSeq) - i++ - } - } else { - break - } - } - if len(reGetSeqsCache) > 0 { - log.ZDebug(ctx, "reGetSeqsCache", "reGetSeqsCache", reGetSeqsCache) - cachedMsgs, failedSeqs2, err := db.msg.GetMessagesBySeq(ctx, conversationID, reGetSeqsCache) - if err != nil { - if err != redis.Nil { - - log.ZError(ctx, "get message from redis exception", err, "conversationID", conversationID, "seqs", reGetSeqsCache) - } - } - failedSeqs = append(failedSeqs, failedSeqs2...) - successMsgs = append(successMsgs, cachedMsgs...) - } - } - log.ZDebug(ctx, "get msgs from cache", "successMsgs", successMsgs) - if len(failedSeqs) != 0 { - log.ZDebug(ctx, "msgs not exist in redis", "seqs", failedSeqs) - } - // get from cache or storage + successMsgs = append(successMsgs, cachedMsgs...) + log.ZDebug(ctx, "get msgs from cache", "cachedMsgs", cachedMsgs) + // get from cache or db if len(failedSeqs) > 0 { + log.ZDebug(ctx, "msgs not exist in redis", "seqs", failedSeqs) mongoMsgs, err := db.getMsgBySeqsRange(ctx, userID, conversationID, failedSeqs, begin, end) if err != nil { @@ -679,7 +622,7 @@ func (db *commonMsgDatabase) GetMsgBySeqs(ctx context.Context, userID string, co log.ZError(ctx, "get message from redis exception", err, "failedSeqs", failedSeqs, "conversationID", conversationID) } } - log.ZDebug(ctx, "storage.seq.GetMessagesBySeq", "userID", userID, "conversationID", conversationID, "seqs", + log.ZDebug(ctx, "db.seq.GetMessagesBySeq", "userID", userID, "conversationID", conversationID, "seqs", seqs, "len(successMsgs)", len(successMsgs), "failedSeqs", failedSeqs) if len(failedSeqs) > 0 { @@ -705,12 +648,6 @@ func (db *commonMsgDatabase) DeleteConversationMsgsAndSetMinSeq(ctx context.Cont if minSeq == 0 { return nil } - if remainTime == 0 { - err = db.msg.CleanUpOneConversationAllMsg(ctx, conversationID) - if err != nil { - log.ZWarn(ctx, "CleanUpOneUserAllMsg", err, "conversationID", conversationID) - } - } return db.seq.SetMinSeq(ctx, conversationID, minSeq) } @@ -830,7 +767,7 @@ func (db *commonMsgDatabase) deleteMsgRecursion(ctx context.Context, conversatio } func (db *commonMsgDatabase) DeleteMsgsPhysicalBySeqs(ctx context.Context, conversationID string, allSeqs []int64) error { - if err := db.msg.DeleteMessages(ctx, conversationID, allSeqs); err != nil { + if err := db.msg.DeleteMessagesFromCache(ctx, conversationID, allSeqs); err != nil { return err } for docID, seqs := range db.msgTable.GetDocIDSeqsMap(conversationID, allSeqs) { @@ -846,21 +783,9 @@ func (db *commonMsgDatabase) DeleteMsgsPhysicalBySeqs(ctx context.Context, conve } func (db *commonMsgDatabase) DeleteUserMsgsBySeqs(ctx context.Context, userID string, conversationID string, seqs []int64) error { - cachedMsgs, _, err := db.msg.GetMessagesBySeq(ctx, conversationID, seqs) - if err != nil && errs.Unwrap(err) != redis.Nil { - log.ZWarn(ctx, "DeleteUserMsgsBySeqs", err, "conversationID", conversationID, "seqs", seqs) + if err := db.msg.DeleteMessagesFromCache(ctx, conversationID, seqs); err != nil { return err } - if len(cachedMsgs) > 0 { - var cacheSeqs []int64 - for _, msg := range cachedMsgs { - cacheSeqs = append(cacheSeqs, msg.Seq) - } - if err := db.msg.UserDeleteMsgs(ctx, conversationID, cacheSeqs, userID); err != nil { - return err - } - } - for docID, seqs := range db.msgTable.GetDocIDSeqsMap(conversationID, seqs) { for _, seq := range seqs { if _, err := db.msgDocDatabase.PushUnique(ctx, docID, db.msgTable.GetMsgIndex(seq), "del_list", []string{userID}); err != nil { @@ -1085,14 +1010,14 @@ func (db *commonMsgDatabase) DeleteDocMsgBefore(ctx context.Context, ts int64, d } } -//func (storage *commonMsgDatabase) ClearMsg(ctx context.Context, ts int64) (err error) { +//func (db *commonMsgDatabase) ClearMsg(ctx context.Context, ts int64) (err error) { // var ( // docNum int // msgNum int // start = time.Now() // ) // for { -// msgs, err := storage.msgDocDatabase.GetBeforeMsg(ctx, ts, 100) +// msgs, err := db.msgDocDatabase.GetBeforeMsg(ctx, ts, 100) // if err != nil { // return err // } @@ -1100,7 +1025,7 @@ func (db *commonMsgDatabase) DeleteDocMsgBefore(ctx context.Context, ts int64, d // return nil // } // for _, msg := range msgs { -// num, err := storage.deleteOneMsg(ctx, ts, msg) +// num, err := db.deleteOneMsg(ctx, ts, msg) // if err != nil { // return err // } diff --git a/pkg/common/storage/database/mgo/msg.go b/pkg/common/storage/database/mgo/msg.go index f676c1f59c..a7291fcc8f 100644 --- a/pkg/common/storage/database/mgo/msg.go +++ b/pkg/common/storage/database/mgo/msg.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" + "github.com/openimsdk/tools/utils/datautil" "time" "github.com/openimsdk/protocol/constant" @@ -108,29 +109,11 @@ func (m *MsgMgo) GetMsgBySeqIndexIn1Doc(ctx context.Context, userID, docID strin {Key: "input", Value: indexs}, {Key: "as", Value: "index"}, {Key: "in", Value: bson.D{ - {Key: "$let", Value: bson.D{ - {Key: "vars", Value: bson.D{ - {Key: "currentMsg", Value: bson.D{ - {Key: "$arrayElemAt", Value: bson.A{"$msgs", "$$index"}}, - }}, - }}, - {Key: "in", Value: bson.D{ - {Key: "$cond", Value: bson.D{ - {Key: "if", Value: bson.D{ - {Key: "$in", Value: bson.A{userID, "$$currentMsg.del_list"}}, - }}, - {Key: "then", Value: nil}, - {Key: "else", Value: "$$currentMsg"}, - }}, - }}, - }}, + {Key: "$arrayElemAt", Value: bson.A{"$msgs", "$$index"}}, }}, }}, }}, }}}, - bson.D{{Key: "$project", Value: bson.D{ - {Key: "msgs.del_list", Value: 0}, - }}}, } msgDocModel, err := mongoutil.Aggregate[*model.MsgDocModel](ctx, m.coll, pipeline) if err != nil { @@ -145,6 +128,10 @@ func (m *MsgMgo) GetMsgBySeqIndexIn1Doc(ctx context.Context, userID, docID strin if msg == nil || msg.Msg == nil { continue } + if datautil.Contain(userID, msg.DelList...) { + msg.Msg.Content = "" + msg.Msg.Status = constant.MsgDeleted + } if msg.Revoke != nil { revokeContent := sdkws.MessageRevokedContent{ RevokerID: msg.Revoke.UserID, diff --git a/pkg/tools/batcher/batcher.go b/pkg/tools/batcher/batcher.go new file mode 100644 index 0000000000..163aeed399 --- /dev/null +++ b/pkg/tools/batcher/batcher.go @@ -0,0 +1,272 @@ +package batcher + +import ( + "context" + "fmt" + "github.com/openimsdk/tools/errs" + "github.com/openimsdk/tools/utils/idutil" + "strings" + "sync" + "time" +) + +var ( + DefaultDataChanSize = 1000 + DefaultSize = 100 + DefaultBuffer = 100 + DefaultWorker = 5 + DefaultInterval = time.Second +) + +type Config struct { + size int // Number of message aggregations + buffer int // The number of caches running in a single coroutine + dataBuffer int // The size of the main data channel + worker int // Number of coroutines processed in parallel + interval time.Duration // Time of message aggregations + syncWait bool // Whether to wait synchronously after distributing messages have been consumed +} + +type Option func(c *Config) + +func WithSize(s int) Option { + return func(c *Config) { + c.size = s + } +} + +func WithBuffer(b int) Option { + return func(c *Config) { + c.buffer = b + } +} + +func WithWorker(w int) Option { + return func(c *Config) { + c.worker = w + } +} + +func WithInterval(i time.Duration) Option { + return func(c *Config) { + c.interval = i + } +} + +func WithSyncWait(wait bool) Option { + return func(c *Config) { + c.syncWait = wait + } +} + +func WithDataBuffer(size int) Option { + return func(c *Config) { + c.dataBuffer = size + } +} + +type Batcher[T any] struct { + config *Config + + globalCtx context.Context + cancel context.CancelFunc + Do func(ctx context.Context, channelID int, val *Msg[T]) + OnComplete func(lastMessage *T, totalCount int) + Sharding func(key string) int + Key func(data *T) string + HookFunc func(triggerID string, messages map[string][]*T, totalCount int, lastMessage *T) + data chan *T + chArrays []chan *Msg[T] + wait sync.WaitGroup + counter sync.WaitGroup +} + +func emptyOnComplete[T any](*T, int) {} +func emptyHookFunc[T any](string, map[string][]*T, int, *T) { +} + +func New[T any](opts ...Option) *Batcher[T] { + b := &Batcher[T]{ + OnComplete: emptyOnComplete[T], + HookFunc: emptyHookFunc[T], + } + config := &Config{ + size: DefaultSize, + buffer: DefaultBuffer, + worker: DefaultWorker, + interval: DefaultInterval, + } + for _, opt := range opts { + opt(config) + } + b.config = config + b.data = make(chan *T, DefaultDataChanSize) + b.globalCtx, b.cancel = context.WithCancel(context.Background()) + + b.chArrays = make([]chan *Msg[T], b.config.worker) + for i := 0; i < b.config.worker; i++ { + b.chArrays[i] = make(chan *Msg[T], b.config.buffer) + } + return b +} + +func (b *Batcher[T]) Worker() int { + return b.config.worker +} + +func (b *Batcher[T]) Start() error { + if b.Sharding == nil { + return errs.New("Sharding function is required").Wrap() + } + if b.Do == nil { + return errs.New("Do function is required").Wrap() + } + if b.Key == nil { + return errs.New("Key function is required").Wrap() + } + b.wait.Add(b.config.worker) + for i := 0; i < b.config.worker; i++ { + go b.run(i, b.chArrays[i]) + } + b.wait.Add(1) + go b.scheduler() + return nil +} + +func (b *Batcher[T]) Put(ctx context.Context, data *T) error { + if data == nil { + return errs.New("data can not be nil").Wrap() + } + select { + case <-b.globalCtx.Done(): + return errs.New("data channel is closed").Wrap() + case <-ctx.Done(): + return ctx.Err() + case b.data <- data: + return nil + } +} + +func (b *Batcher[T]) scheduler() { + ticker := time.NewTicker(b.config.interval) + defer func() { + ticker.Stop() + for _, ch := range b.chArrays { + close(ch) + } + close(b.data) + b.wait.Done() + }() + + vals := make(map[string][]*T) + count := 0 + var lastAny *T + + for { + select { + case data, ok := <-b.data: + if !ok { + // If the data channel is closed unexpectedly + return + } + if data == nil { + if count > 0 { + b.distributeMessage(vals, count, lastAny) + } + return + } + + key := b.Key(data) + vals[key] = append(vals[key], data) + lastAny = data + + count++ + if count >= b.config.size { + + b.distributeMessage(vals, count, lastAny) + vals = make(map[string][]*T) + count = 0 + } + + case <-ticker.C: + if count > 0 { + + b.distributeMessage(vals, count, lastAny) + vals = make(map[string][]*T) + count = 0 + } + } + } +} + +type Msg[T any] struct { + key string + triggerID string + val []*T +} + +func (m Msg[T]) Key() string { + return m.key +} + +func (m Msg[T]) TriggerID() string { + return m.triggerID +} + +func (m Msg[T]) Val() []*T { + return m.val +} + +func (m Msg[T]) String() string { + var sb strings.Builder + sb.WriteString("Key: ") + sb.WriteString(m.key) + sb.WriteString(", Values: [") + for i, v := range m.val { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(fmt.Sprintf("%v", *v)) + } + sb.WriteString("]") + return sb.String() +} + +func (b *Batcher[T]) distributeMessage(messages map[string][]*T, totalCount int, lastMessage *T) { + triggerID := idutil.OperationIDGenerator() + b.HookFunc(triggerID, messages, totalCount, lastMessage) + for key, data := range messages { + if b.config.syncWait { + b.counter.Add(1) + } + channelID := b.Sharding(key) + b.chArrays[channelID] <- &Msg[T]{key: key, triggerID: triggerID, val: data} + } + if b.config.syncWait { + b.counter.Wait() + } + b.OnComplete(lastMessage, totalCount) +} + +func (b *Batcher[T]) run(channelID int, ch <-chan *Msg[T]) { + defer b.wait.Done() + for { + select { + case messages, ok := <-ch: + if !ok { + return + } + b.Do(context.Background(), channelID, messages) + if b.config.syncWait { + b.counter.Done() + } + } + } +} + +func (b *Batcher[T]) Close() { + b.cancel() // Signal to stop put data + b.data <- nil + //wait all goroutines exit + b.wait.Wait() +} diff --git a/pkg/tools/batcher/batcher_test.go b/pkg/tools/batcher/batcher_test.go new file mode 100644 index 0000000000..90e0284490 --- /dev/null +++ b/pkg/tools/batcher/batcher_test.go @@ -0,0 +1,66 @@ +package batcher + +import ( + "context" + "fmt" + "github.com/openimsdk/tools/utils/stringutil" + "testing" + "time" +) + +func TestBatcher(t *testing.T) { + config := Config{ + size: 1000, + buffer: 10, + worker: 10, + interval: 5 * time.Millisecond, + } + + b := New[string]( + WithSize(config.size), + WithBuffer(config.buffer), + WithWorker(config.worker), + WithInterval(config.interval), + WithSyncWait(true), + ) + + // Mock Do function to simply print values for demonstration + b.Do = func(ctx context.Context, channelID int, vals *Msg[string]) { + t.Logf("Channel %d Processed batch: %v", channelID, vals) + } + b.OnComplete = func(lastMessage *string, totalCount int) { + t.Logf("Completed processing with last message: %v, total count: %d", *lastMessage, totalCount) + } + b.Sharding = func(key string) int { + hashCode := stringutil.GetHashCode(key) + return int(hashCode) % config.worker + } + b.Key = func(data *string) string { + return *data + } + + err := b.Start() + if err != nil { + t.Fatal(err) + } + + // Test normal data processing + for i := 0; i < 10000; i++ { + data := "data" + fmt.Sprintf("%d", i) + if err := b.Put(context.Background(), &data); err != nil { + t.Fatal(err) + } + } + + time.Sleep(time.Duration(1) * time.Second) + start := time.Now() + // Wait for all processing to finish + b.Close() + + elapsed := time.Since(start) + t.Logf("Close took %s", elapsed) + + if len(b.data) != 0 { + t.Error("Data channel should be empty after closing") + } +}