Skip to content

Commit

Permalink
Polaris super consumer (#183)
Browse files Browse the repository at this point in the history
* SpotDetection to Polaris app

* Create spot detection consumer

* Initial commit of super consumer

* Add output function for coords only

* Increase MAX_SCALE setting

* Update spot _consume tests

* Manual 2 model channel wrangling

* Switch output dimensions and refactor label to channels

* Remove label detection model

* Update model/app label and cap batch size

* Fix dimension handling in polaris consumer

* Spot predict params from settings.py

* Allow two cell assignments per spot

* Add Polaris test

* Pin jinja2 to <3.1 (#185)

* Add unit test for _add_images

* Add unit tests for _consume

* Add unit test for _analyze_images

* Patch get_image function

* Mock successful status

* Add cell culture and tissue tests

* Consolidate channel logic

* Move Polaris scale to settings

Co-authored-by: Morgan Schwartz <[email protected]>
  • Loading branch information
elaubsch and msschwartz21 authored Apr 19, 2022
1 parent 94cdf9b commit 90970fc
Show file tree
Hide file tree
Showing 11 changed files with 746 additions and 154 deletions.
2 changes: 2 additions & 0 deletions docs/rtd-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ Sphinx==2.3.1
# mistune 2 and m2r 0.2.1 seem to not play well
# https://github.com/miyakogi/m2r/issues/66
mistune<2.0.0
# Jinja2 imports for sphinx are deprecated over 3.1
Jinja2<3.1
2 changes: 2 additions & 0 deletions redis_consumer/consumers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from redis_consumer.consumers.caliban_consumer import CalibanConsumer
from redis_consumer.consumers.mesmer_consumer import MesmerConsumer
from redis_consumer.consumers.polaris_consumer import PolarisConsumer
from redis_consumer.consumers.spot_consumer import SpotConsumer
# TODO: Import future custom Consumer classes.


Expand All @@ -49,6 +50,7 @@
'mesmer': MesmerConsumer,
'caliban': CalibanConsumer,
'polaris': PolarisConsumer,
'spot': SpotConsumer,
# TODO: Add future custom Consumer classes here.
}

Expand Down
4 changes: 2 additions & 2 deletions redis_consumer/consumers/base_consumer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,12 +471,12 @@ def test_get_image_scale(self, mocker, redis_client):

# test scale provided is too large
with pytest.raises(ValueError):
scale = settings.MAX_SCALE + 0.1
scale = settings.MAX_SCALE + 0.05
consumer.get_image_scale(scale, image, 'some hash')

# test scale provided is too small
with pytest.raises(ValueError):
scale = settings.MIN_SCALE - 0.1
scale = settings.MIN_SCALE - 0.05
consumer.get_image_scale(scale, image, 'some hash')

def test_get_grpc_app(self, mocker, redis_client):
Expand Down
360 changes: 272 additions & 88 deletions redis_consumer/consumers/polaris_consumer.py

Large diffs are not rendered by default.

154 changes: 133 additions & 21 deletions redis_consumer/consumers/polaris_consumer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,111 @@

import numpy as np

import pytest
import os
import random
import string
import tifffile
import uuid

from redis_consumer import consumers
from redis_consumer import settings
from redis_consumer.testing_utils import _get_image
from redis_consumer.testing_utils import Bunch
from redis_consumer.testing_utils import DummyStorage
from redis_consumer.testing_utils import redis_client


class TestPolarisConsumer(object):
# pylint: disable=R0201,W0621

def test__add_images(self, redis_client):
queue = 'polaris'
storage = DummyStorage()
consumer = consumers.PolarisConsumer(redis_client, storage, queue)

test_im = np.random.random(size=(1, 32, 32, 1))
test_im_name = 'test_im'

test_hvals = {'original_name': test_im_name}
uid = uuid.uuid4().hex

test_im_hash = consumer._add_images(test_hvals, uid, test_im, queue)
split_hash = test_im_hash.split(":")

assert split_hash[0] == queue
assert split_hash[1] == '{}-{}-{}-image.tif'.format(uid,
test_hvals.get('original_name'),
queue)

result = redis_client.hget(test_im_hash, 'status')
assert result == 'new'

def test__analyze_images(self, tmpdir, mocker, redis_client):
queue = 'polaris'
storage = DummyStorage()
consumer = consumers.PolarisConsumer(redis_client, storage, queue)

fname = 'file.tiff'
filepath = os.path.join(tmpdir, fname)
input_size = (1, 32, 32, 1)

# test successful workflow
def hget_successful_status(*_):
return consumer.final_status

def write_child_tiff(*_, **__):
letters = string.ascii_lowercase
name = ''.join(random.choice(letters) for i in range(12))
path = os.path.join(tmpdir, '{}.tiff'.format(name))
tifffile.imsave(path, _get_image(32, 32))
return [path]

mocker.patch.object(settings, 'INTERVAL', 0)
mocker.patch.object(redis_client, 'hget', hget_successful_status)
mocker.patch('redis_consumer.utils.iter_image_archive',
write_child_tiff)

tifffile.imsave(filepath, np.random.random(input_size))

# No segmentation
test_hash = 'test hash'
empty_data = {'input_file_name': 'file.tiff',
'segmentation_type': 'none',
'channels': '0,,'}
redis_client.hmset(test_hash, empty_data)
results = consumer._analyze_images(test_hash, tmpdir, fname)
coords, segmentation = results.get('coords'), results.get('segmentation')

assert isinstance(coords, np.ndarray)
assert isinstance(segmentation, list)

# Cell culture segmentation
test_hash = 'test hash'
empty_data = {'input_file_name': 'file.tiff',
'segmentation_type': 'cell culture',
'channels': '0,1,2'}
redis_client.hmset(test_hash, empty_data)
results = consumer._analyze_images(test_hash, tmpdir, fname)
coords, segmentation = results.get('coords'), results.get('segmentation')

assert isinstance(coords, np.ndarray)
assert isinstance(segmentation, list)
assert np.shape(segmentation)[1] == input_size[1]
assert np.shape(segmentation)[2] == input_size[2]

# Tissue segmentation
test_hash = 'test hash'
empty_data = {'input_file_name': 'file.tiff',
'segmentation_type': 'tissue',
'channels': '0,1,2'}
redis_client.hmset(test_hash, empty_data)
results = consumer._analyze_images(test_hash, tmpdir, fname)
coords, segmentation = results.get('coords'), results.get('segmentation')

assert isinstance(coords, np.ndarray)
assert isinstance(segmentation, list)
assert np.shape(segmentation)[1] == input_size[1]
assert np.shape(segmentation)[2] == input_size[2]

def test__consume_finished_status(self, redis_client):
queue = 'q'
storage = DummyStorage()
Expand All @@ -66,31 +158,51 @@ def test__consume_finished_status(self, redis_client):

def test__consume(self, mocker, redis_client):
# pylint: disable=W0613
queue = 'multiplex'
queue = 'polaris'
storage = DummyStorage()

consumer = consumers.PolarisConsumer(redis_client, storage, queue)

empty_data = {'input_file_name': 'file.tiff'}

output_shape = (1, 256, 256, 2)

mock_app = Bunch(
predict=lambda *x, **y: np.random.randint(1, 5, size=output_shape),
model_mpp=1,
model=Bunch(
get_batch_size=lambda *x: 1,
input_shape=(1, 32, 32, 1)
)
)

mocker.patch.object(consumer, 'get_grpc_app', lambda *x, **_: mock_app)
mocker.patch.object(consumer, 'get_image_scale', lambda *x, **_: 1)
mocker.patch.object(consumer, 'validate_model_input', lambda *x, **_: x[0])
mocker.patch.object(consumer, 'detect_dimension_order', lambda *x, **_: 'YXC')

# consume with cell culture segmentation and spot detection
empty_data = {'input_file_name': 'file.tiff',
'segmentation_type': 'cell culture'}
mocker.patch.object(consumer,
'_analyze_images',
lambda *x, **_: {'coords': np.random.randint(32, size=(1, 10, 2)),
'segmentation': np.random.random(size=(1, 32, 32, 1))
}
)
test_hash = 'some hash'
redis_client.hmset(test_hash, empty_data)
result = consumer._consume(test_hash)
assert result == consumer.final_status
result = redis_client.hget(test_hash, 'status')
assert result == consumer.final_status

# consume with tissue segmentation and spot detection
empty_data = {'input_file_name': 'file.tiff',
'segmentation_type': 'tissue'}
mocker.patch.object(consumer,
'_analyze_images',
lambda *x, **_: {'coords': np.random.randint(32, size=(1, 10, 2)),
'segmentation': np.random.random(size=(1, 32, 32, 1))
}
)
test_hash = 'another hash'
redis_client.hmset(test_hash, empty_data)
result = consumer._consume(test_hash)
assert result == consumer.final_status
result = redis_client.hget(test_hash, 'status')
assert result == consumer.final_status

# consume with spot detection only
empty_data = {'input_file_name': 'file.tiff',
'segmentation_type': 'none'}
mocker.patch.object(consumer,
'_analyze_images',
lambda *x, **_: {'coords': np.random.randint(32, size=(1, 10, 2)),
'segmentation': []})
test_hash = 'some other hash'
redis_client.hmset(test_hash, empty_data)
result = consumer._consume(test_hash)
assert result == consumer.final_status
Expand Down
79 changes: 44 additions & 35 deletions redis_consumer/consumers/segmentation_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ class SegmentationConsumer(TensorFlowServingConsumer):
"""Consumes image files and uploads the results"""

def detect_label(self, image):
"""Send the image to the LABEL_DETECT_MODEL to detect the type of image
data. The model output is mapped with settings.MODEL_CHOICES.
""" DEPRECATED -- Send the image to the LABEL_DETECT_MODEL to
detect the type of image data. The model output is mapped with
settings.MODEL_CHOICES.
Args:
image (numpy.array): The image data.
Expand All @@ -68,7 +69,7 @@ def detect_label(self, image):
return int(detected_label)

def get_image_label(self, label, image, redis_hash):
"""Calculate label of image."""
""" DEPRACATED -- Calculate label of image."""
if not label:
# Detect scale of image (Default to 1)
label = self.detect_label(image)
Expand Down Expand Up @@ -104,13 +105,15 @@ def _consume(self, redis_hash):
# Load input image
fname = hvals.get('input_file_name')
image = self.download_image(fname)
image = np.squeeze(image)
image = np.expand_dims(image, axis=0) # add a batch dimension
if len(np.shape(image)) == 3:
image = np.expand_dims(image, axis=-1) # add a channel dimension

# Validate input image
if hvals.get('channels'):
channels = [int(c) for c in hvals.get('channels').split(',')]
else:
channels = None
rank = 4 # (b,x,y,c)
channel_axis = image.shape[1:].index(min(image.shape[1:])) + 1
if channel_axis != rank - 1:
image = np.rollaxis(image, 1, rank)

# Pre-process data before sending to the model
self.update_key(redis_hash, {
Expand All @@ -122,34 +125,40 @@ def _consume(self, redis_hash):
scale = hvals.get('scale', '')
scale = self.get_image_scale(scale, image, redis_hash)

label = hvals.get('label', '')
label = self.get_image_label(label, image, redis_hash)

# Grap appropriate model and application class
model = settings.MODEL_CHOICES[label]
app_cls = settings.APPLICATION_CHOICES[label]

model_name, model_version = model.split(':')

# detect dimension order and add to redis
dim_order = self.detect_dimension_order(image, model_name, model_version)
self.update_key(redis_hash, {
'dim_order': ','.join(dim_order)
})

# Validate input image
image = self.validate_model_input(image, model_name, model_version,
channels=channels)

# Send data to the model
self.update_key(redis_hash, {'status': 'predicting'})

app = self.get_grpc_app(model, app_cls)
# with new batching update in deepcell.applications,
# app.predict() cannot handle a batch_size of None.
batch_size = app.model.get_batch_size()
results = app.predict(image, batch_size=batch_size,
image_mpp=scale * app.model_mpp)
channels = hvals.get('channels').split(',') # ex: channels = ['0','1','2']

results = []
for i in range(len(channels)):
if channels[i]:
slice_image = image[..., int(channels[i])]
slice_image = np.expand_dims(slice_image, axis=-1)
# Grap appropriate model and application class
model = settings.MODEL_CHOICES[i]
app_cls = settings.APPLICATION_CHOICES[i]

model_name, model_version = model.split(':')

# detect dimension order and add to redis
dim_order = self.detect_dimension_order(slice_image, model_name, model_version)
self.update_key(redis_hash, {
'dim_order': ','.join(dim_order)
})

# Validate input image
slice_image = self.validate_model_input(slice_image, model_name, model_version)

# Send data to the model
self.update_key(redis_hash, {'status': 'predicting'})

app = self.get_grpc_app(model, app_cls)
# with new batching update in deepcell.applications,
# app.predict() cannot handle a batch_size of None.
batch_size = min(32, app.model.get_batch_size()) # TODO: raise max batch size
pred_results = app.predict(slice_image, batch_size=batch_size,
image_mpp=scale * app.model_mpp)

results.extend(pred_results)

# Save the post-processed results to a file
_ = timeit.default_timer()
Expand Down
4 changes: 2 additions & 2 deletions redis_consumer/consumers/segmentation_consumer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def test__consume(self, mocker, redis_client):

consumer = consumers.SegmentationConsumer(redis_client, storage, queue)

empty_data = {'input_file_name': 'file.tiff'}
empty_data = {'input_file_name': 'file.tiff',
'channels': '0,'}

output_shape = (1, 32, 32, 1)

Expand All @@ -140,7 +141,6 @@ def test__consume(self, mocker, redis_client):

mocker.patch.object(consumer, 'get_grpc_app', lambda *x, **_: mock_app)
mocker.patch.object(consumer, 'get_image_scale', lambda *x, **_: 1)
mocker.patch.object(consumer, 'get_image_label', lambda *x, **_: 1)
mocker.patch.object(consumer, 'validate_model_input', lambda *x, **_: True)
mocker.patch.object(consumer, 'detect_dimension_order', lambda *x, **_: 'YXC')

Expand Down
Loading

0 comments on commit 90970fc

Please sign in to comment.