From 81bc429d6d5eb2948fd10368c5bc4243185f44ec Mon Sep 17 00:00:00 2001 From: Micah Wylde Date: Mon, 13 Jan 2025 15:05:40 -0800 Subject: [PATCH] Lookup joins (#821) --- Cargo.lock | 96 +- crates/arroyo-api/src/connection_tables.rs | 12 + .../src/filesystem/sink/local.rs | 2 +- .../src/filesystem/sink/mod.rs | 3 +- .../src/filesystem/source.rs | 1 + crates/arroyo-connectors/src/fluvio/source.rs | 1 + crates/arroyo-connectors/src/impulse/mod.rs | 1 + crates/arroyo-connectors/src/kafka/mod.rs | 24 +- .../arroyo-connectors/src/kafka/sink/test.rs | 2 +- .../arroyo-connectors/src/kafka/source/mod.rs | 4 +- .../src/kafka/source/test.rs | 21 +- .../arroyo-connectors/src/kinesis/source.rs | 1 + .../arroyo-connectors/src/mqtt/sink/test.rs | 2 +- .../arroyo-connectors/src/mqtt/source/mod.rs | 3 +- .../arroyo-connectors/src/mqtt/source/test.rs | 4 +- .../arroyo-connectors/src/nats/source/mod.rs | 1 + crates/arroyo-connectors/src/nexmark/mod.rs | 1 + .../src/polling_http/operator.rs | 1 + .../arroyo-connectors/src/rabbitmq/source.rs | 1 + crates/arroyo-connectors/src/redis/lookup.rs | 100 +++ crates/arroyo-connectors/src/redis/mod.rs | 221 +++-- .../src/redis/operator/mod.rs | 1 - .../src/redis/{operator => }/sink.rs | 102 ++- crates/arroyo-connectors/src/redis/table.json | 9 + .../src/single_file/source.rs | 1 + crates/arroyo-connectors/src/sse/operator.rs | 1 + .../src/websocket/operator.rs | 1 + crates/arroyo-datastream/src/logical.rs | 32 +- crates/arroyo-formats/src/avro/de.rs | 31 +- crates/arroyo-formats/src/de.rs | 827 ++++++++++-------- crates/arroyo-operator/src/connector.rs | 44 +- crates/arroyo-operator/src/context.rs | 92 +- crates/arroyo-operator/src/operator.rs | 2 +- crates/arroyo-planner/src/builder.rs | 2 +- crates/arroyo-planner/src/extension/join.rs | 3 +- crates/arroyo-planner/src/extension/lookup.rs | 191 ++++ crates/arroyo-planner/src/extension/mod.rs | 4 + crates/arroyo-planner/src/extension/sink.rs | 1 + crates/arroyo-planner/src/lib.rs | 5 +- crates/arroyo-planner/src/plan/join.rs | 128 ++- crates/arroyo-planner/src/rewriters.rs | 15 + crates/arroyo-planner/src/schemas.rs | 2 +- crates/arroyo-planner/src/tables.rs | 84 +- .../error_lookup_join_non_primary_key.sql | 21 + .../test/queries/error_missing_redis_key.sql | 19 + .../src/test/queries/lookup_join.sql | 31 + crates/arroyo-rpc/proto/api.proto | 15 + .../arroyo-rpc/src/api_types/connections.rs | 14 +- crates/arroyo-rpc/src/df.rs | 21 +- crates/arroyo-rpc/src/lib.rs | 13 + crates/arroyo-types/src/lib.rs | 4 +- crates/arroyo-worker/Cargo.toml | 1 + crates/arroyo-worker/src/arrow/lookup_join.rs | 275 ++++++ crates/arroyo-worker/src/arrow/mod.rs | 1 + .../src/arrow/tumbling_aggregating_window.rs | 2 +- crates/arroyo-worker/src/engine.rs | 12 +- 56 files changed, 1854 insertions(+), 655 deletions(-) create mode 100644 crates/arroyo-connectors/src/redis/lookup.rs delete mode 100644 crates/arroyo-connectors/src/redis/operator/mod.rs rename crates/arroyo-connectors/src/redis/{operator => }/sink.rs (82%) create mode 100644 crates/arroyo-planner/src/extension/lookup.rs create mode 100644 crates/arroyo-planner/src/test/queries/error_lookup_join_non_primary_key.sql create mode 100644 crates/arroyo-planner/src/test/queries/error_missing_redis_key.sql create mode 100644 crates/arroyo-planner/src/test/queries/lookup_join.sql create mode 100644 crates/arroyo-worker/src/arrow/lookup_join.rs diff --git a/Cargo.lock b/Cargo.lock index 564a4d56d..6bdf3f55e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1162,6 +1162,7 @@ dependencies = [ "local-ip-address", "md-5", "memchr", + "mini-moka", "object_store", "once_cell", "ordered-float 3.9.2", @@ -2202,6 +2203,12 @@ version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +[[package]] +name = "bytecount" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce" + [[package]] name = "bytemuck" version = "1.19.0" @@ -2281,6 +2288,19 @@ dependencies = [ "serde", ] +[[package]] +name = "cargo_metadata" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4acbb09d9ee8e23699b9634375c72795d095bf268439da88562cf9b501f181fa" +dependencies = [ + "camino", + "cargo-platform", + "semver", + "serde", + "serde_json", +] + [[package]] name = "cargo_metadata" version = "0.18.1" @@ -3784,6 +3804,15 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "error-chain" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d2f06b9cac1506ece98fe3231e3cc9c4410ec3d5b1f24ae1c8946f0742cdefc" +dependencies = [ + "version_check", +] + [[package]] name = "event-listener" version = "2.5.3" @@ -5963,6 +5992,21 @@ dependencies = [ "unicase", ] +[[package]] +name = "mini-moka" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c325dfab65f261f386debee8b0969da215b3fa0037e74c8a1234db7ba986d803" +dependencies = [ + "crossbeam-channel", + "crossbeam-utils", + "dashmap 5.5.3", + "skeptic", + "smallvec", + "tagptr", + "triomphe", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -7253,8 +7297,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15" dependencies = [ "bytes", - "heck 0.4.1", - "itertools 0.10.5", + "heck 0.5.0", + "itertools 0.13.0", "log", "multimap", "once_cell", @@ -7300,7 +7344,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e9552f850d5f0964a4e4d0bf306459ac29323ddfbae05e35a7c0d35cb0803cc5" dependencies = [ "anyhow", - "itertools 0.10.5", + "itertools 0.13.0", "proc-macro2", "quote", "syn 2.0.87", @@ -7380,6 +7424,17 @@ dependencies = [ "cc", ] +[[package]] +name = "pulldown-cmark" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57206b407293d2bcd3af849ce869d52068623f19e1b5ff8e8778e3309439682b" +dependencies = [ + "bitflags 2.6.0", + "memchr", + "unicase", +] + [[package]] name = "pyo3" version = "0.21.2" @@ -7390,7 +7445,7 @@ dependencies = [ "indoc", "libc", "memoffset", - "parking_lot 0.11.2", + "parking_lot 0.12.3", "portable-atomic", "pyo3-build-config", "pyo3-ffi", @@ -8739,6 +8794,21 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" +[[package]] +name = "skeptic" +version = "0.13.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16d23b015676c90a0f01c197bfdc786c20342c73a0afdda9025adb0bc42940a8" +dependencies = [ + "bytecount", + "cargo_metadata 0.14.2", + "error-chain", + "glob", + "pulldown-cmark", + "tempfile", + "walkdir", +] + [[package]] name = "slab" version = "0.4.9" @@ -8775,7 +8845,7 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03c3c6b7927ffe7ecaa769ee0e3994da3b8cafc8f444578982c83ecb161af917" dependencies = [ - "heck 0.4.1", + "heck 0.5.0", "proc-macro2", "quote", "syn 2.0.87", @@ -9074,6 +9144,12 @@ dependencies = [ "libc", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "tap" version = "1.0.1" @@ -9811,6 +9887,12 @@ dependencies = [ "tracing-serde", ] +[[package]] +name = "triomphe" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef8f7726da4807b58ea5c96fdc122f80702030edc33b35aff9190a51148ccc85" + [[package]] name = "try-lock" version = "0.2.5" @@ -10186,7 +10268,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2990d9ea5967266ea0ccf413a4aa5c42a93dbcfda9cb49a97de6931726b12566" dependencies = [ "anyhow", - "cargo_metadata", + "cargo_metadata 0.18.1", "cfg-if", "regex", "rustversion", @@ -10404,7 +10486,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] diff --git a/crates/arroyo-api/src/connection_tables.rs b/crates/arroyo-api/src/connection_tables.rs index 82f705886..23b3df140 100644 --- a/crates/arroyo-api/src/connection_tables.rs +++ b/crates/arroyo-api/src/connection_tables.rs @@ -509,6 +509,9 @@ async fn expand_avro_schema( ConnectionType::Sink => { // don't fetch schemas for sinks for now } + ConnectionType::Lookup => { + todo!("lookup tables cannot be created via the UI") + } } } @@ -521,6 +524,9 @@ async fn expand_avro_schema( schema.inferred = Some(true); Ok(schema) } + ConnectionType::Lookup => { + todo!("lookup tables cannot be created via the UI") + } }; }; @@ -596,6 +602,9 @@ async fn expand_proto_schema( ConnectionType::Sink => { // don't fetch schemas for sinks for now } + ConnectionType::Lookup => { + todo!("lookup tables cannot be created via the UI") + } } } @@ -697,6 +706,9 @@ async fn expand_json_schema( // don't fetch schemas for sinks for now until we're better able to conform our output to the schema schema.inferred = Some(true); } + ConnectionType::Lookup => { + todo!("lookup tables cannot be created via the UI") + } } } diff --git a/crates/arroyo-connectors/src/filesystem/sink/local.rs b/crates/arroyo-connectors/src/filesystem/sink/local.rs index a9a8ae07b..7aec35964 100644 --- a/crates/arroyo-connectors/src/filesystem/sink/local.rs +++ b/crates/arroyo-connectors/src/filesystem/sink/local.rs @@ -233,7 +233,7 @@ impl TwoPhaseCommitter for LocalFileSystemWrite let storage_provider = StorageProvider::for_url(&self.final_dir).await?; - let schema = Arc::new(ctx.in_schemas[0].clone()); + let schema = ctx.in_schemas[0].clone(); self.commit_state = Some(match self.file_settings.commit_style.unwrap() { CommitStyle::DeltaLake => CommitState::DeltaLake { diff --git a/crates/arroyo-connectors/src/filesystem/sink/mod.rs b/crates/arroyo-connectors/src/filesystem/sink/mod.rs index 46521e056..20b318074 100644 --- a/crates/arroyo-connectors/src/filesystem/sink/mod.rs +++ b/crates/arroyo-connectors/src/filesystem/sink/mod.rs @@ -1564,8 +1564,7 @@ impl TwoPhaseCommitter for FileSystemSink, ) -> Result<()> { - self.start(Arc::new(ctx.in_schemas.first().unwrap().clone())) - .await?; + self.start(ctx.in_schemas.first().unwrap().clone()).await?; let mut max_file_index = 0; let mut recovered_files = Vec::new(); diff --git a/crates/arroyo-connectors/src/filesystem/source.rs b/crates/arroyo-connectors/src/filesystem/source.rs index f2792f897..0f120fcbe 100644 --- a/crates/arroyo-connectors/src/filesystem/source.rs +++ b/crates/arroyo-connectors/src/filesystem/source.rs @@ -128,6 +128,7 @@ impl FileSystemSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &[], ); let parallelism = ctx.task_info.parallelism; let task_index = ctx.task_info.task_index; diff --git a/crates/arroyo-connectors/src/fluvio/source.rs b/crates/arroyo-connectors/src/fluvio/source.rs index 491a82b54..a64581053 100644 --- a/crates/arroyo-connectors/src/fluvio/source.rs +++ b/crates/arroyo-connectors/src/fluvio/source.rs @@ -57,6 +57,7 @@ impl SourceOperator for FluvioSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &[], ); match self.run_int(ctx, collector).await { diff --git a/crates/arroyo-connectors/src/impulse/mod.rs b/crates/arroyo-connectors/src/impulse/mod.rs index 0e7d83cd5..22cdaca2d 100644 --- a/crates/arroyo-connectors/src/impulse/mod.rs +++ b/crates/arroyo-connectors/src/impulse/mod.rs @@ -34,6 +34,7 @@ pub fn impulse_schema() -> ConnectionSchema { ], definition: None, inferred: None, + primary_keys: Default::default(), } } diff --git a/crates/arroyo-connectors/src/kafka/mod.rs b/crates/arroyo-connectors/src/kafka/mod.rs index 9d4e402b4..14bb5455e 100644 --- a/crates/arroyo-connectors/src/kafka/mod.rs +++ b/crates/arroyo-connectors/src/kafka/mod.rs @@ -641,14 +641,14 @@ impl KafkaTester { let mut deserializer = ArrowDeserializer::with_schema_resolver( format.clone(), None, - aschema.clone(), + Arc::new(aschema), + &schema.metadata_fields(), BadData::Fail {}, Arc::new(schema_resolver), ); - let mut builders = aschema.builders(); let mut error = deserializer - .deserialize_slice(&mut builders, &msg, SystemTime::now(), None) + .deserialize_slice(&msg, SystemTime::now(), None) .await .into_iter() .next(); @@ -663,14 +663,14 @@ impl KafkaTester { let aschema: ArroyoSchema = schema.clone().into(); let mut deserializer = ArrowDeserializer::new( format.clone(), - aschema.clone(), + Arc::new(aschema), + &schema.metadata_fields(), None, BadData::Fail {}, ); - let mut builders = aschema.builders(); let mut error = deserializer - .deserialize_slice(&mut builders, &msg, SystemTime::now(), None) + .deserialize_slice(&msg, SystemTime::now(), None) .await .into_iter() .next(); @@ -699,12 +699,16 @@ impl KafkaTester { } Format::Protobuf(_) => { let aschema: ArroyoSchema = schema.clone().into(); - let mut deserializer = - ArrowDeserializer::new(format.clone(), aschema.clone(), None, BadData::Fail {}); - let mut builders = aschema.builders(); + let mut deserializer = ArrowDeserializer::new( + format.clone(), + Arc::new(aschema), + &schema.metadata_fields(), + None, + BadData::Fail {}, + ); let mut error = deserializer - .deserialize_slice(&mut builders, &msg, SystemTime::now(), None) + .deserialize_slice(&msg, SystemTime::now(), None) .await .into_iter() .next(); diff --git a/crates/arroyo-connectors/src/kafka/sink/test.rs b/crates/arroyo-connectors/src/kafka/sink/test.rs index 30896e80b..26684e643 100644 --- a/crates/arroyo-connectors/src/kafka/sink/test.rs +++ b/crates/arroyo-connectors/src/kafka/sink/test.rs @@ -96,7 +96,7 @@ impl KafkaTopicTester { None, command_tx, 1, - vec![ArroyoSchema::new_unkeyed(schema(), 0)], + vec![Arc::new(ArroyoSchema::new_unkeyed(schema(), 0))], None, HashMap::new(), ) diff --git a/crates/arroyo-connectors/src/kafka/source/mod.rs b/crates/arroyo-connectors/src/kafka/source/mod.rs index 38c4a1cd2..020c0da7a 100644 --- a/crates/arroyo-connectors/src/kafka/source/mod.rs +++ b/crates/arroyo-connectors/src/kafka/source/mod.rs @@ -173,6 +173,7 @@ impl KafkaSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &self.metadata_fields, schema_resolver.clone(), ); } else { @@ -180,6 +181,7 @@ impl KafkaSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &self.metadata_fields, ); } @@ -201,7 +203,7 @@ impl KafkaSourceFunc { let connector_metadata = if !self.metadata_fields.is_empty() { let mut connector_metadata = HashMap::new(); for f in &self.metadata_fields { - connector_metadata.insert(&f.field_name, match f.key.as_str() { + connector_metadata.insert(f.field_name.as_str(), match f.key.as_str() { "offset_id" => FieldValueType::Int64(msg.offset()), "partition" => FieldValueType::Int32(msg.partition()), "topic" => FieldValueType::String(topic), diff --git a/crates/arroyo-connectors/src/kafka/source/test.rs b/crates/arroyo-connectors/src/kafka/source/test.rs index 3dab460c3..316e93135 100644 --- a/crates/arroyo-connectors/src/kafka/source/test.rs +++ b/crates/arroyo-connectors/src/kafka/source/test.rs @@ -5,14 +5,10 @@ use arroyo_state::tables::ErasedTable; use arroyo_state::{BackingStore, StateBackend}; use rand::random; +use crate::kafka::SourceOffset; use arrow::array::{Array, StringArray}; +use arrow::datatypes::DataType::UInt64; use arrow::datatypes::TimeUnit; -use std::collections::{HashMap, VecDeque}; -use std::num::NonZeroU32; -use std::sync::Arc; -use std::time::{Duration, SystemTime}; - -use crate::kafka::SourceOffset; use arroyo_operator::context::{ batch_bounded, ArrowCollector, BatchReceiver, OperatorContext, SourceCollector, SourceContext, }; @@ -29,6 +25,10 @@ use rdkafka::admin::{AdminClient, AdminOptions, NewTopic}; use rdkafka::producer::{BaseProducer, BaseRecord}; use rdkafka::ClientConfig; use serde::{Deserialize, Serialize}; +use std::collections::{HashMap, VecDeque}; +use std::num::NonZeroU32; +use std::sync::Arc; +use std::time::{Duration, SystemTime}; use tokio::sync::mpsc::{channel, Receiver, Sender}; use super::KafkaSourceFunc; @@ -108,7 +108,7 @@ impl KafkaTopicTester { operator_ids: vec![task_info.operator_id.clone()], }); - let out_schema = Some(ArroyoSchema::new_unkeyed( + let out_schema = Some(Arc::new(ArroyoSchema::new_unkeyed( Arc::new(Schema::new(vec![ Field::new( "_timestamp", @@ -118,7 +118,7 @@ impl KafkaTopicTester { Field::new("value", DataType::Utf8, false), ])), 0, - )); + ))); let task_info = Arc::new(task_info); @@ -389,6 +389,7 @@ async fn test_kafka_with_metadata_fields() { let metadata_fields = vec![MetadataField { field_name: "offset".to_string(), key: "offset_id".to_string(), + data_type: Some(UInt64), }]; // Set metadata fields in KafkaSourceFunc @@ -420,7 +421,7 @@ async fn test_kafka_with_metadata_fields() { command_tx.clone(), 1, vec![], - Some(ArroyoSchema::new_unkeyed( + Some(Arc::new(ArroyoSchema::new_unkeyed( Arc::new(Schema::new(vec![ Field::new( "_timestamp", @@ -431,7 +432,7 @@ async fn test_kafka_with_metadata_fields() { Field::new("offset", DataType::Int64, false), ])), 0, - )), + ))), kafka.tables(), ) .await; diff --git a/crates/arroyo-connectors/src/kinesis/source.rs b/crates/arroyo-connectors/src/kinesis/source.rs index ce43a0411..7cbaf4970 100644 --- a/crates/arroyo-connectors/src/kinesis/source.rs +++ b/crates/arroyo-connectors/src/kinesis/source.rs @@ -173,6 +173,7 @@ impl SourceOperator for KinesisSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &[], ); match self.run_int(ctx, collector).await { diff --git a/crates/arroyo-connectors/src/mqtt/sink/test.rs b/crates/arroyo-connectors/src/mqtt/sink/test.rs index 60c1d03e6..7f2ec9826 100644 --- a/crates/arroyo-connectors/src/mqtt/sink/test.rs +++ b/crates/arroyo-connectors/src/mqtt/sink/test.rs @@ -84,7 +84,7 @@ impl MqttTopicTester { None, command_tx, 1, - vec![ArroyoSchema::new_unkeyed(schema(), 0)], + vec![Arc::new(ArroyoSchema::new_unkeyed(schema(), 0))], None, HashMap::new(), ) diff --git a/crates/arroyo-connectors/src/mqtt/source/mod.rs b/crates/arroyo-connectors/src/mqtt/source/mod.rs index 6c2d51577..4f1508fd2 100644 --- a/crates/arroyo-connectors/src/mqtt/source/mod.rs +++ b/crates/arroyo-connectors/src/mqtt/source/mod.rs @@ -101,6 +101,7 @@ impl MqttSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &self.metadata_fields, ); if ctx.task_info.task_index > 0 { @@ -152,7 +153,7 @@ impl MqttSourceFunc { let connector_metadata = if !self.metadata_fields.is_empty() { let mut connector_metadata = HashMap::new(); for mf in &self.metadata_fields { - connector_metadata.insert(&mf.field_name, match mf.key.as_str() { + connector_metadata.insert(mf.field_name.as_str(), match mf.key.as_str() { "topic" => FieldValueType::String(&topic), k => unreachable!("invalid metadata key '{}' for mqtt", k) }); diff --git a/crates/arroyo-connectors/src/mqtt/source/test.rs b/crates/arroyo-connectors/src/mqtt/source/test.rs index 4690cdc88..f184a9342 100644 --- a/crates/arroyo-connectors/src/mqtt/source/test.rs +++ b/crates/arroyo-connectors/src/mqtt/source/test.rs @@ -141,7 +141,7 @@ impl MqttTopicTester { command_tx.clone(), 1, vec![], - Some(ArroyoSchema::new_unkeyed( + Some(Arc::new(ArroyoSchema::new_unkeyed( Arc::new(Schema::new(vec![ Field::new( "_timestamp", @@ -151,7 +151,7 @@ impl MqttTopicTester { Field::new("value", DataType::UInt64, false), ])), 0, - )), + ))), mqtt.tables(), ) .await; diff --git a/crates/arroyo-connectors/src/nats/source/mod.rs b/crates/arroyo-connectors/src/nats/source/mod.rs index 7cee2bc1e..c592b73c2 100644 --- a/crates/arroyo-connectors/src/nats/source/mod.rs +++ b/crates/arroyo-connectors/src/nats/source/mod.rs @@ -333,6 +333,7 @@ impl NatsSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &[], ); let nats_client = get_nats_client(&self.connection) diff --git a/crates/arroyo-connectors/src/nexmark/mod.rs b/crates/arroyo-connectors/src/nexmark/mod.rs index e7b7a859b..752163f12 100644 --- a/crates/arroyo-connectors/src/nexmark/mod.rs +++ b/crates/arroyo-connectors/src/nexmark/mod.rs @@ -91,6 +91,7 @@ pub fn nexmark_schema() -> ConnectionSchema { .collect(), definition: None, inferred: None, + primary_keys: Default::default(), } } diff --git a/crates/arroyo-connectors/src/polling_http/operator.rs b/crates/arroyo-connectors/src/polling_http/operator.rs index f6217421c..13a8373ae 100644 --- a/crates/arroyo-connectors/src/polling_http/operator.rs +++ b/crates/arroyo-connectors/src/polling_http/operator.rs @@ -208,6 +208,7 @@ impl PollingHttpSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &[], ); // since there's no way to partition across an http source, only read on the first task diff --git a/crates/arroyo-connectors/src/rabbitmq/source.rs b/crates/arroyo-connectors/src/rabbitmq/source.rs index 05d7a53e5..b2af95ff3 100644 --- a/crates/arroyo-connectors/src/rabbitmq/source.rs +++ b/crates/arroyo-connectors/src/rabbitmq/source.rs @@ -50,6 +50,7 @@ impl SourceOperator for RabbitmqStreamSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &[], ); match self.run_int(ctx, collector).await { diff --git a/crates/arroyo-connectors/src/redis/lookup.rs b/crates/arroyo-connectors/src/redis/lookup.rs new file mode 100644 index 000000000..79400bc4d --- /dev/null +++ b/crates/arroyo-connectors/src/redis/lookup.rs @@ -0,0 +1,100 @@ +use crate::redis::sink::GeneralConnection; +use crate::redis::RedisClient; +use arrow::array::{Array, ArrayRef, AsArray, RecordBatch}; +use arrow::datatypes::DataType; +use arroyo_formats::de::{ArrowDeserializer, FieldValueType}; +use arroyo_operator::connector::LookupConnector; +use arroyo_rpc::MetadataField; +use arroyo_types::{SourceError, LOOKUP_KEY_INDEX_FIELD}; +use async_trait::async_trait; +use redis::aio::ConnectionLike; +use redis::{cmd, Value}; +use std::collections::HashMap; + +pub struct RedisLookup { + pub(crate) deserializer: ArrowDeserializer, + pub(crate) client: RedisClient, + pub(crate) connection: Option, + pub(crate) metadata_fields: Vec, +} + +#[async_trait] +impl LookupConnector for RedisLookup { + fn name(&self) -> String { + "RedisLookup".to_string() + } + + async fn lookup(&mut self, keys: &[ArrayRef]) -> Option> { + if self.connection.is_none() { + self.connection = Some(self.client.get_connection().await.unwrap()); + } + + assert_eq!(keys.len(), 1, "redis lookup can only have a single key"); + assert_eq!( + *keys[0].data_type(), + DataType::Utf8, + "redis lookup key must be a string" + ); + + let connection = self.connection.as_mut().unwrap(); + + let mut mget = cmd("mget"); + + let keys = keys[0].as_string::(); + + for k in keys { + mget.arg(k.unwrap()); + } + + let Value::Array(vs) = connection.req_packed_command(&mget).await.unwrap() else { + panic!("value was not an array"); + }; + + assert_eq!( + vs.len(), + keys.len(), + "Redis sent back the wrong number of values" + ); + + let mut additional = HashMap::new(); + + for (idx, (v, k)) in vs.iter().zip(keys).enumerate() { + additional.insert(LOOKUP_KEY_INDEX_FIELD, FieldValueType::UInt64(idx as u64)); + for m in &self.metadata_fields { + additional.insert( + m.field_name.as_str(), + match m.key.as_str() { + "key" => FieldValueType::String(k.unwrap()), + k => unreachable!("Invalid metadata key '{}'", k), + }, + ); + } + + let errors = match v { + Value::Nil => { + self.deserializer.deserialize_null(Some(&additional)); + vec![] + } + Value::SimpleString(s) => { + self.deserializer + .deserialize_without_timestamp(s.as_bytes(), Some(&additional)) + .await + } + Value::BulkString(v) => { + self.deserializer + .deserialize_without_timestamp(v, Some(&additional)) + .await + } + v => { + panic!("unexpected type {:?}", v); + } + }; + + if !errors.is_empty() { + return Some(Err(errors.into_iter().next().unwrap())); + } + } + + self.deserializer.flush_buffer() + } +} diff --git a/crates/arroyo-connectors/src/redis/mod.rs b/crates/arroyo-connectors/src/redis/mod.rs index 7a78e88d0..46348c6c7 100644 --- a/crates/arroyo-connectors/src/redis/mod.rs +++ b/crates/arroyo-connectors/src/redis/mod.rs @@ -1,27 +1,31 @@ -mod operator; +pub mod lookup; +pub mod sink; +use crate::redis::lookup::RedisLookup; +use crate::redis::sink::{GeneralConnection, RedisSinkFunc}; +use crate::{pull_opt, pull_option_to_u64}; use anyhow::{anyhow, bail}; +use arrow::datatypes::{DataType, Schema}; +use arroyo_formats::de::ArrowDeserializer; use arroyo_formats::ser::ArrowSerializer; -use arroyo_operator::connector::{Connection, Connector}; +use arroyo_operator::connector::{Connection, Connector, LookupConnector, MetadataDef}; use arroyo_operator::operator::ConstructedOperator; +use arroyo_rpc::api_types::connections::{ + ConnectionProfile, ConnectionSchema, ConnectionType, FieldType, PrimitiveType, + TestSourceMessage, +}; +use arroyo_rpc::schema_resolver::FailingSchemaResolver; use arroyo_rpc::var_str::VarStr; +use arroyo_rpc::OperatorConfig; use redis::aio::ConnectionManager; use redis::cluster::ClusterClient; use redis::{Client, ConnectionInfo, IntoConnectionInfo}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use std::sync::Arc; use tokio::sync::oneshot::Receiver; use typify::import_types; -use arroyo_rpc::api_types::connections::{ - ConnectionProfile, ConnectionSchema, ConnectionType, FieldType, PrimitiveType, - TestSourceMessage, -}; -use arroyo_rpc::OperatorConfig; - -use crate::redis::operator::sink::{GeneralConnection, RedisSinkFunc}; -use crate::{pull_opt, pull_option_to_u64}; - pub struct RedisConnector {} const CONFIG_SCHEMA: &str = include_str!("./profile.json"); @@ -37,7 +41,7 @@ import_types!( import_types!(schema = "src/redis/table.json"); -enum RedisClient { +pub(crate) enum RedisClient { Standard(Client), Clustered(ClusterClient), } @@ -174,6 +178,13 @@ impl Connector for RedisConnector { } } + fn metadata_defs(&self) -> &'static [MetadataDef] { + &[MetadataDef { + name: "key", + data_type: DataType::Utf8, + }] + } + fn table_type(&self, _: Self::ProfileT, _: Self::TableT) -> ConnectionType { ConnectionType::Source } @@ -289,52 +300,73 @@ impl Connector for RedisConnector { } let sink = match typ.as_str() { - "sink" => TableType::Target(match pull_opt("target", options)?.as_str() { - "string" => Target::StringTable { - key_prefix: pull_opt("target.key_prefix", options)?, - key_column: options - .remove("target.key_column") - .map(|name| validate_column(schema, name, "target.key_column")) - .transpose()?, - ttl_secs: pull_option_to_u64("target.ttl_secs", options)? - .map(|t| t.try_into()) - .transpose() - .map_err(|_| anyhow!("target.ttl_secs must be greater than 0"))?, - }, - "list" => Target::ListTable { - list_prefix: pull_opt("target.key_prefix", options)?, - list_key_column: options - .remove("target.key_column") - .map(|name| validate_column(schema, name, "target.key_column")) - .transpose()?, - max_length: pull_option_to_u64("target.max_length", options)? - .map(|t| t.try_into()) - .transpose() - .map_err(|_| anyhow!("target.max_length must be greater than 0"))?, - operation: match options.remove("target.operation").as_deref() { - Some("append") | None => ListOperation::Append, - Some("prepend") => ListOperation::Prepend, - Some(op) => { - bail!("'{}' is not a valid value for target.operation; must be one of 'append' or 'prepend'", op); - } - }, - }, - "hash" => Target::HashTable { - hash_field_column: validate_column( - schema, - pull_opt("target.field_column", options)?, - "targets.field_column", - )?, - hash_key_column: options - .remove("target.key_column") - .map(|name| validate_column(schema, name, "target.key_column")) - .transpose()?, - hash_key_prefix: pull_opt("target.key_prefix", options)?, - }, - s => { - bail!("'{}' is not a valid redis target", s); + "lookup" => { + // for look-up tables, we require that there's a primary key metadata field + for f in &schema.fields { + if schema.primary_keys.contains(&f.field_name) + && f.metadata_key.as_ref().map(|k| k != "key").unwrap_or(true) + { + bail!( + "Redis lookup tables must have a PRIMARY KEY field defined as \ + `field_name TEXT GENERATED ALWAYS AS (metadata('key')) STORED`" + ); + } } - }), + + TableType::Lookup { + lookup: Default::default(), + } + } + "sink" => { + let target = match pull_opt("target", options)?.as_str() { + "string" => Target::StringTable { + key_prefix: pull_opt("target.key_prefix", options)?, + key_column: options + .remove("target.key_column") + .map(|name| validate_column(schema, name, "target.key_column")) + .transpose()?, + ttl_secs: pull_option_to_u64("target.ttl_secs", options)? + .map(|t| t.try_into()) + .transpose() + .map_err(|_| anyhow!("target.ttl_secs must be greater than 0"))?, + }, + "list" => Target::ListTable { + list_prefix: pull_opt("target.key_prefix", options)?, + list_key_column: options + .remove("target.key_column") + .map(|name| validate_column(schema, name, "target.key_column")) + .transpose()?, + max_length: pull_option_to_u64("target.max_length", options)? + .map(|t| t.try_into()) + .transpose() + .map_err(|_| anyhow!("target.max_length must be greater than 0"))?, + operation: match options.remove("target.operation").as_deref() { + Some("append") | None => ListOperation::Append, + Some("prepend") => ListOperation::Prepend, + Some(op) => { + bail!("'{}' is not a valid value for target.operation; must be one of 'append' or 'prepend'", op); + } + }, + }, + "hash" => Target::HashTable { + hash_field_column: validate_column( + schema, + pull_opt("target.field_column", options)?, + "targets.field_column", + )?, + hash_key_column: options + .remove("target.key_column") + .map(|name| validate_column(schema, name, "target.key_column")) + .transpose()?, + hash_key_prefix: pull_opt("target.key_prefix", options)?, + }, + s => { + bail!("'{}' is not a valid redis target", s); + } + }; + + TableType::Sink { target } + } s => { bail!("'{}' is not a valid type; must be `sink`", s); } @@ -371,6 +403,11 @@ impl Connector for RedisConnector { let _ = RedisClient::new(&config)?; + let (connection_type, description) = match &table.connector_type { + TableType::Sink { .. } => (ConnectionType::Sink, "RedisSink"), + TableType::Lookup { .. } => (ConnectionType::Lookup, "RedisLookup"), + }; + let config = OperatorConfig { connection: serde_json::to_value(config).unwrap(), table: serde_json::to_value(table).unwrap(), @@ -378,17 +415,17 @@ impl Connector for RedisConnector { format: Some(format), bad_data: schema.bad_data.clone(), framing: schema.framing.clone(), - metadata_fields: vec![], + metadata_fields: schema.metadata_fields(), }; Ok(Connection { id, connector: self.name(), name: name.to_string(), - connection_type: ConnectionType::Sink, + connection_type, schema, config: serde_json::to_string(&config).unwrap(), - description: "RedisSink".to_string(), + description: description.to_string(), }) } @@ -400,22 +437,52 @@ impl Connector for RedisConnector { ) -> anyhow::Result { let client = RedisClient::new(&profile)?; - let (tx, cmd_rx) = tokio::sync::mpsc::channel(128); - let (cmd_tx, rx) = tokio::sync::mpsc::channel(128); - - Ok(ConstructedOperator::from_operator(Box::new( - RedisSinkFunc { - serializer: ArrowSerializer::new( - config.format.expect("redis table must have a format"), - ), - table, - client, - cmd_q: Some((cmd_tx, cmd_rx)), - tx, - rx, - key_index: None, - hash_index: None, - }, - ))) + match table.connector_type { + TableType::Sink { target } => { + let (tx, cmd_rx) = tokio::sync::mpsc::channel(128); + let (cmd_tx, rx) = tokio::sync::mpsc::channel(128); + + Ok(ConstructedOperator::from_operator(Box::new( + RedisSinkFunc { + serializer: ArrowSerializer::new( + config.format.expect("redis table must have a format"), + ), + target, + client, + cmd_q: Some((cmd_tx, cmd_rx)), + tx, + rx, + key_index: None, + hash_index: None, + }, + ))) + } + TableType::Lookup { .. } => { + bail!("Cannot construct a lookup table as an operator"); + } + } + } + + fn make_lookup( + &self, + profile: Self::ProfileT, + _: Self::TableT, + config: OperatorConfig, + schema: Arc, + ) -> anyhow::Result> { + Ok(Box::new(RedisLookup { + deserializer: ArrowDeserializer::for_lookup( + config + .format + .ok_or_else(|| anyhow!("Redis table must have a format"))?, + schema, + &config.metadata_fields, + config.bad_data.unwrap_or_default(), + Arc::new(FailingSchemaResolver::new()), + ), + client: RedisClient::new(&profile)?, + connection: None, + metadata_fields: config.metadata_fields, + })) } } diff --git a/crates/arroyo-connectors/src/redis/operator/mod.rs b/crates/arroyo-connectors/src/redis/operator/mod.rs deleted file mode 100644 index 0ecbfb920..000000000 --- a/crates/arroyo-connectors/src/redis/operator/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod sink; diff --git a/crates/arroyo-connectors/src/redis/operator/sink.rs b/crates/arroyo-connectors/src/redis/sink.rs similarity index 82% rename from crates/arroyo-connectors/src/redis/operator/sink.rs rename to crates/arroyo-connectors/src/redis/sink.rs index 93d9e76ba..25eefa475 100644 --- a/crates/arroyo-connectors/src/redis/operator/sink.rs +++ b/crates/arroyo-connectors/src/redis/sink.rs @@ -1,4 +1,4 @@ -use crate::redis::{ListOperation, RedisClient, RedisTable, TableType, Target}; +use crate::redis::{ListOperation, RedisClient, Target}; use arrow::array::{AsArray, RecordBatch}; use arroyo_formats::ser::ArrowSerializer; use arroyo_operator::context::{Collector, ErrorReporter, OperatorContext}; @@ -19,8 +19,8 @@ const FLUSH_BYTES: usize = 10 * 1024 * 1024; pub struct RedisSinkFunc { pub serializer: ArrowSerializer, - pub table: RedisTable, - pub client: RedisClient, + pub target: Target, + pub(crate) client: RedisClient, pub cmd_q: Option<(Sender, Receiver)>, pub rx: Receiver, @@ -229,19 +229,19 @@ impl ArrowOperator for RedisSinkFunc { } async fn on_start(&mut self, ctx: &mut OperatorContext) { - match &self.table.connector_type { - TableType::Target(Target::ListTable { + match &self.target { + Target::ListTable { list_key_column: Some(key), .. - }) - | TableType::Target(Target::StringTable { + } + | Target::StringTable { key_column: Some(key), .. - }) - | TableType::Target(Target::HashTable { + } + | Target::HashTable { hash_key_column: Some(key), .. - }) => { + } => { self.key_index = Some( ctx.in_schemas .first() @@ -258,9 +258,9 @@ impl ArrowOperator for RedisSinkFunc { _ => {} } - if let TableType::Target(Target::HashTable { + if let Target::HashTable { hash_field_column, .. - }) = &self.table.connector_type + } = &self.target { self.hash_index = Some(ctx.in_schemas.first().expect("no in-schema for redis sink!") .schema @@ -282,17 +282,15 @@ impl ArrowOperator for RedisSinkFunc { size_estimate: 0, last_flushed: Instant::now(), max_push_keys: HashSet::new(), - behavior: match self.table.connector_type { - TableType::Target(Target::StringTable { ttl_secs, .. }) => { - RedisBehavior::Set { - ttl: ttl_secs.map(|t| t.get() as usize), - } - } - TableType::Target(Target::ListTable { + behavior: match &self.target { + Target::StringTable { ttl_secs, .. } => RedisBehavior::Set { + ttl: ttl_secs.map(|t| t.get() as usize), + }, + Target::ListTable { max_length, operation, .. - }) => { + } => { let max = max_length.map(|x| x.get() as usize); match operation { ListOperation::Append => { @@ -303,7 +301,7 @@ impl ArrowOperator for RedisSinkFunc { } } } - TableType::Target(Target::HashTable { .. }) => RedisBehavior::Hash, + Target::HashTable { .. } => RedisBehavior::Hash, }, } .start(); @@ -328,39 +326,37 @@ impl ArrowOperator for RedisSinkFunc { _: &mut dyn Collector, ) { for (i, value) in self.serializer.serialize(&batch).enumerate() { - match &self.table.connector_type { - TableType::Target(target) => match &target { - Target::StringTable { key_prefix, .. } => { - let key = self.make_key(key_prefix, &batch, i); - self.tx - .send(RedisCmd::Data { key, value }) - .await - .expect("Redis writer panicked"); - } - Target::ListTable { list_prefix, .. } => { - let key = self.make_key(list_prefix, &batch, i); + match &self.target { + Target::StringTable { key_prefix, .. } => { + let key = self.make_key(key_prefix, &batch, i); + self.tx + .send(RedisCmd::Data { key, value }) + .await + .expect("Redis writer panicked"); + } + Target::ListTable { list_prefix, .. } => { + let key = self.make_key(list_prefix, &batch, i); - self.tx - .send(RedisCmd::Data { key, value }) - .await - .expect("Redis writer panicked"); - } - Target::HashTable { - hash_key_prefix, .. - } => { - let key = self.make_key(hash_key_prefix, &batch, i); - let field = batch - .column(self.hash_index.expect("no hash index")) - .as_string::() - .value(i) - .to_string(); - - self.tx - .send(RedisCmd::HData { key, field, value }) - .await - .expect("Redis writer panicked"); - } - }, + self.tx + .send(RedisCmd::Data { key, value }) + .await + .expect("Redis writer panicked"); + } + Target::HashTable { + hash_key_prefix, .. + } => { + let key = self.make_key(hash_key_prefix, &batch, i); + let field = batch + .column(self.hash_index.expect("no hash index")) + .as_string::() + .value(i) + .to_string(); + + self.tx + .send(RedisCmd::HData { key, field, value }) + .await + .expect("Redis writer panicked"); + } }; } } diff --git a/crates/arroyo-connectors/src/redis/table.json b/crates/arroyo-connectors/src/redis/table.json index e1dd45234..be1be2212 100644 --- a/crates/arroyo-connectors/src/redis/table.json +++ b/crates/arroyo-connectors/src/redis/table.json @@ -114,6 +114,15 @@ "target" ], "additionalProperties": false + }, + { + "type": "object", + "title": "Lookup", + "properties": { + "lookup": { + "type": "object" + } + } } ] } diff --git a/crates/arroyo-connectors/src/single_file/source.rs b/crates/arroyo-connectors/src/single_file/source.rs index 800d21fd9..d5f79f8fe 100644 --- a/crates/arroyo-connectors/src/single_file/source.rs +++ b/crates/arroyo-connectors/src/single_file/source.rs @@ -100,6 +100,7 @@ impl SourceOperator for SingleFileSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &[], ); let state: &mut arroyo_state::tables::global_keyed_map::GlobalKeyedView = diff --git a/crates/arroyo-connectors/src/sse/operator.rs b/crates/arroyo-connectors/src/sse/operator.rs index df00ee4f9..0095447fb 100644 --- a/crates/arroyo-connectors/src/sse/operator.rs +++ b/crates/arroyo-connectors/src/sse/operator.rs @@ -149,6 +149,7 @@ impl SSESourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &[], ); let mut client = eventsource_client::ClientBuilder::for_url(&self.url).unwrap(); diff --git a/crates/arroyo-connectors/src/websocket/operator.rs b/crates/arroyo-connectors/src/websocket/operator.rs index c009dd548..a477079c9 100644 --- a/crates/arroyo-connectors/src/websocket/operator.rs +++ b/crates/arroyo-connectors/src/websocket/operator.rs @@ -65,6 +65,7 @@ impl SourceOperator for WebsocketSourceFunc { self.format.clone(), self.framing.clone(), self.bad_data.clone(), + &[], ); match self.run_int(ctx, collector).await { diff --git a/crates/arroyo-datastream/src/logical.rs b/crates/arroyo-datastream/src/logical.rs index f4689c995..f393d00c0 100644 --- a/crates/arroyo-datastream/src/logical.rs +++ b/crates/arroyo-datastream/src/logical.rs @@ -32,6 +32,7 @@ pub enum OperatorName { AsyncUdf, Join, InstantJoin, + LookupJoin, WindowFunction, TumblingWindowAggregate, SlidingWindowAggregate, @@ -133,16 +134,22 @@ impl TryFrom for PipelineGraph { #[derive(Clone, Debug, Eq, PartialEq)] pub struct LogicalEdge { pub edge_type: LogicalEdgeType, - pub schema: ArroyoSchema, + pub schema: Arc, } impl LogicalEdge { pub fn new(edge_type: LogicalEdgeType, schema: ArroyoSchema) -> Self { - LogicalEdge { edge_type, schema } + LogicalEdge { + edge_type, + schema: Arc::new(schema), + } } pub fn project_all(edge_type: LogicalEdgeType, schema: ArroyoSchema) -> Self { - LogicalEdge { edge_type, schema } + LogicalEdge { + edge_type, + schema: Arc::new(schema), + } } } @@ -156,7 +163,7 @@ pub struct ChainedLogicalOperator { #[derive(Clone, Debug)] pub struct OperatorChain { pub(crate) operators: Vec, - pub(crate) edges: Vec, + pub(crate) edges: Vec>, } impl OperatorChain { @@ -167,7 +174,9 @@ impl OperatorChain { } } - pub fn iter(&self) -> impl Iterator)> { + pub fn iter( + &self, + ) -> impl Iterator>)> { self.operators .iter() .zip_longest(self.edges.iter()) @@ -177,10 +186,10 @@ impl OperatorChain { pub fn iter_mut( &mut self, - ) -> impl Iterator)> { + ) -> impl Iterator>)> { self.operators .iter_mut() - .zip_longest(self.edges.iter_mut()) + .zip_longest(self.edges.iter()) .map(|e| e.left_and_right()) .map(|(l, r)| (l.unwrap(), r)) } @@ -376,6 +385,7 @@ impl LogicalProgram { OperatorName::Join => "join-with-expiration".to_string(), OperatorName::InstantJoin => "windowed-join".to_string(), OperatorName::WindowFunction => "sql-window-function".to_string(), + OperatorName::LookupJoin => "lookup-join".to_string(), OperatorName::TumblingWindowAggregate => { "sql-tumbling-window-aggregate".to_string() } @@ -436,7 +446,7 @@ impl TryFrom for LogicalProgram { edges: node .edges .into_iter() - .map(|e| Ok(e.try_into()?)) + .map(|e| Ok(Arc::new(e.try_into()?))) .collect::>>()?, }, parallelism: node.parallelism as usize, @@ -454,7 +464,7 @@ impl TryFrom for LogicalProgram { target, LogicalEdge { edge_type: edge.edge_type().into(), - schema: schema.clone().try_into()?, + schema: Arc::new(schema.clone().try_into()?), }, ); } @@ -621,7 +631,7 @@ impl From for ArrowProgram { .operator_chain .edges .iter() - .map(|edge| edge.clone().into()) + .map(|edge| (**edge).clone().into()) .collect(), } }) @@ -637,7 +647,7 @@ impl From for ArrowProgram { api::ArrowEdge { source: source.index() as i32, target: target.index() as i32, - schema: Some(edge.schema.clone().into()), + schema: Some((*edge.schema).clone().into()), edge_type: edge_type as i32, } }) diff --git a/crates/arroyo-formats/src/avro/de.rs b/crates/arroyo-formats/src/avro/de.rs index f371bf70c..7e90deffc 100644 --- a/crates/arroyo-formats/src/avro/de.rs +++ b/crates/arroyo-formats/src/avro/de.rs @@ -132,8 +132,7 @@ pub(crate) fn avro_to_json(value: AvroValue) -> JsonValue { mod tests { use crate::avro::schema::to_arrow; use crate::de::ArrowDeserializer; - use arrow_array::builder::{make_builder, ArrayBuilder}; - use arrow_array::RecordBatch; + use arrow_json::writer::record_batch_to_vec; use arrow_schema::{DataType, Field, Schema, TimeUnit}; use arroyo_rpc::df::ArroyoSchema; @@ -214,7 +213,7 @@ mod tests { fn deserializer_with_schema( format: AvroFormat, writer_schema: Option<&str>, - ) -> (ArrowDeserializer, Vec>, ArroyoSchema) { + ) -> (ArrowDeserializer, ArroyoSchema) { let arrow_schema = if format.into_unstructured_json { Schema::new(vec![Field::new("value", DataType::Utf8, false)]) } else { @@ -239,13 +238,6 @@ mod tests { ArroyoSchema::from_schema_keys(Arc::new(Schema::new(fields)), vec![]).unwrap() }; - let builders: Vec<_> = arroyo_schema - .schema - .fields - .iter() - .map(|f| make_builder(f.data_type(), 8)) - .collect(); - let resolver: Arc = if let Some(schema) = &writer_schema { Arc::new(FixedSchemaResolver::new( if format.confluent_schema_registry { @@ -263,11 +255,11 @@ mod tests { ArrowDeserializer::with_schema_resolver( Format::Avro(format), None, - arroyo_schema.clone(), + Arc::new(arroyo_schema.clone()), + &[], BadData::Fail {}, resolver, ), - builders, arroyo_schema, ) } @@ -277,23 +269,14 @@ mod tests { writer_schema: Option<&str>, message: &[u8], ) -> Vec> { - let (mut deserializer, mut builders, arroyo_schema) = - deserializer_with_schema(format.clone(), writer_schema); + let (mut deserializer, _) = deserializer_with_schema(format.clone(), writer_schema); let errors = deserializer - .deserialize_slice(&mut builders, message, SystemTime::now(), None) + .deserialize_slice(message, SystemTime::now(), None) .await; assert_eq!(errors, vec![]); - let batch = if format.into_unstructured_json { - RecordBatch::try_new( - arroyo_schema.schema, - builders.into_iter().map(|mut b| b.finish()).collect(), - ) - .unwrap() - } else { - deserializer.flush_buffer().unwrap().unwrap() - }; + let batch = deserializer.flush_buffer().unwrap().unwrap(); record_batch_to_vec(&batch, true, arrow_json::writer::TimestampFormat::RFC3339) .unwrap() diff --git a/crates/arroyo-formats/src/de.rs b/crates/arroyo-formats/src/de.rs index 3938fe0fc..7d089b9c6 100644 --- a/crates/arroyo-formats/src/de.rs +++ b/crates/arroyo-formats/src/de.rs @@ -4,19 +4,22 @@ use crate::{proto, should_flush}; use arrow::array::{Int32Builder, Int64Builder}; use arrow::compute::kernels; use arrow_array::builder::{ - ArrayBuilder, GenericByteBuilder, StringBuilder, TimestampNanosecondBuilder, + make_builder, ArrayBuilder, BinaryBuilder, GenericByteBuilder, StringBuilder, + TimestampNanosecondBuilder, UInt64Builder, }; use arrow_array::types::GenericBinaryType; -use arrow_array::RecordBatch; +use arrow_array::{ArrayRef, BooleanArray, RecordBatch}; +use arrow_schema::{DataType, Schema, SchemaRef}; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::formats::{ AvroFormat, BadData, Format, Framing, FramingMethod, JsonFormat, ProtobufFormat, }; use arroyo_rpc::schema_resolver::{FailingSchemaResolver, FixedSchemaResolver, SchemaResolver}; -use arroyo_types::{to_nanos, SourceError}; +use arroyo_rpc::{MetadataField, TIMESTAMP_FIELD}; +use arroyo_types::{to_nanos, SourceError, LOOKUP_KEY_INDEX_FIELD}; use prost_reflect::DescriptorPool; use serde_json::Value; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::time::{Instant, SystemTime}; use tokio::sync::Mutex; @@ -24,11 +27,44 @@ use tokio::sync::Mutex; #[derive(Debug, Clone)] pub enum FieldValueType<'a> { Int64(i64), + UInt64(u64), Int32(i32), String(&'a str), // Extend with more types as needed } +struct ContextBuffer { + buffer: Vec>, + created: Instant, +} + +impl ContextBuffer { + fn new(schema: SchemaRef) -> Self { + let buffer = schema + .fields + .iter() + .map(|f| make_builder(f.data_type(), 16)) + .collect(); + + Self { + buffer, + created: Instant::now(), + } + } + + pub fn size(&self) -> usize { + self.buffer.iter().map(|b| b.len()).max().unwrap() + } + + pub fn should_flush(&self) -> bool { + should_flush(self.size(), self.created) + } + + pub fn finish(&mut self) -> Vec { + self.buffer.iter_mut().map(|a| a.finish()).collect() + } +} + pub struct FramingIterator<'a> { framing: Option>, buf: &'a [u8], @@ -80,24 +116,155 @@ impl<'a> Iterator for FramingIterator<'a> { } } +enum BufferDecoder { + Buffer(ContextBuffer), + JsonDecoder { + decoder: arrow::json::reader::Decoder, + buffered_count: usize, + buffered_since: Instant, + }, +} + +impl BufferDecoder { + fn should_flush(&self) -> bool { + match self { + BufferDecoder::Buffer(b) => b.should_flush(), + BufferDecoder::JsonDecoder { + buffered_count, + buffered_since, + .. + } => should_flush(*buffered_count, *buffered_since), + } + } + + #[allow(clippy::type_complexity)] + fn flush( + &mut self, + bad_data: &BadData, + ) -> Option, Option), SourceError>> { + match self { + BufferDecoder::Buffer(buffer) => { + if buffer.size() > 0 { + Some(Ok((buffer.finish(), None))) + } else { + None + } + } + BufferDecoder::JsonDecoder { + decoder, + buffered_since, + buffered_count, + } => { + *buffered_since = Instant::now(); + *buffered_count = 0; + Some(match bad_data { + BadData::Fail { .. } => decoder + .flush() + .map_err(|e| { + SourceError::bad_data(format!("JSON does not match schema: {:?}", e)) + }) + .transpose()? + .map(|batch| (batch.columns().to_vec(), None)), + BadData::Drop { .. } => decoder + .flush_with_bad_data() + .map_err(|e| { + SourceError::bad_data(format!( + "Something went wrong decoding JSON: {:?}", + e + )) + }) + .transpose()? + .map(|(batch, mask, _)| (batch.columns().to_vec(), Some(mask))), + }) + } + } + } + + fn decode_json(&mut self, msg: &[u8]) -> Result<(), SourceError> { + match self { + BufferDecoder::Buffer(_) => { + unreachable!("Tried to decode JSON for non-JSON deserializer"); + } + BufferDecoder::JsonDecoder { + decoder, + buffered_count, + .. + } => { + decoder + .decode(msg) + .map_err(|e| SourceError::bad_data(format!("invalid JSON: {:?}", e)))?; + + *buffered_count += 1; + + Ok(()) + } + } + } + + fn get_buffer(&mut self) -> &mut ContextBuffer { + match self { + BufferDecoder::Buffer(buffer) => buffer, + BufferDecoder::JsonDecoder { .. } => { + panic!("tried to get a raw buffer from a JSON deserializer"); + } + } + } + + fn push_null(&mut self, schema: &Schema) { + match self { + BufferDecoder::Buffer(b) => { + for (f, b) in schema.fields.iter().zip(b.buffer.iter_mut()) { + match f.data_type() { + DataType::Binary => { + b.as_any_mut() + .downcast_mut::() + .unwrap() + .append_null(); + } + DataType::Utf8 => { + b.as_any_mut() + .downcast_mut::() + .unwrap() + .append_null(); + } + dt => { + unreachable!("unsupported datatype {}", dt); + } + } + } + } + BufferDecoder::JsonDecoder { + decoder, + buffered_count, + .. + } => { + decoder.decode("{}".as_bytes()).unwrap(); + + *buffered_count += 1; + } + } + } +} + pub struct ArrowDeserializer { format: Arc, framing: Option>, - schema: ArroyoSchema, + final_schema: Arc, + decoder_schema: Arc, bad_data: BadData, - json_decoder: Option<(arrow::json::reader::Decoder, TimestampNanosecondBuilder)>, - buffered_count: usize, - buffered_since: Instant, schema_registry: Arc>>, proto_pool: DescriptorPool, schema_resolver: Arc, additional_fields_builder: Option>>, + timestamp_builder: Option<(usize, TimestampNanosecondBuilder)>, + buffer_decoder: BufferDecoder, } impl ArrowDeserializer { pub fn new( format: Format, - schema: ArroyoSchema, + schema: Arc, + metadata_fields: &[MetadataField], framing: Option, bad_data: BadData, ) -> Self { @@ -112,13 +279,59 @@ impl ArrowDeserializer { Arc::new(FailingSchemaResolver::new()) as Arc }; - Self::with_schema_resolver(format, framing, schema, bad_data, resolver) + Self::with_schema_resolver(format, framing, schema, metadata_fields, bad_data, resolver) } pub fn with_schema_resolver( format: Format, framing: Option, - schema: ArroyoSchema, + schema: Arc, + metadata_fields: &[MetadataField], + bad_data: BadData, + schema_resolver: Arc, + ) -> Self { + Self::with_schema_resolver_and_raw_schema( + format, + framing, + schema.schema.clone(), + Some(schema.timestamp_index), + metadata_fields, + bad_data, + schema_resolver, + ) + } + + pub fn for_lookup( + format: Format, + schema: Arc, + metadata_fields: &[MetadataField], + bad_data: BadData, + schema_resolver: Arc, + ) -> Self { + let mut metadata_fields = metadata_fields.to_vec(); + metadata_fields.push(MetadataField { + field_name: LOOKUP_KEY_INDEX_FIELD.to_string(), + key: LOOKUP_KEY_INDEX_FIELD.to_string(), + data_type: Some(DataType::UInt64), + }); + + Self::with_schema_resolver_and_raw_schema( + format, + None, + schema, + None, + &metadata_fields, + bad_data, + schema_resolver, + ) + } + + fn with_schema_resolver_and_raw_schema( + format: Format, + framing: Option, + schema: Arc, + timestamp_idx: Option, + metadata_fields: &[MetadataField], bad_data: BadData, schema_resolver: Arc, ) -> Self { @@ -132,144 +345,214 @@ impl ArrowDeserializer { DescriptorPool::global() }; - Self { - json_decoder: matches!( - format, - Format::Json(..) - | Format::Avro(AvroFormat { - into_unstructured_json: false, - .. - }) - | Format::Protobuf(ProtobufFormat { - into_unstructured_json: false, - .. - }) - ) - .then(|| { - // exclude the timestamp field - ( - arrow_json::reader::ReaderBuilder::new(Arc::new( - schema.schema_without_timestamp(), - )) + let metadata_names: HashSet<_> = metadata_fields.iter().map(|f| &f.field_name).collect(); + + let schema_without_additional = { + let fields = schema + .fields() + .iter() + .filter(|f| !metadata_names.contains(f.name()) && f.name() != TIMESTAMP_FIELD) + .cloned() + .collect::>(); + Arc::new(Schema::new_with_metadata(fields, schema.metadata.clone())) + }; + + let buffer_decoder = match format { + Format::Json(JsonFormat { + unstructured: false, + .. + }) + | Format::Avro(AvroFormat { + into_unstructured_json: false, + .. + }) + | Format::Protobuf(ProtobufFormat { + into_unstructured_json: false, + .. + }) => BufferDecoder::JsonDecoder { + decoder: arrow_json::reader::ReaderBuilder::new(schema_without_additional.clone()) .with_limit_to_batch_size(false) .with_strict_mode(false) .with_allow_bad_data(matches!(bad_data, BadData::Drop { .. })) .build_decoder() .unwrap(), - TimestampNanosecondBuilder::new(), - ) - }), + buffered_count: 0, + buffered_since: Instant::now(), + }, + _ => BufferDecoder::Buffer(ContextBuffer::new(schema_without_additional.clone())), + }; + + Self { format: Arc::new(format), framing: framing.map(Arc::new), - schema, + buffer_decoder, + timestamp_builder: timestamp_idx + .map(|i| (i, TimestampNanosecondBuilder::with_capacity(128))), + final_schema: schema, + decoder_schema: schema_without_additional, schema_registry: Arc::new(Mutex::new(HashMap::new())), bad_data, schema_resolver, proto_pool, - buffered_count: 0, - buffered_since: Instant::now(), additional_fields_builder: None, } } + #[must_use] pub async fn deserialize_slice( &mut self, - buffer: &mut [Box], msg: &[u8], timestamp: SystemTime, - additional_fields: Option<&HashMap<&String, FieldValueType<'_>>>, + additional_fields: Option<&HashMap<&str, FieldValueType<'_>>>, ) -> Vec { - match &*self.format { - Format::Avro(_) => self.deserialize_slice_avro(buffer, msg, timestamp).await, - _ => FramingIterator::new(self.framing.clone(), msg) - .map(|t| self.deserialize_single(buffer, t, timestamp, additional_fields)) - .filter_map(|t| t.err()) - .collect(), + self.deserialize_slice_int(msg, Some(timestamp), additional_fields) + .await + } + + #[must_use] + pub async fn deserialize_without_timestamp( + &mut self, + msg: &[u8], + additional_fields: Option<&HashMap<&str, FieldValueType<'_>>>, + ) -> Vec { + self.deserialize_slice_int(msg, None, additional_fields) + .await + } + + pub fn deserialize_null( + &mut self, + additional_fields: Option<&HashMap<&str, FieldValueType<'_>>>, + ) { + self.buffer_decoder.push_null(&self.decoder_schema); + self.add_additional_fields(additional_fields, 1); + } + + async fn deserialize_slice_int( + &mut self, + msg: &[u8], + timestamp: Option, + additional_fields: Option<&HashMap<&str, FieldValueType<'_>>>, + ) -> Vec { + let (count, errors) = match &*self.format { + Format::Avro(_) => self.deserialize_slice_avro(msg).await, + _ => { + let mut count = 0; + let errors = FramingIterator::new(self.framing.clone(), msg) + .map(|t| self.deserialize_single(t)) + .filter_map(|t| { + if t.is_ok() { + count += 1; + } + t.err() + }) + .collect(); + (count, errors) + } + }; + + self.add_additional_fields(additional_fields, count); + + if let Some(timestamp) = timestamp { + let (_, b) = self + .timestamp_builder + .as_mut() + .expect("tried to serialize timestamp to a schema without a timestamp column"); + + for _ in 0..count { + b.append_value(to_nanos(timestamp) as i64); + } + } + + errors + } + + fn add_additional_fields( + &mut self, + additional_fields: Option<&HashMap<&str, FieldValueType<'_>>>, + count: usize, + ) { + if let Some(additional_fields) = additional_fields { + if self.additional_fields_builder.is_none() { + let mut builders = HashMap::new(); + for (key, value) in additional_fields.iter() { + let builder: Box = match value { + FieldValueType::Int32(_) => Box::new(Int32Builder::new()), + FieldValueType::Int64(_) => Box::new(Int64Builder::new()), + FieldValueType::UInt64(_) => Box::new(UInt64Builder::new()), + FieldValueType::String(_) => Box::new(StringBuilder::new()), + }; + builders.insert(key.to_string(), builder); + } + self.additional_fields_builder = Some(builders); + } + + let builders = self.additional_fields_builder.as_mut().unwrap(); + + for (k, v) in additional_fields { + add_additional_fields(builders, k, v, count); + } } } pub fn should_flush(&self) -> bool { - should_flush(self.buffered_count, self.buffered_since) + self.buffer_decoder.should_flush() } pub fn flush_buffer(&mut self) -> Option> { - let (decoder, timestamp) = self.json_decoder.as_mut()?; - self.buffered_since = Instant::now(); - self.buffered_count = 0; - match self.bad_data { - BadData::Fail { .. } => Some( - decoder - .flush() - .map_err(|e| { - SourceError::bad_data(format!("JSON does not match schema: {:?}", e)) - }) - .transpose()? - .map(|batch| { - let mut columns = batch.columns().to_vec(); - columns.insert(self.schema.timestamp_index, Arc::new(timestamp.finish())); - flush_additional_fields_builders( - &mut self.additional_fields_builder, - &self.schema, - &mut columns, - ); - RecordBatch::try_new(self.schema.schema.clone(), columns).unwrap() - }), - ), - BadData::Drop { .. } => Some( - decoder - .flush_with_bad_data() - .map_err(|e| { - SourceError::bad_data(format!( - "Something went wrong decoding JSON: {:?}", - e - )) - }) - .transpose()? - .map(|(batch, mask, _)| { - let mut columns = batch.columns().to_vec(); - let timestamp = - kernels::filter::filter(×tamp.finish(), &mask).unwrap(); - - columns.insert(self.schema.timestamp_index, Arc::new(timestamp)); - flush_additional_fields_builders( - &mut self.additional_fields_builder, - &self.schema, - &mut columns, - ); - RecordBatch::try_new(self.schema.schema.clone(), columns).unwrap() - }), - ), + let (arrays, error_mask) = match self.buffer_decoder.flush(&self.bad_data)? { + Ok((a, b)) => (a, b), + Err(e) => return Some(Err(e)), + }; + + let mut arrays: HashMap<_, _> = arrays + .into_iter() + .zip(self.decoder_schema.fields.iter()) + .map(|(a, f)| (f.name().as_str(), a)) + .collect(); + + if let Some(additional_fields) = &mut self.additional_fields_builder { + for (name, builder) in additional_fields { + let mut array = builder.finish(); + if let Some(error_mask) = &error_mask { + array = kernels::filter::filter(&array, error_mask).unwrap(); + } + + arrays.insert(name.as_str(), array); + } + }; + + if let Some((_, timestamp)) = &mut self.timestamp_builder { + let array = if let Some(error_mask) = &error_mask { + kernels::filter::filter(×tamp.finish(), error_mask).unwrap() + } else { + Arc::new(timestamp.finish()) + }; + + arrays.insert(TIMESTAMP_FIELD, array); } + + let arrays = self + .final_schema + .fields + .iter() + .map(|f| arrays.get(f.name().as_str()).unwrap().clone()) + .collect(); + + Some(Ok( + RecordBatch::try_new(self.final_schema.clone(), arrays).unwrap() + )) } - fn deserialize_single( - &mut self, - buffer: &mut [Box], - msg: &[u8], - timestamp: SystemTime, - additional_fields: Option<&HashMap<&String, FieldValueType>>, - ) -> Result<(), SourceError> { + fn deserialize_single(&mut self, msg: &[u8]) -> Result<(), SourceError> { match &*self.format { Format::RawString(_) | Format::Json(JsonFormat { unstructured: true, .. }) => { - self.deserialize_raw_string(buffer, msg); - add_timestamp(buffer, self.schema.timestamp_index, timestamp); - if let Some(fields) = additional_fields { - for (k, v) in fields.iter() { - add_additional_fields(buffer, &self.schema, k, v); - } - } + self.deserialize_raw_string(msg); } Format::RawBytes(_) => { - self.deserialize_raw_bytes(buffer, msg); - add_timestamp(buffer, self.schema.timestamp_index, timestamp); - if let Some(fields) = additional_fields { - for (k, v) in fields.iter() { - add_additional_fields(buffer, &self.schema, k, v); - } - } + self.deserialize_raw_bytes(msg); } Format::Json(json) => { let msg = if json.confluent_schema_registry { @@ -278,62 +561,17 @@ impl ArrowDeserializer { msg }; - let Some((decoder, timestamp_builder)) = &mut self.json_decoder else { - panic!("json decoder not initialized"); - }; - - if self.additional_fields_builder.is_none() { - if let Some(fields) = additional_fields.as_ref() { - let mut builders = HashMap::new(); - for (key, value) in fields.iter() { - let builder: Box = match value { - FieldValueType::Int32(_) => Box::new(Int32Builder::new()), - FieldValueType::Int64(_) => Box::new(Int64Builder::new()), - FieldValueType::String(_) => Box::new(StringBuilder::new()), - }; - builders.insert(key, builder); - } - self.additional_fields_builder = Some( - builders - .into_iter() - .map(|(k, v)| ((*k).clone(), v)) - .collect(), - ); - } - } - - decoder - .decode(msg) - .map_err(|e| SourceError::bad_data(format!("invalid JSON: {:?}", e)))?; - timestamp_builder.append_value(to_nanos(timestamp) as i64); - - add_additional_fields_using_builder( - additional_fields, - &mut self.additional_fields_builder, - ); - self.buffered_count += 1; + self.buffer_decoder.decode_json(msg)?; } Format::Protobuf(proto) => { let json = proto::de::deserialize_proto(&mut self.proto_pool, proto, msg)?; if proto.into_unstructured_json { - self.decode_into_json(buffer, json, timestamp); + self.decode_into_json(json); } else { - let Some((decoder, timestamp_builder)) = &mut self.json_decoder else { - panic!("json decoder not initialized"); - }; - - decoder - .decode(json.to_string().as_bytes()) + self.buffer_decoder + .decode_json(json.to_string().as_bytes()) .map_err(|e| SourceError::bad_data(format!("invalid JSON: {:?}", e)))?; - timestamp_builder.append_value(to_nanos(timestamp) as i64); - - add_additional_fields_using_builder( - additional_fields, - &mut self.additional_fields_builder, - ); - - self.buffered_count += 1; } } Format::Avro(_) => unreachable!("this should not be called for avro"), @@ -343,33 +581,20 @@ impl ArrowDeserializer { Ok(()) } - fn decode_into_json( - &mut self, - builders: &mut [Box], - value: Value, - timestamp: SystemTime, - ) { + fn decode_into_json(&mut self, value: Value) { let (idx, _) = self - .schema - .schema + .decoder_schema .column_with_name("value") .expect("no 'value' column for unstructured avro"); - let array = builders[idx] + let array = self.buffer_decoder.get_buffer().buffer[idx] .as_any_mut() .downcast_mut::() .expect("'value' column has incorrect type"); array.append_value(value.to_string()); - add_timestamp(builders, self.schema.timestamp_index, timestamp); - self.buffered_count += 1; } - pub async fn deserialize_slice_avro<'a>( - &mut self, - builders: &mut [Box], - msg: &'a [u8], - timestamp: SystemTime, - ) -> Vec { + async fn deserialize_slice_avro(&mut self, msg: &[u8]) -> (usize, Vec) { let Format::Avro(format) = &*self.format else { unreachable!("not avro"); }; @@ -384,13 +609,14 @@ impl ArrowDeserializer { { Ok(messages) => messages, Err(e) => { - return vec![e]; + return (0, vec![e]); } }; let into_json = format.into_unstructured_json; - messages + let mut count = 0; + let errors = messages .into_iter() .map(|record| { let value = record.map_err(|e| { @@ -398,49 +624,45 @@ impl ArrowDeserializer { })?; if into_json { - self.decode_into_json(builders, de::avro_to_json(value), timestamp); + self.decode_into_json(de::avro_to_json(value)); } else { // for now round-trip through json in order to handle unsupported avro features // as that allows us to rely on raw json deserialization let json = de::avro_to_json(value).to_string(); - let Some((decoder, timestamp_builder)) = &mut self.json_decoder else { - panic!("json decoder not initialized"); - }; - - decoder - .decode(json.as_bytes()) + self.buffer_decoder + .decode_json(json.as_bytes()) .map_err(|e| SourceError::bad_data(format!("invalid JSON: {:?}", e)))?; - self.buffered_count += 1; - timestamp_builder.append_value(to_nanos(timestamp) as i64); } + count += 1; + Ok(()) }) .filter_map(|r: Result<(), SourceError>| r.err()) - .collect() + .collect(); + + (count, errors) } - fn deserialize_raw_string(&mut self, buffer: &mut [Box], msg: &[u8]) { + fn deserialize_raw_string(&mut self, msg: &[u8]) { let (col, _) = self - .schema - .schema + .decoder_schema .column_with_name("value") .expect("no 'value' column for RawString format"); - buffer[col] + self.buffer_decoder.get_buffer().buffer[col] .as_any_mut() .downcast_mut::() .expect("'value' column has incorrect type") .append_value(String::from_utf8_lossy(msg)); } - fn deserialize_raw_bytes(&mut self, buffer: &mut [Box], msg: &[u8]) { + fn deserialize_raw_bytes(&mut self, msg: &[u8]) { let (col, _) = self - .schema - .schema + .decoder_schema .column_with_name("value") .expect("no 'value' column for RawBytes format"); - buffer[col] + self.buffer_decoder.get_buffer().buffer[col] .as_any_mut() .downcast_mut::>>() .expect("'value' column has incorrect type") @@ -452,111 +674,51 @@ impl ArrowDeserializer { } } -pub(crate) fn add_timestamp( - builder: &mut [Box], - idx: usize, - timestamp: SystemTime, -) { - builder[idx] - .as_any_mut() - .downcast_mut::() - .expect("_timestamp column has incorrect type") - .append_value(to_nanos(timestamp) as i64); -} - -pub(crate) fn add_additional_fields( - builder: &mut [Box], - schema: &ArroyoSchema, +fn add_additional_fields( + builders: &mut HashMap>, key: &str, value: &FieldValueType<'_>, + count: usize, ) { - let (idx, _) = schema - .schema - .column_with_name(key) - .unwrap_or_else(|| panic!("no '{}' column for additional fields", key)); + let builder = builders + .get_mut(key) + .unwrap_or_else(|| panic!("unexpected additional field '{}'", key)) + .as_any_mut(); match value { FieldValueType::Int32(i) => { - builder[idx] - .as_any_mut() + let b = builder .downcast_mut::() - .expect("additional field has incorrect type") - .append_value(*i); + .expect("additional field has incorrect type"); + + for _ in 0..count { + b.append_value(*i); + } } FieldValueType::Int64(i) => { - builder[idx] - .as_any_mut() + let b = builder .downcast_mut::() - .expect("additional field has incorrect type") - .append_value(*i); - } - FieldValueType::String(s) => { - builder[idx] - .as_any_mut() - .downcast_mut::() - .expect("additional field has incorrect type") - .append_value(s); + .expect("additional field has incorrect type"); + + for _ in 0..count { + b.append_value(*i); + } } - } -} + FieldValueType::UInt64(i) => { + let b = builder + .downcast_mut::() + .expect("additional field has incorrect type"); -pub(crate) fn add_additional_fields_using_builder( - additional_fields: Option<&HashMap<&String, FieldValueType<'_>>>, - additional_fields_builder: &mut Option>>, -) { - if let Some(fields) = additional_fields { - for (k, v) in fields.iter() { - if let Some(builder) = additional_fields_builder - .as_mut() - .and_then(|b| b.get_mut(*k)) - { - match v { - FieldValueType::Int32(i) => { - builder - .as_any_mut() - .downcast_mut::() - .expect("additional field has incorrect type") - .append_value(*i); - } - FieldValueType::Int64(i) => { - builder - .as_any_mut() - .downcast_mut::() - .expect("additional field has incorrect type") - .append_value(*i); - } - FieldValueType::String(s) => { - builder - .as_any_mut() - .downcast_mut::() - .expect("additional field has incorrect type") - .append_value(s); - } - } + for _ in 0..count { + b.append_value(*i); } } - } -} + FieldValueType::String(s) => { + let b = builder + .downcast_mut::() + .expect("additional field has incorrect type"); -pub(crate) fn flush_additional_fields_builders( - additional_fields_builder: &mut Option>>, - schema: &ArroyoSchema, - columns: &mut [Arc], -) { - if let Some(additional_fields) = additional_fields_builder.take() { - for (field_name, mut builder) in additional_fields { - if let Some((idx, _)) = schema.schema.column_with_name(&field_name) { - let expected_type = schema.schema.fields[idx].data_type(); - let built_column = builder.as_mut().finish(); - let actual_type = built_column.data_type(); - if expected_type != actual_type { - panic!( - "Type mismatch for column '{}': expected {:?}, got {:?}", - field_name, expected_type, actual_type - ); - } - columns[idx] = Arc::new(built_column); - } else { - panic!("Field '{}' not found in schema", field_name); + for _ in 0..count { + b.append_value(*s); } } } @@ -566,16 +728,15 @@ pub(crate) fn flush_additional_fields_builders( mod tests { use crate::de::{ArrowDeserializer, FieldValueType, FramingIterator}; use arrow::datatypes::Int32Type; - use arrow_array::builder::{make_builder, ArrayBuilder}; use arrow_array::cast::AsArray; use arrow_array::types::{GenericBinaryType, Int64Type, TimestampNanosecondType}; - use arrow_array::RecordBatch; - use arrow_schema::{Schema, TimeUnit}; + use arrow_schema::{DataType, Schema, TimeUnit}; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::formats::{ BadData, Format, Framing, FramingMethod, JsonFormat, NewlineDelimitedFraming, RawBytesFormat, }; + use arroyo_rpc::MetadataField; use arroyo_types::{to_nanos, SourceError}; use serde_json::json; use std::sync::Arc; @@ -651,7 +812,7 @@ mod tests { ); } - fn setup_deserializer(bad_data: BadData) -> (Vec>, ArrowDeserializer) { + fn setup_deserializer(bad_data: BadData) -> ArrowDeserializer { let schema = Arc::new(Schema::new(vec![ arrow_schema::Field::new("x", arrow_schema::DataType::Int64, true), arrow_schema::Field::new( @@ -661,15 +822,9 @@ mod tests { ), ])); - let arrays: Vec<_> = schema - .fields - .iter() - .map(|f| make_builder(f.data_type(), 16)) - .collect(); - - let schema = ArroyoSchema::from_schema_unkeyed(schema).unwrap(); + let schema = Arc::new(ArroyoSchema::from_schema_unkeyed(schema).unwrap()); - let deserializer = ArrowDeserializer::new( + ArrowDeserializer::new( Format::Json(JsonFormat { confluent_schema_registry: false, schema_id: None, @@ -679,38 +834,27 @@ mod tests { timestamp_format: Default::default(), }), schema, + &[], None, bad_data, - ); - - (arrays, deserializer) + ) } #[tokio::test] async fn test_bad_data_drop() { - let (mut arrays, mut deserializer) = setup_deserializer(BadData::Drop {}); + let mut deserializer = setup_deserializer(BadData::Drop {}); let now = SystemTime::now(); assert_eq!( deserializer - .deserialize_slice( - &mut arrays[..], - json!({ "x": 5 }).to_string().as_bytes(), - now, - None, - ) + .deserialize_slice(json!({ "x": 5 }).to_string().as_bytes(), now, None,) .await, vec![] ); assert_eq!( deserializer - .deserialize_slice( - &mut arrays[..], - json!({ "x": "hello" }).to_string().as_bytes(), - now, - None, - ) + .deserialize_slice(json!({ "x": "hello" }).to_string().as_bytes(), now, None,) .await, vec![] ); @@ -728,12 +872,11 @@ mod tests { #[tokio::test] async fn test_bad_data_fail() { - let (mut arrays, mut deserializer) = setup_deserializer(BadData::Fail {}); + let mut deserializer = setup_deserializer(BadData::Fail {}); assert_eq!( deserializer .deserialize_slice( - &mut arrays[..], json!({ "x": 5 }).to_string().as_bytes(), SystemTime::now(), None, @@ -744,7 +887,6 @@ mod tests { assert_eq!( deserializer .deserialize_slice( - &mut arrays[..], json!({ "x": "hello" }).to_string().as_bytes(), SystemTime::now(), None, @@ -769,29 +911,23 @@ mod tests { ), ])); - let mut arrays: Vec<_> = schema - .fields - .iter() - .map(|f| make_builder(f.data_type(), 16)) - .collect(); - - let arroyo_schema = ArroyoSchema::from_schema_unkeyed(schema.clone()).unwrap(); + let arroyo_schema = Arc::new(ArroyoSchema::from_schema_unkeyed(schema.clone()).unwrap()); let mut deserializer = ArrowDeserializer::new( Format::RawBytes(RawBytesFormat {}), arroyo_schema, + &[], None, BadData::Fail {}, ); let time = SystemTime::now(); let result = deserializer - .deserialize_slice(&mut arrays, &[0, 1, 2, 3, 4, 5], time, None) + .deserialize_slice(&[0, 1, 2, 3, 4, 5], time, None) .await; assert!(result.is_empty()); - let arrays: Vec<_> = arrays.into_iter().map(|mut a| a.finish()).collect(); - let batch = RecordBatch::try_new(schema, arrays).unwrap(); + let batch = deserializer.flush_buffer().unwrap().unwrap(); assert_eq!(batch.num_rows(), 1); assert_eq!( @@ -809,7 +945,7 @@ mod tests { } #[tokio::test] - async fn test_additional_fields_deserialisation() { + async fn test_additional_fields_deserialization() { let schema = Arc::new(Schema::new(vec![ arrow_schema::Field::new("x", arrow_schema::DataType::Int64, true), arrow_schema::Field::new("y", arrow_schema::DataType::Int32, true), @@ -821,13 +957,7 @@ mod tests { ), ])); - let mut arrays: Vec<_> = schema - .fields - .iter() - .map(|f| make_builder(f.data_type(), 16)) - .collect(); - - let arroyo_schema = ArroyoSchema::from_schema_unkeyed(schema.clone()).unwrap(); + let arroyo_schema = Arc::new(ArroyoSchema::from_schema_unkeyed(schema.clone()).unwrap()); let mut deserializer = ArrowDeserializer::new( Format::Json(JsonFormat { @@ -839,21 +969,29 @@ mod tests { timestamp_format: Default::default(), }), arroyo_schema, + &[ + MetadataField { + field_name: "y".to_string(), + key: "y".to_string(), + data_type: Some(DataType::Int64), + }, + MetadataField { + field_name: "z".to_string(), + key: "z".to_string(), + data_type: Some(DataType::Utf8), + }, + ], None, BadData::Drop {}, ); let time = SystemTime::now(); let mut additional_fields = std::collections::HashMap::new(); - let binding = "y".to_string(); - additional_fields.insert(&binding, FieldValueType::Int32(5)); - let z_value = "hello".to_string(); - let binding = "z".to_string(); - additional_fields.insert(&binding, FieldValueType::String(&z_value)); + additional_fields.insert("y", FieldValueType::Int32(5)); + additional_fields.insert("z", FieldValueType::String("hello")); let result = deserializer .deserialize_slice( - &mut arrays, json!({ "x": 5 }).to_string().as_bytes(), time, Some(&additional_fields), @@ -862,6 +1000,7 @@ mod tests { assert!(result.is_empty()); let batch = deserializer.flush_buffer().unwrap().unwrap(); + println!("batch ={:?}", batch); assert_eq!(batch.num_rows(), 1); assert_eq!(batch.columns()[0].as_primitive::().value(0), 5); assert_eq!(batch.columns()[1].as_primitive::().value(0), 5); diff --git a/crates/arroyo-operator/src/connector.rs b/crates/arroyo-operator/src/connector.rs index d079879f5..54d678214 100644 --- a/crates/arroyo-operator/src/connector.rs +++ b/crates/arroyo-operator/src/connector.rs @@ -1,15 +1,18 @@ use crate::operator::ConstructedOperator; use anyhow::{anyhow, bail}; -use arrow::datatypes::{DataType, Field}; +use arrow::array::{ArrayRef, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema}; use arroyo_rpc::api_types::connections::{ ConnectionProfile, ConnectionSchema, ConnectionType, TestSourceMessage, }; use arroyo_rpc::OperatorConfig; -use arroyo_types::DisplayAsSql; +use arroyo_types::{DisplayAsSql, SourceError}; +use async_trait::async_trait; use serde::de::DeserializeOwned; use serde::ser::Serialize; use serde_json::value::Value; use std::collections::HashMap; +use std::sync::Arc; use tokio::sync::mpsc::Sender; use tokio::sync::oneshot; @@ -118,6 +121,17 @@ pub trait Connector: Send { table: Self::TableT, config: OperatorConfig, ) -> anyhow::Result; + + #[allow(unused)] + fn make_lookup( + &self, + profile: Self::ProfileT, + table: Self::TableT, + config: OperatorConfig, + schema: Arc, + ) -> anyhow::Result> { + bail!("{} is not a lookup connector", self.name()) + } } #[allow(clippy::type_complexity)] #[allow(clippy::wrong_self_convention)] @@ -187,6 +201,12 @@ pub trait ErasedConnector: Send { ) -> anyhow::Result; fn make_operator(&self, config: OperatorConfig) -> anyhow::Result; + + fn make_lookup( + &self, + config: OperatorConfig, + schema: Arc, + ) -> anyhow::Result>; } impl ErasedConnector for C { @@ -335,4 +355,24 @@ impl ErasedConnector for C { config, ) } + + fn make_lookup( + &self, + config: OperatorConfig, + schema: Arc, + ) -> anyhow::Result> { + self.make_lookup( + self.parse_config(&config.connection)?, + self.parse_table(&config.table)?, + config, + schema, + ) + } +} + +#[async_trait] +pub trait LookupConnector { + fn name(&self) -> String; + + async fn lookup(&mut self, keys: &[ArrayRef]) -> Option>; } diff --git a/crates/arroyo-operator/src/context.rs b/crates/arroyo-operator/src/context.rs index 32fb3610f..e013feb99 100644 --- a/crates/arroyo-operator/src/context.rs +++ b/crates/arroyo-operator/src/context.rs @@ -1,16 +1,15 @@ use crate::{server_for_hash_array, RateLimiter}; -use arrow::array::{make_builder, Array, ArrayBuilder, PrimitiveArray, RecordBatch}; +use arrow::array::{Array, PrimitiveArray, RecordBatch}; use arrow::compute::{partition, sort_to_indices, take}; -use arrow::datatypes::{SchemaRef, UInt64Type}; +use arrow::datatypes::UInt64Type; use arroyo_formats::de::{ArrowDeserializer, FieldValueType}; -use arroyo_formats::should_flush; use arroyo_metrics::{register_queue_gauge, QueueGauges, TaskCounters}; use arroyo_rpc::config::config; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::formats::{BadData, Format, Framing}; use arroyo_rpc::grpc::rpc::{CheckpointMetadata, TableConfig, TaskCheckpointEventType}; use arroyo_rpc::schema_resolver::SchemaResolver; -use arroyo_rpc::{get_hasher, CompactionResult, ControlMessage, ControlResp}; +use arroyo_rpc::{get_hasher, CompactionResult, ControlMessage, ControlResp, MetadataField}; use arroyo_state::tables::table_manager::TableManager; use arroyo_types::{ ArrowMessage, ChainInfo, CheckpointBarrier, SignalMessage, SourceError, TaskInfo, UserError, @@ -23,7 +22,7 @@ use std::collections::HashMap; use std::mem::size_of_val; use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; use std::sync::Arc; -use std::time::{Instant, SystemTime}; +use std::time::SystemTime; use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender}; use tokio::sync::Notify; @@ -205,46 +204,8 @@ pub fn batch_bounded(size: u32) -> (BatchSender, BatchReceiver) { ) } -struct ContextBuffer { - buffer: Vec>, - created: Instant, - schema: SchemaRef, -} - -impl ContextBuffer { - fn new(schema: SchemaRef) -> Self { - let buffer = schema - .fields - .iter() - .map(|f| make_builder(f.data_type(), 16)) - .collect(); - - Self { - buffer, - created: Instant::now(), - schema, - } - } - - pub fn size(&self) -> usize { - self.buffer[0].len() - } - - pub fn should_flush(&self) -> bool { - should_flush(self.size(), self.created) - } - - pub fn finish(&mut self) -> RecordBatch { - RecordBatch::try_new( - self.schema.clone(), - self.buffer.iter_mut().map(|a| a.finish()).collect(), - ) - .unwrap() - } -} - pub struct SourceContext { - pub out_schema: ArroyoSchema, + pub out_schema: Arc, pub error_reporter: ErrorReporter, pub control_tx: Sender, pub control_rx: Receiver, @@ -303,10 +264,9 @@ impl SourceContext { pub struct SourceCollector { deserializer: Option, - buffer: ContextBuffer, buffered_error: Option, error_rate_limiter: RateLimiter, - pub out_schema: ArroyoSchema, + pub out_schema: Arc, pub(crate) collector: ArrowCollector, control_tx: Sender, chain_info: Arc, @@ -315,14 +275,13 @@ pub struct SourceCollector { impl SourceCollector { pub fn new( - out_schema: ArroyoSchema, + out_schema: Arc, collector: ArrowCollector, control_tx: Sender, chain_info: &Arc, task_info: &Arc, ) -> Self { Self { - buffer: ContextBuffer::new(out_schema.schema.clone()), out_schema, collector, control_tx, @@ -343,12 +302,14 @@ impl SourceCollector { format: Format, framing: Option, bad_data: Option, + metadata_fields: &[MetadataField], schema_resolver: Arc, ) { self.deserializer = Some(ArrowDeserializer::with_schema_resolver( format, framing, self.out_schema.clone(), + metadata_fields, bad_data.unwrap_or_default(), schema_resolver, )); @@ -359,6 +320,7 @@ impl SourceCollector { format: Format, framing: Option, bad_data: Option, + metadata_fields: &[MetadataField], ) { if self.deserializer.is_some() { panic!("Deserialize already initialized"); @@ -367,25 +329,24 @@ impl SourceCollector { self.deserializer = Some(ArrowDeserializer::new( format, self.out_schema.clone(), + metadata_fields, framing, bad_data.unwrap_or_default(), )); } pub fn should_flush(&self) -> bool { - self.buffer.should_flush() - || self - .deserializer - .as_ref() - .map(|d| d.should_flush()) - .unwrap_or(false) + self.deserializer + .as_ref() + .map(|d| d.should_flush()) + .unwrap_or(false) } pub async fn deserialize_slice( &mut self, msg: &[u8], time: SystemTime, - additional_fields: Option<&HashMap<&String, FieldValueType<'_>>>, + additional_fields: Option<&HashMap<&str, FieldValueType<'_>>>, ) -> Result<(), UserError> { let deserializer = self .deserializer @@ -393,7 +354,7 @@ impl SourceCollector { .expect("deserializer not initialized!"); let errors = deserializer - .deserialize_slice(&mut self.buffer.buffer, msg, time, additional_fields) + .deserialize_slice(msg, time, additional_fields) .await; self.collect_source_errors(errors).await?; @@ -443,11 +404,6 @@ impl SourceCollector { } pub async fn flush_buffer(&mut self) -> Result<(), UserError> { - if self.buffer.size() > 0 { - let batch = self.buffer.finish(); - self.collector.collect(batch).await; - } - if let Some(deserializer) = self.deserializer.as_mut() { if let Some(buffer) = deserializer.flush_buffer() { match buffer { @@ -500,8 +456,8 @@ pub struct OperatorContext { pub task_info: Arc, pub control_tx: Sender, pub watermarks: WatermarkHolder, - pub in_schemas: Vec, - pub out_schema: Option, + pub in_schemas: Vec>, + pub out_schema: Option>, pub table_manager: TableManager, pub error_reporter: ErrorReporter, } @@ -536,7 +492,7 @@ pub trait Collector: Send { #[derive(Clone)] pub struct ArrowCollector { pub chain_info: Arc, - out_schema: Option, + out_schema: Option>, out_qs: Vec>, tx_queue_rem_gauges: QueueGauges, tx_queue_size_gauges: QueueGauges, @@ -654,7 +610,7 @@ impl Collector for ArrowCollector { impl ArrowCollector { pub fn new( chain_info: Arc, - out_schema: Option, + out_schema: Option>, out_qs: Vec>, ) -> Self { let tx_queue_size_gauges = register_queue_gauge( @@ -720,8 +676,8 @@ impl OperatorContext { restore_from: Option<&CheckpointMetadata>, control_tx: Sender, input_partitions: usize, - in_schemas: Vec, - out_schema: Option, + in_schemas: Vec>, + out_schema: Option>, tables: HashMap, ) -> Self { let (table_manager, watermark) = @@ -867,7 +823,7 @@ mod tests { let mut collector = ArrowCollector { chain_info, - out_schema: Some(ArroyoSchema::new_keyed(schema, 1, vec![0])), + out_schema: Some(Arc::new(ArroyoSchema::new_keyed(schema, 1, vec![0]))), out_qs, tx_queue_rem_gauges, tx_queue_size_gauges, diff --git a/crates/arroyo-operator/src/operator.rs b/crates/arroyo-operator/src/operator.rs index d34eaf4ff..70fc09bc0 100644 --- a/crates/arroyo-operator/src/operator.rs +++ b/crates/arroyo-operator/src/operator.rs @@ -226,7 +226,7 @@ impl OperatorNode { control_rx: Receiver, mut in_qs: Vec, out_qs: Vec>, - out_schema: Option, + out_schema: Option>, ready: Arc, ) { info!( diff --git a/crates/arroyo-planner/src/builder.rs b/crates/arroyo-planner/src/builder.rs index f11f30200..d19cf0468 100644 --- a/crates/arroyo-planner/src/builder.rs +++ b/crates/arroyo-planner/src/builder.rs @@ -172,7 +172,7 @@ impl<'a> Planner<'a> { let (partial_schema, timestamp_index) = if add_timestamp_field { ( - add_timestamp_field_arrow(partial_schema.clone()), + add_timestamp_field_arrow((*partial_schema).clone()), partial_schema.fields().len(), ) } else { diff --git a/crates/arroyo-planner/src/extension/join.rs b/crates/arroyo-planner/src/extension/join.rs index 5695bcc26..5769e67c4 100644 --- a/crates/arroyo-planner/src/extension/join.rs +++ b/crates/arroyo-planner/src/extension/join.rs @@ -10,7 +10,6 @@ use datafusion::logical_expr::{LogicalPlan, UserDefinedLogicalNodeCore}; use datafusion_proto::generated::datafusion::PhysicalPlanNode; use datafusion_proto::physical_plan::AsExecutionPlan; use prost::Message; -use std::sync::Arc; use std::time::Duration; pub(crate) const JOIN_NODE_NAME: &str = "JoinNode"; @@ -80,7 +79,7 @@ impl ArroyoExtension for JoinExtension { } fn output_schema(&self) -> ArroyoSchema { - ArroyoSchema::from_schema_unkeyed(Arc::new(self.schema().as_ref().clone().into())).unwrap() + ArroyoSchema::from_schema_unkeyed(self.schema().inner().clone()).unwrap() } } diff --git a/crates/arroyo-planner/src/extension/lookup.rs b/crates/arroyo-planner/src/extension/lookup.rs new file mode 100644 index 000000000..4e2a8f4e6 --- /dev/null +++ b/crates/arroyo-planner/src/extension/lookup.rs @@ -0,0 +1,191 @@ +use crate::builder::{NamedNode, Planner}; +use crate::extension::{ArroyoExtension, NodeWithIncomingEdges}; +use crate::multifield_partial_ord; +use crate::schemas::add_timestamp_field_arrow; +use crate::tables::ConnectorTable; +use arroyo_datastream::logical::{LogicalEdge, LogicalEdgeType, LogicalNode, OperatorName}; +use arroyo_rpc::df::{ArroyoSchema, ArroyoSchemaRef}; +use arroyo_rpc::grpc::api::{ConnectorOp, LookupJoinCondition, LookupJoinOperator}; +use datafusion::common::{internal_err, plan_err, Column, DFSchemaRef, JoinType}; +use datafusion::logical_expr::{Expr, LogicalPlan, UserDefinedLogicalNodeCore}; +use datafusion::sql::TableReference; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr; +use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; +use prost::Message; +use std::fmt::Formatter; +use std::sync::Arc; + +pub const SOURCE_EXTENSION_NAME: &str = "LookupSource"; +pub const JOIN_EXTENSION_NAME: &str = "LookupJoin"; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct LookupSource { + pub(crate) table: ConnectorTable, + pub(crate) schema: DFSchemaRef, +} + +multifield_partial_ord!(LookupSource, table); + +impl UserDefinedLogicalNodeCore for LookupSource { + fn name(&self) -> &str { + SOURCE_EXTENSION_NAME + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "LookupSource: {}", self.schema) + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> datafusion::common::Result { + if !inputs.is_empty() { + return internal_err!("LookupSource cannot have inputs"); + } + + Ok(Self { + table: self.table.clone(), + schema: self.schema.clone(), + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct LookupJoin { + pub(crate) input: LogicalPlan, + pub(crate) schema: DFSchemaRef, + pub(crate) connector: ConnectorTable, + pub(crate) on: Vec<(Expr, Column)>, + pub(crate) filter: Option, + pub(crate) alias: Option, + pub(crate) join_type: JoinType, +} + +multifield_partial_ord!(LookupJoin, input, connector, on, filter, alias); + +impl ArroyoExtension for LookupJoin { + fn node_name(&self) -> Option { + None + } + + fn plan_node( + &self, + planner: &Planner, + index: usize, + input_schemas: Vec, + ) -> datafusion::common::Result { + let schema = ArroyoSchema::from_schema_unkeyed(Arc::new(self.schema.as_ref().into()))?; + let lookup_schema = ArroyoSchema::from_schema_unkeyed(add_timestamp_field_arrow( + self.connector.physical_schema(), + ))?; + let join_config = LookupJoinOperator { + input_schema: Some(schema.into()), + lookup_schema: Some(lookup_schema.into()), + connector: Some(ConnectorOp { + connector: self.connector.connector.clone(), + config: self.connector.config.clone(), + description: self.connector.description.clone(), + }), + key_exprs: self + .on + .iter() + .map(|(l, r)| { + let expr = planner.create_physical_expr(l, &self.schema)?; + let expr = serialize_physical_expr(&expr, &DefaultPhysicalExtensionCodec {})?; + Ok(LookupJoinCondition { + left_expr: expr.encode_to_vec(), + right_key: r.name.clone(), + }) + }) + .collect::>>()?, + join_type: match self.join_type { + JoinType::Inner => arroyo_rpc::grpc::api::JoinType::Inner as i32, + JoinType::Left => arroyo_rpc::grpc::api::JoinType::Left as i32, + j => { + return plan_err!("unsupported join type '{j}' for lookup join; only inner and left joins are supported"); + } + }, + ttl_micros: self + .connector + .lookup_cache_ttl + .map(|t| t.as_micros() as u64), + max_capacity_bytes: self.connector.lookup_cache_max_bytes, + }; + + let incoming_edge = + LogicalEdge::project_all(LogicalEdgeType::Shuffle, (*input_schemas[0]).clone()); + + Ok(NodeWithIncomingEdges { + node: LogicalNode::single( + index as u32, + format!("lookupjoin_{}", index), + OperatorName::LookupJoin, + join_config.encode_to_vec(), + format!("LookupJoin<{}>", self.connector.name), + 1, + ), + edges: vec![incoming_edge], + }) + } + + fn output_schema(&self) -> ArroyoSchema { + ArroyoSchema::from_schema_unkeyed(self.schema.inner().clone()).unwrap() + } +} + +impl UserDefinedLogicalNodeCore for LookupJoin { + fn name(&self) -> &str { + JOIN_EXTENSION_NAME + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + let mut e: Vec<_> = self.on.iter().map(|(l, _)| l.clone()).collect(); + + if let Some(filter) = &self.filter { + e.push(filter.clone()); + } + + e + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "LookupJoinExtension: {}", self.schema) + } + + fn with_exprs_and_inputs( + &self, + _: Vec, + inputs: Vec, + ) -> datafusion::common::Result { + Ok(Self { + input: inputs[0].clone(), + schema: self.schema.clone(), + connector: self.connector.clone(), + on: self.on.clone(), + filter: self.filter.clone(), + alias: self.alias.clone(), + join_type: self.join_type, + }) + } +} diff --git a/crates/arroyo-planner/src/extension/mod.rs b/crates/arroyo-planner/src/extension/mod.rs index 7b576b2a0..c6463d3bd 100644 --- a/crates/arroyo-planner/src/extension/mod.rs +++ b/crates/arroyo-planner/src/extension/mod.rs @@ -26,6 +26,7 @@ use self::{ window_fn::WindowFunctionExtension, }; use crate::builder::{NamedNode, Planner}; +use crate::extension::lookup::LookupJoin; use crate::schemas::{add_timestamp_field, has_timestamp_field}; use crate::{fields_with_qualifiers, schema_from_df_fields, DFField, ASYNC_RESULT_FIELD}; use join::JoinExtension; @@ -34,12 +35,14 @@ pub(crate) mod aggregate; pub(crate) mod debezium; pub(crate) mod join; pub(crate) mod key_calculation; +pub(crate) mod lookup; pub(crate) mod remote_table; pub(crate) mod sink; pub(crate) mod table_source; pub(crate) mod updating_aggregate; pub(crate) mod watermark_node; pub(crate) mod window_fn; + pub(crate) trait ArroyoExtension: Debug { // if the extension has a name, return it so that we can memoize. fn node_name(&self) -> Option; @@ -85,6 +88,7 @@ impl<'a> TryFrom<&'a dyn UserDefinedLogicalNode> for &'a dyn ArroyoExtension { .or_else(|_| try_from_t::(node)) .or_else(|_| try_from_t::(node)) .or_else(|_| try_from_t::(node)) + .or_else(|_| try_from_t::(node)) .map_err(|_| DataFusionError::Plan(format!("unexpected node: {}", node.name()))) } } diff --git a/crates/arroyo-planner/src/extension/sink.rs b/crates/arroyo-planner/src/extension/sink.rs index 0e559d175..5141e6e3a 100644 --- a/crates/arroyo-planner/src/extension/sink.rs +++ b/crates/arroyo-planner/src/extension/sink.rs @@ -61,6 +61,7 @@ impl SinkExtension { (false, false) => {} } } + Table::LookupTable(..) => return plan_err!("cannot use a lookup table as a sink"), Table::MemoryTable { .. } => return plan_err!("memory tables not supported"), Table::TableFromQuery { .. } => {} Table::PreviewSink { .. } => { diff --git a/crates/arroyo-planner/src/lib.rs b/crates/arroyo-planner/src/lib.rs index a9972e1af..7ead899ec 100644 --- a/crates/arroyo-planner/src/lib.rs +++ b/crates/arroyo-planner/src/lib.rs @@ -826,8 +826,11 @@ pub async fn parse_and_get_arrow_program( logical_plan.replace(plan_rewrite); continue; } + Table::LookupTable(_) => { + plan_err!("lookup (temporary) tables cannot be inserted into") + } Table::TableFromQuery { .. } => { - plan_err!("Shouldn't be inserting more data into a table made with CREATE TABLE AS") + plan_err!("shouldn't be inserting more data into a table made with CREATE TABLE AS") } Table::PreviewSink { .. } => { plan_err!("queries shouldn't be able insert into preview sink.") diff --git a/crates/arroyo-planner/src/plan/join.rs b/crates/arroyo-planner/src/plan/join.rs index 9dea42c56..0f532d84d 100644 --- a/crates/arroyo-planner/src/plan/join.rs +++ b/crates/arroyo-planner/src/plan/join.rs @@ -1,10 +1,15 @@ use crate::extension::join::JoinExtension; use crate::extension::key_calculation::KeyCalculationExtension; +use crate::extension::lookup::{LookupJoin, LookupSource}; use crate::plan::WindowDetectingVisitor; +use crate::schemas::add_timestamp_field; +use crate::tables::ConnectorTable; use crate::{fields_with_qualifiers, schema_from_df_fields_with_metadata, ArroyoSchemaProvider}; use arroyo_datastream::WindowType; use arroyo_rpc::UPDATING_META_FIELD; -use datafusion::common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion::common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, +}; use datafusion::common::{ not_impl_err, plan_err, Column, DataFusionError, JoinConstraint, JoinType, Result, ScalarValue, TableReference, @@ -15,6 +20,7 @@ use datafusion::logical_expr::{ build_join_schema, BinaryExpr, Case, Expr, Extension, Join, LogicalPlan, Projection, }; use datafusion::prelude::coalesce; +use datafusion::sql::unparser::expr_to_sql; use std::sync::Arc; pub(crate) struct JoinRewriter<'a> { @@ -78,7 +84,6 @@ impl JoinRewriter<'_> { } fn create_join_key_plan( - &self, input: Arc, join_expressions: Vec, name: &'static str, @@ -189,6 +194,116 @@ impl JoinRewriter<'_> { } } +#[derive(Default)] +struct FindLookupExtension { + table: Option, + filter: Option, + alias: Option, +} + +impl TreeNodeVisitor<'_> for FindLookupExtension { + type Node = LogicalPlan; + + fn f_down(&mut self, node: &Self::Node) -> Result { + match node { + LogicalPlan::Extension(e) => { + if let Some(s) = e.node.as_any().downcast_ref::() { + self.table = Some(s.table.clone()); + return Ok(TreeNodeRecursion::Stop); + } + } + LogicalPlan::Filter(filter) => { + if self.filter.replace(filter.predicate.clone()).is_some() { + return plan_err!( + "multiple filters found in lookup join, which is not supported" + ); + } + } + LogicalPlan::SubqueryAlias(s) => { + self.alias = Some(s.alias.clone()); + } + _ => { + return plan_err!("lookup tables must be used directly within a join"); + } + } + Ok(TreeNodeRecursion::Continue) + } +} + +fn has_lookup(plan: &LogicalPlan) -> Result { + plan.exists(|p| { + Ok(match p { + LogicalPlan::Extension(e) => e.node.as_any().is::(), + _ => false, + }) + }) +} + +fn maybe_plan_lookup_join(join: &Join) -> Result> { + if has_lookup(&join.left)? { + return plan_err!("lookup sources must be on the right side of an inner or left join"); + } + + if !has_lookup(&join.right)? { + return Ok(None); + } + + match join.join_type { + JoinType::Inner | JoinType::Left => {} + t => { + return plan_err!( + "{} join is not supported for lookup tables; must be a left or inner join", + t + ); + } + } + + if join.filter.is_some() { + return plan_err!("filter join conditions are not supported for lookup joins; must have an equality condition"); + } + + let mut lookup = FindLookupExtension::default(); + join.right.visit(&mut lookup)?; + + let connector = lookup + .table + .expect("right side of join does not have lookup"); + + let on = join.on.iter().map(|(l, r)| { + match r { + Expr::Column(c) => { + if !connector.primary_keys.contains(&c.name) { + plan_err!("the right-side of a look-up join condition must be a PRIMARY KEY column, but '{}' is not", c.name) + } else { + Ok((l.clone(), c.clone())) + } + }, + e => { + plan_err!("invalid right-side condition for lookup join: `{}`; only column references are supported", + expr_to_sql(e).map(|e| e.to_string()).unwrap_or_else(|_| e.to_string())) + } + } + }).collect::>()?; + + let left_input = JoinRewriter::create_join_key_plan( + join.left.clone(), + join.on.iter().map(|(l, _)| l.clone()).collect(), + "left", + )?; + + Ok(Some(LogicalPlan::Extension(Extension { + node: Arc::new(LookupJoin { + input: left_input, + schema: add_timestamp_field(join.schema.clone(), None)?, + connector, + on, + filter: lookup.filter, + alias: lookup.alias, + join_type: join.join_type, + }), + }))) +} + impl TreeNodeRewriter for JoinRewriter<'_> { type Node = LogicalPlan; @@ -196,6 +311,11 @@ impl TreeNodeRewriter for JoinRewriter<'_> { let LogicalPlan::Join(join) = node else { return Ok(Transformed::no(node)); }; + + if let Some(plan) = maybe_plan_lookup_join(&join)? { + return Ok(Transformed::yes(plan)); + } + let is_instant = Self::check_join_windowing(&join)?; let Join { @@ -220,8 +340,8 @@ impl TreeNodeRewriter for JoinRewriter<'_> { let (left_expressions, right_expressions): (Vec<_>, Vec<_>) = on.clone().into_iter().unzip(); - let left_input = self.create_join_key_plan(left, left_expressions, "left")?; - let right_input = self.create_join_key_plan(right, right_expressions, "right")?; + let left_input = Self::create_join_key_plan(left, left_expressions, "left")?; + let right_input = Self::create_join_key_plan(right, right_expressions, "right")?; let rewritten_join = LogicalPlan::Join(Join { schema: Arc::new(build_join_schema( left_input.schema(), diff --git a/crates/arroyo-planner/src/rewriters.rs b/crates/arroyo-planner/src/rewriters.rs index 611822719..608f87eba 100644 --- a/crates/arroyo-planner/src/rewriters.rs +++ b/crates/arroyo-planner/src/rewriters.rs @@ -19,6 +19,7 @@ use arroyo_rpc::TIMESTAMP_FIELD; use arroyo_rpc::UPDATING_META_FIELD; use datafusion::logical_expr::UserDefinedLogicalNode; +use crate::extension::lookup::LookupSource; use crate::extension::AsyncUDFExtension; use arroyo_udf_host::parse::{AsyncOptions, UdfType}; use datafusion::common::tree_node::{ @@ -215,6 +216,19 @@ impl SourceRewriter<'_> { }))) } + fn mutate_lookup_table( + &self, + table_scan: &TableScan, + table: &ConnectorTable, + ) -> DFResult> { + Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(LookupSource { + table: table.clone(), + schema: table_scan.projected_schema.clone(), + }), + }))) + } + fn mutate_table_from_query( &self, table_scan: &TableScan, @@ -273,6 +287,7 @@ impl TreeNodeRewriter for SourceRewriter<'_> { match table { Table::ConnectorTable(table) => self.mutate_connector_table(&table_scan, table), + Table::LookupTable(table) => self.mutate_lookup_table(&table_scan, table), Table::MemoryTable { name, fields: _, diff --git a/crates/arroyo-planner/src/schemas.rs b/crates/arroyo-planner/src/schemas.rs index 9d80453c4..3d0c53448 100644 --- a/crates/arroyo-planner/src/schemas.rs +++ b/crates/arroyo-planner/src/schemas.rs @@ -49,7 +49,7 @@ pub(crate) fn has_timestamp_field(schema: &DFSchemaRef) -> bool { .any(|field| field.name() == "_timestamp") } -pub fn add_timestamp_field_arrow(schema: SchemaRef) -> SchemaRef { +pub fn add_timestamp_field_arrow(schema: Schema) -> SchemaRef { let mut fields = schema.fields().to_vec(); fields.push(Arc::new(Field::new( "_timestamp", diff --git a/crates/arroyo-planner/src/tables.rs b/crates/arroyo-planner/src/tables.rs index dd8f7141f..6d8a06b3b 100644 --- a/crates/arroyo-planner/src/tables.rs +++ b/crates/arroyo-planner/src/tables.rs @@ -1,10 +1,10 @@ +use arrow::compute::kernels::cast_utils::parse_interval_day_time; +use arrow_schema::{DataType, Field, FieldRef, Schema}; +use arroyo_connectors::connector_for_type; use std::str::FromStr; use std::sync::Arc; use std::{collections::HashMap, time::Duration}; -use arrow_schema::{DataType, Field, FieldRef, Schema}; -use arroyo_connectors::connector_for_type; - use crate::extension::remote_table::RemoteTableExtension; use crate::types::convert_data_type; use crate::{ @@ -72,8 +72,11 @@ pub struct ConnectorTable { pub watermark_field: Option, pub idle_time: Option, pub primary_keys: Arc>, - pub inferred_fields: Option>, + + // for lookup tables + pub lookup_cache_max_bytes: Option, + pub lookup_cache_ttl: Option, } multifield_partial_ord!( @@ -206,6 +209,8 @@ impl From for ConnectorTable { idle_time: DEFAULT_IDLE_TIME, primary_keys: Arc::new(vec![]), inferred_fields: None, + lookup_cache_max_bytes: None, + lookup_cache_ttl: None, } } } @@ -214,6 +219,7 @@ impl ConnectorTable { fn from_options( name: &str, connector: &str, + temporary: bool, mut fields: Vec, primary_keys: Vec, options: &mut HashMap, @@ -247,6 +253,14 @@ impl ConnectorTable { let framing = Framing::from_opts(options) .map_err(|e| DataFusionError::Plan(format!("invalid framing: '{e}'")))?; + if temporary { + if let Some(t) = options.insert("type".to_string(), "lookup".to_string()) { + if t != "lookup" { + return plan_err!("Cannot have a temporary table with type '{}'; temporary tables must be type 'lookup'", t); + } + } + } + let mut input_to_schema_fields = fields.clone(); if let Some(Format::Json(JsonFormat { debezium: true, .. })) = &format { @@ -294,6 +308,7 @@ impl ConnectorTable { schema_fields, None, Some(fields.is_empty()), + primary_keys.iter().cloned().collect(), ) .map_err(|e| DataFusionError::Plan(format!("could not create connection schema: {}", e)))?; @@ -318,6 +333,26 @@ impl ConnectorTable { .filter(|t| *t <= 0) .map(|t| Duration::from_micros(t as u64)); + table.lookup_cache_max_bytes = options + .remove("lookup.cache.max_bytes") + .map(|t| u64::from_str(&t)) + .transpose() + .map_err(|_| { + DataFusionError::Plan("lookup.cache.max_bytes must be set to a number".to_string()) + })?; + + table.lookup_cache_ttl = options + .remove("lookup.cache.ttl") + .map(|t| parse_interval_day_time(&t)) + .transpose() + .map_err(|e| { + DataFusionError::Plan(format!("lookup.cache.ttl must be a valid interval ({})", e)) + })? + .map(|t| { + Duration::from_secs(t.days as u64 * 60 * 60 * 24) + + Duration::from_millis(t.milliseconds as u64) + }); + if !options.is_empty() { let keys: Vec = options.keys().map(|s| format!("'{}'", s)).collect(); return plan_err!( @@ -418,7 +453,7 @@ impl ConnectorTable { pub fn as_sql_source(&self) -> Result { match self.connection_type { ConnectionType::Source => {} - ConnectionType::Sink => { + ConnectionType::Sink | ConnectionType::Lookup => { return plan_err!("cannot read from sink"); } }; @@ -463,6 +498,7 @@ pub struct SourceOperator { #[allow(clippy::enum_variant_names)] #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Table { + LookupTable(ConnectorTable), ConnectorTable(ConnectorTable), MemoryTable { name: String, @@ -673,6 +709,7 @@ impl Table { columns, with_options, query: None, + temporary, .. }) = statement { @@ -750,17 +787,23 @@ impl Table { ), None => None, }; - Ok(Some(Table::ConnectorTable( - ConnectorTable::from_options( - &name, - connector, - fields, - primary_keys, - &mut with_map, - connection_profile, - ) - .map_err(|e| e.context(format!("Failed to create table {}", name)))?, - ))) + let table = ConnectorTable::from_options( + &name, + connector, + *temporary, + fields, + primary_keys, + &mut with_map, + connection_profile, + ) + .map_err(|e| e.context(format!("Failed to create table {}", name)))?; + + Ok(Some(match table.connection_type { + ConnectionType::Source | ConnectionType::Sink => { + Table::ConnectorTable(table) + } + ConnectionType::Lookup => Table::LookupTable(table), + })) } } } else { @@ -798,7 +841,7 @@ impl Table { pub fn name(&self) -> &str { match self { Table::MemoryTable { name, .. } | Table::TableFromQuery { name, .. } => name.as_str(), - Table::ConnectorTable(c) => c.name.as_str(), + Table::ConnectorTable(c) | Table::LookupTable(c) => c.name.as_str(), Table::PreviewSink { .. } => "preview", } } @@ -836,6 +879,11 @@ impl Table { fields, inferred_fields, .. + }) + | Table::LookupTable(ConnectorTable { + fields, + inferred_fields, + .. }) => inferred_fields .as_ref() .map(|fs| fs.iter().map(|f| f.field().clone()).collect()) @@ -856,7 +904,7 @@ impl Table { pub fn connector_op(&self) -> Result { match self { - Table::ConnectorTable(c) => Ok(c.connector_op()), + Table::ConnectorTable(c) | Table::LookupTable(c) => Ok(c.connector_op()), Table::MemoryTable { .. } => plan_err!("can't write to a memory table"), Table::TableFromQuery { .. } => todo!(), Table::PreviewSink { logical_plan: _ } => Ok(default_sink()), diff --git a/crates/arroyo-planner/src/test/queries/error_lookup_join_non_primary_key.sql b/crates/arroyo-planner/src/test/queries/error_lookup_join_non_primary_key.sql new file mode 100644 index 000000000..cf5b486bf --- /dev/null +++ b/crates/arroyo-planner/src/test/queries/error_lookup_join_non_primary_key.sql @@ -0,0 +1,21 @@ +--fail=the right-side of a look-up join condition must be a PRIMARY KEY column, but 'value' is not +create table impulse with ( + connector = 'impulse', + event_rate = '2' +); + +create temporary table lookup ( + key TEXT PRIMARY KEY GENERATED ALWAYS AS (metadata('key')) STORED, + value TEXT, + len INT +) with ( + connector = 'redis', + format = 'raw_string', + address = 'redis://localhost:6379', + format = 'json', + 'lookup.cache.max_bytes' = '100000' +); + +select A.counter, B.key, B.value, len +from impulse A inner join lookup B +on cast((A.counter % 10) as TEXT) = B.value; \ No newline at end of file diff --git a/crates/arroyo-planner/src/test/queries/error_missing_redis_key.sql b/crates/arroyo-planner/src/test/queries/error_missing_redis_key.sql new file mode 100644 index 000000000..537fbaf07 --- /dev/null +++ b/crates/arroyo-planner/src/test/queries/error_missing_redis_key.sql @@ -0,0 +1,19 @@ +--fail=Redis lookup tables must have a PRIMARY KEY field defined as `field_name TEXT GENERATED ALWAYS AS (metadata('key')) STORED` +create table impulse with ( + connector = 'impulse', + event_rate = '2' +); + +create table lookup ( + key TEXT PRIMARY KEY, + value TEXT +) with ( + connector = 'redis', + format = 'json', + address = 'redis://localhost:6379', + type = 'lookup' +); + +select A.counter, B.key, B.value +from impulse A left join lookup B +on cast((A.counter % 10) as TEXT) = B.key; diff --git a/crates/arroyo-planner/src/test/queries/lookup_join.sql b/crates/arroyo-planner/src/test/queries/lookup_join.sql new file mode 100644 index 000000000..05ba36ace --- /dev/null +++ b/crates/arroyo-planner/src/test/queries/lookup_join.sql @@ -0,0 +1,31 @@ +CREATE TABLE events ( + event_id TEXT, + timestamp TIMESTAMP, + customer_id TEXT, + event_type TEXT +) WITH ( + connector = 'kafka', + topic = 'events', + type = 'source', + format = 'json', + bootstrap_servers = 'broker:9092' +); + +create temporary table customers ( + customer_id TEXT PRIMARY KEY GENERATED ALWAYS AS (metadata('key')) STORED, + customer_name TEXT, + plan TEXT +) with ( + connector = 'redis', + format = 'raw_string', + address = 'redis://localhost:6379', + format = 'json', + 'lookup.cache.max_bytes' = '1000000', + 'lookup.cache.ttl' = '5 second' +); + +SELECT e.event_id, e.timestamp, e.customer_id, e.event_type, c.customer_name, c.plan +FROM events e +LEFT JOIN customers c +ON e.customer_id = c.customer_id +WHERE c.plan = 'Premium'; diff --git a/crates/arroyo-rpc/proto/api.proto b/crates/arroyo-rpc/proto/api.proto index 24c120c11..74874b6cb 100644 --- a/crates/arroyo-rpc/proto/api.proto +++ b/crates/arroyo-rpc/proto/api.proto @@ -70,6 +70,21 @@ message JoinOperator { optional uint64 ttl_micros = 6; } +message LookupJoinCondition { + bytes left_expr = 1; + string right_key = 2; +} + +message LookupJoinOperator { + ArroyoSchema input_schema = 1; + ArroyoSchema lookup_schema = 2; + ConnectorOp connector = 3; + repeated LookupJoinCondition key_exprs = 4; + JoinType join_type = 5; + optional uint64 ttl_micros = 6; + optional uint64 max_capacity_bytes = 7; +} + message WindowFunctionOperator { string name = 1; ArroyoSchema input_schema = 2; diff --git a/crates/arroyo-rpc/src/api_types/connections.rs b/crates/arroyo-rpc/src/api_types/connections.rs index 93ca72d60..5dfbd3429 100644 --- a/crates/arroyo-rpc/src/api_types/connections.rs +++ b/crates/arroyo-rpc/src/api_types/connections.rs @@ -1,14 +1,14 @@ +use crate::df::{ArroyoSchema, ArroyoSchemaRef}; use crate::formats::{BadData, Format, Framing}; use crate::{primitive_to_sql, MetadataField}; +use ahash::HashSet; use anyhow::bail; use arrow_schema::{DataType, Field, Fields, TimeUnit}; +use arroyo_types::ArroyoExtensionType; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; use std::fmt::{Display, Formatter}; use std::sync::Arc; - -use crate::df::{ArroyoSchema, ArroyoSchemaRef}; -use arroyo_types::ArroyoExtensionType; use utoipa::{IntoParams, ToSchema}; #[derive(Serialize, Deserialize, Clone, Debug, ToSchema)] @@ -51,6 +51,7 @@ pub struct ConnectionProfilePost { pub enum ConnectionType { Source, Sink, + Lookup, } impl Display for ConnectionType { @@ -58,6 +59,7 @@ impl Display for ConnectionType { match self { ConnectionType::Source => write!(f, "SOURCE"), ConnectionType::Sink => write!(f, "SINK"), + ConnectionType::Lookup => write!(f, "LOOKUP"), } } } @@ -250,9 +252,12 @@ pub struct ConnectionSchema { pub fields: Vec, pub definition: Option, pub inferred: Option, + #[serde(default)] + pub primary_keys: HashSet, } impl ConnectionSchema { + #[allow(clippy::too_many_arguments)] pub fn try_new( format: Option, bad_data: Option, @@ -261,6 +266,7 @@ impl ConnectionSchema { fields: Vec, definition: Option, inferred: Option, + primary_keys: HashSet, ) -> anyhow::Result { let s = ConnectionSchema { format, @@ -270,6 +276,7 @@ impl ConnectionSchema { fields, definition, inferred, + primary_keys, }; s.validate() @@ -321,6 +328,7 @@ impl ConnectionSchema { Some(MetadataField { field_name: f.field_name.clone(), key: f.metadata_key.clone()?, + data_type: Some(Field::from(f.clone()).data_type().clone()), }) }) .collect() diff --git a/crates/arroyo-rpc/src/df.rs b/crates/arroyo-rpc/src/df.rs index 3bbec89c2..03f62fc10 100644 --- a/crates/arroyo-rpc/src/df.rs +++ b/crates/arroyo-rpc/src/df.rs @@ -1,6 +1,6 @@ use crate::grpc::api; use crate::{grpc, Converter, TIMESTAMP_FIELD}; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, bail, Result}; use arrow::compute::kernels::numeric::div; use arrow::compute::{filter_record_batch, take}; use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit}; @@ -384,6 +384,25 @@ impl ArroyoSchema { key_indices: None, }) } + + pub fn with_field(&self, name: &str, data_type: DataType, nullable: bool) -> Result { + if self.schema.field_with_name(name).is_ok() { + bail!( + "cannot add field '{}' to schema, it is already present", + name + ); + } + let mut fields = self.schema.fields().to_vec(); + fields.push(Arc::new(Field::new(name, data_type, nullable))); + Ok(Self { + schema: Arc::new(Schema::new_with_metadata( + fields, + self.schema.metadata.clone(), + )), + timestamp_index: self.timestamp_index, + key_indices: self.key_indices.clone(), + }) + } } pub fn server_for_hash_array( diff --git a/crates/arroyo-rpc/src/lib.rs b/crates/arroyo-rpc/src/lib.rs index 5bdb01bf2..b06ef387f 100644 --- a/crates/arroyo-rpc/src/lib.rs +++ b/crates/arroyo-rpc/src/lib.rs @@ -37,6 +37,17 @@ pub mod grpc { pub mod api { #![allow(clippy::derive_partial_eq_without_eq, deprecated)] tonic::include_proto!("api"); + + impl From for arroyo_types::JoinType { + fn from(value: JoinType) -> Self { + match value { + JoinType::Inner => arroyo_types::JoinType::Inner, + JoinType::Left => arroyo_types::JoinType::Left, + JoinType::Right => arroyo_types::JoinType::Right, + JoinType::Full => arroyo_types::JoinType::Full, + } + } + } } pub const API_FILE_DESCRIPTOR_SET: &[u8] = @@ -190,6 +201,8 @@ pub struct RateLimit { pub struct MetadataField { pub field_name: String, pub key: String, + #[serde(default)] + pub data_type: Option, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] diff --git a/crates/arroyo-types/src/lib.rs b/crates/arroyo-types/src/lib.rs index bcc5df76d..31a0a0db3 100644 --- a/crates/arroyo-types/src/lib.rs +++ b/crates/arroyo-types/src/lib.rs @@ -350,7 +350,7 @@ impl Serialize for DebeziumOp { } } -#[derive(Clone, Encode, Decode, Debug, Serialize, Deserialize, PartialEq)] +#[derive(Copy, Clone, Encode, Decode, Debug, Serialize, Deserialize, PartialEq)] pub enum JoinType { /// Inner Join Inner, @@ -641,6 +641,8 @@ pub fn range_for_server(i: usize, n: usize) -> RangeInclusive { start..=end } +pub const LOOKUP_KEY_INDEX_FIELD: &str = "__lookup_key_index"; + #[cfg(test)] mod tests { use super::*; diff --git a/crates/arroyo-worker/Cargo.toml b/crates/arroyo-worker/Cargo.toml index 270b22000..27c79ef08 100644 --- a/crates/arroyo-worker/Cargo.toml +++ b/crates/arroyo-worker/Cargo.toml @@ -77,6 +77,7 @@ itertools = "0.12.0" async-ffi = "0.5.0" dlopen2 = "0.7.0" dlopen2_derive = "0.4.0" +mini-moka = { version = "0.10.3" } [dev-dependencies] diff --git a/crates/arroyo-worker/src/arrow/lookup_join.rs b/crates/arroyo-worker/src/arrow/lookup_join.rs new file mode 100644 index 000000000..061fa3476 --- /dev/null +++ b/crates/arroyo-worker/src/arrow/lookup_join.rs @@ -0,0 +1,275 @@ +use arrow::compute::filter_record_batch; +use arrow::row::{OwnedRow, RowConverter, SortField}; +use arrow_array::cast::AsArray; +use arrow_array::types::UInt64Type; +use arrow_array::{Array, BooleanArray, RecordBatch}; +use arrow_schema::{DataType, Schema}; +use arroyo_connectors::connectors; +use arroyo_operator::connector::LookupConnector; +use arroyo_operator::context::{Collector, OperatorContext}; +use arroyo_operator::operator::{ + ArrowOperator, ConstructedOperator, OperatorConstructor, Registry, +}; +use arroyo_rpc::df::ArroyoSchema; +use arroyo_rpc::grpc::api; +use arroyo_rpc::{MetadataField, OperatorConfig}; +use arroyo_types::LOOKUP_KEY_INDEX_FIELD; +use async_trait::async_trait; +use datafusion::physical_expr::PhysicalExpr; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr; +use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; +use datafusion_proto::protobuf::PhysicalExprNode; +use mini_moka::sync::Cache; +use prost::Message; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +#[derive(Copy, Clone, PartialEq)] +pub(crate) enum LookupJoinType { + Left, + Inner, +} + +/// A simple in-operator cache storing the entire “right side” row batch keyed by a string. +pub struct LookupJoin { + connector: Box, + key_exprs: Vec>, + cache: Option>, + key_row_converter: RowConverter, + result_row_converter: RowConverter, + join_type: LookupJoinType, + lookup_schema: Arc, + metadata_fields: Vec, +} + +#[async_trait] +impl ArrowOperator for LookupJoin { + fn name(&self) -> String { + format!("LookupJoin({})", self.connector.name()) + } + + async fn process_batch( + &mut self, + batch: RecordBatch, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { + let num_rows = batch.num_rows(); + + let key_arrays: Vec<_> = self + .key_exprs + .iter() + .map(|expr| expr.evaluate(&batch).unwrap().into_array(num_rows).unwrap()) + .collect(); + + let rows = self.key_row_converter.convert_columns(&key_arrays).unwrap(); + + let mut key_map: HashMap<_, Vec> = HashMap::new(); + for (i, row) in rows.iter().enumerate() { + key_map.entry(row.owned()).or_default().push(i); + } + + let uncached_keys = if let Some(cache) = &mut self.cache { + let mut uncached_keys = Vec::new(); + for k in key_map.keys() { + if !cache.contains_key(k) { + uncached_keys.push(k); + } + } + uncached_keys + } else { + key_map.keys().collect() + }; + + let mut results = HashMap::new(); + + #[allow(unused_assignments)] + let mut result_rows = None; + + if !uncached_keys.is_empty() { + let cols = self + .key_row_converter + .convert_rows(uncached_keys.iter().map(|r| r.row())) + .unwrap(); + + let result_batch = self.connector.lookup(&cols).await; + + if let Some(result_batch) = result_batch { + let mut result_batch = result_batch.unwrap(); + let key_idx_col = result_batch + .schema() + .index_of(LOOKUP_KEY_INDEX_FIELD) + .unwrap(); + + let keys = result_batch.remove_column(key_idx_col); + let keys = keys.as_primitive::(); + + result_rows = Some( + self.result_row_converter + .convert_columns(result_batch.columns()) + .unwrap(), + ); + + for (v, idx) in result_rows.as_ref().unwrap().iter().zip(keys) { + results.insert(uncached_keys[idx.unwrap() as usize].as_ref(), v); + if let Some(cache) = &mut self.cache { + cache.insert(uncached_keys[idx.unwrap() as usize].clone(), v.owned()); + } + } + } + } + + let mut output_rows = self + .result_row_converter + .empty_rows(batch.num_rows(), batch.num_rows() * 10); + + for row in rows.iter() { + let row = self + .cache + .as_mut() + .and_then(|c| c.get(&row.owned())) + .unwrap_or_else(|| results.get(row.as_ref()).unwrap().owned()); + + output_rows.push(row.row()); + } + + let right_side = self + .result_row_converter + .convert_rows(output_rows.iter()) + .unwrap(); + + let nonnull = (self.join_type == LookupJoinType::Inner).then(|| { + let mut nonnull = vec![false; batch.num_rows()]; + + for (_, a) in self + .lookup_schema + .fields + .iter() + .zip(right_side.iter()) + .filter(|(f, _)| { + !self + .metadata_fields + .iter() + .any(|m| &m.field_name == f.name()) + }) + { + if let Some(nulls) = a.logical_nulls() { + for (a, b) in nulls.iter().zip(nonnull.iter_mut()) { + *b |= a; + } + } else { + for b in &mut nonnull { + *b = true; + } + break; + } + } + + BooleanArray::from(nonnull) + }); + + let in_schema = ctx.in_schemas.first().unwrap(); + let key_indices = in_schema.key_indices.as_ref().unwrap(); + let non_keys: Vec<_> = (0..batch.num_columns()) + .filter(|i| !key_indices.contains(i) && *i != in_schema.timestamp_index) + .collect(); + + let mut result = batch.project(&non_keys).unwrap().columns().to_vec(); + result.extend(right_side); + result.push(batch.column(in_schema.timestamp_index).clone()); + + let mut batch = + RecordBatch::try_new(ctx.out_schema.as_ref().unwrap().schema.clone(), result).unwrap(); + + if let Some(nonnull) = nonnull { + batch = filter_record_batch(&batch, &nonnull).unwrap(); + } + + collector.collect(batch).await; + } +} + +pub struct LookupJoinConstructor; +impl OperatorConstructor for LookupJoinConstructor { + type ConfigT = api::LookupJoinOperator; + fn with_config( + &self, + config: Self::ConfigT, + registry: Arc, + ) -> anyhow::Result { + let join_type = config.join_type(); + let input_schema: ArroyoSchema = config.input_schema.unwrap().try_into()?; + let lookup_schema: ArroyoSchema = config.lookup_schema.unwrap().try_into()?; + + let exprs = config + .key_exprs + .iter() + .map(|e| { + let expr = PhysicalExprNode::decode(&mut e.left_expr.as_slice())?; + Ok(parse_physical_expr( + &expr, + registry.as_ref(), + &input_schema.schema, + &DefaultPhysicalExtensionCodec {}, + )?) + }) + .collect::>>()?; + + let op = config.connector.unwrap(); + let operator_config: OperatorConfig = serde_json::from_str(&op.config)?; + + let result_row_converter = RowConverter::new( + lookup_schema + .schema_without_timestamp() + .fields + .iter() + .map(|f| SortField::new(f.data_type().clone())) + .collect(), + )?; + + let lookup_schema = Arc::new( + lookup_schema + .with_field(LOOKUP_KEY_INDEX_FIELD, DataType::UInt64, false)? + .schema_without_timestamp(), + ); + + let connector = connectors() + .get(op.connector.as_str()) + .unwrap_or_else(|| panic!("No connector with name '{}'", op.connector)) + .make_lookup(operator_config.clone(), lookup_schema.clone())?; + + let max_capacity_bytes = config.max_capacity_bytes.unwrap_or(8 * 1024 * 1024); + let cache = (max_capacity_bytes > 0).then(|| { + let mut c = Cache::builder() + .weigher(|k: &OwnedRow, v: &OwnedRow| (k.as_ref().len() + v.as_ref().len()) as u32) + .max_capacity(max_capacity_bytes); + + if let Some(ttl) = config.ttl_micros { + c = c.time_to_live(Duration::from_micros(ttl)); + } + + c.build() + }); + + Ok(ConstructedOperator::from_operator(Box::new(LookupJoin { + connector, + cache, + key_row_converter: RowConverter::new( + exprs + .iter() + .map(|e| Ok(SortField::new(e.data_type(&input_schema.schema)?))) + .collect::>()?, + )?, + key_exprs: exprs, + result_row_converter, + join_type: match join_type { + api::JoinType::Inner => LookupJoinType::Inner, + api::JoinType::Left => LookupJoinType::Left, + jt => unreachable!("invalid lookup join type {:?}", jt), + }, + lookup_schema, + metadata_fields: operator_config.metadata_fields, + }))) + } +} diff --git a/crates/arroyo-worker/src/arrow/mod.rs b/crates/arroyo-worker/src/arrow/mod.rs index 70353877a..437e75188 100644 --- a/crates/arroyo-worker/src/arrow/mod.rs +++ b/crates/arroyo-worker/src/arrow/mod.rs @@ -28,6 +28,7 @@ use std::sync::RwLock; pub mod async_udf; pub mod instant_join; pub mod join_with_expiration; +pub mod lookup_join; pub mod session_aggregating_window; pub mod sliding_aggregating_window; pub(crate) mod sync; diff --git a/crates/arroyo-worker/src/arrow/tumbling_aggregating_window.rs b/crates/arroyo-worker/src/arrow/tumbling_aggregating_window.rs index 8a2a89052..ffcad9981 100644 --- a/crates/arroyo-worker/src/arrow/tumbling_aggregating_window.rs +++ b/crates/arroyo-worker/src/arrow/tumbling_aggregating_window.rs @@ -181,7 +181,7 @@ impl OperatorConstructor for TumblingAggregateWindowConstructor { .transpose()?; let aggregate_with_timestamp_schema = - add_timestamp_field_arrow(finish_execution_plan.schema()); + add_timestamp_field_arrow((*finish_execution_plan.schema()).clone()); Ok(ConstructedOperator::from_operator(Box::new( TumblingAggregatingWindowFunc { diff --git a/crates/arroyo-worker/src/engine.rs b/crates/arroyo-worker/src/engine.rs index 879bab37f..c7ed5ae93 100644 --- a/crates/arroyo-worker/src/engine.rs +++ b/crates/arroyo-worker/src/engine.rs @@ -1,6 +1,7 @@ use crate::arrow::async_udf::AsyncUdfConstructor; use crate::arrow::instant_join::InstantJoinConstructor; use crate::arrow::join_with_expiration::JoinWithExpirationConstructor; +use crate::arrow::lookup_join::LookupJoinConstructor; use crate::arrow::session_aggregating_window::SessionAggregatingWindowConstructor; use crate::arrow::sliding_aggregating_window::SlidingAggregatingWindowConstructor; use crate::arrow::tumbling_aggregating_window::TumblingAggregateWindowConstructor; @@ -54,8 +55,8 @@ pub struct SubtaskNode { pub node_id: u32, pub subtask_idx: usize, pub parallelism: usize, - pub in_schemas: Vec, - pub out_schema: Option, + pub in_schemas: Vec>, + pub out_schema: Option>, pub node: OperatorNode, } @@ -97,7 +98,7 @@ pub struct PhysicalGraphEdge { edge_idx: usize, in_logical_idx: usize, out_logical_idx: usize, - schema: ArroyoSchema, + schema: Arc, edge: LogicalEdgeType, tx: Option, rx: Option, @@ -773,8 +774,8 @@ pub async fn construct_node( subtask_idx: u32, parallelism: u32, input_partitions: u32, - in_schemas: Vec, - out_schema: Option, + in_schemas: Vec>, + out_schema: Option>, restore_from: Option<&CheckpointMetadata>, control_tx: Sender, registry: Arc, @@ -874,6 +875,7 @@ pub fn construct_operator( OperatorName::ExpressionWatermark => Box::new(WatermarkGeneratorConstructor), OperatorName::Join => Box::new(JoinWithExpirationConstructor), OperatorName::InstantJoin => Box::new(InstantJoinConstructor), + OperatorName::LookupJoin => Box::new(LookupJoinConstructor), OperatorName::WindowFunction => Box::new(WindowFunctionConstructor), OperatorName::ConnectorSource | OperatorName::ConnectorSink => { let op: api::ConnectorOp = prost::Message::decode(config).unwrap();