Skip to content

Commit

Permalink
Fix bug in ChipClassificationSource.__getitem__() when bbox is specif…
Browse files Browse the repository at this point in the history
…ied. (#2193)
  • Loading branch information
AdeelH authored Jul 10, 2024
1 parent 116e788 commit 9f7237d
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -173,54 +173,80 @@ def infer_cells(self, cells: Iterable[Box] | None = None
) -> ChipClassificationLabels:
"""Infer labels for a list of cells.
Only cells whose labels are not already known are inferred.
Cells are assumed to be in ``bbox`` coords as opposed to global coords
and are converted to global coords before inference. The returned
labels are in global coords. Only cells whose
labels are not already known are inferred.
Args:
cells: Cells whose labels are to be inferred. Defaults to ``None``.
cells: Cells (in ``bbox`` coords) whose labels are to be inferred.
If ``None``, cells are assumed to be sliding windows of size
and stride ``cell_sz`` (specified in
:class:`.ChipClassificationLabelSourceConfig`).
Defaults to ``None``.
Returns:
ChipClassificationLabels: labels
Labels (in global coords).
"""
cfg = self.cfg
if cells is None:
if cfg.cell_sz is None:
cell_sz = self.cfg.cell_sz
if cell_sz is None:
raise ValueError('cell_sz is not set.')
cells = self.extent.get_windows(cfg.cell_sz, cfg.cell_sz)
else:
cells = [cell.to_global_coords(self.bbox) for cell in cells]
cells = self.extent.get_windows(cell_sz, cell_sz)
cells = [cell.to_global_coords(self.bbox) for cell in cells]
labels = self._infer_cells(cells)
return labels

def _infer_cells(self, cells: Iterable[Box]) -> ChipClassificationLabels:
"""Infer labels for a list of cells.
Cells are assumed to be in global coords as opposed to ``bbox`` coords.
Only cells whose labels are not already known are inferred.
Args:
cells: Cells (in global coords) whose labels are to be inferred.
Returns:
Labels (in global coords).
"""
cfg = self.cfg
known_cells = [c for c in cells if c in self.labels]
unknown_cells = [c for c in cells if c not in self.labels]

labels = infer_cells(
cells=unknown_cells,
labels_df=self.labels_df,
ioa_thresh=cfg.ioa_thresh,
use_intersection_over_cell=cfg.use_intersection_over_cell,
pick_min_class_id=cfg.pick_min_class_id,
background_class_id=cfg.background_class_id)

for cell in known_cells:
class_id = self.labels.get_cell_class_id(cell)
labels.set_cell(cell, class_id)

return labels

def get_labels(self,
window: Box | None = None) -> ChipClassificationLabels:
"""Return label for a window, inferring it if not already known.
If window is ``None``, returns all labels.
"""
if window is None:
return self.labels
window = window.to_global_coords(self.bbox)
return self.labels.get_singleton_labels(window)
if window not in self.labels:
self.labels += self._infer_cells(cells=[window])
labels = self.labels.get_singleton_labels(window)
return labels

def __getitem__(self, key: Any) -> int:
"""Return label for a window, inferring it if it is not already known.
"""
"""Return class ID for a window, inferring it if not already known."""
if isinstance(key, Box):
window = key
window = window.to_global_coords(self.bbox)
if window not in self.labels:
self.labels += self.infer_cells(cells=[window])
return self.labels[window].class_id
self.labels += self._infer_cells(cells=[window])
class_id = self.labels[window].class_id
return class_id
else:
return super().__getitem__(key)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from collections.abc import Callable
import unittest
import os
from os.path import join

import geopandas as gpd
import numpy as np

from rastervision.pipeline.file_system import json_to_file, get_tmp_dir
from rastervision.core.box import Box
from rastervision.core.data import (
ClassConfig, ChipClassificationLabelSourceConfig,
GeoJSONVectorSourceConfig, ClassInferenceTransformerConfig,
BufferTransformerConfig)
BufferTransformerConfig, ChipClassificationLabelSource,
ChipClassificationLabelSourceConfig, ClassConfig,
ClassInferenceTransformerConfig, GeoJSONVectorSource,
GeoJSONVectorSourceConfig, IdentityCRSTransformer)
from rastervision.core.data.label_source.chip_classification_label_source \
import infer_cells
from rastervision.core.data.label_store.utils import boxes_to_geojson

from tests import data_file_path
from tests.core.data.mock_crs_transformer import DoubleCRSTransformer
Expand All @@ -30,6 +34,12 @@ def test_ensure_required_transformers(self):


class TestChipClassificationLabelSource(unittest.TestCase):
def assertNoError(self, fn: Callable, msg: str = ''):
try:
fn()
except Exception:
self.fail(msg)

def setUp(self):
self.crs_transformer = DoubleCRSTransformer()
self.geojson = {
Expand Down Expand Up @@ -76,7 +86,7 @@ def setUp(self):

self.file_name = 'labels.json'
self.tmp_dir = get_tmp_dir()
self.uri = os.path.join(self.tmp_dir.name, self.file_name)
self.uri = join(self.tmp_dir.name, self.file_name)
json_to_file(self.geojson, self.uri)

def tearDown(self):
Expand Down Expand Up @@ -292,17 +302,37 @@ def test_get_labels(self):
def test_getitem(self):
# Extent contains both boxes.
extent = Box.make_square(0, 0, 8)

config = ChipClassificationLabelSourceConfig(
vector_source=GeoJSONVectorSourceConfig(uris=self.uri))
source = config.build(self.class_config, self.crs_transformer, extent,
self.tmp_dir.name)
labels = source.get_labels()

label_source = config.build(self.class_config, self.crs_transformer,
extent, self.tmp_dir.name)
labels = label_source.get_labels()
cells = labels.get_cells()
self.assertEqual(len(cells), 2)
self.assertEqual(source[cells[0]], self.class_id1)
self.assertEqual(source[cells[1]], self.class_id2)
self.assertEqual(label_source[cells[0]], self.class_id1)
self.assertEqual(label_source[cells[1]], self.class_id2)

def test_getitem_and_get_labels_with_bbox(self):
extent = Box(0, 0, 100, 100)
boxes = extent.get_windows(10, 10)
class_config = ClassConfig(names=['a', 'b', 'c'], null_class='c')
class_ids = np.random.randint(
0, len(class_config), size=len(boxes)).tolist()
crs_tf = IdentityCRSTransformer()
geojson = boxes_to_geojson(boxes, class_ids, crs_tf, class_config)

ls_cfg = ChipClassificationLabelSourceConfig(
background_class_id=class_config.null_class_id, infer_cells=True)
bbox = Box(25, 25, 50, 50)
with get_tmp_dir() as tmp_dir:
labels_uri = join(tmp_dir, 'labels.json')
json_to_file(geojson, labels_uri)
vs = GeoJSONVectorSource(labels_uri, crs_tf)
ls = ChipClassificationLabelSource(
ls_cfg, vs, bbox=bbox, lazy=True)
self.assertNoError(lambda: ls[:10, :10])
labels = ls.get_labels(Box(0, 0, 11, 11))
self.assertListEqual(labels.get_cells(), [Box(25, 25, 36, 36)])


if __name__ == '__main__':
Expand Down

0 comments on commit 9f7237d

Please sign in to comment.