Skip to content

Commit

Permalink
Feat 27/decrease checkpoints (#28)
Browse files Browse the repository at this point in the history
* ✨ Add custom checkpoint callback

* 🐛 Enable logging and checkpointing(no duplicates?)

* ✨ Change pbt, pb2 into ASHA and add the related logic

* ✏️ Forgot to erase

* ♻️ Change train_func as the static method of RayTuner

* ✨ Add ability to disable ddp

* ✏️ Fix typo

* ♻️ Refactor some nested dict

* ⚡️ Minor parameter adjustment
  • Loading branch information
Haneol-Kijm authored Sep 19, 2024
1 parent 5286f36 commit 452fe78
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 103 deletions.
49 changes: 15 additions & 34 deletions config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@ def __init__(self):
class TrainingConfig:
"""Training-related configuration."""
def __init__(self):
# self.batch_size = tune.choice([32, 64, 128])
# self.lr = tune.loguniform(0.001, 0.1)
# self.weight_decay = tune.loguniform(0.001, 0.1)
self.batch_size = 64
self.lr = (0.001, 0.1)
self.weight_decay = (0.001, 0.1)
self.batch_size = tune.choice([32, 64, 128])
self.lr = tune.loguniform(0.001, 0.1)
self.weight_decay = tune.loguniform(0.001, 0.1)



Expand All @@ -33,9 +30,10 @@ def __init__(self):
self.save_dir = "/data/ephemeral/home/logs/"
self.num_gpus = 1
self.max_epochs = 100
self.num_workers = 2 # number of cpus workers in dataloader
self.num_samples = 3 # number of workers in population-based training(pbt)
self.checkpoint_interval = 5 # number of intervals to save checkpoint in pbt.
self.num_workers = 1 # number of workers in scheduling
self.num_samples = 10 # number of workers in ray tune
# self.checkpoint_interval = 5 # number of intervals to save checkpoint in pbt.
self.ddp = False


class Config:
Expand All @@ -47,34 +45,17 @@ def __init__(self):
self.experiment = ExperimentConfig()

self.search_space = {
# 'batch_size': self.training.batch_size,
'batch_size': self.training.batch_size,
'lr': self.training.lr,
'weight_decay': self.training.weight_decay,
}

def flatten_to_dict(self):
return {
**vars(self.model),
**vars(self.training),
**self.search_space,
**vars(self.dataset),
**vars(self.experiment)
}

def to_nested_dict(self):
return {
'model': vars(self.model),
'training': vars(self.training),
'dataset': vars(self.dataset),
'experiment': vars(self.experiment),
**self.search_space
}

def to_nested_dict2(self):
return {
'model': vars(self.model),
'training': vars(self.training),
'dataset': vars(self.dataset),
'experiment': vars(self.experiment),
}
def flatten_to_dict(self):
flattened_dict = {}
for key, value in vars(self).items():
if key != 'search_space' and key != 'training' and hasattr(value, '__dict__'):
for subkey, subvalue in vars(value).items():
flattened_dict[f"{key}_{subkey}"] = subvalue
return flattened_dict

3 changes: 3 additions & 0 deletions engine/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
from datetime import datetime
from tempfile import TemporaryDirectory

import numpy as np
import pandas as pd
from lightning.pytorch.callbacks import Callback
from ray import train
from ray.train import Checkpoint

class PredictionCallback(Callback):
def __init__(self, data_path, ckpt_dir, model_name):
Expand Down
4 changes: 2 additions & 2 deletions engine/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def run_test(config, ckpt_dir):

# Define the trainer for testing
pred_callback = PredictionCallback(f"{config.dataset.data_path}/test.csv", ckpt_dir, config.model.model_name)
trainer_test = Trainer(callbacks=[pred_callback], logger=False, enable_progress_bar=False,)
best_model = LightningModule.load_from_checkpoint(f"{ckpt_dir}/pltrainer.ckpt")
trainer_test = Trainer(callbacks=[pred_callback], logger=False, enable_progress_bar=True,)
best_model = LightningModule.load_from_checkpoint(f"{ckpt_dir}/checkpoint.ckpt")
# Conduct testing with the loaded model
trainer_test.test(best_model, dataloaders=test_loader)
162 changes: 98 additions & 64 deletions engine/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,30 @@
from datetime import datetime
import ray
from ray import train, tune
# from ray.tune.schedulers import PopulationBasedTraining
from ray.tune.schedulers.pb2 import PB2
from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback
from ray.tune.schedulers import ASHAScheduler
from ray.air.integrations.wandb import WandbLoggerCallback
from lightning import Trainer

from ray.train import RunConfig, ScalingConfig, CheckpointConfig
from ray.train.lightning import (
RayDDPStrategy,
RayLightningEnvironment,
RayTrainReportCallback,
prepare_trainer,
)
from dataset import get_dataloaders
from model import LightningModule


def train_func(config_dict): # Note that config_dict is dict here passed by pbt schduler
# Create the dataloaders
train_loader, val_loader = get_dataloaders(
data_path=config_dict['dataset']['data_path'],
batch_size=config_dict['training']['batch_size'],
num_workers=config_dict['experiment']['num_workers']
)
model = LightningModule(config_dict)
# model = LightningModule(config_dict)

trainer = Trainer(
max_epochs=config_dict['experiment']['max_epochs'],
accelerator='gpu',
devices=config_dict['experiment']['num_gpus'],
strategy='ddp',
logger=False,
callbacks=[TuneReportCheckpointCallback(
metrics={"val_loss": "val_loss", "val_acc": "val_acc"},
filename="pltrainer.ckpt", on="validation_end",
)],
enable_progress_bar=False,
)

trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

from ray.train.torch import TorchTrainer

class RayTuner:
def __init__(self, config):
self.config = config # config is Config class here consisting of 4 subclass config

# Define a TorchTrainer without hyper-parameters for Tuner
self.ray_trainer = TorchTrainer(
self._train_func,
train_loop_config=self.config.flatten_to_dict(),
scaling_config=self._define_scaling_config(),
run_config=self._define_run_config(),
)
def __enter__(self):
if ray.is_initialized():
ray.shutdown()
Expand All @@ -52,36 +36,39 @@ def __exit__(self, type, value, trace_back):
ray.shutdown()

def _define_scheduler(self):
# Define the population-based training scheduler
# pbt_scheduler = PopulationBasedTraining(
# time_attr="training_iteration",
# perturbation_interval=self.config.experiment.checkpoint_interval,
# metric="val_loss",
# mode="min",
# hyperparam_mutations=self.config.search_space,
# )
pbt_scheduler = PB2(
time_attr="training_iteration",
perturbation_interval=self.config.experiment.checkpoint_interval,
metric="val_loss",
mode="min",
hyperparam_bounds=self.config.search_space,
)
return pbt_scheduler
scheduler = ASHAScheduler(
max_t=self.config.experiment.max_epochs,
grace_period=10,
reduction_factor=2,
brackets=3,
)
return scheduler

def _define_tune_config(self):
tune_config = tune.TuneConfig(
scheduler=self._define_scheduler(),
metric="val_loss",
mode="min",
num_samples=self.config.experiment.num_samples,
scheduler=self._define_scheduler(),
)
return tune_config


def _define_scaling_config(self):
scaling_config = ScalingConfig(
num_workers=self.config.experiment.num_workers,
use_gpu=True,
resources_per_worker={
"CPU": 6/self.config.experiment.num_workers,
"GPU": 1/self.config.experiment.num_workers
},
)
return scaling_config
def _define_run_config(self):
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M")
run_config = train.RunConfig(
run_config = RunConfig(
name=f"{self.config.model.model_name}_tune_runs_{current_time}",
checkpoint_config=train.CheckpointConfig(
num_to_keep=10,
checkpoint_config=CheckpointConfig(
num_to_keep=2,
checkpoint_score_attribute="val_loss",
checkpoint_score_order="min",
),
Expand All @@ -90,20 +77,67 @@ def _define_run_config(self):
verbose=1,
)
return run_config
def _define_pltrainer(self):
if self.config.experiment.ddp:
trainer = Trainer(
max_epochs=self.config.experiment.max_epochs,
devices='auto',
accelerator='auto',
strategy=RayDDPStrategy(),
callbacks=[RayTrainReportCallback()],
plugins=[RayLightningEnvironment()],
enable_progress_bar=False,
)

trainer = prepare_trainer(trainer)
else:
trainer = Trainer(
max_epochs=self.config.experiment.max_epochs,
devices=self.config.experiment.num_gpus,
accelerator='auto',
strategy='auto',
callbacks=[RayTrainReportCallback()],
enable_checkpointing=False,
enable_progress_bar=False,
)

return trainer

def _train_func(self, config_dict): # TODO: Clean up nested dict since it is now method
def flatten_to_nested(flattened_dict):
# transforms the dict of the form {key}_{subkey}:value to nested dict.
nested_dict = {'dataset': {}, 'model': {}, 'experiment': {}}
expected_keys = ['dataset', 'model', 'experiment']
for key, value in flattened_dict.items():
if "_" in key:
parts = key.split("_")
subkey = '_'.join(parts[1:])
if parts[0] in expected_keys:
nested_dict[parts[0]][subkey] = value
else:
nested_dict[key] = value
else:
nested_dict[key] = value
return nested_dict

config_dict = flatten_to_nested(config_dict)
# Create the dataloaders
train_loader, val_loader = get_dataloaders(
data_path=self.config.dataset.data_path,
batch_size=config_dict['batch_size'],
num_workers=2
)
model = LightningModule(config_dict)

trainer = self._define_pltrainer()

trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

def tune_and_train(self):
param_space = self.config.to_nested_dict2()
tuner = tune.Tuner(
tune.with_resources(
train_func,
resources={
"cpu": 6/self.config.experiment.num_samples,
"gpu": 1/self.config.experiment.num_samples
}
),
param_space=param_space, # Hyperparameter search space
self.ray_trainer,
param_space={"train_loop_config": self.config.search_space}, # Hyperparameter search space
tune_config=self._define_tune_config(), # Tuner configuration
run_config=self._define_run_config(), # Run environment configuration
)
)
result_grid = tuner.fit() ## Actual training happens here
return result_grid
6 changes: 3 additions & 3 deletions model/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def training_step(self, train_batch, batch_idx):
x, y = train_batch
output = self.forward(x)
loss = torch.nn.CrossEntropyLoss()(output, y)
self.log('train_loss', loss)
self.log('train_loss', loss, sync_dist=True)
return {'loss': loss}

def validation_step(self, val_batch, batch_idx):
Expand All @@ -64,8 +64,8 @@ def validation_step(self, val_batch, batch_idx):
loss = torch.nn.CrossEntropyLoss()(output, y)
_, predicted = torch.max(output, 1)
accuracy = (predicted == y).sum().item() / len(x)
self.log('val_loss', loss)
self.log('val_acc', accuracy)
self.log('val_loss', loss, sync_dist=True)
self.log('val_acc', accuracy, sync_dist=True)

def test_step(self, test_batch, batch_idx):
"""
Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def main(config):
parser.add_argument('--model-name', type=str, help='Name of the model to use.')
parser.add_argument('--num-gpus', type=int, help='Name of the model to use.')
parser.add_argument('--smoke-test', action='store_true', help='Perform a small trial to test the setup.')
parser.add_argument('--ddp', action='store_true', help='Perform the distributed data parallel. Only use when you have multiple gpus.')
args = parser.parse_args()

# Initialize and configure the model configuration object
Expand Down

0 comments on commit 452fe78

Please sign in to comment.