Skip to content

Commit

Permalink
Merge pull request #148 from mcflugen/mcflugen/absolute-import-mobility
Browse files Browse the repository at this point in the history
Absolute import of the deltametrics.mobility module
  • Loading branch information
mcflugen authored Dec 3, 2024
2 parents bab4c53 + d29c913 commit 929eddb
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 45 deletions.
87 changes: 47 additions & 40 deletions tests/test_mobility.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@

import deltametrics as dm
from deltametrics import cube
from deltametrics import mobility as mob
from deltametrics.sample_data import _get_rcm8_path
from deltametrics.mobility import _calculate_temporal_linear_slope
from deltametrics.mobility import calculate_channel_abandonment
from deltametrics.mobility import calculate_channel_decay
from deltametrics.mobility import calculate_channelized_response_variance
from deltametrics.mobility import calculate_planform_overlap
from deltametrics.mobility import calculate_reworking_fraction
from deltametrics.mobility import channel_presence
from deltametrics.mobility import check_inputs


rcm8_path = _get_rcm8_path()
Expand Down Expand Up @@ -62,7 +69,7 @@ def test_check_input_list_of_mask():
"""Test that a deltametrics.mask.BaseMask type can be used."""
# call checker function
assert isinstance(chmask, list)
chmap, landmap, basevalues, time_window, dim0 = mob.check_inputs(
chmap, landmap, basevalues, time_window, dim0 = check_inputs(
chmask, basevalues_idx=[0], window_idx=1, landmap=landmask)
# assert types
assert dim0 == 'time'
Expand All @@ -78,15 +85,15 @@ def test_check_input_single_mask_error():
"""Test that a deltametrics.mask.BaseMask type can be used."""
# call checker function
with pytest.raises(TypeError, match=r'Cannot input a Mask .*'):
chmap, landmap, basevalues, time_window, dim0 = mob.check_inputs(
chmap, landmap, basevalues, time_window, dim0 = check_inputs(
chmask[0], basevalues_idx=[0], window_idx=1,
landmap=landmask[0])


def test_check_xarrays():
"""Test that an xarray.DataArray can be used as an input."""
# call checker function
chmap, landmap, basevalues, time_window, dim0 = mob.check_inputs(
chmap, landmap, basevalues, time_window, dim0 = check_inputs(
ch_xarr, basevalues_idx=[0], window_idx=1,
landmap=land_xarr)
# assert types
Expand All @@ -100,7 +107,7 @@ def test_check_xarrays():
def test_check_list_ndarrays():
"""Test that a list of numpy.ndarray can be used as an input."""
# call checker function
chmap, landmap, basevalues, time_window, dim0 = mob.check_inputs(
chmap, landmap, basevalues, time_window, dim0 = check_inputs(
ch_arr_list, basevalues_idx=[0], window_idx=1,
landmap=land_arr_list)
# assert types
Expand All @@ -114,7 +121,7 @@ def test_check_list_ndarrays():
def test_check_ndarrays():
"""Test that a numpy.ndarray can be used as an input."""
# call checker function
chmap, landmap, basevalues, time_window, dim0 = mob.check_inputs(
chmap, landmap, basevalues, time_window, dim0 = check_inputs(
ch_arr, basevalues_idx=[0], window_idx=1,
landmap=land_arr)
# assert types
Expand All @@ -128,7 +135,7 @@ def test_check_ndarrays():
def test_check_basevalues_window():
"""Test that basevalues and window inputs work."""
# call checker function
chmap, landmap, basevalues, time_window, dim0 = mob.check_inputs(
chmap, landmap, basevalues, time_window, dim0 = check_inputs(
ch_arr, basevalues=[0], window=1, landmap=land_arr)
# assert types
assert dim0 == 'time'
Expand All @@ -141,7 +148,7 @@ def test_check_basevalues_window():
def test_check_input_nolandmask():
"""Test that the check input can run without a landmap."""
# call checker function
chmap, landmap, basevalues, time_window, dim0 = mob.check_inputs(
chmap, landmap, basevalues, time_window, dim0 = check_inputs(
chmask, basevalues_idx=[0], window_idx=1)
# assert types
assert dim0 == 'time'
Expand All @@ -158,7 +165,7 @@ def test_check_input_notbinary_chmap():
ch_nonbin[0, 1, 1] = 1
ch_nonbin[0, 1, 2] = 2
with pytest.raises(ValueError):
mob.check_inputs(ch_nonbin, basevalues_idx=[0], window_idx=1)
check_inputs(ch_nonbin, basevalues_idx=[0], window_idx=1)


@pytest.mark.xfail(reason='Removed binary check - to be added back later.')
Expand All @@ -169,104 +176,104 @@ def test_check_input_notbinary_landmap():
land_nonbin[0, 1, 2] = 2
ch_bin = np.zeros_like(land_nonbin)
with pytest.raises(ValueError):
mob.check_inputs(ch_bin, basevalues_idx=[0], window_idx=1,
check_inputs(ch_bin, basevalues_idx=[0], window_idx=1,
landmap=land_nonbin)


def test_check_input_invalid_chmap():
"""Test that an invalid channel map input will throw an error."""
with pytest.raises(TypeError):
mob.check_inputs('invalid', basevalues_idx=[0], window_idx=1,
check_inputs('invalid', basevalues_idx=[0], window_idx=1,
landmap=landmask)


def test_check_input_invalid_landmap():
"""Test that an invalid landmap will throw an error."""
with pytest.raises(TypeError):
mob.check_inputs(chmask, basevalues_idx=[0], window_idx=1,
check_inputs(chmask, basevalues_idx=[0], window_idx=1,
landmap='invalid')


def test_check_input_invalid_basevalues():
"""Test that a non-listable basevalues throws an error."""
with pytest.raises(TypeError):
mob.check_inputs(chmask, basevalues=0, window_idx='invalid')
check_inputs(chmask, basevalues=0, window_idx='invalid')


def test_check_input_invalid_basevalues_idx():
"""Test that a non-listable basevalues_idx throws an error."""
with pytest.raises(TypeError):
mob.check_inputs(chmask, basevalues_idx=0, window_idx='invalid')
check_inputs(chmask, basevalues_idx=0, window_idx='invalid')


def test_check_no_basevalues_error():
"""No basevalues will throw an error."""
with pytest.raises(ValueError):
mob.check_inputs(chmask, window_idx='invalid')
check_inputs(chmask, window_idx='invalid')


def test_check_input_invalid_time_window():
"""Test that a non-valid time_window throws an error."""
with pytest.raises(TypeError):
mob.check_inputs(chmask, basevalues_idx=[0], window='invalid')
check_inputs(chmask, basevalues_idx=[0], window='invalid')


def test_check_input_invalid_time_window_idx():
"""Test that a non-valid time_window_idx throws an error."""
with pytest.raises(TypeError):
mob.check_inputs(chmask, basevalues_idx=[0], window_idx='invalid')
check_inputs(chmask, basevalues_idx=[0], window_idx='invalid')


def test_check_no_time_window():
"""Test that no time_window throws an error."""
with pytest.raises(ValueError):
mob.check_inputs(chmask, basevalues_idx=[0])
check_inputs(chmask, basevalues_idx=[0])


def test_check_input_2dchanmask():
"""Test that an unexpected channel mask shape throws an error."""
with pytest.raises(TypeError):
mob.check_inputs(np.ones((5,)), basevalues_idx=[0], window_idx=1)
check_inputs(np.ones((5,)), basevalues_idx=[0], window_idx=1)


def test_check_input_diff_shapes():
"""Test that differently shaped channel and land masks throw an error."""
with pytest.raises(ValueError):
mob.check_inputs(chmask, basevalues_idx=[0], window_idx=1,
check_inputs(chmask, basevalues_idx=[0], window_idx=1,
landmap=np.ones((3, 3, 3)))


def test_check_input_1dlandmask():
"""Test a 1d landmask that will throw an error."""
with pytest.raises(TypeError):
mob.check_inputs(chmask, basevalues_idx=[0], window_idx=1,
check_inputs(chmask, basevalues_idx=[0], window_idx=1,
landmap=np.ones((10, 1)))


def test_check_input_exceedmaxvals():
"""Test a basevalue + time window combo that exceeds time indices."""
with pytest.raises(ValueError):
mob.check_inputs(chmask, basevalues_idx=[0], window_idx=100)
check_inputs(chmask, basevalues_idx=[0], window_idx=100)


def test_check_input_invalid_list():
"""Test a wrong list."""
with pytest.raises(TypeError):
mob.check_inputs(['str', 5, 1.], basevalues_idx=[0], window_idx=1)
check_inputs(['str', 5, 1.], basevalues_idx=[0], window_idx=1)


def test_check_input_list_wrong_shape():
"""Test list with wrongly shaped arrays."""
in_list = [np.zeros((5, 2, 1, 1)), np.zeros((5, 2, 2, 2))]
with pytest.raises(ValueError):
mob.check_inputs(in_list, basevalues_idx=[0], window_idx=100)
check_inputs(in_list, basevalues_idx=[0], window_idx=100)


@pytest.mark.xfail(
reason='Removed this functionality - do we want to blindly expand dims?')
def test_check_input_castlandmap():
"""Test ability to cast a 2D landmask to match 3D channelmap."""
chmap, landmap, bv, tw, dim0 = mob.check_inputs(
chmap, landmap, bv, tw, dim0 = check_inputs(
chmask, basevalues_idx=[0], window_idx=1,
landmap=landmask[0].mask[:, :])
assert np.shape(chmap) == np.shape(landmap)
Expand Down Expand Up @@ -312,37 +319,37 @@ def test_check_input_castlandmap():

def test_dry_decay():
"""Test dry fraction decay."""
dryfrac = mob.calculate_channel_decay(
dryfrac = calculate_channel_decay(
chmap, fsurf, basevalues_idx=basevalue, window_idx=time_window)
assert np.all(dryfrac == np.array([[0.75, 0.6875, 0.625, 0.5625, 0.5]]))


def test_planform_olap():
"""Test channel planform overlap."""
ophi = mob.calculate_planform_overlap(
ophi = calculate_planform_overlap(
chmap, fsurf, basevalues_idx=basevalue, window_idx=time_window)
assert pytest.approx(ophi.values) == np.array([[1., 0.66666667, 0.33333333,
0., -0.33333333]])


def test_reworking():
"""Test reworking index."""
fr = mob.calculate_reworking_fraction(
fr = calculate_reworking_fraction(
chmap, fsurf, basevalues_idx=basevalue, window_idx=time_window)
assert pytest.approx(fr.values) == np.array([[0., 0.08333333, 0.16666667,
0.25, 0.33333333]])


def test_channel_abandon():
"""Test channel abandonment function."""
ch_abandon = mob.calculate_channel_abandonment(
ch_abandon = calculate_channel_abandonment(
chmap, basevalues_idx=basevalue, window_idx=time_window)
assert np.all(ch_abandon == np.array([[0., 0.25, 0.5, 0.75, 1.]]))


def test_channel_presence():
"""Test channel presence with a regular array."""
chan_presence = mob.channel_presence(chmap)
chan_presence = channel_presence(chmap)
assert np.all(chan_presence == np.array([[0., 0.8, 0.2, 0.],
[0., 0.6, 0.4, 0.],
[0., 0.4, 0.6, 0.],
Expand All @@ -351,7 +358,7 @@ def test_channel_presence():

def test_channel_presence_xarray():
"""Test channel presence with an xarray."""
chan_presence = mob.channel_presence(chmap_xr)
chan_presence = channel_presence(chmap_xr)
assert np.all(chan_presence == np.array([[0., 0.8, 0.2, 0.],
[0., 0.6, 0.4, 0.],
[0., 0.4, 0.6, 0.],
Expand All @@ -360,7 +367,7 @@ def test_channel_presence_xarray():

def test_channel_presence_xarray_list():
"""Test channel presence with a list of xarrays."""
chan_presence = mob.channel_presence(chmap_xr_list)
chan_presence = channel_presence(chmap_xr_list)
assert np.all(chan_presence == np.array([[0., 0.8, 0.2, 0.],
[0., 0.6, 0.4, 0.],
[0., 0.4, 0.6, 0.],
Expand All @@ -369,7 +376,7 @@ def test_channel_presence_xarray_list():

def test_channel_presence_array_list():
"""Test channel presence with a list of arrays."""
chan_presence = mob.channel_presence(chmap_list)
chan_presence = channel_presence(chmap_list)
assert np.all(chan_presence == np.array([[0., 0.8, 0.2, 0.],
[0., 0.6, 0.4, 0.],
[0., 0.4, 0.6, 0.],
Expand All @@ -379,33 +386,33 @@ def test_channel_presence_array_list():
def test_invalid_list_channel_presence():
"""Test an invalid list."""
with pytest.raises(ValueError):
mob.channel_presence(['in', 'valid', 'list'])
channel_presence(['in', 'valid', 'list'])


def test_invalid_type_channel_presence():
"""Test an invalid input typing."""
with pytest.raises(TypeError):
mob.channel_presence('invalid type input')
channel_presence('invalid type input')


def test_calculate_noslopes():
"""Test calculate slopes w/ constants."""
arr = np.zeros((5, 3, 2))
slopes = mob._calculate_temporal_linear_slope(arr)
slopes = _calculate_temporal_linear_slope(arr)
assert np.all(slopes == 0.0)


def test_calculate_neg_slopes():
"""Test calculate slopes w/ negative values."""
arr = np.linspace(np.ones((5, 3)), np.zeros((5, 3)), 5)
slopes = mob._calculate_temporal_linear_slope(arr)
slopes = _calculate_temporal_linear_slope(arr)
assert np.all(slopes < 0.0)


def test_calculate_pos_slopes():
"""Test calculate slopes w/ positive values."""
arr = np.linspace(np.zeros((3, 2)), np.ones((3, 2)), 5)
slopes = mob._calculate_temporal_linear_slope(arr)
slopes = _calculate_temporal_linear_slope(arr)
assert np.all(slopes > 0.0)


Expand All @@ -416,10 +423,10 @@ def test_calculate_crv():
arr[0, :, :] = 1.0
arr[1, 0, :] = -5.0
# function
crv_mag, slopes, dir_crv = mob.calculate_channelized_response_variance(
crv_mag, slopes, dir_crv = calculate_channelized_response_variance(
arr, threshold=0.0)
assert np.all(crv_mag == np.abs(dir_crv))
assert np.all(slopes[0, :] < 0.0)
assert np.all(slopes[1:, :] > 0.0)
assert np.all(dir_crv[0, :] < 0.0)
assert np.all(dir_crv[1:, :] > 0.0)
assert np.all(dir_crv[1:, :] > 0.0)
9 changes: 4 additions & 5 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import numpy as np

from deltametrics import utils
from deltametrics import mobility as mob
from deltametrics.mobility import calculate_channel_abandonment
from deltametrics import sample_data


class TestNoStratigraphyError:

def test_needs_obj_argument(self):
Expand Down Expand Up @@ -123,7 +122,7 @@ def test_bad_inputs(self):

def test_linear_fit():
"""Test linear curve fitting."""
ch_abandon = mob.calculate_channel_abandonment(
ch_abandon = calculate_channel_abandonment(
chmap, basevalues_idx=basevalue, window_idx=time_window)
yfit, popts, cov, err = utils.curve_fit(ch_abandon, fit='linear')
assert pytest.approx(yfit) == np.array([4.76315477e-24, 2.50000000e-01,
Expand All @@ -137,7 +136,7 @@ def test_linear_fit():

def test_harmonic_fit():
"""Test harmonic curve fitting."""
ch_abandon = mob.calculate_channel_abandonment(
ch_abandon = calculate_channel_abandonment(
chmap, basevalues_idx=basevalue, window_idx=time_window)
yfit, popts, cov, err = utils.curve_fit(ch_abandon, fit='harmonic')
assert pytest.approx(yfit) == np.array([-0.25986438, 0.41294455,
Expand All @@ -151,7 +150,7 @@ def test_harmonic_fit():

def test_invalid_fit():
"""Test invalid fit parameter."""
ch_abandon = mob.calculate_channel_abandonment(
ch_abandon = calculate_channel_abandonment(
chmap, basevalues_idx=basevalue, window_idx=time_window)
with pytest.raises(ValueError):
utils.curve_fit(ch_abandon, fit='invalid')
Expand Down

0 comments on commit 929eddb

Please sign in to comment.