Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: improve error messages when add_field receives a malformed function argument #3921

Merged
merged 5 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nose_unit.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ nologcapture=1
verbosity=2
where=yt
with-timer=1
ignore-files=(test_load_errors.py|test_load_sample.py|test_commons.py|test_ambiguous_fields.py|test_field_access_pytest.py|test_save.py|test_line_annotation_unit.py|test_eps_writer.py|test_registration.py|test_invalid_origin.py|test_outputs_pytest\.py|test_normal_plot_api\.py|test_load_archive\.py|test_stream_particles\.py|test_file_sanitizer\.py|test_version\.py|\test_on_demand_imports\.py)
ignore-files=(test_load_errors.py|test_load_sample.py|test_commons.py|test_ambiguous_fields.py|test_field_access_pytest.py|test_save.py|test_line_annotation_unit.py|test_eps_writer.py|test_registration.py|test_invalid_origin.py|test_outputs_pytest\.py|test_normal_plot_api\.py|test_load_archive\.py|test_stream_particles\.py|test_file_sanitizer\.py|test_version\.py|\test_on_demand_imports\.py|test_add_field\.py)
exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF
1 change: 1 addition & 0 deletions tests/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ other_tests:
- "--ignore-files=test_normal_plot_api\\.py"
- "--ignore-file=test_file_sanitizer\\.py"
- "--ignore-files=test_version\\.py"
- "--ignore-file=test_add_field\\.py"
- "--exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF"
- "--exclude-test=yt.frontends.adaptahop.tests.test_outputs"
- "--exclude-test=yt.frontends.stream.tests.test_stream_particles.test_stream_non_cartesian_particles"
Expand Down
31 changes: 18 additions & 13 deletions yt/data_objects/selection_objects/data_selection_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from yt.fields.field_exceptions import NeedsGridType
from yt.funcs import fix_axis, is_sequence, iter_fields, validate_width_tuple
from yt.geometry.selection_routines import compose_selector
from yt.units import YTArray, dimensions as ytdims
from yt.units import YTArray
from yt.utilities.exceptions import (
GenerationInProgress,
YTBooleanObjectError,
Expand Down Expand Up @@ -242,22 +242,27 @@ def _generate_fields(self, fields_to_generate):
# field accesses
units = getattr(fd, "units", "")
if units == "":
dimensions = ytdims.dimensionless
sunits = ""
dimensions = 1
else:
dimensions = units.dimensions
units = str(
sunits = str(
units.get_base_equivalent(self.ds.unit_system.name)
)
if fi.dimensions != dimensions:
dimensions = units.dimensions

if fi.dimensions is None:
mylog.warning(
"Field %s was added without specifying units or dimensions, "
"auto setting units to %s",
fi.name,
sunits,
)
elif fi.dimensions != dimensions:
raise YTDimensionalityError(fi.dimensions, dimensions)
fi.units = units
fi.units = sunits
fi.dimensions = dimensions
self.field_data[field] = self.ds.arr(fd, units)
mylog.warning(
"Field %s was added without specifying units, "
"assuming units are %s",
fi.name,
units,
)

try:
fd.convert_to_units(fi.units)
except AttributeError:
Expand All @@ -266,7 +271,7 @@ def _generate_fields(self, fields_to_generate):
# supposed to be unitless
fd = self.ds.arr(fd, "")
if fi.units != "":
raise YTFieldUnitError(fi, fd.units)
raise YTFieldUnitError(fi, fd.units) from None
except UnitConversionError as e:
raise YTFieldUnitError(fi, fd.units) from e
except UnitParseError as e:
Expand Down
95 changes: 95 additions & 0 deletions yt/data_objects/tests/test_add_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from functools import partial

import pytest

from yt.testing import fake_random_ds


def test_add_field_lambda():
ds = fake_random_ds(16)

ds.add_field(
("gas", "spam"),
lambda field, data: data["gas", "density"],
sampling_type="cell",
)

# check access
ds.all_data()["gas", "spam"]


def test_add_field_partial():
ds = fake_random_ds(16)

def _spam(field, data, factor):
return factor * data["gas", "density"]

ds.add_field(
("gas", "spam"),
partial(_spam, factor=1),
sampling_type="cell",
)

# check access
ds.all_data()["gas", "spam"]


def test_add_field_arbitrary_callable():
ds = fake_random_ds(16)

class Spam:
def __call__(self, field, data):
return data["gas", "density"]

ds.add_field(("gas", "spam"), Spam(), sampling_type="cell")

# check access
ds.all_data()["gas", "spam"]


def test_add_field_uncallable():
ds = fake_random_ds(16)

class Spam:
pass

with pytest.raises(
TypeError, match=r"Expected a callable object, got .* with type .*"
):
ds.add_field(("bacon", "spam"), Spam(), sampling_type="cell")


def test_add_field_wrong_signature():
ds = fake_random_ds(16)

def _spam(data, field):
return data["gas", "density"]

with pytest.raises(
TypeError,
match=(
r"Received field function <function .*> with invalid signature\. "
r"Expected exactly 2 positional parameters \('field', 'data'\), got \('data', 'field'\)"
),
):
ds.add_field(("bacon", "spam"), _spam, sampling_type="cell")


def test_add_field_keyword_only():
ds = fake_random_ds(16)

def _spam(field, *, data):
return data["gas", "density"]

with pytest.raises(
TypeError,
match=(
r"Received field function .* with invalid signature\. "
r"Parameters 'field' and 'data' must accept positional values \(they cannot be keyword-only\)"
),
):
ds.add_field(
("bacon", "spam"),
_spam,
sampling_type="cell",
)
32 changes: 10 additions & 22 deletions yt/fields/derived_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,9 @@ class DerivedField:
arguments (field, data)
units : str
A plain text string encoding the unit, or a query to a unit system of
a dataset. Powers must be in python syntax (** instead of ^). If set
to "auto" the units will be inferred from the units of the return
value of the field function, and the dimensions keyword must also be
set (see below).
a dataset. Powers must be in Python syntax (** instead of ^). If set
to 'auto' or None (default), units will be inferred from the return value
of the field function.
take_log : bool
Describes whether the field should be logged
validators : list
Expand All @@ -94,8 +93,7 @@ class DerivedField:
fields or that get aliased to themselves, we can specify a different
desired output unit than the unit found on disk.
dimensions : str or object from yt.units.dimensions
The dimensions of the field, only needed if units="auto" and only used
for error checking.
The dimensions of the field, only used for error checking with units='auto'.
nodal_flag : array-like with three components
This describes how the field is centered within a cell. If nodal_flag
is [0, 0, 0], then the field is cell-centered. If any of the components
Expand Down Expand Up @@ -162,18 +160,10 @@ def __init__(

# handle units
self.units: Optional[Union[str, bytes, Unit]]
if units is None:
self.units = ""
if units in (None, "auto"):
self.units = None
elif isinstance(units, str):
if units.lower() == "auto":
if dimensions is None:
raise RuntimeError(
"To set units='auto', please specify the dimensions "
"of the field with dimensions=<dimensions of field>!"
)
self.units = None
else:
self.units = units
self.units = units
elif isinstance(units, Unit):
self.units = str(units)
elif isinstance(units, bytes):
Expand Down Expand Up @@ -332,8 +322,7 @@ def get_label(self, projected=False):

@property
def alias_field(self):
func_name = self._function.__name__
if func_name == "_TranslationFunc":
if getattr(self._function, "__name__", None) == "_TranslationFunc":
return True
return False

Expand All @@ -344,10 +333,9 @@ def alias_name(self):
return None

def __repr__(self):
func_name = self._function.__name__
if self._function == NullFunc:
if self._function is NullFunc:
s = "On-Disk Field "
elif func_name == "_TranslationFunc":
elif getattr(self._function, "__name__", None) == "_TranslationFunc":
s = f'Alias Field for "{self.alias_name}" '
else:
s = "Derived Field "
Expand Down
4 changes: 3 additions & 1 deletion yt/fields/field_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def _reshape_vals(self, arr):
return arr.reshape(self.ActiveDimensions, order="C")

def __missing__(self, item):
from yt.fields.derived_field import NullFunc

if not isinstance(item, tuple):
field = ("unknown", item)
else:
Expand All @@ -115,7 +117,7 @@ def __missing__(self, item):
# Note that the *only* way this works is if we also fix our field
# dependencies during checking. Bug #627 talks about this.
item = self.ds._last_freq
if finfo is not None and finfo._function.__name__ != "NullFunc":
if finfo is not None and finfo._function is not NullFunc:
try:
for param, param_v in permute_params.items():
for v in param_v:
Expand Down
54 changes: 44 additions & 10 deletions yt/fields/field_info_container.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import inspect
from collections import defaultdict
from collections.abc import Callable
from numbers import Number as numeric_type
from typing import Optional, Tuple

Expand Down Expand Up @@ -373,6 +375,30 @@ def create_function(f):

return create_function

if not isinstance(function, Callable):
# this is compatible with lambdas and functools.partial objects
raise TypeError(
f"Expected a callable object, got {function} with type {type(function)}"
)

# lookup parameters that do not have default values
fparams = inspect.signature(function).parameters
nodefaults = tuple(p.name for p in fparams.values() if p.default is p.empty)
if nodefaults != ("field", "data"):
raise TypeError(
f"Received field function {function} with invalid signature. "
f"Expected exactly 2 positional parameters ('field', 'data'), got {nodefaults!r}"
)
if any(
fparams[name].kind == fparams[name].KEYWORD_ONLY
for name in ("field", "data")
):
raise TypeError(
f"Received field function {function} with invalid signature. "
"Parameters 'field' and 'data' must accept positional values "
"(they cannot be keyword-only)"
)

if isinstance(name, tuple):
self[name] = DerivedField(name, sampling_type, function, **kwargs)
return
Expand Down Expand Up @@ -434,25 +460,25 @@ def __setitem__(self, key, value):

def alias(
self,
alias_name,
original_name,
units=None,
alias_name: Tuple[str, str],
original_name: Tuple[str, str],
units: Optional[str] = None,
deprecate: Optional[Tuple[str, str]] = None,
):
"""
Alias one field to another field.

Parameters
----------
alias_name : Tuple[str]
alias_name : Tuple[str, str]
The new field name.
original_name : Tuple[str]
original_name : Tuple[str, str]
The field to be aliased.
units : str
A plain text string encoding the unit. Powers must be in
python syntax (** instead of ^). If set to "auto" the units
will be inferred from the return value of the field function.
deprecate : Tuple[str], optional
deprecate : Tuple[str, str], optional
If this is set, then the tuple contains two string version
numbers: the first marking the version when the field was
deprecated, and the second marking when the field will be
Expand All @@ -463,11 +489,19 @@ def alias(
if units is None:
# We default to CGS here, but in principle, this can be pluggable
# as well.
u = Unit(self[original_name].units, registry=self.ds.unit_registry)
if u.dimensions is not dimensionless:
units = str(self.ds.unit_system[u.dimensions])

# self[original_name].units may be set to `None` at this point
# to signal that units should be autoset later
oru = self[original_name].units
if oru is None:
units = None
else:
units = self[original_name].units
u = Unit(oru, registry=self.ds.unit_registry)
if u.dimensions is not dimensionless:
units = str(self.ds.unit_system[u.dimensions])
else:
units = oru

self.field_aliases[alias_name] = original_name
function = TranslationFunc(original_name)
if deprecate is not None:
Expand Down
4 changes: 0 additions & 4 deletions yt/fields/tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,6 @@ def density_alias(field, data):
def unitless_data(field, data):
return np.ones(data[("gas", "density")].shape)

ds.add_field(
("gas", "density_alias_no_units"), sampling_type="cell", function=density_alias
)
ds.add_field(
("gas", "density_alias_auto"),
sampling_type="cell",
Expand All @@ -340,7 +337,6 @@ def unitless_data(field, data):
units="auto",
dimensions="temperature",
)
assert_raises(YTFieldUnitError, get_data, ds, ("gas", "density_alias_no_units"))
assert_raises(YTFieldUnitError, get_data, ds, ("gas", "density_alias_wrong_units"))
assert_raises(
YTFieldUnitParseError, get_data, ds, ("gas", "density_alias_unparseable_units")
Expand Down
1 change: 0 additions & 1 deletion yt/frontends/amrvac/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def _setup_velocity_fields(self, idust=None):
if not ("amrvac", "m%d%s" % (idir, dust_flag)) in self.field_list:
break
velocity_fn = functools.partial(_velocity, idir=idir, prefix=dust_label)
functools.update_wrapper(velocity_fn, _velocity)
self.add_field(
("gas", f"{dust_label}velocity_{alias}"),
function=velocity_fn,
Expand Down
1 change: 1 addition & 0 deletions yt/frontends/ramses/tests/test_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ def dummy(field, data):
fname,
gen_dummy(ngz),
sampling_type="cell",
units="",
validators=[yt.ValidateSpatial(ghost_zones=ngz)],
)
fields.append(fname)
Expand Down
4 changes: 2 additions & 2 deletions yt/visualization/volume_rendering/off_axis_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ def off_axis_projection(
weightfield = ("index", "temp_weightfield")

def _make_wf(f, w):
def temp_weightfield(a, b):
tr = b[f].astype("float64") * b[w]
def temp_weightfield(field, data):
tr = data[f].astype("float64") * data[w]
return tr.d

return temp_weightfield
Expand Down
Loading