Skip to content

Commit

Permalink
Merge pull request #1 from boostcampaitech2/mingu
Browse files Browse the repository at this point in the history
add src final code
  • Loading branch information
deokgu1994 authored Sep 4, 2021
2 parents bf1909f + 1bbbffb commit 9c25136
Show file tree
Hide file tree
Showing 68 changed files with 5,403 additions and 1,046 deletions.
2 changes: 0 additions & 2 deletions .prettierignore

This file was deleted.

28 changes: 0 additions & 28 deletions README.md

This file was deleted.

6 changes: 6 additions & 0 deletions auto.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash

python train.py -c config_0.json
python train.py -c config_1.json
python train.py -c config_2.json
python train.py -c config_3.json
3 changes: 3 additions & 0 deletions base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base_data_loader import *
from .base_trainer import *
from .base_model import *
70 changes: 70 additions & 0 deletions base/base_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import SubsetRandomSampler
import copy


# FIXME: 메모리 할당이 너무 많이 일어난다 ... deepcopy가 아닌 다른걸로
class BaseDataLoader(DataLoader):
"""
Base class for all data loaders
"""
def __init__(self, data_set, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
self.validation_split = validation_split
self.shuffle = shuffle

self.batch_idx = 0
self.n_samples = len(data_set)

self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)

self.init_kwargs = {
'dataset': data_set,
'batch_size': batch_size,
'shuffle': self.shuffle,
'collate_fn': collate_fn,
'num_workers': num_workers
}
self.valid_init_kwargs = {
'dataset': copy.deepcopy(data_set),
'batch_size': batch_size,
'shuffle': self.shuffle,
'collate_fn': collate_fn,
'num_workers': num_workers
}
super().__init__(sampler=self.sampler, **self.init_kwargs)

def _split_sampler(self, split):
if split == 0.0:
return None, None

idx_full = np.arange(self.n_samples)

np.random.seed(0)
np.random.shuffle(idx_full)

if isinstance(split, int):
assert split > 0
assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
len_valid = split
else:
len_valid = int(self.n_samples * split)

valid_idx = idx_full[0:len_valid]
train_idx = np.delete(idx_full, np.arange(0, len_valid))

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
print("!" *10 , valid_sampler)
# turn off shuffle option which is mutually exclusive with sampler
self.shuffle = False
self.n_samples = len(train_idx)

return train_sampler, valid_sampler

def split_validation(self):
if self.valid_sampler is None:
return None
else:
return DataLoader(sampler=self.valid_sampler, **self.valid_init_kwargs)
25 changes: 25 additions & 0 deletions base/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch.nn as nn
import numpy as np
from abc import abstractmethod


class BaseModel(nn.Module):
"""
Base class for all models
"""
@abstractmethod
def forward(self, *inputs):
"""
Forward pass logic
:return: Model output
"""
raise NotImplementedError

def __str__(self):
"""
Model prints with number of trainable parameters
"""
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
return super().__str__() + '\nTrainable parameters: {}'.format(params)
173 changes: 173 additions & 0 deletions base/base_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import torch
from abc import abstractmethod
from numpy import inf
from logger import TensorboardWriter
from sklearn.model_selection import KFold , StratifiedKFold

class BaseTrainer:
"""
Base class for all trainers
"""
def __init__(self, model, criterion, metric_ftns, optimizer, config):
self.config = config
self.logger = config.get_logger('trainer', config['trainer']['verbosity'])

self.model = model
self.criterion = criterion
self.metric_ftns = metric_ftns
self.optimizer = optimizer

cfg_trainer = config['trainer']
self.epochs = cfg_trainer['epochs']
self.save_period = cfg_trainer['save_period']
self.monitor = cfg_trainer.get('monitor', 'off')

# configuration to monitor model performance and save best
if self.monitor == 'off':
self.mnt_mode = 'off'
self.mnt_best = 0
else:
self.mnt_mode, self.mnt_metric = self.monitor.split()
assert self.mnt_mode in ['min', 'max']

self.mnt_best = inf if self.mnt_mode == 'min' else -inf
self.early_stop = cfg_trainer.get('early_stop', inf)
if self.early_stop <= 0:
self.early_stop = inf

self.start_epoch = 1

self.checkpoint_dir = config.save_dir

if config.resume is not None:
self._resume_checkpoint(config.resume)

@abstractmethod
def _train_epoch(self, epoch):
"""
Training logic for an epoch
:param epoch: Current epoch number
"""
raise NotImplementedError

def train(self):
"""
Full training logic
"""
stratified_kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=None)
k_idx = 1
for train_index, validate_index in stratified_kfold.split(np.zeros(len(train_ages)), train_ages):
print(f'## Stratified_K-Fold :: {k_idx}')
k_idx += 1
train_dataset = torch.utils.data.dataset.Subset(dataset, train_index)
valid_dataset = torch.utils.data.dataset.Subset(dataset, validate_index)
valid_dataset = copy.deepcopy(valid_dataset)
valid_dataset.dataset.transform = transform_val

