-
Notifications
You must be signed in to change notification settings - Fork 476
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
963: Add tests and documentation with improvement of downcast type compatibility (part of #845) r=hgrecco a=jthielen As a part of #845, this PR adds tests for downcast type compatibility with Sparse's `COO` and NumPy's `MaskedArray`, along with more careful handling of downcast types throughout the library. Also included is new documentation on array type compatibility, including the type casting hierarchy digraph by @shoyer and @crusaderky. While this PR doesn't fully bring Pint's downcast type compatibility to a completed state, I think this gets it "good enough" for the upcoming release, and the remaining issues are fairly well defined: - MaskedArray non-commutativity (#633 / numpy/numpy#15200) - Dask compatibility (#883) - Addition of CuPy tests (no issue on issue tracker yet) Because of that, I think this can close #845, but if @hgrecco you want that kept open until the above items are resolved, let me know. - [x] Closes #37; Closes #845 - [x] Executed ``black -t py36 . && isort -rc . && flake8`` with no errors - [x] The change is fully covered by automated unit tests - [x] Documented in docs/ as appropriate - [x] Added an entry to the CHANGES file Co-authored-by: Jon Thielen <[email protected]>
Showing
9 changed files
with
1,027 additions
and
244 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import pytest | ||
|
||
from pint import UnitRegistry | ||
|
||
# Conditionally import NumPy and any upcast type libraries | ||
np = pytest.importorskip("numpy", reason="NumPy is not available") | ||
sparse = pytest.importorskip("sparse", reason="sparse is not available") | ||
|
||
# Set up unit registry and sample | ||
ureg = UnitRegistry(force_ndarray_like=True) | ||
q_base = (np.arange(25).reshape(5, 5).T + 1) * ureg.kg | ||
|
||
|
||
# Define identity function for use in tests | ||
def identity(x): | ||
return x | ||
|
||
|
||
@pytest.fixture(params=["sparse", "masked_array"]) | ||
def array(request): | ||
"""Generate 5x5 arrays of given type for tests.""" | ||
if request.param == "sparse": | ||
# Create sample sparse COO as a permutation matrix. | ||
coords = [[0, 1, 2, 3, 4], [1, 3, 0, 2, 4]] | ||
data = [1.0] * 5 | ||
return sparse.COO(coords, data, shape=(5, 5)) | ||
elif request.param == "masked_array": | ||
# Create sample masked array as an upper triangular matrix. | ||
return np.ma.masked_array( | ||
np.arange(25, dtype=np.float).reshape((5, 5)), | ||
mask=np.logical_not(np.triu(np.ones((5, 5)))), | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"op, magnitude_op, unit_op", | ||
[ | ||
pytest.param(identity, identity, identity, id="identity"), | ||
pytest.param( | ||
lambda x: x + 1 * ureg.m, lambda x: x + 1, identity, id="addition" | ||
), | ||
pytest.param( | ||
lambda x: x - 20 * ureg.cm, lambda x: x - 0.2, identity, id="subtraction" | ||
), | ||
pytest.param( | ||
lambda x: x * (2 * ureg.s), | ||
lambda x: 2 * x, | ||
lambda u: u * ureg.s, | ||
id="multiplication", | ||
), | ||
pytest.param( | ||
lambda x: x / (1 * ureg.s), identity, lambda u: u / ureg.s, id="division" | ||
), | ||
pytest.param(lambda x: x ** 2, lambda x: x ** 2, lambda u: u ** 2, id="square"), | ||
pytest.param(lambda x: x.T, lambda x: x.T, identity, id="transpose"), | ||
pytest.param(np.mean, np.mean, identity, id="mean ufunc"), | ||
pytest.param(np.sum, np.sum, identity, id="sum ufunc"), | ||
pytest.param(np.sqrt, np.sqrt, lambda u: u ** 0.5, id="sqrt ufunc"), | ||
pytest.param( | ||
lambda x: np.reshape(x, 25), | ||
lambda x: np.reshape(x, 25), | ||
identity, | ||
id="reshape function", | ||
), | ||
pytest.param(np.amax, np.amax, identity, id="amax function"), | ||
], | ||
) | ||
def test_univariate_op_consistency(op, magnitude_op, unit_op, array): | ||
q = ureg.Quantity(array, "meter") | ||
res = op(q) | ||
assert np.all(res.magnitude == magnitude_op(array)) # Magnitude check | ||
assert res.units == unit_op(q.units) # Unit check | ||
assert q.magnitude is array # Immutability check | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"op, unit", | ||
[ | ||
pytest.param(lambda x, y: x * y, ureg("kg m"), id="multiplication"), | ||
pytest.param(lambda x, y: x / y, ureg("m / kg"), id="division"), | ||
pytest.param(np.multiply, ureg("kg m"), id="multiply ufunc"), | ||
], | ||
) | ||
def test_bivariate_op_consistency(op, unit, array): | ||
q = ureg.Quantity(array, "meter") | ||
res = op(q, q_base) | ||
assert np.all(res.magnitude == op(array, q_base.magnitude)) # Magnitude check | ||
assert res.units == unit # Unit check | ||
assert q.magnitude is array # Immutability check | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"op", | ||
[ | ||
pytest.param( | ||
lambda a, u: a * u, | ||
id="array-first", | ||
marks=pytest.mark.xfail(reason="upstream issue numpy/numpy#15200"), | ||
), | ||
pytest.param(lambda a, u: u * a, id="unit-first"), | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"unit", | ||
[pytest.param(ureg.m, id="Unit"), pytest.param(ureg("meter"), id="Quantity")], | ||
) | ||
def test_array_quantity_creation_by_multiplication(op, unit, array): | ||
assert type(op(array, unit)) == ureg.Quantity |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters