diff --git a/plugins/flytekit-wandb/README.md b/plugins/flytekit-wandb/README.md new file mode 100644 index 0000000000..7c3984055f --- /dev/null +++ b/plugins/flytekit-wandb/README.md @@ -0,0 +1,89 @@ +# Flytekit Weights and Biases Plugin + +The Weights and Biases MLOps platform helps AI developers streamline their ML workflow from end-to-end. This plugin +enables seamless use of Weights and Biases within Flyte by configuring links between the two platforms. + +To install the plugin, run: + +```bash +pip install flytekitplugins-wandb +``` + +Here is an example of running W&B with XGBoost using W&B for tracking: + +```python +from flytekit import task, Secret, ImageSpec, workflow + +from flytekitplugins.wandb import wandb_init + +WANDB_PROJECT = "flytekit-wandb-plugin" +WANDB_ENTITY = "github-username" +WANDB_SECRET_KEY = "wandb-api-key" +WANDB_SECRET_GROUP = "wandb-api-group" +REGISTRY = "localhost:30000" + +image = ImageSpec( + name="wandb_example", + python_version="3.11", + packages=["flytekitplugins-wandb", "xgboost", "scikit-learn"], + registry=REGISTRY, +) +wandb_secret = Secret(key=WANDB_SECRET_KEY, group=WANDB_SECRET_GROUP) + + +@task( + container_image=image, + secret_requests=[wandb_secret], +) +@wandb_init( + project=WANDB_PROJECT, + entity=WANDB_ENTITY, + secret=wandb_secret, +) +def train() -> float: + from xgboost import XGBClassifier + from wandb.integration.xgboost import WandbCallback + from sklearn.datasets import load_iris + from sklearn.model_selection import train_test_split + + import wandb + + X, y = load_iris(return_X_y=True) + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) + bst = XGBClassifier( + n_estimators=100, + objective="binary:logistic", + callbacks=[WandbCallback(log_model=True)], + ) + bst.fit(X_train, y_train) + + test_score = bst.score(X_test, y_test) + + # Log custom metrics + wandb.run.log({"test_score": test_score}) + return test_score + + +@workflow +def main() -> float: + return train() +``` + +Weights and Biases requires an API key to authenticate with their service. In the above example, +the secret is created using +[Flyte's Secrets manager](https://docs.flyte.org/en/latest/user_guide/productionizing/secrets.html). + +To enable linking from the Flyte side panel to Weights and Biases, add the following to Flyte's +configuration + +```yaml +plugins: + logs: + dynamic-log-links: + - wandb-execution-id: + displayName: Weights & Biases + templateUris: '{{ .taskConfig.host }}/{{ .taskConfig.entity }}/{{ .taskConfig.project }}/runs/{{ .executionName }}-{{ .nodeId }}-{{ .taskRetryAttempt }}' + - wandb-custom-id: + displayName: Weights & Biases + templateUris: '{{ .taskConfig.host }}/{{ .taskConfig.entity }}/{{ .taskConfig.project }}/runs/{{ .taskConfig.id }}' +``` diff --git a/plugins/flytekit-wandb/flytekitplugins/wandb/__init__.py b/plugins/flytekit-wandb/flytekitplugins/wandb/__init__.py new file mode 100644 index 0000000000..329f90d40f --- /dev/null +++ b/plugins/flytekit-wandb/flytekitplugins/wandb/__init__.py @@ -0,0 +1,3 @@ +from .tracking import wandb_init + +__all__ = ["wandb_init"] diff --git a/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py b/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py new file mode 100644 index 0000000000..216bf176c6 --- /dev/null +++ b/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py @@ -0,0 +1,104 @@ +import os +from typing import Callable, Optional + +import wandb +from flytekit import Secret +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.utils import ClassDecorator + +WANDB_EXECUTION_TYPE_VALUE = "wandb-execution-id" +WANDB_CUSTOM_TYPE_VALUE = "wandb-custom-id" + + +class wandb_init(ClassDecorator): + WANDB_PROJECT_KEY = "project" + WANDB_ENTITY_KEY = "entity" + WANDB_ID_KEY = "id" + WANDB_HOST_KEY = "host" + + def __init__( + self, + task_function: Optional[Callable] = None, + project: Optional[str] = None, + entity: Optional[str] = None, + secret: Optional[Secret] = None, + id: Optional[str] = None, + host: str = "https://wandb.ai", + **init_kwargs: dict, + ): + """Weights and Biases plugin. + Args: + task_function (function, optional): The user function to be decorated. Defaults to None. + project (str): The name of the project where you're sending the new run. (Required) + entity (str): An entity is a username or team name where you're sending runs. (Required) + secret (Secret): Secret with your `WANDB_API_KEY`. (Required) + id (str, optional): A unique id for this wandb run. + host (str, optional): URL to your wandb service. The default is "https://wandb.ai". + **init_kwargs (dict): The rest of the arguments are passed directly to `wandb.init`. Please see + [the `wandb.init` docs](https://docs.wandb.ai/ref/python/init) for details. + """ + if project is None: + raise ValueError("project must be set") + if entity is None: + raise ValueError("entity must be set") + if secret is None: + raise ValueError("secret must be set") + + self.project = project + self.entity = entity + self.id = id + self.init_kwargs = init_kwargs + self.secret = secret + self.host = host + + # All kwargs need to be passed up so that the function wrapping works for both + # `@wandb_init` and `@wandb_init(...)` + super().__init__( + task_function, + project=project, + entity=entity, + secret=secret, + id=id, + host=host, + **init_kwargs, + ) + + def execute(self, *args, **kwargs): + ctx = FlyteContextManager.current_context() + is_local_execution = ctx.execution_state.is_local_execution() + + if is_local_execution: + # For location execution, always use the id. If `self.id` is `None`, wandb + # will generate it's own id. + wand_id = self.id + else: + # Set secret for remote execution + secrets = ctx.user_space_params.secrets + os.environ["WANDB_API_KEY"] = secrets.get(key=self.secret.key, group=self.secret.group) + if self.id is None: + # The HOSTNAME is set to {.executionName}-{.nodeID}-{.taskRetryAttempt} + # If HOSTNAME is not defined, use the execution name as a fallback + wand_id = os.environ.get("HOSTNAME", ctx.user_space_params.execution_id.name) + else: + wand_id = self.id + + wandb.init(project=self.project, entity=self.entity, id=wand_id, **self.init_kwargs) + output = self.task_function(*args, **kwargs) + wandb.finish() + return output + + def get_extra_config(self): + extra_config = { + self.WANDB_PROJECT_KEY: self.project, + self.WANDB_ENTITY_KEY: self.entity, + self.WANDB_HOST_KEY: self.host, + } + + if self.id is None: + wandb_value = WANDB_EXECUTION_TYPE_VALUE + else: + wandb_value = WANDB_CUSTOM_TYPE_VALUE + extra_config[self.WANDB_ID_KEY] = self.id + + extra_config[self.LINK_TYPE_KEY] = wandb_value + return extra_config diff --git a/plugins/flytekit-wandb/setup.py b/plugins/flytekit-wandb/setup.py new file mode 100644 index 0000000000..6c41c28cfd --- /dev/null +++ b/plugins/flytekit-wandb/setup.py @@ -0,0 +1,37 @@ +from setuptools import setup + +PLUGIN_NAME = "wandb" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.12.0,<2.0.0", "wandb>=0.17.0"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package enables seamless use of Weights & Biases within Flyte", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-wandb/tests/test_wandb_init.py b/plugins/flytekit-wandb/tests/test_wandb_init.py new file mode 100644 index 0000000000..67d866b5c1 --- /dev/null +++ b/plugins/flytekit-wandb/tests/test_wandb_init.py @@ -0,0 +1,86 @@ +from unittest.mock import Mock, patch + +import pytest +from flytekitplugins.wandb import wandb_init +from flytekitplugins.wandb.tracking import WANDB_CUSTOM_TYPE_VALUE, WANDB_EXECUTION_TYPE_VALUE + +from flytekit import task + + +@pytest.mark.parametrize("id", [None, "abc123"]) +def test_wandb_extra_config(id): + wandb_decorator = wandb_init( + project="abc", + entity="xyz", + secret_key="my-secret-key", + id=id, + host="https://my_org.wandb.org", + ) + + extra_config = wandb_decorator.get_extra_config() + + if id is None: + assert extra_config[wandb_decorator.LINK_TYPE_KEY] == WANDB_EXECUTION_TYPE_VALUE + assert wandb_decorator.WANDB_ID_KEY not in extra_config + else: + assert extra_config[wandb_decorator.LINK_TYPE_KEY] == WANDB_CUSTOM_TYPE_VALUE + assert extra_config[wandb_decorator.WANDB_ID_KEY] == id + assert extra_config[wandb_decorator.WANDB_HOST_KEY] == "https://my_org.wandb.org" + + +@task +@wandb_init(project="abc", entity="xyz", secret_key="my-secret-key", secret_group="my-secret-group", tags=["my_tag"]) +def train_model(): + pass + + +@patch("flytekitplugins.wandb.tracking.wandb") +def test_local_execution(wandb_mock): + train_model() + + wandb_mock.init.assert_called_with(project="abc", entity="xyz", id=None, tags=["my_tag"]) + + +@task +@wandb_init(project="abc", entity="xyz", secret_key="my-secret-key", tags=["my_tag"], id="1234") +def train_model_with_id(): + pass + + +@patch("flytekitplugins.wandb.tracking.wandb") +def test_local_execution_with_id(wandb_mock): + train_model_with_id() + + wandb_mock.init.assert_called_with(project="abc", entity="xyz", id="1234", tags=["my_tag"]) + + +@patch("flytekitplugins.wandb.tracking.os") +@patch("flytekitplugins.wandb.tracking.FlyteContextManager") +@patch("flytekitplugins.wandb.tracking.wandb") +def test_non_local_execution(wandb_mock, manager_mock, os_mock): + # Pretend that the execution is remote + ctx_mock = Mock() + ctx_mock.execution_state.is_local_execution.return_value = False + + ctx_mock.user_space_params.secrets.get.return_value = "this_is_the_secret" + ctx_mock.user_space_params.execution_id.name = "my_execution_id" + + manager_mock.current_context.return_value = ctx_mock + os_mock.environ = {} + + train_model() + + wandb_mock.init.assert_called_with(project="abc", entity="xyz", id="my_execution_id", tags=["my_tag"]) + ctx_mock.user_space_params.secrets.get.assert_called_with(key="my-secret-key", group="my-secret-group") + assert os_mock.environ["WANDB_API_KEY"] == "this_is_the_secret" + + +def test_errors(): + with pytest.raises(ValueError, match="project must be set"): + wandb_init() + + with pytest.raises(ValueError, match="entity must be set"): + wandb_init(project="abc") + + with pytest.raises(ValueError, match="secret_key must be set"): + wandb_init(project="abc", entity="xyz")