-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added the basic skeleton of training code
- Loading branch information
Showing
10 changed files
with
214 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
data: | ||
csv_file: | ||
|
||
parameters: | ||
batch_size: 16 | ||
train_shuffle: True | ||
val_shuffle: False | ||
epochs: 50 | ||
early_stop: 5 | ||
max_length: 333 | ||
|
||
device: | ||
is_cuda: True | ||
|
||
loss: | ||
loss_option: 'bce_cross_entropy_loss' | ||
|
||
optimizer: | ||
choice: 'Adam' | ||
lr: 1e-3 | ||
gamma: 0.5 | ||
step_size: 15 | ||
scheduler: 'step_lr' | ||
mode: 'max' | ||
decay: 0.001 | ||
patience: 5 | ||
factor: 0.5 | ||
verbose: True | ||
|
||
model: | ||
option: 'LSTM_multi_layer_tone_transition_model' | ||
model_type: 'LSTM' | ||
embedding_dim: 512 | ||
n_hidden: 512 | ||
n_layers: 2 | ||
n_classes: 2 | ||
batch_first: True | ||
|
||
output: | ||
model_dir: '/data/digbose92/ads_complete_repo/ads_codes/model_files/recent_models/model_dir' | ||
log_dir: '/data/digbose92/ads_complete_repo/ads_codes/model_files/recent_models/log_dir' |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn.functional as F | ||
import torch.nn as nn | ||
import pandas as pd | ||
from collections import Counter | ||
import math | ||
|
||
#basic binary cross entropy loss | ||
def binary_cross_entropy_loss(device,pos_weights=None,reduction='mean'): | ||
loss=nn.BCEWithLogitsLoss(reduction='mean',pos_weight=pos_weights).to(device) | ||
return(loss) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import torch | ||
from transformers import AdamW, get_linear_schedule_with_warmup | ||
|
||
def optimizer_adam(model,lr,weight_decay=0): | ||
optim_set=torch.optim.Adam(model.parameters(),lr=lr,weight_decay=weight_decay) | ||
return(optim_set) | ||
|
||
def optimizer_adamW(model,lr,weight_decay): | ||
#optim_set=AdamW(model.parameters(),lr=lr) | ||
#default weight decay parameters added | ||
optim_set=AdamW(model.parameters(),lr=lr,weight_decay=weight_decay) | ||
return(optim_set) | ||
|
||
def linear_schedule_with_warmup(optimizer,num_warmup_steps,num_training_steps): | ||
scheduler = get_linear_schedule_with_warmup(optimizer, | ||
num_warmup_steps=num_warmup_steps, # Default value | ||
num_training_steps=num_training_steps) | ||
return(scheduler) | ||
|
||
def reduce_lr_on_plateau(optimizer,mode,patience): | ||
lr_scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,mode=mode,patience=patience) | ||
return(lr_scheduler) | ||
|
||
def steplr_scheduler(optimizer,step_size,gamma): | ||
scheduler=torch.optim.lr_scheduler.StepLR( | ||
optimizer=optimizer, | ||
step_size=step_size, | ||
gamma=gamma | ||
) | ||
return(scheduler) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import torch | ||
import torch.nn as nn | ||
import pandas as pd | ||
import os | ||
import sys | ||
import time | ||
import pickle | ||
#append path of datasets and models | ||
sys.path.append(os.path.join('..', 'datasets')) | ||
sys.path.append(os.path.join('..', 'models')) | ||
sys.path.append(os.path.join('..', 'configs')) | ||
sys.path.append(os.path.join('..', 'losses')) | ||
sys.path.append(os.path.join('..', 'optimizers')) | ||
sys.path.append(os.path.join('..', 'utils')) | ||
|
||
#import all libraries | ||
import random | ||
from ast import literal_eval | ||
import torch | ||
import yaml | ||
import torchvision.transforms as transforms | ||
import torchvision | ||
from torch.utils.data import Dataset, DataLoader | ||
from dataset import * | ||
from loss_functions import * | ||
from LSTM_models import * | ||
from optimizer import * | ||
from metrics import calculate_stats | ||
import torch.nn as nn | ||
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score | ||
from tqdm import tqdm | ||
from statistics import mean | ||
import argparse | ||
from log_file_generate import * | ||
from scipy.stats.stats import pearsonr | ||
import wandb | ||
|
||
#fix seed for reproducibility | ||
seed_value=123457 | ||
np.random.seed(seed_value) # cpu vars | ||
torch.manual_seed(seed_value) # cpu vars | ||
random.seed(seed_value) # Python | ||
torch.cuda.manual_seed(seed_value) | ||
torch.cuda.manual_seed_all(seed_value) | ||
torch.backends.cudnn.deterministic = True | ||
torch.backends.cudnn.benchmark = False | ||
|
||
def sort_batch(X, y, lengths): | ||
lengths, indx = lengths.sort(dim=0, descending=True) | ||
X = X[indx] | ||
y = y[indx] | ||
return X, y, lengths # transpose (batch x seq_length) to (seq_length x batch) | ||
|
||
def load_config(config_file): | ||
|
||
with open(config_file,'r') as f: | ||
config_data=yaml.safe_load(f) | ||
return(config_data) | ||
|
||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import os | ||
import logging | ||
def log(path, file): | ||
"""[Create a log file to record the experiment's logs] | ||
Arguments: | ||
path {string} -- path to the directory | ||
file {string} -- file name | ||
Returns: | ||
[func] -- [logger that record logs] | ||
""" | ||
|
||
# check if the file exist | ||
log_file = os.path.join(path, file) | ||
|
||
if not os.path.isfile(log_file): | ||
open(log_file, "w+").close() | ||
|
||
console_logging_format = "%(levelname)s %(message)s" | ||
file_logging_format = "%(levelname)s: %(asctime)s: %(message)s" | ||
|
||
# configure logger | ||
logging.basicConfig(level=logging.INFO, format=console_logging_format) | ||
logger = logging.getLogger() | ||
|
||
# create a file handler for output file | ||
handler = logging.FileHandler(log_file) | ||
|
||
# set the logging level for log file | ||
handler.setLevel(logging.INFO) | ||
|
||
# create a logging format | ||
formatter = logging.Formatter(file_logging_format) | ||
handler.setFormatter(formatter) | ||
|
||
# add the handlers to the logger | ||
logger.addHandler(handler) | ||
|
||
return logger |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import numpy as np | ||
from scipy import stats | ||
from sklearn import metrics | ||
import torch | ||
|
||
def d_prime(auc): | ||
standard_normal = stats.norm() | ||
d_prime = standard_normal.ppf(auc) * np.sqrt(2.0) | ||
return d_prime | ||
|
||
def calculate_stats(output, target): | ||
"""Calculate statistics including mAP, AUC, etc. | ||
Args: | ||
output: 2d array, (samples_num, classes_num) | ||
target: 2d array, (samples_num, classes_num) | ||
Returns: | ||
stats: list of statistic of each class. | ||
""" | ||
classes_num = target.shape[-1] | ||
stats = [] | ||
# Class-wise statistics | ||
ap_val_list=[] | ||
for k in range(classes_num): | ||
|
||
# Average precision | ||
avg_precision = metrics.average_precision_score( | ||
target[:, k], output[:, k], average=None) | ||
|
||
ap_val_list.append(avg_precision) | ||
|
||
return ap_val_list |