diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a413a590..58f2aedf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ default_language_version: exclude: "^src/atomate2/vasp/schemas/calc_types/" repos: - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.1.0 + rev: v0.1.1 hooks: - id: ruff args: [--fix] @@ -46,7 +46,7 @@ repos: - id: rst-directive-colons - id: rst-inline-touching-normal - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.6.0 + rev: v1.6.1 hooks: - id: mypy files: ^src/ diff --git a/examples/data_store.py b/examples/data_store.py index 964f4112..10ee58cb 100644 --- a/examples/data_store.py +++ b/examples/data_store.py @@ -17,8 +17,7 @@ def generate_big_data(): The data=True in the job decorator tells jobflow to store all outputs in the "data" additional store. """ - mydata = list(range(1000)) - return mydata + return list(range(1000)) big_data_job = generate_big_data() diff --git a/pyproject.toml b/pyproject.toml index 67579bd4..fe20de09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,26 +121,56 @@ exclude_lines = [ target-version = "py39" ignore-init-module-imports = true select = [ - "B", # flake8-bugbear - "C4", # flake8-comprehensions - "D", # pydocstyle - "E", # pycodestyle - "F", # pyflakes - "I", # isort - "PLE", # pylint error - "PLW", # pylint warning - "Q", # flake8-quotes - "RUF", # Ruff-specific rules - "SIM", # flake8-simplify - "TID", # tidy imports - "UP", # pyupgrade - "W", # pycodestyle - "YTT", # flake8-2020 + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "D", # pydocstyle + "E", # pycodestyle error + "EXE", # flake8-executable + "F", # pyflakes + "FA", # flake8-future-annotations + "FBT003", # boolean-positional-value-in-call + "FLY", # flynt + "I", # isort + "ICN", # flake8-import-conventions + "ISC", # flake8-implicit-str-concat + "PD", # pandas-vet + "PERF", # perflint + "PIE", # flake8-pie + "PL", # pylint + "PT", # flake8-pytest-style + "PYI", # flakes8-pyi + "Q", # flake8-quotes + "RET", # flake8-return + "RSE", # flake8-raise + "RUF", # Ruff-specific rules + "SIM", # flake8-simplify + "SLOT", # flake8-slots + "TCH", # flake8-type-checking + "TID", # flake8-tidy-imports + "UP", # pyupgrade + "W", # pycodestyle warning + "YTT", # flake8-2020 +] +# PLR0911: too-many-return-statements +# PLR0912: too-many-branches +# PLR0913: too-many-arguments +# PLR0915: too-many-statements +ignore = [ + "B028", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", + "PLW0603", + "RUF013", ] -ignore = ["B028", "PLW0603", "RUF013"] pydocstyle.convention = "numpy" isort.known-first-party = ["jobflow"] [tool.ruff.per-file-ignores] "__init__.py" = ["F401"] -"**/tests/*" = ["B018", "D"] +# PLR2004: magic-value-comparison +# PT004: pytest-missing-fixture-name-underscore +# PLR0915: too-many-statements +# PT011: pytest-raises-too-broad TODO fix these, should not be ignored +"**/tests/*" = ["B018", "D", "PLR0915", "PLR2004", "PT004", "PT011"] diff --git a/src/jobflow/core/flow.py b/src/jobflow/core/flow.py index ecf3574c..4040a270 100644 --- a/src/jobflow/core/flow.py +++ b/src/jobflow/core/flow.py @@ -4,7 +4,6 @@ import logging import warnings -from collections.abc import Sequence from copy import deepcopy from typing import TYPE_CHECKING @@ -15,7 +14,7 @@ from jobflow.utils import ValueEnum, contains_flow_or_job, suuid if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Iterator, Sequence from typing import Any, Callable from networkx import DiGraph diff --git a/src/jobflow/core/reference.py b/src/jobflow/core/reference.py index 40ed9e7d..0700389a 100644 --- a/src/jobflow/core/reference.py +++ b/src/jobflow/core/reference.py @@ -4,7 +4,6 @@ import contextlib import typing -from collections.abc import Sequence from typing import Any from monty.json import MontyDecoder, MontyEncoder, MSONable, jsanitize @@ -14,6 +13,8 @@ from jobflow.utils.enum import ValueEnum if typing.TYPE_CHECKING: + from collections.abc import Sequence + import jobflow @@ -88,14 +89,14 @@ def __init__( uuid: str, attributes: tuple[tuple[str, Any], ...] = (), output_schema: type[BaseModel] = None, - ): + ) -> None: super().__init__() self.uuid = uuid self.attributes = attributes self.output_schema = output_schema for attr_type, attr in attributes: - if attr_type not in ("a", "i"): + if attr_type not in {"a", "i"}: raise ValueError( f"Unrecognised attribute type '{attr_type}' for attribute '{attr}'" ) @@ -161,9 +162,9 @@ def resolve( f"Could not resolve reference - {self.uuid}{istr} not in store or " f"{index=}, {cache=}" ) - elif on_missing == OnMissing.NONE and index not in cache[self.uuid]: + if on_missing == OnMissing.NONE and index not in cache[self.uuid]: return None - elif on_missing == OnMissing.PASS and index not in cache[self.uuid]: + if on_missing == OnMissing.PASS and index not in cache[self.uuid]: return self data = cache[self.uuid][index] @@ -200,12 +201,11 @@ def set_uuid(self, uuid: str, inplace=True) -> OutputReference: if inplace: self.uuid = uuid return self - else: - from copy import deepcopy + from copy import deepcopy - new_reference = deepcopy(self) - new_reference.uuid = uuid - return new_reference + new_reference = deepcopy(self) + new_reference.uuid = uuid + return new_reference def __getitem__(self, item) -> OutputReference: """Index the reference.""" @@ -263,7 +263,7 @@ def __hash__(self) -> int: """Return a hash of the reference.""" return hash(str(self)) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: """Test for equality against another reference.""" if isinstance(other, OutputReference): return ( @@ -285,7 +285,7 @@ def as_dict(self): """Serialize the reference as a dict.""" schema = self.output_schema schema_dict = MontyEncoder().default(schema) if schema is not None else None - data = { + return { "@module": self.__class__.__module__, "@class": type(self).__name__, "@version": None, @@ -293,7 +293,6 @@ def as_dict(self): "attributes": self.attributes, "output_schema": schema_dict, } - return data def resolve_references( @@ -376,7 +375,7 @@ def find_and_get_references(arg: Any) -> tuple[OutputReference, ...]: # if the argument is a reference then stop there return (arg,) - elif isinstance(arg, (float, int, str, bool)): + if isinstance(arg, (float, int, str, bool)): # argument is a primitive, we won't find a reference here return () @@ -432,7 +431,7 @@ def find_and_resolve_references( # if the argument is a reference then stop there return arg.resolve(store, cache=cache, on_missing=on_missing) - elif isinstance(arg, (float, int, str, bool)): + if isinstance(arg, (float, int, str, bool)): # argument is a primitive, we won't find a reference here return arg diff --git a/src/jobflow/core/state.py b/src/jobflow/core/state.py index a7cf2a92..c23010bc 100644 --- a/src/jobflow/core/state.py +++ b/src/jobflow/core/state.py @@ -15,8 +15,6 @@ from monty.design_patterns import singleton if typing.TYPE_CHECKING: - pass - import jobflow diff --git a/src/jobflow/core/store.py b/src/jobflow/core/store.py index 2916eba7..923cbb46 100644 --- a/src/jobflow/core/store.py +++ b/src/jobflow/core/store.py @@ -8,7 +8,6 @@ from monty.json import MSONable from jobflow.core.reference import OnMissing -from jobflow.schemas.job_output_schema import JobStoreDocument from jobflow.utils.find import get_root_locations if typing.TYPE_CHECKING: @@ -19,6 +18,8 @@ from maggma.core import Sort + from jobflow.schemas.job_output_schema import JobStoreDocument + obj_type = Union[str, Enum, type[MSONable], list[Union[Enum, str, type[MSONable]]]] save_type = Optional[dict[str, obj_type]] load_type = Union[bool, dict[str, Union[bool, obj_type]]] @@ -250,8 +251,7 @@ def query_one( docs = self.query( criteria=criteria, properties=properties, load=load, sort=sort, limit=1 ) - d = next(docs, None) - return d + return next(docs, None) def update( self, @@ -496,7 +496,7 @@ def get_output( # this could be fixed but will require more complicated logic just to # catch a very unlikely event. - if isinstance(which, int) or which in ("last", "first"): + if isinstance(which, int) or which in {"last", "first"}: sort = -1 if which == "last" else 1 criteria: dict[str, Any] = {"uuid": uuid} @@ -522,28 +522,27 @@ def get_output( return find_and_resolve_references( result["output"], self, cache=cache, on_missing=on_missing ) - else: - results = list( - self.query( - criteria={"uuid": uuid}, - properties=["output"], - sort={"index": 1}, - load=load, - ) + results = list( + self.query( + criteria={"uuid": uuid}, + properties=["output"], + sort={"index": 1}, + load=load, ) + ) - if len(results) == 0: - raise ValueError(f"UUID: {uuid} has no outputs.") + if len(results) == 0: + raise ValueError(f"UUID: {uuid} has no outputs.") - results = [r["output"] for r in results] + results = [r["output"] for r in results] - refs = find_and_get_references(results) - if any(ref.uuid == uuid for ref in refs): - raise RuntimeError("Reference cycle detected - aborting.") + refs = find_and_get_references(results) + if any(ref.uuid == uuid for ref in refs): + raise RuntimeError("Reference cycle detected - aborting.") - return find_and_resolve_references( - results, self, cache=cache, on_missing=on_missing - ) + return find_and_resolve_references( + results, self, cache=cache, on_missing=on_missing + ) @classmethod def from_file(cls: type[T], db_file: str | Path, **kwargs) -> T: @@ -761,7 +760,7 @@ def _group_blobs(infos, locs): new_blobs = [] new_locations = [] - for _store_name, store_load in load.items(): + for store_load in load.values(): for blob, location in zip(blob_infos, locations): if store_load is True: new_blobs.append(blob) diff --git a/src/jobflow/managers/fireworks.py b/src/jobflow/managers/fireworks.py index 8cbf1e08..dc1f0230 100644 --- a/src/jobflow/managers/fireworks.py +++ b/src/jobflow/managers/fireworks.py @@ -10,6 +10,7 @@ from collections.abc import Sequence import jobflow + from jobflow.core.job import Job def flow_to_workflow( @@ -146,7 +147,6 @@ class JobFiretask(FiretaskBase): def run_task(self, fw_spec): """Run the job and handle any dynamic firework submissions.""" from jobflow import SETTINGS, initialize_logger - from jobflow.core.job import Job job: Job = self.get("job") store = self.get("store") @@ -190,11 +190,10 @@ def run_task(self, fw_spec): else: detours = [detour_wf] - fwa = FWAction( + return FWAction( stored_data=response.stored_data, detours=detours, additions=additions, defuse_workflow=response.stop_jobflow, defuse_children=response.stop_children, ) - return fwa diff --git a/src/jobflow/managers/local.py b/src/jobflow/managers/local.py index abb0b23d..f9c15bf1 100644 --- a/src/jobflow/managers/local.py +++ b/src/jobflow/managers/local.py @@ -6,8 +6,6 @@ import typing if typing.TYPE_CHECKING: - pass - import jobflow @@ -133,8 +131,7 @@ def _run_job(job: jobflow.Job, parents): if not all(diversion_responses): return None, False - else: - return response, False + return response, False def _get_job_dir(): if create_folders: @@ -142,8 +139,7 @@ def _get_job_dir(): job_dir = root_dir / f"job_{time_now}-{randint(10000, 99999)}" job_dir.mkdir() return job_dir - else: - return root_dir + return root_dir def _run(root_flow): encountered_bad_response = False diff --git a/src/jobflow/utils/enum.py b/src/jobflow/utils/enum.py index a2b444f6..8e7e6c21 100644 --- a/src/jobflow/utils/enum.py +++ b/src/jobflow/utils/enum.py @@ -14,8 +14,7 @@ def __eq__(self, other): """Compare to another enum for equality.""" if type(self) == type(other) and self.value == other.value: return True - else: - return str(self.value) == str(other) + return str(self.value) == str(other) def __hash__(self): """Get a hash of the enum.""" diff --git a/src/jobflow/utils/find.py b/src/jobflow/utils/find.py index 00b5d805..43e6903d 100644 --- a/src/jobflow/utils/find.py +++ b/src/jobflow/utils/find.py @@ -72,10 +72,8 @@ def _lookup(obj, path=None): if ( inspect.isclass(key) and issubclass(key, MSONable) - and "@module" in obj - and obj["@module"] == key.__module__ - and "@class" in obj - and obj["@class"] == key.__name__ + and obj.get("@module") == key.__module__ + and obj.get("@class") == key.__name__ ): found_items.add(path) found = True @@ -203,7 +201,7 @@ def contains_flow_or_job(obj: Any) -> bool: # if the argument is an flow or job then stop there return True - elif isinstance(obj, (float, int, str, bool)): + if isinstance(obj, (float, int, str, bool)): # argument is a primitive, we won't find an flow or job here return False diff --git a/src/jobflow/utils/graph.py b/src/jobflow/utils/graph.py index b58f6eb7..b1cc7380 100644 --- a/src/jobflow/utils/graph.py +++ b/src/jobflow/utils/graph.py @@ -8,15 +8,13 @@ from monty.dev import requires try: - import matplotlib + import matplotlib as mpl except ImportError: - matplotlib = None + mpl = None import typing if typing.TYPE_CHECKING: - pass - import jobflow @@ -54,7 +52,7 @@ def itergraph(graph: nx.DiGraph): yield from nx.topological_sort(subgraph) -@requires(matplotlib, "matplotlib must be installed to plot flow graphs.") +@requires(mpl, "matplotlib must be installed to plot flow graphs.") def draw_graph( graph: nx.DiGraph, layout_function: typing.Callable = None, diff --git a/tests/conftest.py b/tests/conftest.py index b4723dbe..083d0b45 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,7 +26,7 @@ def mongo_jobstore(database): return store -@pytest.fixture(scope="function") +@pytest.fixture() def memory_jobstore(): from maggma.stores import MemoryStore @@ -38,7 +38,7 @@ def memory_jobstore(): return store -@pytest.fixture(scope="function") +@pytest.fixture() def memory_data_jobstore(): from maggma.stores import MemoryStore @@ -50,7 +50,7 @@ def memory_data_jobstore(): return store -@pytest.fixture +@pytest.fixture() def clean_dir(): import os import shutil @@ -85,7 +85,7 @@ def lpad(database, debug_mode): lpad.db[coll].drop() -@pytest.fixture +@pytest.fixture() def no_pydot(monkeypatch): import builtins @@ -93,13 +93,13 @@ def no_pydot(monkeypatch): def mocked_import(name, *args, **kwargs): if name == "pydot": - raise ImportError() + raise ImportError return import_orig(name, *args, **kwargs) monkeypatch.setattr(builtins, "__import__", mocked_import) -@pytest.fixture +@pytest.fixture() def no_matplotlib(monkeypatch): import builtins @@ -107,7 +107,7 @@ def no_matplotlib(monkeypatch): def mocked_import(name, *args, **kwargs): if name == "matplotlib": - raise ImportError() + raise ImportError return import_orig(name, *args, **kwargs) monkeypatch.setattr(builtins, "__import__", mocked_import) diff --git a/tests/core/test_flow.py b/tests/core/test_flow.py index d34fc2ae..2bd5a0c9 100644 --- a/tests/core/test_flow.py +++ b/tests/core/test_flow.py @@ -55,8 +55,7 @@ def make(self, a): if return_makers: return flow, (AddMaker, DivMaker) - else: - return flow + return flow def test_flow_of_jobs_init(): @@ -109,7 +108,9 @@ def test_flow_of_jobs_init(): # # test all jobs included needed to generate outputs add_job = get_test_job() - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="jobs array does not contain all jobs needed for flow output" + ): Flow([], output=add_job.output) # test job given rather than outputs @@ -125,12 +126,14 @@ def test_flow_of_jobs_init(): # test job already belongs to another flow add_job = get_test_job() Flow([add_job]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="already belongs to another flow"): Flow([add_job]) # test that two of the same job cannot be used in the same flow add_job = get_test_job() - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="jobs array contains multiple jobs/flows with the same uuid" + ): Flow([add_job, add_job]) @@ -139,19 +142,19 @@ def test_flow_of_flows_init(): # test single flow add_job = get_test_job() - subflow = Flow([add_job]) - flow = Flow([subflow], name="add") + sub_flow = Flow([add_job]) + flow = Flow([sub_flow], name="add") assert flow.name == "add" assert flow.host is None assert flow.output is None assert flow.job_uuids == (add_job.uuid,) - assert flow.all_uuids == (add_job.uuid, subflow.uuid) + assert flow.all_uuids == (add_job.uuid, sub_flow.uuid) assert flow.jobs[0].host == flow.uuid # test single flow no list add_job = get_test_job() - subflow = Flow(add_job) - flow = Flow(subflow, name="add") + sub_flow = Flow(add_job) + flow = Flow(sub_flow, name="add") assert flow.name == "add" assert flow.host is None assert flow.output is None @@ -178,8 +181,8 @@ def test_flow_of_flows_init(): # test single job and outputs add_job = get_test_job() - subflow = Flow([add_job], output=add_job.output) - flow = Flow([subflow], output=subflow.output) + sub_flow = Flow([add_job], output=add_job.output) + flow = Flow([sub_flow], output=sub_flow.output) assert flow.output == add_job.output # test multi job and list multi outputs @@ -193,34 +196,34 @@ def test_flow_of_flows_init(): # test all jobflow included needed to generate outputs add_job = get_test_job() - subflow = Flow([add_job], output=add_job.output) + sub_flow = Flow([add_job], output=add_job.output) with pytest.raises(ValueError): - Flow([], output=subflow.output) + Flow([], output=sub_flow.output) # test flow given rather than outputs add_job = get_test_job() - subflow = Flow([add_job], output=add_job.output) + sub_flow = Flow([add_job], output=add_job.output) with pytest.warns(UserWarning): - Flow([subflow], output=subflow) + Flow([sub_flow], output=sub_flow) # test complex object containing job given rather than outputs add_job = get_test_job() - subflow = Flow([add_job], output=add_job.output) + sub_flow = Flow([add_job], output=add_job.output) with pytest.warns(UserWarning): - Flow([subflow], output={1: [[{"a": subflow}]]}) + Flow([sub_flow], output={1: [[{"a": sub_flow}]]}) # test flow already belongs to another flow add_job = get_test_job() - subflow = Flow([add_job], output=add_job.output) - Flow([subflow]) + sub_flow = Flow([add_job], output=add_job.output) + Flow([sub_flow]) with pytest.raises(ValueError): - Flow([subflow]) + Flow([sub_flow]) # test that two of the same flow cannot be used in the same flow add_job = get_test_job() - subflow = Flow([add_job], output=add_job.output) + sub_flow = Flow([add_job], output=add_job.output) with pytest.raises(ValueError): - Flow([subflow, subflow]) + Flow([sub_flow, sub_flow]) def test_flow_job_mixed(): @@ -883,7 +886,7 @@ def test_flow_magic_methods(): # test __iter__ for job in flow2: - assert job in [job4, job3] + assert job in {job4, job3} # test __contains__ assert job1 in flow1 @@ -901,7 +904,7 @@ def test_flow_magic_methods(): assert job5 not in flow4 # test __eq__ and __hash__ - assert flow1 == flow1 + assert flow1 == flow1 # noqa: PLR0124 assert flow1 != flow2 assert hash(flow1) != hash(flow2) diff --git a/tests/core/test_job.py b/tests/core/test_job.py index 3d862565..f14c852f 100644 --- a/tests/core/test_job.py +++ b/tests/core/test_job.py @@ -853,7 +853,7 @@ class AddSchema(BaseModel): @job(output_schema=AddSchema) def add_schema(a, b): - return AddSchema(**{"result": a + b}) + return AddSchema(result=a + b) @job(output_schema=AddSchema) def add_schema_dict(a, b): @@ -870,7 +870,6 @@ def add_schema_wrong_key(a, b): @job(output_schema=AddSchema) def add_schema_no_output(a, b): a + b - return None @job(output_schema=AddSchema) def add_schema_response_dict(a, b): @@ -970,7 +969,7 @@ def test_pass_manager_config(): assert test_job2.config.manager_config == manager_config # test bad input - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Unrecognised jobs format"): pass_manager_config(["str"], manager_config) @@ -1218,7 +1217,7 @@ def jsm_wrapped(a, b): assert not test_job.config.resolve_references assert test_job.config.pass_manager_config - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Unknown JobConfig attribute: abc_xyz"): test_job.update_config(new_config, attributes="abc_xyz") # test dictionary config updates @@ -1245,7 +1244,10 @@ def jsm_wrapped(a, b): assert not test_job.config.resolve_references assert test_job.config.pass_manager_config - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Specified attributes include a key that is not present in the config", + ): test_job.update_config(new_config_dict, attributes="abc_xyz") # test applied dynamic updates @@ -1289,8 +1291,8 @@ def test_job_magic_methods(): assert "fake-uuid" not in job1 # test __eq__ - assert job1 == job1 - assert job2 == job2 + assert job1 == job1 # noqa: PLR0124 + assert job2 == job2 # noqa: PLR0124 assert job1 != job2 assert job1 != job3 # Different UUIDs diff --git a/tests/core/test_job_output_schema.py b/tests/core/test_job_output_schema.py index 99777e7e..2db2c330 100644 --- a/tests/core/test_job_output_schema.py +++ b/tests/core/test_job_output_schema.py @@ -3,7 +3,7 @@ import pytest -@pytest.fixture +@pytest.fixture() def sample_data(): from jobflow.schemas.job_output_schema import JobStoreDocument diff --git a/tests/core/test_maker.py b/tests/core/test_maker.py index d08cb352..a7838b78 100644 --- a/tests/core/test_maker.py +++ b/tests/core/test_maker.py @@ -253,5 +253,8 @@ def bad_func(_: Maker) -> int: return 1 # test bad recursive call - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Function must return a Maker object. Got instead.", + ): recursive_call(maker, bad_func) diff --git a/tests/core/test_reference.py b/tests/core/test_reference.py index ae05e65d..525f1837 100644 --- a/tests/core/test_reference.py +++ b/tests/core/test_reference.py @@ -11,7 +11,9 @@ def test_access(): assert ref.attributes == () # test bad init - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Unrecognised attribute type 'x' for attribute '1'" + ): OutputReference("123", (("x", 1),)) new_ref = ref.a @@ -162,7 +164,7 @@ class MediumSchema(BaseModel): s: str nested: InnerSchema nested_opt: InnerSchema = None - nested_u: Union[InnerSchema, dict] + nested_u: Union[InnerSchema, dict] # noqa: FA100 nested_l: list[InnerSchema] nested_d: dict[str, InnerSchema] @@ -224,7 +226,7 @@ def test_resolve(memory_jobstore): assert ref.resolve(memory_jobstore, on_missing=OnMissing.NONE) is None assert ref.resolve(memory_jobstore, on_missing=OnMissing.PASS) == ref - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Could not resolve reference"): ref.resolve(memory_jobstore, on_missing=OnMissing.ERROR) # resolve using store @@ -308,7 +310,7 @@ def test_resolve_references(memory_jobstore): assert output[ref2] == ref2 ref2 = OutputReference("12345") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Could not resolve reference"): resolve_references([ref1, ref2], memory_jobstore, on_missing=OnMissing.ERROR) # resolve using store and empty cache @@ -371,7 +373,7 @@ def test_find_and_resolve_references(memory_jobstore): memory_jobstore.update({"uuid": "1234", "index": 1, "output": {"a": "xyz", "b": 5}}) # test no reference - assert find_and_resolve_references(True, memory_jobstore) is True + assert find_and_resolve_references(arg=True, store=memory_jobstore) is True assert find_and_resolve_references("xyz", memory_jobstore) == "xyz" assert find_and_resolve_references([101], memory_jobstore) == [101] @@ -425,7 +427,7 @@ def test_find_and_resolve_references(memory_jobstore): ) assert output == [101, None] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Could not resolve reference"): find_and_resolve_references( [ref1, ref3], memory_jobstore, on_missing=OnMissing.ERROR ) @@ -461,7 +463,7 @@ def test_reference_in_output(memory_jobstore): memory_jobstore.update(task_data) assert ref1.resolve(memory_jobstore, on_missing=OnMissing.NONE) is None assert ref1.resolve(memory_jobstore, on_missing=OnMissing.PASS) == ref2 - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Could not resolve reference"): ref1.resolve(memory_jobstore, on_missing=OnMissing.ERROR) @@ -473,6 +475,6 @@ def test_not_iterable(): with pytest.raises(TypeError): next(ref) - with pytest.raises(TypeError): + with pytest.raises(TypeError): # noqa: PT012 for _ in ref: pass diff --git a/tests/core/test_store.py b/tests/core/test_store.py index 72a3ef95..2ddd78a7 100644 --- a/tests/core/test_store.py +++ b/tests/core/test_store.py @@ -1,7 +1,12 @@ +from typing import TYPE_CHECKING + import pytest +if TYPE_CHECKING: + from jobflow.core.store import JobStore + -@pytest.fixture +@pytest.fixture() def memory_store(): from maggma.stores import MemoryStore @@ -125,10 +130,12 @@ def test_data_update(memory_data_jobstore): # test bad store name fails results["data"]["store"] = "bad_store" memory_data_jobstore.update(results) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Unrecognised additional store name"): memory_data_jobstore.query_one(c, load=True) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Unrecognised additional store name: bad_store" + ): memory_data_jobstore.update(d, save={"bad_store": "data"}) d = {"index": 2, "uuid": 2, "e": 6, "x": 4, "data2": [1, 2, 3]} @@ -229,8 +236,6 @@ class MyEnum(ValueEnum): def test_nested_msonable(memory_data_jobstore): from monty.json import MSONable - from jobflow.core.store import JobStore - class Child(MSONable): def __init__(self, x): self.x = x @@ -368,10 +373,10 @@ def test_get_output(memory_jobstore): output = memory_jobstore.get_output("1", which=3) assert output == 123 - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="UUID: 1 has no outputs"): memory_jobstore.get_output(1, which="first") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="UUID: 1 has no outputs"): memory_jobstore.get_output(1, which="all") # test resolving reference in output of job @@ -388,7 +393,9 @@ def test_get_output(memory_jobstore): # test missing reference in output of job r = {"@module": "jobflow.core.reference", "@class": "OutputReference", "uuid": "a"} memory_jobstore.update({"uuid": "8", "index": 1, "output": r}) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Could not resolve reference - a not in store or" + ): memory_jobstore.get_output("8", on_missing=OnMissing.ERROR) assert memory_jobstore.get_output("8", on_missing=OnMissing.NONE) is None @@ -430,7 +437,7 @@ def test_from_db_file(test_data): assert data_store.name == "gridfs://localhost/jobflow_unittest/outputs_blobs" # test bad file - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Unrecognised database file format"): JobStore.from_file(test_data / "db_bad.yaml") diff --git a/tests/managers/test_fireworks.py b/tests/managers/test_fireworks.py index 732c67be..bc8ce1d8 100644 --- a/tests/managers/test_fireworks.py +++ b/tests/managers/test_fireworks.py @@ -92,7 +92,7 @@ def test_job_to_firework( assert type(fw) == Firework assert fw.name == "func" - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Both or neither of"): job_to_firework(job2, memory_jobstore, parents=[job.uuid]) @@ -356,7 +356,7 @@ def test_detour_flow(lpad, mongo_jobstore, fw_dir, detour_flow, capsys): wf = lpad.get_wf_by_fw_id(fw_id) uuids = [fw.tasks[0]["job"].uuid for fw in wf.fws] - uuid2 = next(u for u in uuids if u != uuid1 and u != uuid3) + uuid2 = next(u for u in uuids if u not in {uuid1, uuid3}) assert all(s == "COMPLETED" for s in wf.fw_states.values()) # check store has the activity output @@ -565,7 +565,7 @@ def test_detour_stop_flow(lpad, mongo_jobstore, fw_dir, detour_stop_flow, capsys wf = lpad.get_wf_by_fw_id(fw_id) uuids = [fw.tasks[0]["job"].uuid for fw in wf.fws] - uuid2 = next(u for u in uuids if u != uuid1 and u != uuid3) + uuid2 = next(u for u in uuids if u not in {uuid1, uuid3}) # Sort by firework id explicitly instead of assuming they are sorted states_dict = dict(zip(list(wf.id_fw.keys()), list(wf.fw_states.values()))) @@ -604,7 +604,7 @@ def test_replace_and_detour_flow( wf = lpad.get_wf_by_fw_id(fw_id) uuids = [fw.tasks[0]["job"].uuid for fw in wf.fws] - uuid2 = next(u for u in uuids if u != uuid1 and u != uuid3) + uuid2 = next(u for u in uuids if u not in {uuid1, uuid3}) assert all(s == "COMPLETED" for s in wf.fw_states.values()) diff --git a/tests/managers/test_local.py b/tests/managers/test_local.py index e68a967d..aad435c1 100644 --- a/tests/managers/test_local.py +++ b/tests/managers/test_local.py @@ -150,7 +150,7 @@ def test_detour_flow(memory_jobstore, clean_dir, detour_flow): # run with log responses = run_locally(flow, store=memory_jobstore) - uuid2 = next(u for u in responses if u != uuid1 and u != uuid3) + uuid2 = next(u for u in responses if u not in {uuid1, uuid3}) # check responses has been filled assert len(responses) == 3 @@ -398,7 +398,7 @@ def test_detour_stop_flow(memory_jobstore, clean_dir, detour_stop_flow): # run with log responses = run_locally(flow, store=memory_jobstore) - uuid2 = next(u for u in responses if u != uuid1 and u != uuid3) + uuid2 = next(u for u in responses if u not in {uuid1, uuid3}) # check responses has been filled assert len(responses) == 2 diff --git a/tests/test_version.py b/tests/test_version.py index b5b2bd5e..bcf3c8ef 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -5,5 +5,5 @@ def test_installed_version(): import jobflow._version from jobflow import __version__ - assert re.match(r"^\d+\.\d+\.\d+$", __version__) + assert re.match(r"^\d+\.\d+\.\d+.*", __version__) assert __version__ == jobflow._version.__version__ diff --git a/tests/utils/test_dict_mods.py b/tests/utils/test_dict_mods.py index 29b29b07..9069788f 100644 --- a/tests/utils/test_dict_mods.py +++ b/tests/utils/test_dict_mods.py @@ -51,7 +51,8 @@ def test_apply_mod(): assert e == {"List": 3} mod = {"_add_to_set": {"number": 3}} - with pytest.raises(ValueError): + expected_err_msg = "Keyword number does not refer to an array" + with pytest.raises(ValueError, match=expected_err_msg): apply_mod(mod, d) mod = {"_pull": {"List": 1}} @@ -63,7 +64,7 @@ def test_apply_mod(): assert d == {"Bye": "World", "List": [2, 3], "number": 10} mod = {"_pull": {"number": 3}} - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=expected_err_msg): apply_mod(mod, d) mod = {"_pull_all": {"List": [2, 3]}} @@ -71,7 +72,7 @@ def test_apply_mod(): assert d == {"Bye": "World", "List": [], "number": 10} mod = {"_pull_all": {"number": 3}} - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=expected_err_msg): apply_mod(mod, d) mod = {"_push_all": {"List": list(range(10))}} @@ -95,7 +96,7 @@ def test_apply_mod(): assert d == {"Bye": "World", "List": [1, 2, 3, 4, 5, 6, 7, 8], "number": 10} mod = {"_pop": {"number": -1}} - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=expected_err_msg): apply_mod(mod, d) d = {} @@ -148,7 +149,7 @@ def test_apply_mod(): assert d == {"a": {"b": {"c": 102}, "e": {"f": [201, 301]}}} mod = {"_abcd": {"a": "b"}} - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="_abcd is not a supported action"): apply_mod(mod, d) mod = {"_set": {"": ""}} diff --git a/tests/utils/test_find.py b/tests/utils/test_find.py index 5a5160a1..18e97e0e 100644 --- a/tests/utils/test_find.py +++ b/tests/utils/test_find.py @@ -62,18 +62,18 @@ def test_contains_job_or_flow(): job = Job(str) flow = Flow([]) - assert contains_flow_or_job(True) is False - assert contains_flow_or_job(1) is False - assert contains_flow_or_job("abc") is False - assert contains_flow_or_job(job) is True - assert contains_flow_or_job(flow) is True - assert contains_flow_or_job([flow]) is True - assert contains_flow_or_job([[flow]]) is True - assert contains_flow_or_job({"a": flow}) is True - assert contains_flow_or_job({"a": [flow]}) is True - assert contains_flow_or_job(job) is True - assert contains_flow_or_job([job]) is True - assert contains_flow_or_job([[job]]) is True - assert contains_flow_or_job({"a": job}) is True - assert contains_flow_or_job({"a": [job]}) is True - assert contains_flow_or_job({"a": [job], "b": datetime.now()}) is True + assert contains_flow_or_job(obj=True) is False + assert contains_flow_or_job(obj=1) is False + assert contains_flow_or_job(obj="abc") is False + assert contains_flow_or_job(obj=job) is True + assert contains_flow_or_job(obj=flow) is True + assert contains_flow_or_job(obj=[flow]) is True + assert contains_flow_or_job(obj=[[flow]]) is True + assert contains_flow_or_job(obj={"a": flow}) is True + assert contains_flow_or_job(obj={"a": [flow]}) is True + assert contains_flow_or_job(obj=job) is True + assert contains_flow_or_job(obj=[job]) is True + assert contains_flow_or_job(obj=[[job]]) is True + assert contains_flow_or_job(obj={"a": job}) is True + assert contains_flow_or_job(obj={"a": [job]}) is True + assert contains_flow_or_job(obj={"a": [job], "b": datetime.now()}) is True diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index 5f093df0..0b831645 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -16,24 +16,19 @@ def test_itergraph(): # test branched graph = DiGraph([("a", "b"), ("b", "c"), ("a", "c"), ("d", "b")]) result = list(itergraph(graph)) - assert result == ["a", "d", "b", "c"] or result == ["d", "a", "b", "c"] + assert result in (["a", "d", "b", "c"], ["d", "a", "b", "c"]) # test non-connected graph = DiGraph([("a", "b"), ("c", "d")]) with pytest.warns(UserWarning): result = list(itergraph(graph)) - assert ( - result == ["a", "b", "c", "d"] - or result == ["c", "d", "a", "b"] - or result == ["a", "c", "b", "d"] - or result == ["a", "c", "d", "b"] - or result == ["c", "a", "b", "d"] - or result == ["c", "a", "d", "b"] - ) + assert {*result} == {"a", "b", "c", "d"} # test non DAG graph = DiGraph([("a", "b"), ("b", "a")]) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Graph is not acyclic, cannot determine dependency order" + ): list(itergraph(graph))