Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extend transfer-learning to include mobility #303

Merged
merged 1 commit into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions alphadia/transferlearning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd
import torch
from alphabase.peptide.fragment import remove_unused_fragments
from alphabase.peptide.mobility import ccs_to_mobility_for_df, mobility_to_ccs_for_df
from peptdeep.model.charge import ChargeModelForModAASeq
from peptdeep.model.model_interface import CallbackHandler, LR_SchedulerInterface
from peptdeep.pretrained_models import ModelManager
Expand Down Expand Up @@ -896,3 +897,166 @@ def finetune_charge(self, psm_df: pd.DataFrame) -> pd.DataFrame:
metrics = test_metric_manager.get_stats()

return metrics

def _test_ccs(
self,
epoch: int,
epoch_loss: float,
test_df: pd.DataFrame,
metric_accumulator: MetricManager,
data_split: str,
) -> bool:
"""
Test the CCS model using the PSM dataframe and accumulate both the training loss and test metrics.

Parameters
----------
epoch : int
The current epoch number.
epoch_loss : float
The train loss value of the current epoch.
test_df : pd.DataFrame
The PSM dataframe.
metric_accumulator : MetricManager
The metric manager object.
data_split : str
The dataset label to test on. e.g. "validation", "train"
Returns
-------
bool
Whether to continue training or not based on the early stopping criteria.
"""
continue_training = True
if epoch % self.settings["test_interval"] == 0:
mo-sameh marked this conversation as resolved.
Show resolved Hide resolved
self.ccs_model.model.eval()

pred = self.ccs_model.predict(test_df)

test_input = {
"predicted": pred["ccs_pred"].values,
"target": test_df["ccs"].values,
}
val_metrics = metric_accumulator.calculate_test_metric(
test_input, epoch, data_split=data_split, property_name="ccs"
)
if epoch != -1 and data_split == "validation":
mo-sameh marked this conversation as resolved.
Show resolved Hide resolved
metric_accumulator.accumulate_metrics(
epoch,
metric=epoch_loss,
metric_name="l1_loss",
data_split="train",
property_name="ccs",
)
current_lr = self.ccs_model.optimizer.param_groups[0]["lr"]
metric_accumulator.accumulate_metrics(
epoch,
metric=current_lr,
metric_name="lr",
data_split="train",
property_name="ccs",
)
val_loss = val_metrics[val_metrics["metric_name"] == "l1_loss"][
"value"
].values[0]
continue_training = self.early_stopping.step(val_loss)
logger.progress(
f" Epoch {epoch:<3} Lr: {current_lr:.5f} Training loss: {epoch_loss:.4f} validation loss: {val_loss:.4f}"
)
else:
logger.progress(
f" CCS model tested on {data_split} dataset with the following metrics:"
)
for i in range(len(val_metrics)):
logger.progress(
f" {val_metrics['metric_name'].values[i]:<30}: {val_metrics['value'].values[i]:.4f}"
)

self.ccs_model.model.train()
mo-sameh marked this conversation as resolved.
Show resolved Hide resolved

return continue_training

def finetune_ccs(self, psm_df: pd.DataFrame) -> pd.DataFrame:
"""
Fine tune the CCS model using the PSM dataframe.

Parameters
----------
psm_df : pd.DataFrame
The PSM dataframe.

Returns
-------
pd.DataFrame
Accumulated metrics during the fine tuning process.
"""
if "mobility" not in psm_df.columns and "ccs" not in psm_df.columns:
logger.error(
"Failed to finetune CCS model. PSM dataframe does not contain mobility or ccs columns."
)
return
if "ccs" not in psm_df.columns:
psm_df["ccs"] = mobility_to_ccs_for_df(psm_df, "mobility")
elif "mobility" not in psm_df.columns:
psm_df["mobility"] = ccs_to_mobility_for_df(psm_df, "ccs")

# Shuffle the psm_df and split it into train and test
train_df = psm_df.sample(frac=self.settings["train_fraction"])
val_df = psm_df.drop(train_df.index).sample(
frac=self.settings["validation_fraction"]
/ (1 - self.settings["train_fraction"])
)
test_df = psm_df.drop(train_df.index).drop(val_df.index)

