Skip to content

Commit

Permalink
Update pl_examples
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Aug 27, 2021
1 parent d07cee8 commit 6ab08a1
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 15 deletions.
9 changes: 6 additions & 3 deletions pl_examples/basic_examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,12 @@ def predict_dataloader(self):


def cli_main():
cli = LightningCLI(LitAutoEncoder, MyDataModule, seed_everything_default=1234, save_config_overwrite=True)
cli.trainer.test(cli.model, datamodule=cli.datamodule)
predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule)
cli = LightningCLI(
LitAutoEncoder, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False
)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best")
predictions = cli.trainer.predict(ckpt_path="best")
print(predictions[0])


Expand Down
7 changes: 4 additions & 3 deletions pl_examples/basic_examples/backbone_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,10 @@ def predict_dataloader(self):


def cli_main():
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True)
cli.trainer.test(cli.model, datamodule=cli.datamodule)
predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule)
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best")
predictions = cli.trainer.predict(ckpt_path="best")
print(predictions[0])


Expand Down
5 changes: 3 additions & 2 deletions pl_examples/basic_examples/dali_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,9 @@ def cli_main():
if not _DALI_AVAILABLE:
return

cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True)
cli.trainer.test(cli.model, datamodule=cli.datamodule)
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best")


if __name__ == "__main__":
Expand Down
7 changes: 5 additions & 2 deletions pl_examples/basic_examples/simple_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,11 @@ def configure_optimizers(self):


def cli_main():
cli = LightningCLI(LitClassifier, MNISTDataModule, seed_everything_default=1234, save_config_overwrite=True)
cli.trainer.test(cli.model, datamodule=cli.datamodule)
cli = LightningCLI(
LitClassifier, MNISTDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False
)
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.test(ckpt_path="best")


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions pl_examples/domain_templates/computer_vision_fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,10 @@ def add_arguments_to_parser(self, parser):
}
)

def instantiate_trainer(self):
finetuning_callback = MilestonesFinetuning(**self.config_init["finetuning"])
def instantiate_trainer(self, *args):
finetuning_callback = MilestonesFinetuning(**self._get(self.config_init, "finetuning"))
self.trainer_defaults["callbacks"] = [finetuning_callback]
super().instantiate_trainer()
return super().instantiate_trainer(*args)


def cli_main():
Expand Down
7 changes: 6 additions & 1 deletion pl_examples/run_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
set -ex

dir_path=$(dirname "${BASH_SOURCE[0]}")
args="--trainer.max_epochs=1 --data.batch_size=32 --trainer.limit_train_batches=2 --trainer.limit_val_batches=2"
args="--trainer.max_epochs=1 " \
"--data.batch_size=32 " \
"--trainer.limit_train_batches=2 " \
"--trainer.limit_val_batches=2 " \
"--trainer.limit_test_batches=2 "\
"--trainer.limit_predict_batches=2"

python "${dir_path}/basic_examples/simple_image_classifier.py" ${args} "$@"
python "${dir_path}/basic_examples/backbone_image_classifier.py" ${args} "$@"
Expand Down
2 changes: 1 addition & 1 deletion tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ fi
# report+="Ran\ttests/plugins/environments/torch_elastic_deadlock.py\n"

# test that a user can manually launch individual processes
args="--trainer.gpus 2 --trainer.accelerator ddp --trainer.fast_dev_run 1"
args="--trainer.gpus 2 --trainer.accelerator ddp --trainer.max_epochs=1 --trainer.limit_train_batches=1 --trainer.limit_val_batches=1 --limit_test_batches=1"
MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=1 python pl_examples/basic_examples/simple_image_classifier.py ${args} &
MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=0 python pl_examples/basic_examples/simple_image_classifier.py ${args}
report+="Ran\tmanual ddp launch test\n"
Expand Down

0 comments on commit 6ab08a1

Please sign in to comment.