Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add initial version of generator using HF TGI #34

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion crossfit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


try:
from crossfit.backend.torch import HFModel, SentenceTransformerModel, TorchExactSearch
from crossfit.backend.torch import HFModel, HFGenerator, SentenceTransformerModel, TorchExactSearch
from crossfit.dataset.base import IRDataset, MultiDataset
from crossfit.dataset.load import load_dataset
from crossfit.report.beir.embed import embed
Expand All @@ -44,6 +44,7 @@
"HFModel",
"MultiDataset",
"IRDataset",
"HFGenerator",
]
)
except ImportError as e:
Expand Down
2 changes: 1 addition & 1 deletion crossfit/backend/dask/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def __init__(
self._client = client or "auto" # Cannot be `None`
self.cluster_type = cluster_type or ("cuda" if HAS_GPU else "cpu")

if torch_rmm and "rmm_pool_size" not in cluster_options:
if self.cluster_type == "cuda" and torch_rmm and "rmm_pool_size" not in cluster_options:
cluster_options["rmm_pool_size"] = True

self.cluster_options = cluster_options
Expand Down
3 changes: 2 additions & 1 deletion crossfit/backend/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from crossfit.backend.torch.hf.model import HFModel, SentenceTransformerModel
from crossfit.backend.torch.hf.generate import HFGenerator
from crossfit.backend.torch.op.vector_search import TorchExactSearch

__all__ = ["HFModel", "SentenceTransformerModel", "TorchExactSearch"]
__all__ = ["HFModel", "HFGenerator", "SentenceTransformerModel", "TorchExactSearch"]
183 changes: 183 additions & 0 deletions crossfit/backend/torch/hf/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import asyncio
import logging
import os
import time
import urllib.request
from typing import Optional

import aiohttp
import cudf
import dask_cudf
import dask.dataframe as dd
import torchx.runner as runner
import torchx.specs as specs
from dask.distributed import Client
from torchx.components.utils import python as utils_python
from tqdm.asyncio import tqdm_asyncio

from crossfit.dataset.home import CF_HOME


logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

TGI_IMAGE_NAME = "ghcr.io/huggingface/text-generation-inference"
TGI_IMAGE_VERSION = "1.1.1"


class HFGenerator:
def __init__(
self,
path_or_name,
image_name: str = TGI_IMAGE_NAME,
image_version: str = TGI_IMAGE_VERSION,
num_gpus: int = 1,
mem_gb: int = 8,
max_wait_seconds: int = 1800,
max_tokens: int = 384,
) -> None:
self.path_or_name = path_or_name
self.image_name = image_name
self.image_version = image_version
self.num_gpus = num_gpus
self.mem_gb = mem_gb
self.max_wait_seconds = max_wait_seconds
self.max_tokens = max_tokens

self.runner = None
self.app_handle = None

def start(self):
inference_server = get_tgi_app_def(
self.path_or_name,
image_name=self.image_name,
image_version=self.image_version,
num_gpus=self.num_gpus,
mem_gb=self.mem_gb,
)

self.runner = runner.get_runner()

self.app_handle = self.runner.run(
inference_server,
scheduler="local_docker",
)

self.status = self.runner.status(self.app_handle)

container_name = self.app_handle.split("/")[-1]
local_docker_client = self.runner._scheduler_instances["local_docker"]._docker_client
networked_containers = local_docker_client.networks.get("torchx").attrs[
"Containers"
]

self.ip_address = None
for _, container_config in networked_containers.items():
if container_name in container_config["Name"]:
self.ip_address = container_config["IPv4Address"].split("/")[0]
break
if not self.ip_address:
raise RuntimeError("Unable to get server IP address.")

self.health = None
for i in range(self.max_wait_seconds):
try:
urllib.request.urlopen(f"http://{self.ip_address}/health")
self.health = "OK"
except urllib.error.URLError:
time.sleep(1)

if self.health == "OK" or self.status.state != specs.AppState.RUNNING:
break

if i % 10 == 1:
logger.info("Waiting for server to be ready...")

self.status = self.runner.status(self.app_handle)

logger.info(self.status)

def __enter__(self):
self.start()

return self

def __exit__(self, exc_type, exc_value, traceback):
self.stop()

def stop(self):
if self.status.state == specs.AppState.RUNNING:
self.runner.stop(self.app_handle)

async def infer_async(self, data):
address = self.ip_address
async with Client(asynchronous=True) as dask_client:
tasks = [
dask_client.submit(fetch_async, address, string, max_tokens=self.max_tokens)
for string in data
]
return await tqdm_asyncio.gather(*tasks)

def infer(self, data, col: Optional[str] = None):
if isinstance(data, dd.DataFrame):
if not col:
raise ValueError("Column name must be provided for a dataframe.")
data = data.compute()[col].to_pandas().tolist()
generated_text = asyncio.run(self.infer_async(data))

input_col = col or "inputs"
output_col = "generated_text"
npartitions = getattr(data, "npartitions", self.num_gpus)
ddf = dask_cudf.from_cudf(
cudf.DataFrame({input_col: data, output_col: generated_text}),
npartitions=npartitions,
)
return ddf


async def fetch_async(http_address: str, prompt: int, max_tokens: int):
async with aiohttp.ClientSession() as session:
url = f"http://{http_address}/generate"
payload = {
"inputs": prompt,
"parameters": {"max_new_tokens": max_tokens},
}
async with session.post(url, json=payload) as response:
response_json = await response.json()
return response_json["generated_text"]


def get_tgi_app_def(
path_or_name: str,
image_name: str = TGI_IMAGE_NAME,
image_version: str = TGI_IMAGE_VERSION,
num_gpus: int = 1,
mem_gb: int = 8,
) -> specs.AppDef:
if os.path.isabs(path_or_name):
args = ["/data"]
mounts = [
specs.BindMount(path_or_name, "/data"),
]
else:
args = [path_or_name]
mounts = []

if num_gpus > 1:
args.extend(["--sharded", "true", "--num-shard", f"{num_gpus}"])

app_def = specs.AppDef(
name="generator",
roles=[
specs.Role(
name="tgi",
image=f"{image_name}:{image_version}",
entrypoint="--model-id",
args=args,
num_replicas=1,
resource=specs.Resource(cpu=1, gpu=num_gpus, memMB=mem_gb * 1024),
mounts=mounts,
),
],
)
return app_def
38 changes: 38 additions & 0 deletions examples/text_generation_with_tgi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import argparse

import crossfit as cf
from crossfit.dataset.load import load_dataset


def main():
args = parse_arguments()

dataset = load_dataset(args.dataset, tiny_sample=args.tiny_sample)
query = dataset.query.ddf()

with cf.HFGenerator(args.path_or_name, num_gpus=args.num_gpus) as generator:
results = generator.infer(query, col="text")

results.to_parquet(args.output_file)


def parse_arguments():
parser = argparse.ArgumentParser(description="Generate text using CrossFit")
parser.add_argument("path_or_name")
parser.add_argument(
"--dataset", default="beir/fiqa", help="Dataset to load (default: beir/fiqa)"
)
parser.add_argument("--tiny-sample", default=True, action="store_true", help="Use tiny sample dataset")
parser.add_argument(
"--num-gpus", type=int, default=2, help="Number of GPUs to use (default: 1)"
)
parser.add_argument(
"--output-file",
default="generated_text.parquet",
help="Output Parquet file (default: generated_text.parquet)",
)
return parser.parse_args()


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ typing_extensions
typing_utils
tqdm
rich
pynvml>=11.0.0,<11.5
pynvml>=11.0.0,<11.5
1 change: 1 addition & 0 deletions requirements/pytorch.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
torch>=1.0
torchx
transformers
curated-transformers
bitsandbytes
Loading