Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/track raw score #63

Merged
merged 27 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1ba9777
defining rewardresult dataclass and reward event
isabella618033 Nov 7, 2023
22a201f
moved event addition into reward model apply function
isabella618033 Nov 7, 2023
270fb93
clean up relevence
isabella618033 Nov 7, 2023
e0b0f9d
apply to blaclist
isabella618033 Nov 7, 2023
0aeddc2
fixes
isabella618033 Nov 7, 2023
7330314
changed get_reward returns for all
isabella618033 Nov 8, 2023
e1023a5
added BaseRewardEvent
isabella618033 Nov 8, 2023
6022706
update event schema
isabella618033 Nov 8, 2023
e374246
black format
isabella618033 Nov 8, 2023
d76bfd2
fix mock
isabella618033 Nov 8, 2023
c9bbae2
get rewards -> List[BaseRewardEvent]
isabella618033 Nov 8, 2023
21bbd62
schema update
isabella618033 Nov 8, 2023
1d033af
black formatting
isabella618033 Nov 8, 2023
9b318f0
Merge branch 'features/ngram-blacklist' into feature/track_raw_score
isabella618033 Nov 8, 2023
27badfc
black formatting
isabella618033 Nov 8, 2023
abd0f17
retain comments
isabella618033 Nov 8, 2023
00fe827
retain comments
isabella618033 Nov 8, 2023
6b02563
retain comments
isabella618033 Nov 8, 2023
2794c32
retain comments
isabella618033 Nov 8, 2023
3531257
retain comments
isabella618033 Nov 8, 2023
fda08af
retain comments
isabella618033 Nov 8, 2023
c68a22a
retain comments
isabella618033 Nov 8, 2023
820e6a5
retain comments
isabella618033 Nov 8, 2023
b300460
fixes
isabella618033 Nov 9, 2023
93dc591
Merge branch 'feature/track_raw_score' of https://github.com/opentens…
isabella618033 Nov 9, 2023
00a201b
black format
isabella618033 Nov 9, 2023
07882b7
black formatted
isabella618033 Nov 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions prompting/validators/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,25 @@ class EventSchema:
List[float]
] # Output vector of the dahoas reward model
blacklist_filter: Optional[List[float]] # Output vector of the blacklist filter
blacklist_filter_matched_ngram: Optional[
List[str]
] # Output vector of the blacklist filter
blacklist_filter_significance_score: Optional[
List[float]
] # Output vector of the blacklist filter
nsfw_filter: Optional[List[float]] # Output vector of the nsfw filter
reciprocate_reward_model: Optional[
List[float]
] # Output vector of the reciprocate reward model
diversity_reward_model: Optional[
List[float]
] # Output vector of the diversity reward model
diversity_reward_model_historic: Optional[
List[float]
] # Output vector of the diversity reward model
diversity_reward_model_batch: Optional[
List[float]
] # Output vector of the diversity reward model
dpo_reward_model: Optional[List[float]] # Output vector of the dpo reward model
rlhf_reward_model: Optional[List[float]] # Output vector of the rlhf reward model
prompt_reward_model: Optional[
Expand All @@ -68,6 +80,7 @@ class EventSchema:
List[float]
] # Output vector of the dahoas reward model
nsfw_filter_normalized: Optional[List[float]] # Output vector of the nsfw filter
nsfw_filter_score: Optional[List[float]] # Output vector of the nsfw filter
reciprocate_reward_model_normalized: Optional[
List[float]
] # Output vector of the reciprocate reward model
Expand All @@ -86,6 +99,12 @@ class EventSchema:
relevance_filter_normalized: Optional[
List[float]
] # Output vector of the relevance scoring reward model
relevance_filter_bert_score: Optional[
List[float]
] # Output vector of the relevance scoring reward model
relevance_filter_mpnet_score: Optional[
List[float]
] # Output vector of the relevance scoring reward model
task_validator_filter_normalized: Optional[List[float]]

