Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add confluent schema registry support for protobuf #724

Merged
merged 4 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

204 changes: 149 additions & 55 deletions crates/arroyo-api/src/connection_tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use axum::Json;
use axum_extra::extract::WithRejection;
use futures_util::stream::Stream;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::convert::Infallible;
use tokio::sync::mpsc::channel;
use tokio_stream::wrappers::ReceiverStream;
Expand All @@ -19,7 +20,7 @@ use arroyo_formats::{avro, json, proto};
use arroyo_operator::connector::ErasedConnector;
use arroyo_rpc::api_types::connections::{
ConnectionProfile, ConnectionSchema, ConnectionTable, ConnectionTablePost, ConnectionType,
SchemaDefinition,
SchemaDefinition, SourceField,
};
use arroyo_rpc::api_types::{ConnectionTableCollection, PaginationQueryParams};
use arroyo_rpc::formats::{AvroFormat, Format, JsonFormat, ProtobufFormat};
Expand Down Expand Up @@ -467,7 +468,16 @@ pub(crate) async fn expand_schema(
Format::Parquet(_) => Ok(schema),
Format::RawString(_) => Ok(schema),
Format::RawBytes(_) => Ok(schema),
Format::Protobuf(_) => expand_proto_schema(schema).await,
Format::Protobuf(_) => {
expand_proto_schema(
connector,
connection_type,
schema,
profile_config,
table_config,
)
.await
}
}
}

