-
Notifications
You must be signed in to change notification settings - Fork 301
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Thomas J. Fan <[email protected]>
- Loading branch information
1 parent
edab1e3
commit 70332db
Showing
5 changed files
with
319 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Flytekit Weights and Biases Plugin | ||
|
||
The Weights and Biases MLOps platform helps AI developers streamline their ML workflow from end-to-end. This plugin | ||
enables seamless use of Weights and Biases within Flyte by configuring links between the two platforms. | ||
|
||
To install the plugin, run: | ||
|
||
```bash | ||
pip install flytekitplugins-wandb | ||
``` | ||
|
||
Here is an example of running W&B with XGBoost using W&B for tracking: | ||
|
||
```python | ||
from flytekit import task, Secret, ImageSpec, workflow | ||
|
||
from flytekitplugins.wandb import wandb_init | ||
|
||
WANDB_PROJECT = "flytekit-wandb-plugin" | ||
WANDB_ENTITY = "github-username" | ||
WANDB_SECRET_KEY = "wandb-api-key" | ||
WANDB_SECRET_GROUP = "wandb-api-group" | ||
REGISTRY = "localhost:30000" | ||
|
||
image = ImageSpec( | ||
name="wandb_example", | ||
python_version="3.11", | ||
packages=["flytekitplugins-wandb", "xgboost", "scikit-learn"], | ||
registry=REGISTRY, | ||
) | ||
wandb_secret = Secret(key=WANDB_SECRET_KEY, group=WANDB_SECRET_GROUP) | ||
|
||
|
||
@task( | ||
container_image=image, | ||
secret_requests=[wandb_secret], | ||
) | ||
@wandb_init( | ||
project=WANDB_PROJECT, | ||
entity=WANDB_ENTITY, | ||
secret=wandb_secret, | ||
) | ||
def train() -> float: | ||
from xgboost import XGBClassifier | ||
from wandb.integration.xgboost import WandbCallback | ||
from sklearn.datasets import load_iris | ||
from sklearn.model_selection import train_test_split | ||
|
||
import wandb | ||
|
||
X, y = load_iris(return_X_y=True) | ||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) | ||
bst = XGBClassifier( | ||
n_estimators=100, | ||
objective="binary:logistic", | ||
callbacks=[WandbCallback(log_model=True)], | ||
) | ||
bst.fit(X_train, y_train) | ||
|
||
test_score = bst.score(X_test, y_test) | ||
|
||
# Log custom metrics | ||
wandb.run.log({"test_score": test_score}) | ||
return test_score | ||
|
||
|
||
@workflow | ||
def main() -> float: | ||
return train() | ||
``` | ||
|
||
Weights and Biases requires an API key to authenticate with their service. In the above example, | ||
the secret is created using | ||
[Flyte's Secrets manager](https://docs.flyte.org/en/latest/user_guide/productionizing/secrets.html). | ||
|
||
To enable linking from the Flyte side panel to Weights and Biases, add the following to Flyte's | ||
configuration | ||
|
||
```yaml | ||
plugins: | ||
logs: | ||
dynamic-log-links: | ||
- wandb-execution-id: | ||
displayName: Weights & Biases | ||
templateUris: '{{ .taskConfig.host }}/{{ .taskConfig.entity }}/{{ .taskConfig.project }}/runs/{{ .executionName }}-{{ .nodeId }}-{{ .taskRetryAttempt }}' | ||
- wandb-custom-id: | ||
displayName: Weights & Biases | ||
templateUris: '{{ .taskConfig.host }}/{{ .taskConfig.entity }}/{{ .taskConfig.project }}/runs/{{ .taskConfig.id }}' | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .tracking import wandb_init | ||
|
||
__all__ = ["wandb_init"] |
104 changes: 104 additions & 0 deletions
104
plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import os | ||
from typing import Callable, Optional | ||
|
||
import wandb | ||
from flytekit import Secret | ||
from flytekit.core.context_manager import FlyteContextManager | ||
from flytekit.core.utils import ClassDecorator | ||
|
||
WANDB_EXECUTION_TYPE_VALUE = "wandb-execution-id" | ||
WANDB_CUSTOM_TYPE_VALUE = "wandb-custom-id" | ||
|
||
|
||
class wandb_init(ClassDecorator): | ||
WANDB_PROJECT_KEY = "project" | ||
WANDB_ENTITY_KEY = "entity" | ||
WANDB_ID_KEY = "id" | ||
WANDB_HOST_KEY = "host" | ||
|
||
def __init__( | ||
self, | ||
task_function: Optional[Callable] = None, | ||
project: Optional[str] = None, | ||
entity: Optional[str] = None, | ||
secret: Optional[Secret] = None, | ||
id: Optional[str] = None, | ||
host: str = "https://wandb.ai", | ||
**init_kwargs: dict, | ||
): | ||
"""Weights and Biases plugin. | ||
Args: | ||
task_function (function, optional): The user function to be decorated. Defaults to None. | ||
project (str): The name of the project where you're sending the new run. (Required) | ||
entity (str): An entity is a username or team name where you're sending runs. (Required) | ||
secret (Secret): Secret with your `WANDB_API_KEY`. (Required) | ||
id (str, optional): A unique id for this wandb run. | ||
host (str, optional): URL to your wandb service. The default is "https://wandb.ai". | ||
**init_kwargs (dict): The rest of the arguments are passed directly to `wandb.init`. Please see | ||
[the `wandb.init` docs](https://docs.wandb.ai/ref/python/init) for details. | ||
""" | ||
if project is None: | ||
raise ValueError("project must be set") | ||
if entity is None: | ||
raise ValueError("entity must be set") | ||
if secret is None: | ||
raise ValueError("secret must be set") | ||
|
||
self.project = project | ||
self.entity = entity | ||
self.id = id | ||
self.init_kwargs = init_kwargs | ||
self.secret = secret | ||
self.host = host | ||
|
||
# All kwargs need to be passed up so that the function wrapping works for both | ||
# `@wandb_init` and `@wandb_init(...)` | ||
super().__init__( | ||
task_function, | ||
project=project, | ||
entity=entity, | ||
secret=secret, | ||
id=id, | ||
host=host, | ||
**init_kwargs, | ||
) | ||
|
||
def execute(self, *args, **kwargs): | ||
ctx = FlyteContextManager.current_context() | ||
is_local_execution = ctx.execution_state.is_local_execution() | ||
|
||
if is_local_execution: | ||
# For location execution, always use the id. If `self.id` is `None`, wandb | ||
# will generate it's own id. | ||
wand_id = self.id | ||
else: | ||
# Set secret for remote execution | ||
secrets = ctx.user_space_params.secrets | ||
os.environ["WANDB_API_KEY"] = secrets.get(key=self.secret.key, group=self.secret.group) | ||
if self.id is None: | ||
# The HOSTNAME is set to {.executionName}-{.nodeID}-{.taskRetryAttempt} | ||
# If HOSTNAME is not defined, use the execution name as a fallback | ||
wand_id = os.environ.get("HOSTNAME", ctx.user_space_params.execution_id.name) | ||
else: | ||
wand_id = self.id | ||
|
||
wandb.init(project=self.project, entity=self.entity, id=wand_id, **self.init_kwargs) | ||
output = self.task_function(*args, **kwargs) | ||
wandb.finish() | ||
return output | ||
|
||
def get_extra_config(self): | ||
extra_config = { | ||
self.WANDB_PROJECT_KEY: self.project, | ||
self.WANDB_ENTITY_KEY: self.entity, | ||
self.WANDB_HOST_KEY: self.host, | ||
} | ||
|
||
if self.id is None: | ||
wandb_value = WANDB_EXECUTION_TYPE_VALUE | ||
else: | ||
wandb_value = WANDB_CUSTOM_TYPE_VALUE | ||
extra_config[self.WANDB_ID_KEY] = self.id | ||
|
||
extra_config[self.LINK_TYPE_KEY] = wandb_value | ||
return extra_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from setuptools import setup | ||
|
||
PLUGIN_NAME = "wandb" | ||
|
||
microlib_name = f"flytekitplugins-{PLUGIN_NAME}" | ||
|
||
plugin_requires = ["flytekit>=1.12.0,<2.0.0", "wandb>=0.17.0"] | ||
|
||
__version__ = "0.0.0+develop" | ||
|
||
setup( | ||
name=microlib_name, | ||
version=__version__, | ||
author="flyteorg", | ||
author_email="[email protected]", | ||
description="This package enables seamless use of Weights & Biases within Flyte", | ||
namespace_packages=["flytekitplugins"], | ||
packages=[f"flytekitplugins.{PLUGIN_NAME}"], | ||
install_requires=plugin_requires, | ||
license="apache2", | ||
python_requires=">=3.8", | ||
classifiers=[ | ||
"Intended Audience :: Science/Research", | ||
"Intended Audience :: Developers", | ||
"License :: OSI Approved :: Apache Software License", | ||
"Programming Language :: Python :: 3.8", | ||
"Programming Language :: Python :: 3.9", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
"Programming Language :: Python :: 3.12", | ||
"Topic :: Scientific/Engineering", | ||
"Topic :: Scientific/Engineering :: Artificial Intelligence", | ||
"Topic :: Software Development", | ||
"Topic :: Software Development :: Libraries", | ||
"Topic :: Software Development :: Libraries :: Python Modules", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from unittest.mock import Mock, patch | ||
|
||
import pytest | ||
from flytekitplugins.wandb import wandb_init | ||
from flytekitplugins.wandb.tracking import WANDB_CUSTOM_TYPE_VALUE, WANDB_EXECUTION_TYPE_VALUE | ||
|
||
from flytekit import task | ||
|
||
|
||
@pytest.mark.parametrize("id", [None, "abc123"]) | ||
def test_wandb_extra_config(id): | ||
wandb_decorator = wandb_init( | ||
project="abc", | ||
entity="xyz", | ||
secret_key="my-secret-key", | ||
id=id, | ||
host="https://my_org.wandb.org", | ||
) | ||
|
||
extra_config = wandb_decorator.get_extra_config() | ||
|
||
if id is None: | ||
assert extra_config[wandb_decorator.LINK_TYPE_KEY] == WANDB_EXECUTION_TYPE_VALUE | ||
assert wandb_decorator.WANDB_ID_KEY not in extra_config | ||
else: | ||
assert extra_config[wandb_decorator.LINK_TYPE_KEY] == WANDB_CUSTOM_TYPE_VALUE | ||
assert extra_config[wandb_decorator.WANDB_ID_KEY] == id | ||
assert extra_config[wandb_decorator.WANDB_HOST_KEY] == "https://my_org.wandb.org" | ||
|
||
|
||
@task | ||
@wandb_init(project="abc", entity="xyz", secret_key="my-secret-key", secret_group="my-secret-group", tags=["my_tag"]) | ||
def train_model(): | ||
pass | ||
|
||
|
||
@patch("flytekitplugins.wandb.tracking.wandb") | ||
def test_local_execution(wandb_mock): | ||
train_model() | ||
|
||
wandb_mock.init.assert_called_with(project="abc", entity="xyz", id=None, tags=["my_tag"]) | ||
|
||
|
||
@task | ||
@wandb_init(project="abc", entity="xyz", secret_key="my-secret-key", tags=["my_tag"], id="1234") | ||
def train_model_with_id(): | ||
pass | ||
|
||
|
||
@patch("flytekitplugins.wandb.tracking.wandb") | ||
def test_local_execution_with_id(wandb_mock): | ||
train_model_with_id() | ||
|
||
wandb_mock.init.assert_called_with(project="abc", entity="xyz", id="1234", tags=["my_tag"]) | ||
|
||
|
||
@patch("flytekitplugins.wandb.tracking.os") | ||
@patch("flytekitplugins.wandb.tracking.FlyteContextManager") | ||
@patch("flytekitplugins.wandb.tracking.wandb") | ||
def test_non_local_execution(wandb_mock, manager_mock, os_mock): | ||
# Pretend that the execution is remote | ||
ctx_mock = Mock() | ||
ctx_mock.execution_state.is_local_execution.return_value = False | ||
|
||
ctx_mock.user_space_params.secrets.get.return_value = "this_is_the_secret" | ||
ctx_mock.user_space_params.execution_id.name = "my_execution_id" | ||
|
||
manager_mock.current_context.return_value = ctx_mock | ||
os_mock.environ = {} | ||
|
||
train_model() | ||
|
||
wandb_mock.init.assert_called_with(project="abc", entity="xyz", id="my_execution_id", tags=["my_tag"]) | ||
ctx_mock.user_space_params.secrets.get.assert_called_with(key="my-secret-key", group="my-secret-group") | ||
assert os_mock.environ["WANDB_API_KEY"] == "this_is_the_secret" | ||
|
||
|
||
def test_errors(): | ||
with pytest.raises(ValueError, match="project must be set"): | ||
wandb_init() | ||
|
||
with pytest.raises(ValueError, match="entity must be set"): | ||
wandb_init(project="abc") | ||
|
||
with pytest.raises(ValueError, match="secret_key must be set"): | ||
wandb_init(project="abc", entity="xyz") |