diff --git a/docs/rtd-requirements.txt b/docs/rtd-requirements.txt index f58efbdf..25f53a6d 100644 --- a/docs/rtd-requirements.txt +++ b/docs/rtd-requirements.txt @@ -4,3 +4,5 @@ Sphinx==2.3.1 # mistune 2 and m2r 0.2.1 seem to not play well # https://github.com/miyakogi/m2r/issues/66 mistune<2.0.0 +# Jinja2 imports for sphinx are deprecated over 3.1 +Jinja2<3.1 diff --git a/redis_consumer/consumers/__init__.py b/redis_consumer/consumers/__init__.py index 89e83ae6..922efaf2 100644 --- a/redis_consumer/consumers/__init__.py +++ b/redis_consumer/consumers/__init__.py @@ -37,6 +37,7 @@ from redis_consumer.consumers.caliban_consumer import CalibanConsumer from redis_consumer.consumers.mesmer_consumer import MesmerConsumer from redis_consumer.consumers.polaris_consumer import PolarisConsumer +from redis_consumer.consumers.spot_consumer import SpotConsumer # TODO: Import future custom Consumer classes. @@ -49,6 +50,7 @@ 'mesmer': MesmerConsumer, 'caliban': CalibanConsumer, 'polaris': PolarisConsumer, + 'spot': SpotConsumer, # TODO: Add future custom Consumer classes here. } diff --git a/redis_consumer/consumers/base_consumer_test.py b/redis_consumer/consumers/base_consumer_test.py index 1ed7bb8a..242d7de7 100644 --- a/redis_consumer/consumers/base_consumer_test.py +++ b/redis_consumer/consumers/base_consumer_test.py @@ -471,12 +471,12 @@ def test_get_image_scale(self, mocker, redis_client): # test scale provided is too large with pytest.raises(ValueError): - scale = settings.MAX_SCALE + 0.1 + scale = settings.MAX_SCALE + 0.05 consumer.get_image_scale(scale, image, 'some hash') # test scale provided is too small with pytest.raises(ValueError): - scale = settings.MIN_SCALE - 0.1 + scale = settings.MIN_SCALE - 0.05 consumer.get_image_scale(scale, image, 'some hash') def test_get_grpc_app(self, mocker, redis_client): diff --git a/redis_consumer/consumers/polaris_consumer.py b/redis_consumer/consumers/polaris_consumer.py index 8f70673d..f48e36ea 100644 --- a/redis_consumer/consumers/polaris_consumer.py +++ b/redis_consumer/consumers/polaris_consumer.py @@ -23,19 +23,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""PolarisConsumer class for consuming SpotDetection jobs.""" +"""PolarisConsumer class for consuming singleplex FISH analysis jobs.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +import csv +import json import os import tempfile +import time import timeit +import uuid import matplotlib.pyplot as plt import numpy as np +import tifffile -from deepcell_spots.applications import SpotDetection +from deepcell_spots.singleplex import match_spots_to_cells from redis_consumer.consumers import TensorFlowServingConsumer from redis_consumer import settings @@ -43,11 +48,192 @@ class PolarisConsumer(TensorFlowServingConsumer): - """Consumes image files and uploads the results""" + """Consumes multichannnel image files for singleplex FISH analysis, adds single + channel images to spot detection and segmentation queues, and uploads the results + """ + + def _add_images(self, hvals, uid, image, queue, channels=''): + """ + Uploads image to a temporary directory and adds it to the redis queue + for analysis. + """ + with tempfile.TemporaryDirectory() as tempdir: + # Save and upload the spots image + image_fname = '{}-{}-{}-image.tif'.format( + uid, hvals.get('original_name'), queue) + image_local_path = os.path.join(tempdir, image_fname) + tifffile.imsave(image_local_path, image) + upload_file_name, upload_file_url = self.storage.upload( + image_local_path) + + self.logger.debug('Image shape: {}'.format(image.shape)) + # prepare hvals for this images's hash + current_timestamp = self.get_current_timestamp() + image_hvals = { + 'identity_upload': self.name, + 'input_file_name': upload_file_name, + 'original_name': image_fname, + 'status': 'new', + 'created_at': current_timestamp, + 'updated_at': current_timestamp, + 'url': upload_file_url, + 'channels': channels, + 'scale': settings.POLARIS_SCALE} # scaling not supported for spots model + + # make a hash for this frame + image_hash = '{prefix}:{file}:{hash}'.format( + prefix=queue, + file=image_fname, + hash=uuid.uuid4().hex) + + self.redis.hmset(image_hash, image_hvals) + self.redis.lpush(queue, image_hash) + self.logger.debug('Added new hash to %s queue `%s`: %s', + queue, image_hash, json.dumps(image_hvals, indent=4)) + + return(image_hash) + + def _analyze_images(self, redis_hash, subdir, fname): + """ + Given the upload location `input_file_name`, and the downloaded + location of the same file in subdir/fname, return the raw and annotated + data. + """ + hvals = self.redis.hgetall(redis_hash) + raw = utils.get_image(os.path.join(subdir, fname)) + + # remove the last dimensions added by `get_image` + tiff_stack = np.squeeze(raw) + + self.logger.debug('Got tiffstack shape %s.', tiff_stack.shape) - def save_output(self, coords, image, save_name): + # get segmentation type and channel order + if hvals.get('channels'): + channels = hvals.get('channels').split(',') + else: + channels = ['0'] + + self.logger.debug('Channels: {}'.format(channels)) + segmentation_type = hvals.get('segmentation_type') + + remaining_hashes = set() + uid = uuid.uuid4().hex + + self.logger.debug('Starting spot detection') + # get spots image and add to spot_detection queue + spots_image = tiff_stack[..., int(channels[0])] + self.logger.debug('Spot image size: {}'.format(spots_image.shape)) + spots_hash = self._add_images(hvals, uid, spots_image, queue='spot') + remaining_hashes.add(spots_hash) + + self.logger.debug('Starting segmentation') + if segmentation_type == 'cell culture': + if channels[1]: + # add channel 1 ind of tiff stack to nuclear queue + nuc_image = tiff_stack[..., int(channels[1])] + nuc_hash = self._add_images(hvals, uid, nuc_image, + queue='segmentation', channels='0,') + remaining_hashes.add(nuc_hash) + + if channels[2]: + # add channel 2 ind of tiff stack to segmentation queue + cyto_image = tiff_stack[..., int(channels[2])] + cyto_hash = self._add_images(hvals, uid, cyto_image, + queue='segmentation', channels=',0') + remaining_hashes.add(cyto_hash) + + elif segmentation_type == 'tissue': + # add ims 1 and 2 to mesmer queue + nuc_image = tiff_stack[..., int(channels[1])] + nuc_image = np.expand_dims(nuc_image, axis=-1) + cyto_image = tiff_stack[..., int(channels[2])] + cyto_image = np.expand_dims(cyto_image, axis=-1) + mesmer_image = np.concatenate((nuc_image, cyto_image), axis=-1) + mesmer_hash = self._add_images(hvals, uid, mesmer_image, queue='mesmer') + remaining_hashes.add(mesmer_hash) + + coords = [] + segmentation_results = [] + segmentation_dict = {} + while remaining_hashes: + finished_hashes = set() + for h in remaining_hashes: + status = self.redis.hget(h, 'status') + + self.logger.debug('Hash %s has status %s', + h, status) + + if status == self.failed_status: + # Analysis failed + reason = self.redis.hget(h, 'reason') + raise RuntimeError( + 'Analysis failed for image with hash: {} ' + 'for this reason: {}'.format(h, reason)) + + if status == self.final_status: + # Analysis finished + with tempfile.TemporaryDirectory() as tempdir: + out = self.redis.hget(h, 'output_file_name') + pred_zip = self.storage.download(out, tempdir) + pred_files = list(utils.iter_image_archive( + pred_zip, tempdir)) + + if 'spot' in h: + # handle spot detection results + for i, file in enumerate(pred_files): + if file.endswith('.npy'): + spots_pred = np.load(pred_files[i]) + coords.append(spots_pred) + elif 'mesmer' in h: + # handle tissue segmentation results + segmentation_stack = [] + for i, file in enumerate(pred_files): + seg_pred = utils.get_image(file) + seg_pred = np.squeeze(seg_pred) + segmentation_stack.append(seg_pred) + segmentation_stack = np.array(segmentation_stack) + segmentation_stack = np.moveaxis(segmentation_stack, 0, 2) + segmentation_results.append(segmentation_stack) + else: + # handle cell culture segmentation results + segmentation_stack = [] + for i, file in enumerate(pred_files): + seg_pred = utils.get_image(file) + seg_pred = np.squeeze(seg_pred) + segmentation_stack.append(seg_pred) + segmentation_stack = np.array(segmentation_stack) + segmentation_stack = np.moveaxis(segmentation_stack, 0, 2) + segmentation_uid = os.path.split(file)[1][:10] + if segmentation_uid in segmentation_dict.keys(): + segmentation_dict[segmentation_uid].append(segmentation_stack) + else: + segmentation_dict[segmentation_uid] = [segmentation_stack] + + finished_hashes.add(h) + + remaining_hashes -= finished_hashes + time.sleep(settings.INTERVAL) + + if segmentation_type == 'cell culture': + for key in segmentation_dict.keys(): + labeled_im = np.array(segmentation_dict[key]) + labeled_im = np.squeeze(labeled_im, -1) + labeled_im = np.moveaxis(labeled_im, 0, 2) + segmentation_results.append(labeled_im) + + return {'coords': np.array(coords), 'segmentation': segmentation_results} + + def save_output(self, res, hvals): """Save output in a zip file and upload it. Output includes predicted spot locations - plotted on original image as a .tiff file and coordinate spot locations as .npy file""" + plotted on original image as a .tiff file and coordinate spot locations and assigned + cells as a .csv file + """ + # Assign spots to cells + coords = np.array(res['coords']) + labeled_im = np.array(res['segmentation']) + fname = hvals.get('input_file_name') + save_name = hvals.get('original_name', fname) + with tempfile.TemporaryDirectory() as tempdir: # Save each result channel as an image file subdir = os.path.dirname(save_name.replace(tempdir, '')) @@ -55,32 +241,77 @@ def save_output(self, coords, image, save_name): outpaths = [] for i in range(len(coords)): - # Save image with plotted spot locations - img_name = '{}.tif'.format(i) + # Save labeled image + outpaths.extend(utils.save_numpy_array( + labeled_im[i], + name=str(name), + subdir=subdir, output_dir=tempdir)) + + # Save spot locations and assignments in .csv file + csv_name = '{}.csv'.format(i) if name: - img_name = '{}_{}'.format(name, img_name) + csv_name = '{}_{}'.format(name, csv_name) + csv_path = os.path.join(tempdir, subdir, csv_name) + if np.shape(labeled_im)[3] == 2: + csv_header = ['x', 'y', 'cellID0', 'cellID1'] + else: + csv_header = ['x', 'y', 'cellID0'] + with open(csv_path, 'w', newline='') as csv_file: + writer = csv.writer(csv_file, delimiter=',') + writer.writerow(csv_header) + for ii in range(len(coords[i])): + loc = coords[i][ii] + assignment0 = labeled_im[i, int(loc[0]), int(loc[1]), 0] + if np.shape(labeled_im)[3] == 2: + assignment1 = labeled_im[i, int(loc[0]), int(loc[1]), 1] + writer.writerow([loc[1], loc[0], int(assignment0), int(assignment1)]) + else: + writer.writerow([loc[1], loc[0], int(assignment0)]) + + outpaths.extend([csv_path]) + + # Save each prediction image as zip file + zip_file = utils.zip_files(outpaths, tempdir) - img_path = os.path.join(tempdir, subdir, img_name) + # Upload the zip file to cloud storage bucket + cleaned = zip_file.replace(tempdir, '') + subdir = os.path.dirname(utils.strip_bucket_path(cleaned)) + subdir = subdir if subdir else None + dest, output_url = self.storage.upload(zip_file, subdir=subdir) - fig = plt.figure() - plt.ioff() - plt.imshow(image[i], cmap='gray') - plt.scatter(coords[i][:, 1], coords[i][:, 0], edgecolors='r', facecolors='None') - plt.xticks([]) - plt.yticks([]) - plt.savefig(img_path) + return(dest, output_url) - # Save coordiates - coords_name = '{}.npy'.format(i) - if name: - coords_name = '{}_{}'.format(name, coords_name) + def save_coords(self, res, hvals): + """Save output in a zip file and upload it. Output includes predicted spot locations + plotted on original image as a .tiff file and coordinate spot locations as a .csv file + """ + # Assign spots to cells + coords = np.array(res['coords']) + fname = hvals.get('input_file_name') + save_name = hvals.get('original_name', fname) - coords_path = os.path.join(tempdir, subdir, coords_name) + with tempfile.TemporaryDirectory() as tempdir: + # Save each result channel as an image file + subdir = os.path.dirname(save_name.replace(tempdir, '')) + name = os.path.splitext(os.path.basename(save_name))[0] - np.save(coords_path, coords[i]) + outpaths = [] + for i in range(len(coords)): + # Save image with plotted spot locations - outpaths.extend([img_path, coords_path]) - # outpaths.extend([coords_path]) + csv_name = '{}.csv'.format(i) + if name: + csv_name = '{}_{}'.format(name, csv_name) + csv_path = os.path.join(tempdir, subdir, csv_name) + csv_header = ['x', 'y'] + with open(csv_path, 'w', newline='') as csv_file: + writer = csv.writer(csv_file, delimiter=',') + writer.writerow(csv_header) + for ii in range(len(coords[i])): + loc = coords[i][ii] + writer.writerow([loc[1], loc[0]]) + + outpaths.extend([csv_path]) # Save each prediction image as zip file zip_file = utils.zip_files(outpaths, tempdir) @@ -91,7 +322,7 @@ def save_output(self, coords, image, save_name): subdir = subdir if subdir else None dest, output_url = self.storage.upload(zip_file, subdir=subdir) - return dest, output_url + return(dest, output_url) def _consume(self, redis_hash): start = timeit.default_timer() @@ -102,79 +333,32 @@ def _consume(self, redis_hash): redis_hash, hvals.get('status')) return hvals.get('status') - self.logger.debug('Found hash to process `%s` with status `%s`.', - redis_hash, hvals.get('status')) - self.update_key(redis_hash, { 'status': 'started', 'identity_started': self.name, }) - # Get model_name and version - model_name, model_version = settings.POLARIS_MODEL.split(':') - - _ = timeit.default_timer() - - # Load input image - fname = hvals.get('input_file_name') - image = self.download_image(fname) - - # squeeze extra dimension that is added by get_image - image = np.squeeze(image) - if image.ndim == 2: - # add in the batch and channel dims - image = np.expand_dims(image, axis=[0, -1]) - elif image.ndim == 3: - # check if batch first or last - if np.shape(image)[2] < np.shape(image)[1]: - image = np.rollaxis(image, 2, 0) - # add in the channel dim - image = np.expand_dims(image, axis=[-1]) - else: - raise ValueError('Image with {} shape was uploaded, but Polaris only ' - 'supports multi-batch or multi-channel images.'.format( - np.shape(image))) - - # Pre-process data before sending to the model - self.update_key(redis_hash, { - 'status': 'pre-processing', - 'download_time': timeit.default_timer() - _, - }) - - # detect dimension order and add to redis - dim_order = self.detect_dimension_order(image, model_name, model_version) - self.update_key(redis_hash, { - 'dim_order': ','.join(dim_order) - }) - - # Validate input image - if hvals.get('channels'): - channels = [int(c) for c in hvals.get('channels').split(',')] - else: - channels = None - - image = self.validate_model_input(image, model_name, model_version, - channels=channels) - - # Send data to the model - self.update_key(redis_hash, {'status': 'predicting'}) - - app = self.get_grpc_app(settings.POLARIS_MODEL, SpotDetection) - - # with new batching update in deepcell.applications, - # app.predict() cannot handle a batch_size of None. - batch_size = app.model.get_batch_size() - threshold = hvals.get('threshold', settings.POLARIS_THRESHOLD) - clip = hvals.get('clip', settings.POLARIS_CLIP) - results = app.predict(image, batch_size=batch_size, threshold=threshold, - clip=clip) + with tempfile.TemporaryDirectory() as tempdir: + # Pre-process data before sending to the model + fname = self.storage.download(hvals.get('input_file_name'), + tempdir) + self.update_key(redis_hash, { + 'status': 'predicting' + }) + res = self._analyze_images(redis_hash, tempdir, fname) + + self.logger.debug('Finished spot detection and segmentation.') + self.logger.debug('Coords shape: %s', np.shape(res['coords'])) + self.logger.debug('Segmentation result shape: %s', np.shape(res['segmentation'])) # Save the post-processed results to a file _ = timeit.default_timer() self.update_key(redis_hash, {'status': 'saving-results'}) - save_name = hvals.get('original_name', fname) - dest, output_url = self.save_output(results, image, save_name) + if hvals.get('segmentation_type') == 'none': + dest, output_url = self.save_coords(res, hvals) + else: + dest, output_url = self.save_output(res, hvals) # Update redis with the final results end = timeit.default_timer() diff --git a/redis_consumer/consumers/polaris_consumer_test.py b/redis_consumer/consumers/polaris_consumer_test.py index 328f11a1..867107f4 100644 --- a/redis_consumer/consumers/polaris_consumer_test.py +++ b/redis_consumer/consumers/polaris_consumer_test.py @@ -30,12 +30,15 @@ import numpy as np -import pytest +import os +import random +import string +import tifffile +import uuid from redis_consumer import consumers from redis_consumer import settings from redis_consumer.testing_utils import _get_image -from redis_consumer.testing_utils import Bunch from redis_consumer.testing_utils import DummyStorage from redis_consumer.testing_utils import redis_client @@ -43,6 +46,95 @@ class TestPolarisConsumer(object): # pylint: disable=R0201,W0621 + def test__add_images(self, redis_client): + queue = 'polaris' + storage = DummyStorage() + consumer = consumers.PolarisConsumer(redis_client, storage, queue) + + test_im = np.random.random(size=(1, 32, 32, 1)) + test_im_name = 'test_im' + + test_hvals = {'original_name': test_im_name} + uid = uuid.uuid4().hex + + test_im_hash = consumer._add_images(test_hvals, uid, test_im, queue) + split_hash = test_im_hash.split(":") + + assert split_hash[0] == queue + assert split_hash[1] == '{}-{}-{}-image.tif'.format(uid, + test_hvals.get('original_name'), + queue) + + result = redis_client.hget(test_im_hash, 'status') + assert result == 'new' + + def test__analyze_images(self, tmpdir, mocker, redis_client): + queue = 'polaris' + storage = DummyStorage() + consumer = consumers.PolarisConsumer(redis_client, storage, queue) + + fname = 'file.tiff' + filepath = os.path.join(tmpdir, fname) + input_size = (1, 32, 32, 1) + + # test successful workflow + def hget_successful_status(*_): + return consumer.final_status + + def write_child_tiff(*_, **__): + letters = string.ascii_lowercase + name = ''.join(random.choice(letters) for i in range(12)) + path = os.path.join(tmpdir, '{}.tiff'.format(name)) + tifffile.imsave(path, _get_image(32, 32)) + return [path] + + mocker.patch.object(settings, 'INTERVAL', 0) + mocker.patch.object(redis_client, 'hget', hget_successful_status) + mocker.patch('redis_consumer.utils.iter_image_archive', + write_child_tiff) + + tifffile.imsave(filepath, np.random.random(input_size)) + + # No segmentation + test_hash = 'test hash' + empty_data = {'input_file_name': 'file.tiff', + 'segmentation_type': 'none', + 'channels': '0,,'} + redis_client.hmset(test_hash, empty_data) + results = consumer._analyze_images(test_hash, tmpdir, fname) + coords, segmentation = results.get('coords'), results.get('segmentation') + + assert isinstance(coords, np.ndarray) + assert isinstance(segmentation, list) + + # Cell culture segmentation + test_hash = 'test hash' + empty_data = {'input_file_name': 'file.tiff', + 'segmentation_type': 'cell culture', + 'channels': '0,1,2'} + redis_client.hmset(test_hash, empty_data) + results = consumer._analyze_images(test_hash, tmpdir, fname) + coords, segmentation = results.get('coords'), results.get('segmentation') + + assert isinstance(coords, np.ndarray) + assert isinstance(segmentation, list) + assert np.shape(segmentation)[1] == input_size[1] + assert np.shape(segmentation)[2] == input_size[2] + + # Tissue segmentation + test_hash = 'test hash' + empty_data = {'input_file_name': 'file.tiff', + 'segmentation_type': 'tissue', + 'channels': '0,1,2'} + redis_client.hmset(test_hash, empty_data) + results = consumer._analyze_images(test_hash, tmpdir, fname) + coords, segmentation = results.get('coords'), results.get('segmentation') + + assert isinstance(coords, np.ndarray) + assert isinstance(segmentation, list) + assert np.shape(segmentation)[1] == input_size[1] + assert np.shape(segmentation)[2] == input_size[2] + def test__consume_finished_status(self, redis_client): queue = 'q' storage = DummyStorage() @@ -66,31 +158,51 @@ def test__consume_finished_status(self, redis_client): def test__consume(self, mocker, redis_client): # pylint: disable=W0613 - queue = 'multiplex' + queue = 'polaris' storage = DummyStorage() consumer = consumers.PolarisConsumer(redis_client, storage, queue) - empty_data = {'input_file_name': 'file.tiff'} - - output_shape = (1, 256, 256, 2) - - mock_app = Bunch( - predict=lambda *x, **y: np.random.randint(1, 5, size=output_shape), - model_mpp=1, - model=Bunch( - get_batch_size=lambda *x: 1, - input_shape=(1, 32, 32, 1) - ) - ) - - mocker.patch.object(consumer, 'get_grpc_app', lambda *x, **_: mock_app) - mocker.patch.object(consumer, 'get_image_scale', lambda *x, **_: 1) - mocker.patch.object(consumer, 'validate_model_input', lambda *x, **_: x[0]) - mocker.patch.object(consumer, 'detect_dimension_order', lambda *x, **_: 'YXC') - + # consume with cell culture segmentation and spot detection + empty_data = {'input_file_name': 'file.tiff', + 'segmentation_type': 'cell culture'} + mocker.patch.object(consumer, + '_analyze_images', + lambda *x, **_: {'coords': np.random.randint(32, size=(1, 10, 2)), + 'segmentation': np.random.random(size=(1, 32, 32, 1)) + } + ) test_hash = 'some hash' + redis_client.hmset(test_hash, empty_data) + result = consumer._consume(test_hash) + assert result == consumer.final_status + result = redis_client.hget(test_hash, 'status') + assert result == consumer.final_status + + # consume with tissue segmentation and spot detection + empty_data = {'input_file_name': 'file.tiff', + 'segmentation_type': 'tissue'} + mocker.patch.object(consumer, + '_analyze_images', + lambda *x, **_: {'coords': np.random.randint(32, size=(1, 10, 2)), + 'segmentation': np.random.random(size=(1, 32, 32, 1)) + } + ) + test_hash = 'another hash' + redis_client.hmset(test_hash, empty_data) + result = consumer._consume(test_hash) + assert result == consumer.final_status + result = redis_client.hget(test_hash, 'status') + assert result == consumer.final_status + # consume with spot detection only + empty_data = {'input_file_name': 'file.tiff', + 'segmentation_type': 'none'} + mocker.patch.object(consumer, + '_analyze_images', + lambda *x, **_: {'coords': np.random.randint(32, size=(1, 10, 2)), + 'segmentation': []}) + test_hash = 'some other hash' redis_client.hmset(test_hash, empty_data) result = consumer._consume(test_hash) assert result == consumer.final_status diff --git a/redis_consumer/consumers/segmentation_consumer.py b/redis_consumer/consumers/segmentation_consumer.py index 2b6d1cb2..7bc0f41b 100644 --- a/redis_consumer/consumers/segmentation_consumer.py +++ b/redis_consumer/consumers/segmentation_consumer.py @@ -42,8 +42,9 @@ class SegmentationConsumer(TensorFlowServingConsumer): """Consumes image files and uploads the results""" def detect_label(self, image): - """Send the image to the LABEL_DETECT_MODEL to detect the type of image - data. The model output is mapped with settings.MODEL_CHOICES. + """ DEPRECATED -- Send the image to the LABEL_DETECT_MODEL to + detect the type of image data. The model output is mapped with + settings.MODEL_CHOICES. Args: image (numpy.array): The image data. @@ -68,7 +69,7 @@ def detect_label(self, image): return int(detected_label) def get_image_label(self, label, image, redis_hash): - """Calculate label of image.""" + """ DEPRACATED -- Calculate label of image.""" if not label: # Detect scale of image (Default to 1) label = self.detect_label(image) @@ -104,13 +105,15 @@ def _consume(self, redis_hash): # Load input image fname = hvals.get('input_file_name') image = self.download_image(fname) + image = np.squeeze(image) image = np.expand_dims(image, axis=0) # add a batch dimension + if len(np.shape(image)) == 3: + image = np.expand_dims(image, axis=-1) # add a channel dimension - # Validate input image - if hvals.get('channels'): - channels = [int(c) for c in hvals.get('channels').split(',')] - else: - channels = None + rank = 4 # (b,x,y,c) + channel_axis = image.shape[1:].index(min(image.shape[1:])) + 1 + if channel_axis != rank - 1: + image = np.rollaxis(image, 1, rank) # Pre-process data before sending to the model self.update_key(redis_hash, { @@ -122,34 +125,40 @@ def _consume(self, redis_hash): scale = hvals.get('scale', '') scale = self.get_image_scale(scale, image, redis_hash) - label = hvals.get('label', '') - label = self.get_image_label(label, image, redis_hash) - - # Grap appropriate model and application class - model = settings.MODEL_CHOICES[label] - app_cls = settings.APPLICATION_CHOICES[label] - - model_name, model_version = model.split(':') - - # detect dimension order and add to redis - dim_order = self.detect_dimension_order(image, model_name, model_version) - self.update_key(redis_hash, { - 'dim_order': ','.join(dim_order) - }) - # Validate input image - image = self.validate_model_input(image, model_name, model_version, - channels=channels) - - # Send data to the model - self.update_key(redis_hash, {'status': 'predicting'}) - - app = self.get_grpc_app(model, app_cls) - # with new batching update in deepcell.applications, - # app.predict() cannot handle a batch_size of None. - batch_size = app.model.get_batch_size() - results = app.predict(image, batch_size=batch_size, - image_mpp=scale * app.model_mpp) + channels = hvals.get('channels').split(',') # ex: channels = ['0','1','2'] + + results = [] + for i in range(len(channels)): + if channels[i]: + slice_image = image[..., int(channels[i])] + slice_image = np.expand_dims(slice_image, axis=-1) + # Grap appropriate model and application class + model = settings.MODEL_CHOICES[i] + app_cls = settings.APPLICATION_CHOICES[i] + + model_name, model_version = model.split(':') + + # detect dimension order and add to redis + dim_order = self.detect_dimension_order(slice_image, model_name, model_version) + self.update_key(redis_hash, { + 'dim_order': ','.join(dim_order) + }) + + # Validate input image + slice_image = self.validate_model_input(slice_image, model_name, model_version) + + # Send data to the model + self.update_key(redis_hash, {'status': 'predicting'}) + + app = self.get_grpc_app(model, app_cls) + # with new batching update in deepcell.applications, + # app.predict() cannot handle a batch_size of None. + batch_size = min(32, app.model.get_batch_size()) # TODO: raise max batch size + pred_results = app.predict(slice_image, batch_size=batch_size, + image_mpp=scale * app.model_mpp) + + results.extend(pred_results) # Save the post-processed results to a file _ = timeit.default_timer() diff --git a/redis_consumer/consumers/segmentation_consumer_test.py b/redis_consumer/consumers/segmentation_consumer_test.py index b085adaf..406855d4 100644 --- a/redis_consumer/consumers/segmentation_consumer_test.py +++ b/redis_consumer/consumers/segmentation_consumer_test.py @@ -125,7 +125,8 @@ def test__consume(self, mocker, redis_client): consumer = consumers.SegmentationConsumer(redis_client, storage, queue) - empty_data = {'input_file_name': 'file.tiff'} + empty_data = {'input_file_name': 'file.tiff', + 'channels': '0,'} output_shape = (1, 32, 32, 1) @@ -140,7 +141,6 @@ def test__consume(self, mocker, redis_client): mocker.patch.object(consumer, 'get_grpc_app', lambda *x, **_: mock_app) mocker.patch.object(consumer, 'get_image_scale', lambda *x, **_: 1) - mocker.patch.object(consumer, 'get_image_label', lambda *x, **_: 1) mocker.patch.object(consumer, 'validate_model_input', lambda *x, **_: True) mocker.patch.object(consumer, 'detect_dimension_order', lambda *x, **_: 'YXC') diff --git a/redis_consumer/consumers/spot_consumer.py b/redis_consumer/consumers/spot_consumer.py new file mode 100644 index 00000000..6c96b7e5 --- /dev/null +++ b/redis_consumer/consumers/spot_consumer.py @@ -0,0 +1,192 @@ +# Copyright 2016-2022 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/kiosk-redis-consumer/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""SpotDetectionConsumer class for consuming spot detection jobs.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile +import timeit + +import matplotlib.pyplot as plt +import numpy as np + +from deepcell_spots.applications import SpotDetection + +from redis_consumer.consumers import TensorFlowServingConsumer +from redis_consumer import settings +from redis_consumer import utils + + +class SpotConsumer(TensorFlowServingConsumer): + """Consumes image files and uploads the results""" + + def save_output(self, coords, image, save_name): + """Save output in a zip file and upload it. Output includes predicted spot locations + plotted on original image as a .tiff file and coordinate spot locations as .npy file""" + with tempfile.TemporaryDirectory() as tempdir: + # Save each result channel as an image file + subdir = os.path.dirname(save_name.replace(tempdir, '')) + name = os.path.splitext(os.path.basename(save_name))[0] + + outpaths = [] + for i in range(len(coords)): + # Save image with plotted spot locations + img_name = '{}.tif'.format(i) + if name: + img_name = '{}_{}'.format(name, img_name) + + img_path = os.path.join(tempdir, subdir, img_name) + + fig = plt.figure() + plt.ioff() + plt.imshow(image[i], cmap='gray') + plt.scatter(coords[i][:, 1], coords[i][:, 0], c='m', s=4) + plt.xticks([]) + plt.yticks([]) + plt.savefig(img_path) + + # Save coordiates + coords_name = '{}.npy'.format(i) + if name: + coords_name = '{}_{}'.format(name, coords_name) + + coords_path = os.path.join(tempdir, subdir, coords_name) + + np.save(coords_path, coords[i]) + + outpaths.extend([img_path, coords_path]) + # outpaths.extend([coords_path]) + + # Save each prediction image as zip file + zip_file = utils.zip_files(outpaths, tempdir) + + # Upload the zip file to cloud storage bucket + cleaned = zip_file.replace(tempdir, '') + subdir = os.path.dirname(utils.strip_bucket_path(cleaned)) + subdir = subdir if subdir else None + dest, output_url = self.storage.upload(zip_file, subdir=subdir) + + return dest, output_url + + def _consume(self, redis_hash): + start = timeit.default_timer() + hvals = self.redis.hgetall(redis_hash) + + if hvals.get('status') in self.finished_statuses: + self.logger.warning('Found completed hash `%s` with status %s.', + redis_hash, hvals.get('status')) + return hvals.get('status') + + self.logger.debug('Found hash to process `%s` with status `%s`.', + redis_hash, hvals.get('status')) + + self.update_key(redis_hash, { + 'status': 'started', + 'identity_started': self.name, + }) + + # Get model_name and version + model_name, model_version = settings.POLARIS_MODEL.split(':') + + _ = timeit.default_timer() + + # Load input image + fname = hvals.get('input_file_name') + image = self.download_image(fname) + + # squeeze extra dimension that is added by get_image + image = np.squeeze(image) + if image.ndim == 2: + # add in the batch and channel dims + image = np.expand_dims(image, axis=[0, -1]) + elif image.ndim == 3: + # check if batch first or last + if np.shape(image)[2] < np.shape(image)[1]: + image = np.rollaxis(image, 2, 0) + # add in the channel dim + image = np.expand_dims(image, axis=[-1]) + else: + raise ValueError('Image with {} shape was uploaded, but Polaris only ' + 'supports multi-batch or multi-channel images.'.format( + np.shape(image))) + + # Pre-process data before sending to the model + self.update_key(redis_hash, { + 'status': 'pre-processing', + 'download_time': timeit.default_timer() - _, + }) + + # detect dimension order and add to redis + dim_order = self.detect_dimension_order(image, model_name, model_version) + self.update_key(redis_hash, { + 'dim_order': ','.join(dim_order) + }) + + # Validate input image + if hvals.get('channels'): + channels = [int(c) for c in hvals.get('channels').split(',')] + else: + channels = None + + image = self.validate_model_input(image, model_name, model_version, + channels=channels) + + # Send data to the model + self.update_key(redis_hash, {'status': 'predicting'}) + + app = self.get_grpc_app(settings.POLARIS_MODEL, SpotDetection) + + # with new batching update in deepcell.applications, + # app.predict() cannot handle a batch_size of None. + batch_size = app.model.get_batch_size() + threshold = settings.POLARIS_THRESHOLD + clip = settings.POLARIS_CLIP + self.logger.debug('Threshold: {}'.format(threshold)) + self.logger.debug('Clip: {}'.format(clip)) + results = app.predict(image, batch_size=batch_size, threshold=threshold, + clip=clip) + + # Save the post-processed results to a file + _ = timeit.default_timer() + self.update_key(redis_hash, {'status': 'saving-results'}) + + save_name = hvals.get('original_name', fname) + dest, output_url = self.save_output(results, image, save_name) + + # Update redis with the final results + end = timeit.default_timer() + self.update_key(redis_hash, { + 'status': self.final_status, + 'output_url': output_url, + 'upload_time': end - _, + 'output_file_name': dest, + 'total_jobs': 1, + 'total_time': end - start, + 'finished_at': self.get_current_timestamp() + }) + return self.final_status diff --git a/redis_consumer/consumers/spot_consumer_test.py b/redis_consumer/consumers/spot_consumer_test.py new file mode 100644 index 00000000..4ae2d6e7 --- /dev/null +++ b/redis_consumer/consumers/spot_consumer_test.py @@ -0,0 +1,91 @@ +# Copyright 2016-2022 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/kiosk-redis-consumer/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for SpotDetectionConsumer""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from redis_consumer import consumers +from redis_consumer.testing_utils import Bunch +from redis_consumer.testing_utils import DummyStorage +from redis_consumer.testing_utils import redis_client + + +class TestSpotConsumer(object): + # pylint: disable=R0201,W0621 + + def test__consume_finished_status(self, redis_client): + queue = 'q' + storage = DummyStorage() + + consumer = consumers.SpotConsumer(redis_client, storage, queue) + + empty_data = {'input_file_name': 'file.tiff'} + + test_hash = 0 + # test finished statuses are returned + for status in (consumer.failed_status, consumer.final_status): + test_hash += 1 + data = empty_data.copy() + data['status'] = status + redis_client.hmset(test_hash, data) + result = consumer._consume(test_hash) + assert result == status + result = redis_client.hget(test_hash, 'status') + assert result == status + test_hash += 1 + + def test__consume(self, mocker, redis_client): + # pylint: disable=W0613 + queue = 'spot' + storage = DummyStorage() + + consumer = consumers.SpotConsumer(redis_client, storage, queue) + empty_data = {'input_file_name': 'file.tiff'} + output_shape = (1, 32, 2) + + mock_app = Bunch( + predict=lambda *x, **y: np.random.randint(1, 5, size=output_shape), + model=Bunch( + get_batch_size=lambda *x: 1, + input_shape=(1, 32, 32, 1) + ) + ) + + mocker.patch.object(consumer, 'get_grpc_app', lambda *x, **_: mock_app) + mocker.patch.object(consumer, 'get_image_scale', lambda *x, **_: 1) + mocker.patch.object(consumer, 'validate_model_input', lambda *x, **_: x[0]) + mocker.patch.object(consumer, 'detect_dimension_order', lambda *x, **_: 'YXC') + + test_hash = 'some hash' + + redis_client.hmset(test_hash, empty_data) + result = consumer._consume(test_hash) + assert result == consumer.final_status + result = redis_client.hget(test_hash, 'status') + assert result == consumer.final_status diff --git a/redis_consumer/settings.py b/redis_consumer/settings.py index 407d4769..53e8c470 100644 --- a/redis_consumer/settings.py +++ b/redis_consumer/settings.py @@ -105,7 +105,7 @@ # Scale detection settings SCALE_DETECT_MODEL = config('SCALE_DETECT_MODEL', default='ScaleDetection:1') SCALE_DETECT_ENABLED = config('SCALE_DETECT_ENABLED', default=False, cast=bool) -MAX_SCALE = config('MAX_SCALE', default=3, cast=float) +MAX_SCALE = config('MAX_SCALE', default=10, cast=float) MIN_SCALE = config('MIN_SCALE', default=1 / MAX_SCALE, cast=float) # Type detection settings @@ -122,12 +122,13 @@ POLARIS_MODEL = config('POLARIS_MODEL', default='SpotDetection:3', cast=str) POLARIS_THRESHOLD = config('POLARIS_THRESHOLD', default=0.95, cast=float) POLARIS_CLIP = config('POLARIS_CLIP', default=False, cast=bool) +POLARIS_SCALE = config('POLARIS_SCALE', default=0.38, cast=float) # Set default models based on label type MODEL_CHOICES = { - 0: config('NUCLEAR_MODEL', default='NuclearSegmentation:0', cast=str), - 1: config('PHASE_MODEL', default='PhaseCytoSegmentation:0', cast=str), - 2: config('CYTOPLASM_MODEL', default='FluoCytoSegmentation:0', cast=str) + 0: config('NUCLEAR_MODEL', default='NuclearSegmentation:5', cast=str), + 1: config('CYTOPLASM_MODEL', default='CytoplasmSegmentation:4', cast=str), + 2: config('PHASE_MODEL', default='PhaseCytoSegmentation:0', cast=str) } APPLICATION_CHOICES = { diff --git a/requirements.txt b/requirements.txt index f221e552..20f402a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ # deepcell packages deepcell-cpu~=0.11.0 +deepcell-spots~=0.2.0 deepcell-toolbox~=0.10.3 deepcell-tracking~=0.5.2 tensorflow-cpu~=2.5.2 @@ -7,8 +8,6 @@ tifffile>=2020.9.3 numpy>=1.16.6 matplotlib>=2.1.1 -git+git://github.com/vanvalenlab/deepcell-spots@f7749cf77d67a4bfd3a56a66b6488cb0feaffecf - # tensorflow-serving-apis and gRPC dependencies grpcio>=1.0,<2 dict-to-protobuf~=0.0.3.10