Skip to content

Commit

Permalink
Sync eval changes in OLMo/ladder-1xC to here (#1)
Browse files Browse the repository at this point in the history
* Sync eval changes in OLMo/ladder-1xC to here

* Lint

* Increase github workflow timeout to 15min

* Increase github workflow timeout to 30min
  • Loading branch information
liujch1998 authored Dec 18, 2024
1 parent b2a5b4d commit 5f3db3c
Show file tree
Hide file tree
Showing 87 changed files with 403 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
checks:
name: ${{ matrix.task.name }}
runs-on: [ubuntu-latest]
timeout-minutes: 5
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
Expand Down
104 changes: 86 additions & 18 deletions src/olmo_eval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@ def __init__(self, metric_type="acc") -> None:
self.metric_type = metric_type

self.add_state("loglikelihoods", default=[], dist_reduce_fx=None)
self.add_state("celosses", default=[], dist_reduce_fx=None)
self.add_state("bpbs", default=[], dist_reduce_fx=None)
self.add_state("labels", default=[], dist_reduce_fx=None)

def reset(
self,
):
self.loglikelihoods = []
self.celosses = []
self.bpbs = []
self.labels = []

def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=None):
Expand All @@ -46,6 +50,8 @@ def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=No
]

log_likelihood: torch.Tensor
celoss: torch.Tensor
bpb: torch.Tensor
if self.metric_type == "pmi_dc":
assert dc_lm_logits is not None
# get domain conditional continuation logits: [cont_len, vocab]
Expand All @@ -58,19 +64,30 @@ def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=No
torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
/ torch.gather(dc_lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
)
celoss = -log_likelihood
bpb = -log_likelihood # the normalization factors cancel out
elif self.metric_type == "acc" or self.metric_type == "f1":
# gather log-probs at continuation token indices
log_likelihood = torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
elif self.metric_type == "len_norm" or self.metric_type == "ce_loss":
celoss = (
-torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
/ batch["cont_str_len"][idx]
)
bpb = (
-torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
/ batch["cont_byte_len"][idx]
* LOG_2_OF_E
)
elif self.metric_type in ["len_norm", "ce_loss", "bpb"]:
log_likelihood = (
torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
/ batch["cont_str_len"][idx]
)
if self.metric_type == "ce_loss":
log_likelihood = -log_likelihood
elif self.metric_type == "bpb":
# bits per byte
log_likelihood = (
celoss = (
-torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
/ batch["cont_str_len"][idx]
)
bpb = (
-torch.gather(lm_cont_logits, 1, cont_tokens.unsqueeze(-1)).sum()
/ batch["cont_byte_len"][idx]
* LOG_2_OF_E
Expand All @@ -84,16 +101,26 @@ def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=No
batch["continuation"][idx].device
)
)
self.celosses.append(
torch.Tensor((doc_id, cont_id, celoss)).to(batch["continuation"][idx].device)
)
self.bpbs.append(
torch.Tensor((doc_id, cont_id, bpb)).to(batch["continuation"][idx].device)
)
self.labels.append(
torch.LongTensor((doc_id, cont_id, batch["label_id"][idx])).to(
batch["label_id"][idx].device
)
)

def compute(self) -> torch.Tensor:
def compute(self) -> Dict[str, torch.Tensor]:
# Task "suffix" -> tensor

# states should have been synced from all accelerators at this point
# account for duplicates here because of DistributedSampler compensating for drop_last=False
loglikelihood_dict: Dict[int, Dict[int, float]] = {}
celoss_dict: Dict[int, Dict[int, float]] = {}
bpb_dict: Dict[int, Dict[int, float]] = {}
label_dict = {}

# collect labels
Expand All @@ -109,8 +136,29 @@ def compute(self) -> torch.Tensor:
if int(cont_id.item()) not in loglikelihood_dict[int(doc_id.item())]:
loglikelihood_dict[int(doc_id.item())][int(cont_id.item())] = loglikelihood

# collect celosses
for doc_id, cont_id, celoss in self.celosses:
if int(doc_id.item()) not in celoss_dict:
celoss_dict[int(doc_id.item())] = {}

if int(cont_id.item()) not in celoss_dict[int(doc_id.item())]:
celoss_dict[int(doc_id.item())][int(cont_id.item())] = celoss

