Skip to content

Commit

Permalink
Add multi-batch prediction for the segmentation and Polaris consumers (
Browse files Browse the repository at this point in the history
…#189)

* Begin adding tests for different segmentation data inputs

* Add image_dimensions_to_bxyc function

* Raise ValueError if dim mismatch

* Overload save_output method

* Add dimension_order to test input

* Modify DummyStorage image dims

* Add dim order test cases

* Add dim order to hvals

* Fix base consumer tests

* Change prediction results shape

* Remove expand dims for TIFF

* Remove squeeze after get_image

* Lower default TF max batch size and add comment

* Allow multibatch Polaris jobs

* Raise error for channel dim mismatch

* Remove comment

Co-authored-by: Morgan Schwartz <[email protected]>

* Fix Polaris tests

Co-authored-by: msschwartz21 <[email protected]>
  • Loading branch information
elaubsch and msschwartz21 authored May 2, 2022
1 parent 44e5775 commit e61342a
Show file tree
Hide file tree
Showing 10 changed files with 158 additions and 60 deletions.
10 changes: 5 additions & 5 deletions redis_consumer/consumers/base_consumer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def get_file_from_hash(redis_hash, _):

def test__upload_archived_images(self, mocker, redis_client):
N = 3
storage = DummyStorage(num=N)
storage = DummyStorage(batch=N, img_h=300, img_w=300)
consumer = consumers.ZipFileConsumer(redis_client, storage, 'predict')
# mocker.patch.object(consumer.storage, 'download')
hvalues = {'input_file_name': 'test.zip', 'children': 'none'}
Expand All @@ -550,7 +550,7 @@ def test__upload_archived_images(self, mocker, redis_client):
def test__upload_finished_children(self, mocker, redis_client):
finished_children = ['predict:1.tiff', 'predict:2.zip', '']
N = 3
storage = DummyStorage(num=N)
storage = DummyStorage(batch=N, img_h=300, img_w=300)
consumer = consumers.ZipFileConsumer(redis_client, storage, 'predict')
mocker.patch.object(consumer, '_get_output_file_name', lambda x: x)

Expand Down Expand Up @@ -595,7 +595,7 @@ def test__get_output_file_name(self, mocker, redis_client):

def test__parse_failures(self, mocker, redis_client):
N = 3
storage = DummyStorage(num=N)
storage = DummyStorage(batch=N, img_h=300, img_w=300)

keys = [str(x) for x in range(4)]
consumer = consumers.ZipFileConsumer(redis_client, storage, 'predict')
Expand All @@ -619,7 +619,7 @@ def test__cleanup(self, mocker, redis_client):
queue = 'predict'
done = [str(i) for i in range(N)]
failed = [str(i) for i in range(N + 1, N * 2)]
storage = DummyStorage(num=N)
storage = DummyStorage(batch=N, img_h=300, img_w=300)
consumer = consumers.ZipFileConsumer(redis_client, storage, queue)

redis_hash = 'some job hash'
Expand All @@ -642,7 +642,7 @@ def test__cleanup(self, mocker, redis_client):

def test__consume(self, mocker, redis_client):
N = 3
storage = DummyStorage(num=N)
storage = DummyStorage(batch=N, img_h=300, img_w=300)
children = list('abcdefg')
queue = 'q'
test_hash = 0
Expand Down
18 changes: 11 additions & 7 deletions redis_consumer/consumers/caliban_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,20 @@ def _load_data(self, redis_hash, subdir, fname):
raise ValueError('_load_data takes in only .tiff, .trk, or .trks')

# push a key per frame and let ImageFileConsumers segment
raw = utils.get_image(os.path.join(subdir, fname))
tiff_stack = utils.get_image(os.path.join(subdir, fname))

# remove the last dimensions added by `get_image`
tiff_stack = np.squeeze(raw, -1)
if len(tiff_stack.shape) != 3:
raise ValueError('This tiff file has shape {}, which is not 3 '
'dimensions. Tracking can only be done on images '
'with 3 dimensions, (time, width, height)'.format(
tiff_stack.shape))

num_frames = len(tiff_stack)

if num_frames > settings.TF_MAX_BATCH_SIZE:
raise ValueError('This file has {} frames. Maximum allowed number of frames '
'is {}.'.format(num_frames, settings.TF_MAX_BATCH_SIZE))

hash_to_frame = {}
remaining_hashes = set()
frames = {}
Expand All @@ -116,7 +119,7 @@ def _load_data(self, redis_hash, subdir, fname):
segment_fname = '{}-{}-tracking-frame-{}.tif'.format(
uid, hvalues.get('original_name'), i)
segment_local_path = os.path.join(tempdir, segment_fname)
tifffile.imsave(segment_local_path, img)
tifffile.imsave(segment_local_path, np.squeeze(img))
upload_file_name, upload_file_url = self.storage.upload(
segment_local_path)

Expand All @@ -130,7 +133,8 @@ def _load_data(self, redis_hash, subdir, fname):
'created_at': current_timestamp,
'updated_at': current_timestamp,
'url': upload_file_url,
'channels': '0,,', # encodes that images are nuclear
'channels': '0,,', # encodes that images are nuclear,
'dimension_order': 'XY'
}

