Skip to content

Commit

Permalink
Merge pull request #63 from opentensor/feature/track_raw_score
Browse files Browse the repository at this point in the history
Feature/track raw score
  • Loading branch information
p-ferreira authored Nov 9, 2023
2 parents 1d8c3bc + 07882b7 commit beab14a
Show file tree
Hide file tree
Showing 14 changed files with 278 additions and 114 deletions.
43 changes: 42 additions & 1 deletion prompting/validators/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,25 @@ class EventSchema:
List[float]
] # Output vector of the dahoas reward model
blacklist_filter: Optional[List[float]] # Output vector of the blacklist filter
blacklist_filter_matched_ngram: Optional[
List[str]
] # Output vector of the blacklist filter
blacklist_filter_significance_score: Optional[
List[float]
] # Output vector of the blacklist filter
nsfw_filter: Optional[List[float]] # Output vector of the nsfw filter
reciprocate_reward_model: Optional[
List[float]
] # Output vector of the reciprocate reward model
diversity_reward_model: Optional[
List[float]
] # Output vector of the diversity reward model
diversity_reward_model_historic: Optional[
List[float]
] # Output vector of the diversity reward model
diversity_reward_model_batch: Optional[
List[float]
] # Output vector of the diversity reward model
dpo_reward_model: Optional[List[float]] # Output vector of the dpo reward model
rlhf_reward_model: Optional[List[float]] # Output vector of the rlhf reward model
prompt_reward_model: Optional[
Expand All @@ -65,6 +77,7 @@ class EventSchema:
List[float]
] # Output vector of the dahoas reward model
nsfw_filter_normalized: Optional[List[float]] # Output vector of the nsfw filter
nsfw_filter_score: Optional[List[float]] # Output vector of the nsfw filter
reciprocate_reward_model_normalized: Optional[
List[float]
] # Output vector of the reciprocate reward model
Expand All @@ -80,7 +93,16 @@ class EventSchema:
prompt_reward_model_normalized: Optional[
List[float]
] # Output vector of the prompt reward model
relevance_filter_normalized: Optional[List[float]]

