Skip to content

Commit

Permalink
add an out of memory dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Oct 31, 2024
1 parent 099a1d8 commit fe0a420
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 12 deletions.
76 changes: 75 additions & 1 deletion deepforest/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from rasterio.windows import Window
from torchvision import transforms


def get_transform(augment):
"""Albumentations transformation of bounding boxs."""
if augment:
Expand Down Expand Up @@ -153,6 +152,7 @@ def __init__(self,
tile: an in memory numpy array.
patch_size (int): The size for the crops used to cut the input raster into smaller pieces. This is given in pixels, not any geographic unit.
patch_overlap (float): The horizontal and vertical overlap among patches
preload_images (bool): If true, the entire dataset is loaded into memory. This is useful for small datasets, but not recommended for large datasets since both the tile and the crops are stored in memory.
Returns:
ds: a pytorch dataset
Expand Down Expand Up @@ -248,3 +248,77 @@ def __getitem__(self, idx):
image = box

return image


class RasterDataset:
"""Dataset for predicting on raster windows
Args:
raster_path (str): Path to raster file
window_size (int): Size of windows to predict on
overlap (float): Overlap between windows as fraction (0-1)
"""
def __init__(self, raster_path, patch_size=1024, patch_overlap=0.1):

self.raster = rio.open(raster_path)
self.patch_size = patch_size
self.patch_overlap = patch_overlap

# Calculate step size based on overlap
self.step_size = int(patch_size * (1 - patch_overlap))

# Calculate number of windows in each dimension
self.n_windows_height = max(1, int((self.raster.height - patch_size) / self.step_size) + 1)
self.n_windows_width = max(1, int((self.raster.width - patch_size) / self.step_size) + 1)

# Store total number of windows
self.n_windows = self.n_windows_height * self.n_windows_width

def __len__(self):
return self.n_windows

def __getitem__(self, idx):
"""Get a window of the raster
Args:
idx (int): Index of window to get
Returns:
tuple: (window_data, window_bounds)
window_data (np.array): Array of shape (window_size, window_size, channels)
window_bounds (tuple): (row_off, col_off) offset of window in original raster
"""

# Calculate row and column of window
row_idx = idx // self.n_windows_width
col_idx = idx % self.n_windows_width

# Calculate window bounds
row_off = row_idx * self.step_size
col_off = col_idx * self.step_size

# Handle edge cases - ensure window doesn't exceed raster bounds
if row_off + self.patch_size > self.raster.height:
row_off = self.raster.height - self.patch_size
if col_off + self.patch_size > self.raster.width:
col_off = self.raster.width - self.patch_size

# Read window
window = self.raster.read(
window=Window(col_off, row_off, self.patch_size, self.patch_size)
)

window = np.rollaxis(window, 0, 3)
crop = preprocess.preprocess_image(window)

return crop

def close(self):
"""Close the raster dataset"""
self.raster.close()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
43 changes: 34 additions & 9 deletions deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,12 +461,16 @@ def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1

return results

import memory_profiler

@memory_profiler.profile
def predict_tile(self,
raster_path=None,
image=None,
patch_size=400,
patch_overlap=0.05,
iou_threshold=0.15,
in_memory=True,
return_plot=False,
mosaic=True,
sigma=0.5,
Expand All @@ -489,6 +493,7 @@ def predict_tile(self,
iou_threshold: Minimum iou overlap among predictions between
windows to be suppressed.
Lower values suppress more boxes at edges.
in_memory: If true, the entire dataset is loaded into memory. This is useful for small datasets, but not recommended for large datasets since both the tile and the crops are stored in memory.
mosaic: Return a single prediction dataframe (True) or a tuple of image crops and predictions (False)
sigma: variance of Gaussian function used in Gaussian Soft NMS
thresh: the score thresh used to filter bboxes after soft-nms performed
Expand Down Expand Up @@ -529,15 +534,29 @@ def predict_tile(self,
"Both tile and tile_path are None. Either supply a path to a tile on disk, or read one into memory!"
)

if raster_path is None:
self.image = image
else:
self.image = rio.open(raster_path).read()
self.image = np.moveaxis(self.image, 0, 2)

ds = dataset.TileDataset(tile=self.image,
if in_memory:
if raster_path is None:
image = image
else:
image = rio.open(raster_path).read()
image = np.moveaxis(image, 0, 2)

ds = dataset.TileDataset(tile=image,
patch_overlap=patch_overlap,
patch_size=patch_size)
else:
if raster_path is None:
raise ValueError("raster_path is required if in_memory is False")

# Check for workers config when using out of memory dataset
if self.config["workers"] > 0:
raise ValueError("workers must be 0 when using out-of-memory dataset (in_memory=False). Set config['workers']=0 and recreate trainer self.create_trainer().")

ds = dataset.RasterDataset(raster_path=raster_path,
patch_overlap=patch_overlap,
patch_size=patch_size)

batched_results = self.trainer.predict(self, self.predict_dataloader(ds))

# Flatten list from batched prediction
Expand All @@ -564,7 +583,7 @@ def predict_tile(self,
if raster_path:
tile = rio.open(raster_path).read()
else:
tile = self.image
tile = image
drawn_plot = tile[:, :, ::-1]
drawn_plot = visualize.plot_predictions(tile,
results,
Expand All @@ -575,10 +594,16 @@ def predict_tile(self,
for df in results:
df["label"] = df.label.apply(lambda x: self.numeric_to_label_dict[x])

# TODO this is the 2nd time the crops are generated? Could be more efficient.
# TODO this is the 2nd time the crops are generated? Could be more efficient, but memory intensive
self.crops = []
if raster_path is None:
image = image
else:
image = rio.open(raster_path).read()
image = np.moveaxis(image, 0, 2)

for window in ds.windows:
crop = self.image[window.indices()]
crop = image[window.indices()]
self.crops.append(crop)

return list(zip(results, self.crops))
Expand Down
2 changes: 1 addition & 1 deletion deepforest_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Cpu workers for data loaders
# Dataloaders
workers: 1
workers: 0
devices: auto
accelerator: auto
batch_size: 1
Expand Down
67 changes: 66 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,14 @@ def test_predict_tile_empty(raster_path):
predictions = m.predict_tile(raster_path=raster_path, patch_size=300, patch_overlap=0)
assert predictions is None

def test_predict_tile(m, raster_path):
@pytest.mark.parametrize("in_memory", [True, False])
def test_predict_tile(m, raster_path, in_memory):
m.create_model()
m.config["train"]["fast_dev_run"] = False
m.create_trainer()
prediction = m.predict_tile(raster_path=raster_path,
patch_size=300,
in_memory=in_memory,
patch_overlap=0.1)

assert isinstance(prediction, pd.DataFrame)
Expand All @@ -283,6 +285,11 @@ def test_predict_tile(m, raster_path):
}
assert not prediction.empty

# test equivalence for in_memory=True and False
def test_predict_tile_equivalence(m, raster_path):
in_memory_prediction = m.predict_tile(raster_path=raster_path, patch_size=300, patch_overlap=0, in_memory=True)
not_in_memory_prediction = m.predict_tile(raster_path=raster_path, patch_size=300, patch_overlap=0, in_memory=False)
assert in_memory_prediction.equals(not_in_memory_prediction)

@pytest.mark.parametrize("patch_overlap", [0.1, 0])
def test_predict_tile_from_array(m, patch_overlap, raster_path):
Expand Down Expand Up @@ -635,3 +642,61 @@ def test_predict_tile_with_crop_model(m, config):
"xmin", "ymin", "xmax", "ymax", "label", "score", "cropmodel_label", "geometry",
"cropmodel_score", "image_path"
}


def test_predict_tile_memory(m, raster_path):
"""Test memory usage of predict_tile function"""
from memory_profiler import profile

@profile
def predict_tile_wrapper():
# Create model and prepare it for prediction
m.create_model()
m.config["train"]["fast_dev_run"] = False
m.create_trainer()

# Run prediction with standard parameters
prediction = m.predict_tile(
raster_path=raster_path,
patch_size=300,
patch_overlap=0.1
)

return prediction

# Run the profiled function
result = predict_tile_wrapper()

# Verify the prediction worked correctly
assert isinstance(result, pd.DataFrame)
assert not result.empty


def test_raster_dataset(m, raster_path):
"""Test the RasterDataset class"""
from deepforest.dataset import RasterDataset

# Create dataset
window_size = 256
overlap = 0.1

with RasterDataset(raster_path, window_size=window_size, overlap=overlap) as ds:
# Test length
assert len(ds) > 0

# Test getting an item
window, bounds = ds[0]

# Check window shape and type
assert window.shape == (window_size, window_size, 3)
assert window.dtype == np.float32
assert 0 <= window.min() <= window.max() <= 1.0

# Check bounds
assert isinstance(bounds, tuple)
assert len(bounds) == 2
assert all(isinstance(x, int) for x in bounds)

# Test getting last item
window, bounds = ds[len(ds)-1]
assert window.shape == (window_size, window_size, 3)

0 comments on commit fe0a420

Please sign in to comment.