Skip to content

Commit

Permalink
Extend ruff linting (#293)
Browse files Browse the repository at this point in the history
* Extend ruff linting

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jan-janssen and pre-commit-ci[bot] authored Feb 5, 2025
1 parent 34aa552 commit 6ac751c
Show file tree
Hide file tree
Showing 14 changed files with 118 additions and 93 deletions.
10 changes: 5 additions & 5 deletions .ci_support/release.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_setup_version_and_pattern(setup_content):
version_lst.append(dep.split("==")[1])
depend_lst.append(dep.split("==")[0])

version_high_dict = {d: v for d, v in zip(depend_lst, version_lst)}
version_high_dict = dict(zip(depend_lst, version_lst))
return version_high_dict


Expand All @@ -30,13 +30,13 @@ def get_env_version(env_content):
if len(lst) == 2:
depend_lst.append(lst[0])
version_lst.append(lst[1])
return {d: v for d, v in zip(depend_lst, version_lst)}
return dict(zip(depend_lst, version_lst))


def update_dependencies(setup_content, version_low_dict, version_high_dict):
version_combo_dict = {}
for dep, ver in version_high_dict.items():
if dep in version_low_dict.keys() and version_low_dict[dep] != ver:
if dep in version_low_dict and version_low_dict[dep] != ver:
version_combo_dict[dep] = dep + ">=" + version_low_dict[dep] + ",<=" + ver
else:
version_combo_dict[dep] = dep + "==" + ver
Expand All @@ -52,10 +52,10 @@ def update_dependencies(setup_content, version_low_dict, version_high_dict):


if __name__ == "__main__":
with open("pyproject.toml", "r") as f:
with open("pyproject.toml") as f:
setup_content = f.readlines()

with open("environment.yml", "r") as f:
with open("environment.yml") as f:
env_content = f.readlines()

setup_content_new = update_dependencies(
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ repos:
hooks:
- id: ruff
name: ruff lint
args: ["--select", "I", "--fix"]
args: ["--fix"]
files: ^pylammpsmpi/
- id: ruff-format
name: ruff format
4 changes: 2 additions & 2 deletions pylammpsmpi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

from ._version import get_versions

__all__ = [LammpsLibrary, LammpsConcurrent, LammpsBase]
__all__ = ["LammpsLibrary", "LammpsConcurrent", "LammpsBase"]
__version__ = get_versions()["version"]


try:
from pylammpsmpi.wrapper.ase import LammpsASELibrary

__all__ += [LammpsASELibrary]
__all__ += ["LammpsASELibrary"]
except ImportError:
pass
41 changes: 15 additions & 26 deletions pylammpsmpi/mpi/lmpmpi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding: utf-8
# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department
# Distributed under the terms of "New BSD License", see the LICENSE file.

Expand Down Expand Up @@ -127,7 +126,7 @@ def extract_atom(job, funct_args):
# extract atoms return an internal data type
# this has to be reformatted
name = str(funct_args[0])
if name not in atom_properties.keys():
if name not in atom_properties:
return []

# this block prevents error when trying to access values
Expand Down Expand Up @@ -166,14 +165,13 @@ def extract_variable(job, funct_args):
)
if MPI.COMM_WORLD.rank == 0:
return np.array(data)
else:
if MPI.COMM_WORLD.rank == 0:
# if type is 1 - reformat file
try:
data = job.extract_variable(*funct_args)
except ValueError:
return []
return data
elif MPI.COMM_WORLD.rank == 0:
# if type is 1 - reformat file
try:
data = job.extract_variable(*funct_args)
except ValueError:
return []
return data


def get_natoms(job, funct_args):
Expand All @@ -194,7 +192,7 @@ def gather_atoms(job, funct_args):
# extract atoms return an internal data type
# this has to be reformatted
name = str(funct_args[0])
if name not in atom_properties.keys():
if name not in atom_properties:
return []

# this block prevents error when trying to access values
Expand All @@ -209,18 +207,15 @@ def gather_atoms(job, funct_args):
# number of atoms - first dimension
val = list(val)
dim = atom_properties[name]["dim"]
if dim > 1:
data = [val[x : x + dim] for x in range(0, len(val), dim)]
else:
data = list(val)
data = [val[x : x + dim] for x in range(0, len(val), dim)] if dim > 1 else list(val)
return np.array(data)


def gather_atoms_concat(job, funct_args):
# extract atoms return an internal data type
# this has to be reformatted
name = str(funct_args[0])
if name not in atom_properties.keys():
if name not in atom_properties:
return []

# this block prevents error when trying to access values
Expand All @@ -235,10 +230,7 @@ def gather_atoms_concat(job, funct_args):
# number of atoms - first dimension
val = list(val)
dim = atom_properties[name]["dim"]
if dim > 1:
data = [val[x : x + dim] for x in range(0, len(val), dim)]
else:
data = list(val)
data = [val[x : x + dim] for x in range(0, len(val), dim)] if dim > 1 else list(val)
return np.array(data)


Expand All @@ -253,7 +245,7 @@ def gather_atoms_subset(job, funct_args):
for i in range(lenids):
cids[i] = ids[i]

if name not in atom_properties.keys():
if name not in atom_properties:
return []

# this block prevents error when trying to access values
Expand All @@ -272,10 +264,7 @@ def gather_atoms_subset(job, funct_args):
# number of atoms - first dimension
val = list(val)
dim = atom_properties[name]["dim"]
if dim > 1:
data = [val[x : x + dim] for x in range(0, len(val), dim)]
else:
data = list(val)
data = [val[x : x + dim] for x in range(0, len(val), dim)] if dim > 1 else list(val)
return np.array(data)


Expand Down Expand Up @@ -487,7 +476,7 @@ def _run_lammps_mpi(argument_lst):
else:
input_dict = None
input_dict = MPI.COMM_WORLD.bcast(input_dict, root=0)
if "shutdown" in input_dict.keys() and input_dict["shutdown"]:
if "shutdown" in input_dict and input_dict["shutdown"]:
job.close()
if MPI.COMM_WORLD.rank == 0:
interface_send(socket=socket, result_dict={"result": True})
Expand Down
38 changes: 16 additions & 22 deletions pylammpsmpi/wrapper/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import warnings
from ctypes import c_double, c_int
from typing import List, Optional
from typing import Optional

import numpy as np
from ase.atoms import Atoms
Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(
self._interactive_library = library
self._cores = library.cores
elif self._cores == 1:
lammps = getattr(importlib.import_module("lammps"), "lammps")
lammps = importlib.import_module("lammps").lammps
if disable_log_file:
self._interactive_library = lammps(
cmdargs=["-screen", "none", "-log", "none"],
Expand Down Expand Up @@ -89,7 +89,7 @@ def interactive_positions_getter(self) -> np.ndarray:
positions = self._prism.vector_to_ase(positions)
return positions

def interactive_positions_setter(self, positions: List[List[float]]) -> None:
def interactive_positions_setter(self, positions: list[list[float]]) -> None:
"""
Set the positions of atoms in the interactive library.
Expand Down Expand Up @@ -142,7 +142,8 @@ def interactive_cells_setter(self, cell: np.ndarray) -> None:
lx, ly, lz, xy, xz, yz = self._prism.get_lammps_prism()
if not _check_ortho_prism(prism=self._prism):
warnings.warn(
"Warning: setting upper trangular matrix might slow down the calculation"
"Warning: setting upper trangular matrix might slow down the calculation",
stacklevel=2,
)

is_skewed = cell_is_skewed(cell=cell, tolerance=1.0e-8)
Expand All @@ -152,19 +153,16 @@ def interactive_cells_setter(self, cell: np.ndarray) -> None:
if not was_skewed:
self.interactive_lib_command(command="change_box all triclinic")
self.interactive_lib_command(
command="change_box all x final 0 %f y final 0 %f z final 0 %f xy final %f xz final %f yz final %f remap units box"
% (lx, ly, lz, xy, xz, yz),
command=f"change_box all x final 0 {lx:f} y final 0 {ly:f} z final 0 {lz:f} xy final {xy:f} xz final {xz:f} yz final {yz:f} remap units box",
)
elif was_skewed:
self.interactive_lib_command(
command="change_box all x final 0 %f y final 0 %f z final 0 %f xy final %f xz final %f yz final %f remap units box"
% (lx, ly, lz, 0.0, 0.0, 0.0),
command=f"change_box all x final 0 {lx:f} y final 0 {ly:f} z final 0 {lz:f} xy final {0.0:f} xz final {0.0:f} yz final {0.0:f} remap units box",
)
self.interactive_lib_command(command="change_box all ortho")
else:
self.interactive_lib_command(
command="change_box all x final 0 %f y final 0 %f z final 0 %f remap units box"
% (lx, ly, lz),
command=f"change_box all x final 0 {lx:f} y final 0 {ly:f} z final 0 {lz:f} remap units box",
)

def interactive_volume_getter(self) -> float:
Expand Down Expand Up @@ -198,7 +196,7 @@ def interactive_structure_setter(
dimension: int,
boundary: str,
atom_style: str,
el_eam_lst: List[str],
el_eam_lst: list[str],
calc_md: bool = True,
) -> None:
"""
Expand All @@ -224,7 +222,8 @@ def interactive_structure_setter(
self._prism = Prism(structure.cell)
if not _check_ortho_prism(prism=self._prism):
warnings.warn(
"Warning: setting upper trangular matrix might slow down the calculation"
"Warning: setting upper trangular matrix might slow down the calculation",
stacklevel=2,
)
xhi, yhi, zhi, xy, xz, yz = self._prism.get_lammps_prism()
if self._prism.is_skewed():
Expand Down Expand Up @@ -274,13 +273,11 @@ def interactive_structure_setter(
for id_eam, el_eam in enumerate(el_eam_lst):
if el_eam in el_struct_lst:
self.interactive_lib_command(
command="mass {0:3d} {1:f}".format(
id_eam + 1, atomic_masses[atomic_numbers[el_eam]]
),
command=f"mass {id_eam + 1:3d} {atomic_masses[atomic_numbers[el_eam]]:f}",
)
else:
self.interactive_lib_command(
command="mass {0:3d} {1:f}".format(id_eam + 1, 1.00),
command=f"mass {id_eam + 1:3d} {1.00:f}",
)
if not _check_ortho_prism(prism=self._prism):
positions = self._prism.vector_to_lammps(structure.positions).flatten()
Expand Down Expand Up @@ -476,10 +473,7 @@ def cell_is_skewed(cell, tolerance=1.0e-8):
"""
volume = np.abs(np.linalg.det(cell))
prod = np.linalg.norm(cell, axis=-1).prod()
if volume > 0:
if abs(volume - prod) / volume < tolerance:
return False
return True
return not (volume > 0 and abs(volume - prod) / volume < tolerance)


def _check_ortho_prism(prism, rtol=0.0, atol=1e-08):
Expand Down Expand Up @@ -553,11 +547,11 @@ def get_fixed_atom_boolean_vector(structure):
fixed_atom_vector[c_dict["kwargs"]["indices"]] = [True, True, True]
elif c_dict["name"] == "FixedPlane":
if all(np.isin(c_dict["kwargs"]["direction"], [0, 1])):
if "indices" in c_dict["kwargs"].keys():
if "indices" in c_dict["kwargs"]:
fixed_atom_vector[c_dict["kwargs"]["indices"]] = np.array(
c_dict["kwargs"]["direction"]
).astype(bool)
elif "a" in c_dict["kwargs"].keys():
elif "a" in c_dict["kwargs"]:
fixed_atom_vector[c_dict["kwargs"]["a"]] = np.array(
c_dict["kwargs"]["direction"]
).astype(bool)
Expand Down
35 changes: 17 additions & 18 deletions pylammpsmpi/wrapper/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# coding: utf-8
# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department
# Distributed under the terms of "New BSD License", see the LICENSE file.

from typing import List, Union
from typing import Union

from pylammpsmpi.wrapper.concurrent import LammpsConcurrent

Expand Down Expand Up @@ -72,7 +71,7 @@ def extract_global(self, name: str) -> Union[int, float, str]:
"""
return super().extract_global(name=name).result()

def extract_box(self) -> List[Union[float, List[float], List[int]]]:
def extract_box(self) -> list[Union[float, list[float], list[int]]]:
"""
Get the simulation box
Expand All @@ -87,7 +86,7 @@ def extract_box(self) -> List[Union[float, List[float], List[int]]]:
"""
return super().extract_box().result()

def extract_atom(self, name: str) -> Union[List[int], List[float]]:
def extract_atom(self, name: str) -> Union[list[int], list[float]]:
"""
Extract a property of the atoms
Expand All @@ -102,7 +101,7 @@ def extract_atom(self, name: str) -> Union[List[int], List[float]]:
"""
return super().extract_atom(name=name).result()

def extract_fix(self, *args) -> Union[int, float, List[Union[int, float]]]:
def extract_fix(self, *args) -> Union[int, float, list[Union[int, float]]]:
"""
Extract a fix value
Expand All @@ -116,7 +115,7 @@ def extract_fix(self, *args) -> Union[int, float, List[Union[int, float]]]:
"""
return super().extract_fix(*args).result()

def extract_variable(self, *args) -> Union[int, float, List[Union[int, float]]]:
def extract_variable(self, *args) -> Union[int, float, list[Union[int, float]]]:
"""
Extract the value of a variable
Expand Down Expand Up @@ -180,11 +179,11 @@ def reset_box(self, *args) -> None:

def generate_atoms(
self,
ids: List[int] = None,
type: List[int] = None,
x: List[float] = None,
v: List[float] = None,
image: List[int] = None,
ids: list[int] = None,
type: list[int] = None,
x: list[float] = None,
v: list[float] = None,
image: list[int] = None,
shrinkexceed: bool = False,
) -> None:
"""
Expand Down Expand Up @@ -224,11 +223,11 @@ def generate_atoms(
def create_atoms(
self,
n: int,
id: List[int],
type: List[int],
x: List[float],
v: List[float] = None,
image: List[int] = None,
id: list[int],
type: list[int],
x: list[float],
v: list[float] = None,
image: list[int] = None,
shrinkexceed: bool = False,
) -> None:
"""
Expand Down Expand Up @@ -290,7 +289,7 @@ def has_ffmpeg_support(self) -> bool:
return super().has_ffmpeg_support.result()

@property
def installed_packages(self) -> List[str]:
def installed_packages(self) -> list[str]:
return super().installed_packages.result()

def set_fix_external_callback(self, *args) -> None:
Expand All @@ -305,7 +304,7 @@ def get_neighlist(self, *args):
"""
return super().get_neighlist(*args).result()

def find_pair_neighlist(self, style: str) -> int:
def find_pair_neighlist(self, *args) -> int:
"""Find neighbor list index of pair style neighbor list
Try finding pair instance that matches style. If exact is set, the pair must
match style exactly. If exact is 0, style must only be contained. If pair is
Expand Down
Loading

0 comments on commit 6ac751c

Please sign in to comment.