Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ dependencies = [
'ninja>=1.11',
'numpy>=2.0.0 ; python_version < "3.14"',
'numpy>=2.3.2 ; python_version >= "3.14"', # first version with pre-built wheel for Python 3.14
'ordered-set>=4.1.0',
Comment thread
edopao marked this conversation as resolved.
Comment thread
edopao marked this conversation as resolved.
'packaging>=20.0',
Comment thread
edopao marked this conversation as resolved.
Comment thread
edopao marked this conversation as resolved.
'pybind11>=3.0.3',
'setuptools>=77.0.3',
Expand Down Expand Up @@ -478,6 +479,7 @@ url = 'https://gridtools.github.io/pypi/'
# dace = {index = "gridtools"}
[tool.uv.sources]
atlas4py = {index = "test.pypi"}
dace = {git = "https://github.com/GridTools/dace", branch = "dev-sorted_sets"}
Comment thread
edopao marked this conversation as resolved.
Outdated

# -- versioningit --
[tool.versioningit]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1000,8 +1000,8 @@ def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExp
nsdfg_symbols_mapping["__cond"] = condition_value.value
nsdfg_node = self.state.add_nested_sdfg(
nsdfg,
inputs=sorted(used_connectivities | input_memlets.keys()),
outputs=sorted(outputs),
inputs={key: None for key in sorted(used_connectivities | input_memlets.keys())},
Comment thread
edopao marked this conversation as resolved.
outputs={key: None for key in sorted(outputs)},
symbol_mapping=nsdfg_symbols_mapping,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -766,8 +766,8 @@ def add_nested_sdfg(

nsdfg_node = outer_ctx.state.add_nested_sdfg(
inner_ctx.sdfg,
inputs=input_memlets.keys(),
outputs=lambda_outputs,
inputs={key: None for key in sorted(input_memlets.keys())},
outputs={key: None for key in sorted(lambda_outputs)},
symbol_mapping=nsdfg_symbols_mapping,
debuginfo=gtir_to_sdfg_utils.debug_info(node, default=outer_ctx.sdfg.debuginfo),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from dace.transformation.auto import auto_optimize as dace_aoptimize
from dace.transformation.passes import analysis as dace_analysis

from gt4py.next import common as gtx_common
from gt4py.next import common as gtx_common, utils as gtx_utils
from gt4py.next.program_processors.runners.dace import (
library_nodes as gtx_library_nodes,
transformations as gtx_transformations,
Expand Down Expand Up @@ -234,6 +234,7 @@ def gt_auto_optimize(
Something along the line "Fuse if operational intensity goes up, but
not if we have too much internal space (register pressure).
"""
uids = gtx_utils.IDGeneratorPool()
device = dace.DeviceType.GPU if gpu else dace.DeviceType.CPU
optimization_hooks = optimization_hooks or {}

Expand Down Expand Up @@ -333,6 +334,7 @@ def gt_auto_optimize(
scan_loop_unrolling_factor=scan_loop_unrolling_factor,
fuse_tasklets=fuse_tasklets,
validate_all=validate_all,
uids=uids,
)

# Configure the Maps:
Expand Down Expand Up @@ -690,6 +692,7 @@ def _gt_auto_process_dataflow_inside_maps(
scan_loop_unrolling_factor: int,
fuse_tasklets: bool,
validate_all: bool,
uids: gtx_utils.IDGeneratorPool,
) -> dace.SDFG:
"""Optimizes the dataflow inside the top level Maps of the SDFG inplace.

Expand Down Expand Up @@ -735,7 +738,7 @@ def _gt_auto_process_dataflow_inside_maps(
# Constants (tasklets are needed to write them into a variable) should not be
# arguments to a kernel but be present inside the body.
sdfg.apply_transformations_once_everywhere(
gtx_transformations.GT4PyMoveTaskletIntoMap,
gtx_transformations.GT4PyMoveTaskletIntoMap(uids=uids),
validate=False,
validate_all=validate_all,
)
Expand All @@ -762,7 +765,7 @@ def _gt_auto_process_dataflow_inside_maps(
# Make sure that this runs before MoveDataflowIntoIfBody because atm it doesn't handle
# NestedSDFGs inside the ConditionalBlocks it fuses.
sdfg.apply_transformations_repeated(
gtx_transformations.FuseHorizontalConditionBlocks(),
gtx_transformations.FuseHorizontalConditionBlocks(uids=uids),
validate=False,
validate_all=validate_all,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dace.sdfg import graph as dace_graph, nodes as dace_nodes
from dace.transformation import helpers as dace_helpers

from gt4py.next import utils as gtx_utils
from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations


Expand Down Expand Up @@ -50,6 +51,11 @@ class FuseHorizontalConditionBlocks(dace_transformation.SingleStateTransformatio
conditional_access_node = dace_transformation.PatternNode(dace_nodes.AccessNode)
nsdfg_a = dace_transformation.PatternNode(dace_nodes.NestedSDFG)
nsdfg_b = dace_transformation.PatternNode(dace_nodes.NestedSDFG)
uids = dace_properties.Property(dtype=gtx_utils.IDGeneratorPool)

def __init__(self, *args: Any, uids: gtx_utils.IDGeneratorPool, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.uids = uids

# The fusion of the two conditional blocks can happen in any order. To avoid any indeterminism distinguish which one is the fused and which one is the extended conditional block which will include the fused one.
@staticmethod
Expand Down Expand Up @@ -293,7 +299,7 @@ def apply(
for data_name, data_desc in fused_conditional_block.sdfg.arrays.items():
if data_name == "__cond":
continue
new_data_name = gtx_transformations.utils.unique_name(data_name) + "_from_cb_fusion"
new_data_name = next(self.uids[f"{data_name}_cb_fusion"])
data_desc_renamed = copy.deepcopy(data_desc)
second_arrays_rename_map[data_name] = (
nested_sdfg_of_extended_conditional_block.sdfg.add_datadesc(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import copy
import warnings
from typing import Final, Iterable, Optional, Sequence, TypeAlias
Expand All @@ -14,6 +16,7 @@
import sympy
from dace import data as dace_data, subsets as dace_sbs, symbolic as dace_sym
from dace.sdfg import graph as dace_graph, nodes as dace_nodes
from ordered_set import OrderedSet

Comment thread
edopao marked this conversation as resolved.
from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations

Expand Down Expand Up @@ -73,7 +76,7 @@ def perform_dataflow_inlining(
sdfg: dace.SDFG,
state: dace.SDFGState,
edge: dace_graph.MultiConnectorEdge[dace.Memlet],
nodes_to_inline: set[dace_nodes.Node],
nodes_to_inline: OrderedSet[dace_nodes.Node],
inline_spec: InlineSpec,
) -> Optional[tuple[dace_nodes.NestedSDFG, dace_nodes.AccessNode]]:
"""Performs the second step, i.e. the actual inlining, of the dataflow.
Expand Down Expand Up @@ -134,7 +137,7 @@ def find_nodes_to_inline(
sdfg: dace.SDFG,
state: dace.SDFGState,
edge: dace_graph.MultiConnectorEdge[dace.Memlet],
) -> Optional[tuple[set[dace_nodes.Node], InlineSpec]]:
) -> Optional[tuple[OrderedSet[dace_nodes.Node], InlineSpec]]:
"""First step of dataflow inlining, computing the inline specification.

The inline specification describes how the inlining of dataflow has to be done.
Expand Down Expand Up @@ -477,7 +480,7 @@ def _insert_nested_sdfg(
def _populate_nested_sdfg(
sdfg: dace.SDFG,
state: dace.SDFGState,
nodes_to_replicate: set[dace_nodes.Node],
nodes_to_replicate: OrderedSet[dace_nodes.Node],
first_map_exit: dace_nodes.MapExit,
exchange_subset: dace_sbs.Range,
intermediate_node: dace_nodes.AccessNode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def _new_name(old_name: str) -> str:
elif isinstance(node, dace_nodes.NestedSDFG):
node_ = graph.add_nested_sdfg(
sdfg=copy.deepcopy(node.sdfg),
inputs=set(node.in_connectors.keys()),
outputs=set(node.out_connectors.keys()),
inputs={k: None for k in node.in_connectors.keys()},
outputs={k: None for k in node.out_connectors.keys()},
Comment thread
edopao marked this conversation as resolved.
Outdated
symbol_mapping=node.symbol_mapping.copy(),
Comment thread
edopao marked this conversation as resolved.
debuginfo=copy.copy(node.debuginfo),
)
Expand Down Expand Up @@ -224,7 +224,7 @@ def split_overlapping_map_range(

first_map_splitted_dict = {}
second_map_splitted_dict = {}
for param in first_map_params:
for param in sorted(first_map_params):
first_map_range = first_map_dict[param]
second_map_range = second_map_dict[param]
if (step := first_map_range[2]) != second_map_range[2]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import collections
import copy
import functools
Expand All @@ -24,6 +26,7 @@
type_inference as dace_type_inference,
utils as dace_sutils,
)
from ordered_set import OrderedSet

Comment thread
edopao marked this conversation as resolved.
from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations

Expand Down Expand Up @@ -717,8 +720,8 @@ def _filter_relocatable_dataflow(
sdfg: dace.SDFG,
state: dace.SDFGState,
if_block: dace_nodes.NestedSDFG,
raw_relocatable_dataflow: dict[str, set[dace_nodes.Node]],
non_relocatable_dataflow: dict[str, set[dace_nodes.Node]],
raw_relocatable_dataflow: dict[str, OrderedSet[dace_nodes.Node]],
non_relocatable_dataflow: dict[str, OrderedSet[dace_nodes.Node]],
connector_usage_location: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]],
enclosing_map: dace_nodes.MapEntry,
) -> set[dace_nodes.Node]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ def modifies(self) -> dace_ppl.Modifies:
def should_reapply(self, modified: dace_ppl.Modifies) -> bool:
return modified & (dace_ppl.Modifies.Memlets | dace_ppl.Modifies.AccessNodes)

def depends_on(self) -> set[type[dace_transformation.Pass]]:
return {
def depends_on(self) -> list[type[dace_transformation.Pass]]:
return [
dace_transformation.passes.FindAccessStates,
}
]

def apply_pass(
self, sdfg: dace.SDFG, pipeline_results: dict[str, Any]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import collections
import copy
import uuid
import warnings
from typing import Any, Iterable, Optional, TypeAlias

Expand All @@ -29,6 +28,7 @@
passes as dace_passes,
)

from gt4py.next import utils as gtx_utils
from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations


Expand Down Expand Up @@ -551,11 +551,11 @@ def modifies(self) -> dace_ppl.Modifies:
def should_reapply(self, modified: dace_ppl.Modifies) -> bool:
return modified & (dace_ppl.Modifies.Memlets | dace_ppl.Modifies.AccessNodes)

def depends_on(self) -> set[type[dace_transformation.Pass]]:
return {
def depends_on(self) -> list[type[dace_transformation.Pass]]:
return [
dace_transformation.passes.StateReachability,
dace_transformation.passes.FindAccessStates,
}
]

def apply_pass(
self, sdfg: dace.SDFG, pipeline_results: dict[str, Any]
Expand Down Expand Up @@ -933,13 +933,16 @@ class GT4PyMoveTaskletIntoMap(dace_transformation.SingleStateTransformation):
tasklet = dace_transformation.PatternNode(dace_nodes.Tasklet)
access_node = dace_transformation.PatternNode(dace_nodes.AccessNode)
map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry)
uids = dace_properties.Property(dtype=gtx_utils.IDGeneratorPool)

def __init__(
self,
*args: Any,
uids: gtx_utils.IDGeneratorPool,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.uids = uids

@classmethod
def expressions(cls) -> Any:
Expand Down Expand Up @@ -1008,7 +1011,7 @@ def apply(
# This is the tasklet that we will put inside the map, note we have to do it
# this way to avoid some name clash stuff.
inner_tasklet: dace_nodes.Tasklet = graph.add_tasklet(
name=f"{tasklet.label}__clone_{str(uuid.uuid1()).replace('-', '_')}",
name=next(self.uids[f"{tasklet.label}__clone"]),
outputs=tasklet.out_connectors.keys(),
inputs=set(),
Comment thread
edopao marked this conversation as resolved.
code=tasklet.code,
Expand Down Expand Up @@ -1157,8 +1160,8 @@ def __init__(
def expressions(cls) -> Any:
return [dace.sdfg.utils.node_path_graph(cls.map_exit, cls.tmp_ac, cls.glob_ac)]

def depends_on(self) -> set[type[dace_transformation.Pass]]:
return {dace_transformation.passes.ConsolidateEdges}
def depends_on(self) -> list[type[dace_transformation.Pass]]:
return [dace_transformation.passes.ConsolidateEdges]

def can_be_applied(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import warnings
from typing import Any, Iterable, Optional

Expand All @@ -17,6 +19,7 @@
)
from dace.sdfg import graph as dace_graph, nodes as dace_nodes
from dace.transformation.passes import analysis as dace_analysis
from ordered_set import OrderedSet

Comment thread
edopao marked this conversation as resolved.
from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations
from gt4py.next.program_processors.runners.dace.transformations import (
Expand Down Expand Up @@ -148,7 +151,7 @@ def _apply_split_access_node_non_recursive(
class SplitAccessNode(dace_transformation.SingleStateTransformation):
"""The transformation will split an AccessNode into multiple ones.

If there is no interesection between a write and different reads,
If there is no intersection between a write and different reads,
i.e. if every read to the AccessNode can be satisfied by a single
write to the AccessNode and the AccessNode is only used at one
location, then the node is split.
Expand Down Expand Up @@ -307,7 +310,7 @@ def apply(
def _find_edge_reassignment(
self,
state: dace.SDFGState,
) -> dict[dace_graph.MultiConnectorEdge, set[dace_graph.MultiConnectorEdge]] | None:
) -> dict[dace_graph.MultiConnectorEdge, OrderedSet[dace_graph.MultiConnectorEdge]] | None:
"""Determine how the edges should be distributed to the fragments.

The current implementation defines the fragments, i.e. the pieces into
Expand Down Expand Up @@ -335,14 +338,14 @@ def _find_edge_reassignment(
# generate the data for a consumer). This is hard to handle, but should
# be implemented at some point.
edge_reassignments: dict[
dace_graph.MultiConnectorEdge, set[dace_graph.MultiConnectorEdge]
dace_graph.MultiConnectorEdge, OrderedSet[dace_graph.MultiConnectorEdge]
] = {}
for iedge in state.in_edges(access_node):
if iedge.data.dst_subset is None:
return None # TODO(phimuell): Lift this.
if iedge.data.wcr is not None:
return None
edge_reassignments[iedge] = set()
edge_reassignments[iedge] = OrderedSet()

# Now match the outgoing edges to their incoming producers.
for oedge in state.out_edges(access_node):
Expand Down Expand Up @@ -426,7 +429,9 @@ def _check_split_constraints(
self,
state: dace.SDFGState,
sdfg: dace.SDFG,
edge_reassignments: dict[dace_graph.MultiConnectorEdge, set[dace_graph.MultiConnectorEdge]],
edge_reassignments: dict[
dace_graph.MultiConnectorEdge, OrderedSet[dace_graph.MultiConnectorEdge]
],
) -> bool:
"""Checks if the decomposition results in a valid SDFG.

Expand Down
Loading