Skip to content

Commit

Permalink
Try to provide more helpful error if load_dataset is given a DataFrame
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Jun 22, 2021
1 parent 445a54a commit 19cacbc
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
10 changes: 9 additions & 1 deletion seaborn/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,14 +371,22 @@ def test_load_datasets():


@_network(url="https://github.com/mwaskom/seaborn-data")
def test_load_dataset_error():
def test_load_dataset_string_error():

name = "bad_name"
err = f"'{name}' is not one of the example datasets."
with pytest.raises(ValueError, match=err):
load_dataset(name)


def test_load_dataset_passed_data_error():

df = pd.DataFrame()
err = "This function accepts only strings"
with pytest.raises(TypeError, match=err):
load_dataset(df)


@_network(url="https://github.com/mwaskom/seaborn-data")
def test_load_cached_datasets():

Expand Down
22 changes: 16 additions & 6 deletions seaborn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,18 +470,28 @@ def load_dataset(name, cache=True, data_home=None, **kws):
Tabular data, possibly with some preprocessing applied.
"""
path = ("https://raw.githubusercontent.com/"
"mwaskom/seaborn-data/master/{}.csv")
full_path = path.format(name)
# A common beginner mistake is to assume that one's personal data needs
# to be passed through this function to be usable with seaborn.
# Let's provide a more helpful error than you would otherwise get.
if isinstance(name, pd.DataFrame):
err = (
"This function accepts only strings (the name of an example dataset). "
"You passed a pandas DataFrame. If you have your own dataset, "
"it is not necessary to use this function before plotting."
)
raise TypeError(err)

url = f"https://raw.githubusercontent.com/mwaskom/seaborn-data/master/{name}.csv"

if cache:
cache_path = os.path.join(get_data_home(data_home),
os.path.basename(full_path))
cache_path = os.path.join(get_data_home(data_home), os.path.basename(url))
if not os.path.exists(cache_path):
if name not in get_dataset_names():
raise ValueError(f"'{name}' is not one of the example datasets.")
urlretrieve(full_path, cache_path)
urlretrieve(url, cache_path)
full_path = cache_path
else:
full_path = url

df = pd.read_csv(full_path, **kws)

Expand Down

0 comments on commit 19cacbc

Please sign in to comment.