diff --git a/xarray/computation/nanops.py b/xarray/computation/nanops.py index a28078540bb..c5a9d9bcc04 100644 --- a/xarray/computation/nanops.py +++ b/xarray/computation/nanops.py @@ -29,7 +29,9 @@ def _maybe_null_out(result, axis, mask, min_count=1): dtype, fill_value = dtypes.maybe_promote(result.dtype) result = where(null_mask, fill_value, astype(result, dtype)) - elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES: + elif (dtype := getattr(result, "dtype", None)) and getattr( + dtype, "kind", None + ) not in {"m", "M"}: null_mask = mask.size - duck_array_ops.sum(mask) result = where(null_mask < min_count, np.nan, result) diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index 40334d58e68..bb2fe26d727 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -10,6 +10,7 @@ from xarray.compat import array_api_compat, npcompat from xarray.compat.npcompat import HAS_STRING_DTYPE from xarray.core import utils +from xarray.core.types import PDDatetimeUnitOptions if TYPE_CHECKING: from typing import Any @@ -88,7 +89,11 @@ def maybe_promote(dtype: T_dtype) -> tuple[T_dtype, Any]: # See https://github.com/numpy/numpy/issues/10685 # np.timedelta64 is a subclass of np.integer # Check np.timedelta64 before np.integer - fill_value = np.timedelta64("NaT") + unit, _ = np.datetime_data(dtype) + # np.datetime_data returns a generic str for the unit so we need to + # cast it to a valid time unit for mypy purposes. + unit = cast(PDDatetimeUnitOptions, unit) + fill_value = np.timedelta64("NaT", unit) dtype_ = dtype elif isdtype(dtype, "integral"): dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64 @@ -97,8 +102,12 @@ def maybe_promote(dtype: T_dtype) -> tuple[T_dtype, Any]: dtype_ = dtype fill_value = np.nan + np.nan * 1j elif np.issubdtype(dtype, np.datetime64): + unit, _ = np.datetime_data(dtype) + # np.datetime_data returns a generic str for the unit so we need to + # cast it to a valid time unit for mypy purposes. + unit = cast(PDDatetimeUnitOptions, unit) dtype_ = dtype - fill_value = np.datetime64("NaT") + fill_value = np.datetime64("NaT", unit) else: dtype_ = object fill_value = np.nan @@ -108,9 +117,6 @@ def maybe_promote(dtype: T_dtype) -> tuple[T_dtype, Any]: return dtype_out, fill_value -NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype} - - def get_fill_value(dtype): """Return an appropriate fill value for this dtype. diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 167718fe9d8..59d1c5d3c57 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -189,6 +189,15 @@ def _determine_cmap_params( else: mpl = attempt_import("matplotlib") + if plot_data.dtype.kind == "m": + unit, _ = np.datetime_data(plot_data.dtype) + zero = np.timedelta64(0, unit) + elif plot_data.dtype.kind == "M": + unit, _ = np.datetime_data(plot_data.dtype) + zero = np.datetime64(0, unit) + else: + zero = 0.0 + if isinstance(levels, Iterable): levels = sorted(levels) @@ -197,7 +206,7 @@ def _determine_cmap_params( # Handle all-NaN input data gracefully if calc_data.size == 0: # Arbitrary default for when all values are NaN - calc_data = np.array(0.0) + calc_data = np.array(zero) # Setting center=False prevents a divergent cmap possibly_divergent = center is not False @@ -205,7 +214,7 @@ def _determine_cmap_params( # Set center to 0 so math below makes sense but remember its state center_is_none = False if center is None: - center = 0 + center = zero center_is_none = True # Setting both vmin and vmax prevents a divergent cmap @@ -240,10 +249,10 @@ def _determine_cmap_params( if possibly_divergent: levels_are_divergent = ( - isinstance(levels, Iterable) and levels[0] * levels[-1] < 0 + isinstance(levels, Iterable) and levels[0] * levels[-1] < zero ) # kwargs not specific about divergent or not: infer defaults from data - divergent = (vmin < 0 < vmax) or not center_is_none or levels_are_divergent + divergent = (vmin < zero < vmax) or not center_is_none or levels_are_divergent else: divergent = False diff --git a/xarray/tests/test_dtypes.py b/xarray/tests/test_dtypes.py index a1bdd9082ef..4ed66509725 100644 --- a/xarray/tests/test_dtypes.py +++ b/xarray/tests/test_dtypes.py @@ -102,8 +102,8 @@ def test_inf(obj) -> None: ("I", (np.float64, "nan")), # dtype('uint32') ("l", (np.float64, "nan")), # dtype('int64') ("L", (np.float64, "nan")), # dtype('uint64') - ("m", (np.timedelta64, "NaT")), # dtype(' None: assert str(actual[1]) == expected[1] -def test_nat_types_membership() -> None: - assert np.datetime64("NaT").dtype in dtypes.NAT_TYPES - assert np.timedelta64("NaT").dtype in dtypes.NAT_TYPES - assert np.float64 not in dtypes.NAT_TYPES - - @pytest.mark.parametrize( ["dtype", "kinds", "xp", "expected"], (