Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Match atomate2 ruff config #464

Merged
merged 9 commits into from
Oct 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ default_language_version:
exclude: "^src/atomate2/vasp/schemas/calc_types/"
repos:
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.1.0
rev: v0.1.1
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -46,7 +46,7 @@ repos:
- id: rst-directive-colons
- id: rst-inline-touching-normal
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.6.0
rev: v1.6.1
hooks:
- id: mypy
files: ^src/
Expand Down
3 changes: 1 addition & 2 deletions examples/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ def generate_big_data():
The data=True in the job decorator tells jobflow to store all outputs in the "data"
additional store.
"""
mydata = list(range(1000))
return mydata
return list(range(1000))


big_data_job = generate_big_data()
Expand Down
64 changes: 47 additions & 17 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,26 +121,56 @@ exclude_lines = [
target-version = "py39"
ignore-init-module-imports = true
select = [
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"D", # pydocstyle
"E", # pycodestyle
"F", # pyflakes
"I", # isort
"PLE", # pylint error
"PLW", # pylint warning
"Q", # flake8-quotes
"RUF", # Ruff-specific rules
"SIM", # flake8-simplify
"TID", # tidy imports
"UP", # pyupgrade
"W", # pycodestyle
"YTT", # flake8-2020
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"D", # pydocstyle
"E", # pycodestyle error
"EXE", # flake8-executable
"F", # pyflakes
"FA", # flake8-future-annotations
"FBT003", # boolean-positional-value-in-call
"FLY", # flynt
"I", # isort
"ICN", # flake8-import-conventions
"ISC", # flake8-implicit-str-concat
"PD", # pandas-vet
"PERF", # perflint
"PIE", # flake8-pie
"PL", # pylint
"PT", # flake8-pytest-style
"PYI", # flakes8-pyi
"Q", # flake8-quotes
"RET", # flake8-return
"RSE", # flake8-raise
"RUF", # Ruff-specific rules
"SIM", # flake8-simplify
"SLOT", # flake8-slots
"TCH", # flake8-type-checking
"TID", # flake8-tidy-imports
"UP", # pyupgrade
"W", # pycodestyle warning
"YTT", # flake8-2020
]
# PLR0911: too-many-return-statements
# PLR0912: too-many-branches
# PLR0913: too-many-arguments
# PLR0915: too-many-statements
ignore = [
"B028",
"PLR0911",
"PLR0912",
"PLR0913",
"PLR0915",
"PLW0603",
"RUF013",
]
ignore = ["B028", "PLW0603", "RUF013"]
pydocstyle.convention = "numpy"
isort.known-first-party = ["jobflow"]

[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]
"**/tests/*" = ["B018", "D"]
# PLR2004: magic-value-comparison
# PT004: pytest-missing-fixture-name-underscore
# PLR0915: too-many-statements
# PT011: pytest-raises-too-broad TODO fix these, should not be ignored
"**/tests/*" = ["B018", "D", "PLR0915", "PLR2004", "PT004", "PT011"]
3 changes: 1 addition & 2 deletions src/jobflow/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging
import warnings
from collections.abc import Sequence
from copy import deepcopy
from typing import TYPE_CHECKING

Expand All @@ -15,7 +14,7 @@
from jobflow.utils import ValueEnum, contains_flow_or_job, suuid

if TYPE_CHECKING:
from collections.abc import Iterator
from collections.abc import Iterator, Sequence
from typing import Any, Callable

from networkx import DiGraph
Expand Down
29 changes: 14 additions & 15 deletions src/jobflow/core/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import contextlib
import typing
from collections.abc import Sequence
from typing import Any

from monty.json import MontyDecoder, MontyEncoder, MSONable, jsanitize
Expand All @@ -14,6 +13,8 @@
from jobflow.utils.enum import ValueEnum

if typing.TYPE_CHECKING:
from collections.abc import Sequence

import jobflow


Expand Down Expand Up @@ -88,14 +89,14 @@ def __init__(
uuid: str,
attributes: tuple[tuple[str, Any], ...] = (),
output_schema: type[BaseModel] = None,
):
) -> None:
super().__init__()
self.uuid = uuid
self.attributes = attributes
self.output_schema = output_schema

for attr_type, attr in attributes:
if attr_type not in ("a", "i"):
if attr_type not in {"a", "i"}:
raise ValueError(
f"Unrecognised attribute type '{attr_type}' for attribute '{attr}'"
)
Expand Down Expand Up @@ -161,9 +162,9 @@ def resolve(
f"Could not resolve reference - {self.uuid}{istr} not in store or "
f"{index=}, {cache=}"
)
elif on_missing == OnMissing.NONE and index not in cache[self.uuid]:
if on_missing == OnMissing.NONE and index not in cache[self.uuid]:
return None
elif on_missing == OnMissing.PASS and index not in cache[self.uuid]:
if on_missing == OnMissing.PASS and index not in cache[self.uuid]:
return self

data = cache[self.uuid][index]
Expand Down Expand Up @@ -200,12 +201,11 @@ def set_uuid(self, uuid: str, inplace=True) -> OutputReference:
if inplace:
self.uuid = uuid
return self
else:
from copy import deepcopy
from copy import deepcopy

new_reference = deepcopy(self)
new_reference.uuid = uuid
return new_reference
new_reference = deepcopy(self)
new_reference.uuid = uuid
return new_reference

def __getitem__(self, item) -> OutputReference:
"""Index the reference."""
Expand Down Expand Up @@ -263,7 +263,7 @@ def __hash__(self) -> int:
"""Return a hash of the reference."""
return hash(str(self))

def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
"""Test for equality against another reference."""
if isinstance(other, OutputReference):
return (
Expand All @@ -285,15 +285,14 @@ def as_dict(self):
"""Serialize the reference as a dict."""
schema = self.output_schema
schema_dict = MontyEncoder().default(schema) if schema is not None else None
data = {
return {
"@module": self.__class__.__module__,
"@class": type(self).__name__,
"@version": None,
"uuid": self.uuid,
"attributes": self.attributes,
"output_schema": schema_dict,
}
return data


def resolve_references(
Expand Down Expand Up @@ -376,7 +375,7 @@ def find_and_get_references(arg: Any) -> tuple[OutputReference, ...]:
# if the argument is a reference then stop there
return (arg,)

elif isinstance(arg, (float, int, str, bool)):
if isinstance(arg, (float, int, str, bool)):
# argument is a primitive, we won't find a reference here
return ()

Expand Down Expand Up @@ -432,7 +431,7 @@ def find_and_resolve_references(
# if the argument is a reference then stop there
return arg.resolve(store, cache=cache, on_missing=on_missing)

elif isinstance(arg, (float, int, str, bool)):
if isinstance(arg, (float, int, str, bool)):
# argument is a primitive, we won't find a reference here
return arg

Expand Down
2 changes: 0 additions & 2 deletions src/jobflow/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from monty.design_patterns import singleton

if typing.TYPE_CHECKING:
pass

import jobflow


Expand Down
43 changes: 21 additions & 22 deletions src/jobflow/core/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from monty.json import MSONable

from jobflow.core.reference import OnMissing
from jobflow.schemas.job_output_schema import JobStoreDocument
from jobflow.utils.find import get_root_locations

if typing.TYPE_CHECKING:
Expand All @@ -19,6 +18,8 @@

from maggma.core import Sort

from jobflow.schemas.job_output_schema import JobStoreDocument

obj_type = Union[str, Enum, type[MSONable], list[Union[Enum, str, type[MSONable]]]]
save_type = Optional[dict[str, obj_type]]
load_type = Union[bool, dict[str, Union[bool, obj_type]]]
Expand Down Expand Up @@ -250,8 +251,7 @@ def query_one(
docs = self.query(
criteria=criteria, properties=properties, load=load, sort=sort, limit=1
)
d = next(docs, None)
return d
return next(docs, None)

def update(
self,
Expand Down Expand Up @@ -496,7 +496,7 @@ def get_output(
# this could be fixed but will require more complicated logic just to
# catch a very unlikely event.

if isinstance(which, int) or which in ("last", "first"):
if isinstance(which, int) or which in {"last", "first"}:
sort = -1 if which == "last" else 1

criteria: dict[str, Any] = {"uuid": uuid}
Expand All @@ -522,28 +522,27 @@ def get_output(
return find_and_resolve_references(
result["output"], self, cache=cache, on_missing=on_missing
)
else:
results = list(
self.query(
criteria={"uuid": uuid},
properties=["output"],
sort={"index": 1},
load=load,
)
results = list(
self.query(
criteria={"uuid": uuid},
properties=["output"],
sort={"index": 1},
load=load,
)
)

if len(results) == 0:
raise ValueError(f"UUID: {uuid} has no outputs.")
if len(results) == 0:
raise ValueError(f"UUID: {uuid} has no outputs.")

results = [r["output"] for r in results]
results = [r["output"] for r in results]

refs = find_and_get_references(results)
if any(ref.uuid == uuid for ref in refs):
raise RuntimeError("Reference cycle detected - aborting.")
refs = find_and_get_references(results)
if any(ref.uuid == uuid for ref in refs):
raise RuntimeError("Reference cycle detected - aborting.")

return find_and_resolve_references(
results, self, cache=cache, on_missing=on_missing
)
return find_and_resolve_references(
results, self, cache=cache, on_missing=on_missing
)

@classmethod
def from_file(cls: type[T], db_file: str | Path, **kwargs) -> T:
Expand Down Expand Up @@ -761,7 +760,7 @@ def _group_blobs(infos, locs):

new_blobs = []
new_locations = []
for _store_name, store_load in load.items():
for store_load in load.values():
for blob, location in zip(blob_infos, locations):
if store_load is True:
new_blobs.append(blob)
Expand Down
5 changes: 2 additions & 3 deletions src/jobflow/managers/fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections.abc import Sequence

import jobflow
from jobflow.core.job import Job


def flow_to_workflow(
Expand Down Expand Up @@ -146,7 +147,6 @@ class JobFiretask(FiretaskBase):
def run_task(self, fw_spec):
"""Run the job and handle any dynamic firework submissions."""
from jobflow import SETTINGS, initialize_logger
from jobflow.core.job import Job

job: Job = self.get("job")
store = self.get("store")
Expand Down Expand Up @@ -190,11 +190,10 @@ def run_task(self, fw_spec):
else:
detours = [detour_wf]

fwa = FWAction(
return FWAction(
stored_data=response.stored_data,
detours=detours,
additions=additions,
defuse_workflow=response.stop_jobflow,
defuse_children=response.stop_children,
)
return fwa
8 changes: 2 additions & 6 deletions src/jobflow/managers/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import typing

if typing.TYPE_CHECKING:
pass

import jobflow


Expand Down Expand Up @@ -133,17 +131,15 @@ def _run_job(job: jobflow.Job, parents):

if not all(diversion_responses):
return None, False
else:
return response, False
return response, False

def _get_job_dir():
if create_folders:
time_now = datetime.utcnow().strftime(SETTINGS.DIRECTORY_FORMAT)
job_dir = root_dir / f"job_{time_now}-{randint(10000, 99999)}"
job_dir.mkdir()
return job_dir
else:
return root_dir
return root_dir

def _run(root_flow):
encountered_bad_response = False
Expand Down
3 changes: 1 addition & 2 deletions src/jobflow/utils/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ def __eq__(self, other):
"""Compare to another enum for equality."""
if type(self) == type(other) and self.value == other.value:
return True
else:
return str(self.value) == str(other)
return str(self.value) == str(other)

def __hash__(self):
"""Get a hash of the enum."""
Expand Down
Loading