diff --git a/comps/finetuning/README.md b/comps/finetuning/README.md index 3e25d9957..fb46977de 100644 --- a/comps/finetuning/README.md +++ b/comps/finetuning/README.md @@ -99,6 +99,8 @@ For reranking and embedding models finetuning, the training file [toy_finetune_d ## 3.2 Create fine-tuning job +### 3.2.1 Instruction Tuning + After a training file like `alpaca_data.json` is uploaded, use the following command to launch a finetuning job using `meta-llama/Llama-2-7b-chat-hf` as base model: ```bash @@ -112,6 +114,8 @@ curl http://${your_ip}:8015/v1/fine_tuning/jobs \ }' ``` +### 3.2.2 Reranking Model Training + Use the following command to launch a finetuning job for reranking model finetuning, such as `BAAI/bge-reranker-large`: ```bash @@ -129,6 +133,46 @@ curl http://${your_ip}:8015/v1/fine_tuning/jobs \ }' ``` +### 3.2.3 Embedding Model Training + +Use the following command to launch a finetuning job for embedding model finetuning, such as `BAAI/bge-base-en-v1.5`: + +```bash +# create a finetuning job +curl http://${your_ip}:8015/v1/fine_tuning/jobs \ + -X POST \ + -H "Content-Type: application/json" \ + -d '{ + "training_file": "toy_finetune_data.jsonl", + "model": "BAAI/bge-base-en-v1.5", + "General":{ + "task":"embedding", + "lora_config":null + } + }' + + +# If training on Gaudi2, we need to set --padding "max_length" and the value of --query_max_len is same with --passage_max_len for static shape during training. For example: +curl http://${your_ip}:8015/v1/fine_tuning/jobs \ + -X POST \ + -H "Content-Type: application/json" \ + -d '{ + "training_file": "toy_finetune_data.jsonl", + "model": "BAAI/bge-base-en-v1.5", + "General":{ + "task":"embedding", + "lora_config":null + }, + "Dataset":{ + "query_max_len":128, + "passage_max_len":128, + "padding":"max_length" + } + }' + + +``` + ## 3.3 Manage fine-tuning job Below commands show how to list finetuning jobs, retrieve a finetuning job, cancel a finetuning job and list checkpoints of a finetuning job. @@ -149,4 +193,4 @@ curl http://${your_ip}:8015/v1/finetune/list_checkpoints -X POST -H "Content-Typ ## 🚀4. Descriptions for Finetuning parameters -We utilize [OpenAI finetuning parameters](https://platform.openai.com/docs/api-reference/fine-tuning) and extend it with more customizable parameters. +We utilize [OpenAI finetuning parameters](https://platform.openai.com/docs/api-reference/fine-tuning) and extend it with more customizable parameters, see the definitions at [finetune_config](https://github.com/opea-project/GenAIComps/blob/main/comps/finetuning/finetune_config.py). diff --git a/comps/finetuning/finetune_config.py b/comps/finetuning/finetune_config.py index 6271b618d..3accabfb3 100644 --- a/comps/finetuning/finetune_config.py +++ b/comps/finetuning/finetune_config.py @@ -5,7 +5,7 @@ from typing import List, Optional, Union -from pydantic import BaseModel, validator +from pydantic import BaseModel, Field, validator from comps.cores.proto.api_protocol import FineTuningJobsRequest @@ -74,13 +74,29 @@ class DatasetConfig(BaseModel): truncation_side: str = "right" max_seq_length: int = 512 truncation: bool = True - padding: bool = True + padding: Union[bool, str] = True mask_input: bool = True mask_response: bool = True data_preprocess_type: str = "neural_chat" max_train_samples: int = 0 max_eval_samples: int = 0 train_group_size: int = 8 + query_max_len: int = Field( + default=128, + description=( + "The maximum total input sequence length after tokenization for passage. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + ), + ) + passage_max_len: int = Field( + default=128, + description=( + "The maximum total input sequence length after tokenization for passage. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + ), + ) + query_instruction_for_retrieval: Optional[str] = Field(default=None, description="instruction for query") + passage_instruction_for_retrieval: Optional[str] = Field(default=None, description="instruction for passage") class RayResourceConfig(BaseModel): @@ -89,6 +105,14 @@ class RayResourceConfig(BaseModel): HPU: int = 0 +class EmbeddingTrainingConfig(BaseModel): + negatives_cross_device: bool = Field(default=False, description="share negatives across devices") + temperature: Optional[float] = Field(default=0.02) + sentence_pooling_method: str = Field(default="cls", description="the pooling method, should be cls or mean") + normalized: bool = Field(default=True) + use_inbatch_neg: bool = Field(default=True, description="use passages in the same batch as negatives") + + class TrainingConfig(BaseModel): optimizer: str = "adamw_torch" batch_size: int = 2 @@ -106,6 +130,7 @@ class TrainingConfig(BaseModel): gradient_accumulation_steps: int = 1 logging_steps: int = 10 deepspeed_config_file: str = "" + embedding_training_config: Optional[EmbeddingTrainingConfig] = EmbeddingTrainingConfig() @validator("device") def check_device(cls, v: str): diff --git a/comps/finetuning/llm_on_ray/finetune/data_process.py b/comps/finetuning/llm_on_ray/finetune/data_process.py index 38455e878..d85bf2bfa 100644 --- a/comps/finetuning/llm_on_ray/finetune/data_process.py +++ b/comps/finetuning/llm_on_ray/finetune/data_process.py @@ -246,3 +246,74 @@ def __call__(self, features) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.T if isinstance(features[0], list): features = sum(features, []) return super().__call__(features) + + +class TrainDatasetForEmbedding(Dataset): + def __init__(self, dataset, args, tokenizer): + self.dataset = dataset + self.tokenizer = tokenizer + self.args = args + self.total_len = len(self.dataset) + + def __len__(self): + return self.total_len + + def __getitem__(self, item) -> Tuple[str, List[str]]: + query = self.dataset[item]["query"] + if self.args["query_instruction_for_retrieval"] is not None: + query = self.args["query_instruction_for_retrieval"] + query + + passages = [] + + assert isinstance(self.dataset[item]["pos"], list) + pos = random.choice(self.dataset[item]["pos"]) + passages.append(pos) + + train_group_size = self.args.get("train_group_size", 8) + if len(self.dataset[item]["neg"]) < train_group_size - 1: + num = math.ceil((train_group_size - 1) / len(self.dataset[item]["neg"])) + negs = random.sample(self.dataset[item]["neg"] * num, train_group_size - 1) + else: + negs = random.sample(self.dataset[item]["neg"], train_group_size - 1) + passages.extend(negs) + + if self.args["passage_instruction_for_retrieval"] is not None: + passages = [self.args["passage_instruction_for_retrieval"] + p for p in passages] + return query, passages + + +@dataclass +class EmbedCollator(DataCollatorWithPadding): + """Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg] + and pass batch separately to the actual collator. + + Abstract out data detail for the model. + """ + + query_max_len: int = 32 + passage_max_len: int = 128 + + def __call__(self, features): + query = [f[0] for f in features] + passage = [f[1] for f in features] + + if isinstance(query[0], list): + query = sum(query, []) + if isinstance(passage[0], list): + passage = sum(passage, []) + + q_collated = self.tokenizer( + query, + padding=self.padding, + truncation=True, + max_length=self.query_max_len, + return_tensors="pt", + ) + d_collated = self.tokenizer( + passage, + padding=self.padding, + truncation=True, + max_length=self.passage_max_len, + return_tensors="pt", + ) + return {"query": q_collated, "passage": d_collated} diff --git a/comps/finetuning/llm_on_ray/finetune/finetune.py b/comps/finetuning/llm_on_ray/finetune/finetune.py index 2476e9638..c66cc7bbe 100644 --- a/comps/finetuning/llm_on_ray/finetune/finetune.py +++ b/comps/finetuning/llm_on_ray/finetune/finetune.py @@ -27,8 +27,14 @@ from comps import CustomLogger from comps.finetuning.finetune_config import FinetuneConfig from comps.finetuning.llm_on_ray import common -from comps.finetuning.llm_on_ray.finetune.data_process import DataProcessor, GroupCollator, TrainDatasetForCE -from comps.finetuning.llm_on_ray.finetune.modeling import CrossEncoder +from comps.finetuning.llm_on_ray.finetune.data_process import ( + DataProcessor, + EmbedCollator, + GroupCollator, + TrainDatasetForCE, + TrainDatasetForEmbedding, +) +from comps.finetuning.llm_on_ray.finetune.modeling import BiEncoderModel, CrossEncoder logger = CustomLogger("llm_on_ray/finetune") @@ -244,7 +250,8 @@ def group_texts(examples): dataset["train"] = TrainDatasetForCE(dataset["train"], config["Dataset"], tokenizer) return dataset elif task == "embedding": - pass + dataset["train"] = TrainDatasetForEmbedding(dataset["train"], config["Dataset"], tokenizer) + return dataset else: raise NotImplementedError(f"Unsupported task {task}, only support instruction_tuning, rerank, embedding now.") @@ -258,7 +265,12 @@ def prepare_data_collator(config: Dict, tokenizer): elif task == "rerank": return GroupCollator(tokenizer) elif task == "embedding": - pass + return EmbedCollator( + tokenizer=tokenizer, + padding=config["Dataset"]["padding"], + query_max_len=config["Dataset"]["query_max_len"], + passage_max_len=config["Dataset"]["passage_max_len"], + ) else: raise NotImplementedError(f"Unsupported task {task}, only support instruction_tuning, rerank, embedding now.") @@ -268,24 +280,36 @@ def load_model(config: Dict): model_dtype = convert_dtype(config["Training"].get("mixed_precision", "no")) model_config = config["General"].get("config", {}) task = config["General"].get("task", "instruction_tuning") - training_args = convert_to_training_args(TrainingArguments, config) if task == "instruction_tuning": model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype, **model_config) - lora_config = config["General"].get("lora_config", None) if lora_config: peft_config = LoraConfig(**lora_config) model = get_peft_model(model, peft_config) elif task == "rerank": model = CrossEncoder.from_pretrained( - config["Dataset"], - training_args, + config["Dataset"].get("train_group_size", 8), + config["Training"]["batch_size"], model_name, from_tf=bool(".ckpt" in model_name), config=model_config, ) elif task == "embedding": - pass + should_concat = False + if ( + config["Dataset"]["query_max_len"] == config["Dataset"]["passage_max_len"] + and config["Dataset"]["padding"] == "max_length" + ): + should_concat = True + if config["Training"]["device"] == "hpu" and not should_concat: + raise ValueError("please set query_max_len==passage_max_len and padding='max_length' for hpu.") + + if config["Training"].get("embedding_training_config", None) is not None: + model = BiEncoderModel( + model_name=model_name, should_concat=should_concat, **config["Training"]["embedding_training_config"] + ) + else: + model = BiEncoderModel(model_name=model_name, should_concat=should_concat) else: raise NotImplementedError(f"Unsupported task {task}, only support instruction_tuning, rerank, embedding now.") diff --git a/comps/finetuning/llm_on_ray/finetune/modeling.py b/comps/finetuning/llm_on_ray/finetune/modeling.py index 0a7e37af4..7a2884f3b 100644 --- a/comps/finetuning/llm_on_ray/finetune/modeling.py +++ b/comps/finetuning/llm_on_ray/finetune/modeling.py @@ -1,24 +1,29 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +from typing import Dict, Optional + import torch +import torch.distributed as dist from torch import nn -from transformers import AutoModelForSequenceClassification, PreTrainedModel, TrainingArguments -from transformers.modeling_outputs import SequenceClassifierOutput +from transformers import AutoModel, AutoModelForSequenceClassification, PreTrainedModel +from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput + +from comps import CustomLogger -from comps.finetuning.finetune_config import DatasetConfig +logger = CustomLogger("llm_on_ray/finetune/modeling") class CrossEncoder(PreTrainedModel): - def __init__(self, hf_model: PreTrainedModel, data_args: DatasetConfig, train_args: TrainingArguments): + def __init__(self, hf_model: PreTrainedModel, train_group_size: int, batch_size: int): super().__init__(hf_model.config) self.hf_model = hf_model - self.train_args = train_args - self.data_args = data_args + self.train_group_size = train_group_size + self.batch_size = batch_size self.cross_entropy = nn.CrossEntropyLoss(reduction="mean") - self.register_buffer("target_label", torch.zeros(self.train_args.per_device_train_batch_size, dtype=torch.long)) + self.register_buffer("target_label", torch.zeros(self.batch_size, dtype=torch.long)) def gradient_checkpointing_enable(self, **kwargs): self.hf_model.gradient_checkpointing_enable(**kwargs) @@ -28,7 +33,7 @@ def forward(self, **batch): logits = ranker_out.logits if self.training: - scores = logits.view(-1, self.data_args.get("train_group_size", 8)) + scores = logits.view(-1, self.train_group_size) loss = self.cross_entropy(scores, self.target_label[: scores.shape[0]]) return SequenceClassifierOutput( @@ -39,9 +44,9 @@ def forward(self, **batch): return ranker_out @classmethod - def from_pretrained(cls, data_args: DatasetConfig, train_args: TrainingArguments, *args, **kwargs): + def from_pretrained(cls, train_group_size: int, batch_size: int, *args, **kwargs): hf_model = AutoModelForSequenceClassification.from_pretrained(*args, **kwargs) - reranker = cls(hf_model, data_args, train_args) + reranker = cls(hf_model, train_group_size, batch_size) return reranker def save_pretrained(self, output_dir: str, **kwargs): @@ -49,3 +54,158 @@ def save_pretrained(self, output_dir: str, **kwargs): state_dict = type(state_dict)({k: v.clone().cpu() for k, v in state_dict.items()}) kwargs.pop("state_dict") self.hf_model.save_pretrained(output_dir, state_dict=state_dict, **kwargs) + + +class BiEncoderModel(nn.Module): + TRANSFORMER_CLS = AutoModel + + def __init__( + self, + model_name: str = None, + should_concat: bool = False, + normalized: bool = False, + sentence_pooling_method: str = "cls", + negatives_cross_device: bool = False, + temperature: float = 1.0, + use_inbatch_neg: bool = True, + ): + super().__init__() + self.model = AutoModel.from_pretrained(model_name, add_pooling_layer=False) + self.cross_entropy = nn.CrossEntropyLoss(reduction="mean") + + self.should_concat = should_concat + self.normalized = normalized + self.sentence_pooling_method = sentence_pooling_method + self.temperature = temperature + self.use_inbatch_neg = use_inbatch_neg + self.config = self.model.config + + if not normalized: + self.temperature = 1.0 + logger.info("reset temperature = 1.0 due to using inner product to compute similarity") + if normalized: + if self.temperature > 0.5: + raise ValueError( + "Temperature should be smaller than 1.0 when use cosine similarity (i.e., normalized=True). Recommend to set it 0.01-0.1" + ) + + self.negatives_cross_device = negatives_cross_device + if self.negatives_cross_device: + if not dist.is_initialized(): + raise ValueError("Distributed training has not been initialized for representation all gather.") + # logger.info("Run in a single GPU, set negatives_cross_device=False") + # self.negatives_cross_device = False + # else: + self.process_rank = dist.get_rank() + self.world_size = dist.get_world_size() + + def gradient_checkpointing_enable(self, **kwargs): + self.model.gradient_checkpointing_enable(**kwargs) + + def sentence_embedding(self, hidden_state, mask): + if self.sentence_pooling_method == "mean": + s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1) + d = mask.sum(axis=1, keepdim=True).float() + return s / d + elif self.sentence_pooling_method == "cls": + return hidden_state[:, 0] + + def encode(self, features): + if features is None: + return None + psg_out = self.model(**features, return_dict=True) + p_reps = self.sentence_embedding(psg_out.last_hidden_state, features["attention_mask"]) + if self.normalized: + p_reps = torch.nn.functional.normalize(p_reps, dim=-1) + return p_reps.contiguous() + + def encode_concat(self, query, passage): + if query is None or passage is None: + return None + + batch_size = query["input_ids"].size()[0] + + psg_out = self.model( + input_ids=torch.cat([query["input_ids"], passage["input_ids"]]), + attention_mask=torch.cat([query["attention_mask"], passage["attention_mask"]]), + return_dict=True, + ) + reps = self.sentence_embedding( + psg_out.last_hidden_state, torch.cat([query["attention_mask"], passage["attention_mask"]]) + ) + if self.normalized: + reps = torch.nn.functional.normalize(reps, dim=-1) + + q_reps = reps[:batch_size] + p_reps = reps[batch_size:] + + return q_reps.contiguous(), p_reps.contiguous() + + def compute_similarity(self, q_reps, p_reps): + if len(p_reps.size()) == 2: + return torch.matmul(q_reps, p_reps.transpose(0, 1)) + return torch.matmul(q_reps, p_reps.transpose(-2, -1)) + + def forward(self, query: Dict[str, torch.Tensor] = None, passage: Dict[str, torch.Tensor] = None): + if self.should_concat: + q_reps, p_reps = self.encode_concat(query, passage) + else: + q_reps = self.encode(query) + p_reps = self.encode(passage) + + if self.training: + if self.negatives_cross_device and self.use_inbatch_neg: + q_reps = self._dist_gather_tensor(q_reps) + p_reps = self._dist_gather_tensor(p_reps) + + group_size = p_reps.size(0) // q_reps.size(0) + if self.use_inbatch_neg: + scores = self.compute_similarity(q_reps, p_reps) / self.temperature # B B*G + scores = scores.view(q_reps.size(0), -1) + + target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) + target = target * group_size + loss = self.compute_loss(scores, target) + else: + scores = ( + self.compute_similarity( + q_reps[ + :, + None, + :, + ], + p_reps.view(q_reps.size(0), group_size, -1), + ).squeeze(1) + / self.temperature + ) # B G + + scores = scores.view(q_reps.size(0), -1) + target = torch.zeros(scores.size(0), device=scores.device, dtype=torch.long) + loss = self.compute_loss(scores, target) + + else: + scores = self.compute_similarity(q_reps, p_reps) + loss = None + + return MaskedLMOutput(loss=loss, logits=None, hidden_states=None, attentions=None) + + def compute_loss(self, scores, target): + return self.cross_entropy(scores, target) + + def _dist_gather_tensor(self, t: Optional[torch.Tensor]): + if t is None: + return None + t = t.contiguous() + + all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] + dist.all_gather(all_tensors, t) + + all_tensors[self.process_rank] = t + all_tensors = torch.cat(all_tensors, dim=0) + + return all_tensors + + def save(self, output_dir: str): + state_dict = self.model.state_dict() + state_dict = type(state_dict)({k: v.clone().cpu() for k, v in state_dict.items()}) + self.model.save_pretrained(output_dir, state_dict=state_dict) diff --git a/tests/test_finetuning_embedding_hpu.sh b/tests/test_finetuning_embedding_hpu.sh new file mode 100644 index 000000000..e080b0d13 --- /dev/null +++ b/tests/test_finetuning_embedding_hpu.sh @@ -0,0 +1,124 @@ +#!/bin/bash +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +set -x + +WORKPATH=$(dirname "$PWD") +LOG_PATH="$WORKPATH/tests" +ip_address=$(hostname -I | awk '{print $1}') +finetuning_service_port=8015 +ray_port=8265 + +function build_docker_images() { + cd $WORKPATH + echo $(pwd) + docker build -t opea/finetuning-gaudi:latest --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/finetuning/docker/Dockerfile_hpu . + if [ $? -ne 0 ]; then + echo "opea/finetuning-gaudi built fail" + exit 1 + else + echo "opea/finetuning-gaudi built successful" + fi +} + +function start_service() { + export no_proxy="localhost,127.0.0.1,"${ip_address} + docker run -d --name="finetuning-server" --runtime=habana -e HABANA_VISIBLE_DEVICES=all -p $finetuning_service_port:$finetuning_service_port -p $ray_port:$ray_port -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host -e https_proxy=$https_proxy -e http_proxy=$http_proxy -e no_proxy=$no_proxy -e HF_TOKEN=$HF_TOKEN opea/finetuning-gaudi:latest + sleep 1m +} + +function validate_microservice() { + cd $LOG_PATH + export no_proxy="localhost,127.0.0.1,"${ip_address} + + # test /v1/dataprep upload file + URL="http://${ip_address}:$finetuning_service_port/v1/files" + echo '{"query": "A girl with a blue tank top sitting watching three dogs.", "pos": ["A girl is wearing blue."], "neg": ["A girl is with three cats.", "The people are watching a funeral procession.", "The child is wearing black.", "Financing is an issue for us in public schools.", "Kids at a pool.", "It is calming to be assaulted.", "I face a serious problem at eighteen years old. "]}' >> $LOG_PATH/test_embed_data.json + echo '{"query": "A girl with a blue tank top sitting watching three dogs.", "pos": ["A girl is wearing blue."], "neg": ["A girl is with three cats.", "The people are watching a funeral procession.", "The child is wearing black.", "Financing is an issue for us in public schools.", "Kids at a pool.", "It is calming to be assaulted.", "I face a serious problem at eighteen years old. "]}' >> $LOG_PATH/test_embed_data.json + echo '{"query": "A girl with a blue tank top sitting watching three dogs.", "pos": ["A girl is wearing blue."], "neg": ["A girl is with three cats.", "The people are watching a funeral procession.", "The child is wearing black.", "Financing is an issue for us in public schools.", "Kids at a pool.", "It is calming to be assaulted.", "I face a serious problem at eighteen years old. "]}' >> $LOG_PATH/test_embed_data.json + echo '{"query": "A girl with a blue tank top sitting watching three dogs.", "pos": ["A girl is wearing blue."], "neg": ["A girl is with three cats.", "The people are watching a funeral procession.", "The child is wearing black.", "Financing is an issue for us in public schools.", "Kids at a pool.", "It is calming to be assaulted.", "I face a serious problem at eighteen years old. "]}' >> $LOG_PATH/test_embed_data.json + echo '{"query": "A girl with a blue tank top sitting watching three dogs.", "pos": ["A girl is wearing blue."], "neg": ["A girl is with three cats.", "The people are watching a funeral procession.", "The child is wearing black.", "Financing is an issue for us in public schools.", "Kids at a pool.", "It is calming to be assaulted.", "I face a serious problem at eighteen years old. "]}' >> $LOG_PATH/test_embed_data.json + echo '{"query": "A girl with a blue tank top sitting watching three dogs.", "pos": ["A girl is wearing blue."], "neg": ["A girl is with three cats.", "The people are watching a funeral procession.", "The child is wearing black.", "Financing is an issue for us in public schools.", "Kids at a pool.", "It is calming to be assaulted.", "I face a serious problem at eighteen years old. "]}' >> $LOG_PATH/test_embed_data.json + echo '{"query": "A girl with a blue tank top sitting watching three dogs.", "pos": ["A girl is wearing blue."], "neg": ["A girl is with three cats.", "The people are watching a funeral procession.", "The child is wearing black.", "Financing is an issue for us in public schools.", "Kids at a pool.", "It is calming to be assaulted.", "I face a serious problem at eighteen years old. "]}' >> $LOG_PATH/test_embed_data.json + echo '{"query": "A girl with a blue tank top sitting watching three dogs.", "pos": ["A girl is wearing blue."], "neg": ["A girl is with three cats.", "The people are watching a funeral procession.", "The child is wearing black.", "Financing is an issue for us in public schools.", "Kids at a pool.", "It is calming to be assaulted.", "I face a serious problem at eighteen years old. "]}' >> $LOG_PATH/test_embed_data.json + echo '{"query": "A girl with a blue tank top sitting watching three dogs.", "pos": ["A girl is wearing blue."], "neg": ["A girl is with three cats.", "The people are watching a funeral procession.", "The child is wearing black.", "Financing is an issue for us in public schools.", "Kids at a pool.", "It is calming to be assaulted.", "I face a serious problem at eighteen years old. "]}' >> $LOG_PATH/test_embed_data.json + echo '{"query": "A girl with a blue tank top sitting watching three dogs.", "pos": ["A girl is wearing blue."], "neg": ["A girl is with three cats.", "The people are watching a funeral procession.", "The child is wearing black.", "Financing is an issue for us in public schools.", "Kids at a pool.", "It is calming to be assaulted.", "I face a serious problem at eighteen years old. "]}' >> $LOG_PATH/test_embed_data.json + + HTTP_RESPONSE=$(curl --silent --write-out "HTTPSTATUS:%{http_code}" -X POST -F 'file=@./test_embed_data.json' -F purpose="fine-tune" -H 'Content-Type: multipart/form-data' "$URL") + HTTP_STATUS=$(echo $HTTP_RESPONSE | tr -d '\n' | sed -e 's/.*HTTPSTATUS://') + RESPONSE_BODY=$(echo $HTTP_RESPONSE | sed -e 's/HTTPSTATUS\:.*//g') + SERVICE_NAME="finetuning-server - upload - file" + + # Parse the JSON response + purpose=$(echo "$RESPONSE_BODY" | jq -r '.purpose') + filename=$(echo "$RESPONSE_BODY" | jq -r '.filename') + + # Define expected values + expected_purpose="fine-tune" + expected_filename="test_embed_data.json" + + if [ "$HTTP_STATUS" -ne "200" ]; then + echo "[ $SERVICE_NAME ] HTTP status is not 200. Received status was $HTTP_STATUS" + docker logs finetuning-server >> ${LOG_PATH}/finetuning-server_upload_file.log + exit 1 + else + echo "[ $SERVICE_NAME ] HTTP status is 200. Checking content..." + fi + # Check if the parsed values match the expected values + if [[ "$purpose" != "$expected_purpose" || "$filename" != "$expected_filename" ]]; then + echo "[ $SERVICE_NAME ] Content does not match the expected result: $RESPONSE_BODY" + docker logs finetuning-server >> ${LOG_PATH}/finetuning-server_upload_file.log + exit 1 + else + echo "[ $SERVICE_NAME ] Content is as expected." + fi + + # test /v1/fine_tuning/jobs + URL="http://${ip_address}:$finetuning_service_port/v1/fine_tuning/jobs" + HTTP_RESPONSE=$(curl --silent --write-out "HTTPSTATUS:%{http_code}" -X POST -H 'Content-Type: application/json' -d '{"training_file": "test_embed_data.json","model": "BAAI/bge-base-en-v1.5","General":{"task":"embedding","lora_cofig":null,"save_strategy":"epoch"},"Dataset":{"query_max_len":128,"passage_max_len":128,"padding":"max_length"},"Training":{"epochs":3}}' "$URL") + HTTP_STATUS=$(echo $HTTP_RESPONSE | tr -d '\n' | sed -e 's/.*HTTPSTATUS://') + RESPONSE_BODY=$(echo $HTTP_RESPONSE | sed -e 's/HTTPSTATUS\:.*//g') + SERVICE_NAME="finetuning-server - create finetuning job" + + if [ "$HTTP_STATUS" -ne "200" ]; then + echo "[ $SERVICE_NAME ] HTTP status is not 200. Received status was $HTTP_STATUS" + docker logs finetuning-server >> ${LOG_PATH}/finetuning-server_create.log + exit 1 + else + echo "[ $SERVICE_NAME ] HTTP status is 200. Checking content..." + fi + if [[ "$RESPONSE_BODY" != *'{"id":"ft-job'* ]]; then + echo "[ $SERVICE_NAME ] Content does not match the expected result: $RESPONSE_BODY" + docker logs finetuning-server >> ${LOG_PATH}/finetuning-server_create.log + exit 1 + else + echo "[ $SERVICE_NAME ] Content is as expected." + fi + + sleep 10m + + # get logs + docker logs finetuning-server >> ${LOG_PATH}/finetuning-server_create.log +} + +function stop_docker() { + cid=$(docker ps -aq --filter "name=finetuning-server*") + if [[ ! -z "$cid" ]]; then docker stop $cid && docker rm $cid && sleep 1s; fi +} + +function main() { + + stop_docker + + build_docker_images + start_service + + validate_microservice + + stop_docker + echo y | docker system prune + +} + +main