# collect bpbs
for doc_id, cont_id, bpb in self.bpbs:
if int(doc_id.item()) not in bpb_dict:
bpb_dict[int(doc_id.item())] = {}

if int(cont_id.item()) not in bpb_dict[int(doc_id.item())]:
bpb_dict[int(doc_id.item())][int(cont_id.item())] = bpb

# compute acc
correct = []
loglikelihood = []
celoss = []
bpb = []
soft_score = []
soft_log_score = []
preds: Optional[List[float]] = None
labels: Optional[List[int]] = None
if self.metric_type == "f1":
Expand All @@ -121,37 +169,57 @@ def compute(self) -> torch.Tensor:
# each doc_id might have a different number of continuation
num_continuations = len(loglikelihood_dict[doc_id].keys())
loglikelihoods = torch.tensor([-float("inf")] * num_continuations)
celosses = torch.tensor([float("inf")] * num_continuations)
bpbs = torch.tensor([float("inf")] * num_continuations)

skip_document = False
for cont_id in loglikelihood_dict[doc_id]:
try:
loglikelihoods[cont_id] = loglikelihood_dict[doc_id][cont_id]
celosses[cont_id] = celoss_dict[doc_id][cont_id]
bpbs[cont_id] = bpb_dict[doc_id][cont_id]
except IndexError:
# We didn't process all of the continuations, so skip this document.
skip_document = True
break

if skip_document:
continue
if self.metric_type in ["ce_loss", "bpb"]:
correct.append(loglikelihoods[0]) # Only one answer is scored
else:
correct.append(
1.0 if torch.argmax(loglikelihoods).item() == label_dict[doc_id] else 0.0
)

if self.metric_type == "f1":
if self.metric_type == "ce_loss":
celoss.append(celosses[0]) # Only one answer is scored
elif self.metric_type == "bpb":
bpb.append(bpbs[0]) # Only one answer is scored
elif self.metric_type == "f1":
assert preds is not None
assert labels is not None
preds.append(torch.argmax(loglikelihoods).item())
labels.append(label_dict[doc_id])
else:
correct.append(
1.0 if torch.argmax(loglikelihoods).item() == label_dict[doc_id] else 0.0
)
celoss.append(celosses[label_dict[doc_id]].item())
bpb.append(bpbs[label_dict[doc_id]].item())
soft_score.append(torch.softmax(loglikelihoods, dim=0)[label_dict[doc_id]].item())
soft_log_score.append(
torch.log_softmax(loglikelihoods, dim=0)[label_dict[doc_id]].item()
)

if self.metric_type == "f1":
assert preds is not None
assert labels is not None
# for NLI tasks, continuations are yes, no, neither, so idx=0 assigned to pos label
score = f1_score(labels, preds, pos_label=0)
return {"f1": torch.tensor(score)}
elif self.metric_type == "ce_loss":
return {"ce_loss": torch.tensor(sum(celoss) / len(celoss))}
elif self.metric_type == "bpb":
return {"bpb": torch.tensor(sum(bpb) / len(bpb))}
else:
score = sum(correct) / len(correct)

