Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
fix tmppath issue

fix tests

make error message bad again
  • Loading branch information
agoscinski committed Feb 4, 2025
1 parent 2a34c7d commit 3474a15
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 47 deletions.
1 change: 1 addition & 0 deletions disk_objectstore/backup_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def call_rsync( # pylint: disable=too-many-arguments,too-many-branches

def get_existing_backup_folders(self):
"""Get all folders matching the backup folder name pattern."""

success, stdout = self.run_cmd(
[
"find",
Expand Down
49 changes: 27 additions & 22 deletions disk_objectstore/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(self, folder: str | Path) -> None:
self._folder = Path(folder).resolve()
# Will be populated by the _get_session function
self._session: Session | None = None
self._keep_open_session: Session | None = None

# These act as caches and will be populated by the corresponding properties
# IMPORANT! IF YOU ADD MORE, REMEMBER TO CLEAR THEM IN `init_container()`!
Expand All @@ -134,9 +135,17 @@ def get_folder(self) -> Path:
def close(self) -> None:
"""Close open files (in particular, the connection to the SQLite DB)."""
if self._session is not None:
engine = self._session.bind
self._session.close()
engine.dispose()
self._session = None

if self._keep_open_session is not None:
engine = self._keep_open_session.bind
self._keep_open_session.close()
engine.dispose()
self._keep_open_session = None

def __enter__(self) -> Container:
"""Return a context manager that will close the session when exiting the context."""
return self
Expand Down Expand Up @@ -180,32 +189,21 @@ def _get_config_file(self) -> Path:
"""Return the path to the container config file."""
return self._folder / "config.json"

@overload
def _get_session(
self, create: bool = False, raise_if_missing: Literal[True] = True
def _create_init_session(
self
) -> Session:
...

@overload
def _get_session(
self, create: bool = False, raise_if_missing: Literal[False] = False
) -> Session | None:
...

def _get_session(
self, create: bool = False, raise_if_missing: bool = False
) -> Session | None:
"""Return a new session to connect to the pack-index SQLite DB.
:param create: if True, creates the sqlite file and schema.
:param raise_if_missing: ignored if create==True. If create==False, and the index file
is missing, either raise an exception (FileNotFoundError) if this flag is True, or return None
"""
return get_session(
self._get_pack_index_path(),
create=create,
raise_if_missing=raise_if_missing,
)
if self._keep_open_session is None:
self._keep_open_session = get_session(
self._get_pack_index_path(),
create=True,
)
return self._keep_open_session

def _get_cached_session(self) -> Session:
"""Return the SQLAlchemy session to access the SQLite file,
Expand All @@ -214,7 +212,10 @@ def _get_cached_session(self) -> Session:
# the latter means that in the previous run the pack file was missing
# but maybe by now it has been created!
if self._session is None:
self._session = self._get_session(create=False, raise_if_missing=True)
self._session = get_session(
self._get_pack_index_path(),
create=False,
)
return self._session

def _get_loose_path_from_hashkey(self, hashkey: str) -> Path:
Expand Down Expand Up @@ -332,6 +333,7 @@ def init_container(
raise ValueError(f'Unknown hash type "{hash_type}"')

if clear:
self.close()
if self._folder.exists():
shutil.rmtree(self._folder)

Expand Down Expand Up @@ -391,7 +393,7 @@ def init_container(
]:
os.makedirs(folder)

self._get_session(create=True)
self._create_init_session()

def _get_repository_config(self) -> dict[str, int | str]:
"""Return the repository config."""
Expand Down Expand Up @@ -1141,7 +1143,7 @@ def get_total_size(self) -> TotalSize:

retval["total_size_packindexes_on_disk"] = (
self._get_pack_index_path().stat().st_size
)
)

total_size_loose = 0
for loose_hashkey in self._list_loose():
Expand Down Expand Up @@ -1916,6 +1918,9 @@ def add_objects_to_pack( # pylint: disable=too-many-arguments
:return: a list of object hash keys
"""
# TODO should be custom error but not sure what
if not self.is_initialised:
raise ValueError("Invalid use of function, please first initialise the container.")

Check warning on line 1923 in disk_objectstore/container.py

View check run for this annotation

Codecov / codecov/patch

disk_objectstore/container.py#L1923

Added line #L1923 was not covered by tests
stream_list: list[StreamSeekBytesType] = [
io.BytesIO(content) for content in content_list
]
Expand Down
6 changes: 2 additions & 4 deletions disk_objectstore/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Obj(Base): # pylint: disable=too-few-public-methods


def get_session(
path: Path, create: bool = False, raise_if_missing: bool = False
path: Path, create: bool = False
) -> Optional[Session]:
"""Return a new session to connect to the pack-index SQLite DB.
Expand All @@ -41,9 +41,7 @@ def get_session(
is missing, either raise an exception (FileNotFoundError) if this flag is True, or return None
"""
if not create and not path.exists():
if raise_if_missing:
raise FileNotFoundError("Pack index does not exist")
return None
raise FileNotFoundError("Pack index does not exist")

engine = create_engine(f"sqlite:///{path}", future=True)

Expand Down
13 changes: 3 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,14 @@ def temp_container(temp_dir): # pylint: disable=redefined-outer-name


@pytest.fixture(scope="function")
def temp_dir():
def temp_dir(tmp_path):
"""Get a temporary directory.
:return: The path to the directory
:rtype: str
"""
import gc
gc.collect()

try:
dirpath = tempfile.mkdtemp()
yield Path(dirpath)
finally:
# after the test function has completed, remove the directory again
shutil.rmtree(dirpath)
dirpath = tempfile.mkdtemp(dir=str(tmp_path))
yield Path(dirpath)


@pytest.fixture(scope="function")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def test_backup(temp_container, temp_dir, remote, verbosity):
if verbosity:
args += [f"--verbosity={verbosity}"]

result = CliRunner().invoke(cli.backup, args, obj=obj)
result = CliRunner().invoke(cli.backup, args, obj=obj, catch_exceptions=False)

assert result.exit_code == 0
assert path.exists()
Expand Down
20 changes: 10 additions & 10 deletions tests/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,15 +677,16 @@ def test_initialisation(temp_dir):

# Check that the session cannot be obtained before initialising
with pytest.raises(FileNotFoundError):
container._get_session(create=False, raise_if_missing=True)
assert container._get_session(create=False, raise_if_missing=False) is None
container._get_cached_session()

container.init_container()
assert container.is_initialised
container.close()

# This call should go through
container.init_container(clear=True)
assert container.is_initialised
container.close()

with pytest.raises(FileExistsError) as excinfo:
container.init_container()
Expand Down Expand Up @@ -717,31 +718,32 @@ def test_initialisation(temp_dir):

@pytest.mark.parametrize("hash_type", ["sha256", "sha1"])
@pytest.mark.parametrize("compress", [True, False])
def test_check_hash_computation(temp_container, hash_type, compress):
def test_check_hash_computation(temp_dir, hash_type, compress):
"""Check that the hashes are correctly computed, when storing loose,
directly to packs, and while repacking all loose.
Check both compressed and uncompressed packed objects.
"""
# Re-init the container with the correct hash type
temp_container.init_container(hash_type=hash_type, clear=True)
container = Container(temp_dir)
container.init_container(hash_type=hash_type, clear=True)
content1 = b"1"
content2 = b"222"
content3 = b"n2fwd"

expected_hasher = getattr(hashlib, hash_type)

hashkey1 = temp_container.add_object(content1)
hashkey1 = container.add_object(content1)
assert hashkey1 == expected_hasher(content1).hexdigest()

hashkey2, hashkey3 = temp_container.add_objects_to_pack(
hashkey2, hashkey3 = container.add_objects_to_pack(
[content2, content3], compress=compress
)
assert hashkey2 == expected_hasher(content2).hexdigest()
assert hashkey3 == expected_hasher(content3).hexdigest()

# No exceptions should be aised
temp_container.pack_all_loose(compress=compress, validate_objects=True)
container.pack_all_loose(compress=compress, validate_objects=True)


@pytest.mark.parametrize("validate_objects", [True, False])
Expand Down Expand Up @@ -1064,9 +1066,7 @@ def test_sizes(
temp_container, generate_random_data, compress_packs, compression_algorithm
):
"""Check that the information on size is reliable."""
temp_container.init_container(
clear=True, compression_algorithm=compression_algorithm
)
temp_container.init_container( clear=True, compression_algorithm=compression_algorithm)
size_info = temp_container.get_total_size()
assert size_info["total_size_packed"] == 0
assert size_info["total_size_packed_on_disk"] == 0
Expand Down

0 comments on commit 3474a15

Please sign in to comment.