Skip to content

Commit

Permalink
[Arrow] Make unknown Arrow extensions throw at scan instead of bind (#…
Browse files Browse the repository at this point in the history
…14015)

This allows people to scan arrow objects with unsupported extension
types by projecting them out.

For reference: #13931
  • Loading branch information
Mytherin authored Sep 19, 2024
2 parents cb27c04 + 7dd9dd1 commit d1037da
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 23 deletions.
56 changes: 34 additions & 22 deletions src/function/table/arrow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ static unique_ptr<ArrowType> GetArrowExtensionType(const ArrowSchemaMetadata &ex
// Check for arrow canonical extensions
if (arrow_extension == "arrow.uuid") {
if (format != "w:16") {
throw InvalidInputException(
"arrow.uuid must be a fixed-size binary of 16 bytes (i.e., \'w:16\'). It is incorrectly defined as: %s",
format);
std::ostringstream error;
error
<< "arrow.uuid must be a fixed-size binary of 16 bytes (i.e., \'w:16\'). It is incorrectly defined as:"
<< format;
return make_uniq<ArrowType>(error.str());
}
return make_uniq<ArrowType>(LogicalType::UUID);
} else if (arrow_extension == "arrow.json") {
Expand All @@ -49,40 +51,47 @@ static unique_ptr<ArrowType> GetArrowExtensionType(const ArrowSchemaMetadata &ex
} else if (format == "vu") {
return make_uniq<ArrowType>(LogicalType::JSON(), make_uniq<ArrowStringInfo>(ArrowVariableSizeType::VIEW));
} else {
throw InvalidInputException("arrow.json must be of a varchar format (i.e., \'u\',\'U\' or \'vu\'). It is "
"incorrectly defined as: %s",
format);
std::ostringstream error;
error
<< "arrow.json must be of a varchar format (i.e., \'u\',\'U\' or \'vu\'). It is incorrectly defined as:"
<< format;
return make_uniq<ArrowType>(error.str());
}
}
// Check for DuckDB canonical extensions
else if (arrow_extension == "duckdb.hugeint") {
if (format != "w:16") {
throw InvalidInputException("duckdb.hugeint must be a fixed-size binary of 16 bytes (i.e., \'w:16\'). It "
"is incorrectly defined as: %s",
format);
std::ostringstream error;
error << "duckdb.hugeint must be a fixed-size binary of 16 bytes (i.e., \'w:16\'). It is incorrectly "
"defined as:"
<< format;
return make_uniq<ArrowType>(error.str());
}
return make_uniq<ArrowType>(LogicalType::HUGEINT);

} else if (arrow_extension == "duckdb.uhugeint") {
if (format != "w:16") {
throw InvalidInputException("duckdb.hugeint must be a fixed-size binary of 16 bytes (i.e., \'w:16\'). It "
"is incorrectly defined as: %s",
format);
std::ostringstream error;
error << "duckdb.uhugeint must be a fixed-size binary of 16 bytes (i.e., \'w:16\'). It is incorrectly "
"defined as:"
<< format;
return make_uniq<ArrowType>(error.str());
}
return make_uniq<ArrowType>(LogicalType::UHUGEINT);
} else if (arrow_extension == "duckdb.time_tz") {
if (format != "w:8") {
throw InvalidInputException("duckdb.time_tz must be a fixed-size binary of 8 bytes (i.e., \'w:8\'). It "
"is incorrectly defined as: %s",
format);
std::ostringstream error;
error << "duckdb.time_tz must be a fixed-size binary of 8 bytes (i.e., \'w:8\'). It is incorrectly defined "
"as:"
<< format;
return make_uniq<ArrowType>(error.str());
}
return make_uniq<ArrowType>(LogicalType::TIME_TZ,
make_uniq<ArrowDateTimeInfo>(ArrowDateTimeType::MICROSECONDS));
} else if (arrow_extension == "duckdb.bit") {
if (format != "z" && format != "Z") {
throw InvalidInputException("duckdb.bit must be a blob (i.e., \'z\' or \'Z\'). It "
"is incorrectly defined as: %s",
format);
std::ostringstream error;
error << "duckdb.bit must be a blob (i.e., \'z\' or \'Z\'). It is incorrectly defined as:" << format;
return make_uniq<ArrowType>(error.str());
} else if (format == "z") {
auto type_info = make_uniq<ArrowStringInfo>(ArrowVariableSizeType::NORMAL);
return make_uniq<ArrowType>(LogicalType::BIT, std::move(type_info));
Expand All @@ -91,9 +100,10 @@ static unique_ptr<ArrowType> GetArrowExtensionType(const ArrowSchemaMetadata &ex
return make_uniq<ArrowType>(LogicalType::BIT, std::move(type_info));

} else {
throw NotImplementedException(
"Arrow Type with extension name: %s and format: %s, is not currently supported in DuckDB ", arrow_extension,
format);
std::ostringstream error;
error << "Arrow Type with extension name: " << arrow_extension << " and format: " << format
<< ", is not currently supported in DuckDB.";
return make_uniq<ArrowType>(error.str(), true);
}
}
static unique_ptr<ArrowType> GetArrowLogicalTypeNoDictionary(ArrowSchema &schema) {
Expand Down Expand Up @@ -384,10 +394,12 @@ unique_ptr<ArrowArrayStreamWrapper> ProduceArrowScan(const ArrowScanFunctionData
//! Generate Projection Pushdown Vector
ArrowStreamParameters parameters;
D_ASSERT(!column_ids.empty());
auto &arrow_types = function.arrow_table.GetColumns();
for (idx_t idx = 0; idx < column_ids.size(); idx++) {
auto col_idx = column_ids[idx];
if (col_idx != COLUMN_IDENTIFIER_ROW_ID) {
auto &schema = *function.schema_root.arrow_schema.children[col_idx];
arrow_types.at(col_idx)->ThrowIfInvalid();
parameters.projected_columns.projection_map[idx] = schema.name;
parameters.projected_columns.columns.emplace_back(schema.name);
parameters.projected_columns.filter_to_col[idx] = col_idx;
Expand Down
9 changes: 9 additions & 0 deletions src/function/table/arrow/arrow_duck_schema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ bool ArrowType::RunEndEncoded() const {
return run_end_encoded;
}

void ArrowType::ThrowIfInvalid() const {
if (type.id() == LogicalTypeId::INVALID) {
if (not_implemented) {
throw NotImplementedException(error_message);
}
throw InvalidInputException(error_message);
}
}

LogicalType ArrowType::GetDuckType(bool use_dictionary) const {
if (use_dictionary && dictionary_type) {
return dictionary_type->GetDuckType();
Expand Down
11 changes: 11 additions & 0 deletions src/include/duckdb/function/table/arrow/arrow_duck_schema.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#pragma once

#include <utility>

#include "duckdb/common/types.hpp"
#include "duckdb/common/unordered_map.hpp"
#include "duckdb/common/vector.hpp"
Expand All @@ -22,6 +24,10 @@ class ArrowType {
explicit ArrowType(LogicalType type_p, unique_ptr<ArrowTypeInfo> type_info = nullptr)
: type(std::move(type_p)), type_info(std::move(type_info)) {
}
explicit ArrowType(string error_message_p, bool not_implemented_p = false)
: type(LogicalTypeId::INVALID), type_info(nullptr), error_message(std::move(error_message_p)),
not_implemented(not_implemented_p) {
}

public:
LogicalType GetDuckType(bool use_dictionary = false) const;
Expand All @@ -37,6 +43,7 @@ class ArrowType {
const T &GetTypeInfo() const {
return type_info->Cast<T>();
}
void ThrowIfInvalid() const;

private:
LogicalType type;
Expand All @@ -45,6 +52,10 @@ class ArrowType {
//! Is run-end-encoded
bool run_end_encoded = false;
unique_ptr<ArrowTypeInfo> type_info;
//! Error message in case of an invalid type (i.e., from an unsupported extension)
string error_message;
//! In case of an error do we throw not implemented?
bool not_implemented = false;
};

using arrow_column_map_t = unordered_map<idx_t, unique_ptr<ArrowType>>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,15 @@ def __arrow_ext_deserialize__(cls, storage_type, serialized):
my_type = MyType()
storage_array = my_type.wrap_array(storage_array)

arrow_table = pa.Table.from_arrays([storage_array], names=['pedro_pedro_pedro'])
age_array = pa.array([29], pa.int32())

arrow_table = pa.Table.from_arrays([storage_array, age_array], names=['pedro_pedro_pedro', 'age'])

with pytest.raises(duckdb.NotImplementedException, match=" Arrow Type with extension name: pedro.binary"):
duck_arrow = duckdb_cursor.execute('FROM arrow_table').arrow()
duck_res = duckdb_cursor.execute('SELECT age FROM arrow_table').fetchall()
# This works because we project ze unknown extension array
assert duck_res == [(29,)]

def test_hugeint(self, arrow_duckdb_hugeint):
con = duckdb.connect()
Expand Down

0 comments on commit d1037da

Please sign in to comment.