diff --git a/tobac/tests/test_utils.py b/tobac/tests/test_utils.py index 439dd79b..21298b88 100644 --- a/tobac/tests/test_utils.py +++ b/tobac/tests/test_utils.py @@ -584,3 +584,31 @@ def test_transform_feature_points_3D(): assert np.all(new_feat_df["hdim_1"] == [25, 30]) assert np.all(new_feat_df["hdim_2"] == [5, 15]) assert np.all(new_feat_df["vdim"] == [5, 10]) + + +def test_identify_feature_families(): + """tests tobac.utils.general.identify_feature_families""" + orig_feat_df_1 = tb_test.generate_single_feature( + 10, 30, 10, max_h1=50, max_h2=50, feature_num=1 + ) + orig_feat_df_2 = tb_test.generate_single_feature( + 30, 30, 20, max_h1=50, max_h2=50, feature_num=2 + ) + + orig_feat_df = tb_utils.combine_feature_dataframes( + [orig_feat_df_1, orig_feat_df_2], renumber_features=False + ) + + # make fake segmentation + test_arr = np.zeros((2, 50, 50), dtype=int) + test_arr[0, 5:15, 20:40] = 1 + test_arr[0, 15:40, 20:40] = 2 + + test_xr = xr.DataArray(data=test_arr, dims=["time", "hdim_1", "hdim_2"]) + + out_df, out_grid = tb_utils.general.identify_feature_families( + orig_feat_df, test_xr, return_grid=True, family_column_name="family" + ) + assert np.unique(out_df["family"] == 1) + assert np.all(out_grid[0, 5:15, 20:40] == 1) + assert np.all(out_grid[0, 15:40, 20:40] == 1) diff --git a/tobac/utils/general.py b/tobac/utils/general.py index 44f0177a..363a2ad6 100644 --- a/tobac/utils/general.py +++ b/tobac/utils/general.py @@ -4,9 +4,13 @@ import copy import logging +from typing import Union + import pandas as pd +import skimage from . import internal as internal_utils +from . import decorators import numpy as np import sklearn import sklearn.neighbors @@ -922,3 +926,88 @@ def standardize_track_dataset(TrackedFeatures, Mask, Projection=None): ds["ProjectionCoordinateSystem"] = Projection return ds + + +@decorators.iris_to_xarray +def identify_feature_families( + feature_df: pd.DataFrame, + in_segmentation: xr.DataArray, + return_grid: bool = False, + family_column_name: str = "feature_family_id", + unsegmented_point_values: int = 0, + below_threshold_values: int = -1, +) -> Union[tuple[pd.DataFrame, xr.DataArray], pd.DataFrame]: + """ + Function to identify families/storm systems by identifying where segmentation touches. + At a given time, segmentation areas are considered part of the same family if they + touch at any point. + + Parameters + ---------- + feature_df: pd.DataFrame + Input feature dataframe + in_segmentation: xr.DataArray + Input segmentation + return_grid: bool + Whether to return the segmentation grid showing families + family_column_name: str + The name in the output dataframe of the family ID + unsegmented_point_values: int + The value in the input segmentation for unsegmented but above threshold points + below_threshold_values: int + The value in the input segmentation for below threshold points + + Returns + ------- + pd.DataFrame and xr.DataArray or pd.DataFrame + Input dataframe with family IDs associated with each feature + if return_grid is True, the segmentation grid showing families is + also returned. + + """ + + # we need to label the data, but we currently label using skimage label, not dask label. + + # 3D should be 4-D (time, then 3 spatial). + # 2D should be 3-D (time, then 2 spatial) + is_3D = len(in_segmentation.shape) == 4 + seg_family_dict = dict() + out_families = copy.deepcopy(in_segmentation) + + for time_index in range(in_segmentation.shape[0]): + in_arr = np.array(in_segmentation.values[time_index]) + + segmented_arr = np.logical_and( + in_arr != unsegmented_point_values, in_arr != below_threshold_values + ) + # These are our families + family_labeled_data = skimage.measure.label( + segmented_arr, + ) + + # now we need to note feature->family relationship in the dataframe. + segmentation_props = skimage.measure.regionprops(in_arr) + + # associate feature ID -> family ID + for seg_area in segmentation_props: + if is_3D: + seg_family = family_labeled_data[ + seg_area.coords[0, 0], seg_area.coords[0, 1], seg_area.cords[0, 2] + ] + else: + seg_family = family_labeled_data[ + seg_area.coords[0, 0], seg_area.coords[0, 1] + ] + seg_family_dict[seg_area.label] = seg_family + + out_families[time_index] = segmented_arr + + family_series = pd.Series(seg_family_dict, name=family_column_name) + feature_series = pd.Series({x: x for x in seg_family_dict.keys()}, name="feature") + family_df = pd.concat([family_series, feature_series], axis=1) + out_df = feature_df.merge(family_df, on="feature", how="inner") + + if return_grid: + return out_df, out_families + else: + return out_df