Skip to content

Commit

Permalink
PPOtask for mlora_train
Browse files Browse the repository at this point in the history
  • Loading branch information
ck-gyj committed Jan 5, 2025
1 parent 3e4daec commit ac33270
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 21 deletions.
1 change: 1 addition & 0 deletions mlora/config/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class PPOTaskConfig(TrainTaskConfig):
critic_adapter_: AdapterConfig
actor_adapter_: AdapterConfig
kl_coefficient_: float
optim_num_: int

__params_map: Dict[str, str] = {
"gamma_": "gamma",
Expand Down
2 changes: 1 addition & 1 deletion mlora/executor/context/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import abstractmethod
from collections import OrderedDict
from typing import Callable, Dict, List, Type
from typing import Callable, Dict, List, Type, Any

import torch

Expand Down
33 changes: 19 additions & 14 deletions mlora/executor/task/ppo_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.distributions import Categorical
import json
import numpy as np
from functools import partial

from mlora.config import PPOTaskConfig
from mlora.executor.context import TRAINCONTEXT_CLASS, TrainTaskContext, TaskContext, INFERENCECONTEXT_CLASS
Expand Down Expand Up @@ -46,8 +47,7 @@ class PPOTask(TrainTask):
now_optim_iter_num: int
adv: torch.Tensor
td_target: torch.Tensor
policy_tokens: torch.Tensor
reward_tokens: torch.Tensor
policy_tokens: list[list[int]]
state_: Stage # 0: initial stage 1: decision stage 2: update state 3: iteration state


Expand Down Expand Up @@ -78,7 +78,8 @@ def prepare(self, linears_info: OrderedDict[str, LinearInfo],tokenizer: Tokenize
self._pre_dataset()
self.ppo_pre_context(linears_info)

LOSS_CLASS={"mse":self.ppo_mse,"adv_loss":self.ppo_adv_loss,"reward_loss":self.ppo_reward_loss}
LOSS_CLASS={"mse":partial(self.ppo_mse),"adv_loss":partial(self.ppo_adv_loss),
"reward_loss":partial(self.ppo_reward_loss)}
self.critic_context_.set_loss_fn(LOSS_CLASS[self.config_.critic_loss_type_])
self.actor_context_.set_loss_fn(LOSS_CLASS[self.config_.actor_loss_type_])
self.reward_context_.set_loss_fn(LOSS_CLASS[self.config_.reward_loss_type_])
Expand Down Expand Up @@ -113,7 +114,7 @@ def ppo_pre_context(self, linears_info: OrderedDict[str, LinearInfo]):
)
self._pre_ref_context(linears_info)

def ppo_mse(self, data: torch.Tensor,label: torch.Tensor):
def ppo_mse(self, data: torch.Tensor,label: torch.Tensor) -> torch.Tensor:
return (data - label).pow(2).mean()


Expand All @@ -132,7 +133,7 @@ def ppo_adv_loss(self, prob: torch.Tensor, old_prob: torch.Tensor,

return loss1

def ppo_reward_loss(self, reward_chosen: torch.Tensor, reward_reject: torch.Tensor)-> torch.tensor:
def ppo_reward_loss(self, reward_chosen: torch.Tensor, reward_reject: torch.Tensor)-> torch.Tensor:
return reward_chosen-reward_reject

@override
Expand Down Expand Up @@ -171,7 +172,7 @@ def stage_0(self, start_idx: int):
data_idx_s = self.now_data_idx_
data_idx_e = self.now_data_idx_ + self.config_.mini_batch_size_
# get the train raw string
batch_strr = self.prompter_.generate_prompt(self.data_[data_idx_s:data_idx_e],"reward")
batch_strr = self.prompter_.generate_prompt(self.data_[data_idx_s:data_idx_e])
reward_tokens = list(
map(
lambda raw_str: self.tokenizer_.encode(
Expand All @@ -181,6 +182,8 @@ def stage_0(self, start_idx: int):
)
)

l=int(len(reward_tokens)/3)
reward_tokens=reward_tokens[l:]
reward_start_idx=start_idx
reward_end_idx=reward_start_idx+len(reward_tokens)

Expand Down Expand Up @@ -218,7 +221,7 @@ def stage_1(self, start_idx:int):
data_idx_e = self.now_data_idx_ + self.config_.mini_batch_size_

# get the train raw string
batch_str = self.prompter_.generate_prompt(self.data_[data_idx_s:data_idx_e],"instruction")
batch_str = self.prompter_.generate_prompt(self.data_[data_idx_s:data_idx_e])
actor_tokens = list(
map(
lambda raw_str: self.tokenizer_.encode(
Expand All @@ -227,7 +230,9 @@ def stage_1(self, start_idx:int):
batch_str,
)
)


l=int(len(actor_tokens)/3)
actor_tokens=actor_tokens[:l]
batch_num=int(len(actor_tokens))
generate_num=self.config_.generate_num_
BOS=actor_tokens[0][0]
Expand All @@ -240,7 +245,7 @@ def stage_1(self, start_idx:int):

self.state_=Stage.policy_training_decision

def stage_2(self, input: torch.tensor, actor_start_idx: int, actor_end_idx: int,
def stage_2(self, input: torch.Tensor, actor_start_idx: int, actor_end_idx: int,
deterministic: bool = False):
critic_len=int(len(self.policy_tokens[-1]))
if(self.idx==critic_len-1):
Expand All @@ -262,9 +267,9 @@ def stage_2(self, input: torch.tensor, actor_start_idx: int, actor_end_idx: int,
a=a.view(batch_num,-1)

for i in range(batch_num):
self.policy_tokens[i].append(a[i].item())
self.policy_tokens[i].append(int(a[i].item()))
for i in range(batch_num,2*batch_num):
self.policy_tokens[i][idx]=a[i-batch_num].item()
self.policy_tokens[i][idx]=int(a[i-batch_num].item())
self.idx+=1

def stage_3(self, start_idx:int):
Expand All @@ -279,7 +284,7 @@ def stage_3(self, start_idx:int):
critic_len=int(len(self.policy_tokens[-1]))# the real critic's len
reward_tokens = copy.deepcopy(self.policy_tokens[:batch_num])
ref_tokens = copy.deepcopy(self.policy_tokens[:batch_num])
p_tokens=[]
p_tokens: List[List[int]] = []
p_tokens.extend(reward_tokens)
p_tokens.extend(ref_tokens)
p_tokens.extend(self.policy_tokens)
Expand All @@ -295,7 +300,7 @@ def stage_3(self, start_idx:int):

def loss_fn(
input: torch.Tensor, _: torch.Tensor, __: torch.Tensor
)-> Optional[torch.tensor]:
)-> Optional[torch.Tensor]:

if self.state_==Stage.policy_training_decision:
self.stage_2(input,actor_start_idx,actor_end_idx)
Expand All @@ -313,7 +318,7 @@ def loss_fn(
log_prob = log_p.gather(-1, action).squeeze(-1)
ref_log_prob = log_ref_p.gather(-1, action).squeeze(-1)
r=-(log_prob-ref_log_prob)
r[:,-1]+=torch.tanh(input[reward_start_idx:reward_end_idx,-1]@PPOTask.reward_tensor.squeeze(dim=-1))
r[:,-1]+=torch.tanh(input[reward_start_idx:reward_end_idx,-1]@PPOTask.reward_tensor).squeeze(dim=-1)

v=torch.tanh((input[critic_start_idx:critic_end_idx,1:critic_len]@PPOTask.critic_tensor).squeeze(dim=-1))
v_=v.clone().detach()
Expand Down
9 changes: 3 additions & 6 deletions mlora/prompter/ppo_data_prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,20 @@ class PpoDataPrompter(Prompter):
def __init__(self, template: str):
super().__init__(template)

def __generate_prompt(self, data_point: Dict[str, str], optional: str) -> Tuple[str, str]:
def __generate_prompt(self, data_point: Dict[str, str], optional: str) -> str:
data = self.template_.render(data_point=data_point, Optional=optional)
return data

@override
def generate_prompt(self, data_points: List[Dict[str, str]], optional: str) -> List[str]:
def generate_prompt(self, data_points: List[Dict[str, str]]) -> List[str]:
instru_data = []
chosen_data = []
reject_data = []
data=[]

for data_point in data_points:
if optional=="instruction":
data_str = self.__generate_prompt(data_point,"instruction")
instru_data.append(data_str)
else:
chosen_str = self.__generate_prompt(data_point,"chosen")
reject_str = self.__generate_prompt(data_point,"reject")
chosen_data.append(chosen_str)
Expand All @@ -31,5 +29,4 @@ def generate_prompt(self, data_points: List[Dict[str, str]], optional: str) -> L
data.extend(instru_data)
data.extend(chosen_data)
data.extend(reject_data)
return data

return data

0 comments on commit ac33270

Please sign in to comment.