Weights and Biases Example#

The Weights & Biases MLOps platform helps AI developers streamline their ML workflow from end-to-end. This plugin enables seamless use of Weights & Biases within Flyte by configuring links between the two platforms.

from flytekit import ImageSpec, Secret, task, workflow
from flytekitplugins.wandb import wandb_init

First, we specify the project and entity that we will use with Weights and Biases. Please update WANDB_ENTITY to the value associated with your account.

WANDB_PROJECT = "flytekit-wandb-plugin"
WANDB_ENTITY = "github-username"

W&B requires an API key to authenticate with their service. In the above example, the secret is created using Flyte’s Secrets manager.

SECRET_KEY = "wandb-api-key"
SECRET_GROUP = "wandb-api-group"
wandb_secret = Secret(key=SECRET_KEY, group=SECRET_GROUP)

Next, we use ImageSpec to construct a container that contains the dependencies for this task:

REGISTRY = "localhost:30000"

image = ImageSpec(
    name="wandb_example",
    python_version="3.11",
    packages=["flytekitplugins-wandb", "xgboost", "scikit-learn"],
    registry=REGISTRY,
)
# The `wandb_init` decorator calls `wandb.init` and configures it to use Flyte's
# execution id as the Weight and Biases run id. The body of the task is XGBoost training
# code, where we pass `WandbCallback` into `XGBClassifier`'s `callbacks`.
@task(
    container_image=image,
    secret_requests=[wandb_secret],
)
@wandb_init(project=WANDB_PROJECT, entity=WANDB_ENTITY, secret=wandb_secret)
def train() -> float:
    import wandb
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from wandb.integration.xgboost import WandbCallback
    from xgboost import XGBClassifier

    X, y = load_iris(return_X_y=True)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
    bst = XGBClassifier(
        n_estimators=100,
        objective="binary:logistic",
        callbacks=[WandbCallback(log_model=True)],
    )
    bst.fit(X_train, y_train)

    test_score = bst.score(X_test, y_test)

    # Log custom metrics
    wandb.run.log({"test_score": test_score})
    return test_score


@workflow
def wf() -> float:
    return train()

To enable dynamic log links, add plugin to Flyte’s configuration file:

dynamic-log-links:
   - wandb-execution-id:
       displayName: Weights & Biases
       templateUris: '{{ .taskConfig.host }}/{{ .taskConfig.entity }}/{{ .taskConfig.project }}/runs/{{ .executionName }}-{{ .nodeId }}-{{ .taskRetryAttempt }}'
   - wandb-custom-id:
       displayName: Weights & Biases
       templateUris: '{{ .taskConfig.host }}/{{ .taskConfig.entity }}/{{ .taskConfig.project }}/runs/{{ .taskConfig.id }}'