Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for reading Decimal128 Parquet columns #2821

Merged
merged 6 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions src/ArrowFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,31 @@ int64_t cpp_getNumRows(const char* filename, char** errMsg) {
}
}

int cpp_getPrecision(const char* filename, const char* colname, char** errMsg) {
try {
std::shared_ptr<arrow::io::ReadableFile> infile;
ARROWRESULT_OK(arrow::io::ReadableFile::Open(filename, arrow::default_memory_pool()),
infile);

std::unique_ptr<parquet::arrow::FileReader> reader;
ARROWSTATUS_OK(parquet::arrow::OpenFile(infile, arrow::default_memory_pool(), &reader));

std::shared_ptr<arrow::Schema> sc;
std::shared_ptr<arrow::Schema>* out = &sc;
ARROWSTATUS_OK(reader->GetSchema(out));

int idx = sc -> GetFieldIndex(colname);

const auto& decimal_type = static_cast<const ::arrow::DecimalType&>(*sc->field(idx)->type());
const int64_t precision = decimal_type.precision();

return precision;
} catch (const std::exception& e) {
*errMsg = strdup(e.what());
return ARROWERROR;
}
}

int cpp_getType(const char* filename, const char* colname, char** errMsg) {
try {
std::shared_ptr<arrow::io::ReadableFile> infile;
Expand Down Expand Up @@ -112,6 +137,8 @@ int cpp_getType(const char* filename, const char* colname, char** errMsg) {
return ARROWDOUBLE;
else if(myType->id() == arrow::Type::LIST)
return ARROWLIST;
else if(myType->id() == arrow::Type::DECIMAL)
return ARROWDECIMAL;
else {
std::string fname(filename);
std::string dname(colname);
Expand Down Expand Up @@ -692,7 +719,7 @@ int cpp_readListColumnByName(const char* filename, void* chpl_arr, const char* c
}
}

int cpp_readColumnByName(const char* filename, void* chpl_arr, const char* colname, int64_t numElems, int64_t startIdx, int64_t batchSize, char** errMsg) {
int cpp_readColumnByName(const char* filename, void* chpl_arr, const char* colname, int64_t numElems, int64_t startIdx, int64_t batchSize, int64_t byteLength, char** errMsg) {
try {
int64_t ty = cpp_getType(filename, colname, errMsg);

Expand Down Expand Up @@ -849,6 +876,22 @@ int cpp_readColumnByName(const char* filename, void* chpl_arr, const char* colna
delete[] def_lvl;
delete[] rep_lvl;
}
} else if(ty == ARROWDECIMAL) {
auto chpl_ptr = (double*)chpl_arr;
parquet::FixedLenByteArray value;
parquet::FixedLenByteArrayReader* reader =
static_cast<parquet::FixedLenByteArrayReader*>(column_reader.get());
startIdx -= reader->Skip(startIdx);

while (reader->HasNext() && i < numElems) {
(void)reader->ReadBatch(1, nullptr, nullptr, &value, &values_read);
arrow::Decimal128 v;
PARQUET_ASSIGN_OR_THROW(v,
::arrow::Decimal128::FromBigEndian(value.ptr, byteLength));
stress-tess marked this conversation as resolved.
Show resolved Hide resolved

chpl_ptr[i] = v.ToDouble(0);
i+=values_read;
}
}
}
return 0;
Expand Down Expand Up @@ -1905,7 +1948,8 @@ int cpp_getDatasetNames(const char* filename, char** dsetResult, bool readNested
sc->field(i)->type()->id() == arrow::Type::BINARY ||
sc->field(i)->type()->id() == arrow::Type::FLOAT ||
sc->field(i)->type()->id() == arrow::Type::DOUBLE ||
(sc->field(i)->type()->id() == arrow::Type::LIST && readNested)
(sc->field(i)->type()->id() == arrow::Type::LIST && readNested) ||
sc->field(i)->type()->id() == arrow::Type::DECIMAL
) {
if(!first)
fields += ("," + sc->field(i)->name());
Expand Down Expand Up @@ -1954,8 +1998,8 @@ extern "C" {
return cpp_readListColumnByName(filename, chpl_arr, colname, numElems, startIdx, batchSize, errMsg);
}

int c_readColumnByName(const char* filename, void* chpl_arr, const char* colname, int64_t numElems, int64_t startIdx, int64_t batchSize, char** errMsg) {
return cpp_readColumnByName(filename, chpl_arr, colname, numElems, startIdx, batchSize, errMsg);
int c_readColumnByName(const char* filename, void* chpl_arr, const char* colname, int64_t numElems, int64_t startIdx, int64_t batchSize, int64_t byteLength, char** errMsg) {
return cpp_readColumnByName(filename, chpl_arr, colname, numElems, startIdx, batchSize, byteLength, errMsg);
}

int c_getType(const char* filename, const char* colname, char** errMsg) {
Expand Down Expand Up @@ -2049,4 +2093,8 @@ extern "C" {
int64_t compression, char** errMsg){
return cpp_writeMultiColToParquet(filename, column_names, ptr_arr, offset_arr, objTypes, datatypes, segArr_sizes, colnum, numelems, rowGroupSize, compression, errMsg);
}

int c_getPrecision(const char* filename, const char* colname, char** errMsg) {
return cpp_getPrecision(filename, colname, errMsg);
}
}
9 changes: 7 additions & 2 deletions src/ArrowFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <parquet/arrow/writer.h>
#include <parquet/column_reader.h>
#include <parquet/api/writer.h>
#include <parquet/schema.h>
#include <cmath>
#include <queue>
extern "C" {
Expand All @@ -25,6 +26,7 @@ extern "C" {
#define ARROWTIMESTAMP ARROWINT64
#define ARROWSTRING 6
#define ARROWLIST 8
#define ARROWDECIMAL 9
#define ARROWERROR -1

#define ARRAYVIEW 0 // not currently used, but included for continuity with Chapel
Expand All @@ -48,10 +50,10 @@ extern "C" {

int c_readColumnByName(const char* filename, void* chpl_arr,
const char* colname, int64_t numElems, int64_t startIdx,
int64_t batchSize, char** errMsg);
int64_t batchSize, int64_t byteLength, char** errMsg);
int cpp_readColumnByName(const char* filename, void* chpl_arr,
const char* colname, int64_t numElems, int64_t startIdx,
int64_t batchSize, char** errMsg);
int64_t batchSize, int64_t byteLength, char** errMsg);

int c_readListColumnByName(const char* filename, void* chpl_arr,
const char* colname, int64_t numElems,
Expand Down Expand Up @@ -142,6 +144,9 @@ extern "C" {
void** ptr_arr, void** offset_arr, void* objTypes, void* datatypes,
void* segArr_sizes, int64_t colnum, int64_t numelems, int64_t rowGroupSize,
int64_t compression, char** errMsg);

int c_getPrecision(const char* filename, const char* colname, char** errMsg);
int cpp_getPrecision(const char* filename, const char* colname, char** errMsg);

const char* c_getVersionInfo(void);
const char* cpp_getVersionInfo(void);
Expand Down
54 changes: 48 additions & 6 deletions src/ParquetMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@ module ParquetMsg {
extern var ARROWLIST: c_int;
extern var ARROWDOUBLE: c_int;
extern var ARROWERROR: c_int;
extern var ARROWDECIMAL: c_int;

enum ArrowTypes { int64, int32, uint64, uint32,
stringArr, timestamp, boolean,
double, float, list, notimplemented };
double, float, list, decimal,
notimplemented };

record parquetErrorMsg {
var errMsg: c_ptr(uint(8));
Expand Down Expand Up @@ -119,8 +121,8 @@ module ParquetMsg {
return (subdoms, (+ reduce lengths));
}

proc readFilesByName(ref A: [] ?t, filenames: [] string, sizes: [] int, dsetname: string, ty) throws {
extern proc c_readColumnByName(filename, chpl_arr, colNum, numElems, startIdx, batchSize, errMsg): int;
proc readFilesByName(ref A: [] ?t, filenames: [] string, sizes: [] int, dsetname: string, ty, byteLength=-1) throws {
extern proc c_readColumnByName(filename, chpl_arr, colNum, numElems, startIdx, batchSize, byteLength, errMsg): int;
var (subdoms, length) = getSubdomains(sizes);
var fileOffsets = (+ scan sizes) - sizes;

Expand All @@ -137,7 +139,7 @@ module ParquetMsg {
var pqErr = new parquetErrorMsg();
if c_readColumnByName(filename.localize().c_str(), c_ptrTo(A[intersection.low]),
dsetname.localize().c_str(), intersection.size, intersection.low - off,
batchSize,
batchSize, byteLength,
c_ptrTo(pqErr.errMsg)) == ARROWERROR {
pqErr.parquetError(getLineNumber(), getRoutineName(), getModuleName());
}
Expand All @@ -149,7 +151,7 @@ module ParquetMsg {
}

proc readStrFilesByName(A: [] ?t, filenames: [] string, sizes: [] int, dsetname: string, ty) throws {
extern proc c_readColumnByName(filename, chpl_arr, colNum, numElems, startIdx, batchSize, errMsg): int;
extern proc c_readColumnByName(filename, chpl_arr, colNum, numElems, startIdx, batchSize, byteLength, errMsg): int;
var (subdoms, length) = getSubdomains(sizes);

coforall loc in A.targetLocales() do on loc {
Expand All @@ -166,7 +168,7 @@ module ParquetMsg {

if c_readColumnByName(filename.localize().c_str(), c_ptrTo(col),
dsetname.localize().c_str(), intersection.size, 0,
batchSize, c_ptrTo(pqErr.errMsg)) == ARROWERROR {
batchSize, -1, c_ptrTo(pqErr.errMsg)) == ARROWERROR {
pqErr.parquetError(getLineNumber(), getRoutineName(), getModuleName());
}
A[filedom] = col;
Expand Down Expand Up @@ -356,6 +358,7 @@ module ParquetMsg {
else if arrType == ARROWDOUBLE then return ArrowTypes.double;
else if arrType == ARROWFLOAT then return ArrowTypes.float;
else if arrType == ARROWLIST then return ArrowTypes.list;
else if arrType == ARROWDECIMAL then return ArrowTypes.decimal;
throw getErrorWithContext(
msg="Unrecognized Parquet data type",
getLineNumber(),
Expand Down Expand Up @@ -870,6 +873,13 @@ module ParquetMsg {
var create_str: string = parseListDataset(filenames, dsetname, list_ty, len, sizes, st);
rnames.pushBack((dsetname, ObjType.SEGARRAY, create_str));
}
} else if ty == ArrowTypes.decimal {
var byteLength = getByteLength(filenames[0], dsetname);
var entryVal = createSymEntry(len, real);
readFilesByName(entryVal.a, filenames, sizes, dsetname, ty, byteLength);
var valName = st.nextName();
st.addEntry(valName, entryVal);
rnames.pushBack((dsetname, ObjType.PDARRAY, valName));
} else {
var errorMsg = "DType %s not supported for Parquet reading".doFormat(ty);
pqLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
Expand Down Expand Up @@ -900,6 +910,38 @@ module ParquetMsg {
return new list(datasets.split(","));
}

// Decimal columns in Parquet have a fixed number of bytes based on the precision,
// but there isn't a way in Parquet to get the precision. Since the byte length
// will always remain the same for each precision value, here we just created a
// lookup table that maps from the precision to the byte value.
proc getByteLength(filename, colname) throws {
extern proc c_getPrecision(filename, colname, errMsg): int(32);
var pqErr = new parquetErrorMsg();
var res: c_ptr(uint(8));
defer {
extern proc c_free_string(ptr);
c_free_string(res);
}

var precision = c_getPrecision(filename.c_str(), colname.c_str(), c_ptrTo(pqErr.errMsg));
if precision < 3 then return 1;
else if precision < 5 then return 2;
else if precision < 7 then return 3;
else if precision < 10 then return 4;
else if precision < 12 then return 5;
else if precision < 15 then return 6;
else if precision < 17 then return 7;
else if precision < 19 then return 8;
else if precision < 22 then return 9;
else if precision < 24 then return 10;
else if precision < 27 then return 11;
else if precision < 29 then return 12;
else if precision < 32 then return 13;
else if precision < 34 then return 14;
else if precision < 36 then return 15;
return 16;
}

proc pdarray_toParquetMsg(msgArgs: MessageArgs, st: borrowed SymTab): bool throws {
var mode = msgArgs.get("mode").getIntValue();
var filename: string = msgArgs.getValueOf("prefix");
Expand Down
16 changes: 16 additions & 0 deletions tests/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,22 @@ def test_float_edge(self):
for i in range(len(pd_l)):
self.assertTrue(np.allclose(pd_l[i], ak_l[i], equal_nan=True))

def test_decimal_reads(self):
cols = []
data = []
for i in range(1,39):
cols.append(("decCol" + str(i), pa.decimal128(i, 0)))
data.append([i])

schema = pa.schema(cols)

table = pa.Table.from_arrays(data, schema=schema)
with tempfile.TemporaryDirectory(dir=ParquetTest.par_test_base_tmp) as tmp_dirname:
pq.write_table(table, f"{tmp_dirname}/decimal")
ak_data = ak.read(f"{tmp_dirname}/decimal")
for i in range(1,39):
self.assertTrue(np.allclose(ak_data['decCol'+str(i)].to_ndarray(), data[i-1]))

@pytest.mark.optional_parquet
def test_against_standard_files(self):
datadir = "resources/parquet-testing"
Expand Down