diff --git a/pixi.lock b/pixi.lock index e513fa8..208d2c0 100644 --- a/pixi.lock +++ b/pixi.lock @@ -12927,7 +12927,7 @@ packages: name: xmipy version: 1.5.0 path: . - sha256: 65f685f0744f869687a292d6c3f96b3f1e6a818f78c3f8fa38fc14d7597f8ba7 + sha256: d6029626bb13ed898f606e430303f475b1f80e106297c659c227bd2e7f79ff0e requires_dist: - bmipy - numpy diff --git a/pyproject.toml b/pyproject.toml index e697c73..5ec2c23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,7 @@ py310 = ["py310"] py39 = ["py39"] [tool.pytest.ini_options] -addopts = "-vs" +addopts = "-v" testpaths = ["tests"] [tool.mypy] @@ -103,7 +103,7 @@ warn_unused_configs = true warn_redundant_casts = true warn_unused_ignores = true strict_equality = true -strict_concatenate = true +extra_checks = true check_untyped_defs = true disallow_untyped_decorators = true disallow_any_generics = true diff --git a/tests/conftest.py b/tests/conftest.py index ef408c2..c2f82e4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -83,7 +83,11 @@ def flopy_dis(tmp_path): @pytest.fixture def flopy_dis_mf6(flopy_dis, modflow_lib_path, request): - mf6 = XmiWrapper(lib_path=modflow_lib_path, working_directory=flopy_dis.sim_path) + mf6 = XmiWrapper( + lib_path=modflow_lib_path, + working_directory=flopy_dis.sim_path, + logger_level="DEBUG", + ) # If initialized, call finalize() at end of use request.addfinalizer(mf6.__del__) diff --git a/tests/test_mf6_dis_bmi.py b/tests/test_mf6_dis_bmi.py index a3eb312..17fc1b8 100644 --- a/tests/test_mf6_dis_bmi.py +++ b/tests/test_mf6_dis_bmi.py @@ -5,6 +5,11 @@ from xmipy.errors import InputError +def mf6_version_tuple(version: str) -> tuple[int, int, int]: + """Convert a version string into tuple, removing ".dev0" if present.""" + return tuple(map(int, version.split(".")[:3])) + + @pytest.fixture def flopy_dis_idomain_mf6(flopy_dis_idomain, modflow_lib_path, request): mf6 = XmiWrapper( @@ -81,20 +86,20 @@ def test_update_and_get_current_time(flopy_dis_mf6): assert mf6.get_current_time() == 3.0 -def test_get_var_type_double(flopy_dis_mf6): - mf6 = flopy_dis_mf6[1] - mf6.initialize() - - head_tag = mf6.get_var_address("X", "SLN_1") - assert mf6.get_var_type(head_tag) == "DOUBLE (90)" - - -def test_get_var_type_int(flopy_dis_mf6): +@pytest.mark.parametrize( + "addr,expected", + [ + (("X", "SLN_1"), "DOUBLE (90)"), + (("IACTIVE", "SLN_1"), "INTEGER (90)"), + (("ENDOFSIMULATION", "TDIS"), "LOGICAL"), + ], +) +def test_get_var_type(flopy_dis_mf6, addr, expected): mf6 = flopy_dis_mf6[1] mf6.initialize() - iactive_tag = mf6.get_var_address("IACTIVE", "SLN_1") - assert mf6.get_var_type(iactive_tag) == "INTEGER (90)" + addr_tag = mf6.get_var_address(*addr) + assert mf6.get_var_type(addr_tag) == expected def test_get_var_string(flopy_dis_mf6): @@ -166,6 +171,25 @@ def test_get_value_ptr_scalar(flopy_dis_mf6): assert grid_id[0] == 1 +def test_get_value_ptr_scalar_bool(flopy_dis_mf6): + """Test get_value_ptr_scalar with bool (LOGICAL) type.""" + mf6 = flopy_dis_mf6[1] + if mf6_version_tuple(mf6.get_version()) < (6, 5, 0): + pytest.skip("modflow-6.5.0 or later needed") + mf6.initialize() + + # Get end of simulation data + eos_tag = mf6.get_var_address("ENDOFSIMULATION", "TDIS") + eos_value = mf6.get_value_ptr_scalar(eos_tag) + np.testing.assert_array_equal(eos_value, [False]) + + # Change value, then check to see if it changed in lib via get_value + eos_value[0] = True + np.testing.assert_array_equal(mf6.get_value(eos_tag), [True]) + eos_value[0] = False + np.testing.assert_array_equal(mf6.get_value(eos_tag), [False]) + + def test_get_var_grid(flopy_dis_mf6): flopy_dis, mf6 = flopy_dis_mf6 mf6.initialize() @@ -309,6 +333,26 @@ def test_get_value_int_scalar(flopy_dis_idomain_mf6): ) +def test_get_value_bool(flopy_dis_idomain_mf6): + """Test get_value with bool (LOGICAL) type.""" + mf6 = flopy_dis_idomain_mf6[1] + if mf6_version_tuple(mf6.get_version()) < (6, 5, 0): + pytest.skip("modflow-6.5.0 or later needed") + mf6.initialize() + + # get scalar variable: + eos_tag = mf6.get_var_address("ENDOFSIMULATION", "TDIS") + assert mf6.get_var_rank(eos_tag) == 0 + + # Run each time step and check value + end_time = mf6.get_end_time() + while mf6.get_current_time() < end_time: + np.testing.assert_array_equal(mf6.get_value(eos_tag), [False]) + mf6.update() + # Check that simulation has ended + np.testing.assert_array_equal(mf6.get_value(eos_tag), [True]) + + def test_get_value_at_indices(flopy_dis_idomain_mf6): """Expects to be implemented as soon as `get_value_at_indices` is implemented""" mf6 = flopy_dis_idomain_mf6[1] @@ -337,6 +381,26 @@ def test_set_value(flopy_dis_mf6): assert mf6.get_value(mxit_tag).tolist() == [999] +def test_set_value_bool(flopy_dis_idomain_mf6): + """Test set_value with bool (LOGICAL) type.""" + mf6 = flopy_dis_idomain_mf6[1] + if mf6_version_tuple(mf6.get_version()) < (6, 5, 0): + pytest.skip("modflow-6.5.0 or later needed") + mf6.initialize() + + # Toggle end of simulation data + eos_tag = mf6.get_var_address("ENDOFSIMULATION", "TDIS") + np.testing.assert_array_equal(mf6.get_value(eos_tag), [False]) + mf6.set_value(eos_tag, np.array([True])) + np.testing.assert_array_equal(mf6.get_value(eos_tag), [True]) + mf6.set_value(eos_tag, np.array([False])) + np.testing.assert_array_equal(mf6.get_value(eos_tag), [False]) + + # Check wrong dtype + with pytest.raises(InputError): + mf6.set_value(eos_tag, np.array([0])) + + def test_set_value_at_indices(flopy_dis_mf6): """Expects to be implemented as soon as `set_value_at_indices` is implemented""" mf6 = flopy_dis_mf6[1] diff --git a/xmipy/xmiwrapper.py b/xmipy/xmiwrapper.py index 71fb56c..6c272ce 100644 --- a/xmipy/xmiwrapper.py +++ b/xmipy/xmiwrapper.py @@ -5,6 +5,7 @@ CDLL, POINTER, byref, + c_bool, c_char, c_char_p, c_double, @@ -430,26 +431,25 @@ def get_value_ptr(self, name: str) -> NDArray[Any]: var_type = self.get_var_type(name) var_type_lower = var_type.lower() - shape_array = self.get_var_shape(name) - - # convert shape array to python tuple - shape_tuple = tuple(np.trim_zeros(shape_array)) - ndim = len(shape_tuple) - if var_type_lower.startswith("double"): - arraytype = np.ctypeslib.ndpointer( - dtype=np.float64, ndim=ndim, shape=shape_tuple, flags="C" - ) + dtype: Any = np.float64 elif var_type_lower.startswith("float"): - arraytype = np.ctypeslib.ndpointer( - dtype=np.float32, ndim=ndim, shape=shape_tuple, flags="C" - ) + dtype = np.float32 elif var_type_lower.startswith("int"): - arraytype = np.ctypeslib.ndpointer( - dtype=np.int32, ndim=ndim, shape=shape_tuple, flags="C" - ) + dtype = np.int32 + elif var_type_lower.startswith("logical"): + dtype = bool else: raise InputError(f"Unsupported value type {var_type!r}") + + # convert shape array to python tuple + shape_array = self.get_var_shape(name) + shape_tuple = tuple(np.trim_zeros(shape_array)) + ndim = len(shape_tuple) + + arraytype = np.ctypeslib.ndpointer( + dtype=dtype, ndim=ndim, shape=shape_tuple, flags="C" + ) values = arraytype() self._execute_function( self.lib.get_value_ptr, @@ -457,25 +457,22 @@ def get_value_ptr(self, name: str) -> NDArray[Any]: byref(values), detail="for variable " + name, ) - return values.contents + return values.contents # type: ignore def get_value_ptr_scalar(self, name: str) -> NDArray[Any]: var_type = self.get_var_type(name) var_type_lower = var_type.lower() if var_type_lower.startswith("double"): - arraytype = np.ctypeslib.ndpointer( - dtype=np.float64, ndim=1, shape=(1,), flags="C" - ) + dtype: Any = np.float64 elif var_type_lower.startswith("float"): - arraytype = np.ctypeslib.ndpointer( - dtype=np.float32, ndim=1, shape=(1,), flags="C" - ) + dtype = np.float32 elif var_type_lower.startswith("int"): - arraytype = np.ctypeslib.ndpointer( - dtype=np.int32, ndim=1, shape=(1,), flags="C" - ) + dtype = np.int32 + elif var_type_lower.startswith("logical"): + dtype = bool else: raise InputError(f"Unsupported value type {var_type!r}") + arraytype = np.ctypeslib.ndpointer(dtype=dtype, ndim=1, shape=(1,), flags="C") values = arraytype() self._execute_function( self.lib.get_value_ptr, @@ -483,7 +480,7 @@ def get_value_ptr_scalar(self, name: str) -> NDArray[Any]: byref(values), detail="for variable " + name, ) - return values.contents + return values.contents # type: ignore def get_value_at_indices( self, name: str, dest: NDArray[Any], inds: NDArray[np.int32] @@ -511,6 +508,14 @@ def set_value(self, name: str, values: NDArray[Any]) -> None: c_char_p(name.encode()), byref(values.ctypes.data_as(POINTER(c_int))), ) + elif var_type_lower.startswith("logical"): + if values.dtype != bool: + raise InputError("Array should have bool elements") + self._execute_function( + self.lib.set_value_bool, + c_char_p(name.encode()), + byref(values.ctypes.data_as(POINTER(c_bool))), + ) else: raise InputError("Unsupported value type")