Skip to content

Commit

Permalink
overwrite the provide function in RandomLocation
Browse files Browse the repository at this point in the history
no need for supporting skip
  • Loading branch information
pattonw committed May 15, 2024
1 parent d02b696 commit 675cf43
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 15 deletions.
41 changes: 35 additions & 6 deletions gunpowder/nodes/random_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from gunpowder.array import Array
from gunpowder.array_spec import ArraySpec
from .batch_filter import BatchFilter
from gunpowder.profiling import Timing

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -210,6 +211,33 @@ def prepare(self, request):

return request

def provide(self, request):

timing_prepare = Timing(self, "prepare")
timing_prepare.start()

downstream_request = request.copy()

self.prepare(request)

self.remove_provided(request)

timing_prepare.stop()

batch = self.get_upstream_provider().request_batch(request)

timing_process = Timing(self, "process")
timing_process.start()

self.process(batch, downstream_request)

timing_process.stop()

batch.profiling_stats.add(timing_prepare)
batch.profiling_stats.add(timing_process)

return batch

def process(self, batch, request):
if self.random_shift_key is not None:
batch[self.random_shift_key] = Array(
Expand Down Expand Up @@ -429,13 +457,14 @@ def __select_random_location_with_points(

# count all points inside the shifted ROI
points = self.__get_points_in_roi(request_points_roi.shift(random_shift))
assert (
point in points
), "Requested batch to contain point %s, but got points " "%s" % (
point,
points,
assert point in points, (
"Requested batch to contain point %s, but got points "
"%s"
% (
point,
points,
)
)
num_points = len(points)

return random_shift

Expand Down
18 changes: 9 additions & 9 deletions tests/cases/random_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,16 @@ def test_output():
a = ArrayKey("A")
b = ArrayKey("B")
random_shift_key = ArrayKey("RANDOM_SHIFT")
source_a = ExampleSourceRandomLocation(a)
source_b = ExampleSourceRandomLocation(b)

pipeline = (
(source_a, source_b)
(ExampleSourceRandomLocation(a), ExampleSourceRandomLocation(b))
+ MergeProvider()
+ CustomRandomLocation(a, random_store_key=random_shift_key)
+ CustomRandomLocation(a, random_shift_key=random_shift_key)
)
pipeline_no_random = (source_a, source_b) + MergeProvider()
pipeline_no_random = (
ExampleSourceRandomLocation(a),
ExampleSourceRandomLocation(b),
) + MergeProvider()

with build(pipeline), build(pipeline_no_random):
sums = set()
Expand All @@ -95,8 +96,7 @@ def test_output():
),
b: ArraySpec(
roi=Roi(batch[random_shift_key].data, (20, 20, 20))
),
random_shift_key: ArraySpec(nonspatial=True),
)
}
)
)
Expand All @@ -106,8 +106,8 @@ def test_output():
sums.add(batch[a].data.sum())

# Request a ROI with the same shape as the entire ROI
full_roi_a = Roi((0, 0, 0), source_a.roi.shape)
full_roi_b = Roi((0, 0, 0), source_b.roi.shape)
full_roi_a = Roi((0, 0, 0), ExampleSourceRandomLocation(a).roi.shape)
full_roi_b = Roi((0, 0, 0), ExampleSourceRandomLocation(b).roi.shape)
batch = pipeline.request_batch(
BatchRequest(
{a: ArraySpec(roi=full_roi_a), b: ArraySpec(roi=full_roi_b)}
Expand Down

0 comments on commit 675cf43

Please sign in to comment.