From 2027c9016fc96b791a1b63ab154b0d87da853422 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Fri, 8 Dec 2023 18:36:09 -0800 Subject: [PATCH 1/6] use in-place arithmetic ops --- chgnet/model/composition_model.py | 8 ++-- chgnet/model/encoders.py | 4 +- chgnet/model/layers.py | 2 +- chgnet/model/model.py | 2 +- examples/basics.ipynb | 79 ++++++++++++++++++++----------- 5 files changed, 59 insertions(+), 36 deletions(-) diff --git a/chgnet/model/composition_model.py b/chgnet/model/composition_model.py index 7f2c2825..172c29e9 100644 --- a/chgnet/model/composition_model.py +++ b/chgnet/model/composition_model.py @@ -53,7 +53,7 @@ def _get_energy(self, composition_feas: Tensor) -> Tensor: prediction associated with each composition [batchsize]. """ composition_feas = self.activation(self.fc1(composition_feas)) - composition_feas = composition_feas + self.gated_mlp(composition_feas) + composition_feas += self.gated_mlp(composition_feas) return self.fc2(composition_feas).view(-1) def forward(self, graphs: list[CrystalGraph]) -> Tensor: @@ -77,7 +77,7 @@ def _assemble_graphs(self, graphs: list[CrystalGraph]): ) if self.is_intensive: n_atom = graph.atomic_number.shape[0] - composition_fea = composition_fea / n_atom + composition_fea /= n_atom composition_feas.append(composition_fea) return torch.stack(composition_feas, dim=0) @@ -150,7 +150,7 @@ def fit( atomic_number - 1, minlength=self.max_num_elements ) if self.is_intensive: - composition_fea = composition_fea / atomic_number.shape[0] + composition_fea /= atomic_number.shape[0] composition_feas[index, :] = composition_fea e[index] = energy @@ -181,7 +181,7 @@ def _assemble_graphs(self, graphs: list[CrystalGraph]): ) if self.is_intensive: n_atom = graph.atomic_number.shape[0] - composition_fea = composition_fea / n_atom + composition_fea /= n_atom composition_feas.append(composition_fea) return torch.stack(composition_feas, dim=0).float() diff --git a/chgnet/model/encoders.py b/chgnet/model/encoders.py index d2eb4059..2b4145e1 100644 --- a/chgnet/model/encoders.py +++ b/chgnet/model/encoders.py @@ -94,11 +94,11 @@ def forward( bond_vectors (Tensor): normalized bond vectors, for tracking the bond directions [n_bond, 3] """ - neighbor = neighbor + image @ lattice + neighbor += image @ lattice bond_vectors = center - neighbor bond_lengths = torch.norm(bond_vectors, dim=1) # Normalize the bond vectors - bond_vectors = bond_vectors / bond_lengths[:, None] + bond_vectors /= bond_lengths[:, None] # We create bond features only for undirected bonds # atom1 -> atom2 and atom2 -> atom1 should share same bond_basis diff --git a/chgnet/model/layers.py b/chgnet/model/layers.py index f087ecc4..725bba5e 100644 --- a/chgnet/model/layers.py +++ b/chgnet/model/layers.py @@ -117,7 +117,7 @@ def forward( # smooth out message by bond_weights bond_weight = torch.index_select(bond_weights, 0, directed2undirected) - messages = messages * bond_weight + messages *= bond_weight # Aggregate messages new_atom_feas = aggregate( diff --git a/chgnet/model/model.py b/chgnet/model/model.py index e337c9d7..537a380d 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -521,7 +521,7 @@ def _compute( # Normalize energy if model is intensive if self.is_intensive: - energy = energy / atoms_per_graph + energy /= atoms_per_graph prediction["e"] = energy return prediction diff --git a/examples/basics.ipynb b/examples/basics.ipynb index cb348136..ef3bd9e4 100644 --- a/examples/basics.ipynb +++ b/examples/basics.ipynb @@ -19,7 +19,7 @@ " from chgnet.model import CHGNet\n", "except ImportError:\n", " # install CHGNet (only needed on Google Colab or if you didn't install CHGNet yet)\n", - " !pip install chgnet\n" + " !pip install chgnet" ] }, { @@ -37,7 +37,7 @@ "# If the above line fails in Google Colab due to numpy version issue,\n", "# please restart the runtime, and the problem will be solved\n", "\n", - "np.set_printoptions(precision=4, suppress=True)\n" + "np.set_printoptions(precision=4, suppress=True)" ] }, { @@ -89,7 +89,7 @@ " cif = urlopen(url).read().decode(\"utf-8\")\n", " structure = Structure.from_str(cif, fmt=\"cif\")\n", "\n", - "print(structure)\n" + "print(structure)" ] }, { @@ -110,7 +110,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "CHGNet initialized with 400,438 parameters\n" + "CHGNet v0.3.0 initialized with 412,525 parameters\n" ] } ], @@ -118,7 +118,7 @@ "chgnet = CHGNet.load()\n", "\n", "# Alternatively you can read your own model\n", - "# chgnet = CHGNet.from_file(model_path)\n" + "# chgnet = CHGNet.from_file(model_path)" ] }, { @@ -176,7 +176,7 @@ " (\"stress\", \"GPa\"),\n", " (\"magmom\", \"mu_B\"),\n", "]:\n", - " print(f\"CHGNet-predicted {key} ({unit}):\\n{prediction[key[0]]}\\n\")\n" + " print(f\"CHGNet-predicted {key} ({unit}):\\n{prediction[key[0]]}\\n\")" ] }, { @@ -197,7 +197,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "CHGNet initialized with 400,438 parameters\n", + "CHGNet v0.3.0 initialized with 412,525 parameters\n", "CHGNet will run on cpu\n" ] } @@ -205,7 +205,7 @@ "source": [ "from chgnet.model import StructOptimizer\n", "\n", - "relaxer = StructOptimizer()\n" + "relaxer = StructOptimizer()" ] }, { @@ -219,31 +219,54 @@ "output_type": "stream", "text": [ "\n", - "CHGNet took 29 steps. Relaxed structure:\n", + "ExpCellFilter took 52 steps. Relaxed structure:\n", "Full Formula (Li2 Mn2 O4)\n", "Reduced Formula: LiMnO2\n", - "abc : 2.865864 4.648716 5.827764\n", - "angles: 89.917211 90.239405 89.975425\n", + "abc : 2.872010 4.650813 5.834364\n", + "angles: 89.984000 90.166300 89.959166\n", "pbc : True True True\n", "Sites (8)\n", " # SP a b c magmom\n", "--- ---- --------- --------- -------- ----------\n", - " 0 Li+ 0.494018 0.479737 0.387171 0.00498427\n", - " 1 Li+ 0.008464 0.006131 0.625817 0.00512926\n", - " 2 Mn3+ 0.50073 0.502478 0.869608 3.85374\n", - " 3 Mn3+ 0.997815 -0.000319 0.139344 3.859\n", - " 4 O2- 0.502142 0.009453 0.363411 0.0253105\n", - " 5 O2- 1.00293 0.502592 0.104559 0.0366638\n", - " 6 O2- 0.493749 0.998092 0.903592 0.0365367\n", - " 7 O2- -0.002278 0.495108 0.645655 0.0248522\n" + " 0 Li+ 0.466097 0.471908 0.378517 0.00247121\n", + " 1 Li+ 0.979868 0.970313 0.606578 0.00228718\n", + " 2 Mn3+ 0.484925 0.498502 0.859377 3.87809\n", + " 3 Mn3+ 0.984186 0.998235 0.126078 3.87656\n", + " 4 O2- 0.483681 0.009578 0.350723 0.0461\n", + " 5 O2- -0.012445 0.497563 0.093668 0.0405012\n", + " 6 O2- 0.484366 -0.001507 0.89173 0.0405898\n", + " 7 O2- 0.982041 0.509921 0.634653 0.0462106\n", + "\n", + "FrechetCellFilter took 26 steps. Relaxed structure:\n", + "Full Formula (Li2 Mn2 O4)\n", + "Reduced Formula: LiMnO2\n", + "abc : 2.876522 4.657650 5.839542\n", + "angles: 90.012820 90.017011 90.010059\n", + "pbc : True True True\n", + "Sites (8)\n", + " # SP a b c magmom\n", + "--- ---- --------- --------- -------- ----------\n", + " 0 Li+ 0.452761 0.467844 0.381994 0.00225034\n", + " 1 Li+ 0.976853 0.965974 0.60083 0.0018343\n", + " 2 Mn3+ 0.487451 0.500263 0.862042 3.87584\n", + " 3 Mn3+ 0.987112 0.997857 0.123798 3.89048\n", + " 4 O2- 0.483031 0.008843 0.349475 0.0478891\n", + " 5 O2- -0.008411 0.502525 0.095111 0.0408283\n", + " 6 O2- 0.489891 -7.2e-05 0.890263 0.0426996\n", + " 7 O2- 0.984032 0.51128 0.637812 0.0465174\n" ] } ], "source": [ + "from ase.filters import ExpCellFilter, FrechetCellFilter\n", + "\n", "structure.perturb(0.1)\n", - "result = relaxer.relax(structure, verbose=False)\n", - "print(f\"\\nCHGNet took {len(result['trajectory'])} steps. Relaxed structure:\")\n", - "print(result[\"final_structure\"])\n" + "for ase_filter in (ExpCellFilter, FrechetCellFilter):\n", + " result = relaxer.relax(structure, verbose=False, ase_filter=ase_filter)\n", + " print(\n", + " f\"\\n{ase_filter.__name__} took {len(result['trajectory'])} steps. Relaxed structure:\"\n", + " )\n", + " print(result[\"final_structure\"])" ] }, { @@ -289,7 +312,7 @@ " logfile=\"md_out.log\",\n", " loginterval=100,\n", ")\n", - "md.run(50) # run a 0.1 ps MD simulation\n" + "md.run(50) # run a 0.1 ps MD simulation" ] }, { @@ -316,7 +339,7 @@ ], "source": [ "supercell = structure.make_supercell([2, 2, 2], in_place=False)\n", - "print(supercell.composition)\n" + "print(supercell.composition)" ] }, { @@ -340,7 +363,7 @@ "remove_ids = random.sample(list(range(n_Li)), n_Li // 2)\n", "\n", "supercell.remove_sites(remove_ids)\n", - "print(supercell.composition)\n" + "print(supercell.composition)" ] }, { @@ -756,7 +779,7 @@ } ], "source": [ - "result = relaxer.relax(supercell)\n" + "result = relaxer.relax(supercell)" ] }, { @@ -769,7 +792,7 @@ "import pandas as pd\n", "\n", "df_magmom = pd.DataFrame({\"Unrelaxed\": chgnet.predict_structure(supercell)[\"m\"]})\n", - "df_magmom[\"CHGNet relaxed\"] = result[\"final_structure\"].site_properties[\"magmom\"]\n" + "df_magmom[\"CHGNet relaxed\"] = result[\"final_structure\"].site_properties[\"magmom\"]" ] }, { @@ -1820,7 +1843,7 @@ ")\n", "fig.layout.legend.update(title=\"\", x=1, y=1, xanchor=\"right\", yanchor=\"top\")\n", "fig.layout.xaxis.title = \"Magnetic moment\"\n", - "fig.show()\n" + "fig.show()" ] } ], From 7a48f3c39d0aaa28aadbfe33e2dbee0ea8e019c3 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Mon, 11 Dec 2023 09:30:23 -0800 Subject: [PATCH 2/6] bump actions/setup-python to v5 --- .github/workflows/lint.yml | 2 +- .github/workflows/test.yml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d8802a8a..f6df8f1d 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -16,7 +16,7 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.11" cache: pip diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index acaf8436..3f0c747f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,7 +27,7 @@ jobs: uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.9 cache: pip @@ -36,7 +36,7 @@ jobs: - name: Install dependencies run: | pip install cython - # install ase from main branch until FrechetCellFilter is release + # install ase from main branch until FrechetCellFilter is released # TODO remove pip install git+https://gitlab.com/ase/ase pip install git+https://gitlab.com/ase/ase python setup.py build_ext --inplace @@ -61,7 +61,7 @@ jobs: - name: Check out repo uses: actions/checkout@v4 - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 name: Install Python with: python-version: "3.10" From 4ae27dad434bff75ec3625e65ffddfe307a11abe Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Mon, 11 Dec 2023 09:34:57 -0800 Subject: [PATCH 3/6] fix torch non-differentiable in-place op errors --- chgnet/model/composition_model.py | 6 +++--- chgnet/model/encoders.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/chgnet/model/composition_model.py b/chgnet/model/composition_model.py index 172c29e9..d9d6f544 100644 --- a/chgnet/model/composition_model.py +++ b/chgnet/model/composition_model.py @@ -77,7 +77,7 @@ def _assemble_graphs(self, graphs: list[CrystalGraph]): ) if self.is_intensive: n_atom = graph.atomic_number.shape[0] - composition_fea /= n_atom + composition_fea = composition_fea / n_atom composition_feas.append(composition_fea) return torch.stack(composition_feas, dim=0) @@ -150,7 +150,7 @@ def fit( atomic_number - 1, minlength=self.max_num_elements ) if self.is_intensive: - composition_fea /= atomic_number.shape[0] + composition_fea = composition_fea / atomic_number.shape[0] composition_feas[index, :] = composition_fea e[index] = energy @@ -181,7 +181,7 @@ def _assemble_graphs(self, graphs: list[CrystalGraph]): ) if self.is_intensive: n_atom = graph.atomic_number.shape[0] - composition_fea /= n_atom + composition_fea = composition_fea / n_atom composition_feas.append(composition_fea) return torch.stack(composition_feas, dim=0).float() diff --git a/chgnet/model/encoders.py b/chgnet/model/encoders.py index 2b4145e1..e2675879 100644 --- a/chgnet/model/encoders.py +++ b/chgnet/model/encoders.py @@ -98,7 +98,7 @@ def forward( bond_vectors = center - neighbor bond_lengths = torch.norm(bond_vectors, dim=1) # Normalize the bond vectors - bond_vectors /= bond_lengths[:, None] + bond_vectors = bond_vectors / bond_lengths[:, None] # We create bond features only for undirected bonds # atom1 -> atom2 and atom2 -> atom1 should share same bond_basis From a672fe6d02c1b9ff4a07ff0b017fda955a5c522e Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Mon, 11 Dec 2023 09:35:43 -0800 Subject: [PATCH 4/6] add ase git install advice and recommend to use FrechetCellFilter for CHGNet structural relaxation --- chgnet/model/dynamics.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py index 95ea7c20..43f07b4c 100644 --- a/chgnet/model/dynamics.py +++ b/chgnet/model/dynamics.py @@ -10,7 +10,6 @@ import torch from ase import Atoms, units from ase.calculators.calculator import Calculator, all_changes, all_properties -from ase.filters import Filter, FrechetCellFilter from ase.md.npt import NPT from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen, NPTBerendsen from ase.md.nvtberendsen import NVTBerendsen @@ -33,6 +32,18 @@ from ase.io import Trajectory from ase.optimize.optimize import Optimizer +try: + from ase.filters import Filter, FrechetCellFilter +except ImportError: + print( + "We recommend using ase's unreleased FrechetCellFilter over ExpCellFilter for " + "CHGNet structural relaxation. ExpCellFilter has a bug in its calculation " + "of cell gradients which was fixed in FrechetCellFilter. Otherwise the two " + "are identical. ExpCellFilter was kept only for backwards compatibility and " + "should no longer be used. Run pip install git+https://gitlab.com/ase/ase to " + "install from main branch." + ) + # We would like to thank M3GNet develop team for this module # source: https://github.com/materialsvirtuallab/m3gnet From 2fb533d22714884a8593d5857843e674a3670404 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Mon, 11 Dec 2023 10:05:19 -0800 Subject: [PATCH 5/6] allow StructOptimizer ase_filter keyword to be str allowed values err msg on invalid name --- chgnet/model/dynamics.py | 21 ++++++++++++-- examples/basics.ipynb | 63 ++++++++++++++++++++-------------------- 2 files changed, 49 insertions(+), 35 deletions(-) diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py index 43f07b4c..0e6f1fcf 100644 --- a/chgnet/model/dynamics.py +++ b/chgnet/model/dynamics.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import inspect import io import pickle import sys @@ -222,7 +223,7 @@ def relax( fmax: float | None = 0.1, steps: int | None = 500, relax_cell: bool | None = True, - ase_filter: Filter = FrechetCellFilter, + ase_filter: str | Filter = FrechetCellFilter, save_path: str | None = None, loginterval: int | None = 1, crystal_feas_save_path: str | None = None, @@ -239,8 +240,8 @@ def relax( Default = 500 relax_cell (bool | None): Whether to relax the cell as well. Default = True - ase_filter (ase.filters.Filter): The filter to apply to the atoms object - for relaxation. Default = FrechetCellFilter + ase_filter (str | ase.filters.Filter): The filter to apply to the atoms + object for relaxation. Default = FrechetCellFilter Used to default to ExpCellFilter but was removed due to bug reported in https://gitlab.com/ase/ase/-/issues/1321 and fixed in https://gitlab.com/ase/ase/-/merge_requests/3024. @@ -259,6 +260,20 @@ def relax( dict[str, Structure | TrajectoryObserver]: A dictionary with 'final_structure' and 'trajectory'. """ + if isinstance(ase_filter, str): + try: + import ase.filters + + ase_filter = getattr(ase.filters, ase_filter) + except AttributeError as exc: + valid_filter_names = [ + name + for name, cls in inspect.getmembers(ase.filters, inspect.isclass) + if issubclass(cls, Filter) + ] + raise ValueError( + f"Invalid {ase_filter=}, must be one of {valid_filter_names}. " + ) from exc if isinstance(atoms, Structure): atoms = atoms.to_ase_atoms() diff --git a/examples/basics.ipynb b/examples/basics.ipynb index ef3bd9e4..cedde93b 100644 --- a/examples/basics.ipynb +++ b/examples/basics.ipynb @@ -219,53 +219,52 @@ "output_type": "stream", "text": [ "\n", - "ExpCellFilter took 52 steps. Relaxed structure:\n", + "FrechetCellFilter took 49 steps. Relaxed structure:\n", + "\n", "Full Formula (Li2 Mn2 O4)\n", "Reduced Formula: LiMnO2\n", - "abc : 2.872010 4.650813 5.834364\n", - "angles: 89.984000 90.166300 89.959166\n", + "abc : 2.876179 4.609830 5.862965\n", + "angles: 89.863012 89.706707 89.946402\n", "pbc : True True True\n", "Sites (8)\n", - " # SP a b c magmom\n", - "--- ---- --------- --------- -------- ----------\n", - " 0 Li+ 0.466097 0.471908 0.378517 0.00247121\n", - " 1 Li+ 0.979868 0.970313 0.606578 0.00228718\n", - " 2 Mn3+ 0.484925 0.498502 0.859377 3.87809\n", - " 3 Mn3+ 0.984186 0.998235 0.126078 3.87656\n", - " 4 O2- 0.483681 0.009578 0.350723 0.0461\n", - " 5 O2- -0.012445 0.497563 0.093668 0.0405012\n", - " 6 O2- 0.484366 -0.001507 0.89173 0.0405898\n", - " 7 O2- 0.982041 0.509921 0.634653 0.0462106\n", + " # SP a b c magmom\n", + "--- ---- -------- -------- -------- ----------\n", + " 0 Li+ 0.481659 0.517177 0.377268 0.00253999\n", + " 1 Li+ 0.012506 0.04361 0.611305 0.00241017\n", + " 2 Mn3+ 0.508999 0.531361 0.860983 3.88035\n", + " 3 Mn3+ 1.00639 0.032734 0.130496 3.87339\n", + " 4 O2- 0.506676 0.039479 0.353671 0.0450131\n", + " 5 O2- 0.009236 0.531155 0.094449 0.0397426\n", + " 6 O2- 0.503994 0.029776 0.896556 0.0407182\n", + " 7 O2- 0.009769 0.521232 0.636952 0.0473042\n", + "\n", + "ExpCellFilter took 83 steps. Relaxed structure:\n", "\n", - "FrechetCellFilter took 26 steps. Relaxed structure:\n", "Full Formula (Li2 Mn2 O4)\n", "Reduced Formula: LiMnO2\n", - "abc : 2.876522 4.657650 5.839542\n", - "angles: 90.012820 90.017011 90.010059\n", + "abc : 2.874395 4.611958 5.852410\n", + "angles: 89.943237 89.910969 89.994579\n", "pbc : True True True\n", "Sites (8)\n", - " # SP a b c magmom\n", - "--- ---- --------- --------- -------- ----------\n", - " 0 Li+ 0.452761 0.467844 0.381994 0.00225034\n", - " 1 Li+ 0.976853 0.965974 0.60083 0.0018343\n", - " 2 Mn3+ 0.487451 0.500263 0.862042 3.87584\n", - " 3 Mn3+ 0.987112 0.997857 0.123798 3.89048\n", - " 4 O2- 0.483031 0.008843 0.349475 0.0478891\n", - " 5 O2- -0.008411 0.502525 0.095111 0.0408283\n", - " 6 O2- 0.489891 -7.2e-05 0.890263 0.0426996\n", - " 7 O2- 0.984032 0.51128 0.637812 0.0465174\n" + " # SP a b c magmom\n", + "--- ---- -------- -------- -------- ----------\n", + " 0 Li+ 0.474099 0.522936 0.375014 0.00291404\n", + " 1 Li+ 0.007464 0.033067 0.61184 0.00261617\n", + " 2 Mn3+ 0.512206 0.531325 0.861133 3.87057\n", + " 3 Mn3+ 1.00718 0.030874 0.130145 3.86706\n", + " 4 O2- 0.504485 0.035636 0.353984 0.0443497\n", + " 5 O2- 0.010645 0.532059 0.095251 0.0381828\n", + " 6 O2- 0.510762 0.030743 0.896625 0.0382355\n", + " 7 O2- 0.012389 0.529884 0.637688 0.0455911\n" ] } ], "source": [ - "from ase.filters import ExpCellFilter, FrechetCellFilter\n", - "\n", "structure.perturb(0.1)\n", - "for ase_filter in (ExpCellFilter, FrechetCellFilter):\n", + "for ase_filter in (\"FrechetCellFilter\", \"ExpCellFilter\"):\n", " result = relaxer.relax(structure, verbose=False, ase_filter=ase_filter)\n", - " print(\n", - " f\"\\n{ase_filter.__name__} took {len(result['trajectory'])} steps. Relaxed structure:\"\n", - " )\n", + " n_steps = len(result[\"trajectory\"])\n", + " print(f\"\\n{ase_filter} took {n_steps} steps. Relaxed structure:\\n\")\n", " print(result[\"final_structure\"])" ] }, From ef23049798e10dcba453cabb698d9af32291204d Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Mon, 11 Dec 2023 12:00:17 -0800 Subject: [PATCH 6/6] revert BondEncoder neighbor calc --- chgnet/model/encoders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chgnet/model/encoders.py b/chgnet/model/encoders.py index e2675879..d2eb4059 100644 --- a/chgnet/model/encoders.py +++ b/chgnet/model/encoders.py @@ -94,7 +94,7 @@ def forward( bond_vectors (Tensor): normalized bond vectors, for tracking the bond directions [n_bond, 3] """ - neighbor += image @ lattice + neighbor = neighbor + image @ lattice bond_vectors = center - neighbor bond_lengths = torch.norm(bond_vectors, dim=1) # Normalize the bond vectors