Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat 29/lr scheduler train transform #30

Merged
merged 14 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion config/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .config_factory import get_config
from .config_factory import get_config
from .config import ModelConfig
43 changes: 26 additions & 17 deletions config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ class ModelConfig:
"""Model-related configuration."""
def __init__(self):
self.model_name = "ResNet18" # Baseline model
self.pretrained = False
# self.num_layers = ~
# self.num_heads = ~

Expand All @@ -21,19 +22,27 @@ class DatasetConfig:
"""Dataset-related configuration."""
def __init__(self):
self.data_path = "/data/ephemeral/home/data/"
# self.transform_mode = 'albumentation'
self.transform_type = 'albumentations'
self.num_workers = 3


class ExperimentConfig:
"""Experiment-related configuration."""
def __init__(self):
self.save_dir = "/data/ephemeral/home/logs/"
self.num_gpus = 1
self.max_epochs = 100
self.num_workers = 1 # number of workers in scheduling
self.num_workers = 6 # number of workers in scheduling
self.num_samples = 10 # number of workers in ray tune
# self.checkpoint_interval = 5 # number of intervals to save checkpoint in pbt.
self.ddp = False

# Configs related with ASHA scheduler
self.max_epochs = 100
self.grace_period=10
self.reduction_factor=2
self.brackets=3

# Manual checkpoint load and test option
self.checkpoint_path = None


class Config:
Expand All @@ -43,19 +52,19 @@ def __init__(self):
self.training = TrainingConfig()
self.dataset = DatasetConfig()
self.experiment = ExperimentConfig()

self.search_space = {
'batch_size': self.training.batch_size,
'lr': self.training.lr,
'weight_decay': self.training.weight_decay,
}

self.search_space = vars(self.training)


def flatten_to_dict(self):
flattened_dict = {}
for key, value in vars(self).items():
if key != 'search_space' and key != 'training' and hasattr(value, '__dict__'):
for subkey, subvalue in vars(value).items():
flattened_dict[f"{key}_{subkey}"] = subvalue
return flattened_dict
def update_from_args(self, args):
def update_config(obj, args):
for key, value in vars(obj).items():
if hasattr(value, '__dict__'):
update_config(value, args)
else:
for arg_key, arg_value in vars(args).items():
if key == arg_key.replace("-", "_") and arg_value is not None:
setattr(obj, key, arg_value)

update_config(self, args)

51 changes: 33 additions & 18 deletions dataset/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# dataset/dataloader.py
import os

import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, SubsetRandomSampler

from .dataset import CustomDataset
from .transforms import get_transforms
from .transforms import TransformSelector


def get_dataloaders(data_path='/home/data/', batch_size=32, num_workers=1):
def get_dataloaders(data_path='/home/data/', transform_type='torchvision', batch_size=32, num_workers=1):
"""
Returns train and validation data loaders.

Expand All @@ -16,38 +21,45 @@ def get_dataloaders(data_path='/home/data/', batch_size=32, num_workers=1):
Returns:
Tuple[DataLoader, DataLoader]: Train and validation data loaders.
"""
info_df = pd.read_csv(os.path.join(data_path, 'train.csv'))
data_path = os.path.join(data_path, 'train')

transform = get_transforms(mode='train')
full_dataset = CustomDataset(data_path, mode='train', transform=transform)
train_df, val_df = train_test_split(
info_df,
test_size=0.2,
stratify=info_df['target']
)

transform_selector = TransformSelector(transform_type=transform_type)

# Create indices for the train and validation sets
indices = list(range(len(full_dataset)))
split = int(0.8 * len(indices))
train_indices = indices[:split]
val_indices = indices[split:]
train_transform = transform_selector.get_transforms(is_train=True)
train_dataset = CustomDataset(data_path, train_df, transform=train_transform)

# Create val dataset with validation transforms
val_dataset = CustomDataset(data_path, mode='train', transform=transform)
val_transform = transform_selector.get_transforms(is_train=False)
val_dataset = CustomDataset(data_path, val_df, transform=val_transform)

# Create samplers for the train and validation sets
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
# train_sampler = SubsetRandomSampler(train_indices)
# val_sampler = SubsetRandomSampler(val_indices)

