From ca13e7f19ba230d796a002f93d8b388f153ea3eb Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Thu, 21 Nov 2024 22:28:44 +0000 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Update=20Changes=20from=20?= =?UTF-8?q?New=20Engine=20Design=20(#882)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add changes from New engine design #578. This will not only simplify the PR but also keep the main repo up to date. - Refactor `model_to` to `model_abc` - Instead of `on_gpu` use `device` as an input in line with `PyTorch`. - `infer_batch` uses `device` as an input instead of `on_gpu` --- .gitignore | 3 + tests/models/test_abc.py | 20 +++++- tests/models/test_arch_mapde.py | 4 +- tests/models/test_arch_micronet.py | 2 +- tests/models/test_arch_nuclick.py | 3 +- tests/models/test_arch_sccnn.py | 17 +++-- tests/models/test_arch_unet.py | 5 +- tests/models/test_arch_vanilla.py | 11 ++-- tests/models/test_feature_extractor.py | 5 +- tests/models/test_hovernet.py | 9 +-- tests/models/test_hovernetplus.py | 3 +- tests/models/test_multi_task_segmentor.py | 23 +++---- .../models/test_nucleus_instance_segmentor.py | 13 ++-- tests/models/test_patch_predictor.py | 44 ++++++------- tests/models/test_semantic_segmentation.py | 48 +++++++------- tests/test_annotation_stores.py | 11 +--- tests/test_annotation_tilerendering.py | 15 ++++- tests/test_init.py | 2 +- tests/test_utils.py | 22 +------ tests/test_wsimeta.py | 1 - tiatoolbox/annotation/storage.py | 16 ++++- tiatoolbox/cli/common.py | 26 ++++---- tiatoolbox/cli/nucleus_instance_segment.py | 8 +-- tiatoolbox/cli/patch_predictor.py | 8 +-- tiatoolbox/cli/semantic_segment.py | 8 +-- tiatoolbox/models/architecture/hovernet.py | 12 ++-- .../models/architecture/hovernetplus.py | 10 ++- tiatoolbox/models/architecture/mapde.py | 8 +-- tiatoolbox/models/architecture/micronet.py | 10 ++- tiatoolbox/models/architecture/nuclick.py | 11 ++-- tiatoolbox/models/architecture/sccnn.py | 9 +-- tiatoolbox/models/architecture/unet.py | 8 +-- tiatoolbox/models/architecture/utils.py | 8 +-- tiatoolbox/models/architecture/vanilla.py | 64 +++++++++---------- .../engine/nucleus_instance_segmentor.py | 8 +-- tiatoolbox/models/engine/patch_predictor.py | 51 +++++++++------ .../models/engine/semantic_segmentor.py | 34 ++++++---- tiatoolbox/models/models_abc.py | 38 +++++++++-- tiatoolbox/utils/misc.py | 19 ------ tiatoolbox/utils/visualization.py | 11 +++- tiatoolbox/visualization/bokeh_app/main.py | 3 +- whitelist.txt | 1 + 42 files changed, 342 insertions(+), 290 deletions(-) diff --git a/.gitignore b/.gitignore index 409fc1261..66c072da5 100644 --- a/.gitignore +++ b/.gitignore @@ -116,3 +116,6 @@ ENV/ # vim/vi generated *.swp + +# output zarr generated +*.zarr diff --git a/tests/models/test_abc.py b/tests/models/test_abc.py index d8af37193..f7a60e34c 100644 --- a/tests/models/test_abc.py +++ b/tests/models/test_abc.py @@ -6,14 +6,15 @@ import pytest import torch +import torchvision.models as torch_models from torch import nn -from tiatoolbox import rcParam +from tiatoolbox import rcParam, utils from tiatoolbox.models.architecture import ( fetch_pretrained_weights, get_pretrained_model, ) -from tiatoolbox.models.models_abc import ModelABC +from tiatoolbox.models.models_abc import ModelABC, model_to from tiatoolbox.utils import env_detection as toolbox_env if TYPE_CHECKING: @@ -149,3 +150,18 @@ def test_model_abc() -> None: weights_path = fetch_pretrained_weights("alexnet-kather100k") with pytest.raises(RuntimeError, match=r".*loading state_dict*"): _ = model.load_weights_from_file(weights_path) + + +def test_model_to() -> None: + """Test for placing model on device.""" + # Test on GPU + # no GPU on GitHub Actions so this will crash + if not utils.env_detection.has_gpu(): + model = torch_models.resnet18() + with pytest.raises((AssertionError, RuntimeError)): + _ = model_to(device="cuda", model=model) + + # Test on CPU + model = torch_models.resnet18() + model = model_to(device="cpu", model=model) + assert isinstance(model, nn.Module) diff --git a/tests/models/test_arch_mapde.py b/tests/models/test_arch_mapde.py index febcfbdec..61bfde817 100644 --- a/tests/models/test_arch_mapde.py +++ b/tests/models/test_arch_mapde.py @@ -45,7 +45,7 @@ def test_functionality(remote_sample: Callable) -> None: model = _load_mapde(name="mapde-conic") patch = model.preproc(patch) batch = torch.from_numpy(patch)[None] - model = model.to(select_device(on_gpu=ON_GPU)) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + model = model.to() + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) output = model.postproc(output[0]) assert np.all(output[0:2] == [[19, 171], [53, 89]]) diff --git a/tests/models/test_arch_micronet.py b/tests/models/test_arch_micronet.py index cd4bd0833..e7aa23d5b 100644 --- a/tests/models/test_arch_micronet.py +++ b/tests/models/test_arch_micronet.py @@ -39,7 +39,7 @@ def test_functionality( model = model.to(map_location) pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + output = model.infer_batch(model, batch, device=map_location) output, _ = model.postproc(output[0]) assert np.max(np.unique(output)) == 46 diff --git a/tests/models/test_arch_nuclick.py b/tests/models/test_arch_nuclick.py index fda0c01a6..b84516125 100644 --- a/tests/models/test_arch_nuclick.py +++ b/tests/models/test_arch_nuclick.py @@ -10,6 +10,7 @@ from tiatoolbox.models import NuClick from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.utils import imread +from tiatoolbox.utils.misc import select_device ON_GPU = False @@ -53,7 +54,7 @@ def test_functional_nuclick( model = NuClick(num_input_channels=5, num_output_channels=1) pretrained = torch.load(weights_path, map_location="cpu") model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) postproc_masks = model.postproc( output, do_reconstruction=True, diff --git a/tests/models/test_arch_sccnn.py b/tests/models/test_arch_sccnn.py index b3dd94e50..2729d2b3a 100644 --- a/tests/models/test_arch_sccnn.py +++ b/tests/models/test_arch_sccnn.py @@ -5,9 +5,10 @@ import numpy as np import torch -from tiatoolbox import utils from tiatoolbox.models import SCCNN from tiatoolbox.models.architecture import fetch_pretrained_weights +from tiatoolbox.utils import env_detection +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader @@ -15,7 +16,7 @@ def _load_sccnn(name: str) -> torch.nn.Module: """Loads SCCNN model with specified weights.""" model = SCCNN() weights_path = fetch_pretrained_weights(name) - map_location = utils.misc.select_device(on_gpu=utils.env_detection.has_gpu()) + map_location = select_device(on_gpu=env_detection.has_gpu()) pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) @@ -40,11 +41,19 @@ def test_functionality(remote_sample: Callable) -> None: ) batch = torch.from_numpy(patch)[None] model = _load_sccnn(name="sccnn-crchisto") - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch( + model, + batch, + device=select_device(on_gpu=env_detection.has_gpu()), + ) output = model.postproc(output[0]) assert np.all(output == [[8, 7]]) model = _load_sccnn(name="sccnn-conic") - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch( + model, + batch, + device=select_device(on_gpu=env_detection.has_gpu()), + ) output = model.postproc(output[0]) assert np.all(output == [[7, 8]]) diff --git a/tests/models/test_arch_unet.py b/tests/models/test_arch_unet.py index b0cbc6085..2ac231c7c 100644 --- a/tests/models/test_arch_unet.py +++ b/tests/models/test_arch_unet.py @@ -9,6 +9,7 @@ from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.models.architecture.unet import UNetModel +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader ON_GPU = False @@ -48,7 +49,7 @@ def test_functional_unet(remote_sample: Callable) -> None: model = UNetModel(3, 2, encoder="resnet50", decoder_block=[3]) pretrained = torch.load(pretrained_weights, map_location="cpu") model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=ON_GPU) + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) _ = output[0] # run untrained network to test for architecture @@ -60,4 +61,4 @@ def test_functional_unet(remote_sample: Callable) -> None: encoder_levels=[32, 64], skip_type="concat", ) - _ = model.infer_batch(model, batch, on_gpu=ON_GPU) + _ = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) diff --git a/tests/models/test_arch_vanilla.py b/tests/models/test_arch_vanilla.py index 29c76ab4e..a87424dfd 100644 --- a/tests/models/test_arch_vanilla.py +++ b/tests/models/test_arch_vanilla.py @@ -5,10 +5,11 @@ import torch from tiatoolbox.models.architecture.vanilla import CNNModel, TimmModel -from tiatoolbox.utils.misc import model_to +from tiatoolbox.models.models_abc import model_to ON_GPU = False RNG = np.random.default_rng() # Numpy Random Generator +device = "cuda" if ON_GPU else "cpu" def test_functional() -> None: @@ -43,8 +44,8 @@ def test_functional() -> None: try: for backbone in backbones: model = CNNModel(backbone, num_classes=1) - model_ = model_to(on_gpu=ON_GPU, model=model) - model.infer_batch(model_, samples, on_gpu=ON_GPU) + model_ = model_to(device=device, model=model) + model.infer_batch(model_, samples, device=device) except ValueError as exc: msg = f"Model {backbone} failed." raise AssertionError(msg) from exc @@ -70,8 +71,8 @@ def test_timm_functional() -> None: try: for backbone in backbones: model = TimmModel(backbone=backbone, num_classes=1, pretrained=False) - model_ = model_to(on_gpu=ON_GPU, model=model) - model.infer_batch(model_, samples, on_gpu=ON_GPU) + model_ = model_to(device=device, model=model) + model.infer_batch(model_, samples, device=device) except ValueError as exc: msg = f"Model {backbone} failed." raise AssertionError(msg) from exc diff --git a/tests/models/test_feature_extractor.py b/tests/models/test_feature_extractor.py index 15468ab32..cd33f0a5a 100644 --- a/tests/models/test_feature_extractor.py +++ b/tests/models/test_feature_extractor.py @@ -14,6 +14,7 @@ IOSegmentorConfig, ) from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader ON_GPU = not toolbox_env.running_on_ci() and toolbox_env.has_gpu() @@ -35,7 +36,7 @@ def test_engine(remote_sample: Callable, tmp_path: Path) -> None: output_list = extractor.predict( [mini_wsi_svs], mode="wsi", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ) @@ -82,7 +83,7 @@ def test_full_inference( [mini_wsi_svs], mode="wsi", ioconfig=ioconfig, - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ) diff --git a/tests/models/test_hovernet.py b/tests/models/test_hovernet.py index b2271ab4c..2567018b8 100644 --- a/tests/models/test_hovernet.py +++ b/tests/models/test_hovernet.py @@ -14,6 +14,7 @@ ResidualBlock, TFSamepaddingLayer, ) +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader @@ -34,7 +35,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_fast-pannuke") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." @@ -51,7 +52,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_fast-monusac") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." @@ -68,7 +69,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_original-consep") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." @@ -85,7 +86,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernet_original-kumar") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) output = [v[0] for v in output] output = model.postproc(output) assert len(output[1]) > 0, "Must have some nuclei." diff --git a/tests/models/test_hovernetplus.py b/tests/models/test_hovernetplus.py index 96d0f9d23..1377fdd82 100644 --- a/tests/models/test_hovernetplus.py +++ b/tests/models/test_hovernetplus.py @@ -7,6 +7,7 @@ from tiatoolbox.models import HoVerNetPlus from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.utils import imread +from tiatoolbox.utils.misc import select_device from tiatoolbox.utils.transforms import imresize @@ -28,7 +29,7 @@ def test_functionality(remote_sample: Callable) -> None: weights_path = fetch_pretrained_weights("hovernetplus-oed") pretrained = torch.load(weights_path) model.load_state_dict(pretrained) - output = model.infer_batch(model, batch, on_gpu=False) + output = model.infer_batch(model, batch, device=select_device(on_gpu=False)) assert len(output) == 4, "Must contain predictions for: np, hv, tp and ls branches." output = [v[0] for v in output] output = model.postproc(output) diff --git a/tests/models/test_multi_task_segmentor.py b/tests/models/test_multi_task_segmentor.py index a7e76f719..1f135b303 100644 --- a/tests/models/test_multi_task_segmentor.py +++ b/tests/models/test_multi_task_segmentor.py @@ -17,6 +17,7 @@ from tiatoolbox.utils import env_detection as toolbox_env from tiatoolbox.utils import imwrite from tiatoolbox.utils.metrics import f1_detection +from tiatoolbox.utils.misc import select_device ON_GPU = toolbox_env.has_gpu() BATCH_SIZE = 1 if not ON_GPU else 8 # 16 @@ -64,7 +65,7 @@ def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None: output = multi_segmentor.predict( [mini_wsi_svs], mode="wsi", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ) @@ -83,7 +84,7 @@ def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None: output = multi_segmentor.predict( [mini_wsi_svs], mode="wsi", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ) @@ -117,7 +118,7 @@ def test_functionality_hovernetplus(remote_sample: Callable, tmp_path: Path) -> output = multi_segmentor.predict( [mini_wsi_svs], mode="wsi", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ) @@ -148,7 +149,7 @@ def test_functionality_hovernet(remote_sample: Callable, tmp_path: Path) -> None output = multi_segmentor.predict( [mini_wsi_svs], mode="wsi", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ) @@ -195,7 +196,7 @@ def test_masked_segmentor(remote_sample: Callable, tmp_path: Path) -> None: masks=[sample_wsi_msk], mode="wsi", ioconfig=ioconfig, - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ) @@ -230,7 +231,7 @@ def test_functionality_process_instance_predictions( output = semantic_segmentor.predict( [mini_wsi_svs], mode="wsi", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ) @@ -268,7 +269,7 @@ def test_empty_image(tmp_path: Path) -> None: _ = multi_segmentor.predict( [sample_patch_path], mode="tile", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ) @@ -284,7 +285,7 @@ def test_empty_image(tmp_path: Path) -> None: _ = multi_segmentor.predict( [sample_patch_path], mode="tile", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ) @@ -312,7 +313,7 @@ def test_empty_image(tmp_path: Path) -> None: _ = multi_segmentor.predict( [sample_patch_path], mode="tile", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ioconfig=bcc_wsi_ioconfig, @@ -361,7 +362,7 @@ def test_functionality_semantic(remote_sample: Callable, tmp_path: Path) -> None output = multi_segmentor.predict( [mini_wsi_svs], mode="wsi", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ioconfig=bcc_wsi_ioconfig, @@ -413,7 +414,7 @@ def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None: masks=[sample_wsi_msk], mode="wsi", ioconfig=ioconfig, - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ) diff --git a/tests/models/test_nucleus_instance_segmentor.py b/tests/models/test_nucleus_instance_segmentor.py index ff6b9a4cc..2956849fb 100644 --- a/tests/models/test_nucleus_instance_segmentor.py +++ b/tests/models/test_nucleus_instance_segmentor.py @@ -28,6 +28,7 @@ from tiatoolbox.utils import env_detection as toolbox_env from tiatoolbox.utils import imwrite from tiatoolbox.utils.metrics import f1_detection +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader ON_GPU = toolbox_env.has_gpu() @@ -278,7 +279,7 @@ def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None: masks=[sample_wsi_msk], mode="wsi", ioconfig=ioconfig, - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ) @@ -326,7 +327,7 @@ def test_functionality_ci(remote_sample: Callable, tmp_path: Path) -> None: [mini_wsi_svs], mode="wsi", ioconfig=ioconfig, - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ) @@ -373,7 +374,7 @@ def test_functionality_merge_tile_predictions_ci( output = semantic_segmentor.predict( [mini_wsi_svs], mode="wsi", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), ioconfig=ioconfig, crash_on_exception=True, save_dir=save_dir, @@ -453,7 +454,7 @@ def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None: output = inst_segmentor.predict( [mini_wsi_svs], mode="wsi", - on_gpu=True, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ) @@ -471,7 +472,7 @@ def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None: output = inst_segmentor.predict( [mini_wsi_svs], mode="wsi", - on_gpu=True, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ) @@ -496,7 +497,7 @@ def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None: output = semantic_segmentor.predict( [mini_wsi_svs], mode="wsi", - on_gpu=True, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ) diff --git a/tests/models/test_patch_predictor.py b/tests/models/test_patch_predictor.py index 5fd930138..913d63241 100644 --- a/tests/models/test_patch_predictor.py +++ b/tests/models/test_patch_predictor.py @@ -25,6 +25,7 @@ ) from tiatoolbox.utils import download_data, imread, imwrite from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader ON_GPU = toolbox_env.has_gpu() @@ -547,7 +548,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None: [mini_wsi_svs], mode="wsi", save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), **_kwargs, ) shutil.rmtree(tmp_path / "dump", ignore_errors=True) @@ -563,7 +564,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None: ioconfig=ioconfig, mode="wsi", save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), ) shutil.rmtree(tmp_path / "dump", ignore_errors=True) @@ -571,7 +572,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None: [mini_wsi_svs], mode="wsi", save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), **kwargs, ) shutil.rmtree(tmp_path / "dump", ignore_errors=True) @@ -582,7 +583,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None: [mini_wsi_svs], patch_input_shape=(300, 300), mode="wsi", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), save_dir=f"{tmp_path}/dump", ) assert predictor._ioconfig.patch_input_shape == (300, 300) @@ -592,7 +593,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None: [mini_wsi_svs], stride_shape=(300, 300), mode="wsi", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), save_dir=f"{tmp_path}/dump", ) assert predictor._ioconfig.stride_shape == (300, 300) @@ -602,7 +603,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None: [mini_wsi_svs], resolution=1.99, mode="wsi", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), save_dir=f"{tmp_path}/dump", ) assert predictor._ioconfig.input_resolutions[0]["resolution"] == 1.99 @@ -612,7 +613,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None: [mini_wsi_svs], units="baseline", mode="wsi", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), save_dir=f"{tmp_path}/dump", ) assert predictor._ioconfig.input_resolutions[0]["units"] == "baseline" @@ -624,7 +625,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None: mode="wsi", merge_predictions=True, save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), ) shutil.rmtree(tmp_path / "dump", ignore_errors=True) @@ -643,7 +644,7 @@ def test_patch_predictor_api( # don't run test on GPU output = predictor.predict( inputs, - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), save_dir=save_dir_path, ) assert sorted(output.keys()) == ["predictions"] @@ -654,7 +655,7 @@ def test_patch_predictor_api( inputs, labels=[1, "a"], return_labels=True, - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), save_dir=save_dir_path, ) assert sorted(output.keys()) == sorted(["labels", "predictions"]) @@ -665,7 +666,7 @@ def test_patch_predictor_api( output = predictor.predict( inputs, return_probabilities=True, - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), save_dir=save_dir_path, ) assert sorted(output.keys()) == sorted(["predictions", "probabilities"]) @@ -677,7 +678,7 @@ def test_patch_predictor_api( return_probabilities=True, labels=[1, "a"], return_labels=True, - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), save_dir=save_dir_path, ) assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) @@ -687,7 +688,7 @@ def test_patch_predictor_api( # test saving output, should have no effect _ = predictor.predict( inputs, - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), save_dir="special_dir_not_exist", ) assert not Path.is_dir(Path("special_dir_not_exist")) @@ -721,7 +722,7 @@ def test_patch_predictor_api( return_probabilities=True, labels=[1, "a"], return_labels=True, - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), save_dir=save_dir_path, ) assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) @@ -751,7 +752,7 @@ def test_wsi_predictor_api( kwargs = { "return_probabilities": True, "return_labels": True, - "on_gpu": ON_GPU, + "device": select_device(on_gpu=ON_GPU), "patch_input_shape": patch_size, "stride_shape": patch_size, "resolution": 1.0, @@ -788,7 +789,7 @@ def test_wsi_predictor_api( kwargs = { "return_probabilities": True, "return_labels": True, - "on_gpu": ON_GPU, + "device": select_device(on_gpu=ON_GPU), "patch_input_shape": patch_size, "stride_shape": patch_size, "resolution": 0.5, @@ -903,7 +904,7 @@ def test_wsi_predictor_merge_predictions(sample_wsi_dict: dict) -> None: kwargs = { "return_probabilities": True, "return_labels": True, - "on_gpu": ON_GPU, + "device": select_device(on_gpu=ON_GPU), "patch_input_shape": np.array([224, 224]), "stride_shape": np.array([224, 224]), "resolution": 1.0, @@ -958,8 +959,7 @@ def _test_predictor_output( pretrained_model: str, probabilities_check: list | None = None, predictions_check: list | None = None, - *, - on_gpu: bool = ON_GPU, + device: str = select_device(on_gpu=ON_GPU), ) -> None: """Test the predictions of multiple models included in tiatoolbox.""" predictor = PatchPredictor( @@ -972,7 +972,7 @@ def _test_predictor_output( inputs, return_probabilities=True, return_labels=False, - on_gpu=on_gpu, + device=device, ) predictions = output["predictions"] probabilities = output["probabilities"] @@ -1025,7 +1025,7 @@ def test_patch_predictor_kather100k_output( pretrained_model, probabilities_check=expected_prob, predictions_check=[6, 3], - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), ) # only test 1 on travis to limit runtime if toolbox_env.running_on_ci(): @@ -1060,7 +1060,7 @@ def test_patch_predictor_pcam_output(sample_patch3: Path, sample_patch4: Path) - pretrained_model, probabilities_check=expected_prob, predictions_check=[1, 0], - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), ) # only test 1 on travis to limit runtime if toolbox_env.running_on_ci(): diff --git a/tests/models/test_semantic_segmentation.py b/tests/models/test_semantic_segmentation.py index 8fee41a9b..01776b800 100644 --- a/tests/models/test_semantic_segmentation.py +++ b/tests/models/test_semantic_segmentation.py @@ -32,6 +32,7 @@ from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.utils import env_detection as toolbox_env from tiatoolbox.utils import imread, imwrite +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader ON_GPU = toolbox_env.has_gpu() @@ -70,12 +71,7 @@ def forward(self: _CNNTo1, img: np.ndarray) -> torch.Tensor: return self.conv(img) @staticmethod - def infer_batch( - model: nn.Module, - batch_data: torch.Tensor, - *, - on_gpu: bool, - ) -> list: + def infer_batch(model: nn.Module, batch_data: torch.Tensor, device: str) -> list: """Run inference on an input batch. Contains logic for forward operation as well as i/o @@ -85,10 +81,14 @@ def infer_batch( model (nn.Module): PyTorch defined model. batch_data (torch.Tensor): A batch of data generated by torch.utils.data.DataLoader. - on_gpu (bool): Whether to run inference on a GPU. + device (str): + :class:`torch.device` to run the model. + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default value is "cpu". """ - device = "cuda" if on_gpu else "cpu" + device = "cuda" if ON_GPU else "cpu" #### model.eval() # infer mode @@ -307,7 +307,7 @@ def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None: semantic_segmentor.predict( [mini_wsi_jpg], mode="tile", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ) @@ -325,7 +325,7 @@ def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None: semantic_segmentor.predict( [mini_wsi_jpg], mode="tile", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ) @@ -339,7 +339,7 @@ def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None: [mini_wsi_svs], patch_input_shape=(2048, 2048), mode="wsi", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir, ) @@ -350,7 +350,7 @@ def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None: [mini_wsi_svs], patch_input_shape=(2048, 2048), mode="wsi", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=False, save_dir=save_dir, ) @@ -494,7 +494,7 @@ def test_functional_segmentor( semantic_segmentor.predict( [mini_wsi_jpg], mode="tile", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), patch_input_shape=(512, 512), resolution=resolution, units="mpp", @@ -506,7 +506,7 @@ def test_functional_segmentor( semantic_segmentor.predict( [mini_wsi_jpg], mode="tile", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), patch_input_shape=(512, 512), resolution=1 / resolution, units="baseline", @@ -521,7 +521,7 @@ def test_functional_segmentor( semantic_segmentor.predict( [mini_wsi_jpg], mode="tile", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), patch_input_shape=(512, 512), patch_output_shape=(512, 512), stride_shape=(512, 512), @@ -552,7 +552,7 @@ def test_functional_segmentor( output_list = semantic_segmentor.predict( file_list, mode="tile", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), ioconfig=ioconfig, crash_on_exception=True, save_dir=f"{save_dir}/raw/", @@ -581,7 +581,7 @@ def test_functional_segmentor( [mini_wsi_svs], masks=[mini_wsi_msk], mode="wsi", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), ioconfig=ioconfig, crash_on_exception=True, save_dir=f"{save_dir}/raw/", @@ -605,7 +605,7 @@ def test_functional_segmentor( [mini_wsi_svs], masks=[mini_wsi_msk], mode="wsi", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), ioconfig=ioconfig, crash_on_exception=True, save_dir=f"{save_dir}/raw/", @@ -631,7 +631,7 @@ def __init__(self: XSegmentor) -> None: semantic_segmentor.predict( [mini_wsi_jpg], mode="tile", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), patch_input_shape=(1024, 1024), patch_output_shape=(512, 512), stride_shape=(256, 256), @@ -661,7 +661,7 @@ def test_functional_pretrained(remote_sample: Callable, tmp_path: Path) -> None: semantic_segmentor.predict( [mini_wsi_svs], mode="wsi", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=f"{save_dir}/raw/", ) @@ -672,7 +672,7 @@ def test_functional_pretrained(remote_sample: Callable, tmp_path: Path) -> None: semantic_segmentor.predict( [mini_wsi_jpg], mode="tile", - on_gpu=ON_GPU, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=f"{save_dir}/raw/", ) @@ -699,7 +699,7 @@ def test_behavior_tissue_mask_local(remote_sample: Callable, tmp_path: Path) -> semantic_segmentor.predict( [wsi_with_artifacts], mode="wsi", - on_gpu=True, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir / "raw", ) @@ -715,7 +715,7 @@ def test_behavior_tissue_mask_local(remote_sample: Callable, tmp_path: Path) -> semantic_segmentor.predict( [mini_wsi_jpg], mode="tile", - on_gpu=True, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=f"{save_dir}/raw/", ) @@ -738,7 +738,7 @@ def test_behavior_bcss_local(remote_sample: Callable, tmp_path: Path) -> None: semantic_segmentor.predict( [wsi_breast], mode="wsi", - on_gpu=True, + device=select_device(on_gpu=ON_GPU), crash_on_exception=True, save_dir=save_dir / "raw", ) diff --git a/tests/test_annotation_stores.py b/tests/test_annotation_stores.py index 01bbdac45..66c990161 100644 --- a/tests/test_annotation_stores.py +++ b/tests/test_annotation_stores.py @@ -53,14 +53,6 @@ FILLED_LEN = 2 * (GRID_SIZE[0] * GRID_SIZE[1]) RNG = np.random.default_rng(0) # Numpy Random Generator -# ---------------------------------------------------------------------- -# Resets -# ---------------------------------------------------------------------- - -# Reset filters in logger. -for filter_ in logger.filters: - logger.removeFilter(filter_) - # ---------------------------------------------------------------------- # Helper Functions # ---------------------------------------------------------------------- @@ -546,6 +538,9 @@ def test_sqlite_store_compile_options_missing_math( caplog: pytest.LogCaptureFixture, ) -> None: """Test that a warning is shown if the sqlite math module is missing.""" + # Reset filters in logger. + for filter_ in logger.filters[:]: + logger.removeFilter(filter_) monkeypatch.setattr( SQLiteStore, "compile_options", diff --git a/tests/test_annotation_tilerendering.py b/tests/test_annotation_tilerendering.py index 0734b9164..0ee34b17b 100644 --- a/tests/test_annotation_tilerendering.py +++ b/tests/test_annotation_tilerendering.py @@ -23,7 +23,7 @@ from tiatoolbox.annotation import Annotation, AnnotationStore, SQLiteStore from tiatoolbox.tools.pyramid import AnnotationTileGenerator from tiatoolbox.utils.env_detection import running_on_travis -from tiatoolbox.utils.visualization import AnnotationRenderer +from tiatoolbox.utils.visualization import AnnotationRenderer, _find_minimum_mpp_sf from tiatoolbox.wsicore import wsireader RNG = np.random.default_rng(0) # Numpy Random Generator @@ -462,6 +462,7 @@ def test_function_mapper(fill_store: Callable, tmp_path: Path) -> None: _, store = fill_store(SQLiteStore, tmp_path / "test.db") def color_fn(props: dict[str, str]) -> tuple[int, int, int]: + """Tests Red for cells, otherwise green.""" # simple test function that returns red for cells, otherwise green. if props["type"] == "cell": return 1, 0, 0 @@ -480,3 +481,15 @@ def color_fn(props: dict[str, str]) -> tuple[int, int, int]: assert num == 50 # expect 50 green objects _, num = label(np.array(thumb)[:, :, 2]) assert num == 0 # expect 0 blue objects + + +def test_minimum_mpp_sf() -> None: + """Test minimum mpp_sf.""" + mpp_sf = _find_minimum_mpp_sf((0.5, 0.5)) + assert mpp_sf == 1.0 + + mpp_sf = _find_minimum_mpp_sf((0.20, 0.20)) + assert mpp_sf == 0.20 / 0.25 + + mpp_sf = _find_minimum_mpp_sf(None) + assert mpp_sf == 1.0 diff --git a/tests/test_init.py b/tests/test_init.py index 509a9c49f..6d8ed8238 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -114,7 +114,7 @@ def test_duplicate_filter(caplog: pytest.LogCaptureFixture) -> None: logger.addFilter(duplicate_filter) # Reset filters in logger. - for filter_ in logger.filters: + for filter_ in logger.filters[:]: logger.removeFilter(filter_) for _ in range(2): diff --git a/tests/test_utils.py b/tests/test_utils.py index fe18e0d36..95e6ee520 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1336,24 +1336,6 @@ def test_select_device() -> None: assert device == "cpu" -def test_model_to() -> None: - """Test for placing model on device.""" - import torchvision.models as torch_models - from torch import nn - - # Test on GPU - # no GPU on Travis so this will crash - if not utils.env_detection.has_gpu(): - model = torch_models.resnet18() - with pytest.raises((AssertionError, RuntimeError)): - _ = misc.model_to(on_gpu=True, model=model) - - # Test on CPU - model = torch_models.resnet18() - model = misc.model_to(on_gpu=False, model=model) - assert isinstance(model, nn.Module) - - def test_save_as_json(tmp_path: Path) -> None: """Test save data to json.""" # This should be broken up into separate tests! @@ -1673,7 +1655,7 @@ def test_patch_pred_store() -> None: store = misc.dict_to_store(patch_output, (1.0, 1.0)) - # Check that its an SQLiteStore containing the expected annotations + # Check that it is an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) assert len(store) == 3 for annotation in store.values(): @@ -1700,7 +1682,7 @@ def test_patch_pred_store_cdict() -> None: class_dict = {0: "class0", 1: "class1"} store = misc.dict_to_store(patch_output, (1.0, 1.0), class_dict=class_dict) - # Check that its an SQLiteStore containing the expected annotations + # Check that it is an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) assert len(store) == 3 for annotation in store.values(): diff --git a/tests/test_wsimeta.py b/tests/test_wsimeta.py index bc3555e36..01b1cac8b 100644 --- a/tests/test_wsimeta.py +++ b/tests/test_wsimeta.py @@ -8,7 +8,6 @@ from tiatoolbox.wsicore import WSIMeta, wsimeta, wsireader -# noinspection PyTypeChecker def test_wsimeta_init_fail() -> None: """Test incorrect init for WSIMeta raises TypeError.""" with pytest.raises(TypeError): diff --git a/tiatoolbox/annotation/storage.py b/tiatoolbox/annotation/storage.py index 0cd476358..420e94085 100644 --- a/tiatoolbox/annotation/storage.py +++ b/tiatoolbox/annotation/storage.py @@ -2556,7 +2556,21 @@ def _unpack_wkb( cx: float, cy: float, ) -> bytes: - """Unpack WKB data.""" + """Return the geometry as bytes using WKB. + + Args: + data (bytes or str): + The WKB/WKT data to be unpacked. + cx (int): + The X coordinate of the centroid/representative point. + cy (float): + The Y coordinate of the centroid/representative point. + + Returns: + bytes: + The geometry as bytes. + + """ return ( self._decompress_data(data) if data diff --git a/tiatoolbox/cli/common.py b/tiatoolbox/cli/common.py index 18e731b4c..26f85625e 100644 --- a/tiatoolbox/cli/common.py +++ b/tiatoolbox/cli/common.py @@ -234,6 +234,18 @@ def cli_pretrained_weights( ) +def cli_device( + usage_help: str = "Select the device (cpu/cuda/mps) to use for inference.", + default: str = "cpu", +) -> Callable: + """Enables --pretrained-weights option for cli.""" + return click.option( + "--device", + help=add_default_to_usage_help(usage_help, default), + default=default, + ) + + def cli_return_probabilities( usage_help: str = "Whether to return raw model probabilities.", *, @@ -333,20 +345,6 @@ def cli_yaml_config_path( ) -def cli_on_gpu( - usage_help: str = "Run the model on GPU.", - *, - default: bool = False, -) -> Callable: - """Enables --on-gpu option for cli.""" - return click.option( - "--on-gpu", - type=bool, - default=default, - help=add_default_to_usage_help(usage_help, default), - ) - - def cli_num_loader_workers( usage_help: str = "Number of workers to load the data. Please note that they will " "also perform preprocessing.", diff --git a/tiatoolbox/cli/nucleus_instance_segment.py b/tiatoolbox/cli/nucleus_instance_segment.py index b38dcdaed..fdb4b95ca 100644 --- a/tiatoolbox/cli/nucleus_instance_segment.py +++ b/tiatoolbox/cli/nucleus_instance_segment.py @@ -7,13 +7,13 @@ from tiatoolbox.cli.common import ( cli_auto_generate_mask, cli_batch_size, + cli_device, cli_file_type, cli_img_input, cli_masks, cli_mode, cli_num_loader_workers, cli_num_postproc_workers, - cli_on_gpu, cli_output_path, cli_pretrained_model, cli_pretrained_weights, @@ -41,7 +41,7 @@ ) @cli_pretrained_model(default="hovernet_fast-pannuke") @cli_pretrained_weights(default=None) -@cli_on_gpu(default=False) +@cli_device(default="cpu") @cli_batch_size() @cli_masks(default=None) @cli_yaml_config_path(default=None) @@ -61,9 +61,9 @@ def nucleus_instance_segment( yaml_config_path: str, num_loader_workers: int, num_postproc_workers: int, + device: str, *, auto_generate_mask: bool, - on_gpu: bool, verbose: bool, ) -> None: """Process an image/directory of input images with a patch classification CNN.""" @@ -97,7 +97,7 @@ def nucleus_instance_segment( imgs=files_all, masks=masks_all, mode=mode, - on_gpu=on_gpu, + device=device, save_dir=output_path, ioconfig=ioconfig, ) diff --git a/tiatoolbox/cli/patch_predictor.py b/tiatoolbox/cli/patch_predictor.py index a97ecb571..069b6c367 100644 --- a/tiatoolbox/cli/patch_predictor.py +++ b/tiatoolbox/cli/patch_predictor.py @@ -6,13 +6,13 @@ from tiatoolbox.cli.common import ( cli_batch_size, + cli_device, cli_file_type, cli_img_input, cli_masks, cli_merge_predictions, cli_mode, cli_num_loader_workers, - cli_on_gpu, cli_output_path, cli_pretrained_model, cli_pretrained_weights, @@ -45,7 +45,7 @@ @cli_return_probabilities(default=False) @cli_merge_predictions(default=True) @cli_return_labels(default=True) -@cli_on_gpu(default=False) +@cli_device(default="cpu") @cli_batch_size(default=1) @cli_resolution(default=0.5) @cli_units(default="mpp") @@ -64,11 +64,11 @@ def patch_predictor( resolution: float, units: str, num_loader_workers: int, + device: str, *, return_probabilities: bool, return_labels: bool, merge_predictions: bool, - on_gpu: bool, verbose: bool, ) -> None: """Process an image/directory of input images with a patch classification CNN.""" @@ -100,7 +100,7 @@ def patch_predictor( return_labels=return_labels, resolution=resolution, units=units, - on_gpu=on_gpu, + device=device, save_dir=output_path, save_output=True, ) diff --git a/tiatoolbox/cli/semantic_segment.py b/tiatoolbox/cli/semantic_segment.py index 8947b2beb..cbfe18e58 100644 --- a/tiatoolbox/cli/semantic_segment.py +++ b/tiatoolbox/cli/semantic_segment.py @@ -6,12 +6,12 @@ from tiatoolbox.cli.common import ( cli_batch_size, + cli_device, cli_file_type, cli_img_input, cli_masks, cli_mode, cli_num_loader_workers, - cli_on_gpu, cli_output_path, cli_pretrained_model, cli_pretrained_weights, @@ -39,7 +39,7 @@ ) @cli_pretrained_model(default="fcn-tissue_mask") @cli_pretrained_weights(default=None) -@cli_on_gpu() +@cli_device() @cli_batch_size() @cli_masks(default=None) @cli_yaml_config_path() @@ -56,8 +56,8 @@ def semantic_segment( batch_size: int, yaml_config_path: str, num_loader_workers: int, + device: str, *, - on_gpu: bool, verbose: bool, ) -> None: """Process an image/directory of input images with a patch classification CNN.""" @@ -89,7 +89,7 @@ def semantic_segment( imgs=files_all, masks=masks_all, mode=mode, - on_gpu=on_gpu, + device=device, save_dir=output_path, ioconfig=ioconfig, ) diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py index 8f061d273..2853c4946 100644 --- a/tiatoolbox/models/architecture/hovernet.py +++ b/tiatoolbox/models/architecture/hovernet.py @@ -20,7 +20,6 @@ centre_crop_to_shape, ) from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc from tiatoolbox.utils.misc import get_bounding_box @@ -766,7 +765,7 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]: >>> pretrained = torch.load(weights_path) >>> model = HoVerNet(num_types=6, mode="fast") >>> model.load_state_dict(pretrained) - >>> output = model.infer_batch(model, batch, on_gpu=False) + >>> output = model.infer_batch(model, batch, device="cuda") >>> output = [v[0] for v in output] >>> output = model.postproc(output) @@ -785,7 +784,9 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]: return pred_inst, nuc_inst_info_dict @staticmethod - def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tuple: + def infer_batch( # skipcq: PYL-W0221 + model: nn.Module, batch_data: np.ndarray, *, device: str + ) -> tuple: """Run inference on an input batch. This contains logic for forward operation as well as batch i/o @@ -797,8 +798,8 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tu batch_data (ndarray): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: tuple: @@ -810,7 +811,6 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tu """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py index 87db17295..700eb303f 100644 --- a/tiatoolbox/models/architecture/hovernetplus.py +++ b/tiatoolbox/models/architecture/hovernetplus.py @@ -13,7 +13,6 @@ from tiatoolbox.models.architecture.hovernet import HoVerNet from tiatoolbox.models.architecture.utils import UpSample2x -from tiatoolbox.utils import misc class HoVerNetPlus(HoVerNet): @@ -306,7 +305,7 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple: >>> pretrained = torch.load(weights_path) >>> model = HoVerNetPlus(num_types=3, num_layers=5) >>> model.load_state_dict(pretrained) - >>> output = model.infer_batch(model, batch, on_gpu=False) + >>> output = model.infer_batch(model, batch, device="cuda") >>> output = [v[0] for v in output] >>> output = model.postproc(output) @@ -325,7 +324,7 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple: return pred_inst, nuc_inst_info_dict, pred_layer, layer_info_dict @staticmethod - def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tuple: + def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> tuple: """Run inference on an input batch. This contains logic for forward operation as well as batch i/o @@ -337,13 +336,12 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tu batch_data (ndarray): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/mapde.py b/tiatoolbox/models/architecture/mapde.py index a7156531f..bbb468bb8 100644 --- a/tiatoolbox/models/architecture/mapde.py +++ b/tiatoolbox/models/architecture/mapde.py @@ -14,7 +14,6 @@ from skimage.feature import peak_local_max from tiatoolbox.models.architecture.micronet import MicroNet -from tiatoolbox.utils.misc import select_device class MapDe(MicroNet): @@ -259,7 +258,7 @@ def infer_batch( model: torch.nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list[np.ndarray]: """Run inference on an input batch. @@ -272,8 +271,8 @@ def infer_batch( batch_data (:class:`numpy.ndarray`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list(np.ndarray): @@ -282,7 +281,6 @@ def infer_batch( """ patch_imgs = batch_data - device = select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/micronet.py b/tiatoolbox/models/architecture/micronet.py index bfc62c8ab..6065fcd46 100644 --- a/tiatoolbox/models/architecture/micronet.py +++ b/tiatoolbox/models/architecture/micronet.py @@ -19,7 +19,6 @@ from tiatoolbox.models.architecture.hovernet import HoVerNet from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc def group1_forward_branch( @@ -625,11 +624,11 @@ def preproc(image: np.ndarray) -> np.ndarray: return np.transpose(image.numpy(), axes=(1, 2, 0)) @staticmethod - def infer_batch( + def infer_batch( # skipcq: PYL-W0221 model: torch.nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list[np.ndarray]: """Run inference on an input batch. @@ -642,8 +641,8 @@ def infer_batch( batch_data (:class:`torch.Tensor`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list(np.ndarray): @@ -652,7 +651,6 @@ def infer_batch( """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/nuclick.py b/tiatoolbox/models/architecture/nuclick.py index 339777eb1..77f4ad993 100644 --- a/tiatoolbox/models/architecture/nuclick.py +++ b/tiatoolbox/models/architecture/nuclick.py @@ -22,7 +22,6 @@ from tiatoolbox import logger from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc if TYPE_CHECKING: # pragma: no cover from tiatoolbox.typing import IntPair @@ -647,7 +646,7 @@ def infer_batch( model: nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> np.ndarray: """Run inference on an input batch. @@ -656,16 +655,16 @@ def infer_batch( Args: model (nn.Module): PyTorch defined model. - batch_data (torch.Tensor): a batch of data generated by - torch.utils.data.DataLoader. - on_gpu (bool): Whether to run inference on a GPU. + batch_data (torch.Tensor): + A batch of data generated by torch.utils.data.DataLoader. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: Pixel-wise nuclei prediction for each patch, shape: (no.patch, h, w). """ model.eval() - device = misc.select_device(on_gpu=on_gpu) # Assume batch_data is NCHW batch_data = batch_data.to(device).type(torch.float32) diff --git a/tiatoolbox/models/architecture/sccnn.py b/tiatoolbox/models/architecture/sccnn.py index 9941eabff..4da0f9dca 100644 --- a/tiatoolbox/models/architecture/sccnn.py +++ b/tiatoolbox/models/architecture/sccnn.py @@ -17,7 +17,6 @@ from torch import nn from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc class SCCNN(ModelABC): @@ -354,8 +353,7 @@ def postproc(self: SCCNN, prediction_map: np.ndarray) -> np.ndarray: def infer_batch( model: nn.Module, batch_data: np.ndarray | torch.Tensor, - *, - on_gpu: bool, + device: str, ) -> list[np.ndarray]: """Run inference on an input batch. @@ -368,8 +366,8 @@ def infer_batch( batch_data (:class:`numpy.ndarray` or :class:`torch.Tensor`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list of :class:`numpy.ndarray`: @@ -378,7 +376,6 @@ def infer_batch( """ patch_imgs = batch_data - device = misc.select_device(on_gpu=on_gpu) patch_imgs_gpu = patch_imgs.to(device).type(torch.float32) # to NCHW patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() diff --git a/tiatoolbox/models/architecture/unet.py b/tiatoolbox/models/architecture/unet.py index fe1a97cc9..6385e7587 100644 --- a/tiatoolbox/models/architecture/unet.py +++ b/tiatoolbox/models/architecture/unet.py @@ -12,7 +12,6 @@ from tiatoolbox.models.architecture.utils import UpSample2x, centre_crop from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils import misc class ResNetEncoder(ResNet): @@ -416,7 +415,7 @@ def infer_batch( model: nn.Module, batch_data: torch.Tensor, *, - on_gpu: bool, + device: str, ) -> list: """Run inference on an input batch. @@ -429,8 +428,8 @@ def infer_batch( batch_data (:class:`torch.Tensor`): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: list: @@ -439,7 +438,6 @@ def infer_batch( """ model.eval() - device = misc.select_device(on_gpu=on_gpu) #### imgs = batch_data diff --git a/tiatoolbox/models/architecture/utils.py b/tiatoolbox/models/architecture/utils.py index 9df4dd56f..2ec47d99d 100644 --- a/tiatoolbox/models/architecture/utils.py +++ b/tiatoolbox/models/architecture/utils.py @@ -3,7 +3,7 @@ from __future__ import annotations import sys -from typing import Callable, NoReturn +from typing import NoReturn import numpy as np import torch @@ -41,7 +41,7 @@ def compile_model( model: nn.Module | None = None, *, mode: str = "default", -) -> Callable: +) -> nn.Module: """A decorator to compile a model using torch-compile. Args: @@ -60,7 +60,7 @@ def compile_model( CUDA graphs Returns: - Callable: + torch.nn.Module: Compiled model. """ @@ -71,7 +71,7 @@ def compile_model( is_torch_compile_compatible() # This check will be removed when torch.compile is supported in Python 3.12+ - if sys.version_info >= (3, 12): # pragma: no cover + if sys.version_info > (3, 12): # pragma: no cover logger.warning( ("torch-compile is currently not supported in Python 3.12+. ",), ) diff --git a/tiatoolbox/models/architecture/vanilla.py b/tiatoolbox/models/architecture/vanilla.py index 4879ce04c..c7d3d1498 100644 --- a/tiatoolbox/models/architecture/vanilla.py +++ b/tiatoolbox/models/architecture/vanilla.py @@ -11,7 +11,6 @@ from torch import nn from tiatoolbox.models.models_abc import ModelABC -from tiatoolbox.utils.misc import select_device if TYPE_CHECKING: # pragma: no cover from torchvision.models import WeightsEnum @@ -149,9 +148,8 @@ def _postproc(image: np.ndarray) -> np.ndarray: def _infer_batch( model: nn.Module, batch_data: torch.Tensor, - *, - on_gpu: bool, -) -> np.ndarray: + device: str, +) -> dict[str, np.ndarray]: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. @@ -162,11 +160,11 @@ def _infer_batch( batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ - img_patches_device = batch_data.to(select_device(on_gpu=on_gpu)).type( + img_patches_device = batch_data.to(device=device).type( torch.float32, ) # to NCHW img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous() @@ -243,9 +241,8 @@ def postproc(image: np.ndarray) -> np.ndarray: def infer_batch( model: nn.Module, batch_data: torch.Tensor, - *, - on_gpu: bool, - ) -> np.ndarray: + device: str = "cpu", + ) -> dict[str, np.ndarray]: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. @@ -256,11 +253,11 @@ def infer_batch( batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ - return _infer_batch(model=model, batch_data=batch_data, on_gpu=on_gpu) + return _infer_batch(model=model, batch_data=batch_data, device=device) class TimmModel(ModelABC): @@ -339,9 +336,8 @@ def postproc(image: np.ndarray) -> np.ndarray: def infer_batch( model: nn.Module, batch_data: torch.Tensor, - *, - on_gpu: bool, - ) -> np.ndarray: + device: str, + ) -> dict[str, np.ndarray]: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. @@ -352,11 +348,11 @@ def infer_batch( batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". """ - return _infer_batch(model=model, batch_data=batch_data, on_gpu=on_gpu) + return _infer_batch(model=model, batch_data=batch_data, device=device) class CNNBackbone(ModelABC): @@ -425,9 +421,8 @@ def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor: def infer_batch( model: nn.Module, batch_data: torch.Tensor, - *, - on_gpu: bool, - ) -> list[np.ndarray]: + device: str, + ) -> list[dict[str, np.ndarray]]: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. @@ -438,15 +433,15 @@ def infer_batch( batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: - list[np.ndarray]: - list of numpy arrays. + list[dict[str, np.ndarray]]: + list of dictionary values with numpy arrays. """ - return [_infer_batch(model=model, batch_data=batch_data, on_gpu=on_gpu)] + return [_infer_batch(model=model, batch_data=batch_data, device=device)] class TimmBackbone(ModelABC): @@ -500,9 +495,8 @@ def forward(self: TimmBackbone, imgs: torch.Tensor) -> torch.Tensor: def infer_batch( model: nn.Module, batch_data: torch.Tensor, - *, - on_gpu: bool, - ) -> list[np.ndarray]: + device: str, + ) -> list[dict[str, np.ndarray]]: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. @@ -513,12 +507,12 @@ def infer_batch( batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". Returns: - list[np.ndarray]: - list of numpy arrays. + list[dict[str, np.ndarray]]: + list of dictionary values with numpy arrays. """ - return [_infer_batch(model=model, batch_data=batch_data, on_gpu=on_gpu)] + return [_infer_batch(model=model, batch_data=batch_data, device=device)] diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py index cc76b68a0..6649324b1 100644 --- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py +++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py @@ -450,7 +450,7 @@ def _get_tile_info( * ioconfig.patch_output_shape ).astype(np.int32) image_shape = np.array(image_shape) - (_, tile_outputs) = PatchExtractor.get_coordinates( + tile_outputs = PatchExtractor.get_coordinates( image_shape=image_shape, patch_input_shape=tile_shape, patch_output_shape=tile_shape, @@ -459,7 +459,7 @@ def _get_tile_info( # * === Now generating the flags to indicate which side should # * === be removed in postproc callback - boxes = tile_outputs + boxes = tile_outputs[1] # This saves computation time if the image is smaller than the expected tile if np.all(image_shape <= tile_shape): @@ -485,7 +485,7 @@ def unset_removal_flag(boxes: tuple, removal_flag: np.ndarray) -> np.ndarray: return removal_flag w, h = image_shape - boxes = tile_outputs + boxes = tile_outputs[1] # expand to full four corners boxes_br = boxes[:, 2:] boxes_tr = np.dstack([boxes[:, 2], boxes[:, 1]])[0] @@ -646,7 +646,7 @@ def _infer_once(self: NucleusInstanceSegmentor) -> list: sample_outputs = self.model.infer_batch( self._model, sample_datas, - on_gpu=self._on_gpu, + device=self._device, ) # repackage so that it's a N list, each contains # L x etc. output diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index da4420cb0..9989c313e 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -16,7 +16,8 @@ from tiatoolbox.models.architecture.utils import compile_model from tiatoolbox.models.dataset.classification import PatchDataset, WSIPatchDataset from tiatoolbox.models.engine.semantic_segmentor import IOSegmentorConfig -from tiatoolbox.utils import misc, save_as_json +from tiatoolbox.models.models_abc import model_to +from tiatoolbox.utils import save_as_json from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader if TYPE_CHECKING: # pragma: no cover @@ -383,11 +384,11 @@ def merge_predictions( def _predict_engine( self: PatchPredictor, dataset: torch.utils.data.Dataset, + device: str = "cpu", *, return_probabilities: bool = False, return_labels: bool = False, return_coordinates: bool = False, - on_gpu: bool = True, ) -> np.ndarray: """Make a prediction on a dataset. The dataset may be mutated. @@ -401,8 +402,11 @@ def _predict_engine( Whether to return labels. return_coordinates (bool): Whether to return patch coordinates. - on_gpu (bool): - Whether to run model on the GPU. + device (str): + :class:`torch.device` to run the model. + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default value is "cpu". Returns: :class:`numpy.ndarray`: @@ -430,7 +434,7 @@ def _predict_engine( ) # use external for testing - model = misc.model_to(model=self.model, on_gpu=on_gpu) + model = model_to(model=self.model, device=device) cum_output = { "probabilities": [], @@ -442,7 +446,7 @@ def _predict_engine( batch_output_probabilities = self.model.infer_batch( model, batch_data["image"], - on_gpu=on_gpu, + device=device, ) # We get the index of the class with the maximum probability batch_output_predictions = self.model.postproc_func( @@ -587,10 +591,10 @@ def _predict_patch( self: PatchPredictor, imgs: list | np.ndarray, labels: list, + device: str = "cpu", *, return_probabilities: bool, return_labels: bool, - on_gpu: bool, ) -> np.ndarray: """Process patch mode. @@ -609,8 +613,11 @@ def _predict_patch( Whether to return per-class probabilities. return_labels (bool): Whether to return the labels with the predictions. - on_gpu (bool): - Whether to run model on the GPU. + device (str): + :class:`torch.device` to run the model. + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default value is "cpu". Returns: :class:`numpy.ndarray`: @@ -635,7 +642,7 @@ def _predict_patch( return_probabilities=return_probabilities, return_labels=return_labels, return_coordinates=return_coordinates, - on_gpu=on_gpu, + device=device, ) def _predict_tile_wsi( # noqa: PLR0913 @@ -647,11 +654,11 @@ def _predict_tile_wsi( # noqa: PLR0913 ioconfig: IOPatchPredictorConfig, save_dir: str | Path, highest_input_resolution: list[dict], + device: str = "cpu", *, save_output: bool, return_probabilities: bool, merge_predictions: bool, - on_gpu: bool, ) -> list | dict: """Predict on Tile and WSIs. @@ -678,8 +685,11 @@ def _predict_tile_wsi( # noqa: PLR0913 `tile` or `wsi`. return_probabilities (bool): Whether to return per-class probabilities. - on_gpu (bool): - Whether to run model on the GPU. + device (str): + :class:`torch.device` to run the model. + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default value is "cpu". ioconfig (IOPatchPredictorConfig): Patch Predictor IO configuration.. merge_predictions (bool): @@ -747,7 +757,7 @@ def _predict_tile_wsi( # noqa: PLR0913 return_labels=False, return_probabilities=return_probabilities, return_coordinates=return_coordinates, - on_gpu=on_gpu, + device=device, ) output_model["label"] = img_label # add extra information useful for downstream analysis @@ -795,10 +805,10 @@ def predict( # noqa: PLR0913 stride_shape: tuple[int, int] | None = None, resolution: Resolution | None = None, units: Units = None, + device: str = "cpu", *, return_probabilities: bool = False, return_labels: bool = False, - on_gpu: bool = True, merge_predictions: bool = False, save_dir: str | Path | None = None, save_output: bool = False, @@ -830,8 +840,11 @@ def predict( # noqa: PLR0913 Whether to return per-class probabilities. return_labels (bool): Whether to return the labels with the predictions. - on_gpu (bool): - Whether to run model on the GPU. + device (str): + :class:`torch.device` to run the model. + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default value is "cpu". ioconfig (IOPatchPredictorConfig): Patch Predictor IO configuration. patch_input_shape (tuple): @@ -901,7 +914,7 @@ def predict( # noqa: PLR0913 labels, return_probabilities=return_probabilities, return_labels=return_labels, - on_gpu=on_gpu, + device=device, ) if not isinstance(imgs, list): @@ -948,7 +961,7 @@ def predict( # noqa: PLR0913 labels=labels, mode=mode, return_probabilities=return_probabilities, - on_gpu=on_gpu, + device=device, ioconfig=ioconfig, merge_predictions=merge_predictions, save_dir=save_dir, diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 271d49150..029deebf9 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -20,9 +20,9 @@ from tiatoolbox import logger, rcParam from tiatoolbox.models.architecture import get_pretrained_model from tiatoolbox.models.architecture.utils import compile_model -from tiatoolbox.models.models_abc import IOConfigABC +from tiatoolbox.models.models_abc import IOConfigABC, model_to from tiatoolbox.tools.patchextraction import PatchExtractor -from tiatoolbox.utils import imread, misc +from tiatoolbox.utils import imread from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIMeta, WSIReader if TYPE_CHECKING: # pragma: no cover @@ -554,7 +554,7 @@ def __init__( self._cache_dir = None self._loader = None self._model = None - self._on_gpu = None + self._device = None self._mp_shared_space = None self._postproc_workers = None self.num_postproc_workers = num_postproc_workers @@ -818,7 +818,7 @@ def _predict_one_wsi( sample_outputs = self.model.infer_batch( self._model, sample_datas, - on_gpu=self._on_gpu, + device=self._device, ) # repackage so that it's an N list, each contains # L x etc. output @@ -1168,7 +1168,7 @@ def _memory_cleanup(self: SemanticSegmentor) -> None: self._cache_dir = None self._model = None self._loader = None - self._on_gpu = None + self._device = None self._futures = None self._mp_shared_space = None if self._postproc_workers is not None: @@ -1266,8 +1266,8 @@ def predict( # noqa: PLR0913 resolution: Resolution = 1.0, units: Units = "baseline", save_dir: str | Path | None = None, + device: str = "cpu", *, - on_gpu: bool = True, crash_on_exception: bool = False, ) -> list[tuple[Path, Path]]: """Make a prediction for a list of input data. @@ -1305,8 +1305,11 @@ def predict( # noqa: PLR0913 `stride_shape`, `resolution`, and `units` arguments are ignored. Otherwise, those arguments will be internally converted to a :class:`IOSegmentorConfig` object. - on_gpu (bool): - Whether to run the model on the GPU. + device (str): + :class:`torch.device` to run the model. + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default value is "cpu". patch_input_shape (tuple): Size of patches input to the model. The values are at requested read resolution and must be positive. @@ -1366,8 +1369,8 @@ def predict( # noqa: PLR0913 ) # use external for testing - self._on_gpu = on_gpu - self._model = misc.model_to(model=self.model, on_gpu=on_gpu) + self._device = device + self._model = model_to(model=self.model, device=device) # workers should be > 0 else Value Error will be thrown self._prepare_workers() @@ -1566,8 +1569,8 @@ def predict( # noqa: PLR0913 resolution: Resolution = 1.0, units: Units = "baseline", save_dir: str | Path | None = None, + device: str = "cpu", *, - on_gpu: bool = True, crash_on_exception: bool = False, ) -> list[tuple[Path, Path]]: """Make a prediction for a list of input data. @@ -1605,8 +1608,11 @@ def predict( # noqa: PLR0913 `stride_shape`, `resolution`, and `units` arguments are ignored. Otherwise, those arguments will be internally converted to a :class:`IOSegmentorConfig` object. - on_gpu (bool): - Whether to run the model on the GPU. + device (str): + :class:`torch.device` to run the model. + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default value is "cpu". patch_input_shape (IntPair): Size of patches input to the model. The values are at requested read resolution and must be positive. @@ -1662,7 +1668,7 @@ def predict( # noqa: PLR0913 imgs=imgs, masks=masks, mode=mode, - on_gpu=on_gpu, + device=device, ioconfig=ioconfig, patch_input_shape=patch_input_shape, patch_output_shape=patch_output_shape, diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index e16540c87..a3af4e7f0 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -39,6 +39,28 @@ def output_resolutions(self: IOConfigABC) -> None: raise NotImplementedError +def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module: + """Transfers model to specified device e.g., "cpu" or "cuda". + + Args: + model (torch.nn.Module): + PyTorch defined model. + device (str): + Transfers model to the specified device. Default is "cpu". + + Returns: + torch.nn.Module: + The model after being moved to specified device. + + """ + if device != "cpu": + # DataParallel work only for cuda + model = torch.nn.DataParallel(model) + + device = torch.device(device) + return model.to(device) + + class ModelABC(ABC, torch.nn.Module): """Abstract base class for models used in tiatoolbox.""" @@ -59,8 +81,7 @@ def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None: def infer_batch( model: torch.nn.Module, batch_data: np.ndarray, - *, - on_gpu: bool, + device: str, ) -> None: """Run inference on an input batch. @@ -72,8 +93,13 @@ def infer_batch( batch_data (np.ndarray): A batch of data generated by `torch.utils.data.DataLoader`. - on_gpu (bool): - Whether to run inference on a GPU. + device (str): + Transfers model to the specified device. Default is "cpu". + + Returns: + dict: + Returns a dictionary of predictions and other expected outputs + depending on the network architecture. """ ... # pragma: no cover @@ -106,7 +132,7 @@ def preproc_func(self: ModelABC, func: Callable) -> None: >>> # `func` is a user defined function >>> model = ModelABC() >>> model.preproc_func = func - >>> transformed_img = model.preproc_func(img) + >>> transformed_img = model.preproc_func(image=np.ndarray) """ if func is not None and not callable(func): @@ -137,7 +163,7 @@ def postproc_func(self: ModelABC, func: Callable) -> None: >>> # `func` is a user defined function >>> model = ModelABC() >>> model.postproc_func = func - >>> transformed_img = model.postproc_func(img) + >>> transformed_img = model.postproc_func(image=np.ndarray) """ if func is not None and not callable(func): diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 7239f0a8c..7c1c349e7 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -16,7 +16,6 @@ import numpy as np import pandas as pd import requests -import torch import yaml import zarr from filelock import FileLock @@ -878,24 +877,6 @@ def select_device(*, on_gpu: bool) -> str: return "cpu" -def model_to(model: torch.nn.Module, *, on_gpu: bool) -> torch.nn.Module: - """Transfers model to cpu/gpu. - - Args: - model (torch.nn.Module): PyTorch defined model. - on_gpu (bool): Transfers model to gpu if True otherwise to cpu. - - Returns: - torch.nn.Module: - The model after being moved to cpu/gpu. - """ - if on_gpu: # DataParallel work only for cuda - model = torch.nn.DataParallel(model) - return model.to("cuda") - - return model.to("cpu") - - def get_bounding_box(img: np.ndarray) -> np.ndarray: """Get bounding box coordinate information. diff --git a/tiatoolbox/utils/visualization.py b/tiatoolbox/utils/visualization.py index 142b6e061..817485711 100644 --- a/tiatoolbox/utils/visualization.py +++ b/tiatoolbox/utils/visualization.py @@ -559,6 +559,13 @@ def to_int_tuple(x: tuple[int, ...] | np.ndarray) -> tuple[int, ...]: return canvas +def _find_minimum_mpp_sf(mpp: tuple[float, float] | None) -> float: + """Calculates minimum mpp scale factor.""" + if mpp is not None: + return np.minimum(mpp[0] / 0.25, 1) + return 1.0 + + class AnnotationRenderer: """Renders AnnotationStore to a tile. @@ -971,9 +978,7 @@ def render_annotations( int((bounds[2] - bounds[0]) / scale), ] - mpp_sf = 1 - if self.info["mpp"] is not None: - mpp_sf = np.minimum(self.info["mpp"][0] / 0.25, 1) + mpp_sf = _find_minimum_mpp_sf(self.info["mpp"]) min_area = 0.0005 * (output_size[0] * output_size[1]) * (scale * mpp_sf) ** 2 diff --git a/tiatoolbox/visualization/bokeh_app/main.py b/tiatoolbox/visualization/bokeh_app/main.py index 6f9aff33d..4e4195558 100644 --- a/tiatoolbox/visualization/bokeh_app/main.py +++ b/tiatoolbox/visualization/bokeh_app/main.py @@ -69,6 +69,7 @@ NucleusInstanceSegmentor, ) from tiatoolbox.tools.pyramid import ZoomifyGenerator +from tiatoolbox.utils.misc import select_device from tiatoolbox.utils.visualization import random_colors from tiatoolbox.visualization.ui_utils import get_level_by_extent from tiatoolbox.wsicore.wsireader import WSIReader @@ -1237,7 +1238,7 @@ def segment_on_box() -> None: [tmp_mask_dir / "mask.png"], save_dir=tmp_save_dir / "hover_out", mode="wsi", - on_gpu=torch.cuda.is_available(), + device=select_device(on_gpu=torch.cuda.is_available()), crash_on_exception=True, ) diff --git a/whitelist.txt b/whitelist.txt index 07a1b13c3..d1e723f26 100644 --- a/whitelist.txt +++ b/whitelist.txt @@ -96,6 +96,7 @@ coord coords csv cuda +customizable cv2 dataframe dataset