-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat 29/lr scheduler train transform (#30)
* ♻️ Update on nested config in each code for clarity * 🎨 Update on train transforms * 🔧 Update config structure on AHSA * ✏️ Fix typos * 🐛 Fix all the bugs related with lightning module hparams * ✨ Add LR scheduler * ✨ Add Albumentation transform * 🎨 Make it available to config everything in the config.py file * ⚡️ Make server to run concurrent trials * 🚑 Forget to add weight decay in the optimizer * ✨ Add a feature to run test from checkpoint_path * 🎨 Fix config to update properly from args * ⚡️ Change from Adam to AdamW * ✨ Add a feature to use pretrained model or not
- Loading branch information
1 parent
452fe78
commit 345c23d
Showing
10 changed files
with
209 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .config_factory import get_config | ||
from .config_factory import get_config | ||
from .config import ModelConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,94 @@ | ||
import numpy as np | ||
import torch | ||
from torchvision import transforms | ||
|
||
def get_transforms(mode='basic'): | ||
transform = transforms.Compose([ | ||
from PIL import Image | ||
import albumentations as A | ||
from albumentations.pytorch import ToTensorV2 | ||
|
||
|
||
class TransformSelector: | ||
""" | ||
이미지 변환 라이브러리를 선택하기 위한 클래스. | ||
""" | ||
def __init__(self, transform_type: str): | ||
|
||
# 지원하는 변환 라이브러리인지 확인 | ||
if transform_type in ["torchvision", "albumentations"]: | ||
self.transform_type = transform_type | ||
|
||
else: | ||
raise ValueError("Unknown transformation library specified.") | ||
|
||
def get_transforms(self, is_train: bool): | ||
|
||
# 선택된 라이브러리에 따라 적절한 변환 객체를 생성 | ||
if self.transform_type == 'torchvision': | ||
transform = TorchvisionTransform(is_train=is_train) | ||
|
||
elif self.transform_type == 'albumentations': | ||
transform = AlbumentationsTransform(is_train=is_train) | ||
|
||
return transform | ||
|
||
class TorchvisionTransform: | ||
def __init__(self, is_train: bool = True): | ||
# 공통 변환 설정: 이미지 리사이즈, 텐서 변환, 정규화 | ||
common_transforms = [ | ||
transforms.Resize((224, 224)), # 이미지를 224x224 크기로 리사이즈 | ||
transforms.ToTensor(), # 이미지를 PyTorch 텐서로 변환 | ||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 정규화 | ||
]) | ||
return transform | ||
] | ||
|
||
if is_train: | ||
# 훈련용 변환: 랜덤 수평 뒤집기, 랜덤 회전, 색상 조정 추가 | ||
self.transform = transforms.Compose( | ||
[ | ||
transforms.RandomHorizontalFlip(p=0.5), # 50% 확률로 이미지를 수평 뒤집기 | ||
transforms.RandomRotation(15), # 최대 15도 회전 | ||
transforms.ColorJitter(brightness=0.2, contrast=0.2), # 밝기 및 대비 조정 | ||
] + common_transforms | ||
) | ||
else: | ||
# 검증/테스트용 변환: 공통 변환만 적용 | ||
self.transform = transforms.Compose(common_transforms) | ||
|
||
def __call__(self, image: np.ndarray) -> torch.Tensor: | ||
image = Image.fromarray(image) # numpy 배열을 PIL 이미지로 변환 | ||
|
||
transformed = self.transform(image) # 설정된 변환을 적용 | ||
|
||
return transformed # 변환된 이미지 반환 | ||
|
||
|
||
class AlbumentationsTransform: | ||
def __init__(self, is_train: bool = True): | ||
# 공통 변환 설정: 이미지 리사이즈, 정규화, 텐서 변환 | ||
common_transforms = [ | ||
A.Resize(224, 224), # 이미지를 224x224 크기로 리사이즈 | ||
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 정규화 | ||
ToTensorV2() # albumentations에서 제공하는 PyTorch 텐서 변환 | ||
] | ||
|
||
if is_train: | ||
# 훈련용 변환: 랜덤 수평 뒤집기, 랜덤 회전, 랜덤 밝기 및 대비 조정 추가 | ||
self.transform = A.Compose( | ||
[ | ||
A.HorizontalFlip(p=0.5), # 50% 확률로 이미지를 수평 뒤집기 | ||
A.Rotate(limit=15), # 최대 15도 회전 | ||
A.RandomBrightnessContrast(p=0.2), # 밝기 및 대비 무작위 조정 | ||
] + common_transforms | ||
) | ||
else: | ||
# 검증/테스트용 변환: 공통 변환만 적용 | ||
self.transform = A.Compose(common_transforms) | ||
|
||
def __call__(self, image) -> torch.Tensor: | ||
# 이미지가 NumPy 배열인지 확인 | ||
if not isinstance(image, np.ndarray): | ||
raise TypeError("Image should be a NumPy array (OpenCV format).") | ||
|
||
# 이미지에 변환 적용 및 결과 반환 | ||
transformed = self.transform(image=image) # 이미지에 설정된 변환을 적용 | ||
|
||
return transformed['image'] # 변환된 이미지의 텐서를 반환 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.