Skip to content

Commit

Permalink
Feat 29/lr scheduler train transform (#30)
Browse files Browse the repository at this point in the history
* ♻️ Update on nested config in each code for clarity

* 🎨 Update on train transforms

* 🔧 Update config structure on AHSA

* ✏️ Fix typos

* 🐛 Fix all the bugs related with lightning module hparams

* ✨ Add LR scheduler

* ✨ Add Albumentation transform

* 🎨 Make it available to config everything in the config.py file

* ⚡️ Make server to run concurrent trials

* 🚑 Forget to add weight decay in the optimizer

* ✨ Add a feature to run test from checkpoint_path

* 🎨 Fix config to update properly from args

* ⚡️ Change from Adam to AdamW

* ✨ Add a feature to use pretrained model or not
  • Loading branch information
Haneol-Kijm authored Sep 19, 2024
1 parent 452fe78 commit 345c23d
Show file tree
Hide file tree
Showing 10 changed files with 209 additions and 92 deletions.
3 changes: 2 additions & 1 deletion config/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .config_factory import get_config
from .config_factory import get_config
from .config import ModelConfig
43 changes: 26 additions & 17 deletions config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ class ModelConfig:
"""Model-related configuration."""
def __init__(self):
self.model_name = "ResNet18" # Baseline model
self.pretrained = False
# self.num_layers = ~
# self.num_heads = ~

Expand All @@ -21,19 +22,27 @@ class DatasetConfig:
"""Dataset-related configuration."""
def __init__(self):
self.data_path = "/data/ephemeral/home/data/"
# self.transform_mode = 'albumentation'
self.transform_type = 'albumentations'
self.num_workers = 3


class ExperimentConfig:
"""Experiment-related configuration."""
def __init__(self):
self.save_dir = "/data/ephemeral/home/logs/"
self.num_gpus = 1
self.max_epochs = 100
self.num_workers = 1 # number of workers in scheduling
self.num_workers = 6 # number of workers in scheduling
self.num_samples = 10 # number of workers in ray tune
# self.checkpoint_interval = 5 # number of intervals to save checkpoint in pbt.
self.ddp = False

# Configs related with ASHA scheduler
self.max_epochs = 100
self.grace_period=10
self.reduction_factor=2
self.brackets=3

# Manual checkpoint load and test option
self.checkpoint_path = None


class Config:
Expand All @@ -43,19 +52,19 @@ def __init__(self):
self.training = TrainingConfig()
self.dataset = DatasetConfig()
self.experiment = ExperimentConfig()

self.search_space = {
'batch_size': self.training.batch_size,
'lr': self.training.lr,
'weight_decay': self.training.weight_decay,
}

self.search_space = vars(self.training)


def flatten_to_dict(self):
flattened_dict = {}
for key, value in vars(self).items():
if key != 'search_space' and key != 'training' and hasattr(value, '__dict__'):
for subkey, subvalue in vars(value).items():
flattened_dict[f"{key}_{subkey}"] = subvalue
return flattened_dict
def update_from_args(self, args):
def update_config(obj, args):
for key, value in vars(obj).items():
if hasattr(value, '__dict__'):
update_config(value, args)
else:
for arg_key, arg_value in vars(args).items():
if key == arg_key.replace("-", "_") and arg_value is not None:
setattr(obj, key, arg_value)

update_config(self, args)

51 changes: 33 additions & 18 deletions dataset/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# dataset/dataloader.py
import os

import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, SubsetRandomSampler

from .dataset import CustomDataset
from .transforms import get_transforms
from .transforms import TransformSelector


def get_dataloaders(data_path='/home/data/', batch_size=32, num_workers=1):
def get_dataloaders(data_path='/home/data/', transform_type='torchvision', batch_size=32, num_workers=1):
"""
Returns train and validation data loaders.
Expand All @@ -16,38 +21,45 @@ def get_dataloaders(data_path='/home/data/', batch_size=32, num_workers=1):
Returns:
Tuple[DataLoader, DataLoader]: Train and validation data loaders.
"""
info_df = pd.read_csv(os.path.join(data_path, 'train.csv'))
data_path = os.path.join(data_path, 'train')

transform = get_transforms(mode='train')
full_dataset = CustomDataset(data_path, mode='train', transform=transform)
train_df, val_df = train_test_split(
info_df,
test_size=0.2,
stratify=info_df['target']
)

transform_selector = TransformSelector(transform_type=transform_type)

# Create indices for the train and validation sets
indices = list(range(len(full_dataset)))
split = int(0.8 * len(indices))
train_indices = indices[:split]
val_indices = indices[split:]
train_transform = transform_selector.get_transforms(is_train=True)
train_dataset = CustomDataset(data_path, train_df, transform=train_transform)

# Create val dataset with validation transforms
val_dataset = CustomDataset(data_path, mode='train', transform=transform)
val_transform = transform_selector.get_transforms(is_train=False)
val_dataset = CustomDataset(data_path, val_df, transform=val_transform)

# Create samplers for the train and validation sets
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
# train_sampler = SubsetRandomSampler(train_indices)
# val_sampler = SubsetRandomSampler(val_indices)

# Use DataCollator to create batches for both datasets
train_loader = DataLoader(full_dataset,
train_loader = DataLoader(train_dataset,
batch_size=batch_size,
sampler=train_sampler,
# sampler=train_sampler,
shuffle=True,
num_workers=num_workers,
pin_memory=True)
val_loader = DataLoader(val_dataset,
batch_size=batch_size,
sampler=val_sampler,
# sampler=val_sampler,
shuffle=False,
num_workers=num_workers,
pin_memory=True)
return train_loader, val_loader


def get_test_loader(data_path='/home/data/', batch_size=32, num_workers=1):
def get_test_loader(data_path='/home/data/', transform_type='torchvision', batch_size=32, num_workers=1):
"""
Returns a test data loader.
Expand All @@ -59,8 +71,11 @@ def get_test_loader(data_path='/home/data/', batch_size=32, num_workers=1):
Returns:
DataLoader: Test data loader.
"""
transform = get_transforms(mode='test')
test_dataset = CustomDataset(data_path, mode='test', transform=transform)
transform_selector = TransformSelector(transform_type=transform_type)
test_df = pd.read_csv(os.path.join(data_path, 'test.csv'))
data_path = os.path.join(data_path, 'test')
transform = transform_selector.get_transforms(is_train=False)
test_dataset = CustomDataset(data_path, test_df, transform=transform, is_inference=True)
test_loader = DataLoader(test_dataset,
batch_size=batch_size,
num_workers=num_workers,
Expand Down
26 changes: 12 additions & 14 deletions dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,33 @@
import os

import cv2
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset


class CustomDataset(Dataset):
def __init__(self, data_path, mode='train', transform=None):
def __init__(self, data_path, info_df, transform=None, is_inference: bool = False):
"""
Initializes the CustomDataset class.
Args:
- data_path (str): The path to the dataset.
- mode (str): The mode of the dataset, either 'train' or 'test' (default: 'train').
- data_path (str): Path to the dataset.
- info_df (str): The index information dataframe.
- transform (transforms.Compose): The transforms to be applied to the images (default: None).
"""
self.data_path = os.path.join(data_path, f'{mode}')
self.mode = mode
self.data = pd.read_csv(os.path.join(data_path, f'{mode}.csv'))
self.image_paths = self.data['image_path'].tolist()
if mode == 'train':
self.labels = self.data['target'].tolist() # Read image paths from test.csv

self.data_path = data_path
self.info_df = info_df
self.is_inference = is_inference
self.transform = transform
self.image_paths = self.info_df['image_path'].tolist()

if not self.is_inference:
self.labels = self.info_df['target'].tolist() # Read image paths from test.csv


def _load_image(self, image_path):
image = cv2.imread(image_path, cv2.IMREAD_COLOR) # 이미지를 BGR 컬러 포맷의 numpy array로 읽어옵니다.
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # BGR 포맷을 RGB 포맷으로 변환합니다.
# image = self.transform(image) # 설정된 이미지 변환을 적용합니다.
image = Image.fromarray(image)
return image

def _apply_transform(self, image):
Expand All @@ -52,7 +50,7 @@ def __getitem__(self, index):
image = self._load_image(image_path)
image = self._apply_transform(image)

if self.mode != 'train':
if self.is_inference:
return image
else:
label = self.labels[index]
Expand Down
93 changes: 89 additions & 4 deletions dataset/transforms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,94 @@
import numpy as np
import torch
from torchvision import transforms

def get_transforms(mode='basic'):
transform = transforms.Compose([
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2


class TransformSelector:
"""
이미지 변환 라이브러리를 선택하기 위한 클래스.
"""
def __init__(self, transform_type: str):

# 지원하는 변환 라이브러리인지 확인
if transform_type in ["torchvision", "albumentations"]:
self.transform_type = transform_type

else:
raise ValueError("Unknown transformation library specified.")

def get_transforms(self, is_train: bool):

# 선택된 라이브러리에 따라 적절한 변환 객체를 생성
if self.transform_type == 'torchvision':
transform = TorchvisionTransform(is_train=is_train)

elif self.transform_type == 'albumentations':
transform = AlbumentationsTransform(is_train=is_train)

return transform

class TorchvisionTransform:
def __init__(self, is_train: bool = True):
# 공통 변환 설정: 이미지 리사이즈, 텐서 변환, 정규화
common_transforms = [
transforms.Resize((224, 224)), # 이미지를 224x224 크기로 리사이즈
transforms.ToTensor(), # 이미지를 PyTorch 텐서로 변환
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 정규화
])
return transform
]

if is_train:
# 훈련용 변환: 랜덤 수평 뒤집기, 랜덤 회전, 색상 조정 추가
self.transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(p=0.5), # 50% 확률로 이미지를 수평 뒤집기
transforms.RandomRotation(15), # 최대 15도 회전
transforms.ColorJitter(brightness=0.2, contrast=0.2), # 밝기 및 대비 조정
] + common_transforms
)
else:
# 검증/테스트용 변환: 공통 변환만 적용
self.transform = transforms.Compose(common_transforms)

def __call__(self, image: np.ndarray) -> torch.Tensor:
image = Image.fromarray(image) # numpy 배열을 PIL 이미지로 변환

transformed = self.transform(image) # 설정된 변환을 적용

return transformed # 변환된 이미지 반환


class AlbumentationsTransform:
def __init__(self, is_train: bool = True):
# 공통 변환 설정: 이미지 리사이즈, 정규화, 텐서 변환
common_transforms = [
A.Resize(224, 224), # 이미지를 224x224 크기로 리사이즈
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 정규화
ToTensorV2() # albumentations에서 제공하는 PyTorch 텐서 변환
]

if is_train:
# 훈련용 변환: 랜덤 수평 뒤집기, 랜덤 회전, 랜덤 밝기 및 대비 조정 추가
self.transform = A.Compose(
[
A.HorizontalFlip(p=0.5), # 50% 확률로 이미지를 수평 뒤집기
A.Rotate(limit=15), # 최대 15도 회전
A.RandomBrightnessContrast(p=0.2), # 밝기 및 대비 무작위 조정
] + common_transforms
)
else:
# 검증/테스트용 변환: 공통 변환만 적용
self.transform = A.Compose(common_transforms)

def __call__(self, image) -> torch.Tensor:
# 이미지가 NumPy 배열인지 확인
if not isinstance(image, np.ndarray):
raise TypeError("Image should be a NumPy array (OpenCV format).")

# 이미지에 변환 적용 및 결과 반환
transformed = self.transform(image=image) # 이미지에 설정된 변환을 적용

return transformed['image'] # 변환된 이미지의 텐서를 반환
7 changes: 6 additions & 1 deletion engine/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@

def run_test(config, ckpt_dir):
# Call the test loader
test_loader = get_test_loader(data_path=config.dataset.data_path, batch_size=64, num_workers=6)
test_loader = get_test_loader(
data_path=config.dataset.data_path,
transform_type=config.dataset.transform_type,
batch_size=64,
num_workers=6
)

# Define the trainer for testing
pred_callback = PredictionCallback(f"{config.dataset.data_path}/test.csv", ckpt_dir, config.model.model_name)
Expand Down
Loading

0 comments on commit 345c23d

Please sign in to comment.