return torch.tensor(score)
return {
self.metric_type: torch.tensor(sum(correct) / len(correct)),
"ce_loss": torch.tensor(sum(celoss) / len(celoss)),
"bpb": torch.tensor(sum(bpb) / len(bpb)),
"soft": torch.tensor(sum(soft_score) / len(soft_score)),
"soft_log": torch.tensor(sum(soft_log_score) / len(soft_log_score)),
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_challenge:mc", "task_hash": "11abfade7ecce501f3e3e72c937e19cc", "task_config": {"task_name": "arc_challenge:mc", "task_core": "arc_challenge", "limit": 1000000, "split": "test", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_raw", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Challenge", "dataset_path": "ai2_arc", "dataset_name": "ARC-Challenge", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"regimes": ["OLMES-v0.1"], "alias": "arc_challenge:mc::olmes"}}, "current_date": "2024-11-18 22:05:58 UTC", "num_instances": 1172}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_challenge", "task_hash": "b122d520ab0cf70114350ecf00c5c811", "task_config": {"task_name": "arc_challenge", "task_core": "arc_challenge", "limit": 1000000, "split": "test", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_per_char", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {"uncond_docid_offset": 1000000}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Challenge", "dataset_path": "ai2_arc", "dataset_name": "ARC-Challenge", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"regimes": ["OLMES-v0.1"], "alias": "arc_challenge:rc::olmes"}}, "current_date": "2024-11-18 21:50:18 UTC", "num_instances": 1172}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_challenge:mc", "task_hash": "cf2769a2dc6cbea724ff477c3d2543a2", "task_config": {"task_name": "arc_challenge:mc", "task_core": "arc_challenge", "limit": 1000000, "split": "train", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_raw", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Challenge", "dataset_path": "ai2_arc", "dataset_name": "ARC-Challenge", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"regimes": ["OLMES-v0.1"], "alias": "arc_challenge:mc::olmes"}}, "current_date": "2024-11-18 22:05:40 UTC", "num_instances": 1119}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_challenge", "task_hash": "9045ed0bd68a7e9ff34cf51ff24828bf", "task_config": {"task_name": "arc_challenge", "task_core": "arc_challenge", "limit": 1000000, "split": "train", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_per_char", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {"uncond_docid_offset": 1000000}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Challenge", "dataset_path": "ai2_arc", "dataset_name": "ARC-Challenge", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"regimes": ["OLMES-v0.1"], "alias": "arc_challenge:rc::olmes"}}, "current_date": "2024-11-18 22:05:31 UTC", "num_instances": 1119}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_challenge:mc", "task_hash": "a673d7761ce3fc3d5061d72f76755971", "task_config": {"task_name": "arc_challenge:mc", "task_core": "arc_challenge", "limit": 1000000, "split": "validation", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_raw", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Challenge", "dataset_path": "ai2_arc", "dataset_name": "ARC-Challenge", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"regimes": ["OLMES-v0.1"], "alias": "arc_challenge:mc::olmes"}}, "current_date": "2024-11-18 22:05:49 UTC", "num_instances": 299}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_challenge", "task_hash": "bd181c90c43b3ef799af2f300ea09cf1", "task_config": {"task_name": "arc_challenge", "task_core": "arc_challenge", "limit": 1000000, "split": "validation", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_per_char", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {"uncond_docid_offset": 1000000}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Challenge", "dataset_path": "ai2_arc", "dataset_name": "ARC-Challenge", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"regimes": ["OLMES-v0.1"], "alias": "arc_challenge:rc::olmes"}}, "current_date": "2024-11-18 21:45:07 UTC", "num_instances": 299}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_easy:mc", "task_hash": "64250ca6fdf0f02e07b539e8efc04922", "task_config": {"task_name": "arc_easy:mc", "task_core": "arc_easy", "limit": 1000000, "split": "test", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_raw", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Easy", "dataset_path": "ai2_arc", "dataset_name": "ARC-Easy", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"description": "ARC-Easy (MC) using OLMES-v0.1", "regimes": ["OLMES-v0.1"], "alias": "arc_easy:mc::olmes"}}, "current_date": "2024-11-18 22:06:33 UTC", "num_instances": 2376}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_easy", "task_hash": "ccbbd993c851d3300140d81ffec0e397", "task_config": {"task_name": "arc_easy", "task_core": "arc_easy", "limit": 1000000, "split": "test", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_per_char", "random_subsample_seed": 1234, "context_kwargs": {"description": null}, "generation_kwargs": {}, "metric_kwargs": {"uncond_docid_offset": 1000000}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Easy", "dataset_path": "ai2_arc", "dataset_name": "ARC-Easy", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"description": "ARC-Easy (RC) using OLMES-v0.1", "regimes": ["OLMES-v0.1"], "alias": "arc_easy:rc::olmes"}}, "current_date": "2024-11-18 21:50:27 UTC", "num_instances": 2376}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_easy:mc", "task_hash": "afa7e96b485c4e4481b3b9b817faac36", "task_config": {"task_name": "arc_easy:mc", "task_core": "arc_easy", "limit": 1000000, "split": "train", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_raw", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Easy", "dataset_path": "ai2_arc", "dataset_name": "ARC-Easy", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"description": "ARC-Easy (MC) using OLMES-v0.1", "regimes": ["OLMES-v0.1"], "alias": "arc_easy:mc::olmes"}}, "current_date": "2024-11-18 22:06:16 UTC", "num_instances": 2251}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_easy", "task_hash": "4a5241b308edb45d7b9eab594093c519", "task_config": {"task_name": "arc_easy", "task_core": "arc_easy", "limit": 1000000, "split": "train", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_per_char", "random_subsample_seed": 1234, "context_kwargs": {"description": null}, "generation_kwargs": {}, "metric_kwargs": {"uncond_docid_offset": 1000000}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Easy", "dataset_path": "ai2_arc", "dataset_name": "ARC-Easy", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"description": "ARC-Easy (RC) using OLMES-v0.1", "regimes": ["OLMES-v0.1"], "alias": "arc_easy:rc::olmes"}}, "current_date": "2024-11-18 22:06:07 UTC", "num_instances": 2251}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_easy:mc", "task_hash": "443bd52f752399615d01c853a8d7386c", "task_config": {"task_name": "arc_easy:mc", "task_core": "arc_easy", "limit": 1000000, "split": "validation", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_raw", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Easy", "dataset_path": "ai2_arc", "dataset_name": "ARC-Easy", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"description": "ARC-Easy (MC) using OLMES-v0.1", "regimes": ["OLMES-v0.1"], "alias": "arc_easy:mc::olmes"}}, "current_date": "2024-11-18 22:06:24 UTC", "num_instances": 570}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "arc_easy", "task_hash": "0045e4f588a617cbe9ee5a4ae8ca1ce5", "task_config": {"task_name": "arc_easy", "task_core": "arc_easy", "limit": 1000000, "split": "validation", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_per_char", "random_subsample_seed": 1234, "context_kwargs": {"description": null}, "generation_kwargs": {}, "metric_kwargs": {"uncond_docid_offset": 1000000}, "native_id_field": "id", "fewshot_source": "OLMES:ARC-Easy", "dataset_path": "ai2_arc", "dataset_name": "ARC-Easy", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"description": "ARC-Easy (RC) using OLMES-v0.1", "regimes": ["OLMES-v0.1"], "alias": "arc_easy:rc::olmes"}}, "current_date": "2024-11-18 21:45:15 UTC", "num_instances": 570}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "boolq:mc", "task_hash": "a92ca849d7efd331110145eb71e4fc09", "task_config": {"task_name": "boolq:mc", "task_core": "boolq", "limit": 1000000, "split": "train", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_raw", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {"uncond_docid_offset": null}, "native_id_field": "idx", "fewshot_source": "OLMES:BoolQ", "dataset_path": "super_glue", "dataset_name": "boolq", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"regimes": ["OLMES-v0.1"], "alias": "boolq:mc::olmes"}}, "current_date": "2024-11-18 22:06:52 UTC", "num_instances": 9427}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "boolq", "task_hash": "ec8729b372d310aaf3a222f37a7af7b9", "task_config": {"task_name": "boolq", "task_core": "boolq", "limit": 1000000, "split": "train", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_raw", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {"uncond_docid_offset": null}, "native_id_field": "idx", "fewshot_source": "OLMES:BoolQ", "dataset_path": "super_glue", "dataset_name": "boolq", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"regimes": ["OLMES-v0.1"], "alias": "boolq:rc::olmes"}}, "current_date": "2024-11-18 22:06:41 UTC", "num_instances": 9427}
Binary file not shown.
1 change: 1 addition & 0 deletions src/olmo_eval/oe_eval_tasks/boolq/val_mc_5shot/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "boolq:mc", "task_hash": "d88f45757f4a8c3802b7274857894a90", "task_config": {"task_name": "boolq:mc", "task_core": "boolq", "limit": 1000000, "split": "validation", "num_shots": 5, "fewshot_seed": 1234, "primary_metric": "acc_raw", "random_subsample_seed": 1234, "context_kwargs": {}, "generation_kwargs": {}, "metric_kwargs": {"uncond_docid_offset": null}, "native_id_field": "idx", "fewshot_source": "OLMES:BoolQ", "dataset_path": "super_glue", "dataset_name": "boolq", "use_chat_format": null, "version": 0, "revision": null, "metadata": {"regimes": ["OLMES-v0.1"], "alias": "boolq:mc::olmes"}}, "current_date": "2024-11-18 22:07:01 UTC", "num_instances": 3270}
Binary file not shown.
Loading

0 comments on commit 5f3db3c

Please sign in to comment.