relevance_filter_normalized: Optional[
List[float]
] # Output vector of the relevance scoring reward model
relevance_filter_bert_score: Optional[
List[float]
] # Output vector of the relevance scoring reward model
relevance_filter_mpnet_score: Optional[
List[float]
] # Output vector of the relevance scoring reward model
# TODO: Add comments
task_validation_penalty_raw: Optional[List[float]]
task_validation_penalty_adjusted: Optional[List[float]]
Expand Down Expand Up @@ -109,6 +131,12 @@ def from_dict(event_dict: dict, disable_log_rewards: bool) -> "EventSchema":
RewardModelType.reciprocate.value
),
"diversity_reward_model": event_dict.get(RewardModelType.diversity.value),
"diversity_reward_model_historic": event_dict.get(
RewardModelType.diversity.value + "_historic"
),
"diversity_reward_model_batch": event_dict.get(
RewardModelType.diversity.value + "_batch"
),
"dpo_reward_model": event_dict.get(RewardModelType.dpo.value),
"rlhf_reward_model": event_dict.get(RewardModelType.rlhf.value),
"prompt_reward_model": event_dict.get(RewardModelType.prompt.value),
Expand Down Expand Up @@ -136,6 +164,19 @@ def from_dict(event_dict: dict, disable_log_rewards: bool) -> "EventSchema":
"prompt_reward_model_normalized": event_dict.get(
RewardModelType.prompt.value + "_normalized"
),
"blacklist_filter_matched_ngram": event_dict.get(
RewardModelType.blacklist.value + "_matched_ngram"
),
"blacklist_filter_significance_score": event_dict.get(
RewardModelType.blacklist.value + "_significance_score"
),
"relevance_filter_bert_score": event_dict.get(
RewardModelType.relevance.value + "_bert_score"
),
"relevance_filter_mpnet_score": event_dict.get(
RewardModelType.relevance.value + "_mpnet_score"
),
"nsfw_filter_score": event_dict.get(RewardModelType.nsfw.value + "_score"),
}
penalties = {
"task_validation_penalty_raw": event_dict.get(
Expand Down
13 changes: 4 additions & 9 deletions prompting/validators/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,23 +123,21 @@ async def run_step(self, task: Task, k: int, timeout: float, exclude: list = [])
self.device
)
for weight_i, reward_fn_i in zip(self.reward_weights, self.reward_functions):
reward_i, reward_i_normalized = reward_fn_i.apply(
reward_i_normalized, reward_event = reward_fn_i.apply(
task.base_text, responses, task_name
)
rewards += weight_i * reward_i_normalized.to(self.device)
if not self.config.neuron.disable_log_rewards:
event[reward_fn_i.name] = reward_i.tolist()
event[reward_fn_i.name + "_normalized"] = reward_i_normalized.tolist()
event = {**event, **reward_event}
bt.logging.trace(str(reward_fn_i.name), reward_i_normalized.tolist())

for masking_fn_i in self.masking_functions:
mask_i, mask_i_normalized = masking_fn_i.apply(
mask_i_normalized, reward_event = masking_fn_i.apply(
task.base_text, responses, task_name
)
rewards *= mask_i_normalized.to(self.device) # includes diversity
if not self.config.neuron.disable_log_rewards:
event[masking_fn_i.name] = mask_i.tolist()
event[masking_fn_i.name + "_normalized"] = mask_i_normalized.tolist()
event = {**event, **reward_event}
bt.logging.trace(str(masking_fn_i.name), mask_i_normalized.tolist())

for penalty_fn_i in self.penalty_functions:
Expand Down Expand Up @@ -277,6 +275,3 @@ async def forward(self):
)

exclude += qa_event["uids"]

self.blacklist.question_blacklist.append(qg_event["best"])
self.blacklist.answer_blacklist.append(qa_event["best"])
2 changes: 1 addition & 1 deletion prompting/validators/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, mock_name: str = "MockReward"):

def apply(self, prompt: str, completion: List[str], name: str) -> torch.FloatTensor:
mock_reward = torch.tensor([1 for _ in completion], dtype=torch.float32)
return mock_reward, mock_reward
return mock_reward, {}

def reset(self):
return self
Expand Down
39 changes: 28 additions & 11 deletions prompting/validators/reward/blacklist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,22 @@
import torch
import math
from fuzzywuzzy import fuzz
from typing import List
from typing import List, Union
from .config import RewardModelType
from .reward import BaseRewardModel
from .reward import BaseRewardModel, BaseRewardEvent
from transformers import BertTokenizer
from dataclasses import dataclass


# TODO: Use CLI arguments to set blacklist values: the most important being the boundary value and max_size


@dataclass
class BlacklistRewardEvent(BaseRewardEvent):
matched_ngram: str = None
significance_score: float = None


class Blacklist(BaseRewardModel):
@property
def name(self) -> str:
Expand Down Expand Up @@ -263,7 +272,7 @@ def set_counter_to_half(self):
}
self._last_update = 0

def reward(self, prompt: str, completion: str, name: str) -> float:
def reward(self, prompt: str, completion: str, name: str) -> BlacklistRewardEvent:
"""Reward function for blacklist reward model. Returns 1 if completion contains an n-gram with significance above the boundary, 0 otherwise.
Args:
Expand All @@ -275,8 +284,11 @@ def reward(self, prompt: str, completion: str, name: str) -> float:
float: Reward value {0,1}
"""

reward_event = BlacklistRewardEvent()

if completion in prompt:
return 0.0
reward_event.reward = 0.0
return reward_event

# Get significance scores
scores = self.get_significance()
Expand All @@ -288,17 +300,22 @@ def reward(self, prompt: str, completion: str, name: str) -> float:
and fuzz.partial_ratio(ngram, completion.lower())
> self.partial_ratio_boundary
):
return 0
reward_event.reward = 0
reward_event.matched_ngram = ngram
reward_event.significance_score = score
return reward_event

return 1
reward_event.reward = 1
return reward_event

def get_rewards(
self, prompt: str, completions: List[str], name: str
) -> torch.FloatTensor:
return torch.tensor(
[self.reward(prompt, completion, name) for completion in completions],
dtype=torch.float32,
)
) -> List[BlacklistRewardEvent]:
# Get all the reward results.
reward_events = [
self.reward(prompt, completion, name) for completion in completions
]
return reward_events

