Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(protobuf): support any for protobuf message source #12291

Merged
merged 36 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
2b973e4
feat(protobuf): support Any as message source
xzhseh Sep 12, 2023
dbaa746
update
xzhseh Sep 13, 2023
a23b50b
support any protobuf
xzhseh Sep 20, 2023
bb1cd88
Merge branch 'main' into xzhseh/feat-protobuf-any
xzhseh Sep 20, 2023
3ab91ad
fix test
xzhseh Sep 21, 2023
31bd84b
Merge branch 'main' into xzhseh/feat-protobuf-any
xzhseh Sep 21, 2023
b38bfc5
Merge branch 'main' into xzhseh/feat-protobuf-any
xzhseh Sep 22, 2023
e90f067
update
xzhseh Sep 27, 2023
eb85fcd
support nested messages & address issues
xzhseh Sep 27, 2023
ed7c424
update the output format
xzhseh Oct 18, 2023
aa0e2ad
use full_name() instead of DataType::Jsonb
xzhseh Oct 23, 2023
c4a4cb9
fix(protobuf): recursive Any field (#13008)
Rossil2012 Oct 23, 2023
314d8df
Merge branch 'main' into xzhseh/feat-protobuf-any
xzhseh Oct 23, 2023
b1d8d77
bring back risedev.yml
xzhseh Oct 23, 2023
8fe6fb3
Fix "cargo-hakari"
xzhseh Oct 23, 2023
5f0c105
fix format
xzhseh Oct 23, 2023
d221ef9
Merge branch 'main' into xzhseh/feat-protobuf-any
xzhseh Oct 24, 2023
28f2b88
change back to name for unit test
xzhseh Oct 26, 2023
8eac261
Merge branch 'main' into xzhseh/feat-protobuf-any
xzhseh Oct 26, 2023
90ab076
fix check
xzhseh Oct 26, 2023
74f48f1
bump back preserve order for serde_json
xzhseh Oct 26, 2023
813600d
Merge branch 'main' into xzhseh/feat-protobuf-any
xzhseh Oct 26, 2023
d28be41
Fix "cargo-hakari"
xzhseh Oct 26, 2023
2e5570f
add log for e2e_test debug
xzhseh Oct 26, 2023
a32c64e
add type hint for any type
xzhseh Oct 26, 2023
f7f4557
Merge branch 'main' into xzhseh/feat-protobuf-any
xzhseh Oct 26, 2023
2130071
fix test
xzhseh Oct 26, 2023
b143a7a
fix test
xzhseh Oct 27, 2023
a93fcf4
fix test
xzhseh Oct 27, 2023
9e02c3d
add default value for any type when the corresponding value does not …
xzhseh Oct 27, 2023
232d2e7
fix format
xzhseh Oct 27, 2023
643a8c4
Merge branch 'main' into xzhseh/feat-protobuf-any
xzhseh Oct 27, 2023
707f601
Merge branch 'main' into xzhseh/feat-protobuf-any
xzhseh Oct 27, 2023
b4155cf
fix err_msg
xzhseh Oct 31, 2023
828b46d
Merge branch 'main' into xzhseh/feat-protobuf-any
xzhseh Oct 31, 2023
ffb9e87
fix format
xzhseh Oct 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
289 changes: 255 additions & 34 deletions src/connector/src/parser/protobuf/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use std::path::Path;
use std::sync::Arc;

use itertools::Itertools;
use prost_reflect::{
Expand All @@ -39,6 +40,7 @@ use crate::parser::{AccessBuilder, EncodingProperties};
pub struct ProtobufAccessBuilder {
confluent_wire_type: bool,
message_descriptor: MessageDescriptor,
descriptor_pool: Arc<DescriptorPool>,
}

impl AccessBuilder for ProtobufAccessBuilder {
Expand All @@ -53,7 +55,10 @@ impl AccessBuilder for ProtobufAccessBuilder {
let message = DynamicMessage::decode(self.message_descriptor.clone(), payload)
.map_err(|e| ProtocolError(format!("parse message failed: {}", e)))?;

Ok(AccessImpl::Protobuf(ProtobufAccess::new(message)))
Ok(AccessImpl::Protobuf(ProtobufAccess::new(
message,
Arc::clone(&self.descriptor_pool),
)))
}
}

Expand All @@ -62,10 +67,13 @@ impl ProtobufAccessBuilder {
let ProtobufParserConfig {
confluent_wire_type,
message_descriptor,
descriptor_pool,
} = config;

Ok(Self {
confluent_wire_type,
message_descriptor,
descriptor_pool,
})
}
}
Expand All @@ -74,6 +82,8 @@ impl ProtobufAccessBuilder {
pub struct ProtobufParserConfig {
confluent_wire_type: bool,
message_descriptor: MessageDescriptor,
/// Note that the pub(crate) here is merely for testing
pub(crate) descriptor_pool: Arc<DescriptorPool>,
}

impl ProtobufParserConfig {
Expand Down Expand Up @@ -135,15 +145,18 @@ impl ProtobufParserConfig {
location, e
))
})?;

