Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,10 +871,9 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> foast.Call:

new_type = new_type_constructor.type.definition.returns

return_type = type_info.apply_to_primitive_constituents(
lambda primitive_type: with_altered_scalar_kind(primitive_type, new_type.kind),
value.type,
)
return_type = type_info.tree_map_type(
lambda primitive_type: with_altered_scalar_kind(primitive_type, new_type.kind)
)(value.type)
assert isinstance(return_type, (ts.TupleType, ts.ScalarType, ts.FieldType))

return foast.Call(
Expand Down Expand Up @@ -962,7 +961,7 @@ def deduce_return_type(
return_type = ts.FieldType(dims=return_dims, dtype=t_dtype)
return return_type

return deduce_return_type(true_branch, false_branch) # type: ignore[return-value]
return deduce_return_type(true_branch, false_branch)

def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
mask_type, true_branch_type, false_branch_type = (arg.type for arg in node.args)
Expand Down
22 changes: 2 additions & 20 deletions src/gt4py/next/ffront/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,8 @@
from gt4py.next.type_system import type_info


def _tree_map_type_constructor(
value: ts.CollectionTypeSpecT,
elems: NestedTuple[ts.DataType],
) -> ts.CollectionTypeSpecT:
return (
ts.NamedCollectionType(
keys=value.keys, original_python_type=value.original_python_type, types=list(elems)
)
if isinstance(value, ts.NamedCollectionType)
else ts.TupleType(types=list(elems)) # type: ignore[return-value]
)


# TODO: Replace all occurrences of `apply_to_primitive_constituents` with this function,
# which also works with NamedCollections.
tree_map_type = functools.partial(
utils.tree_map,
collection_type=ts.COLLECTION_TYPE_SPECS,
result_collection_constructor=_tree_map_type_constructor,
)
_tree_map_type_constructor = type_info._tree_map_type_constructor
tree_map_type = type_info.tree_map_type

named_collections_to_tuple_types = cast(
Callable[..., ts.TupleType],
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/fuse_as_fieldop.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def fuse_as_fieldop(
assert isinstance(arg.type, ts.TypeSpec)

assert not isinstance(
type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type),
type_info.tree_map_type(type_info.extract_dtype)(arg.type),
ts.ListType,
)
new_args = _merge_arguments(new_args, {stencil_param.id: arg})
Expand Down Expand Up @@ -215,7 +215,7 @@ def _arg_inline_predicate(node: itir.Expr, shifts: set[tuple[itir.OffsetLiteral,
# applied fieldop with list return type must always be inlined as no backend supports this
type_inference.reinfer(node)
assert isinstance(node.type, ts.TypeSpec)
dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, node.type)
dtype = type_info.tree_map_type(type_info.extract_dtype)(node.type)
if isinstance(dtype, ts.ListType):
return True
# only accessed at the center location
Expand Down
23 changes: 11 additions & 12 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,24 +169,24 @@ def _transform_by_pattern(
if extracted_fields:
tmp_stmts: list[itir.Stmt] = []

into_tuple = lambda _, elts: tuple(elts) # noqa: E731

# for each extracted expression generate:
# - one or more `Temporary` declarations (depending on whether the expression is a field
# or a tuple thereof)
# - one `SetAt` statement that materializes the expression into the temporary
for tmp_sym, tmp_expr in extracted_fields.items():
assert isinstance(tmp_expr.type, ts.TypeSpec)
tmp_names: str | tuple[str | tuple, ...] = type_info.apply_to_primitive_constituents(
tmp_names: str | tuple[str | tuple, ...] = type_info.tree_map_type(
lambda x: next(uids["__tmp"]),
tmp_expr.type,
tuple_constructor=lambda *elements: tuple(elements),
)
result_collection_constructor=into_tuple,
)(tmp_expr.type)
tmp_dtypes: (
ts.ScalarType | ts.ListType | tuple[ts.ScalarType | ts.ListType | tuple, ...]
) = type_info.apply_to_primitive_constituents(
) = type_info.tree_map_type(
type_info.extract_dtype,
tmp_expr.type,
tuple_constructor=lambda *elements: elements,
)
result_collection_constructor=into_tuple,
)(tmp_expr.type)

tmp_domains: SymbolicDomain | tuple[SymbolicDomain | tuple, ...] = tmp_expr.annex.domain

Expand Down Expand Up @@ -214,12 +214,11 @@ def get_domain(

# The following propagates the domains to the tuple structure of `tmp_expr.type`.
# `tmp_domains` might not have this structure because domain inference was not able to infer the tuple structure.
tmp_domains = type_info.apply_to_primitive_constituents(
tmp_domains = type_info.tree_map_type(
get_domain,
tmp_expr.type,
result_collection_constructor=into_tuple,
with_path_arg=True,
tuple_constructor=lambda *elements: tuple(elements),
)
)(tmp_expr.type)

declarations.extend(
itir.Temporary(id=tmp_name, domain=domain.as_expr(), dtype=dtype)
Expand Down
7 changes: 4 additions & 3 deletions src/gt4py/next/iterator/type_system/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,10 +424,11 @@ def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.Tup
assert isinstance(domain, ts.DomainType)
assert domain.dims != "unknown"
assert node.dtype
return type_info.apply_to_primitive_constituents(
lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype),
node.dtype,
result = type_info.tree_map_type(lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype))(
node.dtype
)
assert isinstance(result, (ts.FieldType, ts.TupleType))
return result

def visit_IfStmt(self, node: itir.IfStmt, *, ctx) -> None:
cond = self.visit(node.cond, ctx=ctx)
Expand Down
22 changes: 8 additions & 14 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,7 @@ def _convert_as_fieldop_input_to_iterator(
Convert a field operation input into an iterator type, preserving its dimensions and data type.
"""
input_dims = _collect_and_check_dimensions(input_)
element_type: ts.DataType = type_info.apply_to_primitive_constituents(
type_info.extract_dtype, input_
)
element_type: ts.DataType = type_info.tree_map_type(type_info.extract_dtype)(input_)

return it_ts.IteratorType(
position_dims=domain.dims, defined_dims=input_dims, element_type=element_type
Expand Down Expand Up @@ -436,9 +434,7 @@ def _canonicalize_nb_fields(
)
case ts.FieldType():
input_dims = _collect_and_check_dimensions(input_)
element_type: ts.DataType = type_info.apply_to_primitive_constituents(
type_info.extract_dtype, input_
)
element_type: ts.DataType = type_info.tree_map_type(type_info.extract_dtype)(input_)
defined_dims = []
neighbor_dim = None
for dim in input_dims:
Expand Down Expand Up @@ -535,7 +531,7 @@ def applied_as_fieldop(
# For each stencil parameter all locations it is `deref`ed on
# see :func:`gt4py.next.iterator.transforms.trace_stencil`.
shift_sequences_per_param: list[set[tuple[itir.OffsetLiteral, ...]]] | None,
) -> ts.FieldType | ts.DeferredType:
) -> ts.FieldType | ts.TupleType | ts.DeferredType:
if any(
isinstance(el, ts.DeferredType)
for f in fields
Expand Down Expand Up @@ -578,13 +574,11 @@ def applied_as_fieldop(

assert isinstance(stencil_return, ts.DataType)

return type_info.apply_to_primitive_constituents(
lambda el_type: ts.FieldType(
dims=domain.dims,
dtype=el_type,
),
stencil_return,
)
result = type_info.tree_map_type(
lambda el_type: ts.FieldType(dims=domain.dims, dtype=el_type)
)(stencil_return)
assert isinstance(result, (ts.FieldType, ts.TupleType))
return result

return applied_as_fieldop

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _process_elements(
obj: Expr,
type_: ts.TypeSpec,
*,
tuple_constructor: Callable[..., Expr] = lambda _, *elements: FunCall(
tuple_constructor: Callable[..., Expr] = lambda _, elements: FunCall(
fun=SymRef(id="make_tuple"), args=list(elements)
),
) -> Expr:
Expand Down Expand Up @@ -264,13 +264,11 @@ def _gen_constituent_expr(el_type: ts.ScalarType | ts.FieldType, path: tuple[int
)
return process_func(el, el_type)

result = type_info.apply_to_primitive_constituents(
return type_info.tree_map_type(
_gen_constituent_expr,
type_,
result_collection_constructor=tuple_constructor,
with_path_arg=True,
tuple_constructor=tuple_constructor,
)
return result
)(type_)


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -523,7 +521,7 @@ def check_el_type(el_expr: Expr, el_type: ts.ScalarType | ts.FieldType) -> Expr:
check_el_type,
lowered_output,
node.type,
tuple_constructor=lambda *elements: SidComposite(values=list(elements)),
tuple_constructor=lambda _, elements: SidComposite(values=list(elements)),
)

assert isinstance(lowered_output_as_sid, (SidComposite, SymRef))
Expand Down Expand Up @@ -615,7 +613,7 @@ def convert_el_to_sid(el_expr: Expr, el_type: ts.ScalarType | ts.FieldType) -> E
convert_el_to_sid,
lowered_input,
input_.type,
tuple_constructor=lambda *elements: SidComposite(values=list(elements)),
tuple_constructor=lambda _, elements: SidComposite(values=list(elements)),
)

lowered_inputs.append(lowered_input_as_sid)
Expand Down
103 changes: 41 additions & 62 deletions src/gt4py/next/type_system/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,12 @@
import functools
import types
from collections.abc import Callable, Iterator
from typing import (
Any,
Generic,
Literal,
Protocol,
Sequence,
Type,
TypeGuard,
TypeVar,
cast,
overload,
)
from typing import Any, Literal, Sequence, Type, TypeGuard, TypeVar, cast, overload

import numpy as np

from gt4py.eve import extended_typing as xtyping, utils
from gt4py.next import common
from gt4py.next import common, utils as next_utils
from gt4py.next.iterator.type_system import type_specifications as it_ts
from gt4py.next.type_system import type_specifications as ts

Expand Down Expand Up @@ -143,63 +132,53 @@ def constituents_yielder(
return utils.xiter(constituents_yielder(symbol_type, ())) # type: ignore[return-value] # why resolved to XIterable[object]?


_R = TypeVar("_R", covariant=True)
_T = TypeVar("_T")
_C = TypeVar("_C")


class TupleConstructorType(Protocol, Generic[_R]):
def __call__(self, *args: Any) -> _R: ...
def _tree_map_type_constructor(
value: ts.CollectionTypeSpecT,
elems: xtyping.NestedTuple[ts.DataType],
) -> ts.CollectionTypeSpecT:
return (
ts.NamedCollectionType(
keys=value.keys, original_python_type=value.original_python_type, types=list(elems)
)
if isinstance(value, ts.NamedCollectionType)
else ts.TupleType(types=list(elems)) # type: ignore[return-value]
)


@overload
def tree_map_type(
fun: Callable[..., _T], *, with_path_arg: bool = ..., unpack: bool = ...
) -> Callable[..., _T | ts.CollectionTypeSpec]: ...


def apply_to_primitive_constituents(
@overload
def tree_map_type(
fun: Callable[..., _T],
*symbol_types: ts.TypeSpec,
with_path_arg: bool = False,
tuple_constructor: TupleConstructorType[_R] = lambda *elements: ts.TupleType(types=[*elements]), # type: ignore[assignment] # probably related to https://github.com/python/mypy/issues/10854
_path: tuple[int, ...] = (),
) -> _T | _R:
"""
Apply function to all primitive constituents of a type.
*,
result_collection_constructor: Callable[..., _C],
with_path_arg: bool = ...,
unpack: bool = ...,
) -> Callable[..., _T | _C]: ...

>>> int_type = ts.ScalarType(kind=ts.ScalarKind.INT64)
>>> tuple_type = ts.TupleType(types=[int_type, int_type])
>>> print(
... apply_to_primitive_constituents(
... lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type),
... tuple_type,
... )
... )
tuple[Field[[], int64], Field[[], int64]]

>>> apply_to_primitive_constituents(
... lambda primitive_type, path: (path, primitive_type),
... tuple_type,
... with_path_arg=True,
... tuple_constructor=lambda *elements: dict(elements),
... )
{(0,): ScalarType(kind=<ScalarKind.INT64: 8>, shape=None), (1,): ScalarType(kind=<ScalarKind.INT64: 8>, shape=None)}
"""
if isinstance(symbol_types[0], ts.TupleType):
assert all(isinstance(symbol_type, ts.TupleType) for symbol_type in symbol_types)

return tuple_constructor(
*[
apply_to_primitive_constituents(
fun,
*el_types,
_path=(*_path, i),
with_path_arg=with_path_arg,
tuple_constructor=tuple_constructor,
)
for i, el_types in enumerate(
zip(*(symbol_type.types for symbol_type in symbol_types)) # type: ignore[attr-defined] # ensured by assert above
)
]
)
if with_path_arg:
return fun(*symbol_types, path=_path)
else:
return fun(*symbol_types)
def tree_map_type(
fun: Callable[..., _T],
*,
result_collection_constructor: Callable[..., Any] = _tree_map_type_constructor,
with_path_arg: bool = False,
unpack: bool = False,
) -> Callable[..., Any]:
return next_utils.tree_map(
fun,
collection_type=ts.COLLECTION_TYPE_SPECS,
result_collection_constructor=result_collection_constructor,
with_path_arg=with_path_arg,
unpack=unpack,
)


def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType | ts.ListType:
Expand Down