From ac3327062048b9825c5853a81f0b2a5883997bef Mon Sep 17 00:00:00 2001 From: root <261295365@qq.com> Date: Sun, 5 Jan 2025 15:15:59 +0000 Subject: [PATCH] PPOtask for mlora_train --- mlora/config/task.py | 1 + mlora/executor/context/train.py | 2 +- mlora/executor/task/ppo_task.py | 33 +++++++++++++++++------------ mlora/prompter/ppo_data_prompter.py | 9 +++----- 4 files changed, 24 insertions(+), 21 deletions(-) diff --git a/mlora/config/task.py b/mlora/config/task.py index dd0e0042..f1b2a7ed 100644 --- a/mlora/config/task.py +++ b/mlora/config/task.py @@ -164,6 +164,7 @@ class PPOTaskConfig(TrainTaskConfig): critic_adapter_: AdapterConfig actor_adapter_: AdapterConfig kl_coefficient_: float + optim_num_: int __params_map: Dict[str, str] = { "gamma_": "gamma", diff --git a/mlora/executor/context/train.py b/mlora/executor/context/train.py index 0c1a5803..3fadcf3f 100644 --- a/mlora/executor/context/train.py +++ b/mlora/executor/context/train.py @@ -1,6 +1,6 @@ from abc import abstractmethod from collections import OrderedDict -from typing import Callable, Dict, List, Type +from typing import Callable, Dict, List, Type, Any import torch diff --git a/mlora/executor/task/ppo_task.py b/mlora/executor/task/ppo_task.py index 8bbfbe67..d5270f06 100644 --- a/mlora/executor/task/ppo_task.py +++ b/mlora/executor/task/ppo_task.py @@ -9,6 +9,7 @@ from torch.distributions import Categorical import json import numpy as np +from functools import partial from mlora.config import PPOTaskConfig from mlora.executor.context import TRAINCONTEXT_CLASS, TrainTaskContext, TaskContext, INFERENCECONTEXT_CLASS @@ -46,8 +47,7 @@ class PPOTask(TrainTask): now_optim_iter_num: int adv: torch.Tensor td_target: torch.Tensor - policy_tokens: torch.Tensor - reward_tokens: torch.Tensor + policy_tokens: list[list[int]] state_: Stage # 0: initial stage 1: decision stage 2: update state 3: iteration state @@ -78,7 +78,8 @@ def prepare(self, linears_info: OrderedDict[str, LinearInfo],tokenizer: Tokenize self._pre_dataset() self.ppo_pre_context(linears_info) - LOSS_CLASS={"mse":self.ppo_mse,"adv_loss":self.ppo_adv_loss,"reward_loss":self.ppo_reward_loss} + LOSS_CLASS={"mse":partial(self.ppo_mse),"adv_loss":partial(self.ppo_adv_loss), + "reward_loss":partial(self.ppo_reward_loss)} self.critic_context_.set_loss_fn(LOSS_CLASS[self.config_.critic_loss_type_]) self.actor_context_.set_loss_fn(LOSS_CLASS[self.config_.actor_loss_type_]) self.reward_context_.set_loss_fn(LOSS_CLASS[self.config_.reward_loss_type_]) @@ -113,7 +114,7 @@ def ppo_pre_context(self, linears_info: OrderedDict[str, LinearInfo]): ) self._pre_ref_context(linears_info) - def ppo_mse(self, data: torch.Tensor,label: torch.Tensor): + def ppo_mse(self, data: torch.Tensor,label: torch.Tensor) -> torch.Tensor: return (data - label).pow(2).mean() @@ -132,7 +133,7 @@ def ppo_adv_loss(self, prob: torch.Tensor, old_prob: torch.Tensor, return loss1 - def ppo_reward_loss(self, reward_chosen: torch.Tensor, reward_reject: torch.Tensor)-> torch.tensor: + def ppo_reward_loss(self, reward_chosen: torch.Tensor, reward_reject: torch.Tensor)-> torch.Tensor: return reward_chosen-reward_reject @override @@ -171,7 +172,7 @@ def stage_0(self, start_idx: int): data_idx_s = self.now_data_idx_ data_idx_e = self.now_data_idx_ + self.config_.mini_batch_size_ # get the train raw string - batch_strr = self.prompter_.generate_prompt(self.data_[data_idx_s:data_idx_e],"reward") + batch_strr = self.prompter_.generate_prompt(self.data_[data_idx_s:data_idx_e]) reward_tokens = list( map( lambda raw_str: self.tokenizer_.encode( @@ -181,6 +182,8 @@ def stage_0(self, start_idx: int): ) ) + l=int(len(reward_tokens)/3) + reward_tokens=reward_tokens[l:] reward_start_idx=start_idx reward_end_idx=reward_start_idx+len(reward_tokens) @@ -218,7 +221,7 @@ def stage_1(self, start_idx:int): data_idx_e = self.now_data_idx_ + self.config_.mini_batch_size_ # get the train raw string - batch_str = self.prompter_.generate_prompt(self.data_[data_idx_s:data_idx_e],"instruction") + batch_str = self.prompter_.generate_prompt(self.data_[data_idx_s:data_idx_e]) actor_tokens = list( map( lambda raw_str: self.tokenizer_.encode( @@ -227,7 +230,9 @@ def stage_1(self, start_idx:int): batch_str, ) ) - + + l=int(len(actor_tokens)/3) + actor_tokens=actor_tokens[:l] batch_num=int(len(actor_tokens)) generate_num=self.config_.generate_num_ BOS=actor_tokens[0][0] @@ -240,7 +245,7 @@ def stage_1(self, start_idx:int): self.state_=Stage.policy_training_decision - def stage_2(self, input: torch.tensor, actor_start_idx: int, actor_end_idx: int, + def stage_2(self, input: torch.Tensor, actor_start_idx: int, actor_end_idx: int, deterministic: bool = False): critic_len=int(len(self.policy_tokens[-1])) if(self.idx==critic_len-1): @@ -262,9 +267,9 @@ def stage_2(self, input: torch.tensor, actor_start_idx: int, actor_end_idx: int, a=a.view(batch_num,-1) for i in range(batch_num): - self.policy_tokens[i].append(a[i].item()) + self.policy_tokens[i].append(int(a[i].item())) for i in range(batch_num,2*batch_num): - self.policy_tokens[i][idx]=a[i-batch_num].item() + self.policy_tokens[i][idx]=int(a[i-batch_num].item()) self.idx+=1 def stage_3(self, start_idx:int): @@ -279,7 +284,7 @@ def stage_3(self, start_idx:int): critic_len=int(len(self.policy_tokens[-1]))# the real critic's len reward_tokens = copy.deepcopy(self.policy_tokens[:batch_num]) ref_tokens = copy.deepcopy(self.policy_tokens[:batch_num]) - p_tokens=[] + p_tokens: List[List[int]] = [] p_tokens.extend(reward_tokens) p_tokens.extend(ref_tokens) p_tokens.extend(self.policy_tokens) @@ -295,7 +300,7 @@ def stage_3(self, start_idx:int): def loss_fn( input: torch.Tensor, _: torch.Tensor, __: torch.Tensor - )-> Optional[torch.tensor]: + )-> Optional[torch.Tensor]: if self.state_==Stage.policy_training_decision: self.stage_2(input,actor_start_idx,actor_end_idx) @@ -313,7 +318,7 @@ def loss_fn( log_prob = log_p.gather(-1, action).squeeze(-1) ref_log_prob = log_ref_p.gather(-1, action).squeeze(-1) r=-(log_prob-ref_log_prob) - r[:,-1]+=torch.tanh(input[reward_start_idx:reward_end_idx,-1]@PPOTask.reward_tensor.squeeze(dim=-1)) + r[:,-1]+=torch.tanh(input[reward_start_idx:reward_end_idx,-1]@PPOTask.reward_tensor).squeeze(dim=-1) v=torch.tanh((input[critic_start_idx:critic_end_idx,1:critic_len]@PPOTask.critic_tensor).squeeze(dim=-1)) v_=v.clone().detach() diff --git a/mlora/prompter/ppo_data_prompter.py b/mlora/prompter/ppo_data_prompter.py index f6309108..97f6a3fb 100644 --- a/mlora/prompter/ppo_data_prompter.py +++ b/mlora/prompter/ppo_data_prompter.py @@ -7,22 +7,20 @@ class PpoDataPrompter(Prompter): def __init__(self, template: str): super().__init__(template) - def __generate_prompt(self, data_point: Dict[str, str], optional: str) -> Tuple[str, str]: + def __generate_prompt(self, data_point: Dict[str, str], optional: str) -> str: data = self.template_.render(data_point=data_point, Optional=optional) return data @override - def generate_prompt(self, data_points: List[Dict[str, str]], optional: str) -> List[str]: + def generate_prompt(self, data_points: List[Dict[str, str]]) -> List[str]: instru_data = [] chosen_data = [] reject_data = [] data=[] for data_point in data_points: - if optional=="instruction": data_str = self.__generate_prompt(data_point,"instruction") instru_data.append(data_str) - else: chosen_str = self.__generate_prompt(data_point,"chosen") reject_str = self.__generate_prompt(data_point,"reject") chosen_data.append(chosen_str) @@ -31,5 +29,4 @@ def generate_prompt(self, data_points: List[Dict[str, str]], optional: str) -> L data.extend(instru_data) data.extend(chosen_data) data.extend(reject_data) - return data - \ No newline at end of file + return data \ No newline at end of file