Skip to content

Commit

Permalink
estore int tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ritch committed Nov 12, 2024
1 parent 852d6dc commit 4f8ea26
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 16 deletions.
17 changes: 9 additions & 8 deletions fiftyone/factory/repo_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,16 @@ def delegated_operation_repo() -> DelegatedOperationRepo:
def execution_store_repo(
dataset_id: Optional[ObjectId] = None,
) -> ExecutionStoreRepo:
if (
MongoExecutionStoreRepo.COLLECTION_NAME
not in RepositoryFactory.repos
):
RepositoryFactory.repos[
MongoExecutionStoreRepo.COLLECTION_NAME
] = MongoExecutionStoreRepo(
repo_key = (
f"{MongoExecutionStoreRepo.COLLECTION_NAME}_{dataset_id}"
if dataset_id
else MongoExecutionStoreRepo.COLLECTION_NAME
)

if repo_key not in RepositoryFactory.repos:
RepositoryFactory.repos[repo_key] = MongoExecutionStoreRepo(
collection=_get_db()[MongoExecutionStoreRepo.COLLECTION_NAME],
dataset_id=dataset_id,
)

return RepositoryFactory.repos[MongoExecutionStoreRepo.COLLECTION_NAME]
return RepositoryFactory.repos[repo_key]
37 changes: 29 additions & 8 deletions fiftyone/factory/repos/execution_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,21 +198,42 @@ def has_store_global(self, store_name):
)
return bool(result)

def list_stores_global(self) -> list[str]:
"""Lists the stores in the execution store across all datasets and the
global context.
"""
result = self._collection.find(
dict(key="__store__"), {"store_name": 1}
)
return [d["store_name"] for d in result]
def list_stores_global(self) -> list[StoreDocument]:
"""Lists stores across all datasets and the global context."""
pipeline = [
{
"$group": {
"_id": {
"store_name": "$store_name",
"dataset_id": "$dataset_id",
}
}
},
{
"$project": {
"_id": 0,
"store_name": "$_id.store_name",
"dataset_id": "$_id.dataset_id",
}
},
]

result = self._collection.aggregate(pipeline)
return [StoreDocument(**d) for d in result]

def count_stores_global(self) -> int:
"""Counts the stores in the execution store across all datasets and the
global context.
"""
return self._collection.count_documents(dict(key="__store__"))

def delete_store_global(self, store_name) -> int:
"""Deletes the specified store across all datasets and the global
context.
"""
result = self._collection.delete_many(dict(store_name=store_name))
return result.deleted_count


class MongoExecutionStoreRepo(ExecutionStoreRepo):
"""MongoDB implementation of execution store repository."""
Expand Down
2 changes: 2 additions & 0 deletions fiftyone/operators/store/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def to_mongo_dict(self, exclude_id: bool = True) -> dict[str, Any]:
data.pop("_id", None)
if self.dataset_id is None:
data.pop("dataset_id", None)

print(data)
return data


Expand Down
12 changes: 12 additions & 0 deletions fiftyone/operators/store/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,15 @@ def count_stores_global(self) -> int:
the number of stores
"""
return self._repo.count_stores_global()

def delete_store_global(self, store_name) -> int:
"""Deletes the specified store across all datasets and the global
context.
Args:
store_name: the name of the store
Returns:
the number of stores deleted
"""
return self._repo.delete_store_global(store_name)
19 changes: 19 additions & 0 deletions tests/unittests/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from functools import wraps
import platform
import unittest
import fnmatch

import fiftyone as fo
import fiftyone.operators.store as foos


def drop_datasets(func):
Expand Down Expand Up @@ -43,6 +45,23 @@ async def wrapper(*args, **kwargs):
return wrapper


def drop_stores(func, pattern="*"):
"""Decorator that drops all stores from the database before running a test."""

@wraps(func)
def wrapper(*args, **kwargs):
svc = foos.ExecutionStoreService()
stores = svc.list_stores_global()
for store in stores:
store_name = store.store_name
if fnmatch.fnmatch(store_name, pattern):
print(f"Deleting store: {store_name}", pattern)
svc.delete_store_global(store_name)
return func(*args, **kwargs)

return wrapper


def skip_windows(func):
"""Decorator that skips a test when running on Windows."""

Expand Down
95 changes: 95 additions & 0 deletions tests/unittests/execution_store_int_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
FiftyOne execution store related unit tests.
| Copyright 2017-2024, Voxel51, Inc.
| `voxel51.com <https://voxel51.com/>`_
|
"""

import pytest

import fiftyone as fo
from fiftyone.operators.store import ExecutionStoreService

from decorators import drop_stores


@pytest.fixture
def svc():
return ExecutionStoreService()


@pytest.fixture
def svc_with_dataset():
dataset = fo.Dataset(name="test_dataset")
dataset.save()
yield ExecutionStoreService(dataset_id=dataset._doc.id)
dataset.delete()


@drop_stores
def test_store_creation(svc):
NAME = "test_store"
created_store = svc.create_store(NAME)

assert (
created_store.store_name == NAME
), "Store name should match the given name"
assert (
created_store.dataset_id is None
), "Dataset ID should be None when not provided"
assert svc.count_stores() == 1, "Store count should be 1"


@drop_stores
def test_store_creation_with_dataset(svc_with_dataset):
NAME = "test_store"
created_store = svc_with_dataset.create_store(NAME)

assert (
created_store.store_name == NAME
), "Store name should match the given name"
assert (
created_store.dataset_id is not None
), "Dataset ID should be set when provided"
assert svc_with_dataset.count_stores() == 1, "Store count should be 1"


@drop_stores
def test_set_get_key(svc):
NAME = "test_store"
KEY = "test_key"
VALUE = "test_value"

svc.set_key(NAME, KEY, VALUE)
assert (
svc.count_keys(NAME) == 1
), "Store should have 1 key after setting it"
assert (
svc.get_key(NAME, KEY).value == VALUE
), "Retrieved value should match the set value"


@drop_stores
def test_scoping(svc, svc_with_dataset):
NAME = "test_store"
KEY = "test_key"
VALUE = "test_value"
svc.set_key(NAME, KEY, VALUE)
svc_with_dataset.set_key(NAME, KEY, VALUE)
global_list = svc.list_stores_global()
global_names = [store.store_name for store in global_list]
assert global_names == [NAME, NAME], "Global store should be listed"
assert svc.count_keys(NAME) == 1, "Global store should have 1 key"
assert (
svc_with_dataset.count_keys(NAME) == 1
), "Dataset store should have 1 key"
svc_with_dataset.delete_store(NAME)
assert svc.count_keys(NAME) == 1, "Global store should still have 1 key"
assert (
svc_with_dataset.count_keys(NAME) == 0
), "Dataset store should have 0 keys"
svc.delete_store(NAME)
assert svc.count_keys(NAME) == 0, "Global store should have 0 keys"
global_list = svc.list_stores_global()
assert NAME not in global_list, "Global store should not be listed"
File renamed without changes.

0 comments on commit 4f8ea26

Please sign in to comment.