Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync eval changes in OLMo/ladder-1xC to here #1

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
checks:
name: ${{ matrix.task.name }}
runs-on: [ubuntu-latest]
timeout-minutes: 5
timeout-minutes: 15
strategy:
fail-fast: false
matrix:
Expand Down
104 changes: 86 additions & 18 deletions src/olmo_eval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@ def __init__(self, metric_type="acc") -> None:
self.metric_type = metric_type

self.add_state("loglikelihoods", default=[], dist_reduce_fx=None)
self.add_state("celosses", default=[], dist_reduce_fx=None)
self.add_state("bpbs", default=[], dist_reduce_fx=None)
self.add_state("labels", default=[], dist_reduce_fx=None)

def reset(
self,
):
self.loglikelihoods = []
self.celosses = []
self.bpbs = []
self.labels = []

def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=None):
Expand All @@ -46,6 +50,8 @@ def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=No
]

log_likelihood: torch.Tensor
celoss: torch.Tensor
bpb: torch.Tensor
if self.metric_type == "pmi_dc":
assert dc_lm_logits is not None
# get domain conditional continuation logits: [cont_len, vocab]
Expand All @@ -58,19 +64,30 @@ def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=No
torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
/ torch.gather(dc_lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
)
celoss = -log_likelihood
bpb = -log_likelihood # the normalization factors cancel out
elif self.metric_type == "acc" or self.metric_type == "f1":
# gather log-probs at continuation token indices
log_likelihood = torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
elif self.metric_type == "len_norm" or self.metric_type == "ce_loss":
celoss = (
-torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
/ batch["cont_str_len"][idx]
)
bpb = (
-torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
/ batch["cont_byte_len"][idx]
* LOG_2_OF_E
)
elif self.metric_type in ["len_norm", "ce_loss", "bpb"]:
log_likelihood = (
torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
/ batch["cont_str_len"][idx]
)
if self.metric_type == "ce_loss":
log_likelihood = -log_likelihood
elif self.metric_type == "bpb":
# bits per byte
log_likelihood = (
celoss = (
-torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
/ batch["cont_str_len"][idx]
)
bpb = (
-torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
/ batch["cont_byte_len"][idx]
* LOG_2_OF_E
Expand All @@ -84,16 +101,26 @@ def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=No
batch["continuation"][idx].device
)
)
self.celosses.append(
torch.Tensor((doc_id, cont_id, celoss)).to(batch["continuation"][idx].device)
)
self.bpbs.append(
torch.Tensor((doc_id, cont_id, bpb)).to(batch["continuation"][idx].device)
)
self.labels.append(
torch.LongTensor((doc_id, cont_id, batch["label_id"][idx])).to(
batch["label_id"][idx].device
)
)

def compute(self) -> torch.Tensor:
def compute(self) -> Dict[str, torch.Tensor]:
# Task "suffix" -> tensor

# states should have been synced from all accelerators at this point
# account for duplicates here because of DistributedSampler compensating for drop_last=False
loglikelihood_dict: Dict[int, Dict[int, float]] = {}
celoss_dict: Dict[int, Dict[int, float]] = {}
bpb_dict: Dict[int, Dict[int, float]] = {}
label_dict = {}

# collect labels
Expand All @@ -109,8 +136,29 @@ def compute(self) -> torch.Tensor:
if int(cont_id.item()) not in loglikelihood_dict[int(doc_id.item())]:
loglikelihood_dict[int(doc_id.item())][int(cont_id.item())] = loglikelihood

# collect celosses
for doc_id, cont_id, celoss in self.celosses:
if int(doc_id.item()) not in celoss_dict:
celoss_dict[int(doc_id.item())] = {}

if int(cont_id.item()) not in celoss_dict[int(doc_id.item())]:
celoss_dict[int(doc_id.item())][int(cont_id.item())] = celoss

# collect bpbs
for doc_id, cont_id, bpb in self.bpbs:
if int(doc_id.item()) not in bpb_dict:
bpb_dict[int(doc_id.item())] = {}

if int(cont_id.item()) not in bpb_dict[int(doc_id.item())]:
bpb_dict[int(doc_id.item())][int(cont_id.item())] = bpb

# compute acc
correct = []
loglikelihood = []
celoss = []
bpb = []
soft_score = []
soft_log_score = []
preds: Optional[List[float]] = None
labels: Optional[List[int]] = None
if self.metric_type == "f1":
Expand All @@ -121,37 +169,57 @@ def compute(self) -> torch.Tensor:
# each doc_id might have a different number of continuation
num_continuations = len(loglikelihood_dict[doc_id].keys())
loglikelihoods = torch.tensor([-float("inf")] * num_continuations)
celosses = torch.tensor([float("inf")] * num_continuations)
bpbs = torch.tensor([float("inf")] * num_continuations)

skip_document = False
for cont_id in loglikelihood_dict[doc_id]:
try:
loglikelihoods[cont_id] = loglikelihood_dict[doc_id][cont_id]
celosses[cont_id] = celoss_dict[doc_id][cont_id]
bpbs[cont_id] = bpb_dict[doc_id][cont_id]
except IndexError:
# We didn't process all of the continuations, so skip this document.
skip_document = True
break

if skip_document:
continue
if self.metric_type in ["ce_loss", "bpb"]:
correct.append(loglikelihoods[0]) # Only one answer is scored
else:
correct.append(
1.0 if torch.argmax(loglikelihoods).item() == label_dict[doc_id] else 0.0
)

if self.metric_type == "f1":
if self.metric_type == "ce_loss":
celoss.append(celosses[0]) # Only one answer is scored
elif self.metric_type == "bpb":
bpb.append(bpbs[0]) # Only one answer is scored
elif self.metric_type == "f1":
assert preds is not None
assert labels is not None
preds.append(torch.argmax(loglikelihoods).item())
labels.append(label_dict[doc_id])
else:
correct.append(
1.0 if torch.argmax(loglikelihoods).item() == label_dict[doc_id] else 0.0
)
celoss.append(celosses[label_dict[doc_id]].item())
bpb.append(bpbs[label_dict[doc_id]].item())
soft_score.append(torch.softmax(loglikelihoods, dim=0)[label_dict[doc_id]].item())
soft_log_score.append(
torch.log_softmax(loglikelihoods, dim=0)[label_dict[doc_id]].item()
)

if self.metric_type == "f1":
assert preds is not None
assert labels is not None
# for NLI tasks, continuations are yes, no, neither, so idx=0 assigned to pos label
score = f1_score(labels, preds, pos_label=0)
return {"f1": torch.tensor(score)}
elif self.metric_type == "ce_loss":
return {"ce_loss": torch.tensor(sum(celoss) / len(celoss))}
elif self.metric_type == "bpb":
return {"bpb": torch.tensor(sum(bpb) / len(bpb))}
else:
score = sum(correct) / len(correct)

