diff --git a/dev-requirements.in b/dev-requirements.in index c1c82230a9..e9726227fb 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -15,3 +15,4 @@ IPython # Newer versions of torch bring in nvidia dependencies that are not present in windows, so # we put this constraint while we do not have per-environment requirements files torch<=1.12.1 +scikit-learn diff --git a/dev-requirements.txt b/dev-requirements.txt index 6e76021029..5a73546c0d 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -8,6 +8,8 @@ # via # -c requirements.txt # pytest-flyte +appnope==0.1.3 + # via ipython arrow==1.2.3 # via # -c requirements.txt @@ -77,7 +79,6 @@ cryptography==38.0.3 # -c requirements.txt # paramiko # pyopenssl - # secretstorage dataclasses-json==0.5.7 # via # -c requirements.txt @@ -192,11 +193,6 @@ jaraco-classes==3.2.3 # keyring jedi==0.18.1 # via ipython -jeepney==0.8.0 - # via - # -c requirements.txt - # keyring - # secretstorage jinja2==3.1.2 # via # -c requirements.txt @@ -212,6 +208,7 @@ joblib==1.2.0 # -c requirements.txt # -r dev-requirements.in # flytekit + # scikit-learn jsonschema==3.2.0 # via # -c requirements.txt @@ -265,6 +262,8 @@ numpy==1.21.6 # flytekit # pandas # pyarrow + # scikit-learn + # scipy packaging==21.3 # via # -c requirements.txt @@ -419,10 +418,10 @@ retry==0.9.2 # flytekit rsa==4.9 # via google-auth -secretstorage==3.3.3 - # via - # -c requirements.txt - # keyring +scikit-learn==1.0.2 + # via -r dev-requirements.in +scipy==1.7.3 + # via scikit-learn singledispatchmethod==1.0 # via # -c requirements.txt @@ -451,6 +450,8 @@ text-unidecode==1.3 # python-slugify texttable==1.6.4 # via docker-compose +threadpoolctl==3.1.0 + # via scikit-learn toml==0.10.2 # via # -c requirements.txt diff --git a/doc-requirements.in b/doc-requirements.in index 7c5dcc3e96..d61955fbff 100644 --- a/doc-requirements.in +++ b/doc-requirements.in @@ -44,3 +44,4 @@ tensorflow==2.9.0 # onnxtensorflow whylogs # whylogs whylabs-client # whylogs ray # ray +scikit-learn # scikit-learn diff --git a/docs/source/extras.sklearn.rst b/docs/source/extras.sklearn.rst new file mode 100644 index 0000000000..a2efcfa84b --- /dev/null +++ b/docs/source/extras.sklearn.rst @@ -0,0 +1,7 @@ +############ +Sklearn Type +############ +.. automodule:: flytekit.extras.sklearn + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index ef345aadda..dccbaec803 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -487,6 +487,7 @@ def dispatch_execute( # Short circuit the translation to literal map because what's returned may be a dj spec (or an # already-constructed LiteralMap if the dynamic task was a no-op), not python native values + # dynamic_execute returns a literal map in local execute so this also gets triggered. if isinstance(native_outputs, _literal_models.LiteralMap) or isinstance( native_outputs, _dynamic_job.DynamicJobSpec ): diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 0f74c1dee1..7e4600b3bb 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -450,20 +450,20 @@ class Mode(Enum): Defines the possible execution modes, which in turn affects execution behavior. """ - #: This is the mode that is used when a task execution mimics the actual runtime environment. - #: NOTE: This is important to understand the difference between TASK_EXECUTION and LOCAL_TASK_EXECUTION - #: LOCAL_TASK_EXECUTION, is the mode that is run purely locally and in some cases the difference between local - #: and runtime environment may be different. For example for Dynamic tasks local_task_execution will just run it - #: as a regular function, while task_execution will extract a runtime spec + # This is the mode that is used when a task execution mimics the actual runtime environment. + # NOTE: This is important to understand the difference between TASK_EXECUTION and LOCAL_TASK_EXECUTION + # LOCAL_TASK_EXECUTION, is the mode that is run purely locally and in some cases the difference between local + # and runtime environment may be different. For example for Dynamic tasks local_task_execution will just run it + # as a regular function, while task_execution will extract a runtime spec TASK_EXECUTION = 1 - #: This represents when flytekit is locally running a workflow. The behavior of tasks differs in this case - #: because instead of running a task's user defined function directly, it'll need to wrap the return values in - #: NodeOutput + # This represents when flytekit is locally running a workflow. The behavior of tasks differs in this case + # because instead of running a task's user defined function directly, it'll need to wrap the return values in + # NodeOutput LOCAL_WORKFLOW_EXECUTION = 2 - #: This is the mode that is used to to indicate a purely local task execution - i.e. running without a container - #: or propeller. + # This is the mode that is used to indicate a purely local task execution - i.e. running without a container + # or propeller. LOCAL_TASK_EXECUTION = 3 mode: Optional[ExecutionState.Mode] diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 95fdebb4eb..d849ef5397 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -108,6 +108,8 @@ def with_overrides(self, *args, **kwargs): ) if "interruptible" in kwargs: self._metadata._interruptible = kwargs["interruptible"] + if "name" in kwargs: + self._metadata._name = kwargs["name"] return self @@ -134,7 +136,10 @@ def _convert_resource_overrides( ) if resources.ephemeral_storage is not None: resource_entries.append( - _resources_model.ResourceEntry(_resources_model.ResourceName.EPHEMERAL_STORAGE, resources.ephemeral_storage) + _resources_model.ResourceEntry( + _resources_model.ResourceName.EPHEMERAL_STORAGE, + resources.ephemeral_storage, + ) ) return resource_entries diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index fb428a89a2..bcb80f34ca 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -19,12 +19,11 @@ from enum import Enum from typing import Any, Callable, List, Optional, TypeVar, Union -from flytekit.configuration import SerializationSettings -from flytekit.configuration.default_images import DefaultImages from flytekit.core.base_task import Task, TaskResolverMixin from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.docstring import Docstring from flytekit.core.interface import transform_function_to_interface +from flytekit.core.promise import VoidPromise, translate_inputs_to_literals from flytekit.core.python_auto_container import PythonAutoContainerTask, default_task_resolver from flytekit.core.tracker import extract_task_module, is_functools_wrapped_module_level, isnested, istestfunction from flytekit.core.workflow import ( @@ -34,6 +33,7 @@ WorkflowMetadataDefaults, ) from flytekit.exceptions import scopes as exception_scopes +from flytekit.exceptions.user import FlyteValueException from flytekit.loggers import logger from flytekit.models import dynamic_job as _dynamic_job from flytekit.models import literals as _literal_models @@ -145,6 +145,7 @@ def __init__( ) self._task_function = task_function self._execution_mode = execution_mode + self._wf = None # For dynamic tasks @property def execution_mode(self) -> ExecutionBehavior: @@ -164,6 +165,14 @@ def execute(self, **kwargs) -> Any: elif self.execution_mode == self.ExecutionBehavior.DYNAMIC: return self.dynamic_execute(self._task_function, **kwargs) + def _create_and_cache_dynamic_workflow(self): + if self._wf is None: + workflow_meta = WorkflowMetadata(on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY) + defaults = WorkflowMetadataDefaults( + interruptible=self.metadata.interruptible if self.metadata.interruptible is not None else False + ) + self._wf = PythonFunctionWorkflow(self._task_function, metadata=workflow_meta, default_metadata=defaults) + def compile_into_workflow( self, ctx: FlyteContext, task_function: Callable, **kwargs ) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]: @@ -183,12 +192,7 @@ def compile_into_workflow( # TODO: Resolve circular import from flytekit.tools.translator import get_serializable - workflow_metadata = WorkflowMetadata(on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY) - defaults = WorkflowMetadataDefaults( - interruptible=self.metadata.interruptible if self.metadata.interruptible is not None else False - ) - - self._wf = PythonFunctionWorkflow(task_function, metadata=workflow_metadata, default_metadata=defaults) + self._create_and_cache_dynamic_workflow() self._wf.compile(**kwargs) wf = self._wf @@ -259,19 +263,44 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: representing that newly generated workflow, instead of executing it. """ ctx = FlyteContextManager.current_context() - # This is a placeholder SerializationSettings placeholder and is only used to test compilation for dynamic tasks - # when run locally. The output of the compilation should never actually be used anywhere. - _LOCAL_ONLY_SS = SerializationSettings.for_image(DefaultImages.default_image(), "v", "p", "d") - if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: - updated_exec_state = ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION) - with FlyteContextManager.with_context( - ctx.with_execution_state(updated_exec_state).with_serialization_settings(_LOCAL_ONLY_SS) - ) as ctx: - logger.debug(f"Running compilation for {self} as part of local run as check") - self.compile_into_workflow(ctx, task_function, **kwargs) - logger.info("Executing Dynamic workflow, using raw inputs") - return exception_scopes.user_entry_point(task_function)(**kwargs) + # The rest of this function mimics the local_execute of the workflow. We can't use the workflow + # local_execute directly though since that converts inputs into Promises. + logger.debug(f"Executing Dynamic workflow, using raw inputs {kwargs}") + self._create_and_cache_dynamic_workflow() + function_outputs = self._wf.execute(**kwargs) + + if isinstance(function_outputs, VoidPromise) or function_outputs is None: + return VoidPromise(self.name) + + if len(self._wf.python_interface.outputs) == 0: + raise FlyteValueException(function_outputs, "Interface output should've been VoidPromise or None.") + + # TODO: This will need to be cleaned up when we revisit top-level tuple support. + expected_output_names = list(self.python_interface.outputs.keys()) + if len(expected_output_names) == 1: + # Here we have to handle the fact that the wf could've been declared with a typing.NamedTuple of + # length one. That convention is used for naming outputs - and single-length-NamedTuples are + # particularly troublesome but elegant handling of them is not a high priority + # Again, we're using the output_tuple_name as a proxy. + if self.python_interface.output_tuple_name and isinstance(function_outputs, tuple): + wf_outputs_as_map = {expected_output_names[0]: function_outputs[0]} + else: + wf_outputs_as_map = {expected_output_names[0]: function_outputs} + else: + wf_outputs_as_map = { + expected_output_names[i]: function_outputs[i] for i, _ in enumerate(function_outputs) + } + + # In a normal workflow, we'd repackage the promises coming from tasks into new Promises matching the + # workflow's interface. For a dynamic workflow, just return the literal map. + wf_outputs_as_literal_dict = translate_inputs_to_literals( + ctx, + wf_outputs_as_map, + flyte_interface_types=self.interface.outputs, + native_types=self.python_interface.outputs, + ) + return _literal_models.LiteralMap(literals=wf_outputs_as_literal_dict) if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: return self.compile_into_workflow(ctx, task_function, **kwargs) diff --git a/flytekit/extras/sklearn/__init__.py b/flytekit/extras/sklearn/__init__.py new file mode 100644 index 0000000000..0a1bf2dda5 --- /dev/null +++ b/flytekit/extras/sklearn/__init__.py @@ -0,0 +1,26 @@ +""" +Flytekit Sklearn +========================================= +.. currentmodule:: flytekit.extras.sklearn + +.. autosummary:: + :template: custom.rst + :toctree: generated/ +""" +from flytekit.loggers import logger + +# TODO: abstract this out so that there's an established pattern for registering plugins +# that have soft dependencies +try: + # isolate the exception to the sklearn import + import sklearn + + _sklearn_installed = True +except (ImportError, OSError): + _sklearn_installed = False + + +if _sklearn_installed: + from .native import SklearnEstimatorTransformer +else: + logger.info("We won't register SklearnEstimatorTransformer because scikit-learn is not installed.") diff --git a/flytekit/extras/sklearn/native.py b/flytekit/extras/sklearn/native.py new file mode 100644 index 0000000000..59ecca70c5 --- /dev/null +++ b/flytekit/extras/sklearn/native.py @@ -0,0 +1,79 @@ +import pathlib +from typing import Generic, Type, TypeVar + +import joblib +import sklearn + +from flytekit.core.context_manager import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core import types as _core_types +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType + +T = TypeVar("T") + + +class SklearnTypeTransformer(TypeTransformer, Generic[T]): + def get_literal_type(self, t: Type[T]) -> LiteralType: + return LiteralType( + blob=_core_types.BlobType( + format=self.SKLEARN_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) + ) + + def to_literal( + self, + ctx: FlyteContext, + python_val: T, + python_type: Type[T], + expected: LiteralType, + ) -> Literal: + meta = BlobMetadata( + type=_core_types.BlobType( + format=self.SKLEARN_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) + ) + + local_path = ctx.file_access.get_random_local_path() + ".joblib" + pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) + + # save sklearn estimator to a file + joblib.dump(python_val, local_path) + + remote_path = ctx.file_access.get_random_remote_path(local_path) + ctx.file_access.put_data(local_path, remote_path, is_multipart=False) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + try: + uri = lv.scalar.blob.uri + except AttributeError: + TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + + local_path = ctx.file_access.get_random_local_path() + ctx.file_access.get_data(uri, local_path, is_multipart=False) + + # load sklearn estimator from a file + return joblib.load(local_path) + + def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.SKLEARN_FORMAT + ): + return T + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + + +class SklearnEstimatorTransformer(SklearnTypeTransformer[sklearn.base.BaseEstimator]): + SKLEARN_FORMAT = "SklearnEstimator" + + def __init__(self): + super().__init__(name="Sklearn Estimator", t=sklearn.base.BaseEstimator) + + +TypeEngine.register(SklearnEstimatorTransformer()) diff --git a/plugins/flytekit-dbt/flytekitplugins/dbt/schema.py b/plugins/flytekit-dbt/flytekitplugins/dbt/schema.py index 2d535f1fb7..d680559c07 100644 --- a/plugins/flytekit-dbt/flytekitplugins/dbt/schema.py +++ b/plugins/flytekit-dbt/flytekitplugins/dbt/schema.py @@ -251,11 +251,8 @@ class DBTFreshnessOutput(BaseDBTOutput): Attributes ---------- - raw_run_result : str - Raw value of DBT's ``run_result.json``. - raw_manifest : str - Raw value of DBT's ``manifest.json``. + raw_sources : str + Raw value of DBT's ``sources.json``. """ - raw_run_result: str - raw_manifest: str + raw_sources: str \ No newline at end of file diff --git a/plugins/flytekit-dbt/flytekitplugins/dbt/task.py b/plugins/flytekit-dbt/flytekitplugins/dbt/task.py index d78fb56c42..8a0621cb80 100644 --- a/plugins/flytekit-dbt/flytekitplugins/dbt/task.py +++ b/plugins/flytekit-dbt/flytekitplugins/dbt/task.py @@ -354,18 +354,13 @@ def my_workflow() -> DBTFreshnessOutput: raise DBTUnhandledError(f"unhandled error while executing {full_command}", logs) output_dir = os.path.join(task_input.project_dir, task_input.output_path) - run_result_path = os.path.join(output_dir, "run_results.json") - with open(run_result_path) as file: - run_result = file.read() + sources_path = os.path.join(output_dir, "sources.json") + with open(sources_path) as file: + sources = file.read() - # read manifest.json - manifest_path = os.path.join(output_dir, "manifest.json") - with open(manifest_path) as file: - manifest = file.read() return DBTFreshnessOutput( command=full_command, exit_code=exit_code, - raw_run_result=run_result, - raw_manifest=manifest, + raw_sources=sources ) diff --git a/plugins/flytekit-dbt/tests/test_task.py b/plugins/flytekit-dbt/tests/test_task.py index fa628e3550..1c91008552 100644 --- a/plugins/flytekit-dbt/tests/test_task.py +++ b/plugins/flytekit-dbt/tests/test_task.py @@ -219,10 +219,7 @@ def test_task_output(self): == f"dbt --log-format json source freshness --project-dir {DBT_PROJECT_DIR} --profiles-dir {DBT_PROFILES_DIR} --profile {DBT_PROFILE}" ) - with open(f"{DBT_PROJECT_DIR}/target/run_results.json", "r") as fp: - exp_run_result = fp.read() - assert output.raw_run_result == exp_run_result + with open(f"{DBT_PROJECT_DIR}/target/sources.json", "r") as fp: + exp_sources = fp.read() - with open(f"{DBT_PROJECT_DIR}/target/manifest.json", "r") as fp: - exp_manifest = fp.read() - assert output.raw_manifest == exp_manifest + assert output.raw_sources == exp_sources \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 6690a67f85..e4d3fbcdc0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -72,10 +72,6 @@ importlib-metadata==5.0.0 # keyring jaraco-classes==3.2.3 # via keyring -jeepney==0.8.0 - # via - # keyring - # secretstorage jinja2==3.1.2 # via # cookiecutter @@ -173,8 +169,6 @@ responses==0.22.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.3 - # via keyring singledispatchmethod==1.0 # via flytekit six==1.16.0 diff --git a/tests/flytekit/unit/core/test_dynamic.py b/tests/flytekit/unit/core/test_dynamic.py index d19f47ff75..cccf406c71 100644 --- a/tests/flytekit/unit/core/test_dynamic.py +++ b/tests/flytekit/unit/core/test_dynamic.py @@ -8,9 +8,11 @@ from flytekit.core import context_manager from flytekit.core.context_manager import ExecutionState from flytekit.core.node_creation import create_node +from flytekit.core.resources import Resources from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow +from flytekit.models.literals import LiteralMap settings = flytekit.configuration.SerializationSettings( project="test_proj", @@ -155,6 +157,45 @@ def wf() -> str: assert wf() == "hello" +def test_dynamic_local_rshift(): + @task + def task1(s: str) -> str: + return s + + @task + def task2(s: str) -> str: + return s + + @dynamic + def dynamic_wf() -> str: + to1 = task1(s="hello").with_overrides(requests=Resources(cpu="3", mem="5Gi")) + to2 = task2(s="world") + to1 >> to2 # noqa + + return to1 + + @workflow + def wf() -> str: + return dynamic_wf() + + assert wf() == "hello" + + with context_manager.FlyteContextManager.with_context( + context_manager.FlyteContextManager.current_context().with_serialization_settings(settings) + ) as ctx: + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params( + mode=ExecutionState.Mode.TASK_EXECUTION, + ) + ) + ) as ctx: + dynamic_job_spec = dynamic_wf.dispatch_execute(ctx, LiteralMap(literals={})) + assert dynamic_job_spec.nodes[1].upstream_node_ids == ["dn0"] + assert dynamic_job_spec.nodes[0].task_node.overrides.resources.requests[0].value == "3" + assert dynamic_job_spec.nodes[0].task_node.overrides.resources.requests[1].value == "5Gi" + + def test_dynamic_return_dict(): @dynamic def t1(v: str) -> typing.Dict[str, str]: @@ -175,3 +216,44 @@ def wf(): t3(v="c") wf() + + +def test_nested_dynamic_locals(): + @task + def t1(a: int) -> str: + a = a + 2 + return "fast-" + str(a) + + @task + def t2(b: str) -> str: + return f"In t2 string is {b}" + + @task + def t3(b: str) -> str: + return f"In t3 string is {b}" + + @workflow() + def normalwf(a: int) -> str: + x = t1(a=a) + return x + + @dynamic + def dt(ss: str) -> typing.List[str]: + if ss == "hello": + bb = t2(b=ss) + bbb = t3(b=bb) + else: + bb = t2(b=ss + "hi again") + bbb = "static" + return [bb, bbb] + + @workflow + def wf(wf_in: str) -> typing.List[str]: + x = dt(ss=wf_in) + return x + + res = wf(wf_in="hello") + assert res == ["In t2 string is hello", "In t3 string is In t2 string is hello"] + + res = dt(ss="hello") + assert res == ["In t2 string is hello", "In t3 string is In t2 string is hello"] diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index bb1773ec5c..47c8af9830 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -87,6 +87,7 @@ def empty_wf2(): wf_spec = get_serializable(OrderedDict(), serialization_settings, empty_wf2) assert wf_spec.template.nodes[0].upstream_node_ids[0] == "n1" assert wf_spec.template.nodes[0].id == "n0" + assert wf_spec.template.nodes[0].metadata.name == "t2" with pytest.raises(FlyteAssertion): @@ -334,7 +335,8 @@ def my_wf(a: str) -> str: @pytest.mark.parametrize( - "retries,expected", [(None, _literal_models.RetryStrategy(0)), (3, _literal_models.RetryStrategy(3))] + "retries,expected", + [(None, _literal_models.RetryStrategy(0)), (3, _literal_models.RetryStrategy(3))], ) def test_retries_override(retries, expected): @task @@ -401,3 +403,24 @@ def my_wf(a: str): _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "1"), _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "100"), ] + + +def test_name_override(): + @task + def t1(a: str) -> str: + return f"*~*~*~{a}*~*~*~" + + @workflow + def my_wf(a: str) -> str: + return t1(a=a).with_overrides(name="foo") + + serialization_settings = flytekit.configuration.SerializationSettings( + project="test_proj", + domain="test_domain", + version="abc", + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + env={}, + ) + wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) + assert len(wf_spec.template.nodes) == 1 + assert wf_spec.template.nodes[0].metadata.name == "foo" diff --git a/tests/flytekit/unit/extras/sklearn/__init__.py b/tests/flytekit/unit/extras/sklearn/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/extras/sklearn/test_native.py b/tests/flytekit/unit/extras/sklearn/test_native.py new file mode 100644 index 0000000000..7d704597c1 --- /dev/null +++ b/tests/flytekit/unit/extras/sklearn/test_native.py @@ -0,0 +1,48 @@ +import numpy as np +from sklearn.linear_model import LinearRegression +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler + +from flytekit import task, workflow + + +@task +def get_preprocessor() -> StandardScaler: + return StandardScaler() + + +@task +def get_model() -> LinearRegression: + return LinearRegression() + + +@task +def make_pipeline(preprocessor: StandardScaler, model: LinearRegression) -> Pipeline: + return Pipeline([("scaler", preprocessor), ("model", model)]) + + +@task +def fit_pipeline(pipeline: Pipeline) -> Pipeline: + x = np.random.normal(size=(10, 2)) + y = np.random.randint(2, size=(10,)) + pipeline.fit(x, y) + return pipeline + + +@task +def num_features(pipeline: Pipeline) -> int: + return pipeline.n_features_in_ + + +@workflow +def wf(): + preprocessor = get_preprocessor() + model = get_model() + pipeline = make_pipeline(preprocessor=preprocessor, model=model) + pipeline = fit_pipeline(pipeline=pipeline) + num_features(pipeline=pipeline) + + +@workflow +def test_wf(): + wf() diff --git a/tests/flytekit/unit/extras/sklearn/test_transformations.py b/tests/flytekit/unit/extras/sklearn/test_transformations.py new file mode 100644 index 0000000000..39343f9180 --- /dev/null +++ b/tests/flytekit/unit/extras/sklearn/test_transformations.py @@ -0,0 +1,96 @@ +from collections import OrderedDict +from functools import partial + +import numpy as np +import pytest +from sklearn.base import BaseEstimator +from sklearn.linear_model import LinearRegression +from sklearn.svm import SVC + +import flytekit +from flytekit import task +from flytekit.configuration import Image, ImageConfig +from flytekit.core import context_manager +from flytekit.extras.sklearn import SklearnEstimatorTransformer +from flytekit.models.core.types import BlobType +from flytekit.models.literals import BlobMetadata +from flytekit.models.types import LiteralType +from flytekit.tools.translator import get_serializable + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +def get_model(model_type: str) -> BaseEstimator: + models_map = { + "lr": LinearRegression, + "svc": partial(SVC, kernel="linear"), + } + x = np.random.normal(size=(10, 2)) + y = np.random.randint(2, size=(10,)) + model = models_map[model_type]() + model.fit(x, y) + return model + + +@pytest.mark.parametrize( + "transformer,python_type,format", + [ + (SklearnEstimatorTransformer(), BaseEstimator, SklearnEstimatorTransformer.SKLEARN_FORMAT), + ], +) +def test_get_literal_type(transformer, python_type, format): + tf = transformer + lt = tf.get_literal_type(python_type) + assert lt == LiteralType(blob=BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE)) + + +@pytest.mark.parametrize( + "transformer,python_type,format,python_val", + [ + ( + SklearnEstimatorTransformer(), + BaseEstimator, + SklearnEstimatorTransformer.SKLEARN_FORMAT, + get_model("lr"), + ), + ( + SklearnEstimatorTransformer(), + BaseEstimator, + SklearnEstimatorTransformer.SKLEARN_FORMAT, + get_model("svc"), + ), + ], +) +def test_to_python_value_and_literal(transformer, python_type, format, python_val): + ctx = context_manager.FlyteContext.current_context() + tf = transformer + lt = tf.get_literal_type(python_type) + + lv = tf.to_literal(ctx, python_val, type(python_val), lt) # type: ignore + assert lv.scalar.blob.metadata == BlobMetadata( + type=BlobType( + format=format, + dimensionality=BlobType.BlobDimensionality.SINGLE, + ) + ) + assert lv.scalar.blob.uri is not None + + output = tf.to_python_value(ctx, lv, python_type) + + np.testing.assert_array_equal(output.coef_, python_val.coef_) + + +def test_example_estimator(): + @task + def t1() -> BaseEstimator: + return get_model("lr") + + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + assert task_spec.template.interface.outputs["o0"].type.blob.format is SklearnEstimatorTransformer.SKLEARN_FORMAT