Skip to content

Commit

Permalink
Adds flytekitplugin.wandb (#2405)
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas J. Fan <[email protected]>
  • Loading branch information
thomasjpfan authored May 16, 2024
1 parent edab1e3 commit 70332db
Show file tree
Hide file tree
Showing 5 changed files with 319 additions and 0 deletions.
89 changes: 89 additions & 0 deletions plugins/flytekit-wandb/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Flytekit Weights and Biases Plugin

The Weights and Biases MLOps platform helps AI developers streamline their ML workflow from end-to-end. This plugin
enables seamless use of Weights and Biases within Flyte by configuring links between the two platforms.

To install the plugin, run:

```bash
pip install flytekitplugins-wandb
```

Here is an example of running W&B with XGBoost using W&B for tracking:

```python
from flytekit import task, Secret, ImageSpec, workflow

from flytekitplugins.wandb import wandb_init

WANDB_PROJECT = "flytekit-wandb-plugin"
WANDB_ENTITY = "github-username"
WANDB_SECRET_KEY = "wandb-api-key"
WANDB_SECRET_GROUP = "wandb-api-group"
REGISTRY = "localhost:30000"

image = ImageSpec(
name="wandb_example",
python_version="3.11",
packages=["flytekitplugins-wandb", "xgboost", "scikit-learn"],
registry=REGISTRY,
)
wandb_secret = Secret(key=WANDB_SECRET_KEY, group=WANDB_SECRET_GROUP)


@task(
container_image=image,
secret_requests=[wandb_secret],
)
@wandb_init(
project=WANDB_PROJECT,
entity=WANDB_ENTITY,
secret=wandb_secret,
)
def train() -> float:
from xgboost import XGBClassifier
from wandb.integration.xgboost import WandbCallback
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

import wandb

X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
bst = XGBClassifier(
n_estimators=100,
objective="binary:logistic",
callbacks=[WandbCallback(log_model=True)],
)
bst.fit(X_train, y_train)

test_score = bst.score(X_test, y_test)

# Log custom metrics
wandb.run.log({"test_score": test_score})
return test_score


@workflow
def main() -> float:
return train()
```

Weights and Biases requires an API key to authenticate with their service. In the above example,
the secret is created using
[Flyte's Secrets manager](https://docs.flyte.org/en/latest/user_guide/productionizing/secrets.html).

To enable linking from the Flyte side panel to Weights and Biases, add the following to Flyte's
configuration

```yaml
plugins:
logs:
dynamic-log-links:
- wandb-execution-id:
displayName: Weights & Biases
templateUris: '{{ .taskConfig.host }}/{{ .taskConfig.entity }}/{{ .taskConfig.project }}/runs/{{ .executionName }}-{{ .nodeId }}-{{ .taskRetryAttempt }}'
- wandb-custom-id:
displayName: Weights & Biases
templateUris: '{{ .taskConfig.host }}/{{ .taskConfig.entity }}/{{ .taskConfig.project }}/runs/{{ .taskConfig.id }}'
```
3 changes: 3 additions & 0 deletions plugins/flytekit-wandb/flytekitplugins/wandb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .tracking import wandb_init

__all__ = ["wandb_init"]
104 changes: 104 additions & 0 deletions plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
from typing import Callable, Optional

import wandb
from flytekit import Secret
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.utils import ClassDecorator

WANDB_EXECUTION_TYPE_VALUE = "wandb-execution-id"
WANDB_CUSTOM_TYPE_VALUE = "wandb-custom-id"


class wandb_init(ClassDecorator):
WANDB_PROJECT_KEY = "project"
WANDB_ENTITY_KEY = "entity"
WANDB_ID_KEY = "id"
WANDB_HOST_KEY = "host"

def __init__(
self,
task_function: Optional[Callable] = None,
project: Optional[str] = None,
entity: Optional[str] = None,
secret: Optional[Secret] = None,
id: Optional[str] = None,
host: str = "https://wandb.ai",
**init_kwargs: dict,
):
"""Weights and Biases plugin.
Args:
task_function (function, optional): The user function to be decorated. Defaults to None.
project (str): The name of the project where you're sending the new run. (Required)
entity (str): An entity is a username or team name where you're sending runs. (Required)
secret (Secret): Secret with your `WANDB_API_KEY`. (Required)
id (str, optional): A unique id for this wandb run.
host (str, optional): URL to your wandb service. The default is "https://wandb.ai".
**init_kwargs (dict): The rest of the arguments are passed directly to `wandb.init`. Please see
[the `wandb.init` docs](https://docs.wandb.ai/ref/python/init) for details.
"""
if project is None:
raise ValueError("project must be set")
if entity is None:
raise ValueError("entity must be set")
if secret is None:
raise ValueError("secret must be set")

self.project = project
self.entity = entity
self.id = id
self.init_kwargs = init_kwargs
self.secret = secret
self.host = host

# All kwargs need to be passed up so that the function wrapping works for both
# `@wandb_init` and `@wandb_init(...)`
super().__init__(
task_function,
project=project,
entity=entity,
secret=secret,
id=id,
host=host,
**init_kwargs,
)

def execute(self, *args, **kwargs):
ctx = FlyteContextManager.current_context()
is_local_execution = ctx.execution_state.is_local_execution()

if is_local_execution:
# For location execution, always use the id. If `self.id` is `None`, wandb
# will generate it's own id.
wand_id = self.id
else:
# Set secret for remote execution
secrets = ctx.user_space_params.secrets
os.environ["WANDB_API_KEY"] = secrets.get(key=self.secret.key, group=self.secret.group)
if self.id is None:
# The HOSTNAME is set to {.executionName}-{.nodeID}-{.taskRetryAttempt}
# If HOSTNAME is not defined, use the execution name as a fallback
wand_id = os.environ.get("HOSTNAME", ctx.user_space_params.execution_id.name)
else:
wand_id = self.id

wandb.init(project=self.project, entity=self.entity, id=wand_id, **self.init_kwargs)
output = self.task_function(*args, **kwargs)
wandb.finish()
return output

def get_extra_config(self):
extra_config = {
self.WANDB_PROJECT_KEY: self.project,
self.WANDB_ENTITY_KEY: self.entity,
self.WANDB_HOST_KEY: self.host,
}

if self.id is None:
wandb_value = WANDB_EXECUTION_TYPE_VALUE
else:
wandb_value = WANDB_CUSTOM_TYPE_VALUE
extra_config[self.WANDB_ID_KEY] = self.id

extra_config[self.LINK_TYPE_KEY] = wandb_value
return extra_config
37 changes: 37 additions & 0 deletions plugins/flytekit-wandb/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from setuptools import setup

PLUGIN_NAME = "wandb"

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=1.12.0,<2.0.0", "wandb>=0.17.0"]

__version__ = "0.0.0+develop"

setup(
name=microlib_name,
version=__version__,
author="flyteorg",
author_email="[email protected]",
description="This package enables seamless use of Weights & Biases within Flyte",
namespace_packages=["flytekitplugins"],
packages=[f"flytekitplugins.{PLUGIN_NAME}"],
install_requires=plugin_requires,
license="apache2",
python_requires=">=3.8",
classifiers=[
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
)
86 changes: 86 additions & 0 deletions plugins/flytekit-wandb/tests/test_wandb_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from unittest.mock import Mock, patch

import pytest
from flytekitplugins.wandb import wandb_init
from flytekitplugins.wandb.tracking import WANDB_CUSTOM_TYPE_VALUE, WANDB_EXECUTION_TYPE_VALUE

from flytekit import task


@pytest.mark.parametrize("id", [None, "abc123"])
def test_wandb_extra_config(id):
wandb_decorator = wandb_init(
project="abc",
entity="xyz",
secret_key="my-secret-key",
id=id,
host="https://my_org.wandb.org",
)

extra_config = wandb_decorator.get_extra_config()

if id is None:
assert extra_config[wandb_decorator.LINK_TYPE_KEY] == WANDB_EXECUTION_TYPE_VALUE
assert wandb_decorator.WANDB_ID_KEY not in extra_config
else:
assert extra_config[wandb_decorator.LINK_TYPE_KEY] == WANDB_CUSTOM_TYPE_VALUE
assert extra_config[wandb_decorator.WANDB_ID_KEY] == id
assert extra_config[wandb_decorator.WANDB_HOST_KEY] == "https://my_org.wandb.org"


@task
@wandb_init(project="abc", entity="xyz", secret_key="my-secret-key", secret_group="my-secret-group", tags=["my_tag"])
def train_model():
pass


@patch("flytekitplugins.wandb.tracking.wandb")
def test_local_execution(wandb_mock):
train_model()

wandb_mock.init.assert_called_with(project="abc", entity="xyz", id=None, tags=["my_tag"])


@task
@wandb_init(project="abc", entity="xyz", secret_key="my-secret-key", tags=["my_tag"], id="1234")
def train_model_with_id():
pass


@patch("flytekitplugins.wandb.tracking.wandb")
def test_local_execution_with_id(wandb_mock):
train_model_with_id()

wandb_mock.init.assert_called_with(project="abc", entity="xyz", id="1234", tags=["my_tag"])


@patch("flytekitplugins.wandb.tracking.os")
@patch("flytekitplugins.wandb.tracking.FlyteContextManager")
@patch("flytekitplugins.wandb.tracking.wandb")
def test_non_local_execution(wandb_mock, manager_mock, os_mock):
# Pretend that the execution is remote
ctx_mock = Mock()
ctx_mock.execution_state.is_local_execution.return_value = False

ctx_mock.user_space_params.secrets.get.return_value = "this_is_the_secret"
ctx_mock.user_space_params.execution_id.name = "my_execution_id"

manager_mock.current_context.return_value = ctx_mock
os_mock.environ = {}

train_model()

wandb_mock.init.assert_called_with(project="abc", entity="xyz", id="my_execution_id", tags=["my_tag"])
ctx_mock.user_space_params.secrets.get.assert_called_with(key="my-secret-key", group="my-secret-group")
assert os_mock.environ["WANDB_API_KEY"] == "this_is_the_secret"


def test_errors():
with pytest.raises(ValueError, match="project must be set"):
wandb_init()

with pytest.raises(ValueError, match="entity must be set"):
wandb_init(project="abc")

with pytest.raises(ValueError, match="secret_key must be set"):
wandb_init(project="abc", entity="xyz")

0 comments on commit 70332db

Please sign in to comment.