Skip to content

Commit

Permalink
Upstream the k8s data service with the new API and agent, with Linked…
Browse files Browse the repository at this point in the history
…In internal things removed
  • Loading branch information
shuyingliang committed Dec 14, 2024
1 parent c8f98c5 commit 983dc2a
Show file tree
Hide file tree
Showing 28 changed files with 1,817 additions and 9 deletions.
7 changes: 4 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.9
rev: v0.8.3
hooks:
# Run the linter.
- id: ruff
Expand All @@ -26,5 +26,6 @@ repos:
rev: v2.3.0
hooks:
- id: codespell
additional_dependencies:
- tomli
args:
- --ignore-words-list=assertIn # Ignore 'assertIn'
additional_dependencies: [tomli]
18 changes: 18 additions & 0 deletions Dockerfile.k8sdataservice
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
FROM python:3.10-slim-bookworm

MAINTAINER [email protected]
LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit

ARG CODE_DIR=/home/jobuser/code
COPY plugins/flytekit-k8sdataservice $CODE_DIR/flytekit-k8sdataservice
RUN ls -ls $CODE_DIR
# additional dependencies for running in k8s
RUN pip install prometheus-client grpcio-health-checking
RUN cd $CODE_DIR/flytekit-k8sdataservice && pip install .
# flytekit will autoload the agent if package is installed.
RUN pip install flytekitplugins-k8sdataservice
ENV PYTHONPATH=/usr/local/lib/python3.10/site-packages

ENV FLYTE_SDK_LOGGING_LEVEL=20

CMD pyflyte --verbose serve agent --port 8000
104 changes: 104 additions & 0 deletions plugins/flytekit-k8sdataservice/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# K8s Stateful Service Plugin

This plugin provides support for Kubernetes StatefulSet and Service integration, enabling seamless provisioning and coordination with any Kubernetes services or Flyte tasks. It is especially suited for deep learning use cases at scale, where distributed and parallelized data loading and caching across nodes are required.

## Features
- **Predictable and Reliable Endpoints**: The service creates consistent endpoints, facilitating communication between services or tasks within the same Kubernetes cluster.
- **Reusable Across Runs**: Service tasks can persist across task runs, ensuring consistency. Alternatively, a cleanup sensor can release cluster resources when they are no longer needed.
- **Conventional Pod Naming**: Pods in the StatefulSet follow a conventional naming pattern. For instance, if the StatefulSet name is `foo` and replicas are set to 2, the pod endpoints will be `foo-0.foo:1234` and `foo-1.foo:1234`. This simplifies endpoint construction for training or inference scripts. For example, gRPC endpoints can directly use `foo-0.foo:1234` and `foo-1.foo:1234`.

## Installation

Install the plugin via pip:

```bash
pip install flytekitplugins-k8sdataservice
```

## Usage

Below is an example demonstrating how to provision and run a service in Kubernetes, making it reachable within the cluster.

**Note**: Utility functions are available to generate unique service names that can be reused across training or inference scripts.

### Example Usage

#### Provisioning a Data Service
```python
from flytekitplugins.k8sdataservice import DataServiceConfig, DataServiceTask, CleanupSensor
from utils.infra import gen_infra_name
from flytekit import kwtypes, Resources, task, workflow

# Generate a unique infrastructure name
name = gen_infra_name()

def k8s_data_service():
gnn_config = DataServiceConfig(
Name=name,
Requests=Resources(cpu='1', mem='1Gi'),
Limits=Resources(cpu='2', mem='2Gi'),
Replicas=1,
Image="busybox:latest",
Command=[
"bash",
"-c",
"echo Hello Flyte K8s Stateful Service! && sleep 3600"
],
)

gnn_task = DataServiceTask(
name="K8s Stateful Data Service",
inputs=kwtypes(ds=str),
task_config=gnn_config,
)
return gnn_task

# Define a cleanup sensor
gnn_sensor = CleanupSensor(name="Cleanup")

# Define a workflow to test the data service
@workflow
def test_dataservice_wf(name: str):
k8s_data_service()(ds="OSS Flyte K8s Data Service Demo") \
>> gnn_sensor(
release_name=name,
cleanup_data_service=True,
)

if __name__ == "__main__":
out = test_dataservice_wf(name="example")
print(f"Running test_dataservice_wf() {out}")
```

#### Accessing the Data Service
Other tasks or services that need to access the data service can do so in multiple ways. For example, using environment variables:

```python
from kubernetes.client import V1PodSpec, V1Container, V1EnvVar

PRIMARY_CONTAINER_NAME = "primary"
FLYTE_POD_SPEC = V1PodSpec(
containers=[
V1Container(
name=PRIMARY_CONTAINER_NAME,
env=[
V1EnvVar(name="MY_DATASERVICES", value=f"{name}-0.{name}:40000 {name}-1.{name}:40000"),
],
)
],
)

task_config = MPIJob(
launcher=Launcher(replicas=1, pod_template=FLYTE_POD_SPEC),
worker=Worker(replicas=1, pod_template=FLYTE_POD_SPEC),
)

@task(task_config=task_config)
def mpi_task() -> str:
return "your script uses the envs to communicate with the data service "
```

