Skip to content

Commit

Permalink
use internal avro schema representation
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewinci committed Jan 5, 2023
1 parent 62f3597 commit 6d3465c
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 78 deletions.
30 changes: 24 additions & 6 deletions backend/src/lib/avro/avro_schema.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::collections::{BTreeMap, HashMap};

use apache_avro::{schema::Name, Schema};

Expand Down Expand Up @@ -29,16 +29,30 @@ pub enum AvroSchema {
Array(Box<AvroSchema>),
Map(Box<AvroSchema>),
Union(Vec<AvroSchema>),
Record { name: Name, fields: Vec<RecordField> },
Enum { name: Name, symbols: Vec<String> },
Fixed { name: Name, size: usize },
Decimal { precision: usize, scale: usize },
Record {
name: Name,
fields: Vec<RecordField>,
lookup: BTreeMap<String, usize>,
},
Enum {
name: Name,
symbols: Vec<String>,
},
Fixed {
name: Name,
size: usize,
},
Decimal {
precision: usize,
scale: usize,
},
}

#[derive(Clone, Debug, PartialEq)]
pub struct ResolvedAvroSchema {
pub id: i32,
pub schema: AvroSchema,
pub inner_schema: Schema,
}

impl ResolvedAvroSchema {
Expand Down Expand Up @@ -67,8 +81,11 @@ impl ResolvedAvroSchema {
Schema::Union(s) => {
AvroSchema::Union(s.variants().iter().map(|s| map(s, parent_ns, references)).collect())
}
Schema::Record { name, fields, .. } => AvroSchema::Record {
Schema::Record {
name, fields, lookup, ..
} => AvroSchema::Record {
name: name.clone(),
lookup: lookup.clone(),
fields: fields
.iter()
.map(|i| RecordField {
Expand Down Expand Up @@ -99,6 +116,7 @@ impl ResolvedAvroSchema {
Self {
id,
schema: map(schema, &None, &references),
inner_schema: schema.clone(),
}
}
}
Expand Down
35 changes: 9 additions & 26 deletions backend/src/lib/avro/avro_to_json.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use super::avro_schema::{AvroSchema as Schema, RecordField};
use super::{
avro_parser::AvroParser,
error::{AvroError, AvroResult},
helpers::{get_schema_id_from_record_header, get_schema_name},
schema_provider::SchemaProvider,
};
use apache_avro::{from_avro_datum, schema::Name, types::Value as AvroValue, Schema};
use apache_avro::{from_avro_datum, schema::Name, types::Value as AvroValue};
use num_bigint::BigInt;
use rust_decimal::Decimal;
use serde_json::{json, Map, Value as JsonValue};
Expand All @@ -19,9 +20,9 @@ impl<S: SchemaProvider> AvroParser<S> {
let mut data = Cursor::new(&raw[5..]);

// parse the avro record into an AvroValue
let record =
from_avro_datum(&schema.schema, &mut data, None).map_err(|err| AvroError::ParseAvroValue(err.to_string()))?;
let json = map(&record, &schema.schema, &None, &schema.resolved_schemas)?;
let record = from_avro_datum(&schema.inner_schema, &mut data, None)
.map_err(|err| AvroError::ParseAvroValue(err.to_string()))?;
let json = map(&record, &schema.schema, &None, &HashMap::new())?;
let res = serde_json::to_string(&json).map_err(|err| AvroError::ParseJsonValue(err.to_string()))?;
Ok(res)
}
Expand Down Expand Up @@ -67,7 +68,7 @@ fn map(
if *v == Box::new(AvroValue::Null) {
Ok(JsonValue::Null)
} else {
let schema = s.variants().get(*i as usize).ok_or_else(|| {
let schema = s.get(*i as usize).ok_or_else(|| {
AvroError::InvalidUnion(format!("Missing schema index {} in the union {:?}", *i, s))
})?;
let value = map(v, schema, parent_ns, ref_cache)?;
Expand All @@ -76,31 +77,13 @@ fn map(
}
(AvroValue::Enum(_, v), Schema::Enum { name: _, .. }) => Ok(json!(*v)),
(AvroValue::Fixed(_, v), Schema::Fixed { .. }) => Ok(json!(*v)),
(value, Schema::Ref { name }) => parse_ref(ref_cache, name, parent_ns, value),
(_, s) => Err(AvroError::Unsupported(format!(
"Unexpected value/schema tuple. Schema: {:?}",
s
))),
}
}

fn parse_ref(
ref_cache: &HashMap<Name, Schema>,
name: &Name,
parent_ns: &Option<String>,
value: &AvroValue,
) -> AvroResult<JsonValue> {
let schema = ref_cache
.get(
&(Name {
namespace: name.namespace.clone().or_else(|| parent_ns.to_owned()),
name: name.name.clone(),
}),
)
.ok_or_else(|| AvroError::MissingAvroSchemaReference(name.to_string()))?;
map(value, schema, &name.namespace, ref_cache)
}

fn parse_decimal(v: &apache_avro::Decimal, scale: &usize) -> AvroResult<JsonValue> {
// the representation of the decimal in avro is the number in binary with
// the scale encoded in the schema. Therefore we convert the bin array into a big int
Expand All @@ -118,7 +101,7 @@ fn parse_decimal(v: &apache_avro::Decimal, scale: &usize) -> AvroResult<JsonValu
fn parse_record(
vec: &[(String, AvroValue)],
lookup: &std::collections::BTreeMap<String, usize>,
fields: &[apache_avro::schema::RecordField],
fields: &[RecordField],
name: &Name,
parent_ns: &Option<String>,
ref_cache: &HashMap<Name, Schema>,
Expand Down Expand Up @@ -172,7 +155,7 @@ mod tests {
use apache_avro::{to_avro_datum, types::Record, types::Value as AvroValue, Schema as ApacheAvroSchema, Writer};
use async_trait::async_trait;

use crate::lib::{avro::error::AvroResult, schema_registry::ResolvedAvroSchema};
use crate::lib::avro::{error::AvroResult, ResolvedAvroSchema};

use super::{AvroParser, SchemaProvider};
struct MockSchemaRegistry {
Expand All @@ -184,7 +167,7 @@ mod tests {
async fn get_schema_by_id(&self, _: i32) -> AvroResult<ResolvedAvroSchema> {
Ok(ResolvedAvroSchema::from(
123,
ApacheAvroSchema::parse_str(&self.schema).unwrap(),
&ApacheAvroSchema::parse_str(&self.schema).unwrap(),
))
}
async fn get_schema_by_name(&self, _name: &str) -> AvroResult<ResolvedAvroSchema> {
Expand Down
1 change: 0 additions & 1 deletion backend/src/lib/avro/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use crate::lib::schema_registry::SchemaRegistryError;
#[derive(Debug, PartialEq)]
pub enum AvroError {
InvalidNumber(String),
MissingAvroSchemaReference(String),
MissingField(String),
SchemaProvider(String, SchemaRegistryError),
InvalidUnion(String),
Expand Down
4 changes: 2 additions & 2 deletions backend/src/lib/avro/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use apache_avro::{schema::Name, Schema};
use super::avro_schema::AvroSchema as Schema;
use apache_avro::schema::Name;
use log::error;

use super::error::{AvroError, AvroResult};
Expand Down Expand Up @@ -44,7 +45,6 @@ pub(super) fn get_schema_name<'a>(s: &'a Schema, parent_ns: Option<&'a str>) ->
Schema::String => "string".into(),
Schema::Record { name, .. } => ns_name(name, parent_ns),
Schema::Enum { name, .. } => ns_name(name, parent_ns),
Schema::Ref { name, .. } => ns_name(name, parent_ns),
_ => {
//todo: support the other types
let message = format!("Unable to retrieve the name of the schema {:?}", s);
Expand Down
73 changes: 30 additions & 43 deletions backend/src/lib/avro/json_to_avro.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::{collections::HashMap, str::FromStr};

use apache_avro::{schema::Name, to_avro_datum, types::Value as AvroValue, Schema};
use apache_avro::{schema::Name, to_avro_datum, types::Value as AvroValue};
use num_bigint::BigInt;
use uuid::Uuid;

use super::{
avro_parser::AvroParser,
avro_schema::{AvroSchema as Schema, RecordField},
error::AvroResult,
helpers::{build_record_header, get_schema_name},
schema_provider::SchemaProvider,
Expand All @@ -22,10 +23,10 @@ impl<S: SchemaProvider> AvroParser<S> {

pub fn json_to_avro_with_schema(&self, json: &str, schema: ResolvedAvroSchema) -> AvroResult<Vec<u8>> {
let json_value = JsonValue::from_str(json).map_err(|err| AvroError::ParseJsonValue(err.to_string()))?;
let mut res = build_record_header(schema.schema_id);
let avro_value = json_to_avro_map(&json_value, &schema.schema, None, &schema.resolved_schemas)?;
let mut res = build_record_header(schema.id);
let avro_value = json_to_avro_map(&json_value, &schema.schema, None, &HashMap::new())?;
println!("Parsing: {:?}\n\tUsing schema: {:?}", avro_value, &schema.schema);
let mut avro_record = to_avro_datum(&schema.schema, avro_value.clone())
let mut avro_record = to_avro_datum(&schema.inner_schema, avro_value.clone())
.map_err(|err| AvroError::ParseAvroValue(err.to_string()))?;
res.append(&mut avro_record);
Ok(res)
Expand All @@ -46,16 +47,20 @@ fn json_to_avro_map(
(Schema::Array(items_schema), JsonValue::Array(values)) => {
map_json_array_to_avro(values, items_schema, parent_ns, ref_map)
}
(Schema::Union(union_schema), JsonValue::Null) => {
let (position, _) = union_schema.find_schema(&AvroValue::Null).ok_or_else(|| {
AvroError::InvalidUnion(format!(
"Cannot set null to the union. Supported options are: {:?}",
union_schema.variants()
))
})?;
(Schema::Union(union_schemas), JsonValue::Null) => {
let (position, _) = union_schemas
.iter()
.enumerate()
.find(|(_, s)| *s == &Schema::Null)
.ok_or_else(|| {
AvroError::InvalidUnion(format!(
"Cannot set null to the union. Supported options are: {:?}",
union_schemas
))
})?;
Ok(AvroValue::Union(position as u32, AvroValue::Null.into()))
}
(Schema::Union(union_schema), JsonValue::Object(obj)) => map_union(obj, union_schema, parent_ns, ref_map),
(Schema::Union(union_schemas), JsonValue::Object(obj)) => map_union(obj, union_schemas, parent_ns, ref_map),
(Schema::Map(schema), JsonValue::Object(obj)) => {
let mut avro_map = HashMap::new();
for (key, value) in obj {
Expand Down Expand Up @@ -138,14 +143,6 @@ fn json_to_avro_map(
})?;
Ok(AvroValue::TimestampMicros(n))
}
// references
(Schema::Ref { name }, value) => {
let ns_name = name.fully_qualified_name(&parent_ns.map(|v| v.to_string()));
let schema = ref_map
.get(&ns_name)
.ok_or_else(|| AvroError::MissingAvroSchemaReference(format!("Unable to resolve reference {}", name)))?;
json_to_avro_map(value, schema, parent_ns, ref_map)
}
(Schema::Uuid, JsonValue::String(v)) => {
let uuid =
Uuid::parse_str(v).map_err(|_| AvroError::InvalidUUID(format!("Unable to parse {} into a uuid", v)))?;
Expand Down Expand Up @@ -181,29 +178,27 @@ fn parse_decimal(n: &str, scale: u32) -> AvroResult<apache_avro::Decimal> {

fn map_union(
obj: &serde_json::Map<String, JsonValue>,
union_schema: &apache_avro::schema::UnionSchema,
union_schemas: &Vec<Schema>,
parent_ns: Option<&str>,
ref_map: &HashMap<Name, Schema>,
) -> Result<AvroValue, AvroError> {
let fields_vec: Vec<(&String, &JsonValue)> = obj.iter().collect();
if fields_vec.len() != 1 {
Err(AvroError::InvalidUnion(format!(
"Invalid union. Expected one of: {:?}",
union_schema.variants()
union_schemas
)))
} else {
let (union_branch_name, value) = *fields_vec.first().unwrap();
let index_schema = union_schema
.variants()
let index_schema = union_schemas
.iter()
.enumerate()
.find(|(_, schema)| get_schema_name(schema, parent_ns).eq(union_branch_name));
if let Some((index, current_schema)) = index_schema {
let value = json_to_avro_map(value, current_schema, parent_ns, ref_map)?;
Ok(AvroValue::Union(index as u32, value.into()))
} else {
let union_variants: Vec<_> = union_schema
.variants()
let union_variants: Vec<_> = union_schemas
.iter()
.map(|schema| get_schema_name(schema, parent_ns))
.collect();
Expand All @@ -230,7 +225,7 @@ fn map_json_array_to_avro(
}

fn map_json_fields_to_record(
fields: &Vec<apache_avro::schema::RecordField>,
fields: &Vec<RecordField>,
obj: &serde_json::Map<String, JsonValue>,
parent_ns: Option<&str>,
ref_map: &HashMap<Name, Schema>,
Expand All @@ -252,10 +247,10 @@ mod tests {
use std::collections::BTreeMap;
use std::collections::HashMap;

use apache_avro::{schema::RecordField, Schema};

use super::map_json_fields_to_record;
use super::parse_decimal;
use crate::lib::avro::avro_schema::AvroSchema;
use crate::lib::avro::avro_schema::RecordField;
use crate::lib::avro::AvroError;

use apache_avro::types::Value as AvroValue;
Expand Down Expand Up @@ -287,7 +282,7 @@ mod tests {
obj_map.insert("sample".to_string(), json!(1));
obj_map
};
let fields = vec![build_record_field("sample", apache_avro::Schema::Int)];
let fields = vec![build_record_field("sample", AvroSchema::Int)];

// happy path
{
Expand All @@ -301,8 +296,8 @@ mod tests {
// parse a json object with a missing field return an error
{
let fields = vec![
build_record_field("sample", apache_avro::Schema::Int),
build_record_field("sample_2", apache_avro::Schema::Int),
build_record_field("sample", AvroSchema::Int),
build_record_field("sample_2", AvroSchema::Int),
];
let res = map_json_fields_to_record(&fields, &obj, None, &HashMap::new());
assert_eq!(res, Err(AvroError::MissingField("sample_2".into())))
Expand All @@ -316,16 +311,13 @@ mod tests {
obj_map.insert("nested".into(), JsonValue::Object(obj));
obj_map
};
let nested_schema = Schema::Record {
let nested_schema = AvroSchema::Record {
name: "Nested".into(),
aliases: None,
doc: None,
fields: fields,
lookup: BTreeMap::new(),
attributes: BTreeMap::new(),
};
let fields = vec![
build_record_field("sample", apache_avro::Schema::Int),
build_record_field("sample", AvroSchema::Int),
build_record_field("nested", nested_schema),
];
let res = map_json_fields_to_record(&fields, &obj_parent, None, &HashMap::new());
Expand All @@ -342,15 +334,10 @@ mod tests {
}
}

fn build_record_field(name: &str, schema: Schema) -> RecordField {
fn build_record_field(name: &str, schema: AvroSchema) -> RecordField {
RecordField {
name: name.into(),
doc: Default::default(),
default: Default::default(),
schema: schema,
order: apache_avro::schema::RecordFieldOrder::Ignore,
position: Default::default(),
custom_attributes: BTreeMap::new(),
}
}
}

0 comments on commit 6d3465c

Please sign in to comment.