def normalize_rewards(self, rewards: torch.FloatTensor) -> torch.FloatTensor:
return rewards
2 changes: 2 additions & 0 deletions prompting/validators/reward/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class RewardModelType(Enum):
blacklist = "blacklist_filter"
nsfw = "nsfw_filter"
relevance = "relevance_filter"
relevance_bert = "relevance_bert"
relevance_mpnet = "relevance_mpnet"
task_validator = "task_validator_filter"
keyword_match = "keyword_match_penalty"

Expand Down
31 changes: 19 additions & 12 deletions prompting/validators/reward/dahoas.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

import os
import torch
from typing import List
from typing import List, Union
from .config import RewardModelType
from .reward import BaseRewardModel
from .reward import BaseRewardModel, BaseRewardEvent
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig


Expand Down Expand Up @@ -63,10 +63,14 @@ def __init__(self, path: str, device: str):
self.tokenizer.pad_token = self.tokenizer.eos_token
self.PAD_ID = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0]

def reward(self, prompt: str, completion: str, name: str) -> float:
def reward(self, prompt: str, completion: str, name: str) -> BaseRewardEvent:
reward_event = BaseRewardEvent()

def reward_fn(samples):
if samples is None:
return 0
reward_event.reward = 0
return reward_event

scores_list = []
batch_size = 1
for i in range(0, len(samples), batch_size):
Expand All @@ -92,21 +96,24 @@ def reward_fn(samples):
attention_mask=attn_masks.to(self.device),
)
scores_list.append(sub_scores["chosen_end_scores"])
scores = torch.cat(scores_list, dim=0).mean().item()
return scores
score = torch.cat(scores_list, dim=0).mean().item()
return score

with torch.no_grad():
combined_reward = reward_fn(prompt + completion)
independent_reward = reward_fn(completion)
return float((combined_reward - independent_reward).item())
reward_event.reward = float((combined_reward - independent_reward).item())
return reward_event

def get_rewards(
self, prompt: str, completions: List[str], name: str
) -> torch.FloatTensor:
return torch.tensor(
[self.reward(prompt, completion, name) for completion in completions],
dtype=torch.float32,
).to(self.device)
) -> List[BaseRewardEvent]:
# Get all the reward results.
reward_events = [
self.reward(prompt, completion, name) for completion in completions
]

return reward_events

def forward(
self,
Expand Down
28 changes: 20 additions & 8 deletions prompting/validators/reward/diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@

import torch
import torch.nn.functional as F
from typing import List
from typing import List, Union
from .config import RewardModelType
from .reward import BaseRewardModel
from .reward import BaseRewardModel, BaseRewardEvent
from transformers import AutoTokenizer, AutoModel

from dataclasses import dataclass
from torchmetrics.functional import pairwise_cosine_similarity


Expand All @@ -48,6 +48,12 @@ def mean_pooling(model_output, attention_mask):
)


@dataclass
class DiversityRewardEvent(BaseRewardEvent):
historic: float = None
batch: float = None


class DiversityRewardModel(BaseRewardModel):
diversity_model_path = "sentence-transformers/all-mpnet-base-v2"

Expand Down Expand Up @@ -155,10 +161,10 @@ def regularise(rewards):

def get_rewards(
self, prompt: str, completions: List[str], name: str
) -> torch.FloatTensor:
) -> List[DiversityRewardEvent]:
# Check if completions are empty, return 0 if so
if len(completions) == 0:
return torch.tensor([]).to(self.device)
return torch.tensor([]).to(self.device), None

# Get embeddings for all completions.
embeddings = self.get_embeddings(completions)
Expand All @@ -171,11 +177,17 @@ def get_rewards(

self.update_historic_embeddings(embeddings)

# Return all
reward_events = []
if historic_rewards != None:
return batch_rewards * historic_rewards
for b, h in zip(batch_rewards.tolist(), historic_rewards.tolist()):
reward_events.append(
DiversityRewardEvent(reward=b * h, batch=b, historic=h)
)
else:
return batch_rewards
for b in batch_rewards.tolist():
reward_events.append(DiversityRewardEvent(reward=b, batch=b))

return reward_events

def normalize_rewards(self, raw_rewards: torch.FloatTensor) -> torch.FloatTensor:
# Applies binarization on the rewards.
Expand Down
Loading

0 comments on commit beab14a

Please sign in to comment.