Skip to content

Commit

Permalink
ruff fixes + type annotations (#156)
Browse files Browse the repository at this point in the history
* ruff auto fixes

* fix ruff FBT001 FBT002

* ruff select = ["ALL"] and fix legacy errors

* fix TypeError: unhashable type

'list'chgnet/model/functions.py:71: in __init__
    if hidden_dim in {None, 0}:
  • Loading branch information
janosh authored May 20, 2024
1 parent 455f4d8 commit d0632a1
Show file tree
Hide file tree
Showing 17 changed files with 152 additions and 129 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.3
rev: v0.4.4
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -46,7 +46,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.2.0
rev: v9.3.0
hooks:
- id: eslint
types: [file]
Expand Down
28 changes: 19 additions & 9 deletions chgnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
structures: list[Structure],
energies: list[float],
forces: list[Sequence[Sequence[float]]],
*,
stresses: list[Sequence[Sequence[float]]] | None = None,
magmoms: list[Sequence[Sequence[float]]] | None = None,
structure_ids: list | None = None,
Expand Down Expand Up @@ -63,7 +64,7 @@ def __init__(
"""
for idx, struct in enumerate(structures):
if not isinstance(struct, Structure):
raise ValueError(f"{idx} is not a pymatgen Structure object: {struct}")
raise TypeError(f"{idx} is not a pymatgen Structure object: {struct}")
for name in "energies forces stresses magmoms structure_ids".split():
labels = locals()[name]
if labels is not None and len(labels) != len(structures):
Expand All @@ -80,7 +81,7 @@ def __init__(
self.keys = np.arange(len(structures))
if shuffle:
random.shuffle(self.keys)
print(f"{len(structures)} structures imported")
print(f"{type(self).__name__} imported {len(structures):,} structures")
self.graph_converter = graph_converter or CrystalGraphConverter(
atom_graph_cutoff=6, bond_graph_cutoff=3
)
Expand All @@ -91,11 +92,12 @@ def __init__(
def from_vasp(
cls,
file_root: str,
*,
check_electronic_convergence: bool = True,
save_path: str | None = None,
graph_converter: CrystalGraphConverter | None = None,
shuffle: bool = True,
):
) -> StructureData:
"""Parse VASP output files into structures and labels and feed into the dataset.
Args:
Expand Down Expand Up @@ -196,6 +198,7 @@ class CIFData(Dataset):
def __init__(
self,
cif_path: str,
*,
labels: str | dict = "labels.json",
targets: TrainTask = "efsm",
graph_converter: CrystalGraphConverter | None = None,
Expand Down Expand Up @@ -311,6 +314,7 @@ class GraphData(Dataset):
def __init__(
self,
graph_path: str,
*,
labels: str | dict = "labels.json",
targets: TrainTask = "efsm",
exclude: str | list | None = None,
Expand Down Expand Up @@ -429,6 +433,7 @@ def get_train_val_test_loader(
self,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
*,
train_key: list[str] | None = None,
val_key: list[str] | None = None,
test_key: list[str] | None = None,
Expand Down Expand Up @@ -541,6 +546,7 @@ def __init__(
self,
data: str | dict,
graph_converter: CrystalGraphConverter,
*,
targets: TrainTask = "efsm",
energy_key: str = "energy_per_atom",
force_key: str = "force",
Expand Down Expand Up @@ -580,14 +586,14 @@ def __init__(
elif isinstance(data, dict):
self.data = data
else:
raise ValueError(f"data must be JSON path or dictionary, got {type(data)}")
raise TypeError(f"data must be JSON path or dictionary, got {type(data)}")

self.keys = [
(mp_id, graph_id) for mp_id, dct in self.data.items() for graph_id in dct
]
if shuffle:
random.shuffle(self.keys)
print(f"{len(self.data)} mp_ids, {len(self)} structures imported")
print(f"{len(self.data)} MP IDs, {len(self)} structures imported")
self.graph_converter = graph_converter
self.energy_key = energy_key
self.force_key = force_key
Expand All @@ -602,7 +608,7 @@ def __len__(self) -> int:
return len(self.keys)

@functools.cache # Cache loaded structures
def __getitem__(self, idx):
def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict[str, Tensor]]:
"""Get one item in the dataset.
Returns:
Expand Down Expand Up @@ -654,6 +660,7 @@ def get_train_val_test_loader(
self,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
*,
train_key: list[str] | None = None,
val_key: list[str] | None = None,
test_key: list[str] | None = None,
Expand Down Expand Up @@ -747,7 +754,7 @@ def get_train_val_test_loader(
return train_loader, val_loader, test_loader


def collate_graphs(batch_data: list):
def collate_graphs(batch_data: list) -> tuple[list[CrystalGraph], dict[str, Tensor]]:
"""Collate of list of (graph, target) into batch data.
Args:
Expand Down Expand Up @@ -777,13 +784,14 @@ def collate_graphs(batch_data: list):

def get_train_val_test_loader(
dataset: Dataset,
*,
batch_size: int = 64,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
return_test: bool = True,
num_workers: int = 0,
pin_memory: bool = True,
):
) -> tuple[DataLoader, DataLoader, DataLoader]:
"""Randomly partition a dataset into train, val, test loaders.
Args:
Expand Down Expand Up @@ -842,7 +850,9 @@ def get_train_val_test_loader(
return train_loader, val_loader


def get_loader(dataset, batch_size=64, num_workers=0, pin_memory=True):
def get_loader(
dataset, *, batch_size: int = 64, num_workers: int = 0, pin_memory: bool = True
) -> DataLoader:
"""Get a dataloader from a dataset.
Args:
Expand Down
2 changes: 1 addition & 1 deletion chgnet/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class CrystalGraphConverter(nn.Module):

def __init__(
self,
*,
atom_graph_cutoff: float = 6,
bond_graph_cutoff: float = 3,
algorithm: Literal["legacy", "fast"] = "fast",
Expand Down Expand Up @@ -274,7 +275,6 @@ def set_isolated_atom_response(
None
"""
self.on_isolated_atoms = on_isolated_atoms
return

def as_dict(self) -> dict[str, str | float]:
"""Save the args of the graph converter."""
Expand Down
2 changes: 1 addition & 1 deletion chgnet/graph/crystalgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def __repr__(self) -> str:
)

@property
def num_isolated_atoms(self):
def num_isolated_atoms(self) -> int:
"""Number of isolated atoms given the atom graph cutoff
Isolated atoms are disconnected nodes in the atom graph
that will not get updated in CHGNet.
Expand Down
52 changes: 28 additions & 24 deletions chgnet/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class UndirectedEdge(Edge):

__hash__ = Edge.__hash__

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
"""Check if two undirected edges are equal."""
return set(self.nodes) == set(other.nodes) and self.info == other.info

Expand Down Expand Up @@ -178,16 +178,16 @@ def add_edge(self, center_index, neighbor_index, image, distance) -> None:
):
# There is an undirected edge with similar length and only one of
# the directed edges associated has been added
added_DE = self.directed_edges_list[
added_dir_edge = self.directed_edges_list[
undirected_edge.info["directed_edge_index"][0]
]

# See if the DE that's associated to this UDE
# is the reverse of our DE
if added_DE == this_directed_edge:
if added_dir_edge == this_directed_edge:
# Add UDE index to this DE
this_directed_edge.info["undirected_edge_index"] = (
added_DE.info["undirected_edge_index"]
added_dir_edge.info["undirected_edge_index"]
)

# At the center node, draw edge with this DE
Expand Down Expand Up @@ -217,7 +217,7 @@ def add_edge(self, center_index, neighbor_index, image, distance) -> None:
self.nodes[center_index].add_neighbor(neighbor_index, this_directed_edge)
self.directed_edges_list.append(this_directed_edge)

def adjacency_list(self):
def adjacency_list(self) -> tuple[list[list[int]], list[int]]:
"""Get the adjacency list
Return:
graph: the adjacency list
Expand All @@ -240,7 +240,7 @@ def adjacency_list(self):
]
return graph, directed2undirected

def line_graph_adjacency_list(self, cutoff):
def line_graph_adjacency_list(self, cutoff) -> tuple[list[list[int]], list[int]]:
"""Get the line graph adjacency list.
Args:
Expand All @@ -264,11 +264,12 @@ def line_graph_adjacency_list(self, cutoff):
a list of length = num_undirected_edge that
maps the undirected edge index to one of its directed edges indices
"""
assert len(self.directed_edges_list) == 2 * len(self.undirected_edges_list), (
f"Error: number of directed edges={len(self.directed_edges_list)} != 2 * "
f"number of undirected edges={len(self.directed_edges_list)}!"
f"This indicates directed edges are not complete"
)
if len(self.directed_edges_list) != 2 * len(self.undirected_edges_list):
raise ValueError(
f"Error: number of directed edges={len(self.directed_edges_list)} != 2 "
f"* number of undirected edges={len(self.directed_edges_list)}!"
f"This indicates directed edges are not complete"
)
line_graph = []
undirected2directed = []

Expand All @@ -285,39 +286,42 @@ def line_graph_adjacency_list(self, cutoff):
# if encountered exception,
# it means after Atom_Graph creation, the UDE has only 1 DE associated
# This exception is not encountered from the develop team's experience
assert len(u_edge.info["directed_edge_index"]) == 2, (
"Did not find 2 Directed_edges !!!"
f"undirected edge {u_edge} has:"
f"edge.info['directed_edge_index'] = "
f"{u_edge.info['directed_edge_index']}"
f"len directed_edges_list = {len(self.directed_edges_list)}"
f"len undirected_edges_list = {len(self.undirected_edges_list)}"
)
if len(u_edge.info["directed_edge_index"]) != 2:
raise ValueError(
"Did not find 2 Directed_edges !!!"
f"undirected edge {u_edge} has:"
f"edge.info['directed_edge_index'] = "
f"{u_edge.info['directed_edge_index']}"
f"len directed_edges_list = {len(self.directed_edges_list)}"
f"len undirected_edges_list = {len(self.undirected_edges_list)}"
)

# This UDE is valid to be considered as a node in Bond_Graph

# Get the two ends (centers) and the two DE associated with this UDE
# DE1 should have center=center1 and DE2 should have center=center2
# We will need to find directed edges with center = center1
# and create angles with DE1, then do the same for center2 and DE2
for center, DE in zip(u_edge.nodes, u_edge.info["directed_edge_index"]):
for center, dir_edge in zip(
u_edge.nodes, u_edge.info["directed_edge_index"]
):
for directed_edges in self.nodes[center].neighbors.values():
for directed_edge in directed_edges:
if directed_edge.index == DE:
if directed_edge.index == dir_edge:
continue
if directed_edge.info["distance"] < cutoff:
line_graph.append(
[
center,
u_edge.index,
DE,
dir_edge,
directed_edge.info["undirected_edge_index"],
directed_edge.index,
]
)
return line_graph, undirected2directed

def undirected2directed(self):
def undirected2directed(self) -> list[int]:
"""The index map from undirected_edge index to one of its directed_edge
index.
"""
Expand All @@ -326,7 +330,7 @@ def undirected2directed(self):
for undirected_edge in self.undirected_edges_list
]

def as_dict(self):
def as_dict(self) -> dict:
"""Return dictionary serialization of a Graph."""
return {
"nodes": self.nodes,
Expand Down
15 changes: 9 additions & 6 deletions chgnet/model/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class Fourier(nn.Module):
"""Fourier Expansion for angle features."""

def __init__(self, order: int = 5, learnable: bool = False) -> None:
def __init__(self, *, order: int = 5, learnable: bool = False) -> None:
"""Initialize the Fourier expansion.
Args:
Expand Down Expand Up @@ -47,6 +47,7 @@ class RadialBessel(torch.nn.Module):

def __init__(
self,
*,
num_radial: int = 9,
cutoff: float = 5,
learnable: bool = False,
Expand Down Expand Up @@ -90,7 +91,7 @@ def __init__(
self.smooth_cutoff = None

def forward(
self, dist: Tensor, return_smooth_factor: bool = False
self, dist: Tensor, *, return_smooth_factor: bool = False
) -> Tensor | tuple[Tensor, Tensor]:
"""Apply Bessel expansion to a feature Tensor.
Expand Down Expand Up @@ -122,8 +123,8 @@ class GaussianExpansion(nn.Module):

def __init__(
self,
min: float = 0,
max: float = 5,
min: float = 0, # noqa: A002
max: float = 5, # noqa: A002
step: float = 0.5,
var: float | None = None,
) -> None:
Expand All @@ -137,8 +138,10 @@ def __init__(
var (float): variance in gaussian filter, default to step
"""
super().__init__()
assert min < max
assert max - min > step
if min >= max:
raise ValueError(f"{min=} must be less than {max=}")
if max - min <= step:
raise ValueError(f"{max - min=} must be greater than {step=}")
self.register_buffer("gaussian_centers", torch.arange(min, max + step, step))
self.var = var or step
if self.var <= 0:
Expand Down
Loading

0 comments on commit d0632a1

Please sign in to comment.