diff --git a/iwutil/__init__.py b/iwutil/__init__.py index a82484d..4db77e0 100644 --- a/iwutil/__init__.py +++ b/iwutil/__init__.py @@ -6,6 +6,8 @@ from pathlib import Path import shutil import sys +import json + def subplots_autolayout( n, *args, n_rows=None, figsize=None, layout="constrained", **kwargs @@ -109,39 +111,61 @@ def check_and_combine_options(default_options, custom_options=None): @singledispatch -def read_df(file): +def read_df(file, **kwargs): + """ + Read a dataframe from a file. Currently supports csv, xls, xlsx, json, and parquet. + + Parameters + ---------- + file : str or Path + File to read + **kwargs : dict + Additional keyword arguments to pass to the read function + """ raise NotImplementedError(f"Reading type {type(file)} not implemented") @read_df.register -def _(file: str): - return iwutil_file_path_helper(file) +def _(file: str, **kwargs): + return iwutil_file_path_helper(file, **kwargs) @read_df.register -def _(file: Path): - return iwutil_file_path_helper(file) +def _(file: Path, **kwargs): + return iwutil_file_path_helper(file, **kwargs) @read_df.register -def _(file: pd.DataFrame): +def _(file: pd.DataFrame, **kwargs): return file -def iwutil_file_path_helper(file_name: str | Path): +def iwutil_file_path_helper(file_name: str | Path, **kwargs): file_extension = Path(file_name).suffix[1:] if file_extension == "csv": - return pd.read_csv(file_name) + return pd.read_csv(file_name, **kwargs) elif file_extension in ["xls", "xlsx"]: - return pd.read_excel(file_name) + return pd.read_excel(file_name, **kwargs) elif file_extension == "json": - return pd.read_json(file_name) + return pd.read_json(file_name, **kwargs) elif file_extension == "parquet": - return pd.read_parquet(file_name) + return pd.read_parquet(file_name, **kwargs) else: raise ValueError(f"Unsupported file type: {file_extension}") +def read_json(file_name): + """ + Read a json file + + Parameters + ---------- + file_name : str or Path + File to read + """ + with open(file_name) as f: + return json.load(f) + def copyfile(src, dst): """ diff --git a/tests/test_save_read.py b/tests/test_save_read.py index 5d42b83..1bba912 100644 --- a/tests/test_save_read.py +++ b/tests/test_save_read.py @@ -34,3 +34,22 @@ def test_save_read_df(file_format, path_format): df_read = iwutil.read_df(file) assert df.equals(df_read) + + +def test_read_df_kwargs(): + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + with tempfile.TemporaryDirectory() as temp_dir: + file = temp_dir + "/test.csv" + iwutil.save.csv(df, file) + + df_read = iwutil.read_df(file, usecols=["a"]) + assert df_read.equals(pd.DataFrame({"a": [1, 2, 3]})) + + +def test_read_json(): + data = {"a": 1, "b": 4} + with tempfile.TemporaryDirectory() as temp_dir: + file = temp_dir + "/test.json" + iwutil.save.json(data, file) + assert iwutil.read_json(file) == data