Skip to content

Commit

Permalink
aligned with partial original implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
TeodorPoncu committed Sep 14, 2022
1 parent aa95139 commit 1fddecc
Show file tree
Hide file tree
Showing 5 changed files with 280 additions and 59 deletions.
6 changes: 4 additions & 2 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ def __init__(
interpolation=InterpolationMode.BILINEAR,
hflip_prob=0.5,
auto_augment_policy=None,
policy_magnitude=9,
random_erase_prob=0.0,
center_crop=False,
):
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] if center_crop else [transforms.CenterCrop(crop_size)]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment(interpolation=interpolation))
trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=policy_magnitude))
elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
elif auto_augment_policy == "augmix":
Expand Down
122 changes: 122 additions & 0 deletions references/classification/run_with_submitit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import argparse
import os
import uuid
from pathlib import Path

import train
import submitit


def parse_args():
train_parser = train.get_args_parser(add_help=False)
parser = argparse.ArgumentParser("Submitit for train", parents=[train_parser], add_help=True)
parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
parser.add_argument("--nodes", default=1, type=int, help="Number of nodes to request")
parser.add_argument("--timeout", default=60*24*30, type=int, help="Duration of the job")
parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.")
parser.add_argument("--partition", default="train", type=str, help="the partition (default train).")
return parser.parse_args()


def get_shared_folder() -> Path:
user = os.getenv("USER")
path = "/data/checkpoints"
if Path(path).is_dir():
p = Path(f"{path}/{user}/experiments")
p.mkdir(exist_ok=True)
return p
raise RuntimeError("No shared folder available")


def get_init_file_folder() -> Path:
user = os.getenv("USER")
path = "/shared"
if Path(path).is_dir():
p = Path(f"{path}/{user}")
p.mkdir(exist_ok=True)
return p
raise RuntimeError("No shared folder available")


def get_init_file():
# Init file must not exist, but it's parent dir must exist.
os.makedirs(str(get_init_file_folder()), exist_ok=True)
init_file = get_init_file_folder() / f"{uuid.uuid4().hex}_init"
if init_file.exists():
os.remove(str(init_file))
return init_file


class Trainer(object):
def __init__(self, args):
self.args = args

def __call__(self):
import train

self._setup_gpu_args()
train.main(self.args)

def checkpoint(self):
import os
import submitit
from pathlib import Path

self.args.dist_url = get_init_file().as_uri()
checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth")
if os.path.exists(checkpoint_file):
self.args.resume = checkpoint_file
print("Requeuing ", self.args)
empty_trainer = type(self)(self.args)
return submitit.helpers.DelayedSubmission(empty_trainer)

def _setup_gpu_args(self):
import submitit
from pathlib import Path

job_env = submitit.JobEnvironment()
self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id)))
self.args.gpu = job_env.local_rank
self.args.rank = job_env.global_rank
self.args.world_size = job_env.num_tasks
print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")


def main():
args = parse_args()
if args.job_dir == "":
args.job_dir = get_shared_folder() / "%j"

# Note that the folder will depend on the job_id, to easily track experiments
executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=300)

# cluster setup is defined by environment variables
num_gpus_per_node = args.ngpus
nodes = args.nodes
timeout_min = args.timeout

executor.update_parameters(
#mem_gb=96 * num_gpus_per_node, # 768GB per machine
gpus_per_node=num_gpus_per_node,
tasks_per_node=num_gpus_per_node, # one task per GPU
cpus_per_task=12, # 96 cpus per machine
nodes=nodes,
timeout_min=timeout_min, # max is 60 * 72
slurm_partition=args.partition,
slurm_signal_delay_s=120,
)


executor.update_parameters(name="torchvision")

args.dist_url = get_init_file().as_uri()
args.output_dir = args.job_dir

trainer = Trainer(args)
job = executor.submit(trainer)

print("Submitted job_id:", job.job_id)


if __name__ == "__main__":
main()
47 changes: 36 additions & 11 deletions references/classification/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import os
import random
import time
import warnings

Expand All @@ -15,7 +16,7 @@
from torchvision.transforms.functional import InterpolationMode


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None, scheduler=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
Expand Down Expand Up @@ -43,6 +44,9 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
if args.clip_grad_norm is not None:
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
optimizer.step()

