Skip to content

Commit

Permalink
Black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
calebrob6 committed Feb 21, 2024
1 parent 907c846 commit a555656
Showing 1 changed file with 19 additions and 21 deletions.
40 changes: 19 additions & 21 deletions docs/tutorials/custom_segmentation_trainer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@
"from torchgeo.trainers import SemanticSegmentationTask\n",
"from torchgeo.datamodules import LandCoverAIDataModule\n",
"from torchmetrics import MetricCollection\n",
"from torchmetrics.classification import Accuracy, FBetaScore, Precision, Recall, JaccardIndex\n",
"from torchmetrics.classification import (\n",
" Accuracy,\n",
" FBetaScore,\n",
" Precision,\n",
" Recall,\n",
" JaccardIndex,\n",
")\n",
"\n",
"import lightning.pytorch as pl\n",
"from lightning.pytorch.callbacks import ModelCheckpoint\n",
Expand All @@ -70,6 +76,7 @@
"# Get rid of the pesky raised by kornia\n",
"# UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\", category=UserWarning, module=\"torch.nn.functional\")"
]
},
Expand Down Expand Up @@ -99,7 +106,9 @@
" def __init__(self, *args, tmax=50, eta_min=1e-6, **kwargs) -> None:\n",
" super().__init__()\n",
"\n",
" def configure_optimizers(self) -> \"lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig\":\n",
" def configure_optimizers(\n",
" self,\n",
" ) -> \"lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig\":\n",
" \"\"\"Initialize the optimizer and learning rate scheduler.\n",
"\n",
" Returns:\n",
Expand All @@ -121,19 +130,13 @@
" self.train_metrics = MetricCollection(\n",
" {\n",
" \"OverallAccuracy\": Accuracy(\n",
" task=\"multiclass\",\n",
" num_classes=num_classes,\n",
" average=\"micro\",\n",
" task=\"multiclass\", num_classes=num_classes, average=\"micro\"\n",
" ),\n",
" \"OverallPrecision\": Precision(\n",
" task=\"multiclass\",\n",
" num_classes=num_classes,\n",
" average=\"micro\",\n",
" task=\"multiclass\", num_classes=num_classes, average=\"micro\"\n",
" ),\n",
" \"OverallRecall\": Recall(\n",
" task=\"multiclass\",\n",
" num_classes=num_classes,\n",
" average=\"micro\",\n",
" task=\"multiclass\", num_classes=num_classes, average=\"micro\"\n",
" ),\n",
" \"OverallF1Score\": FBetaScore(\n",
" task=\"multiclass\",\n",
Expand All @@ -142,10 +145,8 @@
" average=\"micro\",\n",
" ),\n",
" \"MeanIoU\": JaccardIndex(\n",
" num_classes=num_classes,\n",
" task=\"multiclass\",\n",
" average=\"macro\",\n",
" )\n",
" num_classes=num_classes, task=\"multiclass\", average=\"macro\"\n",
" ),\n",
" },\n",
" prefix=\"train_\",\n",
" )\n",
Expand All @@ -165,7 +166,7 @@
"\n",
" def on_train_epoch_start(self) -> None:\n",
" \"\"\"Log the learning rate at the start of each training epoch.\"\"\"\n",
" lr = self.optimizers().param_groups[0]['lr']\n",
" lr = self.optimizers().param_groups[0][\"lr\"]\n",
" self.logger.experiment.add_scalar(\"lr\", lr, self.current_epoch)"
]
},
Expand All @@ -187,7 +188,7 @@
"dm = LandCoverAIDataModule(\n",
" root=\"/home/calebrobinson/ssdshared/torchgeo-datasets/LandCoverAI\",\n",
" batch_size=64,\n",
" num_workers=8\n",
" num_workers=8,\n",
")"
]
},
Expand Down Expand Up @@ -229,10 +230,7 @@
"accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"trainer = pl.Trainer(\n",
" accelerator=accelerator,\n",
" min_epochs=150,\n",
" max_epochs=300,\n",
" log_every_n_steps=50,\n",
" accelerator=accelerator, min_epochs=150, max_epochs=300, log_every_n_steps=50\n",
")"
]
},
Expand Down

0 comments on commit a555656

Please sign in to comment.