# Create a test metric manager
mo-sameh marked this conversation as resolved.
Show resolved Hide resolved
test_metric_manager = MetricManager(
test_metrics=[
L1LossTestMetric(),
LinearRegressionTestMetric(),
AbsErrorPercentileTestMetric(95),
],
)
# Create a callback handler
callback_handler = CustomCallbackHandler(
self._test_ccs,
test_df=val_df,
metric_accumulator=test_metric_manager,
data_split="validation",
)
# Set the callback handler
self.ccs_model.set_callback_handler(callback_handler)

# Change the learning rate scheduler
self.ccs_model.set_lr_scheduler_class(CustomScheduler)

# Reset the early stopping
self.early_stopping.reset()

# Test the model before training
self._test_ccs(-1, 0, psm_df, test_metric_manager, data_split="all")
# Train the model
logger.progress(" Fine-tuning CCS model with the following settings:")
logger.info(
f" Train fraction: {self.settings['train_fraction']:3.2f} Train size: {len(train_df):<10}"
)
logger.info(
f" Validation fraction: {self.settings['validation_fraction']:3.2f} Validation size: {len(val_df):<10}"
)
logger.info(
f" Test fraction: {self.settings['test_fraction']:3.2f} Test size: {len(test_df):<10}"
)
self.ccs_model.model.train()
self.ccs_model.train(
train_df,
batch_size=self.settings["batch_size"],
epoch=self.settings["epochs"],
warmup_epoch=self.settings["warmup_epochs"],
lr=settings["max_lr"],
)

self._test_ccs(
mo-sameh marked this conversation as resolved.
Show resolved Hide resolved
self.settings["epochs"], 0, test_df, test_metric_manager, data_split="test"
)

metrics = test_metric_manager.get_stats()

return metrics
72 changes: 60 additions & 12 deletions nbs/tutorial_nbs/finetuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -13,6 +13,7 @@
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from alphabase.spectral_library.base import SpecLibBase\n",
"from alphadia.workflow.reporting import *\n",
"from alphadia.transferlearning.train import *\n",
"\n",
"import seaborn as sns\n",
Expand All @@ -21,7 +22,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -31,7 +32,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -41,18 +42,70 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"transfer_lib.precursor_df = transfer_lib.precursor_df[~transfer_lib.precursor_df['mods'].str.contains('Dimethyl@C')]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"tune_mgr = FinetuneManager(\n",
" device=\"gpu\",\n",
" settings=settings)\n",
"tune_mgr.nce = 25\n",
"tune_mgr.instrument = 'Lumos'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Util function to plot the metrics"
"## CCS Fine-tuning\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"transfer_lib.precursor_df = transfer_lib.precursor_df.dropna(subset=['mobility']) # drop rows with na in the mobility column\n",
"transfer_lib.precursor_df = tune_mgr.predict_mobility(transfer_lib.precursor_df)\n",
"plt.scatter(transfer_lib.precursor_df['mobility'], transfer_lib.precursor_df['mobility_pred'], s=1, alpha=0.1)\n",
"plt.xlabel('mobility observed')\n",
"plt.ylabel('mobility predicted')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ccs_stats = tune_mgr.finetune_ccs(transfer_lib.precursor_df)\n",
"\n",
"transfer_lib.precursor_df = tune_mgr.ccs_model.predict(transfer_lib.precursor_df)\n",
"plt.scatter(transfer_lib.precursor_df['ccs'], transfer_lib.precursor_df['ccs_pred'], s=1, alpha=0.1)\n",
"plt.xlabel('ccs observed')\n",
"plt.ylabel('ccs predicted')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"g = sns.relplot(data=ccs_stats, x='epoch', y='value', hue='data_split', marker= 'o',dashes=False, col='metric_name', kind='line', col_wrap=2, facet_kws={'sharex': False, 'sharey': False, 'legend_out': False})\n",
"g.set_titles(\"{col_name}\")\n",
"g.legend.set_title('Data split')"
]
},
{
Expand All @@ -69,11 +122,6 @@
"outputs": [],
"source": [
"\n",
"tune_mgr = FinetuneManager(\n",
" device=\"gpu\",\n",
" settings=settings)\n",
"tune_mgr.nce = 25\n",
"tune_mgr.instrument = 'Lumos'\n",
"transfer_lib.precursor_df = tune_mgr.predict_rt(transfer_lib.precursor_df)\n",
"plt.scatter(transfer_lib.precursor_df['rt_norm'], transfer_lib.precursor_df['rt_norm_pred'], s=1, alpha=0.1)\n",
"plt.xlabel('RT observed')\n",
Expand Down Expand Up @@ -144,7 +192,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -154,7 +202,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down
Loading