Skip to content

Commit

Permalink
Merge pull request #24 from boostcampaitech7/feat-23/hotfix-data-path
Browse files Browse the repository at this point in the history
Feat 23/hotfix data path
  • Loading branch information
Haneol-Kijm authored Sep 17, 2024
2 parents acce91d + cb79b81 commit 0c60863
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
4 changes: 2 additions & 2 deletions config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ def __init__(self):
class DatasetConfig:
"""Dataset-related configuration."""
def __init__(self):
self.data_path = "/home/data/"
self.data_path = "/data/ephemeral/home/data/"
# self.transform_mode = 'albumentation'


class ExperimentConfig:
"""Experiment-related configuration."""
def __init__(self):
self.save_dir = "/home/logs/"
self.save_dir = "/data/ephemeral/home/logs/"
self.num_gpus = 1
self.max_epochs = 100
self.num_workers = 2 # number of cpus workers in dataloader
Expand Down
2 changes: 1 addition & 1 deletion engine/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def run_test(config, ckpt_dir):

# Define the trainer for testing
pred_callback = PredictionCallback(f"{config.dataset.data_path}/test.csv", ckpt_dir, config.model.model_name)
trainer_test = Trainer(callbacks=[pred_callback], logger=False)
trainer_test = Trainer(callbacks=[pred_callback], logger=False, enable_progress_bar=False,)
best_model = LightningModule.load_from_checkpoint(f"{ckpt_dir}/pltrainer.ckpt")
# Conduct testing with the loaded model
trainer_test.test(best_model, dataloaders=test_loader)
4 changes: 3 additions & 1 deletion engine/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ def train_func(config_dict): # Note that config_dict is dict here passed by pbt
accelerator='gpu',
devices=config_dict['experiment']['num_gpus'],
strategy='ddp',
logger=False,
callbacks=[TuneReportCheckpointCallback(
metrics={"val_loss": "val_loss", "val_acc": "val_acc"},
filename="pltrainer.ckpt", on="validation_end",
)],
enable_progress_bar=False,
)

trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
Expand Down Expand Up @@ -72,7 +74,7 @@ def _define_run_config(self):
num_to_keep=4,
checkpoint_score_attribute="val_loss",
),
storage_path="/tmp/ray_results",
storage_path=f"{self.config.experiment.save_dir}/ray_results",
callbacks=[WandbLoggerCallback(project=self.config.model.model_name)],
verbose=1,
)
Expand Down

0 comments on commit 0c60863

Please sign in to comment.