Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add img registration #729

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
51e4dac
Update transform_img_registration.md
DannieSheng Jun 8, 2021
e85b2cc
Update img_registration.py
DannieSheng Jun 8, 2021
5fc1ea4
update with master
DannieSheng Jun 8, 2021
8b806cc
Update img_registration.py
DannieSheng Jun 16, 2021
ca2c0bc
Update img_registration.py
DannieSheng Jun 25, 2021
f8f0365
add `ImageRegistrator` to `__init__.py`
DannieSheng Jun 25, 2021
21c3157
Update documentation for `transform/img_registration`
DannieSheng Jun 25, 2021
272fc14
Add tests for `ImageRegistrator`
DannieSheng Jun 25, 2021
f2ed92b
small format changes for `transform/warp`
DannieSheng Jun 25, 2021
5d798c2
adding tests for `img_registration`
DannieSheng Jun 25, 2021
8967d20
Merge branch '4.x' into add_img_registration
nfahlgren Dec 17, 2021
b3411f5
import _find_closest_pt and remove code redundancy
HaleySchuhl Dec 17, 2021
f5119bb
remove extra imports
HaleySchuhl Dec 17, 2021
8d8166c
try to cut code redundancy
HaleySchuhl Dec 20, 2021
aafe960
typo
HaleySchuhl Dec 20, 2021
38f9ae7
update click event within class
HaleySchuhl Dec 20, 2021
bec1b65
var name update
HaleySchuhl Dec 20, 2021
fdbe6d9
undo since tests wont run
HaleySchuhl Dec 20, 2021
2a02047
refactor img_registration into annotate
HaleySchuhl Dec 21, 2021
1002c87
remove registrator from transform init
HaleySchuhl Dec 21, 2021
cfa1759
append annotate init file
HaleySchuhl Dec 21, 2021
05c53cb
clean up registrator code
HaleySchuhl Dec 21, 2021
726dda4
re-arrange tests
HaleySchuhl Dec 21, 2021
7645e18
Update transform_img_registration.md
HaleySchuhl Dec 21, 2021
0e404ce
other test update
HaleySchuhl Dec 21, 2021
842a90a
Merge branch '4.x' into add_img_registration
HaleySchuhl May 11, 2022
a96296c
Create test_img_registration.py
HaleySchuhl May 11, 2022
e4c3910
update docs page
HaleySchuhl Jun 24, 2022
7451cce
refactor var names to match inputs pattern
HaleySchuhl Jun 24, 2022
8feb07b
add internal docs
HaleySchuhl Jun 24, 2022
5a93d0d
Merge branch '4.x' into add_img_registration
HaleySchuhl Jun 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions docs/transform_img_registration.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
## Image Registration

`ImageRegistrator` is a class that registrater a target image based on user selected landmark pixels on reference and
target images.

*class* **plantcv.transform.ImageRegistrator(img_ref, img_tar, figsize=(12, 6))**

To initialize an instance of `ImageRegistrator` class, two required parameters are `img_ref` and `img_tar`, represent
for target image and reference image, respectively.

Another optional parameter is the desired figure size `figsize`, by default `figsize=(12,6)`.

### Attributes
**img_ref** (`ndarray`, datatype: uint8, required): input reference image.
**img_tar** (`ndarray`, datatype: uint8, required): input target image.
**points** (`list`): list of coordinates of selected pixels on reference image and target image.
**model** (`numpy.ndarray`): tranformation matrix of size (3,3) that register the target image to the reference image.
**img_registered** (`numpy.ndarray`, datatype: uint8, required): registered target image.

### Class methods
**display_coords()**

Display user selected coordinates for both reference image and target image

**regist()**

Register the target image to the reference image based on user selected landmark points.

**save_model(model_file="model")**

Save the transformation matrix used for image registration.

```python

from plantcv import plantcv as pcv
# Initialize an image registrator
img_registrator = ImageRegistrator(img_ref, img_tar, figsize=(12, 6))

##
# collecting land mark points
##

# Display user selected coordinates on both reference image and target image
img_registrator.display_coords()

# Register the target image to the reference image based on the model calculated from selected points
img_registrator.regist()

# Save
img_registrator.save_model(model_file="my_model")

```

Reference image (a thermal image):

![thermal_ref](img/documentation_images/transform_img_registration/ref_therm.png)

Target image (a RGB image):

![thermal_ref](img/documentation_images/transform_img_registration/tar_rgb.png)

Overlay these two images:

![overlay](img/documentation_images/transform_img_registration/overlay_before.png)

Overlay two images after image registration:

![overlay_after](img/documentation_images/transform_img_registration/overlay_after.png)


Check out this video for how this interactive tool works!
<iframe src="https://player.vimeo.com/video/522809945" width="640" height="360" frameborder="0" allow="autoplay; fullscreen; picture-in-picture" allowfullscreen></iframe>


