Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Now reasoner branch reward #181

Merged
merged 10 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions conf/rl_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ task_template: |-
overflow_reward: 0
max_prompt_length: 1024

rewards:
unparsable: 0
wrong_answer: 0
correct_answer: 1

vllm_config:
vllm_kwargs:
--download-dir: /mnt/llmd/base_models/
Expand Down
3 changes: 3 additions & 0 deletions examples/rl_gsm8k/deepseek_math_eval/process_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def process_gsm8k_test(item):

def process_math_test(item):
question = item["problem"]
if "subject" in item and "type" not in item:
item["type"] = item["subject"]

try:
answer = extract_math_answer(question, item["solution"], task="cot")
except Exception:
Expand Down
22 changes: 13 additions & 9 deletions examples/rl_gsm8k/orchestrate_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import wandb
from tapeagents.agent import Agent
from tapeagents.core import LLMCall, LLMOutputParsingFailureAction, StepMetadata, TrainingText
from tapeagents.core import LLMCall, StepMetadata, TrainingText
from tapeagents.finetune.data import MASKED_TOKEN_ID
from tapeagents.finetune.logging_ import flatten_dict_config, init_wandb
from tapeagents.llms import TrainableLLM
Expand All @@ -38,8 +38,10 @@
def load_datasets(cfg: DictConfig) -> Tuple[list, list]:
match cfg.dataset_name:
case "math":
train_dataset_long_name = test_dataset_long_name = "hendrycks/competition_math"
train_dataset_long_name = "hendrycks/competition_math"
test_dataset_long_name = "HuggingFaceH4/MATH-500"
process_fn = process_math_test
test_builder_config = "default"
builder_config = "main"
case "gsm8k":
train_dataset_long_name = test_dataset_long_name = "openai/gsm8k"
Expand All @@ -53,8 +55,9 @@ def load_datasets(cfg: DictConfig) -> Tuple[list, list]:
case _:
raise ValueError(f"Unknown dataset: {cfg.dataset_name}")

test_builder_config = test_builder_config or builder_config
train_dataset = load_dataset(train_dataset_long_name, builder_config, split="train", trust_remote_code=True)
test_dataset = load_dataset(test_dataset_long_name, builder_config, split="test", trust_remote_code=True)
test_dataset = load_dataset(test_dataset_long_name, test_builder_config, split="test", trust_remote_code=True)
train_samples = [
process_fn(s) for s in tqdm(train_dataset, desc="Processing train samples") if process_fn(s) is not None
]
Expand Down Expand Up @@ -145,10 +148,11 @@ def extract_tape_training_samples(
case _:
raise ValueError(f"Unknown dataset: {cfg.dataset_name}")

if any([isinstance(step, LLMOutputParsingFailureAction) for step in new_tape.steps]):
# LLM produced a step that was unparsable. Negative reward.
no_error, reward, success = 0, -1, 0
if "\\boxed" not in new_tape.steps[-1].reasoning:
# LLM did not respect the formatting
no_error, success, reward = 0, 0, cfg.rewards.unparsable
else:
# LLM did respect the formatting
no_error = 1
prediction = extract_fn(new_tape.steps[0].task, new_tape.steps[-1].reasoning, "cot") # type: ignore
answer = new_tape.steps[0].metadata.other["value"]
Expand All @@ -159,10 +163,10 @@ def extract_tape_training_samples(
}
):
# Correct answer
reward, success = 1, 1
reward, success = cfg.rewards.correct_answer, 1
else:
# Incorrect answer or no answer
reward, success = 0, 0
reward, success = cfg.rewards.wrong_answer, 0

training_samples: list[TrainingText] = []
# For each LLM interaction in the tape:
Expand Down Expand Up @@ -194,7 +198,7 @@ def extract_tape_training_samples(

# check if the last produced token is the end of sequence token
overflow = False if input_ids[-1] == agent.llm.tokenizer.eos_token_id else True
trace.reward = cfg.overflow_reward if overflow else reward
trace.reward = cfg.rewards.unparsable if overflow else reward
overflows.append(overflow)
trace.logprobs = [lp.logprob for lp in llm_call.logprobs if lp.generated]
trace.group_id = new_tape.metadata.parent_id
Expand Down