diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index 8386161bf29..ed752d3461f 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -1848,9 +1848,10 @@ def _center_pixels(x): # missing data transparent. We therefore add an alpha channel if # there isn't one, and set it to transparent where data is masked. if z.shape[-1] == 3: - alpha = np.ma.ones(z.shape[:2] + (1,), dtype=z.dtype) + safe_dtype = np.promote_types(z.dtype, np.uint8) + alpha = np.ma.ones(z.shape[:2] + (1,), dtype=safe_dtype) if np.issubdtype(z.dtype, np.integer): - alpha *= 255 + alpha[:] = 255 z = np.ma.concatenate((z, alpha), axis=2) else: z = z.copy() diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 6f983a121fe..cf5ab89caad 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2028,15 +2028,17 @@ def test_normalize_rgb_one_arg_error(self) -> None: for vmin2, vmax2 in ((-1.2, -1), (2, 2.1)): da.plot.imshow(vmin=vmin2, vmax=vmax2) - def test_imshow_rgb_values_in_valid_range(self) -> None: - da = DataArray(np.arange(75, dtype="uint8").reshape((5, 5, 3))) + @pytest.mark.parametrize("dtype", [np.uint8, np.int8, np.int16]) + def test_imshow_rgb_values_in_valid_range(self, dtype) -> None: + da = DataArray(np.arange(75, dtype=dtype).reshape((5, 5, 3))) _, ax = plt.subplots() out = da.plot.imshow(ax=ax).get_array() assert out is not None - dtype = out.dtype - assert dtype is not None - assert dtype == np.uint8 + actual_dtype = out.dtype + assert actual_dtype is not None + assert actual_dtype == np.uint8 assert (out[..., :3] == da.values).all() # Compare without added alpha + assert (out[..., -1] == 255).all() # Compare alpha @pytest.mark.filterwarnings("ignore:Several dimensions of this array") def test_regression_rgb_imshow_dim_size_one(self) -> None: