Skip to content

Commit

Permalink
Fix ValueError: Invalid fmt with Structure.to(fmt='yml') (#3557)
Browse files Browse the repository at this point in the history
* fix Structure.to(fmt='yml'), add Structure.FileFormats to ensure consistent format support and value err msg in Structure.to and Structure.from_file

* test_structure.py cover fmt='yml' in test_to_from_file_string

fix pytest.raises expected msg
  • Loading branch information
janosh authored Jan 16, 2024
1 parent 88921da commit c14d67e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
32 changes: 17 additions & 15 deletions pymatgen/core/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@

from pymatgen.util.typing import CompositionLike, SpeciesLike

FileFormats = Literal["cif", "poscar", "cssr", "json", "yaml", "yml", "xsf", "mcsqs", "res", ""]


class Neighbor(Site):
"""Simple Site subclass to contain a neighboring atom that skips all the unnecessary checks for speed. Can be
Expand Down Expand Up @@ -447,13 +449,13 @@ def is_valid(self, tol: float = DISTANCE_TOLERANCE) -> bool:
return np.min(all_dists) > tol

@abstractmethod
def to(self, filename: str = "", fmt: str = "") -> str | None:
def to(self, filename: str = "", fmt: FileFormats = "") -> str | None:
"""Generates string representations (cif, json, poscar, ....) of SiteCollections (e.g.,
molecules / structures). Should return str or None if written to a file.
"""
raise NotImplementedError

def to_file(self, filename: str = "", fmt: str = "") -> str | None:
def to_file(self, filename: str = "", fmt: FileFormats = "") -> str | None:
"""A more intuitive alias for .to()."""
return self.to(filename, fmt)

Expand Down Expand Up @@ -2653,7 +2655,7 @@ def from_dict(cls, dct: dict[str, Any], fmt: Literal["abivars"] | None = None) -
charge = dct.get("charge")
return cls.from_sites(sites, charge=charge, properties=dct.get("properties"))

def to(self, filename: str | Path = "", fmt: str = "", **kwargs) -> str:
def to(self, filename: str | Path = "", fmt: FileFormats = "", **kwargs) -> str:
"""Outputs the structure to a file or string.
Args:
Expand All @@ -2663,7 +2665,7 @@ def to(self, filename: str | Path = "", fmt: str = "", **kwargs) -> str:
fmt (str): Format to output to. Defaults to JSON unless filename
is provided. If fmt is specifies, it overrides whatever the
filename is. Options include "cif", "poscar", "cssr", "json",
"xsf", "mcsqs", "prismatic", "yaml", "fleur-inpgen".
"xsf", "mcsqs", "prismatic", "yaml", "yml", "fleur-inpgen".
Non-case sensitive.
**kwargs: Kwargs passthru to relevant methods. E.g., This allows
the passing of parameters like symprec to the
Expand All @@ -2673,7 +2675,7 @@ def to(self, filename: str | Path = "", fmt: str = "", **kwargs) -> str:
str: String representation of molecule in given format. If a filename
is provided, the same string is written to the file.
"""
filename, fmt = str(filename), fmt.lower()
filename, fmt = str(filename), cast(FileFormats, fmt.lower())

if fmt == "cif" or fnmatch(filename.lower(), "*.cif*"):
from pymatgen.io.cif import CifWriter
Expand Down Expand Up @@ -2722,7 +2724,7 @@ def to(self, filename: str | Path = "", fmt: str = "", **kwargs) -> str:
from pymatgen.io.prismatic import Prismatic

return Prismatic(self).to_str()
elif fmt == "yaml" or fnmatch(filename, "*.yaml*") or fnmatch(filename, "*.yml*"):
elif fmt in ("yaml", "yml") or fnmatch(filename, "*.yaml*") or fnmatch(filename, "*.yml*"):
yaml = YAML()
str_io = StringIO()
yaml.dump(self.as_dict(), str_io)
Expand All @@ -2747,7 +2749,7 @@ def to(self, filename: str | Path = "", fmt: str = "", **kwargs) -> str:
else:
if fmt == "":
raise ValueError(f"Format not specified and could not infer from {filename=}")
raise ValueError(f"Invalid format={fmt!r}")
raise ValueError(f"Invalid {fmt=}, valid options are {get_args(FileFormats)}")

if filename:
writer.write_file(filename)
Expand All @@ -2757,7 +2759,7 @@ def to(self, filename: str | Path = "", fmt: str = "", **kwargs) -> str:
def from_str( # type: ignore[override]
cls,
input_string: str,
fmt: Literal["cif", "poscar", "cssr", "json", "yaml", "xsf", "mcsqs", "res"],
fmt: FileFormats,
primitive: bool = False,
sort: bool = False,
merge_tol: float = 0.0,
Expand All @@ -2768,7 +2770,7 @@ def from_str( # type: ignore[override]
Args:
input_string (str): String to parse.
fmt (str): A file format specification. One of "cif", "poscar", "cssr",
"json", "yaml", "xsf", "mcsqs".
"json", "yaml", "yml", "xsf", "mcsqs", "res".
primitive (bool): Whether to find a primitive cell. Defaults to
False.
sort (bool): Whether to sort the sites in accordance to the default
Expand Down Expand Up @@ -2797,12 +2799,12 @@ def from_str( # type: ignore[override]
cssr = Cssr.from_str(input_string, **kwargs)
struct = cssr.structure
elif fmt_low == "json":
d = json.loads(input_string)
struct = Structure.from_dict(d)
elif fmt_low == "yaml":
dct = json.loads(input_string)
struct = Structure.from_dict(dct)
elif fmt_low in ("yaml", "yml"):
yaml = YAML()
d = yaml.load(input_string)
struct = Structure.from_dict(d)
dct = yaml.load(input_string)
struct = Structure.from_dict(dct)
elif fmt_low == "xsf":
from pymatgen.io.xcrysden import XSF

Expand All @@ -2825,7 +2827,7 @@ def from_str( # type: ignore[override]

struct = ResIO.structure_from_str(input_string, **kwargs)
else:
raise ValueError(f"Unrecognized format `{fmt}`!")
raise ValueError(f"Invalid {fmt=}, valid options are {get_args(FileFormats)}")

if sort:
struct = struct.get_sorted_structure()
Expand Down
6 changes: 3 additions & 3 deletions tests/core/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ def test_get_dist_matrix(self):
assert_allclose(self.struct.distance_matrix, ans)

def test_to_from_file_and_string(self):
for fmt in ["cif", "json", "poscar", "cssr"]:
for fmt in ("cif", "json", "poscar", "cssr"):
struct = self.struct.to(fmt=fmt)
assert struct is not None
ss = IStructure.from_str(struct, fmt=fmt)
Expand Down Expand Up @@ -851,7 +851,7 @@ def test_to_from_file_and_string(self):

with pytest.raises(ValueError, match="Format not specified and could not infer from filename='whatever'"):
self.struct.to(filename="whatever")
with pytest.raises(ValueError, match="Invalid format='badformat'"):
with pytest.raises(ValueError, match="Invalid fmt='badformat'"):
self.struct.to(fmt="badformat")

self.struct.to(filename=(gz_json_path := "POSCAR.testing.gz"))
Expand Down Expand Up @@ -1284,7 +1284,7 @@ def test_to_from_abivars(self):

def test_to_from_file_string(self):
# to/from string
for fmt in ["cif", "json", "poscar", "cssr", "yaml", "xsf", "res"]:
for fmt in ("cif", "json", "poscar", "cssr", "yaml", "yml", "xsf", "res"):
struct = self.struct.to(fmt=fmt)
assert struct is not None
ss = Structure.from_str(struct, fmt=fmt)
Expand Down

0 comments on commit c14d67e

Please sign in to comment.