diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index bed6365d4a2..c5b92b6b01a 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -74,7 +74,11 @@ def setup(self, stage: str) -> None: if stage in ['fit']: self.train_batch_sampler = RandomBatchGeoSampler( - self.train_dataset, self.patch_size, self.batch_size, self.length + self.train_dataset, + self.patch_size, + self.batch_size, + self.length, + generator=generator, ) if stage in ['fit', 'validate']: self.val_sampler = GridGeoSampler( diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 22726f74b2c..396ad0f0c7b 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -70,6 +70,7 @@ def __init__( length: int | None = None, roi: BoundingBox | None = None, units: Units = Units.PIXELS, + generator: torch.Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -97,9 +98,11 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) units: defines if ``size`` is in pixel or CRS units + generator: random number generator """ super().__init__(dataset, roi) self.size = _to_tuple(size) + self.generator = generator if units == Units.PIXELS: self.size = (self.size[0] * self.res, self.size[1] * self.res) @@ -144,7 +147,9 @@ def __iter__(self) -> Iterator[list[BoundingBox]]: # Choose random indices within that tile batch = [] for _ in range(self.batch_size): - bounding_box = get_random_bounding_box(bounds, self.size, self.res) + bounding_box = get_random_bounding_box( + bounds, self.size, self.res, self.generator + ) batch.append(bounding_box) yield batch diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 094142cb647..ea943db3d53 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -5,6 +5,7 @@ import abc from collections.abc import Callable, Iterable, Iterator +from functools import partial import torch from rtree.index import Index, Property @@ -72,6 +73,7 @@ def __init__( length: int | None = None, roi: BoundingBox | None = None, units: Units = Units.PIXELS, + generator: torch.Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -98,6 +100,8 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) units: defines if ``size`` is in pixel or CRS units + generator: The random generator used for sampling. + """ super().__init__(dataset, roi) self.size = _to_tuple(size) @@ -105,6 +109,7 @@ def __init__( if units == Units.PIXELS: self.size = (self.size[0] * self.res, self.size[1] * self.res) + self.generator = generator self.length = 0 self.hits = [] areas = [] @@ -142,7 +147,9 @@ def __iter__(self) -> Iterator[BoundingBox]: bounds = BoundingBox(*hit.bounds) # Choose a random index within that tile - bounding_box = get_random_bounding_box(bounds, self.size, self.res) + bounding_box = get_random_bounding_box( + bounds, self.size, self.res, self.generator + ) yield bounding_box @@ -270,7 +277,11 @@ class PreChippedGeoSampler(GeoSampler): """ def __init__( - self, dataset: GeoDataset, roi: BoundingBox | None = None, shuffle: bool = False + self, + dataset: GeoDataset, + roi: BoundingBox | None = None, + shuffle: bool = False, + generator: torch.Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -281,9 +292,12 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) shuffle: if True, reshuffle data at every epoch + generator: The random number generator used in combination with shuffle. + """ super().__init__(dataset, roi) self.shuffle = shuffle + self.generator = generator self.hits = [] for hit in self.index.intersection(tuple(self.roi), objects=True): @@ -297,7 +311,7 @@ def __iter__(self) -> Iterator[BoundingBox]: """ generator: Callable[[int], Iterable[int]] = range if self.shuffle: - generator = torch.randperm + generator = partial(torch.randperm, generator=self.generator) for idx in generator(len(self)): yield BoundingBox(*self.hits[idx].bounds) diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index a1fca673a3a..258f74a5425 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -35,7 +35,10 @@ def _to_tuple(value: tuple[float, float] | float) -> tuple[float, float]: def get_random_bounding_box( - bounds: BoundingBox, size: tuple[float, float] | float, res: float + bounds: BoundingBox, + size: tuple[float, float] | float, + res: float, + generator: torch.Generator | None = None, ) -> BoundingBox: """Returns a random bounding box within a given bounding box. @@ -50,6 +53,7 @@ def get_random_bounding_box( bounds: the larger bounding box to sample from size: the size of the bounding box to sample res: the resolution of the image + generator: random number generator Returns: randomly sampled bounding box from the extent of the input @@ -64,8 +68,8 @@ def get_random_bounding_box( miny = bounds.miny # Use an integer multiple of res to avoid resampling - minx += int(torch.rand(1).item() * width) * res - miny += int(torch.rand(1).item() * height) * res + minx += int(torch.rand(1, generator=generator).item() * width) * res + miny += int(torch.rand(1, generator=generator).item() * height) * res maxx = minx + t_size[1] maxy = miny + t_size[0]