Skip to content

Commit

Permalink
Preserve dict_id on Field during serde roundtrip (#8457)
Browse files Browse the repository at this point in the history
* Failing test

* Passing test
  • Loading branch information
avantgardnerio authored Dec 8, 2023
1 parent 205e315 commit a8d74a7
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 2 deletions.
2 changes: 2 additions & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,8 @@ message Field {
// for complex data types like structs, unions
repeated Field children = 4;
map<string, string> metadata = 5;
int64 dict_id = 6;
bool dict_ordered = 7;
}

message FixedSizeBinary{
Expand Down
39 changes: 39 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 14 additions & 2 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,20 @@ impl TryFrom<&protobuf::Field> for Field {
type Error = Error;
fn try_from(field: &protobuf::Field) -> Result<Self, Self::Error> {
let datatype = field.arrow_type.as_deref().required("arrow_type")?;
Ok(Self::new(field.name.as_str(), datatype, field.nullable)
.with_metadata(field.metadata.clone()))
let field = if field.dict_id != 0 {
Self::new_dict(
field.name.as_str(),
datatype,
field.nullable,
field.dict_id,
field.dict_ordered,
)
.with_metadata(field.metadata.clone())
} else {
Self::new(field.name.as_str(), datatype, field.nullable)
.with_metadata(field.metadata.clone())
};
Ok(field)
}
}

Expand Down
2 changes: 2 additions & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ impl TryFrom<&Field> for protobuf::Field {
nullable: field.is_nullable(),
children: Vec::new(),
metadata: field.metadata().clone(),
dict_id: field.dict_id().unwrap_or(0),
dict_ordered: field.dict_is_ordered().unwrap_or(false),
})
}
}
Expand Down
39 changes: 39 additions & 0 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,45 @@ fn round_trip_datatype() {
}
}

#[test]
fn roundtrip_dict_id() -> Result<()> {
let dict_id = 42;
let field = Field::new(
"keys",
DataType::List(Arc::new(Field::new_dict(
"item",
DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
true,
dict_id,
false,
))),
false,
);
let schema = Arc::new(Schema::new(vec![field]));

// encode
let mut buf: Vec<u8> = vec![];
let schema_proto: datafusion_proto::generated::datafusion::Schema =
schema.try_into().unwrap();
schema_proto.encode(&mut buf).unwrap();

// decode
let schema_proto =
datafusion_proto::generated::datafusion::Schema::decode(buf.as_slice()).unwrap();
let decoded: Schema = (&schema_proto).try_into()?;

// assert
let keys = decoded.fields().iter().last().unwrap();
match keys.data_type() {
DataType::List(field) => {
assert_eq!(field.dict_id(), Some(dict_id), "dict_id should be retained");
}
_ => panic!("Invalid type"),
}

Ok(())
}

#[test]
fn roundtrip_null_scalar_values() {
let test_types = vec![
Expand Down

0 comments on commit a8d74a7

Please sign in to comment.