diff --git a/.github/workflows/check_formatting.yml b/.github/workflows/check_formatting.yml index bfdc1542..b19a31fd 100644 --- a/.github/workflows/check_formatting.yml +++ b/.github/workflows/check_formatting.yml @@ -19,4 +19,4 @@ jobs: shell: bash -l {0} run: mamba install --quiet --yes --file requirements.txt black && - black tobac --check --diff + black tobac --check --diff diff --git a/tobac/tests/test_convert.py b/tobac/tests/test_convert.py index d5b6bbe8..ca2228fd 100644 --- a/tobac/tests/test_convert.py +++ b/tobac/tests/test_convert.py @@ -174,8 +174,9 @@ def test_function_kwarg(test_input, kwarg=None): def test_function_tuple_output(test_input, kwarg=None): return (test_input, test_input) - decorated_function_kwarg = decorator(test_function_kwarg) - decorated_function_tuple = decorator(test_function_tuple_output) + decorator_i = decorator() + decorated_function_kwarg = decorator_i(test_function_kwarg) + decorated_function_tuple = decorator_i(test_function_tuple_output) if input_types[0] == xarray.DataArray: data = xarray.DataArray.from_iris(tobac.testing.make_simple_sample_data_2D()) @@ -227,7 +228,8 @@ def test_xarray_workflow(): data_xarray = xarray.DataArray.from_iris(deepcopy(data)) # Testing the get_spacings utility - get_spacings_xarray = xarray_to_iris(tobac.utils.get_spacings) + xarray_to_iris_i = xarray_to_iris() + get_spacings_xarray = xarray_to_iris_i(tobac.utils.get_spacings) dxy, dt = tobac.utils.get_spacings(data) dxy_xarray, dt_xarray = get_spacings_xarray(data_xarray) @@ -235,7 +237,7 @@ def test_xarray_workflow(): assert dt == dt_xarray # Testing feature detection - feature_detection_xarray = xarray_to_iris( + feature_detection_xarray = xarray_to_iris_i( tobac.feature_detection.feature_detection_multithreshold ) features = tobac.feature_detection.feature_detection_multithreshold( @@ -246,7 +248,7 @@ def test_xarray_workflow(): assert_frame_equal(features, features_xarray) # Testing the segmentation - segmentation_xarray = xarray_to_iris(tobac.segmentation.segmentation) + segmentation_xarray = xarray_to_iris_i(tobac.segmentation.segmentation) mask, features = tobac.segmentation.segmentation(features, data, dxy, threshold=1.0) mask_xarray, features_xarray = segmentation_xarray( features_xarray, data_xarray, dxy_xarray, threshold=1.0 @@ -255,7 +257,7 @@ def test_xarray_workflow(): assert (mask.data == mask_xarray.to_iris().data).all() # testing tracking - tracking_xarray = xarray_to_iris(tobac.tracking.linking_trackpy) + tracking_xarray = xarray_to_iris_i(tobac.tracking.linking_trackpy) track = tobac.tracking.linking_trackpy(features, data, dt, dxy, v_max=100.0) track_xarray = tracking_xarray( features_xarray, data_xarray, dt_xarray, dxy_xarray, v_max=100.0 diff --git a/tobac/utils/bulk_statistics.py b/tobac/utils/bulk_statistics.py index f27f9fd5..55fe17fc 100644 --- a/tobac/utils/bulk_statistics.py +++ b/tobac/utils/bulk_statistics.py @@ -147,7 +147,7 @@ def get_statistics( return features -@decorators.iris_to_xarray +@decorators.iris_to_xarray() def get_statistics_from_mask( features: pd.DataFrame, segmentation_mask: xr.DataArray, diff --git a/tobac/utils/decorators.py b/tobac/utils/decorators.py index 33b012f7..afe90d65 100644 --- a/tobac/utils/decorators.py +++ b/tobac/utils/decorators.py @@ -3,346 +3,430 @@ import functools import warnings +import iris.cube +import pandas as pd +import xarray as xr -def iris_to_xarray(func): - """Decorator that converts all input of a function that is in the form of - Iris cubes into xarray DataArrays and converts all outputs with type - xarray DataArrays back into Iris cubes. +def _conv_kwargs_iris_to_xarray(conv_kwargs: dict): + """ + Internal function to convert iris cube kwargs to xarray dataarrays Parameters ---------- - func : function - Function to be decorated + conv_kwargs : dict + Input kwargs to convert Returns ------- - wrapper : function - Function including decorator + dict + Output keyword arguments without any Iris Cubes """ + return { + key: xr.DataArray.from_iris(arg) if isinstance(arg, iris.cube.Cube) else arg + for key, arg in zip(conv_kwargs.keys(), conv_kwargs.values()) + } - import iris - import xarray - - @functools.wraps(func) - def wrapper(*args, **kwargs): - # print(kwargs) - if any([type(arg) == iris.cube.Cube for arg in args]) or any( - [type(arg) == iris.cube.Cube for arg in kwargs.values()] - ): - # print("converting iris to xarray and back") - args = tuple( - [ - xarray.DataArray.from_iris(arg) - if type(arg) == iris.cube.Cube - else arg - for arg in args - ] - ) - kwargs_new = dict( - zip( - kwargs.keys(), - [ - xarray.DataArray.from_iris(arg) - if type(arg) == iris.cube.Cube - else arg - for arg in kwargs.values() - ], - ) - ) - # print(args) - # print(kwargs) - output = func(*args, **kwargs_new) - if type(output) == tuple: - output = tuple( - [ - xarray.DataArray.to_iris(output_item) - if type(output_item) == xarray.DataArray - else output_item - for output_item in output - ] - ) - elif type(output) == xarray.DataArray: - output = xarray.DataArray.to_iris(output) - # if output is neither tuple nor an xr.DataArray - else: - output = func(*args, **kwargs) - else: - output = func(*args, **kwargs) - return output +def _conv_kwargs_irispandas_to_xarray(conv_kwargs: dict): + """ + Internal function to convert iris cube and pandas dataframe kwargs to xarray dataarrays - return wrapper + Parameters + ---------- + conv_kwargs : dict + Input kwargs to convert + Returns + ------- + dict + Output keyword arguments without any Iris Cubes or pandas dataframes -def xarray_to_iris(func): - """Decorator that converts all input of a function that is in the form of - xarray DataArrays into Iris cubes and converts all outputs with type - Iris cubes back into xarray DataArrays. + """ + return { + key: xr.DataArray.from_iris(arg) + if isinstance(arg, iris.cube.Cube) + else arg.to_xarray() + if isinstance(arg, pd.DataFrame) + else arg + for key, arg in zip(conv_kwargs.keys(), conv_kwargs.values()) + } + + +def _conv_kwargs_xarray_to_iris(conv_kwargs: dict): + """ + Internal function to convert xarray dataarray kwargs back to iris cubes Parameters ---------- - func : function - Function to be decorated. + conv_kwargs : dict + Input kwargs to convert Returns ------- - wrapper : function - Function including decorator. - - Examples - -------- - >>> segmentation_xarray = xarray_to_iris(segmentation) + dict + Output keyword arguments with all xarray dataarrays converted back to + iris cubes + """ + return { + key: xr.DataArray.to_iris(arg) if isinstance(arg, xr.DataArray) else arg + for key, arg in zip(conv_kwargs.keys(), conv_kwargs.values()) + } - This line creates a new function that can process xarray fields and - also outputs fields in xarray format, but otherwise works just like - the original function: - >>> mask_xarray, features = segmentation_xarray( - features, data_xarray, dxy, threshold - ) +def _conv_kwargs_xarray_to_irispandas(conv_kwargs: dict): """ + Internal function to convert xarray dataarrays back to iris cubes/pandas dataframes - import iris - import xarray - - @functools.wraps(func) - def wrapper(*args, **kwargs): - # print(args) - # print(kwargs) - if any([type(arg) == xarray.DataArray for arg in args]) or any( - [type(arg) == xarray.DataArray for arg in kwargs.values()] - ): - # print("converting xarray to iris and back") - args = tuple( - [ - xarray.DataArray.to_iris(arg) - if type(arg) == xarray.DataArray - else arg - for arg in args - ] - ) - if kwargs: - kwargs_new = dict( - zip( - kwargs.keys(), - [ - xarray.DataArray.to_iris(arg) - if type(arg) == xarray.DataArray - else arg - for arg in kwargs.values() - ], - ) - ) - else: - kwargs_new = kwargs - # print(args) + Parameters + ---------- + conv_kwargs : dict + Input kwargs to convert + + Returns + ------- + dict + Output keyword arguments with all xarray dataarrays converted back to + iris cubes + """ + return { + key: xr.DataArray.to_iris(arg) + if isinstance(arg, xr.DataArray) + else arg.to_dataframe() + if isinstance(arg, xr.Dataset) + else arg + for key, arg in zip(conv_kwargs.keys(), conv_kwargs.values()) + } + + +def iris_to_xarray(save_iris_info: bool = False): + def iris_to_xarray_i(func): + """Decorator that converts all input of a function that is in the form of + Iris cubes into xarray DataArrays and converts all outputs with type + xarray DataArrays back into Iris cubes. + + Parameters + ---------- + func : function + Function to be decorated + + Returns + ------- + wrapper : function + Function including decorator + """ + + import iris + import iris.cube + import xarray + + @functools.wraps(func) + def wrapper(*args, **kwargs): # print(kwargs) - output = func(*args, **kwargs_new) - if type(output) == tuple: - output = tuple( + + if save_iris_info: + if any([(type(arg) == iris.cube.Cube) for arg in args]) or any( + [(type(arg) == iris.cube.Cube) for arg in kwargs.values()] + ): + kwargs["converted_from_iris"] = True + else: + kwargs["converted_from_iris"] = False + + if any([type(arg) == iris.cube.Cube for arg in args]) or any( + [type(arg) == iris.cube.Cube for arg in kwargs.values()] + ): + # print("converting iris to xarray and back") + args = tuple( [ - xarray.DataArray.from_iris(output_item) - if type(output_item) == iris.cube.Cube - else output_item - for output_item in output + xarray.DataArray.from_iris(arg) + if type(arg) == iris.cube.Cube + else arg + for arg in args ] ) + kwargs_new = _conv_kwargs_iris_to_xarray(kwargs) + # print(args) + # print(kwargs) + output = func(*args, **kwargs_new) + if type(output) == tuple: + output = tuple( + [ + xarray.DataArray.to_iris(output_item) + if type(output_item) == xarray.DataArray + else output_item + for output_item in output + ] + ) + elif type(output) == xarray.DataArray: + output = xarray.DataArray.to_iris(output) + # if output is neither tuple nor an xr.DataArray + else: - if type(output) == iris.cube.Cube: - output = xarray.DataArray.from_iris(output) + output = func(*args, **kwargs) + return output - else: - output = func(*args, **kwargs) - # print(output) - return output + return wrapper - return wrapper + return iris_to_xarray_i -def irispandas_to_xarray(func): - """Decorator that converts all input of a function that is in the form of - Iris cubes/pandas Dataframes into xarray DataArrays/xarray Datasets and - converts all outputs with the type xarray DataArray/xarray Dataset - back into Iris cubes/pandas Dataframes. +def xarray_to_iris(): + def xarray_to_iris_i(func): + """Decorator that converts all input of a function that is in the form of + xarray DataArrays into Iris cubes and converts all outputs with type + Iris cubes back into xarray DataArrays. - Parameters - ---------- - func : function - Function to be decorated. + Parameters + ---------- + func : function + Function to be decorated. - Returns - ------- - wrapper : function - Function including decorator. - """ - import iris - import xarray - import pandas as pd - - @functools.wraps(func) - def wrapper(*args, **kwargs): - # print(kwargs) - if any( - [(type(arg) == iris.cube.Cube or type(arg) == pd.DataFrame) for arg in args] - ) or any( - [ - (type(arg) == iris.cube.Cube or type(arg) == pd.DataFrame) - for arg in kwargs.values() - ] - ): - # print("converting iris to xarray and back") - args = tuple( + Returns + ------- + wrapper : function + Function including decorator. + + Examples + -------- + >>> segmentation_xarray_conv = xarray_to_iris() + >>> segmentation_xarray = segmentation_xarray_conv(segmentation) + + This line creates a new function that can process xarray fields and + also outputs fields in xarray format, but otherwise works just like + the original function: + + >>> mask_xarray, features = segmentation_xarray( + features, data_xarray, dxy, threshold + ) + """ + + import iris + import xarray + + @functools.wraps(func) + def wrapper(*args, **kwargs): + # print(args) + # print(kwargs) + if any([type(arg) == xarray.DataArray for arg in args]) or any( + [type(arg) == xarray.DataArray for arg in kwargs.values()] + ): + # print("converting xarray to iris and back") + args = tuple( + [ + xarray.DataArray.to_iris(arg) + if type(arg) == xarray.DataArray + else arg + for arg in args + ] + ) + if kwargs: + kwargs_new = _conv_kwargs_xarray_to_iris(kwargs) + else: + kwargs_new = kwargs + # print(args) + # print(kwargs) + output = func(*args, **kwargs_new) + if type(output) == tuple: + output = tuple( + [ + xarray.DataArray.from_iris(output_item) + if type(output_item) == iris.cube.Cube + else output_item + for output_item in output + ] + ) + else: + if type(output) == iris.cube.Cube: + output = xarray.DataArray.from_iris(output) + + else: + output = func(*args, **kwargs) + # print(output) + return output + + return wrapper + + return xarray_to_iris_i + + +def irispandas_to_xarray(save_iris_info: bool = False): + def irispandas_to_xarray_i(func): + """Decorator that converts all input of a function that is in the form of + Iris cubes/pandas Dataframes into xarray DataArrays/xarray Datasets and + converts all outputs with the type xarray DataArray/xarray Dataset + back into Iris cubes/pandas Dataframes. + + Parameters + ---------- + func : function + Function to be decorated. + + Returns + ------- + wrapper : function + Function including decorator. + """ + import iris + import iris.cube + import xarray + import pandas as pd + + @functools.wraps(func) + def wrapper(*args, **kwargs): + # pass if we did an iris conversion. + if save_iris_info: + if any([(type(arg) == iris.cube.Cube) for arg in args]) or any( + [(type(arg) == iris.cube.Cube) for arg in kwargs.values()] + ): + kwargs["converted_from_iris"] = True + else: + kwargs["converted_from_iris"] = False + + # print(kwargs) + if any( [ - xarray.DataArray.from_iris(arg) - if type(arg) == iris.cube.Cube - else arg.to_xarray() - if type(arg) == pd.DataFrame - else arg + (type(arg) == iris.cube.Cube or type(arg) == pd.DataFrame) for arg in args ] - ) - kwargs = dict( - zip( - kwargs.keys(), + ) or any( + [ + (type(arg) == iris.cube.Cube or type(arg) == pd.DataFrame) + for arg in kwargs.values() + ] + ): + # print("converting iris to xarray and back") + args = tuple( [ xarray.DataArray.from_iris(arg) if type(arg) == iris.cube.Cube else arg.to_xarray() if type(arg) == pd.DataFrame else arg - for arg in kwargs.values() - ], - ) - ) - - output = func(*args, **kwargs) - if type(output) == tuple: - output = tuple( - [ - xarray.DataArray.to_iris(output_item) - if type(output_item) == xarray.DataArray - else output_item.to_dataframe() - if type(output_item) == xarray.Dataset - else output_item - for output_item in output + for arg in args ] ) + kwargs = _conv_kwargs_irispandas_to_xarray(kwargs) + + output = func(*args, **kwargs) + if type(output) == tuple: + output = tuple( + [ + xarray.DataArray.to_iris(output_item) + if type(output_item) == xarray.DataArray + else output_item.to_dataframe() + if type(output_item) == xarray.Dataset + else output_item + for output_item in output + ] + ) + else: + if type(output) == xarray.DataArray: + output = xarray.DataArray.to_iris(output) + elif type(output) == xarray.Dataset: + output = output.to_dataframe() + else: - if type(output) == xarray.DataArray: - output = xarray.DataArray.to_iris(output) - elif type(output) == xarray.Dataset: - output = output.to_dataframe() + output = func(*args, **kwargs) + return output - else: - output = func(*args, **kwargs) - return output + return wrapper - return wrapper + return irispandas_to_xarray_i -def xarray_to_irispandas(func): - """Decorator that converts all input of a function that is in the form of - DataArrays/xarray Datasets into xarray Iris cubes/pandas Dataframes and - converts all outputs with the type Iris cubes/pandas Dataframes back into - xarray DataArray/xarray Dataset. +def xarray_to_irispandas(): + def xarray_to_irispandas_i(func): + """Decorator that converts all input of a function that is in the form of + DataArrays/xarray Datasets into xarray Iris cubes/pandas Dataframes and + converts all outputs with the type Iris cubes/pandas Dataframes back into + xarray DataArray/xarray Dataset. - Parameters - ---------- - func : function - Function to be decorated. + Parameters + ---------- + func : function + Function to be decorated. - Returns - ------- - wrapper : function - Function including decorator. + Returns + ------- + wrapper : function + Function including decorator. - Examples - -------- - >>> linking_trackpy_xarray = xarray_to_irispandas( - linking_trackpy - ) + Examples + -------- + >>> linking_trackpy_xarray = xarray_to_irispandas( + linking_trackpy + ) - This line creates a new function that can process xarray inputs and - also outputs in xarray formats, but otherwise works just like the - original function: + This line creates a new function that can process xarray inputs and + also outputs in xarray formats, but otherwise works just like the + original function: - >>> track_xarray = linking_trackpy_xarray( - features_xarray, field_xarray, dt, dx - ) - """ - import iris - import xarray - import pandas as pd - - @functools.wraps(func) - def wrapper(*args, **kwargs): - # print(args) - # print(kwargs) - if any( - [ - (type(arg) == xarray.DataArray or type(arg) == xarray.Dataset) - for arg in args - ] - ) or any( - [ - (type(arg) == xarray.DataArray or type(arg) == xarray.Dataset) - for arg in kwargs.values() - ] - ): - # print("converting xarray to iris and back") - args = tuple( - [ - xarray.DataArray.to_iris(arg) - if type(arg) == xarray.DataArray - else arg.to_dataframe() - if type(arg) == xarray.Dataset - else arg - for arg in args - ] + >>> track_xarray = linking_trackpy_xarray( + features_xarray, field_xarray, dt, dx ) - if kwargs: - kwargs_new = dict( - zip( - kwargs.keys(), - [ - xarray.DataArray.to_iris(arg) - if type(arg) == xarray.DataArray - else arg.to_dataframe() - if type(arg) == xarray.Dataset - else arg - for arg in kwargs.values() - ], - ) - ) - else: - kwargs_new = kwargs + """ + import iris + import xarray + import pandas as pd + + @functools.wraps(func) + def wrapper(*args, **kwargs): # print(args) # print(kwargs) - output = func(*args, **kwargs_new) - if type(output) == tuple: - output = tuple( + if any( + [ + (type(arg) == xarray.DataArray or type(arg) == xarray.Dataset) + for arg in args + ] + ) or any( + [ + (type(arg) == xarray.DataArray or type(arg) == xarray.Dataset) + for arg in kwargs.values() + ] + ): + # print("converting xarray to iris and back") + args = tuple( [ - xarray.DataArray.from_iris(output_item) - if type(output_item) == iris.cube.Cube - else output_item.to_xarray() - if type(output_item) == pd.DataFrame - else output_item - for output_item in output + xarray.DataArray.to_iris(arg) + if type(arg) == xarray.DataArray + else arg.to_dataframe() + if type(arg) == xarray.Dataset + else arg + for arg in args ] ) + if kwargs: + kwargs_new = _conv_kwargs_xarray_to_irispandas(kwargs) + else: + kwargs_new = kwargs + # print(args) + # print(kwargs) + output = func(*args, **kwargs_new) + if type(output) == tuple: + output = tuple( + [ + xarray.DataArray.from_iris(output_item) + if type(output_item) == iris.cube.Cube + else output_item.to_xarray() + if type(output_item) == pd.DataFrame + else output_item + for output_item in output + ] + ) + else: + if type(output) == iris.cube.Cube: + output = xarray.DataArray.from_iris(output) + elif type(output) == pd.DataFrame: + output = output.to_xarray() + else: - if type(output) == iris.cube.Cube: - output = xarray.DataArray.from_iris(output) - elif type(output) == pd.DataFrame: - output = output.to_xarray() + output = func(*args, **kwargs) + # print(output) + return output - else: - output = func(*args, **kwargs) - # print(output) - return output + return wrapper - return wrapper + return xarray_to_irispandas_i def njit_if_available(func, **kwargs): diff --git a/tobac/utils/general.py b/tobac/utils/general.py index 44f0177a..66a31813 100644 --- a/tobac/utils/general.py +++ b/tobac/utils/general.py @@ -637,7 +637,7 @@ def combine_feature_dataframes( return combined_sorted -@internal_utils.irispandas_to_xarray +@internal_utils.irispandas_to_xarray() def transform_feature_points( features, new_dataset, diff --git a/tobac/utils/internal/basic.py b/tobac/utils/internal/basic.py index 8dc041de..0efd28a5 100644 --- a/tobac/utils/internal/basic.py +++ b/tobac/utils/internal/basic.py @@ -285,7 +285,7 @@ def find_axis_from_coord( raise ValueError("variable_arr must be Iris Cube or Xarray DataArray") -@irispandas_to_xarray +@irispandas_to_xarray() def detect_latlon_coord_name( in_dataset: Union[xr.DataArray, iris.cube.Cube], latitude_name: Union[str, None] = None,