**Source Code:** [Here](https://github.com/danforthcenter/plantcv/blob/master/plantcv/plantcv/transform/img_registration.py)
3 changes: 2 additions & 1 deletion plantcv/plantcv/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from plantcv.plantcv.transform.resize import resize, resize_factor
from plantcv.plantcv.transform.warp import warp, warp_align
from plantcv.plantcv.transform.gamma_correct import gamma_correct
from plantcv.plantcv.transform.img_registration import ImageRegistrator

__all__ = ["get_color_matrix", "get_matrix_m", "calc_transformation_matrix", "apply_transformation_matrix",
"save_matrix", "load_matrix", "correct_color", "create_color_card_mask", "quick_color_check",
"find_color_card", "rescale", "nonuniform_illumination", "resize", "resize_factor",
"warp", "rotate", "warp", "warp_align", "gamma_correct"]
"warp", "rotate", "warp", "warp_align", "gamma_correct", "ImageRegistrator"]
120 changes: 120 additions & 0 deletions plantcv/plantcv/transform/img_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Image Registration Based On User Selected Landmark Points

import cv2
from plantcv import plantcv as pcv
from scipy.spatial import distance
import numpy as np
import matplotlib.pyplot as plt
import pickle as pkl


def _find_closest(pt, pts):
""" Given coordinates of a point and a list of coordinates of a bunch of points,
find the point that has the smallest Euclidean to the given point

:param pt: (tuple) coordinates of a point
:param pts: (a list of tuples) coordinates of a list of points
:return: index of the closest point and the coordinates of that point
"""
dists = distance.cdist([pt], pts, 'euclidean')
idx = np.argmin(dists)
return idx, pts[idx]


class ImageRegistrator:
"""
An interactive tool that takes user selected landmark points to register two images
"""
def __init__(self, img_ref, img_tar, figsize=(12, 6), cmap='jet'):
self.img_ref = img_ref
self.img_tar = img_tar

self.fig, self.axes = plt.subplots(1, 2, figsize=figsize)
self.axes[0].text(0, -100,
'Collect points matching features between images. '
'Select location on reference image then target image. '
'\nPlease first click on the reference image, then on the same point on the target image.'
'\nPlease select at least 4 pairs of points.')

# assumption: any 3-d images whose 3rd dimension is 3 are rgb images
# This check to be replaced when image class implemented
dim_ref, dim_tar = img_ref.shape, img_tar.shape
if len(dim_ref) == 3 and dim_ref[-1] == 3:
self.axes[0].imshow(cv2.cvtColor(img_ref, cv2.COLOR_BGR2RGB))
else:
self.axes[0].imshow(img_ref, cmap=cmap)
self.axes[0].set_title('Reference Image')

if len(dim_tar) == 3 and dim_tar[-1] == 3:
self.axes[1].imshow(cv2.cvtColor(img_tar, cv2.COLOR_BGR2RGB))
else:
self.axes[1].imshow(img_tar, cmap=cmap)
self.axes[1].set_title('Target Image')

# Set useblit=True on most backends for enhanced performance.
# cursor = Cursor(axes[0], horizOn=True, vertOn=True, useblit=True, color='red', linewidth=2)

self.points = [[], []]
self.events = []

# onclick = functools.partial(_onclick_, fig, axes, array_data, wvs)

self.fig.canvas.mpl_connect('button_press_event', self.onclick)

self.model = None
self.img_registered = None

def left_click(self, idx_ax, x, y):
self.axes[idx_ax].plot(x, y, 'x', c='red')
self.points[idx_ax].append((x, y))

def right_click(self, idx_ax, x, y):
idx_remove, _ = _find_closest((x, y), self.points[idx_ax])
# remove the last added point
# idx_remove = -1
self.points[idx_ax].pop(idx_remove)
axplots = self.axes[idx_ax].lines
self.axes[idx_ax].lines.remove(axplots[idx_remove])

def onclick(self, event):
self.events.append(event)

# collect points on reference image
if str(event.inaxes._subplotspec) == 'GridSpec(1, 2)[0:1, 0:1]':
# left click
if event.button == 1:
self.left_click(0, event.xdata, event.ydata)
# right click
else:
self.right_click(0, event.xdata, event.ydata)

# collect points on target image
elif str(event.inaxes._subplotspec) == 'GridSpec(1, 2)[0:1, 1:2]':
# left click
if event.button == 1:
self.left_click(1, event.xdata, event.ydata)

# right click
else:
self.right_click(1, event.xdata, event.ydata)
self.fig.canvas.draw()

def save_model(self, model_file="model"):
pkl.dump(self.model, open("{}.pkl".format(model_file), "wb"))

def display_coords(self):
print("\nCoordinates for selected reference points: ")
for point_ref in self.points[0]:
print("\n{}".format(point_ref))
print("\nCoordinates for selected target points: ")
for point_tar in self.points[1]:
print("\n{}".format(point_tar))

def regist(self):
# use warp function in plantcv
self.model, self.img_registered = pcv.transform.warp(self.img_tar,
self.img_ref,
self.points[1],
self.points[0],
method='ransac')

7 changes: 4 additions & 3 deletions plantcv/plantcv/transform/warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def warp(img, refimg, pts, refpts, method='default'):
fatal_error('Please provide same number of corresponding coordinates.')
if not(len(pts) >= 4 and len(refpts) >= 4):
fatal_error('Please provide at least 4 pairs of points!')
# convert coordinates to int if they are not int

