-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* ♻️ Refactor overall redundant codes * 🎨 Add config factory functionality * ✏️ Remove minor comments * 🐛 Fixed val_loss error and add error raising code when gpu is not detected * 🎨 Fix param_space to have more elegant representation * 🎨 Improve the pbt config * ♻️ Rename pytorch_lightning package as lightning * 📝 Update the precise guide * 🐛 Fix bug with moduel import * 🙈 Add .gitignore file * ✏️ Fix argparese argument to have dash * ♻️ Fixed the order of import * 🎨 Fix the old function of LightningModule * ♻️ Refactor the code and make loading checkpoint clearer * ♻️ Change the order of arguments for clarity * 🐛 Add a default model case * ✏️ Fix minor typo * ✏️ Fix typo * 🙈 Update to ignore lightning_logs * 🐛 Make dataset to read both train csv and test csv * 🚑 Make a quick fix for saving prediction csv file * 🎨 Add PredictionCallback * ✏️ Fix typo * ✏️ Fix typo * 🐛 Try to fix callback error * 🐛 Fix Expected a parent bug, see [the solution](Lightning-AI/pytorch-lightning#17485 (comment)) * 🐛 Move prediction feature from trainer to callback * 🙈 Update git not to read output csv file * ⬆️ Rewrite the requirements.yml
- Loading branch information
1 parent
78314e6
commit 6eb5127
Showing
16 changed files
with
595 additions
and
302 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Generated and compiled files | ||
*.pyc | ||
*.o | ||
*.exe | ||
|
||
# Build and packaged files | ||
__pycache__ | ||
dist | ||
build | ||
egg-info | ||
|
||
# IDE and Editor files | ||
.vscode | ||
.idea | ||
.pycharm.debug | ||
.cache | ||
|
||
# Virtual environment files | ||
venv | ||
.env | ||
|
||
# Source code management files | ||
.svn | ||
.git | ||
|
||
# Graphical User Interface files | ||
.DS_Store | ||
Thumbs.db | ||
|
||
# Backup and temporary files | ||
*.bak | ||
*.tmp | ||
*.swp | ||
|
||
# Third-party libraries | ||
node_modules | ||
vendor | ||
|
||
# Log files | ||
*.log | ||
lightning_logs | ||
*.csv | ||
|
||
# Operating system files | ||
.AppleDouble | ||
.LSOverride | ||
Desktop.ini |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
from .config import Config | ||
from .config import Config | ||
from .custom_nn_config import CustomNNConfig | ||
from .config_factory import get_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,29 @@ | ||
# config.py | ||
import ray | ||
from ray import tune | ||
from utils.logger import WandbLogger | ||
# from utils.logger import WandbLogger | ||
|
||
class Config: | ||
def __init__(self): | ||
self.model_name = "ResNet18" # Baseline model | ||
self.save_dir = "/home/logs/" | ||
self.data_path = "/home/data/" | ||
self.batch_size = tune.choice([32, 64, 128]) | ||
self.max_epochs = 1 | ||
self.lr = tune.uniform(0.001, 0.1) | ||
self.weight_decay = tune.uniform(0.001, 0.1) | ||
self.n_estimators = tune.randint(10, 100) | ||
self.num_gpus = 1 | ||
self.max_epochs = 100 | ||
self.num_samples = 4 # number of workers in population-based training | ||
self.num_workers = 2 # number of cpus workers in dataloader | ||
self.checkpoint_interval = 5 # number of epoch | ||
self.lr = tune.loguniform(0.001, 0.1) | ||
self.weight_decay = tune.loguniform(0.001, 0.1) | ||
|
||
self.search_space = { | ||
'batch_size': self.batch_size, | ||
'lr': self.lr, | ||
'weight_decay': self.weight_decay, | ||
'n_estimators': self.n_estimators, | ||
} | ||
def get_logger(self): | ||
return WandbLogger(project_name=self.model_name, config={ | ||
"save_dir": self.save_dir, | ||
"batch_size": self.batch_size, | ||
"max_epochs": self.max_epochs, | ||
# add other config parameters as needed | ||
}) | ||
# def get_logger(self): | ||
# return WandbLogger(project_name=self.model_name, config={ | ||
# "save_dir": self.save_dir, | ||
# "batch_size": self.batch_size, | ||
# "max_epochs": self.max_epochs, | ||
# # add other config parameters as needed | ||
# }) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from .config import Config | ||
from .custom_nn_config import CustomNNConfig | ||
|
||
CONFIG_MAP = { | ||
'ResNet18': Config, | ||
'CustomNN': CustomNNConfig, | ||
} | ||
|
||
def get_config(model_name): | ||
""" | ||
Retrieves a config class based on the provided model name. | ||
Args: | ||
model_name (str): The name of the model. | ||
Returns: | ||
Config or subclass: The config class associated with the model name. | ||
Raises: | ||
KeyError: If the model name is not found in the config map. | ||
""" | ||
if model_name is None: | ||
model_name = "ResNet18" | ||
try: | ||
return CONFIG_MAP[model_name] | ||
except KeyError: | ||
raise ValueError(f"Unsupported model name: {model_name}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .callbacks import PredictionCallback |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from lightning.pytorch.callbacks import Callback | ||
import numpy as np | ||
import pandas as pd | ||
from datetime import date | ||
|
||
class PredictionCallback(Callback): | ||
def __init__(self, data_path, model_name): | ||
self.predictions = [] | ||
self.data_path = data_path | ||
self.model_name = model_name | ||
|
||
def on_test_batch_end(self, trainer, pl_module, outputs, *args, **kwargs): | ||
self.predictions.extend(outputs.cpu().numpy()) | ||
|
||
def on_test_end(self, *args, **kwargs): | ||
predictions = np.array(self.predictions) | ||
test_info = pd.read_csv(self.data_path) | ||
test_info['target'] = predictions | ||
test_info = test_info.reset_index().rename(columns={"index": "ID"}) | ||
file_name = f"{self.model_name}_{date.today()}.csv" | ||
test_info.to_csv(file_name, index=False, lineterminator='\n') | ||
print("Output csv file successfully saved!!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# 시작하기 | ||
=============== | ||
|
||
이 프로젝트는 Python 환경에 특정 종속성을 설치해야 합니다. 다음 단계를 따라 환경을 설정하고 `train.py` 파일을 실행합니다. | ||
|
||
## 1단계: 필요한 종속성 설치 | ||
먼저, pip를 사용하여 필요한 종속성을 설치합니다. | ||
|
||
```pip install -U lightning "ray[data,train,tune,serve]" wandb``` | ||
이 명령어는 PyTorch Lightning, Ray 및 Weights & Biases 등 필요한 종속성을 설치합니다. | ||
|
||
## 2단계: Weights & Biases 설정 | ||
Weights & Biases 계정을 설정하고 SDK를 설치합니다. Weights & Biases 설정 방법에 대한 자세한 내용은 [여기](https://docs.wandb.ai/ko/quickstart)를 참조하세요. | ||
|
||
## 3단계: train.py 파일 실행 | ||
종속성이 설치되면 Python을 사용하여 `train.py` 파일을 실행할 수 있습니다. | ||
|
||
### 기본 설정 | ||
인수를 지정하지 않으면 스크립트는 기본 설정을 사용합니다. | ||
|
||
```python train.py``` | ||
이 명령어는 기본 설정으로 모델을 훈련합니다. | ||
|
||
### 사용자 정의 설정 | ||
대신 인수를 지정하여 기본 설정을 재정의할 수도 있습니다: | ||
|
||
```python train.py --model_name <모델_이름> --num_gpus <GPU_개수> --smoke_test``` | ||
- --model_name: 훈련할 모델의 이름을 지정합니다. 모델 정의를 확인하여 사용할 수 있는 모델 목록을 확인할 수 있습니다. 기본 모델은 ResNet18입니다. | ||
- --num_gpus: 훈련에 사용할 GPU 개수를 지정합니다. 멀티-GPU 환경에서 훈련할 때 사용합니다. 기본: 1 | ||
- --smoke_test: (선택 사항) 훈련 스크립트가 올바르게 작동하는지 확인하기 위해 빠른 스모크 테스트를 실행하려면 이 플래그를 추가합니다. 스모크 테스트는 작은 배치 크기와 제한된 에포크 수로 훈련 스크립트를 실행합니다. | ||
|
||
예시: | ||
```python train.py --model_name resnet50 --num_gpus 2``` | ||
이 명령어는 ResNet-50 모델을 2개의 GPU로 훈련합니다. | ||
|
||
`train.py` 파일이 있는 디렉토리에서 이 명령어를 실행하십시오. | ||
|
||
## 문제 해결 | ||
설치 또는 `train.py` 파일 실행 중 문제가 발생하면 다음을 확인하세요. | ||
|
||
- 최신 버전의 conda가 설치되어 있는지 확인합니다. | ||
- requirements.yaml 파일이 올바른 위치에 있는지 확인합니다. | ||
- conda 환경이 올바르게 활성화되었는지 확인합니다. | ||
|
||
## 기여 | ||
이 프로젝트에 기여하고 싶다면 저장소를 포크하고 풀 요청을 제출합니다. | ||
|
||
## 문의 | ||
질문이나 프로젝트에 대한 도움이 필요하면 [연락처 정보를 입력하세요]을 통해 연락해 주세요. |
Oops, something went wrong.