let message_descriptor = pool.get_message_by_name(message_name).ok_or_else(|| {
ProtocolError(format!(
"cannot find message {} in schema: {}.\n poll is {:?}",
"Cannot find message {} in schema: {}.\nDescriptor pool is {:?}",
message_name, location, pool
))
})?;

Ok(Self {
message_descriptor,
confluent_wire_type: protobuf_config.use_schema_registry,
descriptor_pool: Arc::new(pool),
})
}

Expand Down Expand Up @@ -224,7 +237,11 @@ fn detect_loop_and_push(trace: &mut Vec<String>, fd: &FieldDescriptor) -> Result
Ok(())
}

pub fn from_protobuf_value(field_desc: &FieldDescriptor, value: &Value) -> Result<Datum> {
pub fn from_protobuf_value(
field_desc: &FieldDescriptor,
value: &Value,
descriptor_pool: &Arc<DescriptorPool>,
) -> Result<Datum> {
let v = match value {
Value::Bool(v) => ScalarImpl::Bool(*v),
Value::I32(i) => ScalarImpl::Int32(*i),
Expand All @@ -250,30 +267,76 @@ pub fn from_protobuf_value(field_desc: &FieldDescriptor, value: &Value) -> Resul
ScalarImpl::Utf8(enum_symbol.name().into())
}
Value::Message(dyn_msg) => {
let mut rw_values = Vec::with_capacity(dyn_msg.descriptor().fields().len());
// fields is a btree map in descriptor
// so it's order is the same as datatype
for field_desc in dyn_msg.descriptor().fields() {
// missing field
if !dyn_msg.has_field(&field_desc)
&& field_desc.cardinality() == Cardinality::Required
{
let err_msg = format!(
"protobuf parse error.missing required field {:?}",
field_desc
);
return Err(RwError::from(ProtocolError(err_msg)));
if dyn_msg.has_field_by_name("type_url") && dyn_msg.has_field_by_name("value") {
// The message is of type `Any`
xzhseh marked this conversation as resolved.
Show resolved Hide resolved
debug_assert!(
dyn_msg.fields().count() == 2,
"Expected only two fields for Any Type MessageDescriptor"
);

let type_url = dyn_msg
.get_field_by_name("type_url")
.expect("Expect type_url in dyn_msg");
let payload = dyn_msg
.get_field_by_name("value")
.expect("Expect value (payload) in dyn_msg")
.as_ref()
.clone();
xzhseh marked this conversation as resolved.
Show resolved Hide resolved

let type_url =
type_url.to_string().split('/').collect::<Vec<&str>>()[1].to_string();
let type_url = type_url[..type_url.len() - 1].to_string();
let payload_field_desc = dyn_msg.descriptor().get_field_by_name("value").unwrap();
let Some(ScalarImpl::Bytea(payload)) =
from_protobuf_value(&payload_field_desc, &payload, descriptor_pool)?
else {
panic!("Expected ScalarImpl::Bytea for payload");
xzhseh marked this conversation as resolved.
Show resolved Hide resolved
};

// Get the corresponding schema from the descriptor pool
let msg_desc = descriptor_pool
.get_message_by_name(&type_url)
.ok_or_else(|| {
ProtocolError(format!(
"Cannot find message {} in from_protobuf_value.\nDescriptor pool is {:#?}",
type_url, descriptor_pool
))
})?;

// Decode the payload based on the `msg_desc`
let decoded_value = DynamicMessage::decode(msg_desc, payload.as_ref()).unwrap();

return from_protobuf_value(
field_desc,
&Value::Message(decoded_value),
descriptor_pool,
);
} else {
let mut rw_values = Vec::with_capacity(dyn_msg.descriptor().fields().len());
// fields is a btree map in descriptor
// so it's order is the same as datatype
for field_desc in dyn_msg.descriptor().fields() {
// missing field
if !dyn_msg.has_field(&field_desc)
&& field_desc.cardinality() == Cardinality::Required
{
let err_msg = format!(
"protobuf parse error.missing required field {:?}",
field_desc
);
return Err(RwError::from(ProtocolError(err_msg)));
}
// use default value if dyn_msg doesn't has this field
let value = dyn_msg.get_field(&field_desc);
rw_values.push(from_protobuf_value(&field_desc, &value, descriptor_pool)?);
}
// use default value if dyn_msg doesn't has this field
let value = dyn_msg.get_field(&field_desc);
rw_values.push(from_protobuf_value(&field_desc, &value)?);
ScalarImpl::Struct(StructValue::new(rw_values))
tabVersion marked this conversation as resolved.
Show resolved Hide resolved
}
ScalarImpl::Struct(StructValue::new(rw_values))
}
Value::List(values) => {
let rw_values = values
.iter()
.map(|value| from_protobuf_value(field_desc, value))
.map(|value| from_protobuf_value(field_desc, value, descriptor_pool))
.collect::<Result<Vec<_>>>()?;
ScalarImpl::List(ListValue::new(rw_values))
}
Expand Down Expand Up @@ -313,6 +376,7 @@ fn protobuf_type_mapping(
.map(|f| protobuf_type_mapping(&f, parse_trace))
.collect::<Result<Vec<_>>>()?;
let field_names = m.fields().map(|f| f.name().to_string()).collect_vec();

DataType::new_struct(fields, field_names)
}
Kind::Enum(_) => DataType::Varchar,
Expand All @@ -334,7 +398,7 @@ fn protobuf_type_mapping(
pub(crate) fn resolve_pb_header(payload: &[u8]) -> Result<&[u8]> {
// there's a message index array at the front of payload
// if it is the first message in proto def, the array is just and `0`
// TODO: support parsing more complex indec array
// TODO: support parsing more complex index array
let (_, remained) = extract_schema_id(payload)?;
match remained.first() {
Some(0) => Ok(&remained[1..]),
Expand Down Expand Up @@ -655,18 +719,18 @@ mod test {
Some(ScalarImpl::Int32(500000000)),
])),
);
pb_eq(
a,
"any_field",
S::Struct(StructValue::new(vec![
Some(ScalarImpl::Utf8(
m.any_field.as_ref().unwrap().type_url.as_str().into(),
)),
Some(ScalarImpl::Bytea(
m.any_field.as_ref().unwrap().value.clone().into(),
)),
])),
);
// pb_eq(
// a,
// "any_field",
// S::Struct(StructValue::new(vec![
// Some(ScalarImpl::Utf8(
// m.any_field.as_ref().unwrap().type_url.as_str().into(),
// )),
// Some(ScalarImpl::Bytea(
// m.any_field.as_ref().unwrap().value.clone().into(),
// )),
// ])),
// );
xzhseh marked this conversation as resolved.
Show resolved Hide resolved
pb_eq(
a,
"int32_value_field",
Expand Down Expand Up @@ -729,4 +793,161 @@ mod test {
example_oneof: Some(ExampleOneof::OneofInt32(123)),
}
}

// id: 12345
// name {
// type_url: "type.googleapis.com/test.StringValue"
// value: "\n\010John Doe"
// }
static ANY_GEN_PROTO_DATA: &[u8] = b"\x08\xb9\x60\x12\x32\x0a\x24\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x53\x74\x72\x69\x6e\x67\x56\x61\x6c\x75\x65\x12\x0a\x0a\x08\x4a\x6f\x68\x6e\x20\x44\x6f\x65";

#[tokio::test]
async fn test_any_schema() -> Result<()> {
let location = schema_dir() + "/any-schema.pb";
println!("location: {}", location);
let message_name = "test.TestAny";
let info = StreamSourceInfo {
proto_message_name: message_name.to_string(),
row_schema_location: location.to_string(),
use_schema_registry: false,
..Default::default()
};

let parser_config = SpecificParserConfig::new(
SourceStruct::new(SourceFormat::Plain, SourceEncode::Protobuf),
&info,
&HashMap::new(),
)?;

let conf = ProtobufParserConfig::new(parser_config.encoding_config).await?;

println!("Current conf: {:#?}", conf);
println!("---------------------------");

let value =
DynamicMessage::decode(conf.message_descriptor.clone(), ANY_GEN_PROTO_DATA).unwrap();

println!("Test ANY_GEN_PROTO_DATA, current value: {:#?}", value);
println!("---------------------------");

// This is of no use
let field = value.fields().next().unwrap().0;

if let Some(ret) =
from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool).unwrap()
{
println!("Decoded Value for ANY_GEN_PROTO_DATA: {:#?}", ret);
println!("---------------------------");

let ScalarImpl::Struct(struct_value) = ret else {
panic!("Expected ScalarImpl::Struct");
};

let fields = struct_value.fields();

match fields[0].clone() {
Some(ScalarImpl::Int32(v)) => {
println!("Successfully decode field[0]");
assert_eq!(v, 12345);
}
_ => panic!("Expected ScalarImpl::Int32"),
}

match fields[1].clone() {
Some(ScalarImpl::Struct(sv)) => {
let fields = sv.fields();
debug_assert!(fields.len() == 1, "Expected only one field");
match fields[0].clone() {
Some(ScalarImpl::Utf8(v)) => {
println!("Successfully decode field[0] for any type");
assert_eq!(v.to_string(), "John Doe");
}
_ => panic!("Expected ScalarImpl::Int32"),
}
}
_ => panic!("Expected ScalarImpl::Struct"),
}
}

Ok(())
}

// id: 12345
// name {
// type_url: "type.googleapis.com/test.Int32Value"
// value: "\010\322\376\006"
// }
// Unpacked Int32Value from Any: value: 114514
static ANY_GEN_PROTO_DATA_1: &[u8] = b"\x08\xb9\x60\x12\x2b\x0a\x23\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x49\x6e\x74\x33\x32\x56\x61\x6c\x75\x65\x12\x04\x08\xd2\xfe\x06";

#[tokio::test]
async fn test_any_schema_1() -> Result<()> {
let location = schema_dir() + "/any-schema.pb";
println!("location: {}", location);
let message_name = "test.TestAny";
let info = StreamSourceInfo {
proto_message_name: message_name.to_string(),
row_schema_location: location.to_string(),
use_schema_registry: false,
..Default::default()
};

let parser_config = SpecificParserConfig::new(
SourceStruct::new(SourceFormat::Plain, SourceEncode::Protobuf),
&info,
&HashMap::new(),
)?;

let conf = ProtobufParserConfig::new(parser_config.encoding_config).await?;

println!("Current conf: {:#?}", conf);
println!("---------------------------");

let value =
DynamicMessage::decode(conf.message_descriptor.clone(), ANY_GEN_PROTO_DATA_1).unwrap();

println!("Current Value: {:#?}", value);
println!("---------------------------");

// This is of no use
let field = value.fields().next().unwrap().0;

if let Some(ret) =
from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool).unwrap()
{
println!("Decoded Value for ANY_GEN_PROTO_DATA: {:#?}", ret);
println!("---------------------------");

let ScalarImpl::Struct(struct_value) = ret else {
panic!("Expected ScalarImpl::Struct");
};

let fields = struct_value.fields();

match fields[0].clone() {
Some(ScalarImpl::Int32(v)) => {
println!("Successfully decode field[0]");
assert_eq!(v, 12345);
}
_ => panic!("Expected ScalarImpl::Int32"),
}

match fields[1].clone() {
Some(ScalarImpl::Struct(sv)) => {
let fields = sv.fields();
debug_assert!(fields.len() == 1, "Expected only one field");
match fields[0].clone() {
Some(ScalarImpl::Int32(v)) => {
println!("Successfully decode field[0] for any type");
assert_eq!(v, 114514);
}
_ => panic!("Expected ScalarImpl::Int32"),
}
}
_ => panic!("Expected ScalarImpl::Struct"),
}
}

Ok(())
}
}
Loading