Skip to content

Commit

Permalink
switch to ruff (#466)
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri authored Jan 31, 2025
1 parent cd84743 commit e2b5e2d
Show file tree
Hide file tree
Showing 46 changed files with 72 additions and 107 deletions.
1 change: 0 additions & 1 deletion examples/ase/run_ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
# First, we start by importing the necessary libraries, including the integration of ASE
# calculators for metatensor atomistic models.


import ase.md
import ase.md.velocitydistribution
import ase.units
Expand Down
26 changes: 15 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,21 @@ source = [
".tox/*/lib/python*/site-packages/metatrain"
]

[tool.black]
exclude = 'docs/src/examples'

[tool.isort]
skip = "__init__.py"
profile = "black"
line_length = 88
indent = 4
include_trailing_comma = true
lines_after_imports = 2
known_first_party = "metatrain"
[tool.ruff]
exclude = ["docs/src/examples/**"]
line-length = 88

[tool.ruff.lint]
select = ["E", "F", "B", "I"]
ignore = ["B018", "B904"]

[tool.ruff.lint.isort]
lines-after-imports = 2
known-first-party = ["metatrain"]
known-third-party = ["torch"]

[tool.ruff.format]
docstring-code-format = true

[tool.mypy]
exclude = [
Expand Down
3 changes: 2 additions & 1 deletion src/metatrain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path
import secrets
from pathlib import Path


PACKAGE_ROOT = Path(__file__).parent.resolve()

Expand Down
4 changes: 2 additions & 2 deletions src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ def _eval_targets(
std_per_atom = np.std(timings_per_atom)
logger.info(
f"evaluation time: {total_time:.2f} s "
f"[{1000.0*mean_per_atom:.4f} ± "
f"{1000.0*std_per_atom:.4f} ms per atom]"
f"[{1000.0 * mean_per_atom:.4f} ± "
f"{1000.0 * std_per_atom:.4f} ms per atom]"
)

if return_predictions:
Expand Down
1 change: 1 addition & 0 deletions src/metatrain/experimental/gap/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .model import GAP
from .trainer import Trainer


__model__ = GAP
__trainer__ = Trainer

Expand Down
4 changes: 1 addition & 3 deletions src/metatrain/experimental/gap/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ def forward(
return return_dict

def export(self) -> MetatensorAtomisticModel:

interaction_ranges = [self.hypers["soap"]["cutoff"]]
for additive_model in self.additive_models:
if hasattr(additive_model, "cutoff_radius"):
Expand Down Expand Up @@ -458,8 +457,7 @@ def partial_fit(self, KNM, Y, accumulate_only=False, rcond=None):
Phi = KNM
else:
raise ValueError(
"Partial fit can only be realized with "
"solver = 'RKHS' or 'solve'"
"Partial fit can only be realized with solver = 'RKHS' or 'solve'"
)
if self._KY is None:
self._KY = np.zeros((self._nM, Y.shape[1]))
Expand Down
1 change: 1 addition & 0 deletions src/metatrain/experimental/gap/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from metatrain.utils.architectures import get_default_hypers


DEFAULT_HYPERS = get_default_hypers("experimental.gap")
DATASET_PATH = str(Path(__file__).parents[5] / "tests/resources/qm9_reduced_100.xyz")

Expand Down
1 change: 1 addition & 0 deletions src/metatrain/experimental/nanopet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .model import NanoPET
from .trainer import Trainer


__model__ = NanoPET
__trainer__ = Trainer

Expand Down
2 changes: 0 additions & 2 deletions src/metatrain/experimental/nanopet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,6 @@ def requested_neighbor_lists(

@classmethod
def load_checkpoint(cls, path: Union[str, Path]) -> "NanoPET":

# Load the checkpoint
checkpoint = torch.load(path, weights_only=False, map_location="cpu")
model_data = checkpoint["model_data"]
Expand Down Expand Up @@ -539,7 +538,6 @@ def export(self) -> MetatensorAtomisticModel:
return MetatensorAtomisticModel(self.eval(), ModelMetadata(), capabilities)

def _add_output(self, target_name: str, target_info: TargetInfo) -> None:

# one output shape for each tensor block, grouped by target (i.e. tensormap)
self.output_shapes[target_name] = {}
for key, block in target_info.layout.items():
Expand Down
1 change: 0 additions & 1 deletion src/metatrain/experimental/nanopet/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def forward(
inputs: torch.Tensor, # seq_len hidden_size
radial_mask: torch.Tensor, # seq_len
) -> torch.Tensor: # seq_len hidden_size

# Pre-layer normalization
normed_inputs = self.layernorm(inputs)

Expand Down
2 changes: 0 additions & 2 deletions src/metatrain/experimental/nanopet/modules/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def _apply_wigner_D_matrices(
transformations: List[torch.Tensor],
wigner_D_matrices: Dict[int, List[torch.Tensor]],
) -> TensorMap:

new_blocks: List[TensorBlock] = []
for key, block in target_tmap.items():
ell, sigma = int(key[0]), int(key[1])
Expand Down Expand Up @@ -166,7 +165,6 @@ def _apply_random_augmentations(
transformations: List[torch.Tensor],
wigner_D_matrices: Dict[int, List[torch.Tensor]],
) -> Tuple[List[System], Dict[str, TensorMap]]:

# Apply the transformations to the systems
new_systems: List[System] = []
for system, transformation in zip(systems, transformations):
Expand Down
1 change: 0 additions & 1 deletion src/metatrain/experimental/nanopet/modules/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def forward(
self,
inputs: torch.Tensor, # hidden_size
) -> torch.Tensor: # hidden_size

# Pre-layer normalization
normed_inputs = self.layernorm(inputs)

Expand Down
1 change: 0 additions & 1 deletion src/metatrain/experimental/nanopet/modules/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
def concatenate_structures(
systems: List[System], neighbor_list_options: NeighborListOptions
):

positions = []
centers = []
neighbors = []
Expand Down
2 changes: 0 additions & 2 deletions src/metatrain/experimental/nanopet/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def forward(
inputs: torch.Tensor,
radial_mask: torch.Tensor,
) -> torch.Tensor:

attention_output = self.attention_block(inputs, radial_mask)
output = self.ff_block(attention_output)

Expand Down Expand Up @@ -73,7 +72,6 @@ def forward(
inputs,
radial_mask,
):

x = inputs
for layer in self.layers:
x = layer(x, radial_mask)
Expand Down
2 changes: 2 additions & 0 deletions src/metatrain/experimental/nanopet/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pathlib import Path

from metatrain.utils.architectures import get_default_hypers


DATASET_PATH = str(Path(__file__).parents[5] / "tests/resources/qm9_reduced_100.xyz")

DEFAULT_HYPERS = get_default_hypers("experimental.nanopet")
Expand Down
1 change: 0 additions & 1 deletion src/metatrain/experimental/nanopet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,6 @@ def save_checkpoint(self, model, path: Union[str, Path]):

@classmethod
def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer":

# Load the checkpoint
checkpoint = torch.load(path, weights_only=False)
epoch = checkpoint["epoch"]
Expand Down
1 change: 1 addition & 0 deletions src/metatrain/experimental/pet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .model import PET
from .trainer import Trainer


__model__ = PET
__trainer__ = Trainer
__capabilities__ = {
Expand Down
1 change: 0 additions & 1 deletion src/metatrain/experimental/pet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ def forward(

@classmethod
def load_checkpoint(cls, path: Union[str, Path]) -> "PET":

checkpoint = torch.load(path, weights_only=False, map_location="cpu")
hypers = checkpoint["hypers"]
model_hypers = hypers["ARCHITECTURAL_HYPERS"]
Expand Down
1 change: 1 addition & 0 deletions src/metatrain/experimental/pet/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path


DATASET_PATH = str(Path(__file__).parents[5] / "tests/resources/carbon_reduced_100.xyz")
1 change: 0 additions & 1 deletion src/metatrain/experimental/pet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,6 @@ def train(
FITTING_SCHEME.USE_SHIFT_AGNOSTIC_LOSS,
)
if MLIP_SETTINGS.USE_FORCES:

if FITTING_SCHEME.MULTI_GPU:
forces_list = [el.forces for el in batch]
batch_forces = torch.cat(forces_list, dim=0).to(device)
Expand Down
5 changes: 3 additions & 2 deletions src/metatrain/experimental/pet/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .systems_to_batch_dict import systems_to_batch_dict
from .dataset_to_ase import dataset_to_ase
from .load_raw_pet_model import load_raw_pet_model
from .systems_to_batch_dict import systems_to_batch_dict
from .update_hypers import update_hypers
from .update_state_dict import update_state_dict
from .load_raw_pet_model import load_raw_pet_model


__all__ = [
"systems_to_batch_dict",
Expand Down
1 change: 1 addition & 0 deletions src/metatrain/experimental/soap_bpnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .model import SoapBpnn
from .trainer import Trainer


__model__ = SoapBpnn
__trainer__ = Trainer

Expand Down
4 changes: 0 additions & 4 deletions src/metatrain/experimental/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ class MLPHeadMap(ModuleMap):
def __init__(
self, in_keys: Labels, num_features: int, out_properties: List[Labels]
) -> None:

# hardcoded for now, but could be a hyperparameter
activation_function = torch.nn.SiLU()

Expand All @@ -119,7 +118,6 @@ def __init__(


class SoapBpnn(torch.nn.Module):

__supported_devices__ = ["cuda", "cpu"]
__supported_dtypes__ = [torch.float64, torch.float32]

Expand Down Expand Up @@ -428,7 +426,6 @@ def forward(

@classmethod
def load_checkpoint(cls, path: Union[str, Path]) -> "SoapBpnn":

# Load the checkpoint
checkpoint = torch.load(path, weights_only=False, map_location="cpu")
model_data = checkpoint["model_data"]
Expand Down Expand Up @@ -469,7 +466,6 @@ def export(self) -> MetatensorAtomisticModel:
return MetatensorAtomisticModel(self.eval(), ModelMetadata(), capabilities)

def _add_output(self, target_name: str, target: TargetInfo) -> None:

# register bases of spherical tensors (TensorBasis)
self.num_properties[target_name] = {}
self.basis_calculators[target_name] = torch.nn.ModuleDict({})
Expand Down
9 changes: 4 additions & 5 deletions src/metatrain/experimental/soap_bpnn/spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,9 @@ def forward(
sh_1[:, lam * lam : (lam + 1) * (lam + 1)],
sh_2[
:,
(self.o3_lambda - lam)
* (self.o3_lambda - lam) : ((self.o3_lambda - lam) + 1)
(self.o3_lambda - lam) * (self.o3_lambda - lam) : (
(self.o3_lambda - lam) + 1
)
* ((self.o3_lambda - lam) + 1),
],
self.cgs[
Expand All @@ -235,8 +236,7 @@ def forward(
sh_1[:, lam * lam : (lam + 1) * (lam + 1)],
sh_2[
:,
(self.o3_lambda - lam - 1)
* (self.o3_lambda - lam - 1) : (
(self.o3_lambda - lam - 1) * (self.o3_lambda - lam - 1) : (
(self.o3_lambda - lam - 1) + 1
)
* ((self.o3_lambda - lam - 1) + 1),
Expand Down Expand Up @@ -294,7 +294,6 @@ def get_cg_coefficients(l_max):


class ClebschGordanReal:

def __init__(self):
self._cgs = {}

Expand Down
2 changes: 2 additions & 0 deletions src/metatrain/experimental/soap_bpnn/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pathlib import Path

from metatrain.utils.architectures import get_default_hypers


DATASET_PATH = str(Path(__file__).parents[5] / "tests/resources/qm9_reduced_100.xyz")

DEFAULT_HYPERS = get_default_hypers("experimental.soap_bpnn")
Expand Down
1 change: 0 additions & 1 deletion src/metatrain/experimental/soap_bpnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,6 @@ def save_checkpoint(self, model, path: Union[str, Path]):

@classmethod
def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer":

# Load the checkpoint
checkpoint = torch.load(path, weights_only=False, map_location="cpu")
epoch = checkpoint["epoch"]
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/utils/additive/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .composition import CompositionModel # noqa: F401
from .zbl import ZBL # noqa: F401
from .remove import remove_additive # noqa: F401
from .zbl import ZBL # noqa: F401
1 change: 0 additions & 1 deletion src/metatrain/utils/additive/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def train_model(
# Fill the weights for each "new" target (i.e. those that do not already
# have composition weights from a previous training run)
for target_key in self.new_targets:

if target_key in fixed_weights:
# The fixed weights are provided for this target. Use them:
if not sorted(fixed_weights[target_key].keys()) == self.atomic_types:
Expand Down
14 changes: 7 additions & 7 deletions src/metatrain/utils/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from .combine_dataloaders import CombinedDataLoader # noqa: F401
from .dataset import ( # noqa: F401
Dataset,
DatasetInfo,
get_atomic_types,
get_all_targets,
collate_fn,
check_datasets,
collate_fn,
get_all_targets,
get_atomic_types,
get_stats,
)
from .target_info import TargetInfo # noqa: F401
from .get_dataset import get_dataset # noqa: F401
from .readers import read_systems, read_targets # noqa: F401
from .writers import write_predictions # noqa: F401
from .combine_dataloaders import CombinedDataLoader # noqa: F401
from .system_to_ase import system_to_ase # noqa: F401
from .get_dataset import get_dataset # noqa: F401
from .target_info import TargetInfo # noqa: F401
from .writers import write_predictions # noqa: F401
2 changes: 0 additions & 2 deletions src/metatrain/utils/data/readers/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def _read_forces_ase(filename: str, key: str = "energy") -> List[TensorBlock]:

blocks = []
for i_system, atoms in enumerate(frames):

if key not in atoms.arrays:
raise ValueError(
f"forces key {key!r} was not found in system {filename!r} at index "
Expand Down Expand Up @@ -300,7 +299,6 @@ def read_generic(target: DictConfig) -> Tuple[List[TensorMap], TargetInfo]:

tensor_maps = []
for i_system, atoms in enumerate(frames):

if not per_atom and target_key not in atoms.info:
raise ValueError(
f"Target key {target_key!r} was not found in system {filename!r} at "
Expand Down
1 change: 0 additions & 1 deletion src/metatrain/utils/data/readers/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def read_targets(
standard_outputs_list = ["energy"]

for target_key, target in conf.items():

is_standard_target = target_key in standard_outputs_list
if not is_standard_target and not target_key.startswith("mtt::"):
if target_key.lower() in ["force", "forces", "virial", "stress"]:
Expand Down
6 changes: 3 additions & 3 deletions src/metatrain/utils/data/target_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,9 @@ def _check_layout(self, layout: TensorMap) -> None:
f"Found '{layout.keys.names}' instead."
)
for key, block in layout.items():
o3_lambda, o3_sigma = int(key.values[0].item()), int(
key.values[1].item()
o3_lambda, o3_sigma = (
int(key.values[0].item()),
int(key.values[1].item()),
)
if o3_sigma not in [-1, 1]:
raise ValueError(
Expand Down Expand Up @@ -223,7 +224,6 @@ def get_energy_target_info(
add_position_gradients: bool = False,
add_strain_gradients: bool = False,
) -> TargetInfo:

block = TensorBlock(
# float64: otherwise metatensor can't serialize
values=torch.empty(0, 1, dtype=torch.float64),
Expand Down
Loading

0 comments on commit e2b5e2d

Please sign in to comment.