From 1ff81daf016a34fbd338a100dbe2fbb942071327 Mon Sep 17 00:00:00 2001 From: lkk <33276950+lkk12014402@users.noreply.github.com> Date: Wed, 21 Aug 2024 10:22:25 +0800 Subject: [PATCH] update finetuning api with openai format. (#535) * update finetuning api with openai format. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update doc and use 8001 port. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: test Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: root --- comps/cores/mega/micro_service.py | 38 +++-- comps/cores/proto/api_protocol.py | 222 +++++++++++++++++++++++++ comps/finetuning/README.md | 18 +- comps/finetuning/finetuning_service.py | 32 ++-- comps/finetuning/handlers.py | 49 ++++-- comps/finetuning/models.py | 53 ------ 6 files changed, 307 insertions(+), 105 deletions(-) delete mode 100644 comps/finetuning/models.py diff --git a/comps/cores/mega/micro_service.py b/comps/cores/mega/micro_service.py index e83a2836b..689fff9dd 100644 --- a/comps/cores/mega/micro_service.py +++ b/comps/cores/mega/micro_service.py @@ -3,7 +3,7 @@ import asyncio import multiprocessing -from typing import Any, Optional, Type +from typing import Any, List, Optional, Type from ..proto.docarray import TextDoc from .constants import ServiceRoleType, ServiceType @@ -154,25 +154,27 @@ def register_microservice( output_datatype: Type[Any] = TextDoc, provider: Optional[str] = None, provider_endpoint: Optional[str] = None, + methods: List[str] = ["POST"], ): def decorator(func): - micro_service = MicroService( - name=name, - service_role=service_role, - service_type=service_type, - protocol=protocol, - host=host, - port=port, - ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, - endpoint=endpoint, - input_datatype=input_datatype, - output_datatype=output_datatype, - provider=provider, - provider_endpoint=provider_endpoint, - ) - micro_service.app.router.add_api_route(endpoint, func, methods=["POST"]) - opea_microservices[name] = micro_service + if name not in opea_microservices: + micro_service = MicroService( + name=name, + service_role=service_role, + service_type=service_type, + protocol=protocol, + host=host, + port=port, + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + endpoint=endpoint, + input_datatype=input_datatype, + output_datatype=output_datatype, + provider=provider, + provider_endpoint=provider_endpoint, + ) + opea_microservices[name] = micro_service + opea_microservices[name].app.router.add_api_route(endpoint, func, methods=methods) return func return decorator diff --git a/comps/cores/proto/api_protocol.py b/comps/cores/proto/api_protocol.py index 957fc9d95..773bf56f2 100644 --- a/comps/cores/proto/api_protocol.py +++ b/comps/cores/proto/api_protocol.py @@ -279,3 +279,225 @@ def check_requests(request) -> Optional[JSONResponse]: ) return None + + +class Hyperparameters(BaseModel): + batch_size: Optional[Union[Literal["auto"], int]] = "auto" + """Number of examples in each batch. + + A larger batch size means that model parameters are updated less frequently, but with lower variance. + """ + + learning_rate_multiplier: Optional[Union[Literal["auto"], float]] = "auto" + """Scaling factor for the learning rate. + + A smaller learning rate may be useful to avoid overfitting. + """ + + n_epochs: Optional[Union[Literal["auto"], int]] = "auto" + """The number of epochs to train the model for. + + An epoch refers to one full cycle through the training dataset. "auto" decides + the optimal number of epochs based on the size of the dataset. If setting the + number manually, we support any number between 1 and 50 epochs. + """ + + +class FineTuningJobWandbIntegration(BaseModel): + project: str + """The name of the project that the new run will be created under.""" + + entity: Optional[str] = None + """The entity to use for the run. + + This allows you to set the team or username of the WandB user that you would + like associated with the run. If not set, the default entity for the registered + WandB API key is used. + """ + + name: Optional[str] = None + """A display name to set for the run. + + If not set, we will use the Job ID as the name. + """ + + tags: Optional[List[str]] = None + """A list of tags to be attached to the newly created run. + + These tags are passed through directly to WandB. Some default tags are generated + by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}". + """ + + +class FineTuningJobWandbIntegrationObject(BaseModel): + type: Literal["wandb"] + """The type of the integration being enabled for the fine-tuning job.""" + + wandb: FineTuningJobWandbIntegration + """The settings for your integration with Weights and Biases. + + This payload specifies the project that metrics will be sent to. Optionally, you + can set an explicit display name for your run, add tags to your run, and set a + default entity (team, username, etc) to be associated with your run. + """ + + +class FineTuningJobsRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/create + model: str + """The name of the model to fine-tune.""" + + training_file: str + """The ID of an uploaded file that contains training data.""" + + hyperparameters: Optional[Hyperparameters] = None + """The hyperparameters used for the fine-tuning job.""" + + suffix: Optional[str] = None + """A string of up to 64 characters that will be added to your fine-tuned model name.""" + + validation_file: Optional[str] = None + """The ID of an uploaded file that contains validation data.""" + + integrations: Optional[List[FineTuningJobWandbIntegrationObject]] = None + """A list of integrations to enable for your fine-tuning job.""" + + seed: Optional[str] = None + + +class Error(BaseModel): + code: str + """A machine-readable error code.""" + + message: str + """A human-readable error message.""" + + param: Optional[str] = None + """The parameter that was invalid, usually `training_file` or `validation_file`. + + This field will be null if the failure was not parameter-specific. + """ + + +class FineTuningJob(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/object + id: str + """The object identifier, which can be referenced in the API endpoints.""" + + created_at: int + """The Unix timestamp (in seconds) for when the fine-tuning job was created.""" + + error: Optional[Error] = None + """For fine-tuning jobs that have `failed`, this will contain more information on + the cause of the failure.""" + + fine_tuned_model: Optional[str] = None + """The name of the fine-tuned model that is being created. + + The value will be null if the fine-tuning job is still running. + """ + + finished_at: Optional[int] = None + """The Unix timestamp (in seconds) for when the fine-tuning job was finished. + + The value will be null if the fine-tuning job is still running. + """ + + hyperparameters: Hyperparameters + """The hyperparameters used for the fine-tuning job. + + See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning) + for more details. + """ + + model: str + """The base model that is being fine-tuned.""" + + object: Literal["fine_tuning.job"] = "fine_tuning.job" + """The object type, which is always "fine_tuning.job".""" + + organization_id: Optional[str] = None + """The organization that owns the fine-tuning job.""" + + result_files: List[str] = None + """The compiled results file ID(s) for the fine-tuning job. + + You can retrieve the results with the + [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents). + """ + + status: Literal["validating_files", "queued", "running", "succeeded", "failed", "cancelled"] + """The current status of the fine-tuning job, which can be either + `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`.""" + + trained_tokens: Optional[int] = None + """The total number of billable tokens processed by this fine-tuning job. + + The value will be null if the fine-tuning job is still running. + """ + + training_file: str + """The file ID used for training. + + You can retrieve the training data with the + [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents). + """ + + validation_file: Optional[str] = None + """The file ID used for validation. + + You can retrieve the validation results with the + [Files API](https://platform.openai.com/docs/api-reference/files/retrieve-contents). + """ + + integrations: Optional[List[FineTuningJobWandbIntegrationObject]] = None + """A list of integrations to enable for this fine-tuning job.""" + + seed: Optional[int] = None + """The seed used for the fine-tuning job.""" + + estimated_finish: Optional[int] = None + """The Unix timestamp (in seconds) for when the fine-tuning job is estimated to + finish. + + The value will be null if the fine-tuning job is not running. + """ + + +class FineTuningJobIDRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/retrieve + # https://platform.openai.com/docs/api-reference/fine-tuning/cancel + fine_tuning_job_id: str + """The ID of the fine-tuning job.""" + + +class FineTuningJobListRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/list + after: Optional[str] = None + """Identifier for the last job from the previous pagination request.""" + + limit: Optional[int] = 20 + """Number of fine-tuning jobs to retrieve.""" + + +class FineTuningJobList(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/fine-tuning/list + object: str = "list" + """The object type, which is always "list". + + This indicates that the returned data is a list of fine-tuning jobs. + """ + + data: List[FineTuningJob] + """A list containing FineTuningJob objects.""" + + has_more: bool + """Indicates whether there are more fine-tuning jobs beyond the current list. + + If true, additional requests can be made to retrieve more jobs. + """ diff --git a/comps/finetuning/README.md b/comps/finetuning/README.md index 432121d4d..e56de1e82 100644 --- a/comps/finetuning/README.md +++ b/comps/finetuning/README.md @@ -60,7 +60,7 @@ docker build -t opea/finetuning:latest --build-arg https_proxy=$https_proxy --bu Start docker container with below command: ```bash -docker run -d --name="finetuning-server" -p 8000:8000 --runtime=runc --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy opea/finetuning:latest +docker run -d --name="finetuning-server" -p 8001:8001 --runtime=runc --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy opea/finetuning:latest ``` # 🚀3. Consume Finetuning Service @@ -70,11 +70,25 @@ docker run -d --name="finetuning-server" -p 8000:8000 --runtime=runc --ipc=host Assuming a training file `alpaca_data.json` is uploaded, it can be downloaded in [here](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json), the following script launches a finetuning job using `meta-llama/Llama-2-7b-chat-hf` as base model: ```bash -curl http://${your_ip}:8000/v1/fine_tuning/jobs \ +# create a finetuning job +curl http://${your_ip}:8001/v1/fine_tuning/jobs \ -X POST \ -H "Content-Type: application/json" \ -d '{ "training_file": "alpaca_data.json", "model": "meta-llama/Llama-2-7b-chat-hf" }' + +# list finetuning jobs +curl http://${your_ip}:8001/v1/fine_tuning/jobs -X GET + +# retrieve one finetuning job +curl http://localhost:8001/v1/fine_tuning/jobs/retrieve -X POST -H "Content-Type: application/json" -d '{ + "fine_tuning_job_id": ${fine_tuning_job_id}}' + +# cancel one finetuning job + +curl http://localhost:8001/v1/fine_tuning/jobs/cancel -X POST -H "Content-Type: application/json" -d '{ + "fine_tuning_job_id": ${fine_tuning_job_id}}' + ``` diff --git a/comps/finetuning/finetuning_service.py b/comps/finetuning/finetuning_service.py index a1caba88d..2b0b3a91c 100644 --- a/comps/finetuning/finetuning_service.py +++ b/comps/finetuning/finetuning_service.py @@ -1,41 +1,45 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import uvicorn -from fastapi import BackgroundTasks, FastAPI +from fastapi import BackgroundTasks +from comps import opea_microservices, register_microservice +from comps.cores.proto.api_protocol import FineTuningJobIDRequest, FineTuningJobsRequest from comps.finetuning.handlers import ( handle_cancel_finetuning_job, handle_create_finetuning_jobs, handle_list_finetuning_jobs, handle_retrieve_finetuning_job, ) -from comps.finetuning.models import FineTuningJob, FineTuningJobList, FineTuningJobsRequest -app = FastAPI() - -@app.post("/v1/fine_tuning/jobs", response_model=FineTuningJob) +@register_microservice(name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs", host="0.0.0.0", port=8001) def create_finetuning_jobs(request: FineTuningJobsRequest, background_tasks: BackgroundTasks): return handle_create_finetuning_jobs(request, background_tasks) -@app.get("/v1/fine_tuning/jobs", response_model=FineTuningJobList) +@register_microservice( + name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs", host="0.0.0.0", port=8001, methods=["GET"] +) def list_finetuning_jobs(): return handle_list_finetuning_jobs() -@app.get("/v1/fine_tuning/jobs/{fine_tuning_job_id}", response_model=FineTuningJob) -def retrieve_finetuning_job(fine_tuning_job_id): - job = handle_retrieve_finetuning_job(fine_tuning_job_id) +@register_microservice( + name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs/retrieve", host="0.0.0.0", port=8001 +) +def retrieve_finetuning_job(request: FineTuningJobIDRequest): + job = handle_retrieve_finetuning_job(request) return job -@app.post("/v1/fine_tuning/jobs/{fine_tuning_job_id}/cancel", response_model=FineTuningJob) -def cancel_finetuning_job(fine_tuning_job_id): - job = handle_cancel_finetuning_job(fine_tuning_job_id) +@register_microservice( + name="opea_service@finetuning", endpoint="/v1/fine_tuning/jobs/cancel", host="0.0.0.0", port=8001 +) +def cancel_finetuning_job(request: FineTuningJobIDRequest): + job = handle_cancel_finetuning_job(request) return job if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=8000) + opea_microservices["opea_service@finetuning"].start() diff --git a/comps/finetuning/handlers.py b/comps/finetuning/handlers.py index eaf241890..5b842dffb 100644 --- a/comps/finetuning/handlers.py +++ b/comps/finetuning/handlers.py @@ -11,8 +11,13 @@ from pydantic_yaml import parse_yaml_raw_as, to_yaml_file from ray.job_submission import JobSubmissionClient +from comps.cores.proto.api_protocol import ( + FineTuningJob, + FineTuningJobIDRequest, + FineTuningJobList, + FineTuningJobsRequest, +) from comps.finetuning.llm_on_ray.finetune.finetune_config import FinetuneConfig -from comps.finetuning.models import FineTuningJob, FineTuningJobList, FineTuningJobsRequest MODEL_CONFIG_FILE_MAP = { "meta-llama/Llama-2-7b-chat-hf": "./models/llama-2-7b-chat-hf.yaml", @@ -20,6 +25,12 @@ } DATASET_BASE_PATH = "datasets" +JOBS_PATH = "jobs" +if not os.path.exists(DATASET_BASE_PATH): + os.mkdir(DATASET_BASE_PATH) + +if not os.path.exists(JOBS_PATH): + os.mkdir(JOBS_PATH) FineTuningJobID = str CHECK_JOB_STATUS_INTERVAL = 5 # Check every 5 secs @@ -61,6 +72,17 @@ def handle_create_finetuning_jobs(request: FineTuningJobsRequest, background_tas finetune_config = parse_yaml_raw_as(FinetuneConfig, f) finetune_config.Dataset.train_file = train_file_path + + if request.hyperparameters is not None: + if request.hyperparameters.epochs != "auto": + finetune_config.Training.epochs = request.hyperparameters.epochs + + if request.hyperparameters.batch_size != "auto": + finetune_config.Training.batch_size = request.hyperparameters.batch_size + + if request.hyperparameters.learning_rate_multiplier != "auto": + finetune_config.Training.learning_rate = request.hyperparameters.learning_rate_multiplier + if os.getenv("HF_TOKEN", None): finetune_config.General.config.use_auth_token = os.getenv("HF_TOKEN", None) @@ -75,11 +97,10 @@ def handle_create_finetuning_jobs(request: FineTuningJobsRequest, background_tas "learning_rate_multiplier": finetune_config.Training.learning_rate, }, status="running", - # TODO: Add seed in finetune config - seed=random.randint(0, 1000), + seed=random.randint(0, 1000) if request.seed is None else request.seed, ) - finetune_config_file = f"jobs/{job.id}.yaml" + finetune_config_file = f"{JOBS_PATH}/{job.id}.yaml" to_yaml_file(finetune_config_file, finetune_config) global ray_client @@ -107,14 +128,18 @@ def handle_list_finetuning_jobs(): return finetuning_jobs_list -def handle_retrieve_finetuning_job(fine_tuning_job_id): +def handle_retrieve_finetuning_job(request: FineTuningJobIDRequest): + fine_tuning_job_id = request.fine_tuning_job_id + job = running_finetuning_jobs.get(fine_tuning_job_id) if job is None: raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!") return job -def handle_cancel_finetuning_job(fine_tuning_job_id): +def handle_cancel_finetuning_job(request: FineTuningJobIDRequest): + fine_tuning_job_id = request.fine_tuning_job_id + ray_job_id = finetuning_job_to_ray_job.get(fine_tuning_job_id) if ray_job_id is None: raise HTTPException(status_code=404, detail=f"Fine-tuning job '{fine_tuning_job_id}' not found!") @@ -126,15 +151,3 @@ def handle_cancel_finetuning_job(fine_tuning_job_id): job = running_finetuning_jobs.get(fine_tuning_job_id) job.status = "cancelled" return job - - -# def cancel_all_jobs(): -# global ray_client -# ray_client = JobSubmissionClient() if ray_client is None else ray_client -# # stop all jobs -# for job_id in finetuning_job_to_ray_job.values(): -# ray_client.stop_job(job_id) - -# for job_id in running_finetuning_jobs: -# running_finetuning_jobs[job_id].status = "cancelled" -# return running_finetuning_jobs diff --git a/comps/finetuning/models.py b/comps/finetuning/models.py deleted file mode 100644 index f6757364d..000000000 --- a/comps/finetuning/models.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -from datetime import datetime -from typing import List, Optional - -from pydantic import BaseModel - - -class FineTuningJobsRequest(BaseModel): - training_file: str - model: str - - -class Hyperparameters(BaseModel): - n_epochs: int - batch_size: int - learning_rate_multiplier: float - - -class FineTuningJob(BaseModel): - object: str = "fine_tuning.job" # Set as constant - id: str - model: str - created_at: int - finished_at: int = None - fine_tuned_model: str = None - organization_id: str = None - result_files: List[str] = None - status: str - validation_file: str = None - training_file: str - hyperparameters: Hyperparameters - trained_tokens: int = None - integrations: List[str] = [] # Empty list by default - seed: int - estimated_finish: int = 0 # Set default value to 0 - - -class FineTuningJobList(BaseModel): - object: str = "list" # Set as constant - data: List[FineTuningJob] - has_more: bool - - -class FineTuningJobEvent(BaseModel): - object: str = "fine_tuning.job.event" # Set as constant - id: str - created_at: int - level: str - message: str - data: None = None # No data expected for this event type, set to None - type: str = "message" # Default event type is "message"