diff --git a/sdks/rust/Cargo.lock b/sdks/rust/Cargo.lock index 8cdc3e2dee92..8c122b94dacd 100644 --- a/sdks/rust/Cargo.lock +++ b/sdks/rust/Cargo.lock @@ -30,7 +30,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", - "strum_macros", + "strum", "tokio", "tokio-stream", "tokio-util", @@ -1370,6 +1370,15 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "strum" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "063e6045c0e62079840579a7e47a355ae92f60eb74daaf156fb1e84ba164e63f" +dependencies = [ + "strum_macros", +] + [[package]] name = "strum_macros" version = "0.24.3" diff --git a/sdks/rust/Cargo.toml b/sdks/rust/Cargo.toml index a4589d616a7e..d2fe9af2118d 100644 --- a/sdks/rust/Cargo.toml +++ b/sdks/rust/Cargo.toml @@ -43,7 +43,7 @@ rand = "0.7" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" serde_yaml = "0.9.14" -strum_macros = "0.24" +strum = { version = "0.24", features = ["derive"] } tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "sync", "time"] } tokio-util = "0.7.4" tokio-stream = "0.1" diff --git a/sdks/rust/src/coders/register_coders/mod.rs b/sdks/rust/src/coders/register_coders/mod.rs index 413a6dd597a3..79c02b334f68 100644 --- a/sdks/rust/src/coders/register_coders/mod.rs +++ b/sdks/rust/src/coders/register_coders/mod.rs @@ -48,8 +48,8 @@ macro_rules! register_coders { } #[ctor::ctor] - fn init_coder_from_urn() { - $crate::worker::CODER_FROM_URN.set($crate::worker::CoderFromUrn { + fn init_custom_coder_from_urn() { + $crate::worker::CUSTOM_CODER_FROM_URN.set($crate::worker::CustomCoderFromUrn { enc: encode_from_urn, dec: decode_from_urn, }).unwrap(); diff --git a/sdks/rust/src/coders/required_coders.rs b/sdks/rust/src/coders/required_coders.rs index ac292e984900..80608653e28b 100644 --- a/sdks/rust/src/coders/required_coders.rs +++ b/sdks/rust/src/coders/required_coders.rs @@ -133,23 +133,22 @@ impl fmt::Debug for BytesCoder { } /// A coder for a key-value pair -#[derive(Clone, Debug)] pub struct KVCoder { phantom: PhantomData, } impl CoderUrn for KVCoder> where - K: ElemType + Clone + fmt::Debug, - V: ElemType + Clone + fmt::Debug, + K: ElemType, + V: ElemType, { const URN: &'static str = KV_CODER_URN; } impl Coder for KVCoder> where - K: ElemType + Clone + fmt::Debug, - V: ElemType + Clone + fmt::Debug, + K: ElemType, + V: ElemType, { /// Encode the input element (a key-value pair) into a byte output stream. They key and value are encoded one after the /// other (first key, then value). The key is encoded with `Context::NeedsDelimiters`, while the value is encoded with @@ -173,10 +172,20 @@ where } } +impl fmt::Debug for KVCoder> +where + K: ElemType, + V: ElemType, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("KVCoder").finish() + } +} + impl Default for KVCoder> where - K: Send, - V: Send, + K: ElemType, + V: ElemType, { fn default() -> Self { Self { diff --git a/sdks/rust/src/coders/urns.rs b/sdks/rust/src/coders/urns.rs index 3084d597f87b..c003d633ff0f 100644 --- a/sdks/rust/src/coders/urns.rs +++ b/sdks/rust/src/coders/urns.rs @@ -16,6 +16,44 @@ * limitations under the License. */ +use strum::{EnumDiscriminants, EnumIter}; + +#[derive(Debug, EnumDiscriminants, EnumIter)] +pub(crate) enum PresetCoderUrn { + Bytes, + Kv, + Iterable, + StrUtf8, + VarInt, + Unit, + GeneralObject, +} + +impl PresetCoderUrn { + pub(crate) fn as_str(&self) -> &str { + match self { + // ******* Standard coders ******* + Self::Bytes => BYTES_CODER_URN, + Self::Kv => KV_CODER_URN, + Self::Iterable => ITERABLE_CODER_URN, + + // ******* Required coders ******* + Self::StrUtf8 => STR_UTF8_CODER_URN, + Self::VarInt => VARINT_CODER_URN, + + // ******* Rust coders ******* + Self::Unit => UNIT_CODER_URN, + Self::GeneralObject => GENERAL_OBJECT_CODER_URN, + } + } +} + +impl AsRef for PresetCoderUrn { + fn as_ref(&self) -> &str { + self.as_str() + } +} + // ******* Standard coders ******* pub const BYTES_CODER_URN: &str = "beam:coder:bytes:v1"; pub const KV_CODER_URN: &str = "beam:coder:kvcoder:v1"; diff --git a/sdks/rust/src/lib.rs b/sdks/rust/src/lib.rs index 397f78c6a73e..c196c2b3cf1b 100644 --- a/sdks/rust/src/lib.rs +++ b/sdks/rust/src/lib.rs @@ -28,6 +28,3 @@ pub mod proto; pub mod runners; pub mod transforms; pub mod worker; - -#[macro_use] -extern crate strum_macros; diff --git a/sdks/rust/src/worker/coder_from_urn/custom_coder_from_urn/mod.rs b/sdks/rust/src/worker/coder_from_urn/custom_coder_from_urn/mod.rs new file mode 100644 index 000000000000..d72e657e10d6 --- /dev/null +++ b/sdks/rust/src/worker/coder_from_urn/custom_coder_from_urn/mod.rs @@ -0,0 +1,47 @@ +use std::fmt; + +use once_cell::sync::OnceCell; + +use crate::coders::{DecodeFromUrnFn, EncodeFromUrnFn}; + +/// The visibility is `pub` because this is used internally from `register_coders!` macro. +pub static CUSTOM_CODER_FROM_URN: OnceCell = OnceCell::new(); + +/// The visibility is `pub` because this is instantiated internally from `register_coders!` macro. +pub struct CustomCoderFromUrn { + pub enc: EncodeFromUrnFn, + pub dec: DecodeFromUrnFn, +} + +impl CustomCoderFromUrn { + pub(in crate::worker::coder_from_urn) fn global() -> &'static CustomCoderFromUrn { + CUSTOM_CODER_FROM_URN + .get() + .expect("you might forget calling `register_coders!(CustomCoder1, CustomCoder2)`") + } + + pub(in crate::worker::coder_from_urn) fn encode_from_urn( + &self, + urn: &str, + elem: &dyn crate::elem_types::ElemType, + writer: &mut dyn std::io::Write, + context: &crate::coders::Context, + ) -> Result { + (self.enc)(urn, elem, writer, context) + } + + pub(in crate::worker::coder_from_urn) fn decode_from_urn( + &self, + urn: &str, + reader: &mut dyn std::io::Read, + context: &crate::coders::Context, + ) -> Result, std::io::Error> { + (self.dec)(urn, reader, context) + } +} + +impl fmt::Debug for CustomCoderFromUrn { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CodersFromUrn").finish() + } +} diff --git a/sdks/rust/src/worker/coder_from_urn/mod.rs b/sdks/rust/src/worker/coder_from_urn/mod.rs index c54b34e15db4..30e7209cbcee 100644 --- a/sdks/rust/src/worker/coder_from_urn/mod.rs +++ b/sdks/rust/src/worker/coder_from_urn/mod.rs @@ -1,47 +1,33 @@ -use std::fmt; +mod custom_coder_from_urn; +pub use custom_coder_from_urn::{CustomCoderFromUrn, CUSTOM_CODER_FROM_URN}; -use once_cell::sync::OnceCell; +use crate::worker::coder_from_urn::preset_coder_from_urn::PresetCoderFromUrn; -use crate::coders::{DecodeFromUrnFn, EncodeFromUrnFn}; +mod preset_coder_from_urn; -/// Called internally from `register_coders!` macro. -pub static CODER_FROM_URN: OnceCell = OnceCell::new(); - -/// Called internally from `register_coders!` macro. -pub struct CoderFromUrn { - pub enc: EncodeFromUrnFn, - pub dec: DecodeFromUrnFn, -} +pub(in crate::worker) struct CoderFromUrn; impl CoderFromUrn { - pub fn global() -> &'static CoderFromUrn { - CODER_FROM_URN - .get() - .expect("you might forget calling `register_coders!(CustomCoder1, CustomCoder2)`") - } - - pub fn encode_from_urn( - &self, + pub(in crate::worker) fn encode_from_urn( urn: &str, elem: &dyn crate::elem_types::ElemType, writer: &mut dyn std::io::Write, context: &crate::coders::Context, ) -> Result { - (self.enc)(urn, elem, writer, context) + PresetCoderFromUrn::encode_from_urn(urn, elem, writer, context).unwrap_or_else(|| { + let custom = CustomCoderFromUrn::global(); + (custom.enc)(urn, elem, writer, context) + }) } - pub fn decode_from_urn( - &self, + pub(in crate::worker) fn decode_from_urn( urn: &str, reader: &mut dyn std::io::Read, context: &crate::coders::Context, ) -> Result, std::io::Error> { - (self.dec)(urn, reader, context) - } -} - -impl fmt::Debug for CoderFromUrn { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("CodersFromUrn").finish() + PresetCoderFromUrn::decode_from_urn(urn, reader, context).unwrap_or_else(|| { + let custom = CustomCoderFromUrn::global(); + (custom.dec)(urn, reader, context) + }) } } diff --git a/sdks/rust/src/worker/coder_from_urn/preset_coder_from_urn/mod.rs b/sdks/rust/src/worker/coder_from_urn/preset_coder_from_urn/mod.rs new file mode 100644 index 000000000000..d848eac92aa0 --- /dev/null +++ b/sdks/rust/src/worker/coder_from_urn/preset_coder_from_urn/mod.rs @@ -0,0 +1,51 @@ +use crate::coders::{ + required_coders::BytesCoder, rust_coders::GeneralObjectCoder, standard_coders::StrUtf8Coder, + urns::PresetCoderUrn, Coder, +}; +use strum::IntoEnumIterator; + +#[derive(Eq, PartialEq, Debug)] +pub(in crate::worker::coder_from_urn) struct PresetCoderFromUrn; + +impl PresetCoderFromUrn { + /// Returns `None` if the urn is not a preset coder urn. + pub(in crate::worker) fn encode_from_urn( + urn: &str, + elem: &dyn crate::elem_types::ElemType, + writer: &mut dyn std::io::Write, + context: &crate::coders::Context, + ) -> Option> { + let opt_variant = PresetCoderUrn::iter().find(|variant| variant.as_str() == urn); + + opt_variant.map(|variant| match variant { + PresetCoderUrn::Bytes => BytesCoder::default().encode(elem, writer, context), + PresetCoderUrn::Kv => todo!("create full type including components (not only urn but also full proto maybe required"), + PresetCoderUrn::Iterable => todo!("create full type including components (not only urn but also full proto maybe required"), + PresetCoderUrn::StrUtf8 => StrUtf8Coder::default().encode(elem, writer, context), + PresetCoderUrn::VarInt => todo!("create full type including components (not only urn but also full proto maybe required"), + PresetCoderUrn::Unit => todo!("make UnitCoder"), + PresetCoderUrn::GeneralObject => { + GeneralObjectCoder::default().encode(elem, writer, context) + } + }) + } + + /// Returns `None` if the urn is not a preset coder urn. + pub(in crate::worker) fn decode_from_urn( + urn: &str, + reader: &mut dyn std::io::Read, + context: &crate::coders::Context, + ) -> Option, std::io::Error>> { + let opt_variant = PresetCoderUrn::iter().find(|variant| variant.as_str() == urn); + + opt_variant.map(|variant| match variant { + PresetCoderUrn::Bytes => BytesCoder::default().decode(reader, context), + PresetCoderUrn::Kv => todo!("create full type including components (not only urn but also full proto maybe required"), + PresetCoderUrn::Iterable => todo!("create full type including components (not only urn but also full proto maybe required"), + PresetCoderUrn::StrUtf8 => StrUtf8Coder::default().decode(reader, context), + PresetCoderUrn::VarInt => todo!("create full type including components (not only urn but also full proto maybe required"), + PresetCoderUrn::Unit => todo!("make UnitCoder"), + PresetCoderUrn::GeneralObject => GeneralObjectCoder::default().decode(reader, context), + }) + } +} diff --git a/sdks/rust/src/worker/mod.rs b/sdks/rust/src/worker/mod.rs index b92c43abd01e..9e2c17c36076 100644 --- a/sdks/rust/src/worker/mod.rs +++ b/sdks/rust/src/worker/mod.rs @@ -20,7 +20,9 @@ mod external_worker_service; pub mod operators; mod coder_from_urn; -pub use coder_from_urn::{CoderFromUrn, CODER_FROM_URN}; +pub(in crate::worker) use coder_from_urn::CoderFromUrn; +pub use coder_from_urn::{CustomCoderFromUrn, CUSTOM_CODER_FROM_URN}; + mod interceptors; pub use external_worker_service::ExternalWorkerPool; @@ -46,10 +48,7 @@ mod tests { use std::{collections::HashMap, sync::Arc}; use crate::internals::urns; - use crate::proto::{ - fn_execution_v1, - pipeline_v1, - }; + use crate::proto::{fn_execution_v1, pipeline_v1}; use crate::{ worker::sdk_worker::BundleProcessor, @@ -145,3 +144,212 @@ mod tests { ); } } + +#[cfg(test)] +mod serde_preset_coder_test { + mod sdk_launcher { + use crate::{ + coders::{standard_coders::StrUtf8Coder, Coder}, + proto::pipeline::v1 as pipeline_v1, + }; + + pub fn launcher_register_coder_proto() -> pipeline_v1::Coder { + // in the proto registration (in the pipeline construction) + let coder = StrUtf8Coder::default(); + coder.to_proto(vec![]) + } + } + + mod sdk_harness { + use bytes::Buf; + use std::io; + + use crate::{ + coders::Context, elem_types::ElemType, proto::pipeline::v1 as pipeline_v1, + worker::CoderFromUrn, + }; + + fn receive_coder() -> pipeline_v1::Coder { + // serialized coder is sent from the launcher + super::sdk_launcher::launcher_register_coder_proto() + } + + fn create_element() -> String { + // A PTransform (UDF) create an instance of i32 + "hello".to_string() + } + + fn encode_element(element: &dyn ElemType, coder: &pipeline_v1::Coder) -> Vec { + let urn = &coder.spec.as_ref().unwrap().urn; + + let mut encoded_element = vec![]; + CoderFromUrn::encode_from_urn( + urn, + element, + &mut encoded_element, + &Context::WholeStream, + ) + .unwrap(); + + encoded_element + } + + fn decode_element( + elem_reader: &mut dyn io::Read, + coder: &pipeline_v1::Coder, + ) -> Box { + let urn = &coder.spec.as_ref().unwrap().urn; + + let decoded_element_dyn = + CoderFromUrn::decode_from_urn(urn, elem_reader, &Context::WholeStream).unwrap(); + + decoded_element_dyn + } + + pub fn test() { + let coder = receive_coder(); + let element = create_element(); + + let encoded_element = encode_element(&element, &coder); + let decoded_element_dyn = decode_element(&mut encoded_element.reader(), &coder); + + let decoded_element = decoded_element_dyn.as_any().downcast_ref::().unwrap(); + + assert_eq!(decoded_element, &element); + } + } + + #[test] + fn serde_custom_coder() { + sdk_harness::test(); + } +} + +#[cfg(test)] +mod serde_costom_coder_test { + mod sdk_launcher { + use crate::{ + coders::{Coder, Context}, + elem_types::ElemType, + proto::pipeline::v1 as pipeline_v1, + register_coders, + }; + + #[derive(Clone, PartialEq, Eq, Debug)] + pub struct MyElement { + pub some_field: String, + } + + impl ElemType for MyElement {} + + #[derive(Debug, Default)] + struct MyCoder; + + impl Coder for MyCoder { + fn encode( + &self, + element: &dyn ElemType, + writer: &mut dyn std::io::Write, + _context: &Context, + ) -> Result { + let element = element.as_any().downcast_ref::().unwrap(); + + writer + .write_all(format!("ENCPREFIX{}", element.some_field).as_bytes()) + .map(|_| 0) // TODO make Result to Result<(), std::io::Error> + } + + fn decode( + &self, + reader: &mut dyn std::io::Read, + _context: &Context, + ) -> Result, std::io::Error> { + let mut buf = Vec::new(); + reader.read_to_end(&mut buf)?; + + let encoded_element = String::from_utf8(buf).unwrap(); + let element = encoded_element.strip_prefix("ENCPREFIX").unwrap(); + Ok(Box::new(MyElement { + some_field: element.to_string(), + })) + } + } + + register_coders!(MyCoder); + + pub fn launcher_register_coder_proto() -> pipeline_v1::Coder { + // in the proto registration (in the pipeline construction) + let my_coder = MyCoder::default(); + my_coder.to_proto(vec![]) + } + } + + mod sdk_harness { + use bytes::Buf; + use std::io; + + use crate::{ + coders::Context, elem_types::ElemType, proto::pipeline::v1 as pipeline_v1, + worker::CoderFromUrn, + }; + + fn receive_coder() -> pipeline_v1::Coder { + // serialized coder is sent from the launcher + super::sdk_launcher::launcher_register_coder_proto() + } + + fn create_my_element() -> super::sdk_launcher::MyElement { + // A PTransform (UDF) create an instance of MyElement + super::sdk_launcher::MyElement { + some_field: "some_value".to_string(), + } + } + + fn encode_element(element: &dyn ElemType, coder: &pipeline_v1::Coder) -> Vec { + let urn = &coder.spec.as_ref().unwrap().urn; + + let mut encoded_element = vec![]; + CoderFromUrn::encode_from_urn( + urn, + element, + &mut encoded_element, + &Context::WholeStream, + ) + .unwrap(); + + encoded_element + } + + fn decode_element( + elem_reader: &mut dyn io::Read, + coder: &pipeline_v1::Coder, + ) -> Box { + let urn = &coder.spec.as_ref().unwrap().urn; + + let decoded_element_dyn = + CoderFromUrn::decode_from_urn(urn, elem_reader, &Context::WholeStream).unwrap(); + + decoded_element_dyn + } + + pub fn test() { + let coder = receive_coder(); + let element = create_my_element(); + + let encoded_element = encode_element(&element, &coder); + let decoded_element_dyn = decode_element(&mut encoded_element.reader(), &coder); + + let decoded_element = decoded_element_dyn + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(decoded_element, &element); + } + } + + #[test] + fn serde_custom_coder() { + sdk_harness::test(); + } +} diff --git a/sdks/rust/src/worker/operators.rs b/sdks/rust/src/worker/operators.rs index 50e20c6c0e8b..5abc84bd5923 100644 --- a/sdks/rust/src/worker/operators.rs +++ b/sdks/rust/src/worker/operators.rs @@ -27,6 +27,7 @@ use std::sync::{Arc, Mutex}; use once_cell::sync::Lazy; use serde_json; +use strum::EnumDiscriminants; use crate::elem_types::ElemType; use crate::internals::serialize; diff --git a/sdks/rust/tests/serde_custom_coder_test.rs b/sdks/rust/tests/serde_custom_coder_test.rs deleted file mode 100644 index cdfa2f50bf9b..000000000000 --- a/sdks/rust/tests/serde_custom_coder_test.rs +++ /dev/null @@ -1,122 +0,0 @@ -mod sdk_launcher { - use apache_beam::{ - coders::{Coder, Context}, - elem_types::ElemType, - proto::pipeline::v1 as pipeline_v1, - register_coders, - }; - - #[derive(Clone, PartialEq, Eq, Debug)] - pub struct MyElement { - pub some_field: String, - } - - impl ElemType for MyElement {} - - #[derive(Debug, Default)] - struct MyCoder; - - impl Coder for MyCoder { - fn encode( - &self, - element: &dyn ElemType, - writer: &mut dyn std::io::Write, - _context: &Context, - ) -> Result { - let element = element.as_any().downcast_ref::().unwrap(); - - writer - .write_all(format!("ENCPREFIX{}", element.some_field).as_bytes()) - .map(|_| 0) // TODO make Result to Result<(), std::io::Error> - } - - fn decode( - &self, - reader: &mut dyn std::io::Read, - _context: &Context, - ) -> Result, std::io::Error> { - let mut buf = Vec::new(); - reader.read_to_end(&mut buf)?; - - let encoded_element = String::from_utf8(buf).unwrap(); - let element = encoded_element.strip_prefix("ENCPREFIX").unwrap(); - Ok(Box::new(MyElement { - some_field: element.to_string(), - })) - } - } - - register_coders!(MyCoder); - - pub fn launcher_register_coder_proto() -> pipeline_v1::Coder { - // in the proto registration (in the pipeline construction) - let my_coder = MyCoder::default(); - my_coder.to_proto(vec![]) - } -} - -mod sdk_harness { - use bytes::Buf; - use std::io; - - use apache_beam::{ - coders::Context, elem_types::ElemType, proto::pipeline::v1 as pipeline_v1, - worker::CoderFromUrn, - }; - - fn receive_coder() -> pipeline_v1::Coder { - // serialized coder is sent from the launcher - super::sdk_launcher::launcher_register_coder_proto() - } - - fn create_my_element() -> super::sdk_launcher::MyElement { - // A PTransform (UDF) create an instance of MyElement - super::sdk_launcher::MyElement { - some_field: "some_value".to_string(), - } - } - - fn encode_element(element: &dyn ElemType, coder: &pipeline_v1::Coder) -> Vec { - let urn = &coder.spec.as_ref().unwrap().urn; - - let mut encoded_element = vec![]; - CoderFromUrn::global() - .encode_from_urn(urn, element, &mut encoded_element, &Context::WholeStream) - .unwrap(); - - encoded_element - } - - fn decode_element( - elem_reader: &mut dyn io::Read, - coder: &pipeline_v1::Coder, - ) -> Box { - let urn = &coder.spec.as_ref().unwrap().urn; - - let decoded_element_dyn = CoderFromUrn::global() - .decode_from_urn(urn, elem_reader, &Context::WholeStream) - .unwrap(); - - decoded_element_dyn - } - - pub fn test() { - let coder = receive_coder(); - let element = create_my_element(); - - let encoded_element = encode_element(&element, &coder); - let decoded_element_dyn = decode_element(&mut encoded_element.reader(), &coder); - - let decoded_element = decoded_element_dyn - .as_any() - .downcast_ref::() - .unwrap(); - - assert_eq!(decoded_element, &element); - } -} - -#[test] -fn serde_custom_coder() { - sdk_harness::test(); -}