diff --git a/tokenlearn/train.py b/tokenlearn/train.py index b1416f1..4d70d12 100644 --- a/tokenlearn/train.py +++ b/tokenlearn/train.py @@ -127,7 +127,6 @@ def train_supervised( # noqa: C901 batch_size: int = 256, project_name: str | None = None, config: dict | None = None, - save_dir: str = "saved_models", lr_scheduler_patience: int = 3, lr_scheduler_min_delta: float = 0.03, cosine_weight: float = 1.0, @@ -148,7 +147,6 @@ def train_supervised( # noqa: C901 :param batch_size: The batch size. :param project_name: The name of the project for W&B. :param config: The configuration for W&B. - :param save_dir: The directory to save the model. :param lr_scheduler_patience: The patience for the learning rate scheduler. :param lr_scheduler_min_delta: The minimum delta for the learning rate scheduler. :param cosine_weight: The weight for the cosine loss. @@ -304,7 +302,4 @@ def train_supervised( # noqa: C901 new_model = StaticModel(vectors=vectors, tokenizer=model.tokenizer, config=model.config) - # Save the best model based on training loss - new_model.save_pretrained(f"{save_dir}/best_model_train_loss_{lowest_loss:.4f}") - return new_model, trainable_model diff --git a/train.py b/train.py index 8a78673..b56429f 100644 --- a/train.py +++ b/train.py @@ -48,7 +48,7 @@ def train_model( model_name: str, data_path: str, save_path: str, device: str = "cpu", random_embeddings: bool = False ) -> StaticModel: """ - Train a tokenlearnn model. + Train a tokenlearn model. :param model_name: The sentence transformer model name for distillation. :param data_path: Path to the directory containing the dataset. @@ -72,7 +72,7 @@ def train_model( train_data = TextDataset(train_txt, torch.from_numpy(train_vec), s.tokenizer) # Train the model - model, _ = train_supervised(train_data, s, device=device) + model, _ = train_supervised(train_dataset=train_data, model=s, device=device) # Save the trained model model.save_pretrained(save_path)