diff --git a/edvart/report.py b/edvart/report.py index 9b1a143..e46ca5f 100755 --- a/edvart/report.py +++ b/edvart/report.py @@ -3,7 +3,6 @@ # Standard imports import base64 import logging -import os import pickle from abc import ABC from typing import List, Optional, Tuple, Union @@ -24,6 +23,7 @@ from edvart.report_sections.table_of_contents import TableOfContents from edvart.report_sections.timeseries_analysis import TimeseriesAnalysis from edvart.report_sections.univariate_analysis import UnivariateAnalysis +from edvart.utils import env_var class ReportBase(ABC): @@ -228,12 +228,8 @@ def _export_html( # Workaround for a warning from `nbconvert` regarding debugging # and frozen modules. We are not debugging, so we can safely ignore it. - disable_validation_env_var_name = "PYDEVD_DISABLE_FILE_VALIDATION" - env_original = os.environ.copy() - os.environ[disable_validation_env_var_name] = "1" - - html = html_exporter.from_notebook_node(nb)[0] - os.environ = env_original + with env_var("PYDEVD_DISABLE_FILE_VALIDATION", "1"): + html = html_exporter.from_notebook_node(nb)[0] # Save HTML to file with open(html_filepath, "w") as html_file: diff --git a/edvart/utils.py b/edvart/utils.py index 1b1edd3..27f6d4f 100755 --- a/edvart/utils.py +++ b/edvart/utils.py @@ -1,6 +1,8 @@ """Utils package.""" -from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union +import os +from contextlib import contextmanager +from typing import Any, Dict, Iterable, Iterator, List, Literal, Optional, Tuple, Union import pandas as pd import seaborn as sns @@ -542,3 +544,23 @@ def contingency_table(df: pd.DataFrame) -> pd.DataFrame: """ table = sm.stats.Table.from_data(df) return table.table_orig.astype(int) + + +@contextmanager +def env_var(name: str, value: str) -> Iterator[None]: + """ + Set an environment variable for the duration of the context. + + Parameters + ---------- + name : str + Name of the environment variable. + value : str + Value of the environment variable. + """ + original_env = os.environ.copy() + os.environ[name] = value + try: + yield + finally: + os.environ = original_env diff --git a/tests/test_utils.py b/tests/test_utils.py index 5d11adf..cdbf91b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,5 @@ import math +import os import warnings import numpy as np @@ -31,3 +32,10 @@ def test_full_na_series(): assert utils.is_numeric(series) assert utils.is_categorical(series) assert utils.num_unique_values(series) == 0 + +def test_env_var(): + test_var_name = "TEST_VAR" + test_var_value = "test" + with utils.env_var(test_var_name, test_var_value): + assert os.environ[test_var_name] == test_var_value + assert test_var_value not in os.environ