-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bf1909f
commit 1bbbffb
Showing
68 changed files
with
5,403 additions
and
1,046 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/bin/bash | ||
|
||
python train.py -c config_0.json | ||
python train.py -c config_1.json | ||
python train.py -c config_2.json | ||
python train.py -c config_3.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .base_data_loader import * | ||
from .base_trainer import * | ||
from .base_model import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import numpy as np | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data.dataloader import default_collate | ||
from torch.utils.data.sampler import SubsetRandomSampler | ||
import copy | ||
|
||
|
||
# FIXME: 메모리 할당이 너무 많이 일어난다 ... deepcopy가 아닌 다른걸로 | ||
class BaseDataLoader(DataLoader): | ||
""" | ||
Base class for all data loaders | ||
""" | ||
def __init__(self, data_set, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): | ||
self.validation_split = validation_split | ||
self.shuffle = shuffle | ||
|
||
self.batch_idx = 0 | ||
self.n_samples = len(data_set) | ||
|
||
self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) | ||
|
||
self.init_kwargs = { | ||
'dataset': data_set, | ||
'batch_size': batch_size, | ||
'shuffle': self.shuffle, | ||
'collate_fn': collate_fn, | ||
'num_workers': num_workers | ||
} | ||
self.valid_init_kwargs = { | ||
'dataset': copy.deepcopy(data_set), | ||
'batch_size': batch_size, | ||
'shuffle': self.shuffle, | ||
'collate_fn': collate_fn, | ||
'num_workers': num_workers | ||
} | ||
super().__init__(sampler=self.sampler, **self.init_kwargs) | ||
|
||
def _split_sampler(self, split): | ||
if split == 0.0: | ||
return None, None | ||
|
||
idx_full = np.arange(self.n_samples) | ||
|
||
np.random.seed(0) | ||
np.random.shuffle(idx_full) | ||
|
||
if isinstance(split, int): | ||
assert split > 0 | ||
assert split < self.n_samples, "validation set size is configured to be larger than entire dataset." | ||
len_valid = split | ||
else: | ||
len_valid = int(self.n_samples * split) | ||
|
||
valid_idx = idx_full[0:len_valid] | ||
train_idx = np.delete(idx_full, np.arange(0, len_valid)) | ||
|
||
train_sampler = SubsetRandomSampler(train_idx) | ||
valid_sampler = SubsetRandomSampler(valid_idx) | ||
print("!" *10 , valid_sampler) | ||
# turn off shuffle option which is mutually exclusive with sampler | ||
self.shuffle = False | ||
self.n_samples = len(train_idx) | ||
|
||
return train_sampler, valid_sampler | ||
|
||
def split_validation(self): | ||
if self.valid_sampler is None: | ||
return None | ||
else: | ||
return DataLoader(sampler=self.valid_sampler, **self.valid_init_kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import torch.nn as nn | ||
import numpy as np | ||
from abc import abstractmethod | ||
|
||
|
||
class BaseModel(nn.Module): | ||
""" | ||
Base class for all models | ||
""" | ||
@abstractmethod | ||
def forward(self, *inputs): | ||
""" | ||
Forward pass logic | ||
:return: Model output | ||
""" | ||
raise NotImplementedError | ||
|
||
def __str__(self): | ||
""" | ||
Model prints with number of trainable parameters | ||
""" | ||
model_parameters = filter(lambda p: p.requires_grad, self.parameters()) | ||
params = sum([np.prod(p.size()) for p in model_parameters]) | ||
return super().__str__() + '\nTrainable parameters: {}'.format(params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
import torch | ||
from abc import abstractmethod | ||
from numpy import inf | ||
from logger import TensorboardWriter | ||
from sklearn.model_selection import KFold , StratifiedKFold | ||
|
||
class BaseTrainer: | ||
""" | ||
Base class for all trainers | ||
""" | ||
def __init__(self, model, criterion, metric_ftns, optimizer, config): | ||
self.config = config | ||
self.logger = config.get_logger('trainer', config['trainer']['verbosity']) | ||
|
||
self.model = model | ||
self.criterion = criterion | ||
self.metric_ftns = metric_ftns | ||
self.optimizer = optimizer | ||
|
||
cfg_trainer = config['trainer'] | ||
self.epochs = cfg_trainer['epochs'] | ||
self.save_period = cfg_trainer['save_period'] | ||
self.monitor = cfg_trainer.get('monitor', 'off') | ||
|
||
# configuration to monitor model performance and save best | ||
if self.monitor == 'off': | ||
self.mnt_mode = 'off' | ||
self.mnt_best = 0 | ||
else: | ||
self.mnt_mode, self.mnt_metric = self.monitor.split() | ||
assert self.mnt_mode in ['min', 'max'] | ||
|
||
self.mnt_best = inf if self.mnt_mode == 'min' else -inf | ||
self.early_stop = cfg_trainer.get('early_stop', inf) | ||
if self.early_stop <= 0: | ||
self.early_stop = inf | ||
|
||
self.start_epoch = 1 | ||
|
||
self.checkpoint_dir = config.save_dir | ||
|
||
if config.resume is not None: | ||
self._resume_checkpoint(config.resume) | ||
|
||
@abstractmethod | ||
def _train_epoch(self, epoch): | ||
""" | ||
Training logic for an epoch | ||
:param epoch: Current epoch number | ||
""" | ||
raise NotImplementedError | ||
|
||
def train(self): | ||
""" | ||
Full training logic | ||
""" | ||
stratified_kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=None) | ||
k_idx = 1 | ||
for train_index, validate_index in stratified_kfold.split(np.zeros(len(train_ages)), train_ages): | ||
print(f'## Stratified_K-Fold :: {k_idx}') | ||
k_idx += 1 | ||
train_dataset = torch.utils.data.dataset.Subset(dataset, train_index) | ||
valid_dataset = torch.utils.data.dataset.Subset(dataset, validate_index) | ||
valid_dataset = copy.deepcopy(valid_dataset) | ||
valid_dataset.dataset.transform = transform_val | ||
|
||
train_loader = DataLoader(train_dataset, | ||
batch_size=4, | ||
shuffle=True, | ||
num_workers=0, | ||
drop_last=True | ||
) | ||
|
||
val_loader = DataLoader(valid_dataset, | ||
batch_size = 4, | ||
shuffle=True, | ||
num_workers=0 | ||
) | ||
|
||
not_improved_count = 0 | ||
for epoch in range(self.start_epoch, self.epochs + 1): | ||
result = self._train_epoch(epoch) | ||
|
||
# save logged informations into log dict | ||
log = {'epoch': epoch} | ||
log.update(result) | ||
|
||
# print logged informations to the screen | ||
for key, value in log.items(): | ||
self.logger.info(' {:15s}: {}'.format(str(key), value)) | ||
|
||
# evaluate model performance according to configured metric, save best checkpoint as model_best | ||
best = False | ||
if self.mnt_mode != 'off': | ||
try: | ||
# check whether model performance improved or not, according to specified metric(mnt_metric) | ||
improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ | ||
(self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) | ||
except KeyError: | ||
self.logger.warning("Warning: Metric '{}' is not found. " | ||
"Model performance monitoring is disabled.".format(self.mnt_metric)) | ||
self.mnt_mode = 'off' | ||
improved = False | ||
|
||
if improved: | ||
self.mnt_best = log[self.mnt_metric] | ||
not_improved_count = 0 | ||
best = True | ||
else: | ||
not_improved_count += 1 | ||
|
||
if not_improved_count > self.early_stop: | ||
self.logger.info("Validation performance didn\'t improve for {} epochs. " | ||
"Training stops.".format(self.early_stop)) | ||
break | ||
|
||
if epoch % self.save_period == 0 or best: # 3번째 부터 저장한다 | ||
self._save_checkpoint(epoch, save_best=best) | ||
|
||
def _save_checkpoint(self, epoch, save_best=False): | ||
""" | ||
Saving checkpoints | ||
:param epoch: current epoch number | ||
:param log: logging information of the epoch | ||
:param save_best: if True, rename the saved checkpoint to 'model_best.pth' | ||
""" | ||
arch = type(self.model).__name__ | ||
state = { | ||
'arch': arch, | ||
'epoch': epoch, | ||
'state_dict': self.model.state_dict(), | ||
'optimizer': self.optimizer.state_dict(), | ||
'monitor_best': self.mnt_best, | ||
'config': self.config | ||
} | ||
if save_best: | ||
best_path = str(self.checkpoint_dir / 'model_best-epoch{}.pth'.format(epoch)) | ||
torch.save(state, best_path) | ||
self.logger.info("Saving current best: model_best.pth ...") | ||
else: | ||
filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) | ||
torch.save(state, filename) | ||
self.logger.info("Saving checkpoint: {} ...".format(filename)) | ||
|
||
|
||
def _resume_checkpoint(self, resume_path): | ||
""" | ||
Resume from saved checkpoints | ||
:param resume_path: Checkpoint path to be resumed | ||
""" | ||
resume_path = str(resume_path) | ||
self.logger.info("Loading checkpoint: {} ...".format(resume_path)) | ||
checkpoint = torch.load(resume_path) | ||
self.start_epoch = checkpoint['epoch'] + 1 | ||
self.mnt_best = checkpoint['monitor_best'] | ||
|
||
# load architecture params from checkpoint. | ||
if checkpoint['config']['arch'] != self.config['arch']: | ||
self.logger.warning("Warning: Architecture configuration given in config file is different from that of " | ||
"checkpoint. This may yield an exception while state_dict is being loaded.") | ||
self.model.load_state_dict(checkpoint['state_dict']) | ||
|
||
# load optimizer state from checkpoint only when optimizer type is not changed. | ||
if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: | ||
self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " | ||
"Optimizer parameters not being resumed.") | ||
else: | ||
self.optimizer.load_state_dict(checkpoint['optimizer']) | ||
|
||
self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
{ | ||
"name": "Maske", | ||
"n_gpu": 1, | ||
"test_name": "NO_resize_Base_C400_200_TransForm", | ||
"data_set": { | ||
"type": "CustomDataset", | ||
"args": { | ||
"dir_path": "/opt/ml/input/data/train/" | ||
} | ||
}, | ||
"set_transform": { | ||
"type": "NO_resize_Base_C400_200_TransForm", | ||
"args": { | ||
"mean": [0.560, 0.524, 0.501], | ||
"std": [0.233, 0.243, 0.245], | ||
"resize": [512, 384], | ||
"use_rand_aug": false | ||
}, | ||
"cut_mix": false, | ||
"mix_up": false | ||
}, | ||
"data_loader":{ | ||
"type":"StratifiedKFold", | ||
"args": { | ||
"batch_size": 32, | ||
"shuffle": false, | ||
"validation_split": 0.1, | ||
"num_workers": 0 | ||
} | ||
}, | ||
"module": { | ||
"type": "CustomModel", | ||
"args": { | ||
} | ||
}, | ||
"optimizer": { | ||
"type": "AdamP", | ||
"args": { | ||
"lr": 3e-4, | ||
"betas": [0.9, 0.999], | ||
"weight_decay": 1e-5 | ||
} | ||
}, | ||
"lr_scheduler": { | ||
"type": "ReduceLROnPlateau", | ||
"args": { | ||
"mode": "min", | ||
"factor": 0.05, | ||
"patience": 2 | ||
} | ||
}, | ||
"set_loss": { | ||
"type": "FocalLoss", | ||
"args": { | ||
"gamma" : 5 | ||
} | ||
}, | ||
"metrics": [ | ||
"accuracy", | ||
"top_10_acc", | ||
"f1_score" | ||
], | ||
"trainer": { | ||
"type": "multi_label", | ||
"epochs": 3, | ||
"save_dir": "logs/", | ||
"save_period": 1, | ||
"verbosity": 2, | ||
"monitor": "min val_loss", | ||
"early_stop": 1, | ||
"tensorboard": false, | ||
"beta": 0.8 | ||
}, | ||
"test":{ | ||
"path": "/opt/ml/code/src/logs/log/Maske/0902_013705/model_best-epoch3.pth" | ||
} | ||
} | ||
|
Oops, something went wrong.