Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flag write_site_properties = False in CifWriter for writing Structure.site_properties as _atom_site_{prop} #3550

Merged
merged 26 commits into from
Jan 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.11
rev: v0.1.13
hooks:
- id: ruff
args: [--fix, --unsafe-fixes]
Expand Down
4 changes: 2 additions & 2 deletions pymatgen/io/abinit/abitimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,11 +888,11 @@ def scatter_hist(self, ax: plt.Axes = None, **kwargs):
# axHistx.axis["bottom"].major_ticklabels.set_visible(False)
axHistx.set_yticks([0, 50, 100])
for tl in axHistx.get_xticklabels():
tl.set_visible(False) # noqa: FBT003
tl.set_visible(False)

# axHisty.axis["left"].major_ticklabels.set_visible(False)
for tl in axHisty.get_yticklabels():
tl.set_visible(False) # noqa: FBT003
tl.set_visible(False)
axHisty.set_xticks([0, 50, 100])

# plt.draw()
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/io/abinit/netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(self, path):
# Slicing a ncvar returns a MaskedArrray and this is really annoying
# because it can lead to unexpected behavior in e.g. calls to np.matmul!
# See also https://github.com/Unidata/netcdf4-python/issues/785
self.rootgrp.set_auto_mask(False) # noqa: FBT003
self.rootgrp.set_auto_mask(False)

def __enter__(self):
"""Activated when used in the with statement."""
Expand Down
62 changes: 37 additions & 25 deletions pymatgen/io/cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
import textwrap
import warnings
from collections import deque
from collections import defaultdict, deque
from datetime import datetime
from functools import partial
from inspect import getfullargspec as getargspec
Expand Down Expand Up @@ -1313,13 +1313,14 @@ class CifWriter:

def __init__(
self,
struct,
symprec=None,
write_magmoms=False,
significant_figures=8,
angle_tolerance=5.0,
refine_struct=True,
):
struct: Structure,
symprec: float | None = None,
write_magmoms: bool = False,
significant_figures: int = 8,
angle_tolerance: float = 5,
refine_struct: bool = True,
write_site_properties: bool = False,
) -> None:
"""
Args:
struct (Structure): structure to write
Expand All @@ -1335,14 +1336,16 @@ def __init__(
is not None.
refine_struct: Used only if symprec is not None. If True, get_refined_structure
is invoked to convert input structure from primitive to conventional.
write_site_properties (bool): Whether to write the Structure.site_properties
to the CIF as _atom_site_{property name}. Defaults to False.
"""
if write_magmoms and symprec:
warnings.warn("Magnetic symmetry cannot currently be detected by pymatgen,disabling symmetry detection.")
symprec = None

format_str = f"{{:.{significant_figures}f}}"

block = {}
block: dict[str, Any] = {}
loops = []
spacegroup = ("P 1", 1)
if symprec is not None:
Expand All @@ -1367,7 +1370,7 @@ def __init__(
block["_chemical_formula_sum"] = no_oxi_comp.formula
block["_cell_volume"] = format_str.format(lattice.volume)

_reduced_comp, fu = no_oxi_comp.get_reduced_composition_and_factor()
_, fu = no_oxi_comp.get_reduced_composition_and_factor()
block["_cell_formula_units_Z"] = str(int(fu))

if symprec is None:
Expand All @@ -1388,12 +1391,12 @@ def __init__(
loops.append(["_symmetry_equiv_pos_site_id", "_symmetry_equiv_pos_as_xyz"])

try:
symbol_to_oxinum = {str(el): float(el.oxi_state) for el in sorted(comp.elements)}
block["_atom_type_symbol"] = list(symbol_to_oxinum)
block["_atom_type_oxidation_number"] = symbol_to_oxinum.values()
symbol_to_oxi_num = {str(el): float(el.oxi_state or 0) for el in sorted(comp.elements)}
block["_atom_type_symbol"] = list(symbol_to_oxi_num)
block["_atom_type_oxidation_number"] = symbol_to_oxi_num.values()
loops.append(["_atom_type_symbol", "_atom_type_oxidation_number"])
except (TypeError, AttributeError):
symbol_to_oxinum = {el.symbol: 0 for el in sorted(comp.elements)}
symbol_to_oxi_num = {el.symbol: 0 for el in sorted(comp.elements)}

atom_site_type_symbol = []
atom_site_symmetry_multiplicity = []
Expand All @@ -1406,6 +1409,7 @@ def __init__(
atom_site_moment_crystalaxis_x = []
atom_site_moment_crystalaxis_y = []
atom_site_moment_crystalaxis_z = []
atom_site_properties: dict[str, list] = defaultdict(list)
count = 0
if symprec is None:
for site in struct:
Expand Down Expand Up @@ -1437,6 +1441,10 @@ def __init__(
atom_site_moment_crystalaxis_y.append(format_str.format(moment[1]))
atom_site_moment_crystalaxis_z.append(format_str.format(moment[2]))

if write_site_properties:
for key, val in site.properties.items():
atom_site_properties[key].append(format_str.format(val))

count += 1
else:
# The following just presents a deterministic ordering.
Expand Down Expand Up @@ -1475,17 +1483,21 @@ def __init__(
block["_atom_site_fract_y"] = atom_site_fract_y
block["_atom_site_fract_z"] = atom_site_fract_z
block["_atom_site_occupancy"] = atom_site_occupancy
loops.append(
[
"_atom_site_type_symbol",
"_atom_site_label",
"_atom_site_symmetry_multiplicity",
"_atom_site_fract_x",
"_atom_site_fract_y",
"_atom_site_fract_z",
"_atom_site_occupancy",
]
)
loop_labels = [
"_atom_site_type_symbol",
"_atom_site_label",
"_atom_site_symmetry_multiplicity",
"_atom_site_fract_x",
"_atom_site_fract_y",
"_atom_site_fract_z",
"_atom_site_occupancy",
]
if write_site_properties:
for key, vals in atom_site_properties.items():
block[f"_atom_site_{key}"] = vals
loop_labels += [f"_atom_site_{key}"]
loops.append(loop_labels)

if write_magmoms:
block["_atom_site_moment_label"] = atom_site_moment_label
block["_atom_site_moment_crystalaxis_x"] = atom_site_moment_crystalaxis_x
Expand Down
11 changes: 11 additions & 0 deletions tests/io/test_cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,17 @@ def test_cif_writer_write_file(self):
assert len(read_structs) == 2
assert [x.formula for x in read_structs] == ["Fe4 P4 O16", "C4"]

def test_cif_writer_site_properties(self):
struct = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR")
struct.add_site_property(label := "hello", [1.0] * (len(struct) - 1) + [-1.0])
out_path = f"{self.tmp_path}/test2.cif"
CifWriter(struct, write_site_properties=True).write_file(out_path)
with open(out_path) as file:
cif_str = file.read()
assert f"_atom_site_occupancy\n _atom_site_{label}\n" in cif_str
assert "Fe Fe0 1 0.21872822 0.75000000 0.47486711 1 1.0" in cif_str
assert "O O23 1 0.95662769 0.25000000 0.29286233 1 -1.0" in cif_str


class TestMagCif(PymatgenTest):
def setUp(self):
Expand Down