diff --git a/diffpy/nmf_mapping/__init__.py b/diffpy/nmf_mapping/__init__.py index ff7b2cc..6fc19a1 100644 --- a/diffpy/nmf_mapping/__init__.py +++ b/diffpy/nmf_mapping/__init__.py @@ -24,4 +24,6 @@ # top-level import from diffpy.nmf_mapping.nmf_mapping import nmf_mapping_code as nmf +__all__ = ["nmf"] + # End of file diff --git a/diffpy/nmf_mapping/nmf_mapping/main.py b/diffpy/nmf_mapping/nmf_mapping/main.py index a81a7d4..f833902 100644 --- a/diffpy/nmf_mapping/nmf_mapping/main.py +++ b/diffpy/nmf_mapping/nmf_mapping/main.py @@ -34,11 +34,14 @@ def main(args=None): parser = ArgumentParser(prog="nmf_mapping", description=_BANNER, formatter_class=RawTextHelpFormatter) def tup(s): + if not isinstance(s, str): + raise TypeError("Input must be a string of two integers separated by a comma.") + try: l, h = map(int, s.split(",")) return l, h - except: - raise TypeError("r range must be low, high") + except ValueError: + raise ValueError("Input must be two integers separated by a comma (e.g., '1,5')") # args parser.add_argument( diff --git a/diffpy/nmf_mapping/nmf_mapping/nmf_mapping_code.py b/diffpy/nmf_mapping/nmf_mapping/nmf_mapping_code.py index b7e96d2..4ad293a 100644 --- a/diffpy/nmf_mapping/nmf_mapping/nmf_mapping_code.py +++ b/diffpy/nmf_mapping/nmf_mapping/nmf_mapping_code.py @@ -3,24 +3,20 @@ Local NMF Analysis of PDFs for PDFitc. """ +import re +import warnings from pathlib import Path import matplotlib.pyplot as plt import numpy as np import pandas as pd - -try: - from bg_mpl_stylesheet.bg_mpl_stylesheet import bg_mpl_style -except ImportError: - print("bg_mpl_style not found. Using generic matplotlib style.") -import re -import warnings - +from bg_mpl_stylesheets.styles import all_styles from diffpy.utils.parsers.loaddata import loadData from scipy import interpolate from sklearn.decomposition import NMF, PCA from sklearn.exceptions import ConvergenceWarning +plt.style.use(all_styles["bg_style"]) warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=ConvergenceWarning) @@ -74,8 +70,8 @@ def load_data(dir, xrd=False): raise FileNotFoundError("No gr or dat files found") n = len(data_list) data_list.sort(key=natural_keys_file_name) - l = loadData(data_list[0]).shape[0] - data_arr = np.zeros((n, l, 2)) + data_length = loadData(data_list[0]).shape[0] + data_arr = np.zeros((n, data_length, 2)) # if not on the same x grid, interpolate, use first data set as standard x_set = loadData(data_list[0])[:, 0] @@ -158,7 +154,8 @@ def NMF_decomposition( nmf_loss = [] pca_explained_variance = [] - # Assuming that the algorithm won't be able to decompose a timeseries of less than x scans into x or more components + # Assuming that the algorithm won't be able to decompose a timeseries of less than x scans into + # x or more components if thresh is None: if len(x_vs_y_df.columns) < 10: max_comp = len(x_vs_y_df.columns) @@ -193,7 +190,8 @@ def NMF_decomposition( if additional_comp: thresh += 1 - # Assuming that the algorithm won't be able to decompose a timeseries of less than x scans into x or more components + # Assuming that the algorithm won't be able to decompose a timeseries of less than x scans into + # x or more components if len(x_vs_y_df.columns) < thresh: n_comp = len(x_vs_y_df.columns) else: @@ -235,10 +233,7 @@ def component_plot(df_components, xrd=False, x_units=None, show=True): figure on absolute scale """ - try: - plt.style.use(bg_mpl_style) - except: - pass + df = df_components.copy() data_list = df.columns @@ -287,13 +282,9 @@ def component_ratio_plot(df_component_weight_timeseries, show=True): figure on absolute scale """ - try: - plt.style.use(bg_mpl_style) - except: - pass + df = df_component_weight_timeseries.copy() component_list = df.index - fig, ax = plt.subplots(figsize=(6, 8)) # seq to align with input phase for component in component_list: @@ -328,10 +319,7 @@ def reconstruction_error_plot(df_reconstruction_error, show=True): figure on absolute scale with removed files """ - try: - plt.style.use(bg_mpl_style) - except: - pass + df = df_reconstruction_error.copy() fig, ax = plt.subplots(figsize=(6, 8)) @@ -370,10 +358,7 @@ def explained_variance_plot(df_explained_var_ratio, show=True): figure on absolute scale with removed files """ - try: - plt.style.use(bg_mpl_style) - except: - pass + df = df_explained_var_ratio.copy() fig, ax = plt.subplots(figsize=(6, 8)) diff --git a/requirements/run.txt b/requirements/run.txt index 794072e..862eca2 100644 --- a/requirements/run.txt +++ b/requirements/run.txt @@ -4,3 +4,4 @@ scipy diffpy.utils pandas matplotlib +bg-mpl-stylesheets diff --git a/setup.py b/setup.py index a9a4fab..05f92d9 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ long_description = fh.read() -setuptools.setup( +setup( name="diffpy.nmf_mapping", version="1.0.0", author="Simon J.L. Billinge",