diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 30c7e5ede1..a4ab566eec 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -871,10 +871,9 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> foast.Call: new_type = new_type_constructor.type.definition.returns - return_type = type_info.apply_to_primitive_constituents( - lambda primitive_type: with_altered_scalar_kind(primitive_type, new_type.kind), - value.type, - ) + return_type = type_info.tree_map_type( + lambda primitive_type: with_altered_scalar_kind(primitive_type, new_type.kind) + )(value.type) assert isinstance(return_type, (ts.TupleType, ts.ScalarType, ts.FieldType)) return foast.Call( @@ -962,7 +961,7 @@ def deduce_return_type( return_type = ts.FieldType(dims=return_dims, dtype=t_dtype) return return_type - return deduce_return_type(true_branch, false_branch) # type: ignore[return-value] + return deduce_return_type(true_branch, false_branch) def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: mask_type, true_branch_type, false_branch_type = (arg.type for arg in node.args) diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index 1f77224885..414f1552c4 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -19,26 +19,8 @@ from gt4py.next.type_system import type_info -def _tree_map_type_constructor( - value: ts.CollectionTypeSpecT, - elems: NestedTuple[ts.DataType], -) -> ts.CollectionTypeSpecT: - return ( - ts.NamedCollectionType( - keys=value.keys, original_python_type=value.original_python_type, types=list(elems) - ) - if isinstance(value, ts.NamedCollectionType) - else ts.TupleType(types=list(elems)) # type: ignore[return-value] - ) - - -# TODO: Replace all occurrences of `apply_to_primitive_constituents` with this function, -# which also works with NamedCollections. -tree_map_type = functools.partial( - utils.tree_map, - collection_type=ts.COLLECTION_TYPE_SPECS, - result_collection_constructor=_tree_map_type_constructor, -) +_tree_map_type_constructor = type_info._tree_map_type_constructor +tree_map_type = type_info.tree_map_type named_collections_to_tuple_types = cast( Callable[..., ts.TupleType], diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index ffad69c921..a59eaa7e28 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -164,7 +164,7 @@ def fuse_as_fieldop( assert isinstance(arg.type, ts.TypeSpec) assert not isinstance( - type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type), + type_info.tree_map_type(type_info.extract_dtype)(arg.type), ts.ListType, ) new_args = _merge_arguments(new_args, {stencil_param.id: arg}) @@ -215,7 +215,7 @@ def _arg_inline_predicate(node: itir.Expr, shifts: set[tuple[itir.OffsetLiteral, # applied fieldop with list return type must always be inlined as no backend supports this type_inference.reinfer(node) assert isinstance(node.type, ts.TypeSpec) - dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, node.type) + dtype = type_info.tree_map_type(type_info.extract_dtype)(node.type) if isinstance(dtype, ts.ListType): return True # only accessed at the center location diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index d16749bfac..10062c2ff3 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -169,24 +169,24 @@ def _transform_by_pattern( if extracted_fields: tmp_stmts: list[itir.Stmt] = [] + into_tuple = lambda _, elts: tuple(elts) # noqa: E731 + # for each extracted expression generate: # - one or more `Temporary` declarations (depending on whether the expression is a field # or a tuple thereof) # - one `SetAt` statement that materializes the expression into the temporary for tmp_sym, tmp_expr in extracted_fields.items(): assert isinstance(tmp_expr.type, ts.TypeSpec) - tmp_names: str | tuple[str | tuple, ...] = type_info.apply_to_primitive_constituents( + tmp_names: str | tuple[str | tuple, ...] = type_info.tree_map_type( lambda x: next(uids["__tmp"]), - tmp_expr.type, - tuple_constructor=lambda *elements: tuple(elements), - ) + result_collection_constructor=into_tuple, + )(tmp_expr.type) tmp_dtypes: ( ts.ScalarType | ts.ListType | tuple[ts.ScalarType | ts.ListType | tuple, ...] - ) = type_info.apply_to_primitive_constituents( + ) = type_info.tree_map_type( type_info.extract_dtype, - tmp_expr.type, - tuple_constructor=lambda *elements: elements, - ) + result_collection_constructor=into_tuple, + )(tmp_expr.type) tmp_domains: SymbolicDomain | tuple[SymbolicDomain | tuple, ...] = tmp_expr.annex.domain @@ -214,12 +214,11 @@ def get_domain( # The following propagates the domains to the tuple structure of `tmp_expr.type`. # `tmp_domains` might not have this structure because domain inference was not able to infer the tuple structure. - tmp_domains = type_info.apply_to_primitive_constituents( + tmp_domains = type_info.tree_map_type( get_domain, - tmp_expr.type, + result_collection_constructor=into_tuple, with_path_arg=True, - tuple_constructor=lambda *elements: tuple(elements), - ) + )(tmp_expr.type) declarations.extend( itir.Temporary(id=tmp_name, domain=domain.as_expr(), dtype=dtype) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 0923471b3f..e9a300a24b 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -424,10 +424,11 @@ def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.Tup assert isinstance(domain, ts.DomainType) assert domain.dims != "unknown" assert node.dtype - return type_info.apply_to_primitive_constituents( - lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), - node.dtype, + result = type_info.tree_map_type(lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype))( + node.dtype ) + assert isinstance(result, (ts.FieldType, ts.TupleType)) + return result def visit_IfStmt(self, node: itir.IfStmt, *, ctx) -> None: cond = self.visit(node.cond, ctx=ctx) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 16d5da7e3b..49e435bf8b 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -382,9 +382,7 @@ def _convert_as_fieldop_input_to_iterator( Convert a field operation input into an iterator type, preserving its dimensions and data type. """ input_dims = _collect_and_check_dimensions(input_) - element_type: ts.DataType = type_info.apply_to_primitive_constituents( - type_info.extract_dtype, input_ - ) + element_type: ts.DataType = type_info.tree_map_type(type_info.extract_dtype)(input_) return it_ts.IteratorType( position_dims=domain.dims, defined_dims=input_dims, element_type=element_type @@ -436,9 +434,7 @@ def _canonicalize_nb_fields( ) case ts.FieldType(): input_dims = _collect_and_check_dimensions(input_) - element_type: ts.DataType = type_info.apply_to_primitive_constituents( - type_info.extract_dtype, input_ - ) + element_type: ts.DataType = type_info.tree_map_type(type_info.extract_dtype)(input_) defined_dims = [] neighbor_dim = None for dim in input_dims: @@ -535,7 +531,7 @@ def applied_as_fieldop( # For each stencil parameter all locations it is `deref`ed on # see :func:`gt4py.next.iterator.transforms.trace_stencil`. shift_sequences_per_param: list[set[tuple[itir.OffsetLiteral, ...]]] | None, - ) -> ts.FieldType | ts.DeferredType: + ) -> ts.FieldType | ts.TupleType | ts.DeferredType: if any( isinstance(el, ts.DeferredType) for f in fields @@ -578,13 +574,11 @@ def applied_as_fieldop( assert isinstance(stencil_return, ts.DataType) - return type_info.apply_to_primitive_constituents( - lambda el_type: ts.FieldType( - dims=domain.dims, - dtype=el_type, - ), - stencil_return, - ) + result = type_info.tree_map_type( + lambda el_type: ts.FieldType(dims=domain.dims, dtype=el_type) + )(stencil_return) + assert isinstance(result, (ts.FieldType, ts.TupleType)) + return result return applied_as_fieldop diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index ecd8ed88ed..eabe48131e 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -235,7 +235,7 @@ def _process_elements( obj: Expr, type_: ts.TypeSpec, *, - tuple_constructor: Callable[..., Expr] = lambda _, *elements: FunCall( + tuple_constructor: Callable[..., Expr] = lambda _, elements: FunCall( fun=SymRef(id="make_tuple"), args=list(elements) ), ) -> Expr: @@ -264,13 +264,11 @@ def _gen_constituent_expr(el_type: ts.ScalarType | ts.FieldType, path: tuple[int ) return process_func(el, el_type) - result = type_info.apply_to_primitive_constituents( + return type_info.tree_map_type( _gen_constituent_expr, - type_, + result_collection_constructor=tuple_constructor, with_path_arg=True, - tuple_constructor=tuple_constructor, - ) - return result + )(type_) @dataclasses.dataclass(frozen=True) @@ -523,7 +521,7 @@ def check_el_type(el_expr: Expr, el_type: ts.ScalarType | ts.FieldType) -> Expr: check_el_type, lowered_output, node.type, - tuple_constructor=lambda *elements: SidComposite(values=list(elements)), + tuple_constructor=lambda _, elements: SidComposite(values=list(elements)), ) assert isinstance(lowered_output_as_sid, (SidComposite, SymRef)) @@ -615,7 +613,7 @@ def convert_el_to_sid(el_expr: Expr, el_type: ts.ScalarType | ts.FieldType) -> E convert_el_to_sid, lowered_input, input_.type, - tuple_constructor=lambda *elements: SidComposite(values=list(elements)), + tuple_constructor=lambda _, elements: SidComposite(values=list(elements)), ) lowered_inputs.append(lowered_input_as_sid) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index eb70d15947..6737b72fae 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -9,23 +9,12 @@ import functools import types from collections.abc import Callable, Iterator -from typing import ( - Any, - Generic, - Literal, - Protocol, - Sequence, - Type, - TypeGuard, - TypeVar, - cast, - overload, -) +from typing import Any, Literal, Sequence, Type, TypeGuard, TypeVar, cast, overload import numpy as np from gt4py.eve import extended_typing as xtyping, utils -from gt4py.next import common +from gt4py.next import common, utils as next_utils from gt4py.next.iterator.type_system import type_specifications as it_ts from gt4py.next.type_system import type_specifications as ts @@ -143,63 +132,53 @@ def constituents_yielder( return utils.xiter(constituents_yielder(symbol_type, ())) # type: ignore[return-value] # why resolved to XIterable[object]? -_R = TypeVar("_R", covariant=True) _T = TypeVar("_T") +_C = TypeVar("_C") -class TupleConstructorType(Protocol, Generic[_R]): - def __call__(self, *args: Any) -> _R: ... +def _tree_map_type_constructor( + value: ts.CollectionTypeSpecT, + elems: xtyping.NestedTuple[ts.DataType], +) -> ts.CollectionTypeSpecT: + return ( + ts.NamedCollectionType( + keys=value.keys, original_python_type=value.original_python_type, types=list(elems) + ) + if isinstance(value, ts.NamedCollectionType) + else ts.TupleType(types=list(elems)) # type: ignore[return-value] + ) + + +@overload +def tree_map_type( + fun: Callable[..., _T], *, with_path_arg: bool = ..., unpack: bool = ... +) -> Callable[..., _T | ts.CollectionTypeSpec]: ... -def apply_to_primitive_constituents( +@overload +def tree_map_type( fun: Callable[..., _T], - *symbol_types: ts.TypeSpec, - with_path_arg: bool = False, - tuple_constructor: TupleConstructorType[_R] = lambda *elements: ts.TupleType(types=[*elements]), # type: ignore[assignment] # probably related to https://github.com/python/mypy/issues/10854 - _path: tuple[int, ...] = (), -) -> _T | _R: - """ - Apply function to all primitive constituents of a type. + *, + result_collection_constructor: Callable[..., _C], + with_path_arg: bool = ..., + unpack: bool = ..., +) -> Callable[..., _T | _C]: ... - >>> int_type = ts.ScalarType(kind=ts.ScalarKind.INT64) - >>> tuple_type = ts.TupleType(types=[int_type, int_type]) - >>> print( - ... apply_to_primitive_constituents( - ... lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type), - ... tuple_type, - ... ) - ... ) - tuple[Field[[], int64], Field[[], int64]] - >>> apply_to_primitive_constituents( - ... lambda primitive_type, path: (path, primitive_type), - ... tuple_type, - ... with_path_arg=True, - ... tuple_constructor=lambda *elements: dict(elements), - ... ) - {(0,): ScalarType(kind=, shape=None), (1,): ScalarType(kind=, shape=None)} - """ - if isinstance(symbol_types[0], ts.TupleType): - assert all(isinstance(symbol_type, ts.TupleType) for symbol_type in symbol_types) - - return tuple_constructor( - *[ - apply_to_primitive_constituents( - fun, - *el_types, - _path=(*_path, i), - with_path_arg=with_path_arg, - tuple_constructor=tuple_constructor, - ) - for i, el_types in enumerate( - zip(*(symbol_type.types for symbol_type in symbol_types)) # type: ignore[attr-defined] # ensured by assert above - ) - ] - ) - if with_path_arg: - return fun(*symbol_types, path=_path) - else: - return fun(*symbol_types) +def tree_map_type( + fun: Callable[..., _T], + *, + result_collection_constructor: Callable[..., Any] = _tree_map_type_constructor, + with_path_arg: bool = False, + unpack: bool = False, +) -> Callable[..., Any]: + return next_utils.tree_map( + fun, + collection_type=ts.COLLECTION_TYPE_SPECS, + result_collection_constructor=result_collection_constructor, + with_path_arg=with_path_arg, + unpack=unpack, + ) def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType | ts.ListType: