From 044814fd2080ad00afa92457fafe0c07e6f63631 Mon Sep 17 00:00:00 2001 From: SunXiaoye <31361630+JingofXin@users.noreply.github.com> Date: Fri, 22 Nov 2024 17:12:38 +0800 Subject: [PATCH] Add Training-Service for local sft and unify sft for local and online (#348) --- LazyLLM-Env | 2 +- lazyllm/cli/run.py | 15 +- lazyllm/components/finetune/alpacalora.py | 2 +- lazyllm/components/finetune/collie.py | 2 +- lazyllm/components/finetune/llamafactory.py | 13 +- .../stable_diffusion/stable_diffusion3.py | 2 +- lazyllm/components/text_to_speech/bark.py | 2 +- lazyllm/components/text_to_speech/chattts.py | 2 +- lazyllm/components/text_to_speech/musicgen.py | 2 +- lazyllm/configs.py | 2 + lazyllm/engine/lightengine.py | 260 ++++++++++- lazyllm/engine/scripts/__init__.py | 13 + .../engine/scripts/dataset_format/__init__.py | 0 .../scripts/dataset_format/pt_x2alpaca.py | 12 + .../scripts/dataset_format/sft_x2alpaca.py | 140 ++++++ lazyllm/launcher.py | 22 +- lazyllm/module/module.py | 31 +- lazyllm/module/onlineChatModule/glmModule.py | 137 +++++- .../onlineChatModule/onlineChatModuleBase.py | 26 ++ .../module/onlineChatModule/openaiModule.py | 125 +++++- lazyllm/module/onlineChatModule/qwenModule.py | 150 ++++++- lazyllm/module/utils.py | 175 ++++++++ lazyllm/tools/train_service/__init__.py | 0 lazyllm/tools/train_service/client.py | 406 ++++++++++++++++++ lazyllm/tools/train_service/serve.py | 393 +++++++++++++++++ lazyllm/tools/webpages/webmodule.py | 2 +- pyproject.toml | 3 +- requirements.full.txt | 1 + requirements.txt | 1 + .../standard_test/test_engine.py | 59 ++- tests/charge_tests/test_engine.py | 10 + 31 files changed, 1948 insertions(+), 62 deletions(-) create mode 100644 lazyllm/engine/scripts/__init__.py create mode 100644 lazyllm/engine/scripts/dataset_format/__init__.py create mode 100644 lazyllm/engine/scripts/dataset_format/pt_x2alpaca.py create mode 100644 lazyllm/engine/scripts/dataset_format/sft_x2alpaca.py create mode 100644 lazyllm/module/utils.py create mode 100644 lazyllm/tools/train_service/__init__.py create mode 100644 lazyllm/tools/train_service/client.py create mode 100644 lazyllm/tools/train_service/serve.py diff --git a/LazyLLM-Env b/LazyLLM-Env index 80b13f6a..dd8810c2 160000 --- a/LazyLLM-Env +++ b/LazyLLM-Env @@ -1 +1 @@ -Subproject commit 80b13f6a8eb049e3712b6d53350da54f4c9286b5 +Subproject commit dd8810c2382a1b5c071f5e3f5842c85666548408 diff --git a/lazyllm/cli/run.py b/lazyllm/cli/run.py index e44af14e..ce060d2c 100644 --- a/lazyllm/cli/run.py +++ b/lazyllm/cli/run.py @@ -1,7 +1,10 @@ import sys import argparse import json + +import lazyllm from lazyllm.engine.lightengine import LightEngine +from lazyllm.tools.train_service.serve import TrainServer # lazyllm run xx.json / xx.dsl / xx.lazyml # lazyllm run chatbot --model=xx --framework=xx --source=xx @@ -47,9 +50,17 @@ def graph(json_file): res = engine.run(eid, query) print(f'answer: {res}') +def training_service(): + train_server = TrainServer() + local_server = lazyllm.ServerModule(train_server, launcher=lazyllm.launcher.EmptyLauncher(sync=False)) + local_server.start() + local_server() + local_server.wait() + def run(commands): if not commands: - print('Usage:\n lazyllm run graph.json\n lazyllm run chatbot\n lazyllm run rag\n') + print('Usage:\n lazyllm run graph.json\n lazyllm run chatbot\n ' + 'lazyllm run rag\n lazyllm run training_service\n') parser = argparse.ArgumentParser(description='lazyllm deploy command') parser.add_argument('command', type=str, help='command') @@ -75,6 +86,8 @@ def run(commands): rag(llm, args.documents) elif args.command.endswith('.json'): graph(args.command) + elif args.command == 'training_service': + training_service() else: print('lazyllm run is not ready yet.') sys.exit(0) diff --git a/lazyllm/components/finetune/alpacalora.py b/lazyllm/components/finetune/alpacalora.py index d151bcac..ad3bd4e0 100644 --- a/lazyllm/components/finetune/alpacalora.py +++ b/lazyllm/components/finetune/alpacalora.py @@ -40,7 +40,7 @@ def __init__(self, **kw ): if not merge_path: - save_path = os.path.join(os.getcwd(), target_path) + save_path = os.path.join(lazyllm.config['train_target_root'], target_path) target_path, merge_path = os.path.join(save_path, "lazyllm_lora"), os.path.join(save_path, "lazyllm_merge") os.system(f'mkdir -p {target_path} {merge_path}') super().__init__( diff --git a/lazyllm/components/finetune/collie.py b/lazyllm/components/finetune/collie.py index ecadb92a..15790d84 100644 --- a/lazyllm/components/finetune/collie.py +++ b/lazyllm/components/finetune/collie.py @@ -39,7 +39,7 @@ def __init__(self, **kw ): if not merge_path: - save_path = os.path.join(os.getcwd(), target_path) + save_path = os.path.join(lazyllm.config['train_target_root'], target_path) target_path, merge_path = os.path.join(save_path, "lazyllm_lora"), os.path.join(save_path, "lazyllm_merge") os.system(f'mkdir -p {target_path} {merge_path}') super().__init__( diff --git a/lazyllm/components/finetune/llamafactory.py b/lazyllm/components/finetune/llamafactory.py index 9341a4d3..088cb48e 100644 --- a/lazyllm/components/finetune/llamafactory.py +++ b/lazyllm/components/finetune/llamafactory.py @@ -2,6 +2,8 @@ import yaml import json import tempfile +import random +from datetime import datetime import lazyllm from lazyllm import launchers, ArgsDict, thirdparty, CaseInsensitiveDict @@ -30,7 +32,7 @@ def __init__(self, if os.path.exists(defatult_path): base_model = defatult_path if not merge_path: - save_path = os.path.join(os.getcwd(), target_path) + save_path = os.path.join(lazyllm.config['train_target_root'], target_path) target_path, merge_path = os.path.join(save_path, "lazyllm_lora"), os.path.join(save_path, "lazyllm_merge") os.system(f'mkdir -p {target_path} {merge_path}') super().__init__( @@ -73,9 +75,10 @@ def __init__(self, self.export_dict['export_dir'] = merge_path self.export_dict['template'] = self.template_dict['template'] - self.temp_folder = os.path.join(os.getcwd(), '.temp') + self.temp_folder = os.path.join(lazyllm.config['temp_dir'], 'llamafactory_config') if not os.path.exists(self.temp_folder): os.makedirs(self.temp_folder) + self.log_file_path = None def get_template_name(self, base_model): try: @@ -144,8 +147,12 @@ def cmd(self, trainset, valset=None) -> str: updated_template_str = yaml.dump(dict(self.template_dict), default_flow_style=False) self.temp_yaml_file = self.build_temp_yaml(updated_template_str) + formatted_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + random_value = random.randint(1000, 9999) + self.log_file_path = f'{self.target_path}/train_log_{formatted_date}_{random_value}.log' + cmds = f'llamafactory-cli train {self.temp_yaml_file}' - cmds += f' 2>&1 | tee {self.target_path}/llm_$(date +"%Y-%m-%d_%H-%M-%S").log' + cmds += f' 2>&1 | tee {self.log_file_path}' if self.temp_export_yaml_file: cmds += f' && llamafactory-cli export {self.temp_export_yaml_file}' return cmds diff --git a/lazyllm/components/stable_diffusion/stable_diffusion3.py b/lazyllm/components/stable_diffusion/stable_diffusion3.py index a360adb3..685bcbc8 100644 --- a/lazyllm/components/stable_diffusion/stable_diffusion3.py +++ b/lazyllm/components/stable_diffusion/stable_diffusion3.py @@ -20,7 +20,7 @@ def __init__(self, base_sd, source=None, embed_batch_size=30, trust_remote_code= self.trust_remote_code = trust_remote_code self.sd = None self.init_flag = lazyllm.once_flag() - self.save_path = save_path if save_path else os.path.join(os.getcwd(), '.temp/sd3') + self.save_path = save_path or os.path.join(lazyllm.config['temp_dir'], 'sd3') if init: lazyllm.call_once(self.init_flag, self.load_sd) diff --git a/lazyllm/components/text_to_speech/bark.py b/lazyllm/components/text_to_speech/bark.py index 9b27d2b8..012951f1 100644 --- a/lazyllm/components/text_to_speech/bark.py +++ b/lazyllm/components/text_to_speech/bark.py @@ -17,7 +17,7 @@ def __init__(self, base_path, source=None, trust_remote_code=True, save_path=Non self.processor, self.bark = None, None self.init_flag = lazyllm.once_flag() self.device = 'cpu' - self.save_path = save_path if save_path else os.path.join(os.getcwd(), '.temp/bark') + self.save_path = save_path or os.path.join(lazyllm.config['temp_dir'], 'bark') if init: lazyllm.call_once(self.init_flag, self.load_bark) diff --git a/lazyllm/components/text_to_speech/chattts.py b/lazyllm/components/text_to_speech/chattts.py index b64160ee..15bf56a3 100644 --- a/lazyllm/components/text_to_speech/chattts.py +++ b/lazyllm/components/text_to_speech/chattts.py @@ -17,7 +17,7 @@ def __init__(self, base_path, source=None, save_path=None, init=False): self.init_flag = lazyllm.once_flag() self.device = 'cpu' self.seed = 1024 - self.save_path = save_path if save_path else os.path.join(os.getcwd(), '.temp/chattts') + self.save_path = save_path or os.path.join(lazyllm.config['temp_dir'], 'chattts') if init: lazyllm.call_once(self.init_flag, self.load_tts) diff --git a/lazyllm/components/text_to_speech/musicgen.py b/lazyllm/components/text_to_speech/musicgen.py index 261d28b4..0e10bfd2 100644 --- a/lazyllm/components/text_to_speech/musicgen.py +++ b/lazyllm/components/text_to_speech/musicgen.py @@ -14,7 +14,7 @@ def __init__(self, base_path, source=None, save_path=None, init=False): self.base_path = ModelManager(source).download(base_path) self.model = None self.init_flag = lazyllm.once_flag() - self.save_path = save_path if save_path else os.path.join(os.getcwd(), '.temp/musicgen') + self.save_path = save_path or os.path.join(lazyllm.config['temp_dir'], 'musicgen') if init: lazyllm.call_once(self.init_flag, self.load_tts) diff --git a/lazyllm/configs.py b/lazyllm/configs.py index b69f7f2f..d935bd95 100644 --- a/lazyllm/configs.py +++ b/lazyllm/configs.py @@ -91,4 +91,6 @@ def refresh(self, targets: Union[str, List[str]] = None) -> None: ).add('repr_ml', bool, False, 'REPR_USE_ML' ).add('rag_store', str, 'none', 'RAG_STORE' ).add('gpu_type', str, 'A100', 'GPU_TYPE' + ).add('train_target_root', str, os.path.join(os.getcwd(), 'save_ckpt'), 'TRAIN_TARGET_ROOT' + ).add('temp_dir', str, os.path.join(os.getcwd(), '.temp'), 'TEMP_DIR' ) diff --git a/lazyllm/engine/lightengine.py b/lazyllm/engine/lightengine.py index 4182bf09..b7fa0ca1 100644 --- a/lazyllm/engine/lightengine.py +++ b/lazyllm/engine/lightengine.py @@ -1,10 +1,14 @@ -from .engine import Engine, Node, ServerGraph -import lazyllm -from lazyllm import once_wrapper -from typing import List, Dict, Optional, Set, Union import copy import uuid +from urllib.parse import urlparse from contextlib import contextmanager +from typing import List, Dict, Optional, Set, Union + +import lazyllm +from lazyllm import once_wrapper +from .engine import Engine, Node, ServerGraph +from lazyllm.tools.train_service.serve import TrainServer +from lazyllm.tools.train_service.client import LocalTrainClient, OnlineTrainClient @contextmanager @@ -20,7 +24,7 @@ class LightEngine(Engine): _instance = None - def __new__(cls): + def __new__(cls, *args, **kwargs): if not LightEngine._instance: cls._instance = super().__new__(cls) return cls._instance @@ -29,6 +33,252 @@ def __new__(cls): def __init__(self): super().__init__() self.node_graph: Set[str, List[str]] = dict() + self.online_train_client = OnlineTrainClient() + + @once_wrapper + def launch_localllm_train_service(self): + train_server = TrainServer() + self._local_serve = lazyllm.ServerModule(train_server, launcher=lazyllm.launcher.EmptyLauncher(sync=False)) + self._local_serve.start()() + parsed_url = urlparse(self._local_serve._url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + self.local_train_client = LocalTrainClient(base_url) + + # Local + def local_model_train(self, train_config, token): + """ + Start a new training job on the LazyLLM training service. + + This method sends a request to the LazyLLM API to launch a training job with the specified configuration. + + Parameters: + - train_config (dict): A dictionary containing the training configuration details. + - token (str): The user group token required for authentication. + + Returns: + - tuple: A tuple containing the job ID and the current status of the training job if the request is successful. + - tuple: A tuple containing `None` and an error message if the request fails. + + The training configuration dictionary should include the following keys: + - finetune_model_name: The name of the model to be fine-tuned. + - base_model: The base model to use for traning. + - data_path: The path to the training data. + - training_type: The type of training (e.g., 'sft'). + - finetuning_type: The type of finetuning (e.g., 'lora'). + - val_size: The ratio of validation data set to training data set. + - num_epochs: The number of training epochs. + - learning_rate: The learning rate for training. + - lr_scheduler_type: The type of learning rate scheduler. + - batch_size: The batch size for training. + - cutoff_len: The maximum sequence length for training. + - lora_r: The LoRA rank. + - lora_alpha: The LoRA alpha parameter. + - lora_rate: The parameter ratio for LoRA fine-tuning. + """ + if not self.launch_localllm_train_service.flag: + raise RuntimeError("Please call the member function 'launch_localllm_train_service' " + "of the LightEngine instance to start the training service.") + return self.local_train_client.train(train_config, token) + + def local_model_cancel_training(self, token, job_id): + """ + Cancel a training job on the LazyLLM training service. + + This method sends a request to the LazyLLM API to cancel a specific training job. + + Parameters: + - token (str): The user group token required for authentication. + - job_id (str): The unique identifier of the training job to be cancelled. + + Returns: + - bool: True if the job was successfully cancelled, otherwise an error message is returned. + """ + if not self.launch_localllm_train_service.flag: + raise RuntimeError("Please call the member function 'launch_localllm_train_service' " + "of the LightEngine instance to start the training service.") + return self.local_train_client.cancel_training(token, job_id) + + def local_model_get_training_status(self, token, job_id): + """ + Retrieve the current status of a training job on the LazyLLM training service. + + This method sends a request to the LazyLLM API to fetch the current status of a specific training job. + + Parameters: + - token (str): The user group token required for authentication. + - job_id (str): The unique identifier of the training job for which to retrieve the status. + + Returns: + - str: The current status of the training job if the request is successful. + - 'Invalid' (str): If the request fails or an error occurs. + """ + if not self.launch_localllm_train_service.flag: + raise RuntimeError("Please call the member function 'launch_localllm_train_service' " + "of the LightEngine instance to start the training service.") + return self.local_train_client.get_training_status(token, job_id) + + def local_model_get_training_log(self, token, job_id): + """ + Retrieve the log for the current training job on the LazyLLM training service. + + This method sends a request to the LazyLLM API to fetch the log associated with a specific training job. + + Parameters: + - token (str): The user group token required for authentication. + - job_id (str): The unique identifier of the training job for which to retrieve the log. + + Returns: + - str: The log path if the request is successful. + - None: If the request fails or an error occurs. + """ + if not self.launch_localllm_train_service.flag: + raise RuntimeError("Please call the member function 'launch_localllm_train_service' " + "of the LightEngine instance to start the training service.") + return self.local_train_client.get_training_log(token, job_id) + + def local_model_get_all_trained_models(self, token): + """ + List all models with their job-id, model-id and statuse for the LazyLLM training service. + + Parameters: + - token (str): The user group token required for authentication. + + Returns: + - list of lists: Each sublist contains [job_id, model_name, status] for each trained model. + - None: If the request fails or an error occurs. + """ + if not self.launch_localllm_train_service.flag: + raise RuntimeError("Please call the member function 'launch_localllm_train_service' " + "of the LightEngine instance to start the training service.") + return self.local_train_client.get_all_trained_models(token) + + def local_model_get_training_cost(self, token, job_id): + """ + Retrieve the GPU usage time for a training job on the LazyLLM training service. + + This method sends a request to the LazyLLM API to fetch the GPU usage time (in seconds) + for a specific training job. + + Parameters: + - token (str): The user group token required for authentication. + - job_id (str): The unique identifier of the training job for which to retrieve the GPU usage time. + + Returns: + - int: The GPU usage time in seconds if the request is successful. + - str: An error message if the request fails. + """ + if not self.launch_localllm_train_service.flag: + raise RuntimeError("Please call the member function 'launch_localllm_train_service' " + "of the LightEngine instance to start the training service.") + return self.local_train_client.get_training_cost(token, job_id) + + # Online + def online_model_train(self, train_config, token, source): + """ + Initiates an online training task with the specified parameters and configurations. + + Args: + - train_config (dict): Configuration parameters for the training task. + - token (str): API-Key provided by the supplier, used for authentication. + - source (str): Specifies the supplier. Supported suppliers are 'glm' and 'qwen'. + + Returns: + - tuple: A tuple containing the Job-ID and its status if the training starts successfully. + If an error occurs, the Job-ID will be None, and the error message will be included. + + The training configuration dictionary should include the following keys: + - finetune_model_name: The name of the model to be fine-tuned. + - base_model: The base model to use for traning. + - data_path: The path to the training data. + - training_type: The type of training (e.g., 'sft'). + - finetuning_type: The type of finetuning (e.g., 'lora'). + - val_size: The ratio of validation data set to training data set. + - num_epochs: The number of training epochs. + - learning_rate: The learning rate for training. + - lr_scheduler_type: The type of learning rate scheduler. + - batch_size: The batch size for training. + - cutoff_len: The maximum sequence length for training. + - lora_r: The LoRA rank. + - lora_alpha: The LoRA alpha parameter. + - lora_rate: The parameter ratio for LoRA fine-tuning. + """ + return self.online_train_client.train(train_config, token, source) + + def online_model_cancel_training(self, token, job_id, source): + """ + Cancels an ongoing online training task by its Job-ID. + + Args: + - token (str): API-Key provided by the supplier, used for authentication. + - job_id (str): The unique identifier of the training job to be cancelled. + - source (str): Specifies the supplier. Supported suppliers are 'glm' and 'qwen'. + + Returns: + - bool or str: Returns True if the training task was successfully cancelled. If the cancellation fails, + it returns a string with the reason for the failure, including any final information about the task. + """ + return self.online_train_client.cancel_training(token, job_id, source) + + def online_model_get_training_status(self, token, job_id, source): + """ + Retrieves the current status of a training task by its Job-ID. + + Args: + - token (str): API-Key provided by the supplier, used for authentication. + - job_id (str): The unique identifier of the training job to query. + - source (str): Specifies the supplier. Supported suppliers are 'glm' and 'qwen'. + + Returns: + - str: A string representing the current status of the training task. This could be one of: + 'Pending', 'Running', 'Done', 'Cancelled', 'Failed', or 'Invalid' if the query could not be processed. + """ + return self.online_train_client.get_training_status(token, job_id, source) + + def online_model_get_training_log(self, token, job_id, source, target_path=None): + """ + Retrieves the training log for a specific training task by its Job-ID and saves it to a file. + + Args: + - token (str): API-Key provided by the supplier, used for authentication. + - job_id (str): The unique identifier of the training job for which to retrieve the log. + - source (str): Specifies the supplier. Supported suppliers are 'glm' and 'qwen'. + - target_path (str, optional): The path where the log file should be saved. If not provided, + the log will be saved to a temporary directory. + + Returns: + - str or None: The path to the saved log file if the log retrieval and saving was successful. + If an error occurs, None is returned. + """ + return self.online_train_client.get_training_log(token, job_id, source=source, target_path=target_path) + + def online_model_get_all_trained_models(self, token, source): + """ + Lists all model jobs with their corresponding job-id, model-id, and statuse for online training services. + + Args: + - token (str): API-Key provided by the supplier, used for authentication. + - source (str): Specifies the supplier. Supported suppliers are 'glm' and 'qwen'. + + Returns: + - list of lists: Each sublist contains [job_id, model_name, status] for each trained model. + - None: If the request fails or an error occurs. + """ + return self.online_train_client.get_all_trained_models(token, source) + + def online_model_get_training_cost(self, token, job_id, source): + """ + Retrieves the number of tokens consumed by an online traning task. + + Args: + - token (str): API-Key provided by the supplier, used for authentication. + - job_id (str): The unique identifier of the traning job for which to retrieve the token consumption. + - source (str): Specifies the supplier. Supported suppliers are 'glm' and 'qwen'. + + Returns: + - int or str: The number of tokens consumed by the traning task if the query is successful. + If an error occurs, a string containing the error message is returned. + """ + return self.online_train_client.get_training_cost(token, job_id, source) def build_node(self, node): if not isinstance(node, Node): diff --git a/lazyllm/engine/scripts/__init__.py b/lazyllm/engine/scripts/__init__.py new file mode 100644 index 00000000..3e7654ea --- /dev/null +++ b/lazyllm/engine/scripts/__init__.py @@ -0,0 +1,13 @@ +from .dataset_format.sft_x2alpaca import ( + csv2alpaca, + parquet2alpaca, + json2alpaca, + merge2alpaca +) + +__all__ = [ + 'csv2alpaca', + 'parquet2alpaca', + 'json2alpaca', + 'merge2alpaca', +] diff --git a/lazyllm/engine/scripts/dataset_format/__init__.py b/lazyllm/engine/scripts/dataset_format/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lazyllm/engine/scripts/dataset_format/pt_x2alpaca.py b/lazyllm/engine/scripts/dataset_format/pt_x2alpaca.py new file mode 100644 index 00000000..f5d3bb92 --- /dev/null +++ b/lazyllm/engine/scripts/dataset_format/pt_x2alpaca.py @@ -0,0 +1,12 @@ + +def txt2pretrain(dataset_path: str) -> str: + return dataset_path + +def csv2pretrain(dataset_path: str) -> str: + return dataset_path + +def parquet2pretrain(dataset_path: str) -> str: + return dataset_path + +def json2pretrain(dataset_path: str) -> str: + return dataset_path diff --git a/lazyllm/engine/scripts/dataset_format/sft_x2alpaca.py b/lazyllm/engine/scripts/dataset_format/sft_x2alpaca.py new file mode 100644 index 00000000..2157833e --- /dev/null +++ b/lazyllm/engine/scripts/dataset_format/sft_x2alpaca.py @@ -0,0 +1,140 @@ +import os +import csv +import json +import pandas as pd +from typing import List +from datetime import datetime +from datasets import load_dataset + +import lazyllm +from lazyllm.module.utils import openai2alpaca +from lazyllm.components.utils.file_operate import delete_old_files + +# origin_key: target_key: +default_mapping = {'instruction': 'instruction', 'input': 'input', 'output': 'output'} + +def csv2alpaca(dataset_path: str, header_mapping=None, target_path: str = None) -> str: + """ + Convert a CSV file to a JSON file with custom header mapping. + + :param dataset_path: path of the CSV file to be converted. + :param header_mapping: A dictionary representing the header mapping. Default is None. + :param target_path: The path of the folder where the converted files are stored. + The default is None, and it will be stored in the working path + `.temp/dataset`. + """ + save_dir = _build_target_dir(target_path) + + mapping = header_mapping if header_mapping else default_mapping + data = [] + with open(dataset_path, mode='r', encoding='utf-8') as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + renamed_row = {mapping.get(k, k): v for k, v in row.items() if k in mapping} + data.append(renamed_row) + file_name = os.path.basename(dataset_path) + base_name, _ = file_name.split('.') + + res_path = _save_dataset(data, save_dir, base_name) + return res_path + +def parquet2alpaca(dataset_path: str, header_mapping=None, target_path: str = None) -> str: + """ + Convert a Parquet file to a JSON file with custom header mapping. + + :param dataset_path: path of the Parquet file to be converted. + :param header_mapping: A dictionary representing the header mapping. Default is None. + :param target_path: The path of the folder where the converted files are stored. + The default is None, and it will be stored in the working path + `.temp/dataset`. + """ + save_dir = _build_target_dir(target_path) + + mapping = header_mapping if header_mapping else default_mapping + df = pd.read_parquet(dataset_path) + + df = df.rename(columns=mapping) + df = df[[col for col in mapping.values()]] + data = df.to_dict(orient='records') + + file_name = os.path.basename(dataset_path) + base_name, _ = file_name.split('.') + + res_path = _save_dataset(data, save_dir, base_name) + return res_path + +def json2alpaca(dataset_path: str, header_mapping=None, target_path: str = None) -> str: + """ + Convert a JSON file to a JSON file with custom header mapping. + + :param dataset_path: path of the JSON file to be converted. + :param header_mapping: A dictionary representing the header mapping. Default is None. + :param target_path: The path of the folder where the converted files are stored. + The default is None, and it will be stored in the working path + `.temp/dataset`. + """ + save_dir = _build_target_dir(target_path) + + mapping = header_mapping if header_mapping else default_mapping + dataset = load_dataset('json', data_files=dataset_path) + + data_list = [] + for row in dataset['train']: + renamed_row = {mapping.get(k, k): v for k, v in row.items() if k in mapping} + data_list.append(renamed_row) + + file_name = os.path.basename(dataset_path) + base_name, _ = file_name.split('.') + + res_path = _save_dataset(data_list, save_dir, base_name) + return res_path + +def merge2alpaca(dataset_paths: List[str], target_path: str = None) -> str: + """ + Merge multiple JSON files into a single JSON file formatted for Alpaca. + This function reads multiple JSON files(Alpaca or OpenAI format), converts them to Alpaca format. + The merged file is saved to the specified target directory or to a default temporary directory + if no target is provided. + + :param dataset_paths: A list of paths to the JSON files to be merged. + :param target_path: The path of the folder where the merged file will be stored. + If `None`, the file will be stored in the working path + `.temp/merged`. + :return: The path to the merged dataset file. + + Raises: + RuntimeError: If any of the provided file paths do not exist. + """ + if isinstance(dataset_paths, str): + dataset_paths = [dataset_paths] + non_existent_files = [path for path in dataset_paths if not os.path.exists(path)] + if non_existent_files: + raise RuntimeError(f"These files does not exist at {non_existent_files}") + save_dir = _build_target_dir(target_path) + + merge_list = [] + for path in dataset_paths: + data = load_dataset('json', data_files=path) + if "messages" in data["train"][0]: + alpaca_data = openai2alpaca(data) + merge_list.extend(alpaca_data) + else: + merge_list.extend(data['train'].to_list()) + res_path = _save_dataset(merge_list, save_dir, 'merge_dataset') + return res_path + +def _build_target_dir(target_path: str = None) -> str: + if target_path: + save_dir = target_path + if os.path.exists(save_dir): + raise RuntimeError(f"The target_path at {save_dir} does not exist.") + else: + save_dir = os.path.join(lazyllm.config['temp_dir'], 'dataset') + if not os.path.exists(save_dir): + os.system(f'mkdir -p {save_dir}') + else: + delete_old_files(save_dir) + return save_dir + +def _save_dataset(data: list, save_dir: str, base_name: str) -> str: + time_stamp = datetime.now().strftime('%y%m%d%H%M%S%f')[:14] + output_json_path = os.path.join(save_dir, f'{base_name}_{time_stamp}.json') + with open(output_json_path, 'w', encoding='utf-8') as json_file: + json.dump(data, json_file, ensure_ascii=False, indent=4) + return output_json_path diff --git a/lazyllm/launcher.py b/lazyllm/launcher.py index 28a82dc7..c4fc9a2c 100644 --- a/lazyllm/launcher.py +++ b/lazyllm/launcher.py @@ -16,7 +16,7 @@ import psutil import lazyllm -from lazyllm import LazyLLMRegisterMetaClass, LazyLLMCMD, final, timeout, LOG +from lazyllm import LazyLLMRegisterMetaClass, LazyLLMCMD, final, LOG class Status(Enum): TBSubmitted = 0, @@ -125,10 +125,15 @@ def _start(self, *, fixed): if self.sync: self.ps.wait() else: - with timeout(3600, msg='Launch failed: No computing resources are available.'): - while self.status in (Status.TBSubmitted, Status.InQueue, Status.Pending): - time.sleep(2) self.launcher.all_processes[self.launcher._id].append((self.jobid, self)) + n = 0 + while self.status in (Status.TBSubmitted, Status.InQueue, Status.Pending): + time.sleep(2) + n += 1 + if n > 1800: # 3600s + self.launcher.all_processes[self.launcher._id].pop() + LOG.error('Launch failed: No computing resources are available.') + break def restart(self, *, fixed=False): self.stop() @@ -570,9 +575,12 @@ def stop(self): self._scancel_job(cmd) time.sleep(0.5) # Avoid the execution of scancel and scontrol too close together. - with lazyllm.timeout(25): - while self.status not in (Status.Done, Status.Cancelled, Status.Failed): - time.sleep(1) + n = 0 + while self.status not in (Status.Done, Status.Cancelled, Status.Failed): + time.sleep(1) + n += 1 + if n > 25: + break if self.ps: self.ps.terminate() diff --git a/lazyllm/module/module.py b/lazyllm/module/module.py index 7742dfb3..3d5eada8 100644 --- a/lazyllm/module/module.py +++ b/lazyllm/module/module.py @@ -585,11 +585,13 @@ def __init__(self, base_model='', target_path='', stream=False, train=None, fine # TODO(wangzhihong): Update ModelDownloader to support async download, and move it to deploy. # Then support Option for base_model self._base_model = ModelManager(lazyllm.config['model_source']).download(base_model) - self._target_path = target_path if target_path else os.path.join(os.getcwd(), 'save_ckpt') + self._target_path = os.path.join(lazyllm.config['train_target_root'], target_path) self._stream = stream self._father = [] self._launchers: Dict[str, Dict[str, Launcher]] = dict(default=dict(), manual=dict()) + self._delimiter = '-LazySplit-' self._deployer = None + self._file_name = None self._specific_target_path = None self._train, self._finetune = train, finetune self.deploy_method(deploy) @@ -627,26 +629,28 @@ def after_train(real_target_path): return real_target_path return Pipeline(*self._get_train_tasks_impl(), after_train) - def _trian(self, name: str, ngpus: int = 1, mode: str = None, batch_size: int = 16, - micro_batch_size: int = 2, num_epochs: int = 3, learning_rate: float = 5e-4, - lora_r: int = 8, lora_alpha: int = 32, lora_dropout: float = 0.05, **kw): + def _async_finetune(self, name: str, ngpus: int = 1, **kw): assert name and isinstance(name, str), 'Invalid name: {name}, expect a valid string' assert name not in self._launchers['manual'], 'Duplicate name: {name}' self._launchers['manual'][name] = kw['launcher'] = launchers.remote(sync=False, ngpus=ngpus) + self._set_file_name(name) - Pipeline(*self._get_train_tasks_impl( - mode=mode, batch_size=batch_size, micro_batch_size=micro_batch_size, num_epochs=num_epochs, - learning_rate=learning_rate, lora_r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, **kw))() + def after_train(real_target_path): + self._finetuned_model_path = real_target_path + return real_target_path + return Pipeline(*self._get_train_tasks_impl(mode='finetune', **kw), after_train)() def _get_all_finetuned_models(self): valid_paths = [] invalid_paths = [] for root, dirs, files in os.walk(self._target_path): if root.endswith('lazyllm_merge'): + model_path = os.path.abspath(root) + model_id = model_path.split(os.sep)[-2].split(self._delimiter)[0] if any(file.endswith(('.bin', '.safetensors')) for file in files): - valid_paths.append(os.path.abspath(root)) + valid_paths.append((model_id, model_path)) else: - invalid_paths.append(os.path.abspath(root)) + invalid_paths.append((model_id, model_path)) return valid_paths, invalid_paths def _set_specific_finetuned_model(self, model_path): @@ -723,13 +727,16 @@ def optimize_name(name): return name[:5] + '_' + name[-4:] return name base_model_name = optimize_name(base_model_name) + file_name = base_model_name if not self._file_name else self._file_name train_set_name = optimize_name(train_set_name) - target_path = os.path.join(self._target_path, - f"{base_model_name}-{train_set_name}-" + target_path = os.path.join(self._target_path, base_model_name, + f"{file_name}{self._delimiter}{train_set_name}{self._delimiter}" f"{datetime.now().strftime('%y%m%d%H%M%S%f')[:14]}") return target_path + def _set_file_name(self, name): + self._file_name = name class TrainableModule(UrlModule): builder_keys = _TrainableModuleImpl.builder_keys @@ -762,7 +769,7 @@ def stream(self): def stream(self, v: bool): self._stream = v - def get_all_finetuned_models(self): + def get_all_models(self): return self._impl._get_all_finetuned_models() def set_specific_finetuned_model(self, model_path): diff --git a/lazyllm/module/onlineChatModule/glmModule.py b/lazyllm/module/onlineChatModule/glmModule.py index 85ce4ab3..5cedf617 100644 --- a/lazyllm/module/onlineChatModule/glmModule.py +++ b/lazyllm/module/onlineChatModule/glmModule.py @@ -1,7 +1,8 @@ import json import os import requests -from typing import Tuple +from typing import Tuple, List + import lazyllm from .onlineChatModuleBase import OnlineChatModuleBase from .fileHandler import FileHandlerBase @@ -26,6 +27,25 @@ def __init__(self, return_trace=return_trace, **kwargs) FileHandlerBase.__init__(self) + self.default_train_data = { + "model": None, + "training_file": None, + "validation_file": None, + "extra_hyperparameters": { + "fine_tuning_method": None, # lora\full, default: lora, + "fine_tuning_parameters": { + "max_sequence_length": None # [1, 8192](int), default: 8192 + } + }, + "hyperparameters": { + "learning_rate_multiplier": 0.01, # (0,5] , default: 1.0 + "batch_size": None, # [1, 32], default: 8 + "n_epochs": 1, # [1, 10], default: 3 + }, + "suffix": None, + "request_id": None + } + self.fine_tuning_job_id = None def _get_system_prompt(self): return ("You are ChatGLM, an AI assistant developed based on a language model trained by Zhipu AI. " @@ -73,6 +93,18 @@ def _upload_train_file(self, train_file): self._dataHandler.close() return r.json()["id"] + def _update_kw(self, data, normal_config): + cur_data = self.default_train_data.copy() + cur_data.update(data) + + cur_data["extra_hyperparameters"]["fine_tuning_method"] = normal_config["finetuning_type"].strip().lower() + cur_data["extra_hyperparameters"]["fine_tuning_parameters"]["max_sequence_length"] = normal_config["cutoff_len"] + cur_data["hyperparameters"]["learning_rate_multiplier"] = normal_config["learning_rate"] + cur_data["hyperparameters"]["batch_size"] = normal_config["batch_size"] + cur_data["hyperparameters"]["n_epochs"] = normal_config["num_epochs"] + cur_data["suffix"] = normal_config["finetune_model_name"] + return cur_data + def _create_finetuning_job(self, train_model, train_file_id, **kw) -> Tuple[str, str]: url = os.path.join(self._base_url, "fine_tuning/jobs") headers = { @@ -84,17 +116,96 @@ def _create_finetuning_job(self, train_model, train_file_id, **kw) -> Tuple[str, "training_file": train_file_id } if len(kw) > 0: - data.update(kw) + if 'finetuning_type' in kw: + data = self._update_kw(data, kw) + else: + data.update(kw) with requests.post(url, headers=headers, json=data) as r: if r.status_code != 200: raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)])) fine_tuning_job_id = r.json()["id"] - status = r.json()["status"] + self.fine_tuning_job_id = fine_tuning_job_id + status = self._status_mapping(r.json()["status"]) return (fine_tuning_job_id, status) - def _query_finetuning_job(self, fine_tuning_job_id) -> Tuple[str, str]: + def _cancel_finetuning_job(self, fine_tuning_job_id=None): + if not fine_tuning_job_id and not self.fine_tuning_job_id: + return 'Invalid' + job_id = fine_tuning_job_id if fine_tuning_job_id else self.fine_tuning_job_id + fine_tune_url = os.path.join(self._base_url, f"fine_tuning/jobs/{job_id}/cancel") + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self._api_key}", + } + with requests.post(fine_tune_url, headers=headers) as r: + if r.status_code != 200: + raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)])) + status = r.json()['status'] + if status == 'cancelled': + return 'Cancelled' + else: + return f'JOB {job_id} status: {status}' + + def _query_finetuned_jobs(self): + fine_tune_url = os.path.join(self._base_url, "fine_tuning/jobs/") + headers = { + "Authorization": f"Bearer {self._api_key}" + } + with requests.get(fine_tune_url, headers=headers) as r: + if r.status_code != 200: + raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)])) + return r.json() + + def _get_finetuned_model_names(self) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]]]: + model_data = self._query_finetuned_jobs() + res = list() + for model in model_data['data']: + res.append([model['id'], model['fine_tuned_model'], self._status_mapping(model['status'])]) + return res + + def _status_mapping(self, status): + if status == 'succeeded': + return 'Done' + elif status == 'failed': + return 'Failed' + elif status == 'cancelled': + return 'Cancelled' + elif status == 'running': + return 'Running' + else: # create, validating_files, queued + return 'Pending' + + def _query_job_status(self, fine_tuning_job_id=None): + if not fine_tuning_job_id and not self.fine_tuning_job_id: + raise RuntimeError("No job ID specified. Please ensure that a valid 'fine_tuning_job_id' is " + "provided as an argument or started a training job.") + job_id = fine_tuning_job_id if fine_tuning_job_id else self.fine_tuning_job_id + _, status = self._query_finetuning_job(job_id) + return self._status_mapping(status) + + def _get_log(self, fine_tuning_job_id=None): + if not fine_tuning_job_id and not self.fine_tuning_job_id: + raise RuntimeError("No job ID specified. Please ensure that a valid 'fine_tuning_job_id' is " + "provided as an argument or started a training job.") + job_id = fine_tuning_job_id if fine_tuning_job_id else self.fine_tuning_job_id + fine_tune_url = os.path.join(self._base_url, f"fine_tuning/jobs/{job_id}/events") + headers = { + "Authorization": f"Bearer {self._api_key}" + } + with requests.get(fine_tune_url, headers=headers) as r: + if r.status_code != 200: + raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)])) + return job_id, r.json() + + def _get_curr_job_model_id(self): + if not self.fine_tuning_job_id: + return None, None + model_id, _ = self._query_finetuning_job(self.fine_tuning_job_id) + return self.fine_tuning_job_id, model_id + + def _query_finetuning_job_info(self, fine_tuning_job_id): fine_tune_url = os.path.join(self._base_url, f"fine_tuning/jobs/{fine_tuning_job_id}") headers = { "Authorization": f"Bearer {self._api_key}" @@ -102,12 +213,20 @@ def _query_finetuning_job(self, fine_tuning_job_id) -> Tuple[str, str]: with requests.get(fine_tune_url, headers=headers) as r: if r.status_code != 200: raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)])) + return r.json() - status = r.json()['status'] - fine_tuned_model = None - if status.lower() == "succeeded": - fine_tuned_model = r.json()["fine_tuned_model"] - return (fine_tuned_model, status) + def _query_finetuning_job(self, fine_tuning_job_id) -> Tuple[str, str]: + info = self._query_finetuning_job_info(fine_tuning_job_id) + status = info['status'] + fine_tuned_model = info["fine_tuned_model"] if 'fine_tuned_model' in info else None + return (fine_tuned_model, status) + + def _query_finetuning_cost(self, fine_tuning_job_id): + info = self._query_finetuning_job_info(fine_tuning_job_id) + if 'trained_tokens' in info and info['trained_tokens']: + return info['trained_tokens'] + else: + return None def _create_deployment(self) -> Tuple[str]: return (self._model_name, "RUNNING") diff --git a/lazyllm/module/onlineChatModule/onlineChatModuleBase.py b/lazyllm/module/onlineChatModule/onlineChatModuleBase.py index 98b639cb..69cef08a 100644 --- a/lazyllm/module/onlineChatModule/onlineChatModuleBase.py +++ b/lazyllm/module/onlineChatModule/onlineChatModuleBase.py @@ -6,12 +6,15 @@ import re from typing import Tuple, List, Dict, Union, Any import time + import lazyllm from lazyllm import globals, FileSystemQueue from lazyllm.components.prompter import PrompterBase, ChatPrompter from lazyllm.components.formatter import FormatterBase, EmptyFormatter +from lazyllm.components.utils.file_operate import delete_old_files from ..module import ModuleBase, Pipeline + class OnlineChatModuleBase(ModuleBase): def __init__(self, @@ -318,10 +321,33 @@ def _create_finetuning_job(self, train_model, train_file_id, **kw) -> Tuple[str, def _query_finetuning_job(self, fine_tuning_job_id) -> Tuple[str, str]: raise NotImplementedError(f"{self._model_series} not implemented _query_finetuning_job method in subclass") + def _query_finetuned_jobs(self) -> dict: + raise NotImplementedError(f"{self._model_series} not implemented _query_finetuned_jobs method in subclass") + + def _get_finetuned_model_names(self) -> (List[str], List[str]): + raise NotImplementedError(f"{self._model_series} not implemented _get_finetuned_model_names method in subclass") + def set_train_tasks(self, train_file, **kw): self._train_file = train_file self._train_parameters = kw + def set_specific_finetuned_model(self, model_id): + valid_jobs, _ = self._get_finetuned_model_names() + valid_model_id = [model for _, model in valid_jobs] + if model_id in valid_model_id: + self._model_name = model_id + self._is_trained = True + else: + raise ValueError(f"Cannot find modle({model_id}), in fintuned model list: {valid_model_id}") + + def _get_temp_save_dir_path(self): + save_dir = os.path.join(lazyllm.config['temp_dir'], 'online_model_sft_log') + if not os.path.exists(save_dir): + os.system(f'mkdir -p {save_dir}') + else: + delete_old_files(save_dir) + return save_dir + def _get_train_tasks(self): if not self._model_name or not self._train_file: raise ValueError("train_model and train_file is required") diff --git a/lazyllm/module/onlineChatModule/openaiModule.py b/lazyllm/module/onlineChatModule/openaiModule.py index 70d499b5..45c94357 100644 --- a/lazyllm/module/onlineChatModule/openaiModule.py +++ b/lazyllm/module/onlineChatModule/openaiModule.py @@ -1,7 +1,7 @@ import json import os import requests -from typing import Tuple +from typing import Tuple, List import lazyllm from .onlineChatModuleBase import OnlineChatModuleBase from .fileHandler import FileHandlerBase @@ -28,6 +28,17 @@ def __init__(self, return_trace=return_trace, **kwargs) FileHandlerBase.__init__(self) + self.default_train_data = { + "model": "gpt-3.5-turbo-0613", + "training_file": None, + "validation_file": None, + "hyperparameters": { + "n_epochs": 1, + "batch_size": 16, + "learning_rate_multiplier": "1.6e-5", + } + } + self.fine_tuning_job_id = None def _get_system_prompt(self): return "You are ChatGPT, a large language model trained by OpenAI.You are a helpful assistant." @@ -71,6 +82,17 @@ def _upload_train_file(self, train_file): self._dataHandler.close() return r.json()["id"] + def _update_kw(self, data, normal_config): + current_train_data = self.default_train_data.copy() + current_train_data.update(data) + + current_train_data["hyper_parameters"]["n_epochs"] = normal_config["num_epochs"] + current_train_data["hyper_parameters"]["learning_rate_multiplier"] = str(normal_config["learning_rate"]) + current_train_data["hyper_parameters"]["batch_size"] = normal_config["batch_size"] + current_train_data["suffix"] = normal_config["finetune_model_name"] + + return current_train_data + def _create_finetuning_job(self, train_model, train_file_id, **kw) -> Tuple[str, str]: url = os.path.join(self._base_url, "fine_tuning/jobs") headers = { @@ -82,17 +104,96 @@ def _create_finetuning_job(self, train_model, train_file_id, **kw) -> Tuple[str, "training_file": train_file_id } if len(kw) > 0: - data.update(kw) + if 'finetuning_type' in kw: + data = self._update_kw(data, kw) + else: + data.update(kw) with requests.post(url, headers=headers, json=data) as r: if r.status_code != 200: raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)])) fine_tuning_job_id = r.json()["id"] + self.fine_tuning_job_id = fine_tuning_job_id status = r.json()["status"] return (fine_tuning_job_id, status) - def _query_finetuning_job(self, fine_tuning_job_id) -> Tuple[str, str]: + def _cancel_finetuning_job(self, fine_tuning_job_id=None): + if not fine_tuning_job_id and not self.fine_tuning_job_id: + return 'Invalid' + job_id = fine_tuning_job_id if fine_tuning_job_id else self.fine_tuning_job_id + fine_tune_url = os.path.join(self._base_url, f"fine_tuning/jobs/{job_id}/cancel") + headers = { + "Authorization": f"Bearer {self._api_key}" + } + with requests.post(fine_tune_url, headers=headers) as r: + if r.status_code != 200: + raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)])) + status = r.json()['status'] + if status == 'cancelled': + return 'Cancelled' + else: + return f'JOB {job_id} status: {status}' + + def _query_finetuned_jobs(self): + fine_tune_url = os.path.join(self._base_url, "fine_tuning/jobs") + headers = { + "Authorization": f"Bearer {self._api_key}", + } + with requests.get(fine_tune_url, headers=headers) as r: + if r.status_code != 200: + raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)])) + return r.json() + + def _get_finetuned_model_names(self) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]]]: + model_data = self._query_finetuned_jobs() + res = list() + for model in model_data['data']: + status = 'Done'if 'successful' in model['message'] else 'Failed' + res.append([model['id'], model['fine_tuned_model'], status]) + return res + + def _status_mapping(self, status): + if status == 'succeeded': + return 'Done' + elif status == 'failed': + return 'Failed' + elif status == 'cancelled': + return 'Cancelled' + elif status == 'running': + return 'Running' + else: # validating_files, queued + return 'Pending' + + def _query_job_status(self, fine_tuning_job_id=None): + if not fine_tuning_job_id and not self.fine_tuning_job_id: + raise RuntimeError("No job ID specified. Please ensure that a valid 'fine_tuning_job_id' is " + "provided as an argument or started a training job.") + job_id = fine_tuning_job_id if fine_tuning_job_id else self.fine_tuning_job_id + _, status = self._query_finetuning_job(job_id) + return self._status_mapping(status) + + def _get_log(self, fine_tuning_job_id=None): + if not fine_tuning_job_id and not self.fine_tuning_job_id: + raise RuntimeError("No job ID specified. Please ensure that a valid 'fine_tuning_job_id' is " + "provided as an argument or started a training job.") + job_id = fine_tuning_job_id if fine_tuning_job_id else self.fine_tuning_job_id + fine_tune_url = os.path.join(self._base_url, f"fine_tuning/jobs/{job_id}/events") + headers = { + "Authorization": f"Bearer {self._api_key}" + } + with requests.get(fine_tune_url, headers=headers) as r: + if r.status_code != 200: + raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)])) + return job_id, r.json() + + def _get_curr_job_model_id(self): + if not self.fine_tuning_job_id: + return None, None + model_id, _ = self._query_finetuning_job(self.fine_tuning_job_id) + return self.fine_tuning_job_id, model_id + + def _query_finetuning_job_info(self, fine_tuning_job_id): fine_tune_url = os.path.join(self._base_url, f"fine_tuning/jobs/{fine_tuning_job_id}") headers = { "Authorization": f"Bearer {self._api_key}" @@ -100,12 +201,20 @@ def _query_finetuning_job(self, fine_tuning_job_id) -> Tuple[str, str]: with requests.get(fine_tune_url, headers=headers) as r: if r.status_code != 200: raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)])) + return r.json() - status = r.json()['status'] - fine_tuned_model = None - if status.lower() == "succeeded": - fine_tuned_model = r.json()["fine_tuned_model"] - return (fine_tuned_model, status) + def _query_finetuning_job(self, fine_tuning_job_id) -> Tuple[str, str]: + info = self._query_finetuning_job_info(fine_tuning_job_id) + status = info['status'] + fine_tuned_model = info["fine_tuned_model"] if 'fine_tuned_model' in info else None + return (fine_tuned_model, status) + + def _query_finetuning_cost(self, fine_tuning_job_id): + info = self._query_finetuning_job_info(fine_tuning_job_id) + if 'trained_tokens' in info and info['trained_tokens']: + return info['trained_tokens'] + else: + return None def _create_deployment(self) -> Tuple[str, str]: return (self._model_name, "RUNNING") diff --git a/lazyllm/module/onlineChatModule/qwenModule.py b/lazyllm/module/onlineChatModule/qwenModule.py index bba605c8..03cb6a16 100644 --- a/lazyllm/module/onlineChatModule/qwenModule.py +++ b/lazyllm/module/onlineChatModule/qwenModule.py @@ -1,7 +1,7 @@ import json import os import requests -from typing import Tuple +from typing import Tuple, List import lazyllm from .onlineChatModuleBase import OnlineChatModuleBase from .fileHandler import FileHandlerBase @@ -34,6 +34,26 @@ def __init__(self, self._deploy_paramters = dict() if stream: self._model_optional_params['incremental_output'] = True + self.default_train_data = { + "model": "qwen-turbo", + "training_file_ids": None, + "validation_file_ids": None, + "training_type": "efficient_sft", # sft or efficient_sft + "hyper_parameters": { + "n_epochs": 1, + "batch_size": 16, + "learning_rate": "1.6e-5", + "split": 0.9, + "warmup_ratio": 0.0, + "eval_steps": 1, + "lr_scheduler_type": "linear", + "max_length": 2048, + "lora_rank": 8, + "lora_alpha": 32, + "lora_dropout": 0.1, + } + } + self.fine_tuning_job_id = None def _get_system_prompt(self): return ("You are a large-scale language model from Alibaba Cloud, " @@ -42,6 +62,9 @@ def _get_system_prompt(self): def _set_chat_url(self): self._url = os.path.join(self._base_url, 'compatible-mode/v1/chat/completions') + # def _set_chat_sft_url(self): + # self._url = os.path.join(self._base_url, ) + def _convert_file_format(self, filepath: str) -> None: with open(filepath, 'r', encoding='utf-8') as fr: dataset = [json.loads(line) for line in fr] @@ -88,6 +111,20 @@ def _upload_train_file(self, train_file): self._dataHandler.close() return r.json()['data']['uploaded_files'][0]["file_id"] + def _update_kw(self, data, normal_config): + current_train_data = self.default_train_data.copy() + current_train_data.update(data) + + current_train_data["hyper_parameters"]["n_epochs"] = normal_config["num_epochs"] + current_train_data["hyper_parameters"]["learning_rate"] = str(normal_config["learning_rate"]) + current_train_data["hyper_parameters"]["lr_scheduler_type"] = normal_config["lr_scheduler_type"] + current_train_data["hyper_parameters"]["batch_size"] = normal_config["batch_size"] + current_train_data["hyper_parameters"]["max_length"] = normal_config["cutoff_len"] + current_train_data["hyper_parameters"]["lora_rank"] = normal_config["lora_r"] + current_train_data["hyper_parameters"]["lora_alpha"] = normal_config["lora_alpha"] + + return current_train_data + def _create_finetuning_job(self, train_model, train_file_id, **kw) -> Tuple[str, str]: url = os.path.join(self._base_url, "api/v1/fine-tunes") headers = { @@ -100,16 +137,103 @@ def _create_finetuning_job(self, train_model, train_file_id, **kw) -> Tuple[str, } if "training_parameters" in kw.keys(): data.update(kw["training_parameters"]) + elif 'finetuning_type' in kw: + data = self._update_kw(data, kw) with requests.post(url, headers=headers, json=data) as r: if r.status_code != 200: raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)])) fine_tuning_job_id = r.json()["output"]["job_id"] + self.fine_tuning_job_id = fine_tuning_job_id status = r.json()["output"]["status"] return (fine_tuning_job_id, status) - def _query_finetuning_job(self, fine_tuning_job_id) -> Tuple[str, str]: + def _cancel_finetuning_job(self, fine_tuning_job_id=None): + if not fine_tuning_job_id and not self.fine_tuning_job_id: + return 'Invalid' + job_id = fine_tuning_job_id if fine_tuning_job_id else self.fine_tuning_job_id + fine_tune_url = os.path.join(self._base_url, f"api/v1/fine-tunes/{job_id}/cancel") + headers = { + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json" + } + with requests.post(fine_tune_url, headers=headers) as r: + if r.status_code != 200: + raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)])) + status = r.json()['output']['status'] + if status == 'success': + return 'Cancelled' + else: + return f'JOB {job_id} status: {status}' + + def _query_finetuned_jobs(self): + fine_tune_url = os.path.join(self._base_url, "api/v1/fine-tunes") + headers = { + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json" + } + with requests.get(fine_tune_url, headers=headers) as r: + if r.status_code != 200: + raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)])) + return r.json() + + def _get_finetuned_model_names(self) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]]]: + model_data = self._query_finetuned_jobs() + res = list() + if 'jobs' not in model_data['output']: + return res + for model in model_data['output']['jobs']: + status = self._status_mapping(model['status']) + if status == 'Done': + model_id = model['finetuned_output'] + else: + model_id = model['model'] + '-' + model['job_id'] + res.append([model['job_id'], model_id, status]) + return res + + def _status_mapping(self, status): + if status == 'SUCCEEDED': + return 'Done' + elif status == 'FAILED': + return 'Failed' + elif status in ('CANCELING', 'CANCELED'): + return 'Cancelled' + elif status == 'RUNNING': + return 'Running' + else: # PENDING, QUEUING + return 'Pending' + + def _query_job_status(self, fine_tuning_job_id=None): + if not fine_tuning_job_id and not self.fine_tuning_job_id: + raise RuntimeError("No job ID specified. Please ensure that a valid 'fine_tuning_job_id' is " + "provided as an argument or started a training job.") + job_id = fine_tuning_job_id if fine_tuning_job_id else self.fine_tuning_job_id + _, status = self._query_finetuning_job(job_id) + return self._status_mapping(status) + + def _get_log(self, fine_tuning_job_id=None): + if not fine_tuning_job_id and not self.fine_tuning_job_id: + raise RuntimeError("No job ID specified. Please ensure that a valid 'fine_tuning_job_id' is " + "provided as an argument or started a training job.") + job_id = fine_tuning_job_id if fine_tuning_job_id else self.fine_tuning_job_id + fine_tune_url = os.path.join(self._base_url, f"api/v1/fine-tunes/{job_id}/logs") + headers = { + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json" + } + with requests.get(fine_tune_url, headers=headers) as r: + if r.status_code != 200: + raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)])) + return job_id, r.json() + + def _get_curr_job_model_id(self): + if not self.fine_tuning_job_id: + return None, None + model_id, _ = self._query_finetuning_job(self.fine_tuning_job_id) + return self.fine_tuning_job_id, model_id + + def _query_finetuning_job_info(self, fine_tuning_job_id): fine_tune_url = os.path.join(self._base_url, f"api/v1/fine-tunes/{fine_tuning_job_id}") headers = { "Authorization": f"Bearer {self._api_key}", @@ -118,12 +242,24 @@ def _query_finetuning_job(self, fine_tuning_job_id) -> Tuple[str, str]: with requests.get(fine_tune_url, headers=headers) as r: if r.status_code != 200: raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)])) + return r.json()['output'] - status = r.json()["output"]['status'] - fine_tuned_model = None - if status.lower() == "succeeded": - fine_tuned_model = r.json()["output"]["finetuned_output"] - return (fine_tuned_model, status) + def _query_finetuning_job(self, fine_tuning_job_id) -> Tuple[str, str]: + info = self._query_finetuning_job_info(fine_tuning_job_id) + status = info['status'] + # QWen only status == 'SUCCEEDED' can have `finetuned_output` + if 'finetuned_output' in info: + fine_tuned_model = info["finetuned_output"] + else: + fine_tuned_model = info["model"] + '-' + info["job_id"] + return (fine_tuned_model, status) + + def _query_finetuning_cost(self, fine_tuning_job_id): + info = self._query_finetuning_job_info(fine_tuning_job_id) + if 'usage' in info and info['usage']: + return info['usage'] + else: + return None def set_deploy_parameters(self, **kw): self._deploy_paramters = kw diff --git a/lazyllm/module/utils.py b/lazyllm/module/utils.py new file mode 100644 index 00000000..cb72155c --- /dev/null +++ b/lazyllm/module/utils.py @@ -0,0 +1,175 @@ +import os +import json +from datetime import datetime +from dataclasses import dataclass, asdict + +import lazyllm +from lazyllm.thirdparty import datasets +from ..components.utils.file_operate import delete_old_files + +@dataclass +class TrainConfig: + finetune_model_name: str = 'llm' + base_model: str = 'llm' + training_type: str = 'SFT' + finetuning_type: str = 'LoRA' + data_path: str = 'path/to/dataset' + val_size: float = 0.1 + num_epochs: int = 1 + learning_rate: float = 1e-4 + lr_scheduler_type: str = 'cosine' + batch_size: int = 32 + cutoff_len: int = 1024 + lora_r: int = 8 + lora_alpha: int = 32 + lora_rate: float = 0.1 + +def update_config(input_dict: dict, default_data: type) -> dict: + config = TrainConfig() + config_dict = asdict(config) + assert all([key in config_dict for key in input_dict.keys()]), \ + f"The {input_dict.keys()} must be the subset of {config_dict.keys()}." + config_dict.update(input_dict) + return config_dict + +INPUT_SPLIT = " ### input " + +def uniform_sft_dataset(dataset_path: str, target: str = 'alpaca') -> str: + ''' + {origin_format}.{suffix} -> {target_format}, supported all 8 cases: + 1. openai.json -> alpaca: Conversion: openai2alpaca: json + 2. openai.jsonl -> alpaca: Conversion: openai2alpaca: json + 3. alpaca.json -> alpaca: Keep: json + 4. alpaca.jsonl -> alpaca: Restore: jsonl -> json + 5. openai.json -> openai: Restore: json -> jsonl + 6. openai.jsonl -> openai: Keep: jsonl + 7. alpaca.json -> openai: Conversion: alpaca2openai: jsonl + 8. alpaca.jsonl -> openai: Conversion: alpaca2openai: jsonl + Note: target-suffix does match:{'openai': 'jsonl'; 'alpaca': 'json'} + ''' + assert os.path.exists(dataset_path), f"Path: {dataset_path} does not exist!" + + data = datasets.load_dataset('json', data_files=dataset_path) + file_name = os.path.basename(dataset_path) + base_name, suffix = file_name.split('.') + assert suffix in ['json', 'jsonl'] + target = target.strip().lower() + save_suffix = 'json' + + # Get the format('alpaca' or 'openai') of the original dataset + origin_format = 'alpaca' + if "messages" in data["train"][0]: + origin_format = 'openai' + + # Verify that the dataset format is consistent with the target format + if origin_format == target: + if target == 'alpaca': + if suffix == 'json': + return dataset_path + else: + save_data = alpaca_filter_null(data) + else: + if suffix == 'jsonl': + return dataset_path + else: + save_suffix = 'jsonl' + save_data = data['train'].to_list() + else: + # The format is inconsistent, conversion is required + if target == 'alpaca': + save_data = openai2alpaca(data) + elif target == 'openai': + save_data = alpaca2openai(data) + save_suffix = 'jsonl' + else: + raise ValueError(f"Not supported type: {target}") + + return save_dataset(save_data, save_suffix, base_name + f'_{suffix}') + +def save_json(data: list, output_json_path: str) -> None: + with open(output_json_path, 'w', encoding='utf-8') as json_file: + json.dump(data, json_file, ensure_ascii=False, indent=4) + +def save_jsonl(data: list, output_json_path: str) -> None: + with open(output_json_path, mode='w', encoding='utf-8') as json_file: + for row in data: + json_file.write(json.dumps(row, ensure_ascii=False) + '\n') + +def save_dataset(save_data: list, save_suffix='json', base_name='train_data') -> str: + directory = os.path.join(lazyllm.config['temp_dir'], 'dataset') + if not os.path.exists(directory): + os.makedirs(directory) + delete_old_files(directory) + time_stamp = datetime.now().strftime('%y%m%d%H%M%S%f')[:14] + output_json_path = os.path.join(directory, f'{base_name}_{time_stamp}.{save_suffix}') + if save_suffix == 'json': + save_json(save_data, output_json_path) + else: + save_jsonl(save_data, output_json_path) + return output_json_path + +def alpaca_filter_null(data) -> list: + res = [] + for item in data["train"]: + alpaca_item = dict() + for key in item.keys(): + if item[key]: + alpaca_item[key] = item[key] + res.append(alpaca_item) + return res + +def alpaca2openai(data) -> list: + res = [] + for item in data["train"]: + openai_item = {"messages": []} + inp = item.get("input", "") + system = item.get("system", "") # Maybe get None + historys = item.get("history", []) + if system: + openai_item["messages"].append({"role": "system", "content": system}) + openai_item["messages"].extend([ + {"role": "user", "content": item["instruction"] + (INPUT_SPLIT + inp if inp else "")}, + {"role": "assistant", "content": item["output"]} + ]) + if historys: + for history in historys: + openai_item["messages"].append({"role": "user", "content": history[0]}) + openai_item["messages"].append({"role": "assistant", "content": history[1]}) + + res.append(openai_item) + + return res + +def openai2alpaca(data) -> list: + res = [] + for line in data["train"]: + chat = line["messages"] + system = '' + instructions = [] + outputs = [] + for item in chat: + if item["role"] == "system" and not system: + system = item["content"] + if item["role"] == "user": + instructions.append(item["content"]) + if item["role"] == "assistant": + outputs.append(item["content"]) + assert len(instructions) == len(outputs) + history = [[x, y] for x, y in zip(instructions[1:], outputs[1:])] + instruction_input = instructions[0].split(INPUT_SPLIT) + instruction = instruction_input[0] + inp = '' + if len(instruction_input) >= 2: + inp = instruction_input[-1] + output = outputs[0] + alpaca_item = dict() + if system: + alpaca_item["system"] = system + alpaca_item["instruction"] = instruction + if inp or system: # 'or system' fix llama-factory-bug: system must have input + alpaca_item["input"] = inp + alpaca_item["output"] = output + if history: + alpaca_item["history"] = history + res.append(alpaca_item) + return res diff --git a/lazyllm/tools/train_service/__init__.py b/lazyllm/tools/train_service/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lazyllm/tools/train_service/client.py b/lazyllm/tools/train_service/client.py new file mode 100644 index 00000000..a8d98d1e --- /dev/null +++ b/lazyllm/tools/train_service/client.py @@ -0,0 +1,406 @@ +import os +import json +import requests +from urllib.parse import urljoin + +import lazyllm +from lazyllm.launcher import Status +from lazyllm.module.utils import update_config, TrainConfig, uniform_sft_dataset + + +class LocalTrainClient: + + def __init__(self, url): + self.url = url + + def uniform_status(self, status): + if status == 'Invalid': + res = 'Invalid' + elif Status[status] == Status.Done: + res = 'Done' + elif Status[status] == Status.Cancelled: + res = 'Cancelled' + elif Status[status] == Status.Failed: + res = 'Failed' + elif Status[status] == Status.Running: + res = 'Running' + else: # TBSubmitted, InQueue, Pending + res = 'Pending' + return res + + def train(self, train_config, token): + """ + Start a new training job on the LazyLLM training service. + + This method sends a request to the LazyLLM API to launch a training job with the specified configuration. + + Parameters: + - train_config (dict): A dictionary containing the training configuration details. + - token (str): The user group token required for authentication. + + Returns: + - tuple: A tuple containing the job ID and the current status of the training job if the request is successful. + - tuple: A tuple containing `None` and an error message if the request fails. + + Raises: + - Exception: If an error occurs during the request, it will be logged. + + The training configuration dictionary should include the following keys: + - finetune_model_name: The name of the model to be fine-tuned. + - base_model: The base model to use for traning. + - data_path: The path to the training data. + - training_type: The type of training (e.g., 'sft'). + - finetuning_type: The type of finetuning (e.g., 'lora'). + - val_size: The ratio of validation data set to training data set. + - num_epochs: The number of training epochs. + - learning_rate: The learning rate for training. + - lr_scheduler_type: The type of learning rate scheduler. + - batch_size: The batch size for training. + - cutoff_len: The maximum sequence length for training. + - lora_r: The LoRA rank. + - lora_alpha: The LoRA alpha parameter. + """ + url = urljoin(self.url, 'v1/fine_tuning/jobs') + headers = { + "Content-Type": "application/json", + "token": token, + } + train_config = update_config(train_config, TrainConfig) + data = { + 'finetune_model_name': train_config['finetune_model_name'], + 'base_model': train_config['base_model'], + 'data_path': train_config['data_path'], + 'hyperparameters': { + 'stage': train_config['training_type'].strip().lower(), + 'finetuning_type': train_config['finetuning_type'].strip().lower(), + 'val_size': train_config['val_size'], + 'num_train_epochs': train_config['num_epochs'], + 'learning_rate': train_config['learning_rate'], + 'lr_scheduler_type': train_config['lr_scheduler_type'], + 'per_device_train_batch_size': train_config['batch_size'], + 'cutoff_len': train_config['cutoff_len'], + 'lora_r': train_config['lora_r'], + 'lora_alpha': train_config['lora_alpha'], + } + } + + try: + response = requests.post(url, headers=headers, json=data) + response.raise_for_status() + res = response.json() + return (res['job_id'], self.uniform_status(res['status'])) + except Exception as e: + lazyllm.LOG.error(str(e)) + return (None, str(e)) + + def cancel_training(self, token, job_id): + """ + Cancel a training job on the LazyLLM training service. + + This method sends a request to the LazyLLM API to cancel a specific training job. + + Parameters: + - token (str): The user group token required for authentication. + - job_id (str): The unique identifier of the training job to be cancelled. + + Returns: + - bool: True if the job was successfully cancelled, otherwise an error message is returned. + + Raises: + - Exception: If an error occurs during the request, it will be logged and an error message will be returned. + """ + url = urljoin(self.url, f'v1/fine_tuning/jobs/{job_id}/cancel') + headers = { + "token": token, + } + try: + response = requests.post(url, headers=headers) + response.raise_for_status() + status = response.json()['status'] + if status == 'Cancelled': + return True + else: + return f"Failed to cancel task. Final status is {status}" + except Exception as e: + status = str(e) + lazyllm.LOG.error(str(e)) + return f"Failed to cancel task. Because: {str(e)}" + + def get_training_cost(self, token, job_id): + """ + Retrieve the GPU usage time for a training job on the LazyLLM training service. + + This method sends a request to the LazyLLM API to fetch the GPU usage time (in seconds) + for a specific training job. + + Parameters: + - token (str): The user group token required for authentication. + - job_id (str): The unique identifier of the training job for which to retrieve the GPU usage time. + + Returns: + - int: The GPU usage time in seconds if the request is successful. + - str: An error message if the request fails. + + Raises: + - Exception: If an error occurs during the request, it will be logged and an error message will be returned. + + """ + url = urljoin(self.url, f'v1/fine_tuning/jobs/{job_id}') + headers = {"token": token} + try: + response = requests.get(url, headers=headers) + response.raise_for_status() + return response.json()['cost'] + except Exception as e: + error = f"Failed to get cost. Because: {str(e)}" + lazyllm.LOG.error(error) + return error + + def get_training_status(self, token, job_id): + """ + Retrieve the current status of a training job on the LazyLLM training service. + + This method sends a request to the LazyLLM API to fetch the current status of a specific training job. + + Parameters: + - token (str): The user group token required for authentication. + - job_id (str): The unique identifier of the training job for which to retrieve the status. + + Returns: + - str: The current status of the training job if the request is successful. + - 'Invalid' (str): If the request fails or an error occurs. + + Raises: + - Exception: If an error occurs during the request, it will be logged. + """ + url = urljoin(self.url, f'v1/fine_tuning/jobs/{job_id}') + headers = {"token": token} + try: + response = requests.get(url, headers=headers) + response.raise_for_status() + status = self.uniform_status(response.json()['status']) + except Exception as e: + status = 'Invalid' + lazyllm.LOG.error(str(e)) + return status + + def get_training_log(self, token, job_id): + """ + Retrieve the log for the current training job on the LazyLLM training service. + + This method sends a request to the LazyLLM API to fetch the log associated with a specific training job. + + Parameters: + - token (str): The user group token required for authentication. + - job_id (str): The unique identifier of the training job for which to retrieve the log. + + Returns: + - str: The log content if the request is successful. + - None: If the request fails or an error occurs. + + Raises: + - Exception: If an error occurs during the request, it will be logged. + """ + url = urljoin(self.url, f'v1/fine_tuning/jobs/{job_id}/events') + headers = {"token": token} + try: + response = requests.get(url, headers=headers) + response.raise_for_status() + return response.json()['log'] + except Exception as e: + lazyllm.LOG.error(f"Failed to get log. Because: {str(e)}") + return None + + def get_all_trained_models(self, token): + """ + List all models with their job-id, model-id and statuse for the LazyLLM training service. + + Parameters: + - token (str): The user group token required for authentication. + + Returns: + - list of lists: Each sublist contains [job_id, model_name, status] for each trained model. + - None: If the request fails or an error occurs. + + Raises: + - Exception: If an error occurs during the request, it will be logged. + """ + url = urljoin(self.url, 'v1/fine_tuning/jobs') + headers = {"token": token} + try: + response = requests.get(url, headers=headers) + response.raise_for_status() + model_data = response.json() + res = list() + for job_id, job in model_data.items(): + res.append([job_id, job['fine_tuned_model'], job['status']]) + return res + except Exception as e: + lazyllm.LOG.error(f"Failed to get log. Because: {e}") + return None + +class OnlineTrainClient: + + def __init__(self): + pass + + def train(self, train_config, token, source): + """ + Initiates an online training task with the specified parameters and configurations. + + Args: + - train_config (dict): Configuration parameters for the training task. + - token (str): API-Key provided by the supplier, used for authentication. + - source (str): Specifies the supplier. Supported suppliers are 'glm' and 'qwen'. + + Returns: + - tuple: A tuple containing the Job-ID and its status if the training starts successfully. + If an error occurs, the Job-ID will be None, and the error message will be included. + + Raises: + - Exception: For any other errors that occur during the process, which will be logged and returned. + """ + try: + train_config = update_config(train_config, TrainConfig) + assert train_config['training_type'].lower() == 'sft', 'Only supported sft!' + + data_path = os.path.join(lazyllm.config['data_path'], train_config['data_path']) + data_path = uniform_sft_dataset(data_path, target='openai') + m = lazyllm.OnlineChatModule(model=train_config['base_model'], api_key=token, source=source) + + file_id = m._upload_train_file(train_file=data_path) + fine_tuning_job_id, status = m._create_finetuning_job(m._model_name, file_id, **train_config) + + return (fine_tuning_job_id, status) + except Exception as e: + lazyllm.LOG.error(str(e)) + return (None, str(e)) + + def get_all_trained_models(self, token, source): + """ + Lists all model jobs with their corresponding job-id, model-id, and statuse for online training services. + + Args: + - token (str): API-Key provided by the supplier, used for authentication. + - source (str): Specifies the supplier. Supported suppliers are 'glm' and 'qwen'. + + Returns: + - list of lists: Each sublist contains [job_id, model_name, status] for each trained model. + - None: If the request fails or an error occurs. + + Raises: + - Exception: If an error occurs during the request, it will be logged. + """ + try: + m = lazyllm.OnlineChatModule(source=source, api_key=token) + return m._get_finetuned_model_names() + except Exception as e: + lazyllm.LOG.error(str(e)) + return None + + def get_training_status(self, token, job_id, source): + """ + Retrieves the current status of a training task by its Job-ID. + + Args: + - token (str): API-Key provided by the supplier, used for authentication. + - job_id (str): The unique identifier of the training job to query. + - source (str): Specifies the supplier. Supported suppliers are 'glm' and 'qwen'. + + Returns: + - str: A string representing the current status of the training task. This could be one of: + 'Pending', 'Running', 'Done', 'Cancelled', 'Failed', or 'Invalid' if the query could not be processed. + + Raises: + - Exception: For any other errors that occur during the status query process, + which will be logged and returned as 'Invalid'. + """ + try: + m = lazyllm.OnlineChatModule(source=source, api_key=token) + status = m._query_job_status(job_id) + except Exception as e: + status = 'Invalid' + lazyllm.LOG.error(e) + return status + + def cancel_training(self, token, job_id, source): + """ + Cancels an ongoing online training task by its Job-ID. + + Args: + - token (str): API-Key provided by the supplier, used for authentication. + - job_id (str): The unique identifier of the training job to be cancelled. + - source (str): Specifies the supplier. Supported suppliers are 'glm' and 'qwen'. + + Returns: + - bool or str: Returns True if the training task was successfully cancelled. If the cancellation fails, + it returns a string with the reason for the failure, including any final information about the task. + + Raises: + - Exception: For any other errors that occur during the cancellation process, + which will be logged and returned as a string. + """ + try: + m = lazyllm.OnlineChatModule(source=source, api_key=token) + res = m._cancel_finetuning_job(job_id) + if res == 'Cancelled': + return True + else: + return f"Failed to cancel task. Final info is {res}" + except Exception as e: + lazyllm.LOG.error(str(e)) + return f"Failed to cancel task. Because: {str(e)}" + + def get_training_log(self, token, job_id, source, target_path=None): + """ + Retrieves the training log for a specific training task by its Job-ID and saves it to a file. + + Args: + - token (str): API-Key provided by the supplier, used for authentication. + - job_id (str): The unique identifier of the training job for which to retrieve the log. + - source (str): Specifies the supplier. Supported suppliers are 'glm' and 'qwen'. + - target_path (str, optional): The path where the log file should be saved. If not provided, + the log will be saved to a temporary directory. + + Returns: + - str or None: The path to the saved log file if the log retrieval and saving was successful. + If an error occurs, None is returned. + + Raises: + - Exception: For any other errors that occur during the log retrieval and saving process, which will be logged. + """ + try: + m = lazyllm.OnlineChatModule(source=source, api_key=token) + file_name, log = m._get_log(job_id) + save_path = target_path if target_path else os.path.join(m._get_temp_save_dir_path(), f'{file_name}.log') + with open(save_path, 'w', encoding='utf-8') as log_file: + json.dump(log, log_file, indent=4, ensure_ascii=False) + return save_path + except Exception as e: + lazyllm.LOG.error(f"Failed to get log. Because: {e}") + return None + + def get_training_cost(self, token, job_id, source): + """ + Retrieves the number of tokens consumed by an online traning task. + + Args: + - token (str): API-Key provided by the supplier, used for authentication. + - job_id (str): The unique identifier of the traning job for which to retrieve the token consumption. + - source (str): Specifies the supplier. Supported suppliers are 'glm' and 'qwen'. + + Returns: + - int or str: The number of tokens consumed by the traning task if the query is successful. + If an error occurs, a string containing the error message is returned. + + Raises: + - Exception: For any other errors that occur during the token consumption query process, which will be logged. + """ + try: + m = lazyllm.OnlineChatModule(source=source, api_key=token) + res = m._query_finetuning_cost(job_id) + return res + except Exception as e: + error = f"Failed to get cost. Because: {str(e)}" + lazyllm.LOG.error(error) + return error diff --git a/lazyllm/tools/train_service/serve.py b/lazyllm/tools/train_service/serve.py new file mode 100644 index 00000000..de0af21f --- /dev/null +++ b/lazyllm/tools/train_service/serve.py @@ -0,0 +1,393 @@ +import os +import time +import uuid +import copy +import string +import random +import asyncio +import threading +from datetime import datetime +from pydantic import BaseModel, Field +from fastapi import HTTPException, Header +from async_timeout import timeout + +import lazyllm +from lazyllm.launcher import Status +from lazyllm.module.utils import uniform_sft_dataset +from lazyllm import FastapiApp as app + + +class JobDescription(BaseModel): + finetune_model_name: str + base_model: str = Field(default="qwen1.5-0.5b-chat") + data_path: str = Field(default="alpaca/alpaca_data_zh_128.json") + hyperparameters: dict = Field( + default={ + "stage": "sft", + "finetuning_type": "lora", + "val_size": 1, + "num_train_epochs": 1, + "learning_rate": 0.0001, + "lr_scheduler_type": "cosine", + "per_device_train_batch_size": 16, + "cutoff_len": 1024, + "lora_r": 8, + "lora_alpha": 32, + } + ) + +class TrainServer: + + def __init__(self): + self._user_job_training_info = {'default': dict()} + self._active_job_trainings = dict() + self._info_lock = threading.Lock() + self._active_lock = threading.Lock() + self._time_format = '%y%m%d%H%M%S%f' + self._polling_thread = None + + def __call__(self): + if not self._polling_thread: + self._polling_status_checker() + + def __reduce__(self): + return (self.__class__, ()) + + def _update_dict(sef, lock, dicts, k1, k2=None, dict_value=None): + with lock: + if k1 not in dicts: + dicts[k1] = {} + if k2 is None: + return + if k2 not in dicts[k1]: + dicts[k1][k2] = {} + if dict_value is None: + return + if isinstance(dict_value, tuple): # for self._active_job_trainings + dicts[k1][k2] = dict_value + elif isinstance(dict_value, dict): # for self._user_job_training_info + dicts[k1][k2].update(dict_value) + else: + raise RuntimeError('dict_value only supported: dict and tuple') + + def _read_dict(self, lock, dicts, k1=None, k2=None, vk=None, deepcopy=True): + with lock: + if k1 and k2 and vk: + return copy.deepcopy(dicts[k1][k2][vk]) if deepcopy else dicts[k1][k2][vk] + elif k1 and k2: + return copy.deepcopy(dicts[k1][k2]) if deepcopy else dicts[k1][k2] + elif k1: + return copy.deepcopy(dicts[k1]) if deepcopy else dicts[k1] + else: + raise RuntimeError('At least specific k1.') + + def _in_dict(self, lock, dicts, k1, k2=None, vk=None): + with lock: + if k1 not in dicts: + return False + + if k2 is not None: + if k2 not in dicts[k1]: + return False + else: + return True + + if vk is not None: + if vk not in dicts[k1][k2]: + return False + return True + + def _pop_dict(self, lock, dicts, k1, k2=None, vk=None): + with lock: + if k1 and k2 and vk: + return dicts[k1][k2].pop(vk) + elif k1 and k2: + return dicts[k1].pop(k2) + elif k1: + return dicts.pop(k1) + else: + raise RuntimeError('At least specific k1.') + + def _update_user_job_training_info(self, token, job_id=None, dict_value=None): + self._update_dict(self._info_lock, self._user_job_training_info, token, job_id, dict_value) + + def _update_active_job_trainings(self, token, job_id=None, dict_value=None): + self._update_dict(self._active_lock, self._active_job_trainings, token, job_id, dict_value) + + def _read_user_job_training_info(self, token, job_id=None, key=None): + return self._read_dict(self._info_lock, self._user_job_training_info, token, job_id, key) + + def _read_active_job_trainings(self, token, job_id=None): + return self._read_dict(self._active_lock, self._active_job_trainings, token, job_id, deepcopy=False) + + def _in_user_job_training_info(self, token, job_id=None, key=None): + return self._in_dict(self._info_lock, self._user_job_training_info, token, job_id, key) + + def _in_active_job_trainings(self, token, job_id=None): + return self._in_dict(self._active_lock, self._active_job_trainings, token, job_id) + + def _pop_user_job_training_info(self, token, job_id=None, key=None): + return self._pop_dict(self._info_lock, self._user_job_training_info, token, job_id, key) + + def _pop_active_job_trainings(self, token, job_id=None): + return self._pop_dict(self._active_lock, self._active_job_trainings, token, job_id) + + def _update_status(self, token, job_id): + if not self._in_active_job_trainings(token, job_id): + return + # Get basic info + info = self._read_user_job_training_info(token, job_id) + save_path = info['fine_tuned_model'] + log_path = info['log_path'] + + # Get status + m, _ = self._read_active_job_trainings(token, job_id) + status = m.status(info['model_id']).name + + update = {'status': status} + + # Some tasks not run when they are just created + if Status[status] == Status.Running and not info['started_at']: + update = { + 'status': status, + 'started_at': datetime.now().strftime(self._time_format), + } + + # Some tasks cannot obtain the storage path when they are just started + if not save_path: + update['fine_tuned_model'] = self._get_save_path(m) + if not log_path: + update['log_path'] = self._get_log_path(m) + + # Update Status + self._update_user_job_training_info(token, job_id, update) + + # Pop and kill jobs with status: Done, Failed + if Status[status] in (Status.Done, Status.Failed): + m, _ = self._pop_active_job_trainings(token, job_id) + m.stop(info['model_id']) + if info['started_at'] and not info['cost']: + cost = (datetime.now() - datetime.strptime(info['started_at'], self._time_format)).total_seconds() + self._update_user_job_training_info(token, job_id, {'cost': cost}) + return + + create_time = datetime.strptime(info['created_at'], self._time_format) + delta_time = (datetime.now() - create_time).total_seconds() + + # More than 5 min pop and kill jobs with status: Cancelled. Because of + # some tasks have just been started and their status cannot be checked. + if delta_time > 300 and Status[status] == Status.Cancelled: + m, _ = self._pop_active_job_trainings(token, job_id) + m.stop(info['model_id']) + if info['started_at'] and not info['cost']: + cost = (datetime.now() - datetime.strptime(info['started_at'], self._time_format)).total_seconds() + self._update_user_job_training_info(token, job_id, {'cost': cost}) + return + + # More than 50 min pop and kill jobs with status: TBSubmitted, InQueue, Pending + if delta_time > 3000 and Status[status] in (Status.TBSubmitted, Status.InQueue, Status.Pending): + m, _ = self._pop_active_job_trainings(token, job_id) + m.stop(info['model_id']) + return + + def _get_save_path(self, model): + if not hasattr(model._impl, '_finetuned_model_path'): + return None + return model._impl._finetuned_model_path + + def _get_log_path(self, model): + log_dir = self._get_save_path(model) + if not log_dir: + return None + + parts = log_dir.split(os.sep) + if parts[-1].endswith('lazyllm_merge'): + parts[-1] = parts[-1].replace('lazyllm_merge', 'lazyllm_lora') + log_dir = os.sep.join(parts) + + log_files_paths = [] + for file in os.listdir(log_dir): + if file.endswith(".log") and file.startswith("train_log_"): + log_files_paths.append(os.path.join(log_dir, file)) + if len(log_files_paths) == 0: + return None + assert len(log_files_paths) == 1 + return log_files_paths[-1] + + def _polling_status_checker(self, frequent=5): + def polling(): + while True: + # Thread-safe access to two-level keys + with self._active_lock: + loop_items = [(token, job_id) for token in self._active_job_trainings.keys() + for job_id in self._active_job_trainings[token]] + # Update the status of all jobs in sequence + for token, job_id in loop_items: + self._update_status(token, job_id) + time.sleep(frequent) + + self._polling_thread = threading.Thread(target=polling) + self._polling_thread.daemon = True + self._polling_thread.start() + + async def authorize_current_user(self, Bearer: str = None): + if not self._in_user_job_training_info(Bearer): + raise HTTPException( + status_code=401, + detail="Invalid token", + ) + return Bearer + + @app.post("/v1/fine_tuning/jobs") + async def create_job(self, job: JobDescription, token: str = Header(None)): + # await self.authorize_current_user(token) + if not self._in_user_job_training_info(token): + self._update_user_job_training_info(token) + # Build Job-ID: + create_time = datetime.now().strftime(self._time_format) + job_id = '-'.join(['ft', create_time, str(uuid.uuid4())[:5]]) + + # Build Model-ID: + characters = string.ascii_letters + string.digits + random_string = ''.join(random.choices(characters, k=7)) + model_id = job.finetune_model_name + '_' + random_string + + # Build checkpoint save dir: + # - No-Env-Set: (work/path + save_ckpt) + token + job_id; + # - Env-Set: (train_target_root) + token + job_id; + save_root = os.path.join(lazyllm.config['train_target_root'], token, job_id) + + # Add launcher into hyperparameters: + hypram = job.hyperparameters + hypram['launcher'] = lazyllm.launcher.RemoteLauncher(sync=False, ngpus=1) + + # Uniform Training DataSet: + job.data_path = os.path.join(lazyllm.config['data_path'], job.data_path) + job.data_path = uniform_sft_dataset(job.data_path, target='alpaca') + + # Set params for TrainableModule: + m = lazyllm.TrainableModule(job.base_model, save_root)\ + .trainset(job.data_path)\ + .finetune_method(lazyllm.finetune.llamafactory) + + # Launch Training: + thread = threading.Thread(target=m._impl._async_finetune, args=(model_id,), kwargs=hypram) + thread.start() + + # Sleep 5s for launch cmd. + try: + async with timeout(5): + while m.status(model_id) == Status.Cancelled: + await asyncio.sleep(1) + except asyncio.TimeoutError: + pass + + # The first getting the path may be invalid, and it will be getted with each update. + save_path = self._get_save_path(m) + log_path = self._get_log_path(m) + + # Save status + status = m.status(model_id).name + if Status[status] == Status.Running: + started_time = datetime.now().strftime(self._time_format) + else: + started_time = None + self._update_active_job_trainings(token, job_id, (m, thread)) + self._update_user_job_training_info(token, job_id, { + "model_id": model_id, + "job_id": job_id, + "base_model": job.base_model, + "created_at": create_time, + "fine_tuned_model": save_path, + "status": status, + "data_path": job.data_path, + "hyperparameters": hypram, + "log_path": log_path, + "started_at": started_time, + "cost": None, + }) + + return {"job_id": job_id, 'status': status} + + @app.post("/v1/fine_tuning/jobs/{job_id}/cancel") + async def cancel_job(self, job_id: str, token: str = Header(None)): + await self.authorize_current_user(token) + if not self._in_active_job_trainings(token, job_id): + raise HTTPException(status_code=404, detail="Job not found") + + m, _ = self._pop_active_job_trainings(token, job_id) + info = self._read_user_job_training_info(token, job_id) + m.stop(info['model_id']) + + total_sleep = 0 + while m.status(info['model_id']) != Status.Cancelled: + time.sleep(1) + total_sleep += 1 + if total_sleep > 10: + raise HTTPException(status_code=404, detail=f"Task {job_id}, ccancelled timed out.") + + status = m.status(info['model_id']).name + update_dict = {'status': status} + if info['started_at'] and not info['cost']: + update_dict['cost'] = (datetime.now() - datetime.strptime(info['started_at'], + self._time_format)).total_seconds() + self._update_user_job_training_info(token, job_id, update_dict) + + return {"status": status} + + @app.get("/v1/fine_tuning/jobs") + async def list_jobs(self, token: str = Header(None)): + # await self.authorize_current_user(token) + if not self._in_user_job_training_info(token): + self._update_user_job_training_info(token) + save_root = os.path.join(lazyllm.config['train_target_root'], token) + server_running_dict = self._read_user_job_training_info(token) + m = lazyllm.TrainableModule('dummpy', save_root) + valid_models, invalid_models = m.get_all_models() + for model_id, model_path in valid_models: + job_id = model_path[len(save_root):].lstrip(os.sep).split(os.sep)[0] + if job_id in server_running_dict and server_running_dict[job_id]['status'] != 'Done': + server_running_dict[job_id]['status'] = 'Done' + server_running_dict[job_id]['fine_tuned_model'] = model_path + elif job_id not in server_running_dict: + server_running_dict[job_id] = { + 'status': 'Done', + 'model_id': model_id, + 'fine_tuned_model': model_path, + } + for model_id, model_path in invalid_models: + job_id = model_path[len(save_root):].lstrip(os.sep).split(os.sep)[0] + if job_id in server_running_dict and server_running_dict[job_id]['status'] == 'Done': + server_running_dict[job_id]['status'] = 'Failed' + server_running_dict[job_id]['fine_tuned_model'] = model_path + elif job_id not in server_running_dict: + server_running_dict[job_id] = { + 'status': 'Failed', + 'model_id': model_id, + 'fine_tuned_model': model_path, + } + return server_running_dict + + @app.get("/v1/fine_tuning/jobs/{job_id}") + async def get_job_info(self, job_id: str, token: str = Header(None)): + await self.authorize_current_user(token) + if not self._in_user_job_training_info(token, job_id): + raise HTTPException(status_code=404, detail="Job not found") + + self._update_status(token, job_id) + + return self._read_user_job_training_info(token, job_id) + + @app.get("/v1/fine_tuning/jobs/{job_id}/events") + async def get_job_log(self, job_id: str, token: str = Header(None)): + await self.authorize_current_user(token) + if not self._in_user_job_training_info(token, job_id): + raise HTTPException(status_code=404, detail="Job not found") + + self._update_status(token, job_id) + info = self._read_user_job_training_info(token, job_id) + + if info['log_path']: + return {"log": info['log_path']} + else: + return {"log": 'invalid'} diff --git a/lazyllm/tools/webpages/webmodule.py b/lazyllm/tools/webpages/webmodule.py index 82fe2aa4..424dbeca 100644 --- a/lazyllm/tools/webpages/webmodule.py +++ b/lazyllm/tools/webpages/webmodule.py @@ -75,7 +75,7 @@ def _set_up_caching(self): if 'GRADIO_TEMP_DIR' in os.environ: cach_path = os.environ['GRADIO_TEMP_DIR'] else: - cach_path = os.path.join(os.getcwd(), '.temp') + cach_path = os.path.join(lazyllm.config['temp_dir'], 'gradio_cach') os.environ['GRADIO_TEMP_DIR'] = cach_path if not os.path.exists(cach_path): os.makedirs(cach_path) diff --git a/pyproject.toml b/pyproject.toml index 2f2731b5..04f3d665 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ pypdf = "^5.0.0" pytest = "^8.3.3" numpy = "==1.26.4" pymilvus = "^2.4.8" +async-timeout = "^5.0.1" redis = { version = ">=5.0.4", optional = true } huggingface-hub = { version = ">=0.23.1", optional = true } pandas = { version = ">=2.2.2", optional = true } @@ -150,4 +151,4 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.poetry.scripts] -lazyllm = "lazyllm.cli.main:main" \ No newline at end of file +lazyllm = "lazyllm.cli.main:main" diff --git a/requirements.full.txt b/requirements.full.txt index 732bab12..029cf489 100644 --- a/requirements.full.txt +++ b/requirements.full.txt @@ -32,6 +32,7 @@ pypdf pytest numpy==1.26.4 pymilvus +async-timeout redis>=5.0.4 huggingface-hub>=0.23.1 pandas>=2.2.2 diff --git a/requirements.txt b/requirements.txt index 4160246d..1422b468 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,3 +32,4 @@ pypdf pytest numpy==1.26.4 pymilvus +async-timeout diff --git a/tests/advanced_tests/standard_test/test_engine.py b/tests/advanced_tests/standard_test/test_engine.py index ebfed10c..e433459c 100644 --- a/tests/advanced_tests/standard_test/test_engine.py +++ b/tests/advanced_tests/standard_test/test_engine.py @@ -1,8 +1,11 @@ -import lazyllm import os +import time import pytest + +import lazyllm from lazyllm.engine import LightEngine + class TestEngine(object): # This test requires 4 GPUs and takes about 4 minutes to execute, skip this test to save time. def _test_vqa(self): @@ -135,3 +138,57 @@ def test_stream_and_hostory(self): assert '一天' in stream_result and '小时' in stream_result assert '您好,我的答案是' in stream_result and '24' in stream_result assert '蓝鲸' in result and '水' in result + + def test_engine_train_serve(self): + train_config = { + 'finetune_model_name': 'my_super_model', + 'base_model': 'qwen1.5-0.5b-chat', + 'training_type': 'SFT', + 'finetuning_type': 'LoRA', + 'data_path': 'alpaca/alpaca_data_zh_128.json', + 'val_size': 0.1, + 'num_epochs': 1, + 'learning_rate': 0.1, + 'lr_scheduler_type': 'cosine', + 'batch_size': 32, + 'cutoff_len': 1024, + 'lora_r': 8, + 'lora_alpha': 32, + 'lora_rate': 0.1, + } + engine = LightEngine() + engine.launch_localllm_train_service() + + token = 'test' + job_id = None + + # Launch train + res = engine.local_model_train(train_config, token=token) + job_id = res[0] + assert len(job_id) > 0 + status = res[1] + + n = 0 + while status != 'Running': + time.sleep(1) + status = engine.local_model_get_training_status(token, job_id) + n += 1 + assert n < 300, 'Launch training timeout.' + + # After Launch, training 20s + time.sleep(20) + + res = engine.local_model_cancel_training(token, job_id) + assert isinstance(res, bool) + + res = engine.local_model_get_training_status(token, job_id) + assert res == 'Cancelled' + + res = engine.local_model_get_training_log(token, job_id) + assert os.path.exists(res) + + res = engine.local_model_get_all_trained_models(token) + assert len(res[0]) == 3 + + res = engine.local_model_get_training_cost(token, job_id) + assert res > 15 diff --git a/tests/charge_tests/test_engine.py b/tests/charge_tests/test_engine.py index 26892784..b909e527 100644 --- a/tests/charge_tests/test_engine.py +++ b/tests/charge_tests/test_engine.py @@ -256,3 +256,13 @@ def test_stream_and_hostory(self): assert '一天' in stream_result and '小时' in stream_result assert '您好,我的答案是' in stream_result and '24' in stream_result assert '蓝鲸' in result and '水' in result + + def test_egine_online_serve_train(self): + envs = ['glm_api_key', 'qwen_api_key'] + sources = ['glm', 'qwen'] + engine = LightEngine() + + for env, source in list(zip(envs, sources)): + token = lazyllm.config[env] + res = engine.online_model_get_all_trained_models(token, source=source) + assert isinstance(res, list)