Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xarray/computation/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _maybe_null_out(result, axis, mask, min_count=1):
dtype, fill_value = dtypes.maybe_promote(result.dtype)
result = where(null_mask, fill_value, astype(result, dtype))

elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES:
elif hasattr(result, "dtype") and result.dtype.kind not in "mM":
null_mask = mask.size - duck_array_ops.sum(mask)
result = where(null_mask < min_count, np.nan, result)

Expand Down
16 changes: 11 additions & 5 deletions xarray/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from xarray.compat import array_api_compat, npcompat
from xarray.compat.npcompat import HAS_STRING_DTYPE
from xarray.core import utils
from xarray.core.types import PDDatetimeUnitOptions

if TYPE_CHECKING:
from typing import Any
Expand Down Expand Up @@ -88,7 +89,11 @@ def maybe_promote(dtype: T_dtype) -> tuple[T_dtype, Any]:
# See https://github.com/numpy/numpy/issues/10685
# np.timedelta64 is a subclass of np.integer
# Check np.timedelta64 before np.integer
fill_value = np.timedelta64("NaT")
unit, _ = np.datetime_data(dtype)
# np.datetime_data returns a generic str for the unit so we need to
# cast it to a valid time unit for mypy purposes.
unit = cast(PDDatetimeUnitOptions, unit)
fill_value = np.timedelta64("NaT", unit)
dtype_ = dtype
elif isdtype(dtype, "integral"):
dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64
Expand All @@ -97,8 +102,12 @@ def maybe_promote(dtype: T_dtype) -> tuple[T_dtype, Any]:
dtype_ = dtype
fill_value = np.nan + np.nan * 1j
elif np.issubdtype(dtype, np.datetime64):
unit, _ = np.datetime_data(dtype)
# np.datetime_data returns a generic str for the unit so we need to
# cast it to a valid time unit for mypy purposes.
unit = cast(PDDatetimeUnitOptions, unit)
dtype_ = dtype
fill_value = np.datetime64("NaT")
fill_value = np.datetime64("NaT", unit)
else:
dtype_ = object
fill_value = np.nan
Expand All @@ -108,9 +117,6 @@ def maybe_promote(dtype: T_dtype) -> tuple[T_dtype, Any]:
return dtype_out, fill_value


NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype}


def get_fill_value(dtype):
"""Return an appropriate fill value for this dtype.

Expand Down
17 changes: 13 additions & 4 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ def _determine_cmap_params(
else:
mpl = attempt_import("matplotlib")

if plot_data.dtype.kind == "m":
unit, _ = np.datetime_data(plot_data.dtype)
zero = np.timedelta64(0, unit)
elif plot_data.dtype.kind == "M":
unit, _ = np.datetime_data(plot_data.dtype)
zero = np.datetime64(0, unit)
else:
zero = 0.0

if isinstance(levels, Iterable):
levels = sorted(levels)

Expand All @@ -197,15 +206,15 @@ def _determine_cmap_params(
# Handle all-NaN input data gracefully
if calc_data.size == 0:
# Arbitrary default for when all values are NaN
calc_data = np.array(0.0)
calc_data = np.array(zero)

# Setting center=False prevents a divergent cmap
possibly_divergent = center is not False

# Set center to 0 so math below makes sense but remember its state
center_is_none = False
if center is None:
center = 0
center = zero
center_is_none = True

# Setting both vmin and vmax prevents a divergent cmap
Expand Down Expand Up @@ -240,10 +249,10 @@ def _determine_cmap_params(

if possibly_divergent:
levels_are_divergent = (
isinstance(levels, Iterable) and levels[0] * levels[-1] < 0
isinstance(levels, Iterable) and levels[0] * levels[-1] < zero
)
# kwargs not specific about divergent or not: infer defaults from data
divergent = (vmin < 0 < vmax) or not center_is_none or levels_are_divergent
divergent = (vmin < zero < vmax) or not center_is_none or levels_are_divergent
else:
divergent = False

Expand Down
10 changes: 2 additions & 8 deletions xarray/tests/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def test_inf(obj) -> None:
("I", (np.float64, "nan")), # dtype('uint32')
("l", (np.float64, "nan")), # dtype('int64')
("L", (np.float64, "nan")), # dtype('uint64')
("m", (np.timedelta64, "NaT")), # dtype('<m8')
("M", (np.datetime64, "NaT")), # dtype('<M8')
("<m8[ns]", (np.dtype("<m8[ns]"), "NaT")), # dtype('<m8[ns]')
("<M8[ns]", (np.dtype("<M8[ns]"), "NaT")), # dtype('<M8[ns]')
("O", (np.dtype("O"), "nan")), # dtype('O')
("p", (np.float64, "nan")), # dtype('int64')
("P", (np.float64, "nan")), # dtype('uint64')
Expand All @@ -123,12 +123,6 @@ def test_maybe_promote(kind, expected) -> None:
assert str(actual[1]) == expected[1]


def test_nat_types_membership() -> None:
assert np.datetime64("NaT").dtype in dtypes.NAT_TYPES
assert np.timedelta64("NaT").dtype in dtypes.NAT_TYPES
assert np.float64 not in dtypes.NAT_TYPES


@pytest.mark.parametrize(
["dtype", "kinds", "xp", "expected"],
(
Expand Down
Loading