From a9f51a6a1bb94459a85b2ec5938ec23fac04f5e7 Mon Sep 17 00:00:00 2001 From: Avik Basu <3485425+ab93@users.noreply.github.com> Date: Wed, 7 Jun 2023 16:34:28 -0700 Subject: [PATCH] feat!: introduce numalogic blocks (#206) - block as a way of abstraction of ML related tasks - block pipeline to chain multiple blocks together - support saving/loading of artifacts (only redis registry can support this) - improve typing --------- Signed-off-by: Avik Basu --- numalogic/base.py | 6 +- numalogic/blocks/__init__.py | 29 +++++ numalogic/blocks/_base.py | 138 +++++++++++++++++++++ numalogic/blocks/_nn.py | 93 +++++++++++++++ numalogic/blocks/_transform.py | 132 ++++++++++++++++++++ numalogic/blocks/pipeline.py | 173 +++++++++++++++++++++++++++ numalogic/registry/artifact.py | 29 ++++- numalogic/tools/types.py | 16 ++- poetry.lock | 207 +++----------------------------- pyproject.toml | 1 + tests/blocks/__init__.py | 0 tests/blocks/test_blocks.py | 32 +++++ tests/blocks/test_pipeline.py | 212 +++++++++++++++++++++++++++++++++ 13 files changed, 868 insertions(+), 200 deletions(-) create mode 100644 numalogic/blocks/__init__.py create mode 100644 numalogic/blocks/_base.py create mode 100644 numalogic/blocks/_nn.py create mode 100644 numalogic/blocks/_transform.py create mode 100644 numalogic/blocks/pipeline.py create mode 100644 tests/blocks/__init__.py create mode 100644 tests/blocks/test_blocks.py create mode 100644 tests/blocks/test_pipeline.py diff --git a/numalogic/base.py b/numalogic/base.py index 259b611b..db0dab9b 100644 --- a/numalogic/base.py +++ b/numalogic/base.py @@ -16,10 +16,10 @@ import numpy.typing as npt import pytorch_lightning as pl -from sklearn.base import TransformerMixin, BaseEstimator, OutlierMixin +from sklearn.base import TransformerMixin, OutlierMixin -class BaseTransformer(TransformerMixin, BaseEstimator): +class BaseTransformer(TransformerMixin): """Base class for all transformer classes.""" pass @@ -47,7 +47,7 @@ class TorchModel(pl.LightningModule, metaclass=ABCMeta): pass -class BaseThresholdModel(OutlierMixin, BaseEstimator): +class BaseThresholdModel(OutlierMixin): """Base class for all threshold models.""" pass diff --git a/numalogic/blocks/__init__.py b/numalogic/blocks/__init__.py new file mode 100644 index 00000000..748319c3 --- /dev/null +++ b/numalogic/blocks/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Module for numalogic blocks which are units of computation that can be +chained together to form a pipeline if needed. A block can be stateful or stateless. +""" + +from numalogic.blocks._base import Block +from numalogic.blocks._nn import NNBlock +from numalogic.blocks._transform import PreprocessBlock, PostprocessBlock, ThresholdBlock +from numalogic.blocks.pipeline import BlockPipeline + +__all__ = [ + "Block", + "NNBlock", + "PreprocessBlock", + "PostprocessBlock", + "ThresholdBlock", + "BlockPipeline", +] diff --git a/numalogic/blocks/_base.py b/numalogic/blocks/_base.py new file mode 100644 index 00000000..072431ac --- /dev/null +++ b/numalogic/blocks/_base.py @@ -0,0 +1,138 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABCMeta, abstractmethod +from typing import Generic, Union + +import numpy.typing as npt + +from numalogic.tools.types import artifact_t, state_dict_t + + +class Block(Generic[artifact_t], metaclass=ABCMeta): + """ + Base class for all blocks. + + A block is a unit of computation that can be + chained together to form a pipeline. A block can be stateful or stateless. + + A stateful block is one that has a state that can be updated by calling the + block with new data. A stateless block is one that does not have a state and + can be called with new data without any side effects. + + A block can be used as a callable. The call method is an alias for the run method. + + Args: + ---- + artifact: The artifact that the block operates on. + name: The name of the block + stateful: Whether the block is stateful or not. (default: True) + """ + + __slots__ = ("_name", "_stateful", "_artifact") + + def __init__(self, artifact: artifact_t, name: str, stateful: bool = True): + self._artifact = artifact + self._name = name + self._stateful = stateful + + @property + def name(self) -> str: + """The name of the block.""" + return self._name + + @property + def stateful(self) -> bool: + """Whether the block is stateful or not.""" + return self._stateful + + @property + def artifact(self) -> artifact_t: + """The artifact that the block operates on.""" + return self._artifact + + @property + def artifact_state(self) -> Union[artifact_t, state_dict_t]: + """ + The state of the artifact that needs to be serialized for saving. + + This needs to be overridden if something other than the artifact itself + needs to be serialized, e.g. statedict, or a torchscript module. + """ + return self._artifact + + @artifact_state.setter + def artifact_state(self, state: Union[artifact_t, state_dict_t]) -> None: + """ + The state of the artifact that needs to be deserialized for loading. + + This needs to be overridden if something other than the artifact itself + needs to be deserialized, e.g. statedict, or a torchscript module. + """ + self._artifact = state + + def __call__(self, *args, **kwargs) -> npt.NDArray[float]: + """Alias for the run method.""" + return self.run(*args, **kwargs) + + @abstractmethod + def fit(self, data: npt.NDArray[float], *args, **kwargs): + """ + Train the block on the input data. + + Implement this method to train the block, using the block's artifact. + + Args: + ---- + data: The input data to train the block on. + *args: Additional arguments for the block. + **kwargs: Additional keyword arguments for fitting the block. + """ + pass + + @abstractmethod + def run(self, stream: npt.NDArray[float], *args, **kwargs) -> npt.NDArray[float]: + """ + Run inference on the block on the streaming input data. + + Implement this method to run inference on the block, + using the block's artifact. + + Args: + ---- + stream: The streaming input data. + *args: Additional arguments for the block. + **kwargs: Additional keyword arguments for the block. + """ + pass + + +class StatelessBlock(Block, metaclass=ABCMeta): + """ + Base class for all stateless blocks. + + A stateless block is one that does not have a state and + can be called with new data without any side effects. + """ + + def __init__(self, artifact: artifact_t, name: str): + super().__init__(artifact, name, stateful=False) + + def fit(self, data: npt.NDArray[float], *args, **kwargs) -> npt.NDArray[float]: + """ + A no-op for stateless blocks. + + Args: + ---- + data: The input data to train the block on. + *args: Additional arguments for the block. + **kwargs: Additional keyword arguments for fitting the block. + """ + return self.run(data, *args, **kwargs) diff --git a/numalogic/blocks/_nn.py b/numalogic/blocks/_nn.py new file mode 100644 index 00000000..a343b8d0 --- /dev/null +++ b/numalogic/blocks/_nn.py @@ -0,0 +1,93 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.utils.data import DataLoader +import numpy.typing as npt + +from numalogic.blocks import Block +from numalogic.models.autoencoder import AutoencoderTrainer +from numalogic.tools.data import StreamingDataset +from numalogic.tools.types import nn_model_t, state_dict_t + + +class NNBlock(Block): + """ + A block that uses a neural network model to operate on the artifact. + + Serialization is done by saving state dict of the model. + + Args: + ---- + model: The neural network model. + seq_len: The sequence length of the input data. + name: The name of the block. Defaults to "nn". + """ + + __slots__ = ("seq_len",) + + def __init__(self, model: nn_model_t, seq_len: int, name: str = "nn"): + super().__init__(artifact=model, name=name) + self.seq_len = seq_len + + @property + def artifact_state(self) -> state_dict_t: + """The state dict of the model.""" + return self._artifact.state_dict() + + @artifact_state.setter + def artifact_state(self, artifact_state: state_dict_t) -> None: + """Set the state dict of the model.""" + self._artifact.load_state_dict(artifact_state) + + def fit( + self, input_: npt.NDArray[float], batch_size: int = 64, **trainer_kwargs + ) -> npt.NDArray[float]: + """ + Train the model on the input data. + + Args: + ---- + input_: The input data. + batch_size: The batch size to use for training. + trainer_kwargs: Keyword arguments to pass to the lightning trainer. + + Returns + ------- + The error of the model on the input data. + """ + trainer = AutoencoderTrainer(**trainer_kwargs) + ds = StreamingDataset(input_, self.seq_len) + trainer.fit(self._artifact, train_dataloaders=DataLoader(ds, batch_size=batch_size)) + reconerr = trainer.predict( + self._artifact, dataloaders=DataLoader(ds, batch_size=batch_size) + ) + return reconerr.numpy() + + def run(self, input_: npt.NDArray[float], **_) -> npt.NDArray[float]: + """ + Perform forward pass on the streaming input data. + + Args: + ---- + input_: The streaming input data. + + Returns + ------- + The error of the model on the input data. + """ + input_ = torch.from_numpy(input_).float() + # Add a batch dimension + input_ = torch.unsqueeze(input_, dim=0).contiguous() + self._artifact.eval() + with torch.no_grad(): + reconerr = self._artifact.predict_step(input_, batch_idx=0) + return torch.squeeze(reconerr, dim=0).numpy() diff --git a/numalogic/blocks/_transform.py b/numalogic/blocks/_transform.py new file mode 100644 index 00000000..6f62f8f1 --- /dev/null +++ b/numalogic/blocks/_transform.py @@ -0,0 +1,132 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy.typing as npt + +from numalogic.blocks._base import Block, StatelessBlock +from numalogic.tools.types import transform_t, thresh_t + + +class PreprocessBlock(Block): + """ + A stateful block that is used to preprocess the input data, before it is fed to an ML model. + + Serialization is done by saving the preprocessor object. + + Args: + ---- + preprocessor: The preprocessor object. + name: The name of the block. Defaults to "preprocess". + stateful: Whether the block is stateful or not. Defaults to True. + """ + + def __init__(self, preprocessor: transform_t, name: str = "preprocess", stateful: bool = True): + super().__init__(artifact=preprocessor, name=name, stateful=stateful) + + def fit(self, input_: npt.NDArray[float], **__) -> npt.NDArray[float]: + """ + Fit the preprocessor on the input data. + + Args: + ---- + input_: The input data to train on. + + Returns + ------- + The transformed/scaled input data. + """ + return self._artifact.fit_transform(input_) + + def run(self, input_: npt.NDArray[float], **__) -> npt.NDArray[float]: + """ + Transform the streaming input data. + + Args: + ---- + input_: The streaming input data. + + Returns + ------- + The transformed/scaled input data. + """ + return self._artifact.transform(input_) + + +class ThresholdBlock(Block): + """ + A stateful block that is used to threshold the output of an ML model. + + Serialization is done by saving the threshold object. + + Args: + ---- + thresh_model: The threshold model object. + name: The name of the block. Defaults to "threshold". + """ + + def __init__(self, thresh_model: thresh_t, name: str = "threshold"): + super().__init__(artifact=thresh_model, name=name) + + def fit(self, input_: npt.NDArray[float], **__) -> npt.NDArray[float]: + """ + Fit the threshold model on the training data. + + Args: + ---- + input_: The input data to train on. + + Returns + ------- + The anomaly scores of the training data. + """ + self._artifact.fit(input_) + return self._artifact.score_samples(input_) + + def run(self, input_: npt.NDArray[float], **__) -> npt.NDArray[float]: + """ + Transform the streaming input data. + + Args: + ---- + input_: The streaming input data. + + Returns + ------- + The anomaly score of the streaming data. + """ + return self._artifact.score_samples(input_) + + +class PostprocessBlock(StatelessBlock): + """ + A stateless block that is used to postprocess the output of an ML model. + + Args: + ---- + postprocessor: The postprocessor object. + """ + + def __init__(self, postprocessor: transform_t, name: str = "postprocess"): + super().__init__(artifact=postprocessor, name=name) + + def run(self, input_: npt.NDArray[float], **__) -> npt.NDArray[float]: + """ + Transform the streaming input data. + + Args: + ---- + input_: The streaming input data. + + Returns + ------- + The postprocessed streaming data. + """ + return self._artifact.transform(input_) diff --git a/numalogic/blocks/pipeline.py b/numalogic/blocks/pipeline.py new file mode 100644 index 00000000..748afa4e --- /dev/null +++ b/numalogic/blocks/pipeline.py @@ -0,0 +1,173 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence +from collections.abc import Iterator + +import numpy.typing as npt + +from numalogic.blocks._transform import Block +from numalogic.registry import ArtifactManager +from numalogic.tools.types import artifact_t + + +class BlockPipeline(Sequence[Block]): + """ + A pipeline of blocks. + + A pipeline is a sequence of blocks that can be chained together to form a + pipeline. A pipeline can be used as a callable. The call method is an alias + for the run method. + + Args: + ---- + blocks: A list/tuple of blocks that form the pipeline. + registry: The registry to use for storing artifacts. + """ + + __slots__ = ("_blocks", "_registry") + + def __init__(self, *blocks: Block, registry: ArtifactManager = None): + self._blocks = blocks + self._registry = registry + + def __call__(self, *args, **kwargs): + return self.run(*args, **kwargs) + + def __getitem__(self, idx: int) -> Block: + """Get the block at the given index.""" + return self._blocks[idx] + + def __len__(self) -> int: + """Get the number of blocks in the pipeline.""" + return len(self._blocks) + + def __iter__(self) -> Iterator[Block]: + """Get an iterator over the blocks in the pipeline.""" + return iter(self._blocks) + + def named_blocks(self) -> Iterator[tuple[str, Block]]: + """Get an iterator over the blocks in the pipeline along with their names.""" + names = [block.name for block in self._blocks] + return zip(names, self._blocks) + + def _get_block_params(self, **fit_params) -> dict[str, dict]: + """ + Get the parameters for each block from the fit_params. + + Inspired by sklearn Pipeline + (https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.make_pipeline.html). + + Args: + ---- + fit_params : keyword dict of string -> object + + Returns + ------- + A nested dict of blockname -> parameter -> value + + Raises + ------ + ValueError: If the keyword arguments are not of the form + blockname__parameter, e.g. `block_pipeline.fit(data, nn__max_epochs=50)` + """ + block_params = {name: {} for name, block in self.named_blocks()} + err_msg = ( + "Invalid kwarg: {pname} found. Keyword args of " + "BlockPipeline must be of the form blockname__parameter, " + "e.g. `block_pipeline.fit(data, nn__max_epochs=50)`" + ) + for pname, pval in fit_params.items(): + if "__" not in pname: + raise ValueError(err_msg.format(pname=pname)) + blockname, param = pname.split("__", 1) + block_params[blockname][param] = pval + return block_params + + def fit(self, input_: npt.NDArray[float], **fit_params) -> npt.NDArray[float]: + """ + Fit the pipeline on the input data. + + Args: + ---- + input_: The input data to fit the pipeline on. + fit_params : dict of string -> object + Parameters passed to the ``fit`` method of each block, where + each parameter name is prefixed such that parameter ``p`` for step + ``s`` has key ``s__p``. + + Returns + ------- + Final fit block output. + """ + fit_params = self._get_block_params(**fit_params) + for block in self._blocks: + input_ = block.fit(input_, **fit_params.get(block.name, {})) + return input_ + + def run(self, data: npt.NDArray[float]) -> npt.NDArray[float]: + """ + Perform inference on streaming data. + + Args: + ---- + data: Streaming input data + + """ + for block in self._blocks: + data = block.run(data) + return data + + def save(self, skeys: Sequence[str], dkeys: Sequence[str]) -> None: + """ + Save the state of the pipeline. + + Args: + ---- + skeys: Sequence of source keys. + dkeys: Sequence of destination keys. + + Raises + ------ + ValueError: If no registry is provided. + """ + if not self._registry: + raise ValueError("No registry provided.") + + artifacts: dict[str, artifact_t] = {} + for block in self._blocks: + if not block.stateful: + continue + artifacts[block.name] = block.artifact_state + self._registry.save(skeys, dkeys, artifacts) + + def load(self, skeys: Sequence[str], dkeys: Sequence[str]) -> None: + """ + Load the state of the pipeline. + + Args: + ---- + skeys: Sequence of source keys. + dkeys: Sequence of destination keys. + + Raises + ------ + ValueError: If no registry is provided. + """ + if not self._registry: + raise ValueError("No registry provided.") + + artifact_data = self._registry.load(skeys, dkeys) + artifacts = artifact_data.artifact + for block in self._blocks: + if not block.stateful: + continue + block.artifact_state = artifacts[block.name] diff --git a/numalogic/registry/artifact.py b/numalogic/registry/artifact.py index f6e5137e..e519eed5 100644 --- a/numalogic/registry/artifact.py +++ b/numalogic/registry/artifact.py @@ -11,18 +11,27 @@ from dataclasses import dataclass -from typing import Any, Generic, TypeVar +from typing import Any, Generic, TypeVar, Union -from numalogic.tools.types import artifact_t, KEYS, META_T, META_VT, EXTRA_T +from numalogic.tools.types import artifact_t, KEYS, META_T, META_VT, EXTRA_T, state_dict_t @dataclass class ArtifactData: - """Dataclass to hold the artifact, its metadata and other extra info.""" + """ + Dataclass to hold the artifact, its metadata and other extra info. + + Args: + ---- + artifact: artifact to be saved; can be a model instance, a state_dict. + metadata: additional metadata surrounding the artifact that needs to be saved. + extras: any other extra information that needs to be saved. + + """ __slots__ = ("artifact", "metadata", "extras") - artifact: artifact_t + artifact: Union[artifact_t, state_dict_t] metadata: META_T extras: EXTRA_T @@ -34,7 +43,9 @@ class ArtifactData: class ArtifactManager(Generic[KEYS, A_D]): """Abstract base class for artifact save, load and delete. - :param uri: server/connection uri + Args: + ---- + uri: server/connection uri """ __slots__ = ("uri",) @@ -56,7 +67,13 @@ def load( """ raise NotImplementedError("Please implement this method!") - def save(self, skeys: KEYS, dkeys: KEYS, artifact: artifact_t, **metadata: META_VT) -> Any: + def save( + self, + skeys: KEYS, + dkeys: KEYS, + artifact: Union[artifact_t, state_dict_t], + **metadata: META_VT + ) -> Any: r"""Saves the artifact into mlflow registry and updates version. Args: diff --git a/numalogic/tools/types.py b/numalogic/tools/types.py index d883d665..f73988da 100644 --- a/numalogic/tools/types.py +++ b/numalogic/tools/types.py @@ -10,11 +10,13 @@ # limitations under the License. -from typing import Union, TypeVar from collections.abc import Sequence +from typing import Union, TypeVar + from sklearn.base import BaseEstimator -from torch import nn +from torch import Tensor +from numalogic.base import TorchModel, BaseThresholdModel, BaseTransformer try: from redis.client import AbstractRedis @@ -23,7 +25,15 @@ else: redis_client_t = TypeVar("redis_client_t", bound=AbstractRedis, covariant=True) -artifact_t = TypeVar("artifact_t", bound=Union[nn.Module, BaseEstimator], covariant=True) +artifact_t = TypeVar( + "artifact_t", + bound=Union[TorchModel, BaseThresholdModel, BaseTransformer], + covariant=True, +) +nn_model_t = TypeVar("nn_model_t", bound=TorchModel) +state_dict_t = TypeVar("state_dict_t", bound=dict[str, Tensor], covariant=True) +transform_t = TypeVar("transform_t", bound=Union[BaseTransformer, BaseEstimator], covariant=True) +thresh_t = TypeVar("thresh_t", bound=BaseThresholdModel, covariant=True) META_T = TypeVar("META_T", bound=dict[str, Union[str, float, int, list, dict]]) META_VT = TypeVar("META_VT", str, int, float, list, dict) EXTRA_T = TypeVar("EXTRA_T", bound=dict[str, Union[str, list, dict]]) diff --git a/poetry.lock b/poetry.lock index 8a86407c..da4f11b3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,10 +1,9 @@ -# This file is automatically @generated by Poetry and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "aiohttp" version = "3.8.4" description = "Async http client/server framework (asyncio)" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -113,7 +112,6 @@ speedups = ["Brotli", "aiodns", "cchardet"] name = "aiosignal" version = "1.3.1" description = "aiosignal: a list of registered asynchronous callbacks" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -128,7 +126,6 @@ frozenlist = ">=1.1.0" name = "antlr4-python3-runtime" version = "4.9.3" description = "ANTLR 4.9.3 runtime for Python 3.7" -category = "main" optional = false python-versions = "*" files = [ @@ -139,7 +136,6 @@ files = [ name = "anyio" version = "3.7.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -161,7 +157,6 @@ trio = ["trio (<0.22)"] name = "appnope" version = "0.1.3" description = "Disable App Nap on macOS >= 10.9" -category = "dev" optional = false python-versions = "*" files = [ @@ -173,7 +168,6 @@ files = [ name = "argon2-cffi" version = "21.3.0" description = "The secure Argon2 password hashing algorithm." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -193,7 +187,6 @@ tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pytest"] name = "argon2-cffi-bindings" version = "21.2.0" description = "Low-level CFFI bindings for Argon2" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -231,7 +224,6 @@ tests = ["pytest"] name = "arrow" version = "1.2.3" description = "Better dates & times for Python" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -246,7 +238,6 @@ python-dateutil = ">=2.7.0" name = "asttokens" version = "2.2.1" description = "Annotate AST trees with source code positions" -category = "dev" optional = false python-versions = "*" files = [ @@ -264,7 +255,6 @@ test = ["astroid", "pytest"] name = "async-timeout" version = "4.0.2" description = "Timeout context manager for asyncio programs" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -276,7 +266,6 @@ files = [ name = "attrs" version = "23.1.0" description = "Classes Without Boilerplate" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -295,7 +284,6 @@ tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pyte name = "backcall" version = "0.2.0" description = "Specifications for callback functions passed in to an API" -category = "dev" optional = false python-versions = "*" files = [ @@ -307,7 +295,6 @@ files = [ name = "beautifulsoup4" version = "4.12.2" description = "Screen-scraping library" -category = "dev" optional = false python-versions = ">=3.6.0" files = [ @@ -326,7 +313,6 @@ lxml = ["lxml"] name = "black" version = "23.3.0" description = "The uncompromising code formatter." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -376,7 +362,6 @@ uvloop = ["uvloop (>=0.15.2)"] name = "bleach" version = "6.0.0" description = "An easy safelist-based HTML-sanitizing tool." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -395,7 +380,6 @@ css = ["tinycss2 (>=1.1.0,<1.2)"] name = "cachetools" version = "5.3.1" description = "Extensible memoizing collections and decorators" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -407,7 +391,6 @@ files = [ name = "certifi" version = "2023.5.7" description = "Python package for providing Mozilla's CA Bundle." -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -419,7 +402,6 @@ files = [ name = "cffi" version = "1.15.1" description = "Foreign Function Interface for Python calling C code." -category = "dev" optional = false python-versions = "*" files = [ @@ -496,7 +478,6 @@ pycparser = "*" name = "cfgv" version = "3.3.1" description = "Validate configuration and produce human readable error messages." -category = "dev" optional = false python-versions = ">=3.6.1" files = [ @@ -508,7 +489,6 @@ files = [ name = "charset-normalizer" version = "3.1.0" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." -category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -593,7 +573,6 @@ files = [ name = "click" version = "8.1.3" description = "Composable command line interface toolkit" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -608,7 +587,6 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "cloudpickle" version = "2.2.1" description = "Extended pickling support for Python objects" -category = "main" optional = true python-versions = ">=3.6" files = [ @@ -620,7 +598,6 @@ files = [ name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." -category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -632,7 +609,6 @@ files = [ name = "comm" version = "0.1.3" description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -652,7 +628,6 @@ typing = ["mypy (>=0.990)"] name = "contourpy" version = "1.0.7" description = "Python library for calculating contours of 2D quadrilateral grids" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -727,7 +702,6 @@ test-no-images = ["pytest"] name = "coverage" version = "7.2.7" description = "Code coverage measurement for Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -803,7 +777,6 @@ toml = ["tomli"] name = "cycler" version = "0.11.0" description = "Composable style cycles" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -815,7 +788,6 @@ files = [ name = "databricks-cli" version = "0.17.7" description = "A command line interface for Databricks" -category = "main" optional = true python-versions = "*" files = [ @@ -836,7 +808,6 @@ urllib3 = ">=1.26.7,<2.0.0" name = "debugpy" version = "1.6.7" description = "An implementation of the Debug Adapter Protocol for Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -864,7 +835,6 @@ files = [ name = "decorator" version = "5.1.1" description = "Decorators for Humans" -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -876,7 +846,6 @@ files = [ name = "defusedxml" version = "0.7.1" description = "XML bomb protection for Python stdlib modules" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -888,7 +857,6 @@ files = [ name = "distlib" version = "0.3.6" description = "Distribution utilities" -category = "dev" optional = false python-versions = "*" files = [ @@ -900,7 +868,6 @@ files = [ name = "entrypoints" version = "0.4" description = "Discover and load entry points from installed packages." -category = "main" optional = true python-versions = ">=3.6" files = [ @@ -912,7 +879,6 @@ files = [ name = "exceptiongroup" version = "1.1.1" description = "Backport of PEP 654 (exception groups)" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -927,7 +893,6 @@ test = ["pytest (>=6)"] name = "executing" version = "1.2.0" description = "Get the currently executing AST node of a frame, and other information" -category = "dev" optional = false python-versions = "*" files = [ @@ -942,7 +907,6 @@ tests = ["asttokens", "littleutils", "pytest", "rich"] name = "fakeredis" version = "2.14.1" description = "Python implementation of redis API, can be used for testing purposes." -category = "dev" optional = false python-versions = ">=3.7,<4.0" files = [ @@ -962,7 +926,6 @@ lua = ["lupa (>=1.14,<2.0)"] name = "fastjsonschema" version = "2.17.1" description = "Fastest Python implementation of JSON schema" -category = "dev" optional = false python-versions = "*" files = [ @@ -977,7 +940,6 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc name = "filelock" version = "3.12.0" description = "A platform independent file lock." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -993,7 +955,6 @@ testing = ["covdefaults (>=2.3)", "coverage (>=7.2.3)", "diff-cover (>=7.5)", "p name = "fonttools" version = "4.39.4" description = "Tools to manipulate font files" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1019,7 +980,6 @@ woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] name = "fqdn" version = "1.5.1" description = "Validates fully-qualified domain names against RFC 1123, so that they are acceptable to modern bowsers" -category = "dev" optional = false python-versions = ">=2.7, !=3.0, !=3.1, !=3.2, !=3.3, !=3.4, <4" files = [ @@ -1031,7 +991,6 @@ files = [ name = "freezegun" version = "1.2.2" description = "Let your Python tests travel through time" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1046,7 +1005,6 @@ python-dateutil = ">=2.7" name = "frozenlist" version = "1.3.3" description = "A list-like structure which implements collections.abc.MutableSequence" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1130,7 +1088,6 @@ files = [ name = "fsspec" version = "2023.5.0" description = "File-system specification" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1170,7 +1127,6 @@ tqdm = ["tqdm"] name = "gitdb" version = "4.0.10" description = "Git Object Database" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1185,7 +1141,6 @@ smmap = ">=3.0.1,<6" name = "gitpython" version = "3.1.31" description = "GitPython is a Python library used to interact with Git repositories" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1200,7 +1155,6 @@ gitdb = ">=4.0.1,<5" name = "google-api-core" version = "2.11.0" description = "Google API client core library" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1223,7 +1177,6 @@ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0dev)"] name = "google-auth" version = "2.19.1" description = "Google Authentication Library" -category = "main" optional = true python-versions = ">=3.6" files = [ @@ -1249,7 +1202,6 @@ requests = ["requests (>=2.20.0,<3.0.0dev)"] name = "google-cloud" version = "0.34.0" description = "API Client library for Google Cloud" -category = "main" optional = true python-versions = "*" files = [ @@ -1261,7 +1213,6 @@ files = [ name = "googleapis-common-protos" version = "1.59.0" description = "Common protobufs used in Google APIs" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1279,7 +1230,6 @@ grpc = ["grpcio (>=1.44.0,<2.0.0dev)"] name = "grpcio" version = "1.54.2" description = "HTTP/2-based RPC framework" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1337,7 +1287,6 @@ protobuf = ["grpcio-tools (>=1.54.2)"] name = "grpcio-tools" version = "1.54.2" description = "Protobuf code generator for gRPC" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1397,7 +1346,6 @@ setuptools = "*" name = "hiredis" version = "2.2.3" description = "Python wrapper for hiredis" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1496,7 +1444,6 @@ files = [ name = "identify" version = "2.5.24" description = "File identification library for Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1511,7 +1458,6 @@ license = ["ukkonen"] name = "idna" version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" -category = "main" optional = false python-versions = ">=3.5" files = [ @@ -1523,7 +1469,6 @@ files = [ name = "importlib-metadata" version = "6.6.0" description = "Read metadata from Python packages" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1543,7 +1488,6 @@ testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packag name = "importlib-resources" version = "5.12.0" description = "Read resources from Python packages" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1562,7 +1506,6 @@ testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-chec name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1574,7 +1517,6 @@ files = [ name = "ipykernel" version = "6.23.1" description = "IPython Kernel for Jupyter" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1588,7 +1530,7 @@ comm = ">=0.1.1" debugpy = ">=1.6.5" ipython = ">=7.23.1" jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" matplotlib-inline = ">=0.1" nest-asyncio = "*" packaging = "*" @@ -1608,7 +1550,6 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio" name = "ipympl" version = "0.9.3" description = "Matplotlib Jupyter Extension" -category = "dev" optional = false python-versions = "*" files = [ @@ -1632,7 +1573,6 @@ docs = ["Sphinx (>=1.5)", "myst-nb", "sphinx-book-theme", "sphinx-copybutton", " name = "ipython" version = "8.14.0" description = "IPython: Productive Interactive Computing" -category = "dev" optional = false python-versions = ">=3.9" files = [ @@ -1672,7 +1612,6 @@ test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.21)", "pa name = "ipython-autotime" version = "0.3.1" description = "Time everything in IPython" -category = "dev" optional = false python-versions = "*" files = [ @@ -1687,7 +1626,6 @@ ipython = "*" name = "ipython-genutils" version = "0.2.0" description = "Vestigial utilities from IPython" -category = "dev" optional = false python-versions = "*" files = [ @@ -1699,7 +1637,6 @@ files = [ name = "ipywidgets" version = "8.0.6" description = "Jupyter interactive widgets" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1721,7 +1658,6 @@ test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"] name = "isoduration" version = "20.11.0" description = "Operations with ISO 8601 durations" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1736,7 +1672,6 @@ arrow = ">=0.15.0" name = "jedi" version = "0.18.2" description = "An autocompletion tool for Python that can be used for text editors." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1756,7 +1691,6 @@ testing = ["Django (<3.1)", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] name = "jinja2" version = "3.1.2" description = "A very fast and expressive template engine." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1774,7 +1708,6 @@ i18n = ["Babel (>=2.7)"] name = "joblib" version = "1.2.0" description = "Lightweight pipelining with Python functions" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1786,7 +1719,6 @@ files = [ name = "jsonpointer" version = "2.3" description = "Identify specific nodes in a JSON document (RFC 6901)" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -1798,7 +1730,6 @@ files = [ name = "jsonschema" version = "4.17.3" description = "An implementation of JSON Schema validation for Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1826,7 +1757,6 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339- name = "jupyter" version = "1.0.0" description = "Jupyter metapackage. Install all the Jupyter components in one go." -category = "dev" optional = false python-versions = "*" files = [ @@ -1847,7 +1777,6 @@ qtconsole = "*" name = "jupyter-client" version = "8.2.0" description = "Jupyter protocol implementation and client libraries" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1857,7 +1786,7 @@ files = [ [package.dependencies] importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} -jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" python-dateutil = ">=2.8.2" pyzmq = ">=23.0" tornado = ">=6.2" @@ -1871,7 +1800,6 @@ test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pyt name = "jupyter-console" version = "6.6.3" description = "Jupyter terminal console" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1883,7 +1811,7 @@ files = [ ipykernel = ">=6.14" ipython = "*" jupyter-client = ">=7.0.0" -jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" prompt-toolkit = ">=3.0.30" pygments = "*" pyzmq = ">=17" @@ -1896,7 +1824,6 @@ test = ["flaky", "pexpect", "pytest"] name = "jupyter-core" version = "5.3.0" description = "Jupyter core package. A base package on which Jupyter projects rely." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1917,7 +1844,6 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] name = "jupyter-events" version = "0.6.3" description = "Jupyter Event System library" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1942,7 +1868,6 @@ test = ["click", "coverage", "pre-commit", "pytest (>=7.0)", "pytest-asyncio (>= name = "jupyter-server" version = "2.6.0" description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1955,7 +1880,7 @@ anyio = ">=3.1.0" argon2-cffi = "*" jinja2 = "*" jupyter-client = ">=7.4.4" -jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" jupyter-events = ">=0.6.0" jupyter-server-terminals = "*" nbconvert = ">=6.4.4" @@ -1979,7 +1904,6 @@ test = ["ipykernel", "pre-commit", "pytest (>=7.0)", "pytest-console-scripts", " name = "jupyter-server-terminals" version = "0.4.4" description = "A Jupyter Server Extension Providing Terminals." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1999,7 +1923,6 @@ test = ["coverage", "jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-cov", name = "jupyterlab-pygments" version = "0.2.2" description = "Pygments theme using JupyterLab CSS variables" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2011,7 +1934,6 @@ files = [ name = "jupyterlab-widgets" version = "3.0.7" description = "Jupyter interactive widgets for JupyterLab" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2023,7 +1945,6 @@ files = [ name = "kiwisolver" version = "1.4.4" description = "A fast implementation of the Cassowary constraint solver" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2101,7 +2022,6 @@ files = [ name = "lightning-utilities" version = "0.8.0" description = "PyTorch Lightning Sample project." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2123,7 +2043,6 @@ typing = ["mypy (>=1.0.0)"] name = "markupsafe" version = "2.1.3" description = "Safely add untrusted strings to HTML/XML markup." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2183,7 +2102,6 @@ files = [ name = "matplotlib" version = "3.7.1" description = "Python plotting package" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2246,7 +2164,6 @@ python-dateutil = ">=2.7" name = "matplotlib-inline" version = "0.1.6" description = "Inline Matplotlib backend for Jupyter" -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2261,7 +2178,6 @@ traitlets = "*" name = "mistune" version = "2.0.5" description = "A sane Markdown parser with useful plugins and renderers" -category = "dev" optional = false python-versions = "*" files = [ @@ -2273,7 +2189,6 @@ files = [ name = "mlflow-skinny" version = "2.4.0" description = "MLflow: A Platform for ML Development and Productionization" -category = "main" optional = true python-versions = ">=3.8" files = [ @@ -2305,7 +2220,6 @@ sqlserver = ["mlflow-dbstore"] name = "mpmath" version = "1.3.0" description = "Python library for arbitrary-precision floating-point arithmetic" -category = "dev" optional = false python-versions = "*" files = [ @@ -2323,7 +2237,6 @@ tests = ["pytest (>=4.6)"] name = "multidict" version = "6.0.4" description = "multidict implementation" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2407,7 +2320,6 @@ files = [ name = "mypy-extensions" version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2419,7 +2331,6 @@ files = [ name = "nb-black" version = "1.0.7" description = "A simple extension for Jupyter Notebook and Jupyter Lab to beautify Python code automatically using Black." -category = "dev" optional = false python-versions = "*" files = [ @@ -2433,7 +2344,6 @@ ipython = "*" name = "nbclassic" version = "1.0.0" description = "Jupyter Notebook as a Jupyter Server extension." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2469,7 +2379,6 @@ test = ["coverage", "nbval", "pytest", "pytest-cov", "pytest-jupyter", "pytest-p name = "nbclient" version = "0.8.0" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." -category = "dev" optional = false python-versions = ">=3.8.0" files = [ @@ -2479,7 +2388,7 @@ files = [ [package.dependencies] jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" nbformat = ">=5.1" traitlets = ">=5.4" @@ -2492,7 +2401,6 @@ test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>= name = "nbconvert" version = "7.4.0" description = "Converting Jupyter Notebooks" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2531,7 +2439,6 @@ webpdf = ["pyppeteer (>=1,<1.1)"] name = "nbformat" version = "5.9.0" description = "The Jupyter Notebook format" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2553,7 +2460,6 @@ test = ["pep440", "pre-commit", "pytest", "testpath"] name = "nest-asyncio" version = "1.5.6" description = "Patch asyncio to allow nested event loops" -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2565,7 +2471,6 @@ files = [ name = "networkx" version = "3.1" description = "Python package for creating and manipulating graphs and networks" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2584,7 +2489,6 @@ test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] name = "nodeenv" version = "1.8.0" description = "Node.js virtual environment builder" -category = "dev" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" files = [ @@ -2599,7 +2503,6 @@ setuptools = "*" name = "notebook" version = "6.5.4" description = "A web-based notebook environment for interactive computing" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2634,7 +2537,6 @@ test = ["coverage", "nbval", "pytest", "pytest-cov", "requests", "requests-unixs name = "notebook-shim" version = "0.2.3" description = "A shim layer for notebook traits and config" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2652,7 +2554,6 @@ test = ["pytest", "pytest-console-scripts", "pytest-jupyter", "pytest-tornasync" name = "numpy" version = "1.24.3" description = "Fundamental package for array computing in Python" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2690,7 +2591,6 @@ files = [ name = "oauthlib" version = "3.2.2" description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" -category = "main" optional = true python-versions = ">=3.6" files = [ @@ -2707,7 +2607,6 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] name = "omegaconf" version = "2.3.0" description = "A flexible configuration library" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -2716,14 +2615,13 @@ files = [ ] [package.dependencies] -antlr4-python3-runtime = ">=4.9.0,<4.10.0" +antlr4-python3-runtime = "==4.9.*" PyYAML = ">=5.1.0" [[package]] name = "overrides" version = "7.3.1" description = "A decorator to automatically detect mismatch when overriding a method." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2735,7 +2633,6 @@ files = [ name = "packaging" version = "23.1" description = "Core utilities for Python packages" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2747,7 +2644,6 @@ files = [ name = "pandas" version = "2.0.2" description = "Powerful data structures for data analysis, time series, and statistics" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2814,7 +2710,6 @@ xml = ["lxml (>=4.6.3)"] name = "pandocfilters" version = "1.5.0" description = "Utilities for writing pandoc filters in python" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2826,7 +2721,6 @@ files = [ name = "parso" version = "0.8.3" description = "A Python Parser" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2842,7 +2736,6 @@ testing = ["docopt", "pytest (<6.0.0)"] name = "pathspec" version = "0.11.1" description = "Utility library for gitignore style pattern matching of file paths." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2854,7 +2747,6 @@ files = [ name = "pexpect" version = "4.8.0" description = "Pexpect allows easy control of interactive console applications." -category = "dev" optional = false python-versions = "*" files = [ @@ -2869,7 +2761,6 @@ ptyprocess = ">=0.5" name = "pickleshare" version = "0.7.5" description = "Tiny 'shelve'-like database with concurrency support" -category = "dev" optional = false python-versions = "*" files = [ @@ -2881,7 +2772,6 @@ files = [ name = "pillow" version = "9.5.0" description = "Python Imaging Library (Fork)" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2961,7 +2851,6 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa name = "platformdirs" version = "3.5.1" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2977,7 +2866,6 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest- name = "pluggy" version = "1.0.0" description = "plugin and hook calling mechanisms for python" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2993,7 +2881,6 @@ testing = ["pytest", "pytest-benchmark"] name = "pre-commit" version = "3.3.2" description = "A framework for managing and maintaining multi-language pre-commit hooks." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3012,7 +2899,6 @@ virtualenv = ">=20.10.0" name = "prometheus-client" version = "0.17.0" description = "Python client for the Prometheus monitoring system." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3027,7 +2913,6 @@ twisted = ["twisted"] name = "prompt-toolkit" version = "3.0.38" description = "Library for building powerful interactive command lines in Python" -category = "dev" optional = false python-versions = ">=3.7.0" files = [ @@ -3042,7 +2927,6 @@ wcwidth = "*" name = "protobuf" version = "4.23.2" description = "" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3065,7 +2949,6 @@ files = [ name = "psutil" version = "5.9.5" description = "Cross-platform lib for process and system monitoring in Python." -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3092,7 +2975,6 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] name = "ptyprocess" version = "0.7.0" description = "Run a subprocess in a pseudo terminal" -category = "dev" optional = false python-versions = "*" files = [ @@ -3104,7 +2986,6 @@ files = [ name = "pure-eval" version = "0.2.2" description = "Safely evaluate AST nodes without side effects" -category = "dev" optional = false python-versions = "*" files = [ @@ -3119,7 +3000,6 @@ tests = ["pytest"] name = "pyasn1" version = "0.5.0" description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" -category = "main" optional = true python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ @@ -3131,7 +3011,6 @@ files = [ name = "pyasn1-modules" version = "0.3.0" description = "A collection of ASN.1-based protocols modules" -category = "main" optional = true python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ @@ -3146,7 +3025,6 @@ pyasn1 = ">=0.4.6,<0.6.0" name = "pycparser" version = "2.21" description = "C parser in Python" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3158,7 +3036,6 @@ files = [ name = "pygments" version = "2.15.1" description = "Pygments is a syntax highlighting package written in Python." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3173,7 +3050,6 @@ plugins = ["importlib-metadata"] name = "pyjwt" version = "2.7.0" description = "JSON Web Token implementation in Python" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3187,11 +3063,21 @@ dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pyte docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] +[[package]] +name = "pympler" +version = "1.0.1" +description = "A development tool to measure, monitor and analyze the memory behavior of Python objects." +optional = false +python-versions = ">=3.6" +files = [ + {file = "Pympler-1.0.1-py3-none-any.whl", hash = "sha256:d260dda9ae781e1eab6ea15bacb84015849833ba5555f141d2d9b7b7473b307d"}, + {file = "Pympler-1.0.1.tar.gz", hash = "sha256:993f1a3599ca3f4fcd7160c7545ad06310c9e12f70174ae7ae8d4e25f6c5d3fa"}, +] + [[package]] name = "pynumaflow" version = "0.4.1" description = "Provides the interfaces of writing Python User Defined Functions and Sinks for NumaFlow." -category = "main" optional = true python-versions = ">=3.9,<3.12" files = [ @@ -3210,7 +3096,6 @@ protobuf = ">=3.20,<5.0" name = "pyparsing" version = "3.0.9" description = "pyparsing module - Classes and methods to define and execute parsing grammars" -category = "dev" optional = false python-versions = ">=3.6.8" files = [ @@ -3225,7 +3110,6 @@ diagrams = ["jinja2", "railroad-diagrams"] name = "pyrsistent" version = "0.19.3" description = "Persistent/Functional/Immutable data structures" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3262,7 +3146,6 @@ files = [ name = "pytest" version = "7.3.1" description = "pytest: simple powerful testing with Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3285,7 +3168,6 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no name = "pytest-cov" version = "4.1.0" description = "Pytest plugin for measuring coverage." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3304,7 +3186,6 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtuale name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" -category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ @@ -3319,7 +3200,6 @@ six = ">=1.5" name = "python-json-logger" version = "2.0.7" description = "A python library adding a json log formatter" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3331,7 +3211,6 @@ files = [ name = "pytorch-lightning" version = "2.0.2" description = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3363,7 +3242,6 @@ test = ["cloudpickle (>=1.3)", "coverage (==6.5.0)", "fastapi (<0.87.0)", "onnx name = "pytz" version = "2023.3" description = "World timezone definitions, modern and historical" -category = "main" optional = false python-versions = "*" files = [ @@ -3375,7 +3253,6 @@ files = [ name = "pywin32" version = "306" description = "Python for Window Extensions" -category = "dev" optional = false python-versions = "*" files = [ @@ -3399,7 +3276,6 @@ files = [ name = "pywinpty" version = "2.0.10" description = "Pseudo terminal support for Windows from Python." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3415,7 +3291,6 @@ files = [ name = "pyyaml" version = "6.0" description = "YAML parser and emitter for Python" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3465,7 +3340,6 @@ files = [ name = "pyzmq" version = "25.1.0" description = "Python bindings for 0MQ" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3555,7 +3429,6 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""} name = "qtconsole" version = "5.4.3" description = "Jupyter Qt console" -category = "dev" optional = false python-versions = ">= 3.7" files = [ @@ -3582,7 +3455,6 @@ test = ["flaky", "pytest", "pytest-qt"] name = "qtpy" version = "2.3.1" description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5/6 and PySide2/6)." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3600,7 +3472,6 @@ test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"] name = "redis" version = "4.5.5" description = "Python client for Redis database and key-value store" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3620,7 +3491,6 @@ ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)" name = "requests" version = "2.31.0" description = "Python HTTP for Humans." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3642,7 +3512,6 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] name = "rfc3339-validator" version = "0.1.4" description = "A pure python RFC3339 validator" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -3657,7 +3526,6 @@ six = "*" name = "rfc3986-validator" version = "0.1.1" description = "Pure python rfc3986 validator" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -3669,7 +3537,6 @@ files = [ name = "rsa" version = "4.9" description = "Pure-Python RSA implementation" -category = "main" optional = true python-versions = ">=3.6,<4" files = [ @@ -3684,7 +3551,6 @@ pyasn1 = ">=0.1.3" name = "ruff" version = "0.0.264" description = "An extremely fast Python linter, written in Rust." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3711,7 +3577,6 @@ files = [ name = "scikit-learn" version = "1.2.2" description = "A set of python modules for machine learning and data mining" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -3754,7 +3619,6 @@ tests = ["black (>=22.3.0)", "flake8 (>=3.8.2)", "matplotlib (>=3.1.3)", "mypy ( name = "scipy" version = "1.10.1" description = "Fundamental algorithms for scientific computing in Python" -category = "main" optional = false python-versions = "<3.12,>=3.8" files = [ @@ -3793,7 +3657,6 @@ test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeo name = "send2trash" version = "1.8.2" description = "Send file to trash natively under Mac OS X, Windows and Linux" -category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" files = [ @@ -3810,7 +3673,6 @@ win32 = ["pywin32"] name = "setuptools" version = "67.8.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3827,7 +3689,6 @@ testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs ( name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -3839,7 +3700,6 @@ files = [ name = "smmap" version = "5.0.0" description = "A pure Python implementation of a sliding window memory map manager" -category = "main" optional = true python-versions = ">=3.6" files = [ @@ -3851,7 +3711,6 @@ files = [ name = "sniffio" version = "1.3.0" description = "Sniff out which async library your code is running under" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3863,7 +3722,6 @@ files = [ name = "sortedcontainers" version = "2.4.0" description = "Sorted Containers -- Sorted List, Sorted Dict, Sorted Set" -category = "dev" optional = false python-versions = "*" files = [ @@ -3875,7 +3733,6 @@ files = [ name = "soupsieve" version = "2.4.1" description = "A modern CSS selector implementation for Beautiful Soup." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3887,7 +3744,6 @@ files = [ name = "sqlparse" version = "0.4.4" description = "A non-validating SQL parser." -category = "main" optional = true python-versions = ">=3.5" files = [ @@ -3904,7 +3760,6 @@ test = ["pytest", "pytest-cov"] name = "stack-data" version = "0.6.2" description = "Extract data from python stack frames and tracebacks for informative displays" -category = "dev" optional = false python-versions = "*" files = [ @@ -3924,7 +3779,6 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] name = "sympy" version = "1.12" description = "Computer algebra system (CAS) in Python" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3939,7 +3793,6 @@ mpmath = ">=0.19" name = "tabulate" version = "0.9.0" description = "Pretty-print tabular data" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3954,7 +3807,6 @@ widechars = ["wcwidth"] name = "terminado" version = "0.17.1" description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3975,7 +3827,6 @@ test = ["pre-commit", "pytest (>=7.0)", "pytest-timeout"] name = "threadpoolctl" version = "3.1.0" description = "threadpoolctl" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3987,7 +3838,6 @@ files = [ name = "tinycss2" version = "1.2.1" description = "A tiny CSS parser" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4006,7 +3856,6 @@ test = ["flake8", "isort", "pytest"] name = "tomli" version = "2.0.1" description = "A lil' TOML parser" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4018,7 +3867,6 @@ files = [ name = "torch" version = "2.0.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" -category = "dev" optional = false python-versions = ">=3.8.0" files = [ @@ -4058,7 +3906,6 @@ opt-einsum = ["opt-einsum (>=3.3)"] name = "torchinfo" version = "1.8.0" description = "Model summary in PyTorch, based off of the original torchsummary." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4070,7 +3917,6 @@ files = [ name = "torchmetrics" version = "0.11.4" description = "PyTorch native Metrics" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4096,7 +3942,6 @@ text = ["nltk (>=3.6)", "regex (>=2021.9.24)", "tqdm (>=4.41.0)"] name = "tornado" version = "6.3.2" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." -category = "dev" optional = false python-versions = ">= 3.8" files = [ @@ -4117,7 +3962,6 @@ files = [ name = "tqdm" version = "4.65.0" description = "Fast, Extensible Progress Meter" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4138,7 +3982,6 @@ telegram = ["requests"] name = "traitlets" version = "5.9.0" description = "Traitlets Python configuration system" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4154,7 +3997,6 @@ test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"] name = "typing-extensions" version = "4.6.3" description = "Backported and Experimental Type Hints for Python 3.7+" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4166,7 +4008,6 @@ files = [ name = "tzdata" version = "2023.3" description = "Provider of IANA time zone data" -category = "main" optional = false python-versions = ">=2" files = [ @@ -4178,7 +4019,6 @@ files = [ name = "uri-template" version = "1.2.0" description = "RFC 6570 URI Template Processor" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -4193,7 +4033,6 @@ dev = ["flake8 (<4.0.0)", "flake8-annotations", "flake8-bugbear", "flake8-commas name = "urllib3" version = "1.26.16" description = "HTTP library with thread-safe connection pooling, file post, and more." -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -4210,7 +4049,6 @@ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] name = "virtualenv" version = "20.23.0" description = "Virtual Python Environment builder" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4231,7 +4069,6 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.3)", "coverage-enable-subprocess name = "wcwidth" version = "0.2.6" description = "Measures the displayed width of unicode strings in a terminal" -category = "dev" optional = false python-versions = "*" files = [ @@ -4243,7 +4080,6 @@ files = [ name = "webcolors" version = "1.13" description = "A library for working with the color formats defined by HTML and CSS." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4259,7 +4095,6 @@ tests = ["pytest", "pytest-cov"] name = "webencodings" version = "0.5.1" description = "Character encoding aliases for legacy web content" -category = "dev" optional = false python-versions = "*" files = [ @@ -4271,7 +4106,6 @@ files = [ name = "websocket-client" version = "1.5.2" description = "WebSocket client for Python with low level API options" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4288,7 +4122,6 @@ test = ["websockets"] name = "widgetsnbextension" version = "4.0.7" description = "Jupyter interactive widgets for Jupyter Notebook" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4300,7 +4133,6 @@ files = [ name = "yarl" version = "1.9.2" description = "Yet another URL library" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4388,7 +4220,6 @@ multidict = ">=4.0" name = "zipp" version = "3.15.0" description = "Backport of pathlib-compatible object wrapper for zip files" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4408,4 +4239,4 @@ redis = ["redis"] [metadata] lock-version = "2.0" python-versions = ">=3.9, <3.11" -content-hash = "88dde9f56415172bf16235f0f6a2dc113c46ff0fb8c869178530436530164e8b" +content-hash = "50f593285a0b4d475c8ff29a8c2fe1be9b8d03e802b57a102f9d4da9e9661b87" diff --git a/pyproject.toml b/pyproject.toml index d83e1bc9..af84c49a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ ruff = "^0.0.264" pre-commit = "^3.3.1" fakeredis = "^2.11.2" freezegun = "^1.2.2" +pympler = "^1.0.1" [tool.poetry.group.jupyter] optional = true diff --git a/tests/blocks/__init__.py b/tests/blocks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/blocks/test_blocks.py b/tests/blocks/test_blocks.py new file mode 100644 index 00000000..e7027cfc --- /dev/null +++ b/tests/blocks/test_blocks.py @@ -0,0 +1,32 @@ +import unittest + +import numpy as np + +from numalogic.blocks import Block +from sklearn.ensemble import IsolationForest + + +class DummyBlock(Block): + def fit(self, input_: np.ndarray, **__) -> np.ndarray: + return self._artifact.fit_predict(input_).reshape(-1, 1) + + def run(self, input_: np.ndarray, **__) -> np.ndarray: + return self._artifact.predict(input_).reshape(-1, 1) + + +class TestBlock(unittest.TestCase): + def test_random_block(self): + block = DummyBlock(IsolationForest(), name="isolation_forest") + self.assertEqual(block.name, "isolation_forest") + + block.fit(np.arange(100).reshape(-1, 2)) + out = block(np.arange(10).reshape(-1, 2)) + self.assertEqual(out.shape, (5, 1)) + + self.assertIsInstance(block.artifact, IsolationForest) + self.assertIsInstance(block.artifact_state, IsolationForest) + self.assertTrue(block.stateful) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/blocks/test_pipeline.py b/tests/blocks/test_pipeline.py new file mode 100644 index 00000000..9bfe2684 --- /dev/null +++ b/tests/blocks/test_pipeline.py @@ -0,0 +1,212 @@ +import os +import unittest + +import pandas as pd +import torch +from fakeredis import FakeRedis, FakeServer +from pympler import asizeof +from sklearn.preprocessing import StandardScaler + +from numalogic._constants import TESTS_DIR +from numalogic.blocks import ( + BlockPipeline, + PreprocessBlock, + NNBlock, + PostprocessBlock, + ThresholdBlock, +) +from numalogic.models.autoencoder.variants import ( + VanillaAE, + LSTMAE, + Conv1dAE, + TransformerAE, + SparseVanillaAE, + SparseConv1dAE, +) +from numalogic.models.threshold import StdDevThreshold +from numalogic.registry import RedisRegistry +from numalogic.transforms import TanhScaler, TanhNorm, LogTransformer + +ROOT_DIR = os.path.join(TESTS_DIR, "resources", "data") +DATA_FILE = os.path.join(ROOT_DIR, "interactionstatus.csv") +server = FakeServer() +SEQ_LEN = 10 + + +class TestBlockPipeline(unittest.TestCase): + x_train = None + x_stream = None + + @classmethod + def setUpClass(cls) -> None: + df = pd.read_csv(DATA_FILE, nrows=1000) + df = df[["success", "failure"]] + cls.x_train = df[:990].to_numpy() + cls.x_stream = df[-10:].to_numpy() + assert cls.x_train.shape == (990, 2) + assert cls.x_stream.shape == (10, 2) + + def setUp(self) -> None: + self.reg = RedisRegistry(client=FakeRedis(server=server)) + + def test_pipeline_01(self): + block_pl = BlockPipeline( + PreprocessBlock(TanhScaler()), + NNBlock(VanillaAE(SEQ_LEN, n_features=2), SEQ_LEN), + ThresholdBlock(StdDevThreshold()), + PostprocessBlock(TanhNorm()), + registry=self.reg, + ) + block_pl.fit(self.x_train, nn__max_epochs=1) + out = block_pl(self.x_stream) + + self.assertTupleEqual(self.x_stream.shape, out.shape) + self.assertEqual(4, len(block_pl)) + self.assertIsInstance(block_pl[1], NNBlock) + + def test_pipeline_02(self): + block_pl = BlockPipeline( + PreprocessBlock(StandardScaler()), + NNBlock(LSTMAE(SEQ_LEN, no_features=2, embedding_dim=4), SEQ_LEN), + PostprocessBlock(TanhNorm()), + registry=self.reg, + ) + block_pl.fit( + self.x_train, + nn__max_epochs=1, + nn__accelerator="cpu", + ) + out = block_pl.run(self.x_stream) + self.assertTupleEqual(self.x_stream.shape, out.shape) + + def test_pipeline_03(self): + block_pl = BlockPipeline( + PreprocessBlock(LogTransformer(), stateful=False), + NNBlock(Conv1dAE(SEQ_LEN, in_channels=2), SEQ_LEN), + registry=self.reg, + ) + block_pl.fit( + self.x_train, + nn__max_epochs=1, + nn__accelerator="cpu", + ) + out = block_pl.run(self.x_stream) + self.assertTupleEqual(self.x_stream.shape, out.shape) + + def test_pipeline_04(self): + block_pl = BlockPipeline( + PreprocessBlock(StandardScaler()), + NNBlock(TransformerAE(SEQ_LEN, n_features=2), SEQ_LEN), + ThresholdBlock(StdDevThreshold()), + registry=self.reg, + ) + block_pl.fit( + self.x_train, + nn__max_epochs=1, + nn__accelerator="cpu", + ) + out = block_pl.run(self.x_stream) + self.assertTupleEqual(self.x_stream.shape, out.shape) + for block in block_pl: + self.assertTrue(block.stateful) + + def test_pipeline_05(self): + block_pl = BlockPipeline( + PreprocessBlock(LogTransformer(), stateful=False), + NNBlock(SparseVanillaAE(seq_len=SEQ_LEN, n_features=2), SEQ_LEN), + PostprocessBlock(TanhNorm()), + registry=self.reg, + ) + block_pl.fit( + self.x_train, + nn__max_epochs=1, + ) + out = block_pl.run(self.x_stream) + self.assertTupleEqual(self.x_stream.shape, out.shape) + + def test_pipeline_persistence(self): + skeys = ["test"] + dkeys = ["pipeline"] + # Pipeline for saving + pl_1 = BlockPipeline( + PreprocessBlock(TanhScaler()), + NNBlock(SparseConv1dAE(seq_len=SEQ_LEN, in_channels=2), SEQ_LEN), + ThresholdBlock(StdDevThreshold()), + PostprocessBlock(TanhNorm()), + registry=self.reg, + ) + pl_1.fit( + self.x_train, + nn__accelerator="cpu", + nn__max_epochs=1, + ) + + _preweights = [] + with torch.no_grad(): + for params in pl_1[1].artifact.parameters(): + _preweights.append(torch.mean(params)) + + pl_1.save(skeys, dkeys) + + # Pipeline for loading + pl_2 = BlockPipeline( + PreprocessBlock(TanhScaler()), + NNBlock(SparseConv1dAE(seq_len=SEQ_LEN, in_channels=2), SEQ_LEN), + PostprocessBlock(TanhNorm()), + registry=self.reg, + ) + pl_2.load(skeys, dkeys) + + _postweights = [] + with torch.no_grad(): + for params in pl_2[1].artifact.parameters(): + _postweights.append(torch.mean(params)) + + self.assertListEqual(_preweights, _postweights) + out = pl_2(self.x_stream) + self.assertTupleEqual(self.x_stream.shape, out.shape) + + def test_pipeline_save_err(self): + block_pl = BlockPipeline( + PreprocessBlock(TanhScaler()), + NNBlock(VanillaAE(SEQ_LEN, n_features=2), SEQ_LEN), + PostprocessBlock(TanhNorm()), + ) + block_pl.fit(self.x_train, nn__max_epochs=1) + self.assertRaises(ValueError, block_pl.save, ["ml"], ["pl"]) + self.assertRaises(ValueError, block_pl.load, ["ml"], ["pl"]) + + def test_pipeline_fit_err(self): + block_pl = BlockPipeline( + PreprocessBlock(TanhScaler()), + NNBlock(VanillaAE(SEQ_LEN, n_features=2), SEQ_LEN), + PostprocessBlock(TanhNorm()), + ) + self.assertRaises(ValueError, block_pl.fit, self.x_train, max_epochs=1) + + @unittest.skip("Just for testing memory usage") + def test_memory_usage(self): + model = SparseConv1dAE(seq_len=SEQ_LEN, in_channels=2) + block_nn = NNBlock(model, SEQ_LEN) + block_pre = PreprocessBlock(TanhScaler()) + block_post = PostprocessBlock(TanhNorm()) + block_th = ThresholdBlock(StdDevThreshold()) + + pl = BlockPipeline( + block_pre, + block_nn, + block_th, + block_post, + registry=self.reg, + ) + + print(asizeof.asizeof(model) / 1024) + print(asizeof.asizeof(block_nn) / 1024) + print(asizeof.asizeof(block_pre) / 1024) + print(asizeof.asizeof(block_post) / 1024) + print(asizeof.asizeof(block_th) / 1024) + print(asizeof.asizeof(pl) / 1024) + + +if __name__ == "__main__": + unittest.main()