diff --git a/src/ArrowFunctions.cpp b/src/ArrowFunctions.cpp index 870f8964af..2b0bfe733a 100644 --- a/src/ArrowFunctions.cpp +++ b/src/ArrowFunctions.cpp @@ -65,6 +65,31 @@ int64_t cpp_getNumRows(const char* filename, char** errMsg) { } } +int cpp_getPrecision(const char* filename, const char* colname, char** errMsg) { + try { + std::shared_ptr infile; + ARROWRESULT_OK(arrow::io::ReadableFile::Open(filename, arrow::default_memory_pool()), + infile); + + std::unique_ptr reader; + ARROWSTATUS_OK(parquet::arrow::OpenFile(infile, arrow::default_memory_pool(), &reader)); + + std::shared_ptr sc; + std::shared_ptr* out = ≻ + ARROWSTATUS_OK(reader->GetSchema(out)); + + int idx = sc -> GetFieldIndex(colname); + + const auto& decimal_type = static_cast(*sc->field(idx)->type()); + const int64_t precision = decimal_type.precision(); + + return precision; + } catch (const std::exception& e) { + *errMsg = strdup(e.what()); + return ARROWERROR; + } +} + int cpp_getType(const char* filename, const char* colname, char** errMsg) { try { std::shared_ptr infile; @@ -112,6 +137,8 @@ int cpp_getType(const char* filename, const char* colname, char** errMsg) { return ARROWDOUBLE; else if(myType->id() == arrow::Type::LIST) return ARROWLIST; + else if(myType->id() == arrow::Type::DECIMAL) + return ARROWDECIMAL; else { std::string fname(filename); std::string dname(colname); @@ -692,7 +719,7 @@ int cpp_readListColumnByName(const char* filename, void* chpl_arr, const char* c } } -int cpp_readColumnByName(const char* filename, void* chpl_arr, const char* colname, int64_t numElems, int64_t startIdx, int64_t batchSize, char** errMsg) { +int cpp_readColumnByName(const char* filename, void* chpl_arr, const char* colname, int64_t numElems, int64_t startIdx, int64_t batchSize, int64_t byteLength, char** errMsg) { try { int64_t ty = cpp_getType(filename, colname, errMsg); @@ -849,6 +876,22 @@ int cpp_readColumnByName(const char* filename, void* chpl_arr, const char* colna delete[] def_lvl; delete[] rep_lvl; } + } else if(ty == ARROWDECIMAL) { + auto chpl_ptr = (double*)chpl_arr; + parquet::FixedLenByteArray value; + parquet::FixedLenByteArrayReader* reader = + static_cast(column_reader.get()); + startIdx -= reader->Skip(startIdx); + + while (reader->HasNext() && i < numElems) { + (void)reader->ReadBatch(1, nullptr, nullptr, &value, &values_read); + arrow::Decimal128 v; + PARQUET_ASSIGN_OR_THROW(v, + ::arrow::Decimal128::FromBigEndian(value.ptr, byteLength)); + + chpl_ptr[i] = v.ToDouble(0); + i+=values_read; + } } } return 0; @@ -1905,7 +1948,8 @@ int cpp_getDatasetNames(const char* filename, char** dsetResult, bool readNested sc->field(i)->type()->id() == arrow::Type::BINARY || sc->field(i)->type()->id() == arrow::Type::FLOAT || sc->field(i)->type()->id() == arrow::Type::DOUBLE || - (sc->field(i)->type()->id() == arrow::Type::LIST && readNested) + (sc->field(i)->type()->id() == arrow::Type::LIST && readNested) || + sc->field(i)->type()->id() == arrow::Type::DECIMAL ) { if(!first) fields += ("," + sc->field(i)->name()); @@ -1954,8 +1998,8 @@ extern "C" { return cpp_readListColumnByName(filename, chpl_arr, colname, numElems, startIdx, batchSize, errMsg); } - int c_readColumnByName(const char* filename, void* chpl_arr, const char* colname, int64_t numElems, int64_t startIdx, int64_t batchSize, char** errMsg) { - return cpp_readColumnByName(filename, chpl_arr, colname, numElems, startIdx, batchSize, errMsg); + int c_readColumnByName(const char* filename, void* chpl_arr, const char* colname, int64_t numElems, int64_t startIdx, int64_t batchSize, int64_t byteLength, char** errMsg) { + return cpp_readColumnByName(filename, chpl_arr, colname, numElems, startIdx, batchSize, byteLength, errMsg); } int c_getType(const char* filename, const char* colname, char** errMsg) { @@ -2049,4 +2093,8 @@ extern "C" { int64_t compression, char** errMsg){ return cpp_writeMultiColToParquet(filename, column_names, ptr_arr, offset_arr, objTypes, datatypes, segArr_sizes, colnum, numelems, rowGroupSize, compression, errMsg); } + + int c_getPrecision(const char* filename, const char* colname, char** errMsg) { + return cpp_getPrecision(filename, colname, errMsg); + } } diff --git a/src/ArrowFunctions.h b/src/ArrowFunctions.h index 9b785c9207..ae01ea812d 100644 --- a/src/ArrowFunctions.h +++ b/src/ArrowFunctions.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include extern "C" { @@ -25,6 +26,7 @@ extern "C" { #define ARROWTIMESTAMP ARROWINT64 #define ARROWSTRING 6 #define ARROWLIST 8 +#define ARROWDECIMAL 9 #define ARROWERROR -1 #define ARRAYVIEW 0 // not currently used, but included for continuity with Chapel @@ -48,10 +50,10 @@ extern "C" { int c_readColumnByName(const char* filename, void* chpl_arr, const char* colname, int64_t numElems, int64_t startIdx, - int64_t batchSize, char** errMsg); + int64_t batchSize, int64_t byteLength, char** errMsg); int cpp_readColumnByName(const char* filename, void* chpl_arr, const char* colname, int64_t numElems, int64_t startIdx, - int64_t batchSize, char** errMsg); + int64_t batchSize, int64_t byteLength, char** errMsg); int c_readListColumnByName(const char* filename, void* chpl_arr, const char* colname, int64_t numElems, @@ -142,6 +144,9 @@ extern "C" { void** ptr_arr, void** offset_arr, void* objTypes, void* datatypes, void* segArr_sizes, int64_t colnum, int64_t numelems, int64_t rowGroupSize, int64_t compression, char** errMsg); + + int c_getPrecision(const char* filename, const char* colname, char** errMsg); + int cpp_getPrecision(const char* filename, const char* colname, char** errMsg); const char* c_getVersionInfo(void); const char* cpp_getVersionInfo(void); diff --git a/src/ParquetMsg.chpl b/src/ParquetMsg.chpl index b9c4f3ad1a..aad5bb5a45 100644 --- a/src/ParquetMsg.chpl +++ b/src/ParquetMsg.chpl @@ -58,10 +58,12 @@ module ParquetMsg { extern var ARROWLIST: c_int; extern var ARROWDOUBLE: c_int; extern var ARROWERROR: c_int; + extern var ARROWDECIMAL: c_int; enum ArrowTypes { int64, int32, uint64, uint32, stringArr, timestamp, boolean, - double, float, list, notimplemented }; + double, float, list, decimal, + notimplemented }; record parquetErrorMsg { var errMsg: c_ptr(uint(8)); @@ -119,8 +121,8 @@ module ParquetMsg { return (subdoms, (+ reduce lengths)); } - proc readFilesByName(ref A: [] ?t, filenames: [] string, sizes: [] int, dsetname: string, ty) throws { - extern proc c_readColumnByName(filename, chpl_arr, colNum, numElems, startIdx, batchSize, errMsg): int; + proc readFilesByName(ref A: [] ?t, filenames: [] string, sizes: [] int, dsetname: string, ty, byteLength=-1) throws { + extern proc c_readColumnByName(filename, chpl_arr, colNum, numElems, startIdx, batchSize, byteLength, errMsg): int; var (subdoms, length) = getSubdomains(sizes); var fileOffsets = (+ scan sizes) - sizes; @@ -137,7 +139,7 @@ module ParquetMsg { var pqErr = new parquetErrorMsg(); if c_readColumnByName(filename.localize().c_str(), c_ptrTo(A[intersection.low]), dsetname.localize().c_str(), intersection.size, intersection.low - off, - batchSize, + batchSize, byteLength, c_ptrTo(pqErr.errMsg)) == ARROWERROR { pqErr.parquetError(getLineNumber(), getRoutineName(), getModuleName()); } @@ -149,7 +151,7 @@ module ParquetMsg { } proc readStrFilesByName(A: [] ?t, filenames: [] string, sizes: [] int, dsetname: string, ty) throws { - extern proc c_readColumnByName(filename, chpl_arr, colNum, numElems, startIdx, batchSize, errMsg): int; + extern proc c_readColumnByName(filename, chpl_arr, colNum, numElems, startIdx, batchSize, byteLength, errMsg): int; var (subdoms, length) = getSubdomains(sizes); coforall loc in A.targetLocales() do on loc { @@ -166,7 +168,7 @@ module ParquetMsg { if c_readColumnByName(filename.localize().c_str(), c_ptrTo(col), dsetname.localize().c_str(), intersection.size, 0, - batchSize, c_ptrTo(pqErr.errMsg)) == ARROWERROR { + batchSize, -1, c_ptrTo(pqErr.errMsg)) == ARROWERROR { pqErr.parquetError(getLineNumber(), getRoutineName(), getModuleName()); } A[filedom] = col; @@ -356,6 +358,7 @@ module ParquetMsg { else if arrType == ARROWDOUBLE then return ArrowTypes.double; else if arrType == ARROWFLOAT then return ArrowTypes.float; else if arrType == ARROWLIST then return ArrowTypes.list; + else if arrType == ARROWDECIMAL then return ArrowTypes.decimal; throw getErrorWithContext( msg="Unrecognized Parquet data type", getLineNumber(), @@ -870,6 +873,13 @@ module ParquetMsg { var create_str: string = parseListDataset(filenames, dsetname, list_ty, len, sizes, st); rnames.pushBack((dsetname, ObjType.SEGARRAY, create_str)); } + } else if ty == ArrowTypes.decimal { + var byteLength = getByteLength(filenames[0], dsetname); + var entryVal = createSymEntry(len, real); + readFilesByName(entryVal.a, filenames, sizes, dsetname, ty, byteLength); + var valName = st.nextName(); + st.addEntry(valName, entryVal); + rnames.pushBack((dsetname, ObjType.PDARRAY, valName)); } else { var errorMsg = "DType %s not supported for Parquet reading".doFormat(ty); pqLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); @@ -900,6 +910,38 @@ module ParquetMsg { return new list(datasets.split(",")); } + // Decimal columns in Parquet have a fixed number of bytes based on the precision, + // but there isn't a way in Parquet to get the precision. Since the byte length + // will always remain the same for each precision value, here we just created a + // lookup table that maps from the precision to the byte value. + proc getByteLength(filename, colname) throws { + extern proc c_getPrecision(filename, colname, errMsg): int(32); + var pqErr = new parquetErrorMsg(); + var res: c_ptr(uint(8)); + defer { + extern proc c_free_string(ptr); + c_free_string(res); + } + + var precision = c_getPrecision(filename.c_str(), colname.c_str(), c_ptrTo(pqErr.errMsg)); + if precision < 3 then return 1; + else if precision < 5 then return 2; + else if precision < 7 then return 3; + else if precision < 10 then return 4; + else if precision < 12 then return 5; + else if precision < 15 then return 6; + else if precision < 17 then return 7; + else if precision < 19 then return 8; + else if precision < 22 then return 9; + else if precision < 24 then return 10; + else if precision < 27 then return 11; + else if precision < 29 then return 12; + else if precision < 32 then return 13; + else if precision < 34 then return 14; + else if precision < 36 then return 15; + return 16; + } + proc pdarray_toParquetMsg(msgArgs: MessageArgs, st: borrowed SymTab): bool throws { var mode = msgArgs.get("mode").getIntValue(); var filename: string = msgArgs.getValueOf("prefix"); diff --git a/tests/parquet_test.py b/tests/parquet_test.py index 687c3bd21d..845f44bb3a 100644 --- a/tests/parquet_test.py +++ b/tests/parquet_test.py @@ -558,6 +558,22 @@ def test_float_edge(self): for i in range(len(pd_l)): self.assertTrue(np.allclose(pd_l[i], ak_l[i], equal_nan=True)) + def test_decimal_reads(self): + cols = [] + data = [] + for i in range(1,39): + cols.append(("decCol" + str(i), pa.decimal128(i, 0))) + data.append([i]) + + schema = pa.schema(cols) + + table = pa.Table.from_arrays(data, schema=schema) + with tempfile.TemporaryDirectory(dir=ParquetTest.par_test_base_tmp) as tmp_dirname: + pq.write_table(table, f"{tmp_dirname}/decimal") + ak_data = ak.read(f"{tmp_dirname}/decimal") + for i in range(1,39): + self.assertTrue(np.allclose(ak_data['decCol'+str(i)].to_ndarray(), data[i-1])) + @pytest.mark.optional_parquet def test_against_standard_files(self): datadir = "resources/parquet-testing"