Expand All @@ -486,7 +496,7 @@ async fn expand_avro_schema(
let schema_response = get_schema(connector, table_config, profile_config).await?;
match connection_type {
ConnectionType::Source => {
let schema_response = schema_response.ok_or_else(|| bad_request(
let (schema_response, _) = schema_response.ok_or_else(|| bad_request(
"No schema was found; ensure that the topic exists and has a value schema configured in the schema registry".to_string()))?;

if schema_response.schema_type != ConfluentSchemaType::Avro {
Expand Down Expand Up @@ -535,64 +545,123 @@ async fn expand_avro_schema(
Ok(schema)
}

async fn expand_proto_schema(mut schema: ConnectionSchema) -> Result<ConnectionSchema, ErrorResp> {
async fn expand_proto_schema(
connector: &str,
connection_type: ConnectionType,
mut schema: ConnectionSchema,
profile_config: &Value,
table_config: &Value,
) -> Result<ConnectionSchema, ErrorResp> {
let Some(Format::Protobuf(ProtobufFormat {
message_name,
compiled_schema,
confluent_schema_registry,
..
})) = &mut schema.format
else {
panic!("not proto");
};

if let Some(definition) = &schema.definition {
let SchemaDefinition::ProtobufSchema {
schema: protobuf_schema,
dependencies,
} = &definition
else {
return Err(bad_request("Schema is not a protobuf schema"));
};
if *confluent_schema_registry {
let schema_response = get_schema(connector, table_config, profile_config).await?;
match connection_type {
ConnectionType::Source => {
let (schema_response, dependencies) = schema_response.ok_or_else(|| bad_request(
"No schema was found; ensure that the topic exists and has a value schema configured in the schema registry".to_string()))?;

let message_name = message_name
.as_ref()
.filter(|m| !m.is_empty())
.ok_or_else(|| bad_request("message name must be provided for protobuf schemas"))?;
if schema_response.schema_type != ConfluentSchemaType::Protobuf {
return Err(bad_request(format!(
"Format configured is protobuf, but confluent schema repository returned a {:?} schema",
schema_response.schema_type
)));
}

let encoded = schema_file_to_descriptor(protobuf_schema, dependencies)
.await
.map_err(|e| bad_request(e.to_string()))?;
let dependencies: Result<HashMap<_, _>, ErrorResp> = dependencies
.into_iter()
.map(|(name, s)| {
if s.schema_type != ConfluentSchemaType::Protobuf {
return Err(bad_request(format!(
"Schema reference {} has type {:?}, but must be protobuf",
name, s.schema_type
)));
} else {
Ok((name, s.schema))
}
})
.collect();

schema.definition = Some(SchemaDefinition::ProtobufSchema {
schema: schema_response.schema,
dependencies: dependencies?,
});
}
ConnectionType::Sink => {
// don't fetch schemas for sinks for now
}
}
}

let pool = proto::schema::get_pool(&encoded)
.map_err(|e| bad_request(format!("error handling protobuf: {}", e)))?;
*compiled_schema = Some(encoded);
let Some(definition) = &schema.definition else {
return Err(bad_request("No definition for protobuf schema"));
};

let descriptor = pool.get_message_by_name(message_name).ok_or_else(|| {
bad_request(format!(
"Message '{}' not found in proto definition; messages are {}",
message_name,
pool.all_messages()
.map(|m| m.full_name().to_string())
.filter(|m| !m.starts_with("google.protobuf."))
.collect::<Vec<_>>()
.join(", ")
))
})?;
let SchemaDefinition::ProtobufSchema {
schema: protobuf_schema,
dependencies,
} = &definition
else {
return Err(bad_request("Schema is not a protobuf schema"));
};

let arrow = protobuf_to_arrow(&descriptor)
.map_err(|e| bad_request(format!("Failed to convert schema: {}", e)))?;
let (compiled, fields) =
expand_local_proto_schema(protobuf_schema, message_name, dependencies).await?;
*compiled_schema = Some(compiled);
schema.fields = fields;

let fields: Result<_, String> = arrow
.fields
.into_iter()
.map(|f| (**f).clone().try_into())
.collect();
Ok(schema)
}

schema.fields =
fields.map_err(|e| bad_request(format!("failed to convert schema: {}", e)))?;
};
async fn expand_local_proto_schema(
schema_def: &str,
message_name: &Option<String>,
dependencies: &HashMap<String, String>,
) -> Result<(Vec<u8>, Vec<SourceField>), ErrorResp> {
let message_name = message_name
.as_ref()
.filter(|m| !m.is_empty())
.ok_or_else(|| bad_request("message name must be provided for protobuf schemas"))?;

let encoded = schema_file_to_descriptor(schema_def, dependencies)
.await
.map_err(|e| bad_request(e.to_string()))?;

Ok(schema)
let pool = proto::schema::get_pool(&encoded)
.map_err(|e| bad_request(format!("error handling protobuf: {}", e)))?;

let descriptor = pool.get_message_by_name(message_name).ok_or_else(|| {
bad_request(format!(
"Message '{}' not found in proto definition; messages are {}",
message_name,
pool.all_messages()
.map(|m| m.full_name().to_string())
.filter(|m| !m.starts_with("google.protobuf."))
.collect::<Vec<_>>()
.join(", ")
))
})?;

let arrow = protobuf_to_arrow(&descriptor)
.map_err(|e| bad_request(format!("Failed to convert schema: {}", e)))?;

let fields: Result<_, String> = arrow
.fields
.into_iter()
.map(|f| (**f).clone().try_into())
.collect();

let fields = fields.map_err(|e| bad_request(format!("failed to convert schema: {}", e)))?;

Ok((encoded, fields))
}

async fn expand_json_schema(
Expand All @@ -616,15 +685,15 @@ async fn expand_json_schema(
let schema_response = schema_response.ok_or_else(|| bad_request(
"No schema was found; ensure that the topic exists and has a value schema configured in the schema registry".to_string()))?;

if schema_response.schema_type != ConfluentSchemaType::Json {
if schema_response.0.schema_type != ConfluentSchemaType::Json {
return Err(bad_request(format!(
"Format configured is json, but confluent schema repository returned a {:?} schema",
schema_response.schema_type
schema_response.0.schema_type
)));
}
confluent_schema_id.replace(schema_response.version);
confluent_schema_id.replace(schema_response.0.version);

schema.definition = Some(SchemaDefinition::JsonSchema(schema_response.schema));
schema.definition = Some(SchemaDefinition::JsonSchema(schema_response.0.schema));
}
ConnectionType::Sink => {
// don't fetch schemas for sinks for now until we're better able to conform our output to the schema
Expand Down Expand Up @@ -658,7 +727,13 @@ async fn get_schema(
connector: &str,
table_config: &Value,
profile_config: &Value,
) -> Result<Option<ConfluentSchemaSubjectResponse>, ErrorResp> {
) -> Result<
Option<(
ConfluentSchemaSubjectResponse,
Vec<(String, ConfluentSchemaSubjectResponse)>,
)>,
ErrorResp,
> {
let profile: KafkaConfig = match connector {
"kafka" => {
// we unwrap here because this should already have been validated
Expand Down Expand Up @@ -698,20 +773,30 @@ async fn get_schema(
)
.map_err(|e| {
bad_request(format!(
"failed to fetch schemas from schema repository: {}",
"failed to fetch schemas from schema registry: {}",
e
))
})?;

resolver.get_schema_for_version(None).await.map_err(|e| {
let Some(resp) = resolver.get_schema_for_version(None).await.map_err(|e| {
bad_request(format!(
"failed to fetch schemas from schema repository: {}",
"failed to fetch schemas from schema registry: {}",
e.chain()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join(": ")
))
})
})?
else {
return Ok(None);
};

let references = resolver
.resolve_references(&resp.references)
.await
.map_err(|e| bad_request(e.to_string()))?;

Ok(Some((resp, references)))
}

/// Test a Connection Schema
Expand Down Expand Up @@ -739,8 +824,17 @@ pub(crate) async fn test_schema(
Ok(())
}
}
SchemaDefinition::ProtobufSchema { .. } => {
let _ = expand_proto_schema(req.clone()).await?;
SchemaDefinition::ProtobufSchema {
schema,
dependencies,
} => {
let Some(Format::Protobuf(ProtobufFormat { message_name, .. })) = &req.format else {
return Err(bad_request(
"Schema has a protobuf definition but is not protobuf format",
));
};

let _ = expand_local_proto_schema(schema, message_name, dependencies).await?;
Ok(())
}
_ => {
Expand Down
3 changes: 2 additions & 1 deletion crates/arroyo-formats/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ prost-reflect = { workspace = true}
prost-build = { workspace = true }
prost-types = { workspace = true}
base64 = "0.22.1"
uuid = { version = "1.10.0", features = ["v4"] }
uuid = { version = "1.10.0", features = ["v4"] }
regex = "1.10.6"
6 changes: 5 additions & 1 deletion crates/arroyo-formats/src/proto/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@ use serde_json::Value as JsonValue;
pub(crate) fn deserialize_proto(
pool: &mut DescriptorPool,
proto: &ProtobufFormat,
msg: &[u8],
mut msg: &[u8],
) -> Result<serde_json::Value, SourceError> {
if proto.confluent_schema_registry {
msg = &msg[5..];
}

let message = proto.message_name.as_ref().expect("no message name");
let descriptor = pool.get_message_by_name(message).expect("no descriptor");

Expand Down
Loading
Loading