From 19cacbcd4a62c90eabfc1315fb664ca78cbeeaa0 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 22 Jun 2021 14:50:33 -0400 Subject: [PATCH] Try to provide more helpful error if load_dataset is given a DataFrame --- seaborn/tests/test_utils.py | 10 +++++++++- seaborn/utils.py | 22 ++++++++++++++++------ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/seaborn/tests/test_utils.py b/seaborn/tests/test_utils.py index cab8a357f2..6edcb7ac0b 100644 --- a/seaborn/tests/test_utils.py +++ b/seaborn/tests/test_utils.py @@ -371,7 +371,7 @@ def test_load_datasets(): @_network(url="https://github.com/mwaskom/seaborn-data") -def test_load_dataset_error(): +def test_load_dataset_string_error(): name = "bad_name" err = f"'{name}' is not one of the example datasets." @@ -379,6 +379,14 @@ def test_load_dataset_error(): load_dataset(name) +def test_load_dataset_passed_data_error(): + + df = pd.DataFrame() + err = "This function accepts only strings" + with pytest.raises(TypeError, match=err): + load_dataset(df) + + @_network(url="https://github.com/mwaskom/seaborn-data") def test_load_cached_datasets(): diff --git a/seaborn/utils.py b/seaborn/utils.py index 4547e61d4b..a9261028b9 100644 --- a/seaborn/utils.py +++ b/seaborn/utils.py @@ -470,18 +470,28 @@ def load_dataset(name, cache=True, data_home=None, **kws): Tabular data, possibly with some preprocessing applied. """ - path = ("https://raw.githubusercontent.com/" - "mwaskom/seaborn-data/master/{}.csv") - full_path = path.format(name) + # A common beginner mistake is to assume that one's personal data needs + # to be passed through this function to be usable with seaborn. + # Let's provide a more helpful error than you would otherwise get. + if isinstance(name, pd.DataFrame): + err = ( + "This function accepts only strings (the name of an example dataset). " + "You passed a pandas DataFrame. If you have your own dataset, " + "it is not necessary to use this function before plotting." + ) + raise TypeError(err) + + url = f"https://raw.githubusercontent.com/mwaskom/seaborn-data/master/{name}.csv" if cache: - cache_path = os.path.join(get_data_home(data_home), - os.path.basename(full_path)) + cache_path = os.path.join(get_data_home(data_home), os.path.basename(url)) if not os.path.exists(cache_path): if name not in get_dataset_names(): raise ValueError(f"'{name}' is not one of the example datasets.") - urlretrieve(full_path, cache_path) + urlretrieve(url, cache_path) full_path = cache_path + else: + full_path = url df = pd.read_csv(full_path, **kws)