Skip to content

Commit

Permalink
Implement soft versions of accuracies
Browse files Browse the repository at this point in the history
  • Loading branch information
CodeCreator committed Nov 1, 2024
1 parent 6f60edd commit 99c0d80
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 12 deletions.
26 changes: 20 additions & 6 deletions olmo/eval/downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def update(self, batch: Dict[str, Any], lm_logits: torch.Tensor, dc_lm_logits=No
torch.LongTensor((doc_id, cont_id, batch["label_id"][idx])).to(batch["label_id"][idx].device)
)

def compute(self) -> torch.Tensor:
def compute(self) -> Dict[str, torch.Tensor]:
# Task "suffix" -> tensor

# states should have been synced from all accelerators at this point
# account for duplicates here because of DistributedSampler compensating for drop_last=False
loglikelihood_dict: Dict[int, Dict[int, float]] = {}
Expand All @@ -116,6 +118,9 @@ def compute(self) -> torch.Tensor:

# compute acc
correct = []
soft_scores = []
soft_log_scores = []

preds: Optional[List[float]] = None
labels: Optional[List[int]] = None
if self.metric_type == "f1":
Expand All @@ -140,14 +145,15 @@ def compute(self) -> torch.Tensor:
continue
if self.metric_type in ["ce_loss", "bpb"]:
correct.append(loglikelihoods[0]) # Only one answer is scored
else:
correct.append(1.0 if torch.argmax(loglikelihoods).item() == label_dict[doc_id] else 0.0)

if self.metric_type == "f1":
elif self.metric_type == "f1":
assert preds is not None
assert labels is not None
preds.append(torch.argmax(loglikelihoods).item())
labels.append(label_dict[doc_id])
else:
correct.append(1.0 if torch.argmax(loglikelihoods).item() == label_dict[doc_id] else 0.0)
soft_scores.append(torch.softmax(loglikelihoods, dim=0)[label_dict[doc_id]].item())
soft_log_scores.append(torch.log_softmax(loglikelihoods, dim=0)[label_dict[doc_id]].item())

if self.metric_type == "f1":
assert preds is not None
Expand All @@ -157,7 +163,15 @@ def compute(self) -> torch.Tensor:
else:
score = sum(correct) / len(correct)

return torch.tensor(score)
outputs = {
"": torch.tensor(score),
}

if soft_scores:
outputs["_soft"] = torch.tensor(sum(soft_scores) / len(soft_scores))
outputs["_soft_log"] = torch.tensor(sum(soft_log_scores) / len(soft_log_scores))

return outputs


class ICLMultiChoiceTaskDataset(metaclass=abc.ABCMeta):
Expand Down
15 changes: 9 additions & 6 deletions olmo/eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@ def reset_metrics(self) -> None:
def compute_metrics(self) -> Dict[str, float]:
if self.type == EvaluatorType.downstream:
assert isinstance(self.eval_metric, ICLMetric)
value = self.eval_metric.compute().item()
key = f"eval/downstream/{self.label}_{self.eval_metric.metric_type}"
if self.eval_metric.metric_type in ["ce_loss", "bpb"]:
key = key.replace("/downstream/", f"/downstream_{self.eval_metric.metric_type}/")
return {key: value}
suffix_to_value = self.eval_metric.compute()
outputs = {}
for suffix, value in suffix_to_value.items():
key = f"eval/downstream/{self.label}_{self.eval_metric.metric_type}{suffix}"
if self.eval_metric.metric_type in ["ce_loss", "bpb"]:
key = key.replace("/downstream/", f"/downstream_{self.eval_metric.metric_type}/")
outputs[key] = value.item()
return outputs
elif self.type == EvaluatorType.lm:
# Metric(s) = cross entropy loss
metrics: Dict[str, Metric]
Expand All @@ -52,7 +55,7 @@ def compute_metrics(self) -> Dict[str, float]:
# This can happen when the evaluator contains multiple tasks/datasets and we didn't
# get to this one within the current evaluation loop.
metric.update(0.0, 0.0)
loss = metric.compute()
loss = metric.compute()[""] # always no suffix
if loss.isnan().item():
# This can happen when the evaluator contains multiple tasks/datasets and we didn't
# get to this one within the current evaluation loop.
Expand Down

0 comments on commit 99c0d80

Please sign in to comment.