return torch.tensor(score)
return {
self.metric_type: torch.tensor(sum(correct) / len(correct)),
"ce_loss": torch.tensor(sum(celoss) / len(celoss)),
"bpb": torch.tensor(sum(bpb) / len(bpb)),
"soft": torch.tensor(sum(soft_score) / len(soft_score)),
"soft_log": torch.tensor(sum(soft_log_score) / len(soft_log_score)),
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new, richer metric output (along with the correlated changes in olmo-core's evaluator_callback.py) will change the recorded metrics for "acc" and "len_norm" metrics. Do we know that this won't interfere with other parts of the setup? (logging etc)

In general, it's definitely the right thing to do though, to compute these in a single go, rather than as separate versions of the same task (in fact, could also have "acc" and "len_norm" together), other than making it less clear which one is the "official" metric for each task.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that makes a lot of sense! I have updated my edits in olmo-core so that it displays the mapped metric type text.

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_challenge:mc", "task_hash": "11abfade7ecce501f3e3e72c937e19cc", "task_config": {"task_name": "arc_challenge:mc", "task_core": "arc_challenge", "limit": 1000000, "split": "test", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_raw", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Challenge", "dataset_path": "ai2_arc", "dataset_name": "ARC-Challenge", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"regimes": ["OLMES-v0.1"], "alias": "arc_challenge:mc::olmes"}}, "current_date": "2024-11-18 22:05:58 UTC", "num_instances": 1172}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_challenge", "task_hash": "b122d520ab0cf70114350ecf00c5c811", "task_config": {"task_name": "arc_challenge", "task_core": "arc_challenge", "limit": 1000000, "split": "test", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_per_char", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {"uncond_docid_offset": 1000000}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Challenge", "dataset_path": "ai2_arc", "dataset_name": "ARC-Challenge", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"regimes": ["OLMES-v0.1"], "alias": "arc_challenge:rc::olmes"}}, "current_date": "2024-11-18 21:50:18 UTC", "num_instances": 1172}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_challenge:mc", "task_hash": "cf2769a2dc6cbea724ff477c3d2543a2", "task_config": {"task_name": "arc_challenge:mc", "task_core": "arc_challenge", "limit": 1000000, "split": "train", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_raw", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Challenge", "dataset_path": "ai2_arc", "dataset_name": "ARC-Challenge", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"regimes": ["OLMES-v0.1"], "alias": "arc_challenge:mc::olmes"}}, "current_date": "2024-11-18 22:05:40 UTC", "num_instances": 1119}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_challenge", "task_hash": "9045ed0bd68a7e9ff34cf51ff24828bf", "task_config": {"task_name": "arc_challenge", "task_core": "arc_challenge", "limit": 1000000, "split": "train", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_per_char", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {"uncond_docid_offset": 1000000}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Challenge", "dataset_path": "ai2_arc", "dataset_name": "ARC-Challenge", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"regimes": ["OLMES-v0.1"], "alias": "arc_challenge:rc::olmes"}}, "current_date": "2024-11-18 22:05:31 UTC", "num_instances": 1119}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_challenge:mc", "task_hash": "a673d7761ce3fc3d5061d72f76755971", "task_config": {"task_name": "arc_challenge:mc", "task_core": "arc_challenge", "limit": 1000000, "split": "validation", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_raw", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Challenge", "dataset_path": "ai2_arc", "dataset_name": "ARC-Challenge", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"regimes": ["OLMES-v0.1"], "alias": "arc_challenge:mc::olmes"}}, "current_date": "2024-11-18 22:05:49 UTC", "num_instances": 299}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_challenge", "task_hash": "bd181c90c43b3ef799af2f300ea09cf1", "task_config": {"task_name": "arc_challenge", "task_core": "arc_challenge", "limit": 1000000, "split": "validation", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_per_char", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {"uncond_docid_offset": 1000000}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Challenge", "dataset_path": "ai2_arc", "dataset_name": "ARC-Challenge", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"regimes": ["OLMES-v0.1"], "alias": "arc_challenge:rc::olmes"}}, "current_date": "2024-11-18 21:45:07 UTC", "num_instances": 299}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_easy:mc", "task_hash": "64250ca6fdf0f02e07b539e8efc04922", "task_config": {"task_name": "arc_easy:mc", "task_core": "arc_easy", "limit": 1000000, "split": "test", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_raw", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Easy", "dataset_path": "ai2_arc", "dataset_name": "ARC-Easy", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"description": "ARC-Easy (MC) using OLMES-v0.1", "regimes": ["OLMES-v0.1"], "alias": "arc_easy:mc::olmes"}}, "current_date": "2024-11-18 22:06:33 UTC", "num_instances": 2376}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_easy", "task_hash": "ccbbd993c851d3300140d81ffec0e397", "task_config": {"task_name": "arc_easy", "task_core": "arc_easy", "limit": 1000000, "split": "test", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_per_char", "random_subsample_seed": 1234, "context_kwargs": {"description": null}, "generation_kwargs": {}, "metric_kwargs": {"uncond_docid_offset": 1000000}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Easy", "dataset_path": "ai2_arc", "dataset_name": "ARC-Easy", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"description": "ARC-Easy (RC) using OLMES-v0.1", "regimes": ["OLMES-v0.1"], "alias": "arc_easy:rc::olmes"}}, "current_date": "2024-11-18 21:50:27 UTC", "num_instances": 2376}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_easy:mc", "task_hash": "afa7e96b485c4e4481b3b9b817faac36", "task_config": {"task_name": "arc_easy:mc", "task_core": "arc_easy", "limit": 1000000, "split": "train", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_raw", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Easy", "dataset_path": "ai2_arc", "dataset_name": "ARC-Easy", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"description": "ARC-Easy (MC) using OLMES-v0.1", "regimes": ["OLMES-v0.1"], "alias": "arc_easy:mc::olmes"}}, "current_date": "2024-11-18 22:06:16 UTC", "num_instances": 2251}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_easy", "task_hash": "4a5241b308edb45d7b9eab594093c519", "task_config": {"task_name": "arc_easy", "task_core": "arc_easy", "limit": 1000000, "split": "train", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_per_char", "random_subsample_seed": 1234, "context_kwargs": {"description": null}, "generation_kwargs": {}, "metric_kwargs": {"uncond_docid_offset": 1000000}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Easy", "dataset_path": "ai2_arc", "dataset_name": "ARC-Easy", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"description": "ARC-Easy (RC) using OLMES-v0.1", "regimes": ["OLMES-v0.1"], "alias": "arc_easy:rc::olmes"}}, "current_date": "2024-11-18 22:06:07 UTC", "num_instances": 2251}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_easy:mc", "task_hash": "443bd52f752399615d01c853a8d7386c", "task_config": {"task_name": "arc_easy:mc", "task_core": "arc_easy", "limit": 1000000, "split": "validation", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_raw", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Easy", "dataset_path": "ai2_arc", "dataset_name": "ARC-Easy", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"description": "ARC-Easy (MC) using OLMES-v0.1", "regimes": ["OLMES-v0.1"], "alias": "arc_easy:mc::olmes"}}, "current_date": "2024-11-18 22:06:24 UTC", "num_instances": 570}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_easy", "task_hash": "0045e4f588a617cbe9ee5a4ae8ca1ce5", "task_config": {"task_name": "arc_easy", "task_core": "arc_easy", "limit": 1000000, "split": "validation", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_per_char", "random_subsample_seed": 1234, "context_kwargs": {"description": null}, "generation_kwargs": {}, "metric_kwargs": {"uncond_docid_offset": 1000000}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Easy", "dataset_path": "ai2_arc", "dataset_name": "ARC-Easy", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"description": "ARC-Easy (RC) using OLMES-v0.1", "regimes": ["OLMES-v0.1"], "alias": "arc_easy:rc::olmes"}}, "current_date": "2024-11-18 21:45:15 UTC", "num_instances": 570}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "boolq:mc", "task_hash": "a92ca849d7efd331110145eb71e4fc09", "task_config": {"task_name": "boolq:mc", "task_core": "boolq", "limit": 1000000, "split": "train", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_raw", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {"uncond_docid_offset": null}, "native_id_field": "idx", "fewshot_source": "OLMES:BoolQ", "dataset_path": "super_glue", "dataset_name": "boolq", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"regimes": ["OLMES-v0.1"], "alias": "boolq:mc::olmes"}}, "current_date": "2024-11-18 22:06:52 UTC", "num_instances": 9427}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "boolq", "task_hash": "ec8729b372d310aaf3a222f37a7af7b9", "task_config": {"task_name": "boolq", "task_core": "boolq", "limit": 1000000, "split": "train", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_raw", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {"uncond_docid_offset": null}, "native_id_field": "idx", "fewshot_source": "OLMES:BoolQ", "dataset_path": "super_glue", "dataset_name": "boolq", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"regimes": ["OLMES-v0.1"], "alias": "boolq:rc::olmes"}}, "current_date": "2024-11-18 22:06:41 UTC", "num_instances": 9427}
Binary file not shown.
1 change: 1 addition & 0 deletions src/olmo_eval/oe_eval_tasks/boolq/val_mc_5shot/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "boolq:mc", "task_hash": "d88f45757f4a8c3802b7274857894a90", "task_config": {"task_name": "boolq:mc", "task_core": "boolq", "limit": 1000000, "split": "validation", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_raw", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {"uncond_docid_offset": null}, "native_id_field": "idx", "fewshot_source": "OLMES:BoolQ", "dataset_path": "super_glue", "dataset_name": "boolq", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"regimes": ["OLMES-v0.1"], "alias": "boolq:mc::olmes"}}, "current_date": "2024-11-18 22:07:01 UTC", "num_instances": 3270}
Binary file not shown.
Loading
Loading