diff --git a/pipeline/eval/eval.py b/pipeline/eval/eval.py index 31c5e43d7..dc7e7cdc7 100755 --- a/pipeline/eval/eval.py +++ b/pipeline/eval/eval.py @@ -356,7 +356,9 @@ def main(args_list: Optional[list[str]] = None) -> None: # Allow publishing metrics as a table on existing runs (i.e. previous trainings) wandb.open(resume=True) logger.info(f"Publishing metrics to Weight & Biases ({wandb.extra_kwargs})") - metric = metric_from_tc_context(chrf=chrf_details["score"], bleu=bleu_details["score"]) + metric = metric_from_tc_context( + chrf=chrf_details["score"], bleu=bleu_details["score"], comet=comet_score + ) wandb.handle_metrics(metrics=[metric]) wandb.close() diff --git a/tracking/translations_parser/data.py b/tracking/translations_parser/data.py index 0cc5068c9..44128b8fb 100644 --- a/tracking/translations_parser/data.py +++ b/tracking/translations_parser/data.py @@ -56,6 +56,7 @@ class Metric: # Scores chrf: float bleu_detok: float + comet: float | None = None # optional @classmethod def from_file( @@ -79,10 +80,14 @@ def from_file( values.append(float(line)) except ValueError: continue - assert len(values) == 2, "file must contain exactly 2 float values" + assert len(values) in (2, 3), "file must contain 2 or 3 lines with a float value" except Exception as e: raise ValueError(f"Metrics file could not be parsed: {e}") - bleu_detok, chrf = values + if len(values) == 2: + bleu_detok, chrf = values + comet = None + if len(values) == 3: + bleu_detok, chrf, comet = values if importer is None: _, importer, dataset, augmentation = parse_task_label(metrics_file.stem) return cls( @@ -91,6 +96,7 @@ def from_file( augmentation=augmentation, chrf=chrf, bleu_detok=bleu_detok, + comet=comet, ) @classmethod diff --git a/tracking/translations_parser/publishers.py b/tracking/translations_parser/publishers.py index 20807d9aa..fbf642427 100644 --- a/tracking/translations_parser/publishers.py +++ b/tracking/translations_parser/publishers.py @@ -176,7 +176,11 @@ def handle_metrics(self, metrics: Sequence[Metric]) -> None: title: wandb.plot.bar( wandb.Table( columns=["Metric", "Value"], - data=[[key, getattr(metric, key)] for key in ("bleu_detok", "chrf")], + data=[ + [key, getattr(metric, key)] + for key in ("bleu_detok", "chrf", "comet") + if getattr(metric, key) is not None + ], ), "Metric", "Value", diff --git a/tracking/translations_parser/utils.py b/tracking/translations_parser/utils.py index ebfeab75c..01569ddfa 100644 --- a/tracking/translations_parser/utils.py +++ b/tracking/translations_parser/utils.py @@ -164,7 +164,7 @@ def build_task_name(task: dict): return prefix, label.model -def metric_from_tc_context(chrf: float, bleu: float): +def metric_from_tc_context(chrf: float, bleu: float, comet: float): """ Find the various names needed to build a metric directly from a Taskcluster task """ @@ -185,4 +185,5 @@ def metric_from_tc_context(chrf: float, bleu: float): augmentation=parsed.augmentation, chrf=chrf, bleu_detok=bleu, + comet=comet, )