Skip to content

Commit

Permalink
GH-39865: [C++] Strip extension metadata when importing a registered …
Browse files Browse the repository at this point in the history
…extension (#39866)

### Rationale for this change

When importing an extension type from the C Data Interface and the extension type is registered, we would still leave the extension-related metadata on the storage type.

### What changes are included in this PR?

Strip extension-related metadata on the storage type if we succeed in recreating the extension type.
This matches the behavior of the IPC layer and allows for more exact roundtripping.

### Are these changes tested?

Yes.

### Are there any user-facing changes?

No, unless people mistakingly rely on the presence of said metadata.
* Closes: #39865

Authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
pitrou authored and raulcd committed Feb 20, 2024
1 parent 4f7819a commit 91be098
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 31 deletions.
6 changes: 6 additions & 0 deletions cpp/src/arrow/c/bridge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,8 @@ struct DecodedMetadata {
std::shared_ptr<KeyValueMetadata> metadata;
std::string extension_name;
std::string extension_serialized;
int extension_name_index = -1; // index of extension_name in metadata
int extension_serialized_index = -1; // index of extension_serialized in metadata
};

Result<DecodedMetadata> DecodeMetadata(const char* metadata) {
Expand Down Expand Up @@ -956,8 +958,10 @@ Result<DecodedMetadata> DecodeMetadata(const char* metadata) {
RETURN_NOT_OK(read_string(&values[i]));
if (keys[i] == kExtensionTypeKeyName) {
decoded.extension_name = values[i];
decoded.extension_name_index = i;
} else if (keys[i] == kExtensionMetadataKeyName) {
decoded.extension_serialized = values[i];
decoded.extension_serialized_index = i;
}
}
decoded.metadata = key_value_metadata(std::move(keys), std::move(values));
Expand Down Expand Up @@ -1046,6 +1050,8 @@ struct SchemaImporter {
ARROW_ASSIGN_OR_RAISE(
type_, registered_ext_type->Deserialize(std::move(type_),
metadata_.extension_serialized));
RETURN_NOT_OK(metadata_.metadata->DeleteMany(
{metadata_.extension_name_index, metadata_.extension_serialized_index}));
}
}

Expand Down
48 changes: 32 additions & 16 deletions cpp/src/arrow/c/bridge_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1870,7 +1870,7 @@ class TestSchemaImport : public ::testing::Test, public SchemaStructBuilder {
ASSERT_TRUE(ArrowSchemaIsReleased(&c_struct_));
Reset(); // for further tests
cb.AssertCalled(); // was released
AssertTypeEqual(*expected, *type);
AssertTypeEqual(*expected, *type, /*check_metadata=*/true);
}

void CheckImport(const std::shared_ptr<Field>& expected) {
Expand All @@ -1890,7 +1890,7 @@ class TestSchemaImport : public ::testing::Test, public SchemaStructBuilder {
ASSERT_TRUE(ArrowSchemaIsReleased(&c_struct_));
Reset(); // for further tests
cb.AssertCalled(); // was released
AssertSchemaEqual(*expected, *schema);
AssertSchemaEqual(*expected, *schema, /*check_metadata=*/true);
}

void CheckImportError() {
Expand Down Expand Up @@ -3569,7 +3569,7 @@ class TestSchemaRoundtrip : public ::testing::Test {
// Recreate the type
ASSERT_OK_AND_ASSIGN(actual, ImportType(&c_schema));
type = factory_expected();
AssertTypeEqual(*type, *actual);
AssertTypeEqual(*type, *actual, /*check_metadata=*/true);
type.reset();
actual.reset();

Expand Down Expand Up @@ -3600,7 +3600,7 @@ class TestSchemaRoundtrip : public ::testing::Test {
// Recreate the schema
ASSERT_OK_AND_ASSIGN(actual, ImportSchema(&c_schema));
schema = factory();
AssertSchemaEqual(*schema, *actual);
AssertSchemaEqual(*schema, *actual, /*check_metadata=*/true);
schema.reset();
actual.reset();

Expand Down Expand Up @@ -3693,13 +3693,27 @@ TEST_F(TestSchemaRoundtrip, Dictionary) {
}
}

// Given an extension type, return a field of its storage type + the
// serialized extension metadata.
std::shared_ptr<Field> GetStorageWithMetadata(const std::string& field_name,
const std::shared_ptr<DataType>& type) {
const auto& ext_type = checked_cast<const ExtensionType&>(*type);
auto storage_type = ext_type.storage_type();
auto md = KeyValueMetadata::Make({kExtensionTypeKeyName, kExtensionMetadataKeyName},
{ext_type.extension_name(), ext_type.Serialize()});
return field(field_name, storage_type, /*nullable=*/true, md);
}

TEST_F(TestSchemaRoundtrip, UnregisteredExtension) {
TestWithTypeFactory(uuid, []() { return fixed_size_binary(16); });
TestWithTypeFactory(dict_extension_type, []() { return dictionary(int8(), utf8()); });

// Inside nested type
TestWithTypeFactory([]() { return list(dict_extension_type()); },
[]() { return list(dictionary(int8(), utf8())); });
// Inside nested type.
// When an extension type is not known by the importer, it is imported
// as its storage type and the extension metadata is preserved on the field.
TestWithTypeFactory(
[]() { return list(dict_extension_type()); },
[]() { return list(GetStorageWithMetadata("item", dict_extension_type())); });
}

TEST_F(TestSchemaRoundtrip, RegisteredExtension) {
Expand All @@ -3708,7 +3722,9 @@ TEST_F(TestSchemaRoundtrip, RegisteredExtension) {
TestWithTypeFactory(dict_extension_type);
TestWithTypeFactory(complex128);

// Inside nested type
// Inside nested type.
// When the extension type is registered, the extension metadata is removed
// from the storage type's field to ensure roundtripping (GH-39865).
TestWithTypeFactory([]() { return list(uuid()); });
TestWithTypeFactory([]() { return list(dict_extension_type()); });
TestWithTypeFactory([]() { return list(complex128()); });
Expand Down Expand Up @@ -3808,7 +3824,7 @@ class TestArrayRoundtrip : public ::testing::Test {
{
std::shared_ptr<Array> expected;
ASSERT_OK_AND_ASSIGN(expected, ToResult(factory_expected()));
AssertTypeEqual(*expected->type(), *array->type());
AssertTypeEqual(*expected->type(), *array->type(), /*check_metadata=*/true);
AssertArraysEqual(*expected, *array, true);
}
array.reset();
Expand Down Expand Up @@ -3848,7 +3864,7 @@ class TestArrayRoundtrip : public ::testing::Test {
{
std::shared_ptr<RecordBatch> expected;
ASSERT_OK_AND_ASSIGN(expected, ToResult(factory()));
AssertSchemaEqual(*expected->schema(), *batch->schema());
AssertSchemaEqual(*expected->schema(), *batch->schema(), /*check_metadata=*/true);
AssertBatchesEqual(*expected, *batch);
}
batch.reset();
Expand Down Expand Up @@ -4228,7 +4244,7 @@ class TestDeviceArrayRoundtrip : public ::testing::Test {
{
std::shared_ptr<Array> expected;
ASSERT_OK_AND_ASSIGN(expected, ToResult(factory_expected()));
AssertTypeEqual(*expected->type(), *array->type());
AssertTypeEqual(*expected->type(), *array->type(), /*check_metadata=*/true);
AssertArraysEqual(*expected, *array, true);
}
array.reset();
Expand Down Expand Up @@ -4274,7 +4290,7 @@ class TestDeviceArrayRoundtrip : public ::testing::Test {
{
std::shared_ptr<RecordBatch> expected;
ASSERT_OK_AND_ASSIGN(expected, ToResult(factory()));
AssertSchemaEqual(*expected->schema(), *batch->schema());
AssertSchemaEqual(*expected->schema(), *batch->schema(), /*check_metadata=*/true);
AssertBatchesEqual(*expected, *batch);
}
batch.reset();
Expand Down Expand Up @@ -4351,7 +4367,7 @@ class TestArrayStreamExport : public BaseArrayStreamTest {
SchemaExportGuard schema_guard(&c_schema);
ASSERT_FALSE(ArrowSchemaIsReleased(&c_schema));
ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema));
AssertSchemaEqual(expected, *schema);
AssertSchemaEqual(expected, *schema, /*check_metadata=*/true);
}

void AssertStreamEnd(struct ArrowArrayStream* c_stream) {
Expand Down Expand Up @@ -4435,7 +4451,7 @@ TEST_F(TestArrayStreamExport, ArrayLifetime) {
{
SchemaExportGuard schema_guard(&c_schema);
ASSERT_OK_AND_ASSIGN(auto got_schema, ImportSchema(&c_schema));
AssertSchemaEqual(*schema, *got_schema);
AssertSchemaEqual(*schema, *got_schema, /*check_metadata=*/true);
}

ASSERT_GT(pool_->bytes_allocated(), orig_allocated_);
Expand All @@ -4460,7 +4476,7 @@ TEST_F(TestArrayStreamExport, Errors) {
{
SchemaExportGuard schema_guard(&c_schema);
ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema));
AssertSchemaEqual(schema, arrow::schema({}));
AssertSchemaEqual(schema, arrow::schema({}), /*check_metadata=*/true);
}

struct ArrowArray c_array;
Expand Down Expand Up @@ -4537,7 +4553,7 @@ TEST_F(TestArrayStreamRoundtrip, Simple) {
ASSERT_OK_AND_ASSIGN(auto reader, RecordBatchReader::Make(batches, orig_schema));

Roundtrip(std::move(reader), [&](const std::shared_ptr<RecordBatchReader>& reader) {
AssertSchemaEqual(*orig_schema, *reader->schema());
AssertSchemaEqual(*orig_schema, *reader->schema(), /*check_metadata=*/true);
AssertReaderNext(reader, *batches[0]);
AssertReaderNext(reader, *batches[1]);
AssertReaderEnd(reader);
Expand Down
18 changes: 8 additions & 10 deletions cpp/src/arrow/util/key_value_metadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ void KeyValueMetadata::Append(std::string key, std::string value) {
values_.push_back(std::move(value));
}

Result<std::string> KeyValueMetadata::Get(const std::string& key) const {
Result<std::string> KeyValueMetadata::Get(std::string_view key) const {
auto index = FindKey(key);
if (index < 0) {
return Status::KeyError(key);
Expand Down Expand Up @@ -129,7 +129,7 @@ Status KeyValueMetadata::DeleteMany(std::vector<int64_t> indices) {
return Status::OK();
}

Status KeyValueMetadata::Delete(const std::string& key) {
Status KeyValueMetadata::Delete(std::string_view key) {
auto index = FindKey(key);
if (index < 0) {
return Status::KeyError(key);
Expand All @@ -138,20 +138,18 @@ Status KeyValueMetadata::Delete(const std::string& key) {
}
}

Status KeyValueMetadata::Set(const std::string& key, const std::string& value) {
Status KeyValueMetadata::Set(std::string key, std::string value) {
auto index = FindKey(key);
if (index < 0) {
Append(key, value);
Append(std::move(key), std::move(value));
} else {
keys_[index] = key;
values_[index] = value;
keys_[index] = std::move(key);
values_[index] = std::move(value);
}
return Status::OK();
}

bool KeyValueMetadata::Contains(const std::string& key) const {
return FindKey(key) >= 0;
}
bool KeyValueMetadata::Contains(std::string_view key) const { return FindKey(key) >= 0; }

void KeyValueMetadata::reserve(int64_t n) {
DCHECK_GE(n, 0);
Expand Down Expand Up @@ -188,7 +186,7 @@ std::vector<std::pair<std::string, std::string>> KeyValueMetadata::sorted_pairs(
return pairs;
}

int KeyValueMetadata::FindKey(const std::string& key) const {
int KeyValueMetadata::FindKey(std::string_view key) const {
for (size_t i = 0; i < keys_.size(); ++i) {
if (keys_[i] == key) {
return static_cast<int>(i);
Expand Down
11 changes: 6 additions & 5 deletions cpp/src/arrow/util/key_value_metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cstdint>
#include <memory>
#include <string>
#include <string_view>
#include <unordered_map>
#include <utility>
#include <vector>
Expand All @@ -44,13 +45,13 @@ class ARROW_EXPORT KeyValueMetadata {
void ToUnorderedMap(std::unordered_map<std::string, std::string>* out) const;
void Append(std::string key, std::string value);

Result<std::string> Get(const std::string& key) const;
bool Contains(const std::string& key) const;
Result<std::string> Get(std::string_view key) const;
bool Contains(std::string_view key) const;
// Note that deleting may invalidate known indices
Status Delete(const std::string& key);
Status Delete(std::string_view key);
Status Delete(int64_t index);
Status DeleteMany(std::vector<int64_t> indices);
Status Set(const std::string& key, const std::string& value);
Status Set(std::string key, std::string value);

void reserve(int64_t n);

Expand All @@ -63,7 +64,7 @@ class ARROW_EXPORT KeyValueMetadata {
std::vector<std::pair<std::string, std::string>> sorted_pairs() const;

/// \brief Perform linear search for key, returning -1 if not found
int FindKey(const std::string& key) const;
int FindKey(std::string_view key) const;

std::shared_ptr<KeyValueMetadata> Copy() const;

Expand Down

0 comments on commit 91be098

Please sign in to comment.