From ee33778b88b0df70be3b77943241afb40cdbb064 Mon Sep 17 00:00:00 2001 From: Micah Wylde Date: Sun, 15 Dec 2024 20:31:42 -0800 Subject: [PATCH 01/14] WIP on lookup joins --- crates/arroyo-api/src/connection_tables.rs | 12 ++ crates/arroyo-connectors/src/lib.rs | 10 ++ crates/arroyo-connectors/src/redis/lookup.rs | 75 ++++++++ crates/arroyo-connectors/src/redis/mod.rs | 161 ++++++++++-------- .../src/redis/operator/mod.rs | 1 - .../src/redis/{operator => }/sink.rs | 92 +++++----- crates/arroyo-connectors/src/redis/table.json | 9 + crates/arroyo-datastream/src/logical.rs | 2 + crates/arroyo-planner/src/extension/lookup.rs | 121 +++++++++++++ crates/arroyo-planner/src/extension/mod.rs | 2 + crates/arroyo-planner/src/extension/sink.rs | 1 + crates/arroyo-planner/src/lib.rs | 5 +- crates/arroyo-planner/src/plan/join.rs | 89 +++++++++- crates/arroyo-planner/src/rewriters.rs | 15 ++ crates/arroyo-planner/src/tables.rs | 42 +++-- crates/arroyo-rpc/proto/api.proto | 7 + .../arroyo-rpc/src/api_types/connections.rs | 2 + crates/arroyo-worker/src/arrow/lookup_join.rs | 31 ++++ crates/arroyo-worker/src/arrow/mod.rs | 1 + crates/arroyo-worker/src/engine.rs | 1 + 20 files changed, 547 insertions(+), 132 deletions(-) create mode 100644 crates/arroyo-connectors/src/redis/lookup.rs delete mode 100644 crates/arroyo-connectors/src/redis/operator/mod.rs rename crates/arroyo-connectors/src/redis/{operator => }/sink.rs (84%) create mode 100644 crates/arroyo-planner/src/extension/lookup.rs create mode 100644 crates/arroyo-worker/src/arrow/lookup_join.rs diff --git a/crates/arroyo-api/src/connection_tables.rs b/crates/arroyo-api/src/connection_tables.rs index 82f705886..23b3df140 100644 --- a/crates/arroyo-api/src/connection_tables.rs +++ b/crates/arroyo-api/src/connection_tables.rs @@ -509,6 +509,9 @@ async fn expand_avro_schema( ConnectionType::Sink => { // don't fetch schemas for sinks for now } + ConnectionType::Lookup => { + todo!("lookup tables cannot be created via the UI") + } } } @@ -521,6 +524,9 @@ async fn expand_avro_schema( schema.inferred = Some(true); Ok(schema) } + ConnectionType::Lookup => { + todo!("lookup tables cannot be created via the UI") + } }; }; @@ -596,6 +602,9 @@ async fn expand_proto_schema( ConnectionType::Sink => { // don't fetch schemas for sinks for now } + ConnectionType::Lookup => { + todo!("lookup tables cannot be created via the UI") + } } } @@ -697,6 +706,9 @@ async fn expand_json_schema( // don't fetch schemas for sinks for now until we're better able to conform our output to the schema schema.inferred = Some(true); } + ConnectionType::Lookup => { + todo!("lookup tables cannot be created via the UI") + } } } diff --git a/crates/arroyo-connectors/src/lib.rs b/crates/arroyo-connectors/src/lib.rs index bc4892269..087342c9d 100644 --- a/crates/arroyo-connectors/src/lib.rs +++ b/crates/arroyo-connectors/src/lib.rs @@ -11,6 +11,8 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::time::Duration; +use arrow::array::{ArrayRef, RecordBatch}; +use async_trait::async_trait; use tokio::sync::mpsc::Sender; use tracing::warn; @@ -64,6 +66,14 @@ pub fn connectors() -> HashMap<&'static str, Box> { #[derive(Serialize, Deserialize)] pub struct EmptyConfig {} +#[async_trait] +pub trait LookupConnector { + fn name(&self) -> String; + + async fn lookup(&mut self, keys: &[ArrayRef]) -> RecordBatch; +} + + pub(crate) async fn send(tx: &mut Sender, msg: TestSourceMessage) { if tx.send(msg).await.is_err() { warn!("Test API rx closed while sending message"); diff --git a/crates/arroyo-connectors/src/redis/lookup.rs b/crates/arroyo-connectors/src/redis/lookup.rs new file mode 100644 index 000000000..2af114259 --- /dev/null +++ b/crates/arroyo-connectors/src/redis/lookup.rs @@ -0,0 +1,75 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use arrow::array::{ArrayRef, AsArray, RecordBatch}; +use arrow::compute::StringArrayType; +use arrow::datatypes::DataType; +use async_trait::async_trait; +use futures::future::OptionFuture; +use futures::stream::FuturesOrdered; +use futures::StreamExt; +use redis::{AsyncCommands, RedisFuture, RedisResult}; +use arroyo_formats::de::ArrowDeserializer; +use crate::LookupConnector; +use crate::redis::{RedisClient, RedisConnector}; +use crate::redis::sink::GeneralConnection; + +pub struct RedisLookup { + deserializer: ArrowDeserializer, + client: RedisClient, + connection: Option, +} + +// pub enum RedisFutureOrNull<'a> { +// RedisFuture(RedisFuture<'a, String>), +// Null +// } +// +// impl <'a> Future for RedisFutureOrNull<'a> { +// type Output = RedisResult>; +// +// fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { +// match self { +// RedisFutureOrNull::RedisFuture(f) => (**f).poll().map(|t| Some(t)), +// RedisFutureOrNull::Null => Poll::Ready(None), +// } +// } +// } + +#[async_trait] +impl LookupConnector for RedisLookup { + fn name(&self) -> String { + "RedisLookup".to_string() + } + + async fn lookup(&mut self, keys: &[ArrayRef]) -> RecordBatch { + if self.connection.is_none() { + self.connection = Some(self.client.get_connection().await.unwrap()); + } + + assert_eq!(keys.len(), 1, "redis lookup can only have a single key"); + assert_eq!(*keys[0].data_type(), DataType::Utf8, "redis lookup key must be a string"); + + let key = keys[0].as_string(); + + let connection = self.connection.as_mut().unwrap(); + + let result = connection.mget::<_, Vec>(&key.iter().filter_map(|k| k).collect::>()) + .await + .unwrap(); + + let mut result_iter = result.iter(); + + for k in key.iter() { + if k.is_some() { + self.deserializer.deserialize_slice()result_iter.next() + } + } + + while let Some(t) = futures.next().await { + + }; + + Ok(()) + } +} \ No newline at end of file diff --git a/crates/arroyo-connectors/src/redis/mod.rs b/crates/arroyo-connectors/src/redis/mod.rs index 7a78e88d0..8a678bf98 100644 --- a/crates/arroyo-connectors/src/redis/mod.rs +++ b/crates/arroyo-connectors/src/redis/mod.rs @@ -1,4 +1,5 @@ -mod operator; +pub mod sink; +pub mod lookup; use anyhow::{anyhow, bail}; use arroyo_formats::ser::ArrowSerializer; @@ -19,8 +20,8 @@ use arroyo_rpc::api_types::connections::{ }; use arroyo_rpc::OperatorConfig; -use crate::redis::operator::sink::{GeneralConnection, RedisSinkFunc}; use crate::{pull_opt, pull_option_to_u64}; +use crate::redis::sink::{GeneralConnection, RedisSinkFunc}; pub struct RedisConnector {} @@ -289,52 +290,61 @@ impl Connector for RedisConnector { } let sink = match typ.as_str() { - "sink" => TableType::Target(match pull_opt("target", options)?.as_str() { - "string" => Target::StringTable { - key_prefix: pull_opt("target.key_prefix", options)?, - key_column: options - .remove("target.key_column") - .map(|name| validate_column(schema, name, "target.key_column")) - .transpose()?, - ttl_secs: pull_option_to_u64("target.ttl_secs", options)? - .map(|t| t.try_into()) - .transpose() - .map_err(|_| anyhow!("target.ttl_secs must be greater than 0"))?, - }, - "list" => Target::ListTable { - list_prefix: pull_opt("target.key_prefix", options)?, - list_key_column: options - .remove("target.key_column") - .map(|name| validate_column(schema, name, "target.key_column")) - .transpose()?, - max_length: pull_option_to_u64("target.max_length", options)? - .map(|t| t.try_into()) - .transpose() - .map_err(|_| anyhow!("target.max_length must be greater than 0"))?, - operation: match options.remove("target.operation").as_deref() { - Some("append") | None => ListOperation::Append, - Some("prepend") => ListOperation::Prepend, - Some(op) => { - bail!("'{}' is not a valid value for target.operation; must be one of 'append' or 'prepend'", op); - } - }, - }, - "hash" => Target::HashTable { - hash_field_column: validate_column( - schema, - pull_opt("target.field_column", options)?, - "targets.field_column", - )?, - hash_key_column: options - .remove("target.key_column") - .map(|name| validate_column(schema, name, "target.key_column")) - .transpose()?, - hash_key_prefix: pull_opt("target.key_prefix", options)?, - }, - s => { - bail!("'{}' is not a valid redis target", s); + "lookup" => { + TableType::Lookup { + lookup: Default::default(), } - }), + } + "sink" => { + let target = match pull_opt("target", options)?.as_str() { + "string" => Target::StringTable { + key_prefix: pull_opt("target.key_prefix", options)?, + key_column: options + .remove("target.key_column") + .map(|name| validate_column(schema, name, "target.key_column")) + .transpose()?, + ttl_secs: pull_option_to_u64("target.ttl_secs", options)? + .map(|t| t.try_into()) + .transpose() + .map_err(|_| anyhow!("target.ttl_secs must be greater than 0"))?, + }, + "list" => Target::ListTable { + list_prefix: pull_opt("target.key_prefix", options)?, + list_key_column: options + .remove("target.key_column") + .map(|name| validate_column(schema, name, "target.key_column")) + .transpose()?, + max_length: pull_option_to_u64("target.max_length", options)? + .map(|t| t.try_into()) + .transpose() + .map_err(|_| anyhow!("target.max_length must be greater than 0"))?, + operation: match options.remove("target.operation").as_deref() { + Some("append") | None => ListOperation::Append, + Some("prepend") => ListOperation::Prepend, + Some(op) => { + bail!("'{}' is not a valid value for target.operation; must be one of 'append' or 'prepend'", op); + } + }, + }, + "hash" => Target::HashTable { + hash_field_column: validate_column( + schema, + pull_opt("target.field_column", options)?, + "targets.field_column", + )?, + hash_key_column: options + .remove("target.key_column") + .map(|name| validate_column(schema, name, "target.key_column")) + .transpose()?, + hash_key_prefix: pull_opt("target.key_prefix", options)?, + }, + s => { + bail!("'{}' is not a valid redis target", s); + } + }; + + TableType::Sink { target } + }, s => { bail!("'{}' is not a valid type; must be `sink`", s); } @@ -371,6 +381,15 @@ impl Connector for RedisConnector { let _ = RedisClient::new(&config)?; + let (connection_type, description) = match &table.connector_type { + TableType::Sink { .. } => { + (ConnectionType::Sink, "RedisSink") + } + TableType::Lookup { .. } => { + (ConnectionType::Lookup, "RedisLookup") + } + }; + let config = OperatorConfig { connection: serde_json::to_value(config).unwrap(), table: serde_json::to_value(table).unwrap(), @@ -380,15 +399,15 @@ impl Connector for RedisConnector { framing: schema.framing.clone(), metadata_fields: vec![], }; - + Ok(Connection { id, connector: self.name(), name: name.to_string(), - connection_type: ConnectionType::Sink, + connection_type, schema, config: serde_json::to_string(&config).unwrap(), - description: "RedisSink".to_string(), + description: description.to_string(), }) } @@ -400,22 +419,30 @@ impl Connector for RedisConnector { ) -> anyhow::Result { let client = RedisClient::new(&profile)?; - let (tx, cmd_rx) = tokio::sync::mpsc::channel(128); - let (cmd_tx, rx) = tokio::sync::mpsc::channel(128); - - Ok(ConstructedOperator::from_operator(Box::new( - RedisSinkFunc { - serializer: ArrowSerializer::new( - config.format.expect("redis table must have a format"), - ), - table, - client, - cmd_q: Some((cmd_tx, cmd_rx)), - tx, - rx, - key_index: None, - hash_index: None, - }, - ))) + match table.connector_type { + TableType::Sink { target } => { + let (tx, cmd_rx) = tokio::sync::mpsc::channel(128); + let (cmd_tx, rx) = tokio::sync::mpsc::channel(128); + + Ok(ConstructedOperator::from_operator(Box::new( + RedisSinkFunc { + serializer: ArrowSerializer::new( + config.format.expect("redis table must have a format"), + ), + target, + client, + cmd_q: Some((cmd_tx, cmd_rx)), + tx, + rx, + key_index: None, + hash_index: None, + }, + ))) + + } + TableType::Lookup { .. } => { + todo!() + } + } } } diff --git a/crates/arroyo-connectors/src/redis/operator/mod.rs b/crates/arroyo-connectors/src/redis/operator/mod.rs deleted file mode 100644 index 0ecbfb920..000000000 --- a/crates/arroyo-connectors/src/redis/operator/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod sink; diff --git a/crates/arroyo-connectors/src/redis/operator/sink.rs b/crates/arroyo-connectors/src/redis/sink.rs similarity index 84% rename from crates/arroyo-connectors/src/redis/operator/sink.rs rename to crates/arroyo-connectors/src/redis/sink.rs index 93d9e76ba..b40f8c99a 100644 --- a/crates/arroyo-connectors/src/redis/operator/sink.rs +++ b/crates/arroyo-connectors/src/redis/sink.rs @@ -19,7 +19,7 @@ const FLUSH_BYTES: usize = 10 * 1024 * 1024; pub struct RedisSinkFunc { pub serializer: ArrowSerializer, - pub table: RedisTable, + pub target: Target, pub client: RedisClient, pub cmd_q: Option<(Sender, Receiver)>, @@ -229,19 +229,19 @@ impl ArrowOperator for RedisSinkFunc { } async fn on_start(&mut self, ctx: &mut OperatorContext) { - match &self.table.connector_type { - TableType::Target(Target::ListTable { + match &self.target { + Target::ListTable { list_key_column: Some(key), .. - }) - | TableType::Target(Target::StringTable { + } + | Target::StringTable { key_column: Some(key), .. - }) - | TableType::Target(Target::HashTable { + } + | Target::HashTable { hash_key_column: Some(key), .. - }) => { + } => { self.key_index = Some( ctx.in_schemas .first() @@ -258,9 +258,9 @@ impl ArrowOperator for RedisSinkFunc { _ => {} } - if let TableType::Target(Target::HashTable { + if let Target::HashTable { hash_field_column, .. - }) = &self.table.connector_type + } = &self.target { self.hash_index = Some(ctx.in_schemas.first().expect("no in-schema for redis sink!") .schema @@ -282,17 +282,17 @@ impl ArrowOperator for RedisSinkFunc { size_estimate: 0, last_flushed: Instant::now(), max_push_keys: HashSet::new(), - behavior: match self.table.connector_type { - TableType::Target(Target::StringTable { ttl_secs, .. }) => { + behavior: match &self.target { + Target::StringTable { ttl_secs, .. } => { RedisBehavior::Set { ttl: ttl_secs.map(|t| t.get() as usize), } } - TableType::Target(Target::ListTable { + Target::ListTable { max_length, operation, .. - }) => { + } => { let max = max_length.map(|x| x.get() as usize); match operation { ListOperation::Append => { @@ -303,7 +303,7 @@ impl ArrowOperator for RedisSinkFunc { } } } - TableType::Target(Target::HashTable { .. }) => RedisBehavior::Hash, + Target::HashTable { .. } => RedisBehavior::Hash, }, } .start(); @@ -328,39 +328,37 @@ impl ArrowOperator for RedisSinkFunc { _: &mut dyn Collector, ) { for (i, value) in self.serializer.serialize(&batch).enumerate() { - match &self.table.connector_type { - TableType::Target(target) => match &target { - Target::StringTable { key_prefix, .. } => { - let key = self.make_key(key_prefix, &batch, i); - self.tx - .send(RedisCmd::Data { key, value }) - .await - .expect("Redis writer panicked"); - } - Target::ListTable { list_prefix, .. } => { - let key = self.make_key(list_prefix, &batch, i); + match &self.target { + Target::StringTable { key_prefix, .. } => { + let key = self.make_key(key_prefix, &batch, i); + self.tx + .send(RedisCmd::Data { key, value }) + .await + .expect("Redis writer panicked"); + } + Target::ListTable { list_prefix, .. } => { + let key = self.make_key(list_prefix, &batch, i); - self.tx - .send(RedisCmd::Data { key, value }) - .await - .expect("Redis writer panicked"); - } - Target::HashTable { - hash_key_prefix, .. - } => { - let key = self.make_key(hash_key_prefix, &batch, i); - let field = batch - .column(self.hash_index.expect("no hash index")) - .as_string::() - .value(i) - .to_string(); - - self.tx - .send(RedisCmd::HData { key, field, value }) - .await - .expect("Redis writer panicked"); - } - }, + self.tx + .send(RedisCmd::Data { key, value }) + .await + .expect("Redis writer panicked"); + } + Target::HashTable { + hash_key_prefix, .. + } => { + let key = self.make_key(hash_key_prefix, &batch, i); + let field = batch + .column(self.hash_index.expect("no hash index")) + .as_string::() + .value(i) + .to_string(); + + self.tx + .send(RedisCmd::HData { key, field, value }) + .await + .expect("Redis writer panicked"); + } }; } } diff --git a/crates/arroyo-connectors/src/redis/table.json b/crates/arroyo-connectors/src/redis/table.json index e1dd45234..be1be2212 100644 --- a/crates/arroyo-connectors/src/redis/table.json +++ b/crates/arroyo-connectors/src/redis/table.json @@ -114,6 +114,15 @@ "target" ], "additionalProperties": false + }, + { + "type": "object", + "title": "Lookup", + "properties": { + "lookup": { + "type": "object" + } + } } ] } diff --git a/crates/arroyo-datastream/src/logical.rs b/crates/arroyo-datastream/src/logical.rs index f4689c995..2f4f957a7 100644 --- a/crates/arroyo-datastream/src/logical.rs +++ b/crates/arroyo-datastream/src/logical.rs @@ -32,6 +32,7 @@ pub enum OperatorName { AsyncUdf, Join, InstantJoin, + LookupJoin, WindowFunction, TumblingWindowAggregate, SlidingWindowAggregate, @@ -376,6 +377,7 @@ impl LogicalProgram { OperatorName::Join => "join-with-expiration".to_string(), OperatorName::InstantJoin => "windowed-join".to_string(), OperatorName::WindowFunction => "sql-window-function".to_string(), + OperatorName::LookupJoin => "lookup-join".to_string(), OperatorName::TumblingWindowAggregate => { "sql-tumbling-window-aggregate".to_string() } diff --git a/crates/arroyo-planner/src/extension/lookup.rs b/crates/arroyo-planner/src/extension/lookup.rs new file mode 100644 index 000000000..1b6dbc19f --- /dev/null +++ b/crates/arroyo-planner/src/extension/lookup.rs @@ -0,0 +1,121 @@ +use std::fmt::Formatter; +use datafusion::common::{internal_err, DFSchemaRef}; +use datafusion::logical_expr::{Expr, Join, LogicalPlan, UserDefinedLogicalNodeCore}; +use datafusion::sql::TableReference; +use arroyo_rpc::df::{ArroyoSchema, ArroyoSchemaRef}; +use crate::builder::{NamedNode, Planner}; +use crate::extension::{ArroyoExtension, NodeWithIncomingEdges}; +use crate::multifield_partial_ord; +use crate::tables::ConnectorTable; + +pub const SOURCE_EXTENSION_NAME: &str = "LookupSource"; +pub const JOIN_EXTENSION_NAME: &str = "LookupJoin"; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct LookupSource { + pub(crate) table: ConnectorTable, + pub(crate) schema: DFSchemaRef, +} + +multifield_partial_ord!(LookupSource, table); + +impl UserDefinedLogicalNodeCore for LookupSource { + fn name(&self) -> &str { + SOURCE_EXTENSION_NAME + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "LookupSource: {}", self.schema) + } + + fn with_exprs_and_inputs(&self, _exprs: Vec, inputs: Vec) -> datafusion::common::Result { + if !inputs.is_empty() { + return internal_err!("LookupSource cannot have inputs"); + } + + + Ok(Self { + table: self.table.clone(), + schema: self.schema.clone(), + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct LookupJoin { + pub(crate) input: LogicalPlan, + pub(crate) schema: DFSchemaRef, + pub(crate) connector: ConnectorTable, + pub(crate) on: Vec<(Expr, Expr)>, + pub(crate) filter: Option, + pub(crate) alias: Option, +} + +multifield_partial_ord!(LookupJoin, input, connector, on, filter, alias); + +impl ArroyoExtension for LookupJoin { + fn node_name(&self) -> Option { + todo!() + } + + fn plan_node(&self, planner: &Planner, index: usize, input_schemas: Vec) -> datafusion::common::Result { + todo!() + } + + fn output_schema(&self) -> ArroyoSchema { + todo!() + } +} + +impl UserDefinedLogicalNodeCore for LookupJoin { + fn name(&self) -> &str { + JOIN_EXTENSION_NAME + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + let mut e: Vec<_> = self.on.iter() + .flat_map(|(l, r)| vec![l.clone(), r.clone()]) + .collect(); + + if let Some(filter) = &self.filter { + e.push(filter.clone()); + } + + e + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "LookupJoinExtension: {}", self.schema) + } + + fn with_exprs_and_inputs(&self, _: Vec, inputs: Vec) -> datafusion::common::Result { + Ok(Self { + input: inputs[0].clone(), + schema: self.schema.clone(), + connector: self.connector.clone(), + on: self.on.clone(), + filter: self.filter.clone(), + alias: self.alias.clone(), + }) + } +} \ No newline at end of file diff --git a/crates/arroyo-planner/src/extension/mod.rs b/crates/arroyo-planner/src/extension/mod.rs index 7b576b2a0..5b88f06bf 100644 --- a/crates/arroyo-planner/src/extension/mod.rs +++ b/crates/arroyo-planner/src/extension/mod.rs @@ -40,6 +40,8 @@ pub(crate) mod table_source; pub(crate) mod updating_aggregate; pub(crate) mod watermark_node; pub(crate) mod window_fn; +pub(crate) mod lookup; + pub(crate) trait ArroyoExtension: Debug { // if the extension has a name, return it so that we can memoize. fn node_name(&self) -> Option; diff --git a/crates/arroyo-planner/src/extension/sink.rs b/crates/arroyo-planner/src/extension/sink.rs index 0e559d175..5141e6e3a 100644 --- a/crates/arroyo-planner/src/extension/sink.rs +++ b/crates/arroyo-planner/src/extension/sink.rs @@ -61,6 +61,7 @@ impl SinkExtension { (false, false) => {} } } + Table::LookupTable(..) => return plan_err!("cannot use a lookup table as a sink"), Table::MemoryTable { .. } => return plan_err!("memory tables not supported"), Table::TableFromQuery { .. } => {} Table::PreviewSink { .. } => { diff --git a/crates/arroyo-planner/src/lib.rs b/crates/arroyo-planner/src/lib.rs index a9972e1af..7ead899ec 100644 --- a/crates/arroyo-planner/src/lib.rs +++ b/crates/arroyo-planner/src/lib.rs @@ -826,8 +826,11 @@ pub async fn parse_and_get_arrow_program( logical_plan.replace(plan_rewrite); continue; } + Table::LookupTable(_) => { + plan_err!("lookup (temporary) tables cannot be inserted into") + } Table::TableFromQuery { .. } => { - plan_err!("Shouldn't be inserting more data into a table made with CREATE TABLE AS") + plan_err!("shouldn't be inserting more data into a table made with CREATE TABLE AS") } Table::PreviewSink { .. } => { plan_err!("queries shouldn't be able insert into preview sink.") diff --git a/crates/arroyo-planner/src/plan/join.rs b/crates/arroyo-planner/src/plan/join.rs index 9dea42c56..ae5d3c2e3 100644 --- a/crates/arroyo-planner/src/plan/join.rs +++ b/crates/arroyo-planner/src/plan/join.rs @@ -4,7 +4,7 @@ use crate::plan::WindowDetectingVisitor; use crate::{fields_with_qualifiers, schema_from_df_fields_with_metadata, ArroyoSchemaProvider}; use arroyo_datastream::WindowType; use arroyo_rpc::UPDATING_META_FIELD; -use datafusion::common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor}; use datafusion::common::{ not_impl_err, plan_err, Column, DataFusionError, JoinConstraint, JoinType, Result, ScalarValue, TableReference, @@ -16,6 +16,8 @@ use datafusion::logical_expr::{ }; use datafusion::prelude::coalesce; use std::sync::Arc; +use crate::extension::lookup::{LookupJoin, LookupSource}; +use crate::tables::ConnectorTable; pub(crate) struct JoinRewriter<'a> { pub schema_provider: &'a ArroyoSchemaProvider, @@ -189,6 +191,86 @@ impl JoinRewriter<'_> { } } +#[derive(Default)] +struct FindLookupExtension { + table: Option, + filter: Option, + alias: Option, +} + +impl <'a> TreeNodeVisitor<'a> for FindLookupExtension { + type Node = LogicalPlan; + + fn f_down(&mut self, node: &Self::Node) -> Result { + match node { + LogicalPlan::Extension(e) => { + if let Some(s) = e.node.as_any().downcast_ref::() { + self.table = Some(s.table.clone()); + return Ok(TreeNodeRecursion::Stop); + } + } + LogicalPlan::Filter(filter) => { + if self.filter.replace(filter.predicate.clone()).is_some() { + return plan_err!("multiple filters found in lookup join, which is not supported"); + } + } + LogicalPlan::SubqueryAlias(s) => { + self.alias = Some(s.alias.clone()); + } + _ => { + return plan_err!("lookup tables must be used directly within a join"); + } + } + Ok(TreeNodeRecursion::Continue) + } +} + +fn has_lookup(plan: &LogicalPlan) -> Result { + plan.exists(|p| Ok(match p { + LogicalPlan::Extension(e) => e.node.as_any().is::(), + _ => false + })) +} + +fn maybe_plan_lookup_join(join: &Join) -> Result> { + if has_lookup(&join.left)? { + return plan_err!("lookup sources must be on the right side of an inner or left join"); + } + + if !has_lookup(&join.right)? { + return Ok(None); + } + + println!("JOin = {:?} {:?}\n{:#?}", join.join_constraint, join.join_type, join.on); + + match join.join_type { + JoinType::Inner | JoinType::Left => {} + t => { + return plan_err!("{} join is not supported for lookup tables; must be a left or inner join", t); + } + } + + if join.filter.is_some() { + return plan_err!("filter join conditions are not supported for lookup joins; must have an equality condition"); + } + + let mut lookup = FindLookupExtension::default(); + join.right.visit(&mut lookup)?; + + let connector = lookup.table.expect("right side of join does not have lookup"); + + Ok(Some(LogicalPlan::Extension(Extension { + node: Arc::new(LookupJoin { + input: (*join.left).clone(), + schema: join.schema.clone(), + connector, + on: join.on.clone(), + filter: lookup.filter, + alias: lookup.alias, + }) + }))) +} + impl TreeNodeRewriter for JoinRewriter<'_> { type Node = LogicalPlan; @@ -196,6 +278,11 @@ impl TreeNodeRewriter for JoinRewriter<'_> { let LogicalPlan::Join(join) = node else { return Ok(Transformed::no(node)); }; + + if let Some(plan) = maybe_plan_lookup_join(&join)? { + return Ok(Transformed::yes(plan)); + } + let is_instant = Self::check_join_windowing(&join)?; let Join { diff --git a/crates/arroyo-planner/src/rewriters.rs b/crates/arroyo-planner/src/rewriters.rs index 611822719..12cbfd21f 100644 --- a/crates/arroyo-planner/src/rewriters.rs +++ b/crates/arroyo-planner/src/rewriters.rs @@ -36,6 +36,7 @@ use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; use std::time::Duration; +use crate::extension::lookup::LookupSource; /// Rewrites a logical plan to move projections out of table scans /// and into a separate projection node which may include virtual fields, @@ -215,6 +216,19 @@ impl SourceRewriter<'_> { }))) } + fn mutate_lookup_table( + &self, + table_scan: &TableScan, + table: &ConnectorTable + ) -> DFResult> { + Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(LookupSource { + table: table.clone(), + schema: table_scan.projected_schema.clone(), + }) + }))) + } + fn mutate_table_from_query( &self, table_scan: &TableScan, @@ -273,6 +287,7 @@ impl TreeNodeRewriter for SourceRewriter<'_> { match table { Table::ConnectorTable(table) => self.mutate_connector_table(&table_scan, table), + Table::LookupTable(table) => self.mutate_lookup_table(&table_scan, table), Table::MemoryTable { name, fields: _, diff --git a/crates/arroyo-planner/src/tables.rs b/crates/arroyo-planner/src/tables.rs index dd8f7141f..8bc133313 100644 --- a/crates/arroyo-planner/src/tables.rs +++ b/crates/arroyo-planner/src/tables.rs @@ -418,7 +418,7 @@ impl ConnectorTable { pub fn as_sql_source(&self) -> Result { match self.connection_type { ConnectionType::Source => {} - ConnectionType::Sink => { + ConnectionType::Sink | ConnectionType::Lookup => { return plan_err!("cannot read from sink"); } }; @@ -463,6 +463,7 @@ pub struct SourceOperator { #[allow(clippy::enum_variant_names)] #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Table { + LookupTable(ConnectorTable), ConnectorTable(ConnectorTable), MemoryTable { name: String, @@ -673,6 +674,7 @@ impl Table { columns, with_options, query: None, + temporary, .. }) = statement { @@ -750,17 +752,23 @@ impl Table { ), None => None, }; - Ok(Some(Table::ConnectorTable( - ConnectorTable::from_options( - &name, - connector, - fields, - primary_keys, - &mut with_map, - connection_profile, - ) - .map_err(|e| e.context(format!("Failed to create table {}", name)))?, - ))) + let table = ConnectorTable::from_options( + &name, + connector, + fields, + primary_keys, + &mut with_map, + connection_profile, + ).map_err(|e| e.context(format!("Failed to create table {}", name)))?; + + Ok(Some(match table.connection_type { + ConnectionType::Source | ConnectionType::Sink => { + Table::ConnectorTable(table) + } + ConnectionType::Lookup => { + Table::LookupTable(table) + } + })) } } } else { @@ -798,7 +806,7 @@ impl Table { pub fn name(&self) -> &str { match self { Table::MemoryTable { name, .. } | Table::TableFromQuery { name, .. } => name.as_str(), - Table::ConnectorTable(c) => c.name.as_str(), + Table::ConnectorTable(c) | Table::LookupTable(c) => c.name.as_str(), Table::PreviewSink { .. } => "preview", } } @@ -836,7 +844,11 @@ impl Table { fields, inferred_fields, .. - }) => inferred_fields + }) | Table::LookupTable(ConnectorTable { + fields, + inferred_fields, + .. + }) => inferred_fields .as_ref() .map(|fs| fs.iter().map(|f| f.field().clone()).collect()) .unwrap_or_else(|| { @@ -856,7 +868,7 @@ impl Table { pub fn connector_op(&self) -> Result { match self { - Table::ConnectorTable(c) => Ok(c.connector_op()), + Table::ConnectorTable(c) | Table::LookupTable(c) => Ok(c.connector_op()), Table::MemoryTable { .. } => plan_err!("can't write to a memory table"), Table::TableFromQuery { .. } => todo!(), Table::PreviewSink { logical_plan: _ } => Ok(default_sink()), diff --git a/crates/arroyo-rpc/proto/api.proto b/crates/arroyo-rpc/proto/api.proto index 24c120c11..09c8621f5 100644 --- a/crates/arroyo-rpc/proto/api.proto +++ b/crates/arroyo-rpc/proto/api.proto @@ -70,6 +70,13 @@ message JoinOperator { optional uint64 ttl_micros = 6; } +message LookupJoinOperator { + string name = 1; + ArroyoSchema schema = 2; + ConnectorOp connector = 3; + +} + message WindowFunctionOperator { string name = 1; ArroyoSchema input_schema = 2; diff --git a/crates/arroyo-rpc/src/api_types/connections.rs b/crates/arroyo-rpc/src/api_types/connections.rs index 93ca72d60..3ce2a8f72 100644 --- a/crates/arroyo-rpc/src/api_types/connections.rs +++ b/crates/arroyo-rpc/src/api_types/connections.rs @@ -51,6 +51,7 @@ pub struct ConnectionProfilePost { pub enum ConnectionType { Source, Sink, + Lookup, } impl Display for ConnectionType { @@ -58,6 +59,7 @@ impl Display for ConnectionType { match self { ConnectionType::Source => write!(f, "SOURCE"), ConnectionType::Sink => write!(f, "SINK"), + ConnectionType::Lookup => write!(f, "LOOKUP"), } } } diff --git a/crates/arroyo-worker/src/arrow/lookup_join.rs b/crates/arroyo-worker/src/arrow/lookup_join.rs new file mode 100644 index 000000000..c11927e22 --- /dev/null +++ b/crates/arroyo-worker/src/arrow/lookup_join.rs @@ -0,0 +1,31 @@ +use arrow_array::RecordBatch; +use async_trait::async_trait; +use datafusion::physical_expr::PhysicalExpr; +use arroyo_connectors::LookupConnector; +use arroyo_operator::context::{Collector, OperatorContext}; +use arroyo_operator::operator::ArrowOperator; + + +pub struct LookupJoin { + connector: Box, + key_exprs: Vec, +} + +#[async_trait] +impl ArrowOperator for LookupJoin { + fn name(&self) -> String { + format!("LookupJoin<{}>", self.connector.name()) + } + + async fn process_batch(&mut self, batch: RecordBatch, ctx: &mut OperatorContext, collector: &mut dyn Collector) { + let keys = self.key_exprs.iter() + .map(|expr| expr.evaluate(&batch).unwrap().into_array().unwrap()) + .collect::>(); + + + + for i in 0..keys.num_rows() { + + } + } +} \ No newline at end of file diff --git a/crates/arroyo-worker/src/arrow/mod.rs b/crates/arroyo-worker/src/arrow/mod.rs index 70353877a..56ff81e7b 100644 --- a/crates/arroyo-worker/src/arrow/mod.rs +++ b/crates/arroyo-worker/src/arrow/mod.rs @@ -35,6 +35,7 @@ pub mod tumbling_aggregating_window; pub mod updating_aggregator; pub mod watermark_generator; pub mod window_fn; +mod lookup_join; pub struct ValueExecutionOperator { name: String, diff --git a/crates/arroyo-worker/src/engine.rs b/crates/arroyo-worker/src/engine.rs index 879bab37f..ab01a3240 100644 --- a/crates/arroyo-worker/src/engine.rs +++ b/crates/arroyo-worker/src/engine.rs @@ -874,6 +874,7 @@ pub fn construct_operator( OperatorName::ExpressionWatermark => Box::new(WatermarkGeneratorConstructor), OperatorName::Join => Box::new(JoinWithExpirationConstructor), OperatorName::InstantJoin => Box::new(InstantJoinConstructor), + OperatorName::LookupJoin => todo!(), OperatorName::WindowFunction => Box::new(WindowFunctionConstructor), OperatorName::ConnectorSource | OperatorName::ConnectorSink => { let op: api::ConnectorOp = prost::Message::decode(config).unwrap(); From 56ee3b8d0c6360f0555d2b4f186afadaf3660d5e Mon Sep 17 00:00:00 2001 From: Micah Wylde Date: Thu, 26 Dec 2024 16:09:39 -0800 Subject: [PATCH 02/14] work --- crates/arroyo-connectors/src/redis/lookup.rs | 39 +++++---- crates/arroyo-planner/src/plan/join.rs | 4 +- .../src/test/queries/lookup_join.sql | 39 +++++++++ crates/arroyo-worker/src/arrow/lookup_join.rs | 86 ++++++++++++++++--- 4 files changed, 140 insertions(+), 28 deletions(-) create mode 100644 crates/arroyo-planner/src/test/queries/lookup_join.sql diff --git a/crates/arroyo-connectors/src/redis/lookup.rs b/crates/arroyo-connectors/src/redis/lookup.rs index 2af114259..24e962a21 100644 --- a/crates/arroyo-connectors/src/redis/lookup.rs +++ b/crates/arroyo-connectors/src/redis/lookup.rs @@ -1,14 +1,15 @@ use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -use arrow::array::{ArrayRef, AsArray, RecordBatch}; +use arrow::array::{ArrayRef, AsArray, OffsetSizeTrait, RecordBatch}; use arrow::compute::StringArrayType; use arrow::datatypes::DataType; use async_trait::async_trait; use futures::future::OptionFuture; use futures::stream::FuturesOrdered; use futures::StreamExt; -use redis::{AsyncCommands, RedisFuture, RedisResult}; +use redis::{cmd, AsyncCommands, Pipeline, RedisFuture, RedisResult, Value}; +use redis::aio::ConnectionLike; use arroyo_formats::de::ArrowDeserializer; use crate::LookupConnector; use crate::redis::{RedisClient, RedisConnector}; @@ -50,26 +51,32 @@ impl LookupConnector for RedisLookup { assert_eq!(keys.len(), 1, "redis lookup can only have a single key"); assert_eq!(*keys[0].data_type(), DataType::Utf8, "redis lookup key must be a string"); - let key = keys[0].as_string(); let connection = self.connection.as_mut().unwrap(); - let result = connection.mget::<_, Vec>(&key.iter().filter_map(|k| k).collect::>()) - .await - .unwrap(); + let mut mget = cmd("mget"); - let mut result_iter = result.iter(); + for k in keys[0].as_string::() { + mget.arg(k.unwrap()); + } + + let Value::Array(vs) = connection.req_packed_command(&mget).await.unwrap() else { + panic!("value was not an array"); + }; - for k in key.iter() { - if k.is_some() { - self.deserializer.deserialize_slice()result_iter.next() + for v in vs { + match v { + Value::Nil => {} + Value::SimpleString(s) => { + self.deserializer.deserialize_slice() + } + v => { + panic!("unexpected type {:?}", v); + } } - } - - while let Some(t) = futures.next().await { - - }; + } + - Ok(()) + todo!() } } \ No newline at end of file diff --git a/crates/arroyo-planner/src/plan/join.rs b/crates/arroyo-planner/src/plan/join.rs index ae5d3c2e3..ba374bd96 100644 --- a/crates/arroyo-planner/src/plan/join.rs +++ b/crates/arroyo-planner/src/plan/join.rs @@ -17,6 +17,7 @@ use datafusion::logical_expr::{ use datafusion::prelude::coalesce; use std::sync::Arc; use crate::extension::lookup::{LookupJoin, LookupSource}; +use crate::schemas::add_timestamp_field; use crate::tables::ConnectorTable; pub(crate) struct JoinRewriter<'a> { @@ -233,6 +234,7 @@ fn has_lookup(plan: &LogicalPlan) -> Result { } fn maybe_plan_lookup_join(join: &Join) -> Result> { + println!("Planning lookup join"); if has_lookup(&join.left)? { return plan_err!("lookup sources must be on the right side of an inner or left join"); } @@ -262,7 +264,7 @@ fn maybe_plan_lookup_join(join: &Join) -> Result> { Ok(Some(LogicalPlan::Extension(Extension { node: Arc::new(LookupJoin { input: (*join.left).clone(), - schema: join.schema.clone(), + schema: add_timestamp_field(join.schema.clone(), None)?, connector, on: join.on.clone(), filter: lookup.filter, diff --git a/crates/arroyo-planner/src/test/queries/lookup_join.sql b/crates/arroyo-planner/src/test/queries/lookup_join.sql new file mode 100644 index 000000000..76c9e84e5 --- /dev/null +++ b/crates/arroyo-planner/src/test/queries/lookup_join.sql @@ -0,0 +1,39 @@ +CREATE TABLE orders ( + order_id INT, + user_id INT, + product_id INT, + quantity INT, + order_timestamp TIMESTAMP +) with ( + connector = 'kafka', + bootstrap_servers = 'localhost:9092', + type = 'source', + topic = 'orders', + format = 'json' +); + +CREATE TEMPORARY TABLE products ( + product_id INT PRIMARY KEY, + product_name TEXT, + unit_price FLOAT, + category TEXT, + last_updated TIMESTAMP +) with ( + connector = 'redis', + format = 'json', + type = 'lookup', + address = 'redis://localhost:6379' +); + +SELECT + o.order_id, + o.user_id, + o.quantity, + o.order_timestamp, + p.product_name, + p.unit_price, + p.category, + (o.quantity * p.unit_price) as total_amount +FROM orders o + JOIN products p + ON o.product_id = p.product_id; \ No newline at end of file diff --git a/crates/arroyo-worker/src/arrow/lookup_join.rs b/crates/arroyo-worker/src/arrow/lookup_join.rs index c11927e22..92419176b 100644 --- a/crates/arroyo-worker/src/arrow/lookup_join.rs +++ b/crates/arroyo-worker/src/arrow/lookup_join.rs @@ -1,14 +1,25 @@ -use arrow_array::RecordBatch; +use std::collections::HashMap; +use std::sync::Arc; + +use arrow_array::{RecordBatch}; +use arrow::row::{OwnedRow, RowConverter}; use async_trait::async_trait; +use datafusion::common::DFSchemaRef; use datafusion::physical_expr::PhysicalExpr; + use arroyo_connectors::LookupConnector; use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::ArrowOperator; +use arroyo_types::JoinType; - +/// A simple in-operator cache storing the entire “right side” row batch keyed by a string. pub struct LookupJoin { connector: Box, - key_exprs: Vec, + key_exprs: Vec>, + cache: HashMap, OwnedRow>, + key_row_converter: RowConverter, + result_row_converter: RowConverter, + join_type: JoinType, } #[async_trait] @@ -17,15 +28,68 @@ impl ArrowOperator for LookupJoin { format!("LookupJoin<{}>", self.connector.name()) } - async fn process_batch(&mut self, batch: RecordBatch, ctx: &mut OperatorContext, collector: &mut dyn Collector) { - let keys = self.key_exprs.iter() - .map(|expr| expr.evaluate(&batch).unwrap().into_array().unwrap()) - .collect::>(); + async fn process_batch( + &mut self, + batch: RecordBatch, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { + let num_rows = batch.num_rows(); - - - for i in 0..keys.num_rows() { + let key_arrays: Vec<_> = self + .key_exprs + .iter() + .map(|expr| { + expr.evaluate(&batch) + .unwrap() + .into_array(num_rows) + .unwrap() + }) + .collect(); + + let rows = self.key_row_converter.convert_columns(&key_arrays).unwrap(); + + let mut key_map: HashMap<_, Vec> = HashMap::new(); + for (i, row) in rows.iter().enumerate() { + key_map.entry(row.owned()).or_default().push(i); + } + + let mut uncached_keys = Vec::new(); + for k in key_map.keys() { + if !self.cache.contains_key(k.row().as_ref()) { + uncached_keys.push(k.clone()); + } + } + if !uncached_keys.is_empty() { + let cols = self.key_row_converter.convert_rows(uncached_keys.iter().map(|r| r.row())).unwrap(); + + let result_batch = self + .connector + .lookup(&cols) + .await; + + let result_rows = self.result_row_converter.convert_columns(result_batch.columns()) + .unwrap(); + + assert_eq!(result_rows.num_rows(), uncached_keys.len()); + + for (k, v) in uncached_keys.iter().zip(result_rows.iter()) { + self.cache.insert(k.as_ref().to_vec(), v.owned()); + } + } + + let mut output_rows = self.result_row_converter.empty_rows(batch.num_rows(), batch.num_rows() * 10); + + for row in rows.iter() { + output_rows.push(self.cache.get(row.data()).expect("row should be cached").row()); } + + let right_side = self.result_row_converter.convert_rows(output_rows.iter()).unwrap(); + let mut result = batch.columns().to_vec(); + result.extend(right_side); + + collector.collect(RecordBatch::try_new(ctx.out_schema.as_ref().unwrap().schema.clone(), result).unwrap()) + .await; } -} \ No newline at end of file +} From 448cc7020bf22f5b87d41a0cb45ff6a83321f70b Mon Sep 17 00:00:00 2001 From: Micah Wylde Date: Thu, 26 Dec 2024 16:43:32 -0800 Subject: [PATCH 03/14] move buffers into deserializer --- crates/arroyo-connectors/src/kafka/mod.rs | 9 +- crates/arroyo-connectors/src/lib.rs | 4 +- crates/arroyo-connectors/src/redis/lookup.rs | 41 ++----- crates/arroyo-formats/src/de.rs | 110 +++++++++++------- crates/arroyo-operator/src/context.rs | 65 ++--------- crates/arroyo-planner/src/extension/join.rs | 2 +- crates/arroyo-planner/src/extension/lookup.rs | 6 +- crates/arroyo-planner/src/extension/mod.rs | 2 + crates/arroyo-planner/src/plan/join.rs | 1 - crates/arroyo-rpc/proto/api.proto | 3 +- crates/arroyo-worker/src/arrow/lookup_join.rs | 13 ++- 11 files changed, 110 insertions(+), 146 deletions(-) diff --git a/crates/arroyo-connectors/src/kafka/mod.rs b/crates/arroyo-connectors/src/kafka/mod.rs index 9d4e402b4..acb653aec 100644 --- a/crates/arroyo-connectors/src/kafka/mod.rs +++ b/crates/arroyo-connectors/src/kafka/mod.rs @@ -645,10 +645,9 @@ impl KafkaTester { BadData::Fail {}, Arc::new(schema_resolver), ); - let mut builders = aschema.builders(); let mut error = deserializer - .deserialize_slice(&mut builders, &msg, SystemTime::now(), None) + .deserialize_slice(&msg, SystemTime::now(), None) .await .into_iter() .next(); @@ -667,10 +666,9 @@ impl KafkaTester { None, BadData::Fail {}, ); - let mut builders = aschema.builders(); let mut error = deserializer - .deserialize_slice(&mut builders, &msg, SystemTime::now(), None) + .deserialize_slice(&msg, SystemTime::now(), None) .await .into_iter() .next(); @@ -701,10 +699,9 @@ impl KafkaTester { let aschema: ArroyoSchema = schema.clone().into(); let mut deserializer = ArrowDeserializer::new(format.clone(), aschema.clone(), None, BadData::Fail {}); - let mut builders = aschema.builders(); let mut error = deserializer - .deserialize_slice(&mut builders, &msg, SystemTime::now(), None) + .deserialize_slice(&msg, SystemTime::now(), None) .await .into_iter() .next(); diff --git a/crates/arroyo-connectors/src/lib.rs b/crates/arroyo-connectors/src/lib.rs index 087342c9d..e4378bb48 100644 --- a/crates/arroyo-connectors/src/lib.rs +++ b/crates/arroyo-connectors/src/lib.rs @@ -5,7 +5,7 @@ use arroyo_rpc::api_types::connections::{ }; use arroyo_rpc::primitive_to_sql; use arroyo_rpc::var_str::VarStr; -use arroyo_types::string_to_map; +use arroyo_types::{string_to_map, SourceError}; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -70,7 +70,7 @@ pub struct EmptyConfig {} pub trait LookupConnector { fn name(&self) -> String; - async fn lookup(&mut self, keys: &[ArrayRef]) -> RecordBatch; + async fn lookup(&mut self, keys: &[ArrayRef]) -> Option>; } diff --git a/crates/arroyo-connectors/src/redis/lookup.rs b/crates/arroyo-connectors/src/redis/lookup.rs index 24e962a21..8d002aa0f 100644 --- a/crates/arroyo-connectors/src/redis/lookup.rs +++ b/crates/arroyo-connectors/src/redis/lookup.rs @@ -1,18 +1,13 @@ -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; -use arrow::array::{ArrayRef, AsArray, OffsetSizeTrait, RecordBatch}; -use arrow::compute::StringArrayType; +use std::time::SystemTime; +use arrow::array::{ArrayRef, AsArray, RecordBatch}; use arrow::datatypes::DataType; use async_trait::async_trait; -use futures::future::OptionFuture; -use futures::stream::FuturesOrdered; -use futures::StreamExt; -use redis::{cmd, AsyncCommands, Pipeline, RedisFuture, RedisResult, Value}; +use redis::{cmd, Value}; use redis::aio::ConnectionLike; use arroyo_formats::de::ArrowDeserializer; +use arroyo_types::SourceError; use crate::LookupConnector; -use crate::redis::{RedisClient, RedisConnector}; +use crate::redis::{RedisClient}; use crate::redis::sink::GeneralConnection; pub struct RedisLookup { @@ -21,29 +16,13 @@ pub struct RedisLookup { connection: Option, } -// pub enum RedisFutureOrNull<'a> { -// RedisFuture(RedisFuture<'a, String>), -// Null -// } -// -// impl <'a> Future for RedisFutureOrNull<'a> { -// type Output = RedisResult>; -// -// fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { -// match self { -// RedisFutureOrNull::RedisFuture(f) => (**f).poll().map(|t| Some(t)), -// RedisFutureOrNull::Null => Poll::Ready(None), -// } -// } -// } - #[async_trait] impl LookupConnector for RedisLookup { fn name(&self) -> String { "RedisLookup".to_string() } - async fn lookup(&mut self, keys: &[ArrayRef]) -> RecordBatch { + async fn lookup(&mut self, keys: &[ArrayRef]) -> Option> { if self.connection.is_none() { self.connection = Some(self.client.get_connection().await.unwrap()); } @@ -66,9 +45,11 @@ impl LookupConnector for RedisLookup { for v in vs { match v { - Value::Nil => {} + Value::Nil => { + todo!("handle missing values") + } Value::SimpleString(s) => { - self.deserializer.deserialize_slice() + self.deserializer.deserialize_slice(s.as_bytes(), SystemTime::now(), None).await; } v => { panic!("unexpected type {:?}", v); @@ -77,6 +58,6 @@ impl LookupConnector for RedisLookup { } - todo!() + self.deserializer.flush_buffer() } } \ No newline at end of file diff --git a/crates/arroyo-formats/src/de.rs b/crates/arroyo-formats/src/de.rs index 3938fe0fc..62a4a59b8 100644 --- a/crates/arroyo-formats/src/de.rs +++ b/crates/arroyo-formats/src/de.rs @@ -3,9 +3,7 @@ use crate::proto::schema::get_pool; use crate::{proto, should_flush}; use arrow::array::{Int32Builder, Int64Builder}; use arrow::compute::kernels; -use arrow_array::builder::{ - ArrayBuilder, GenericByteBuilder, StringBuilder, TimestampNanosecondBuilder, -}; +use arrow_array::builder::{make_builder, ArrayBuilder, GenericByteBuilder, StringBuilder, TimestampNanosecondBuilder}; use arrow_array::types::GenericBinaryType; use arrow_array::RecordBatch; use arroyo_rpc::df::ArroyoSchema; @@ -19,6 +17,7 @@ use serde_json::Value; use std::collections::HashMap; use std::sync::Arc; use std::time::{Instant, SystemTime}; +use arrow_schema::SchemaRef; use tokio::sync::Mutex; #[derive(Debug, Clone)] @@ -29,6 +28,45 @@ pub enum FieldValueType<'a> { // Extend with more types as needed } +struct ContextBuffer { + buffer: Vec>, + created: Instant, + schema: SchemaRef, +} + +impl ContextBuffer { + fn new(schema: SchemaRef) -> Self { + let buffer = schema + .fields + .iter() + .map(|f| make_builder(f.data_type(), 16)) + .collect(); + + Self { + buffer, + created: Instant::now(), + schema, + } + } + + pub fn size(&self) -> usize { + self.buffer[0].len() + } + + pub fn should_flush(&self) -> bool { + should_flush(self.size(), self.created) + } + + pub fn finish(&mut self) -> RecordBatch { + RecordBatch::try_new( + self.schema.clone(), + self.buffer.iter_mut().map(|a| a.finish()).collect(), + ) + .unwrap() + } +} + + pub struct FramingIterator<'a> { framing: Option>, buf: &'a [u8], @@ -92,6 +130,7 @@ pub struct ArrowDeserializer { proto_pool: DescriptorPool, schema_resolver: Arc, additional_fields_builder: Option>>, + buffer: ContextBuffer, } impl ArrowDeserializer { @@ -161,6 +200,7 @@ impl ArrowDeserializer { }), format: Arc::new(format), framing: framing.map(Arc::new), + buffer: ContextBuffer::new(schema.schema.clone()), schema, schema_registry: Arc::new(Mutex::new(HashMap::new())), bad_data, @@ -174,25 +214,29 @@ impl ArrowDeserializer { pub async fn deserialize_slice( &mut self, - buffer: &mut [Box], msg: &[u8], timestamp: SystemTime, additional_fields: Option<&HashMap<&String, FieldValueType<'_>>>, ) -> Vec { match &*self.format { - Format::Avro(_) => self.deserialize_slice_avro(buffer, msg, timestamp).await, + Format::Avro(_) => self.deserialize_slice_avro(msg, timestamp).await, _ => FramingIterator::new(self.framing.clone(), msg) - .map(|t| self.deserialize_single(buffer, t, timestamp, additional_fields)) + .map(|t| self.deserialize_single(t, timestamp, additional_fields)) .filter_map(|t| t.err()) .collect(), } } pub fn should_flush(&self) -> bool { - should_flush(self.buffered_count, self.buffered_since) + self.buffer.should_flush() || + should_flush(self.buffered_count, self.buffered_since) } pub fn flush_buffer(&mut self) -> Option> { + if self.buffer.size() > 0 { + return Some(Ok(self.buffer.finish())); + } + let (decoder, timestamp) = self.json_decoder.as_mut()?; self.buffered_since = Instant::now(); self.buffered_count = 0; @@ -244,7 +288,6 @@ impl ArrowDeserializer { fn deserialize_single( &mut self, - buffer: &mut [Box], msg: &[u8], timestamp: SystemTime, additional_fields: Option<&HashMap<&String, FieldValueType>>, @@ -254,20 +297,20 @@ impl ArrowDeserializer { | Format::Json(JsonFormat { unstructured: true, .. }) => { - self.deserialize_raw_string(buffer, msg); - add_timestamp(buffer, self.schema.timestamp_index, timestamp); + self.deserialize_raw_string(msg); + add_timestamp(&mut self.buffer.buffer, self.schema.timestamp_index, timestamp); if let Some(fields) = additional_fields { for (k, v) in fields.iter() { - add_additional_fields(buffer, &self.schema, k, v); + add_additional_fields(&mut self.buffer.buffer, &self.schema, k, v); } } } Format::RawBytes(_) => { - self.deserialize_raw_bytes(buffer, msg); - add_timestamp(buffer, self.schema.timestamp_index, timestamp); + self.deserialize_raw_bytes(msg); + add_timestamp(&mut self.buffer.buffer, self.schema.timestamp_index, timestamp); if let Some(fields) = additional_fields { for (k, v) in fields.iter() { - add_additional_fields(buffer, &self.schema, k, v); + add_additional_fields(&mut self.buffer.buffer, &self.schema, k, v); } } } @@ -317,7 +360,7 @@ impl ArrowDeserializer { let json = proto::de::deserialize_proto(&mut self.proto_pool, proto, msg)?; if proto.into_unstructured_json { - self.decode_into_json(buffer, json, timestamp); + self.decode_into_json(json, timestamp); } else { let Some((decoder, timestamp_builder)) = &mut self.json_decoder else { panic!("json decoder not initialized"); @@ -345,7 +388,6 @@ impl ArrowDeserializer { fn decode_into_json( &mut self, - builders: &mut [Box], value: Value, timestamp: SystemTime, ) { @@ -354,19 +396,18 @@ impl ArrowDeserializer { .schema .column_with_name("value") .expect("no 'value' column for unstructured avro"); - let array = builders[idx] + let array = self.buffer.buffer[idx] .as_any_mut() .downcast_mut::() .expect("'value' column has incorrect type"); array.append_value(value.to_string()); - add_timestamp(builders, self.schema.timestamp_index, timestamp); + add_timestamp(&mut self.buffer.buffer, self.schema.timestamp_index, timestamp); self.buffered_count += 1; } pub async fn deserialize_slice_avro<'a>( &mut self, - builders: &mut [Box], msg: &'a [u8], timestamp: SystemTime, ) -> Vec { @@ -398,7 +439,7 @@ impl ArrowDeserializer { })?; if into_json { - self.decode_into_json(builders, de::avro_to_json(value), timestamp); + self.decode_into_json(de::avro_to_json(value), timestamp); } else { // for now round-trip through json in order to handle unsupported avro features // as that allows us to rely on raw json deserialization @@ -421,26 +462,26 @@ impl ArrowDeserializer { .collect() } - fn deserialize_raw_string(&mut self, buffer: &mut [Box], msg: &[u8]) { + fn deserialize_raw_string(&mut self, msg: &[u8]) { let (col, _) = self .schema .schema .column_with_name("value") .expect("no 'value' column for RawString format"); - buffer[col] + self.buffer.buffer[col] .as_any_mut() .downcast_mut::() .expect("'value' column has incorrect type") .append_value(String::from_utf8_lossy(msg)); } - fn deserialize_raw_bytes(&mut self, buffer: &mut [Box], msg: &[u8]) { + fn deserialize_raw_bytes(&mut self, msg: &[u8]) { let (col, _) = self .schema .schema .column_with_name("value") .expect("no 'value' column for RawBytes format"); - buffer[col] + self.buffer.buffer[col] .as_any_mut() .downcast_mut::>>() .expect("'value' column has incorrect type") @@ -651,7 +692,7 @@ mod tests { ); } - fn setup_deserializer(bad_data: BadData) -> (Vec>, ArrowDeserializer) { + fn setup_deserializer(bad_data: BadData) -> ArrowDeserializer { let schema = Arc::new(Schema::new(vec![ arrow_schema::Field::new("x", arrow_schema::DataType::Int64, true), arrow_schema::Field::new( @@ -661,12 +702,6 @@ mod tests { ), ])); - let arrays: Vec<_> = schema - .fields - .iter() - .map(|f| make_builder(f.data_type(), 16)) - .collect(); - let schema = ArroyoSchema::from_schema_unkeyed(schema).unwrap(); let deserializer = ArrowDeserializer::new( @@ -683,19 +718,18 @@ mod tests { bad_data, ); - (arrays, deserializer) + deserializer } #[tokio::test] async fn test_bad_data_drop() { - let (mut arrays, mut deserializer) = setup_deserializer(BadData::Drop {}); + let mut deserializer = setup_deserializer(BadData::Drop {}); let now = SystemTime::now(); assert_eq!( deserializer .deserialize_slice( - &mut arrays[..], json!({ "x": 5 }).to_string().as_bytes(), now, None, @@ -706,7 +740,6 @@ mod tests { assert_eq!( deserializer .deserialize_slice( - &mut arrays[..], json!({ "x": "hello" }).to_string().as_bytes(), now, None, @@ -728,12 +761,11 @@ mod tests { #[tokio::test] async fn test_bad_data_fail() { - let (mut arrays, mut deserializer) = setup_deserializer(BadData::Fail {}); + let mut deserializer = setup_deserializer(BadData::Fail {}); assert_eq!( deserializer .deserialize_slice( - &mut arrays[..], json!({ "x": 5 }).to_string().as_bytes(), SystemTime::now(), None, @@ -744,7 +776,6 @@ mod tests { assert_eq!( deserializer .deserialize_slice( - &mut arrays[..], json!({ "x": "hello" }).to_string().as_bytes(), SystemTime::now(), None, @@ -786,7 +817,7 @@ mod tests { let time = SystemTime::now(); let result = deserializer - .deserialize_slice(&mut arrays, &[0, 1, 2, 3, 4, 5], time, None) + .deserialize_slice(&[0, 1, 2, 3, 4, 5], time, None) .await; assert!(result.is_empty()); @@ -853,7 +884,6 @@ mod tests { let result = deserializer .deserialize_slice( - &mut arrays, json!({ "x": 5 }).to_string().as_bytes(), time, Some(&additional_fields), diff --git a/crates/arroyo-operator/src/context.rs b/crates/arroyo-operator/src/context.rs index 32fb3610f..38dd8c353 100644 --- a/crates/arroyo-operator/src/context.rs +++ b/crates/arroyo-operator/src/context.rs @@ -1,9 +1,8 @@ use crate::{server_for_hash_array, RateLimiter}; -use arrow::array::{make_builder, Array, ArrayBuilder, PrimitiveArray, RecordBatch}; +use arrow::array::{Array,PrimitiveArray, RecordBatch}; use arrow::compute::{partition, sort_to_indices, take}; -use arrow::datatypes::{SchemaRef, UInt64Type}; +use arrow::datatypes::{UInt64Type}; use arroyo_formats::de::{ArrowDeserializer, FieldValueType}; -use arroyo_formats::should_flush; use arroyo_metrics::{register_queue_gauge, QueueGauges, TaskCounters}; use arroyo_rpc::config::config; use arroyo_rpc::df::ArroyoSchema; @@ -23,7 +22,7 @@ use std::collections::HashMap; use std::mem::size_of_val; use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; use std::sync::Arc; -use std::time::{Instant, SystemTime}; +use std::time::{SystemTime}; use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender}; use tokio::sync::Notify; @@ -205,44 +204,6 @@ pub fn batch_bounded(size: u32) -> (BatchSender, BatchReceiver) { ) } -struct ContextBuffer { - buffer: Vec>, - created: Instant, - schema: SchemaRef, -} - -impl ContextBuffer { - fn new(schema: SchemaRef) -> Self { - let buffer = schema - .fields - .iter() - .map(|f| make_builder(f.data_type(), 16)) - .collect(); - - Self { - buffer, - created: Instant::now(), - schema, - } - } - - pub fn size(&self) -> usize { - self.buffer[0].len() - } - - pub fn should_flush(&self) -> bool { - should_flush(self.size(), self.created) - } - - pub fn finish(&mut self) -> RecordBatch { - RecordBatch::try_new( - self.schema.clone(), - self.buffer.iter_mut().map(|a| a.finish()).collect(), - ) - .unwrap() - } -} - pub struct SourceContext { pub out_schema: ArroyoSchema, pub error_reporter: ErrorReporter, @@ -303,7 +264,6 @@ impl SourceContext { pub struct SourceCollector { deserializer: Option, - buffer: ContextBuffer, buffered_error: Option, error_rate_limiter: RateLimiter, pub out_schema: ArroyoSchema, @@ -322,7 +282,6 @@ impl SourceCollector { task_info: &Arc, ) -> Self { Self { - buffer: ContextBuffer::new(out_schema.schema.clone()), out_schema, collector, control_tx, @@ -373,12 +332,11 @@ impl SourceCollector { } pub fn should_flush(&self) -> bool { - self.buffer.should_flush() - || self - .deserializer - .as_ref() - .map(|d| d.should_flush()) - .unwrap_or(false) + self + .deserializer + .as_ref() + .map(|d| d.should_flush()) + .unwrap_or(false) } pub async fn deserialize_slice( @@ -393,7 +351,7 @@ impl SourceCollector { .expect("deserializer not initialized!"); let errors = deserializer - .deserialize_slice(&mut self.buffer.buffer, msg, time, additional_fields) + .deserialize_slice( msg, time, additional_fields) .await; self.collect_source_errors(errors).await?; @@ -443,11 +401,6 @@ impl SourceCollector { } pub async fn flush_buffer(&mut self) -> Result<(), UserError> { - if self.buffer.size() > 0 { - let batch = self.buffer.finish(); - self.collector.collect(batch).await; - } - if let Some(deserializer) = self.deserializer.as_mut() { if let Some(buffer) = deserializer.flush_buffer() { match buffer { diff --git a/crates/arroyo-planner/src/extension/join.rs b/crates/arroyo-planner/src/extension/join.rs index 5695bcc26..38dc27581 100644 --- a/crates/arroyo-planner/src/extension/join.rs +++ b/crates/arroyo-planner/src/extension/join.rs @@ -80,7 +80,7 @@ impl ArroyoExtension for JoinExtension { } fn output_schema(&self) -> ArroyoSchema { - ArroyoSchema::from_schema_unkeyed(Arc::new(self.schema().as_ref().clone().into())).unwrap() + ArroyoSchema::from_schema_unkeyed(self.schema().inner().clone()).unwrap() } } diff --git a/crates/arroyo-planner/src/extension/lookup.rs b/crates/arroyo-planner/src/extension/lookup.rs index 1b6dbc19f..a762c2499 100644 --- a/crates/arroyo-planner/src/extension/lookup.rs +++ b/crates/arroyo-planner/src/extension/lookup.rs @@ -67,15 +67,15 @@ multifield_partial_ord!(LookupJoin, input, connector, on, filter, alias); impl ArroyoExtension for LookupJoin { fn node_name(&self) -> Option { - todo!() + None } fn plan_node(&self, planner: &Planner, index: usize, input_schemas: Vec) -> datafusion::common::Result { - todo!() + let keys = } fn output_schema(&self) -> ArroyoSchema { - todo!() + ArroyoSchema::from_schema_unkeyed(self.schema.inner().clone()).unwrap() } } diff --git a/crates/arroyo-planner/src/extension/mod.rs b/crates/arroyo-planner/src/extension/mod.rs index 5b88f06bf..2fec17b0b 100644 --- a/crates/arroyo-planner/src/extension/mod.rs +++ b/crates/arroyo-planner/src/extension/mod.rs @@ -29,6 +29,7 @@ use crate::builder::{NamedNode, Planner}; use crate::schemas::{add_timestamp_field, has_timestamp_field}; use crate::{fields_with_qualifiers, schema_from_df_fields, DFField, ASYNC_RESULT_FIELD}; use join::JoinExtension; +use crate::extension::lookup::LookupJoin; pub(crate) mod aggregate; pub(crate) mod debezium; @@ -87,6 +88,7 @@ impl<'a> TryFrom<&'a dyn UserDefinedLogicalNode> for &'a dyn ArroyoExtension { .or_else(|_| try_from_t::(node)) .or_else(|_| try_from_t::(node)) .or_else(|_| try_from_t::(node)) + .or_else(|_| try_from_t::(node)) .map_err(|_| DataFusionError::Plan(format!("unexpected node: {}", node.name()))) } } diff --git a/crates/arroyo-planner/src/plan/join.rs b/crates/arroyo-planner/src/plan/join.rs index ba374bd96..af7b5e4fe 100644 --- a/crates/arroyo-planner/src/plan/join.rs +++ b/crates/arroyo-planner/src/plan/join.rs @@ -234,7 +234,6 @@ fn has_lookup(plan: &LogicalPlan) -> Result { } fn maybe_plan_lookup_join(join: &Join) -> Result> { - println!("Planning lookup join"); if has_lookup(&join.left)? { return plan_err!("lookup sources must be on the right side of an inner or left join"); } diff --git a/crates/arroyo-rpc/proto/api.proto b/crates/arroyo-rpc/proto/api.proto index 09c8621f5..1fa9aa63c 100644 --- a/crates/arroyo-rpc/proto/api.proto +++ b/crates/arroyo-rpc/proto/api.proto @@ -74,7 +74,8 @@ message LookupJoinOperator { string name = 1; ArroyoSchema schema = 2; ConnectorOp connector = 3; - + repeated bytes key_exprs = 4; + JoinType join_type = 5; } message WindowFunctionOperator { diff --git a/crates/arroyo-worker/src/arrow/lookup_join.rs b/crates/arroyo-worker/src/arrow/lookup_join.rs index 92419176b..07a4d598c 100644 --- a/crates/arroyo-worker/src/arrow/lookup_join.rs +++ b/crates/arroyo-worker/src/arrow/lookup_join.rs @@ -4,7 +4,6 @@ use std::sync::Arc; use arrow_array::{RecordBatch}; use arrow::row::{OwnedRow, RowConverter}; use async_trait::async_trait; -use datafusion::common::DFSchemaRef; use datafusion::physical_expr::PhysicalExpr; use arroyo_connectors::LookupConnector; @@ -69,13 +68,15 @@ impl ArrowOperator for LookupJoin { .lookup(&cols) .await; - let result_rows = self.result_row_converter.convert_columns(result_batch.columns()) - .unwrap(); + if let Some(result_batch) = result_batch { + let result_rows = self.result_row_converter.convert_columns(result_batch.unwrap().columns()) + .unwrap(); - assert_eq!(result_rows.num_rows(), uncached_keys.len()); + assert_eq!(result_rows.num_rows(), uncached_keys.len()); - for (k, v) in uncached_keys.iter().zip(result_rows.iter()) { - self.cache.insert(k.as_ref().to_vec(), v.owned()); + for (k, v) in uncached_keys.iter().zip(result_rows.iter()) { + self.cache.insert(k.as_ref().to_vec(), v.owned()); + } } } From eb5f9edcf5507f58ae8cff6a40554b3835dfd54e Mon Sep 17 00:00:00 2001 From: Micah Wylde Date: Mon, 30 Dec 2024 14:48:04 -0800 Subject: [PATCH 04/14] work on lookup joins --- .../src/filesystem/sink/mod.rs | 3 +- crates/arroyo-connectors/src/kafka/mod.rs | 12 +- crates/arroyo-connectors/src/lib.rs | 12 +- crates/arroyo-connectors/src/redis/lookup.rs | 45 ++++--- crates/arroyo-connectors/src/redis/mod.rs | 69 ++++++---- crates/arroyo-connectors/src/redis/sink.rs | 8 +- crates/arroyo-datastream/src/logical.rs | 30 +++-- crates/arroyo-formats/src/de.rs | 62 ++++----- crates/arroyo-operator/src/connector.rs | 43 ++++++- crates/arroyo-operator/src/context.rs | 31 +++-- crates/arroyo-operator/src/operator.rs | 2 +- crates/arroyo-planner/src/builder.rs | 2 +- crates/arroyo-planner/src/extension/lookup.rs | 102 ++++++++++++--- crates/arroyo-planner/src/extension/mod.rs | 4 +- crates/arroyo-planner/src/plan/join.rs | 81 ++++++++---- crates/arroyo-planner/src/rewriters.rs | 6 +- crates/arroyo-planner/src/schemas.rs | 2 +- crates/arroyo-planner/src/tables.rs | 14 +-- .../src/test/queries/lookup_join.sql | 4 +- crates/arroyo-rpc/proto/api.proto | 11 +- crates/arroyo-rpc/src/lib.rs | 11 ++ crates/arroyo-worker/src/arrow/lookup_join.rs | 119 ++++++++++++++---- crates/arroyo-worker/src/arrow/mod.rs | 2 +- .../src/arrow/tumbling_aggregating_window.rs | 2 +- crates/arroyo-worker/src/engine.rs | 21 ++-- 25 files changed, 472 insertions(+), 226 deletions(-) diff --git a/crates/arroyo-connectors/src/filesystem/sink/mod.rs b/crates/arroyo-connectors/src/filesystem/sink/mod.rs index 46521e056..caa68632c 100644 --- a/crates/arroyo-connectors/src/filesystem/sink/mod.rs +++ b/crates/arroyo-connectors/src/filesystem/sink/mod.rs @@ -1564,8 +1564,7 @@ impl TwoPhaseCommitter for FileSystemSink, ) -> Result<()> { - self.start(Arc::new(ctx.in_schemas.first().unwrap().clone())) - .await?; + self.start(Arc::new(ctx.in_schemas.first().unwrap().clone()))?; let mut max_file_index = 0; let mut recovered_files = Vec::new(); diff --git a/crates/arroyo-connectors/src/kafka/mod.rs b/crates/arroyo-connectors/src/kafka/mod.rs index acb653aec..fd095d56c 100644 --- a/crates/arroyo-connectors/src/kafka/mod.rs +++ b/crates/arroyo-connectors/src/kafka/mod.rs @@ -641,7 +641,7 @@ impl KafkaTester { let mut deserializer = ArrowDeserializer::with_schema_resolver( format.clone(), None, - aschema.clone(), + Arc::new(aschema), BadData::Fail {}, Arc::new(schema_resolver), ); @@ -662,7 +662,7 @@ impl KafkaTester { let aschema: ArroyoSchema = schema.clone().into(); let mut deserializer = ArrowDeserializer::new( format.clone(), - aschema.clone(), + Arc::new(aschema), None, BadData::Fail {}, ); @@ -697,8 +697,12 @@ impl KafkaTester { } Format::Protobuf(_) => { let aschema: ArroyoSchema = schema.clone().into(); - let mut deserializer = - ArrowDeserializer::new(format.clone(), aschema.clone(), None, BadData::Fail {}); + let mut deserializer = ArrowDeserializer::new( + format.clone(), + Arc::new(aschema), + None, + BadData::Fail {}, + ); let mut error = deserializer .deserialize_slice(&msg, SystemTime::now(), None) diff --git a/crates/arroyo-connectors/src/lib.rs b/crates/arroyo-connectors/src/lib.rs index e4378bb48..191d66627 100644 --- a/crates/arroyo-connectors/src/lib.rs +++ b/crates/arroyo-connectors/src/lib.rs @@ -1,4 +1,5 @@ use anyhow::{anyhow, bail, Context}; +use arrow::array::{ArrayRef, RecordBatch}; use arroyo_operator::connector::ErasedConnector; use arroyo_rpc::api_types::connections::{ ConnectionSchema, ConnectionType, FieldType, SourceField, SourceFieldType, TestSourceMessage, @@ -6,13 +7,12 @@ use arroyo_rpc::api_types::connections::{ use arroyo_rpc::primitive_to_sql; use arroyo_rpc::var_str::VarStr; use arroyo_types::{string_to_map, SourceError}; +use async_trait::async_trait; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use reqwest::Client; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::time::Duration; -use arrow::array::{ArrayRef, RecordBatch}; -use async_trait::async_trait; use tokio::sync::mpsc::Sender; use tracing::warn; @@ -66,14 +66,6 @@ pub fn connectors() -> HashMap<&'static str, Box> { #[derive(Serialize, Deserialize)] pub struct EmptyConfig {} -#[async_trait] -pub trait LookupConnector { - fn name(&self) -> String; - - async fn lookup(&mut self, keys: &[ArrayRef]) -> Option>; -} - - pub(crate) async fn send(tx: &mut Sender, msg: TestSourceMessage) { if tx.send(msg).await.is_err() { warn!("Test API rx closed while sending message"); diff --git a/crates/arroyo-connectors/src/redis/lookup.rs b/crates/arroyo-connectors/src/redis/lookup.rs index 8d002aa0f..106fc6108 100644 --- a/crates/arroyo-connectors/src/redis/lookup.rs +++ b/crates/arroyo-connectors/src/redis/lookup.rs @@ -1,19 +1,19 @@ -use std::time::SystemTime; +use crate::redis::sink::GeneralConnection; +use crate::redis::RedisClient; use arrow::array::{ArrayRef, AsArray, RecordBatch}; use arrow::datatypes::DataType; -use async_trait::async_trait; -use redis::{cmd, Value}; -use redis::aio::ConnectionLike; use arroyo_formats::de::ArrowDeserializer; +use arroyo_operator::connector::LookupConnector; use arroyo_types::SourceError; -use crate::LookupConnector; -use crate::redis::{RedisClient}; -use crate::redis::sink::GeneralConnection; +use async_trait::async_trait; +use redis::aio::ConnectionLike; +use redis::{cmd, Value}; +use std::time::SystemTime; pub struct RedisLookup { - deserializer: ArrowDeserializer, - client: RedisClient, - connection: Option, + pub(crate) deserializer: ArrowDeserializer, + pub(crate) client: RedisClient, + pub(crate) connection: Option, } #[async_trait] @@ -28,13 +28,16 @@ impl LookupConnector for RedisLookup { } assert_eq!(keys.len(), 1, "redis lookup can only have a single key"); - assert_eq!(*keys[0].data_type(), DataType::Utf8, "redis lookup key must be a string"); - + assert_eq!( + *keys[0].data_type(), + DataType::Utf8, + "redis lookup key must be a string" + ); let connection = self.connection.as_mut().unwrap(); - + let mut mget = cmd("mget"); - + for k in keys[0].as_string::() { mget.arg(k.unwrap()); } @@ -42,22 +45,24 @@ impl LookupConnector for RedisLookup { let Value::Array(vs) = connection.req_packed_command(&mget).await.unwrap() else { panic!("value was not an array"); }; - + for v in vs { match v { Value::Nil => { - todo!("handle missing values") + self.deserializer.deserialize_slice("null".as_bytes(), SystemTime::now(), None) + .await; } Value::SimpleString(s) => { - self.deserializer.deserialize_slice(s.as_bytes(), SystemTime::now(), None).await; + self.deserializer + .deserialize_slice(s.as_bytes(), SystemTime::now(), None) + .await; } v => { panic!("unexpected type {:?}", v); } } } - - + self.deserializer.flush_buffer() } -} \ No newline at end of file +} diff --git a/crates/arroyo-connectors/src/redis/mod.rs b/crates/arroyo-connectors/src/redis/mod.rs index 8a678bf98..9673ddb63 100644 --- a/crates/arroyo-connectors/src/redis/mod.rs +++ b/crates/arroyo-connectors/src/redis/mod.rs @@ -1,27 +1,30 @@ -pub mod sink; pub mod lookup; +pub mod sink; use anyhow::{anyhow, bail}; +use arroyo_formats::de::ArrowDeserializer; use arroyo_formats::ser::ArrowSerializer; -use arroyo_operator::connector::{Connection, Connector}; +use arroyo_operator::connector::{Connection, Connector, LookupConnector}; use arroyo_operator::operator::ConstructedOperator; +use arroyo_rpc::api_types::connections::{ + ConnectionProfile, ConnectionSchema, ConnectionType, FieldType, PrimitiveType, + TestSourceMessage, +}; +use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::var_str::VarStr; +use arroyo_rpc::OperatorConfig; use redis::aio::ConnectionManager; use redis::cluster::ClusterClient; use redis::{Client, ConnectionInfo, IntoConnectionInfo}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use std::sync::Arc; use tokio::sync::oneshot::Receiver; use typify::import_types; -use arroyo_rpc::api_types::connections::{ - ConnectionProfile, ConnectionSchema, ConnectionType, FieldType, PrimitiveType, - TestSourceMessage, -}; -use arroyo_rpc::OperatorConfig; - -use crate::{pull_opt, pull_option_to_u64}; +use crate::redis::lookup::RedisLookup; use crate::redis::sink::{GeneralConnection, RedisSinkFunc}; +use crate::{pull_opt, pull_option_to_u64}; pub struct RedisConnector {} @@ -290,11 +293,9 @@ impl Connector for RedisConnector { } let sink = match typ.as_str() { - "lookup" => { - TableType::Lookup { - lookup: Default::default(), - } - } + "lookup" => TableType::Lookup { + lookup: Default::default(), + }, "sink" => { let target = match pull_opt("target", options)?.as_str() { "string" => Target::StringTable { @@ -342,9 +343,9 @@ impl Connector for RedisConnector { bail!("'{}' is not a valid redis target", s); } }; - + TableType::Sink { target } - }, + } s => { bail!("'{}' is not a valid type; must be `sink`", s); } @@ -382,14 +383,10 @@ impl Connector for RedisConnector { let _ = RedisClient::new(&config)?; let (connection_type, description) = match &table.connector_type { - TableType::Sink { .. } => { - (ConnectionType::Sink, "RedisSink") - } - TableType::Lookup { .. } => { - (ConnectionType::Lookup, "RedisLookup") - } + TableType::Sink { .. } => (ConnectionType::Sink, "RedisSink"), + TableType::Lookup { .. } => (ConnectionType::Lookup, "RedisLookup"), }; - + let config = OperatorConfig { connection: serde_json::to_value(config).unwrap(), table: serde_json::to_value(table).unwrap(), @@ -399,7 +396,7 @@ impl Connector for RedisConnector { framing: schema.framing.clone(), metadata_fields: vec![], }; - + Ok(Connection { id, connector: self.name(), @@ -438,11 +435,31 @@ impl Connector for RedisConnector { hash_index: None, }, ))) - } TableType::Lookup { .. } => { - todo!() + bail!("Cannot construct a lookup table as an operator"); } } } + + fn make_lookup( + &self, + profile: Self::ProfileT, + table: Self::TableT, + config: OperatorConfig, + schema: Arc, + ) -> anyhow::Result> { + Ok(Box::new(RedisLookup { + deserializer: ArrowDeserializer::new( + config + .format + .ok_or_else(|| anyhow!("Redis table must have a format"))?, + schema, + config.framing, + config.bad_data.unwrap_or_default(), + ), + client: RedisClient::new(&profile)?, + connection: None, + })) + } } diff --git a/crates/arroyo-connectors/src/redis/sink.rs b/crates/arroyo-connectors/src/redis/sink.rs index b40f8c99a..6820c8b91 100644 --- a/crates/arroyo-connectors/src/redis/sink.rs +++ b/crates/arroyo-connectors/src/redis/sink.rs @@ -283,11 +283,9 @@ impl ArrowOperator for RedisSinkFunc { last_flushed: Instant::now(), max_push_keys: HashSet::new(), behavior: match &self.target { - Target::StringTable { ttl_secs, .. } => { - RedisBehavior::Set { - ttl: ttl_secs.map(|t| t.get() as usize), - } - } + Target::StringTable { ttl_secs, .. } => RedisBehavior::Set { + ttl: ttl_secs.map(|t| t.get() as usize), + }, Target::ListTable { max_length, operation, diff --git a/crates/arroyo-datastream/src/logical.rs b/crates/arroyo-datastream/src/logical.rs index 2f4f957a7..f393d00c0 100644 --- a/crates/arroyo-datastream/src/logical.rs +++ b/crates/arroyo-datastream/src/logical.rs @@ -134,16 +134,22 @@ impl TryFrom for PipelineGraph { #[derive(Clone, Debug, Eq, PartialEq)] pub struct LogicalEdge { pub edge_type: LogicalEdgeType, - pub schema: ArroyoSchema, + pub schema: Arc, } impl LogicalEdge { pub fn new(edge_type: LogicalEdgeType, schema: ArroyoSchema) -> Self { - LogicalEdge { edge_type, schema } + LogicalEdge { + edge_type, + schema: Arc::new(schema), + } } pub fn project_all(edge_type: LogicalEdgeType, schema: ArroyoSchema) -> Self { - LogicalEdge { edge_type, schema } + LogicalEdge { + edge_type, + schema: Arc::new(schema), + } } } @@ -157,7 +163,7 @@ pub struct ChainedLogicalOperator { #[derive(Clone, Debug)] pub struct OperatorChain { pub(crate) operators: Vec, - pub(crate) edges: Vec, + pub(crate) edges: Vec>, } impl OperatorChain { @@ -168,7 +174,9 @@ impl OperatorChain { } } - pub fn iter(&self) -> impl Iterator)> { + pub fn iter( + &self, + ) -> impl Iterator>)> { self.operators .iter() .zip_longest(self.edges.iter()) @@ -178,10 +186,10 @@ impl OperatorChain { pub fn iter_mut( &mut self, - ) -> impl Iterator)> { + ) -> impl Iterator>)> { self.operators .iter_mut() - .zip_longest(self.edges.iter_mut()) + .zip_longest(self.edges.iter()) .map(|e| e.left_and_right()) .map(|(l, r)| (l.unwrap(), r)) } @@ -438,7 +446,7 @@ impl TryFrom for LogicalProgram { edges: node .edges .into_iter() - .map(|e| Ok(e.try_into()?)) + .map(|e| Ok(Arc::new(e.try_into()?))) .collect::>>()?, }, parallelism: node.parallelism as usize, @@ -456,7 +464,7 @@ impl TryFrom for LogicalProgram { target, LogicalEdge { edge_type: edge.edge_type().into(), - schema: schema.clone().try_into()?, + schema: Arc::new(schema.clone().try_into()?), }, ); } @@ -623,7 +631,7 @@ impl From for ArrowProgram { .operator_chain .edges .iter() - .map(|edge| edge.clone().into()) + .map(|edge| (**edge).clone().into()) .collect(), } }) @@ -639,7 +647,7 @@ impl From for ArrowProgram { api::ArrowEdge { source: source.index() as i32, target: target.index() as i32, - schema: Some(edge.schema.clone().into()), + schema: Some((*edge.schema).clone().into()), edge_type: edge_type as i32, } }) diff --git a/crates/arroyo-formats/src/de.rs b/crates/arroyo-formats/src/de.rs index 62a4a59b8..052662d0d 100644 --- a/crates/arroyo-formats/src/de.rs +++ b/crates/arroyo-formats/src/de.rs @@ -3,9 +3,12 @@ use crate::proto::schema::get_pool; use crate::{proto, should_flush}; use arrow::array::{Int32Builder, Int64Builder}; use arrow::compute::kernels; -use arrow_array::builder::{make_builder, ArrayBuilder, GenericByteBuilder, StringBuilder, TimestampNanosecondBuilder}; +use arrow_array::builder::{ + make_builder, ArrayBuilder, GenericByteBuilder, StringBuilder, TimestampNanosecondBuilder, +}; use arrow_array::types::GenericBinaryType; use arrow_array::RecordBatch; +use arrow_schema::SchemaRef; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::formats::{ AvroFormat, BadData, Format, Framing, FramingMethod, JsonFormat, ProtobufFormat, @@ -17,7 +20,6 @@ use serde_json::Value; use std::collections::HashMap; use std::sync::Arc; use std::time::{Instant, SystemTime}; -use arrow_schema::SchemaRef; use tokio::sync::Mutex; #[derive(Debug, Clone)] @@ -62,11 +64,10 @@ impl ContextBuffer { self.schema.clone(), self.buffer.iter_mut().map(|a| a.finish()).collect(), ) - .unwrap() + .unwrap() } } - pub struct FramingIterator<'a> { framing: Option>, buf: &'a [u8], @@ -121,7 +122,7 @@ impl<'a> Iterator for FramingIterator<'a> { pub struct ArrowDeserializer { format: Arc, framing: Option>, - schema: ArroyoSchema, + schema: Arc, bad_data: BadData, json_decoder: Option<(arrow::json::reader::Decoder, TimestampNanosecondBuilder)>, buffered_count: usize, @@ -136,7 +137,7 @@ pub struct ArrowDeserializer { impl ArrowDeserializer { pub fn new( format: Format, - schema: ArroyoSchema, + schema: Arc, framing: Option, bad_data: BadData, ) -> Self { @@ -157,7 +158,7 @@ impl ArrowDeserializer { pub fn with_schema_resolver( format: Format, framing: Option, - schema: ArroyoSchema, + schema: Arc, bad_data: BadData, schema_resolver: Arc, ) -> Self { @@ -228,15 +229,14 @@ impl ArrowDeserializer { } pub fn should_flush(&self) -> bool { - self.buffer.should_flush() || - should_flush(self.buffered_count, self.buffered_since) + self.buffer.should_flush() || should_flush(self.buffered_count, self.buffered_since) } pub fn flush_buffer(&mut self) -> Option> { if self.buffer.size() > 0 { return Some(Ok(self.buffer.finish())); } - + let (decoder, timestamp) = self.json_decoder.as_mut()?; self.buffered_since = Instant::now(); self.buffered_count = 0; @@ -298,7 +298,11 @@ impl ArrowDeserializer { unstructured: true, .. }) => { self.deserialize_raw_string(msg); - add_timestamp(&mut self.buffer.buffer, self.schema.timestamp_index, timestamp); + add_timestamp( + &mut self.buffer.buffer, + self.schema.timestamp_index, + timestamp, + ); if let Some(fields) = additional_fields { for (k, v) in fields.iter() { add_additional_fields(&mut self.buffer.buffer, &self.schema, k, v); @@ -307,7 +311,11 @@ impl ArrowDeserializer { } Format::RawBytes(_) => { self.deserialize_raw_bytes(msg); - add_timestamp(&mut self.buffer.buffer, self.schema.timestamp_index, timestamp); + add_timestamp( + &mut self.buffer.buffer, + self.schema.timestamp_index, + timestamp, + ); if let Some(fields) = additional_fields { for (k, v) in fields.iter() { add_additional_fields(&mut self.buffer.buffer, &self.schema, k, v); @@ -386,11 +394,7 @@ impl ArrowDeserializer { Ok(()) } - fn decode_into_json( - &mut self, - value: Value, - timestamp: SystemTime, - ) { + fn decode_into_json(&mut self, value: Value, timestamp: SystemTime) { let (idx, _) = self .schema .schema @@ -402,7 +406,11 @@ impl ArrowDeserializer { .expect("'value' column has incorrect type"); array.append_value(value.to_string()); - add_timestamp(&mut self.buffer.buffer, self.schema.timestamp_index, timestamp); + add_timestamp( + &mut self.buffer.buffer, + self.schema.timestamp_index, + timestamp, + ); self.buffered_count += 1; } @@ -702,7 +710,7 @@ mod tests { ), ])); - let schema = ArroyoSchema::from_schema_unkeyed(schema).unwrap(); + let schema = Arc::new(ArroyoSchema::from_schema_unkeyed(schema).unwrap()); let deserializer = ArrowDeserializer::new( Format::Json(JsonFormat { @@ -729,21 +737,13 @@ mod tests { assert_eq!( deserializer - .deserialize_slice( - json!({ "x": 5 }).to_string().as_bytes(), - now, - None, - ) + .deserialize_slice(json!({ "x": 5 }).to_string().as_bytes(), now, None,) .await, vec![] ); assert_eq!( deserializer - .deserialize_slice( - json!({ "x": "hello" }).to_string().as_bytes(), - now, - None, - ) + .deserialize_slice(json!({ "x": "hello" }).to_string().as_bytes(), now, None,) .await, vec![] ); @@ -806,7 +806,7 @@ mod tests { .map(|f| make_builder(f.data_type(), 16)) .collect(); - let arroyo_schema = ArroyoSchema::from_schema_unkeyed(schema.clone()).unwrap(); + let arroyo_schema = Arc::new(ArroyoSchema::from_schema_unkeyed(schema.clone()).unwrap()); let mut deserializer = ArrowDeserializer::new( Format::RawBytes(RawBytesFormat {}), @@ -858,7 +858,7 @@ mod tests { .map(|f| make_builder(f.data_type(), 16)) .collect(); - let arroyo_schema = ArroyoSchema::from_schema_unkeyed(schema.clone()).unwrap(); + let arroyo_schema = Arc::new(ArroyoSchema::from_schema_unkeyed(schema.clone()).unwrap()); let mut deserializer = ArrowDeserializer::new( Format::Json(JsonFormat { diff --git a/crates/arroyo-operator/src/connector.rs b/crates/arroyo-operator/src/connector.rs index d079879f5..51c0a38d2 100644 --- a/crates/arroyo-operator/src/connector.rs +++ b/crates/arroyo-operator/src/connector.rs @@ -1,15 +1,19 @@ use crate::operator::ConstructedOperator; use anyhow::{anyhow, bail}; +use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::{DataType, Field}; use arroyo_rpc::api_types::connections::{ ConnectionProfile, ConnectionSchema, ConnectionType, TestSourceMessage, }; +use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::OperatorConfig; -use arroyo_types::DisplayAsSql; +use arroyo_types::{DisplayAsSql, SourceError}; +use async_trait::async_trait; use serde::de::DeserializeOwned; use serde::ser::Serialize; use serde_json::value::Value; use std::collections::HashMap; +use std::sync::Arc; use tokio::sync::mpsc::Sender; use tokio::sync::oneshot; @@ -118,6 +122,17 @@ pub trait Connector: Send { table: Self::TableT, config: OperatorConfig, ) -> anyhow::Result; + + #[allow(unused)] + fn make_lookup( + &self, + profile: Self::ProfileT, + table: Self::TableT, + config: OperatorConfig, + schema: Arc, + ) -> anyhow::Result> { + bail!("{} is not a lookup connector", self.name()) + } } #[allow(clippy::type_complexity)] #[allow(clippy::wrong_self_convention)] @@ -187,6 +202,12 @@ pub trait ErasedConnector: Send { ) -> anyhow::Result; fn make_operator(&self, config: OperatorConfig) -> anyhow::Result; + + fn make_lookup( + &self, + config: OperatorConfig, + schema: Arc, + ) -> anyhow::Result>; } impl ErasedConnector for C { @@ -335,4 +356,24 @@ impl ErasedConnector for C { config, ) } + + fn make_lookup( + &self, + config: OperatorConfig, + schema: Arc, + ) -> anyhow::Result> { + self.make_lookup( + self.parse_config(&config.connection)?, + self.parse_table(&config.table)?, + config, + schema, + ) + } +} + +#[async_trait] +pub trait LookupConnector { + fn name(&self) -> String; + + async fn lookup(&mut self, keys: &[ArrayRef]) -> Option>; } diff --git a/crates/arroyo-operator/src/context.rs b/crates/arroyo-operator/src/context.rs index 38dd8c353..aeba7a2a5 100644 --- a/crates/arroyo-operator/src/context.rs +++ b/crates/arroyo-operator/src/context.rs @@ -1,7 +1,7 @@ use crate::{server_for_hash_array, RateLimiter}; -use arrow::array::{Array,PrimitiveArray, RecordBatch}; +use arrow::array::{Array, PrimitiveArray, RecordBatch}; use arrow::compute::{partition, sort_to_indices, take}; -use arrow::datatypes::{UInt64Type}; +use arrow::datatypes::UInt64Type; use arroyo_formats::de::{ArrowDeserializer, FieldValueType}; use arroyo_metrics::{register_queue_gauge, QueueGauges, TaskCounters}; use arroyo_rpc::config::config; @@ -22,7 +22,7 @@ use std::collections::HashMap; use std::mem::size_of_val; use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; use std::sync::Arc; -use std::time::{SystemTime}; +use std::time::SystemTime; use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender}; use tokio::sync::Notify; @@ -205,7 +205,7 @@ pub fn batch_bounded(size: u32) -> (BatchSender, BatchReceiver) { } pub struct SourceContext { - pub out_schema: ArroyoSchema, + pub out_schema: Arc, pub error_reporter: ErrorReporter, pub control_tx: Sender, pub control_rx: Receiver, @@ -266,7 +266,7 @@ pub struct SourceCollector { deserializer: Option, buffered_error: Option, error_rate_limiter: RateLimiter, - pub out_schema: ArroyoSchema, + pub out_schema: Arc, pub(crate) collector: ArrowCollector, control_tx: Sender, chain_info: Arc, @@ -275,7 +275,7 @@ pub struct SourceCollector { impl SourceCollector { pub fn new( - out_schema: ArroyoSchema, + out_schema: Arc, collector: ArrowCollector, control_tx: Sender, chain_info: &Arc, @@ -332,8 +332,7 @@ impl SourceCollector { } pub fn should_flush(&self) -> bool { - self - .deserializer + self.deserializer .as_ref() .map(|d| d.should_flush()) .unwrap_or(false) @@ -351,7 +350,7 @@ impl SourceCollector { .expect("deserializer not initialized!"); let errors = deserializer - .deserialize_slice( msg, time, additional_fields) + .deserialize_slice(msg, time, additional_fields) .await; self.collect_source_errors(errors).await?; @@ -453,8 +452,8 @@ pub struct OperatorContext { pub task_info: Arc, pub control_tx: Sender, pub watermarks: WatermarkHolder, - pub in_schemas: Vec, - pub out_schema: Option, + pub in_schemas: Vec>, + pub out_schema: Option>, pub table_manager: TableManager, pub error_reporter: ErrorReporter, } @@ -489,7 +488,7 @@ pub trait Collector: Send { #[derive(Clone)] pub struct ArrowCollector { pub chain_info: Arc, - out_schema: Option, + out_schema: Option>, out_qs: Vec>, tx_queue_rem_gauges: QueueGauges, tx_queue_size_gauges: QueueGauges, @@ -607,7 +606,7 @@ impl Collector for ArrowCollector { impl ArrowCollector { pub fn new( chain_info: Arc, - out_schema: Option, + out_schema: Option>, out_qs: Vec>, ) -> Self { let tx_queue_size_gauges = register_queue_gauge( @@ -673,8 +672,8 @@ impl OperatorContext { restore_from: Option<&CheckpointMetadata>, control_tx: Sender, input_partitions: usize, - in_schemas: Vec, - out_schema: Option, + in_schemas: Vec>, + out_schema: Option>, tables: HashMap, ) -> Self { let (table_manager, watermark) = @@ -820,7 +819,7 @@ mod tests { let mut collector = ArrowCollector { chain_info, - out_schema: Some(ArroyoSchema::new_keyed(schema, 1, vec![0])), + out_schema: Some(Arc::new(ArroyoSchema::new_keyed(schema, 1, vec![0]))), out_qs, tx_queue_rem_gauges, tx_queue_size_gauges, diff --git a/crates/arroyo-operator/src/operator.rs b/crates/arroyo-operator/src/operator.rs index d34eaf4ff..70fc09bc0 100644 --- a/crates/arroyo-operator/src/operator.rs +++ b/crates/arroyo-operator/src/operator.rs @@ -226,7 +226,7 @@ impl OperatorNode { control_rx: Receiver, mut in_qs: Vec, out_qs: Vec>, - out_schema: Option, + out_schema: Option>, ready: Arc, ) { info!( diff --git a/crates/arroyo-planner/src/builder.rs b/crates/arroyo-planner/src/builder.rs index f11f30200..d19cf0468 100644 --- a/crates/arroyo-planner/src/builder.rs +++ b/crates/arroyo-planner/src/builder.rs @@ -172,7 +172,7 @@ impl<'a> Planner<'a> { let (partial_schema, timestamp_index) = if add_timestamp_field { ( - add_timestamp_field_arrow(partial_schema.clone()), + add_timestamp_field_arrow((*partial_schema).clone()), partial_schema.fields().len(), ) } else { diff --git a/crates/arroyo-planner/src/extension/lookup.rs b/crates/arroyo-planner/src/extension/lookup.rs index a762c2499..2d72a05f3 100644 --- a/crates/arroyo-planner/src/extension/lookup.rs +++ b/crates/arroyo-planner/src/extension/lookup.rs @@ -1,12 +1,19 @@ -use std::fmt::Formatter; -use datafusion::common::{internal_err, DFSchemaRef}; -use datafusion::logical_expr::{Expr, Join, LogicalPlan, UserDefinedLogicalNodeCore}; -use datafusion::sql::TableReference; -use arroyo_rpc::df::{ArroyoSchema, ArroyoSchemaRef}; use crate::builder::{NamedNode, Planner}; use crate::extension::{ArroyoExtension, NodeWithIncomingEdges}; use crate::multifield_partial_ord; use crate::tables::ConnectorTable; +use arroyo_datastream::logical::{LogicalEdge, LogicalEdgeType, LogicalNode, OperatorName}; +use arroyo_rpc::df::{ArroyoSchema, ArroyoSchemaRef}; +use arroyo_rpc::grpc::api::{ConnectorOp, LookupJoinCondition, LookupJoinOperator}; +use datafusion::common::{internal_err, plan_err, Column, DFSchemaRef, JoinType}; +use datafusion::logical_expr::{Expr, LogicalPlan, UserDefinedLogicalNodeCore}; +use datafusion::sql::TableReference; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr; +use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; +use prost::Message; +use std::fmt::Formatter; +use std::sync::Arc; +use crate::schemas::{add_timestamp_field_arrow}; pub const SOURCE_EXTENSION_NAME: &str = "LookupSource"; pub const JOIN_EXTENSION_NAME: &str = "LookupJoin"; @@ -40,12 +47,15 @@ impl UserDefinedLogicalNodeCore for LookupSource { write!(f, "LookupSource: {}", self.schema) } - fn with_exprs_and_inputs(&self, _exprs: Vec, inputs: Vec) -> datafusion::common::Result { + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> datafusion::common::Result { if !inputs.is_empty() { return internal_err!("LookupSource cannot have inputs"); } - - + Ok(Self { table: self.table.clone(), schema: self.schema.clone(), @@ -57,10 +67,11 @@ impl UserDefinedLogicalNodeCore for LookupSource { pub struct LookupJoin { pub(crate) input: LogicalPlan, pub(crate) schema: DFSchemaRef, - pub(crate) connector: ConnectorTable, - pub(crate) on: Vec<(Expr, Expr)>, + pub(crate) connector: ConnectorTable, + pub(crate) on: Vec<(Expr, Column)>, pub(crate) filter: Option, pub(crate) alias: Option, + pub(crate) join_type: JoinType, } multifield_partial_ord!(LookupJoin, input, connector, on, filter, alias); @@ -70,8 +81,58 @@ impl ArroyoExtension for LookupJoin { None } - fn plan_node(&self, planner: &Planner, index: usize, input_schemas: Vec) -> datafusion::common::Result { - let keys = + fn plan_node( + &self, + planner: &Planner, + index: usize, + input_schemas: Vec, + ) -> datafusion::common::Result { + let schema = ArroyoSchema::from_schema_unkeyed(Arc::new(self.schema.as_ref().into()))?; + let lookup_schema = ArroyoSchema::from_schema_unkeyed( + add_timestamp_field_arrow(self.connector.physical_schema()))?; + let join_config = LookupJoinOperator { + input_schema: Some(schema.into()), + lookup_schema: Some(lookup_schema.into()), + connector: Some(ConnectorOp { + connector: self.connector.connector.clone(), + config: self.connector.config.clone(), + description: self.connector.description.clone(), + }), + key_exprs: self + .on + .iter() + .map(|(l, r)| { + let expr = planner.create_physical_expr(l, &self.schema)?; + let expr = serialize_physical_expr(&expr, &DefaultPhysicalExtensionCodec {})?; + Ok(LookupJoinCondition { + left_expr: expr.encode_to_vec(), + right_key: r.name.clone(), + }) + }) + .collect::>>()?, + join_type: match self.join_type { + JoinType::Inner => arroyo_rpc::grpc::api::JoinType::Inner as i32, + JoinType::Left => arroyo_rpc::grpc::api::JoinType::Left as i32, + j => { + return plan_err!("unsupported join type '{j}' for lookup join; only inner and left joins are supported"); + } + }, + }; + + let incoming_edge = + LogicalEdge::project_all(LogicalEdgeType::Shuffle, (*input_schemas[0]).clone()); + + Ok(NodeWithIncomingEdges { + node: LogicalNode::single( + index as u32, + format!("lookupjoin_{}", index), + OperatorName::LookupJoin, + join_config.encode_to_vec(), + format!("LookupJoin<{}>", self.connector.name), + 1, + ), + edges: vec![incoming_edge], + }) } fn output_schema(&self) -> ArroyoSchema { @@ -93,14 +154,12 @@ impl UserDefinedLogicalNodeCore for LookupJoin { } fn expressions(&self) -> Vec { - let mut e: Vec<_> = self.on.iter() - .flat_map(|(l, r)| vec![l.clone(), r.clone()]) - .collect(); - + let mut e: Vec<_> = self.on.iter().map(|(l, _)| l.clone()).collect(); + if let Some(filter) = &self.filter { e.push(filter.clone()); } - + e } @@ -108,7 +167,11 @@ impl UserDefinedLogicalNodeCore for LookupJoin { write!(f, "LookupJoinExtension: {}", self.schema) } - fn with_exprs_and_inputs(&self, _: Vec, inputs: Vec) -> datafusion::common::Result { + fn with_exprs_and_inputs( + &self, + _: Vec, + inputs: Vec, + ) -> datafusion::common::Result { Ok(Self { input: inputs[0].clone(), schema: self.schema.clone(), @@ -116,6 +179,7 @@ impl UserDefinedLogicalNodeCore for LookupJoin { on: self.on.clone(), filter: self.filter.clone(), alias: self.alias.clone(), + join_type: self.join_type, }) } -} \ No newline at end of file +} diff --git a/crates/arroyo-planner/src/extension/mod.rs b/crates/arroyo-planner/src/extension/mod.rs index 2fec17b0b..c6463d3bd 100644 --- a/crates/arroyo-planner/src/extension/mod.rs +++ b/crates/arroyo-planner/src/extension/mod.rs @@ -26,22 +26,22 @@ use self::{ window_fn::WindowFunctionExtension, }; use crate::builder::{NamedNode, Planner}; +use crate::extension::lookup::LookupJoin; use crate::schemas::{add_timestamp_field, has_timestamp_field}; use crate::{fields_with_qualifiers, schema_from_df_fields, DFField, ASYNC_RESULT_FIELD}; use join::JoinExtension; -use crate::extension::lookup::LookupJoin; pub(crate) mod aggregate; pub(crate) mod debezium; pub(crate) mod join; pub(crate) mod key_calculation; +pub(crate) mod lookup; pub(crate) mod remote_table; pub(crate) mod sink; pub(crate) mod table_source; pub(crate) mod updating_aggregate; pub(crate) mod watermark_node; pub(crate) mod window_fn; -pub(crate) mod lookup; pub(crate) trait ArroyoExtension: Debug { // if the extension has a name, return it so that we can memoize. diff --git a/crates/arroyo-planner/src/plan/join.rs b/crates/arroyo-planner/src/plan/join.rs index af7b5e4fe..b0a045432 100644 --- a/crates/arroyo-planner/src/plan/join.rs +++ b/crates/arroyo-planner/src/plan/join.rs @@ -1,10 +1,15 @@ use crate::extension::join::JoinExtension; use crate::extension::key_calculation::KeyCalculationExtension; +use crate::extension::lookup::{LookupJoin, LookupSource}; use crate::plan::WindowDetectingVisitor; +use crate::schemas::add_timestamp_field; +use crate::tables::ConnectorTable; use crate::{fields_with_qualifiers, schema_from_df_fields_with_metadata, ArroyoSchemaProvider}; use arroyo_datastream::WindowType; use arroyo_rpc::UPDATING_META_FIELD; -use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor}; +use datafusion::common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, +}; use datafusion::common::{ not_impl_err, plan_err, Column, DataFusionError, JoinConstraint, JoinType, Result, ScalarValue, TableReference, @@ -15,10 +20,8 @@ use datafusion::logical_expr::{ build_join_schema, BinaryExpr, Case, Expr, Extension, Join, LogicalPlan, Projection, }; use datafusion::prelude::coalesce; +use datafusion::sql::unparser::expr_to_sql; use std::sync::Arc; -use crate::extension::lookup::{LookupJoin, LookupSource}; -use crate::schemas::add_timestamp_field; -use crate::tables::ConnectorTable; pub(crate) struct JoinRewriter<'a> { pub schema_provider: &'a ArroyoSchemaProvider, @@ -81,7 +84,6 @@ impl JoinRewriter<'_> { } fn create_join_key_plan( - &self, input: Arc, join_expressions: Vec, name: &'static str, @@ -199,7 +201,7 @@ struct FindLookupExtension { alias: Option, } -impl <'a> TreeNodeVisitor<'a> for FindLookupExtension { +impl<'a> TreeNodeVisitor<'a> for FindLookupExtension { type Node = LogicalPlan; fn f_down(&mut self, node: &Self::Node) -> Result { @@ -212,7 +214,9 @@ impl <'a> TreeNodeVisitor<'a> for FindLookupExtension { } LogicalPlan::Filter(filter) => { if self.filter.replace(filter.predicate.clone()).is_some() { - return plan_err!("multiple filters found in lookup join, which is not supported"); + return plan_err!( + "multiple filters found in lookup join, which is not supported" + ); } } LogicalPlan::SubqueryAlias(s) => { @@ -227,48 +231,75 @@ impl <'a> TreeNodeVisitor<'a> for FindLookupExtension { } fn has_lookup(plan: &LogicalPlan) -> Result { - plan.exists(|p| Ok(match p { - LogicalPlan::Extension(e) => e.node.as_any().is::(), - _ => false - })) + plan.exists(|p| { + Ok(match p { + LogicalPlan::Extension(e) => e.node.as_any().is::(), + _ => false, + }) + }) } fn maybe_plan_lookup_join(join: &Join) -> Result> { if has_lookup(&join.left)? { return plan_err!("lookup sources must be on the right side of an inner or left join"); } - + if !has_lookup(&join.right)? { return Ok(None); } - - println!("JOin = {:?} {:?}\n{:#?}", join.join_constraint, join.join_type, join.on); - + + println!( + "JOin = {:?} {:?}\n{:#?}", + join.join_constraint, join.join_type, join.on + ); + match join.join_type { JoinType::Inner | JoinType::Left => {} t => { - return plan_err!("{} join is not supported for lookup tables; must be a left or inner join", t); + return plan_err!( + "{} join is not supported for lookup tables; must be a left or inner join", + t + ); } } - + if join.filter.is_some() { return plan_err!("filter join conditions are not supported for lookup joins; must have an equality condition"); } + let on = join.on.iter().map(|(l, r)| { + match r { + Expr::Column(c) => Ok((l.clone(), c.clone())), + e => { + return plan_err!("invalid right-side condition for lookup join: `{}`; only column references are supported", + expr_to_sql(e).map(|e| e.to_string()).unwrap_or_else(|_| e.to_string())); + } + } + }).collect::>()?; + let mut lookup = FindLookupExtension::default(); join.right.visit(&mut lookup)?; - let connector = lookup.table.expect("right side of join does not have lookup"); + let connector = lookup + .table + .expect("right side of join does not have lookup"); + + let left_input = JoinRewriter::create_join_key_plan( + join.left.clone(), + join.on.iter().map(|(l, _)| l.clone()).collect(), + "left", + )?; Ok(Some(LogicalPlan::Extension(Extension { node: Arc::new(LookupJoin { - input: (*join.left).clone(), + input: left_input, schema: add_timestamp_field(join.schema.clone(), None)?, connector, - on: join.on.clone(), + on, filter: lookup.filter, alias: lookup.alias, - }) + join_type: join.join_type, + }), }))) } @@ -279,11 +310,11 @@ impl TreeNodeRewriter for JoinRewriter<'_> { let LogicalPlan::Join(join) = node else { return Ok(Transformed::no(node)); }; - + if let Some(plan) = maybe_plan_lookup_join(&join)? { return Ok(Transformed::yes(plan)); } - + let is_instant = Self::check_join_windowing(&join)?; let Join { @@ -308,8 +339,8 @@ impl TreeNodeRewriter for JoinRewriter<'_> { let (left_expressions, right_expressions): (Vec<_>, Vec<_>) = on.clone().into_iter().unzip(); - let left_input = self.create_join_key_plan(left, left_expressions, "left")?; - let right_input = self.create_join_key_plan(right, right_expressions, "right")?; + let left_input = Self::create_join_key_plan(left, left_expressions, "left")?; + let right_input = Self::create_join_key_plan(right, right_expressions, "right")?; let rewritten_join = LogicalPlan::Join(Join { schema: Arc::new(build_join_schema( left_input.schema(), diff --git a/crates/arroyo-planner/src/rewriters.rs b/crates/arroyo-planner/src/rewriters.rs index 12cbfd21f..608f87eba 100644 --- a/crates/arroyo-planner/src/rewriters.rs +++ b/crates/arroyo-planner/src/rewriters.rs @@ -19,6 +19,7 @@ use arroyo_rpc::TIMESTAMP_FIELD; use arroyo_rpc::UPDATING_META_FIELD; use datafusion::logical_expr::UserDefinedLogicalNode; +use crate::extension::lookup::LookupSource; use crate::extension::AsyncUDFExtension; use arroyo_udf_host::parse::{AsyncOptions, UdfType}; use datafusion::common::tree_node::{ @@ -36,7 +37,6 @@ use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; use std::time::Duration; -use crate::extension::lookup::LookupSource; /// Rewrites a logical plan to move projections out of table scans /// and into a separate projection node which may include virtual fields, @@ -219,13 +219,13 @@ impl SourceRewriter<'_> { fn mutate_lookup_table( &self, table_scan: &TableScan, - table: &ConnectorTable + table: &ConnectorTable, ) -> DFResult> { Ok(Transformed::yes(LogicalPlan::Extension(Extension { node: Arc::new(LookupSource { table: table.clone(), schema: table_scan.projected_schema.clone(), - }) + }), }))) } diff --git a/crates/arroyo-planner/src/schemas.rs b/crates/arroyo-planner/src/schemas.rs index 9d80453c4..3d0c53448 100644 --- a/crates/arroyo-planner/src/schemas.rs +++ b/crates/arroyo-planner/src/schemas.rs @@ -49,7 +49,7 @@ pub(crate) fn has_timestamp_field(schema: &DFSchemaRef) -> bool { .any(|field| field.name() == "_timestamp") } -pub fn add_timestamp_field_arrow(schema: SchemaRef) -> SchemaRef { +pub fn add_timestamp_field_arrow(schema: Schema) -> SchemaRef { let mut fields = schema.fields().to_vec(); fields.push(Arc::new(Field::new( "_timestamp", diff --git a/crates/arroyo-planner/src/tables.rs b/crates/arroyo-planner/src/tables.rs index 8bc133313..7b2638a51 100644 --- a/crates/arroyo-planner/src/tables.rs +++ b/crates/arroyo-planner/src/tables.rs @@ -759,15 +759,14 @@ impl Table { primary_keys, &mut with_map, connection_profile, - ).map_err(|e| e.context(format!("Failed to create table {}", name)))?; + ) + .map_err(|e| e.context(format!("Failed to create table {}", name)))?; Ok(Some(match table.connection_type { ConnectionType::Source | ConnectionType::Sink => { Table::ConnectorTable(table) } - ConnectionType::Lookup => { - Table::LookupTable(table) - } + ConnectionType::Lookup => Table::LookupTable(table), })) } } @@ -844,11 +843,12 @@ impl Table { fields, inferred_fields, .. - }) | Table::LookupTable(ConnectorTable { + }) + | Table::LookupTable(ConnectorTable { fields, inferred_fields, - .. - }) => inferred_fields + .. + }) => inferred_fields .as_ref() .map(|fs| fs.iter().map(|f| f.field().clone()).collect()) .unwrap_or_else(|| { diff --git a/crates/arroyo-planner/src/test/queries/lookup_join.sql b/crates/arroyo-planner/src/test/queries/lookup_join.sql index 76c9e84e5..d0c12c31d 100644 --- a/crates/arroyo-planner/src/test/queries/lookup_join.sql +++ b/crates/arroyo-planner/src/test/queries/lookup_join.sql @@ -13,7 +13,7 @@ CREATE TABLE orders ( ); CREATE TEMPORARY TABLE products ( - product_id INT PRIMARY KEY, + key TEXT PRIMARY KEY, product_name TEXT, unit_price FLOAT, category TEXT, @@ -36,4 +36,4 @@ SELECT (o.quantity * p.unit_price) as total_amount FROM orders o JOIN products p - ON o.product_id = p.product_id; \ No newline at end of file + ON concat('blah', o.product_id) = p.key; \ No newline at end of file diff --git a/crates/arroyo-rpc/proto/api.proto b/crates/arroyo-rpc/proto/api.proto index 1fa9aa63c..06bacfa47 100644 --- a/crates/arroyo-rpc/proto/api.proto +++ b/crates/arroyo-rpc/proto/api.proto @@ -70,11 +70,16 @@ message JoinOperator { optional uint64 ttl_micros = 6; } +message LookupJoinCondition { + bytes left_expr = 1; + string right_key = 2; +} + message LookupJoinOperator { - string name = 1; - ArroyoSchema schema = 2; + ArroyoSchema input_schema = 1; + ArroyoSchema lookup_schema = 2; ConnectorOp connector = 3; - repeated bytes key_exprs = 4; + repeated LookupJoinCondition key_exprs = 4; JoinType join_type = 5; } diff --git a/crates/arroyo-rpc/src/lib.rs b/crates/arroyo-rpc/src/lib.rs index 5bdb01bf2..6b693d06c 100644 --- a/crates/arroyo-rpc/src/lib.rs +++ b/crates/arroyo-rpc/src/lib.rs @@ -37,6 +37,17 @@ pub mod grpc { pub mod api { #![allow(clippy::derive_partial_eq_without_eq, deprecated)] tonic::include_proto!("api"); + + impl From for arroyo_types::JoinType { + fn from(value: JoinType) -> Self { + match value { + JoinType::Inner => arroyo_types::JoinType::Inner, + JoinType::Left => arroyo_types::JoinType::Left, + JoinType::Right => arroyo_types::JoinType::Right, + JoinType::Full => arroyo_types::JoinType::Full, + } + } + } } pub const API_FILE_DESCRIPTOR_SET: &[u8] = diff --git a/crates/arroyo-worker/src/arrow/lookup_join.rs b/crates/arroyo-worker/src/arrow/lookup_join.rs index 07a4d598c..05e376c40 100644 --- a/crates/arroyo-worker/src/arrow/lookup_join.rs +++ b/crates/arroyo-worker/src/arrow/lookup_join.rs @@ -1,15 +1,23 @@ use std::collections::HashMap; use std::sync::Arc; -use arrow_array::{RecordBatch}; -use arrow::row::{OwnedRow, RowConverter}; -use async_trait::async_trait; -use datafusion::physical_expr::PhysicalExpr; - -use arroyo_connectors::LookupConnector; +use arrow::row::{OwnedRow, RowConverter, SortField}; +use arrow_array::RecordBatch; +use arroyo_connectors::{connectors}; +use arroyo_operator::connector::LookupConnector; use arroyo_operator::context::{Collector, OperatorContext}; -use arroyo_operator::operator::ArrowOperator; +use arroyo_operator::operator::{ + ArrowOperator, ConstructedOperator, OperatorConstructor, Registry, +}; +use arroyo_rpc::df::ArroyoSchema; +use arroyo_rpc::grpc::api; use arroyo_types::JoinType; +use async_trait::async_trait; +use datafusion::physical_expr::PhysicalExpr; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr; +use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; +use datafusion_proto::protobuf::PhysicalExprNode; +use prost::Message; /// A simple in-operator cache storing the entire “right side” row batch keyed by a string. pub struct LookupJoin { @@ -38,12 +46,7 @@ impl ArrowOperator for LookupJoin { let key_arrays: Vec<_> = self .key_exprs .iter() - .map(|expr| { - expr.evaluate(&batch) - .unwrap() - .into_array(num_rows) - .unwrap() - }) + .map(|expr| expr.evaluate(&batch).unwrap().into_array(num_rows).unwrap()) .collect(); let rows = self.key_row_converter.convert_columns(&key_arrays).unwrap(); @@ -61,15 +64,17 @@ impl ArrowOperator for LookupJoin { } if !uncached_keys.is_empty() { - let cols = self.key_row_converter.convert_rows(uncached_keys.iter().map(|r| r.row())).unwrap(); + let cols = self + .key_row_converter + .convert_rows(uncached_keys.iter().map(|r| r.row())) + .unwrap(); - let result_batch = self - .connector - .lookup(&cols) - .await; + let result_batch = self.connector.lookup(&cols).await; if let Some(result_batch) = result_batch { - let result_rows = self.result_row_converter.convert_columns(result_batch.unwrap().columns()) + let result_rows = self + .result_row_converter + .convert_columns(result_batch.unwrap().columns()) .unwrap(); assert_eq!(result_rows.num_rows(), uncached_keys.len()); @@ -80,17 +85,81 @@ impl ArrowOperator for LookupJoin { } } - let mut output_rows = self.result_row_converter.empty_rows(batch.num_rows(), batch.num_rows() * 10); + let mut output_rows = self + .result_row_converter + .empty_rows(batch.num_rows(), batch.num_rows() * 10); for row in rows.iter() { - output_rows.push(self.cache.get(row.data()).expect("row should be cached").row()); + output_rows.push( + self.cache + .get(row.data()) + .expect("row should be cached") + .row(), + ); } - - let right_side = self.result_row_converter.convert_rows(output_rows.iter()).unwrap(); + + let right_side = self + .result_row_converter + .convert_rows(output_rows.iter()) + .unwrap(); let mut result = batch.columns().to_vec(); result.extend(right_side); - - collector.collect(RecordBatch::try_new(ctx.out_schema.as_ref().unwrap().schema.clone(), result).unwrap()) + + collector + .collect( + RecordBatch::try_new(ctx.out_schema.as_ref().unwrap().schema.clone(), result) + .unwrap(), + ) .await; } } + +pub struct LookupJoinConstructor; +impl OperatorConstructor for LookupJoinConstructor { + type ConfigT = api::LookupJoinOperator; + fn with_config( + &self, + config: Self::ConfigT, + registry: Arc, + ) -> anyhow::Result { + let join_type = config.join_type(); + let input_schema: ArroyoSchema = config.input_schema.unwrap().try_into()?; + let lookup_schema: ArroyoSchema = config.lookup_schema.unwrap().try_into()?; + + let exprs = config + .key_exprs + .iter() + .map(|e| { + let expr = PhysicalExprNode::decode(&mut e.left_expr.as_slice())?; + Ok(parse_physical_expr( + &expr, + registry.as_ref(), + &input_schema.schema, + &DefaultPhysicalExtensionCodec {}, + )?) + }) + .collect::>>()?; + + let op = config.connector.unwrap(); + let operator_config = serde_json::from_str(&op.config)?; + + let result_row_converter = RowConverter::new(lookup_schema.schema.fields.iter().map(|f| + SortField::new(f.data_type().clone())).collect())?; + + let connector = connectors() + .get(op.connector.as_str()) + .unwrap_or_else(|| panic!("No connector with name '{}'", op.connector)) + .make_lookup(operator_config, Arc::new(lookup_schema))?; + + Ok(ConstructedOperator::from_operator(Box::new(LookupJoin { + connector, + cache: Default::default(), + key_row_converter: RowConverter::new(exprs.iter().map(|e| + Ok(SortField::new(e.data_type(&input_schema.schema)?))) + .collect::>()?)?, + key_exprs: exprs, + result_row_converter, + join_type: join_type.into(), + }))) + } +} diff --git a/crates/arroyo-worker/src/arrow/mod.rs b/crates/arroyo-worker/src/arrow/mod.rs index 56ff81e7b..437e75188 100644 --- a/crates/arroyo-worker/src/arrow/mod.rs +++ b/crates/arroyo-worker/src/arrow/mod.rs @@ -28,6 +28,7 @@ use std::sync::RwLock; pub mod async_udf; pub mod instant_join; pub mod join_with_expiration; +pub mod lookup_join; pub mod session_aggregating_window; pub mod sliding_aggregating_window; pub(crate) mod sync; @@ -35,7 +36,6 @@ pub mod tumbling_aggregating_window; pub mod updating_aggregator; pub mod watermark_generator; pub mod window_fn; -mod lookup_join; pub struct ValueExecutionOperator { name: String, diff --git a/crates/arroyo-worker/src/arrow/tumbling_aggregating_window.rs b/crates/arroyo-worker/src/arrow/tumbling_aggregating_window.rs index 8a2a89052..ffcad9981 100644 --- a/crates/arroyo-worker/src/arrow/tumbling_aggregating_window.rs +++ b/crates/arroyo-worker/src/arrow/tumbling_aggregating_window.rs @@ -181,7 +181,7 @@ impl OperatorConstructor for TumblingAggregateWindowConstructor { .transpose()?; let aggregate_with_timestamp_schema = - add_timestamp_field_arrow(finish_execution_plan.schema()); + add_timestamp_field_arrow((*finish_execution_plan.schema()).clone()); Ok(ConstructedOperator::from_operator(Box::new( TumblingAggregatingWindowFunc { diff --git a/crates/arroyo-worker/src/engine.rs b/crates/arroyo-worker/src/engine.rs index ab01a3240..da147c451 100644 --- a/crates/arroyo-worker/src/engine.rs +++ b/crates/arroyo-worker/src/engine.rs @@ -1,6 +1,7 @@ use crate::arrow::async_udf::AsyncUdfConstructor; use crate::arrow::instant_join::InstantJoinConstructor; use crate::arrow::join_with_expiration::JoinWithExpirationConstructor; +use crate::arrow::lookup_join::LookupJoinConstructor; use crate::arrow::session_aggregating_window::SessionAggregatingWindowConstructor; use crate::arrow::sliding_aggregating_window::SlidingAggregatingWindowConstructor; use crate::arrow::tumbling_aggregating_window::TumblingAggregateWindowConstructor; @@ -54,8 +55,8 @@ pub struct SubtaskNode { pub node_id: u32, pub subtask_idx: usize, pub parallelism: usize, - pub in_schemas: Vec, - pub out_schema: Option, + pub in_schemas: Vec>, + pub out_schema: Option>, pub node: OperatorNode, } @@ -97,7 +98,7 @@ pub struct PhysicalGraphEdge { edge_idx: usize, in_logical_idx: usize, out_logical_idx: usize, - schema: ArroyoSchema, + schema: Arc, edge: LogicalEdgeType, tx: Option, rx: Option, @@ -773,17 +774,19 @@ pub async fn construct_node( subtask_idx: u32, parallelism: u32, input_partitions: u32, - in_schemas: Vec, - out_schema: Option, + in_schemas: Vec>, + out_schema: Option>, restore_from: Option<&CheckpointMetadata>, control_tx: Sender, registry: Arc, ) -> OperatorNode { if chain.is_source() { let (head, _) = chain.iter().next().unwrap(); - let ConstructedOperator::Source(operator) = - construct_operator(head.operator_name, &head.operator_config, registry) - else { + let ConstructedOperator::Source(operator) = construct_operator( + head.operator_name, + &head.operator_config, + registry, + ) else { unreachable!(); }; @@ -874,7 +877,7 @@ pub fn construct_operator( OperatorName::ExpressionWatermark => Box::new(WatermarkGeneratorConstructor), OperatorName::Join => Box::new(JoinWithExpirationConstructor), OperatorName::InstantJoin => Box::new(InstantJoinConstructor), - OperatorName::LookupJoin => todo!(), + OperatorName::LookupJoin => Box::new(LookupJoinConstructor), OperatorName::WindowFunction => Box::new(WindowFunctionConstructor), OperatorName::ConnectorSource | OperatorName::ConnectorSink => { let op: api::ConnectorOp = prost::Message::decode(config).unwrap(); From cb312455ff22cd8865cf91b78185a9769a188e36 Mon Sep 17 00:00:00 2001 From: Micah Wylde Date: Mon, 30 Dec 2024 20:04:38 -0800 Subject: [PATCH 05/14] Refactor deserialization to reduce code duplication and fix gaps in additional field deser --- crates/arroyo-connectors/src/lib.rs | 4 +- crates/arroyo-connectors/src/redis/lookup.rs | 3 +- crates/arroyo-formats/src/avro/de.rs | 26 +- crates/arroyo-formats/src/de.rs | 573 ++++++++---------- crates/arroyo-planner/src/extension/lookup.rs | 7 +- crates/arroyo-rpc/src/lib.rs | 2 +- crates/arroyo-worker/src/arrow/lookup_join.rs | 30 +- crates/arroyo-worker/src/engine.rs | 8 +- 8 files changed, 302 insertions(+), 351 deletions(-) diff --git a/crates/arroyo-connectors/src/lib.rs b/crates/arroyo-connectors/src/lib.rs index 191d66627..bc4892269 100644 --- a/crates/arroyo-connectors/src/lib.rs +++ b/crates/arroyo-connectors/src/lib.rs @@ -1,13 +1,11 @@ use anyhow::{anyhow, bail, Context}; -use arrow::array::{ArrayRef, RecordBatch}; use arroyo_operator::connector::ErasedConnector; use arroyo_rpc::api_types::connections::{ ConnectionSchema, ConnectionType, FieldType, SourceField, SourceFieldType, TestSourceMessage, }; use arroyo_rpc::primitive_to_sql; use arroyo_rpc::var_str::VarStr; -use arroyo_types::{string_to_map, SourceError}; -use async_trait::async_trait; +use arroyo_types::string_to_map; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use reqwest::Client; use serde::{Deserialize, Serialize}; diff --git a/crates/arroyo-connectors/src/redis/lookup.rs b/crates/arroyo-connectors/src/redis/lookup.rs index 106fc6108..f14903c3e 100644 --- a/crates/arroyo-connectors/src/redis/lookup.rs +++ b/crates/arroyo-connectors/src/redis/lookup.rs @@ -49,7 +49,8 @@ impl LookupConnector for RedisLookup { for v in vs { match v { Value::Nil => { - self.deserializer.deserialize_slice("null".as_bytes(), SystemTime::now(), None) + self.deserializer + .deserialize_slice("null".as_bytes(), SystemTime::now(), None) .await; } Value::SimpleString(s) => { diff --git a/crates/arroyo-formats/src/avro/de.rs b/crates/arroyo-formats/src/avro/de.rs index f371bf70c..4c2bfd115 100644 --- a/crates/arroyo-formats/src/avro/de.rs +++ b/crates/arroyo-formats/src/avro/de.rs @@ -214,7 +214,7 @@ mod tests { fn deserializer_with_schema( format: AvroFormat, writer_schema: Option<&str>, - ) -> (ArrowDeserializer, Vec>, ArroyoSchema) { + ) -> (ArrowDeserializer, ArroyoSchema) { let arrow_schema = if format.into_unstructured_json { Schema::new(vec![Field::new("value", DataType::Utf8, false)]) } else { @@ -239,13 +239,6 @@ mod tests { ArroyoSchema::from_schema_keys(Arc::new(Schema::new(fields)), vec![]).unwrap() }; - let builders: Vec<_> = arroyo_schema - .schema - .fields - .iter() - .map(|f| make_builder(f.data_type(), 8)) - .collect(); - let resolver: Arc = if let Some(schema) = &writer_schema { Arc::new(FixedSchemaResolver::new( if format.confluent_schema_registry { @@ -263,11 +256,10 @@ mod tests { ArrowDeserializer::with_schema_resolver( Format::Avro(format), None, - arroyo_schema.clone(), + Arc::new(arroyo_schema.clone()), BadData::Fail {}, resolver, ), - builders, arroyo_schema, ) } @@ -277,23 +269,15 @@ mod tests { writer_schema: Option<&str>, message: &[u8], ) -> Vec> { - let (mut deserializer, mut builders, arroyo_schema) = + let (mut deserializer, arroyo_schema) = deserializer_with_schema(format.clone(), writer_schema); let errors = deserializer - .deserialize_slice(&mut builders, message, SystemTime::now(), None) + .deserialize_slice(message, SystemTime::now(), None) .await; assert_eq!(errors, vec![]); - let batch = if format.into_unstructured_json { - RecordBatch::try_new( - arroyo_schema.schema, - builders.into_iter().map(|mut b| b.finish()).collect(), - ) - .unwrap() - } else { - deserializer.flush_buffer().unwrap().unwrap() - }; + let batch = deserializer.flush_buffer().unwrap().unwrap(); record_batch_to_vec(&batch, true, arrow_json::writer::TimestampFormat::RFC3339) .unwrap() diff --git a/crates/arroyo-formats/src/de.rs b/crates/arroyo-formats/src/de.rs index 052662d0d..9eb02a366 100644 --- a/crates/arroyo-formats/src/de.rs +++ b/crates/arroyo-formats/src/de.rs @@ -7,7 +7,7 @@ use arrow_array::builder::{ make_builder, ArrayBuilder, GenericByteBuilder, StringBuilder, TimestampNanosecondBuilder, }; use arrow_array::types::GenericBinaryType; -use arrow_array::RecordBatch; +use arrow_array::{ArrayRef, BooleanArray, RecordBatch}; use arrow_schema::SchemaRef; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::formats::{ @@ -33,7 +33,6 @@ pub enum FieldValueType<'a> { struct ContextBuffer { buffer: Vec>, created: Instant, - schema: SchemaRef, } impl ContextBuffer { @@ -47,7 +46,6 @@ impl ContextBuffer { Self { buffer, created: Instant::now(), - schema, } } @@ -59,12 +57,8 @@ impl ContextBuffer { should_flush(self.size(), self.created) } - pub fn finish(&mut self) -> RecordBatch { - RecordBatch::try_new( - self.schema.clone(), - self.buffer.iter_mut().map(|a| a.finish()).collect(), - ) - .unwrap() + pub fn finish(&mut self) -> Vec { + self.buffer.iter_mut().map(|a| a.finish()).collect() } } @@ -119,19 +113,111 @@ impl<'a> Iterator for FramingIterator<'a> { } } +enum BufferDecoder { + Buffer(ContextBuffer), + JsonDecoder { + decoder: arrow::json::reader::Decoder, + buffered_count: usize, + buffered_since: Instant, + }, +} + +impl BufferDecoder { + fn should_flush(&self) -> bool { + match self { + BufferDecoder::Buffer(b) => b.should_flush(), + BufferDecoder::JsonDecoder { + buffered_count, + buffered_since, + .. + } => should_flush(*buffered_count, *buffered_since), + } + } + + fn flush( + &mut self, + bad_data: &BadData, + ) -> Option, Option), SourceError>> { + match self { + BufferDecoder::Buffer(buffer) => { + if buffer.size() > 0 { + Some(Ok((buffer.finish(), None))) + } else { + None + } + } + BufferDecoder::JsonDecoder { + decoder, + buffered_since, + buffered_count, + } => { + *buffered_since = Instant::now(); + *buffered_count = 0; + Some(match bad_data { + BadData::Fail { .. } => decoder + .flush() + .map_err(|e| { + SourceError::bad_data(format!("JSON does not match schema: {:?}", e)) + }) + .transpose()? + .map(|batch| (batch.columns().to_vec(), None)), + BadData::Drop { .. } => decoder + .flush_with_bad_data() + .map_err(|e| { + SourceError::bad_data(format!( + "Something went wrong decoding JSON: {:?}", + e + )) + }) + .transpose()? + .map(|(batch, mask, _)| (batch.columns().to_vec(), Some(mask))), + }) + } + } + } + + fn decode_json(&mut self, msg: &[u8]) -> Result<(), SourceError> { + match self { + BufferDecoder::Buffer(_) => { + unreachable!("Tried to decode JSON for non-JSON deserializer"); + } + BufferDecoder::JsonDecoder { + decoder, + buffered_count, + .. + } => { + decoder + .decode(msg) + .map_err(|e| SourceError::bad_data(format!("invalid JSON: {:?}", e)))?; + + *buffered_count += 1; + + Ok(()) + } + } + } + + fn get_buffer(&mut self) -> &mut ContextBuffer { + match self { + BufferDecoder::Buffer(buffer) => buffer, + BufferDecoder::JsonDecoder { .. } => { + panic!("tried to get a raw buffer from a JSON deserializer"); + } + } + } +} + pub struct ArrowDeserializer { format: Arc, framing: Option>, schema: Arc, bad_data: BadData, - json_decoder: Option<(arrow::json::reader::Decoder, TimestampNanosecondBuilder)>, - buffered_count: usize, - buffered_since: Instant, schema_registry: Arc>>, proto_pool: DescriptorPool, schema_resolver: Arc, additional_fields_builder: Option>>, - buffer: ContextBuffer, + timestamp_builder: Option, + buffer_decoder: BufferDecoder, } impl ArrowDeserializer { @@ -172,43 +258,42 @@ impl ArrowDeserializer { DescriptorPool::global() }; + let buffer_decoder = match format { + Format::Json(..) + | Format::Avro(AvroFormat { + into_unstructured_json: false, + .. + }) + | Format::Protobuf(ProtobufFormat { + into_unstructured_json: false, + .. + }) => BufferDecoder::JsonDecoder { + decoder: arrow_json::reader::ReaderBuilder::new(Arc::new( + schema.schema_without_timestamp(), + )) + .with_limit_to_batch_size(false) + .with_strict_mode(false) + .with_allow_bad_data(matches!(bad_data, BadData::Drop { .. })) + .build_decoder() + .unwrap(), + buffered_count: 0, + buffered_since: Instant::now(), + }, + _ => BufferDecoder::Buffer(ContextBuffer::new(Arc::new( + schema.schema_without_timestamp(), + ))), + }; + Self { - json_decoder: matches!( - format, - Format::Json(..) - | Format::Avro(AvroFormat { - into_unstructured_json: false, - .. - }) - | Format::Protobuf(ProtobufFormat { - into_unstructured_json: false, - .. - }) - ) - .then(|| { - // exclude the timestamp field - ( - arrow_json::reader::ReaderBuilder::new(Arc::new( - schema.schema_without_timestamp(), - )) - .with_limit_to_batch_size(false) - .with_strict_mode(false) - .with_allow_bad_data(matches!(bad_data, BadData::Drop { .. })) - .build_decoder() - .unwrap(), - TimestampNanosecondBuilder::new(), - ) - }), format: Arc::new(format), framing: framing.map(Arc::new), - buffer: ContextBuffer::new(schema.schema.clone()), + buffer_decoder, + timestamp_builder: Some(TimestampNanosecondBuilder::with_capacity(128)), schema, schema_registry: Arc::new(Mutex::new(HashMap::new())), bad_data, schema_resolver, proto_pool, - buffered_count: 0, - buffered_since: Instant::now(), additional_fields_builder: None, } } @@ -219,108 +304,120 @@ impl ArrowDeserializer { timestamp: SystemTime, additional_fields: Option<&HashMap<&String, FieldValueType<'_>>>, ) -> Vec { - match &*self.format { - Format::Avro(_) => self.deserialize_slice_avro(msg, timestamp).await, - _ => FramingIterator::new(self.framing.clone(), msg) - .map(|t| self.deserialize_single(t, timestamp, additional_fields)) - .filter_map(|t| t.err()) - .collect(), + self.deserialize_slice_int(msg, Some(timestamp), additional_fields) + .await + } + + async fn deserialize_slice_int( + &mut self, + msg: &[u8], + timestamp: Option, + additional_fields: Option<&HashMap<&String, FieldValueType<'_>>>, + ) -> Vec { + let (count, errors) = match &*self.format { + Format::Avro(_) => self.deserialize_slice_avro(msg).await, + _ => { + let mut count = 0; + let errors = FramingIterator::new(self.framing.clone(), msg) + .map(|t| self.deserialize_single(t)) + .filter_map(|t| { + if t.is_ok() { + count += 1; + } + t.err() + }) + .collect(); + (count, errors) + } + }; + + if let Some(timestamp) = timestamp { + let b = self + .timestamp_builder + .as_mut() + .expect("tried to serialize timestamp to a schema without a timestamp column"); + + for _ in 0..count { + b.append_value(to_nanos(timestamp) as i64); + } } + + if let Some(additional_fields) = additional_fields { + if self.additional_fields_builder.is_none() { + let mut builders = HashMap::new(); + for (key, value) in additional_fields.iter() { + let builder: Box = match value { + FieldValueType::Int32(_) => Box::new(Int32Builder::new()), + FieldValueType::Int64(_) => Box::new(Int64Builder::new()), + FieldValueType::String(_) => Box::new(StringBuilder::new()), + }; + builders.insert(key.to_string(), builder); + } + self.additional_fields_builder = Some(builders); + } + + let builders = self.additional_fields_builder.as_mut().unwrap(); + + for (k, v) in additional_fields { + add_additional_fields(builders, k, v, count); + } + } + + errors } pub fn should_flush(&self) -> bool { - self.buffer.should_flush() || should_flush(self.buffered_count, self.buffered_since) + self.buffer_decoder.should_flush() } pub fn flush_buffer(&mut self) -> Option> { - if self.buffer.size() > 0 { - return Some(Ok(self.buffer.finish())); - } + let (mut arrays, error_mask) = match self.buffer_decoder.flush(&self.bad_data)? { + Ok((a, b)) => (a, b), + Err(e) => return Some(Err(e)), + }; - let (decoder, timestamp) = self.json_decoder.as_mut()?; - self.buffered_since = Instant::now(); - self.buffered_count = 0; - match self.bad_data { - BadData::Fail { .. } => Some( - decoder - .flush() - .map_err(|e| { - SourceError::bad_data(format!("JSON does not match schema: {:?}", e)) - }) - .transpose()? - .map(|batch| { - let mut columns = batch.columns().to_vec(); - columns.insert(self.schema.timestamp_index, Arc::new(timestamp.finish())); - flush_additional_fields_builders( - &mut self.additional_fields_builder, - &self.schema, - &mut columns, - ); - RecordBatch::try_new(self.schema.schema.clone(), columns).unwrap() - }), - ), - BadData::Drop { .. } => Some( - decoder - .flush_with_bad_data() - .map_err(|e| { - SourceError::bad_data(format!( - "Something went wrong decoding JSON: {:?}", - e - )) - }) - .transpose()? - .map(|(batch, mask, _)| { - let mut columns = batch.columns().to_vec(); - let timestamp = - kernels::filter::filter(×tamp.finish(), &mask).unwrap(); - - columns.insert(self.schema.timestamp_index, Arc::new(timestamp)); - flush_additional_fields_builders( - &mut self.additional_fields_builder, - &self.schema, - &mut columns, - ); - RecordBatch::try_new(self.schema.schema.clone(), columns).unwrap() - }), - ), + if let Some(additional_fields) = &mut self.additional_fields_builder { + for (name, builder) in additional_fields { + let (idx, _) = self + .schema + .schema + .column_with_name(&name) + .unwrap_or_else(|| panic!("Field '{}' not found in schema", name)); + + let mut array = builder.finish(); + if let Some(error_mask) = &error_mask { + array = kernels::filter::filter(&array, error_mask).unwrap(); + } + + arrays[idx] = array; + } + }; + + if let Some(timestamp) = &mut self.timestamp_builder { + let array = if let Some(error_mask) = &error_mask { + kernels::filter::filter(×tamp.finish(), error_mask).unwrap() + } else { + Arc::new(timestamp.finish()) + }; + + arrays.insert(self.schema.timestamp_index, array); } + + Some(Ok( + RecordBatch::try_new(self.schema.schema.clone(), arrays).unwrap() + )) } - fn deserialize_single( - &mut self, - msg: &[u8], - timestamp: SystemTime, - additional_fields: Option<&HashMap<&String, FieldValueType>>, - ) -> Result<(), SourceError> { + fn deserialize_single(&mut self, msg: &[u8]) -> Result<(), SourceError> { match &*self.format { Format::RawString(_) | Format::Json(JsonFormat { unstructured: true, .. }) => { self.deserialize_raw_string(msg); - add_timestamp( - &mut self.buffer.buffer, - self.schema.timestamp_index, - timestamp, - ); - if let Some(fields) = additional_fields { - for (k, v) in fields.iter() { - add_additional_fields(&mut self.buffer.buffer, &self.schema, k, v); - } - } } Format::RawBytes(_) => { self.deserialize_raw_bytes(msg); - add_timestamp( - &mut self.buffer.buffer, - self.schema.timestamp_index, - timestamp, - ); - if let Some(fields) = additional_fields { - for (k, v) in fields.iter() { - add_additional_fields(&mut self.buffer.buffer, &self.schema, k, v); - } - } } Format::Json(json) => { let msg = if json.confluent_schema_registry { @@ -329,62 +426,17 @@ impl ArrowDeserializer { msg }; - let Some((decoder, timestamp_builder)) = &mut self.json_decoder else { - panic!("json decoder not initialized"); - }; - - if self.additional_fields_builder.is_none() { - if let Some(fields) = additional_fields.as_ref() { - let mut builders = HashMap::new(); - for (key, value) in fields.iter() { - let builder: Box = match value { - FieldValueType::Int32(_) => Box::new(Int32Builder::new()), - FieldValueType::Int64(_) => Box::new(Int64Builder::new()), - FieldValueType::String(_) => Box::new(StringBuilder::new()), - }; - builders.insert(key, builder); - } - self.additional_fields_builder = Some( - builders - .into_iter() - .map(|(k, v)| ((*k).clone(), v)) - .collect(), - ); - } - } - - decoder - .decode(msg) - .map_err(|e| SourceError::bad_data(format!("invalid JSON: {:?}", e)))?; - timestamp_builder.append_value(to_nanos(timestamp) as i64); - - add_additional_fields_using_builder( - additional_fields, - &mut self.additional_fields_builder, - ); - self.buffered_count += 1; + self.buffer_decoder.decode_json(msg)?; } Format::Protobuf(proto) => { let json = proto::de::deserialize_proto(&mut self.proto_pool, proto, msg)?; if proto.into_unstructured_json { - self.decode_into_json(json, timestamp); + self.decode_into_json(json); } else { - let Some((decoder, timestamp_builder)) = &mut self.json_decoder else { - panic!("json decoder not initialized"); - }; - - decoder - .decode(json.to_string().as_bytes()) + self.buffer_decoder + .decode_json(json.to_string().as_bytes()) .map_err(|e| SourceError::bad_data(format!("invalid JSON: {:?}", e)))?; - timestamp_builder.append_value(to_nanos(timestamp) as i64); - - add_additional_fields_using_builder( - additional_fields, - &mut self.additional_fields_builder, - ); - - self.buffered_count += 1; } } Format::Avro(_) => unreachable!("this should not be called for avro"), @@ -394,31 +446,21 @@ impl ArrowDeserializer { Ok(()) } - fn decode_into_json(&mut self, value: Value, timestamp: SystemTime) { + fn decode_into_json(&mut self, value: Value) { let (idx, _) = self .schema .schema .column_with_name("value") .expect("no 'value' column for unstructured avro"); - let array = self.buffer.buffer[idx] + let array = self.buffer_decoder.get_buffer().buffer[idx] .as_any_mut() .downcast_mut::() .expect("'value' column has incorrect type"); array.append_value(value.to_string()); - add_timestamp( - &mut self.buffer.buffer, - self.schema.timestamp_index, - timestamp, - ); - self.buffered_count += 1; } - pub async fn deserialize_slice_avro<'a>( - &mut self, - msg: &'a [u8], - timestamp: SystemTime, - ) -> Vec { + async fn deserialize_slice_avro(&mut self, msg: &[u8]) -> (usize, Vec) { let Format::Avro(format) = &*self.format else { unreachable!("not avro"); }; @@ -433,13 +475,14 @@ impl ArrowDeserializer { { Ok(messages) => messages, Err(e) => { - return vec![e]; + return (0, vec![e]); } }; let into_json = format.into_unstructured_json; - messages + let mut count = 0; + let errors = messages .into_iter() .map(|record| { let value = record.map_err(|e| { @@ -447,27 +490,25 @@ impl ArrowDeserializer { })?; if into_json { - self.decode_into_json(de::avro_to_json(value), timestamp); + self.decode_into_json(de::avro_to_json(value)); } else { // for now round-trip through json in order to handle unsupported avro features // as that allows us to rely on raw json deserialization let json = de::avro_to_json(value).to_string(); - let Some((decoder, timestamp_builder)) = &mut self.json_decoder else { - panic!("json decoder not initialized"); - }; - - decoder - .decode(json.as_bytes()) + self.buffer_decoder + .decode_json(json.as_bytes()) .map_err(|e| SourceError::bad_data(format!("invalid JSON: {:?}", e)))?; - self.buffered_count += 1; - timestamp_builder.append_value(to_nanos(timestamp) as i64); } + count += 1; + Ok(()) }) .filter_map(|r: Result<(), SourceError>| r.err()) - .collect() + .collect(); + + (count, errors) } fn deserialize_raw_string(&mut self, msg: &[u8]) { @@ -476,7 +517,7 @@ impl ArrowDeserializer { .schema .column_with_name("value") .expect("no 'value' column for RawString format"); - self.buffer.buffer[col] + self.buffer_decoder.get_buffer().buffer[col] .as_any_mut() .downcast_mut::() .expect("'value' column has incorrect type") @@ -489,7 +530,7 @@ impl ArrowDeserializer { .schema .column_with_name("value") .expect("no 'value' column for RawBytes format"); - self.buffer.buffer[col] + self.buffer_decoder.get_buffer().buffer[col] .as_any_mut() .downcast_mut::>>() .expect("'value' column has incorrect type") @@ -501,111 +542,42 @@ impl ArrowDeserializer { } } -pub(crate) fn add_timestamp( - builder: &mut [Box], - idx: usize, - timestamp: SystemTime, -) { - builder[idx] - .as_any_mut() - .downcast_mut::() - .expect("_timestamp column has incorrect type") - .append_value(to_nanos(timestamp) as i64); -} - -pub(crate) fn add_additional_fields( - builder: &mut [Box], - schema: &ArroyoSchema, +fn add_additional_fields( + builders: &mut HashMap>, key: &str, value: &FieldValueType<'_>, + count: usize, ) { - let (idx, _) = schema - .schema - .column_with_name(key) - .unwrap_or_else(|| panic!("no '{}' column for additional fields", key)); + let builder = builders + .get_mut(key) + .unwrap_or_else(|| panic!("unexpected additional field '{}'", key)) + .as_any_mut(); match value { FieldValueType::Int32(i) => { - builder[idx] - .as_any_mut() + let b = builder .downcast_mut::() - .expect("additional field has incorrect type") - .append_value(*i); + .expect("additional field has incorrect type"); + + for _ in 0..count { + b.append_value(*i); + } } FieldValueType::Int64(i) => { - builder[idx] - .as_any_mut() + let b = builder .downcast_mut::() - .expect("additional field has incorrect type") - .append_value(*i); - } - FieldValueType::String(s) => { - builder[idx] - .as_any_mut() - .downcast_mut::() - .expect("additional field has incorrect type") - .append_value(s); - } - } -} + .expect("additional field has incorrect type"); -pub(crate) fn add_additional_fields_using_builder( - additional_fields: Option<&HashMap<&String, FieldValueType<'_>>>, - additional_fields_builder: &mut Option>>, -) { - if let Some(fields) = additional_fields { - for (k, v) in fields.iter() { - if let Some(builder) = additional_fields_builder - .as_mut() - .and_then(|b| b.get_mut(*k)) - { - match v { - FieldValueType::Int32(i) => { - builder - .as_any_mut() - .downcast_mut::() - .expect("additional field has incorrect type") - .append_value(*i); - } - FieldValueType::Int64(i) => { - builder - .as_any_mut() - .downcast_mut::() - .expect("additional field has incorrect type") - .append_value(*i); - } - FieldValueType::String(s) => { - builder - .as_any_mut() - .downcast_mut::() - .expect("additional field has incorrect type") - .append_value(s); - } - } + for _ in 0..count { + b.append_value(*i); } } - } -} + FieldValueType::String(s) => { + let b = builder + .downcast_mut::() + .expect("additional field has incorrect type"); -pub(crate) fn flush_additional_fields_builders( - additional_fields_builder: &mut Option>>, - schema: &ArroyoSchema, - columns: &mut [Arc], -) { - if let Some(additional_fields) = additional_fields_builder.take() { - for (field_name, mut builder) in additional_fields { - if let Some((idx, _)) = schema.schema.column_with_name(&field_name) { - let expected_type = schema.schema.fields[idx].data_type(); - let built_column = builder.as_mut().finish(); - let actual_type = built_column.data_type(); - if expected_type != actual_type { - panic!( - "Type mismatch for column '{}': expected {:?}, got {:?}", - field_name, expected_type, actual_type - ); - } - columns[idx] = Arc::new(built_column); - } else { - panic!("Field '{}' not found in schema", field_name); + for _ in 0..count { + b.append_value(*s); } } } @@ -615,10 +587,8 @@ pub(crate) fn flush_additional_fields_builders( mod tests { use crate::de::{ArrowDeserializer, FieldValueType, FramingIterator}; use arrow::datatypes::Int32Type; - use arrow_array::builder::{make_builder, ArrayBuilder}; use arrow_array::cast::AsArray; use arrow_array::types::{GenericBinaryType, Int64Type, TimestampNanosecondType}; - use arrow_array::RecordBatch; use arrow_schema::{Schema, TimeUnit}; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::formats::{ @@ -800,12 +770,6 @@ mod tests { ), ])); - let mut arrays: Vec<_> = schema - .fields - .iter() - .map(|f| make_builder(f.data_type(), 16)) - .collect(); - let arroyo_schema = Arc::new(ArroyoSchema::from_schema_unkeyed(schema.clone()).unwrap()); let mut deserializer = ArrowDeserializer::new( @@ -821,8 +785,7 @@ mod tests { .await; assert!(result.is_empty()); - let arrays: Vec<_> = arrays.into_iter().map(|mut a| a.finish()).collect(); - let batch = RecordBatch::try_new(schema, arrays).unwrap(); + let batch = deserializer.flush_buffer().unwrap().unwrap(); assert_eq!(batch.num_rows(), 1); assert_eq!( @@ -852,12 +815,6 @@ mod tests { ), ])); - let mut arrays: Vec<_> = schema - .fields - .iter() - .map(|f| make_builder(f.data_type(), 16)) - .collect(); - let arroyo_schema = Arc::new(ArroyoSchema::from_schema_unkeyed(schema.clone()).unwrap()); let mut deserializer = ArrowDeserializer::new( diff --git a/crates/arroyo-planner/src/extension/lookup.rs b/crates/arroyo-planner/src/extension/lookup.rs index 2d72a05f3..210d19d38 100644 --- a/crates/arroyo-planner/src/extension/lookup.rs +++ b/crates/arroyo-planner/src/extension/lookup.rs @@ -1,6 +1,7 @@ use crate::builder::{NamedNode, Planner}; use crate::extension::{ArroyoExtension, NodeWithIncomingEdges}; use crate::multifield_partial_ord; +use crate::schemas::add_timestamp_field_arrow; use crate::tables::ConnectorTable; use arroyo_datastream::logical::{LogicalEdge, LogicalEdgeType, LogicalNode, OperatorName}; use arroyo_rpc::df::{ArroyoSchema, ArroyoSchemaRef}; @@ -13,7 +14,6 @@ use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; use prost::Message; use std::fmt::Formatter; use std::sync::Arc; -use crate::schemas::{add_timestamp_field_arrow}; pub const SOURCE_EXTENSION_NAME: &str = "LookupSource"; pub const JOIN_EXTENSION_NAME: &str = "LookupJoin"; @@ -88,8 +88,9 @@ impl ArroyoExtension for LookupJoin { input_schemas: Vec, ) -> datafusion::common::Result { let schema = ArroyoSchema::from_schema_unkeyed(Arc::new(self.schema.as_ref().into()))?; - let lookup_schema = ArroyoSchema::from_schema_unkeyed( - add_timestamp_field_arrow(self.connector.physical_schema()))?; + let lookup_schema = ArroyoSchema::from_schema_unkeyed(add_timestamp_field_arrow( + self.connector.physical_schema(), + ))?; let join_config = LookupJoinOperator { input_schema: Some(schema.into()), lookup_schema: Some(lookup_schema.into()), diff --git a/crates/arroyo-rpc/src/lib.rs b/crates/arroyo-rpc/src/lib.rs index 6b693d06c..49e150bfa 100644 --- a/crates/arroyo-rpc/src/lib.rs +++ b/crates/arroyo-rpc/src/lib.rs @@ -37,7 +37,7 @@ pub mod grpc { pub mod api { #![allow(clippy::derive_partial_eq_without_eq, deprecated)] tonic::include_proto!("api"); - + impl From for arroyo_types::JoinType { fn from(value: JoinType) -> Self { match value { diff --git a/crates/arroyo-worker/src/arrow/lookup_join.rs b/crates/arroyo-worker/src/arrow/lookup_join.rs index 05e376c40..1403e0eaf 100644 --- a/crates/arroyo-worker/src/arrow/lookup_join.rs +++ b/crates/arroyo-worker/src/arrow/lookup_join.rs @@ -2,8 +2,8 @@ use std::collections::HashMap; use std::sync::Arc; use arrow::row::{OwnedRow, RowConverter, SortField}; -use arrow_array::RecordBatch; -use arroyo_connectors::{connectors}; +use arrow_array::{Array, RecordBatch}; +use arroyo_connectors::connectors; use arroyo_operator::connector::LookupConnector; use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::{ @@ -105,6 +105,9 @@ impl ArrowOperator for LookupJoin { let mut result = batch.columns().to_vec(); result.extend(right_side); + println!("SCHEMA = {:?}", ctx.out_schema.as_ref().unwrap().schema); + println!("RESULT COLS = {:?}", result.iter().map(|s| s.data_type())); + collector .collect( RecordBatch::try_new(ctx.out_schema.as_ref().unwrap().schema.clone(), result) @@ -143,20 +146,29 @@ impl OperatorConstructor for LookupJoinConstructor { let op = config.connector.unwrap(); let operator_config = serde_json::from_str(&op.config)?; - let result_row_converter = RowConverter::new(lookup_schema.schema.fields.iter().map(|f| - SortField::new(f.data_type().clone())).collect())?; - + let result_row_converter = RowConverter::new( + lookup_schema + .schema + .fields + .iter() + .map(|f| SortField::new(f.data_type().clone())) + .collect(), + )?; + let connector = connectors() .get(op.connector.as_str()) .unwrap_or_else(|| panic!("No connector with name '{}'", op.connector)) .make_lookup(operator_config, Arc::new(lookup_schema))?; - + Ok(ConstructedOperator::from_operator(Box::new(LookupJoin { connector, cache: Default::default(), - key_row_converter: RowConverter::new(exprs.iter().map(|e| - Ok(SortField::new(e.data_type(&input_schema.schema)?))) - .collect::>()?)?, + key_row_converter: RowConverter::new( + exprs + .iter() + .map(|e| Ok(SortField::new(e.data_type(&input_schema.schema)?))) + .collect::>()?, + )?, key_exprs: exprs, result_row_converter, join_type: join_type.into(), diff --git a/crates/arroyo-worker/src/engine.rs b/crates/arroyo-worker/src/engine.rs index da147c451..c7ed5ae93 100644 --- a/crates/arroyo-worker/src/engine.rs +++ b/crates/arroyo-worker/src/engine.rs @@ -782,11 +782,9 @@ pub async fn construct_node( ) -> OperatorNode { if chain.is_source() { let (head, _) = chain.iter().next().unwrap(); - let ConstructedOperator::Source(operator) = construct_operator( - head.operator_name, - &head.operator_config, - registry, - ) else { + let ConstructedOperator::Source(operator) = + construct_operator(head.operator_name, &head.operator_config, registry) + else { unreachable!(); }; From 3ba8c4221bf8bb1eead39838541ad393ed0e9796 Mon Sep 17 00:00:00 2001 From: Micah Wylde Date: Thu, 9 Jan 2025 17:14:45 -0800 Subject: [PATCH 06/14] More progress on lookup joins --- .../src/filesystem/sink/local.rs | 2 +- .../src/filesystem/sink/mod.rs | 2 +- crates/arroyo-connectors/src/kafka/mod.rs | 1 + .../arroyo-connectors/src/kafka/source/mod.rs | 3 +- .../src/kafka/source/test.rs | 7 +- .../arroyo-connectors/src/mqtt/source/mod.rs | 2 +- crates/arroyo-connectors/src/redis/lookup.rs | 51 ++++-- crates/arroyo-connectors/src/redis/mod.rs | 24 ++- crates/arroyo-formats/src/de.rs | 160 +++++++++++++----- crates/arroyo-operator/src/connector.rs | 9 +- crates/arroyo-operator/src/context.rs | 6 +- .../arroyo-rpc/src/api_types/connections.rs | 1 + crates/arroyo-rpc/src/df.rs | 15 +- crates/arroyo-rpc/src/lib.rs | 2 + crates/arroyo-types/src/lib.rs | 2 + crates/arroyo-worker/src/arrow/lookup_join.rs | 37 ++-- 16 files changed, 241 insertions(+), 83 deletions(-) diff --git a/crates/arroyo-connectors/src/filesystem/sink/local.rs b/crates/arroyo-connectors/src/filesystem/sink/local.rs index a9a8ae07b..7aec35964 100644 --- a/crates/arroyo-connectors/src/filesystem/sink/local.rs +++ b/crates/arroyo-connectors/src/filesystem/sink/local.rs @@ -233,7 +233,7 @@ impl TwoPhaseCommitter for LocalFileSystemWrite let storage_provider = StorageProvider::for_url(&self.final_dir).await?; - let schema = Arc::new(ctx.in_schemas[0].clone()); + let schema = ctx.in_schemas[0].clone(); self.commit_state = Some(match self.file_settings.commit_style.unwrap() { CommitStyle::DeltaLake => CommitState::DeltaLake { diff --git a/crates/arroyo-connectors/src/filesystem/sink/mod.rs b/crates/arroyo-connectors/src/filesystem/sink/mod.rs index caa68632c..20b318074 100644 --- a/crates/arroyo-connectors/src/filesystem/sink/mod.rs +++ b/crates/arroyo-connectors/src/filesystem/sink/mod.rs @@ -1564,7 +1564,7 @@ impl TwoPhaseCommitter for FileSystemSink, ) -> Result<()> { - self.start(Arc::new(ctx.in_schemas.first().unwrap().clone()))?; + self.start(ctx.in_schemas.first().unwrap().clone()).await?; let mut max_file_index = 0; let mut recovered_files = Vec::new(); diff --git a/crates/arroyo-connectors/src/kafka/mod.rs b/crates/arroyo-connectors/src/kafka/mod.rs index fd095d56c..515999627 100644 --- a/crates/arroyo-connectors/src/kafka/mod.rs +++ b/crates/arroyo-connectors/src/kafka/mod.rs @@ -642,6 +642,7 @@ impl KafkaTester { format.clone(), None, Arc::new(aschema), + &schema.metadata_fields(), BadData::Fail {}, Arc::new(schema_resolver), ); diff --git a/crates/arroyo-connectors/src/kafka/source/mod.rs b/crates/arroyo-connectors/src/kafka/source/mod.rs index 38c4a1cd2..23cb382e6 100644 --- a/crates/arroyo-connectors/src/kafka/source/mod.rs +++ b/crates/arroyo-connectors/src/kafka/source/mod.rs @@ -173,6 +173,7 @@ impl KafkaSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &self.metadata_fields, schema_resolver.clone(), ); } else { @@ -201,7 +202,7 @@ impl KafkaSourceFunc { let connector_metadata = if !self.metadata_fields.is_empty() { let mut connector_metadata = HashMap::new(); for f in &self.metadata_fields { - connector_metadata.insert(&f.field_name, match f.key.as_str() { + connector_metadata.insert(f.field_name.as_str(), match f.key.as_str() { "offset_id" => FieldValueType::Int64(msg.offset()), "partition" => FieldValueType::Int32(msg.partition()), "topic" => FieldValueType::String(topic), diff --git a/crates/arroyo-connectors/src/kafka/source/test.rs b/crates/arroyo-connectors/src/kafka/source/test.rs index 3dab460c3..b0a6d2a84 100644 --- a/crates/arroyo-connectors/src/kafka/source/test.rs +++ b/crates/arroyo-connectors/src/kafka/source/test.rs @@ -11,7 +11,7 @@ use std::collections::{HashMap, VecDeque}; use std::num::NonZeroU32; use std::sync::Arc; use std::time::{Duration, SystemTime}; - +use arrow::datatypes::DataType::UInt64; use crate::kafka::SourceOffset; use arroyo_operator::context::{ batch_bounded, ArrowCollector, BatchReceiver, OperatorContext, SourceCollector, SourceContext, @@ -389,6 +389,7 @@ async fn test_kafka_with_metadata_fields() { let metadata_fields = vec![MetadataField { field_name: "offset".to_string(), key: "offset_id".to_string(), + data_type: Some(UInt64), }]; // Set metadata fields in KafkaSourceFunc @@ -420,7 +421,7 @@ async fn test_kafka_with_metadata_fields() { command_tx.clone(), 1, vec![], - Some(ArroyoSchema::new_unkeyed( + Some(Arc::new(ArroyoSchema::new_unkeyed( Arc::new(Schema::new(vec![ Field::new( "_timestamp", @@ -431,7 +432,7 @@ async fn test_kafka_with_metadata_fields() { Field::new("offset", DataType::Int64, false), ])), 0, - )), + ))), kafka.tables(), ) .await; diff --git a/crates/arroyo-connectors/src/mqtt/source/mod.rs b/crates/arroyo-connectors/src/mqtt/source/mod.rs index 6c2d51577..2710682f7 100644 --- a/crates/arroyo-connectors/src/mqtt/source/mod.rs +++ b/crates/arroyo-connectors/src/mqtt/source/mod.rs @@ -152,7 +152,7 @@ impl MqttSourceFunc { let connector_metadata = if !self.metadata_fields.is_empty() { let mut connector_metadata = HashMap::new(); for mf in &self.metadata_fields { - connector_metadata.insert(&mf.field_name, match mf.key.as_str() { + connector_metadata.insert(mf.field_name.as_str(), match mf.key.as_str() { "topic" => FieldValueType::String(&topic), k => unreachable!("invalid metadata key '{}' for mqtt", k) }); diff --git a/crates/arroyo-connectors/src/redis/lookup.rs b/crates/arroyo-connectors/src/redis/lookup.rs index f14903c3e..a2907f54e 100644 --- a/crates/arroyo-connectors/src/redis/lookup.rs +++ b/crates/arroyo-connectors/src/redis/lookup.rs @@ -1,19 +1,21 @@ +use std::collections::HashMap; use crate::redis::sink::GeneralConnection; use crate::redis::RedisClient; -use arrow::array::{ArrayRef, AsArray, RecordBatch}; +use arrow::array::{Array, ArrayRef, AsArray, RecordBatch}; use arrow::datatypes::DataType; -use arroyo_formats::de::ArrowDeserializer; +use arroyo_formats::de::{ArrowDeserializer, FieldValueType}; use arroyo_operator::connector::LookupConnector; -use arroyo_types::SourceError; +use arroyo_types::{SourceError, LOOKUP_KEY_INDEX_FIELD}; use async_trait::async_trait; use redis::aio::ConnectionLike; use redis::{cmd, Value}; -use std::time::SystemTime; +use arroyo_rpc::MetadataField; pub struct RedisLookup { pub(crate) deserializer: ArrowDeserializer, pub(crate) client: RedisClient, pub(crate) connection: Option, + pub(crate) metadata_fields: Vec, } #[async_trait] @@ -38,29 +40,54 @@ impl LookupConnector for RedisLookup { let mut mget = cmd("mget"); - for k in keys[0].as_string::() { + let keys = keys[0].as_string::(); + + for k in keys { mget.arg(k.unwrap()); + println!("GETTTING {:?}", k); } let Value::Array(vs) = connection.req_packed_command(&mget).await.unwrap() else { panic!("value was not an array"); }; - for v in vs { - match v { + assert_eq!(vs.len(), keys.len(), "Redis sent back the wrong number of values"); + + let mut additional = HashMap::new(); + + for (idx, (v, k)) in vs.iter().zip(keys).enumerate() { + additional.insert(LOOKUP_KEY_INDEX_FIELD, FieldValueType::UInt64(idx as u64)); + for m in &self.metadata_fields { + additional.insert(m.field_name.as_str(), match m.key.as_str() { + "key" => FieldValueType::String(k.unwrap()), + k => unreachable!("Invalid metadata key '{}'", k) + }); + } + + let errors = match v { Value::Nil => { - self.deserializer - .deserialize_slice("null".as_bytes(), SystemTime::now(), None) - .await; + println!("GOt null"); + vec![] } Value::SimpleString(s) => { + println!("Got {:?}", s); + self.deserializer + .deserialize_without_timestamp(s.as_bytes(), Some(&additional)) + .await + } + Value::BulkString(v) => { + println!("Got {:?}", String::from_utf8(v.clone())); self.deserializer - .deserialize_slice(s.as_bytes(), SystemTime::now(), None) - .await; + .deserialize_without_timestamp(&v, Some(&additional)) + .await } v => { panic!("unexpected type {:?}", v); } + }; + + if !errors.is_empty() { + return Some(Err(errors.into_iter().next().unwrap())); } } diff --git a/crates/arroyo-connectors/src/redis/mod.rs b/crates/arroyo-connectors/src/redis/mod.rs index 9673ddb63..8d8057be8 100644 --- a/crates/arroyo-connectors/src/redis/mod.rs +++ b/crates/arroyo-connectors/src/redis/mod.rs @@ -4,7 +4,7 @@ pub mod sink; use anyhow::{anyhow, bail}; use arroyo_formats::de::ArrowDeserializer; use arroyo_formats::ser::ArrowSerializer; -use arroyo_operator::connector::{Connection, Connector, LookupConnector}; +use arroyo_operator::connector::{Connection, Connector, LookupConnector, MetadataDef}; use arroyo_operator::operator::ConstructedOperator; use arroyo_rpc::api_types::connections::{ ConnectionProfile, ConnectionSchema, ConnectionType, FieldType, PrimitiveType, @@ -19,9 +19,10 @@ use redis::{Client, ConnectionInfo, IntoConnectionInfo}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; +use arrow::datatypes::{DataType, Schema}; use tokio::sync::oneshot::Receiver; use typify::import_types; - +use arroyo_rpc::schema_resolver::FailingSchemaResolver; use crate::redis::lookup::RedisLookup; use crate::redis::sink::{GeneralConnection, RedisSinkFunc}; use crate::{pull_opt, pull_option_to_u64}; @@ -41,7 +42,7 @@ import_types!( import_types!(schema = "src/redis/table.json"); -enum RedisClient { +pub(crate) enum RedisClient { Standard(Client), Clustered(ClusterClient), } @@ -178,6 +179,13 @@ impl Connector for RedisConnector { } } + fn metadata_defs(&self) -> &'static [MetadataDef] { + &[MetadataDef { + name: "key", + data_type: DataType::Utf8, + }] + } + fn table_type(&self, _: Self::ProfileT, _: Self::TableT) -> ConnectionType { ConnectionType::Source } @@ -394,7 +402,7 @@ impl Connector for RedisConnector { format: Some(format), bad_data: schema.bad_data.clone(), framing: schema.framing.clone(), - metadata_fields: vec![], + metadata_fields: schema.metadata_fields(), }; Ok(Connection { @@ -447,19 +455,21 @@ impl Connector for RedisConnector { profile: Self::ProfileT, table: Self::TableT, config: OperatorConfig, - schema: Arc, + schema: Arc, ) -> anyhow::Result> { Ok(Box::new(RedisLookup { - deserializer: ArrowDeserializer::new( + deserializer: ArrowDeserializer::for_lookup( config .format .ok_or_else(|| anyhow!("Redis table must have a format"))?, schema, - config.framing, + &config.metadata_fields, config.bad_data.unwrap_or_default(), + Arc::new(FailingSchemaResolver::new()), ), client: RedisClient::new(&profile)?, connection: None, + metadata_fields: config.metadata_fields, })) } } diff --git a/crates/arroyo-formats/src/de.rs b/crates/arroyo-formats/src/de.rs index 9eb02a366..de681954c 100644 --- a/crates/arroyo-formats/src/de.rs +++ b/crates/arroyo-formats/src/de.rs @@ -3,28 +3,28 @@ use crate::proto::schema::get_pool; use crate::{proto, should_flush}; use arrow::array::{Int32Builder, Int64Builder}; use arrow::compute::kernels; -use arrow_array::builder::{ - make_builder, ArrayBuilder, GenericByteBuilder, StringBuilder, TimestampNanosecondBuilder, -}; +use arrow_array::builder::{make_builder, ArrayBuilder, GenericByteBuilder, StringBuilder, TimestampNanosecondBuilder, UInt64Builder}; use arrow_array::types::GenericBinaryType; use arrow_array::{ArrayRef, BooleanArray, RecordBatch}; -use arrow_schema::SchemaRef; +use arrow_schema::{DataType, Schema, SchemaRef}; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::formats::{ AvroFormat, BadData, Format, Framing, FramingMethod, JsonFormat, ProtobufFormat, }; use arroyo_rpc::schema_resolver::{FailingSchemaResolver, FixedSchemaResolver, SchemaResolver}; -use arroyo_types::{to_nanos, SourceError}; +use arroyo_types::{to_nanos, SourceError, LOOKUP_KEY_INDEX_FIELD}; use prost_reflect::DescriptorPool; use serde_json::Value; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::time::{Instant, SystemTime}; use tokio::sync::Mutex; +use arroyo_rpc::MetadataField; #[derive(Debug, Clone)] pub enum FieldValueType<'a> { Int64(i64), + UInt64(u64), Int32(i32), String(&'a str), // Extend with more types as needed @@ -50,7 +50,7 @@ impl ContextBuffer { } pub fn size(&self) -> usize { - self.buffer[0].len() + self.buffer.iter().map(|b| b.len()).max().unwrap() } pub fn should_flush(&self) -> bool { @@ -210,13 +210,14 @@ impl BufferDecoder { pub struct ArrowDeserializer { format: Arc, framing: Option>, - schema: Arc, + final_schema: Arc, + decoder_schema: Arc, bad_data: BadData, schema_registry: Arc>>, proto_pool: DescriptorPool, schema_resolver: Arc, additional_fields_builder: Option>>, - timestamp_builder: Option, + timestamp_builder: Option<(usize, TimestampNanosecondBuilder)>, buffer_decoder: BufferDecoder, } @@ -238,13 +239,60 @@ impl ArrowDeserializer { Arc::new(FailingSchemaResolver::new()) as Arc }; - Self::with_schema_resolver(format, framing, schema, bad_data, resolver) + Self::with_schema_resolver(format, framing, schema, &[], bad_data, resolver) } - + pub fn with_schema_resolver( format: Format, framing: Option, schema: Arc, + metadata_fields: &[MetadataField], + bad_data: BadData, + schema_resolver: Arc, + ) -> Self { + Self::with_schema_resolver_and_raw_schema( + format, + framing, + Arc::new(schema.schema_without_timestamp()), + Some(schema.timestamp_index), + metadata_fields, + bad_data, + schema_resolver + ) + } + + pub fn for_lookup( + format: Format, + schema: Arc, + metadata_fields: &[MetadataField], + bad_data: BadData, + schema_resolver: Arc, + ) -> Self { + let mut metadata_fields = metadata_fields.to_vec(); + metadata_fields + .push(MetadataField { + field_name: LOOKUP_KEY_INDEX_FIELD.to_string(), + key: LOOKUP_KEY_INDEX_FIELD.to_string(), + data_type: Some(DataType::UInt64), + }); + + Self::with_schema_resolver_and_raw_schema( + format, + None, + schema, + None, + &metadata_fields, + bad_data, + schema_resolver + ) + } + + fn with_schema_resolver_and_raw_schema( + format: Format, + framing: Option, + schema_without_timestamp: Arc, + timestamp_idx: Option, + metadata_fields: &[MetadataField], bad_data: BadData, schema_resolver: Arc, ) -> Self { @@ -257,7 +305,20 @@ impl ArrowDeserializer { } else { DescriptorPool::global() }; - + + let metadata_names: HashSet<_> = metadata_fields.iter().map(|f| &f.field_name).collect(); + + let schema_without_additional = { + let fields = schema_without_timestamp.fields().iter() + .filter(|f| !metadata_names.contains(f.name())) + .map(|f| f.clone()) + .collect::>(); + Arc::new(Schema::new_with_metadata(fields, schema_without_timestamp.metadata.clone())) + }; + + println!("Schema without additional = {:?}", schema_without_additional); + println!("metadata_names = {:?}", metadata_names); + let buffer_decoder = match format { Format::Json(..) | Format::Avro(AvroFormat { @@ -268,9 +329,7 @@ impl ArrowDeserializer { into_unstructured_json: false, .. }) => BufferDecoder::JsonDecoder { - decoder: arrow_json::reader::ReaderBuilder::new(Arc::new( - schema.schema_without_timestamp(), - )) + decoder: arrow_json::reader::ReaderBuilder::new(schema_without_additional.clone()) .with_limit_to_batch_size(false) .with_strict_mode(false) .with_allow_bad_data(matches!(bad_data, BadData::Drop { .. })) @@ -279,17 +338,16 @@ impl ArrowDeserializer { buffered_count: 0, buffered_since: Instant::now(), }, - _ => BufferDecoder::Buffer(ContextBuffer::new(Arc::new( - schema.schema_without_timestamp(), - ))), + _ => BufferDecoder::Buffer(ContextBuffer::new(schema_without_additional.clone())), }; Self { format: Arc::new(format), framing: framing.map(Arc::new), buffer_decoder, - timestamp_builder: Some(TimestampNanosecondBuilder::with_capacity(128)), - schema, + timestamp_builder: timestamp_idx.map(|i| (i, TimestampNanosecondBuilder::with_capacity(128))), + final_schema: schema_without_timestamp, + decoder_schema: schema_without_additional, schema_registry: Arc::new(Mutex::new(HashMap::new())), bad_data, schema_resolver, @@ -298,21 +356,31 @@ impl ArrowDeserializer { } } + #[must_use] pub async fn deserialize_slice( &mut self, msg: &[u8], timestamp: SystemTime, - additional_fields: Option<&HashMap<&String, FieldValueType<'_>>>, + additional_fields: Option<&HashMap<&str, FieldValueType<'_>>>, ) -> Vec { self.deserialize_slice_int(msg, Some(timestamp), additional_fields) .await } + #[must_use] + pub async fn deserialize_without_timestamp( + &mut self, + msg: &[u8], + additional_fields: Option<&HashMap<&str, FieldValueType<'_>>>, + ) -> Vec { + self.deserialize_slice_int(msg, None, additional_fields).await + } + async fn deserialize_slice_int( &mut self, msg: &[u8], timestamp: Option, - additional_fields: Option<&HashMap<&String, FieldValueType<'_>>>, + additional_fields: Option<&HashMap<&str, FieldValueType<'_>>>, ) -> Vec { let (count, errors) = match &*self.format { Format::Avro(_) => self.deserialize_slice_avro(msg).await, @@ -332,7 +400,7 @@ impl ArrowDeserializer { }; if let Some(timestamp) = timestamp { - let b = self + let (_, b) = self .timestamp_builder .as_mut() .expect("tried to serialize timestamp to a schema without a timestamp column"); @@ -349,6 +417,7 @@ impl ArrowDeserializer { let builder: Box = match value { FieldValueType::Int32(_) => Box::new(Int32Builder::new()), FieldValueType::Int64(_) => Box::new(Int64Builder::new()), + FieldValueType::UInt64(_) => Box::new(UInt64Builder::new()), FieldValueType::String(_) => Box::new(StringBuilder::new()), }; builders.insert(key.to_string(), builder); @@ -369,18 +438,20 @@ impl ArrowDeserializer { pub fn should_flush(&self) -> bool { self.buffer_decoder.should_flush() } - + pub fn flush_buffer(&mut self) -> Option> { let (mut arrays, error_mask) = match self.buffer_decoder.flush(&self.bad_data)? { Ok((a, b)) => (a, b), Err(e) => return Some(Err(e)), }; + + println!("ARrays = {:?}", arrays); + println!("error mask = {:?}", error_mask); if let Some(additional_fields) = &mut self.additional_fields_builder { for (name, builder) in additional_fields { let (idx, _) = self - .schema - .schema + .final_schema .column_with_name(&name) .unwrap_or_else(|| panic!("Field '{}' not found in schema", name)); @@ -388,23 +459,29 @@ impl ArrowDeserializer { if let Some(error_mask) = &error_mask { array = kernels::filter::filter(&array, error_mask).unwrap(); } + + println!("Schema = {:?}", self.final_schema); - arrays[idx] = array; + arrays.insert(idx, array); } }; - if let Some(timestamp) = &mut self.timestamp_builder { + println!("With add = {:?}", arrays); + + if let Some((idx, timestamp)) = &mut self.timestamp_builder { let array = if let Some(error_mask) = &error_mask { kernels::filter::filter(×tamp.finish(), error_mask).unwrap() } else { Arc::new(timestamp.finish()) }; - arrays.insert(self.schema.timestamp_index, array); + arrays.insert(*idx, array); } + + println!("With timestamp ={:?}", arrays); Some(Ok( - RecordBatch::try_new(self.schema.schema.clone(), arrays).unwrap() + RecordBatch::try_new(self.final_schema.clone(), arrays).unwrap() )) } @@ -448,8 +525,7 @@ impl ArrowDeserializer { fn decode_into_json(&mut self, value: Value) { let (idx, _) = self - .schema - .schema + .decoder_schema .column_with_name("value") .expect("no 'value' column for unstructured avro"); let array = self.buffer_decoder.get_buffer().buffer[idx] @@ -512,9 +588,9 @@ impl ArrowDeserializer { } fn deserialize_raw_string(&mut self, msg: &[u8]) { + println!("Deserializing raw string {:?}", msg); let (col, _) = self - .schema - .schema + .decoder_schema .column_with_name("value") .expect("no 'value' column for RawString format"); self.buffer_decoder.get_buffer().buffer[col] @@ -526,8 +602,7 @@ impl ArrowDeserializer { fn deserialize_raw_bytes(&mut self, msg: &[u8]) { let (col, _) = self - .schema - .schema + .decoder_schema .column_with_name("value") .expect("no 'value' column for RawBytes format"); self.buffer_decoder.get_buffer().buffer[col] @@ -571,6 +646,15 @@ fn add_additional_fields( b.append_value(*i); } } + FieldValueType::UInt64(i) => { + let b = builder + .downcast_mut::() + .expect("additional field has incorrect type"); + + for _ in 0..count { + b.append_value(*i); + } + } FieldValueType::String(s) => { let b = builder .downcast_mut::() @@ -834,10 +918,10 @@ mod tests { let time = SystemTime::now(); let mut additional_fields = std::collections::HashMap::new(); let binding = "y".to_string(); - additional_fields.insert(&binding, FieldValueType::Int32(5)); + additional_fields.insert(binding.as_str(), FieldValueType::Int32(5)); let z_value = "hello".to_string(); let binding = "z".to_string(); - additional_fields.insert(&binding, FieldValueType::String(&z_value)); + additional_fields.insert(binding.as_str(), FieldValueType::String(&z_value)); let result = deserializer .deserialize_slice( diff --git a/crates/arroyo-operator/src/connector.rs b/crates/arroyo-operator/src/connector.rs index 51c0a38d2..54d678214 100644 --- a/crates/arroyo-operator/src/connector.rs +++ b/crates/arroyo-operator/src/connector.rs @@ -1,11 +1,10 @@ use crate::operator::ConstructedOperator; use anyhow::{anyhow, bail}; use arrow::array::{ArrayRef, RecordBatch}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, Schema}; use arroyo_rpc::api_types::connections::{ ConnectionProfile, ConnectionSchema, ConnectionType, TestSourceMessage, }; -use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::OperatorConfig; use arroyo_types::{DisplayAsSql, SourceError}; use async_trait::async_trait; @@ -129,7 +128,7 @@ pub trait Connector: Send { profile: Self::ProfileT, table: Self::TableT, config: OperatorConfig, - schema: Arc, + schema: Arc, ) -> anyhow::Result> { bail!("{} is not a lookup connector", self.name()) } @@ -206,7 +205,7 @@ pub trait ErasedConnector: Send { fn make_lookup( &self, config: OperatorConfig, - schema: Arc, + schema: Arc, ) -> anyhow::Result>; } @@ -360,7 +359,7 @@ impl ErasedConnector for C { fn make_lookup( &self, config: OperatorConfig, - schema: Arc, + schema: Arc, ) -> anyhow::Result> { self.make_lookup( self.parse_config(&config.connection)?, diff --git a/crates/arroyo-operator/src/context.rs b/crates/arroyo-operator/src/context.rs index aeba7a2a5..747ed91e4 100644 --- a/crates/arroyo-operator/src/context.rs +++ b/crates/arroyo-operator/src/context.rs @@ -9,7 +9,7 @@ use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::formats::{BadData, Format, Framing}; use arroyo_rpc::grpc::rpc::{CheckpointMetadata, TableConfig, TaskCheckpointEventType}; use arroyo_rpc::schema_resolver::SchemaResolver; -use arroyo_rpc::{get_hasher, CompactionResult, ControlMessage, ControlResp}; +use arroyo_rpc::{get_hasher, CompactionResult, ControlMessage, ControlResp, MetadataField}; use arroyo_state::tables::table_manager::TableManager; use arroyo_types::{ ArrowMessage, ChainInfo, CheckpointBarrier, SignalMessage, SourceError, TaskInfo, UserError, @@ -302,12 +302,14 @@ impl SourceCollector { format: Format, framing: Option, bad_data: Option, + metadata_fields: &[MetadataField], schema_resolver: Arc, ) { self.deserializer = Some(ArrowDeserializer::with_schema_resolver( format, framing, self.out_schema.clone(), + metadata_fields, bad_data.unwrap_or_default(), schema_resolver, )); @@ -342,7 +344,7 @@ impl SourceCollector { &mut self, msg: &[u8], time: SystemTime, - additional_fields: Option<&HashMap<&String, FieldValueType<'_>>>, + additional_fields: Option<&HashMap<&str, FieldValueType<'_>>>, ) -> Result<(), UserError> { let deserializer = self .deserializer diff --git a/crates/arroyo-rpc/src/api_types/connections.rs b/crates/arroyo-rpc/src/api_types/connections.rs index 3ce2a8f72..a47d58151 100644 --- a/crates/arroyo-rpc/src/api_types/connections.rs +++ b/crates/arroyo-rpc/src/api_types/connections.rs @@ -323,6 +323,7 @@ impl ConnectionSchema { Some(MetadataField { field_name: f.field_name.clone(), key: f.metadata_key.clone()?, + data_type: Some(Field::from(f.clone()).data_type().clone()), }) }) .collect() diff --git a/crates/arroyo-rpc/src/df.rs b/crates/arroyo-rpc/src/df.rs index 3bbec89c2..a2e587d9f 100644 --- a/crates/arroyo-rpc/src/df.rs +++ b/crates/arroyo-rpc/src/df.rs @@ -1,6 +1,6 @@ use crate::grpc::api; use crate::{grpc, Converter, TIMESTAMP_FIELD}; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, bail, Result}; use arrow::compute::kernels::numeric::div; use arrow::compute::{filter_record_batch, take}; use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit}; @@ -384,6 +384,19 @@ impl ArroyoSchema { key_indices: None, }) } + + pub fn with_field(&self, name: &str, data_type: DataType, nullable: bool) -> Result { + if self.schema.field_with_name(name).is_ok() { + bail!("cannot add field '{}' to schema, it is already present", name); + } + let mut fields = self.schema.fields().to_vec(); + fields.push(Arc::new(Field::new(name, data_type, nullable))); + Ok(Self { + schema: Arc::new(Schema::new_with_metadata(fields, self.schema.metadata.clone())), + timestamp_index: self.timestamp_index, + key_indices: self.key_indices.clone(), + }) + } } pub fn server_for_hash_array( diff --git a/crates/arroyo-rpc/src/lib.rs b/crates/arroyo-rpc/src/lib.rs index 49e150bfa..b06ef387f 100644 --- a/crates/arroyo-rpc/src/lib.rs +++ b/crates/arroyo-rpc/src/lib.rs @@ -201,6 +201,8 @@ pub struct RateLimit { pub struct MetadataField { pub field_name: String, pub key: String, + #[serde(default)] + pub data_type: Option, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] diff --git a/crates/arroyo-types/src/lib.rs b/crates/arroyo-types/src/lib.rs index bcc5df76d..c9b8274cd 100644 --- a/crates/arroyo-types/src/lib.rs +++ b/crates/arroyo-types/src/lib.rs @@ -693,3 +693,5 @@ mod tests { ); } } + +pub const LOOKUP_KEY_INDEX_FIELD: &str = "__lookup_key_index"; \ No newline at end of file diff --git a/crates/arroyo-worker/src/arrow/lookup_join.rs b/crates/arroyo-worker/src/arrow/lookup_join.rs index 1403e0eaf..46dbbc754 100644 --- a/crates/arroyo-worker/src/arrow/lookup_join.rs +++ b/crates/arroyo-worker/src/arrow/lookup_join.rs @@ -1,17 +1,19 @@ use std::collections::HashMap; +use std::ops::Index; use std::sync::Arc; use arrow::row::{OwnedRow, RowConverter, SortField}; use arrow_array::{Array, RecordBatch}; +use arrow_array::cast::AsArray; +use arrow_array::types::UInt64Type; +use arrow_schema::DataType; use arroyo_connectors::connectors; use arroyo_operator::connector::LookupConnector; use arroyo_operator::context::{Collector, OperatorContext}; -use arroyo_operator::operator::{ - ArrowOperator, ConstructedOperator, OperatorConstructor, Registry, -}; +use arroyo_operator::operator::{ArrowOperator, ConstructedOperator, DisplayableOperator, OperatorConstructor, Registry}; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::grpc::api; -use arroyo_types::JoinType; +use arroyo_types::{JoinType, LOOKUP_KEY_INDEX_FIELD}; use async_trait::async_trait; use datafusion::physical_expr::PhysicalExpr; use datafusion_proto::physical_plan::from_proto::parse_physical_expr; @@ -32,7 +34,7 @@ pub struct LookupJoin { #[async_trait] impl ArrowOperator for LookupJoin { fn name(&self) -> String { - format!("LookupJoin<{}>", self.connector.name()) + format!("LookupJoin({})", self.connector.name()) } async fn process_batch( @@ -71,16 +73,22 @@ impl ArrowOperator for LookupJoin { let result_batch = self.connector.lookup(&cols).await; + println!("Batch = {:?}", result_batch); + if let Some(result_batch) = result_batch { + let mut result_batch = result_batch.unwrap(); + let key_idx_col = result_batch.schema().index_of(LOOKUP_KEY_INDEX_FIELD).unwrap(); + + let keys = result_batch.remove_column(key_idx_col); + let keys = keys.as_primitive::(); + let result_rows = self .result_row_converter - .convert_columns(result_batch.unwrap().columns()) + .convert_columns(result_batch.columns()) .unwrap(); - assert_eq!(result_rows.num_rows(), uncached_keys.len()); - - for (k, v) in uncached_keys.iter().zip(result_rows.iter()) { - self.cache.insert(k.as_ref().to_vec(), v.owned()); + for (v, idx) in result_rows.iter().zip(keys) { + self.cache.insert(uncached_keys[idx.unwrap() as usize].as_ref().to_vec(), v.owned()); } } } @@ -89,6 +97,8 @@ impl ArrowOperator for LookupJoin { .result_row_converter .empty_rows(batch.num_rows(), batch.num_rows() * 10); + println!("Cache {:?}", self.cache); + for row in rows.iter() { output_rows.push( self.cache @@ -102,6 +112,7 @@ impl ArrowOperator for LookupJoin { .result_row_converter .convert_rows(output_rows.iter()) .unwrap(); + let mut result = batch.columns().to_vec(); result.extend(right_side); @@ -148,13 +159,17 @@ impl OperatorConstructor for LookupJoinConstructor { let result_row_converter = RowConverter::new( lookup_schema - .schema + .schema_without_timestamp() .fields .iter() .map(|f| SortField::new(f.data_type().clone())) .collect(), )?; + let lookup_schema = lookup_schema + .with_field(LOOKUP_KEY_INDEX_FIELD, DataType::UInt64, false)? + .schema_without_timestamp(); + let connector = connectors() .get(op.connector.as_str()) .unwrap_or_else(|| panic!("No connector with name '{}'", op.connector)) From 924e92c7afc3c0dfae7e1eb790917910d9cab277 Mon Sep 17 00:00:00 2001 From: Micah Wylde Date: Fri, 10 Jan 2025 16:13:23 -0800 Subject: [PATCH 07/14] checkpoint --- crates/arroyo-worker/src/arrow/lookup_join.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/crates/arroyo-worker/src/arrow/lookup_join.rs b/crates/arroyo-worker/src/arrow/lookup_join.rs index 46dbbc754..14e32bb60 100644 --- a/crates/arroyo-worker/src/arrow/lookup_join.rs +++ b/crates/arroyo-worker/src/arrow/lookup_join.rs @@ -113,8 +113,18 @@ impl ArrowOperator for LookupJoin { .convert_rows(output_rows.iter()) .unwrap(); - let mut result = batch.columns().to_vec(); + let in_schema = ctx.in_schemas.first().unwrap(); + let key_indices = in_schema.key_indices.as_ref().unwrap(); + let non_keys: Vec<_> = (0..batch.num_columns()) + .into_iter() + .filter(|i| !key_indices.contains(i) && *i != in_schema.timestamp_index) + .collect(); + + let mut result = batch.project(&non_keys) + .unwrap().columns() + .to_vec(); result.extend(right_side); + result.push(batch.column(in_schema.timestamp_index).clone()); println!("SCHEMA = {:?}", ctx.out_schema.as_ref().unwrap().schema); println!("RESULT COLS = {:?}", result.iter().map(|s| s.data_type())); From c8d06026ee2f2c7e21a3b4081e899ccfb58a1c28 Mon Sep 17 00:00:00 2001 From: Micah Wylde Date: Sun, 12 Jan 2025 11:28:42 -0800 Subject: [PATCH 08/14] join types --- .../src/kafka/source/test.rs | 12 +- crates/arroyo-connectors/src/redis/lookup.rs | 32 ++-- crates/arroyo-connectors/src/redis/mod.rs | 12 +- crates/arroyo-formats/src/de.rs | 153 ++++++++++++------ crates/arroyo-planner/src/plan/join.rs | 5 - crates/arroyo-rpc/src/df.rs | 12 +- crates/arroyo-types/src/lib.rs | 4 +- crates/arroyo-worker/src/arrow/lookup_join.rs | 112 +++++++++---- 8 files changed, 225 insertions(+), 117 deletions(-) diff --git a/crates/arroyo-connectors/src/kafka/source/test.rs b/crates/arroyo-connectors/src/kafka/source/test.rs index b0a6d2a84..801243870 100644 --- a/crates/arroyo-connectors/src/kafka/source/test.rs +++ b/crates/arroyo-connectors/src/kafka/source/test.rs @@ -5,14 +5,10 @@ use arroyo_state::tables::ErasedTable; use arroyo_state::{BackingStore, StateBackend}; use rand::random; +use crate::kafka::SourceOffset; use arrow::array::{Array, StringArray}; -use arrow::datatypes::TimeUnit; -use std::collections::{HashMap, VecDeque}; -use std::num::NonZeroU32; -use std::sync::Arc; -use std::time::{Duration, SystemTime}; use arrow::datatypes::DataType::UInt64; -use crate::kafka::SourceOffset; +use arrow::datatypes::TimeUnit; use arroyo_operator::context::{ batch_bounded, ArrowCollector, BatchReceiver, OperatorContext, SourceCollector, SourceContext, }; @@ -29,6 +25,10 @@ use rdkafka::admin::{AdminClient, AdminOptions, NewTopic}; use rdkafka::producer::{BaseProducer, BaseRecord}; use rdkafka::ClientConfig; use serde::{Deserialize, Serialize}; +use std::collections::{HashMap, VecDeque}; +use std::num::NonZeroU32; +use std::sync::Arc; +use std::time::{Duration, SystemTime}; use tokio::sync::mpsc::{channel, Receiver, Sender}; use super::KafkaSourceFunc; diff --git a/crates/arroyo-connectors/src/redis/lookup.rs b/crates/arroyo-connectors/src/redis/lookup.rs index a2907f54e..0f9967de3 100644 --- a/crates/arroyo-connectors/src/redis/lookup.rs +++ b/crates/arroyo-connectors/src/redis/lookup.rs @@ -1,15 +1,15 @@ -use std::collections::HashMap; use crate::redis::sink::GeneralConnection; use crate::redis::RedisClient; use arrow::array::{Array, ArrayRef, AsArray, RecordBatch}; use arrow::datatypes::DataType; use arroyo_formats::de::{ArrowDeserializer, FieldValueType}; use arroyo_operator::connector::LookupConnector; +use arroyo_rpc::MetadataField; use arroyo_types::{SourceError, LOOKUP_KEY_INDEX_FIELD}; use async_trait::async_trait; use redis::aio::ConnectionLike; use redis::{cmd, Value}; -use arroyo_rpc::MetadataField; +use std::collections::HashMap; pub struct RedisLookup { pub(crate) deserializer: ArrowDeserializer, @@ -41,42 +41,46 @@ impl LookupConnector for RedisLookup { let mut mget = cmd("mget"); let keys = keys[0].as_string::(); - + for k in keys { mget.arg(k.unwrap()); - println!("GETTTING {:?}", k); } let Value::Array(vs) = connection.req_packed_command(&mget).await.unwrap() else { panic!("value was not an array"); }; - assert_eq!(vs.len(), keys.len(), "Redis sent back the wrong number of values"); + assert_eq!( + vs.len(), + keys.len(), + "Redis sent back the wrong number of values" + ); let mut additional = HashMap::new(); - + for (idx, (v, k)) in vs.iter().zip(keys).enumerate() { additional.insert(LOOKUP_KEY_INDEX_FIELD, FieldValueType::UInt64(idx as u64)); for m in &self.metadata_fields { - additional.insert(m.field_name.as_str(), match m.key.as_str() { - "key" => FieldValueType::String(k.unwrap()), - k => unreachable!("Invalid metadata key '{}'", k) - }); + additional.insert( + m.field_name.as_str(), + match m.key.as_str() { + "key" => FieldValueType::String(k.unwrap()), + k => unreachable!("Invalid metadata key '{}'", k), + }, + ); } let errors = match v { Value::Nil => { - println!("GOt null"); + self.deserializer.deserialize_null(Some(&additional)); vec![] } Value::SimpleString(s) => { - println!("Got {:?}", s); self.deserializer .deserialize_without_timestamp(s.as_bytes(), Some(&additional)) .await } Value::BulkString(v) => { - println!("Got {:?}", String::from_utf8(v.clone())); self.deserializer .deserialize_without_timestamp(&v, Some(&additional)) .await @@ -85,7 +89,7 @@ impl LookupConnector for RedisLookup { panic!("unexpected type {:?}", v); } }; - + if !errors.is_empty() { return Some(Err(errors.into_iter().next().unwrap())); } diff --git a/crates/arroyo-connectors/src/redis/mod.rs b/crates/arroyo-connectors/src/redis/mod.rs index 8d8057be8..9e435c712 100644 --- a/crates/arroyo-connectors/src/redis/mod.rs +++ b/crates/arroyo-connectors/src/redis/mod.rs @@ -1,7 +1,11 @@ pub mod lookup; pub mod sink; +use crate::redis::lookup::RedisLookup; +use crate::redis::sink::{GeneralConnection, RedisSinkFunc}; +use crate::{pull_opt, pull_option_to_u64}; use anyhow::{anyhow, bail}; +use arrow::datatypes::{DataType, Schema}; use arroyo_formats::de::ArrowDeserializer; use arroyo_formats::ser::ArrowSerializer; use arroyo_operator::connector::{Connection, Connector, LookupConnector, MetadataDef}; @@ -11,6 +15,7 @@ use arroyo_rpc::api_types::connections::{ TestSourceMessage, }; use arroyo_rpc::df::ArroyoSchema; +use arroyo_rpc::schema_resolver::FailingSchemaResolver; use arroyo_rpc::var_str::VarStr; use arroyo_rpc::OperatorConfig; use redis::aio::ConnectionManager; @@ -19,13 +24,8 @@ use redis::{Client, ConnectionInfo, IntoConnectionInfo}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; -use arrow::datatypes::{DataType, Schema}; use tokio::sync::oneshot::Receiver; use typify::import_types; -use arroyo_rpc::schema_resolver::FailingSchemaResolver; -use crate::redis::lookup::RedisLookup; -use crate::redis::sink::{GeneralConnection, RedisSinkFunc}; -use crate::{pull_opt, pull_option_to_u64}; pub struct RedisConnector {} @@ -463,7 +463,7 @@ impl Connector for RedisConnector { .format .ok_or_else(|| anyhow!("Redis table must have a format"))?, schema, - &config.metadata_fields, + &config.metadata_fields, config.bad_data.unwrap_or_default(), Arc::new(FailingSchemaResolver::new()), ), diff --git a/crates/arroyo-formats/src/de.rs b/crates/arroyo-formats/src/de.rs index de681954c..2089fe616 100644 --- a/crates/arroyo-formats/src/de.rs +++ b/crates/arroyo-formats/src/de.rs @@ -3,7 +3,10 @@ use crate::proto::schema::get_pool; use crate::{proto, should_flush}; use arrow::array::{Int32Builder, Int64Builder}; use arrow::compute::kernels; -use arrow_array::builder::{make_builder, ArrayBuilder, GenericByteBuilder, StringBuilder, TimestampNanosecondBuilder, UInt64Builder}; +use arrow_array::builder::{ + make_builder, ArrayBuilder, BinaryBuilder, GenericByteBuilder, StringBuilder, + TimestampNanosecondBuilder, UInt64Builder, +}; use arrow_array::types::GenericBinaryType; use arrow_array::{ArrayRef, BooleanArray, RecordBatch}; use arrow_schema::{DataType, Schema, SchemaRef}; @@ -12,6 +15,7 @@ use arroyo_rpc::formats::{ AvroFormat, BadData, Format, Framing, FramingMethod, JsonFormat, ProtobufFormat, }; use arroyo_rpc::schema_resolver::{FailingSchemaResolver, FixedSchemaResolver, SchemaResolver}; +use arroyo_rpc::{MetadataField, TIMESTAMP_FIELD}; use arroyo_types::{to_nanos, SourceError, LOOKUP_KEY_INDEX_FIELD}; use prost_reflect::DescriptorPool; use serde_json::Value; @@ -19,7 +23,6 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::time::{Instant, SystemTime}; use tokio::sync::Mutex; -use arroyo_rpc::MetadataField; #[derive(Debug, Clone)] pub enum FieldValueType<'a> { @@ -205,6 +208,41 @@ impl BufferDecoder { } } } + + fn push_null(&mut self, schema: &Schema) { + match self { + BufferDecoder::Buffer(b) => { + for (f, b) in schema.fields.iter().zip(b.buffer.iter_mut()) { + match f.data_type() { + DataType::Binary => { + b.as_any_mut() + .downcast_mut::() + .unwrap() + .append_null(); + } + DataType::Utf8 => { + b.as_any_mut() + .downcast_mut::() + .unwrap() + .append_null(); + } + dt => { + unreachable!("unsupported datatype {}", dt); + } + } + } + } + BufferDecoder::JsonDecoder { + decoder, + buffered_count, + .. + } => { + decoder.decode("{}".as_bytes()).unwrap(); + + *buffered_count += 1; + } + } + } } pub struct ArrowDeserializer { @@ -241,7 +279,7 @@ impl ArrowDeserializer { Self::with_schema_resolver(format, framing, schema, &[], bad_data, resolver) } - + pub fn with_schema_resolver( format: Format, framing: Option, @@ -257,7 +295,7 @@ impl ArrowDeserializer { Some(schema.timestamp_index), metadata_fields, bad_data, - schema_resolver + schema_resolver, ) } @@ -269,12 +307,11 @@ impl ArrowDeserializer { schema_resolver: Arc, ) -> Self { let mut metadata_fields = metadata_fields.to_vec(); - metadata_fields - .push(MetadataField { - field_name: LOOKUP_KEY_INDEX_FIELD.to_string(), - key: LOOKUP_KEY_INDEX_FIELD.to_string(), - data_type: Some(DataType::UInt64), - }); + metadata_fields.push(MetadataField { + field_name: LOOKUP_KEY_INDEX_FIELD.to_string(), + key: LOOKUP_KEY_INDEX_FIELD.to_string(), + data_type: Some(DataType::UInt64), + }); Self::with_schema_resolver_and_raw_schema( format, @@ -283,10 +320,10 @@ impl ArrowDeserializer { None, &metadata_fields, bad_data, - schema_resolver + schema_resolver, ) } - + fn with_schema_resolver_and_raw_schema( format: Format, framing: Option, @@ -305,20 +342,22 @@ impl ArrowDeserializer { } else { DescriptorPool::global() }; - + let metadata_names: HashSet<_> = metadata_fields.iter().map(|f| &f.field_name).collect(); let schema_without_additional = { - let fields = schema_without_timestamp.fields().iter() + let fields = schema_without_timestamp + .fields() + .iter() .filter(|f| !metadata_names.contains(f.name())) .map(|f| f.clone()) .collect::>(); - Arc::new(Schema::new_with_metadata(fields, schema_without_timestamp.metadata.clone())) + Arc::new(Schema::new_with_metadata( + fields, + schema_without_timestamp.metadata.clone(), + )) }; - - println!("Schema without additional = {:?}", schema_without_additional); - println!("metadata_names = {:?}", metadata_names); - + let buffer_decoder = match format { Format::Json(..) | Format::Avro(AvroFormat { @@ -330,11 +369,11 @@ impl ArrowDeserializer { .. }) => BufferDecoder::JsonDecoder { decoder: arrow_json::reader::ReaderBuilder::new(schema_without_additional.clone()) - .with_limit_to_batch_size(false) - .with_strict_mode(false) - .with_allow_bad_data(matches!(bad_data, BadData::Drop { .. })) - .build_decoder() - .unwrap(), + .with_limit_to_batch_size(false) + .with_strict_mode(false) + .with_allow_bad_data(matches!(bad_data, BadData::Drop { .. })) + .build_decoder() + .unwrap(), buffered_count: 0, buffered_since: Instant::now(), }, @@ -345,7 +384,8 @@ impl ArrowDeserializer { format: Arc::new(format), framing: framing.map(Arc::new), buffer_decoder, - timestamp_builder: timestamp_idx.map(|i| (i, TimestampNanosecondBuilder::with_capacity(128))), + timestamp_builder: timestamp_idx + .map(|i| (i, TimestampNanosecondBuilder::with_capacity(128))), final_schema: schema_without_timestamp, decoder_schema: schema_without_additional, schema_registry: Arc::new(Mutex::new(HashMap::new())), @@ -373,9 +413,18 @@ impl ArrowDeserializer { msg: &[u8], additional_fields: Option<&HashMap<&str, FieldValueType<'_>>>, ) -> Vec { - self.deserialize_slice_int(msg, None, additional_fields).await + self.deserialize_slice_int(msg, None, additional_fields) + .await + } + + pub fn deserialize_null( + &mut self, + additional_fields: Option<&HashMap<&str, FieldValueType<'_>>>, + ) { + self.buffer_decoder.push_null(&self.decoder_schema); + self.add_additional_fields(additional_fields, 1); } - + async fn deserialize_slice_int( &mut self, msg: &[u8], @@ -399,6 +448,8 @@ impl ArrowDeserializer { } }; + self.add_additional_fields(additional_fields, count); + if let Some(timestamp) = timestamp { let (_, b) = self .timestamp_builder @@ -410,6 +461,14 @@ impl ArrowDeserializer { } } + errors + } + + fn add_additional_fields( + &mut self, + additional_fields: Option<&HashMap<&str, FieldValueType<'_>>>, + count: usize, + ) { if let Some(additional_fields) = additional_fields { if self.additional_fields_builder.is_none() { let mut builders = HashMap::new(); @@ -431,54 +490,51 @@ impl ArrowDeserializer { add_additional_fields(builders, k, v, count); } } - - errors } pub fn should_flush(&self) -> bool { self.buffer_decoder.should_flush() } - + pub fn flush_buffer(&mut self) -> Option> { - let (mut arrays, error_mask) = match self.buffer_decoder.flush(&self.bad_data)? { + let (arrays, error_mask) = match self.buffer_decoder.flush(&self.bad_data)? { Ok((a, b)) => (a, b), Err(e) => return Some(Err(e)), }; - - println!("ARrays = {:?}", arrays); - println!("error mask = {:?}", error_mask); + + let mut arrays: HashMap<_, _> = arrays + .into_iter() + .zip(self.decoder_schema.fields.iter()) + .map(|(a, f)| (f.name().as_str(), a)) + .collect(); if let Some(additional_fields) = &mut self.additional_fields_builder { for (name, builder) in additional_fields { - let (idx, _) = self - .final_schema - .column_with_name(&name) - .unwrap_or_else(|| panic!("Field '{}' not found in schema", name)); - let mut array = builder.finish(); if let Some(error_mask) = &error_mask { array = kernels::filter::filter(&array, error_mask).unwrap(); } - - println!("Schema = {:?}", self.final_schema); - arrays.insert(idx, array); + arrays.insert(name.as_str(), array); } }; - println!("With add = {:?}", arrays); - - if let Some((idx, timestamp)) = &mut self.timestamp_builder { + if let Some((_, timestamp)) = &mut self.timestamp_builder { let array = if let Some(error_mask) = &error_mask { kernels::filter::filter(×tamp.finish(), error_mask).unwrap() } else { Arc::new(timestamp.finish()) }; - arrays.insert(*idx, array); + arrays.insert(TIMESTAMP_FIELD, array); } - - println!("With timestamp ={:?}", arrays); + + let arrays = self + .final_schema + .fields + .iter() + .map(|f| arrays.get(f.name().as_str()).unwrap().clone()) + .collect(); Some(Ok( RecordBatch::try_new(self.final_schema.clone(), arrays).unwrap() @@ -588,7 +644,6 @@ impl ArrowDeserializer { } fn deserialize_raw_string(&mut self, msg: &[u8]) { - println!("Deserializing raw string {:?}", msg); let (col, _) = self .decoder_schema .column_with_name("value") diff --git a/crates/arroyo-planner/src/plan/join.rs b/crates/arroyo-planner/src/plan/join.rs index b0a045432..a2ade4bb0 100644 --- a/crates/arroyo-planner/src/plan/join.rs +++ b/crates/arroyo-planner/src/plan/join.rs @@ -248,11 +248,6 @@ fn maybe_plan_lookup_join(join: &Join) -> Result> { return Ok(None); } - println!( - "JOin = {:?} {:?}\n{:#?}", - join.join_constraint, join.join_type, join.on - ); - match join.join_type { JoinType::Inner | JoinType::Left => {} t => { diff --git a/crates/arroyo-rpc/src/df.rs b/crates/arroyo-rpc/src/df.rs index a2e587d9f..03f62fc10 100644 --- a/crates/arroyo-rpc/src/df.rs +++ b/crates/arroyo-rpc/src/df.rs @@ -384,15 +384,21 @@ impl ArroyoSchema { key_indices: None, }) } - + pub fn with_field(&self, name: &str, data_type: DataType, nullable: bool) -> Result { if self.schema.field_with_name(name).is_ok() { - bail!("cannot add field '{}' to schema, it is already present", name); + bail!( + "cannot add field '{}' to schema, it is already present", + name + ); } let mut fields = self.schema.fields().to_vec(); fields.push(Arc::new(Field::new(name, data_type, nullable))); Ok(Self { - schema: Arc::new(Schema::new_with_metadata(fields, self.schema.metadata.clone())), + schema: Arc::new(Schema::new_with_metadata( + fields, + self.schema.metadata.clone(), + )), timestamp_index: self.timestamp_index, key_indices: self.key_indices.clone(), }) diff --git a/crates/arroyo-types/src/lib.rs b/crates/arroyo-types/src/lib.rs index c9b8274cd..f4290e8ed 100644 --- a/crates/arroyo-types/src/lib.rs +++ b/crates/arroyo-types/src/lib.rs @@ -350,7 +350,7 @@ impl Serialize for DebeziumOp { } } -#[derive(Clone, Encode, Decode, Debug, Serialize, Deserialize, PartialEq)] +#[derive(Copy, Clone, Encode, Decode, Debug, Serialize, Deserialize, PartialEq)] pub enum JoinType { /// Inner Join Inner, @@ -694,4 +694,4 @@ mod tests { } } -pub const LOOKUP_KEY_INDEX_FIELD: &str = "__lookup_key_index"; \ No newline at end of file +pub const LOOKUP_KEY_INDEX_FIELD: &str = "__lookup_key_index"; diff --git a/crates/arroyo-worker/src/arrow/lookup_join.rs b/crates/arroyo-worker/src/arrow/lookup_join.rs index 14e32bb60..69b372e3b 100644 --- a/crates/arroyo-worker/src/arrow/lookup_join.rs +++ b/crates/arroyo-worker/src/arrow/lookup_join.rs @@ -1,25 +1,33 @@ -use std::collections::HashMap; -use std::ops::Index; -use std::sync::Arc; - +use arrow::compute::{filter_record_batch, is_null}; use arrow::row::{OwnedRow, RowConverter, SortField}; -use arrow_array::{Array, RecordBatch}; use arrow_array::cast::AsArray; use arrow_array::types::UInt64Type; -use arrow_schema::DataType; +use arrow_array::{Array, BooleanArray, RecordBatch}; +use arrow_schema::{DataType, Schema}; use arroyo_connectors::connectors; use arroyo_operator::connector::LookupConnector; use arroyo_operator::context::{Collector, OperatorContext}; -use arroyo_operator::operator::{ArrowOperator, ConstructedOperator, DisplayableOperator, OperatorConstructor, Registry}; +use arroyo_operator::operator::{ + ArrowOperator, ConstructedOperator, OperatorConstructor, Registry, +}; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::grpc::api; -use arroyo_types::{JoinType, LOOKUP_KEY_INDEX_FIELD}; +use arroyo_rpc::{MetadataField, OperatorConfig}; +use arroyo_types::LOOKUP_KEY_INDEX_FIELD; use async_trait::async_trait; use datafusion::physical_expr::PhysicalExpr; use datafusion_proto::physical_plan::from_proto::parse_physical_expr; use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; use datafusion_proto::protobuf::PhysicalExprNode; use prost::Message; +use std::collections::HashMap; +use std::sync::Arc; + +#[derive(Copy, Clone, PartialEq)] +pub(crate) enum LookupJoinType { + Left, + Inner, +} /// A simple in-operator cache storing the entire “right side” row batch keyed by a string. pub struct LookupJoin { @@ -28,7 +36,9 @@ pub struct LookupJoin { cache: HashMap, OwnedRow>, key_row_converter: RowConverter, result_row_converter: RowConverter, - join_type: JoinType, + join_type: LookupJoinType, + lookup_schema: Arc, + metadata_fields: Vec, } #[async_trait] @@ -73,11 +83,12 @@ impl ArrowOperator for LookupJoin { let result_batch = self.connector.lookup(&cols).await; - println!("Batch = {:?}", result_batch); - if let Some(result_batch) = result_batch { let mut result_batch = result_batch.unwrap(); - let key_idx_col = result_batch.schema().index_of(LOOKUP_KEY_INDEX_FIELD).unwrap(); + let key_idx_col = result_batch + .schema() + .index_of(LOOKUP_KEY_INDEX_FIELD) + .unwrap(); let keys = result_batch.remove_column(key_idx_col); let keys = keys.as_primitive::(); @@ -88,7 +99,10 @@ impl ArrowOperator for LookupJoin { .unwrap(); for (v, idx) in result_rows.iter().zip(keys) { - self.cache.insert(uncached_keys[idx.unwrap() as usize].as_ref().to_vec(), v.owned()); + self.cache.insert( + uncached_keys[idx.unwrap() as usize].as_ref().to_vec(), + v.owned(), + ); } } } @@ -97,8 +111,6 @@ impl ArrowOperator for LookupJoin { .result_row_converter .empty_rows(batch.num_rows(), batch.num_rows() * 10); - println!("Cache {:?}", self.cache); - for row in rows.iter() { output_rows.push( self.cache @@ -113,6 +125,37 @@ impl ArrowOperator for LookupJoin { .convert_rows(output_rows.iter()) .unwrap(); + let nonnull = (self.join_type == LookupJoinType::Inner).then(|| { + let mut nonnull = vec![false; batch.num_rows()]; + + for (_, a) in self + .lookup_schema + .fields + .iter() + .zip(right_side.iter()) + .filter(|(f, _)| { + !self + .metadata_fields + .iter() + .any(|m| &m.field_name == f.name()) + }) + { + if let Some(nulls) = a.logical_nulls() { + for (a, b) in nulls.iter().zip(nonnull.iter_mut()) { + *b |= a; + } + } else { + for b in &mut nonnull { + *b = true; + } + break; + } + } + + + BooleanArray::from(nonnull) + }); + let in_schema = ctx.in_schemas.first().unwrap(); let key_indices = in_schema.key_indices.as_ref().unwrap(); let non_keys: Vec<_> = (0..batch.num_columns()) @@ -120,21 +163,18 @@ impl ArrowOperator for LookupJoin { .filter(|i| !key_indices.contains(i) && *i != in_schema.timestamp_index) .collect(); - let mut result = batch.project(&non_keys) - .unwrap().columns() - .to_vec(); + let mut result = batch.project(&non_keys).unwrap().columns().to_vec(); result.extend(right_side); result.push(batch.column(in_schema.timestamp_index).clone()); - println!("SCHEMA = {:?}", ctx.out_schema.as_ref().unwrap().schema); - println!("RESULT COLS = {:?}", result.iter().map(|s| s.data_type())); + let mut batch = + RecordBatch::try_new(ctx.out_schema.as_ref().unwrap().schema.clone(), result).unwrap(); - collector - .collect( - RecordBatch::try_new(ctx.out_schema.as_ref().unwrap().schema.clone(), result) - .unwrap(), - ) - .await; + if let Some(nonnull) = nonnull { + batch = filter_record_batch(&batch, &nonnull).unwrap(); + } + + collector.collect(batch).await; } } @@ -165,7 +205,7 @@ impl OperatorConstructor for LookupJoinConstructor { .collect::>>()?; let op = config.connector.unwrap(); - let operator_config = serde_json::from_str(&op.config)?; + let operator_config: OperatorConfig = serde_json::from_str(&op.config)?; let result_row_converter = RowConverter::new( lookup_schema @@ -176,14 +216,16 @@ impl OperatorConstructor for LookupJoinConstructor { .collect(), )?; - let lookup_schema = lookup_schema - .with_field(LOOKUP_KEY_INDEX_FIELD, DataType::UInt64, false)? - .schema_without_timestamp(); + let lookup_schema = Arc::new( + lookup_schema + .with_field(LOOKUP_KEY_INDEX_FIELD, DataType::UInt64, false)? + .schema_without_timestamp(), + ); let connector = connectors() .get(op.connector.as_str()) .unwrap_or_else(|| panic!("No connector with name '{}'", op.connector)) - .make_lookup(operator_config, Arc::new(lookup_schema))?; + .make_lookup(operator_config.clone(), lookup_schema.clone())?; Ok(ConstructedOperator::from_operator(Box::new(LookupJoin { connector, @@ -196,7 +238,13 @@ impl OperatorConstructor for LookupJoinConstructor { )?, key_exprs: exprs, result_row_converter, - join_type: join_type.into(), + join_type: match join_type { + api::JoinType::Inner => LookupJoinType::Inner, + api::JoinType::Left => LookupJoinType::Left, + jt => unreachable!("invalid lookup join type {:?}", jt), + }, + lookup_schema, + metadata_fields: operator_config.metadata_fields, }))) } } From 7388440b7968bf998100f3ac878a7f7848ad1958 Mon Sep 17 00:00:00 2001 From: Micah Wylde Date: Sun, 12 Jan 2025 11:40:00 -0800 Subject: [PATCH 09/14] format & clippy --- crates/arroyo-connectors/src/kafka/sink/test.rs | 2 +- crates/arroyo-connectors/src/kafka/source/test.rs | 4 ++-- crates/arroyo-connectors/src/mqtt/sink/test.rs | 2 +- crates/arroyo-connectors/src/mqtt/source/test.rs | 4 ++-- crates/arroyo-connectors/src/redis/lookup.rs | 2 +- crates/arroyo-connectors/src/redis/mod.rs | 3 +-- crates/arroyo-connectors/src/redis/sink.rs | 4 ++-- crates/arroyo-formats/src/avro/de.rs | 7 +++---- crates/arroyo-formats/src/de.rs | 9 ++++----- crates/arroyo-planner/src/extension/join.rs | 1 - crates/arroyo-planner/src/plan/join.rs | 6 +++--- crates/arroyo-types/src/lib.rs | 4 ++-- crates/arroyo-worker/src/arrow/lookup_join.rs | 4 +--- 13 files changed, 23 insertions(+), 29 deletions(-) diff --git a/crates/arroyo-connectors/src/kafka/sink/test.rs b/crates/arroyo-connectors/src/kafka/sink/test.rs index 30896e80b..26684e643 100644 --- a/crates/arroyo-connectors/src/kafka/sink/test.rs +++ b/crates/arroyo-connectors/src/kafka/sink/test.rs @@ -96,7 +96,7 @@ impl KafkaTopicTester { None, command_tx, 1, - vec![ArroyoSchema::new_unkeyed(schema(), 0)], + vec![Arc::new(ArroyoSchema::new_unkeyed(schema(), 0))], None, HashMap::new(), ) diff --git a/crates/arroyo-connectors/src/kafka/source/test.rs b/crates/arroyo-connectors/src/kafka/source/test.rs index 801243870..316e93135 100644 --- a/crates/arroyo-connectors/src/kafka/source/test.rs +++ b/crates/arroyo-connectors/src/kafka/source/test.rs @@ -108,7 +108,7 @@ impl KafkaTopicTester { operator_ids: vec![task_info.operator_id.clone()], }); - let out_schema = Some(ArroyoSchema::new_unkeyed( + let out_schema = Some(Arc::new(ArroyoSchema::new_unkeyed( Arc::new(Schema::new(vec![ Field::new( "_timestamp", @@ -118,7 +118,7 @@ impl KafkaTopicTester { Field::new("value", DataType::Utf8, false), ])), 0, - )); + ))); let task_info = Arc::new(task_info); diff --git a/crates/arroyo-connectors/src/mqtt/sink/test.rs b/crates/arroyo-connectors/src/mqtt/sink/test.rs index 60c1d03e6..7f2ec9826 100644 --- a/crates/arroyo-connectors/src/mqtt/sink/test.rs +++ b/crates/arroyo-connectors/src/mqtt/sink/test.rs @@ -84,7 +84,7 @@ impl MqttTopicTester { None, command_tx, 1, - vec![ArroyoSchema::new_unkeyed(schema(), 0)], + vec![Arc::new(ArroyoSchema::new_unkeyed(schema(), 0))], None, HashMap::new(), ) diff --git a/crates/arroyo-connectors/src/mqtt/source/test.rs b/crates/arroyo-connectors/src/mqtt/source/test.rs index 4690cdc88..f184a9342 100644 --- a/crates/arroyo-connectors/src/mqtt/source/test.rs +++ b/crates/arroyo-connectors/src/mqtt/source/test.rs @@ -141,7 +141,7 @@ impl MqttTopicTester { command_tx.clone(), 1, vec![], - Some(ArroyoSchema::new_unkeyed( + Some(Arc::new(ArroyoSchema::new_unkeyed( Arc::new(Schema::new(vec![ Field::new( "_timestamp", @@ -151,7 +151,7 @@ impl MqttTopicTester { Field::new("value", DataType::UInt64, false), ])), 0, - )), + ))), mqtt.tables(), ) .await; diff --git a/crates/arroyo-connectors/src/redis/lookup.rs b/crates/arroyo-connectors/src/redis/lookup.rs index 0f9967de3..79400bc4d 100644 --- a/crates/arroyo-connectors/src/redis/lookup.rs +++ b/crates/arroyo-connectors/src/redis/lookup.rs @@ -82,7 +82,7 @@ impl LookupConnector for RedisLookup { } Value::BulkString(v) => { self.deserializer - .deserialize_without_timestamp(&v, Some(&additional)) + .deserialize_without_timestamp(v, Some(&additional)) .await } v => { diff --git a/crates/arroyo-connectors/src/redis/mod.rs b/crates/arroyo-connectors/src/redis/mod.rs index 9e435c712..8dd358f74 100644 --- a/crates/arroyo-connectors/src/redis/mod.rs +++ b/crates/arroyo-connectors/src/redis/mod.rs @@ -14,7 +14,6 @@ use arroyo_rpc::api_types::connections::{ ConnectionProfile, ConnectionSchema, ConnectionType, FieldType, PrimitiveType, TestSourceMessage, }; -use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::schema_resolver::FailingSchemaResolver; use arroyo_rpc::var_str::VarStr; use arroyo_rpc::OperatorConfig; @@ -453,7 +452,7 @@ impl Connector for RedisConnector { fn make_lookup( &self, profile: Self::ProfileT, - table: Self::TableT, + _: Self::TableT, config: OperatorConfig, schema: Arc, ) -> anyhow::Result> { diff --git a/crates/arroyo-connectors/src/redis/sink.rs b/crates/arroyo-connectors/src/redis/sink.rs index 6820c8b91..25eefa475 100644 --- a/crates/arroyo-connectors/src/redis/sink.rs +++ b/crates/arroyo-connectors/src/redis/sink.rs @@ -1,4 +1,4 @@ -use crate::redis::{ListOperation, RedisClient, RedisTable, TableType, Target}; +use crate::redis::{ListOperation, RedisClient, Target}; use arrow::array::{AsArray, RecordBatch}; use arroyo_formats::ser::ArrowSerializer; use arroyo_operator::context::{Collector, ErrorReporter, OperatorContext}; @@ -20,7 +20,7 @@ const FLUSH_BYTES: usize = 10 * 1024 * 1024; pub struct RedisSinkFunc { pub serializer: ArrowSerializer, pub target: Target, - pub client: RedisClient, + pub(crate) client: RedisClient, pub cmd_q: Option<(Sender, Receiver)>, pub rx: Receiver, diff --git a/crates/arroyo-formats/src/avro/de.rs b/crates/arroyo-formats/src/avro/de.rs index 4c2bfd115..7e90deffc 100644 --- a/crates/arroyo-formats/src/avro/de.rs +++ b/crates/arroyo-formats/src/avro/de.rs @@ -132,8 +132,7 @@ pub(crate) fn avro_to_json(value: AvroValue) -> JsonValue { mod tests { use crate::avro::schema::to_arrow; use crate::de::ArrowDeserializer; - use arrow_array::builder::{make_builder, ArrayBuilder}; - use arrow_array::RecordBatch; + use arrow_json::writer::record_batch_to_vec; use arrow_schema::{DataType, Field, Schema, TimeUnit}; use arroyo_rpc::df::ArroyoSchema; @@ -257,6 +256,7 @@ mod tests { Format::Avro(format), None, Arc::new(arroyo_schema.clone()), + &[], BadData::Fail {}, resolver, ), @@ -269,8 +269,7 @@ mod tests { writer_schema: Option<&str>, message: &[u8], ) -> Vec> { - let (mut deserializer, arroyo_schema) = - deserializer_with_schema(format.clone(), writer_schema); + let (mut deserializer, _) = deserializer_with_schema(format.clone(), writer_schema); let errors = deserializer .deserialize_slice(message, SystemTime::now(), None) diff --git a/crates/arroyo-formats/src/de.rs b/crates/arroyo-formats/src/de.rs index 2089fe616..1bdaced38 100644 --- a/crates/arroyo-formats/src/de.rs +++ b/crates/arroyo-formats/src/de.rs @@ -137,6 +137,7 @@ impl BufferDecoder { } } + #[allow(clippy::type_complexity)] fn flush( &mut self, bad_data: &BadData, @@ -350,7 +351,7 @@ impl ArrowDeserializer { .fields() .iter() .filter(|f| !metadata_names.contains(f.name())) - .map(|f| f.clone()) + .cloned() .collect::>(); Arc::new(Schema::new_with_metadata( fields, @@ -821,7 +822,7 @@ mod tests { let schema = Arc::new(ArroyoSchema::from_schema_unkeyed(schema).unwrap()); - let deserializer = ArrowDeserializer::new( + ArrowDeserializer::new( Format::Json(JsonFormat { confluent_schema_registry: false, schema_id: None, @@ -833,9 +834,7 @@ mod tests { schema, None, bad_data, - ); - - deserializer + ) } #[tokio::test] diff --git a/crates/arroyo-planner/src/extension/join.rs b/crates/arroyo-planner/src/extension/join.rs index 38dc27581..5769e67c4 100644 --- a/crates/arroyo-planner/src/extension/join.rs +++ b/crates/arroyo-planner/src/extension/join.rs @@ -10,7 +10,6 @@ use datafusion::logical_expr::{LogicalPlan, UserDefinedLogicalNodeCore}; use datafusion_proto::generated::datafusion::PhysicalPlanNode; use datafusion_proto::physical_plan::AsExecutionPlan; use prost::Message; -use std::sync::Arc; use std::time::Duration; pub(crate) const JOIN_NODE_NAME: &str = "JoinNode"; diff --git a/crates/arroyo-planner/src/plan/join.rs b/crates/arroyo-planner/src/plan/join.rs index a2ade4bb0..9780f1ea4 100644 --- a/crates/arroyo-planner/src/plan/join.rs +++ b/crates/arroyo-planner/src/plan/join.rs @@ -201,7 +201,7 @@ struct FindLookupExtension { alias: Option, } -impl<'a> TreeNodeVisitor<'a> for FindLookupExtension { +impl TreeNodeVisitor<'_> for FindLookupExtension { type Node = LogicalPlan; fn f_down(&mut self, node: &Self::Node) -> Result { @@ -266,8 +266,8 @@ fn maybe_plan_lookup_join(join: &Join) -> Result> { match r { Expr::Column(c) => Ok((l.clone(), c.clone())), e => { - return plan_err!("invalid right-side condition for lookup join: `{}`; only column references are supported", - expr_to_sql(e).map(|e| e.to_string()).unwrap_or_else(|_| e.to_string())); + plan_err!("invalid right-side condition for lookup join: `{}`; only column references are supported", + expr_to_sql(e).map(|e| e.to_string()).unwrap_or_else(|_| e.to_string())) } } }).collect::>()?; diff --git a/crates/arroyo-types/src/lib.rs b/crates/arroyo-types/src/lib.rs index f4290e8ed..31a0a0db3 100644 --- a/crates/arroyo-types/src/lib.rs +++ b/crates/arroyo-types/src/lib.rs @@ -641,6 +641,8 @@ pub fn range_for_server(i: usize, n: usize) -> RangeInclusive { start..=end } +pub const LOOKUP_KEY_INDEX_FIELD: &str = "__lookup_key_index"; + #[cfg(test)] mod tests { use super::*; @@ -693,5 +695,3 @@ mod tests { ); } } - -pub const LOOKUP_KEY_INDEX_FIELD: &str = "__lookup_key_index"; diff --git a/crates/arroyo-worker/src/arrow/lookup_join.rs b/crates/arroyo-worker/src/arrow/lookup_join.rs index 69b372e3b..7e1fc1fad 100644 --- a/crates/arroyo-worker/src/arrow/lookup_join.rs +++ b/crates/arroyo-worker/src/arrow/lookup_join.rs @@ -1,4 +1,4 @@ -use arrow::compute::{filter_record_batch, is_null}; +use arrow::compute::filter_record_batch; use arrow::row::{OwnedRow, RowConverter, SortField}; use arrow_array::cast::AsArray; use arrow_array::types::UInt64Type; @@ -152,14 +152,12 @@ impl ArrowOperator for LookupJoin { } } - BooleanArray::from(nonnull) }); let in_schema = ctx.in_schemas.first().unwrap(); let key_indices = in_schema.key_indices.as_ref().unwrap(); let non_keys: Vec<_> = (0..batch.num_columns()) - .into_iter() .filter(|i| !key_indices.contains(i) && *i != in_schema.timestamp_index) .collect(); From bcf7d54e6f0b81d4f677d18fbce9b001dcdbe1aa Mon Sep 17 00:00:00 2001 From: Micah Wylde Date: Sun, 12 Jan 2025 15:50:30 -0800 Subject: [PATCH 10/14] configuration and tests --- Cargo.lock | 96 +++++++++++++++++-- .../src/filesystem/source.rs | 1 + crates/arroyo-connectors/src/fluvio/source.rs | 1 + .../arroyo-connectors/src/kafka/source/mod.rs | 1 + .../arroyo-connectors/src/kinesis/source.rs | 1 + .../arroyo-connectors/src/mqtt/source/mod.rs | 1 + .../arroyo-connectors/src/nats/source/mod.rs | 1 + .../src/polling_http/operator.rs | 1 + .../arroyo-connectors/src/rabbitmq/source.rs | 1 + .../src/single_file/source.rs | 1 + crates/arroyo-connectors/src/sse/operator.rs | 1 + .../src/websocket/operator.rs | 1 + crates/arroyo-formats/src/de.rs | 29 ++++-- crates/arroyo-operator/src/context.rs | 2 + crates/arroyo-planner/src/extension/lookup.rs | 5 + crates/arroyo-planner/src/tables.rs | 43 ++++++++- crates/arroyo-rpc/proto/api.proto | 2 + crates/arroyo-worker/Cargo.toml | 1 + crates/arroyo-worker/src/arrow/lookup_join.rs | 71 +++++++++----- 19 files changed, 219 insertions(+), 41 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 564a4d56d..6bdf3f55e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1162,6 +1162,7 @@ dependencies = [ "local-ip-address", "md-5", "memchr", + "mini-moka", "object_store", "once_cell", "ordered-float 3.9.2", @@ -2202,6 +2203,12 @@ version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +[[package]] +name = "bytecount" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce" + [[package]] name = "bytemuck" version = "1.19.0" @@ -2281,6 +2288,19 @@ dependencies = [ "serde", ] +[[package]] +name = "cargo_metadata" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4acbb09d9ee8e23699b9634375c72795d095bf268439da88562cf9b501f181fa" +dependencies = [ + "camino", + "cargo-platform", + "semver", + "serde", + "serde_json", +] + [[package]] name = "cargo_metadata" version = "0.18.1" @@ -3784,6 +3804,15 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "error-chain" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d2f06b9cac1506ece98fe3231e3cc9c4410ec3d5b1f24ae1c8946f0742cdefc" +dependencies = [ + "version_check", +] + [[package]] name = "event-listener" version = "2.5.3" @@ -5963,6 +5992,21 @@ dependencies = [ "unicase", ] +[[package]] +name = "mini-moka" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c325dfab65f261f386debee8b0969da215b3fa0037e74c8a1234db7ba986d803" +dependencies = [ + "crossbeam-channel", + "crossbeam-utils", + "dashmap 5.5.3", + "skeptic", + "smallvec", + "tagptr", + "triomphe", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -7253,8 +7297,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15" dependencies = [ "bytes", - "heck 0.4.1", - "itertools 0.10.5", + "heck 0.5.0", + "itertools 0.13.0", "log", "multimap", "once_cell", @@ -7300,7 +7344,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e9552f850d5f0964a4e4d0bf306459ac29323ddfbae05e35a7c0d35cb0803cc5" dependencies = [ "anyhow", - "itertools 0.10.5", + "itertools 0.13.0", "proc-macro2", "quote", "syn 2.0.87", @@ -7380,6 +7424,17 @@ dependencies = [ "cc", ] +[[package]] +name = "pulldown-cmark" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57206b407293d2bcd3af849ce869d52068623f19e1b5ff8e8778e3309439682b" +dependencies = [ + "bitflags 2.6.0", + "memchr", + "unicase", +] + [[package]] name = "pyo3" version = "0.21.2" @@ -7390,7 +7445,7 @@ dependencies = [ "indoc", "libc", "memoffset", - "parking_lot 0.11.2", + "parking_lot 0.12.3", "portable-atomic", "pyo3-build-config", "pyo3-ffi", @@ -8739,6 +8794,21 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" +[[package]] +name = "skeptic" +version = "0.13.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16d23b015676c90a0f01c197bfdc786c20342c73a0afdda9025adb0bc42940a8" +dependencies = [ + "bytecount", + "cargo_metadata 0.14.2", + "error-chain", + "glob", + "pulldown-cmark", + "tempfile", + "walkdir", +] + [[package]] name = "slab" version = "0.4.9" @@ -8775,7 +8845,7 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03c3c6b7927ffe7ecaa769ee0e3994da3b8cafc8f444578982c83ecb161af917" dependencies = [ - "heck 0.4.1", + "heck 0.5.0", "proc-macro2", "quote", "syn 2.0.87", @@ -9074,6 +9144,12 @@ dependencies = [ "libc", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "tap" version = "1.0.1" @@ -9811,6 +9887,12 @@ dependencies = [ "tracing-serde", ] +[[package]] +name = "triomphe" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef8f7726da4807b58ea5c96fdc122f80702030edc33b35aff9190a51148ccc85" + [[package]] name = "try-lock" version = "0.2.5" @@ -10186,7 +10268,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2990d9ea5967266ea0ccf413a4aa5c42a93dbcfda9cb49a97de6931726b12566" dependencies = [ "anyhow", - "cargo_metadata", + "cargo_metadata 0.18.1", "cfg-if", "regex", "rustversion", @@ -10404,7 +10486,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] diff --git a/crates/arroyo-connectors/src/filesystem/source.rs b/crates/arroyo-connectors/src/filesystem/source.rs index f2792f897..0f120fcbe 100644 --- a/crates/arroyo-connectors/src/filesystem/source.rs +++ b/crates/arroyo-connectors/src/filesystem/source.rs @@ -128,6 +128,7 @@ impl FileSystemSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &[], ); let parallelism = ctx.task_info.parallelism; let task_index = ctx.task_info.task_index; diff --git a/crates/arroyo-connectors/src/fluvio/source.rs b/crates/arroyo-connectors/src/fluvio/source.rs index 491a82b54..a64581053 100644 --- a/crates/arroyo-connectors/src/fluvio/source.rs +++ b/crates/arroyo-connectors/src/fluvio/source.rs @@ -57,6 +57,7 @@ impl SourceOperator for FluvioSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &[], ); match self.run_int(ctx, collector).await { diff --git a/crates/arroyo-connectors/src/kafka/source/mod.rs b/crates/arroyo-connectors/src/kafka/source/mod.rs index 23cb382e6..020c0da7a 100644 --- a/crates/arroyo-connectors/src/kafka/source/mod.rs +++ b/crates/arroyo-connectors/src/kafka/source/mod.rs @@ -181,6 +181,7 @@ impl KafkaSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &self.metadata_fields, ); } diff --git a/crates/arroyo-connectors/src/kinesis/source.rs b/crates/arroyo-connectors/src/kinesis/source.rs index ce43a0411..7cbaf4970 100644 --- a/crates/arroyo-connectors/src/kinesis/source.rs +++ b/crates/arroyo-connectors/src/kinesis/source.rs @@ -173,6 +173,7 @@ impl SourceOperator for KinesisSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &[], ); match self.run_int(ctx, collector).await { diff --git a/crates/arroyo-connectors/src/mqtt/source/mod.rs b/crates/arroyo-connectors/src/mqtt/source/mod.rs index 2710682f7..4f1508fd2 100644 --- a/crates/arroyo-connectors/src/mqtt/source/mod.rs +++ b/crates/arroyo-connectors/src/mqtt/source/mod.rs @@ -101,6 +101,7 @@ impl MqttSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &self.metadata_fields, ); if ctx.task_info.task_index > 0 { diff --git a/crates/arroyo-connectors/src/nats/source/mod.rs b/crates/arroyo-connectors/src/nats/source/mod.rs index 7cee2bc1e..c592b73c2 100644 --- a/crates/arroyo-connectors/src/nats/source/mod.rs +++ b/crates/arroyo-connectors/src/nats/source/mod.rs @@ -333,6 +333,7 @@ impl NatsSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &[], ); let nats_client = get_nats_client(&self.connection) diff --git a/crates/arroyo-connectors/src/polling_http/operator.rs b/crates/arroyo-connectors/src/polling_http/operator.rs index f6217421c..13a8373ae 100644 --- a/crates/arroyo-connectors/src/polling_http/operator.rs +++ b/crates/arroyo-connectors/src/polling_http/operator.rs @@ -208,6 +208,7 @@ impl PollingHttpSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &[], ); // since there's no way to partition across an http source, only read on the first task diff --git a/crates/arroyo-connectors/src/rabbitmq/source.rs b/crates/arroyo-connectors/src/rabbitmq/source.rs index 05d7a53e5..b2af95ff3 100644 --- a/crates/arroyo-connectors/src/rabbitmq/source.rs +++ b/crates/arroyo-connectors/src/rabbitmq/source.rs @@ -50,6 +50,7 @@ impl SourceOperator for RabbitmqStreamSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &[], ); match self.run_int(ctx, collector).await { diff --git a/crates/arroyo-connectors/src/single_file/source.rs b/crates/arroyo-connectors/src/single_file/source.rs index 800d21fd9..d5f79f8fe 100644 --- a/crates/arroyo-connectors/src/single_file/source.rs +++ b/crates/arroyo-connectors/src/single_file/source.rs @@ -100,6 +100,7 @@ impl SourceOperator for SingleFileSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &[], ); let state: &mut arroyo_state::tables::global_keyed_map::GlobalKeyedView = diff --git a/crates/arroyo-connectors/src/sse/operator.rs b/crates/arroyo-connectors/src/sse/operator.rs index df00ee4f9..0095447fb 100644 --- a/crates/arroyo-connectors/src/sse/operator.rs +++ b/crates/arroyo-connectors/src/sse/operator.rs @@ -149,6 +149,7 @@ impl SSESourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &[], ); let mut client = eventsource_client::ClientBuilder::for_url(&self.url).unwrap(); diff --git a/crates/arroyo-connectors/src/websocket/operator.rs b/crates/arroyo-connectors/src/websocket/operator.rs index c009dd548..a477079c9 100644 --- a/crates/arroyo-connectors/src/websocket/operator.rs +++ b/crates/arroyo-connectors/src/websocket/operator.rs @@ -65,6 +65,7 @@ impl SourceOperator for WebsocketSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &[], ); match self.run_int(ctx, collector).await { diff --git a/crates/arroyo-formats/src/de.rs b/crates/arroyo-formats/src/de.rs index 1bdaced38..c937a36e2 100644 --- a/crates/arroyo-formats/src/de.rs +++ b/crates/arroyo-formats/src/de.rs @@ -264,6 +264,7 @@ impl ArrowDeserializer { pub fn new( format: Format, schema: Arc, + metadata_fields: &[MetadataField], framing: Option, bad_data: BadData, ) -> Self { @@ -278,7 +279,7 @@ impl ArrowDeserializer { Arc::new(FailingSchemaResolver::new()) as Arc }; - Self::with_schema_resolver(format, framing, schema, &[], bad_data, resolver) + Self::with_schema_resolver(format, framing, schema, metadata_fields, bad_data, resolver) } pub fn with_schema_resolver( @@ -729,12 +730,13 @@ mod tests { use arrow::datatypes::Int32Type; use arrow_array::cast::AsArray; use arrow_array::types::{GenericBinaryType, Int64Type, TimestampNanosecondType}; - use arrow_schema::{Schema, TimeUnit}; + use arrow_schema::{DataType, Schema, TimeUnit}; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::formats::{ BadData, Format, Framing, FramingMethod, JsonFormat, NewlineDelimitedFraming, RawBytesFormat, }; + use arroyo_rpc::MetadataField; use arroyo_types::{to_nanos, SourceError}; use serde_json::json; use std::sync::Arc; @@ -832,6 +834,7 @@ mod tests { timestamp_format: Default::default(), }), schema, + &[], None, bad_data, ) @@ -913,6 +916,7 @@ mod tests { let mut deserializer = ArrowDeserializer::new( Format::RawBytes(RawBytesFormat {}), arroyo_schema, + &[], None, BadData::Fail {}, ); @@ -941,7 +945,7 @@ mod tests { } #[tokio::test] - async fn test_additional_fields_deserialisation() { + async fn test_additional_fields_deserialization() { let schema = Arc::new(Schema::new(vec![ arrow_schema::Field::new("x", arrow_schema::DataType::Int64, true), arrow_schema::Field::new("y", arrow_schema::DataType::Int32, true), @@ -965,17 +969,26 @@ mod tests { timestamp_format: Default::default(), }), arroyo_schema, + &[ + MetadataField { + field_name: "y".to_string(), + key: "y".to_string(), + data_type: Some(DataType::Int64), + }, + MetadataField { + field_name: "z".to_string(), + key: "z".to_string(), + data_type: Some(DataType::Utf8), + }, + ], None, BadData::Drop {}, ); let time = SystemTime::now(); let mut additional_fields = std::collections::HashMap::new(); - let binding = "y".to_string(); - additional_fields.insert(binding.as_str(), FieldValueType::Int32(5)); - let z_value = "hello".to_string(); - let binding = "z".to_string(); - additional_fields.insert(binding.as_str(), FieldValueType::String(&z_value)); + additional_fields.insert("y", FieldValueType::Int32(5)); + additional_fields.insert("z", FieldValueType::String("hello")); let result = deserializer .deserialize_slice( diff --git a/crates/arroyo-operator/src/context.rs b/crates/arroyo-operator/src/context.rs index 747ed91e4..e013feb99 100644 --- a/crates/arroyo-operator/src/context.rs +++ b/crates/arroyo-operator/src/context.rs @@ -320,6 +320,7 @@ impl SourceCollector { format: Format, framing: Option, bad_data: Option, + metadata_fields: &[MetadataField], ) { if self.deserializer.is_some() { panic!("Deserialize already initialized"); @@ -328,6 +329,7 @@ impl SourceCollector { self.deserializer = Some(ArrowDeserializer::new( format, self.out_schema.clone(), + metadata_fields, framing, bad_data.unwrap_or_default(), )); diff --git a/crates/arroyo-planner/src/extension/lookup.rs b/crates/arroyo-planner/src/extension/lookup.rs index 210d19d38..4e2a8f4e6 100644 --- a/crates/arroyo-planner/src/extension/lookup.rs +++ b/crates/arroyo-planner/src/extension/lookup.rs @@ -118,6 +118,11 @@ impl ArroyoExtension for LookupJoin { return plan_err!("unsupported join type '{j}' for lookup join; only inner and left joins are supported"); } }, + ttl_micros: self + .connector + .lookup_cache_ttl + .map(|t| t.as_micros() as u64), + max_capacity_bytes: self.connector.lookup_cache_max_bytes, }; let incoming_edge = diff --git a/crates/arroyo-planner/src/tables.rs b/crates/arroyo-planner/src/tables.rs index 7b2638a51..b7f5b8f9d 100644 --- a/crates/arroyo-planner/src/tables.rs +++ b/crates/arroyo-planner/src/tables.rs @@ -1,10 +1,10 @@ +use arrow::compute::kernels::cast_utils::parse_interval_day_time; +use arrow_schema::{DataType, Field, FieldRef, Schema}; +use arroyo_connectors::connector_for_type; use std::str::FromStr; use std::sync::Arc; use std::{collections::HashMap, time::Duration}; -use arrow_schema::{DataType, Field, FieldRef, Schema}; -use arroyo_connectors::connector_for_type; - use crate::extension::remote_table::RemoteTableExtension; use crate::types::convert_data_type; use crate::{ @@ -72,8 +72,11 @@ pub struct ConnectorTable { pub watermark_field: Option, pub idle_time: Option, pub primary_keys: Arc>, - pub inferred_fields: Option>, + + // for lookup tables + pub lookup_cache_max_bytes: Option, + pub lookup_cache_ttl: Option, } multifield_partial_ord!( @@ -206,6 +209,8 @@ impl From for ConnectorTable { idle_time: DEFAULT_IDLE_TIME, primary_keys: Arc::new(vec![]), inferred_fields: None, + lookup_cache_max_bytes: None, + lookup_cache_ttl: None, } } } @@ -214,6 +219,7 @@ impl ConnectorTable { fn from_options( name: &str, connector: &str, + temporary: bool, mut fields: Vec, primary_keys: Vec, options: &mut HashMap, @@ -247,6 +253,14 @@ impl ConnectorTable { let framing = Framing::from_opts(options) .map_err(|e| DataFusionError::Plan(format!("invalid framing: '{e}'")))?; + if temporary { + if let Some(t) = options.insert("type".to_string(), "lookup".to_string()) { + if t != "lookup" { + return plan_err!("Cannot have a temporary table with type '{}'; temporary tables must be type 'lookup'", t); + } + } + } + let mut input_to_schema_fields = fields.clone(); if let Some(Format::Json(JsonFormat { debezium: true, .. })) = &format { @@ -318,6 +332,26 @@ impl ConnectorTable { .filter(|t| *t <= 0) .map(|t| Duration::from_micros(t as u64)); + table.lookup_cache_max_bytes = options + .remove("lookup.cache.max_bytes") + .map(|t| u64::from_str(&t)) + .transpose() + .map_err(|_| { + DataFusionError::Plan("lookup.cache.max_bytes must be set to a number".to_string()) + })?; + + table.lookup_cache_ttl = options + .remove("lookup.cache.ttl") + .map(|t| parse_interval_day_time(&t)) + .transpose() + .map_err(|e| { + DataFusionError::Plan(format!("lookup.cache.ttl must be a valid interval ({})", e)) + })? + .map(|t| { + Duration::from_secs(t.days as u64 * 60 * 60 * 24) + + Duration::from_millis(t.milliseconds as u64) + }); + if !options.is_empty() { let keys: Vec = options.keys().map(|s| format!("'{}'", s)).collect(); return plan_err!( @@ -755,6 +789,7 @@ impl Table { let table = ConnectorTable::from_options( &name, connector, + *temporary, fields, primary_keys, &mut with_map, diff --git a/crates/arroyo-rpc/proto/api.proto b/crates/arroyo-rpc/proto/api.proto index 06bacfa47..74874b6cb 100644 --- a/crates/arroyo-rpc/proto/api.proto +++ b/crates/arroyo-rpc/proto/api.proto @@ -81,6 +81,8 @@ message LookupJoinOperator { ConnectorOp connector = 3; repeated LookupJoinCondition key_exprs = 4; JoinType join_type = 5; + optional uint64 ttl_micros = 6; + optional uint64 max_capacity_bytes = 7; } message WindowFunctionOperator { diff --git a/crates/arroyo-worker/Cargo.toml b/crates/arroyo-worker/Cargo.toml index 270b22000..27c79ef08 100644 --- a/crates/arroyo-worker/Cargo.toml +++ b/crates/arroyo-worker/Cargo.toml @@ -77,6 +77,7 @@ itertools = "0.12.0" async-ffi = "0.5.0" dlopen2 = "0.7.0" dlopen2_derive = "0.4.0" +mini-moka = { version = "0.10.3" } [dev-dependencies] diff --git a/crates/arroyo-worker/src/arrow/lookup_join.rs b/crates/arroyo-worker/src/arrow/lookup_join.rs index 7e1fc1fad..d80c1423e 100644 --- a/crates/arroyo-worker/src/arrow/lookup_join.rs +++ b/crates/arroyo-worker/src/arrow/lookup_join.rs @@ -19,9 +19,11 @@ use datafusion::physical_expr::PhysicalExpr; use datafusion_proto::physical_plan::from_proto::parse_physical_expr; use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; use datafusion_proto::protobuf::PhysicalExprNode; +use mini_moka::sync::Cache; use prost::Message; use std::collections::HashMap; use std::sync::Arc; +use std::time::Duration; #[derive(Copy, Clone, PartialEq)] pub(crate) enum LookupJoinType { @@ -33,7 +35,7 @@ pub(crate) enum LookupJoinType { pub struct LookupJoin { connector: Box, key_exprs: Vec>, - cache: HashMap, OwnedRow>, + cache: Option>, key_row_converter: RowConverter, result_row_converter: RowConverter, join_type: LookupJoinType, @@ -68,12 +70,22 @@ impl ArrowOperator for LookupJoin { key_map.entry(row.owned()).or_default().push(i); } - let mut uncached_keys = Vec::new(); - for k in key_map.keys() { - if !self.cache.contains_key(k.row().as_ref()) { - uncached_keys.push(k.clone()); + let uncached_keys = if let Some(cache) = &mut self.cache { + let mut uncached_keys = Vec::new(); + for k in key_map.keys() { + if !cache.contains_key(k) { + uncached_keys.push(k); + } } - } + uncached_keys + } else { + key_map.keys().collect() + }; + + let mut results = HashMap::new(); + + #[allow(unused_assignments)] + let mut result_rows = None; if !uncached_keys.is_empty() { let cols = self @@ -93,16 +105,17 @@ impl ArrowOperator for LookupJoin { let keys = result_batch.remove_column(key_idx_col); let keys = keys.as_primitive::(); - let result_rows = self - .result_row_converter - .convert_columns(result_batch.columns()) - .unwrap(); + result_rows = Some( + self.result_row_converter + .convert_columns(result_batch.columns()) + .unwrap(), + ); - for (v, idx) in result_rows.iter().zip(keys) { - self.cache.insert( - uncached_keys[idx.unwrap() as usize].as_ref().to_vec(), - v.owned(), - ); + for (v, idx) in result_rows.as_ref().unwrap().iter().zip(keys) { + results.insert(uncached_keys[idx.unwrap() as usize].as_ref(), v); + if let Some(cache) = &mut self.cache { + cache.insert(uncached_keys[idx.unwrap() as usize].clone(), v.owned()); + } } } } @@ -112,12 +125,13 @@ impl ArrowOperator for LookupJoin { .empty_rows(batch.num_rows(), batch.num_rows() * 10); for row in rows.iter() { - output_rows.push( - self.cache - .get(row.data()) - .expect("row should be cached") - .row(), - ); + let row = self + .cache + .as_mut() + .and_then(|c| c.get(&row.owned())) + .unwrap_or_else(|| results.get(row.as_ref()).unwrap().owned()); + + output_rows.push(row.row()); } let right_side = self @@ -225,9 +239,22 @@ impl OperatorConstructor for LookupJoinConstructor { .unwrap_or_else(|| panic!("No connector with name '{}'", op.connector)) .make_lookup(operator_config.clone(), lookup_schema.clone())?; + let max_capacity_bytes = config.max_capacity_bytes.unwrap_or_else(|| 8 * 1024 * 1024); + let cache = (max_capacity_bytes > 0).then(|| { + let mut c = Cache::builder() + .weigher(|k: &OwnedRow, v: &OwnedRow| (k.as_ref().len() + v.as_ref().len()) as u32) + .max_capacity(max_capacity_bytes); + + if let Some(ttl) = config.ttl_micros { + c = c.time_to_live(Duration::from_micros(ttl)); + } + + c.build() + }); + Ok(ConstructedOperator::from_operator(Box::new(LookupJoin { connector, - cache: Default::default(), + cache, key_row_converter: RowConverter::new( exprs .iter() From 539383eef1d7ec87ce1c5c881c1929ee7b79a5a1 Mon Sep 17 00:00:00 2001 From: Micah Wylde Date: Sun, 12 Jan 2025 20:49:24 -0800 Subject: [PATCH 11/14] tests --- crates/arroyo-connectors/src/kafka/mod.rs | 2 ++ crates/arroyo-formats/src/de.rs | 21 +++++++++++---------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/crates/arroyo-connectors/src/kafka/mod.rs b/crates/arroyo-connectors/src/kafka/mod.rs index 515999627..14bb5455e 100644 --- a/crates/arroyo-connectors/src/kafka/mod.rs +++ b/crates/arroyo-connectors/src/kafka/mod.rs @@ -664,6 +664,7 @@ impl KafkaTester { let mut deserializer = ArrowDeserializer::new( format.clone(), Arc::new(aschema), + &schema.metadata_fields(), None, BadData::Fail {}, ); @@ -701,6 +702,7 @@ impl KafkaTester { let mut deserializer = ArrowDeserializer::new( format.clone(), Arc::new(aschema), + &schema.metadata_fields(), None, BadData::Fail {}, ); diff --git a/crates/arroyo-formats/src/de.rs b/crates/arroyo-formats/src/de.rs index c937a36e2..7d089b9c6 100644 --- a/crates/arroyo-formats/src/de.rs +++ b/crates/arroyo-formats/src/de.rs @@ -293,7 +293,7 @@ impl ArrowDeserializer { Self::with_schema_resolver_and_raw_schema( format, framing, - Arc::new(schema.schema_without_timestamp()), + schema.schema.clone(), Some(schema.timestamp_index), metadata_fields, bad_data, @@ -329,7 +329,7 @@ impl ArrowDeserializer { fn with_schema_resolver_and_raw_schema( format: Format, framing: Option, - schema_without_timestamp: Arc, + schema: Arc, timestamp_idx: Option, metadata_fields: &[MetadataField], bad_data: BadData, @@ -348,20 +348,20 @@ impl ArrowDeserializer { let metadata_names: HashSet<_> = metadata_fields.iter().map(|f| &f.field_name).collect(); let schema_without_additional = { - let fields = schema_without_timestamp + let fields = schema .fields() .iter() - .filter(|f| !metadata_names.contains(f.name())) + .filter(|f| !metadata_names.contains(f.name()) && f.name() != TIMESTAMP_FIELD) .cloned() .collect::>(); - Arc::new(Schema::new_with_metadata( - fields, - schema_without_timestamp.metadata.clone(), - )) + Arc::new(Schema::new_with_metadata(fields, schema.metadata.clone())) }; let buffer_decoder = match format { - Format::Json(..) + Format::Json(JsonFormat { + unstructured: false, + .. + }) | Format::Avro(AvroFormat { into_unstructured_json: false, .. @@ -388,7 +388,7 @@ impl ArrowDeserializer { buffer_decoder, timestamp_builder: timestamp_idx .map(|i| (i, TimestampNanosecondBuilder::with_capacity(128))), - final_schema: schema_without_timestamp, + final_schema: schema, decoder_schema: schema_without_additional, schema_registry: Arc::new(Mutex::new(HashMap::new())), bad_data, @@ -1000,6 +1000,7 @@ mod tests { assert!(result.is_empty()); let batch = deserializer.flush_buffer().unwrap().unwrap(); + println!("batch ={:?}", batch); assert_eq!(batch.num_rows(), 1); assert_eq!(batch.columns()[0].as_primitive::().value(0), 5); assert_eq!(batch.columns()[1].as_primitive::().value(0), 5); From 0aaf1a707d92d8fec7914a18301baf3c7c59fcfd Mon Sep 17 00:00:00 2001 From: Micah Wylde Date: Sun, 12 Jan 2025 21:14:27 -0800 Subject: [PATCH 12/14] clippy --- crates/arroyo-worker/src/arrow/lookup_join.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/arroyo-worker/src/arrow/lookup_join.rs b/crates/arroyo-worker/src/arrow/lookup_join.rs index d80c1423e..061fa3476 100644 --- a/crates/arroyo-worker/src/arrow/lookup_join.rs +++ b/crates/arroyo-worker/src/arrow/lookup_join.rs @@ -239,7 +239,7 @@ impl OperatorConstructor for LookupJoinConstructor { .unwrap_or_else(|| panic!("No connector with name '{}'", op.connector)) .make_lookup(operator_config.clone(), lookup_schema.clone())?; - let max_capacity_bytes = config.max_capacity_bytes.unwrap_or_else(|| 8 * 1024 * 1024); + let max_capacity_bytes = config.max_capacity_bytes.unwrap_or(8 * 1024 * 1024); let cache = (max_capacity_bytes > 0).then(|| { let mut c = Cache::builder() .weigher(|k: &OwnedRow, v: &OwnedRow| (k.as_ref().len() + v.as_ref().len()) as u32) From 129033353327f41b895b95facba2978fe8eaf415 Mon Sep 17 00:00:00 2001 From: Micah Wylde Date: Mon, 13 Jan 2025 14:23:11 -0800 Subject: [PATCH 13/14] Enforce primary keys for lookup joins --- crates/arroyo-connectors/src/impulse/mod.rs | 1 + crates/arroyo-connectors/src/nexmark/mod.rs | 1 + crates/arroyo-connectors/src/redis/mod.rs | 20 +++++-- crates/arroyo-planner/src/plan/join.rs | 22 +++++--- crates/arroyo-planner/src/tables.rs | 1 + .../error_lookup_join_non_primary_key.sql | 21 ++++++++ .../test/queries/error_missing_redis_key.sql | 19 +++++++ .../src/test/queries/lookup_join.sql | 52 ++++++++----------- .../arroyo-rpc/src/api_types/connections.rs | 10 ++-- 9 files changed, 103 insertions(+), 44 deletions(-) create mode 100644 crates/arroyo-planner/src/test/queries/error_lookup_join_non_primary_key.sql create mode 100644 crates/arroyo-planner/src/test/queries/error_missing_redis_key.sql diff --git a/crates/arroyo-connectors/src/impulse/mod.rs b/crates/arroyo-connectors/src/impulse/mod.rs index 0e7d83cd5..22cdaca2d 100644 --- a/crates/arroyo-connectors/src/impulse/mod.rs +++ b/crates/arroyo-connectors/src/impulse/mod.rs @@ -34,6 +34,7 @@ pub fn impulse_schema() -> ConnectionSchema { ], definition: None, inferred: None, + primary_keys: Default::default(), } } diff --git a/crates/arroyo-connectors/src/nexmark/mod.rs b/crates/arroyo-connectors/src/nexmark/mod.rs index e7b7a859b..752163f12 100644 --- a/crates/arroyo-connectors/src/nexmark/mod.rs +++ b/crates/arroyo-connectors/src/nexmark/mod.rs @@ -91,6 +91,7 @@ pub fn nexmark_schema() -> ConnectionSchema { .collect(), definition: None, inferred: None, + primary_keys: Default::default(), } } diff --git a/crates/arroyo-connectors/src/redis/mod.rs b/crates/arroyo-connectors/src/redis/mod.rs index 8dd358f74..2f84469e4 100644 --- a/crates/arroyo-connectors/src/redis/mod.rs +++ b/crates/arroyo-connectors/src/redis/mod.rs @@ -300,9 +300,23 @@ impl Connector for RedisConnector { } let sink = match typ.as_str() { - "lookup" => TableType::Lookup { - lookup: Default::default(), - }, + "lookup" => { + // for look-up tables, we require that there's a primary key metadata field + for f in &schema.fields { + if schema.primary_keys.contains(&f.field_name) { + if f.metadata_key.as_ref().map(|k| k != "key").unwrap_or(true) { + bail!( + "Redis lookup tables must have a PRIMARY KEY field defined as \ + `field_name TEXT GENERATED ALWAYS AS (metadata('key')) STORED`" + ); + } + } + } + + TableType::Lookup { + lookup: Default::default(), + } + } "sink" => { let target = match pull_opt("target", options)?.as_str() { "string" => Target::StringTable { diff --git a/crates/arroyo-planner/src/plan/join.rs b/crates/arroyo-planner/src/plan/join.rs index 9780f1ea4..0f532d84d 100644 --- a/crates/arroyo-planner/src/plan/join.rs +++ b/crates/arroyo-planner/src/plan/join.rs @@ -262,9 +262,22 @@ fn maybe_plan_lookup_join(join: &Join) -> Result> { return plan_err!("filter join conditions are not supported for lookup joins; must have an equality condition"); } + let mut lookup = FindLookupExtension::default(); + join.right.visit(&mut lookup)?; + + let connector = lookup + .table + .expect("right side of join does not have lookup"); + let on = join.on.iter().map(|(l, r)| { match r { - Expr::Column(c) => Ok((l.clone(), c.clone())), + Expr::Column(c) => { + if !connector.primary_keys.contains(&c.name) { + plan_err!("the right-side of a look-up join condition must be a PRIMARY KEY column, but '{}' is not", c.name) + } else { + Ok((l.clone(), c.clone())) + } + }, e => { plan_err!("invalid right-side condition for lookup join: `{}`; only column references are supported", expr_to_sql(e).map(|e| e.to_string()).unwrap_or_else(|_| e.to_string())) @@ -272,13 +285,6 @@ fn maybe_plan_lookup_join(join: &Join) -> Result> { } }).collect::>()?; - let mut lookup = FindLookupExtension::default(); - join.right.visit(&mut lookup)?; - - let connector = lookup - .table - .expect("right side of join does not have lookup"); - let left_input = JoinRewriter::create_join_key_plan( join.left.clone(), join.on.iter().map(|(l, _)| l.clone()).collect(), diff --git a/crates/arroyo-planner/src/tables.rs b/crates/arroyo-planner/src/tables.rs index b7f5b8f9d..6d8a06b3b 100644 --- a/crates/arroyo-planner/src/tables.rs +++ b/crates/arroyo-planner/src/tables.rs @@ -308,6 +308,7 @@ impl ConnectorTable { schema_fields, None, Some(fields.is_empty()), + primary_keys.iter().cloned().collect(), ) .map_err(|e| DataFusionError::Plan(format!("could not create connection schema: {}", e)))?; diff --git a/crates/arroyo-planner/src/test/queries/error_lookup_join_non_primary_key.sql b/crates/arroyo-planner/src/test/queries/error_lookup_join_non_primary_key.sql new file mode 100644 index 000000000..cf5b486bf --- /dev/null +++ b/crates/arroyo-planner/src/test/queries/error_lookup_join_non_primary_key.sql @@ -0,0 +1,21 @@ +--fail=the right-side of a look-up join condition must be a PRIMARY KEY column, but 'value' is not +create table impulse with ( + connector = 'impulse', + event_rate = '2' +); + +create temporary table lookup ( + key TEXT PRIMARY KEY GENERATED ALWAYS AS (metadata('key')) STORED, + value TEXT, + len INT +) with ( + connector = 'redis', + format = 'raw_string', + address = 'redis://localhost:6379', + format = 'json', + 'lookup.cache.max_bytes' = '100000' +); + +select A.counter, B.key, B.value, len +from impulse A inner join lookup B +on cast((A.counter % 10) as TEXT) = B.value; \ No newline at end of file diff --git a/crates/arroyo-planner/src/test/queries/error_missing_redis_key.sql b/crates/arroyo-planner/src/test/queries/error_missing_redis_key.sql new file mode 100644 index 000000000..537fbaf07 --- /dev/null +++ b/crates/arroyo-planner/src/test/queries/error_missing_redis_key.sql @@ -0,0 +1,19 @@ +--fail=Redis lookup tables must have a PRIMARY KEY field defined as `field_name TEXT GENERATED ALWAYS AS (metadata('key')) STORED` +create table impulse with ( + connector = 'impulse', + event_rate = '2' +); + +create table lookup ( + key TEXT PRIMARY KEY, + value TEXT +) with ( + connector = 'redis', + format = 'json', + address = 'redis://localhost:6379', + type = 'lookup' +); + +select A.counter, B.key, B.value +from impulse A left join lookup B +on cast((A.counter % 10) as TEXT) = B.key; diff --git a/crates/arroyo-planner/src/test/queries/lookup_join.sql b/crates/arroyo-planner/src/test/queries/lookup_join.sql index d0c12c31d..05ba36ace 100644 --- a/crates/arroyo-planner/src/test/queries/lookup_join.sql +++ b/crates/arroyo-planner/src/test/queries/lookup_join.sql @@ -1,39 +1,31 @@ -CREATE TABLE orders ( - order_id INT, - user_id INT, - product_id INT, - quantity INT, - order_timestamp TIMESTAMP -) with ( +CREATE TABLE events ( + event_id TEXT, + timestamp TIMESTAMP, + customer_id TEXT, + event_type TEXT +) WITH ( connector = 'kafka', - bootstrap_servers = 'localhost:9092', + topic = 'events', type = 'source', - topic = 'orders', - format = 'json' + format = 'json', + bootstrap_servers = 'broker:9092' ); -CREATE TEMPORARY TABLE products ( - key TEXT PRIMARY KEY, - product_name TEXT, - unit_price FLOAT, - category TEXT, - last_updated TIMESTAMP +create temporary table customers ( + customer_id TEXT PRIMARY KEY GENERATED ALWAYS AS (metadata('key')) STORED, + customer_name TEXT, + plan TEXT ) with ( connector = 'redis', + format = 'raw_string', + address = 'redis://localhost:6379', format = 'json', - type = 'lookup', - address = 'redis://localhost:6379' + 'lookup.cache.max_bytes' = '1000000', + 'lookup.cache.ttl' = '5 second' ); -SELECT - o.order_id, - o.user_id, - o.quantity, - o.order_timestamp, - p.product_name, - p.unit_price, - p.category, - (o.quantity * p.unit_price) as total_amount -FROM orders o - JOIN products p - ON concat('blah', o.product_id) = p.key; \ No newline at end of file +SELECT e.event_id, e.timestamp, e.customer_id, e.event_type, c.customer_name, c.plan +FROM events e +LEFT JOIN customers c +ON e.customer_id = c.customer_id +WHERE c.plan = 'Premium'; diff --git a/crates/arroyo-rpc/src/api_types/connections.rs b/crates/arroyo-rpc/src/api_types/connections.rs index a47d58151..63a30336a 100644 --- a/crates/arroyo-rpc/src/api_types/connections.rs +++ b/crates/arroyo-rpc/src/api_types/connections.rs @@ -1,14 +1,14 @@ +use crate::df::{ArroyoSchema, ArroyoSchemaRef}; use crate::formats::{BadData, Format, Framing}; use crate::{primitive_to_sql, MetadataField}; +use ahash::HashSet; use anyhow::bail; use arrow_schema::{DataType, Field, Fields, TimeUnit}; +use arroyo_types::ArroyoExtensionType; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; use std::fmt::{Display, Formatter}; use std::sync::Arc; - -use crate::df::{ArroyoSchema, ArroyoSchemaRef}; -use arroyo_types::ArroyoExtensionType; use utoipa::{IntoParams, ToSchema}; #[derive(Serialize, Deserialize, Clone, Debug, ToSchema)] @@ -252,6 +252,8 @@ pub struct ConnectionSchema { pub fields: Vec, pub definition: Option, pub inferred: Option, + #[serde(default)] + pub primary_keys: HashSet, } impl ConnectionSchema { @@ -263,6 +265,7 @@ impl ConnectionSchema { fields: Vec, definition: Option, inferred: Option, + primary_keys: HashSet, ) -> anyhow::Result { let s = ConnectionSchema { format, @@ -272,6 +275,7 @@ impl ConnectionSchema { fields, definition, inferred, + primary_keys, }; s.validate() From 450f212c90c1e642ec39817084f028f54b6c52b2 Mon Sep 17 00:00:00 2001 From: Micah Wylde Date: Mon, 13 Jan 2025 14:43:35 -0800 Subject: [PATCH 14/14] clippy --- crates/arroyo-connectors/src/redis/mod.rs | 14 +++++++------- crates/arroyo-rpc/src/api_types/connections.rs | 1 + 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/crates/arroyo-connectors/src/redis/mod.rs b/crates/arroyo-connectors/src/redis/mod.rs index 2f84469e4..46348c6c7 100644 --- a/crates/arroyo-connectors/src/redis/mod.rs +++ b/crates/arroyo-connectors/src/redis/mod.rs @@ -303,13 +303,13 @@ impl Connector for RedisConnector { "lookup" => { // for look-up tables, we require that there's a primary key metadata field for f in &schema.fields { - if schema.primary_keys.contains(&f.field_name) { - if f.metadata_key.as_ref().map(|k| k != "key").unwrap_or(true) { - bail!( - "Redis lookup tables must have a PRIMARY KEY field defined as \ - `field_name TEXT GENERATED ALWAYS AS (metadata('key')) STORED`" - ); - } + if schema.primary_keys.contains(&f.field_name) + && f.metadata_key.as_ref().map(|k| k != "key").unwrap_or(true) + { + bail!( + "Redis lookup tables must have a PRIMARY KEY field defined as \ + `field_name TEXT GENERATED ALWAYS AS (metadata('key')) STORED`" + ); } } diff --git a/crates/arroyo-rpc/src/api_types/connections.rs b/crates/arroyo-rpc/src/api_types/connections.rs index 63a30336a..5dfbd3429 100644 --- a/crates/arroyo-rpc/src/api_types/connections.rs +++ b/crates/arroyo-rpc/src/api_types/connections.rs @@ -257,6 +257,7 @@ pub struct ConnectionSchema { } impl ConnectionSchema { + #[allow(clippy::too_many_arguments)] pub fn try_new( format: Option, bad_data: Option,