diff --git a/nose_unit.cfg b/nose_unit.cfg index 58124fbd08c..0b29a9dc2a0 100644 --- a/nose_unit.cfg +++ b/nose_unit.cfg @@ -6,5 +6,5 @@ nologcapture=1 verbosity=2 where=yt with-timer=1 -ignore-files=(test_load_errors.py|test_load_sample.py|test_commons.py|test_ambiguous_fields.py|test_field_access_pytest.py|test_save.py|test_line_annotation_unit.py|test_eps_writer.py|test_registration.py|test_invalid_origin.py|test_outputs_pytest\.py|test_normal_plot_api\.py|test_load_archive\.py|test_stream_particles\.py|test_file_sanitizer\.py|test_version\.py|\test_on_demand_imports\.py) +ignore-files=(test_load_errors.py|test_load_sample.py|test_commons.py|test_ambiguous_fields.py|test_field_access_pytest.py|test_save.py|test_line_annotation_unit.py|test_eps_writer.py|test_registration.py|test_invalid_origin.py|test_outputs_pytest\.py|test_normal_plot_api\.py|test_load_archive\.py|test_stream_particles\.py|test_file_sanitizer\.py|test_version\.py|\test_on_demand_imports\.py|test_add_field\.py) exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF diff --git a/tests/tests.yaml b/tests/tests.yaml index 6028cb33506..af9a72d05d3 100644 --- a/tests/tests.yaml +++ b/tests/tests.yaml @@ -203,6 +203,7 @@ other_tests: - "--ignore-files=test_normal_plot_api\\.py" - "--ignore-file=test_file_sanitizer\\.py" - "--ignore-files=test_version\\.py" + - "--ignore-file=test_add_field\\.py" - "--exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF" - "--exclude-test=yt.frontends.adaptahop.tests.test_outputs" - "--exclude-test=yt.frontends.stream.tests.test_stream_particles.test_stream_non_cartesian_particles" diff --git a/yt/data_objects/selection_objects/data_selection_objects.py b/yt/data_objects/selection_objects/data_selection_objects.py index f7fd9d8e25f..a4cc4fe612e 100644 --- a/yt/data_objects/selection_objects/data_selection_objects.py +++ b/yt/data_objects/selection_objects/data_selection_objects.py @@ -15,7 +15,7 @@ from yt.fields.field_exceptions import NeedsGridType from yt.funcs import fix_axis, is_sequence, iter_fields, validate_width_tuple from yt.geometry.selection_routines import compose_selector -from yt.units import YTArray, dimensions as ytdims +from yt.units import YTArray from yt.utilities.exceptions import ( GenerationInProgress, YTBooleanObjectError, @@ -242,22 +242,27 @@ def _generate_fields(self, fields_to_generate): # field accesses units = getattr(fd, "units", "") if units == "": - dimensions = ytdims.dimensionless + sunits = "" + dimensions = 1 else: - dimensions = units.dimensions - units = str( + sunits = str( units.get_base_equivalent(self.ds.unit_system.name) ) - if fi.dimensions != dimensions: + dimensions = units.dimensions + + if fi.dimensions is None: + mylog.warning( + "Field %s was added without specifying units or dimensions, " + "auto setting units to %s", + fi.name, + sunits, + ) + elif fi.dimensions != dimensions: raise YTDimensionalityError(fi.dimensions, dimensions) - fi.units = units + fi.units = sunits + fi.dimensions = dimensions self.field_data[field] = self.ds.arr(fd, units) - mylog.warning( - "Field %s was added without specifying units, " - "assuming units are %s", - fi.name, - units, - ) + try: fd.convert_to_units(fi.units) except AttributeError: @@ -266,7 +271,7 @@ def _generate_fields(self, fields_to_generate): # supposed to be unitless fd = self.ds.arr(fd, "") if fi.units != "": - raise YTFieldUnitError(fi, fd.units) + raise YTFieldUnitError(fi, fd.units) from None except UnitConversionError as e: raise YTFieldUnitError(fi, fd.units) from e except UnitParseError as e: diff --git a/yt/data_objects/tests/test_add_field.py b/yt/data_objects/tests/test_add_field.py new file mode 100644 index 00000000000..74485357e98 --- /dev/null +++ b/yt/data_objects/tests/test_add_field.py @@ -0,0 +1,95 @@ +from functools import partial + +import pytest + +from yt.testing import fake_random_ds + + +def test_add_field_lambda(): + ds = fake_random_ds(16) + + ds.add_field( + ("gas", "spam"), + lambda field, data: data["gas", "density"], + sampling_type="cell", + ) + + # check access + ds.all_data()["gas", "spam"] + + +def test_add_field_partial(): + ds = fake_random_ds(16) + + def _spam(field, data, factor): + return factor * data["gas", "density"] + + ds.add_field( + ("gas", "spam"), + partial(_spam, factor=1), + sampling_type="cell", + ) + + # check access + ds.all_data()["gas", "spam"] + + +def test_add_field_arbitrary_callable(): + ds = fake_random_ds(16) + + class Spam: + def __call__(self, field, data): + return data["gas", "density"] + + ds.add_field(("gas", "spam"), Spam(), sampling_type="cell") + + # check access + ds.all_data()["gas", "spam"] + + +def test_add_field_uncallable(): + ds = fake_random_ds(16) + + class Spam: + pass + + with pytest.raises( + TypeError, match=r"Expected a callable object, got .* with type .*" + ): + ds.add_field(("bacon", "spam"), Spam(), sampling_type="cell") + + +def test_add_field_wrong_signature(): + ds = fake_random_ds(16) + + def _spam(data, field): + return data["gas", "density"] + + with pytest.raises( + TypeError, + match=( + r"Received field function with invalid signature\. " + r"Expected exactly 2 positional parameters \('field', 'data'\), got \('data', 'field'\)" + ), + ): + ds.add_field(("bacon", "spam"), _spam, sampling_type="cell") + + +def test_add_field_keyword_only(): + ds = fake_random_ds(16) + + def _spam(field, *, data): + return data["gas", "density"] + + with pytest.raises( + TypeError, + match=( + r"Received field function .* with invalid signature\. " + r"Parameters 'field' and 'data' must accept positional values \(they cannot be keyword-only\)" + ), + ): + ds.add_field( + ("bacon", "spam"), + _spam, + sampling_type="cell", + ) diff --git a/yt/fields/derived_field.py b/yt/fields/derived_field.py index 32909caf8e0..d6711d741d1 100644 --- a/yt/fields/derived_field.py +++ b/yt/fields/derived_field.py @@ -69,10 +69,9 @@ class DerivedField: arguments (field, data) units : str A plain text string encoding the unit, or a query to a unit system of - a dataset. Powers must be in python syntax (** instead of ^). If set - to "auto" the units will be inferred from the units of the return - value of the field function, and the dimensions keyword must also be - set (see below). + a dataset. Powers must be in Python syntax (** instead of ^). If set + to 'auto' or None (default), units will be inferred from the return value + of the field function. take_log : bool Describes whether the field should be logged validators : list @@ -94,8 +93,7 @@ class DerivedField: fields or that get aliased to themselves, we can specify a different desired output unit than the unit found on disk. dimensions : str or object from yt.units.dimensions - The dimensions of the field, only needed if units="auto" and only used - for error checking. + The dimensions of the field, only used for error checking with units='auto'. nodal_flag : array-like with three components This describes how the field is centered within a cell. If nodal_flag is [0, 0, 0], then the field is cell-centered. If any of the components @@ -162,18 +160,10 @@ def __init__( # handle units self.units: Optional[Union[str, bytes, Unit]] - if units is None: - self.units = "" + if units in (None, "auto"): + self.units = None elif isinstance(units, str): - if units.lower() == "auto": - if dimensions is None: - raise RuntimeError( - "To set units='auto', please specify the dimensions " - "of the field with dimensions=!" - ) - self.units = None - else: - self.units = units + self.units = units elif isinstance(units, Unit): self.units = str(units) elif isinstance(units, bytes): @@ -332,8 +322,7 @@ def get_label(self, projected=False): @property def alias_field(self): - func_name = self._function.__name__ - if func_name == "_TranslationFunc": + if getattr(self._function, "__name__", None) == "_TranslationFunc": return True return False @@ -344,10 +333,9 @@ def alias_name(self): return None def __repr__(self): - func_name = self._function.__name__ - if self._function == NullFunc: + if self._function is NullFunc: s = "On-Disk Field " - elif func_name == "_TranslationFunc": + elif getattr(self._function, "__name__", None) == "_TranslationFunc": s = f'Alias Field for "{self.alias_name}" ' else: s = "Derived Field " diff --git a/yt/fields/field_detector.py b/yt/fields/field_detector.py index bebdb6f7be3..ce5ee6c5221 100644 --- a/yt/fields/field_detector.py +++ b/yt/fields/field_detector.py @@ -101,6 +101,8 @@ def _reshape_vals(self, arr): return arr.reshape(self.ActiveDimensions, order="C") def __missing__(self, item): + from yt.fields.derived_field import NullFunc + if not isinstance(item, tuple): field = ("unknown", item) else: @@ -115,7 +117,7 @@ def __missing__(self, item): # Note that the *only* way this works is if we also fix our field # dependencies during checking. Bug #627 talks about this. item = self.ds._last_freq - if finfo is not None and finfo._function.__name__ != "NullFunc": + if finfo is not None and finfo._function is not NullFunc: try: for param, param_v in permute_params.items(): for v in param_v: diff --git a/yt/fields/field_info_container.py b/yt/fields/field_info_container.py index cf891eb24f6..5a73ea97084 100644 --- a/yt/fields/field_info_container.py +++ b/yt/fields/field_info_container.py @@ -1,4 +1,6 @@ +import inspect from collections import defaultdict +from collections.abc import Callable from numbers import Number as numeric_type from typing import Optional, Tuple @@ -373,6 +375,30 @@ def create_function(f): return create_function + if not isinstance(function, Callable): + # this is compatible with lambdas and functools.partial objects + raise TypeError( + f"Expected a callable object, got {function} with type {type(function)}" + ) + + # lookup parameters that do not have default values + fparams = inspect.signature(function).parameters + nodefaults = tuple(p.name for p in fparams.values() if p.default is p.empty) + if nodefaults != ("field", "data"): + raise TypeError( + f"Received field function {function} with invalid signature. " + f"Expected exactly 2 positional parameters ('field', 'data'), got {nodefaults!r}" + ) + if any( + fparams[name].kind == fparams[name].KEYWORD_ONLY + for name in ("field", "data") + ): + raise TypeError( + f"Received field function {function} with invalid signature. " + "Parameters 'field' and 'data' must accept positional values " + "(they cannot be keyword-only)" + ) + if isinstance(name, tuple): self[name] = DerivedField(name, sampling_type, function, **kwargs) return @@ -434,9 +460,9 @@ def __setitem__(self, key, value): def alias( self, - alias_name, - original_name, - units=None, + alias_name: Tuple[str, str], + original_name: Tuple[str, str], + units: Optional[str] = None, deprecate: Optional[Tuple[str, str]] = None, ): """ @@ -444,15 +470,15 @@ def alias( Parameters ---------- - alias_name : Tuple[str] + alias_name : Tuple[str, str] The new field name. - original_name : Tuple[str] + original_name : Tuple[str, str] The field to be aliased. units : str A plain text string encoding the unit. Powers must be in python syntax (** instead of ^). If set to "auto" the units will be inferred from the return value of the field function. - deprecate : Tuple[str], optional + deprecate : Tuple[str, str], optional If this is set, then the tuple contains two string version numbers: the first marking the version when the field was deprecated, and the second marking when the field will be @@ -463,11 +489,19 @@ def alias( if units is None: # We default to CGS here, but in principle, this can be pluggable # as well. - u = Unit(self[original_name].units, registry=self.ds.unit_registry) - if u.dimensions is not dimensionless: - units = str(self.ds.unit_system[u.dimensions]) + + # self[original_name].units may be set to `None` at this point + # to signal that units should be autoset later + oru = self[original_name].units + if oru is None: + units = None else: - units = self[original_name].units + u = Unit(oru, registry=self.ds.unit_registry) + if u.dimensions is not dimensionless: + units = str(self.ds.unit_system[u.dimensions]) + else: + units = oru + self.field_aliases[alias_name] = original_name function = TranslationFunc(original_name) if deprecate is not None: diff --git a/yt/fields/tests/test_fields.py b/yt/fields/tests/test_fields.py index 49f26ee2bea..7a6cc459a59 100644 --- a/yt/fields/tests/test_fields.py +++ b/yt/fields/tests/test_fields.py @@ -311,9 +311,6 @@ def density_alias(field, data): def unitless_data(field, data): return np.ones(data[("gas", "density")].shape) - ds.add_field( - ("gas", "density_alias_no_units"), sampling_type="cell", function=density_alias - ) ds.add_field( ("gas", "density_alias_auto"), sampling_type="cell", @@ -340,7 +337,6 @@ def unitless_data(field, data): units="auto", dimensions="temperature", ) - assert_raises(YTFieldUnitError, get_data, ds, ("gas", "density_alias_no_units")) assert_raises(YTFieldUnitError, get_data, ds, ("gas", "density_alias_wrong_units")) assert_raises( YTFieldUnitParseError, get_data, ds, ("gas", "density_alias_unparseable_units") diff --git a/yt/frontends/amrvac/fields.py b/yt/frontends/amrvac/fields.py index 523b356a916..fde7f0f1fdd 100644 --- a/yt/frontends/amrvac/fields.py +++ b/yt/frontends/amrvac/fields.py @@ -101,7 +101,6 @@ def _setup_velocity_fields(self, idust=None): if not ("amrvac", "m%d%s" % (idir, dust_flag)) in self.field_list: break velocity_fn = functools.partial(_velocity, idir=idir, prefix=dust_label) - functools.update_wrapper(velocity_fn, _velocity) self.add_field( ("gas", f"{dust_label}velocity_{alias}"), function=velocity_fn, diff --git a/yt/frontends/ramses/tests/test_outputs.py b/yt/frontends/ramses/tests/test_outputs.py index 25f089bbb43..28c440c5169 100644 --- a/yt/frontends/ramses/tests/test_outputs.py +++ b/yt/frontends/ramses/tests/test_outputs.py @@ -590,6 +590,7 @@ def dummy(field, data): fname, gen_dummy(ngz), sampling_type="cell", + units="", validators=[yt.ValidateSpatial(ghost_zones=ngz)], ) fields.append(fname) diff --git a/yt/visualization/volume_rendering/off_axis_projection.py b/yt/visualization/volume_rendering/off_axis_projection.py index 110a6f11fae..d3a4cc9fe3b 100644 --- a/yt/visualization/volume_rendering/off_axis_projection.py +++ b/yt/visualization/volume_rendering/off_axis_projection.py @@ -320,8 +320,8 @@ def off_axis_projection( weightfield = ("index", "temp_weightfield") def _make_wf(f, w): - def temp_weightfield(a, b): - tr = b[f].astype("float64") * b[w] + def temp_weightfield(field, data): + tr = data[f].astype("float64") * data[w] return tr.d return temp_weightfield diff --git a/yt/visualization/volume_rendering/old_camera.py b/yt/visualization/volume_rendering/old_camera.py index cc98517de42..d5c690081f7 100644 --- a/yt/visualization/volume_rendering/old_camera.py +++ b/yt/visualization/volume_rendering/old_camera.py @@ -2066,9 +2066,9 @@ def __init__( self.weightfield = ("index", "temp_weightfield_%u" % (id(self),)) def _make_wf(f, w): - def temp_weightfield(a, b): - tr = b[f].astype("float64") * b[w] - return b.apply_units(tr, a.units) + def temp_weightfield(field, data): + tr = data[f].astype("float64") * data[w] + return data.apply_units(tr, field.units) return temp_weightfield