# Weights data
Expand All @@ -106,6 +125,8 @@ def from_dict(event_dict: dict, disable_log_rewards: bool) -> "EventSchema":
RewardModelType.reciprocate.value
),
"diversity_reward_model": event_dict.get(RewardModelType.diversity.value),
"diversity_reward_model_historic": event_dict.get(RewardModelType.diversity.value + '_historic'),
"diversity_reward_model_batch": event_dict.get(RewardModelType.diversity.value + '_batch'),
"dpo_reward_model": event_dict.get(RewardModelType.dpo.value),
"rlhf_reward_model": event_dict.get(RewardModelType.rlhf.value),
"prompt_reward_model": event_dict.get(RewardModelType.prompt.value),
Expand Down Expand Up @@ -136,6 +157,19 @@ def from_dict(event_dict: dict, disable_log_rewards: bool) -> "EventSchema":
"prompt_reward_model_normalized": event_dict.get(
RewardModelType.prompt.value + "_normalized"
),
"blacklist_filter_matched_ngram": event_dict.get(
RewardModelType.blacklist.value + "_matched_ngram"
),
"blacklist_filter_significance_score": event_dict.get(
RewardModelType.blacklist.value + "_significance_score"
),
"relevance_filter_bert_score": event_dict.get(
RewardModelType.relevance.value + "_bert_score"
),
"relevance_filter_mpnet_score": event_dict.get(
RewardModelType.relevance.value + "_mpnet_score"
),
"nsfw_filter_score": event_dict.get(RewardModelType.nsfw.value + "_score"),
}

# Logs warning that expected data was not set properly
Expand Down
16 changes: 9 additions & 7 deletions prompting/validators/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ async def run_step(
)

# Update blacklist with completions so that n-gram filtering can be applied
self.blacklist.add([response.completion for response in responses if response.completion])
self.blacklist.add(
[response.completion for response in responses if response.completion]
)

# Restrict the format of acceptable followup completions.
for response in responses:
Expand All @@ -122,19 +124,19 @@ async def run_step(
self.device
)
for weight_i, reward_fn_i in zip(self.reward_weights, self.reward_functions):
reward_i, reward_i_normalized = reward_fn_i.apply(prompt, responses, name)
reward_i_normalized, reward_event = reward_fn_i.apply(prompt, responses, name)
p-ferreira marked this conversation as resolved.
Show resolved Hide resolved
rewards += weight_i * reward_i_normalized.to(self.device)
if not self.config.neuron.disable_log_rewards:
event[reward_fn_i.name] = reward_i.tolist()
event[reward_fn_i.name + "_normalized"] = reward_i_normalized.tolist()
event = {**event, **reward_event}
bt.logging.trace(str(reward_fn_i.name), reward_i_normalized.tolist())

for masking_fn_i in self.masking_functions:
mask_i, mask_i_normalized = masking_fn_i.apply(base_prompt, responses, name)
mask_i_normalized, reward_event = masking_fn_i.apply(
base_prompt, responses, name
)
rewards *= mask_i_normalized.to(self.device) # includes diversity
if not self.config.neuron.disable_log_rewards:
event[masking_fn_i.name] = mask_i.tolist()
event[masking_fn_i.name + "_normalized"] = mask_i_normalized.tolist()
event = {**event, **reward_event}
bt.logging.trace(str(masking_fn_i.name), mask_i_normalized.tolist())

# Train the gating model based on the predicted scores and the actual rewards.
Expand Down
2 changes: 1 addition & 1 deletion prompting/validators/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, mock_name: str = "MockReward"):

def apply(self, prompt: str, completion: List[str], name: str) -> torch.FloatTensor:
mock_reward = torch.tensor([1 for _ in completion], dtype=torch.float32)
return mock_reward, mock_reward
return mock_reward, {}
isabella618033 marked this conversation as resolved.
Show resolved Hide resolved

def reset(self):
return self
Expand Down
Loading