From 453e5a9837624f37b23fd25b4827df9ad8e274fa Mon Sep 17 00:00:00 2001 From: Pringled Date: Fri, 14 Feb 2025 12:12:10 +0100 Subject: [PATCH 01/29] Added multilabel option to training --- model2vec/train/classifier.py | 188 ++++++++++++++++++++++------------ tests/conftest.py | 21 +++- tests/test_trainable.py | 5 +- 3 files changed, 144 insertions(+), 70 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index d90986a..abded95 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -3,12 +3,14 @@ import logging from collections import Counter from tempfile import TemporaryDirectory +from typing import cast import lightning as pl import numpy as np import torch from lightning.pytorch.callbacks import Callback, EarlyStopping from lightning.pytorch.utilities.types import OptimizerLRScheduler +from sklearn.metrics import jaccard_score from sklearn.model_selection import train_test_split from sklearn.neural_network import MLPClassifier from sklearn.pipeline import make_pipeline @@ -20,7 +22,6 @@ from model2vec.train.base import FinetunableStaticModel, TextDataset logger = logging.getLogger(__name__) - _RANDOM_SEED = 42 @@ -40,11 +41,13 @@ def __init__( self.hidden_dim = hidden_dim # Alias: Follows scikit-learn. Set to dummy classes self.classes_: list[str] = [str(x) for x in range(out_dim)] + # multilabel flag will be set based on the type of `y` passed to fit. + self.multilabel: bool = False super().__init__(vectors=vectors, out_dim=out_dim, pad_id=pad_id, tokenizer=tokenizer) @property def classes(self) -> list[str]: - """Return all clasess in the correct order.""" + """Return the classes in the correct order.""" return self.classes_ def construct_head(self) -> nn.Sequential: @@ -67,33 +70,53 @@ def construct_head(self) -> nn.Sequential: return nn.Sequential(*modules) def predict(self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024) -> np.ndarray: - """Predict a class for a set of texts.""" - pred: list[str] = [] + """ + Predict labels for a set of texts. + + In single-label mode, each prediction is a single class. + In multilabel mode, each prediction is a list of classes. + """ + pred = [] for batch in trange(0, len(X), batch_size, disable=not show_progress_bar): logits = self._predict_single_batch(X[batch : batch + batch_size]) - pred.extend([self.classes[idx] for idx in logits.argmax(1)]) - - return np.asarray(pred) + if self.multilabel: + probs = torch.sigmoid(logits) + for sample in probs: + sample_labels = [self.classes[i] for i, p in enumerate(sample) if p > 0.5] + # Fallback: if no label passes the threshold, choose the highest probability label. + if not sample_labels: + sample_labels = [self.classes[sample.argmax().item()]] + pred.append(sample_labels) + else: + pred.extend([self.classes[idx] for idx in logits.argmax(dim=1).tolist()]) + return np.array(pred, dtype=object) @torch.no_grad() def _predict_single_batch(self, X: list[str]) -> torch.Tensor: input_ids = self.tokenize(X) - vectors, _ = self.forward(input_ids) - return vectors + head_out, _ = self.forward(input_ids) + return head_out def predict_proba(self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024) -> np.ndarray: - """Predict the probability of each class.""" - pred: list[np.ndarray] = [] + """ + Predict probabilities for each class. + + In single-label mode, returns softmax probabilities. + In multilabel mode, returns sigmoid probabilities. + """ + pred = [] for batch in trange(0, len(X), batch_size, disable=not show_progress_bar): logits = self._predict_single_batch(X[batch : batch + batch_size]) - pred.append(torch.softmax(logits, dim=1).numpy()) - - return np.concatenate(pred) + if self.multilabel: + pred.append(torch.sigmoid(logits).cpu().numpy()) + else: + pred.append(torch.softmax(logits, dim=1).cpu().numpy()) + return np.concatenate(pred, axis=0) def fit( self, X: list[str], - y: list[str], + y: list[str] | list[list[str]], learning_rate: float = 1e-3, batch_size: int | None = None, min_epochs: int | None = None, @@ -106,43 +129,44 @@ def fit( Fit a model. This function creates a Lightning Trainer object and fits the model to the data. - We use early stopping. After training, the weigths of the best model are loaded back into the model. + It supports both single-label and multi-label classification. + We use early stopping. After training, the weights of the best model are loaded back into the model. - This function seeds everything with a seed of 42, so the results are reproducible. - It also splits the data into a train and validation set, again with a random seed. + The function seeds everything with a seed of 42, so the results are reproducible. + It also splits the data into training and validation sets using a random seed. :param X: The texts to train on. - :param y: The labels to train on. + :param y: The labels to train on. If the first element is a list, multi-label classification is assumed. :param learning_rate: The learning rate. - :param batch_size: The batch size. - If this is None, a good batch size is chosen automatically. + :param batch_size: The batch size. If None, a good batch size is chosen automatically. :param min_epochs: The minimum number of epochs to train for. :param max_epochs: The maximum number of epochs to train for. - If this is -1, the model trains until early stopping is triggered. - :param early_stopping_patience: The patience for early stopping. - If this is None, early stopping is disabled. + If -1, training continues until early stopping is triggered. + :param early_stopping_patience: The patience for early stopping. If None, early stopping is disabled. :param test_size: The test size for the train-test split. - :param device: The device to train on. If this is "auto", the device is chosen automatically. + :param device: The device to train on. If "auto", the device is chosen automatically. :return: The fitted model. """ pl.seed_everything(_RANDOM_SEED) logger.info("Re-initializing model.") - self._initialize(y) + + # Determine whether we're in multilabel mode based on the first element of y. + multilabel = isinstance(y[0], list) + self._initialize(y, multilabel=multilabel) train_texts, validation_texts, train_labels, validation_labels = self._train_test_split( - X, y, test_size=test_size + X, y, test_size=test_size, multilabel=multilabel ) if batch_size is None: - # Set to a multiple of 32 base_number = int(min(max(1, (len(train_texts) / 30) // 32), 16)) batch_size = int(base_number * 32) logger.info("Batch size automatically set to %d.", batch_size) logger.info("Preparing train dataset.") - train_dataset = self._prepare_dataset(train_texts, train_labels) + train_dataset = self._prepare_dataset(train_texts, train_labels, multilabel=multilabel) logger.info("Preparing validation dataset.") - val_dataset = self._prepare_dataset(validation_texts, validation_labels) + val_dataset = self._prepare_dataset(validation_texts, validation_labels, multilabel=multilabel) c = _ClassifierLightningModule(self, learning_rate=learning_rate) @@ -152,8 +176,7 @@ def fit( callback = EarlyStopping(monitor="val_accuracy", mode="max", patience=early_stopping_patience) callbacks.append(callback) - # If the dataset is small, we check the validation set every epoch. - # If the dataset is large, we check the validation set every 250 batches. + # Check validation frequency. if n_train_batches < 250: val_check_interval = None check_val_every_epoch = 1 @@ -186,45 +209,75 @@ def fit( self.load_state_dict(state_dict) self.eval() - return self - def _initialize(self, y: list[str]) -> None: - """Sets the out dimensionality, the classes and initializes the head.""" - classes = sorted(set(y)) - self.classes_ = classes - - if len(self.classes) != self.out_dim: - self.out_dim = len(self.classes) + def _initialize(self, y: list[str] | list[list[str]], multilabel: bool = False) -> None: + """ + Sets the output dimensionality, the classes, and initializes the head. + For multilabel classification, y is assumed to be a list of lists. + """ + self.multilabel = multilabel + if multilabel: + classes = sorted({label for sublist in y for label in sublist}) + else: + classes = sorted(set(cast(list[str], y))) + self.classes_ = classes + self.out_dim = len(self.classes_) # Update output dimension self.head = self.construct_head() self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=self.pad_id) self.w = self.construct_weights() self.train() - def _prepare_dataset(self, X: list[str], y: list[str], max_length: int = 512) -> TextDataset: - """Prepare a dataset.""" - # This is a speed optimization. - # assumes a mean token length of 10, which is really high, so safe. + def _prepare_dataset( + self, X: list[str], y: list[str] | list[list[str]], max_length: int = 512, multilabel: bool = False + ) -> TextDataset: + """ + Prepare a dataset. + + For multilabel classification, each target is converted into a multi-hot vector. + """ truncate_length = max_length * 10 X = [x[:truncate_length] for x in X] tokenized: list[list[int]] = [ encoding.ids[:max_length] for encoding in self.tokenizer.encode_batch_fast(X, add_special_tokens=False) ] - labels_tensor = torch.Tensor([self.classes.index(label) for label in y]).long() - + if multilabel: + num_classes = len(self.classes) + label_list = [] + for sample_labels in y: + multi_hot = torch.zeros(num_classes, dtype=torch.float) + for label in sample_labels: + index = self.classes.index(label) + multi_hot[index] = 1.0 + label_list.append(multi_hot) + labels_tensor = torch.stack(label_list) + else: + labels_tensor = torch.tensor([self.classes.index(label) for label in cast(list[str], y)], dtype=torch.long) return TextDataset(tokenized, labels_tensor) @staticmethod def _train_test_split( - X: list[str], y: list[str], test_size: float - ) -> tuple[list[str], list[str], list[str], list[str]]: - """Split the data.""" - label_counts = Counter(y) - if min(label_counts.values()) < 2: - logger.info("Some classes have less than 2 samples. Stratification is disabled.") + X: list[str], + y: list[str] | list[list[str]], + test_size: float, + multilabel: bool = False, + ) -> tuple[list[str], list[str], list[str] | list[list[str]], list[str] | list[list[str]]]: + """ + Split the data. + + For single-label classification, stratification is attempted (if possible). + For multilabel classification, a random split is performed. + """ + if not multilabel: + label_counts = Counter(y) # type: ignore + if min(label_counts.values()) < 2: + logger.info("Some classes have less than 2 samples. Stratification is disabled.") + return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True) + return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True, stratify=y) + else: + # Multilabel classification does not support stratification. return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True) - return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True, stratify=y) def to_pipeline(self) -> StaticModelPipeline: """Convert the model to an sklearn pipeline.""" @@ -256,38 +309,46 @@ def to_pipeline(self) -> StaticModelPipeline: class _ClassifierLightningModule(pl.LightningModule): def __init__(self, model: StaticModelForClassification, learning_rate: float) -> None: - """Initialize the lightningmodule.""" + """Initialize the LightningModule.""" super().__init__() self.model = model self.learning_rate = learning_rate def forward(self, x: torch.Tensor) -> torch.Tensor: - """Simple forward pass.""" + """Forward pass.""" return self.model(x) def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: - """Simple training step using cross entropy loss.""" + """Training step using cross-entropy loss for single-label and binary cross-entropy for multilabel training.""" x, y = batch head_out, _ = self.model(x) - loss = nn.functional.cross_entropy(head_out, y).mean() - + if self.model.multilabel: + loss = nn.functional.binary_cross_entropy_with_logits(head_out, y.float()) + else: + loss = nn.functional.cross_entropy(head_out, y) self.log("train_loss", loss) return loss def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: - """Simple validation step using cross entropy loss and accuracy.""" + """Validation step computing loss and accuracy.""" x, y = batch head_out, _ = self.model(x) - loss = nn.functional.cross_entropy(head_out, y).mean() - accuracy = (head_out.argmax(1) == y).float().mean() - + if self.model.multilabel: + # Compute multi-label accuracy by checking if all labels are correct. + loss = nn.functional.binary_cross_entropy_with_logits(head_out, y.float()) + preds = (torch.sigmoid(head_out) > 0.5).float() + # Accuracy is defined as the Jaccard score averaged over samples. + accuracy = jaccard_score(y.cpu(), preds.cpu(), average="samples") + else: + loss = nn.functional.cross_entropy(head_out, y) + accuracy = (head_out.argmax(dim=1) == y).float().mean() self.log("val_loss", loss) self.log("val_accuracy", accuracy, prog_bar=True) return loss def configure_optimizers(self) -> OptimizerLRScheduler: - """Simple Adam optimizer.""" + """Configure optimizer and learning rate scheduler.""" optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, @@ -299,5 +360,4 @@ def configure_optimizers(self) -> OptimizerLRScheduler: threshold=0.03, threshold_mode="rel", ) - return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}} diff --git a/tests/conftest.py b/tests/conftest.py index 6220329..e22fd56 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -86,13 +86,24 @@ def mock_inference_pipeline(mock_trained_pipeline: StaticModelForClassification) return mock_trained_pipeline.to_pipeline() -@pytest.fixture(scope="session") -def mock_trained_pipeline() -> StaticModelForClassification: +@pytest.fixture(params=[False, True], ids=["single_label", "multilabel"], scope="session") +def mock_trained_pipeline(request: pytest.FixtureRequest) -> StaticModelForClassification: """Mock staticmodelforclassification.""" tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer torch.random.manual_seed(42) vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12) - s = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu") - s.fit(["dog", "cat"], ["a", "b"], device="cpu") + model = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu") + X = ["dog", "cat"] + y: list[str] | list[list[str]] + + if request.param: + # Use multilabel targets. + + y = [["a", "b"], ["a"]] + else: + # Use single-label targets. + X = ["dog", "cat"] + y = ["a", "b"] + model.fit(X, y) - return s + return model diff --git a/tests/test_trainable.py b/tests/test_trainable.py index dc9bb81..c4fabc6 100644 --- a/tests/test_trainable.py +++ b/tests/test_trainable.py @@ -112,7 +112,10 @@ def test_textdataset_init_incorrect() -> None: def test_predict(mock_trained_pipeline: StaticModelForClassification) -> None: """Test the predict function.""" result = mock_trained_pipeline.predict(["dog cat", "dog"]).tolist() - assert result == ["b", "b"] + if mock_trained_pipeline.multilabel: + assert result == [["a", "b"], ["a", "b"]] + else: + assert result == ["b", "b"] def test_predict_proba(mock_trained_pipeline: StaticModelForClassification) -> None: From 0226494b3b473de4d5e7c3ded86f04799b5ce946 Mon Sep 17 00:00:00 2001 From: Pringled Date: Fri, 14 Feb 2025 12:12:28 +0100 Subject: [PATCH 02/29] Added multilabel option to training --- tests/conftest.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e22fd56..fb00cda 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -93,17 +93,16 @@ def mock_trained_pipeline(request: pytest.FixtureRequest) -> StaticModelForClass torch.random.manual_seed(42) vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12) model = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu") + X = ["dog", "cat"] y: list[str] | list[list[str]] - if request.param: # Use multilabel targets. - y = [["a", "b"], ["a"]] else: - # Use single-label targets. - X = ["dog", "cat"] + # Use singlelabel targets. y = ["a", "b"] + model.fit(X, y) return model From a22d61acbc9e5241a5f1b7bd435041aad414e78b Mon Sep 17 00:00:00 2001 From: Pringled Date: Fri, 14 Feb 2025 12:15:45 +0100 Subject: [PATCH 03/29] Added multilabel option to training --- model2vec/train/classifier.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index abded95..17d3865 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -94,8 +94,8 @@ def predict(self, X: list[str], show_progress_bar: bool = False, batch_size: int @torch.no_grad() def _predict_single_batch(self, X: list[str]) -> torch.Tensor: input_ids = self.tokenize(X) - head_out, _ = self.forward(input_ids) - return head_out + vectors, _ = self.forward(input_ids) + return vectors def predict_proba(self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024) -> np.ndarray: """ @@ -132,8 +132,8 @@ def fit( It supports both single-label and multi-label classification. We use early stopping. After training, the weights of the best model are loaded back into the model. - The function seeds everything with a seed of 42, so the results are reproducible. - It also splits the data into training and validation sets using a random seed. + This function seeds everything with a seed of 42, so the results are reproducible. + It also splits the data into a train and validation set, again with a random seed. :param X: The texts to train on. :param y: The labels to train on. If the first element is a list, multi-label classification is assumed. @@ -150,7 +150,7 @@ def fit( pl.seed_everything(_RANDOM_SEED) logger.info("Re-initializing model.") - # Determine whether we're in multilabel mode based on the first element of y. + # Determine whether the task is multilabel based on the type of y. multilabel = isinstance(y[0], list) self._initialize(y, multilabel=multilabel) @@ -176,7 +176,8 @@ def fit( callback = EarlyStopping(monitor="val_accuracy", mode="max", patience=early_stopping_patience) callbacks.append(callback) - # Check validation frequency. + # If the dataset is small, we check the validation set every epoch. + # If the dataset is large, we check the validation set every 250 batches. if n_train_batches < 250: val_check_interval = None check_val_every_epoch = 1 @@ -215,7 +216,8 @@ def _initialize(self, y: list[str] | list[list[str]], multilabel: bool = False) """ Sets the output dimensionality, the classes, and initializes the head. - For multilabel classification, y is assumed to be a list of lists. + :param y: The labels. + :param multilabel: Whether the task is multilabel. """ self.multilabel = multilabel if multilabel: @@ -233,10 +235,16 @@ def _prepare_dataset( self, X: list[str], y: list[str] | list[list[str]], max_length: int = 512, multilabel: bool = False ) -> TextDataset: """ - Prepare a dataset. + Prepare a dataset. For multilabel classification, each target is converted into a multi-hot vector. - For multilabel classification, each target is converted into a multi-hot vector. + :param X: The texts. + :param y: The labels. + :param max_length: The maximum length of the input. + :param multilabel: Whether the task is multilabel. + :return: A TextDataset. """ + # This is a speed optimization. + # assumes a mean token length of 10, which is really high, so safe. truncate_length = max_length * 10 X = [x[:truncate_length] for x in X] tokenized: list[list[int]] = [ @@ -270,7 +278,7 @@ def _train_test_split( For multilabel classification, a random split is performed. """ if not multilabel: - label_counts = Counter(y) # type: ignore + label_counts = Counter(y) if min(label_counts.values()) < 2: logger.info("Some classes have less than 2 samples. Stratification is disabled.") return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True) From 68a4ae4d3ddc2ee5161da78d9ae96cf6c213d19b Mon Sep 17 00:00:00 2001 From: Pringled Date: Fri, 14 Feb 2025 12:16:45 +0100 Subject: [PATCH 04/29] Added multilabel option to training --- model2vec/train/classifier.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 17d3865..94bf011 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -323,7 +323,7 @@ def __init__(self, model: StaticModelForClassification, learning_rate: float) -> self.learning_rate = learning_rate def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass.""" + """Simple forward pass.""" return self.model(x) def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: @@ -342,10 +342,9 @@ def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: i x, y = batch head_out, _ = self.model(x) if self.model.multilabel: - # Compute multi-label accuracy by checking if all labels are correct. loss = nn.functional.binary_cross_entropy_with_logits(head_out, y.float()) preds = (torch.sigmoid(head_out) > 0.5).float() - # Accuracy is defined as the Jaccard score averaged over samples. + # Multilabel accuracy is defined as the Jaccard score averaged over samples. accuracy = jaccard_score(y.cpu(), preds.cpu(), average="samples") else: loss = nn.functional.cross_entropy(head_out, y) From 614069aa51ad5b16048b56548a9dfcd4ffb1578f Mon Sep 17 00:00:00 2001 From: Pringled Date: Fri, 14 Feb 2025 12:18:15 +0100 Subject: [PATCH 05/29] Added multilabel option to training --- model2vec/train/classifier.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 94bf011..7682b1d 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -47,7 +47,7 @@ def __init__( @property def classes(self) -> list[str]: - """Return the classes in the correct order.""" + """Return all clasess in the correct order.""" return self.classes_ def construct_head(self) -> nn.Sequential: @@ -142,9 +142,10 @@ def fit( :param min_epochs: The minimum number of epochs to train for. :param max_epochs: The maximum number of epochs to train for. If -1, training continues until early stopping is triggered. - :param early_stopping_patience: The patience for early stopping. If None, early stopping is disabled. + :param early_stopping_patience: The patience for early stopping. + If this is None, early stopping is disabled. :param test_size: The test size for the train-test split. - :param device: The device to train on. If "auto", the device is chosen automatically. + :param device: The device to train on. If this is "auto", the device is chosen automatically. :return: The fitted model. """ pl.seed_everything(_RANDOM_SEED) From b50bc4a067fa1e90eddbf876c9f6e4ec49d8fc36 Mon Sep 17 00:00:00 2001 From: Pringled Date: Fri, 14 Feb 2025 12:19:50 +0100 Subject: [PATCH 06/29] Added multilabel option to training --- model2vec/train/classifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 7682b1d..a9027e7 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -141,7 +141,7 @@ def fit( :param batch_size: The batch size. If None, a good batch size is chosen automatically. :param min_epochs: The minimum number of epochs to train for. :param max_epochs: The maximum number of epochs to train for. - If -1, training continues until early stopping is triggered. + If this is -1, the model trains until early stopping is triggered. :param early_stopping_patience: The patience for early stopping. If this is None, early stopping is disabled. :param test_size: The test size for the train-test split. From 6831bfe3d38d5067342a85768ccd58cdad1a5e0e Mon Sep 17 00:00:00 2001 From: Pringled Date: Fri, 14 Feb 2025 12:31:44 +0100 Subject: [PATCH 07/29] Added threshold to predict --- model2vec/train/classifier.py | 12 ++++++++++-- tests/conftest.py | 3 --- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index a9027e7..fbebd8d 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -69,12 +69,20 @@ def construct_head(self) -> nn.Sequential: return nn.Sequential(*modules) - def predict(self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024) -> np.ndarray: + def predict( + self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024, threshold: float = 0.5 + ) -> np.ndarray: """ Predict labels for a set of texts. In single-label mode, each prediction is a single class. In multilabel mode, each prediction is a list of classes. + + :param X: The texts to predict on. + :param show_progress_bar: Whether to show a progress bar. + :param batch_size: The batch size. + :param threshold: The threshold for multilabel classification. + :return: The predictions. """ pred = [] for batch in trange(0, len(X), batch_size, disable=not show_progress_bar): @@ -82,7 +90,7 @@ def predict(self, X: list[str], show_progress_bar: bool = False, batch_size: int if self.multilabel: probs = torch.sigmoid(logits) for sample in probs: - sample_labels = [self.classes[i] for i, p in enumerate(sample) if p > 0.5] + sample_labels = [self.classes[i] for i, p in enumerate(sample) if p > threshold] # Fallback: if no label passes the threshold, choose the highest probability label. if not sample_labels: sample_labels = [self.classes[sample.argmax().item()]] diff --git a/tests/conftest.py b/tests/conftest.py index fb00cda..2f25886 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,15 +5,12 @@ import numpy as np import pytest import torch -from sklearn.neural_network import MLPClassifier -from sklearn.pipeline import make_pipeline from tokenizers import Tokenizer from tokenizers.models import WordLevel from tokenizers.pre_tokenizers import Whitespace from transformers import AutoModel, AutoTokenizer from model2vec.inference import StaticModelPipeline -from model2vec.model import StaticModel from model2vec.train import StaticModelForClassification From 7bf46ea838ebc78d39a01c6ec6c344f30a805838 Mon Sep 17 00:00:00 2001 From: Pringled Date: Fri, 14 Feb 2025 13:07:42 +0100 Subject: [PATCH 08/29] Updated docs --- model2vec/train/README.md | 48 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/model2vec/train/README.md b/model2vec/train/README.md index 87365fc..4cfa4fc 100644 --- a/model2vec/train/README.md +++ b/model2vec/train/README.md @@ -2,6 +2,8 @@ Aside from [distillation](../../README.md#distillation), `model2vec` also supports training simple classifiers on top of static models, using [pytorch](https://pytorch.org/), [lightning](https://lightning.ai/) and [scikit-learn](https://scikit-learn.org/stable/index.html). +We support both single and multi-label classification, which work seamlessly based on the labels you provide. + # Installation To train, make sure you install the training extra: @@ -65,6 +67,52 @@ print(f"Took {int((perf_counter() - s) * 1000)} milliseconds for {len(test)} ins # Took 67 milliseconds for 2000 instances on CPU. ``` +## Multi-label classification + +Multi-label classification is supported out of the box. Just pass a list of multi-labels to the `fit` function, and a multi-label classifier will be trained. For example, the following code trains a multi-label classifier on the [go_emotions](https://huggingface.co/datasets/google-research-datasets/go_emotions) dataset: + +```python +from datasets import load_dataset +from model2vec.train import StaticModelForClassification + +# Initialize a classifier from a pre-trained model +classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M") + +# Load a multi-label dataset +ds = load_dataset("google-research-datasets/go_emotions") + +# Inspect some of the labels +print(ds["train"]["labels"][40:50]) +# [[0, 15], [15, 18], [16, 27], [27], [7, 13], [10], [20], [27], [27], [27]] + +# Train the classifier on text (X) and labels (y) +classifier.fit(ds["train"]["text"], ds["train"]["labels"]) +``` + +Then, we can evaluate the classifier: + +```python +from sklearn import metrics +from sklearn.preprocessing import MultiLabelBinarizer + +# Make predictions on the test set +predictions = classifier.predict(ds["test"]["text"]) + +# Evaluate the classifier +mlb = MultiLabelBinarizer(classes=classifier.classes) +y_true = mlb.fit_transform(ds["test"]["labels"]) +y_pred = mlb.transform(predictions) + +print(f"Accuracy: {metrics.accuracy_score(y_true, y_pred):.3f}") +print(f"Precision: {metrics.precision_score(y_true, y_pred, average='macro', zero_division=0):.3f}") +print(f"Recall: {metrics.recall_score(y_true, y_pred, average='macro', zero_division=0):.3f}") +print(f"F1: {metrics.f1_score(y_true, y_pred, average='macro', zero_division=0):.3f}") +# Accuracy: 0.488 +# Precision: 0.510 +# Recall: 0.372 +# F1: 0.412 +``` + # Persistence You can turn a classifier into a scikit-learn compatible pipeline, as follows: From d277e79c9074c4229628927171a68844eaebd593 Mon Sep 17 00:00:00 2001 From: Pringled Date: Fri, 14 Feb 2025 13:11:01 +0100 Subject: [PATCH 09/29] Updated docs --- model2vec/train/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/model2vec/train/README.md b/model2vec/train/README.md index 4cfa4fc..974855a 100644 --- a/model2vec/train/README.md +++ b/model2vec/train/README.md @@ -113,6 +113,8 @@ print(f"F1: {metrics.f1_score(y_true, y_pred, average='macro', zero_division=0): # F1: 0.412 ``` +The scores are competetive with the popular [roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) model, while our model is orders of magnitude faster. + # Persistence You can turn a classifier into a scikit-learn compatible pipeline, as follows: From d28b89591401a2a5127ec3b49c27164f65d911de Mon Sep 17 00:00:00 2001 From: Pringled Date: Fri, 14 Feb 2025 13:13:04 +0100 Subject: [PATCH 10/29] Removed fallback logic --- model2vec/train/classifier.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index fbebd8d..58b412b 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -91,9 +91,6 @@ def predict( probs = torch.sigmoid(logits) for sample in probs: sample_labels = [self.classes[i] for i, p in enumerate(sample) if p > threshold] - # Fallback: if no label passes the threshold, choose the highest probability label. - if not sample_labels: - sample_labels = [self.classes[sample.argmax().item()]] pred.append(sample_labels) else: pred.extend([self.classes[idx] for idx in logits.argmax(dim=1).tolist()]) From 327ecb10befd31ce8e36fadafd0d98bb43584a3b Mon Sep 17 00:00:00 2001 From: Pringled Date: Fri, 14 Feb 2025 13:15:39 +0100 Subject: [PATCH 11/29] Updated docs --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 944b978..83f806a 100644 --- a/README.md +++ b/README.md @@ -105,7 +105,7 @@ from model2vec.train import StaticModelForClassification # Initialize a classifier from a pre-trained model classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M") -# Load a dataset +# Load a dataset. Note: both single and multi-label classification datasets are supported ds = load_dataset("setfit/subj") # Train the classifier on text (X) and labels (y) From d38679fc6182db271b683f380ed559e1b8a3498d Mon Sep 17 00:00:00 2001 From: Pringled Date: Fri, 14 Feb 2025 13:23:44 +0100 Subject: [PATCH 12/29] Updated docs --- model2vec/train/README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/model2vec/train/README.md b/model2vec/train/README.md index 974855a..730b205 100644 --- a/model2vec/train/README.md +++ b/model2vec/train/README.md @@ -95,8 +95,8 @@ Then, we can evaluate the classifier: from sklearn import metrics from sklearn.preprocessing import MultiLabelBinarizer -# Make predictions on the test set -predictions = classifier.predict(ds["test"]["text"]) +# Make predictions on the test set with a threshold of 0.3 +predictions = classifier.predict(ds["test"]["text"], threshold=0.3) # Evaluate the classifier mlb = MultiLabelBinarizer(classes=classifier.classes) @@ -107,10 +107,10 @@ print(f"Accuracy: {metrics.accuracy_score(y_true, y_pred):.3f}") print(f"Precision: {metrics.precision_score(y_true, y_pred, average='macro', zero_division=0):.3f}") print(f"Recall: {metrics.recall_score(y_true, y_pred, average='macro', zero_division=0):.3f}") print(f"F1: {metrics.f1_score(y_true, y_pred, average='macro', zero_division=0):.3f}") -# Accuracy: 0.488 -# Precision: 0.510 -# Recall: 0.372 -# F1: 0.412 +# Accuracy: 0.410 +# Precision: 0.527 +# Recall: 0.410 +# F1: 0.439 ``` The scores are competetive with the popular [roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) model, while our model is orders of magnitude faster. From 6d80e9066a17fe7ba9378c806a77d1fb09c45b00 Mon Sep 17 00:00:00 2001 From: Pringled Date: Fri, 14 Feb 2025 15:16:21 +0100 Subject: [PATCH 13/29] Resolved feedback --- model2vec/train/README.md | 2 +- model2vec/train/classifier.py | 66 ++++++++++++++++++++--------------- tests/test_trainable.py | 4 +-- 3 files changed, 41 insertions(+), 31 deletions(-) diff --git a/model2vec/train/README.md b/model2vec/train/README.md index 730b205..d6e95be 100644 --- a/model2vec/train/README.md +++ b/model2vec/train/README.md @@ -69,7 +69,7 @@ print(f"Took {int((perf_counter() - s) * 1000)} milliseconds for {len(test)} ins ## Multi-label classification -Multi-label classification is supported out of the box. Just pass a list of multi-labels to the `fit` function, and a multi-label classifier will be trained. For example, the following code trains a multi-label classifier on the [go_emotions](https://huggingface.co/datasets/google-research-datasets/go_emotions) dataset: +Multi-label classification is supported out of the box. Just pass a list of lists to the `fit` function (e.g. `[[label1, label2], [label1, label3]]`), and a multi-label classifier will be trained. For example, the following code trains a multi-label classifier on the [go_emotions](https://huggingface.co/datasets/google-research-datasets/go_emotions) dataset: ```python from datasets import load_dataset diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 58b412b..ba37ab1 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -2,8 +2,9 @@ import logging from collections import Counter +from itertools import chain from tempfile import TemporaryDirectory -from typing import cast +from typing import TypeVar, cast import lightning as pl import numpy as np @@ -24,6 +25,8 @@ logger = logging.getLogger(__name__) _RANDOM_SEED = 42 +LabelType = TypeVar("LabelType", list[str], list[list[str]]) + class StaticModelForClassification(FinetunableStaticModel): def __init__( @@ -121,7 +124,7 @@ def predict_proba(self, X: list[str], show_progress_bar: bool = False, batch_siz def fit( self, X: list[str], - y: list[str] | list[list[str]], + y: LabelType, learning_rate: float = 1e-3, batch_size: int | None = None, min_epochs: int | None = None, @@ -157,22 +160,25 @@ def fit( logger.info("Re-initializing model.") # Determine whether the task is multilabel based on the type of y. - multilabel = isinstance(y[0], list) - self._initialize(y, multilabel=multilabel) + + self._initialize(y) train_texts, validation_texts, train_labels, validation_labels = self._train_test_split( - X, y, test_size=test_size, multilabel=multilabel + X, + y, + test_size=test_size, ) if batch_size is None: + # Set to a multiple of 32 base_number = int(min(max(1, (len(train_texts) / 30) // 32), 16)) batch_size = int(base_number * 32) logger.info("Batch size automatically set to %d.", batch_size) logger.info("Preparing train dataset.") - train_dataset = self._prepare_dataset(train_texts, train_labels, multilabel=multilabel) + train_dataset = self._prepare_dataset(train_texts, train_labels) logger.info("Preparing validation dataset.") - val_dataset = self._prepare_dataset(validation_texts, validation_labels, multilabel=multilabel) + val_dataset = self._prepare_dataset(validation_texts, validation_labels) c = _ClassifierLightningModule(self, learning_rate=learning_rate) @@ -218,18 +224,28 @@ def fit( self.eval() return self - def _initialize(self, y: list[str] | list[list[str]], multilabel: bool = False) -> None: + def _initialize(self, y: LabelType) -> None: """ Sets the output dimensionality, the classes, and initializes the head. :param y: The labels. - :param multilabel: Whether the task is multilabel. + :raises ValueError: If the labels are inconsistent. """ + # Determine multilabel status by checking the type of each element in y. + if any(isinstance(label, (list, tuple)) for label in y): + if not all(isinstance(label, (list, tuple)) for label in y): + raise ValueError("Inconsistent label types in y. All labels must be either singular or list/tuple.") + multilabel = True + else: + multilabel = False + self.multilabel = multilabel if multilabel: - classes = sorted({label for sublist in y for label in sublist}) + # Flatten the labels + classes = sorted(set(chain.from_iterable(y))) else: classes = sorted(set(cast(list[str], y))) + self.classes_ = classes self.out_dim = len(self.classes_) # Update output dimension self.head = self.construct_head() @@ -237,16 +253,13 @@ def _initialize(self, y: list[str] | list[list[str]], multilabel: bool = False) self.w = self.construct_weights() self.train() - def _prepare_dataset( - self, X: list[str], y: list[str] | list[list[str]], max_length: int = 512, multilabel: bool = False - ) -> TextDataset: + def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) -> TextDataset: """ Prepare a dataset. For multilabel classification, each target is converted into a multi-hot vector. :param X: The texts. :param y: The labels. :param max_length: The maximum length of the input. - :param multilabel: Whether the task is multilabel. :return: A TextDataset. """ # This is a speed optimization. @@ -256,34 +269,31 @@ def _prepare_dataset( tokenized: list[list[int]] = [ encoding.ids[:max_length] for encoding in self.tokenizer.encode_batch_fast(X, add_special_tokens=False) ] - if multilabel: - num_classes = len(self.classes) - label_list = [] - for sample_labels in y: - multi_hot = torch.zeros(num_classes, dtype=torch.float) - for label in sample_labels: - index = self.classes.index(label) - multi_hot[index] = 1.0 - label_list.append(multi_hot) - labels_tensor = torch.stack(label_list) + if self.multilabel: + # Convert labels to multi-hot vectors + num_classes = len(self.classes_) + labels_tensor = torch.zeros(len(y), num_classes, dtype=torch.float) + mapping = {label: idx for idx, label in enumerate(self.classes_)} + for i, sample_labels in enumerate(y): + indices = [mapping[label] for label in sample_labels] + labels_tensor[i, indices] = 1.0 else: labels_tensor = torch.tensor([self.classes.index(label) for label in cast(list[str], y)], dtype=torch.long) return TextDataset(tokenized, labels_tensor) - @staticmethod def _train_test_split( + self, X: list[str], y: list[str] | list[list[str]], test_size: float, - multilabel: bool = False, - ) -> tuple[list[str], list[str], list[str] | list[list[str]], list[str] | list[list[str]]]: + ) -> tuple[list[str], list[str], LabelType, LabelType]: """ Split the data. For single-label classification, stratification is attempted (if possible). For multilabel classification, a random split is performed. """ - if not multilabel: + if not self.multilabel: label_counts = Counter(y) if min(label_counts.values()) < 2: logger.info("Some classes have less than 2 samples. Stratification is disabled.") diff --git a/tests/test_trainable.py b/tests/test_trainable.py index c4fabc6..3e76531 100644 --- a/tests/test_trainable.py +++ b/tests/test_trainable.py @@ -139,9 +139,9 @@ def test_convert_to_pipeline(mock_trained_pipeline: StaticModelForClassification assert np.allclose(p1, p2) -def test_train_test_split() -> None: +def test_train_test_split(mock_trained_pipeline: StaticModelForClassification) -> None: """Test the train test split function.""" - a, b, c, d = StaticModelForClassification._train_test_split(["0", "1", "2", "3"], ["1", "1", "0", "0"], 0.5) + a, b, c, d = mock_trained_pipeline._train_test_split(["0", "1", "2", "3"], ["1", "1", "0", "0"], 0.5) assert len(a) == 2 assert len(b) == 2 assert len(c) == len(a) From ad8ea8dd0f33b315ef9ce3b52deb8ae949bc7927 Mon Sep 17 00:00:00 2001 From: Thomas van Dongen Date: Fri, 14 Feb 2025 15:17:31 +0100 Subject: [PATCH 14/29] Update model2vec/train/README.md Co-authored-by: Stephan Tulkens --- model2vec/train/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model2vec/train/README.md b/model2vec/train/README.md index d6e95be..2c908ff 100644 --- a/model2vec/train/README.md +++ b/model2vec/train/README.md @@ -113,7 +113,7 @@ print(f"F1: {metrics.f1_score(y_true, y_pred, average='macro', zero_division=0): # F1: 0.439 ``` -The scores are competetive with the popular [roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) model, while our model is orders of magnitude faster. +The scores are competitive with the popular [roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) model, while our model is orders of magnitude faster. # Persistence From b3363ff78961b4f960d70a4cb3b8eaf695ad8869 Mon Sep 17 00:00:00 2001 From: Pringled Date: Fri, 14 Feb 2025 15:25:43 +0100 Subject: [PATCH 15/29] Resolved feedback --- model2vec/train/classifier.py | 58 ++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index ba37ab1..e6e65ab 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -231,20 +231,45 @@ def _initialize(self, y: LabelType) -> None: :param y: The labels. :raises ValueError: If the labels are inconsistent. """ - # Determine multilabel status by checking the type of each element in y. - if any(isinstance(label, (list, tuple)) for label in y): + if not y: + raise ValueError("y must not be empty") + + first_label = y[0] + if isinstance(first_label, str): + # Now we know y should be a list of strings. + if not all(isinstance(label, str) for label in y): + raise ValueError("Inconsistent label types in y. All labels must be strings.") + self.multilabel = False + y_single: list[str] = y # Now mypy knows this is a list of strings. + classes = sorted(set(y_single)) + elif isinstance(first_label, (list, tuple)): + # Now we know y should be a list of lists/tuples. if not all(isinstance(label, (list, tuple)) for label in y): - raise ValueError("Inconsistent label types in y. All labels must be either singular or list/tuple.") - multilabel = True + raise ValueError("Inconsistent label types in y. All labels must be lists or tuples.") + self.multilabel = True + y_multilabel: list[list[str]] = y # mypy now knows this is a list of lists. + classes = sorted(set(chain.from_iterable(y_multilabel))) else: - multilabel = False - - self.multilabel = multilabel - if multilabel: - # Flatten the labels - classes = sorted(set(chain.from_iterable(y))) - else: - classes = sorted(set(cast(list[str], y))) + raise ValueError("Labels must be either strings or lists/tuples of strings.") + + # # Determine multilabel status by checking the type of each element in y. + # if any(isinstance(label, (list, tuple)) for label in y): + # if not all(isinstance(label, (list, tuple)) for label in y): + # raise ValueError("Inconsistent label types in y. All labels must be either singular or list/tuple.") + # self.multilabel = True + # y_multilabel: list[list[str]] = y + # classes = sorted(set(chain.from_iterable(y_multilabel))) + # else: + # self.multilabel = False + # y_single: list[str] = y + # classes = sorted(set(y_single)) + + # self.multilabel = multilabel + # if multilabel: + # # Flatten the labels + # classes = sorted(set(chain.from_iterable(y))) + # else: + # classes = sorted(set(cast(list[str], y))) self.classes_ = classes self.out_dim = len(self.classes_) # Update output dimension @@ -337,6 +362,7 @@ def __init__(self, model: StaticModelForClassification, learning_rate: float) -> super().__init__() self.model = model self.learning_rate = learning_rate + self.loss_function = nn.CrossEntropyLoss() if not model.multilabel else nn.BCEWithLogitsLoss() def forward(self, x: torch.Tensor) -> torch.Tensor: """Simple forward pass.""" @@ -346,10 +372,7 @@ def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int """Training step using cross-entropy loss for single-label and binary cross-entropy for multilabel training.""" x, y = batch head_out, _ = self.model(x) - if self.model.multilabel: - loss = nn.functional.binary_cross_entropy_with_logits(head_out, y.float()) - else: - loss = nn.functional.cross_entropy(head_out, y) + loss = self.loss_function(head_out, y) self.log("train_loss", loss) return loss @@ -357,13 +380,12 @@ def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: i """Validation step computing loss and accuracy.""" x, y = batch head_out, _ = self.model(x) + loss = self.loss_function(head_out, y) if self.model.multilabel: - loss = nn.functional.binary_cross_entropy_with_logits(head_out, y.float()) preds = (torch.sigmoid(head_out) > 0.5).float() # Multilabel accuracy is defined as the Jaccard score averaged over samples. accuracy = jaccard_score(y.cpu(), preds.cpu(), average="samples") else: - loss = nn.functional.cross_entropy(head_out, y) accuracy = (head_out.argmax(dim=1) == y).float().mean() self.log("val_loss", loss) self.log("val_accuracy", accuracy, prog_bar=True) From 15f48738fdb2d0c75bcc0eb87f0fa7b2f80021c4 Mon Sep 17 00:00:00 2001 From: Pringled Date: Fri, 14 Feb 2025 15:26:30 +0100 Subject: [PATCH 16/29] Resolved feedback --- model2vec/train/classifier.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index e6e65ab..5320408 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -231,18 +231,14 @@ def _initialize(self, y: LabelType) -> None: :param y: The labels. :raises ValueError: If the labels are inconsistent. """ - if not y: - raise ValueError("y must not be empty") - - first_label = y[0] - if isinstance(first_label, str): + if isinstance(y[0], str): # Now we know y should be a list of strings. if not all(isinstance(label, str) for label in y): raise ValueError("Inconsistent label types in y. All labels must be strings.") self.multilabel = False y_single: list[str] = y # Now mypy knows this is a list of strings. classes = sorted(set(y_single)) - elif isinstance(first_label, (list, tuple)): + elif isinstance(y[0], (list, tuple)): # Now we know y should be a list of lists/tuples. if not all(isinstance(label, (list, tuple)) for label in y): raise ValueError("Inconsistent label types in y. All labels must be lists or tuples.") From 06dc246372e7c07f5c573784579e27343cb54959 Mon Sep 17 00:00:00 2001 From: Pringled Date: Fri, 14 Feb 2025 15:27:49 +0100 Subject: [PATCH 17/29] Resolved feedback --- model2vec/train/classifier.py | 31 ++++++------------------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 5320408..39238d4 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -232,41 +232,22 @@ def _initialize(self, y: LabelType) -> None: :raises ValueError: If the labels are inconsistent. """ if isinstance(y[0], str): - # Now we know y should be a list of strings. + # Check if all labels are strings. if not all(isinstance(label, str) for label in y): raise ValueError("Inconsistent label types in y. All labels must be strings.") self.multilabel = False - y_single: list[str] = y # Now mypy knows this is a list of strings. - classes = sorted(set(y_single)) + # y_single: list[str] = y + classes = sorted(set(y)) elif isinstance(y[0], (list, tuple)): - # Now we know y should be a list of lists/tuples. + # Check if all labels are lists or tuples. if not all(isinstance(label, (list, tuple)) for label in y): raise ValueError("Inconsistent label types in y. All labels must be lists or tuples.") self.multilabel = True - y_multilabel: list[list[str]] = y # mypy now knows this is a list of lists. - classes = sorted(set(chain.from_iterable(y_multilabel))) + # y_multilabel: list[list[str]] = y # mypy now knows this is a list of lists. + classes = sorted(set(chain.from_iterable(y))) else: raise ValueError("Labels must be either strings or lists/tuples of strings.") - # # Determine multilabel status by checking the type of each element in y. - # if any(isinstance(label, (list, tuple)) for label in y): - # if not all(isinstance(label, (list, tuple)) for label in y): - # raise ValueError("Inconsistent label types in y. All labels must be either singular or list/tuple.") - # self.multilabel = True - # y_multilabel: list[list[str]] = y - # classes = sorted(set(chain.from_iterable(y_multilabel))) - # else: - # self.multilabel = False - # y_single: list[str] = y - # classes = sorted(set(y_single)) - - # self.multilabel = multilabel - # if multilabel: - # # Flatten the labels - # classes = sorted(set(chain.from_iterable(y))) - # else: - # classes = sorted(set(cast(list[str], y))) - self.classes_ = classes self.out_dim = len(self.classes_) # Update output dimension self.head = self.construct_head() From 43de6dac8804afad6dcea1bf255cc120638a2732 Mon Sep 17 00:00:00 2001 From: Pringled Date: Fri, 14 Feb 2025 15:28:23 +0100 Subject: [PATCH 18/29] Resolved feedback --- model2vec/train/classifier.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 39238d4..08f56dc 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -236,17 +236,13 @@ def _initialize(self, y: LabelType) -> None: if not all(isinstance(label, str) for label in y): raise ValueError("Inconsistent label types in y. All labels must be strings.") self.multilabel = False - # y_single: list[str] = y classes = sorted(set(y)) - elif isinstance(y[0], (list, tuple)): + else: # Check if all labels are lists or tuples. if not all(isinstance(label, (list, tuple)) for label in y): raise ValueError("Inconsistent label types in y. All labels must be lists or tuples.") self.multilabel = True - # y_multilabel: list[list[str]] = y # mypy now knows this is a list of lists. classes = sorted(set(chain.from_iterable(y))) - else: - raise ValueError("Labels must be either strings or lists/tuples of strings.") self.classes_ = classes self.out_dim = len(self.classes_) # Update output dimension From 8e944ab98d2f257b9d93e8172a7b0ed0e1c102a4 Mon Sep 17 00:00:00 2001 From: Stephan Tulkens Date: Sat, 15 Feb 2025 13:23:10 +0100 Subject: [PATCH 19/29] add multilabel targets, fix tests (#194) --- model2vec/inference/model.py | 56 +++++++++++++++++++++++++++++++---- model2vec/train/classifier.py | 5 ++-- tests/test_inference.py | 18 ++++++++--- uv.lock | 2 +- 4 files changed, 67 insertions(+), 14 deletions(-) diff --git a/model2vec/inference/model.py b/model2vec/inference/model.py index 5b08dad..6a2fe6a 100644 --- a/model2vec/inference/model.py +++ b/model2vec/inference/model.py @@ -7,6 +7,7 @@ import huggingface_hub import numpy as np import skops.io +from sklearn.neural_network import MLPClassifier from sklearn.pipeline import Pipeline from model2vec.hf_utils import _create_model_card @@ -21,6 +22,20 @@ def __init__(self, model: StaticModel, head: Pipeline) -> None: """Create a pipeline with a StaticModel encoder.""" self.model = model self.head = head + classifier = self.head[-1] + # Check if the classifier is a multilabel classifier. + # NOTE: this doesn't look robust, but it is. + # Different classifiers, such as OVR wrappers, support multilabel output natively, so we + # can just use predict. + self.multilabel = False + if isinstance(classifier, MLPClassifier): + if classifier.out_activation_ == "logistic": + self.multilabel = True + + @property + def classes_(self) -> np.ndarray: + """The classes of the classifier.""" + return self.head.classes_ @classmethod def from_pretrained( @@ -60,7 +75,7 @@ def push_to_hub(self, repo_id: str, token: str | None = None, private: bool = Fa self.model.save_pretrained(temp_dir) push_folder_to_hub(Path(temp_dir), repo_id, private, token) - def _predict_and_coerce_to_2d( + def _encode_and_coerce_to_2d( self, X: list[str] | str, show_progress_bar: bool, @@ -69,7 +84,7 @@ def _predict_and_coerce_to_2d( use_multiprocessing: bool, multiprocessing_threshold: int, ) -> np.ndarray: - """Predict the labels of the input and coerce the output to a matrix.""" + """Encode the instances and coerce the output to a matrix.""" encoded = self.model.encode( X, show_progress_bar=show_progress_bar, @@ -91,9 +106,21 @@ def predict( batch_size: int = 1024, use_multiprocessing: bool = True, multiprocessing_threshold: int = 10_000, + threshold: float = 0.5, ) -> np.ndarray: - """Predict the labels of the input.""" - encoded = self._predict_and_coerce_to_2d( + """ + Predict the labels of the input. + + :param X: The input data to predict. Can be a list of strings or a single string. + :param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False. + :param max_length: The maximum length of the input sequences. Defaults to 512. + :param batch_size: The batch size for prediction. Defaults to 1024. + :param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True. + :param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000. + :param threshold: The threshold for multilabel classification. Defaults to 0.5. Ignored if not multilabel. + :return: The predicted labels or probabilities. + """ + encoded = self._encode_and_coerce_to_2d( X, show_progress_bar=show_progress_bar, max_length=max_length, @@ -102,6 +129,13 @@ def predict( multiprocessing_threshold=multiprocessing_threshold, ) + if self.multilabel: + out_labels = [] + proba = self.head.predict_proba(encoded) + for vector in proba: + out_labels.append(self.classes_[vector > threshold]) + return np.asarray(out_labels) + return self.head.predict(encoded) def predict_proba( @@ -113,8 +147,18 @@ def predict_proba( use_multiprocessing: bool = True, multiprocessing_threshold: int = 10_000, ) -> np.ndarray: - """Predict the probabilities of the labels of the input.""" - encoded = self._predict_and_coerce_to_2d( + """ + Predict the labels of the input. + + :param X: The input data to predict. Can be a list of strings or a single string. + :param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False. + :param max_length: The maximum length of the input sequences. Defaults to 512. + :param batch_size: The batch size for prediction. Defaults to 1024. + :param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True. + :param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000. + :return: The predicted labels or probabilities. + """ + encoded = self._encode_and_coerce_to_2d( X, show_progress_bar=show_progress_bar, max_length=max_length, diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 08f56dc..16a47eb 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -323,8 +323,8 @@ def to_pipeline(self) -> StaticModelPipeline: # To convert correctly, we need to set the outputs correctly, and fix the activation function. # Make sure n_outputs is set to > 1. mlp_head.n_outputs_ = self.out_dim - # Set to softmax - mlp_head.out_activation_ = "softmax" + # Set to softmax or sigmoid + mlp_head.out_activation_ = "logistic" if self.multilabel else "softmax" return StaticModelPipeline(static_model, converted) @@ -373,7 +373,6 @@ def configure_optimizers(self) -> OptimizerLRScheduler: mode="min", factor=0.5, patience=3, - verbose=True, min_lr=1e-6, threshold=0.03, threshold_mode="rel", diff --git a/tests/test_inference.py b/tests/test_inference.py index 9f4618d..9a12894 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -10,8 +10,13 @@ def test_init_predict(mock_inference_pipeline: StaticModelPipeline) -> None: """Test successful initialization of StaticModelPipeline.""" - assert mock_inference_pipeline.predict("dog").tolist() == ["b"] - assert mock_inference_pipeline.predict(["dog"]).tolist() == ["b"] + target: list[str] | list[list[str]] + if mock_inference_pipeline.multilabel: + target = [["a", "b"]] + else: + target = ["b"] + assert mock_inference_pipeline.predict("dog").tolist() == target + assert mock_inference_pipeline.predict(["dog"]).tolist() == target def test_init_predict_proba(mock_inference_pipeline: StaticModelPipeline) -> None: @@ -25,8 +30,13 @@ def test_roundtrip_save(mock_inference_pipeline: StaticModelPipeline) -> None: with TemporaryDirectory() as temp_dir: mock_inference_pipeline.save_pretrained(temp_dir) loaded = StaticModelPipeline.from_pretrained(temp_dir) - assert loaded.predict("dog") == ["b"] - assert loaded.predict(["dog"]) == ["b"] + target: list[str] | list[list[str]] + if mock_inference_pipeline.multilabel: + target = [["a", "b"]] + else: + target = ["b"] + assert loaded.predict("dog").tolist() == target + assert loaded.predict(["dog"]).tolist() == target assert loaded.predict_proba("dog").argmax() == 1 assert loaded.predict_proba(["dog"]).argmax(1).tolist() == [1] diff --git a/uv.lock b/uv.lock index f7d37b8..a046575 100644 --- a/uv.lock +++ b/uv.lock @@ -791,7 +791,7 @@ wheels = [ [[package]] name = "model2vec" -version = "0.3.8" +version = "0.4.0" source = { editable = "." } dependencies = [ { name = "jinja2" }, From 5c9d397f9e02271bc499e491baa490c4ad85780b Mon Sep 17 00:00:00 2001 From: Pringled Date: Sat, 15 Feb 2025 14:32:06 +0100 Subject: [PATCH 20/29] Fixed bug with array conversion --- model2vec/inference/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model2vec/inference/model.py b/model2vec/inference/model.py index 6a2fe6a..6f5a17b 100644 --- a/model2vec/inference/model.py +++ b/model2vec/inference/model.py @@ -134,7 +134,7 @@ def predict( proba = self.head.predict_proba(encoded) for vector in proba: out_labels.append(self.classes_[vector > threshold]) - return np.asarray(out_labels) + return np.asarray(out_labels, dtype=object) return self.head.predict(encoded) From 6a4f89bc52bf1dd666109c00689a8dca63680130 Mon Sep 17 00:00:00 2001 From: Pringled Date: Sat, 15 Feb 2025 15:20:28 +0100 Subject: [PATCH 21/29] Optimized inference performance --- model2vec/train/classifier.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 16a47eb..14c85e0 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -92,9 +92,8 @@ def predict( logits = self._predict_single_batch(X[batch : batch + batch_size]) if self.multilabel: probs = torch.sigmoid(logits) - for sample in probs: - sample_labels = [self.classes[i] for i, p in enumerate(sample) if p > threshold] - pred.append(sample_labels) + mask = (probs > threshold).cpu().numpy() + pred.extend([np.array(self.classes)[np.flatnonzero(row)] for row in mask]) else: pred.extend([self.classes[idx] for idx in logits.argmax(dim=1).tolist()]) return np.array(pred, dtype=object) From 3609e621ee904f20cdb758a4f9791508dac808db Mon Sep 17 00:00:00 2001 From: Pringled Date: Sat, 15 Feb 2025 15:28:40 +0100 Subject: [PATCH 22/29] Changed classes to np array --- model2vec/train/classifier.py | 8 ++++---- tests/test_trainable.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 14c85e0..afe1019 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -49,9 +49,9 @@ def __init__( super().__init__(vectors=vectors, out_dim=out_dim, pad_id=pad_id, tokenizer=tokenizer) @property - def classes(self) -> list[str]: + def classes(self) -> np.ndarray: """Return all clasess in the correct order.""" - return self.classes_ + return np.array(self.classes_) def construct_head(self) -> nn.Sequential: """Constructs a simple classifier head.""" @@ -93,7 +93,7 @@ def predict( if self.multilabel: probs = torch.sigmoid(logits) mask = (probs > threshold).cpu().numpy() - pred.extend([np.array(self.classes)[np.flatnonzero(row)] for row in mask]) + pred.extend([self.classes[np.flatnonzero(row)] for row in mask]) else: pred.extend([self.classes[idx] for idx in logits.argmax(dim=1).tolist()]) return np.array(pred, dtype=object) @@ -275,7 +275,7 @@ def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) -> indices = [mapping[label] for label in sample_labels] labels_tensor[i, indices] = 1.0 else: - labels_tensor = torch.tensor([self.classes.index(label) for label in cast(list[str], y)], dtype=torch.long) + labels_tensor = torch.tensor([self.classes_.index(label) for label in cast(list[str], y)], dtype=torch.long) return TextDataset(tokenized, labels_tensor) def _train_test_split( diff --git a/tests/test_trainable.py b/tests/test_trainable.py index 3e76531..2fd11e8 100644 --- a/tests/test_trainable.py +++ b/tests/test_trainable.py @@ -17,8 +17,8 @@ def test_init_predict(n_layers: int, mock_vectors: np.ndarray, mock_tokenizer: T s = StaticModelForClassification(vectors=vectors_torched, tokenizer=mock_tokenizer, n_layers=n_layers) assert s.vectors.shape == mock_vectors.shape assert s.w.shape[0] == mock_vectors.shape[0] - assert s.classes == s.classes_ - assert s.classes == ["0", "1"] + assert list(s.classes) == s.classes_ + assert list(s.classes) == ["0", "1"] head = s.construct_head() assert head[0].in_features == mock_vectors.shape[1] From b4df861a73d079e442a481f4f328756ece363c7b Mon Sep 17 00:00:00 2001 From: Pringled Date: Sun, 16 Feb 2025 13:16:34 +0100 Subject: [PATCH 23/29] Added int as possible label type --- model2vec/train/classifier.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index afe1019..792670b 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) _RANDOM_SEED = 42 -LabelType = TypeVar("LabelType", list[str], list[list[str]]) +LabelType = TypeVar("LabelType", list[str | int], list[list[str | int]]) class StaticModelForClassification(FinetunableStaticModel): @@ -230,18 +230,18 @@ def _initialize(self, y: LabelType) -> None: :param y: The labels. :raises ValueError: If the labels are inconsistent. """ - if isinstance(y[0], str): + if isinstance(y[0], (str, int)): # Check if all labels are strings. - if not all(isinstance(label, str) for label in y): + if not all(isinstance(label, (str, int)) for label in y): raise ValueError("Inconsistent label types in y. All labels must be strings.") self.multilabel = False - classes = sorted(set(y)) + classes = sorted({str(label) for label in y}) else: # Check if all labels are lists or tuples. if not all(isinstance(label, (list, tuple)) for label in y): raise ValueError("Inconsistent label types in y. All labels must be lists or tuples.") self.multilabel = True - classes = sorted(set(chain.from_iterable(y))) + classes = sorted({str(label) for label in chain.from_iterable(y)}) self.classes_ = classes self.out_dim = len(self.classes_) # Update output dimension @@ -258,6 +258,7 @@ def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) -> :param y: The labels. :param max_length: The maximum length of the input. :return: A TextDataset. + :raises ValueError: If the labels are inconsistent. """ # This is a speed optimization. # assumes a mean token length of 10, which is really high, so safe. @@ -272,7 +273,9 @@ def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) -> labels_tensor = torch.zeros(len(y), num_classes, dtype=torch.float) mapping = {label: idx for idx, label in enumerate(self.classes_)} for i, sample_labels in enumerate(y): - indices = [mapping[label] for label in sample_labels] + if not isinstance(sample_labels, (list, tuple)): + raise ValueError("For multilabel classification, each label should be a list or tuple.") + indices = [mapping[str(label)] for label in sample_labels] labels_tensor[i, indices] = 1.0 else: labels_tensor = torch.tensor([self.classes_.index(label) for label in cast(list[str], y)], dtype=torch.long) @@ -281,7 +284,7 @@ def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) -> def _train_test_split( self, X: list[str], - y: list[str] | list[list[str]], + y: LabelType, test_size: float, ) -> tuple[list[str], list[str], LabelType, LabelType]: """ From ba29febb24564a22b6179d96b292cb530c992028 Mon Sep 17 00:00:00 2001 From: Pringled Date: Sun, 16 Feb 2025 13:20:09 +0100 Subject: [PATCH 24/29] Added int as possible label type --- model2vec/train/classifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 792670b..1a7a89e 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -278,7 +278,7 @@ def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) -> indices = [mapping[str(label)] for label in sample_labels] labels_tensor[i, indices] = 1.0 else: - labels_tensor = torch.tensor([self.classes_.index(label) for label in cast(list[str], y)], dtype=torch.long) + labels_tensor = torch.tensor([self.classes_.index(str(label)) for label in y], dtype=torch.long) return TextDataset(tokenized, labels_tensor) def _train_test_split( From 3dcddf57ea4724055058f217ea83d75c5b65f8ff Mon Sep 17 00:00:00 2001 From: Pringled Date: Sun, 16 Feb 2025 13:32:43 +0100 Subject: [PATCH 25/29] Use previous logic --- model2vec/train/classifier.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 1a7a89e..afe1019 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) _RANDOM_SEED = 42 -LabelType = TypeVar("LabelType", list[str | int], list[list[str | int]]) +LabelType = TypeVar("LabelType", list[str], list[list[str]]) class StaticModelForClassification(FinetunableStaticModel): @@ -230,18 +230,18 @@ def _initialize(self, y: LabelType) -> None: :param y: The labels. :raises ValueError: If the labels are inconsistent. """ - if isinstance(y[0], (str, int)): + if isinstance(y[0], str): # Check if all labels are strings. - if not all(isinstance(label, (str, int)) for label in y): + if not all(isinstance(label, str) for label in y): raise ValueError("Inconsistent label types in y. All labels must be strings.") self.multilabel = False - classes = sorted({str(label) for label in y}) + classes = sorted(set(y)) else: # Check if all labels are lists or tuples. if not all(isinstance(label, (list, tuple)) for label in y): raise ValueError("Inconsistent label types in y. All labels must be lists or tuples.") self.multilabel = True - classes = sorted({str(label) for label in chain.from_iterable(y)}) + classes = sorted(set(chain.from_iterable(y))) self.classes_ = classes self.out_dim = len(self.classes_) # Update output dimension @@ -258,7 +258,6 @@ def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) -> :param y: The labels. :param max_length: The maximum length of the input. :return: A TextDataset. - :raises ValueError: If the labels are inconsistent. """ # This is a speed optimization. # assumes a mean token length of 10, which is really high, so safe. @@ -273,18 +272,16 @@ def _prepare_dataset(self, X: list[str], y: LabelType, max_length: int = 512) -> labels_tensor = torch.zeros(len(y), num_classes, dtype=torch.float) mapping = {label: idx for idx, label in enumerate(self.classes_)} for i, sample_labels in enumerate(y): - if not isinstance(sample_labels, (list, tuple)): - raise ValueError("For multilabel classification, each label should be a list or tuple.") - indices = [mapping[str(label)] for label in sample_labels] + indices = [mapping[label] for label in sample_labels] labels_tensor[i, indices] = 1.0 else: - labels_tensor = torch.tensor([self.classes_.index(str(label)) for label in y], dtype=torch.long) + labels_tensor = torch.tensor([self.classes_.index(label) for label in cast(list[str], y)], dtype=torch.long) return TextDataset(tokenized, labels_tensor) def _train_test_split( self, X: list[str], - y: LabelType, + y: list[str] | list[list[str]], test_size: float, ) -> tuple[list[str], list[str], LabelType, LabelType]: """ From eccec802c04bb772facb9b43ce242bfa80663c26 Mon Sep 17 00:00:00 2001 From: Pringled Date: Sun, 16 Feb 2025 13:35:33 +0100 Subject: [PATCH 26/29] Updated type check --- model2vec/train/classifier.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index afe1019..f86779f 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -230,9 +230,9 @@ def _initialize(self, y: LabelType) -> None: :param y: The labels. :raises ValueError: If the labels are inconsistent. """ - if isinstance(y[0], str): + if isinstance(y[0], (str, int)): # Check if all labels are strings. - if not all(isinstance(label, str) for label in y): + if not all(isinstance(label, (str | int)) for label in y): raise ValueError("Inconsistent label types in y. All labels must be strings.") self.multilabel = False classes = sorted(set(y)) From f9037d9266a3abe45607260017e12fc0e67e5834 Mon Sep 17 00:00:00 2001 From: Pringled Date: Sun, 16 Feb 2025 13:36:29 +0100 Subject: [PATCH 27/29] Updated type check --- model2vec/train/classifier.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index f86779f..93db91e 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -231,9 +231,9 @@ def _initialize(self, y: LabelType) -> None: :raises ValueError: If the labels are inconsistent. """ if isinstance(y[0], (str, int)): - # Check if all labels are strings. + # Check if all labels are strings or integers. if not all(isinstance(label, (str | int)) for label in y): - raise ValueError("Inconsistent label types in y. All labels must be strings.") + raise ValueError("Inconsistent label types in y. All labels must be strings or integers.") self.multilabel = False classes = sorted(set(y)) else: From 2dc5b17b63bc8daac365e9440a237808eb6c0d68 Mon Sep 17 00:00:00 2001 From: Pringled Date: Sun, 16 Feb 2025 13:46:22 +0100 Subject: [PATCH 28/29] Updated type check logic --- model2vec/train/classifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 93db91e..a796648 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -232,7 +232,7 @@ def _initialize(self, y: LabelType) -> None: """ if isinstance(y[0], (str, int)): # Check if all labels are strings or integers. - if not all(isinstance(label, (str | int)) for label in y): + if not all(isinstance(label, (str, int)) for label in y): raise ValueError("Inconsistent label types in y. All labels must be strings or integers.") self.multilabel = False classes = sorted(set(y)) From 27f1b82ec9a7a68a7a6a4ed07da88be4ab2e8282 Mon Sep 17 00:00:00 2001 From: Pringled Date: Sun, 16 Feb 2025 14:18:27 +0100 Subject: [PATCH 29/29] Only return object type array for multilabel --- model2vec/train/classifier.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index a796648..213b75a 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -96,7 +96,11 @@ def predict( pred.extend([self.classes[np.flatnonzero(row)] for row in mask]) else: pred.extend([self.classes[idx] for idx in logits.argmax(dim=1).tolist()]) - return np.array(pred, dtype=object) + if self.multilabel: + # Return as object array to allow for lists of varying lengths. + return np.array(pred, dtype=object) + else: + return np.array(pred) @torch.no_grad() def _predict_single_batch(self, X: list[str]) -> torch.Tensor: