-
Notifications
You must be signed in to change notification settings - Fork 64
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
Merge pull request #462 from mdekstrand/feature/pipeline
Build a Pipeline abstraction
Showing
12 changed files
with
1,859 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,8 @@ invoker | |
CUDA | ||
subpackages | ||
recomputation | ||
Higley | ||
POPROX | ||
rankers | ||
Scikit-Learn | ||
unpickle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,6 +50,7 @@ Resources | |
:caption: Algorithms | ||
|
||
interfaces | ||
pipeline | ||
algorithms | ||
basic | ||
ranking | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,331 @@ | ||
Recommendation Pipelines | ||
======================== | ||
|
||
.. module:: lenskit.pipeline | ||
|
||
.. todo:: | ||
None of this has been implemented yet. | ||
|
||
Since version :ref:`2024.1`, LensKit uses a flexible “pipeline” abstraction to | ||
wire together different components such as candidate selectors, personalized | ||
item scorers, and rankers to produce predictions, recommendations, or other | ||
recommender system outputs. This is a significant change from the LensKit 0.x | ||
design of monolithic and composable components based on the Scikit-Learn API, | ||
allowing new recommendation designs to be composed without writing new classes | ||
just for the composition. It also makes recommender definition code more explicit | ||
by laying out the pipeline instead of burying composition logic in the definitions | ||
of different composition classes. | ||
|
||
If all you want to do is build a standard top-N recommendation pipeline from an | ||
item scorer, see :func:`topn_pipeline`; this is the equivalent to | ||
``Recommender.adapt`` in the old LensKit API. If you want more flexibility, you | ||
can write out the pipeline configuration yourself; the equivalent to | ||
``topn_pipeline(scorer)`` is: | ||
|
||
.. code:: python | ||
pipe = Pipeline() | ||
# define an input parameter for the user ID | ||
user = pipe.create_input('user', EntityId) | ||
# allow candidate items to be optionally specified | ||
items = pipe.create_input('items', list[EntityId], None) | ||
# look up a user's history in the training data | ||
history = pipe.add_component('lookup-user', LookupTrainingHistory(), user=user) | ||
# find candidates from the training data | ||
lookup_candidates = pipe.add_component( | ||
'select-candidates', | ||
UnratedTrainingItemsCandidateSelector(), | ||
user=history, | ||
) | ||
# if the client provided items as a pipeline input, use those; otherwise | ||
# use the candidate selector we just configured. | ||
candidates = pipe.use_first_of('candidates', items, lookup_candidates) | ||
# score the candidate items using the specified scorer | ||
score = pipe.add_component('score', scorer, user=user, items=candidates) | ||
# rank the items by score | ||
recommend = pipe.add_component('recommend', TopNRanker(50), items=score) | ||
You can then run this pipeline to produce recommendations with: | ||
|
||
.. code:: python | ||
user_recs = pipe.run(recommend, user=user_id) | ||
.. todo:: | ||
Redo some of those types with user & item data, etc. | ||
|
||
.. todo:: | ||
Provide utility functions to make more common wiring operations easy so there | ||
is middle ground between “give me a standard pipeline” and “make me do everything | ||
myself”. | ||
|
||
.. todo:: | ||
Rethink the “keyword inputs only” constraint in view of the limitation it | ||
places on fallback or other compositional components — it's hard to specify | ||
a component that implements fallback logic for an arbitrary number of | ||
inputs. | ||
|
||
Pipeline components are not limited to looking things up from training data — | ||
they can query databases, load files, and any other operations. A runtime | ||
pipeline can use some components (especially the scorer) trained from training | ||
data, and other components that query a database or REST services for things | ||
like user history and candidate set lookup. | ||
|
||
The LensKit pipeline design is heavily inspired by Haystack_ and by the pipeline | ||
abstraction Karl Higley created for POPROX_. | ||
|
||
.. _Haystack: https://docs.haystack.deepset.ai/docs/pipelines | ||
.. _POPROX: https://ccri-poprox.github.io/poprox-researcher-manual/reference/recommender/poprox_recommender.pipeline.html | ||
|
||
Common Pipelines | ||
~~~~~~~~~~~~~~~~ | ||
|
||
These functions make it easy to create common pipeline designs. | ||
|
||
.. autofunction:: topn_pipeline | ||
|
||
.. _pipeline-model: | ||
|
||
Pipeline Model | ||
~~~~~~~~~~~~~~ | ||
|
||
A pipeline has a couple key concepts: | ||
|
||
* An **input** is data that needs to be provided to the pipeline when it is run, | ||
such as the user to generate recommendations for. Inputs have specified data | ||
types, and it is an error to provide an input value of an unexpected type. | ||
* A **component** processes input data and produces an output. It can be either | ||
a Python function or object (anything that implements the :class:`Component` | ||
protocol) that takes inputs as keyword arguments and returns an output. | ||
|
||
These are arranged in a directed acyclic graph, consisting of: | ||
|
||
* **Nodes** (represented by :class:`Node`), which correspond to either *inputs* | ||
or *components*. | ||
* **Connections** from one node's input to another node's data (or to a fixed | ||
data value). This is how the pipeline knows which components depend on other | ||
components and how to provide each component with the inputs it requires; see | ||
:ref:`pipeline-connections` for details. | ||
|
||
Each node has a name that can be used to look up the node with | ||
:meth:`Pipeline.node` and appears in serialization and logging situations. Names | ||
must be unique within a pipeline. | ||
|
||
.. _pipeline-connections: | ||
|
||
Connections | ||
----------- | ||
|
||
Components declare their inputs as keyword arguments on their call signatures | ||
(either the function call signature, if it is a bare function, or the | ||
``__call__`` method if it is implemented by a class). In a pipeline, these | ||
inputs can be connected to a source, which the pipeline will use to obtain a | ||
value for that parameter when running the pipeline. Inputs can be connected to | ||
the following types: | ||
|
||
* A :class:`Node`, in which case the input will be provided from the | ||
corresponding pipeline input or component return value. Nodes are | ||
returned by :meth:`create_input` or :meth:`add_component`, and can be | ||
looked up after creation with :meth:`node`. | ||
* A Python object, in which case that value will be provided directly to | ||
the component input argument. | ||
|
||
These input connections are specified via keyword arguments to the | ||
:meth:`Pipeline.add_component` or :meth:`Pipeline.connect` methods — specify the | ||
component's input name(s) and the node or data to which each input should be | ||
wired. | ||
|
||
You can also use :meth:`Pipeline.add_default` to specify default connections. For example, | ||
you can specify a default for ``user``:: | ||
|
||
pipe.add_default('user', user_history) | ||
|
||
With this default in place, if a component has an input named ``user`` and that | ||
input is not explicitly connected to a node, then the ``user_history`` node will | ||
be used to supply its value. Judicious use of defaults can reduce the amount of | ||
code overhead needed to wire common pipelines. | ||
|
||
.. note:: | ||
|
||
You cannot directly wire an input another component using only that | ||
component's name; if you only have a name, pass it to :meth:`node` | ||
to obtain the node. This is because it would be impossible to | ||
distinguish between a string component name and a string data value. | ||
|
||
.. note:: | ||
|
||
You do not usually need to call this method directly; when possible, | ||
provide the wirings when calling :meth:`add_component`. | ||
|
||
.. _pipeline-execution: | ||
|
||
Execution | ||
--------- | ||
|
||
Once configured, a pipeline can be run with :meth:`Pipeline.run`. This | ||
method takes two types of inputs: | ||
|
||
* Positional arguments specifying the node(s) to run and whose results should | ||
be returned. This is to allow partial runs of pipelines (e.g. to only score | ||
items without ranking them), and to allow multiple return values to be | ||
obtained (e.g. initial item scores and final rankings, which may have | ||
altered scores). | ||
|
||
If no components are specified, it is the same as specifying the last | ||
component added to the pipeline. | ||
|
||
* Keyword arguments specifying the values for the pipeline's inputs, as defined by | ||
calls to :meth:`create_input`. | ||
|
||
Pipeline execution logically proceeds in the following steps: | ||
|
||
1. Determine the full list of pipeline components that need to be run | ||
in order to run the specified components. | ||
2. Run those components in order, taking their inputs from pipeline | ||
inputs or previous components as specified by the pipeline | ||
connections and defaults. | ||
3. Return the values of the specified components. If a single | ||
component is specified, its value is returned directly; if two or | ||
more components are specified, their values are returned in a tuple. | ||
|
||
.. _pipeline-names: | ||
|
||
Component Names | ||
--------------- | ||
|
||
As noted above, each component (and pipeline input) has a *name* that is unique | ||
across the pipeline. For consistency and clarity, we recommend naming | ||
components with a verb or kebab-case verb phrase that captures the action that component performs, such as: | ||
|
||
* ``recommend`` | ||
* ``rerank`` | ||
* ``score`` | ||
* ``lookup-user-history`` | ||
* ``embed-items`` | ||
|
||
Component nodes can also have *aliases*, allowing them to be accessed by more | ||
than one name. Use :meth:`Pipeline.alias` to define these aliases. | ||
|
||
Various LensKit facilities recognize several standard component names that we | ||
recommend you use when applicable: | ||
|
||
* ``score`` — compute (usually personalized) scores for items for a given user. | ||
* ``rank`` — compute a (ranked) list of recommendations for a user. If you are | ||
configuring a pipeline with rerankers whose outputs are also rankings, this | ||
name should usually be used for the last such ranker, and downstream | ||
components (if any) transform that ranking into another layout; that way the | ||
evaluation tools will operate on the last such ranking. | ||
* ``recommend`` — compute recommendations for a user. This will often be an | ||
alias for ``rank``, as in a top-*N* recommender, but may return other formats | ||
such as grids or unordered slates. | ||
* ``predict-ratings`` — predict a user's ratings for the specified items. When | ||
present, this is usually an alias for ``score``, but in some pipelines it will | ||
be a different component that transforms the scores into rating predictions. | ||
|
||
These component names replace the task-specific interfaces in pre-2024 LensKit; | ||
a ``Recommender`` is now just a pipeline with ``recommend`` and/or ``rank`` | ||
components. | ||
|
||
.. _pipeline-serialization: | ||
|
||
Pipeline Serialization | ||
---------------------- | ||
|
||
Pipelines are defined by the following: | ||
|
||
* The components and inputs (nodes) | ||
* The component input connections (edges) | ||
* The component configurations (see :class:`ConfigurableComponent`) | ||
* The components' learned parameters (see :class:`TrainableComponent`) | ||
|
||
.. todo:: | ||
Serialization support other than ``pickle`` is not yet implemented. | ||
|
||
LensKit supports serializing both pipeline descriptions (components, | ||
connections, and configurations) and pipeline parameters. There are | ||
three ways to save a pipeline or part thereof: | ||
|
||
1. Pickle the entire pipeline. This is easy, and saves everything pipeline; it | ||
has the usual downsides of pickling (arbitrary code execution, etc.). | ||
LensKit uses pickling to share pipelines with worker processes for parallel | ||
batch operations. | ||
2. Save the pipeline configuration with :meth:`Pipeline.save_config`. This saves | ||
the components, their configurations, and their connections, but **not** any | ||
learned parameter data. A new pipeline can be constructed from such a | ||
configuration can be reloaded with :meth:`Pipeline.from_config`. | ||
3. Save the pipeline parameters with :meth:`Pipeline.save_params`. This saves | ||
the learned parameters but **not** the configuration or connections. The | ||
parameters can be reloaded into a compatible pipeline with | ||
:meth:`Pipeline.load_params`; a compatible pipeline can be created by | ||
running the pipeline setup code or using a saved pipeline configuration. | ||
|
||
These can be mixed and matched; if you pickle an untrained pipeline, you can | ||
unpickle it and use :meth:`~Pipeline.load_params` to infuse it with parameters. | ||
|
||
Component implementations need to support the configuration and/or parameter | ||
values, as needed, in addition to functioning correctly with pickle (no specific | ||
logic is usually needed for this). | ||
|
||
LensKit knows how to safely save the following object types from | ||
:meth:`Component.get_params`: | ||
|
||
* :class:`torch.Tensor` (dense, CSR, and COO tensors). | ||
* :class:`numpy.ndarray`. | ||
* :class:`scipy.sparse.csr_array`, :class:`scipy.sparse.~coo_array`, | ||
:class:`scipy.sparse.~csc_array`, and the corresponding ``*_matrix`` | ||
versions. | ||
|
||
Other objects (including Pandas dataframes) are serialized by pickling, and the | ||
pipeline will emit a warning (or fail, if ``allow_pickle=False`` is passed to | ||
:meth:`~Pipeline.save_params`). | ||
|
||
.. note:: | ||
The load/save parameter operations are modeled after PyTorch's | ||
:meth:`~torch.nn.Module.state_dict` and the needs of ``safetensors``. | ||
|
||
Pipeline Class | ||
~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: Pipeline | ||
|
||
Pipeline Nodes | ||
~~~~~~~~~~~~~~ | ||
|
||
Pipeline nodes are represented by :class:`Node` objects. For the purposes of | ||
client code, these should be considered opaque objects usable only to reference | ||
a node. | ||
|
||
.. autoclass:: Node | ||
:members: name | ||
|
||
Component Interface | ||
~~~~~~~~~~~~~~~~~~~ | ||
|
||
Pipeline components are callable objects that can optionally provide training | ||
and serialization capabilities. In the simplest case, a component that requires | ||
no training or configuration can simply be a Python function; more sophisticated | ||
components can implement the :class:`TrainableComponent` and/or | ||
:class:`ConfigurableComponent` protocols to support flexible model training and | ||
pipeline serialization. | ||
|
||
Components also need to be pickleable, as LensKit uses pickling for shared | ||
memory parallelism in its batch-inference code. | ||
|
||
.. note:: | ||
|
||
The component interfaces are simply protocol definitions (defined using | ||
:class:`typing.Protocol` with :func:`~typing.runtime_checkable`), so | ||
implementations can directly implement the specified methods and do not need | ||
to explicitly inherit from the protocol classes, although they are free to | ||
do so. | ||
|
||
.. todo:: | ||
|
||
Is it clear to write these capabilities as separate protocols, or would it be | ||
better to write a single ``Component`` :class:`~abc.ABC`? | ||
|
||
.. autoclass:: Component | ||
|
||
.. autoclass:: ConfigurableComponent | ||
|
||
.. autoclass:: TrainableComponent |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# This file is part of LensKit. | ||
# Copyright (C) 2018-2023 Boise State University | ||
# Copyright (C) 2023-2024 Drexel University | ||
# Licensed under the MIT license, see LICENSE.md for details. | ||
# SPDX-License-Identifier: MIT | ||
|
||
import pandas as pd | ||
|
||
from . import Pipeline | ||
from .components import Component | ||
|
||
|
||
def topn_pipeline(scorer: Component[pd.Series], *, predicts_ratings: bool = False) -> Pipeline: | ||
""" | ||
Create a pipeline that produces top-N recommendations using the specified | ||
scorer. The scorer should have the following call signature:: | ||
def scorer(user: UserHistory, items: ItemList) -> pd.Series: ... | ||
Args: | ||
scorer: | ||
The scorer to use in the pipeline (it will added with the component | ||
name ``score``, see :ref:`pipeline-names`). | ||
predicts_ratings: | ||
If ``True``, make ``predict-ratings`` an alias for ``score`` so that | ||
evaluation components know this pipeline can predict ratings. | ||
""" | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# This file is part of LensKit. | ||
# Copyright (C) 2018-2023 Boise State University | ||
# Copyright (C) 2023-2024 Drexel University | ||
# Licensed under the MIT license, see LICENSE.md for details. | ||
# SPDX-License-Identifier: MIT | ||
|
||
"Definition of the component interfaces." | ||
|
||
# pyright: strict | ||
from __future__ import annotations | ||
|
||
from typing import Callable, TypeAlias | ||
|
||
from typing_extensions import Any, Generic, Protocol, Self, TypeVar, runtime_checkable | ||
|
||
from lenskit.data.dataset import Dataset | ||
|
||
# COut is only return, so Component[U] can be assigned to Component[T] if U ≼ T. | ||
COut = TypeVar("COut", covariant=True) | ||
Component: TypeAlias = Callable[..., COut] | ||
|
||
|
||
@runtime_checkable | ||
class ConfigurableComponent(Generic[COut], Protocol): | ||
""" | ||
Interface for configurable pipeline components (those that have | ||
hyperparameters). A configurable component supports two additional | ||
operations: | ||
* saving its configuration with :meth:`get_config`. | ||
* creating a new instance from a saved configuration with the class method | ||
:meth:`from_config`. | ||
A component must implement both of these methods to be considered | ||
configurable. | ||
.. note:: | ||
Configuration data should be JSON-compatible (strings, numbers, etc.). | ||
.. note:: | ||
Implementations must also implement ``__call__``. | ||
""" | ||
|
||
@classmethod | ||
def from_config(cls, cfg: dict[str, Any]) -> dict[str, object]: | ||
""" | ||
Reinstantiate this component from configuration values. | ||
""" | ||
... | ||
|
||
def get_config(self) -> dict[str, object]: | ||
""" | ||
Get this component's configured hyperparameters. | ||
""" | ||
... | ||
|
||
|
||
@runtime_checkable | ||
class TrainableComponent(Generic[COut], Protocol): | ||
""" | ||
Interface for pipeline components that can learn parameters from training | ||
data, and expose those parameters for serialization as an alternative to | ||
pickling (components also need to be picklable). | ||
.. note:: | ||
Trainable components must also implement ``__call__``. | ||
""" | ||
|
||
def train(self, data: Dataset) -> Self: | ||
""" | ||
Train the pipeline component to learn its parameters from a training | ||
dataset. | ||
Args: | ||
data: | ||
The training dataset. | ||
Returns: | ||
The component. | ||
""" | ||
raise NotImplementedError() | ||
|
||
def get_params(self) -> dict[str, object]: | ||
""" | ||
Get the model's learned parameters for serialization. | ||
LensKit components that learn parameters from training data should both | ||
implement this method and work when pickled and unpickled. Pickling is | ||
sometimes used for convenience, but parameter / state dictionaries allow | ||
serializing wtih tools like ``safetensors``. | ||
Args: | ||
include_caches: | ||
Whether the parameter dictionary should include ephemeral | ||
caching structures only used for runtime performance | ||
optimizations. | ||
Returns: | ||
The model's parameters, as a dictionary from names to parameter data | ||
(usually arrays, tensors, etc.). | ||
""" | ||
raise NotImplementedError() | ||
|
||
def load_params(self, params: dict[str, object]) -> None: | ||
""" | ||
Reload model state from parameters saved via :meth:`get_params`. | ||
""" | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# This file is part of LensKit. | ||
# Copyright (C) 2018-2023 Boise State University | ||
# Copyright (C) 2023-2024 Drexel University | ||
# Licensed under the MIT license, see LICENSE.md for details. | ||
# SPDX-License-Identifier: MIT | ||
|
||
# pyright: strict | ||
|
||
import warnings | ||
from inspect import Signature, signature | ||
|
||
from typing_extensions import Generic, TypeVar | ||
|
||
from lenskit.pipeline.types import TypecheckWarning | ||
|
||
from .components import Component | ||
|
||
# Nodes are (conceptually) immutable data containers, so Node[U] can be assigned | ||
# to Node[T] if U ≼ T. | ||
ND = TypeVar("ND", covariant=True) | ||
|
||
|
||
class Node(Generic[ND]): | ||
""" | ||
Representation of a single node in a :class:`Pipeline`. | ||
""" | ||
|
||
__match_args__ = ("name",) | ||
|
||
name: str | ||
"The name of this node." | ||
types: set[type] | None | ||
"The set of valid data types of this node, or None for no typechecking." | ||
|
||
def __init__(self, name: str, *, types: set[type] | None = None): | ||
self.name = name | ||
self.types = types | ||
|
||
def __str__(self) -> str: | ||
return f"<{self.__class__.__name__} {self.name}>" | ||
|
||
|
||
class InputNode(Node[ND], Generic[ND]): | ||
""" | ||
An input node. | ||
""" | ||
|
||
|
||
class FallbackNode(Node[ND], Generic[ND]): | ||
""" | ||
Node for trying several nodes in turn. | ||
""" | ||
|
||
__match_args__ = ("name", "alternatives") | ||
|
||
alternatives: list[Node[ND | None]] | ||
"The nodes that can possibly fulfil this node." | ||
|
||
def __init__(self, name: str, alternatives: list[Node[ND | None]]): | ||
super().__init__(name) | ||
self.alternatives = alternatives | ||
|
||
|
||
class LiteralNode(Node[ND], Generic[ND]): | ||
__match_args__ = ("name", "value") | ||
value: ND | ||
"The value associated with this node" | ||
|
||
def __init__(self, name: str, value: ND, *, types: set[type] | None = None): | ||
super().__init__(name, types=types) | ||
self.value = value | ||
|
||
|
||
class ComponentNode(Node[ND], Generic[ND]): | ||
__match_args__ = ("name", "component", "inputs", "connections") | ||
|
||
component: Component[ND] | ||
"The component associated with this node" | ||
|
||
inputs: dict[str, type | None] | ||
"The component's inputs." | ||
|
||
connections: dict[str, str] | ||
"The component's input connections." | ||
|
||
def __init__(self, name: str, component: Component[ND]): | ||
super().__init__(name) | ||
self.component = component | ||
self.connections = {} | ||
|
||
sig = signature(component) | ||
if sig.return_annotation == Signature.empty: | ||
warnings.warn( | ||
f"component {component} has no return type annotation", TypecheckWarning, 2 | ||
) | ||
else: | ||
self.types = set([sig.return_annotation]) | ||
|
||
self.inputs = { | ||
param.name: None if param.annotation == Signature.empty else param.annotation | ||
for param in sig.parameters.values() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# This file is part of LensKit. | ||
# Copyright (C) 2018-2023 Boise State University | ||
# Copyright (C) 2023-2024 Drexel University | ||
# Licensed under the MIT license, see LICENSE.md for details. | ||
# SPDX-License-Identifier: MIT | ||
|
||
""" | ||
Pipeline runner logic. | ||
""" | ||
|
||
# pyright: strict | ||
import logging | ||
from typing import Any, Literal, TypeAlias | ||
|
||
from . import Pipeline | ||
from .components import Component | ||
from .nodes import ComponentNode, FallbackNode, InputNode, LiteralNode, Node | ||
from .types import is_compatible_data | ||
|
||
_log = logging.getLogger(__name__) | ||
State: TypeAlias = Literal["pending", "in-progress", "finished", "failed"] | ||
|
||
|
||
class PipelineRunner: | ||
""" | ||
Node status and results for a single pipeline run. | ||
This class operates recursively; pipelines should never be so deep that | ||
recursion fails. | ||
""" | ||
|
||
pipe: Pipeline | ||
inputs: dict[str, Any] | ||
status: dict[str, State] | ||
state: dict[str, Any] | ||
|
||
def __init__(self, pipe: Pipeline, inputs: dict[str, Any]): | ||
self.pipe = pipe | ||
self.inputs = inputs | ||
self.status = {n.name: "pending" for n in pipe.nodes} | ||
self.state = {} | ||
|
||
def run(self, node: Node[Any], *, required: bool = True) -> Any: | ||
""" | ||
Run the pipleline to obtain the results of a node. | ||
""" | ||
status = self.status[node.name] | ||
if status == "finished": | ||
return self.state[node.name] | ||
elif status == "in-progress": | ||
raise RuntimeError(f"pipeline cycle encountered at {node}") | ||
elif status == "failed": # pragma: nocover | ||
raise RuntimeError(f"{node} previously failed") | ||
|
||
_log.debug("processing node %s", node) | ||
self.status[node.name] = "in-progress" | ||
try: | ||
self._run_node(node, required) | ||
self.status[node.name] = "finished" | ||
except Exception as e: | ||
_log.error("node %s failed with error %s", node, e) | ||
self.status[node.name] = "failed" | ||
raise e | ||
|
||
return self.state[node.name] | ||
|
||
def _run_node(self, node: Node[Any], required: bool) -> None: | ||
match node: | ||
case LiteralNode(name, value): | ||
self.state[name] = value | ||
case InputNode(name, types=types): | ||
self._inject_input(name, types, required) | ||
case ComponentNode(name, comp, inputs, wiring): | ||
self._run_component(name, comp, inputs, wiring) | ||
case FallbackNode(name, alts): | ||
self._run_fallback(name, alts) | ||
case _: # pragma: nocover | ||
raise RuntimeError(f"invalid node {node}") | ||
|
||
def _inject_input(self, name: str, types: set[type] | None, required: bool) -> None: | ||
val = self.inputs.get(name, None) | ||
if val is None and required and types and not is_compatible_data(None, *types): | ||
raise RuntimeError(f"input {name} not specified") | ||
|
||
if val is not None and types and not is_compatible_data(val, *types): | ||
raise TypeError(f"invalid data for input {name} (expected {types}, got {type(val)})") | ||
|
||
self.state[name] = val | ||
|
||
def _run_component( | ||
self, | ||
name: str, | ||
comp: Component[Any], | ||
inputs: dict[str, type | None], | ||
wiring: dict[str, str], | ||
) -> None: | ||
in_data = {} | ||
_log.debug("processing inputs for component %s", name) | ||
for iname, itype in inputs.items(): | ||
src = wiring.get(iname, None) | ||
if src is not None: | ||
snode = self.pipe.node(src) | ||
else: | ||
snode = self.pipe.get_default(iname) | ||
|
||
if snode is None: | ||
ival = None | ||
else: | ||
if itype: | ||
required = not is_compatible_data(None, itype) | ||
else: | ||
required = False | ||
ival = self.run(snode, required=required) | ||
|
||
if itype and not is_compatible_data(ival, itype): | ||
raise TypeError( | ||
f"input {iname} for component {name}" | ||
f" has invalid type {type(ival)} (expected {itype})" | ||
) | ||
|
||
in_data[iname] = ival | ||
|
||
_log.debug("running component %s", name) | ||
self.state[name] = comp(**in_data) | ||
|
||
def _run_fallback(self, name: str, alternatives: list[Node[Any]]) -> None: | ||
for alt in alternatives: | ||
val = self.run(alt, required=False) | ||
if val is not None: | ||
self.state[name] = val | ||
return | ||
|
||
# got this far, no alternatives | ||
raise RuntimeError(f"no alternative for {name} returned data") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# This file is part of LensKit. | ||
# Copyright (C) 2018-2023 Boise State University | ||
# Copyright (C) 2023-2024 Drexel University | ||
# Licensed under the MIT license, see LICENSE.md for details. | ||
# SPDX-License-Identifier: MIT | ||
|
||
# pyright: basic | ||
from __future__ import annotations | ||
|
||
import warnings | ||
from types import GenericAlias | ||
from typing import Union, _GenericAlias, get_args, get_origin # type: ignore | ||
|
||
import numpy as np | ||
|
||
|
||
class TypecheckWarning(UserWarning): | ||
"Warnings about type-checking logic." | ||
|
||
pass | ||
|
||
|
||
def is_compatible_type(typ: type, *targets: type) -> bool: | ||
""" | ||
Make a best-effort check whether a type is compatible with at least one | ||
target type. This function is limited by limitations of the Python type | ||
system and the effort required to (re-)write a full type checker. It is | ||
written to be over-accepting instead of over-restrictive, so it can be used | ||
to reject clearly incompatible types without rejecting combinations it | ||
cannot properly check. | ||
Args: | ||
typ: | ||
The type to check. | ||
targets: | ||
One or more target types to check against. | ||
Returns: | ||
``False`` if it is clear that the specified type is incompatible with | ||
all of the targets, and ``True`` otherwise. | ||
""" | ||
for target in targets: | ||
# try a straight subclass check first, but gracefully handle incompatible types | ||
try: | ||
if issubclass(typ, target): | ||
return True | ||
except TypeError: | ||
pass | ||
|
||
if isinstance(target, (GenericAlias, _GenericAlias)): | ||
tcls = get_origin(target) | ||
# if we're matching a raw type against a generic, just check the origin | ||
if isinstance(typ, GenericAlias): | ||
warnings.warn(f"cannot type-check generic type {typ}", TypecheckWarning) | ||
cls = get_origin(typ) | ||
if issubclass(cls, tcls): # type: ignore | ||
return True | ||
elif isinstance(typ, type): | ||
print(typ, type(typ)) | ||
if issubclass(typ, tcls): # type: ignore | ||
return True | ||
elif typ == int and issubclass(target, (float, complex)): # noqa: E721 | ||
return True | ||
elif typ == float and issubclass(target, complex): # noqa: E721 | ||
return True | ||
|
||
return False | ||
|
||
|
||
def is_compatible_data(obj: object, *targets: type) -> bool: | ||
""" | ||
Make a best-effort check whether a type is compatible with at least one | ||
target type. This function is limited by limitations of the Python type | ||
system and the effort required to (re-)write a full type checker. It is | ||
written to be over-accepting instead of over-restrictive, so it can be used | ||
to reject clearly incompatible types without rejecting combinations it | ||
cannot properly check. | ||
Args: | ||
typ: | ||
The type to check. | ||
targets: | ||
One or more target types to check against. | ||
Returns: | ||
``False`` if it is clear that the specified type is incompatible with | ||
all of the targets, and ``True`` otherwise. | ||
""" | ||
for target in targets: | ||
# try a straight subclass check first, but gracefully handle incompatible types | ||
try: | ||
if isinstance(obj, target): | ||
return True | ||
except TypeError: | ||
pass | ||
|
||
if get_origin(target) == Union: | ||
types = get_args(target) | ||
if is_compatible_data(obj, *types): | ||
return True | ||
elif isinstance(target, (GenericAlias, _GenericAlias)): | ||
tcls = get_origin(target) | ||
if isinstance(obj, np.ndarray) and tcls == np.ndarray: | ||
# check for type compatibility | ||
_sz, dtw = get_args(target) | ||
(dt,) = get_args(dtw) | ||
if issubclass(obj.dtype.type, dt): | ||
return True | ||
elif isinstance(tcls, type) and isinstance(obj, tcls): | ||
warnings.warn( | ||
f"cannot type-check object of type {type(obj)} against generic", | ||
TypecheckWarning, | ||
) | ||
return True | ||
elif isinstance(obj, int) and issubclass(target, (float, complex)): # noqa: E721 | ||
return True | ||
elif isinstance(obj, float) and issubclass(target, complex): # noqa: E721 | ||
return True | ||
|
||
return False |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# This file is part of LensKit. | ||
# Copyright (C) 2018-2023 Boise State University | ||
# Copyright (C) 2023-2024 Drexel University | ||
# Licensed under the MIT license, see LICENSE.md for details. | ||
# SPDX-License-Identifier: MIT | ||
|
||
""" | ||
Tests for the pipeline type-checking functions. | ||
""" | ||
|
||
import typing | ||
from collections.abc import Iterable, Sequence | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from numpy.typing import ArrayLike, NDArray | ||
|
||
from pytest import warns | ||
|
||
from lenskit.data.dataset import Dataset, MatrixDataset | ||
from lenskit.pipeline.types import TypecheckWarning, is_compatible_data, is_compatible_type | ||
|
||
|
||
def test_type_compat_identical(): | ||
assert is_compatible_type(int, int) | ||
assert is_compatible_type(str, str) | ||
|
||
|
||
def test_type_compat_subclass(): | ||
assert is_compatible_type(MatrixDataset, Dataset) | ||
|
||
|
||
def test_type_compat_assignable(): | ||
assert is_compatible_type(int, float) | ||
|
||
|
||
def test_type_raw_compat_with_generic(): | ||
assert is_compatible_type(list, list[int]) | ||
assert not is_compatible_type(set, list[int]) | ||
|
||
|
||
def test_type_compat_protocol(): | ||
assert is_compatible_type(list, Sequence) | ||
assert is_compatible_type(list, typing.Sequence) | ||
assert not is_compatible_type(set, Sequence) | ||
assert not is_compatible_type(set, typing.Sequence) | ||
assert is_compatible_type(set, Iterable) | ||
|
||
|
||
def test_type_compat_protocol_generic(): | ||
assert is_compatible_type(list, Sequence[int]) | ||
assert is_compatible_type(list, typing.Sequence[int]) | ||
|
||
|
||
def test_type_compat_generics_with_protocol(): | ||
assert is_compatible_type(list[int], Sequence[int]) | ||
|
||
|
||
def test_type_incompat_generics(): | ||
with warns(TypecheckWarning): | ||
assert is_compatible_type(list[int], list[str]) | ||
with warns(TypecheckWarning): | ||
assert is_compatible_type(list[int], Sequence[str]) | ||
|
||
|
||
def test_data_compat_basic(): | ||
assert is_compatible_data(72, int) | ||
assert is_compatible_data("hello", str) | ||
assert not is_compatible_data(72, str) | ||
|
||
|
||
def test_data_compat_float_assignabile(): | ||
assert is_compatible_data(72, float) | ||
|
||
|
||
def test_data_compat_generic(): | ||
assert is_compatible_data(["foo"], list[str]) | ||
# this is compatible because we can't check generics | ||
with warns(TypecheckWarning): | ||
assert is_compatible_data([72], list[str]) | ||
|
||
|
||
def test_numpy_typecheck(): | ||
assert is_compatible_data(np.arange(10, dtype="i8"), NDArray[np.int64]) | ||
assert is_compatible_data(np.arange(10, dtype="i4"), NDArray[np.int32]) | ||
assert is_compatible_data(np.arange(10), ArrayLike) | ||
assert is_compatible_data(np.arange(10), NDArray[np.integer]) | ||
# numpy types can be checked | ||
assert not is_compatible_data(np.arange(10), NDArray[np.float64]) | ||
|
||
|
||
def test_pandas_typecheck(): | ||
assert is_compatible_data(pd.Series(["a", "b"]), ArrayLike) |