diff --git a/src/substrait_extension.cpp b/src/substrait_extension.cpp index 7361623..c6bbb7f 100644 --- a/src/substrait_extension.cpp +++ b/src/substrait_extension.cpp @@ -275,6 +275,7 @@ static unique_ptr FromSubstraitBind(ClientContext &context, TableFunct static unique_ptr FromSubstraitBindJSON(ClientContext &context, TableFunctionBindInput &input) { return SubstraitBind(context, input, true); +} //! Container for TableFnExplainSubstrait to get data from BindFnExplainSubstrait struct FromSubstraitFunctionData : public TableFunctionData { @@ -284,8 +285,8 @@ struct FromSubstraitFunctionData : public TableFunctionData { unique_ptr conn; }; -static unique_ptr BindFnExplainSubstrait(ClientContext &context, TableFunctionBindInput &input - vector &return_types, vector &names) { +static unique_ptr BindFnExplainSubstrait(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { if (input.inputs[0].IsNull()) { throw BinderException("explain_substrait cannot be called with a NULL parameter"); } diff --git a/test/python/test_substrait_explain.py b/test/python/test_substrait_explain.py new file mode 100644 index 0000000..b4ef39d --- /dev/null +++ b/test/python/test_substrait_explain.py @@ -0,0 +1,26 @@ +import pandas as pd +import duckdb + +EXPECTED_RESULT = ''' +┌───────────────┬──────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ explain_key │ explain_value │ +│ varchar │ varchar │ +├───────────────┼──────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│ physical_plan │ ┌───────────────────────────┐\n│ STREAMING_LIMIT │\n└─────────────┬─────────────┘\n┌────… │ +└───────────────┴──────────────────────────────────────────────────────────────────────────────────────────────────────┘ + +''' + +def test_roundtrip_substrait(require): + connection = require('substrait') + connection.execute('CREATE TABLE integers (i integer)') + connection.execute('INSERT INTO integers VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9),(NULL)') + + translate_result = connection.get_substrait('SELECT * FROM integers LIMIT 5') + proto_bytes = translate_result.fetchone()[0] + + expected = pd.Series([EXPECTED_RESULT], name='Explain Plan', dtype='str') + actual = connection.table_function('explain_substrait', proto_bytes).execute() + + pd.testing.assert_series_equal(actual.df()['Explain Plan'], expected) +