if scheduler is not None and args.lr_step_every_batch:
scheduler.step()

if model_ema and i % args.model_ema_steps == 0:
model_ema.update_parameters(model)
Expand Down Expand Up @@ -113,7 +117,7 @@ def _get_cache_path(filepath):
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
val_resize_size, val_crop_size, train_crop_size, center_crop, policy_magnitude = args.val_resize_size, args.val_crop_size, args.train_crop_size, args.train_center_crop, args.policy_magnitude
interpolation = InterpolationMode(args.interpolation)

print("Loading training data")
Expand All @@ -129,10 +133,12 @@ def load_data(traindir, valdir, args):
dataset = torchvision.datasets.ImageFolder(
traindir,
presets.ClassificationPresetTrain(
center_crop=center_crop,
crop_size=train_crop_size,
interpolation=interpolation,
auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob,
policy_magnitude=policy_magnitude,
),
)
if args.cache_dataset:
Expand Down Expand Up @@ -182,7 +188,12 @@ def load_data(traindir, valdir, args):
def main(args):
if args.output_dir:
utils.mkdir(args.output_dir)


if args.seed is None:
# randomly choose a seed
args.seed = random.randint(0, 2 ** 32)
utils.set_seed(args.seed)

utils.init_distributed_mode(args)
print(args)

Expand Down Expand Up @@ -261,13 +272,21 @@ def main(args):
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")

scaler = torch.cuda.amp.GradScaler() if args.amp else None

batches_per_epoch = len(data_loader)
warmup_iters = args.lr_warmup_epochs
total_iters = args.epochs

if args.lr_step_every_batch:
warmup_iters *= batches_per_epoch
total_iters *= batches_per_epoch

args.lr_scheduler = args.lr_scheduler.lower()
if args.lr_scheduler == "steplr":
main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
elif args.lr_scheduler == "cosineannealinglr":
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
optimizer, T_max=total_iters - warmup_iters, eta_min=args.lr_min
)
elif args.lr_scheduler == "exponentiallr":
main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
Expand All @@ -280,18 +299,18 @@ def main(args):
if args.lr_warmup_epochs > 0:
if args.lr_warmup_method == "linear":
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
)
elif args.lr_warmup_method == "constant":
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
)
else:
raise RuntimeError(
f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
)
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
)
else:
lr_scheduler = main_lr_scheduler
Expand Down Expand Up @@ -341,8 +360,9 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
lr_scheduler.step()
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler, lr_scheduler)
if not args.lr_step_every_batch:
lr_scheduler.step()
evaluate(model, criterion, data_loader_test, device=device)
if model_ema:
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
Expand Down Expand Up @@ -371,7 +391,7 @@ def get_args_parser(add_help=True):

parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)

parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
parser.add_argument("--data-path", default="/datasets01_ontap/imagenet_full_size/061417/", type=str, help="dataset path")
parser.add_argument("--model", default="resnet18", type=str, help="model name")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument(
Expand Down Expand Up @@ -425,6 +445,7 @@ def get_args_parser(add_help=True):
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
parser.add_argument("--lr-step-every-batch", action="store_true", help="decrease lr every step-size batches", default=False)
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
Expand All @@ -448,6 +469,7 @@ def get_args_parser(add_help=True):
action="store_true",
)
parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
parser.add_argument("--policy-magnitude", default=9, type=int, help="magnitude of auto augment policy")
parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")

# Mixed precision training parameters
Expand Down Expand Up @@ -486,13 +508,16 @@ def get_args_parser(add_help=True):
parser.add_argument(
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)
parser.add_argument(
"--train-center-crop", action="store_true", help="use center crop instead of random crop for training (default: False)"
)
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
parser.add_argument(
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

parser.add_argument("--seed", default=None, type=int, help="the seed for randomness (default: None). A `None` value means a seed will be randomly generated")
return parser


Expand Down
12 changes: 12 additions & 0 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import torch
import torch.distributed as dist
import numpy as np
import random


class SmoothedValue:
Expand Down Expand Up @@ -463,3 +465,13 @@ def _add_params(module, prefix=""):
if len(params[key]) > 0:
param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
return param_groups

def set_seed(seed: int):
"""
Function for setting all the RNGs to the same seed
"""
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
Loading

0 comments on commit 1fddecc

Please sign in to comment.