Skip to content

Commit

Permalink
Merge pull request #464 from materialsproject/match-atomate2-ruff-config
Browse files Browse the repository at this point in the history
Match `atomate2` `ruff` config
  • Loading branch information
janosh authored Oct 21, 2023
2 parents abca261 + a34e772 commit d7e6760
Show file tree
Hide file tree
Showing 25 changed files with 205 additions and 178 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ default_language_version:
exclude: "^src/atomate2/vasp/schemas/calc_types/"
repos:
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.1.0
rev: v0.1.1
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -46,7 +46,7 @@ repos:
- id: rst-directive-colons
- id: rst-inline-touching-normal
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.6.0
rev: v1.6.1
hooks:
- id: mypy
files: ^src/
Expand Down
3 changes: 1 addition & 2 deletions examples/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ def generate_big_data():
The data=True in the job decorator tells jobflow to store all outputs in the "data"
additional store.
"""
mydata = list(range(1000))
return mydata
return list(range(1000))


big_data_job = generate_big_data()
Expand Down
64 changes: 47 additions & 17 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,26 +121,56 @@ exclude_lines = [
target-version = "py39"
ignore-init-module-imports = true
select = [
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"D", # pydocstyle
"E", # pycodestyle
"F", # pyflakes
"I", # isort
"PLE", # pylint error
"PLW", # pylint warning
"Q", # flake8-quotes
"RUF", # Ruff-specific rules
"SIM", # flake8-simplify
"TID", # tidy imports
"UP", # pyupgrade
"W", # pycodestyle
"YTT", # flake8-2020
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"D", # pydocstyle
"E", # pycodestyle error
"EXE", # flake8-executable
"F", # pyflakes
"FA", # flake8-future-annotations
"FBT003", # boolean-positional-value-in-call
"FLY", # flynt
"I", # isort
"ICN", # flake8-import-conventions
"ISC", # flake8-implicit-str-concat
"PD", # pandas-vet
"PERF", # perflint
"PIE", # flake8-pie
"PL", # pylint
"PT", # flake8-pytest-style
"PYI", # flakes8-pyi
"Q", # flake8-quotes
"RET", # flake8-return
"RSE", # flake8-raise
"RUF", # Ruff-specific rules
"SIM", # flake8-simplify
"SLOT", # flake8-slots
"TCH", # flake8-type-checking
"TID", # flake8-tidy-imports
"UP", # pyupgrade
"W", # pycodestyle warning
"YTT", # flake8-2020
]
# PLR0911: too-many-return-statements
# PLR0912: too-many-branches
# PLR0913: too-many-arguments
# PLR0915: too-many-statements
ignore = [
"B028",
"PLR0911",
"PLR0912",
"PLR0913",
"PLR0915",
"PLW0603",
"RUF013",
]
ignore = ["B028", "PLW0603", "RUF013"]
pydocstyle.convention = "numpy"
isort.known-first-party = ["jobflow"]

[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]
"**/tests/*" = ["B018", "D"]
# PLR2004: magic-value-comparison
# PT004: pytest-missing-fixture-name-underscore
# PLR0915: too-many-statements
# PT011: pytest-raises-too-broad TODO fix these, should not be ignored
"**/tests/*" = ["B018", "D", "PLR0915", "PLR2004", "PT004", "PT011"]
3 changes: 1 addition & 2 deletions src/jobflow/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging
import warnings
from collections.abc import Sequence
from copy import deepcopy
from typing import TYPE_CHECKING

Expand All @@ -15,7 +14,7 @@
from jobflow.utils import ValueEnum, contains_flow_or_job, suuid

if TYPE_CHECKING:
from collections.abc import Iterator
from collections.abc import Iterator, Sequence
from typing import Any, Callable

from networkx import DiGraph
Expand Down
29 changes: 14 additions & 15 deletions src/jobflow/core/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import contextlib
import typing
from collections.abc import Sequence
from typing import Any

from monty.json import MontyDecoder, MontyEncoder, MSONable, jsanitize
Expand All @@ -14,6 +13,8 @@
from jobflow.utils.enum import ValueEnum

if typing.TYPE_CHECKING:
from collections.abc import Sequence

import jobflow


Expand Down Expand Up @@ -88,14 +89,14 @@ def __init__(
uuid: str,
attributes: tuple[tuple[str, Any], ...] = (),
output_schema: type[BaseModel] = None,
):
) -> None:
super().__init__()
self.uuid = uuid
self.attributes = attributes
self.output_schema = output_schema

for attr_type, attr in attributes:
if attr_type not in ("a", "i"):
if attr_type not in {"a", "i"}:
raise ValueError(
f"Unrecognised attribute type '{attr_type}' for attribute '{attr}'"
)
Expand Down Expand Up @@ -161,9 +162,9 @@ def resolve(
f"Could not resolve reference - {self.uuid}{istr} not in store or "
f"{index=}, {cache=}"
)
elif on_missing == OnMissing.NONE and index not in cache[self.uuid]:
if on_missing == OnMissing.NONE and index not in cache[self.uuid]:
return None
elif on_missing == OnMissing.PASS and index not in cache[self.uuid]:
if on_missing == OnMissing.PASS and index not in cache[self.uuid]:
return self

data = cache[self.uuid][index]
Expand Down Expand Up @@ -200,12 +201,11 @@ def set_uuid(self, uuid: str, inplace=True) -> OutputReference:
if inplace:
self.uuid = uuid
return self
else:
from copy import deepcopy
from copy import deepcopy

new_reference = deepcopy(self)
new_reference.uuid = uuid
return new_reference
new_reference = deepcopy(self)
new_reference.uuid = uuid
return new_reference

def __getitem__(self, item) -> OutputReference:
"""Index the reference."""
Expand Down Expand Up @@ -263,7 +263,7 @@ def __hash__(self) -> int:
"""Return a hash of the reference."""
return hash(str(self))

def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
"""Test for equality against another reference."""
if isinstance(other, OutputReference):
return (
Expand All @@ -285,15 +285,14 @@ def as_dict(self):
"""Serialize the reference as a dict."""
schema = self.output_schema
schema_dict = MontyEncoder().default(schema) if schema is not None else None
data = {
return {
"@module": self.__class__.__module__,
"@class": type(self).__name__,
"@version": None,
"uuid": self.uuid,
"attributes": self.attributes,
"output_schema": schema_dict,
}
return data


def resolve_references(
Expand Down Expand Up @@ -376,7 +375,7 @@ def find_and_get_references(arg: Any) -> tuple[OutputReference, ...]:
# if the argument is a reference then stop there
return (arg,)

elif isinstance(arg, (float, int, str, bool)):
if isinstance(arg, (float, int, str, bool)):
# argument is a primitive, we won't find a reference here
return ()

Expand Down Expand Up @@ -432,7 +431,7 @@ def find_and_resolve_references(
# if the argument is a reference then stop there
return arg.resolve(store, cache=cache, on_missing=on_missing)

elif isinstance(arg, (float, int, str, bool)):
if isinstance(arg, (float, int, str, bool)):
# argument is a primitive, we won't find a reference here
return arg

Expand Down
2 changes: 0 additions & 2 deletions src/jobflow/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from monty.design_patterns import singleton

if typing.TYPE_CHECKING:
pass

import jobflow


Expand Down
43 changes: 21 additions & 22 deletions src/jobflow/core/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from monty.json import MSONable

from jobflow.core.reference import OnMissing
from jobflow.schemas.job_output_schema import JobStoreDocument
from jobflow.utils.find import get_root_locations

if typing.TYPE_CHECKING:
Expand All @@ -19,6 +18,8 @@

from maggma.core import Sort

from jobflow.schemas.job_output_schema import JobStoreDocument

obj_type = Union[str, Enum, type[MSONable], list[Union[Enum, str, type[MSONable]]]]
save_type = Optional[dict[str, obj_type]]
load_type = Union[bool, dict[str, Union[bool, obj_type]]]
Expand Down Expand Up @@ -250,8 +251,7 @@ def query_one(
docs = self.query(
criteria=criteria, properties=properties, load=load, sort=sort, limit=1
)
d = next(docs, None)
return d
return next(docs, None)

def update(
self,
Expand Down Expand Up @@ -496,7 +496,7 @@ def get_output(
# this could be fixed but will require more complicated logic just to
# catch a very unlikely event.

if isinstance(which, int) or which in ("last", "first"):
if isinstance(which, int) or which in {"last", "first"}:
sort = -1 if which == "last" else 1

criteria: dict[str, Any] = {"uuid": uuid}
Expand All @@ -522,28 +522,27 @@ def get_output(
return find_and_resolve_references(
result["output"], self, cache=cache, on_missing=on_missing
)
else:
results = list(
self.query(
criteria={"uuid": uuid},
properties=["output"],
sort={"index": 1},
load=load,
)
results = list(
self.query(
criteria={"uuid": uuid},
properties=["output"],
sort={"index": 1},
load=load,
)
)

if len(results) == 0:
raise ValueError(f"UUID: {uuid} has no outputs.")
if len(results) == 0:
raise ValueError(f"UUID: {uuid} has no outputs.")

results = [r["output"] for r in results]
results = [r["output"] for r in results]

refs = find_and_get_references(results)
if any(ref.uuid == uuid for ref in refs):
raise RuntimeError("Reference cycle detected - aborting.")
refs = find_and_get_references(results)
if any(ref.uuid == uuid for ref in refs):
raise RuntimeError("Reference cycle detected - aborting.")

return find_and_resolve_references(
results, self, cache=cache, on_missing=on_missing
)
return find_and_resolve_references(
results, self, cache=cache, on_missing=on_missing
)

@classmethod
def from_file(cls: type[T], db_file: str | Path, **kwargs) -> T:
Expand Down Expand Up @@ -761,7 +760,7 @@ def _group_blobs(infos, locs):

new_blobs = []
new_locations = []
for _store_name, store_load in load.items():
for store_load in load.values():
for blob, location in zip(blob_infos, locations):
if store_load is True:
new_blobs.append(blob)
Expand Down
5 changes: 2 additions & 3 deletions src/jobflow/managers/fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections.abc import Sequence

import jobflow
from jobflow.core.job import Job


def flow_to_workflow(
Expand Down Expand Up @@ -146,7 +147,6 @@ class JobFiretask(FiretaskBase):
def run_task(self, fw_spec):
"""Run the job and handle any dynamic firework submissions."""
from jobflow import SETTINGS, initialize_logger
from jobflow.core.job import Job

job: Job = self.get("job")
store = self.get("store")
Expand Down Expand Up @@ -190,11 +190,10 @@ def run_task(self, fw_spec):
else:
detours = [detour_wf]

fwa = FWAction(
return FWAction(
stored_data=response.stored_data,
detours=detours,
additions=additions,
defuse_workflow=response.stop_jobflow,
defuse_children=response.stop_children,
)
return fwa
8 changes: 2 additions & 6 deletions src/jobflow/managers/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import typing

if typing.TYPE_CHECKING:
pass

import jobflow


Expand Down Expand Up @@ -133,17 +131,15 @@ def _run_job(job: jobflow.Job, parents):

if not all(diversion_responses):
return None, False
else:
return response, False
return response, False

def _get_job_dir():
if create_folders:
time_now = datetime.utcnow().strftime(SETTINGS.DIRECTORY_FORMAT)
job_dir = root_dir / f"job_{time_now}-{randint(10000, 99999)}"
job_dir.mkdir()
return job_dir
else:
return root_dir
return root_dir

def _run(root_flow):
encountered_bad_response = False
Expand Down
3 changes: 1 addition & 2 deletions src/jobflow/utils/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ def __eq__(self, other):
"""Compare to another enum for equality."""
if type(self) == type(other) and self.value == other.value:
return True
else:
return str(self.value) == str(other)
return str(self.value) == str(other)

def __hash__(self):
"""Get a hash of the enum."""
Expand Down
Loading

0 comments on commit d7e6760

Please sign in to comment.