Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Examples] Add airflow example #3982

Merged
merged 17 commits into from
Sep 26, 2024
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ Runnable examples:
- [LocalGPT](./llm/localgpt)
- [Falcon](./llm/falcon)
- Add yours here & see more in [`llm/`](./llm)!
- Framework examples: [PyTorch DDP](https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_distributed_torch.yaml), [DeepSpeed](./examples/deepspeed-multinode/sky.yaml), [JAX/Flax on TPU](https://github.com/skypilot-org/skypilot/blob/master/examples/tpu/tpuvm_mnist.yaml), [Stable Diffusion](https://github.com/skypilot-org/skypilot/tree/master/examples/stable_diffusion), [Detectron2](https://github.com/skypilot-org/skypilot/blob/master/examples/detectron2_docker.yaml), [Distributed](https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_distributed_tf_app.py) [TensorFlow](https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_app_storage.yaml), [Ray Train](examples/distributed_ray_train/ray_train.yaml), [NeMo](https://github.com/skypilot-org/skypilot/blob/master/examples/nemo/nemo.yaml), [programmatic grid search](https://github.com/skypilot-org/skypilot/blob/master/examples/huggingface_glue_imdb_grid_search_app.py), [Docker](https://github.com/skypilot-org/skypilot/blob/master/examples/docker/echo_app.yaml), [Cog](https://github.com/skypilot-org/skypilot/blob/master/examples/cog/), [Unsloth](https://github.com/skypilot-org/skypilot/blob/master/examples/unsloth/unsloth.yaml), [Ollama](https://github.com/skypilot-org/skypilot/blob/master/llm/ollama), [llm.c](https://github.com/skypilot-org/skypilot/tree/master/llm/gpt-2) and [many more (`examples/`)](./examples).
- Framework examples: [PyTorch DDP](https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_distributed_torch.yaml), [DeepSpeed](./examples/deepspeed-multinode/sky.yaml), [JAX/Flax on TPU](https://github.com/skypilot-org/skypilot/blob/master/examples/tpu/tpuvm_mnist.yaml), [Stable Diffusion](https://github.com/skypilot-org/skypilot/tree/master/examples/stable_diffusion), [Detectron2](https://github.com/skypilot-org/skypilot/blob/master/examples/detectron2_docker.yaml), [Distributed](https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_distributed_tf_app.py) [TensorFlow](https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_app_storage.yaml), [Ray Train](examples/distributed_ray_train/ray_train.yaml), [NeMo](https://github.com/skypilot-org/skypilot/blob/master/examples/nemo/nemo.yaml), [programmatic grid search](https://github.com/skypilot-org/skypilot/blob/master/examples/huggingface_glue_imdb_grid_search_app.py), [Docker](https://github.com/skypilot-org/skypilot/blob/master/examples/docker/echo_app.yaml), [Cog](https://github.com/skypilot-org/skypilot/blob/master/examples/cog/), [Unsloth](https://github.com/skypilot-org/skypilot/blob/master/examples/unsloth/unsloth.yaml), [Ollama](https://github.com/skypilot-org/skypilot/blob/master/llm/ollama), [llm.c](https://github.com/skypilot-org/skypilot/tree/master/llm/gpt-2), [Airflow](./examples/airflow/training_workflow) and [many more (`examples/`)](./examples).

Case Studies and Integrations: [Community Spotlights](https://blog.skypilot.co/community/)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ Runnable examples:
* `Falcon <https://github.com/skypilot-org/skypilot/tree/master/llm/falcon>`_
* Add yours here & see more in `llm/ <https://github.com/skypilot-org/skypilot/tree/master/llm>`_!

* Framework examples: `PyTorch DDP <https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_distributed_torch.yaml>`_, `DeepSpeed <https://github.com/skypilot-org/skypilot/blob/master/examples/deepspeed-multinode/sky.yaml>`_, `JAX/Flax on TPU <https://github.com/skypilot-org/skypilot/blob/master/examples/tpu/tpuvm_mnist.yaml>`_, `Stable Diffusion <https://github.com/skypilot-org/skypilot/tree/master/examples/stable_diffusion>`_, `Detectron2 <https://github.com/skypilot-org/skypilot/blob/master/examples/detectron2_docker.yaml>`_, `Distributed <https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_distributed_tf_app.py>`_ `TensorFlow <https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_app_storage.yaml>`_, `NeMo <https://github.com/skypilot-org/skypilot/blob/master/examples/nemo/nemo_gpt_train.yaml>`_, `programmatic grid search <https://github.com/skypilot-org/skypilot/blob/master/examples/huggingface_glue_imdb_grid_search_app.py>`_, `Docker <https://github.com/skypilot-org/skypilot/blob/master/examples/docker/echo_app.yaml>`_, `Cog <https://github.com/skypilot-org/skypilot/blob/master/examples/cog/>`_, `Unsloth <https://github.com/skypilot-org/skypilot/blob/master/examples/unsloth/unsloth.yaml>`_, `Ollama <https://github.com/skypilot-org/skypilot/blob/master/llm/ollama>`_, `llm.c <https://github.com/skypilot-org/skypilot/tree/master/llm/gpt-2>`__ and `many more <https://github.com/skypilot-org/skypilot/tree/master/examples>`_.
* Framework examples: `PyTorch DDP <https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_distributed_torch.yaml>`_, `DeepSpeed <https://github.com/skypilot-org/skypilot/blob/master/examples/deepspeed-multinode/sky.yaml>`_, `JAX/Flax on TPU <https://github.com/skypilot-org/skypilot/blob/master/examples/tpu/tpuvm_mnist.yaml>`_, `Stable Diffusion <https://github.com/skypilot-org/skypilot/tree/master/examples/stable_diffusion>`_, `Detectron2 <https://github.com/skypilot-org/skypilot/blob/master/examples/detectron2_docker.yaml>`_, `Distributed <https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_distributed_tf_app.py>`_ `TensorFlow <https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_app_storage.yaml>`_, `NeMo <https://github.com/skypilot-org/skypilot/blob/master/examples/nemo/nemo_gpt_train.yaml>`_, `programmatic grid search <https://github.com/skypilot-org/skypilot/blob/master/examples/huggingface_glue_imdb_grid_search_app.py>`_, `Docker <https://github.com/skypilot-org/skypilot/blob/master/examples/docker/echo_app.yaml>`_, `Cog <https://github.com/skypilot-org/skypilot/blob/master/examples/cog/>`_, `Unsloth <https://github.com/skypilot-org/skypilot/blob/master/examples/unsloth/unsloth.yaml>`_, `Ollama <https://github.com/skypilot-org/skypilot/blob/master/llm/ollama>`_, `llm.c <https://github.com/skypilot-org/skypilot/tree/master/llm/gpt-2>`__, `Airflow <https://github.com/skypilot-org/skypilot/blob/master/examples/airflow/training_workflow>`_ and `many more <https://github.com/skypilot-org/skypilot/tree/master/examples>`_.

Case Studies and Integrations: `Community Spotlights <https://blog.skypilot.co/community/>`_

Expand Down
9 changes: 9 additions & 0 deletions examples/airflow/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SkyPilot Airflow integration examples

This directory contains two examples of integrating SkyPilot with Apache Airflow:
1. [training_workflow](training_workflow)
* A simple training workflow that preprocesses data, trains a model, and evaluates it.
* Showcases how SkyPilot can help easily transition from dev to production in Airflow.
2. [shared_state](shared_state)
* An example showing how SkyPilot state can be persisted across Airflow tasks.
* Useful for operating on the same shared SkyPilot clusters from different Airflow tasks.
172 changes: 172 additions & 0 deletions examples/airflow/shared_state/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Running SkyPilot tasks in an Airflow DAG

SkyPilot can be used in an orchestration framework like Airflow to launch tasks as a part of a DAG.

In this guide, we demonstrate how some simple SkyPilot operations, such as launching a cluster, getting its logs and tearing it down, can be orchestrated using Airflow.

<p align="center">
<img src="https://i.imgur.com/BVZBaR9.png" width="800">
</p>

## Prerequisites

* Airflow installed on a [Kubernetes cluster](https://airflow.apache.org/docs/helm-chart/stable/index.html) or [locally](https://airflow.apache.org/docs/apache-airflow/stable/start.html) (`SequentialExecutor`)
* A Kubernetes cluster to run tasks on. We'll use GKE in this example.
romilbhardwaj marked this conversation as resolved.
Show resolved Hide resolved
* A persistent volume storage class should be available that supports at least `ReadWriteOnce` access mode.
romilbhardwaj marked this conversation as resolved.
Show resolved Hide resolved

## Preparing the Kubernetes Cluster

1. Provision a service account on your Kubernetes cluster for SkyPilot to use to launch tasks.
```bash
kubectl apply -f sky-sa.yaml
```
For reference, here are the contents of `sky-sa.yaml`:
```yaml
# sky-sa.yaml
apiVersion: v1
kind: ServiceAccount
metadata:
name: sky-airflow-sa
namespace: default
---
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRoleBinding
metadata:
name: sky-airflow-sa-binding
subjects:
- kind: ServiceAccount
name: sky-airflow-sa
namespace: default
roleRef:
# For minimal permissions, refer to https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions/kubernetes.html
kind: ClusterRole
name: cluster-admin
apiGroup: rbac.authorization.k8s.io
```

2. Provision a persistent volume for SkyPilot to store state across runs.
```bash
kubectl apply -f sky-pv.yaml
```
For reference, here are the contents of `sky-pv.yaml`:
```yaml
# sky-pv.yaml
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: sky-pvc
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 10Gi # 10Gi is minimum for GKE pd-balanced
storageClassName: standard-rwo
```
Note: The `storageClassName` should be set to the appropriate storage class that's supported on your cluster. If you have many concurrent tasks, you may want to use a storage class that supports `ReadWriteMany` access mode.

## Writing the Airflow DAG

We provide an example DAG in `sky_k8s_example.py` that:
1. Launches a SkyPilot cluster.
2. Writes logs from the cluster to a local file
3. Checks the status of the cluster and prints to Airflow logs
4. Tears down the cluster.

The DAG is defined in `sky_k8s_example.py`:

```python
# sky_k8s_example.py
from airflow import DAG
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator
from airflow.utils.dates import days_ago

from kubernetes.client import models as k8s

default_args = {
'owner': 'airflow',
'start_date': days_ago(1),
}

def get_skypilot_task(task_id: str, sky_command: str):
skypilot_task = KubernetesPodOperator(
task_id=task_id,
name="skypilot-pod",
namespace="default",
image="us-central1-docker.pkg.dev/skypilot-375900/skypilotk8s/skypilot:20240613",
cmds=["/bin/bash", "-i", "-c"],
arguments=[
"chown -R 1000:1000 /home/sky/.sky /home/sky/.ssh && "
"pip install skypilot-nightly[kubernetes] && "
f"{sky_command}"],
service_account_name="sky-airflow-sa",
env_vars={"HOME": "/home/sky"},
volumes=[
k8s.V1Volume(
name="sky-pvc",
persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(
claim_name="sky-pvc"
),
),
],
volume_mounts=[
k8s.V1VolumeMount(name="sky-pvc", mount_path="/home/sky/.sky",
sub_path="sky"),
k8s.V1VolumeMount(name="sky-pvc", mount_path="/home/sky/.ssh",
sub_path="ssh"),
],
is_delete_operator_pod=True,
get_logs=True,
)
return skypilot_task


with DAG(dag_id='sky_k8s_example',
default_args=default_args,
schedule_interval=None,
catchup=False) as dag:
# Task to launch a SkyPilot cluster
cmds = ("git clone https://github.com/skypilot-org/skypilot.git && "
"sky launch -y -c train --cloud kubernetes skypilot/examples/minimal.yaml")
sky_launch = get_skypilot_task("sky_launch", cmds)
# Task to get the logs of the SkyPilot cluster
sky_logs = get_skypilot_task("sky_logs", "sky logs train > task_logs.txt")
# Task to get the list of SkyPilot clusters
sky_status = get_skypilot_task("sky_status", "sky status")
# Task to delete the SkyPilot cluster
sky_down = get_skypilot_task("sky_down", "sky down train")

sky_launch >> sky_logs >> sky_status >> sky_down
```

## Running the DAG

1. Copy the DAG file to the Airflow DAGs directory.
```bash
cp sky_k8s_example.py /path/to/airflow/dags
# If your Airflow is running on Kubernetes, you may use kubectl cp to copy the file to the pod
# kubectl cp sky_k8s_example.py <airflow-pod-name>:/opt/airflow/dags
```
2. Run `airflow dags list` to confirm that the DAG is loaded.
3. Find the DAG in the Airflow UI (typically http://localhost:8080) and enable it. The UI may take a couple of minutes to reflect the changes.
4. Trigger the DAG from the Airflow UI using the `Trigger DAG` button.
5. Navigate to the run in the Airflow UI to see the DAG progress and logs of each task.

<p align="center">
<img src="https://i.imgur.com/BVZBaR9.png" width="800">
</p>
<p align="center">
<img src="https://i.imgur.com/GgqpSiU.png" width="800">
</p>

## Tips

1. **Persistent Volume**: If you have many concurrent tasks, you may want to use a storage class that supports `ReadWriteMany` access mode.
romilbhardwaj marked this conversation as resolved.
Show resolved Hide resolved
2. **Cloud credentials**: If you wish to run tasks on different clouds, you can configure cloud credentials in Kubernetes secrets and mount them in the Sky pod defined in the DAG.
romilbhardwaj marked this conversation as resolved.
Show resolved Hide resolved
3. **Logging**: All SkyPilot logs are written to container stdout, which is captured as task logs in Airflow and displayed in the UI. You can also write logs to a file and read them in subsequent tasks.

## Future work: a native Airflow Executor built on SkyPilot

SkyPilot can in the future provide a native Airflow Executor, that provides an operator similar to the `KubernetesPodOperator` but runs the task as native SkyPilot task.

In such a setup, SkyPilot state management would no longer be required, as the executor will handle SkyPilot cluster launching and termination.
11 changes: 11 additions & 0 deletions examples/airflow/shared_state/sky-pv.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: sky-pvc
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 10Gi
storageClassName: standard-rwo
18 changes: 18 additions & 0 deletions examples/airflow/shared_state/sky-sa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
apiVersion: v1
kind: ServiceAccount
metadata:
name: sky-airflow-sa
namespace: default
---
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRoleBinding
metadata:
name: sky-airflow-sa-binding
subjects:
- kind: ServiceAccount
name: sky-airflow-sa
namespace: default
roleRef:
kind: ClusterRole
name: cluster-admin
apiGroup: rbac.authorization.k8s.io
64 changes: 64 additions & 0 deletions examples/airflow/shared_state/sky_k8s_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from airflow import DAG
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import (
KubernetesPodOperator)
from airflow.utils.dates import days_ago
from kubernetes.client import models as k8s

default_args = {
'owner': 'airflow',
'start_date': days_ago(1),
}


def get_skypilot_task(task_id: str, sky_command: str):
skypilot_task = KubernetesPodOperator(
task_id=task_id,
name="skypilot-pod",
namespace="default",
image=
"us-central1-docker.pkg.dev/skypilot-375900/skypilotk8s/skypilot:20240613",
cmds=["/bin/bash", "-i", "-c"],
arguments=[
"chown -R 1000:1000 /home/sky/.sky /home/sky/.ssh && "
"pip install skypilot-nightly[kubernetes] && "
f"{sky_command}"
],
service_account_name="sky-airflow-sa",
env_vars={"HOME": "/home/sky"},
volumes=[
k8s.V1Volume(
name="sky-pvc",
persistent_volume_claim=k8s.V1PersistentVolumeClaimVolumeSource(
claim_name="sky-pvc"),
),
],
volume_mounts=[
k8s.V1VolumeMount(name="sky-pvc",
mount_path="/home/sky/.sky",
sub_path="sky"),
k8s.V1VolumeMount(name="sky-pvc",
mount_path="/home/sky/.ssh",
sub_path="ssh"),
],
is_delete_operator_pod=True,
get_logs=True,
)
return skypilot_task


with DAG(dag_id='sky_k8s_example',
default_args=default_args,
schedule_interval=None,
catchup=False) as dag:
# Task to launch a SkyPilot cluster
sky_launch = get_skypilot_task(
"sky_launch",
"sky launch -y -c train --cloud kubernetes -- echo training the model")
# Task to get the logs of the SkyPilot cluster
sky_logs = get_skypilot_task("sky_logs", "sky logs train > task_logs.txt")
# Task to get the list of SkyPilot clusters
sky_status = get_skypilot_task("sky_status", "sky status")
# Task to delete the SkyPilot cluster
sky_down = get_skypilot_task("sky_down", "sky down train")

sky_launch >> sky_logs >> sky_status >> sky_down
Loading
Loading