From 26dd6ae925e6c2cd3a6219fde250728d7ad81dba Mon Sep 17 00:00:00 2001 From: mxsm Date: Mon, 25 Nov 2024 23:38:08 +0800 Subject: [PATCH] =?UTF-8?q?[ISSUE=20#1293]=F0=9F=94=A5Rocketmq-client=20su?= =?UTF-8?q?pports=20the=20AllocateMessageQueueStrategy=20algorithm-Allocat?= =?UTF-8?q?eMessageQueueAveragelyByCircle=F0=9F=9A=80=20(#1312)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../allocate_message_queue_strategy.rs | 19 ++ .../src/consumer/rebalance_strategy.rs | 9 +- ...ocate_message_queue_averagely_by_circle.rs | 169 ++++++++++++++++++ .../src/common/message/message_queue.rs | 3 + 4 files changed, 196 insertions(+), 4 deletions(-) create mode 100644 rocketmq-client/src/consumer/rebalance_strategy/allocate_message_queue_averagely_by_circle.rs diff --git a/rocketmq-client/src/consumer/allocate_message_queue_strategy.rs b/rocketmq-client/src/consumer/allocate_message_queue_strategy.rs index 7033baa9..6875b663 100644 --- a/rocketmq-client/src/consumer/allocate_message_queue_strategy.rs +++ b/rocketmq-client/src/consumer/allocate_message_queue_strategy.rs @@ -19,7 +19,21 @@ use rocketmq_common::common::message::message_queue::MessageQueue; use crate::Result; +/// Trait for allocating message queues to consumers in a consumer group. +/// This trait is implemented by different strategies for message queue allocation. pub trait AllocateMessageQueueStrategy: Send + Sync { + /// Allocates message queues to a consumer in a consumer group. + /// + /// # Arguments + /// + /// * `consumer_group` - The name of the consumer group. + /// * `current_cid` - The ID of the current consumer. + /// * `mq_all` - A slice of all available message queues. + /// * `cid_all` - A slice of all consumer IDs in the consumer group. + /// + /// # Returns + /// + /// A `Result` containing a vector of allocated message queues or an error. fn allocate( &self, consumer_group: &CheetahString, @@ -28,5 +42,10 @@ pub trait AllocateMessageQueueStrategy: Send + Sync { cid_all: &[CheetahString], ) -> Result>; + /// Returns the name of the allocation strategy. + /// + /// # Returns + /// + /// A static string slice representing the name of the strategy. fn get_name(&self) -> &'static str; } diff --git a/rocketmq-client/src/consumer/rebalance_strategy.rs b/rocketmq-client/src/consumer/rebalance_strategy.rs index 0ec9f47e..642501ce 100644 --- a/rocketmq-client/src/consumer/rebalance_strategy.rs +++ b/rocketmq-client/src/consumer/rebalance_strategy.rs @@ -15,6 +15,7 @@ * limitations under the License. */ pub mod allocate_message_queue_averagely; +mod allocate_message_queue_averagely_by_circle; use std::collections::HashSet; @@ -26,8 +27,8 @@ use crate::error::MQClientError::IllegalArgumentError; use crate::Result; pub fn check( - consumer_group: &str, - current_cid: &str, + consumer_group: &CheetahString, + current_cid: &CheetahString, mq_all: &[MessageQueue], cid_all: &[CheetahString], ) -> Result { @@ -44,9 +45,9 @@ pub fn check( "cidAll is null or cidAll empty".to_string(), )); } - let current_cid: CheetahString = current_cid.to_string().into(); + let cid_set: HashSet<_> = cid_all.iter().collect(); - if !cid_set.contains(¤t_cid) { + if !cid_set.contains(current_cid) { info!( "[BUG] ConsumerGroup: {} The consumerId: {} not in cidAll: {:?}", consumer_group, current_cid, cid_all diff --git a/rocketmq-client/src/consumer/rebalance_strategy/allocate_message_queue_averagely_by_circle.rs b/rocketmq-client/src/consumer/rebalance_strategy/allocate_message_queue_averagely_by_circle.rs new file mode 100644 index 00000000..8f3ae5d4 --- /dev/null +++ b/rocketmq-client/src/consumer/rebalance_strategy/allocate_message_queue_averagely_by_circle.rs @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +use cheetah_string::CheetahString; +use rocketmq_common::common::message::message_queue::MessageQueue; + +use crate::consumer::allocate_message_queue_strategy::AllocateMessageQueueStrategy; +use crate::consumer::rebalance_strategy::check; + +pub struct AllocateMessageQueueAveragelyByCircle; + +impl AllocateMessageQueueStrategy for AllocateMessageQueueAveragelyByCircle { + fn allocate( + &self, + consumer_group: &CheetahString, + current_cid: &CheetahString, + mq_all: &[MessageQueue], + cid_all: &[CheetahString], + ) -> crate::Result> { + let mut result = Vec::new(); + if !check(consumer_group, current_cid, mq_all, cid_all)? { + return Ok(result); + } + let index = cid_all + .iter() + .position(|cid| cid == current_cid) + .unwrap_or(0); + for (i, item) in mq_all.iter().enumerate().skip(index) { + if i % cid_all.len() == index { + result.push(item.clone()); + } + } + Ok(result) + } + + #[inline] + fn get_name(&self) -> &'static str { + "AVG_BY_CIRCLE" + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use cheetah_string::CheetahString; + use rocketmq_common::common::message::message_queue::MessageQueue; + + use super::*; + + #[test] + fn allocate_returns_empty_when_check_fails() { + let strategy = AllocateMessageQueueAveragelyByCircle; + let consumer_group = CheetahString::from("test_group"); + let current_cid = CheetahString::from("consumer1"); + let mq_all = vec![MessageQueue::from_parts("topic", "broker", 0)]; + let cid_all = vec![CheetahString::from("consumer1")]; + + let result = strategy + .allocate(&consumer_group, ¤t_cid, &mq_all, &cid_all) + .unwrap(); + assert!(!result.is_empty()); + } + + #[test] + fn allocate_returns_correct_queues_for_single_consumer() { + let strategy = AllocateMessageQueueAveragelyByCircle; + let consumer_group = CheetahString::from("test_group"); + let current_cid = CheetahString::from("consumer1"); + let mq_all = vec![ + MessageQueue::from_parts("topic", "broker", 0), + MessageQueue::from_parts("topic", "broker", 1), + ]; + let cid_all = vec![CheetahString::from("consumer1")]; + + let result = strategy + .allocate(&consumer_group, ¤t_cid, &mq_all, &cid_all) + .unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].get_queue_id(), 0); + assert_eq!(result[1].get_queue_id(), 1); + } + + #[test] + fn allocate_returns_correct_queues_for_multiple_consumers() { + let strategy = AllocateMessageQueueAveragelyByCircle; + let consumer_group = CheetahString::from("test_group"); + let current_cid = CheetahString::from("consumer2"); + let mq_all = vec![ + MessageQueue::from_parts("topic", "broker", 0), + MessageQueue::from_parts("topic", "broker", 1), + MessageQueue::from_parts("topic", "broker", 2), + ]; + let cid_all = vec![ + CheetahString::from("consumer1"), + CheetahString::from("consumer2"), + CheetahString::from("consumer3"), + ]; + + let result = strategy + .allocate(&consumer_group, ¤t_cid, &mq_all, &cid_all) + .unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].get_queue_id(), 1); + } + + #[test] + fn get_name_returns_correct_name() { + let strategy = AllocateMessageQueueAveragelyByCircle; + assert_eq!(strategy.get_name(), "AVG_BY_CIRCLE"); + } + + #[test] + fn test_allocate_message_queue_averagely_by_circle() { + let consumer_id_list = create_consumer_id_list(4); + let message_queue_list = create_message_queue_list(10); + let allocate_queues = AllocateMessageQueueAveragelyByCircle + .allocate( + &CheetahString::from(""), + &CheetahString::from("CID_PREFIX"), + &message_queue_list, + &consumer_id_list, + ) + .unwrap(); + assert_eq!(0, allocate_queues.len()); + + let mut consumer_allocate_queue = HashMap::new(); + for consumer_id in &consumer_id_list { + let queues = AllocateMessageQueueAveragelyByCircle + .allocate( + &CheetahString::from(""), + consumer_id, + &message_queue_list, + &consumer_id_list, + ) + .unwrap(); + let queue_ids: Vec = queues.iter().map(|q| q.get_queue_id()).collect(); + consumer_allocate_queue.insert(consumer_id.clone(), queue_ids); + } + assert_eq!(vec![0, 4, 8], consumer_allocate_queue["CID_PREFIX0"]); + assert_eq!(vec![1, 5, 9], consumer_allocate_queue["CID_PREFIX1"]); + assert_eq!(vec![2, 6], consumer_allocate_queue["CID_PREFIX2"]); + assert_eq!(vec![3, 7], consumer_allocate_queue["CID_PREFIX3"]); + } + fn create_consumer_id_list(size: usize) -> Vec { + (0..size) + .map(|i| format!("CID_PREFIX{}", i).into()) + .collect() + } + + fn create_message_queue_list(size: usize) -> Vec { + (0..size) + .map(|i| MessageQueue::from_parts("topic", "broker", i as i32)) + .collect() + } +} diff --git a/rocketmq-common/src/common/message/message_queue.rs b/rocketmq-common/src/common/message/message_queue.rs index 646f4eb9..b7abbb5d 100644 --- a/rocketmq-common/src/common/message/message_queue.rs +++ b/rocketmq-common/src/common/message/message_queue.rs @@ -61,6 +61,7 @@ impl MessageQueue { } } + #[inline] pub fn get_topic(&self) -> &str { &self.topic } @@ -80,6 +81,7 @@ impl MessageQueue { &self.broker_name } + #[inline] pub fn set_broker_name(&mut self, broker_name: CheetahString) { self.broker_name = broker_name; } @@ -89,6 +91,7 @@ impl MessageQueue { self.queue_id } + #[inline] pub fn set_queue_id(&mut self, queue_id: i32) { self.queue_id = queue_id; }