diff --git a/REQUIREMENTS.txt b/REQUIREMENTS.txt index b1c700a84..e645accc3 100644 --- a/REQUIREMENTS.txt +++ b/REQUIREMENTS.txt @@ -16,7 +16,7 @@ semantic_version # 0.16.[012] are excluded because https://github.com/scikit-image/scikit-image/pull/3984 introduced # a bug into max peak finder. 0.16.3 presumably will have the fix from # https://github.com/scikit-image/scikit-image/pull/4263. -scikit-image >= 0.14.0, != 0.16.0.*, != 0.16.1.*, != 0.16.2.*, != 0.17.1.*, != 0.17.2.* +scikit-image >= 0.14.0, != 0.16.0.*, != 0.16.1.*, != 0.16.2.*, != 0.17.1.*, != 0.17.2.*, < 0.19.0 scikit-learn scipy showit >= 1.1.4 @@ -26,4 +26,4 @@ tqdm trackpy validators xarray >= 0.14.1 -ipywidgets +ipywidgets \ No newline at end of file diff --git a/requirements/REQUIREMENTS-CI.txt b/requirements/REQUIREMENTS-CI.txt index 37450b48c..677bad25a 100644 --- a/requirements/REQUIREMENTS-CI.txt +++ b/requirements/REQUIREMENTS-CI.txt @@ -116,7 +116,7 @@ slicedimage==4.1.1 snowballstemmer==2.1.0 Sphinx==4.1.2 sphinx-autodoc-typehints==1.12.0 -sphinx-bootstrap-theme==0.7.1 +sphinx-bootstrap-theme==0.8.1 sphinx-gallery==0.9.0 sphinx-rtd-theme==0.5.2 sphinxcontrib-applehelp==1.0.2 @@ -147,5 +147,4 @@ wcwidth==0.2.5 webencodings==0.5.1 widgetsnbextension==3.5.1 xarray==0.19.0 -zipp==3.5.0 - +zipp==3.5.0 \ No newline at end of file diff --git a/requirements/REQUIREMENTS-CI.txt.in b/requirements/REQUIREMENTS-CI.txt.in index c53a1ea0b..0d16f555d 100644 --- a/requirements/REQUIREMENTS-CI.txt.in +++ b/requirements/REQUIREMENTS-CI.txt.in @@ -22,4 +22,4 @@ sphinx_bootstrap_theme sphinxcontrib-programoutput sphinx-gallery sphinx_rtd_theme -twine +twine \ No newline at end of file diff --git a/requirements/REQUIREMENTS-NAPARI-CI.txt b/requirements/REQUIREMENTS-NAPARI-CI.txt index a266ecb5e..4d7989afa 100644 --- a/requirements/REQUIREMENTS-NAPARI-CI.txt +++ b/requirements/REQUIREMENTS-NAPARI-CI.txt @@ -1,142 +1,160 @@ # You should not edit this file directly. Instead, you should edit one of the following files (requirements/REQUIREMENTS-NAPARI-CI.txt.in) and run make requirements/REQUIREMENTS-NAPARI-CI.txt alabaster==0.7.12 appdirs==1.4.4 -argon2-cffi==21.1.0 -attrs==21.2.0 -Babel==2.9.1 +argon2-cffi==21.3.0 +argon2-cffi-bindings==21.2.0 +asttokens==2.0.5 +attrs==21.4.0 +Babel==2.10.3 backcall==0.2.0 -bleach==4.1.0 -boto3==1.18.37 -botocore==1.21.37 +beautifulsoup4==4.11.1 +bleach==5.0.0 +boto3==1.24.14 +botocore==1.27.14 +build==0.8.0 cachey==0.2.1 -certifi==2021.5.30 -cffi==1.14.6 -charset-normalizer==2.0.4 -click==8.0.1 -cloudpickle==1.6.0 -cycler==0.10.0 -dask==2021.10.0 +certifi==2022.6.15 +cffi==1.15.0 +charset-normalizer==2.0.12 +click==8.1.3 +cloudpickle==2.1.0 +cycler==0.11.0 +dask==2022.6.0 dataclasses==0.6 -debugpy==1.4.1 +debugpy==1.6.0 decorator==4.4.2 defusedxml==0.7.1 -diskcache==5.2.1 -docstring-parser==0.10 -docutils==0.17.1 -entrypoints==0.3 -freetype-py==2.2.0 -fsspec==2021.8.1 -h5py==3.4.0 +diskcache==5.4.0 +docstring-parser==0.14.1 +docutils==0.18.1 +entrypoints==0.4 +executing==0.8.3 +fastjsonschema==2.15.3 +fonttools==4.33.3 +freetype-py==2.3.0 +fsspec==2022.5.0 +h5py==3.7.0 HeapDict==1.0.1 -hsluv==5.0.2 -idna==3.2 -imageio==2.9.0 -imagesize==1.2.0 +hsluv==5.0.3 +idna==3.3 +imageio==2.19.3 +imagesize==1.3.0 +importlib-metadata==4.11.4 iniconfig==1.1.1 -ipykernel==6.3.1 -ipython==7.27.0 +ipykernel==6.15.0 +ipython==8.4.0 ipython-genutils==0.2.0 -ipywidgets==7.6.4 -jedi==0.18.0 -Jinja2==3.0.1 -jmespath==0.10.0 -joblib==1.0.1 -jsonschema==3.2.0 -jupyter-client==7.0.2 -jupyter-core==4.7.1 -jupyterlab-pygments==0.1.2 -jupyterlab-widgets==1.0.1 -kiwisolver==1.3.2 -locket==0.2.1 -magicgui==0.2.10 -MarkupSafe==2.0.1 -matplotlib==3.4.3 +ipywidgets==7.7.1 +jedi==0.18.1 +Jinja2==3.1.2 +jmespath==1.0.1 +joblib==1.1.0 +jsonschema==4.6.0 +jupyter-client==7.3.4 +jupyter-core==4.10.0 +jupyterlab-pygments==0.2.2 +jupyterlab-widgets==1.1.1 +kiwisolver==1.4.3 +locket==1.0.0 +magicgui==0.5.1 +MarkupSafe==2.1.1 +matplotlib==3.5.2 matplotlib-inline==0.1.3 mistune==0.8.4 mpmath==1.2.1 -napari==0.4.11 +napari==0.4.16 napari-console==0.0.4 -napari-plugin-engine==0.1.9 -napari-svg==0.1.5 -nbclient==0.5.4 -nbconvert==6.1.0 -nbformat==5.1.3 -nest-asyncio==1.5.1 -networkx==2.6.2 -notebook==6.4.3 -numpy==1.21.2 -numpydoc==1.1.0 -packaging==21.0 -pandas==1.3.2 -pandocfilters==1.4.3 -parso==0.8.2 +napari-plugin-engine==0.2.0 +napari-svg==0.1.6 +nbclient==0.6.4 +nbconvert==6.5.0 +nbformat==5.4.0 +nest-asyncio==1.5.5 +networkx==2.8.4 +notebook==6.4.12 +npe2==0.5.0 +numpy==1.22.4 +numpydoc==1.4.0 +packaging==21.3 +pandas==1.4.2 +pandocfilters==1.5.0 +parso==0.8.3 partd==1.2.0 +pep517==0.12.0 pexpect==4.8.0 pickleshare==0.7.5 -Pillow==8.3.2 -Pint==0.17 +Pillow==9.1.1 +Pint==0.19.2 pluggy==1.0.0 -prometheus-client==0.11.0 -prompt-toolkit==3.0.20 -psutil==5.8.0 +prometheus-client==0.14.1 +prompt-toolkit==3.0.29 +psutil==5.9.1 +psygnal==0.3.5 ptyprocess==0.7.0 -py==1.10.0 -pycparser==2.20 -pydantic==1.8.2 -Pygments==2.10.0 -PyOpenGL==3.1.5 -pyparsing==2.4.7 +pure-eval==0.2.2 +py==1.11.0 +pycparser==2.21 +pydantic==1.9.1 +Pygments==2.12.0 +PyOpenGL==3.1.6 +pyparsing==3.0.9 PyQt5==5.14.2 -PyQt5-sip==12.9.0 -pyrsistent==0.18.0 -pytest==6.2.5 +PyQt5-sip==12.11.0 +pyrsistent==0.18.1 +pytest==7.1.2 pytest-qt==4.0.2 python-dateutil==2.8.2 -pytz==2021.1 -PyWavelets==1.1.1 -PyYAML==5.4.1 -pyzmq==22.2.1 -qtconsole==5.1.1 -QtPy==1.11.0 +pytomlpp==1.0.11 +pytz==2022.1 +PyWavelets==1.3.0 +PyYAML==6.0 +pyzmq==23.2.0 +qtconsole==5.3.1 +QtPy==2.1.0 read-roi==1.6.0 regional==1.1.2 -requests==2.26.0 -s3transfer==0.5.0 +requests==2.28.0 +s3transfer==0.6.0 scikit-image==0.18.3 -scikit-learn==0.24.2 -scipy==1.7.1 -semantic-version==2.8.5 +scikit-learn==1.1.1 +scipy==1.8.1 +semantic-version==2.10.0 Send2Trash==1.8.0 -setuptools==56.0.0 +setuptools==58.1.0 showit==1.1.4 six==1.16.0 slicedimage==4.1.1 -snowballstemmer==2.1.0 -Sphinx==4.1.2 +snowballstemmer==2.2.0 +soupsieve==2.3.2.post1 +Sphinx==5.0.2 +sphinx-bootstrap-theme==0.8.1 sphinxcontrib-applehelp==1.0.2 sphinxcontrib-devhelp==1.0.2 sphinxcontrib-htmlhelp==2.0.0 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 -superqt==0.2.3 +stack-data==0.3.0 +superqt==0.3.2 sympy==1.5.1 -terminado==0.12.1 -testpath==0.5.0 -threadpoolctl==2.2.0 -tifffile==2021.8.30 -toml==0.10.2 -toolz==0.11.1 +terminado==0.15.0 +threadpoolctl==3.1.0 +tifffile==2022.5.4 +tinycss2==1.1.1 +tomli==2.0.1 +toolz==0.11.2 tornado==6.1 -tqdm==4.62.2 +tqdm==4.64.0 trackpy==0.5.0 -traitlets==5.1.0 -typing-extensions==3.10.0.2 -urllib3==1.26.6 -validators==0.18.2 -vispy==0.8.1 +traitlets==5.3.0 +typer==0.4.1 +typing_extensions==4.2.0 +urllib3==1.26.9 +validators==0.20.0 +vispy==0.10.0 wcwidth==0.2.5 webencodings==0.5.1 -widgetsnbextension==3.5.1 -wrapt==1.12.1 -xarray==0.19.0 +widgetsnbextension==3.6.1 +wrapt==1.14.1 +xarray==2022.3.0 +zipp==3.8.0 diff --git a/starfish/REQUIREMENTS-STRICT.txt b/starfish/REQUIREMENTS-STRICT.txt index d0d9d1d68..eda66e668 100644 --- a/starfish/REQUIREMENTS-STRICT.txt +++ b/starfish/REQUIREMENTS-STRICT.txt @@ -1,94 +1,105 @@ # You should not edit this file directly. Instead, you should edit one of the following files (REQUIREMENTS.txt) and run make starfish/REQUIREMENTS-STRICT.txt -argon2-cffi==21.1.0 -attrs==21.2.0 +argon2-cffi==21.3.0 +argon2-cffi-bindings==21.2.0 +asttokens==2.0.5 +attrs==21.4.0 backcall==0.2.0 -bleach==4.1.0 -boto3==1.18.37 -botocore==1.21.37 -certifi==2021.5.30 -cffi==1.14.6 -charset-normalizer==2.0.4 -click==8.0.1 -cycler==0.10.0 +beautifulsoup4==4.11.1 +bleach==5.0.0 +boto3==1.24.14 +botocore==1.27.14 +certifi==2022.6.15 +cffi==1.15.0 +charset-normalizer==2.0.12 +click==8.1.3 +cycler==0.11.0 dataclasses==0.6 -debugpy==1.4.1 +debugpy==1.6.0 decorator==4.4.2 defusedxml==0.7.1 -diskcache==5.2.1 -entrypoints==0.3 -h5py==3.4.0 -idna==3.2 -imageio==2.9.0 -ipykernel==6.3.1 -ipython==7.27.0 +diskcache==5.4.0 +entrypoints==0.4 +executing==0.8.3 +fastjsonschema==2.15.3 +fonttools==4.33.3 +h5py==3.7.0 +idna==3.3 +imageio==2.19.3 +ipykernel==6.15.0 +ipython==8.4.0 ipython-genutils==0.2.0 -ipywidgets==7.6.4 -jedi==0.18.0 -Jinja2==3.0.1 -jmespath==0.10.0 -joblib==1.0.1 -jsonschema==3.2.0 -jupyter-client==7.0.2 -jupyter-core==4.7.1 -jupyterlab-pygments==0.1.2 -jupyterlab-widgets==1.0.1 -kiwisolver==1.3.2 -MarkupSafe==2.0.1 -matplotlib==3.4.3 +ipywidgets==7.7.1 +jedi==0.18.1 +Jinja2==3.1.2 +jmespath==1.0.1 +joblib==1.1.0 +jsonschema==4.6.0 +jupyter-client==7.3.4 +jupyter-core==4.10.0 +jupyterlab-pygments==0.2.2 +jupyterlab-widgets==1.1.1 +kiwisolver==1.4.3 +MarkupSafe==2.1.1 +matplotlib==3.5.2 matplotlib-inline==0.1.3 mistune==0.8.4 mpmath==1.2.1 -nbclient==0.5.4 -nbconvert==6.1.0 -nbformat==5.1.3 -nest-asyncio==1.5.1 -networkx==2.6.2 -notebook==6.4.3 -numpy==1.21.2 -packaging==21.0 -pandas==1.3.2 -pandocfilters==1.4.3 -parso==0.8.2 +nbclient==0.6.4 +nbconvert==6.5.0 +nbformat==5.4.0 +nest-asyncio==1.5.5 +networkx==2.8.4 +notebook==6.4.12 +numpy==1.22.4 +packaging==21.3 +pandas==1.4.2 +pandocfilters==1.5.0 +parso==0.8.3 pexpect==4.8.0 pickleshare==0.7.5 -Pillow==8.3.2 -prometheus-client==0.11.0 -prompt-toolkit==3.0.20 +Pillow==9.1.1 +prometheus-client==0.14.1 +prompt-toolkit==3.0.29 +psutil==5.9.1 ptyprocess==0.7.0 -pycparser==2.20 -Pygments==2.10.0 -pyparsing==2.4.7 -pyrsistent==0.18.0 +pure-eval==0.2.2 +pycparser==2.21 +Pygments==2.12.0 +pyparsing==3.0.9 +pyrsistent==0.18.1 python-dateutil==2.8.2 -pytz==2021.1 -PyWavelets==1.1.1 -PyYAML==5.4.1 -pyzmq==22.2.1 +pytz==2022.1 +PyWavelets==1.3.0 +PyYAML==6.0 +pyzmq==23.2.0 read-roi==1.6.0 regional==1.1.2 -requests==2.26.0 -s3transfer==0.5.0 +requests==2.28.0 +s3transfer==0.6.0 scikit-image==0.18.3 -scikit-learn==0.24.2 -scipy==1.7.1 -semantic-version==2.8.5 +scikit-learn==1.1.1 +scipy==1.8.1 +semantic-version==2.10.0 Send2Trash==1.8.0 -setuptools==56.0.0 +setuptools==58.1.0 showit==1.1.4 six==1.16.0 slicedimage==4.1.1 +soupsieve==2.3.2.post1 +sphinx-bootstrap-theme==0.8.1 +stack-data==0.3.0 sympy==1.5.1 -terminado==0.12.1 -testpath==0.5.0 -threadpoolctl==2.2.0 -tifffile==2021.8.30 +terminado==0.15.0 +threadpoolctl==3.1.0 +tifffile==2022.5.4 +tinycss2==1.1.1 tornado==6.1 -tqdm==4.62.2 +tqdm==4.64.0 trackpy==0.5.0 -traitlets==5.1.0 -urllib3==1.26.6 -validators==0.18.2 +traitlets==5.3.0 +urllib3==1.26.9 +validators==0.20.0 wcwidth==0.2.5 webencodings==0.5.1 -widgetsnbextension==3.5.1 -xarray==0.19.0 +widgetsnbextension==3.6.1 +xarray==2022.3.0 diff --git a/starfish/core/intensity_table/decoded_intensity_table.py b/starfish/core/intensity_table/decoded_intensity_table.py index 7a062ab8b..091086239 100644 --- a/starfish/core/intensity_table/decoded_intensity_table.py +++ b/starfish/core/intensity_table/decoded_intensity_table.py @@ -54,7 +54,6 @@ class DecodedIntensityTable(IntensityTable): * c (c) int64 0 1 2 * h (h) int64 0 1 2 3 target (features) object 08b1a822-a1b4-4e06-81ea-8a4bd2b004a9 ... - """ __slots__ = () @@ -65,7 +64,9 @@ def from_intensity_table( intensities: IntensityTable, targets: Tuple[str, np.ndarray], distances: Optional[Tuple[str, np.ndarray]] = None, - passes_threshold: Optional[Tuple[str, np.ndarray]] = None): + passes_threshold: Optional[Tuple[str, np.ndarray]] = None, + rounds_used: Optional[Tuple[str, np.ndarray]] = None): + """ Assign target values to intensities. @@ -80,6 +81,9 @@ def from_intensity_table( passes_threshold : Optional[Tuple[str, np.ndarray]] Corresponding array of boolean values indicating if each itensity passed given thresholds. + rounds_used: Optional[Tuple[str, np.ndarray]] + Corresponding array of integers indicated the number of rounds this + decoded intensity was found in Returns ------- @@ -92,6 +96,8 @@ def from_intensity_table( intensities[Features.DISTANCE] = distances if passes_threshold: intensities[Features.PASSES_THRESHOLDS] = passes_threshold + if rounds_used: + intensities['rounds_used'] = rounds_used return intensities def to_decoded_dataframe(self) -> DecodedSpots: @@ -120,7 +126,6 @@ def to_mermaid(self, filename: str) -> pd.DataFrame: Notes ------ See also https://github.com/JEFworks/MERmaid - """ # construct the MERMAID dataframe. As MERMAID adds support for non-categorical variables, # additional columns can be added here diff --git a/starfish/core/spots/DecodeSpots/__init__.py b/starfish/core/spots/DecodeSpots/__init__.py index 62803d464..5660f1339 100644 --- a/starfish/core/spots/DecodeSpots/__init__.py +++ b/starfish/core/spots/DecodeSpots/__init__.py @@ -1,4 +1,5 @@ from ._base import DecodeSpotsAlgorithm +from .check_all_decoder import CheckAll from .metric_decoder import MetricDistance from .per_round_max_channel_decoder import PerRoundMaxChannel from .simple_lookup_decoder import SimpleLookupDecoder diff --git a/starfish/core/spots/DecodeSpots/check_all_decoder.py b/starfish/core/spots/DecodeSpots/check_all_decoder.py new file mode 100644 index 000000000..77a3c8a05 --- /dev/null +++ b/starfish/core/spots/DecodeSpots/check_all_decoder.py @@ -0,0 +1,440 @@ +from collections import Counter +from copy import deepcopy +from typing import Any, Hashable, Mapping, Tuple + +import numpy as np +import pandas as pd + +from starfish.core.codebook.codebook import Codebook +from starfish.core.intensity_table.decoded_intensity_table import DecodedIntensityTable +from starfish.core.intensity_table.intensity_table import IntensityTable +from starfish.core.intensity_table.intensity_table_coordinates import \ + transfer_physical_coords_to_intensity_table +from starfish.core.types import SpotFindingResults +from starfish.types import Axes, Features +from ._base import DecodeSpotsAlgorithm +from .check_all_funcs import buildBarcodes, cleanup, createNeighborDict, createRefDicts, decoder, \ + distanceFilter, findNeighbors, removeUsedSpots +from .util import _merge_spots_by_round + +class CheckAll(DecodeSpotsAlgorithm): + """ + Decode spots by generating all possible combinations of neighboring spots to form barcodes + given a radius distance that spots may be from each other in order to form a barcode. Then + chooses the best set of nonoverlapping spot combinations by choosing the ones with the least + spatial variance of their spot coordinates, highest normalized intensity and are also found + to be best for multiple spots in the barcode (see algorithm below). Allows for one error + correction round (option for more may be added in the future). + + Two slightly different algorithms are used to balance the precision (proportion of targets that + represent true mRNA molecules) and recall (proportion of true mRNA molecules that are + recovered). They share mostly the same steps but two are switched between the different + versions. The following is for the "filter-first" version: + + 1. For each spot in each round, find all neighbors in other rounds that are within the search + radius + 2. For each spot in each round, build all possible full length barcodes based on the channel + labels of the spot's neighbors and itself + 3. Choose the "best" barcode of each spot's possible barcodes by calculating a score that is + based on minimizing the spatial variance and maximizing the intensities of the spots in the + barcode. Each spot is assigned a "best" barcode in this way. + 4. Drop "best" barcodes that don't have a matching target in the codebook + 5. Only keep barcodes/targets that were found as "best" using at least x of the spots that make + each up (x is determined by parameters) + 6. Find maximum independent set (approximation) of the spot combinations so no two barcodes use + the same spot + + The other method (which I'll call "decode-first") is the same except steps 3 and 4 are switched + so that the minimum scoring barcode is chosen from the set of possible codes that have a match + to the codebook. The filter-first method will return fewer decoded targets (lower recall) but + has a lower false positive rate (higher precision) while the other method will find more targets + (higher recall) but at the cost of an increased false positive rate (lower precision). + + Decoding is run in multiple stages with the parameters becoming less strict as it gets into + later stages. The high accuracy algorithm (filter-first) is always run first followed by the low + accuracy method (decode-first), each with slightly different parameters based on the choice of + "mode" parameter. After each decoding, the spots found to be in decoded barcodes are removed + from the original set of spots before they are decoded again with a new set of parameters. In + order to simplify the number of parameters to choose from, I have sorted them into three sets of + presets ("high", "medium", or "low" accuracy) determined by the "mode" parameter. + + Decoding is also done multiple times at multiple search radius values that start at 0 and + increase incrementally until they reach the user-specified search radius. This allows high + confidence barcodes to be called first and make things easier when later codes are called. + + If error_rounds is set to 1 (currently cannot handle more than 1), after running all decodings + for barcodes that exactly match the codebook, another set of decodings will be run to find + barcodes that are missing a spot in exactly one round. If the codes in the codebook all have a + hamming distance of at least 2 from all other codes, each can still be uniquely identified + using a partial code with a single round dropped. Barcodes decoded with a partial code like this + are inherently less accurate and so an extra dimension called "rounds_used" was added to the + DecodedIntensityTable output that labels each decoded target with the number of rounds that was + used to decode it, allowing you to easily separate these less accurate codes from your high + accuracy set if you wish + + + Parameters + ---------- + codebook : Codebook + Contains codes to decode IntensityTable + search_radius : float + Maximum allowed distance (in pixels) that spots in different rounds can be from each other + and still be allowed to be combined into a barcode together + error_rounds : int + Maximum hamming distance a barcode can be from it's target in the codebook and still be + uniquely identified (i.e. number of error correction rounds in each the experiment) + mode : string + One of three preset parmaters sets. Choices are: "low", "med", or 'high'. Low accuracy mode + will return more decoded targets but at the cost to accuracy (high recall, low precision) + while the high accuracy version will find fewer false postives but also fewer targets + overall (high precision, low recall), medium is a balance between the two. + physical_coords : bool + True or False, should decoding using physical distances from the original imagestack that + you performed spot finding on? Should be used when distances between z pixels is much + greater than distance between x and y pixels. + """ + + def __init__( + self, + codebook: Codebook, + search_radius: float=3, + error_rounds: int=0, + mode='med', + physical_coords=False): + self.codebook = codebook + self.searchRadius = search_radius + self.errorRounds = error_rounds + self.mode = mode + self.physicalCoords = physical_coords + + # Error checking for some inputs + + # Check that codebook is the right class and not empty + if not isinstance(self.codebook, Codebook) or len(codebook) == 0: + raise ValueError( + 'codebook is either not a Codebook object or is empty') + # Check that error_rounds is either 0 or 1 + if self.errorRounds not in [0, 1]: + raise ValueError( + 'error_rounds can only take a value of 0 or 1') + # Return error if search radius is greater than 4.5 or negative + if self.searchRadius < 0 or self.searchRadius > 4.5: + raise ValueError( + 'search_radius must be positive w/ max value of 4.5') + + def run(self, + spots: SpotFindingResults, + n_processes: int=1, + *args) -> DecodedIntensityTable: + """ + Decode spots by finding the set of nonoverlapping barcodes that have the minimum spatial + variance within each barcode. + Parameters + ---------- + spots: SpotFindingResults + A Dict of tile indices and their corresponding measured spots + n_processes: int + Number of threads to run decoder in parallel with + Returns + ------- + DecodedIntensityTable : + IntensityTable decoded and appended with Features.TARGET values. + """ + + # Rename n_processes (trying to stay consistent between starFISH's _ variables and my + # camel case ones) + numJobs = n_processes + # Check that numJobs is a positive integer + if numJobs < 0 or not isinstance(numJobs, int): + raise ValueError( + 'n_process must be a positive integer') + + # Create dictionary where keys are round labels and the values are pandas dataframes + # containing information on the spots found in that round + spotTables = _merge_spots_by_round(spots) + + # Check that enough rounds have spots to make at least one barcode + spotsPerRound = [len(spotTables[r]) for r in range(len(spotTables))] + counter = Counter(spotsPerRound) + if counter[0] > self.errorRounds: + raise ValueError( + 'Not enough spots to form a barcode') + + # If using physical coordinates, extract z and xy scales and check that they are all > 0 + if self.physicalCoords: + physicalCoords = spots.physical_coord_ranges + if len(physicalCoords['z'].data) > 1: + zScale = physicalCoords['z'][1].data - physicalCoords['z'][0].data + else: + zScale = 1 + yScale = physicalCoords['y'][1].data - physicalCoords['y'][0].data + xScale = physicalCoords['x'][1].data - physicalCoords['x'][0].data + if xScale <= 0 or yScale <= 0 or zScale <= 0: + raise ValueError( + 'invalid physical coords') + + # Add one to channels labels (prevents collisions between hashes of barcodes later), adds + # unique spot_id column for each spot in each round, and scales the x, y, and z columns to + # the phsyical coordinates if specified + for r in spots.round_labels: + spotTables[r]['c'] += 1 + spotTables[r]['spot_id'] = range(1, len(spotTables[r]) + 1) + if self.physicalCoords: + spotTables[r]['z'] = spotTables[r]['z'] * zScale + spotTables[r]['y'] = spotTables[r]['y'] * yScale + spotTables[r]['x'] = spotTables[r]['x'] * xScale + + # Choose search radius set based on search_radius parameter and ability for spots to be + # neighbors across z slices. Each value in allSearchRadii represents an incremental + # increase in neighborhood size + set1 = False + zs = set() + for r in range(len(spotTables)): + zs.update(spotTables[r]['z']) + if self.physicalCoords: + if zScale < self.searchRadius and len(zs) > 1: + set1 = True + else: + if len(zs) > 1: + set1 = True + if set1: + allSearchRadii = np.array([0, 1.05, 1.5, 1.8, 2.05, 2.3, 2.45, 2.85, 3.05, 3.2, + 3.35, 3.5, 3.65, 3.75, 4.05, 4.15, 4.25, 4.4, 4.5]) + else: + allSearchRadii = np.array([0, 1.05, 1.5, 2.05, 2.3, 2.85, 3.05, 3.2, 3.65, 4.05, 4.15, + 4.25, 4.5]) + + maxRadii = allSearchRadii[(allSearchRadii - self.searchRadius) >= 0][0] + radiusSet = allSearchRadii[allSearchRadii <= maxRadii] + + # Calculate neighbors for each radius in the set (done only once and referred back to + # throughout decodings) + neighborsByRadius = {} + for searchRadius in radiusSet: + if self.physicalCoords: + searchRadius = round(searchRadius * xScale, 5) + neighborsByRadius[searchRadius] = findNeighbors(spotTables, searchRadius, numJobs) + + # Create reference dictionaries for spot channels, coordinates, raw intensities, and + # normalized intensities. Each is a dict w/ keys equal to the round labels and each + # value is a dict with spot IDs in that round as keys and their corresponding value + # (channel label, spatial coords, etc) + channelDict, spotCoords, spotIntensities, spotQualDict = createRefDicts(spotTables, numJobs) + + # Add spot quality (normalized spot intensity) tp spotTables + for r in range(len(spotTables)): + spotTables[r]['spot_quals'] = [spotQualDict[r][spot] for spot in + spotTables[r]['spot_id']] + + # Set list of round omission numbers to loop through + roundOmits = range(self.errorRounds + 1) + + # Set parameters according to presets (determined empirically). Strictness value determines + # the decoding method used and the allowed number of possible barcode choices (positive + # for filter-first, negative for decode-first). + if self.mode == 'high': + strictnesses = [50, -1] + seedNumbers = [len(spotTables) - 1, len(spotTables)] + minDist = 3 + if self.errorRounds == 1: + strictnesses.append(1) + seedNumbers.append(len(spotTables) - 1) + elif self.mode == 'med': + strictnesses = [50, -5] + seedNumbers = [len(spotTables) - 1, len(spotTables)] + minDist = 3 + if self.errorRounds == 1: + strictnesses.append(5) + seedNumbers.append(len(spotTables) - 1) + elif self.mode == 'low': + strictnesses = [50, -100] + seedNumbers = [len(spotTables) - 1, len(spotTables) - 1] + minDist = 100 + if self.errorRounds == 1: + strictnesses.append(10) + seedNumbers.append(len(spotTables) - 1) + else: + raise ValueError( + 'Invalid mode choice ("high", "med", or "low")') + + # Decode for each round omission number, intensity cutoff, and then search radius + allCodes = pd.DataFrame() + for currentRoundOmitNum in roundOmits: + for s, strictness in enumerate(strictnesses): + + # Set seedNumber according to parameters for this strictness value + seedNumber = seedNumbers[s] + + # First decodes only the highest normalized intensity spots then adds in the rest + for intVal in range(50, -1, -50): + + # First check that there are enough spots left otherwise an error will occur + spotsPerRound = [len(spotTables[r]) for r in range(len(spotTables))] + counter = Counter(spotsPerRound) + condition3 = True if counter[0] > currentRoundOmitNum else False + if not condition3: + # Subset spots by intensity, start with top 50% then decode again with all + currentTables = {} + for r in range(len(spotTables)): + + if len(spotTables[r]) > 0: + lowerBound = np.percentile(spotTables[r]['spot_quals'], intVal) + currentTables[r] = spotTables[r][spotTables[r]['spot_quals'] + >= lowerBound] + else: + currentTables[r] = pd.DataFrame() + + # Decode each radius and remove spots found in each decoding before the next + for sr, searchRadius in enumerate(radiusSet): + + # Scale radius by xy scale if needed + if self.physicalCoords: + searchRadius = round(searchRadius * xScale, 5) + + # Only run partial codes for the final strictness and don't run full + # barcodes for the final strictness. Also don't run if there are not + # enough spots left. + condition1 = (currentRoundOmitNum == 1 and s != len(strictnesses) - 1) + condition2 = (len(roundOmits) > 1 and currentRoundOmitNum == 0 + and s == len(strictnesses) - 1) + + if condition1 or condition2 or condition3: + pass + else: + + # Creates neighbor dictionary for the current radius and current set of + # spots + neighborDict = createNeighborDict(currentTables, searchRadius, + neighborsByRadius) + + # Find best spot combination using each spot in each round as seed + decodedTables = {} + for r in range(len(spotTables)): + + if len(spotTables[r]) > 0: + + # roundData will carry the possible barcode info for each spot + # in the current round being examined + roundData = deepcopy(currentTables[r]) + + # Drop all but the spot_id column + roundData = roundData[['spot_id']] + + # From each spot's neighbors, create all possible combinations + # that would form a barocde with the correct number of rounds. + # Adds spot_codes column to roundData + + roundData = buildBarcodes(roundData, neighborDict, + currentRoundOmitNum, channelDict, + strictness, r, numJobs) + + # When strictness is positive the filter-first methods is used + # and distanceFilter is run first on all the potential barcodes + # to choose the one with the minimum score (based on spatial + # variance of the spots and their intensities) which are then + # matched to the codebook. Spots that have more possible + # barcodes to choose between than the current strictnessnumber + # are dropped as ambiguous. If strictness is negative, the + # decode-first method is run where all the possible barcodes + # are instead first matched to the codebook and then the lowest + # scoring decodable spot combination is chosen for each spot. + # Spots that have more decodable barcodes to choose from than + # the strictness value (absolute value) are dropped. + if strictness > 0: + + # Choose most likely combination of spots for each seed + # spot using their spatial variance and normalized intensity + # values. Adds distance column to roundData + roundData = distanceFilter(roundData, spotCoords, + spotQualDict, r, + currentRoundOmitNum, numJobs) + + # Match possible barcodes to codebook. Adds target column + # to roundData + roundData = decoder(roundData, self.codebook, channelDict, + strictness, currentRoundOmitNum, r, + numJobs) + + else: + + # Match possible barcodes to codebook. Adds target column + # to roundData + roundData = decoder(roundData, self.codebook, channelDict, + strictness, currentRoundOmitNum, r, + numJobs) + + # Choose most likely combination of spots for each seed + # spot using their spatial variance and normalized + # intensity values. Adds distance column to roundData + roundData = distanceFilter(roundData, spotCoords, + spotQualDict, r, + currentRoundOmitNum, numJobs) + + # Assign to DecodedTables dictionary + decodedTables[r] = roundData + + else: + decodedTables[r] = pd.DataFrame() + + # Turn spot table dictionary into single table, filter barcodes by + # the seed number, add additional information, and choose between + # barcodes that have overlapping spots + finalCodes = cleanup(decodedTables, spotCoords, channelDict, + strictness, currentRoundOmitNum, seedNumber) + + # Remove spots that have just been found to be in passing barcodes from + # neighborDict so they are not used for the next decoding round and + # filter codes whose distance value is above the minimum + if len(finalCodes) > 0: + finalCodes = finalCodes[finalCodes['distance'] <= minDist] + spotTables = removeUsedSpots(finalCodes, spotTables) + currentTables = removeUsedSpots(finalCodes, currentTables) + + # Append found codes to allCodes table + allCodes = allCodes.append(finalCodes).reset_index(drop=True) + + # Create and fill in intensity table + channels = spots.ch_labels + rounds = spots.round_labels + + # create empty IntensityTable filled with np.nan + data = np.full((len(allCodes), len(channels), len(rounds)), fill_value=np.nan) + dims = (Features.AXIS, Axes.CH.value, Axes.ROUND.value) + centers = allCodes['center'] + coords: Mapping[Hashable, Tuple[str, Any]] = { + Features.SPOT_RADIUS: (Features.AXIS, np.full(len(allCodes), 1)), + Axes.ZPLANE.value: (Features.AXIS, np.asarray([round(c[0]) for c in centers])), + Axes.Y.value: (Features.AXIS, np.asarray([round(c[1]) for c in centers])), + Axes.X.value: (Features.AXIS, np.asarray([round(c[2]) for c in centers])), + Features.SPOT_ID: (Features.AXIS, np.arange(len(allCodes))), + Features.AXIS: (Features.AXIS, np.arange(len(allCodes))), + Axes.ROUND.value: (Axes.ROUND.value, rounds), + Axes.CH.value: (Axes.CH.value, channels) + } + int_table = IntensityTable(data=data, dims=dims, coords=coords) + + # Fill in data values + table_codes = [] + for i in range(len(allCodes)): + code = [] + # ints = allCodes.loc[i, 'intensities'] + for j, ch in enumerate(allCodes.loc[i, 'best_barcodes']): + # If a round is not used, row will be all zeros + code.append(np.asarray([0 if k != ch - 1 else 1 for k in range(len(channels))])) + table_codes.append(np.asarray(code).T) + int_table.values = np.asarray(table_codes) + int_table = transfer_physical_coords_to_intensity_table(intensity_table=int_table, + spots=spots) + + # Validate results are correct shape + self.codebook._validate_decode_intensity_input_matches_codebook_shape(int_table) + + # Create DecodedIntensityTable + result = DecodedIntensityTable.from_intensity_table( + int_table, + targets=(Features.AXIS, allCodes['targets'].astype('U')), + distances=(Features.AXIS, allCodes["distance"]), + passes_threshold=(Features.AXIS, np.full(len(allCodes), True)), + rounds_used=(Features.AXIS, allCodes['rounds_used'])) + + return result diff --git a/starfish/core/spots/DecodeSpots/check_all_funcs.py b/starfish/core/spots/DecodeSpots/check_all_funcs.py new file mode 100644 index 000000000..3d36e42de --- /dev/null +++ b/starfish/core/spots/DecodeSpots/check_all_funcs.py @@ -0,0 +1,911 @@ +import typing +from collections import Counter, defaultdict +from concurrent.futures.process import ProcessPoolExecutor +from copy import deepcopy +from functools import partial +from itertools import chain, islice, permutations, product + +import numpy as np +import pandas as pd +from scipy.spatial import cKDTree + +from starfish.core.codebook.codebook import Codebook +from starfish.types import Axes + +def findNeighbors(spotTables: dict, + searchRadius: float, + numJobs: int) -> dict: + + ''' + Using scipy's cKDTree method, finds all neighbors within the seach radius between the spots in + each pair of rounds and stores the indices in a dictionary for later access. + Parameters + ---------- + spotTables : dict + Dictionary with round labels as keys and pandas dataframes containing spot information + for its key round as values (result of _merge_spots_by_round function) + searchRadius : float + Distance that spots can be from each other and still form a barcode + numJobs : int + Number of CPU threads to use in parallel + Returns + ------- + dict: a dictionary with the following structure: + {(round1, round2): index table showing neighbors between spots in round1 and round2 + where round1 != round2} + ''' + + allNeighborDict = {} + for r1 in range((len(spotTables))): + tree = cKDTree(spotTables[r1][['z', 'y', 'x']]) + for r2 in list(range((len(spotTables))))[r1 + 1:]: + allNeighborDict[(r1, r2)] = tree.query_ball_point(spotTables[r2][['z', 'y', 'x']], + searchRadius, workers=numJobs) + + return allNeighborDict + +def createNeighborDict(spotTables: dict, + searchRadius: float, + neighborsByRadius: dict) -> dict: + + ''' + Create dictionary of neighbors (within the search radius) in other rounds for each spot. + Parameters + ---------- + spotTables : dict + Dictionary with round labels as keys and pandas dataframes containing spot information + for its key round as values (result of _merge_spots_by_round function) + searchRadius : float + Distance that spots can be from each other and still form a barcode + neighborsByRadius : dict + Dictionary of outputs from findNeighbors() where each key is a radius and the value is + the findNeighbors dictionary + Returns + ------- + dict: a dictionary with the following structure + neighborDict[roundNum][spotID] = {0 : neighbors in round 0, 1: neighbors in round 1,etc} + ''' + + # Create empty neighbor dictionary + neighborDict = {} + spotIDs = {} + for r in spotTables: + if len(spotTables[r]) > 0: + spotIDs[r] = {idd: 0 for idd in spotTables[r]['spot_id']} + neighborDict[r] = {i: defaultdict(list, {r: [i]}) for i in spotTables[r]['spot_id']} + else: + neighborDict[r] = {} + + # Add neighbors in neighborsByRadius[searchRadius] but check to make sure that spot is still + # available before adding it + for r1 in range(len(spotTables)): + for r2 in list(range((len(spotTables))))[r1 + 1:]: + for j, neighbors in enumerate(neighborsByRadius[searchRadius][(r1, r2)]): + try: + spotIDs[r2][j + 1] + for neighbor in neighbors: + try: + spotIDs[r1][neighbor + 1] + neighborDict[r1][neighbor + 1][r2].append(j + 1) + neighborDict[r2][j + 1][r1].append(neighbor + 1) + except Exception: + pass + except Exception: + pass + return neighborDict + +def createRefDicts(spotTables: dict, numJobs: int) -> tuple: + + ''' + Create dictionaries with mapping from spot id (row index + 1) in spotTables to channel label, + spatial coordinates raw intensity and normalized intensity. + Parameters + ---------- + spotTables : dict + Dictionary with round labels as keys and pandas dataframes containing spot information + for its key round as values (result of _merge_spots_by_round function) + numJobs : int + Number of CPU threads to use in parallel + Returns + ------- + tuple : First object is the channel dictionary, second is the spatial coordinate dictionary, + the third object is the raw spot instensity dictionary, and the last object is the + normalized spot intensity dictionary + ''' + + # Create channel label and spatial coordinate dictionaries + channelDict = {} + spotCoords = {} + for r in [*spotTables]: + channelDict[r] = spotTables[r][['c', 'spot_id']].set_index('spot_id').to_dict()['c'] + channelDict[r][0] = 0 + tmpTable = spotTables[r][['z', 'y', 'x', 'spot_id']].set_index('spot_id') + spotCoords[r] = tmpTable.to_dict(orient='index') + for key in [*spotCoords[r]]: + spotCoords[r][key] = tuple(spotCoords[r][key].values()) + + # Create raw intensity dictionary + spotIntensities = {r: spotTables[r][['intensity', 'spot_id']].set_index('spot_id').to_dict() + ['intensity'] for r in [*spotTables]} + for r in [*spotTables]: + spotIntensities[r][0] = 0 + + # Create normalized intensity dictionary + spotQualDict = spotQuality(spotTables, spotCoords, spotIntensities, channelDict, numJobs) + + return channelDict, spotCoords, spotIntensities, spotQualDict + +def encodeSpots(spotCodes: list) -> list: + + ''' + For compressing spot ID codes into single integers. Saves memory. The number of digits in + each ID is counted and these integer lengths and concatenated into a string in the same + order as the IDs they correspond to. The IDs themselves are then converted to strings and + concatenated to this, also maintaining order. + Parameters + ---------- + spotCodes : list + List of spot codes (each a tuple of integers with length equal to the number of rounds) + Returns + ------- + list: List of compressed spot codes, one int per code + ''' + + strs = [list(map(str, code)) for code in spotCodes] + compressed = [int(''.join(map(str, map(len, intStr))) + ''.join(intStr)) for intStr in strs] + + return compressed + +def decodeSpots(compressed: list, roundNum: int) -> list: + + ''' + Reconverts compressed spot codes back into their roundNum length tupes of integers with + the same order and IDs as their original source. First roundNum values in the compressed + code will each correspond to the string length of each spot ID integer (as long as no round + has 10 billion or more spots). Can use these to determine how to split the rest of the string + to retrieve the original values in the correct order. + Parameters + ---------- + compressed : list + List of integer values corresponding to compressed spot codes + roundNum : int + The number of rounds in the experiment + Returns + ------- + list: List of recovered spot codes in their original tuple form + ''' + + strs = [str(intStr) for intStr in compressed] + idxs, nums = list(zip(*[(map(int, s[:roundNum]), [iter(s[roundNum:])] * roundNum) + for s in strs])) + decompressed = [tuple(int(''.join(islice(n, i))) for i, n in zip(idxs[j], nums[j])) + for j in range(len(idxs))] + return decompressed + +def spotQualityFunc(spots: list, + spotCoords: dict, + spotIntensities: dict, + spotTables: dict, + channelDict: dict, + r: int) -> list: + + ''' + Helper function for spotQuality to run in parallel + Parameters + ---------- + spots : list + List of spot IDs in the current round to calculate the normalized intensity of + spotCoords : dict + Spot ID to spatial coordinate dictionary + spotIntensities : dict + Spot ID to raw intensity dictionary + spotTables : dict + Dictionary containing spot info tables + channelDict : dict + Spot ID to channel label dictionary + r : int + Current round + Returns + ------- + list : list of normalized spot intensities of the input spot IDs + ''' + + # Find spots in the same neighborhood (same channel and z slice and less than 100 pixels away + # in either x or y direction) + neighborhood = 100 + quals = [] + for i, spot in enumerate(spots): + z, y, x = spotCoords[r][spot] + ch = channelDict[r][spot] + yMin = y - neighborhood if y - neighborhood >= 0 else 0 + yMax = y + neighborhood if y + neighborhood <= 2048 else 2048 + xMin = x - neighborhood if x - neighborhood >= 0 else 0 + xMax = x + neighborhood if x + neighborhood <= 2048 else 2048 + neighborInts = spotTables[r][(spotTables[r]['c'] == ch) + & (spotTables[r]['z'] == z) + & (spotTables[r]['y'] >= yMin) + & (spotTables[r]['y'] < yMax) + & (spotTables[r]['x'] >= xMin) + & (spotTables[r]['x'] < xMax)]['intensity'] + # If no neighbors drop requirement that they be within 100 pixels of each other + if len(neighborInts) == 1: + neighborInts = spotTables[r][(spotTables[r]['c'] == ch) + & (spotTables[r]['z'] == z)]['intensity'] + # If still no neighbors drop requirement that they be on the same z slice + if len(neighborInts) == 1: + neighborInts = spotTables[r][(spotTables[r]['c'] == ch)]['intensity'] + # Calculate the l2 norm of the neighbor's intensities and divide the spot's intensity by + # this value to get it's normalized intensity value + norm = np.linalg.norm(neighborInts) + quals.append(spotIntensities[r][spot] / norm) + + return quals + +def spotQuality(spotTables: dict, + spotCoords: dict, + spotIntensities: dict, + channelDict: dict, + numJobs: int) -> dict: + + ''' + Creates dictionary mapping each spot ID to their normalized intensity value. Calculated as the + spot intensity value divided by the l2 norm of the intensities of all the spots in the same + neighborhood. + Parameters + ---------- + spotTables : dict + Dictionary containing spot info tables + spotCoords : dict + Spot ID to spatial coordinate dictionary + spotIntensities : dict + Spot ID to raw intensity dictionary + channelDict : dict + Spot ID to channel label dictionary + numJobs : int + Number of CPU threads to use in parallel + Returns + ------- + dict : dictionary mapping spot ID to it's normalized intensity value + ''' + + # Calculate normalize spot intensities for each spot in each round + spotQuals = {} # type: dict + for r in range(len(spotTables)): + roundSpots = spotTables[r]['spot_id'] + spotQuals[r] = {} + + # Calculates index ranges to chunk data by + ranges = [0] + for i in range(1, numJobs): + ranges.append(round((len(roundSpots) / numJobs) * i)) + ranges.append(len(roundSpots)) + chunkedSpots = [roundSpots[ranges[i]:ranges[i + 1]] for i in range(len(ranges[:-1]))] + + # Run in parallel + with ProcessPoolExecutor() as pool: + part = partial(spotQualityFunc, spotCoords=spotCoords, spotIntensities=spotIntensities, + spotTables=spotTables, channelDict=channelDict, r=r) + poolMap = pool.map(part, [subSpots for subSpots in chunkedSpots]) + results = [x for x in poolMap] + + # Extract results + for spot, qual in zip(roundSpots, list(chain(*results))): + spotQuals[r][spot] = qual + + return spotQuals + +def barcodeBuildFunc(allNeighbors: list, + channelDict: dict, + currentRound: int, + roundOmitNum: int, + roundNum: int) -> list: + + ''' + Subfunction to buildBarcodes that allows it to run in parallel chunks + Parameters + ---------- + allNeighbors : list + List of neighbor from which to build barcodes from + channelDict : dict + Dictionary mapping spot IDs to their channels labels + currentRound : int + The round that the spots being used for reference points are found in + roundOmitNum : int + Maximum hamming distance a barcode can be from it's target in the codebook and + still be uniquely identified (i.e. number of error correction rounds in each + the experiment) + roundNum : int + Total number of round in experiment + Returns + ------- + list : list of the possible spot codes + ''' + + # spotCodes are the ordered spot IDs of the spots making up each barcode while barcodes are + # the corresponding channel labels, need spotCodes so each barcode can have a unique + # identifier + allSpotCodes = [] + for neighbors in allNeighbors: + neighborLists = [neighbors[rnd] for rnd in range(roundNum)] + # Adds a 0 to each round of the neighbors dictionary (allows barcodes with dropped + # rounds to be created) + if roundOmitNum > 0: + [neighbors[rnd].append(0) for rnd in range(roundNum)] + # Creates all possible spot code combinations from neighbors + codes = list(product(*neighborLists)) + # Only save the ones with the correct number of dropped rounds + counters = [Counter(code) for code in codes] # type: typing.List[Counter] + spotCodes = [code for j, code in enumerate(codes) if counters[j][0] == roundOmitNum] + spotCodes = [code for code in spotCodes if code[currentRound] != 0] + + allSpotCodes.append(encodeSpots(spotCodes)) + + return allSpotCodes + +def buildBarcodes(roundData: pd.DataFrame, + neighborDict: dict, + roundOmitNum: int, + channelDict: dict, + strictness: int, + currentRound: int, + numJobs: int) -> pd.DataFrame: + + ''' + Builds possible barcodes for each seed spot from its neighbors. First checks that each spot has + enough neighbors in each round to form a barcode and, depending on the strictness value, drops + spots who have too many possible barcodes to choose from + Parameters + ---------- + roundData : dict + Spot data table for the current round + neighborDict : dict + Dictionary that contains all the neighbors for each spot in other rounds that are + within the search radius + roundOmitNum : int + Maximum hamming distance a barcode can be from it's target in the codebook and still + be uniquely identified (i.e. number of error correction rounds in each the experiment + channelDict : dict + Dictionary with mappings between spot IDs and their channel labels + strictness: int + Determines the number of possible codes a spot is allowed to have before it is dropped + as ambiguous (if it is positive) + currentRound : int + Current round to build barcodes for (same round that roundData is from) + numJobs : int + Number of CPU threads to use in parallel + Returns + ------- + pd.DataFrame : Copy of roundData with an additional column which lists all the possible spot + codes that could be made from each spot's neighbors for those spots that + passed the strictness requirement (if it is positive) + ''' + + # Only keep spots that have enough neighbors to form a barcode (determined by the total number + # of rounds and the number of rounds that can be omitted from each code) and if strictness is + # positive, drop spots that have more than the strictness value number of possible barcodes + roundNum = len(neighborDict) + if strictness > 0: + passed = [key for key in neighborDict[currentRound] if + len(neighborDict[currentRound][key]) >= roundNum - roundOmitNum + and np.prod([len(values) for values in + neighborDict[currentRound][key].values()]) <= strictness] + else: + passed = [key for key in neighborDict[currentRound] if + len(neighborDict[currentRound][key]) >= roundNum - roundOmitNum] + roundData = roundData[roundData['spot_id'].isin(passed)].reset_index(drop=True) + roundData['neighbors'] = [neighborDict[currentRound][p] for p in passed] + + # Find all possible barcodes for the spots in each round by splitting each round's spots into + # numJob chunks and constructing each chunks barcodes in parallel + + # Calculates index ranges to chunk data by + ranges = [0] + for i in range(1, numJobs + 1): + ranges.append(round((len(roundData) / numJobs) * i)) + chunkedNeighbors = [list(roundData['neighbors'])[ranges[i]: ranges[i + 1]] for i in + range(len(ranges[:-1]))] + + # Run in parallel + with ProcessPoolExecutor() as pool: + part = partial(barcodeBuildFunc, channelDict=channelDict, currentRound=currentRound, + roundOmitNum=roundOmitNum, roundNum=roundNum) + poolMap = pool.map(part, [chunkedNeighbors[i] for i in range(len(chunkedNeighbors))]) + results = [x for x in poolMap] + + # Drop unneeded columns (saves memory) + roundData = roundData.drop(['neighbors', 'spot_id'], axis=1) + + # Add possible spot codes to spot table (must chain results from different jobs together) + roundData['spot_codes'] = list(chain(*[job for job in results])) + + return roundData + +def generateRoundPermutations(size: int, roundOmitNum: int) -> list: + + ''' + Creates list of lists of logicals detailing the rounds to be used for decoding based on the + current roundOmitNum + Parameters + ---------- + size : int + Number of rounds in experiment + roundOmitNum: int + Number of rounds that can be dropped from each barcode + Returns + ------- + list : list of lists of logicals detailing the rounds to be used for decoding based on + the current roundOmitNum + ''' + + if roundOmitNum == 0: + return [tuple([True] * size)] + else: + return sorted(set(list(permutations([*([False] * roundOmitNum), + *([True] * (size - roundOmitNum))])))) + +def decodeFunc(data: pd.DataFrame, permutationCodes: dict) -> tuple: + + ''' + Subfunction for decoder that allows it to run in parallel chunks + Parameters + ---------- + data : pd.DataFrame + DataFrame with columns called 'barcodes' and 'spot_codes' + permutationCodes : dict + Dictionary containing barcode information for each roundPermutation + Returns + ------- + tuple : First element is a list of all decoded targets, second element is a list of all + decoded spot codes + ''' + + # Checks if each barcode is in the permutationsCodes dict, if it isn't, there is no match + allTargets = [] + allDecodedSpotCodes = [] + allBarcodes = list(data['barcodes']) + allSpotCodes = list(data['spot_codes']) + for i in range(len(allBarcodes)): + targets = [] + decodedSpotCodes = [] + for j, barcode in enumerate(allBarcodes[i]): + try: + # Try to assign target by using barcode as key in permutationsCodes dictionary for + # current set of rounds. If there is no barcode match, it will error and go to the + # except and if it succeeds it will add the data to the other lists for this barcode + targets.append(permutationCodes[barcode]) + decodedSpotCodes.append(allSpotCodes[i][j]) + except Exception: + pass + allTargets.append(targets) + allDecodedSpotCodes.append(decodedSpotCodes) + return (allTargets, allDecodedSpotCodes) + +def decoder(roundData: pd.DataFrame, + codebook: Codebook, + channelDict: dict, + strictness: int, + currentRoundOmitNum: int, + currentRound: int, + numJobs: int) -> pd.DataFrame: + + ''' + Function that takes spots tables with possible barcodes added and matches each to the codebook + to identify any matches. Matches are added to the spot tables and spots without any matches are + dropped + Parameters + ---------- + roundData : pd.DataFrane + Modified spot table containing all possible barcodes that can be made from each spot + for the current round + codebook : Codebook + starFISH Codebook object containg the barcode information for the experiment + channelDict : dict + Dictionary with mappings between spot IDs and their channel labels + strictness : int + Determines the number of target matching barcodes each spot is allowed before it is + dropped as ambiguous (if it is negative) + currentRoundOmitNum : int + Number of rounds that can be dropped from each barcode + currentRound : int + Current round being for which spots are being decoded + numJobs : int + Number of CPU threads to use in parallel + Returns + ------- + pd.DataFrane : Modified spot table with added columns with information on decodable + barcodes + ''' + + # Add barcodes column by mapping spotIDs in spot_codes to channel labels using channelDict + if strictness > 0: + roundData['barcodes'] = [[hash(tuple([channelDict[j][spot] for j, spot in + enumerate(code)]))] for code in roundData['spot_codes']] + roundData['spot_codes'] = [[codes] for codes in roundData['spot_codes']] + else: + barcodes = [] + for codes in roundData['spot_codes']: + barcodes.append([hash(tuple([channelDict[j][spot] for j, spot in enumerate(code)])) + for code in decodeSpots(codes, len(channelDict))]) + roundData['barcodes'] = barcodes + + # Create list of logical arrays corresponding to the round sets being used to decode + roundPermutations = generateRoundPermutations(codebook.sizes[Axes.ROUND], currentRoundOmitNum) + + # Create dictionary where the keys are the different round sets that can be used for decoding + # and the values are the modified codebooks corresponding to the rounds used + permCodeDict = {} + for currentRounds in roundPermutations: + codes = codebook.argmax(Axes.CH.value) + if currentRoundOmitNum > 0: + omittedRounds = np.argwhere(~np.asarray(currentRounds)) + codes.data[:, omittedRounds] = -1 + codes.data += 1 + roundDict = dict(zip([hash(tuple(code)) for code in codes.data], codes['target'].data)) + permCodeDict.update(roundDict) + + # Calculates index ranges to chunk data by and creates list of chunked data to loop through + ranges = [0] + for i in range(1, numJobs + 1): + ranges.append(round((len(roundData) / numJobs) * i)) + chunkedData = [] + for i in range(len(ranges[:-1])): + chunkedData.append(deepcopy(roundData[ranges[i]:ranges[i + 1]])) + + # Run in parallel + with ProcessPoolExecutor() as pool: + part = partial(decodeFunc, permutationCodes=permCodeDict) + poolMap = pool.map(part, [chunkedData[i] for i in range(len(chunkedData))]) + results = [x for x in poolMap] + + # Update table + roundData['targets'] = list(chain(*[job[0] for job in results])) + roundData['spot_codes'] = list(chain(*[job[1] for job in results])) + + roundData = roundData[[len(targets) > 0 for targets in + roundData['targets']]].reset_index(drop=True) + + if len(roundData) > 0: + if strictness < 0: + roundData = roundData[[len(targets) <= np.abs(strictness) for targets in + roundData['targets']]].reset_index(drop=True) + + roundData = roundData.drop(['barcodes'], axis=1) + + return roundData + +def distanceFunc(spotsAndTargets: list, + spotCoords: dict, + spotQualDict: dict, + currentRoundOmitNum: int) -> tuple: + + ''' + Subfunction for distanceFilter to allow it to run in parallel + Parameters + ---------- + subSpotCodes : list + Chunk of full list of spot codes for the current round to calculate the spatial + variance for + subSpotCodes : list + Chunk of full list of targets (0s if strictness is positive) associated with the + current set of spots whose spatial variance is being calculated + spotCoords : dict + Spot ID to spatial coordinate dictionary + spotQualDict : dict + Spot ID to normalized intensity value dictionary + currentRoundOmitNum : int + Number of rounds that can be dropped from each barcode + Returns + ------- + tuple: First object is the min scoring spot code for each spots, the second is the min + score for each spot, and the third is the min scoring target for each spot + ''' + + subSpotCodes = spotsAndTargets[0] + subTargets = spotsAndTargets[1] + + # Find minimum scoring combination of spots from set of possible combinations + constant = 2 + bestSpotCodes = [] + bestDistances = [] + bestTargets = [] + for i, codes in enumerate(subSpotCodes): + quals = [sum([spotQualDict[r][spot] for r, spot in enumerate(code) if spot != 0]) + for code in codes] + newQuals = np.asarray([-np.log(1 / (1 + (len(spotCoords) - currentRoundOmitNum - qual))) + for qual in quals]) + subCoords = [[spotCoords[r][spot] for r, spot in enumerate(code) if spot != 0] + for code in codes] + spaVars = [sum(np.var(np.asarray(coords), axis=0)) for coords in subCoords] + newSpaVars = np.asarray([-np.log(1 / (1 + spaVar)) for spaVar in spaVars]) + combined = newQuals + (newSpaVars * constant) + minInds = np.where(combined == min(combined))[0] + if len(minInds) == 1: + bestSpotCodes.append(codes[minInds[0]]) + bestDistances.append(combined[minInds[0]]) + bestTargets.append(subTargets[i][minInds[0]]) + else: + bestSpotCodes.append(-1) + bestDistances.append(-1) + bestTargets.append(-1) + + return (bestSpotCodes, bestDistances, bestTargets) + +def distanceFilter(roundData: pd.DataFrame, + spotCoords: dict, + spotQualDict: dict, + currentRound: int, + currentRoundOmitNum: int, + numJobs: int) -> pd.DataFrame: + + ''' + Function that chooses between the best barcode for each spot from the set of decodable barcodes. + Does this by choosing the barcode with the least spatial variance and high intensity spots + according to this calculation: + Score = -log(1 / 1 + (numRounds - qualSum)) + (-log(1 / 1 + spaVar) * constant) + Where: + numRounds = number of rounds being used for decoding (total - currentRoundOmitNum) + qualSum = sum of normalized intensity values for the spots in the code + spaVar = spatial variance of spots in code, calculates as the sum of variances of the + values in each spatial dimension + constant = a constant that determines the balance between the score being more influenced + by spatial variance or intensity, set to 2 so spatial variance is the biggest + deciding factor but allows ties to be broken by intensity + Parameters + ---------- + roundData : pd.DataFrame + Modified spot table containing info on decodable barcodes for the spots in the current + round + spotCoords : dict + Spot ID to spatial coordinate dictionary + spotQualDict : dict + Spot ID to normalized intensity value dictionary + currentRound : int + Current round number to calculate distances for + currentRoundOmitNum : int + Number of rounds that can be dropped from each barcode + numJobs : int + Number of CPU threads to use in parallel + Returns + ------- + pd.DataFrame : Modified spot table with added columns to with info on the "best" barcode + found for each spot + ''' + + # Calculate the spatial variance for each decodable barcode for each spot in each round + if len(roundData) == 0: + return roundData + + if 'targets' in roundData.columns: + checkTargets = True + else: + checkTargets = False + + # Extract spot codes and targets + allSpotCodes = [decodeSpots(codes, len(spotCoords)) for codes in roundData['spot_codes']] + if checkTargets: + allTargets = roundData['targets'].tolist() + else: + allTargets = [[0 for code in codes] for codes in roundData['spot_codes']] + + # Find ranges to chunk data by + ranges = [0] + for i in range(1, numJobs): + ranges.append(round((len(roundData) / numJobs) * i)) + ranges.append(len(roundData)) + chunkedSpotCodes = [allSpotCodes[ranges[i]:ranges[i + 1]] for i in range(len(ranges[:-1]))] + chunkedTargets = [allTargets[ranges[i]:ranges[i + 1]] for i in range(len(ranges[:-1]))] + + # Run in parallel + with ProcessPoolExecutor() as pool: + part = partial(distanceFunc, spotCoords=spotCoords, spotQualDict=spotQualDict, + currentRoundOmitNum=currentRoundOmitNum) + poolMap = pool.map(part, [spotsAndTargets for spotsAndTargets in zip(chunkedSpotCodes, + chunkedTargets)]) + results = [x for x in poolMap] + + # Add distances to decodedTables as new column and replace spot_codes and targets column with + # only the min scoring values + roundData['spot_codes'] = list(chain(*[job[0] for job in results])) + roundData['distance'] = list(chain(*[job[1] for job in results])) + if checkTargets: + roundData['targets'] = list(chain(*[job[2] for job in results])) + + # Remove spots who had a tie between possible spot combinations + roundData = roundData[roundData['spot_codes'] != -1] + + return roundData + +def cleanup(bestPerSpotTables: dict, + spotCoords: dict, + channelDict: dict, + strictness: int, + currentRoundOmitNum: int, + seedNumber: int) -> pd.DataFrame: + + ''' + Function that combines all "best" codes for each spot in each round into a single table, + filters them by their frequency (with a user-defined threshold), chooses between overlapping + codes (using the same distance function as used earlier), and finally adds some additional + information to the final set of barcodes + Parameters + ---------- + bestPerSpotTables : dict + Spot tables dictionary containing columns with information on the "best" barcode found + for each spot + spotCoords : dict + Dictionary containing spatial locations of spots + channelDict : dict + Dictionary with mapping between spot IDs and the channel labels + strictness : int + Parameter that determines how many possible barcodes each spot can have before it is + dropped as ambiguous + currentRoundOmitNum : int + Number of rounds that can be dropped from each barcode + seedNumber : A barcode must be chosen as "best" in this number of rounds to pass filters + Returns + ------- + pd.DataFrame : Dataframe containing final set of codes that have passed all filters + ''' + + # Create merged spot results dataframe containing the passing barcodes found in all the rounds + mergedCodes = pd.DataFrame() + roundNum = len(bestPerSpotTables) + for r in range(roundNum): + if len(bestPerSpotTables[r]) != 0: + if strictness > 0: + spotCodes = bestPerSpotTables[r]['spot_codes'] + targets = bestPerSpotTables[r]['targets'] + # Turn each barcode and spot code into a tuple so they can be used as dictionary + # keys + bestPerSpotTables[r]['spot_codes'] = [tuple(spotCode[0]) for spotCode in spotCodes] + bestPerSpotTables[r]['targets'] = [target[0] for target in targets] + mergedCodes = mergedCodes.append(bestPerSpotTables[r]) + mergedCodes = mergedCodes.reset_index(drop=True) + + # If no codes return empty dataframe + if len(mergedCodes) == 0: + return pd.DataFrame() + + # Only pass codes that are chosen as best for at least 2 of the spots that make it up + spotCodes = mergedCodes['spot_codes'] + counts = defaultdict(int) # type: dict + for code in spotCodes: + counts[code] += 1 + passing = list(set(code for code in counts if counts[code] >= seedNumber)) + + passingCodes = mergedCodes[mergedCodes['spot_codes'].isin(passing)].reset_index(drop=True) + passingCodes = passingCodes.iloc[passingCodes['spot_codes'].drop_duplicates().index] + passingCodes = passingCodes.reset_index(drop=True) + + # If no codes return empty dataframe + if len(passingCodes) == 0: + return pd.DataFrame() + + # Need to find maximum independent set of spot codes where each spot code is a node and there + # is an edge connecting two codes if they share at least one spot. Does this by eliminating + # nodes (spot codes) that have the most edges first and if there is tie for which has the most + # edges they are ordered in order of decreasing spatial variance of the spots that make it up + # (so codes are eliminated in order first of how many other codes they share a spots with and + # then spatial variance is used to break ties). Nodes are eliminated from the graph in this way + # until there are no more edges in the graph + + # First prepare list of counters of the spot IDs for each round + spotCodes = passingCodes['spot_codes'] + codeArray = np.asarray([np.asarray(code) for code in spotCodes]) + counters = [] # type: typing.List[Counter] + for r in range(roundNum): + counters.append(Counter(codeArray[:, r])) + counters[-1][0] = 0 + + # Then create collisonCounter dictionary which has the number of edges for each code and the + # collisions dictionary which holds a list of codes each code has an overlap with. Any code with + # no overlaps is added to keep to save later + collisionCounter = defaultdict(int) # type: dict + collisions = defaultdict(list) + keep = [] + for i, spotCode in enumerate(spotCodes): + collision = False + for r in range(roundNum): + if spotCode[r] != 0: + count = counters[r][spotCode[r]] - 1 + if count > 0: + collision = True + collisionCounter[spotCode] += count + collisions[spotCode].extend([spotCodes[ind[0]] for ind in + np.argwhere(codeArray[:, r] == spotCode[r]) + if ind[0] != i]) + if not collision: + keep.append(i) + + # spotDict dictionary has mapping for codes to their index location in spotCodes and + # codeDistance has mapping for codes to their spatial variance value + spotDict = {code: i for i, code in enumerate(spotCodes)} + codeDistance = passingCodes.set_index('spot_codes')['distance'].to_dict() + while len(collisions): + # Gets all the codes that have the highest value for number of edges, and then sorts them by + # their spatial variance values in decreasing order + maxValue = max(collisionCounter.values()) + maxCodes = [code for code in collisionCounter if collisionCounter[code] == maxValue] + distances = np.asarray([codeDistance[code] for code in maxCodes]) + sortOrder = [item[1] for item in sorted(zip(distances, range(len(distances))), + reverse=True)] + maxCodes = [tuple(code) for code in np.asarray(maxCodes)[sortOrder]] + + # For every maxCode, first check that it is still a maxCode (may change during this loop), + # if it is then modify all the nodes that have edge to it to have one less edge (if this + # causes that node to have no more edges then delete it from the graph and add it to the + # codes we keep), then delete the maxCode from the graph + for maxCode in maxCodes: + if collisionCounter[maxCode] == maxValue: + for code in collisions[maxCode]: + if collisionCounter[code] == 1: + del collisionCounter[code] + del collisions[code] + keep.append(spotDict[code]) + else: + collisionCounter[code] -= 1 + collisions[code] = [c for c in collisions[code] if c != maxCode] + + del collisionCounter[maxCode] + del collisions[maxCode] + + # Only choose codes that we found to not have any edges in the graph + finalCodes = passingCodes.loc[keep].reset_index(drop=True) + + if len(finalCodes) == 0: + return pd.DataFrame() + + # Add barcode lables, spot coordinates, barcode center coordinates, and number of rounds used + # for each barcode to table + barcodes = [] + allCoords = [] + centers = [] + roundsUsed = [] + # intensities = [] + for i in range(len(finalCodes)): + spotCode = finalCodes.iloc[i]['spot_codes'] + barcodes.append([channelDict[j][spot] for j, spot in enumerate(spotCode)]) + counter = Counter(spotCode) # type: Counter + roundsUsed.append(roundNum - counter[0]) + coords = np.asarray([spotCoords[j][spot] for j, spot in enumerate(spotCode) if spot != 0]) + allCoords.append(coords) + coords = np.asarray([coord for coord in coords]) + center = np.asarray(coords).mean(axis=0) + centers.append(tuple(center)) + # intensities.append([spotIntensities[j][spot] for j,spot in enumerate(spotCode)]) + finalCodes['best_barcodes'] = barcodes + finalCodes['coords'] = allCoords + finalCodes['center'] = centers + # finalCodes['intensities'] = intensities + finalCodes['rounds_used'] = roundsUsed + + return finalCodes + +def removeUsedSpots(finalCodes: pd.DataFrame, spotTables: dict) -> dict: + + ''' + Remove spots found to be in barcodes for the current round omission number from the spotTables + so they are not used for the next round omission number + Parameters + ---------- + finalCodes : pd.DataFrame + Dataframe containing final set of codes that have passed all filters + spotTables : dict + Dictionary of original data tables extracted from SpotFindingResults objects by the + _merge_spots_by_round() function + Returns + ------- + dict : Modified version of spotTables with spots that have been used in the current round + omission removed + ''' + + # Remove used spots + for r in range(len(spotTables)): + if len(spotTables[r]) > 0: + usedSpots = set([passed[r] for passed in finalCodes['spot_codes'] + if passed[r] != 0]) + spotTables[r] = spotTables[r][~spotTables[r]['spot_id'].isin(usedSpots)] + spotTables[r] = spotTables[r].reset_index(drop=True) + spotTables[r].index = range(1, len(spotTables[r]) + 1) + + return spotTables diff --git a/starfish/core/spots/DecodeSpots/test/test_check_all.py b/starfish/core/spots/DecodeSpots/test/test_check_all.py new file mode 100644 index 000000000..04afaa323 --- /dev/null +++ b/starfish/core/spots/DecodeSpots/test/test_check_all.py @@ -0,0 +1,168 @@ +import random + +import numpy as np +from scipy.ndimage import gaussian_filter + +from starfish import ImageStack +from starfish.core.codebook.codebook import Codebook +from starfish.core.spots.DecodeSpots.check_all_decoder import CheckAll +from starfish.core.spots.FindSpots import BlobDetector + +def syntheticSeqfish(x, y, z, codebook, nSpots, jitter, error): + nRound = codebook.shape[1] + nChannel = codebook.shape[2] + img = np.zeros((nRound, nChannel, z, y, x), dtype=np.float32) + + intCodes = np.argmax(codebook.data, axis=2) + + targets = [] + for _ in range(nSpots): + randx = random.choice(range(5, x - 5)) + randy = random.choice(range(5, y - 5)) + randz = random.choice(range(2, z - 2)) + randCode = random.choice(range(len(codebook))) + targets.append((randCode, (randx, randy, randz))) + if jitter > 0: + randx += random.choice(range(jitter + 1)) * random.choice([1, -1]) + randy += random.choice(range(jitter + 1)) * random.choice([1, -1]) + if error: + skip = random.choice(range(nRound)) + else: + skip = 100 + for r, ch in enumerate(intCodes[randCode]): + if r != skip: + img[r, ch, randz, randy, randx] = 10 + + gaussian_filter(img, (0, 0, 0.5, 1.5, 1.5), output=img) + + return ImageStack.from_numpy(img / img.max()), targets + + +def seqfishCodebook(nRound, nChannel, nCodes): + + def barcodeConv(lis, chs): + barcode = np.zeros((len(lis), chs)) + for i in range(len(lis)): + barcode[i][lis[i]] = 1 + return barcode + + def incrBarcode(lis, chs): + currInd = len(lis) - 1 + lis[currInd] += 1 + while lis[currInd] == chs: + lis[currInd] = 0 + currInd -= 1 + lis[currInd] += 1 + return lis + + allCombo = np.zeros((nChannel ** nRound, nRound, nChannel)) + + barcode = [0] * nRound + for i in range(np.shape(allCombo)[0]): + allCombo[i] = barcodeConv(barcode, nChannel) + barcode = incrBarcode(barcode, nChannel) + + hammingDistance = 1 + blanks = [] + i = 0 + while i < len(allCombo): + blanks.append(allCombo[i]) + j = i + 1 + while j < len(allCombo): + if np.count_nonzero(~(allCombo[i] == allCombo[j])) / 2 <= hammingDistance: + allCombo = allCombo[[k for k in range(len(allCombo)) if k != j]] + else: + j += 1 + i += 1 + + data = np.asarray(blanks)[random.sample(range(len(blanks)), nCodes)] + + return Codebook.from_numpy(code_names=range(len(data)), n_round=nRound, + n_channel=nChannel, data=data) + +def testExactMatches(): + + codebook = seqfishCodebook(5, 3, 20) + + img, trueTargets = syntheticSeqfish(100, 100, 20, codebook, 5, 0, False) + + bd = BlobDetector(min_sigma=1, max_sigma=4, num_sigma=30, threshold=.1, exclude_border=False) + spots = bd.run(image_stack=img) + assert spots.count_total_spots() == 5 * 5, 'Spot detector did not find all spots' + + decoder = CheckAll(codebook=codebook, search_radius=1, error_rounds=0) + hits = decoder.run(spots=spots, n_processes=4) + + testTargets = [] + for i in range(len(hits)): + testTargets.append((int(hits[i]['target'].data), + (int(hits[i]['x'].data), int(hits[i]['y'].data), + int(hits[i]['z'].data)))) + + matches = 0 + for true in trueTargets: + for test in testTargets: + if true[0] == test[0]: + if test[1][0] + 1 >= true[1][0] >= test[1][0] - 1 and \ + test[1][1] + 1 >= true[1][1] >= test[1][1] - 1: + matches += 1 + + assert matches == len(trueTargets), 'Incorrect number of targets found' + +def testJitteredMatches(): + + codebook = seqfishCodebook(5, 3, 20) + + img, trueTargets = syntheticSeqfish(100, 100, 20, codebook, 5, 2, False) + + bd = BlobDetector(min_sigma=1, max_sigma=4, num_sigma=30, threshold=.1, exclude_border=False) + spots = bd.run(image_stack=img) + assert spots.count_total_spots() == 5 * 5, 'Spot detector did not find all spots' + + decoder = CheckAll(codebook=codebook, search_radius=3, error_rounds=0) + hits = decoder.run(spots=spots, n_processes=4) + + testTargets = [] + for i in range(len(hits)): + testTargets.append((int(hits[i]['target'].data), + (int(hits[i]['x'].data), int(hits[i]['y'].data), + int(hits[i]['z'].data)))) + + matches = 0 + for true in trueTargets: + for test in testTargets: + if true[0] == test[0]: + if test[1][0] + 3 >= true[1][0] >= test[1][0] - 3 and \ + test[1][1] + 3 >= true[1][1] >= test[1][1] - 3: + matches += 1 + + assert matches == len(trueTargets), 'Incorrect number of targets found' + +def testErrorCorrection(): + + codebook = seqfishCodebook(5, 3, 20) + + img, trueTargets = syntheticSeqfish(100, 100, 20, codebook, 5, 0, True) + + bd = BlobDetector(min_sigma=1, max_sigma=4, num_sigma=10, threshold=.1, exclude_border=False) + spots = bd.run(image_stack=img) + assert spots.count_total_spots() == 4 * 5, 'Spot detector did not find all spots' + + decoder = CheckAll(codebook=codebook, search_radius=1, error_rounds=1) + hits = decoder.run(spots=spots, n_processes=4) + + testTargets = [] + for i in range(len(hits)): + testTargets.append((int(str(hits[i]['target'].data).split('.')[0]), + (int(hits[i]['x'].data), int(hits[i]['y'].data), + int(hits[i]['z'].data)))) + + matches = 0 + for true in trueTargets: + for test in testTargets: + if true[0] == test[0]: + if test[1][0] + 1 >= true[1][0] >= test[1][0] - 1 and \ + test[1][1] + 1 >= true[1][1] >= test[1][1] - 1: + matches += 1 + + assert matches == len(trueTargets), 'Incorrect number of targets found'