# convert coordinates to int (if they are not int)
pts = [tuple(map(int, tup)) for tup in pts]
refpts = [tuple(map(int, tup)) for tup in refpts]

Expand All @@ -74,7 +75,7 @@ def warp(img, refimg, pts, refpts, method='default'):
rows_ref, cols_ref = shape_ref[0:2]
rows_img, _ = shape_img[0:2]

# convert list of tuples to array for cv2 functions
# convert from lists (of tuples) to arrays (for cv2 functions)
ptsarr = np.array(pts, dtype='float32')
refptsarr = np.array(refpts, dtype='float32')

Expand All @@ -101,7 +102,7 @@ def warp(img, refimg, pts, refpts, method='default'):
colors = color_palette(len(pts))

# convert image types to accepted ones for cv2.cvtColor, also compatible with colors generated by color_palette
# (color_palette generated colors that are in range of (0,255)
# (color_palette generated colors that are in range of (0,255))
img_ = _preprocess_img_dtype(img)
refimg_ = _preprocess_img_dtype(refimg)

Expand Down
69 changes: 69 additions & 0 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,14 @@ def test_plantcv_outputs_add_observation_invalid_type():
datatype=list, value=np.array([2]), label=[])


def test_plantcv_transform_warp_align():
print('dd')
cache_dir = os.path.join(TEST_TMPDIR, "test_plantcv_warp_align")
os.mkdir(cache_dir)
pcv.params.debug_outdir = cache_dir
img = cv2.imread(os.path.join(TEST_DATA, TEST_INPUT_COLOR))


def test_plantcv_outputs_save_results_json_newfile(tmpdir):
# Create a test tmp directory
cache_dir = tmpdir.mkdir("sub")
Expand Down Expand Up @@ -5783,6 +5791,67 @@ def test_plantcv_transform_nonuniform_illumination_gray():
pcv.params.debug = "print"
corrected = pcv.transform.nonuniform_illumination(img=gray_img, ksize=11)
assert np.shape(corrected) == np.shape(gray_img)


def test_plantcv_transform_img_registration(tmpdir):
# Create a test tmp directory
cache_dir = tmpdir.mkdir("sub")
# generate fake testing images
img_ref = create_test_img((12, 10, 3))
img_tar = create_test_img((12, 10, 3))
img_registrator = pcv.transform.ImageRegistrator(img_ref, img_tar)

# create mock events: left click
for pt in [(0, 0), (1, 0), (0, 3), (4, 4), (3,2)]:
# left click, left axis
e1 = matplotlib.backend_bases.MouseEvent(name="button_press_event", canvas=img_registrator.fig.canvas,
x=0, y=0, button=1)
e1.inaxes = img_registrator.axes[0]
e1.inaxes._subplotspec = matplotlib.gridspec.SubplotSpec(matplotlib.gridspec.GridSpec(1,2,img_registrator.fig), 0)
e1.xdata, e1.ydata = pt
img_registrator.onclick(e1)

# left click, right axis
e2 = matplotlib.backend_bases.MouseEvent(name="button_press_event", canvas=img_registrator.fig.canvas,
x=0, y=0, button=1)
e2.inaxes = img_registrator.axes[0]
e2.inaxes._subplotspec = matplotlib.gridspec.SubplotSpec(matplotlib.gridspec.GridSpec(1,2,img_registrator.fig), 1)
e2.xdata, e2.ydata = pt
img_registrator.onclick(e2)
# right click, left axis
e1_ = matplotlib.backend_bases.MouseEvent(name="button_press_event", canvas=img_registrator.fig.canvas,
x=0, y=0, button=3)
e1_.inaxes = img_registrator.axes[0]
e1_.inaxes._subplotspec = matplotlib.gridspec.SubplotSpec(matplotlib.gridspec.GridSpec(1, 2, img_registrator.fig), 0)
e1_.xdata, e1_.ydata = 3, 2
img_registrator.onclick(e1_)

# right click, right axis
e2_ = matplotlib.backend_bases.MouseEvent(name="button_press_event", canvas=img_registrator.fig.canvas,
x=0, y=0, button=3)
e2_.inaxes = img_registrator.axes[0]
e2_.inaxes._subplotspec = matplotlib.gridspec.SubplotSpec(matplotlib.gridspec.GridSpec(1, 2, img_registrator.fig), 1)
e2_.xdata, e2_.ydata = 3, 2
img_registrator.onclick(e2_)


img_registrator.regist()
img_registrator.display_coords()
img_registrator.save_model(model_file=os.path.join(cache_dir, "model"))

assert img_registrator.model is not None \
and len(img_registrator.points[0]) == 4 \
and len(img_registrator.points[1]) == 4 \
and os.path.isfile(os.path.join(cache_dir, "model.pkl"))


def test_plantcv_transform_img_registration_gray():
# generate fake testing images
img_ref = create_test_img((12, 10))
img_tar = create_test_img((12, 10))
img_registrator = pcv.transform.ImageRegistrator(img_ref, img_tar)
assert len(img_registrator.events) == 0



def test_plantcv_transform_warp_default():
Expand Down