Skip to content

Commit

Permalink
refactoring allen smFish with new spot finding
Browse files Browse the repository at this point in the history
  • Loading branch information
Shannon Axelrod committed Oct 1, 2019
1 parent 5996a09 commit 3ad1a4e
Show file tree
Hide file tree
Showing 8 changed files with 274 additions and 31 deletions.
23 changes: 14 additions & 9 deletions notebooks/py/smFISH.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@

import starfish
import starfish.data
from starfish import FieldOfView, IntensityTable
from starfish import FieldOfView, DecodedIntensityTable
from starfish.types import TraceBuildingStrategies

# equivalent to %gui qt
ipython = get_ipython()
Expand Down Expand Up @@ -72,7 +73,7 @@
# EPY: END markdown

# EPY: START code
tlmpf = starfish.spots.DetectSpots.TrackpyLocalMaxPeakFinder(
tlmpf = starfish.spots.FindSpots.TrackpyLocalMaxPeakFinder(
spot_diameter=5, # must be odd integer
min_mass=0.02,
max_size=2, # this is max radius
Expand All @@ -96,6 +97,7 @@
import sys
print = partial(print, file=sys.stderr)


def processing_pipeline(
experiment: starfish.Experiment,
fov_name: str,
Expand Down Expand Up @@ -123,6 +125,11 @@ def processing_pipeline(
print("Loading images...")
images = enumerate(experiment[fov_name].get_images(FieldOfView.PRIMARY_IMAGES))

decoder = starfish.spots.DecodeSpots.PerRoundMaxChannel(
codebook=codebook,
trace_building_strategy=TraceBuildingStrategies.SEQUENTIAL
)

for image_number, primary_image in images:
print(f"Filtering image {image_number}...")
filter_kwargs = dict(
Expand All @@ -140,15 +147,13 @@ def processing_pipeline(
clip2.run(primary_image, **filter_kwargs)

print("Calling spots...")
spot_attributes = tlmpf.run(primary_image)
all_intensities.append(spot_attributes)
spots = tlmpf.run(primary_image)
decoded_intensities = decoder.run(spots)
all_intensities.append(decoded_intensities)

spot_attributes = IntensityTable.concatenate_intensity_tables(all_intensities)

print("Decoding spots...")
decoded = codebook.decode_per_round_max(spot_attributes)
decoded = DecodedIntensityTable.concatenate_intensity_tables(all_intensities)
decoded = decoded[decoded["total_intensity"] > .025]

print("Processing complete.")

return primary_image, decoded
Expand Down
23 changes: 14 additions & 9 deletions notebooks/smFISH.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
"\n",
"import starfish\n",
"import starfish.data\n",
"from starfish import FieldOfView, IntensityTable\n",
"from starfish import FieldOfView, DecodedIntensityTable\n",
"from starfish.types import TraceBuildingStrategies\n",
"\n",
"# equivalent to %gui qt\n",
"ipython = get_ipython()\n",
Expand Down Expand Up @@ -98,7 +99,7 @@
"metadata": {},
"outputs": [],
"source": [
"tlmpf = starfish.spots.DetectSpots.TrackpyLocalMaxPeakFinder(\n",
"tlmpf = starfish.spots.FindSpots.TrackpyLocalMaxPeakFinder(\n",
" spot_diameter=5, # must be odd integer\n",
" min_mass=0.02,\n",
" max_size=2, # this is max radius\n",
Expand Down Expand Up @@ -130,6 +131,7 @@
"import sys\n",
"print = partial(print, file=sys.stderr)\n",
"\n",
"\n",
"def processing_pipeline(\n",
" experiment: starfish.Experiment,\n",
" fov_name: str,\n",
Expand Down Expand Up @@ -157,6 +159,11 @@
" print(\"Loading images...\")\n",
" images = enumerate(experiment[fov_name].get_images(FieldOfView.PRIMARY_IMAGES))\n",
"\n",
" decoder = starfish.spots.DecodeSpots.PerRoundMaxChannel(\n",
" codebook=codebook,\n",
" trace_building_strategy=TraceBuildingStrategies.SEQUENTIAL\n",
" )\n",
"\n",
" for image_number, primary_image in images:\n",
" print(f\"Filtering image {image_number}...\")\n",
" filter_kwargs = dict(\n",
Expand All @@ -174,15 +181,13 @@
" clip2.run(primary_image, **filter_kwargs)\n",
"\n",
" print(\"Calling spots...\")\n",
" spot_attributes = tlmpf.run(primary_image)\n",
" all_intensities.append(spot_attributes)\n",
" spots = tlmpf.run(primary_image)\n",
" decoded_intensities = decoder.run(spots)\n",
" all_intensities.append(decoded_intensities)\n",
"\n",
" spot_attributes = IntensityTable.concatenate_intensity_tables(all_intensities)\n",
"\n",
" print(\"Decoding spots...\")\n",
" decoded = codebook.decode_per_round_max(spot_attributes)\n",
" decoded = DecodedIntensityTable.concatenate_intensity_tables(all_intensities)\n",
" decoded = decoded[decoded[\"total_intensity\"] > .025]\n",
" \n",
"\n",
" print(\"Processing complete.\")\n",
"\n",
" return primary_image, decoded"
Expand Down
44 changes: 41 additions & 3 deletions starfish/core/spots/DecodeSpots/trace_builders.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Callable, Mapping

import pandas as pd

from starfish.core.intensity_table.intensity_table import IntensityTable
from starfish.core.types import Axes, Features, SpotFindingResults, TraceBuildingStrategies
from starfish.core.types import Axes, Features, SpotAttributes, SpotFindingResults, \
TraceBuildingStrategies
from .util import _build_intensity_table, _match_spots, _merge_spots_by_round


def build_spot_traces_exact_match(spot_results: SpotFindingResults, **kwargs):
def build_spot_traces_exact_match(spot_results: SpotFindingResults, **kwargs) -> IntensityTable:
"""
Combines spots found in matching x/y positions across rounds and channels of
an ImageStack into traces represented as an IntensityTable.
Expand All @@ -30,6 +33,40 @@ def build_spot_traces_exact_match(spot_results: SpotFindingResults, **kwargs):
return intensity_table


def build_traces_sequential(spot_results: SpotFindingResults, **kwargs) -> IntensityTable:
"""
Build spot traces without merging across channels and imaging rounds. Used for sequential
methods like smFIsh.
Parameters
----------
spot_results: SpotFindingResults
Spots found across rounds/channels of an ImageStack
Returns
-------
IntensityTable :
concatenated input SpotAttributes, converted to an IntensityTable object
"""
ch_values = spot_results.ch_labels
round_values = spot_results.round_labels

all_spots = pd.concat([sa.data for sa in spot_results.values()], sort=True)

intensity_table = IntensityTable.zeros(
SpotAttributes(all_spots), ch_values, round_values,
)

i = 0
for (r, c), attrs in spot_results.items():
for _, row in attrs.data.iterrows():
selector = dict(features=i, c=c, r=r)
intensity_table.loc[selector] = row[Features.INTENSITY]
i += 1
return intensity_table


def build_traces_nearest_neighbors(spot_results: SpotFindingResults, anchor_round: int=1,
search_radius: int=3):
"""
Expand Down Expand Up @@ -65,5 +102,6 @@ def build_traces_nearest_neighbors(spot_results: SpotFindingResults, anchor_roun

trace_builders: Mapping[TraceBuildingStrategies, Callable] = {
TraceBuildingStrategies.EXACT_MATCH: build_spot_traces_exact_match,
TraceBuildingStrategies.NEAREST_NEIGHBOR: build_traces_nearest_neighbors
TraceBuildingStrategies.NEAREST_NEIGHBOR: build_traces_nearest_neighbors,
TraceBuildingStrategies.SEQUENTIAL: build_traces_sequential
}
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ def run(
n_processes : Optional[int] = None,
Number of processes to devote to spot finding.
"""
DeprecationWarning("Starfish is embarking on a SpotFinding data structures refactor"
"(See https://github.com/spacetx/starfish/issues/1514) This version of "
"TrackpyLocalMaxPeakFinder will soon be deleted. To find and decode your"
"spots please instead use FindSpots.TrackpyLocalMaxPeakFinder then "
"DecodeSpots.PerRoundMaxChannel with the parameter "
"TraceBuildingStrategies.SEQUENTIAL. See example in smFISH.py")
intensity_table = detect_spots(
data_stack=primary_image,
spot_finding_method=self.image_to_spots,
Expand Down
1 change: 1 addition & 0 deletions starfish/core/spots/FindSpots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._base import FindSpotsAlgorithm
from .blob import BlobDetector
from .trackpy_local_max_peak_finder import TrackpyLocalMaxPeakFinder

# autodoc's automodule directive only captures the modules explicitly listed in __all__.
all_filters = {
Expand Down
179 changes: 179 additions & 0 deletions starfish/core/spots/FindSpots/trackpy_local_max_peak_finder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import warnings
from functools import partial
from typing import Optional, Tuple, Union

import numpy as np
import xarray as xr
from trackpy import locate

from starfish.core.image.Filter.util import determine_axes_to_group_by
from starfish.core.imagestack.imagestack import ImageStack
from starfish.core.spots.FindSpots import spot_finding_utils
from starfish.core.types import Axes, SpotAttributes, SpotFindingResults
from ._base import FindSpotsAlgorithm


class TrackpyLocalMaxPeakFinder(FindSpotsAlgorithm):
"""
Find spots using a local max peak finding algorithm
This is a wrapper for :code:`trackpy.locate`, which implements a version of the
`Crocker-Grier <crocker_grier>`_ algorithm.
.. _crocker_grier: https://physics.nyu.edu/grierlab/methods3c/
Parameters
----------
spot_diameter : odd integer or tuple of odd integers.
This may be a single number or a tuple giving the feature’s extent in each dimension,
useful when the dimensions do not have equal resolution (e.g. confocal microscopy).
The tuple order is the same as the image shape, conventionally (z, y, x) or (y, x).
The number(s) must be odd integers. When in doubt, round up.
min_mass : float, optional
The minimum integrated brightness. This is a crucial parameter for eliminating spurious
features. Recommended minimum values are 100 for integer images and 1 for float images.
Defaults to 0 (no filtering).
max_size : float
maximum radius-of-gyration of brightness, default None
separation : float or tuple
Minimum separtion between features. Default is diameter + 1. May be a tuple, see
diameter for details.
percentile : float
Features must have a peak brighter than pixels in this percentile. This helps eliminate
spurious peaks.
noise_size : float or tuple
Width of Gaussian blurring kernel, in pixels Default is 1. May be a tuple, see diameter
for details.
smoothing_size : float or tuple
The size of the sides of the square kernel used in boxcar (rolling average) smoothing,
in pixels Default is diameter. May be a tuple, making the kernel rectangular.
threshold : float
Clip bandpass result below this value. Thresholding is done on the already
background-subtracted image. By default, 1 for integer images and 1/255 for float
images.
measurement_type : str ['max', 'mean']
name of the function used to calculate the intensity for each identified spot area
preprocess : boolean
Set to False to turn off bandpass preprocessing.
max_iterations : integer
Max number of loops to refine the center of mass, default 10
is_volume : bool
if True, run the algorithm on 3d volumes of the provided stack
verbose : bool
If True, report the percentage completed (default = False) during processing
Notes
-----
See also trackpy locate: http://soft-matter.github.io/trackpy/dev/generated/trackpy.locate.html
"""

def __init__(
self, spot_diameter, min_mass, max_size, separation, percentile=0,
noise_size: Tuple[int, int, int] = (1, 1, 1), smoothing_size=None, threshold=None,
preprocess: bool = False, max_iterations: int = 10, measurement_type: str = 'max',
is_volume: bool = False, verbose=False) -> None:

self.diameter = spot_diameter
self.minmass = min_mass
self.maxsize = max_size
self.separation = separation
self.noise_size = noise_size
self.smoothing_size = smoothing_size
self.percentile = percentile
self.threshold = threshold
self.measurement_function = self._get_measurement_function(measurement_type)
self.preprocess = preprocess
self.max_iterations = max_iterations
self.is_volume = is_volume
self.verbose = verbose

def image_to_spots(self, data_image: Union[np.ndarray, xr.DataArray]) -> SpotAttributes:
"""
Parameters
----------
data_image : np.ndarray
three-dimensional image containing spots to be detected
Returns
-------
SpotAttributes :
spot attributes table for all detected spots
"""
data_image = np.asarray(data_image)
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning) # trackpy numpy indexing warning
warnings.simplefilter('ignore', UserWarning) # yielded if black images
attributes = locate(
data_image,
diameter=self.diameter,
minmass=self.minmass,
maxsize=self.maxsize,
separation=self.separation,
noise_size=self.noise_size,
smoothing_size=self.smoothing_size,
threshold=self.threshold,
percentile=self.percentile,
preprocess=self.preprocess,
max_iterations=self.max_iterations,
)

# when zero spots are detected, 'ep' is missing from the trackpy locate results.
if attributes.shape[0] == 0:
attributes['ep'] = []

# TODO ambrosejcarr: data should always be at least pseudo-3d, this may not be necessary
# TODO ambrosejcarr: this is where max vs. sum vs. mean would be parametrized.
# here, total_intensity = sum, intensity = max
new_colnames = [
'y', 'x', 'total_intensity', 'radius', 'eccentricity', 'intensity', 'raw_mass', 'ep'
]
if len(data_image.shape) == 3:
attributes.columns = ['z'] + new_colnames
else:
attributes.columns = new_colnames

attributes['spot_id'] = np.arange(attributes.shape[0])
return SpotAttributes(attributes)

def run(
self,
primary_image: ImageStack,
reference_image: Optional[ImageStack] = None,
n_processes: Optional[int] = None,
*args,
) -> SpotFindingResults:
"""
Find spots.
Parameters
----------
primary_image : ImageStack
ImageStack where we find the spots in.
reference_image : xr.DataArray
(Optional) a reference image. If provided, spots will be found in this image, and then
the locations that correspond to these spots will be measured across each channel.
n_processes : Optional[int] = None,
Number of processes to devote to spot finding.
"""
spot_finding_method = partial(self.image_to_spots, *args)
if reference_image:
data_image = reference_image._squeezed_numpy(*{Axes.ROUND, Axes.CH})
reference_spots = spot_finding_method(data_image)
results = spot_finding_utils.measure_spot_intensities(
primary_image,
reference_spots,
measurement_function=self.measurement_function)
else:
spot_attributes_list = primary_image.transform(
func=spot_finding_method,
group_by=determine_axes_to_group_by(self.is_volume),
n_processes=n_processes
)
results = SpotFindingResults(imagestack_coords=primary_image.xarray.coords,
log=primary_image.log,
spot_attributes_list=spot_attributes_list)
return results
1 change: 1 addition & 0 deletions starfish/core/types/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,4 @@ class TraceBuildingStrategies(AugmentedEnum):
"""
EXACT_MATCH = 'exact_match'
NEAREST_NEIGHBOR = 'nearest_neighbor'
SEQUENTIAL = 'sequential'
Loading

0 comments on commit 3ad1a4e

Please sign in to comment.