# Use DataCollator to create batches for both datasets
train_loader = DataLoader(full_dataset,
train_loader = DataLoader(train_dataset,
batch_size=batch_size,
sampler=train_sampler,
# sampler=train_sampler,
shuffle=True,
num_workers=num_workers,
pin_memory=True)
val_loader = DataLoader(val_dataset,
batch_size=batch_size,
sampler=val_sampler,
# sampler=val_sampler,
shuffle=False,
num_workers=num_workers,
pin_memory=True)
return train_loader, val_loader


def get_test_loader(data_path='/home/data/', batch_size=32, num_workers=1):
def get_test_loader(data_path='/home/data/', transform_type='torchvision', batch_size=32, num_workers=1):
"""
Returns a test data loader.

Expand All @@ -59,8 +71,11 @@ def get_test_loader(data_path='/home/data/', batch_size=32, num_workers=1):
Returns:
DataLoader: Test data loader.
"""
transform = get_transforms(mode='test')
test_dataset = CustomDataset(data_path, mode='test', transform=transform)
transform_selector = TransformSelector(transform_type=transform_type)
test_df = pd.read_csv(os.path.join(data_path, 'test.csv'))
data_path = os.path.join(data_path, 'test')
transform = transform_selector.get_transforms(is_train=False)
test_dataset = CustomDataset(data_path, test_df, transform=transform, is_inference=True)
test_loader = DataLoader(test_dataset,
batch_size=batch_size,
num_workers=num_workers,
Expand Down
26 changes: 12 additions & 14 deletions dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,33 @@
import os

import cv2
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset


class CustomDataset(Dataset):
def __init__(self, data_path, mode='train', transform=None):
def __init__(self, data_path, info_df, transform=None, is_inference: bool = False):
"""
Initializes the CustomDataset class.

Args:
- data_path (str): The path to the dataset.
- mode (str): The mode of the dataset, either 'train' or 'test' (default: 'train').
- data_path (str): Path to the dataset.
- info_df (str): The index information dataframe.
- transform (transforms.Compose): The transforms to be applied to the images (default: None).
"""
self.data_path = os.path.join(data_path, f'{mode}')
self.mode = mode
self.data = pd.read_csv(os.path.join(data_path, f'{mode}.csv'))
self.image_paths = self.data['image_path'].tolist()
if mode == 'train':
self.labels = self.data['target'].tolist() # Read image paths from test.csv

self.data_path = data_path
self.info_df = info_df
self.is_inference = is_inference
self.transform = transform
self.image_paths = self.info_df['image_path'].tolist()

if not self.is_inference:
self.labels = self.info_df['target'].tolist() # Read image paths from test.csv


def _load_image(self, image_path):
image = cv2.imread(image_path, cv2.IMREAD_COLOR) # 이미지를 BGR 컬러 포맷의 numpy array로 읽어옵니다.
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # BGR 포맷을 RGB 포맷으로 변환합니다.
# image = self.transform(image) # 설정된 이미지 변환을 적용합니다.
image = Image.fromarray(image)
return image

def _apply_transform(self, image):
Expand All @@ -52,7 +50,7 @@ def __getitem__(self, index):
image = self._load_image(image_path)
image = self._apply_transform(image)

if self.mode != 'train':
if self.is_inference:
return image
else:
label = self.labels[index]
Expand Down
93 changes: 89 additions & 4 deletions dataset/transforms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,94 @@
import numpy as np
import torch
from torchvision import transforms

def get_transforms(mode='basic'):
transform = transforms.Compose([
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2


class TransformSelector:
"""
이미지 변환 라이브러리를 선택하기 위한 클래스.
"""
def __init__(self, transform_type: str):

# 지원하는 변환 라이브러리인지 확인
if transform_type in ["torchvision", "albumentations"]:
self.transform_type = transform_type

else:
raise ValueError("Unknown transformation library specified.")

def get_transforms(self, is_train: bool):

# 선택된 라이브러리에 따라 적절한 변환 객체를 생성
if self.transform_type == 'torchvision':
transform = TorchvisionTransform(is_train=is_train)

elif self.transform_type == 'albumentations':
transform = AlbumentationsTransform(is_train=is_train)

return transform

class TorchvisionTransform:
def __init__(self, is_train: bool = True):
# 공통 변환 설정: 이미지 리사이즈, 텐서 변환, 정규화
common_transforms = [
transforms.Resize((224, 224)), # 이미지를 224x224 크기로 리사이즈
transforms.ToTensor(), # 이미지를 PyTorch 텐서로 변환
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 정규화
])
return transform
]

if is_train:
# 훈련용 변환: 랜덤 수평 뒤집기, 랜덤 회전, 색상 조정 추가
self.transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(p=0.5), # 50% 확률로 이미지를 수평 뒤집기
transforms.RandomRotation(15), # 최대 15도 회전
transforms.ColorJitter(brightness=0.2, contrast=0.2), # 밝기 및 대비 조정
] + common_transforms
)
else:
# 검증/테스트용 변환: 공통 변환만 적용
self.transform = transforms.Compose(common_transforms)

def __call__(self, image: np.ndarray) -> torch.Tensor:
image = Image.fromarray(image) # numpy 배열을 PIL 이미지로 변환

transformed = self.transform(image) # 설정된 변환을 적용

return transformed # 변환된 이미지 반환


class AlbumentationsTransform:
def __init__(self, is_train: bool = True):
# 공통 변환 설정: 이미지 리사이즈, 정규화, 텐서 변환
common_transforms = [
A.Resize(224, 224), # 이미지를 224x224 크기로 리사이즈
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 정규화
ToTensorV2() # albumentations에서 제공하는 PyTorch 텐서 변환
]

if is_train:
# 훈련용 변환: 랜덤 수평 뒤집기, 랜덤 회전, 랜덤 밝기 및 대비 조정 추가
self.transform = A.Compose(
[
A.HorizontalFlip(p=0.5), # 50% 확률로 이미지를 수평 뒤집기
A.Rotate(limit=15), # 최대 15도 회전
A.RandomBrightnessContrast(p=0.2), # 밝기 및 대비 무작위 조정
] + common_transforms
)
else:
# 검증/테스트용 변환: 공통 변환만 적용
self.transform = A.Compose(common_transforms)

def __call__(self, image) -> torch.Tensor:
# 이미지가 NumPy 배열인지 확인
if not isinstance(image, np.ndarray):
raise TypeError("Image should be a NumPy array (OpenCV format).")

# 이미지에 변환 적용 및 결과 반환
transformed = self.transform(image=image) # 이미지에 설정된 변환을 적용

return transformed['image'] # 변환된 이미지의 텐서를 반환
7 changes: 6 additions & 1 deletion engine/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@

def run_test(config, ckpt_dir):
# Call the test loader
test_loader = get_test_loader(data_path=config.dataset.data_path, batch_size=64, num_workers=6)
test_loader = get_test_loader(
data_path=config.dataset.data_path,
transform_type=config.dataset.transform_type,
batch_size=64,
num_workers=6
)

# Define the trainer for testing
pred_callback = PredictionCallback(f"{config.dataset.data_path}/test.csv", ckpt_dir, config.model.model_name)
Expand Down
Loading