Skip to content

Commit

Permalink
Closes #4076: bug in reading and writing to/from parquet when the loc…
Browse files Browse the repository at this point in the history
…ales change
  • Loading branch information
ajpotts committed Feb 7, 2025
1 parent b41b1b3 commit 940472f
Show file tree
Hide file tree
Showing 11 changed files with 493 additions and 165 deletions.
25 changes: 24 additions & 1 deletion arkouda/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,26 @@ def _bulk_write_prep(
return datasetNames, data, col_objtypes


def _delete_arkouda_files(prefix_path: str):
"""
Delete files of the pattern prefix_path + LOCALE + <local number>
Parameters
----------
prefix_path : str
Directory and filename prefix for files to be deleted
"""
cast(
str,
generic_msg(
cmd="deleteMatchingFilenames",
args={
"prefix": prefix_path.replace("*", "").replace("+", ""),
},
),
)


def to_parquet(
columns: Union[
Mapping[str, Union[pdarray, Strings, SegArray]],
Expand Down Expand Up @@ -1222,7 +1242,8 @@ def to_parquet(
Creates one file per locale containing that locale's chunk of each pdarray.
If columns is a dictionary, the keys are used as the Parquet column names.
Otherwise, if no names are supplied, 0-up integers are used. By default,
any existing files at path_prefix will be overwritten, unless the user
any existing files at path_prefix will be deleted
(regardless of whether they would be overwritten), unless the user
specifies the 'append' mode, in which case arkouda will attempt to add
<columns> as new datasets to existing files. If the wrong number of files
is present or dataset names already exist, a RuntimeError is raised.
Expand All @@ -1246,6 +1267,8 @@ def to_parquet(
"Please write all columns to the file at once.",
DeprecationWarning,
)
if mode.lower() == "truncate":
_delete_arkouda_files(prefix_path)

datasetNames, data, col_objtypes = _bulk_write_prep(columns, names, convert_categoricals)
# append or single column use the old logic
Expand Down
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,5 @@ markers =
skip_if_rank_not_compiled
skip_if_nl_greater_than
skip_if_nl_less_than
skip_if_nl_eq
skip_if_nl_neq
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
34 changes: 34 additions & 0 deletions src/FileIO.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,37 @@ module FileIO {
return glob("%s_LOCALE*%s".format(prefix, extension));
}

/*
* Delete files matching a prefix and following the pattern <prefix>_LOCALE*.
*/
proc deleteMatchingFilenamesMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
var prefix = msgArgs["prefix"].toScalar(string);
var extension: string;
(prefix, extension) = getFileMetadata(prefix);
deleteMatchingFilenames(prefix, extension);

var repMsg = "Files deleted successfully!";
return new MsgTuple(repMsg, MsgType.NORMAL);
}

proc deleteMatchingFilenames(prefix : string, extension : string) throws {
const filenames = getMatchingFilenames(prefix, extension);
forall filename in filenames{
deleteFile(filename);
}
}

proc deleteFile(filename: string) throws {
try {
remove(filename);
fioLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),
"File %s has been deleted successfully.".format(filename));
} catch e {
fioLogger.error(getModuleName(),getRoutineName(),getLineNumber(),
"Error deleting file: %s".format(e.message()));
}
}

/*
* Returns a tuple composed of a file prefix and extension to be used to
* generate locale-specific filenames to be written to.
Expand Down Expand Up @@ -378,4 +409,7 @@ module FileIO {
}
return new MsgTuple(formatJson(filenames), MsgType.NORMAL);
}

use CommandMap;
registerFunction("deleteMatchingFilenames", deleteMatchingFilenamesMsg, getModuleName());
}
18 changes: 16 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

def pytest_addoption(parser):
parser.addoption(
"--optional-parquet", action="store_true", default=False, help="run optional parquet tests"
"--optional-parquet",
action="store_true",
default=False,
help="run optional parquet tests",
)
parser.addoption(
"--nl",
Expand All @@ -37,7 +40,10 @@ def pytest_addoption(parser):
"be multiplied by the number of locales.",
)
parser.addoption(
"--seed", action="store", default="", help="Value to initialize random number generator."
"--seed",
action="store",
default="",
help="Value to initialize random number generator.",
)


Expand Down Expand Up @@ -160,3 +166,11 @@ def skip_by_num_locales(request):
if request.node.get_closest_marker("skip_if_nl_greater_than"):
if request.node.get_closest_marker("skip_if_nl_greater_than").args[0] < pytest.nl:
pytest.skip("this test requires server with nl =< {}".format(pytest.nl))

if request.node.get_closest_marker("skip_if_nl_eq"):
if request.node.get_closest_marker("skip_if_nl_eq").args[0] == pytest.nl:
pytest.skip("this test requires server with nl == {}".format(pytest.nl))

if request.node.get_closest_marker("skip_if_nl_neq"):
if request.node.get_closest_marker("skip_if_nl_neq").args[0] != pytest.nl:
pytest.skip("this test requires server with nl != {}".format(pytest.nl))
Loading

0 comments on commit 940472f

Please sign in to comment.