Skip to content

Commit a96fc0a

Browse files
committed
add _DimType_co generic typevar to NamedArray
Also makes private typing names 'public' by removing leading _
1 parent c9ac506 commit a96fc0a

13 files changed

Lines changed: 492 additions & 423 deletions

xarray/core/common.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,14 +215,13 @@ def __iter__(self: Any) -> Iterator[Any]:
215215
return self._iter()
216216

217217
@overload
218-
def get_axis_num(self, dim: str) -> int: ... # type: ignore [overload-overlap]
218+
def get_axis_num(
219+
self, dim: Hashable
220+
) -> int: ... # put this first to match a single str
219221

220222
@overload
221223
def get_axis_num(self, dim: Iterable[Hashable]) -> tuple[int, ...]: ...
222224

223-
@overload
224-
def get_axis_num(self, dim: Hashable) -> int: ...
225-
226225
def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, ...]:
227226
"""Return axis number(s) corresponding to dimension(s) in this array.
228227

xarray/core/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7254,7 +7254,7 @@ def _to_dataframe(self, ordered_dims: Mapping[Any, int]):
72547254
extension_array_df = pd.DataFrame(
72557255
{extension_array_column: extension_array},
72567256
index=pd.Index(index.array)
7257-
if isinstance(index, PandasExtensionArray) # type: ignore[redundant-expr]
7257+
if isinstance(index, PandasExtensionArray)
72587258
else index,
72597259
)
72607260
extension_array_df.index.name = self.variables[extension_array_column].dims[

xarray/core/indexing.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from xarray.core.indexes import Index
4343
from xarray.core.types import Self
4444
from xarray.core.variable import Variable
45-
from xarray.namedarray._typing import _Shape, duckarray
45+
from xarray.namedarray._typing import Shape, duckarray
4646
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint
4747

4848
BasicIndexerType = int | np.integer | slice
@@ -730,7 +730,7 @@ def __init__(self, array: Any, key: ExplicitIndexer | None = None):
730730
self.array = as_indexable(array)
731731
self.key = key
732732

733-
shape: _Shape = ()
733+
shape: Shape = ()
734734
for size, k in zip(self.array.shape, self.key.tuple, strict=True):
735735
if isinstance(k, slice):
736736
shape += (len(range(*k.indices(size))),)
@@ -754,7 +754,7 @@ def _updated_key(self, new_key: ExplicitIndexer) -> BasicIndexer | OuterIndexer:
754754
return OuterIndexer(full_key_tuple)
755755

756756
@property
757-
def shape(self) -> _Shape:
757+
def shape(self) -> Shape:
758758
return self._shape
759759

760760
def get_duck_array(self):
@@ -836,7 +836,7 @@ def __init__(self, array: duckarray[Any, Any], key: ExplicitIndexer):
836836
self.array = as_indexable(array)
837837

838838
@property
839-
def shape(self) -> _Shape:
839+
def shape(self) -> Shape:
840840
return np.broadcast(*self.key.tuple).shape
841841

842842
def get_duck_array(self):
@@ -1025,7 +1025,7 @@ def as_indexable(array):
10251025

10261026

10271027
def _outer_to_vectorized_indexer(
1028-
indexer: BasicIndexer | OuterIndexer, shape: _Shape
1028+
indexer: BasicIndexer | OuterIndexer, shape: Shape
10291029
) -> VectorizedIndexer:
10301030
"""Convert an OuterIndexer into a vectorized indexer.
10311031
@@ -1061,7 +1061,7 @@ def _outer_to_vectorized_indexer(
10611061
return VectorizedIndexer(tuple(new_key))
10621062

10631063

1064-
def _outer_to_numpy_indexer(indexer: BasicIndexer | OuterIndexer, shape: _Shape):
1064+
def _outer_to_numpy_indexer(indexer: BasicIndexer | OuterIndexer, shape: Shape):
10651065
"""Convert an OuterIndexer into an indexer for NumPy.
10661066
10671067
Parameters
@@ -1085,7 +1085,7 @@ def _outer_to_numpy_indexer(indexer: BasicIndexer | OuterIndexer, shape: _Shape)
10851085
return _outer_to_vectorized_indexer(indexer, shape).tuple
10861086

10871087

1088-
def _combine_indexers(old_key, shape: _Shape, new_key) -> VectorizedIndexer:
1088+
def _combine_indexers(old_key, shape: Shape, new_key) -> VectorizedIndexer:
10891089
"""Combine two indexers.
10901090
10911091
Parameters
@@ -1127,7 +1127,7 @@ class IndexingSupport(enum.Enum):
11271127

11281128
def explicit_indexing_adapter(
11291129
key: ExplicitIndexer,
1130-
shape: _Shape,
1130+
shape: Shape,
11311131
indexing_support: IndexingSupport,
11321132
raw_indexing_method: Callable[..., Any],
11331133
) -> Any:
@@ -1163,7 +1163,7 @@ def explicit_indexing_adapter(
11631163

11641164
async def async_explicit_indexing_adapter(
11651165
key: ExplicitIndexer,
1166-
shape: _Shape,
1166+
shape: Shape,
11671167
indexing_support: IndexingSupport,
11681168
raw_indexing_method: Callable[..., Any],
11691169
) -> Any:
@@ -1197,7 +1197,7 @@ def set_with_indexer(indexable, indexer: ExplicitIndexer, value: Any) -> None:
11971197

11981198

11991199
def decompose_indexer(
1200-
indexer: ExplicitIndexer, shape: _Shape, indexing_support: IndexingSupport
1200+
indexer: ExplicitIndexer, shape: Shape, indexing_support: IndexingSupport
12011201
) -> tuple[ExplicitIndexer, ExplicitIndexer]:
12021202
if isinstance(indexer, VectorizedIndexer):
12031203
return _decompose_vectorized_indexer(indexer, shape, indexing_support)
@@ -1236,7 +1236,7 @@ def _decompose_slice(key: slice, size: int) -> tuple[slice, slice]:
12361236

12371237
def _decompose_vectorized_indexer(
12381238
indexer: VectorizedIndexer,
1239-
shape: _Shape,
1239+
shape: Shape,
12401240
indexing_support: IndexingSupport,
12411241
) -> tuple[ExplicitIndexer, ExplicitIndexer]:
12421242
"""
@@ -1318,7 +1318,7 @@ def _decompose_vectorized_indexer(
13181318

13191319
def _decompose_outer_indexer(
13201320
indexer: BasicIndexer | OuterIndexer,
1321-
shape: _Shape,
1321+
shape: Shape,
13221322
indexing_support: IndexingSupport,
13231323
) -> tuple[ExplicitIndexer, ExplicitIndexer]:
13241324
"""
@@ -1500,7 +1500,7 @@ def _arrayize_outer_indexer(indexer: OuterIndexer, shape) -> OuterIndexer:
15001500

15011501

15021502
def _arrayize_vectorized_indexer(
1503-
indexer: VectorizedIndexer, shape: _Shape
1503+
indexer: VectorizedIndexer, shape: Shape
15041504
) -> VectorizedIndexer:
15051505
"""Return an identical vindex but slices are replaced by arrays"""
15061506
slices = [v for v in indexer.tuple if isinstance(v, slice)]
@@ -1564,7 +1564,7 @@ def _masked_result_drop_slice(key, data: duckarray[Any, Any] | None = None):
15641564

15651565

15661566
def create_mask(
1567-
indexer: ExplicitIndexer, shape: _Shape, data: duckarray[Any, Any] | None = None
1567+
indexer: ExplicitIndexer, shape: Shape, data: duckarray[Any, Any] | None = None
15681568
):
15691569
"""Create a mask for indexing with a fill-value.
15701570
@@ -1975,7 +1975,7 @@ def get_duck_array(self) -> np.ndarray | PandasExtensionArray:
19751975
return np.asarray(self)
19761976

19771977
@property
1978-
def shape(self) -> _Shape:
1978+
def shape(self) -> Shape:
19791979
return (len(self.array),)
19801980

19811981
def _convert_scalar(self, item) -> np.ndarray:

xarray/namedarray/_aggregations.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,20 @@
55
from __future__ import annotations
66

77
from collections.abc import Callable, Sequence
8-
from typing import Any
8+
from typing import Any, Generic
99

1010
from xarray.core import duck_array_ops
11-
from xarray.core.types import Dims, Self
11+
from xarray.core.types import Self
12+
from xarray.namedarray._typing import DimsLike, DimType_co
1213

1314

14-
class NamedArrayAggregations:
15+
class NamedArrayAggregations(Generic[DimType_co]):
1516
__slots__ = ()
1617

1718
def reduce(
1819
self,
1920
func: Callable[..., Any],
20-
dim: Dims = None,
21+
dim: DimsLike[DimType_co] = None,
2122
*,
2223
axis: int | Sequence[int] | None = None,
2324
keepdims: bool = False,
@@ -27,7 +28,7 @@ def reduce(
2728

2829
def count(
2930
self,
30-
dim: Dims = None,
31+
dim: DimsLike[DimType_co] = None,
3132
**kwargs: Any,
3233
) -> Self:
3334
"""
@@ -78,7 +79,7 @@ def count(
7879

7980
def all(
8081
self,
81-
dim: Dims = None,
82+
dim: DimsLike[DimType_co] = None,
8283
**kwargs: Any,
8384
) -> Self:
8485
"""
@@ -131,7 +132,7 @@ def all(
131132

132133
def any(
133134
self,
134-
dim: Dims = None,
135+
dim: DimsLike[DimType_co] = None,
135136
**kwargs: Any,
136137
) -> Self:
137138
"""
@@ -184,7 +185,7 @@ def any(
184185

185186
def max(
186187
self,
187-
dim: Dims = None,
188+
dim: DimsLike[DimType_co] = None,
188189
*,
189190
skipna: bool | None = None,
190191
**kwargs: Any,
@@ -249,7 +250,7 @@ def max(
249250

250251
def min(
251252
self,
252-
dim: Dims = None,
253+
dim: DimsLike[DimType_co] = None,
253254
*,
254255
skipna: bool | None = None,
255256
**kwargs: Any,
@@ -314,7 +315,7 @@ def min(
314315

315316
def mean(
316317
self,
317-
dim: Dims = None,
318+
dim: DimsLike[DimType_co] = None,
318319
*,
319320
skipna: bool | None = None,
320321
**kwargs: Any,
@@ -379,7 +380,7 @@ def mean(
379380

380381
def prod(
381382
self,
382-
dim: Dims = None,
383+
dim: DimsLike[DimType_co] = None,
383384
*,
384385
skipna: bool | None = None,
385386
min_count: int | None = None,
@@ -462,7 +463,7 @@ def prod(
462463

463464
def sum(
464465
self,
465-
dim: Dims = None,
466+
dim: DimsLike[DimType_co] = None,
466467
*,
467468
skipna: bool | None = None,
468469
min_count: int | None = None,
@@ -545,7 +546,7 @@ def sum(
545546

546547
def std(
547548
self,
548-
dim: Dims = None,
549+
dim: DimsLike[DimType_co] = None,
549550
*,
550551
skipna: bool | None = None,
551552
ddof: int = 0,
@@ -625,7 +626,7 @@ def std(
625626

626627
def var(
627628
self,
628-
dim: Dims = None,
629+
dim: DimsLike[DimType_co] = None,
629630
*,
630631
skipna: bool | None = None,
631632
ddof: int = 0,
@@ -705,7 +706,7 @@ def var(
705706

706707
def median(
707708
self,
708-
dim: Dims = None,
709+
dim: DimsLike[DimType_co] = None,
709710
*,
710711
skipna: bool | None = None,
711712
**kwargs: Any,
@@ -774,7 +775,7 @@ def median(
774775

775776
def cumsum(
776777
self,
777-
dim: Dims = None,
778+
dim: DimsLike[DimType_co] = None,
778779
*,
779780
skipna: bool | None = None,
780781
**kwargs: Any,
@@ -848,7 +849,7 @@ def cumsum(
848849

849850
def cumprod(
850851
self,
851-
dim: Dims = None,
852+
dim: DimsLike[DimType_co] = None,
852853
*,
853854
skipna: bool | None = None,
854855
**kwargs: Any,

0 commit comments

Comments
 (0)