Skip to content

Commit

Permalink
feat(DRAFT): Add optional backend parameter for load(...)
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned committed Nov 18, 2024
1 parent 9544d9b commit 7b3a89e
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 43 deletions.
94 changes: 80 additions & 14 deletions altair/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Generic, overload
from typing import TYPE_CHECKING, Generic, final, overload

from narwhals.typing import IntoDataFrameT, IntoFrameT

Expand Down Expand Up @@ -320,28 +320,94 @@ def __repr__(self) -> str:
return f"{type(self).__name__}[{self._reader._name}]"


load: Loader[Any, Any]
@final
class _Load(Loader[IntoDataFrameT, IntoFrameT]):
@overload
def __call__( # pyright: ignore[reportOverlappingOverload]
self,
name: Dataset | LiteralString,
suffix: Extension | None = ...,
/,
tag: Version | None = ...,
backend: None = ...,
**kwds: Any,
) -> IntoDataFrameT: ...
@overload
def __call__(
self,
name: Dataset | LiteralString,
suffix: Extension | None = ...,
/,
tag: Version | None = ...,
backend: Literal["polars", "polars[pyarrow]"] = ...,
**kwds: Any,
) -> pl.DataFrame: ...
@overload
def __call__(
self,
name: Dataset | LiteralString,
suffix: Extension | None = ...,
/,
tag: Version | None = ...,
backend: Literal["pandas", "pandas[pyarrow]"] = ...,
**kwds: Any,
) -> pd.DataFrame: ...
@overload
def __call__(
self,
name: Dataset | LiteralString,
suffix: Extension | None = ...,
/,
tag: Version | None = ...,
backend: Literal["pyarrow"] = ...,
**kwds: Any,
) -> pa.Table: ...
def __call__(
self,
name: Dataset | LiteralString,
suffix: Extension | None = None,
/,
tag: Version | None = None,
backend: _Backend | None = None,
**kwds: Any,
) -> IntoDataFrameT | pl.DataFrame | pd.DataFrame | pa.Table:
if backend is None:
return super().__call__(name, suffix, tag, **kwds)
else:
return self.from_backend(backend)(name, suffix, tag=tag, **kwds)


load: _Load[Any, Any]
"""
For full IDE completions, instead use:
from altair.datasets import Loader
load = Loader.from_backend("polars")
cars = load("cars")
movies = load("movies")
Alternatively, specify ``backend`` during a call:
from altair.datasets import load
cars = load("cars", backend="polars")
movies = load("movies", backend="polars")
Related
-------
- https://github.com/vega/altair/pull/3631#issuecomment-2480832609
- https://github.com/vega/altair/pull/3631#discussion_r1847111064
- https://github.com/vega/altair/pull/3631#discussion_r1847176465
"""


def __getattr__(name):
if name == "load":
import warnings

from altair.datasets._readers import infer_backend

reader = infer_backend()
global load
load = Loader.__new__(Loader)
load = _Load.__new__(_Load)
load._reader = reader

warnings.warn(
"For full IDE completions, instead use:\n\n"
" from altair.datasets import Loader\n"
" load = Loader.from_backend(...)\n\n"
"Related: https://github.com/vega/altair/pull/3631#issuecomment-2480832609",
UserWarning,
stacklevel=3,
)
return load
else:
msg = f"module {__name__!r} has no attribute {name!r}"
Expand Down
81 changes: 52 additions & 29 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
import datetime as dt
import re
import sys
import warnings
from functools import partial
from importlib import import_module
from importlib.util import find_spec
from typing import TYPE_CHECKING, Any, cast, get_args
from urllib.error import URLError

import pytest
from narwhals.dependencies import is_into_dataframe, is_polars_dataframe
from narwhals.dependencies import (
is_into_dataframe,
is_pandas_dataframe,
is_polars_dataframe,
is_pyarrow_table,
)
from narwhals.stable import v1 as nw

from altair.datasets import Loader
Expand Down Expand Up @@ -138,47 +142,66 @@ def test_load(monkeypatch: pytest.MonkeyPatch) -> None:
priority: Sequence[_Backend] = "polars", "pandas[pyarrow]", "pandas", "pyarrow"
"""
import altair.datasets
from altair.datasets import load

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
from altair.datasets import load
assert load._reader._name == "polars"
monkeypatch.delattr(altair.datasets, "load")

monkeypatch.setitem(sys.modules, "polars", None)

assert load._reader._name == "polars"
from altair.datasets import load

if find_spec("pyarrow") is None:
# NOTE: We can end the test early for the CI job that removes `pyarrow`
assert load._reader._name == "pandas"
monkeypatch.delattr(altair.datasets, "load")
monkeypatch.setitem(sys.modules, "pandas", None)
with pytest.raises(NotImplementedError, match="no.+backend"):
from altair.datasets import load
else:
assert load._reader._name == "pandas[pyarrow]"
monkeypatch.delattr(altair.datasets, "load")

monkeypatch.setitem(sys.modules, "polars", None)
monkeypatch.setitem(sys.modules, "pyarrow", None)

from altair.datasets import load

if find_spec("pyarrow") is None:
# NOTE: We can end the test early for the CI job that removes `pyarrow`
assert load._reader._name == "pandas"
monkeypatch.delattr(altair.datasets, "load")
monkeypatch.setitem(sys.modules, "pandas", None)
with pytest.raises(NotImplementedError, match="no.+backend"):
from altair.datasets import load
else:
assert load._reader._name == "pandas[pyarrow]"
monkeypatch.delattr(altair.datasets, "load")
assert load._reader._name == "pandas"
monkeypatch.delattr(altair.datasets, "load")

monkeypatch.setitem(sys.modules, "pandas", None)
monkeypatch.delitem(sys.modules, "pyarrow")
monkeypatch.setitem(sys.modules, "pyarrow", import_module("pyarrow"))
from altair.datasets import load

monkeypatch.setitem(sys.modules, "pyarrow", None)
assert load._reader._name == "pyarrow"
monkeypatch.delattr(altair.datasets, "load")
monkeypatch.setitem(sys.modules, "pyarrow", None)

with pytest.raises(NotImplementedError, match="no.+backend"):
from altair.datasets import load

assert load._reader._name == "pandas"
monkeypatch.delattr(altair.datasets, "load")

monkeypatch.setitem(sys.modules, "pandas", None)
monkeypatch.delitem(sys.modules, "pyarrow")
monkeypatch.setitem(sys.modules, "pyarrow", import_module("pyarrow"))
from altair.datasets import load
@requires_pyarrow
def test_load_call(monkeypatch: pytest.MonkeyPatch) -> None:
import altair.datasets

monkeypatch.delattr(altair.datasets, "load", raising=False)

load = altair.datasets.load
assert load._reader._name == "polars"

assert load._reader._name == "pyarrow"
monkeypatch.delattr(altair.datasets, "load")
monkeypatch.setitem(sys.modules, "pyarrow", None)
default = load("cars")
df_pyarrow = load("cars", backend="pyarrow")
df_pandas = load("cars", backend="pandas[pyarrow]")
default_2 = load("cars")
df_polars = load("cars", backend="polars")

with pytest.raises(NotImplementedError, match="no.+backend"):
from altair.datasets import load
assert is_polars_dataframe(default)
assert is_pyarrow_table(df_pyarrow)
assert is_pandas_dataframe(df_pandas)
assert is_polars_dataframe(default_2)
assert is_polars_dataframe(df_polars)


@backends
Expand Down

0 comments on commit 7b3a89e

Please sign in to comment.