Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Add support to extension types in FFI #363

Merged
merged 1 commit into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
import arrow_pyarrow_integration_testing


class UuidType(pyarrow.PyExtensionType):
def __init__(self):
super().__init__(pyarrow.binary(16))

def __reduce__(self):
return UuidType, ()


class TestCase(unittest.TestCase):
def setUp(self):
self.old_allocated_rust = (
Expand Down Expand Up @@ -179,3 +187,10 @@ def test_field_metadata(self):
result = arrow_pyarrow_integration_testing.round_trip_field(field)
assert field == result
assert field.metadata == result.metadata

# see https://issues.apache.org/jira/browse/ARROW-13855
def _test_field_extension(self):
field = pyarrow.field("aa", UuidType())
result = arrow_pyarrow_integration_testing.round_trip_field(field)
assert field == result
assert field.metadata == result.metadata
16 changes: 16 additions & 0 deletions src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,19 @@ impl std::fmt::Display for Field {
write!(f, "{:?}", self)
}
}

pub(crate) type Metadata = Option<BTreeMap<String, String>>;
pub(crate) type Extension = Option<(String, Option<String>)>;

pub(crate) fn get_extension(metadata: &Option<BTreeMap<String, String>>) -> Extension {
if let Some(metadata) = metadata {
if let Some(name) = metadata.get("ARROW:extension:name") {
let metadata = metadata.get("ARROW:extension:metadata").cloned();
Some((name.clone(), metadata))
} else {
None
}
} else {
None
}
}
2 changes: 2 additions & 0 deletions src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ pub use field::Field;
pub use physical_type::*;
pub use schema::Schema;

pub(crate) use field::{get_extension, Extension, Metadata};

/// The set of datatypes that are supported by this implementation of Apache Arrow.
///
/// The Arrow specification on data types includes some more types.
Expand Down
55 changes: 46 additions & 9 deletions src/ffi/schema.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{collections::BTreeMap, convert::TryInto, ffi::CStr, ffi::CString, ptr};

use crate::{
datatypes::{DataType, Field, IntervalUnit, TimeUnit},
datatypes::{DataType, Extension, Field, IntervalUnit, Metadata, TimeUnit},
error::{ArrowError, Result},
};

Expand Down Expand Up @@ -91,7 +91,26 @@ impl Ffi_ArrowSchema {
None
};

let metadata = field.metadata().as_ref().map(metadata_to_bytes);
let metadata = field.metadata();

let metadata = if let DataType::Extension(name, _, extension_metadata) = field.data_type() {
// append extension information.
let mut metadata = metadata.clone().unwrap_or_default();

// metadata
if let Some(extension_metadata) = extension_metadata {
metadata.insert(
"ARROW:extension:metadata".to_string(),
extension_metadata.clone(),
);
}

metadata.insert("ARROW:extension:name".to_string(), name.clone());

Some(metadata_to_bytes(&metadata))
} else {
metadata.as_ref().map(metadata_to_bytes)
};

let name = CString::new(name).unwrap();
let format = CString::new(format).unwrap();
Expand Down Expand Up @@ -192,7 +211,14 @@ pub fn to_field(schema: &Ffi_ArrowSchema) -> Result<Field> {
} else {
to_data_type(schema)?
};
let metadata = unsafe { metadata_from_bytes(schema.metadata) };
let (metadata, extension) = unsafe { metadata_from_bytes(schema.metadata) };

let data_type = if let Some((name, extension_metadata)) = extension {
DataType::Extension(name, Box::new(data_type), extension_metadata)
} else {
data_type
};

let mut field = Field::new(schema.name(), data_type, schema.nullable());
field.set_metadata(metadata);
Ok(field)
Expand Down Expand Up @@ -412,17 +438,17 @@ unsafe fn read_bytes(ptr: *const u8, len: usize) -> &'static str {
std::str::from_utf8(slice).unwrap()
}

unsafe fn metadata_from_bytes(
data: *const ::std::os::raw::c_char,
) -> Option<BTreeMap<String, String>> {
unsafe fn metadata_from_bytes(data: *const ::std::os::raw::c_char) -> (Metadata, Extension) {
let mut data = data as *const u8; // u8 = i8
if data.is_null() {
return None;
return (None, None);
};
let len = read_ne_i32(data);
data = data.add(4);

let mut result = BTreeMap::new();
let mut extension_name = None;
let mut extension_metadata = None;
for _ in 0..len {
let key_len = read_ne_i32(data) as usize;
data = data.add(4);
Expand All @@ -432,7 +458,18 @@ unsafe fn metadata_from_bytes(
data = data.add(4);
let value = read_bytes(data, value_len);
data = data.add(value_len);
result.insert(key.to_string(), value.to_string());
match key {
"ARROW:extension:name" => {
extension_name = Some(value.to_string());
}
"ARROW:extension:metadata" => {
extension_metadata = Some(value.to_string());
}
_ => {
result.insert(key.to_string(), value.to_string());
}
};
}
Some(result)
let extension = extension_name.map(|name| (name, extension_metadata));
(Some(result), extension)
}
20 changes: 3 additions & 17 deletions src/io/ipc/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

//! Utilities for converting between IPC types and native Arrow types

use crate::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit};
use crate::datatypes::{
get_extension, DataType, Extension, Field, IntervalUnit, Metadata, Schema, TimeUnit,
};
use crate::endianess::is_native_little_endian;
use crate::io::ipc::convert::ipc::UnionMode;

Expand All @@ -32,9 +34,6 @@ use std::collections::{BTreeMap, HashMap};

use DataType::*;

type Metadata = Option<BTreeMap<String, String>>;
type Extension = Option<(String, Option<String>)>;

pub fn schema_to_fb_offset<'a>(
fbb: &mut FlatBufferBuilder<'a>,
schema: &Schema,
Expand Down Expand Up @@ -84,19 +83,6 @@ fn read_metadata(field: &ipc::Field) -> Metadata {
}
}

pub(crate) fn get_extension(metadata: &Metadata) -> Extension {
if let Some(metadata) = metadata {
if let Some(name) = metadata.get("ARROW:extension:name") {
let metadata = metadata.get("ARROW:extension:metadata").cloned();
Some((name.clone(), metadata))
} else {
None
}
} else {
None
}
}

/// Convert an IPC Field to Arrow Field
impl<'a> From<ipc::Field<'a>> for Field {
fn from(field: ipc::Field) -> Field {
Expand Down
1 change: 0 additions & 1 deletion src/io/ipc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ mod compression;
mod convert;

pub use convert::fb_to_schema;
pub(crate) use convert::get_extension;
pub use gen::Message::root_as_message;
pub mod read;
pub mod write;
Expand Down
3 changes: 1 addition & 2 deletions src/io/json_integration/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ use serde_json::{json, Value};

use crate::error::{ArrowError, Result};

use crate::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit};
use crate::io::ipc::get_extension;
use crate::datatypes::{get_extension, DataType, Field, IntervalUnit, Schema, TimeUnit};

pub trait ToJson {
/// Generate a JSON representation
Expand Down
10 changes: 10 additions & 0 deletions tests/it/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,13 @@ fn schema() -> Result<()> {
let field = field.with_metadata(metadata);
test_round_trip_schema(field)
}

#[test]
fn extension() -> Result<()> {
let field = Field::new(
"a",
DataType::Extension("a".to_string(), Box::new(DataType::Int32), None),
true,
);
test_round_trip_schema(field)
}