diff --git a/src/anomalib/cli/cli.py b/src/anomalib/cli/cli.py index 9184367c40..0a31e3840d 100644 --- a/src/anomalib/cli/cli.py +++ b/src/anomalib/cli/cli.py @@ -26,6 +26,7 @@ _LIGHTNING_AVAILABLE = True try: from lightning.pytorch import Trainer + from lightning.pytorch.core.datamodule import LightningDataModule from torch.utils.data import DataLoader, Dataset from anomalib.data import AnomalibDataModule, AnomalibDataset @@ -296,6 +297,9 @@ def instantiate_classes(self) -> None: # the minor change here is that engine is instantiated instead of trainer self.config_init = self.parser.instantiate_classes(self.config) self.datamodule = self._get(self.config_init, "data") + if isinstance(self.datamodule, Dataset): + kwargs = {f"{self.config.subcommand}_dataset": self.datamodule} + self.datamodule = LightningDataModule.from_datasets(**kwargs) self.model = self._get(self.config_init, "model") self._configure_optimizers_method_to_model() self.instantiate_engine() diff --git a/src/anomalib/engine/engine.py b/src/anomalib/engine/engine.py index 73e4d664d9..358f34ed28 100644 --- a/src/anomalib/engine/engine.py +++ b/src/anomalib/engine/engine.py @@ -308,7 +308,7 @@ def _setup_anomalib_callbacks(self) -> None: def _should_run_validation( self, model: AnomalyModule, - dataloaders: EVAL_DATALOADERS | AnomalibDataModule | None, + dataloaders: EVAL_DATALOADERS | None, datamodule: AnomalibDataModule | None, ckpt_path: str | None, ) -> bool: @@ -326,7 +326,7 @@ def _should_run_validation( Args: model (AnomalyModule): Model passed to the entrypoint. - dataloaders (EVAL_DATALOADERS | AnomalibDataModule | None): Dataloaders passed to the entrypoint. + dataloaders (EVAL_DATALOADERS | None): Dataloaders passed to the entrypoint. datamodule (AnomalibDataModule | None): Lightning datamodule passed to the entrypoint. ckpt_path (str | None): Checkpoint path passed to the entrypoint. @@ -348,7 +348,7 @@ def _should_run_validation( def fit( self, model: AnomalyModule, - train_dataloaders: TRAIN_DATALOADERS | AnomalibDataModule | None = None, + train_dataloaders: TRAIN_DATALOADERS | None = None, val_dataloaders: EVAL_DATALOADERS | None = None, datamodule: AnomalibDataModule | None = None, ckpt_path: str | None = None, @@ -357,7 +357,7 @@ def fit( Args: model (AnomalyModule): Model to be trained. - train_dataloaders (TRAIN_DATALOADERS | AnomalibDataModule | None, optional): Train dataloaders. + train_dataloaders (TRAIN_DATALOADERS | None, optional): Train dataloaders. Defaults to None. val_dataloaders (EVAL_DATALOADERS | None, optional): Validation dataloaders. Defaults to None. @@ -392,7 +392,7 @@ def fit( def validate( self, model: AnomalyModule | None = None, - dataloaders: EVAL_DATALOADERS | AnomalibDataModule | None = None, + dataloaders: EVAL_DATALOADERS | None = None, ckpt_path: str | None = None, verbose: bool = True, datamodule: AnomalibDataModule | None = None, @@ -402,7 +402,7 @@ def validate( Args: model (AnomalyModule | None, optional): Model to be validated. Defaults to None. - dataloaders (EVAL_DATALOADERS | AnomalibDataModule | None, optional): Dataloaders to be used for + dataloaders (EVAL_DATALOADERS | None, optional): Dataloaders to be used for validation. Defaults to None. ckpt_path (str | None, optional): Checkpoint path. If provided, the model will be loaded from this path. @@ -439,7 +439,7 @@ def validate( def test( self, model: AnomalyModule | None = None, - dataloaders: EVAL_DATALOADERS | AnomalibDataModule | None = None, + dataloaders: EVAL_DATALOADERS | None = None, ckpt_path: str | None = None, verbose: bool = True, datamodule: AnomalibDataModule | None = None, @@ -453,7 +453,7 @@ def test( model (AnomalyModule | None, optional): The model to be tested. Defaults to None. - dataloaders (EVAL_DATALOADERS | AnomalibDataModule | None, optional): + dataloaders (EVAL_DATALOADERS | None, optional): An iterable or collection of iterables specifying test samples. Defaults to None. ckpt_path (str | None, optional): @@ -526,13 +526,12 @@ def test( self.trainer.validate(model, dataloaders, None, verbose=False, datamodule=datamodule) return self.trainer.test(model, dataloaders, ckpt_path, verbose, datamodule) - # TODO(ashwinvaidya17): revisit typing of data args - # https://github.com/openvinotoolkit/anomalib/issues/1638 def predict( self, model: AnomalyModule | None = None, - dataloaders: EVAL_DATALOADERS | AnomalibDataModule | None = None, - datamodule: AnomalibDataModule | Dataset | PredictDataset | None = None, + dataloaders: EVAL_DATALOADERS | None = None, + datamodule: AnomalibDataModule | None = None, + dataset: Dataset | PredictDataset | None = None, return_predictions: bool | None = None, ckpt_path: str | None = None, ) -> _PREDICT_OUTPUT | None: @@ -545,7 +544,7 @@ def predict( model (AnomalyModule | None, optional): Model to be used for prediction. Defaults to None. - dataloaders (EVAL_DATALOADERS | AnomalibDataModule | None, optional): + dataloaders (EVAL_DATALOADERS | None, optional): An iterable or collection of iterables specifying predict samples. Defaults to None. datamodule (AnomalibDataModule | None, optional): @@ -553,6 +552,9 @@ def predict( the :class:`~lightning.pytorch.core.hooks.DataHooks.predict_dataloader` hook. The datamodule can also be a dataset that will be wrapped in a torch Dataloader. Defaults to None. + dataset (Dataset | PredictDataset | None, optional): + A :class:`~torch.utils.data.Dataset` or :class:`~anomalib.data.PredictDataset` that will be used + to create a dataloader. Defaults to None. return_predictions (bool | None, optional): Whether to return predictions. ``True`` by default except when an accelerator that spawns processes is used (not supported). @@ -598,9 +600,8 @@ def predict( logger.warning("ckpt_path is not provided. Model weights will not be loaded.") # Handle the instance when a dataset is passed to the predict method - if datamodule is not None and isinstance(datamodule, Dataset): - dataloader = DataLoader(datamodule) - datamodule = None + if dataset is not None: + dataloader = DataLoader(dataset) if dataloaders is None: dataloaders = dataloader elif isinstance(dataloaders, DataLoader): @@ -628,7 +629,7 @@ def predict( def train( self, model: AnomalyModule, - train_dataloaders: TRAIN_DATALOADERS | AnomalibDataModule | None = None, + train_dataloaders: TRAIN_DATALOADERS | None = None, val_dataloaders: EVAL_DATALOADERS | None = None, test_dataloaders: EVAL_DATALOADERS | None = None, datamodule: AnomalibDataModule | None = None, @@ -638,7 +639,7 @@ def train( Args: model (AnomalyModule): Model to be trained. - train_dataloaders (TRAIN_DATALOADERS | AnomalibDataModule | None, optional): Train dataloaders. + train_dataloaders (TRAIN_DATALOADERS | None, optional): Train dataloaders. Defaults to None. val_dataloaders (EVAL_DATALOADERS | None, optional): Validation dataloaders. Defaults to None.