# make a hash for this frame
Expand Down Expand Up @@ -191,11 +195,11 @@ def _load_data(self, redis_hash, subdir, fname):
labels = [frames[i] for i in range(num_frames)]

# Cast y to int to avoid issues during fourier transform/drift correction
y = np.array(labels, dtype='uint16')
y = np.expand_dims(np.array(labels, dtype='uint16'), axis=-1)
# TODO: Why is there an extra dimension?
# Not a problem in tests, only with application based results.
# Issue with batch dimension from outputs?
y = y[:, 0] if y.shape[1] == 1 else y
# y = y[:, 0] if y.shape[1] == 1 else y
return {'X': np.expand_dims(tiff_stack, axis=-1), 'y': y}

def _consume(self, redis_hash):
Expand Down
26 changes: 18 additions & 8 deletions redis_consumer/consumers/polaris_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def _add_images(self, hvals, uid, image, queue, channels=''):
'updated_at': current_timestamp,
'url': upload_file_url,
'channels': channels,
'scale': settings.POLARIS_SCALE} # scaling not supported for spots model
'scale': settings.POLARIS_SCALE, # scaling not supported for spots model
'dimension_order': 'BXY'}

# make a hash for this frame
image_hash = '{prefix}:{file}:{hash}'.format(
Expand All @@ -100,10 +101,14 @@ def _analyze_images(self, redis_hash, subdir, fname):
data.
"""
hvals = self.redis.hgetall(redis_hash)
raw = utils.get_image(os.path.join(subdir, fname))
tiff_stack = utils.get_image(os.path.join(subdir, fname))

# remove the last dimensions added by `get_image`
tiff_stack = np.squeeze(raw)
channels = hvals.get('channels').split(',') # ex: channels = ['0','1','2']
filled_channels = [c for c in channels if c]
if len(filled_channels) > np.shape(tiff_stack)[-1]:
raise ValueError('Input image has {} channels but {} channels were specified '
'for segmentation'.format(np.shape(tiff_stack)[-1],
len(filled_channels)))

self.logger.debug('Got tiffstack shape %s.', tiff_stack.shape)

Expand Down Expand Up @@ -131,13 +136,19 @@ def _analyze_images(self, redis_hash, subdir, fname):
if channels[1]:
# add channel 1 ind of tiff stack to nuclear queue
nuc_image = tiff_stack[..., int(channels[1])]
# add batch dimension if it doesn't exist
if len(np.shape(nuc_image)) == 2:
nuc_image = np.expand_dims(nuc_image, axis=0)
nuc_hash = self._add_images(hvals, uid, nuc_image,
queue='segmentation', channels='0,')
remaining_hashes.add(nuc_hash)

if channels[2]:
# add channel 2 ind of tiff stack to segmentation queue
cyto_image = tiff_stack[..., int(channels[2])]
# add batch dimension if it doesn't exist
if len(np.shape(cyto_image)) == 2:
cyto_image = np.expand_dims(cyto_image, axis=0)
cyto_hash = self._add_images(hvals, uid, cyto_image,
queue='segmentation', channels=',0')
remaining_hashes.add(cyto_hash)
Expand Down Expand Up @@ -217,9 +228,8 @@ def _analyze_images(self, redis_hash, subdir, fname):
if segmentation_type == 'cell culture':
for key in segmentation_dict.keys():
labeled_im = np.array(segmentation_dict[key])
labeled_im = np.squeeze(labeled_im, -1)
labeled_im = np.moveaxis(labeled_im, 0, 2)
segmentation_results.append(labeled_im)
labeled_im = np.swapaxes(labeled_im, 0, -1) # c,x,y,b to b,x,y,c
segmentation_results.extend(labeled_im)

return {'coords': np.array(coords), 'segmentation': segmentation_results}

Expand All @@ -244,7 +254,7 @@ def save_output(self, res, hvals):
# Save labeled image
outpaths.extend(utils.save_numpy_array(
labeled_im[i],
name=str(name),
name=str(name) + '_batch_{}'.format(i),
subdir=subdir, output_dir=tempdir))

# Save spot locations and assignments in .csv file
Expand Down
2 changes: 1 addition & 1 deletion redis_consumer/consumers/polaris_consumer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test__analyze_images(self, tmpdir, mocker, redis_client):

fname = 'file.tiff'
filepath = os.path.join(tmpdir, fname)
input_size = (1, 32, 32, 1)
input_size = (1, 32, 32, 3)

# test successful workflow
def hget_successful_status(*_):
Expand Down
82 changes: 63 additions & 19 deletions redis_consumer/consumers/segmentation_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from __future__ import division
from __future__ import print_function

import os
import tempfile
import timeit

import numpy as np
Expand All @@ -36,6 +38,7 @@

from redis_consumer.consumers import TensorFlowServingConsumer
from redis_consumer import settings
from redis_consumer import utils


class SegmentationConsumer(TensorFlowServingConsumer):
Expand Down Expand Up @@ -69,7 +72,7 @@ def detect_label(self, image):
return int(detected_label)

def get_image_label(self, label, image, redis_hash):
""" DEPRACATED -- Calculate label of image."""
""" DEPRECATED -- Calculate label of image."""
if not label:
# Detect scale of image (Default to 1)
label = self.detect_label(image)
Expand All @@ -83,6 +86,52 @@ def get_image_label(self, label, image, redis_hash):

return label

def image_dimensions_to_bxyc(self, dim_order, image):
"""Modifies image dimensions to be BXYC."""

if len(np.shape(image)) != len(dim_order):
raise ValueError('Input dimension order was {} but input '
'image has shape {}'.format(dim_order, np.shape(image)))

if dim_order == 'XYB':
image = np.moveaxis(image, -1, 0)
elif dim_order == 'CXY':
image = np.moveaxis(image, 0, -1)
elif dim_order == 'CXYB':
image = np.swapaxes(image, 0, -1)

if 'B' not in dim_order:
image = np.expand_dims(image, axis=0)
if 'C' not in dim_order:
image = np.expand_dims(image, axis=-1)

return(image)

def save_output(self, image, save_name):
"""Save output images into a zip file and upload it."""
with tempfile.TemporaryDirectory() as tempdir:
# Save each result channel as an image file
subdir = os.path.dirname(save_name.replace(tempdir, ''))
name = os.path.splitext(os.path.basename(save_name))[0]

outpaths = []
for i, img in enumerate(image):
outpaths.extend(utils.save_numpy_array(
img,
name=str(name) + '_batch_{}'.format(i),
subdir=subdir, output_dir=tempdir))

# Save each prediction image as zip file
zip_file = utils.zip_files(outpaths, tempdir)

# Upload the zip file to cloud storage bucket
cleaned = zip_file.replace(tempdir, '')
subdir = os.path.dirname(utils.strip_bucket_path(cleaned))
subdir = subdir if subdir else None
dest, output_url = self.storage.upload(zip_file, subdir=subdir)

return dest, output_url

def _consume(self, redis_hash):
start = timeit.default_timer()
hvals = self.redis.hgetall(redis_hash)
Expand All @@ -105,15 +154,16 @@ def _consume(self, redis_hash):
# Load input image
fname = hvals.get('input_file_name')
image = self.download_image(fname)
image = np.squeeze(image)
image = np.expand_dims(image, axis=0) # add a batch dimension
if len(np.shape(image)) == 3:
image = np.expand_dims(image, axis=-1) # add a channel dimension
dim_order = hvals.get('dimension_order')

rank = 4 # (b,x,y,c)
channel_axis = image.shape[1:].index(min(image.shape[1:])) + 1
if channel_axis != rank - 1:
image = np.rollaxis(image, 1, rank)
# Modify image dimensions to be BXYC
image = self.image_dimensions_to_bxyc(dim_order, image)

channels = hvals.get('channels').split(',') # ex: channels = ['0','1','2']
filled_channels = [c for c in channels if c]
if len(filled_channels) > np.shape(image)[3]:
raise ValueError('Input image has {} channels but {} channels were specified '
'for segmentation'.format(np.shape(image)[3], len(filled_channels)))

# Pre-process data before sending to the model
self.update_key(redis_hash, {
Expand All @@ -125,9 +175,6 @@ def _consume(self, redis_hash):
scale = hvals.get('scale', '')
scale = self.get_image_scale(scale, image, redis_hash)

# Validate input image
channels = hvals.get('channels').split(',') # ex: channels = ['0','1','2']

results = []
for i in range(len(channels)):
if channels[i]:
Expand All @@ -139,12 +186,6 @@ def _consume(self, redis_hash):

model_name, model_version = model.split(':')

# detect dimension order and add to redis
dim_order = self.detect_dimension_order(slice_image, model_name, model_version)
self.update_key(redis_hash, {
'dim_order': ','.join(dim_order)
})

# Validate input image
slice_image = self.validate_model_input(slice_image, model_name, model_version)

Expand All @@ -158,7 +199,10 @@ def _consume(self, redis_hash):
pred_results = app.predict(slice_image, batch_size=batch_size,
image_mpp=scale * app.model_mpp)

results.extend(pred_results)
results.append(pred_results)

results = np.squeeze(np.array(results), axis=-1) # c,b,x,y,1 to c,b,x,y
results = np.moveaxis(results, 0, -1) # c,b,x,y to b,x,y,c

# Save the post-processed results to a file
_ = timeit.default_timer()
Expand Down
35 changes: 29 additions & 6 deletions redis_consumer/consumers/segmentation_consumer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,31 +118,54 @@ def test__consume_finished_status(self, redis_client):
assert result == status
test_hash += 1

def test__consume(self, mocker, redis_client):
@pytest.mark.parametrize(
'shape,channels,dim_order',
[
pytest.param((None, 32, 32, None), '0,,', 'XY', id='xy-nuc'),
pytest.param((None, 32, 32, None), ',0,', 'XY', id='xy-cyto'),
pytest.param((1, 32, 32, None), '0,,', 'BXY', id='bxy-nuc'),
pytest.param((1, 32, 32, None), ',0,', 'BXY', id='bxy-cyto'),
pytest.param((None, 32, 32, 1), '0,,', 'XYB', id='xyb-nuc'),
pytest.param((None, 32, 32, 1), '0,,', 'XYC', id='xyc-nuc'),
pytest.param((None, 32, 32, 1), ',0,', 'XYC', id='xyc-cyto'),
pytest.param((1, 32, 32, None), '0,,', 'CXY', id='cxy-nuc'),
pytest.param((1, 32, 32, 1), '0,,', 'BXYC', id='bxyc-nuc'),
pytest.param((1, 32, 32, 1), ',0,', 'BXYC', id='bxyc-cyto'),
pytest.param((1, 32, 32, 2), '0,1,', 'BXYC', id='bxyc-nuc-cyto'),
pytest.param((1, 32, 32, 2), '1,0,', 'BXYC', id='bxyc-cyto-nuc'),
pytest.param((2, 32, 32, 2), '0,1,', 'BXYC', id='bxyc-nuc-cyto-multibatch'),
pytest.param((2, 32, 32, 1), '0,1,', 'CXYB', id='cxyb-nuc-cyto')
]
)
def test__consume(self, mocker, redis_client, shape, channels, dim_order):
# pylint: disable=W0613
queue = 'predict'
storage = DummyStorage()
storage = DummyStorage(batch=shape[0],
img_h=shape[1],
img_w=shape[2],
channel=shape[3])

consumer = consumers.SegmentationConsumer(redis_client, storage, queue)

empty_data = {'input_file_name': 'file.tiff',
'channels': '0,'}
'channels': channels,
'dimension_order': dim_order}

output_shape = (1, 32, 32, 1)
output_shape = (1, shape[1], shape[2], 1)

mock_app = Bunch(
predict=lambda *x, **y: np.random.randint(1, 5, size=output_shape),
model_mpp=1,
model=Bunch(
get_batch_size=lambda *x: 1,
input_shape=(1, 32, 32, 1)
input_shape=shape
)
)

mocker.patch.object(consumer, 'get_grpc_app', lambda *x, **_: mock_app)
mocker.patch.object(consumer, 'get_image_scale', lambda *x, **_: 1)
mocker.patch.object(consumer, 'validate_model_input', lambda *x, **_: True)
mocker.patch.object(consumer, 'detect_dimension_order', lambda *x, **_: 'YXC')
mocker.patch.object(consumer, 'detect_dimension_order', lambda *x, **_: dim_order)

test_hash = 'some hash'

Expand Down
3 changes: 2 additions & 1 deletion redis_consumer/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@
TF_HOST = config('TF_HOST', default='tf-serving')
TF_PORT = config('TF_PORT', default=8500, cast=int)
# maximum batch allowed by TensorFlow Serving
TF_MAX_BATCH_SIZE = config('TF_MAX_BATCH_SIZE', default=128, cast=int)
# must be manually matched to the helmfile for the TF-serving pod
TF_MAX_BATCH_SIZE = config('TF_MAX_BATCH_SIZE', default=64, cast=int)
# minimum expected model size, dynamically change batches proportionately.
TF_MIN_MODEL_SIZE = config('TF_MIN_MODEL_SIZE', default=128, cast=int)

Expand Down
Loading

0 comments on commit e61342a

Please sign in to comment.