train_loader = DataLoader(train_dataset,
batch_size=4,
shuffle=True,
num_workers=0,
drop_last=True
)

val_loader = DataLoader(valid_dataset,
batch_size = 4,
shuffle=True,
num_workers=0
)

not_improved_count = 0
for epoch in range(self.start_epoch, self.epochs + 1):
result = self._train_epoch(epoch)

# save logged informations into log dict
log = {'epoch': epoch}
log.update(result)

# print logged informations to the screen
for key, value in log.items():
self.logger.info(' {:15s}: {}'.format(str(key), value))

# evaluate model performance according to configured metric, save best checkpoint as model_best
best = False
if self.mnt_mode != 'off':
try:
# check whether model performance improved or not, according to specified metric(mnt_metric)
improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
(self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
except KeyError:
self.logger.warning("Warning: Metric '{}' is not found. "
"Model performance monitoring is disabled.".format(self.mnt_metric))
self.mnt_mode = 'off'
improved = False

if improved:
self.mnt_best = log[self.mnt_metric]
not_improved_count = 0
best = True
else:
not_improved_count += 1

if not_improved_count > self.early_stop:
self.logger.info("Validation performance didn\'t improve for {} epochs. "
"Training stops.".format(self.early_stop))
break

if epoch % self.save_period == 0 or best: # 3번째 부터 저장한다
self._save_checkpoint(epoch, save_best=best)

def _save_checkpoint(self, epoch, save_best=False):
"""
Saving checkpoints
:param epoch: current epoch number
:param log: logging information of the epoch
:param save_best: if True, rename the saved checkpoint to 'model_best.pth'
"""
arch = type(self.model).__name__
state = {
'arch': arch,
'epoch': epoch,
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'monitor_best': self.mnt_best,
'config': self.config
}
if save_best:
best_path = str(self.checkpoint_dir / 'model_best-epoch{}.pth'.format(epoch))
torch.save(state, best_path)
self.logger.info("Saving current best: model_best.pth ...")
else:
filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch))
torch.save(state, filename)
self.logger.info("Saving checkpoint: {} ...".format(filename))


def _resume_checkpoint(self, resume_path):
"""
Resume from saved checkpoints
:param resume_path: Checkpoint path to be resumed
"""
resume_path = str(resume_path)
self.logger.info("Loading checkpoint: {} ...".format(resume_path))
checkpoint = torch.load(resume_path)
self.start_epoch = checkpoint['epoch'] + 1
self.mnt_best = checkpoint['monitor_best']

# load architecture params from checkpoint.
if checkpoint['config']['arch'] != self.config['arch']:
self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
"checkpoint. This may yield an exception while state_dict is being loaded.")
self.model.load_state_dict(checkpoint['state_dict'])

# load optimizer state from checkpoint only when optimizer type is not changed.
if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
"Optimizer parameters not being resumed.")
else:
self.optimizer.load_state_dict(checkpoint['optimizer'])

self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
78 changes: 78 additions & 0 deletions config_example.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
{
"name": "Maske",
"n_gpu": 1,
"test_name": "NO_resize_Base_C400_200_TransForm",
"data_set": {
"type": "CustomDataset",
"args": {
"dir_path": "/opt/ml/input/data/train/"
}
},
"set_transform": {
"type": "NO_resize_Base_C400_200_TransForm",
"args": {
"mean": [0.560, 0.524, 0.501],
"std": [0.233, 0.243, 0.245],
"resize": [512, 384],
"use_rand_aug": false
},
"cut_mix": false,
"mix_up": false
},
"data_loader":{
"type":"StratifiedKFold",
"args": {
"batch_size": 32,
"shuffle": false,
"validation_split": 0.1,
"num_workers": 0
}
},
"module": {
"type": "CustomModel",
"args": {
}
},
"optimizer": {
"type": "AdamP",
"args": {
"lr": 3e-4,
"betas": [0.9, 0.999],
"weight_decay": 1e-5
}
},
"lr_scheduler": {
"type": "ReduceLROnPlateau",
"args": {
"mode": "min",
"factor": 0.05,
"patience": 2
}
},
"set_loss": {
"type": "FocalLoss",
"args": {
"gamma" : 5
}
},
"metrics": [
"accuracy",
"top_10_acc",
"f1_score"
],
"trainer": {
"type": "multi_label",
"epochs": 3,
"save_dir": "logs/",
"save_period": 1,
"verbosity": 2,
"monitor": "min val_loss",
"early_stop": 1,
"tensorboard": false,
"beta": 0.8
},
"test":{
"path": "/opt/ml/code/src/logs/log/Maske/0902_013705/model_best-epoch3.pth"
}
}

Loading

0 comments on commit 9c25136

Please sign in to comment.