Skip to content

Commit

Permalink
ENH: improve error messages when add_field received a malformed funct…
Browse files Browse the repository at this point in the history
…ion argument
  • Loading branch information
neutrinoceros committed May 12, 2022
1 parent c1433df commit c08c55d
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 8 deletions.
2 changes: 1 addition & 1 deletion nose_unit.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ nologcapture=1
verbosity=2
where=yt
with-timer=1
ignore-files=(test_load_errors.py|test_load_sample.py|test_commons.py|test_ambiguous_fields.py|test_field_access_pytest.py|test_save.py|test_line_annotation_unit.py|test_eps_writer.py|test_registration.py|test_invalid_origin.py|test_outputs_pytest\.py|test_normal_plot_api\.py|test_load_archive\.py|test_stream_particles\.py|test_file_sanitizer\.py)
ignore-files=(test_load_errors.py|test_load_sample.py|test_commons.py|test_ambiguous_fields.py|test_field_access_pytest.py|test_save.py|test_line_annotation_unit.py|test_eps_writer.py|test_registration.py|test_invalid_origin.py|test_outputs_pytest\.py|test_normal_plot_api\.py|test_load_archive\.py|test_stream_particles\.py|test_file_sanitizer\.py|test_add_field\.py)
exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF
1 change: 1 addition & 0 deletions tests/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ other_tests:
- "--ignore-files=test_outputs_pytest\\.py"
- "--ignore-files=test_normal_plot_api\\.py"
- "--ignore-file=test_file_sanitizer\\.py"
- "--ignore-file=test_add_field\\.py"
- "--exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF"
- "--exclude-test=yt.frontends.adaptahop.tests.test_outputs"
- "--exclude-test=yt.frontends.stream.tests.test_stream_particles.test_stream_non_cartesian_particles"
Expand Down
103 changes: 103 additions & 0 deletions yt/data_objects/tests/test_add_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from functools import partial

import pytest

from yt.testing import fake_random_ds


def test_add_field_lambda():
ds = fake_random_ds(16)

ds.add_field(
("bacon", "spam"),
lambda field, data: data["gas", "density"],
sampling_type="cell",
)


def test_add_field_partial():
ds = fake_random_ds(16)

def _spam(field, data, factor):
return factor * data["gas", "density"]

ds.add_field(
("bacon", "spam"),
partial(_spam, factor=1),
sampling_type="cell",
)


def test_add_field_arbitrary_callable():
ds = fake_random_ds(16)

class Spam:
def __call__(self, field, data):
return data["gas", "density"]

ds.add_field(("bacon", "spam"), Spam(), sampling_type="cell")


def test_add_field_uncallable():
ds = fake_random_ds(16)

class Spam:
pass

with pytest.raises(
TypeError, match=r"Expected a callable object, got .* with type .*"
):
ds.add_field(("bacon", "spam"), Spam(), sampling_type="cell")


def test_add_field_wrong_signature():
ds = fake_random_ds(16)

def _spam(data, field):
return data["gas", "density"]

with pytest.raises(
TypeError,
match=(
r"Received field function <function .*> with invalid signature\. "
r"Expected exactly 2 positional parameters \('field', 'data'\), got \('data', 'field'\)"
),
):
ds.add_field(("bacon", "spam"), _spam, sampling_type="cell")


def test_add_field_wrong_signature_lambda():
ds = fake_random_ds(16)

with pytest.raises(
TypeError,
match=(
r"Received field function <function .*> with invalid signature\. "
r"Expected exactly 2 positional parameters \('field', 'data'\), got \('data', 'field'\)"
),
):
ds.add_field(
("bacon", "spam"),
lambda data, field: data["gas", "density"],
sampling_type="cell",
)


def test_add_field_keyword_only():
ds = fake_random_ds(16)

def _spam(field, *, data):
return data["gas", "density"]

with pytest.raises(
TypeError,
match=(
r"Received field function .* with invalid signature\. "
r"field and data parameters must accept positional values \(they cannot be keyword-only\)"
),
):
ds.add_field(
("bacon", "spam"),
_spam,
sampling_type="cell",
)
4 changes: 3 additions & 1 deletion yt/fields/field_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def _reshape_vals(self, arr):
return arr.reshape(self.ActiveDimensions, order="C")

def __missing__(self, item):
from yt.fields.derived_field import NullFunc

if not isinstance(item, tuple):
field = ("unknown", item)
else:
Expand All @@ -115,7 +117,7 @@ def __missing__(self, item):
# Note that the *only* way this works is if we also fix our field
# dependencies during checking. Bug #627 talks about this.
item = self.ds._last_freq
if finfo is not None and finfo._function.__name__ != "NullFunc":
if finfo is not None and finfo._function is not NullFunc:
try:
for param, param_v in permute_params.items():
for v in param_v:
Expand Down
26 changes: 26 additions & 0 deletions yt/fields/field_info_container.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import inspect
from collections import defaultdict
from collections.abc import Callable
from numbers import Number as numeric_type
from typing import Optional, Tuple

Expand Down Expand Up @@ -355,6 +357,30 @@ def create_function(f):

return create_function

if not isinstance(function, Callable):
# this is compatible with lambdas and functools.partial objects
raise TypeError(
f"Expected a callable object, got {function} with type {type(function)}"
)

# lookup parameters that do not have default values
fparams = inspect.signature(function).parameters
nodefaults = tuple(p.name for p in fparams.values() if p.default is p.empty)
if nodefaults != ("field", "data"):
raise TypeError(
f"Received field function {function} with invalid signature. "
f"Expected exactly 2 positional parameters ('field', 'data'), got {nodefaults!r}"
)
if any(
fparams[name].kind == fparams[name].KEYWORD_ONLY
for name in ("field", "data")
):
raise TypeError(
f"Received field function {function} with invalid signature. "
"field and data parameters must accept positional values "
"(they cannot be keyword-only)"
)

if isinstance(name, tuple):
self[name] = DerivedField(name, sampling_type, function, **kwargs)
return
Expand Down
1 change: 0 additions & 1 deletion yt/frontends/amrvac/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def _setup_velocity_fields(self, idust=None):
if not ("amrvac", "m%d%s" % (idir, dust_flag)) in self.field_list:
break
velocity_fn = functools.partial(_velocity, idir=idir, prefix=dust_label)
functools.update_wrapper(velocity_fn, _velocity)
self.add_field(
("gas", f"{dust_label}velocity_{alias}"),
function=velocity_fn,
Expand Down
4 changes: 2 additions & 2 deletions yt/visualization/volume_rendering/off_axis_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ def off_axis_projection(
weightfield = ("index", "temp_weightfield")

def _make_wf(f, w):
def temp_weightfield(a, b):
tr = b[f].astype("float64") * b[w]
def temp_weightfield(field, data):
tr = data[f].astype("float64") * data[w]
return tr.d

return temp_weightfield
Expand Down
6 changes: 3 additions & 3 deletions yt/visualization/volume_rendering/old_camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -2066,9 +2066,9 @@ def __init__(
self.weightfield = ("index", "temp_weightfield_%u" % (id(self),))

def _make_wf(f, w):
def temp_weightfield(a, b):
tr = b[f].astype("float64") * b[w]
return b.apply_units(tr, a.units)
def temp_weightfield(field, data):
tr = data[f].astype("float64") * data[w]
return data.apply_units(tr, field.units)

return temp_weightfield

Expand Down

0 comments on commit c08c55d

Please sign in to comment.