From 6d8c3c5f524ebae91e486420fc66be6699b5d6ca Mon Sep 17 00:00:00 2001 From: PanQL <36041487+PanQL@users.noreply.github.com> Date: Mon, 17 Oct 2022 23:47:30 +0800 Subject: [PATCH] refactor(source): remove drop_source rpc in CN (#5849) * remove drop_source rpc * drop related tables when drop actor in drop_actor rpc * impl insert_source/try_drop_source for source mgr * fix unittest * exclude stream source id from actor_tables * Update src/source/src/manager.rs Co-authored-by: August * Update src/source/src/manager.rs Co-authored-by: August * rename SourceManager; clear actor_tables when doing stop_all_actors; * use Weak/Arc instead of usize to count source ref * repair unit tests * merge sources and source_ref_count into a map;clear source map when drop_all_actors. * modification: - remove TableSourceManager trait and use MemSourceManager directly. - remove useless source desc during insert_source instead. * add clear_source and call it during force_stop_actors. * improve unittest coverage. Co-authored-by: August Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- proto/stream_service.proto | 9 - src/batch/src/executor/delete.rs | 10 +- src/batch/src/executor/insert.rs | 10 +- src/batch/src/executor/update.rs | 10 +- src/batch/src/task/context.rs | 6 +- src/batch/src/task/env.rs | 13 +- src/compute/src/rpc/service/stream_service.rs | 20 +- src/compute/src/server.rs | 4 +- src/compute/tests/integration_tests.rs | 4 +- src/frontend/src/scheduler/task_context.rs | 4 +- src/meta/src/stream/stream_manager.rs | 7 - src/rpc_client/src/stream_client.rs | 1 - src/source/src/lib.rs | 1 + src/source/src/manager.rs | 183 +++++++++--------- .../src/executor/source/source_executor.rs | 18 +- src/stream/src/task/env.rs | 13 +- 16 files changed, 134 insertions(+), 179 deletions(-) diff --git a/proto/stream_service.proto b/proto/stream_service.proto index 6a28082c1786..2f9a1e243c9d 100644 --- a/proto/stream_service.proto +++ b/proto/stream_service.proto @@ -96,14 +96,6 @@ message BroadcastActorInfoTableResponse { common.Status status = 1; } -message DropSourceRequest { - uint32 source_id = 1; -} - -message DropSourceResponse { - common.Status status = 1; -} - message WaitEpochCommitRequest { uint64 epoch = 1; } @@ -119,7 +111,6 @@ service StreamService { rpc DropActors(DropActorsRequest) returns (DropActorsResponse); rpc ForceStopActors(ForceStopActorsRequest) returns (ForceStopActorsResponse); rpc InjectBarrier(InjectBarrierRequest) returns (InjectBarrierResponse); - rpc DropSource(DropSourceRequest) returns (DropSourceResponse); rpc BarrierComplete(BarrierCompleteRequest) returns (BarrierCompleteResponse); rpc WaitEpochCommit(WaitEpochCommitRequest) returns (WaitEpochCommitResponse); } diff --git a/src/batch/src/executor/delete.rs b/src/batch/src/executor/delete.rs index 06b5c8e84a2f..67677cb4f5de 100644 --- a/src/batch/src/executor/delete.rs +++ b/src/batch/src/executor/delete.rs @@ -20,7 +20,7 @@ use risingwave_common::catalog::{Field, Schema, TableId}; use risingwave_common::error::{Result, RwError}; use risingwave_common::types::DataType; use risingwave_pb::batch_plan::plan_node::NodeBody; -use risingwave_source::SourceManagerRef; +use risingwave_source::TableSourceManagerRef; use crate::error::BatchError; use crate::executor::{ @@ -33,7 +33,7 @@ use crate::task::BatchTaskContext; pub struct DeleteExecutor { /// Target table id. table_id: TableId, - source_manager: SourceManagerRef, + source_manager: TableSourceManagerRef, child: BoxedExecutor, schema: Schema, identity: String, @@ -42,7 +42,7 @@ pub struct DeleteExecutor { impl DeleteExecutor { pub fn new( table_id: TableId, - source_manager: SourceManagerRef, + source_manager: TableSourceManagerRef, child: BoxedExecutor, identity: String, ) -> Self { @@ -145,7 +145,7 @@ mod tests { use risingwave_common::catalog::schema_test_utils; use risingwave_common::test_prelude::DataChunkTestExt; use risingwave_source::table_test_utils::create_table_info; - use risingwave_source::{MemSourceManager, SourceDescBuilder, SourceManagerRef}; + use risingwave_source::{SourceDescBuilder, TableSourceManager, TableSourceManagerRef}; use super::*; use crate::executor::test_utils::MockExecutor; @@ -153,7 +153,7 @@ mod tests { #[tokio::test] async fn test_delete_executor() -> Result<()> { - let source_manager: SourceManagerRef = Arc::new(MemSourceManager::default()); + let source_manager: TableSourceManagerRef = Arc::new(TableSourceManager::default()); // Schema for mock executor. let schema = schema_test_utils::ii(); diff --git a/src/batch/src/executor/insert.rs b/src/batch/src/executor/insert.rs index 9b7f3c3e25f1..fb938941d988 100644 --- a/src/batch/src/executor/insert.rs +++ b/src/batch/src/executor/insert.rs @@ -23,7 +23,7 @@ use risingwave_common::catalog::{Field, Schema, TableId}; use risingwave_common::error::{Result, RwError}; use risingwave_common::types::DataType; use risingwave_pb::batch_plan::plan_node::NodeBody; -use risingwave_source::SourceManagerRef; +use risingwave_source::TableSourceManagerRef; use crate::executor::{ BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder, @@ -33,7 +33,7 @@ use crate::task::BatchTaskContext; pub struct InsertExecutor { /// Target table id. table_id: TableId, - source_manager: SourceManagerRef, + source_manager: TableSourceManagerRef, child: BoxedExecutor, schema: Schema, @@ -43,7 +43,7 @@ pub struct InsertExecutor { impl InsertExecutor { pub fn new( table_id: TableId, - source_manager: SourceManagerRef, + source_manager: TableSourceManagerRef, child: BoxedExecutor, identity: String, ) -> Self { @@ -159,7 +159,7 @@ mod tests { use risingwave_common::column_nonnull; use risingwave_common::types::DataType; use risingwave_source::table_test_utils::create_table_info; - use risingwave_source::{MemSourceManager, SourceDescBuilder, SourceManagerRef}; + use risingwave_source::{SourceDescBuilder, TableSourceManager, TableSourceManagerRef}; use risingwave_storage::memory::MemoryStateStore; use risingwave_storage::store::ReadOptions; use risingwave_storage::*; @@ -170,7 +170,7 @@ mod tests { #[tokio::test] async fn test_insert_executor() -> Result<()> { - let source_manager: SourceManagerRef = Arc::new(MemSourceManager::default()); + let source_manager: TableSourceManagerRef = Arc::new(TableSourceManager::default()); let store = MemoryStateStore::new(); // Make struct field diff --git a/src/batch/src/executor/update.rs b/src/batch/src/executor/update.rs index e15e65d27b3f..26bd43f08c8b 100644 --- a/src/batch/src/executor/update.rs +++ b/src/batch/src/executor/update.rs @@ -23,7 +23,7 @@ use risingwave_common::error::{Result, RwError}; use risingwave_common::types::DataType; use risingwave_expr::expr::{build_from_prost, BoxedExpression}; use risingwave_pb::batch_plan::plan_node::NodeBody; -use risingwave_source::SourceManagerRef; +use risingwave_source::TableSourceManagerRef; use crate::error::BatchError; use crate::executor::{ @@ -37,7 +37,7 @@ use crate::task::BatchTaskContext; pub struct UpdateExecutor { /// Target table id. table_id: TableId, - source_manager: SourceManagerRef, + source_manager: TableSourceManagerRef, child: BoxedExecutor, exprs: Vec, schema: Schema, @@ -47,7 +47,7 @@ pub struct UpdateExecutor { impl UpdateExecutor { pub fn new( table_id: TableId, - source_manager: SourceManagerRef, + source_manager: TableSourceManagerRef, child: BoxedExecutor, exprs: Vec, identity: String, @@ -200,7 +200,7 @@ mod tests { use risingwave_common::test_prelude::DataChunkTestExt; use risingwave_expr::expr::InputRefExpression; use risingwave_source::table_test_utils::create_table_info; - use risingwave_source::{MemSourceManager, SourceDescBuilder, SourceManagerRef}; + use risingwave_source::{SourceDescBuilder, TableSourceManager, TableSourceManagerRef}; use super::*; use crate::executor::test_utils::MockExecutor; @@ -208,7 +208,7 @@ mod tests { #[tokio::test] async fn test_update_executor() -> Result<()> { - let source_manager: SourceManagerRef = Arc::new(MemSourceManager::default()); + let source_manager: TableSourceManagerRef = Arc::new(TableSourceManager::default()); // Schema for mock executor. let schema = schema_test_utils::ii(); diff --git a/src/batch/src/task/context.rs b/src/batch/src/task/context.rs index 0422eb9ec38d..a223d13ead0c 100644 --- a/src/batch/src/task/context.rs +++ b/src/batch/src/task/context.rs @@ -17,7 +17,7 @@ use risingwave_common::config::BatchConfig; use risingwave_common::error::Result; use risingwave_common::util::addr::{is_local_address, HostAddr}; use risingwave_rpc_client::ComputeClientPoolRef; -use risingwave_source::SourceManagerRef; +use risingwave_source::TableSourceManagerRef; use risingwave_storage::StateStoreImpl; use super::TaskId; @@ -39,7 +39,7 @@ pub trait BatchTaskContext: Clone + Send + Sync + 'static { /// Whether `peer_addr` is in same as current task. fn is_local_addr(&self, peer_addr: &HostAddr) -> bool; - fn source_manager(&self) -> SourceManagerRef; + fn source_manager(&self) -> TableSourceManagerRef; fn state_store(&self) -> StateStoreImpl; @@ -78,7 +78,7 @@ impl BatchTaskContext for ComputeNodeContext { is_local_address(self.env.server_address(), peer_addr) } - fn source_manager(&self) -> SourceManagerRef { + fn source_manager(&self) -> TableSourceManagerRef { self.env.source_manager_ref() } diff --git a/src/batch/src/task/env.rs b/src/batch/src/task/env.rs index 393c41904501..0f1ed6849f55 100644 --- a/src/batch/src/task/env.rs +++ b/src/batch/src/task/env.rs @@ -17,7 +17,7 @@ use std::sync::Arc; use risingwave_common::config::BatchConfig; use risingwave_common::util::addr::HostAddr; use risingwave_rpc_client::ComputeClientPoolRef; -use risingwave_source::{SourceManager, SourceManagerRef}; +use risingwave_source::{TableSourceManager, TableSourceManagerRef}; use risingwave_storage::StateStoreImpl; use crate::executor::BatchTaskMetrics; @@ -36,7 +36,7 @@ pub struct BatchEnvironment { task_manager: Arc, /// Reference to the source manager. This is used to query the sources. - source_manager: SourceManagerRef, + source_manager: TableSourceManagerRef, /// Batch related configurations. config: Arc, @@ -57,7 +57,7 @@ pub struct BatchEnvironment { impl BatchEnvironment { #[allow(clippy::too_many_arguments)] pub fn new( - source_manager: SourceManagerRef, + source_manager: TableSourceManagerRef, task_manager: Arc, server_addr: HostAddr, config: Arc, @@ -82,13 +82,12 @@ impl BatchEnvironment { #[cfg(test)] pub fn for_test() -> Self { use risingwave_rpc_client::ComputeClientPool; - use risingwave_source::MemSourceManager; use risingwave_storage::monitor::StateStoreMetrics; BatchEnvironment { task_manager: Arc::new(BatchManager::new(None)), server_addr: "127.0.0.1:5688".parse().unwrap(), - source_manager: std::sync::Arc::new(MemSourceManager::default()), + source_manager: std::sync::Arc::new(TableSourceManager::default()), config: Arc::new(BatchConfig::default()), worker_id: WorkerNodeId::default(), state_store: StateStoreImpl::shared_in_memory_store(Arc::new( @@ -108,11 +107,11 @@ impl BatchEnvironment { } #[expect(clippy::explicit_auto_deref)] - pub fn source_manager(&self) -> &dyn SourceManager { + pub fn source_manager(&self) -> &TableSourceManager { &*self.source_manager } - pub fn source_manager_ref(&self) -> SourceManagerRef { + pub fn source_manager_ref(&self) -> TableSourceManagerRef { self.source_manager.clone() } diff --git a/src/compute/src/rpc/service/stream_service.rs b/src/compute/src/rpc/service/stream_service.rs index 365d8796a46f..1de83a63f933 100644 --- a/src/compute/src/rpc/service/stream_service.rs +++ b/src/compute/src/rpc/service/stream_service.rs @@ -16,7 +16,6 @@ use std::sync::Arc; use async_stack_trace::StackTrace; use itertools::Itertools; -use risingwave_common::catalog::TableId; use risingwave_common::error::tonic_err; use risingwave_pb::stream_service::barrier_complete_response::GroupedSstableInfo; use risingwave_pb::stream_service::stream_service_server::StreamService; @@ -117,6 +116,7 @@ impl StreamService for StreamServiceImpl { ) -> std::result::Result, Status> { let req = request.into_inner(); self.mgr.stop_all_actors().await?; + self.env.source_manager().clear_sources(); Ok(Response::new(ForceStopActorsResponse { request_id: req.request_id, status: None, @@ -197,22 +197,4 @@ impl StreamService for StreamServiceImpl { Ok(Response::new(WaitEpochCommitResponse { status: None })) } - - #[cfg_attr(coverage, no_coverage)] - async fn drop_source( - &self, - request: Request, - ) -> Result, Status> { - let id = request.into_inner().source_id; - let id = TableId::new(id); // TODO: use SourceId instead - - self.env - .source_manager() - .drop_source(&id) - .map_err(tonic_err)?; - - tracing::debug!(id = %id, "drop source"); - - Ok(Response::new(DropSourceResponse { status: None })) - } } diff --git a/src/compute/src/server.rs b/src/compute/src/server.rs index 254c752e0213..d5b09af589d8 100644 --- a/src/compute/src/server.rs +++ b/src/compute/src/server.rs @@ -30,7 +30,7 @@ use risingwave_pb::task_service::exchange_service_server::ExchangeServiceServer; use risingwave_pb::task_service::task_service_server::TaskServiceServer; use risingwave_rpc_client::{ComputeClientPool, ExtraInfoSourceRef, MetaClient}; use risingwave_source::monitor::SourceMetrics; -use risingwave_source::MemSourceManager; +use risingwave_source::TableSourceManager; use risingwave_storage::hummock::compactor::{ CompactionExecutor, Compactor, CompactorContext, Context, }; @@ -190,7 +190,7 @@ pub async fn compute_node_serve( opts.enable_async_stack_trace, opts.enable_managed_cache, )); - let source_mgr = Arc::new(MemSourceManager::new( + let source_mgr = Arc::new(TableSourceManager::new( source_metrics, stream_config.developer.stream_connector_message_buffer_size, )); diff --git a/src/compute/tests/integration_tests.rs b/src/compute/tests/integration_tests.rs index e54f93fa42d7..60ef766d9c59 100644 --- a/src/compute/tests/integration_tests.rs +++ b/src/compute/tests/integration_tests.rs @@ -34,7 +34,7 @@ use risingwave_common::test_prelude::DataChunkTestExt; use risingwave_common::types::{DataType, IntoOrdered}; use risingwave_common::util::epoch::EpochPair; use risingwave_common::util::sort_util::{OrderPair, OrderType}; -use risingwave_source::{MemSourceManager, SourceDescBuilder, SourceManagerRef}; +use risingwave_source::{SourceDescBuilder, TableSourceManager, TableSourceManagerRef}; use risingwave_storage::memory::MemoryStateStore; use risingwave_storage::table::batch_table::storage_table::StorageTable; use risingwave_storage::table::streaming_table::state_table::StateTable; @@ -92,7 +92,7 @@ async fn test_table_materialize() -> StreamResult<()> { use risingwave_stream::executor::state_table_handler::default_source_internal_table; let memory_state_store = MemoryStateStore::new(); - let source_manager: SourceManagerRef = Arc::new(MemSourceManager::default()); + let source_manager: TableSourceManagerRef = Arc::new(TableSourceManager::default()); let source_table_id = TableId::default(); let schema = Schema { fields: vec![ diff --git a/src/frontend/src/scheduler/task_context.rs b/src/frontend/src/scheduler/task_context.rs index 0691d0d98d93..8f28d6ca3816 100644 --- a/src/frontend/src/scheduler/task_context.rs +++ b/src/frontend/src/scheduler/task_context.rs @@ -21,7 +21,7 @@ use risingwave_common::config::BatchConfig; use risingwave_common::error::Result; use risingwave_common::util::addr::{is_local_address, HostAddr}; use risingwave_rpc_client::ComputeClientPoolRef; -use risingwave_source::SourceManagerRef; +use risingwave_source::TableSourceManagerRef; use crate::catalog::pg_catalog::SysCatalogReaderImpl; use crate::session::{AuthContext, FrontendEnv}; @@ -58,7 +58,7 @@ impl BatchTaskContext for FrontendBatchTaskContext { is_local_address(self.env.server_address(), peer_addr) } - fn source_manager(&self) -> SourceManagerRef { + fn source_manager(&self) -> TableSourceManagerRef { unimplemented!("not supported in local mode") } diff --git a/src/meta/src/stream/stream_manager.rs b/src/meta/src/stream/stream_manager.rs index e78bf8ab5a49..446c07e66760 100644 --- a/src/meta/src/stream/stream_manager.rs +++ b/src/meta/src/stream/stream_manager.rs @@ -859,13 +859,6 @@ mod tests { Ok(Response::new(InjectBarrierResponse::default())) } - async fn drop_source( - &self, - _request: Request, - ) -> std::result::Result, Status> { - unimplemented!() - } - async fn barrier_complete( &self, _request: Request, diff --git a/src/rpc_client/src/stream_client.rs b/src/rpc_client/src/stream_client.rs index ae0959bdcc7c..25c043b44d49 100644 --- a/src/rpc_client/src/stream_client.rs +++ b/src/rpc_client/src/stream_client.rs @@ -58,7 +58,6 @@ macro_rules! for_all_stream_rpc { ,{ 0, drop_actors, DropActorsRequest, DropActorsResponse } ,{ 0, force_stop_actors, ForceStopActorsRequest, ForceStopActorsResponse} ,{ 0, inject_barrier, InjectBarrierRequest, InjectBarrierResponse } - ,{ 0, drop_source, DropSourceRequest, DropSourceResponse } ,{ 0, barrier_complete, BarrierCompleteRequest, BarrierCompleteResponse } ,{ 0, wait_epoch_commit, WaitEpochCommitRequest, WaitEpochCommitResponse } } diff --git a/src/source/src/lib.rs b/src/source/src/lib.rs index 2d5e4385a2c2..17f4c27d864d 100644 --- a/src/source/src/lib.rs +++ b/src/source/src/lib.rs @@ -20,6 +20,7 @@ #![feature(lint_reasons)] #![feature(result_option_inspect)] #![feature(generators)] +#![feature(hash_drain_filter)] use std::collections::HashMap; use std::fmt::Debug; diff --git a/src/source/src/manager.rs b/src/source/src/manager.rs index ae905825aa76..6153bb821ecb 100644 --- a/src/source/src/manager.rs +++ b/src/source/src/manager.rs @@ -14,12 +14,11 @@ use std::collections::HashMap; use std::fmt::Debug; -use std::sync::Arc; +use std::sync::{Arc, Weak}; -use async_trait::async_trait; -use parking_lot::{Mutex, MutexGuard}; +use itertools::Itertools; +use parking_lot::Mutex; use risingwave_common::catalog::{ColumnDesc, ColumnId, TableId}; -use risingwave_common::ensure; use risingwave_common::error::ErrorCode::{ConnectorError, InternalError, ProtocolError}; use risingwave_common::error::{Result, RwError}; use risingwave_common::types::DataType; @@ -32,21 +31,8 @@ use crate::monitor::SourceMetrics; use crate::table::TableSource; use crate::{ConnectorSource, SourceFormat, SourceImpl, SourceParserImpl}; -pub type SourceRef = Arc; - -/// The local source manager on the compute node. -#[async_trait] -pub trait SourceManager: Debug + Sync + Send { - fn get_source(&self, source_id: &TableId) -> Result; - fn drop_source(&self, source_id: &TableId) -> Result<()>; - - /// Clear sources, this is used when failover happens. - fn clear_sources(&self) -> Result<()>; - - fn metrics(&self) -> Arc; - fn msg_buf_size(&self) -> usize; - fn get_sources(&self) -> Result>>; -} +pub type SourceDescRef = Arc; +type WeakSourceDescRef = Weak; /// `SourceColumnDesc` is used to describe a column in the Source and is used as the column /// counterpart in `StreamScan` @@ -103,9 +89,9 @@ impl From<&SourceColumnDesc> for ColumnDesc { } /// `SourceDesc` is used to describe a `Source` -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct SourceDesc { - pub source: SourceRef, + pub source: SourceImpl, pub format: SourceFormat, pub columns: Vec, pub metrics: Arc, @@ -115,11 +101,11 @@ pub struct SourceDesc { pub pk_column_ids: Vec, } -pub type SourceManagerRef = Arc; +pub type TableSourceManagerRef = Arc; #[derive(Debug)] -pub struct MemSourceManager { - sources: Mutex>, +pub struct TableSourceManager { + sources: Mutex>, /// local source metrics metrics: Arc, /// The capacity of the chunks in the channel that connects between `ConnectorSource` and @@ -127,30 +113,56 @@ pub struct MemSourceManager { connector_message_buffer_size: usize, } -#[async_trait] -impl SourceManager for MemSourceManager { - fn get_source(&self, table_id: &TableId) -> Result { - let sources = self.get_sources()?; - sources.get(table_id).cloned().ok_or_else(|| { - InternalError(format!("Get source table id not exists: {:?}", table_id)).into() - }) +impl TableSourceManager { + pub fn get_source(&self, source_id: &TableId) -> Result { + let sources = self.sources.lock(); + sources + .get(source_id) + .and_then(|weak_ref| weak_ref.upgrade()) + .ok_or_else(|| { + InternalError(format!("Get source table id not exists: {:?}", source_id)).into() + }) } - fn drop_source(&self, table_id: &TableId) -> Result<()> { - let mut sources = self.get_sources()?; - ensure!( - sources.contains_key(table_id), - "Source does not exist: {:?}", - table_id - ); - sources.remove(table_id); - Ok(()) + pub fn insert_source( + &self, + source_id: &TableId, + info: &TableSourceInfo, + ) -> Result { + let mut sources = self.sources.lock(); + sources.drain_filter(|_, weak_ref| weak_ref.strong_count() == 0); + if let Some(strong_ref) = sources + .get(source_id) + .and_then(|weak_ref| weak_ref.upgrade()) + { + Ok(strong_ref) + } else { + let columns = info + .columns + .iter() + .cloned() + .map(|c| ColumnDesc::from(c.column_desc.unwrap())) + .collect_vec(); + let row_id_index = info.row_id_index.as_ref().map(|index| index.index as _); + let pk_column_ids = info.pk_column_ids.clone(); + + // Table sources do not need columns and format + let strong_ref = Arc::new(SourceDesc { + columns: columns.iter().map(SourceColumnDesc::from).collect(), + source: SourceImpl::Table(TableSource::new(columns)), + format: SourceFormat::Invalid, + row_id_index, + pk_column_ids, + metrics: self.metrics.clone(), + }); + sources.insert(*source_id, Arc::downgrade(&strong_ref)); + Ok(strong_ref) + } } - fn clear_sources(&self) -> Result<()> { - let mut sources = self.get_sources()?; - sources.clear(); - Ok(()) + /// For recovery, clear all sources' weak references. + pub fn clear_sources(&self) { + self.sources.lock().clear() } fn metrics(&self) -> Arc { @@ -160,15 +172,11 @@ impl SourceManager for MemSourceManager { fn msg_buf_size(&self) -> usize { self.connector_message_buffer_size } - - fn get_sources(&self) -> Result>> { - Ok(self.sources.lock()) - } } -impl Default for MemSourceManager { +impl Default for TableSourceManager { fn default() -> Self { - MemSourceManager { + TableSourceManager { sources: Default::default(), metrics: Default::default(), connector_message_buffer_size: 16, @@ -176,9 +184,9 @@ impl Default for MemSourceManager { } } -impl MemSourceManager { +impl TableSourceManager { pub fn new(metrics: Arc, connector_message_buffer_size: usize) -> Self { - MemSourceManager { + TableSourceManager { sources: Mutex::new(HashMap::new()), metrics, connector_message_buffer_size, @@ -190,11 +198,11 @@ impl MemSourceManager { pub struct SourceDescBuilder { id: TableId, info: ProstSourceInfo, - mgr: SourceManagerRef, + mgr: TableSourceManagerRef, } impl SourceDescBuilder { - pub fn new(id: TableId, info: &ProstSourceInfo, mgr: &SourceManagerRef) -> Self { + pub fn new(id: TableId, info: &ProstSourceInfo, mgr: &TableSourceManagerRef) -> Self { Self { id, info: info.clone(), @@ -202,7 +210,7 @@ impl SourceDescBuilder { } } - pub async fn build(&self) -> Result { + pub async fn build(&self) -> Result { let Self { id, info, mgr } = self; match &info { ProstSourceInfo::TableSource(info) => Self::build_table_source(mgr, id, info), @@ -211,42 +219,17 @@ impl SourceDescBuilder { } fn build_table_source( - mgr: &SourceManagerRef, + mgr: &TableSourceManagerRef, table_id: &TableId, info: &TableSourceInfo, - ) -> Result { - let mut sources = mgr.get_sources()?; - if let Some(source_desc) = sources.get(table_id) { - return Ok(source_desc.clone()); - } - - let columns: Vec<_> = info - .columns - .iter() - .cloned() - .map(|c| ColumnDesc::from(c.column_desc.unwrap())) - .collect(); - let row_id_index = info.row_id_index.as_ref().map(|index| index.index as _); - let pk_column_ids = info.pk_column_ids.clone(); - - // Table sources do not need columns and format - let desc = SourceDesc { - columns: columns.iter().map(SourceColumnDesc::from).collect(), - source: Arc::new(SourceImpl::Table(TableSource::new(columns))), - format: SourceFormat::Invalid, - row_id_index, - pk_column_ids, - metrics: mgr.metrics(), - }; - - sources.insert(*table_id, desc.clone()); - Ok(desc) + ) -> Result { + mgr.insert_source(table_id, info) } async fn build_stream_source( - mgr: &SourceManagerRef, + mgr: &TableSourceManagerRef, info: &StreamSourceInfo, - ) -> Result { + ) -> Result { let format = match info.get_row_format()? { RowFormatType::Json => SourceFormat::Json, RowFormatType::Protobuf => SourceFormat::Protobuf, @@ -294,14 +277,14 @@ impl SourceDescBuilder { connector_message_buffer_size: mgr.msg_buf_size(), }); - Ok(SourceDesc { - source: Arc::new(source), + Ok(Arc::new(SourceDesc { + source, format, columns, row_id_index, pk_column_ids, metrics: mgr.metrics(), - }) + })) } } @@ -345,7 +328,7 @@ mod tests { }; let source_id = TableId::default(); - let mem_source_manager: SourceManagerRef = Arc::new(MemSourceManager::default()); + let mem_source_manager: TableSourceManagerRef = Arc::new(TableSourceManager::default()); let source_builder = SourceDescBuilder::new(source_id, &Info::StreamSource(info), &mem_source_manager); let source = source_builder.build().await; @@ -392,8 +375,8 @@ mod tests { let _keyspace = Keyspace::table_root(MemoryStateStore::new(), &table_id); - let mem_source_manager: SourceManagerRef = Arc::new(MemSourceManager::default()); - let source_builder = + let mem_source_manager: TableSourceManagerRef = Arc::new(TableSourceManager::default()); + let mut source_builder = SourceDescBuilder::new(table_id, &Info::TableSource(info), &mem_source_manager); let res = source_builder.build().await; assert!(res.is_ok()); @@ -402,11 +385,19 @@ mod tests { let get_source_res = mem_source_manager.get_source(&table_id); assert!(get_source_res.is_ok()); - // drop source - let drop_source_res = mem_source_manager.drop_source(&table_id); - assert!(drop_source_res.is_ok()); - let get_source_res = mem_source_manager.get_source(&table_id); - assert!(get_source_res.is_err()); + // drop all replicas of TableId(0) + drop(res); + drop(get_source_res); + // failed to get_source + let result = mem_source_manager.get_source(&table_id); + assert!(result.is_err()); + + source_builder.id = TableId::new(1u32); + let _new_source = source_builder.build().await; + + assert_eq!(mem_source_manager.sources.lock().len(), 1); + mem_source_manager.clear_sources(); + assert!(mem_source_manager.sources.lock().is_empty()); Ok(()) } diff --git a/src/stream/src/executor/source/source_executor.rs b/src/stream/src/executor/source/source_executor.rs index 54a92be067d3..695ebef04424 100644 --- a/src/stream/src/executor/source/source_executor.rs +++ b/src/stream/src/executor/source/source_executor.rs @@ -221,10 +221,10 @@ impl SourceExecutor { async fn build_stream_source_reader( &mut self, - source_desc: &SourceDesc, + source_desc: &SourceDescRef, state: ConnectorState, ) -> StreamExecutorResult { - let reader = match source_desc.source.as_ref() { + let reader = match &source_desc.source { SourceImpl::Table(t) => t .stream_reader(self.column_ids.clone()) .await @@ -387,7 +387,7 @@ impl SourceExecutor { } // Refill row id column for source. - chunk = match source_desc.source.as_ref() { + chunk = match &source_desc.source { SourceImpl::Connector(_) => { self.refill_row_id_column(chunk, true, row_id_index).await } @@ -414,7 +414,7 @@ impl SourceExecutor { async fn apply_split_change( &mut self, - source_desc: &SourceDesc, + source_desc: &SourceDescRef, stream: &mut SourceReaderStream, mapping: &HashMap>, ) -> StreamExecutorResult<()> { @@ -430,7 +430,7 @@ impl SourceExecutor { async fn replace_stream_reader_with_target_state( &mut self, - source_desc: &SourceDesc, + source_desc: &SourceDescRef, stream: &mut SourceReaderStream, target_state: Vec, ) -> StreamExecutorResult<()> { @@ -524,7 +524,7 @@ mod tests { let row_id_index = Some(0); let pk_column_ids = vec![0]; let info = create_table_info(&schema, row_id_index, pk_column_ids); - let source_manager: SourceManagerRef = Arc::new(MemSourceManager::default()); + let source_manager: TableSourceManagerRef = Arc::new(TableSourceManager::default()); let source_builder = SourceDescBuilder::new(table_id, &info, &source_manager); let source_desc = source_builder.build().await.unwrap(); @@ -626,7 +626,7 @@ mod tests { let row_id_index = Some(0); let pk_column_ids = vec![0]; let info = create_table_info(&schema, row_id_index, pk_column_ids); - let source_manager: SourceManagerRef = Arc::new(MemSourceManager::default()); + let source_manager: TableSourceManagerRef = Arc::new(TableSourceManager::default()); let source_builder = SourceDescBuilder::new(table_id, &info, &source_manager); let source_desc = source_builder.build().await.unwrap(); @@ -742,9 +742,9 @@ mod tests { async fn test_split_change_mutation() { let stream_source_info = mock_stream_source_info(); let source_table_id = TableId::default(); - let source_manager: SourceManagerRef = Arc::new(MemSourceManager::default()); + let source_manager: TableSourceManagerRef = Arc::new(TableSourceManager::default()); - let get_schema = |column_ids: &[ColumnId], source_desc: &SourceDesc| { + let get_schema = |column_ids: &[ColumnId], source_desc: &SourceDescRef| { let mut fields = Vec::with_capacity(column_ids.len()); for &column_id in column_ids { let column_desc = source_desc diff --git a/src/stream/src/task/env.rs b/src/stream/src/task/env.rs index c59ddf080001..5053de21220c 100644 --- a/src/stream/src/task/env.rs +++ b/src/stream/src/task/env.rs @@ -16,7 +16,7 @@ use std::sync::Arc; use risingwave_common::config::StreamingConfig; use risingwave_common::util::addr::HostAddr; -use risingwave_source::{SourceManager, SourceManagerRef}; +use risingwave_source::{TableSourceManager, TableSourceManagerRef}; use risingwave_storage::StateStoreImpl; pub(crate) type WorkerNodeId = u32; @@ -29,7 +29,7 @@ pub struct StreamEnvironment { server_addr: HostAddr, /// Reference to the source manager. - source_manager: SourceManagerRef, + source_manager: TableSourceManagerRef, /// Streaming related configurations. config: Arc, @@ -43,7 +43,7 @@ pub struct StreamEnvironment { impl StreamEnvironment { pub fn new( - source_manager: SourceManagerRef, + source_manager: TableSourceManagerRef, server_addr: HostAddr, config: Arc, worker_id: WorkerNodeId, @@ -61,11 +61,10 @@ impl StreamEnvironment { // Create an instance for testing purpose. #[cfg(test)] pub fn for_test() -> Self { - use risingwave_source::MemSourceManager; use risingwave_storage::monitor::StateStoreMetrics; StreamEnvironment { server_addr: "127.0.0.1:5688".parse().unwrap(), - source_manager: Arc::new(MemSourceManager::default()), + source_manager: Arc::new(TableSourceManager::default()), config: Arc::new(StreamingConfig::default()), worker_id: WorkerNodeId::default(), state_store: StateStoreImpl::shared_in_memory_store(Arc::new( @@ -79,11 +78,11 @@ impl StreamEnvironment { } #[expect(clippy::explicit_auto_deref)] - pub fn source_manager(&self) -> &dyn SourceManager { + pub fn source_manager(&self) -> &TableSourceManager { &*self.source_manager } - pub fn source_manager_ref(&self) -> SourceManagerRef { + pub fn source_manager_ref(&self) -> TableSourceManagerRef { self.source_manager.clone() }