diff --git a/docs/transform_img_registration.md b/docs/transform_img_registration.md new file mode 100644 index 000000000..74930d7ba --- /dev/null +++ b/docs/transform_img_registration.md @@ -0,0 +1,75 @@ +## Image Registration + +`ImageRegistrator` is a class that transforms a target image based on corresponding, user-selected landmark points on the reference and +target images. + +*class* **plantcv.annotate.ImageRegistrator(ref_img, target_img, figsize=(12, 6))** + +To initialize an instance of `ImageRegistrator` class, two required parameters are `ref_img` and `target_img`, represent +for target image and reference image, respectively. + +Another optional parameter is the desired figure size `figsize`, by default `figsize=(12,6)`. + +### Attributes +**ref_img** (`ndarray`, datatype: uint8, required): input reference image. +**target_img** (`ndarray`, datatype: uint8, required): input target image. +**points** (`list`): list of coordinates of selected pixels on reference image and target image. +**model** (`numpy.ndarray`): tranformation matrix of size (3,3) that register the target image to the reference image. +**img_registered** (`numpy.ndarray`, datatype: uint8, required): registered target image. + +### Class methods +**display_coords()** + +Display user selected coordinates for both reference image and target image + +**regist()** + +Register the target image to the reference image based on user selected landmark points. + +**save_model(model_file="model")** + +Save the transformation matrix used for image registration. + +```python + +from plantcv import plantcv as pcv +# Initialize an image registrator +img_registrator = pcv.annotate.ImageRegistrator(ref_img, target_img, figsize=(12, 6)) + +## +# collecting land mark points +## + +# Display user selected coordinates on both reference image and target image +img_registrator.display_coords() + +# Register the target image to the reference image based on the model calculated from selected points +img_registrator.regist() + +# Save +img_registrator.save_model(model_file="my_model") + +``` + +Reference image (a thermal image): + +![thermal_ref](img/documentation_images/transform_img_registration/ref_therm.png) + +Target image (a RGB image): + +![thermal_ref](img/documentation_images/transform_img_registration/tar_rgb.png) + +Overlay these two images: + +![overlay](img/documentation_images/transform_img_registration/overlay_before.png) + +Overlay two images after image registration: + +![overlay_after](img/documentation_images/transform_img_registration/overlay_after.png) + + +Check out this video for how this interactive tool works! + + + +**Source Code:** [Here](https://github.com/danforthcenter/plantcv/blob/master/plantcv/plantcv/annotate/img_registration.py) diff --git a/plantcv/plantcv/annotate/__init__.py b/plantcv/plantcv/annotate/__init__.py index e69de29bb..7b1c3cecc 100644 --- a/plantcv/plantcv/annotate/__init__.py +++ b/plantcv/plantcv/annotate/__init__.py @@ -0,0 +1,3 @@ +from plantcv.plantcv.annotate.img_registration import ImageRegistrator + +__all__ = ["ImageRegistrator"] diff --git a/plantcv/plantcv/annotate/img_registration.py b/plantcv/plantcv/annotate/img_registration.py new file mode 100644 index 000000000..896af8d20 --- /dev/null +++ b/plantcv/plantcv/annotate/img_registration.py @@ -0,0 +1,113 @@ +# Image Registration Based On User Selected Landmark Points + +import cv2 +from plantcv import plantcv as pcv +from plantcv.plantcv.annotate.points import _find_closest_pt +import matplotlib.pyplot as plt +import pickle as pkl + + +class ImageRegistrator: + """ + An interactive tool that takes user selected landmark points to register two images + """ + def __init__(self, ref_img, target_img, figsize=(12, 6), cmap='jet'): + """Initialize parameters. + + Keyword arguments/parameters: + ref_img = Reference image + target_img = Target image + figsize = optional parameter is the desired figure size (default figsize=(12,6)) + cmap = Width of line drawings. (default: 5) + + :param ref_img: image data + :param target_img: image data + :param figsize: tuple + :param cmap: str + """ + self.img_ref = ref_img + self.img_tar = target_img + + self.fig, self.axes = plt.subplots(1, 2, figsize=figsize) + self.axes[0].text(0, -100, + 'Collect points matching features between images. ' + 'Select location on reference image then target image. ' + '\nPlease first click on the reference image, then on the same point on the target image.' + '\nPlease select at least 4 pairs of points.') + + # assumption: any 3-d images whose 3rd dimension is 3 are rgb images + # This check to be replaced when image class implemented + dim_ref, dim_tar = ref_img.shape, target_img.shape + if len(dim_ref) == 3 and dim_ref[-1] == 3: + self.axes[0].imshow(cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)) + else: + self.axes[0].imshow(ref_img, cmap=cmap) + self.axes[0].set_title('Reference Image') + + if len(dim_tar) == 3 and dim_tar[-1] == 3: + self.axes[1].imshow(cv2.cvtColor(target_img, cv2.COLOR_BGR2RGB)) + else: + self.axes[1].imshow(target_img, cmap=cmap) + self.axes[1].set_title('Target Image') + + # Set useblit=True on most backends for enhanced performance. + # cursor = Cursor(axes[0], horizOn=True, vertOn=True, useblit=True, color='red', linewidth=2) + + self.points = [[], []] + self.events = [] + + # onclick = functools.partial(_onclick_, fig, axes, array_data, wvs) + + self.fig.canvas.mpl_connect('button_press_event', self.onclick) + + self.model = None + self.img_registered = None + + def left_click(self, idx_ax, x, y): + self.axes[idx_ax].plot(x, y, 'x', c='red') + self.points[idx_ax].append((x, y)) + + def right_click(self, idx_ax, x, y): + idx_remove, _ = _find_closest_pt((x, y), self.points[idx_ax]) + self.points[idx_ax].pop(idx_remove) + axplots = self.axes[idx_ax].lines + self.axes[idx_ax].lines.remove(axplots[idx_remove]) + + def onclick(self, event): + self.events.append(event) + + # collect points on reference image + if str(event.inaxes._subplotspec) == 'GridSpec(1, 2)[0:1, 0:1]': + # left click + if event.button == 1: + self.left_click(0, event.xdata, event.ydata) + # right click + else: + self.right_click(0, event.xdata, event.ydata) + + # collect points on target image + elif str(event.inaxes._subplotspec) == 'GridSpec(1, 2)[0:1, 1:2]': + if event.button == 1: + self.left_click(1, event.xdata, event.ydata) + else: + self.right_click(1, event.xdata, event.ydata) + self.fig.canvas.draw() + + def save_model(self, model_file="model"): + pkl.dump(self.model, open("{}.pkl".format(model_file), "wb")) + + def display_coords(self): + print("\nCoordinates for selected reference points: ") + for point_ref in self.points[0]: + print("\n{}".format(point_ref)) + print("\nCoordinates for selected target points: ") + for point_tar in self.points[1]: + print("\n{}".format(point_tar)) + + def regist(self): + # use warp function in plantcv + self.model, self.img_registered = pcv.transform.warp(self.img_tar, + self.img_ref, + self.points[1], + self.points[0], + method='ransac') diff --git a/plantcv/plantcv/transform/warp.py b/plantcv/plantcv/transform/warp.py index fb2dd9253..0d7bd830a 100644 --- a/plantcv/plantcv/transform/warp.py +++ b/plantcv/plantcv/transform/warp.py @@ -58,7 +58,8 @@ def warp(img, refimg, pts, refpts, method='default'): fatal_error('Please provide same number of corresponding coordinates.') if not(len(pts) >= 4 and len(refpts) >= 4): fatal_error('Please provide at least 4 pairs of points!') - # convert coordinates to int if they are not int + + # convert coordinates to int (if they are not int) pts = [tuple(map(int, tup)) for tup in pts] refpts = [tuple(map(int, tup)) for tup in refpts] @@ -73,7 +74,7 @@ def warp(img, refimg, pts, refpts, method='default'): rows_ref, cols_ref = shape_ref[0:2] rows_img, _ = shape_img[0:2] - # convert list of tuples to array for cv2 functions + # convert from lists (of tuples) to arrays (for cv2 functions) ptsarr = np.array(pts, dtype='float32') refptsarr = np.array(refpts, dtype='float32') @@ -130,21 +131,58 @@ def warp(img, refimg, pts, refpts, method='default'): markerSize=params.marker_size * res_ratio_r, thickness=params.line_thickness * res_ratio_r) else: - cv2.drawMarker(refimg_marked, pt, color=colors[i], markerType=cv2.MARKER_TRIANGLE_UP, - markerSize=params.marker_size * res_ratio_r, - thickness=params.line_thickness * res_ratio_r) - - debug_mode = params.debug - params.debug = None - - # make sure the input image for "overlay_two_imgs" is of dtype "uint8" such that it would be acceptable for - # overlay_two_imgs (cv2.cvtColor) - img_blend = overlay_two_imgs(_preprocess_img_dtype(warped_img), refimg_) - params.debug = debug_mode - - _debug(visual=img_marked, filename=os.path.join(params.debug_outdir, str(params.device) + "_img-to-warp.png")) - _debug(visual=refimg_marked, filename=os.path.join(params.debug_outdir, str(params.device) + "_img-ref.png")) - _debug(visual=img_blend, filename=os.path.join(params.debug_outdir, str(params.device) + "_warp_overlay.png")) + res_ratio_r = int(np.ceil(rows_ref / rows_img)) + res_ratio_i = 1 + # marker colors + colors = color_palette(len(pts)) + + # convert image types to accepted ones for cv2.cvtColor, also compatible with colors generated by color_palette + # (color_palette generated colors that are in range of (0,255)) + img_ = _preprocess_img_dtype(img) + refimg_ = _preprocess_img_dtype(refimg) + + # rgb image for colored markers on img + img_marked = img_.copy() + # convert to RGB image if not + if len(shape_img) == 2: + img_marked = cv2.cvtColor(img_marked, cv2.COLOR_GRAY2RGB) + + for i, pt in enumerate(pts): + if status[i][0] == 1: + cv2.drawMarker(img_marked, pt, color=colors[i], markerType=cv2.MARKER_CROSS, + markerSize=params.marker_size * res_ratio_i, + thickness=params.line_thickness * res_ratio_i) + else: + cv2.drawMarker(img_marked, pt, color=colors[i], markerType=cv2.MARKER_TRIANGLE_UP, + markerSize=params.marker_size * res_ratio_i, + thickness=params.line_thickness * res_ratio_i) + + # rgb image for colored markers on refimg + refimg_marked = refimg_.copy() + if len(shape_ref) == 2: + refimg_marked = cv2.cvtColor(refimg_marked, cv2.COLOR_GRAY2RGB) + + for i, pt in enumerate(refpts): + if status[i][0] == 1: + cv2.drawMarker(refimg_marked, pt, color=colors[i], markerType=cv2.MARKER_CROSS, + markerSize=params.marker_size * res_ratio_r, + thickness=params.line_thickness * res_ratio_r) + else: + cv2.drawMarker(refimg_marked, pt, color=colors[i], markerType=cv2.MARKER_TRIANGLE_UP, + markerSize=params.marker_size * res_ratio_r, + thickness=params.line_thickness * res_ratio_r) + + debug_mode = params.debug + params.debug = None + + # make sure the input image for "overlay_two_imgs" is of dtype "uint8" such that it would be acceptable for + # overlay_two_imgs (cv2.cvtColor) + img_blend = overlay_two_imgs(_preprocess_img_dtype(warped_img), refimg_) + params.debug = debug_mode + + _debug(visual=img_marked, filename=os.path.join(params.debug_outdir, str(params.device) + "_img-to-warp.png")) + _debug(visual=refimg_marked, filename=os.path.join(params.debug_outdir, str(params.device) + "_img-ref.png")) + _debug(visual=img_blend, filename=os.path.join(params.debug_outdir, str(params.device) + "_warp_overlay.png")) return warped_img, mat diff --git a/tests/plantcv/annotate/test_img_registration.py b/tests/plantcv/annotate/test_img_registration.py new file mode 100644 index 000000000..308e37650 --- /dev/null +++ b/tests/plantcv/annotate/test_img_registration.py @@ -0,0 +1,60 @@ +import os +import matplotlib +from plantcv.plantcv.annotate import ImageRegistrator + +def test_plantcv_transform_img_registration_gray(transform_test_data): + # generate fake testing images + img_ref = transform_test_data.create_test_img((12, 10)) + img_tar = transform_test_data.create_test_img((12, 10)) + img_registrator = ImageRegistrator(img_ref, img_tar) + assert len(img_registrator.events) == 0 + + +def test_plantcv_transform_img_registration(transform_test_data, tmpdir): + # Create a test tmp directory + cache_dir = tmpdir.mkdir("sub") + # generate fake testing images + img_ref = transform_test_data.create_test_img((12, 10, 3)) + img_tar = transform_test_data.create_test_img((12, 10, 3)) + img_registrator = ImageRegistrator(img_ref, img_tar) + + + # create mock events: left click + for pt in [(0, 0), (1, 0), (0, 3), (4, 4), (3,2)]: + # left click, left axis + e1 = matplotlib.backend_bases.MouseEvent(name="button_press_event", canvas=img_registrator.fig.canvas, + x=0, y=0, button=1) + e1.inaxes = img_registrator.axes[0] + e1.inaxes._subplotspec = matplotlib.gridspec.SubplotSpec(matplotlib.gridspec.GridSpec(1,2,img_registrator.fig), 0) + e1.xdata, e1.ydata = pt + img_registrator.onclick(e1) + # left click, right axis + e2 = matplotlib.backend_bases.MouseEvent(name="button_press_event", canvas=img_registrator.fig.canvas, + x=0, y=0, button=1) + e2.inaxes = img_registrator.axes[0] + e2.inaxes._subplotspec = matplotlib.gridspec.SubplotSpec(matplotlib.gridspec.GridSpec(1,2,img_registrator.fig), 1) + e2.xdata, e2.ydata = pt + img_registrator.onclick(e2) + # right click, left axis + e1_ = matplotlib.backend_bases.MouseEvent(name="button_press_event", canvas=img_registrator.fig.canvas, + x=0, y=0, button=3) + e1_.inaxes = img_registrator.axes[0] + e1_.inaxes._subplotspec = matplotlib.gridspec.SubplotSpec(matplotlib.gridspec.GridSpec(1, 2, img_registrator.fig), 0) + e1_.xdata, e1_.ydata = 3, 2 + img_registrator.onclick(e1_) + # right click, right axis + e2_ = matplotlib.backend_bases.MouseEvent(name="button_press_event", canvas=img_registrator.fig.canvas, + x=0, y=0, button=3) + e2_.inaxes = img_registrator.axes[0] + e2_.inaxes._subplotspec = matplotlib.gridspec.SubplotSpec(matplotlib.gridspec.GridSpec(1, 2, img_registrator.fig), 1) + e2_.xdata, e2_.ydata = 3, 2 + img_registrator.onclick(e2_) + img_registrator.regist() + img_registrator.display_coords() + img_registrator.save_model(model_file=os.path.join(cache_dir, "model")) + + assert img_registrator.model is not None and len(img_registrator.points[0]) == 4 and len(img_registrator.points[1]) == 4 and os.path.isfile(os.path.join(cache_dir, "model.pkl")) + assert img_registrator.model is not None \ + and len(img_registrator.points[0]) == 4 \ + and len(img_registrator.points[1]) == 4 \ + and os.path.isfile(os.path.join(cache_dir, "model.pkl"))