Skip to content

Commit

Permalink
Add bold flag to latexify (#3516)
Browse files Browse the repository at this point in the history
* better var names

* add bold flag to latexify to support bold subscripts

latex doesn't respect the boldness of surrounding text for subscripts in math mode

* simplify tick generation in PhononBSPlotter._make_ticks

* add n_bands property to PhononBSPlotter
  • Loading branch information
janosh authored Dec 15, 2023
1 parent d5a7c92 commit 4a94f9c
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ def elastic_centered_graph(self, start_node=None):
return centered_connected_subgraph

@staticmethod
def _edgekey_to_edgedictkey(key):
def _edge_key_to_edge_dict_key(key):
if isinstance(key, int):
return str(key)
if isinstance(key, str):
Expand Down Expand Up @@ -817,7 +817,7 @@ def as_dict(self):
in2 = node2stringindex[n2]
new_dict_of_dicts[in1][in2] = {}
for ie, edge_data in edges_dict.items():
ied = self._edgekey_to_edgedictkey(ie)
ied = self._edge_key_to_edge_dict_key(ie)
new_dict_of_dicts[in1][in2][ied] = jsanitize(edge_data)
return {
"@module": type(self).__module__,
Expand Down
12 changes: 6 additions & 6 deletions pymatgen/phonon/bandstructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,17 @@ def __init__(
self.nac_frequencies = []
self.nac_eigendisplacements = []
if nac_frequencies is not None:
for t in nac_frequencies:
self.nac_frequencies.append(([i / np.linalg.norm(t[0]) for i in t[0]], t[1]))
for freq in nac_frequencies:
self.nac_frequencies.append(([idx / np.linalg.norm(freq[0]) for idx in freq[0]], freq[1]))
if nac_eigendisplacements is not None:
for t in nac_eigendisplacements:
self.nac_eigendisplacements.append(([i / np.linalg.norm(t[0]) for i in t[0]], t[1]))
for freq in nac_eigendisplacements:
self.nac_eigendisplacements.append(([idx / np.linalg.norm(freq[0]) for idx in freq[0]], freq[1]))

def min_freq(self) -> tuple[Kpoint, float]:
"""Returns the point where the minimum frequency is reached and its value."""
i = np.unravel_index(np.argmin(self.bands), self.bands.shape)
idx = np.unravel_index(np.argmin(self.bands), self.bands.shape)

return self.qpoints[i[1]], self.bands[i]
return self.qpoints[idx[1]], self.bands[idx]

def has_imaginary_freq(self, tol: float = 1e-5) -> bool:
"""True if imaginary frequencies are present in the BS."""
Expand Down
58 changes: 27 additions & 31 deletions pymatgen/phonon/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,30 +268,26 @@ def __init__(self, bs: PhononBandStructureSymmLine, label: str | None = None) ->
"not along symmetry lines won't work)"
)
self._bs = bs
self._nb_bands = bs.nb_bands
self._label = label

@property
def n_bands(self) -> int:
"""Number of bands."""
return self._bs.nb_bands

def _make_ticks(self, ax: Axes) -> Axes:
"""Utility private method to add ticks to a band structure."""
ticks = self.get_ticks()
# Sanitize only plot the uniq values
uniq_d = []
uniq_l = []
temp_ticks = list(zip(ticks["distance"], ticks["label"]))
for idx, tt in enumerate(temp_ticks):
if idx == 0:
uniq_d.append(tt[0])
uniq_l.append(tt[1])
else:
uniq_d.append(tt[0])
uniq_l.append(tt[1])

ax.set_xticks(uniq_d)
ax.set_xticklabels(uniq_l)
# zip to sanitize, only plot the uniq values
ticks_labels = list(zip(*zip(ticks["distance"], ticks["label"])))
if ticks_labels:
ax.set_xticks(ticks_labels[0])
ax.set_xticklabels(ticks_labels[1])

for idx, label in enumerate(ticks["label"]):
if label is not None:
ax.axvline(ticks["distance"][idx], color="k")
ax.axvline(ticks["distance"][idx], color="black")
return ax

def bs_plot_data(self) -> dict[str, Any]:
Expand All @@ -315,7 +311,7 @@ def bs_plot_data(self) -> dict[str, Any]:
frequency.append([])
distance.append([self._bs.distance[j] for j in range(branch["start_index"], branch["end_index"] + 1)])

for idx in range(self._nb_bands):
for idx in range(self.n_bands):
frequency[-1].append(
[self._bs.bands[idx][j] for j in range(branch["start_index"], branch["end_index"] + 1)]
)
Expand Down Expand Up @@ -346,14 +342,14 @@ def get_plot(
data = self.bs_plot_data()
kwargs.setdefault("color", "blue")
for dists, freqs in zip(data["distances"], data["frequency"]):
for idx in range(self._nb_bands):
for idx in range(self.n_bands):
ys = [freqs[idx][j] * u.factor for j in range(len(dists))]
ax.plot(dists, ys, **kwargs)

self._make_ticks(ax)

# plot y=0 line
ax.axhline(0, linewidth=1, color="k")
ax.axhline(0, linewidth=1, color="black")

# Main X and Y Labels
ax.set_xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30)
Expand All @@ -374,7 +370,7 @@ def _get_weight(self, vec: np.ndarray, indices: list[list[int]]) -> np.ndarray:
"""Compute the weight for each combination of sites according to the
eigenvector.
"""
num_atom = int(self._nb_bands / 3)
num_atom = int(self.n_bands / 3)
new_vec = np.zeros(num_atom)
for idx in range(num_atom):
new_vec[idx] = np.linalg.norm(vec[idx * 3 : idx * 3 + 3])
Expand Down Expand Up @@ -461,13 +457,13 @@ def get_proj_plot(
for d in range(1, len(k_dist)):
# consider 2 k points each time so they connect
colors = []
for idx in range(self._nb_bands):
for idx in range(self.n_bands):
eigenvec_1 = self._bs.eigendisplacements[idx][d - 1].flatten()
eigenvec_2 = self._bs.eigendisplacements[idx][d].flatten()
colors1 = self._get_weight(eigenvec_1, indices)
colors2 = self._get_weight(eigenvec_2, indices)
colors.append(self._make_color((colors1 + colors2) / 2))
seg = np.zeros((self._nb_bands, 2, 2))
seg = np.zeros((self.n_bands, 2, 2))
seg[:, :, 1] = self._bs.bands[:, d - 1 : d + 1] * u.factor
seg[:, 0, 0] = k_dist[d - 1]
seg[:, 1, 0] = k_dist[d]
Expand Down Expand Up @@ -631,10 +627,10 @@ def plot_compare(
other_kwargs = other_kwargs or {}
legend_kwargs.setdefault("fontsize", 20)

data_orig = self.bs_plot_data()
data = other_plotter.bs_plot_data()
self_data = self.bs_plot_data()
other_data = other_plotter.bs_plot_data()

if len(data_orig["distances"]) != len(data["distances"]):
if len(self_data["distances"]) != len(other_data["distances"]):
if on_incompatible == "raise":
raise ValueError("The two band structures are not compatible.")
if on_incompatible == "warn":
Expand All @@ -647,10 +643,10 @@ def plot_compare(

kwargs.setdefault("color", "red") # don't move this line up! it would mess up self.get_plot color

for band_idx in range(other_plotter._nb_bands):
for dist_idx, dists in enumerate(data_orig["distances"]):
for band_idx in range(other_plotter.n_bands):
for dist_idx, dists in enumerate(self_data["distances"]):
xs = dists
ys = [data["frequency"][dist_idx][band_idx][j] * unit.factor for j in range(len(dists))]
ys = [other_data["frequency"][dist_idx][band_idx][j] * unit.factor for j in range(len(dists))]
ax.plot(xs, ys, **(kwargs | other_kwargs))

# add legend showing which color corresponds to which band structure
Expand Down Expand Up @@ -998,7 +994,7 @@ def bs_plot_data(self) -> dict[str, Any]:
gruneisen.append([])
distance.append([self._bs.distance[j] for j in range(branch["start_index"], branch["end_index"] + 1)])

for idx in range(self._nb_bands):
for idx in range(self.n_bands):
frequency[-1].append(
[self._bs.bands[idx][j] for j in range(branch["start_index"], branch["end_index"] + 1)]
)
Expand Down Expand Up @@ -1030,15 +1026,15 @@ def get_plot_gs(self, ylim: float | None = None, **kwargs) -> Axes:

data = self.bs_plot_data()
for dist_idx in range(len(data["distances"])):
for band_idx in range(self._nb_bands):
for band_idx in range(self.n_bands):
ys = [data["gruneisen"][dist_idx][band_idx][idx] for idx in range(len(data["distances"][dist_idx]))]

ax.plot(data["distances"][dist_idx], ys, "b-", **kwargs)

self._make_ticks(ax)

# plot y=0 line
ax.axhline(0, linewidth=1, color="k")
ax.axhline(0, linewidth=1, color="black")

# Main X and Y Labels
ax.set_xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30)
Expand Down Expand Up @@ -1105,7 +1101,7 @@ def plot_compare_gs(self, other_plotter: GruneisenPhononBSPlotter) -> Axes:

ax = self.get_plot()
band_linewidth = 1
for band_idx in range(other_plotter._nb_bands):
for band_idx in range(other_plotter.n_bands):
for dist_idx in range(len(data_orig["distances"])):
ax.plot(
data_orig["distances"][dist_idx],
Expand Down
5 changes: 3 additions & 2 deletions pymatgen/util/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def charge_string(charge, brackets=True, explicit_one=True):
return chg_str


def latexify(formula):
def latexify(formula: str, bold: bool = False):
"""Generates a LaTeX formatted formula. E.g., Fe2O3 is transformed to
Fe$_{2}$O$_{3}$.
Expand All @@ -154,11 +154,12 @@ def latexify(formula):
Args:
formula (str): Input formula.
bold (bool): Whether to make the subscripts bold. Defaults to False.
Returns:
Formula suitable for display as in LaTeX with proper subscripts.
"""
return re.sub(r"([A-Za-z\(\)])([\d\.]+)", r"\1$_{\2}$", formula)
return re.sub(r"([A-Za-z\(\)])([\d\.]+)", r"\1$_{\\mathbf{\2}}$" if bold else r"\1$_{\2}$", formula)


def htmlify(formula):
Expand Down
20 changes: 10 additions & 10 deletions tests/analysis/chemenv/connectivity/test_connected_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ def test_serialization(self):
sorted_edges = sorted(sorted(e) for e in cc.graph.edges())
assert sorted_edges == ref_sorted_edges

ccfromdict = ConnectedComponent.from_dict(cc.as_dict())
ccfromjson = ConnectedComponent.from_dict(json.loads(json.dumps(cc.as_dict())))
loaded_cc_list = [ccfromdict, ccfromjson]
cc_from_dict = ConnectedComponent.from_dict(cc.as_dict())
cc_from_json = ConnectedComponent.from_dict(json.loads(json.dumps(cc.as_dict())))
loaded_cc_list = [cc_from_dict, cc_from_json]
if bson is not None:
bson_data = bson.BSON.encode(cc.as_dict())
ccfrombson = ConnectedComponent.from_dict(bson_data.decode())
loaded_cc_list.append(ccfrombson)
cc_from_bson = ConnectedComponent.from_dict(bson_data.decode())
loaded_cc_list.append(cc_from_bson)
for loaded_cc in loaded_cc_list:
assert loaded_cc.graph.number_of_nodes() == 3
assert loaded_cc.graph.number_of_edges() == 2
Expand All @@ -145,18 +145,18 @@ def test_serialization(self):
assert isinstance(node.central_site, PeriodicSite)

def test_serialization_private_methods(self):
# Testing _edgekey_to_edgedictkey
key = ConnectedComponent._edgekey_to_edgedictkey(3)
# Testing _edge_key_to_edge_dict_key
key = ConnectedComponent._edge_key_to_edge_dict_key(3)
assert key == "3"
with pytest.raises(
RuntimeError,
match=r"Cannot pass an edge key which is a str representation of an int\x2E",
):
key = ConnectedComponent._edgekey_to_edgedictkey("5")
key = ConnectedComponent._edgekey_to_edgedictkey("some-key")
key = ConnectedComponent._edge_key_to_edge_dict_key("5")
key = ConnectedComponent._edge_key_to_edge_dict_key("some-key")
assert key == "some-key"
with pytest.raises(ValueError, match=r"Edge key should be either a str or an int\x2E"):
key = ConnectedComponent._edgekey_to_edgedictkey(0.2)
key = ConnectedComponent._edge_key_to_edge_dict_key(0.2)

def test_periodicity(self):
env_node1 = EnvironmentNode(central_site="Si", i_central_site=3, ce_symbol="T:4")
Expand Down
2 changes: 1 addition & 1 deletion tests/files/.pytest-split-durations
Original file line number Diff line number Diff line change
Expand Up @@ -2306,7 +2306,7 @@
"tests/phonon/test_gruneisen.py::TestGruneisenParameter::test_average_gruneisen": 7.851578625966795,
"tests/phonon/test_gruneisen.py::TestGruneisenParameter::test_debye_temp_phonopy": 7.661829167045653,
"tests/phonon/test_gruneisen.py::TestGruneisenParameter::test_frequencies": 7.70038720802404,
"tests/phonon/test_gruneisen.py::TestGruneisenParameter::test_fromdict_asdict": 7.64449966698885,
"tests/phonon/test_gruneisen.py::TestGruneisenParameter::test_as_from_dict": 7.64449966698885,
"tests/phonon/test_gruneisen.py::TestGruneisenParameter::test_gruneisen": 7.717112623970024,
"tests/phonon/test_gruneisen.py::TestGruneisenParameter::test_multi": 6.107562876015436,
"tests/phonon/test_gruneisen.py::TestGruneisenParameter::test_phdos": 7.646594374033157,
Expand Down
2 changes: 1 addition & 1 deletion tests/phonon/test_gruneisen.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_plot(self):
ax = plotter.get_plot(units="mev")
assert isinstance(ax, plt.Axes)

def test_fromdict_asdict(self):
def test_as_from_dict(self):
new_dict = self.gruneisen_obj.as_dict()
self.gruneisen_obj2 = GruneisenParameter.from_dict(new_dict)

Expand Down

0 comments on commit 4a94f9c

Please sign in to comment.