Skip to content

Commit

Permalink
Add classification cfg checks for labels
Browse files Browse the repository at this point in the history
  • Loading branch information
maxjeblick authored Nov 14, 2023
1 parent 75acecf commit 92b402f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from dataclasses import dataclass, field
from typing import Any, Tuple
from typing import Any, Dict, List, Tuple

import llm_studio.src.datasets.text_causal_classification_ds
import llm_studio.src.plots.text_causal_classification_modeling_plots
Expand Down Expand Up @@ -188,3 +188,28 @@ def __post_init__(self):
),
allow_custom=True,
)

def check(self) -> Dict[str, List]:
errors: Dict[str, List] = {"title": [], "message": []}

if self.training.loss_function == "CrossEntropyLoss":
if self.dataset.num_classes == 1:
errors["title"] += ["CrossEntropyLoss requires num_classes > 1"]
errors["message"] += [
"CrossEntropyLoss requires num_classes > 1, "
"but num_classes is set to 1."
]
elif self.training.loss_function == "BinaryCrossEntropyLoss":
if self.dataset.num_classes != 1:
errors["title"] += ["BinaryCrossEntropyLoss requires num_classes == 1"]
errors["message"] += [
"BinaryCrossEntropyLoss requires num_classes == 1, "
"but num_classes is set to {}.".format(self.dataset.num_classes)
]
if self.dataset.parent_id_column not in ["None", None]:
errors["title"] += ["Parent ID column is not supported for classification"]
errors["message"] += [
"Parent ID column is not supported for classification datasets."
]

return errors
18 changes: 18 additions & 0 deletions llm_studio/src/datasets/text_causal_classification_ds.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Any, Dict

import numpy as np
Expand All @@ -9,12 +10,15 @@
)
from llm_studio.src.utils.exceptions import LLMDataException

logger = logging.getLogger(__name__)


class CustomDataset(TextCausalLanguageModelingCustomDataset):
def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"):
super().__init__(df=df, cfg=cfg, mode=mode)
check_for_non_int_answers(cfg, df)
self.answers_int = df[cfg.dataset.answer_column].astype(int).values.tolist()

if 1 < cfg.dataset.num_classes <= max(self.answers_int):
raise LLMDataException(
"Number of classes is smaller than max label "
Expand All @@ -25,6 +29,20 @@ def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"):
"For binary classification, max label should be 1 but is "
f"{max(self.answers_int)}."
)
if min(self.answers_int) < 0:
raise LLMDataException(
"Labels should be non-negative but min label is "
f"{min(self.answers_int)}."
)
if (
min(self.answers_int) != 0
or max(self.answers_int) != len(set(self.answers_int)) - 1
):
logger.warning(
"Labels should start at 0 and be continuous but are "
f"{sorted(set(self.answers_int))}."
)

if cfg.dataset.parent_id_column != "None":
raise LLMDataException(
"Parent ID column is not supported for classification datasets."
Expand Down

0 comments on commit 92b402f

Please sign in to comment.