Skip to content

Commit

Permalink
fix: Removed save model call (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled authored Oct 30, 2024
1 parent ac9d3fe commit 71f91b4
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 7 deletions.
5 changes: 0 additions & 5 deletions tokenlearn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def train_supervised( # noqa: C901
batch_size: int = 256,
project_name: str | None = None,
config: dict | None = None,
save_dir: str = "saved_models",
lr_scheduler_patience: int = 3,
lr_scheduler_min_delta: float = 0.03,
cosine_weight: float = 1.0,
Expand All @@ -148,7 +147,6 @@ def train_supervised( # noqa: C901
:param batch_size: The batch size.
:param project_name: The name of the project for W&B.
:param config: The configuration for W&B.
:param save_dir: The directory to save the model.
:param lr_scheduler_patience: The patience for the learning rate scheduler.
:param lr_scheduler_min_delta: The minimum delta for the learning rate scheduler.
:param cosine_weight: The weight for the cosine loss.
Expand Down Expand Up @@ -304,7 +302,4 @@ def train_supervised( # noqa: C901

new_model = StaticModel(vectors=vectors, tokenizer=model.tokenizer, config=model.config)

# Save the best model based on training loss
new_model.save_pretrained(f"{save_dir}/best_model_train_loss_{lowest_loss:.4f}")

return new_model, trainable_model
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def train_model(
model_name: str, data_path: str, save_path: str, device: str = "cpu", random_embeddings: bool = False
) -> StaticModel:
"""
Train a tokenlearnn model.
Train a tokenlearn model.
:param model_name: The sentence transformer model name for distillation.
:param data_path: Path to the directory containing the dataset.
Expand All @@ -72,7 +72,7 @@ def train_model(
train_data = TextDataset(train_txt, torch.from_numpy(train_vec), s.tokenizer)

# Train the model
model, _ = train_supervised(train_data, s, device=device)
model, _ = train_supervised(train_dataset=train_data, model=s, device=device)

# Save the trained model
model.save_pretrained(save_path)
Expand Down

0 comments on commit 71f91b4

Please sign in to comment.