From c08c55d24017a0d15ae29e4b5fcdebc1c0ea488f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Wed, 11 May 2022 12:39:34 +0200 Subject: [PATCH] ENH: improve error messages when add_field received a malformed function argument --- nose_unit.cfg | 2 +- tests/tests.yaml | 1 + yt/data_objects/tests/test_add_field.py | 103 ++++++++++++++++++ yt/fields/field_detector.py | 4 +- yt/fields/field_info_container.py | 26 +++++ yt/frontends/amrvac/fields.py | 1 - .../volume_rendering/off_axis_projection.py | 4 +- .../volume_rendering/old_camera.py | 6 +- 8 files changed, 139 insertions(+), 8 deletions(-) create mode 100644 yt/data_objects/tests/test_add_field.py diff --git a/nose_unit.cfg b/nose_unit.cfg index f67d56d6e4c..c717986eed9 100644 --- a/nose_unit.cfg +++ b/nose_unit.cfg @@ -6,5 +6,5 @@ nologcapture=1 verbosity=2 where=yt with-timer=1 -ignore-files=(test_load_errors.py|test_load_sample.py|test_commons.py|test_ambiguous_fields.py|test_field_access_pytest.py|test_save.py|test_line_annotation_unit.py|test_eps_writer.py|test_registration.py|test_invalid_origin.py|test_outputs_pytest\.py|test_normal_plot_api\.py|test_load_archive\.py|test_stream_particles\.py|test_file_sanitizer\.py) +ignore-files=(test_load_errors.py|test_load_sample.py|test_commons.py|test_ambiguous_fields.py|test_field_access_pytest.py|test_save.py|test_line_annotation_unit.py|test_eps_writer.py|test_registration.py|test_invalid_origin.py|test_outputs_pytest\.py|test_normal_plot_api\.py|test_load_archive\.py|test_stream_particles\.py|test_file_sanitizer\.py|test_add_field\.py) exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF diff --git a/tests/tests.yaml b/tests/tests.yaml index fcd810a3590..befaae62d00 100644 --- a/tests/tests.yaml +++ b/tests/tests.yaml @@ -201,6 +201,7 @@ other_tests: - "--ignore-files=test_outputs_pytest\\.py" - "--ignore-files=test_normal_plot_api\\.py" - "--ignore-file=test_file_sanitizer\\.py" + - "--ignore-file=test_add_field\\.py" - "--exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF" - "--exclude-test=yt.frontends.adaptahop.tests.test_outputs" - "--exclude-test=yt.frontends.stream.tests.test_stream_particles.test_stream_non_cartesian_particles" diff --git a/yt/data_objects/tests/test_add_field.py b/yt/data_objects/tests/test_add_field.py new file mode 100644 index 00000000000..f6cc5dafde5 --- /dev/null +++ b/yt/data_objects/tests/test_add_field.py @@ -0,0 +1,103 @@ +from functools import partial + +import pytest + +from yt.testing import fake_random_ds + + +def test_add_field_lambda(): + ds = fake_random_ds(16) + + ds.add_field( + ("bacon", "spam"), + lambda field, data: data["gas", "density"], + sampling_type="cell", + ) + + +def test_add_field_partial(): + ds = fake_random_ds(16) + + def _spam(field, data, factor): + return factor * data["gas", "density"] + + ds.add_field( + ("bacon", "spam"), + partial(_spam, factor=1), + sampling_type="cell", + ) + + +def test_add_field_arbitrary_callable(): + ds = fake_random_ds(16) + + class Spam: + def __call__(self, field, data): + return data["gas", "density"] + + ds.add_field(("bacon", "spam"), Spam(), sampling_type="cell") + + +def test_add_field_uncallable(): + ds = fake_random_ds(16) + + class Spam: + pass + + with pytest.raises( + TypeError, match=r"Expected a callable object, got .* with type .*" + ): + ds.add_field(("bacon", "spam"), Spam(), sampling_type="cell") + + +def test_add_field_wrong_signature(): + ds = fake_random_ds(16) + + def _spam(data, field): + return data["gas", "density"] + + with pytest.raises( + TypeError, + match=( + r"Received field function with invalid signature\. " + r"Expected exactly 2 positional parameters \('field', 'data'\), got \('data', 'field'\)" + ), + ): + ds.add_field(("bacon", "spam"), _spam, sampling_type="cell") + + +def test_add_field_wrong_signature_lambda(): + ds = fake_random_ds(16) + + with pytest.raises( + TypeError, + match=( + r"Received field function with invalid signature\. " + r"Expected exactly 2 positional parameters \('field', 'data'\), got \('data', 'field'\)" + ), + ): + ds.add_field( + ("bacon", "spam"), + lambda data, field: data["gas", "density"], + sampling_type="cell", + ) + + +def test_add_field_keyword_only(): + ds = fake_random_ds(16) + + def _spam(field, *, data): + return data["gas", "density"] + + with pytest.raises( + TypeError, + match=( + r"Received field function .* with invalid signature\. " + r"field and data parameters must accept positional values \(they cannot be keyword-only\)" + ), + ): + ds.add_field( + ("bacon", "spam"), + _spam, + sampling_type="cell", + ) diff --git a/yt/fields/field_detector.py b/yt/fields/field_detector.py index bebdb6f7be3..ce5ee6c5221 100644 --- a/yt/fields/field_detector.py +++ b/yt/fields/field_detector.py @@ -101,6 +101,8 @@ def _reshape_vals(self, arr): return arr.reshape(self.ActiveDimensions, order="C") def __missing__(self, item): + from yt.fields.derived_field import NullFunc + if not isinstance(item, tuple): field = ("unknown", item) else: @@ -115,7 +117,7 @@ def __missing__(self, item): # Note that the *only* way this works is if we also fix our field # dependencies during checking. Bug #627 talks about this. item = self.ds._last_freq - if finfo is not None and finfo._function.__name__ != "NullFunc": + if finfo is not None and finfo._function is not NullFunc: try: for param, param_v in permute_params.items(): for v in param_v: diff --git a/yt/fields/field_info_container.py b/yt/fields/field_info_container.py index fd0c318d6f3..3f63e65da7e 100644 --- a/yt/fields/field_info_container.py +++ b/yt/fields/field_info_container.py @@ -1,4 +1,6 @@ +import inspect from collections import defaultdict +from collections.abc import Callable from numbers import Number as numeric_type from typing import Optional, Tuple @@ -355,6 +357,30 @@ def create_function(f): return create_function + if not isinstance(function, Callable): + # this is compatible with lambdas and functools.partial objects + raise TypeError( + f"Expected a callable object, got {function} with type {type(function)}" + ) + + # lookup parameters that do not have default values + fparams = inspect.signature(function).parameters + nodefaults = tuple(p.name for p in fparams.values() if p.default is p.empty) + if nodefaults != ("field", "data"): + raise TypeError( + f"Received field function {function} with invalid signature. " + f"Expected exactly 2 positional parameters ('field', 'data'), got {nodefaults!r}" + ) + if any( + fparams[name].kind == fparams[name].KEYWORD_ONLY + for name in ("field", "data") + ): + raise TypeError( + f"Received field function {function} with invalid signature. " + "field and data parameters must accept positional values " + "(they cannot be keyword-only)" + ) + if isinstance(name, tuple): self[name] = DerivedField(name, sampling_type, function, **kwargs) return diff --git a/yt/frontends/amrvac/fields.py b/yt/frontends/amrvac/fields.py index 523b356a916..fde7f0f1fdd 100644 --- a/yt/frontends/amrvac/fields.py +++ b/yt/frontends/amrvac/fields.py @@ -101,7 +101,6 @@ def _setup_velocity_fields(self, idust=None): if not ("amrvac", "m%d%s" % (idir, dust_flag)) in self.field_list: break velocity_fn = functools.partial(_velocity, idir=idir, prefix=dust_label) - functools.update_wrapper(velocity_fn, _velocity) self.add_field( ("gas", f"{dust_label}velocity_{alias}"), function=velocity_fn, diff --git a/yt/visualization/volume_rendering/off_axis_projection.py b/yt/visualization/volume_rendering/off_axis_projection.py index 110a6f11fae..d3a4cc9fe3b 100644 --- a/yt/visualization/volume_rendering/off_axis_projection.py +++ b/yt/visualization/volume_rendering/off_axis_projection.py @@ -320,8 +320,8 @@ def off_axis_projection( weightfield = ("index", "temp_weightfield") def _make_wf(f, w): - def temp_weightfield(a, b): - tr = b[f].astype("float64") * b[w] + def temp_weightfield(field, data): + tr = data[f].astype("float64") * data[w] return tr.d return temp_weightfield diff --git a/yt/visualization/volume_rendering/old_camera.py b/yt/visualization/volume_rendering/old_camera.py index cc98517de42..d5c690081f7 100644 --- a/yt/visualization/volume_rendering/old_camera.py +++ b/yt/visualization/volume_rendering/old_camera.py @@ -2066,9 +2066,9 @@ def __init__( self.weightfield = ("index", "temp_weightfield_%u" % (id(self),)) def _make_wf(f, w): - def temp_weightfield(a, b): - tr = b[f].astype("float64") * b[w] - return b.apply_units(tr, a.units) + def temp_weightfield(field, data): + tr = data[f].astype("float64") * data[w] + return data.apply_units(tr, field.units) return temp_weightfield