-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change model output to dictionary #22
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a check to your postprocessing functions to confirm that the two required keys are present in the dictionary and maybe that the shapes make sense. Let me know what you think.
def _run_model(self, | ||
image, | ||
batch_size=4, | ||
pad_mode='constant', | ||
preprocess_kwargs={}): | ||
"""Run the model to generate output probabilities on the data. | ||
|
||
Args: | ||
image (numpy.array): Image with shape ``[batch, x, y, channel]`` | ||
batch_size (int): Number of images to predict on per batch. | ||
pad_mode (str): The padding mode, one of "constant" or "reflect". | ||
preprocess_kwargs (dict): Keyword arguments to pass to | ||
the preprocessing function. | ||
|
||
Returns: | ||
numpy.array: Model outputs | ||
""" | ||
# Preprocess image if function is defined | ||
image = self._preprocess(image, **preprocess_kwargs) | ||
|
||
# Tile images, raises error if the image is not 4d | ||
tiles, tiles_info = self._tile_input(image, pad_mode=pad_mode) | ||
|
||
# Run images through model | ||
t = timeit.default_timer() | ||
output_tiles = self.model.predict(tiles, batch_size=batch_size) | ||
self.logger.debug('Model inference finished in %s s', | ||
timeit.default_timer() - t) | ||
|
||
# Untile images | ||
output_images = self._untile_output(output_tiles, tiles_info) | ||
|
||
# restructure outputs into a dict if function provided | ||
formatted_images = {name: pred for name, pred in zip(self.model.output_names, | ||
output_images)} | ||
|
||
return formatted_images | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you decide to duplicate _run_model with the addition of lines 142-144 instead of just adding those lines after calling _run_model in _predict?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I replaced the line calling _format_model_output
because it didn't have a way to pass arguments in the _format_model_output_fn
. You're right that I could have just done this outside of _run_model but my thought was that the line I added serves the same purpose as the _format_model_output
call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would try to avoid duplicating a core function like _run_model especially if there is a way to pass in a format function. It's going to make the spots repo harder to maintain in the long run if we ever do a major refactor of the deepcell application code. If you want keep the output names exposed, you could make it a parameter in a function format_spots_output
and then when you pass it in to the application to define _format_model_output_fn
you could first override it with self.model.output_names
, e.g.
def _format_spots_output(output):
return format_spots_otuput(output, output_names=self.model.output_names)
I'll let it go if you think its too complicated, but take a look first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you think it's cleaner to move it out of _run_model
I can change it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer having that change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This PR makes the changes necessary to name the outputs of the
dot_net_2D
model after prediction in a dictionary as opposed to a list. I have also removed arguments for scaling from the application, which is not currently supported by the spots model.