Skip to content

Commit

Permalink
feat: multifile handling in pin_upload/pin_download (#319)
Browse files Browse the repository at this point in the history
* first prototype of working pin_upload

* save single file uploads as Paths too

* handle pin_download as well

* change back connect api

* update tests

* add tests for upload/download

* return hashes in a list
  • Loading branch information
isabelizimm authored Dec 13, 2024
1 parent e64874f commit 368641f
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 72 deletions.
74 changes: 52 additions & 22 deletions pins/boards.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .cache import PinsCache
from .config import get_allow_rsc_short_name
from .drivers import default_title, load_data, load_file, save_data
from .drivers import REQUIRES_SINGLE_FILE, default_title, load_data, load_file, save_data
from .errors import PinsError, PinsVersionError
from .meta import Meta, MetaFactory, MetaRaw
from .utils import ExtendMethodDoc, inform, warn_deprecated
Expand Down Expand Up @@ -243,9 +243,17 @@ def _pin_store(
if isinstance(x, (tuple, list)) and len(x) == 1:
x = x[0]

_p = Path(x)
_base_len = len(_p.name) - len("".join(_p.suffixes))
object_name = _p.name[:_base_len]
if not isinstance(x, (list, tuple)):
_p = Path(x)
_base_len = len(_p.name) - len("".join(_p.suffixes))
object_name = _p.name[:_base_len]
else:
# multifile upload, keep list of filenames
object_name = []
for file in x:
_p = Path(file)
# _base_len = len(_p.name) - len("".join(_p.suffixes))
object_name.append(_p.name) # [:_base_len])
else:
object_name = None

Expand Down Expand Up @@ -415,20 +423,32 @@ def pin_download(self, name, version=None, hash=None) -> Sequence[str]:
if hash is not None:
raise NotImplementedError("TODO: validate hash")

fnames = [meta.file] if isinstance(meta.file, str) else meta.file
pin_type = meta.type

if len(fnames) > 1 and pin_type in REQUIRES_SINGLE_FILE:
raise ValueError("Cannot load data when more than 1 file")

pin_name = self.path_to_pin(name)
files = []

# TODO: raise for multiple files
# fetch file
with load_file(
meta, self.fs, self.construct_path([pin_name, meta.version.version])
) as f:
# could also check whether f isinstance of PinCache
fname = getattr(f, "name", None)
for fname in fnames:
# fetch file
with load_file(
fname,
self.fs,
self.construct_path([pin_name, meta.version.version]),
pin_type,
) as f:
# could also check whether f isinstance of PinCache
fname = getattr(f, "name", None)

if fname is None:
raise PinsError("pin_download requires a cache.")
if fname is None:
raise PinsError("pin_download requires a cache.")

return [str(Path(fname).absolute())]
files.append(str(Path(fname).absolute()))

return files

def pin_upload(
self,
Expand Down Expand Up @@ -461,6 +481,12 @@ def pin_upload(
This gets stored on the Meta.user field.
"""

if isinstance(paths, (list, tuple)):
# check if all paths exist
for path in paths:
if not Path(path).is_file():
raise PinsError(f"Path is not a valid file: {path}")

return self._pin_store(
paths,
name,
Expand Down Expand Up @@ -665,7 +691,7 @@ def prepare_pin_version(
metadata: Mapping | None = None,
versioned: bool | None = None,
created: datetime | None = None,
object_name: str | None = None,
object_name: str | list[str] | None = None,
):
meta = self._create_meta(
pin_dir_path,
Expand Down Expand Up @@ -710,14 +736,18 @@ def _create_meta(
# create metadata from object on disk ---------------------------------
# save all pin data to a temporary folder (including data.txt), so we
# can fs.put it all straight onto the backend filesystem

if object_name is None:
p_obj = Path(pin_dir_path) / name
apply_suffix = True
if isinstance(object_name, (list, tuple)):
apply_suffix = False
p_obj = []
for obj in object_name:
p_obj.append(str(Path(pin_dir_path) / obj))
elif object_name is None:
p_obj = str(Path(pin_dir_path) / name)
else:
p_obj = Path(pin_dir_path) / object_name

p_obj = str(Path(pin_dir_path) / object_name)
# file is saved locally in order to hash, calc size
file_names = save_data(x, str(p_obj), type)
file_names = save_data(x, p_obj, type, apply_suffix)

meta = self.meta_factory.create(
pin_dir_path,
Expand Down Expand Up @@ -910,7 +940,7 @@ def pin_download(self, name, version=None, hash=None) -> Sequence[str]:
meta = self.pin_meta(name, version)

if isinstance(meta, MetaRaw):
f = load_file(meta, self.fs, None)
f = load_file(meta.file, self.fs, None, meta.type)
else:
raise NotImplementedError(
"TODO: pin_download currently can only read a url to a single file."
Expand Down
68 changes: 34 additions & 34 deletions pins/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


UNSAFE_TYPES = frozenset(["joblib"])
REQUIRES_SINGLE_FILE = frozenset(["csv", "joblib", "file"])
REQUIRES_SINGLE_FILE = frozenset(["csv", "joblib"])


def _assert_is_pandas_df(x, file_type: str) -> None:
Expand All @@ -22,35 +22,24 @@ def _assert_is_pandas_df(x, file_type: str) -> None:
)


def load_path(meta, path_to_version):
# Check that only a single file name was given
fnames = [meta.file] if isinstance(meta.file, str) else meta.file

_type = meta.type

if len(fnames) > 1 and _type in REQUIRES_SINGLE_FILE:
raise ValueError("Cannot load data when more than 1 file")

def load_path(filename: str, path_to_version, pin_type=None):
# file path creation ------------------------------------------------------

if _type == "table":
if pin_type == "table":
# this type contains an rds and csv files named data.{ext}, so we match
# R pins behavior and hardcode the name
target_fname = "data.csv"
else:
target_fname = fnames[0]
filename = "data.csv"

if path_to_version is not None:
path_to_file = f"{path_to_version}/{target_fname}"
path_to_file = f"{path_to_version}/{filename}"
else:
# BoardUrl doesn't have versions, and the file is the full url
path_to_file = target_fname
path_to_file = filename

return path_to_file


def load_file(meta: Meta, fs, path_to_version):
return fs.open(load_path(meta, path_to_version))
def load_file(filename: str, fs, path_to_version, pin_type):
return fs.open(load_path(filename, path_to_version, pin_type))


def load_data(
Expand Down Expand Up @@ -81,7 +70,7 @@ def load_data(
" * https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations"
)

with load_file(meta, fs, path_to_version) as f:
with load_file(meta.file, fs, path_to_version, meta.type) as f:
if meta.type == "csv":
import pandas as pd

Expand Down Expand Up @@ -136,7 +125,9 @@ def load_data(
raise NotImplementedError(f"No driver for type {meta.type}")


def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequence[str]":
def save_data(
obj, fname, pin_type=None, apply_suffix: bool = True
) -> "str | Sequence[str]":
# TODO: extensible saving with deferred importing
# TODO: how to encode arguments to saving / loading drivers?
# e.g. pandas index options
Expand All @@ -145,59 +136,68 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen
# of saving / loading objects different ways.

if apply_suffix:
if type == "file":
if pin_type == "file":
suffix = "".join(Path(obj).suffixes)
else:
suffix = f".{type}"
suffix = f".{pin_type}"
else:
suffix = ""

final_name = f"{fname}{suffix}"
if isinstance(fname, list):
final_name = fname
else:
final_name = f"{fname}{suffix}"

if type == "csv":
if pin_type == "csv":
_assert_is_pandas_df(obj, file_type=type)

obj.to_csv(final_name, index=False)

elif type == "arrow":
elif pin_type == "arrow":
# NOTE: R pins accepts the type arrow, and saves it as feather.
# we allow reading this type, but raise an error for writing.
_assert_is_pandas_df(obj, file_type=type)

obj.to_feather(final_name)

elif type == "feather":
elif pin_type == "feather":
_assert_is_pandas_df(obj, file_type=type)

raise NotImplementedError(
'Saving data as type "feather" no longer supported. Use type "arrow" instead.'
)

elif type == "parquet":
elif pin_type == "parquet":
_assert_is_pandas_df(obj, file_type=type)

obj.to_parquet(final_name)

elif type == "joblib":
elif pin_type == "joblib":
import joblib

joblib.dump(obj, final_name)

elif type == "json":
elif pin_type == "json":
import json

json.dump(obj, open(final_name, "w"))

elif type == "file":
elif pin_type == "file":
import contextlib
import shutil

if isinstance(obj, list):
for file, final in zip(obj, final_name):
with contextlib.suppress(shutil.SameFileError):
shutil.copyfile(str(file), final)
return obj
# ignore the case where the source is the same as the target
with contextlib.suppress(shutil.SameFileError):
shutil.copyfile(str(obj), final_name)
else:
with contextlib.suppress(shutil.SameFileError):
shutil.copyfile(str(obj), final_name)

else:
raise NotImplementedError(f"Cannot save type: {type}")
raise NotImplementedError(f"Cannot save type: {pin_type}")

return final_name

Expand Down
7 changes: 6 additions & 1 deletion pins/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,12 @@ def create(

raise NotImplementedError("Cannot create from file object.")
else:
raise NotImplementedError("TODO: creating meta from multiple files")
if isinstance(files, (list, tuple)):
from pathlib import Path

file_name = [Path(f).name for f in files]
file_size = [Path(f).stat().st_size for f in files]
version = Version.from_files(files, created)

return Meta(
title=title,
Expand Down
20 changes: 20 additions & 0 deletions pins/tests/test_boards.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,26 @@ def test_board_pin_upload_path_list(board_with_cache, tmp_path):
(pin_path,) = board_with_cache.pin_download("cool_pin")


def test_board_pin_download_filename_multifile(board_with_cache, tmp_path):
# create and save data
df = pd.DataFrame({"x": [1, 2, 3]})

path1, path2 = tmp_path / "data1.csv", tmp_path / "data2.csv"
df.to_csv(path1, index=False)
df.to_csv(path2, index=False)

meta = board_with_cache.pin_upload([path1, path2], "cool_pin")

assert meta.type == "file"
assert meta.file == ["data1.csv", "data2.csv"]

pin_path = board_with_cache.pin_download("cool_pin")

assert len(pin_path) == 2
assert Path(pin_path[0]).name == "data1.csv"
assert Path(pin_path[1]).name == "data2.csv"


def test_board_pin_write_rsc_index_html(board, tmp_path: Path, snapshot):
if board.fs.protocol != "rsc":
pytest.skip()
Expand Down
14 changes: 3 additions & 11 deletions pins/tests/test_drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,31 +164,23 @@ def test_driver_apply_suffix_false(tmp_path: Path):


class TestLoadFile:
def test_multi_file_raises(self):
class _MockMetaMultiFile:
file: str | list[str] = ["a", "b"]
type: str = "csv"

with pytest.raises(ValueError, match="Cannot load data when more than 1 file"):
load_path(_MockMetaMultiFile(), None)

def test_str_file(self):
class _MockMetaStrFile:
file: str = "a"
type: str = "csv"

assert load_path(_MockMetaStrFile(), None) == "a"
assert load_path(_MockMetaStrFile().file, None, _MockMetaStrFile().type) == "a"

def test_table(self):
class _MockMetaTable:
file: str = "a"
type: str = "table"

assert load_path(_MockMetaTable(), None) == "data.csv"
assert load_path(_MockMetaTable().file, None, _MockMetaTable().type) == "data.csv"

def test_version(self):
class _MockMetaTable:
file: str = "a"
type: str = "csv"

assert load_path(_MockMetaTable(), "v1") == "v1/a"
assert load_path(_MockMetaTable().file, "v1", _MockMetaTable().type) == "v1/a"
11 changes: 7 additions & 4 deletions pins/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import Mapping, Sequence

from xxhash import xxh64
Expand Down Expand Up @@ -56,9 +57,7 @@ def render_created(self):
def hash_file(f: IOBase, block_size: int = -1) -> str:
# TODO: what kind of things implement the "buffer API"?
hasher = xxh64()

buf = f.read(block_size)

while len(buf) > 0:
hasher.update(buf)
buf = f.read(block_size)
Expand Down Expand Up @@ -99,14 +98,18 @@ def from_files(
) -> Version:
hashes = []
for f in files:
hash_ = cls.hash_file(open(f, "rb") if isinstance(f, str) else f)
hash_ = cls.hash_file(open(f, "rb") if isinstance(f, (str, Path)) else f)
hashes.append(hash_)

if created is None:
created = datetime.now()

if len(hashes) > 1:
raise NotImplementedError("Only 1 file may be currently be hashed")
# Combine the hashes into a single string
combined_hashes = "".join(hashes)

# Create an xxh64 hash of the combined string
hashes = [xxh64(combined_hashes).hexdigest()]

return cls(created, hashes[0])

Expand Down

0 comments on commit 368641f

Please sign in to comment.