### Key Points
- The `DataServiceConfig` defines resource requests, limits, replicas, and the container image/command.
- The `CleanupSensor` ensures resources are cleaned up when required.
- The workflow connects the service provisioning and cleanup process for streamlined operations.
1 change: 1 addition & 0 deletions plugins/flytekit-k8sdataservice/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
kubernetes~=23.6.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
.. currentmodule:: flytekitplugins.k8sdataservice
This package contains things that are useful when extending Flytekit.
.. autosummary::
:template: custom.rst
:toctree: generated/
DataServiceTask
"""

from .agent import DataServiceAgent # noqa: F401
from .sensor import CleanupSensor # noqa: F401
from .task import DataServiceConfig, DataServiceTask # noqa: F401
# from .dataservice_sensor_engine import DSSensorEngine # noqa: F401
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from dataclasses import dataclass
from typing import Optional

from flyteidl.core.execution_pb2 import TaskExecution
from flytekitplugins.k8sdataservice.k8s.manager import K8sManager
from flytekitplugins.k8sdataservice.task import DataServiceConfig

from flytekit import logger
from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate


@dataclass
class DataServiceMetadata(ResourceMeta):
dataservice_config: DataServiceConfig
name: str


class DataServiceAgent(AsyncAgentBase):
name = "K8s DataService Async Agent"
# config_file_path = "/etc/config/aipflyteagent/task_logs.yaml"

def __init__(self):
self.k8s_manager = K8sManager()
super().__init__(task_type_name="dataservicetask", metadata_type=DataServiceMetadata)
self.config = None
self.kk_execution_id = None

def create(
self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap] = None, **kwargs
) -> DataServiceMetadata:
graph_engine_config = task_template.custom
self.k8s_manager.set_configs(graph_engine_config)
logger.info(f"Loaded agent config file {self.config}")
existing_release_name = graph_engine_config.get("ExistingReleaseName", None)
logger.info(f"The existing data service release name is {existing_release_name}")

name = ""
if existing_release_name is None or existing_release_name == "":
logger.info("Creating K8s data service resources...")
name = self.k8s_manager.create_data_service(self.kk_execution_id)
logger.info(f'Data service {name} with image {graph_engine_config["Image"]} completed')
else:
name = existing_release_name
logger.info(f"User configs to use the existing data service release name: {name}.")
logger.info(f"The existing execution ID found is: {self.kk_execution_id}.")

dataservice_config = DataServiceConfig(
Name=graph_engine_config.get("Name", None),
Image=graph_engine_config["Image"],
Command=graph_engine_config["Command"],
Cluster=graph_engine_config["Cluster"],
ExistingReleaseName=graph_engine_config.get("ExistingReleaseName", None),
)
metadata = DataServiceMetadata(
dataservice_config=dataservice_config,
name=name,
)
logger.info(f"Created DataService metadata {metadata}")
return metadata

def get(self, resource_meta: DataServiceMetadata) -> Resource:
logger.info("K8s Data Service get is called")
data = resource_meta.dataservice_config
data_dict = data.__dict__ if isinstance(data, DataServiceConfig) else data
logger.info(f"The data_dict is {data_dict}")
self.k8s_manager.set_configs(data_dict)
name = data.Name
logger.info(f"Get the stateful set name {name}")

k8s_status = self.k8s_manager.check_stateful_set_status(name)
flyte_state = None
if k8s_status in ["failed", "timeout", "timedout", "canceled", "skipped", "internal_error"]:
flyte_state = TaskExecution.FAILED
elif k8s_status in ["done", "succeeded", "success"]:
flyte_state = TaskExecution.SUCCEEDED
elif k8s_status in ["running", "terminating", "pending"]:
flyte_state = TaskExecution.RUNNING
else:
logger.error(f"Unrecognized state: {k8s_status}")
outputs = {
"data_service_name": name,
}
# TODO: Add logs for StatefulSet.
return Resource(phase=flyte_state, outputs=outputs)

def delete(self, resource_meta: DataServiceMetadata):
logger.info("DataService delete is called")
data = resource_meta.dataservice_config

data_dict = data.__dict__ if isinstance(data, DataServiceConfig) else data
self.k8s_manager.set_configs(data_dict)

name = resource_meta.name
logger.info(f"To delete the DataService (e.g., StatefulSet and Service) with name {name}")
self.k8s_manager.delete_stateful_set(name)
self.k8s_manager.delete_service(name)


AgentRegistry.register(DataServiceAgent())
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from kubernetes import config

from flytekit import logger


class KubeConfig:
def __init__(self):
pass

def load_kube_config(self) -> None:
"""Load the kubernetes config based on fabric details prior to K8s client usage
:params target_fabric: fabric on which we are loading configs
"""
try:
logger.info("Attempting to load in-cluster configuration.")
config.load_incluster_config() # This will use the service account credentials
logger.info("Successfully loaded in-cluster configuration using the agent service account.")
except config.ConfigException as e:
logger.warning(f"Failed to load in-cluster configuration. {e}")
Loading

0 comments on commit 983dc2a

Please sign in to comment.