From cd90f5b11f66d0c53cb6d8be1887d8638044d10d Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 14 May 2026 12:00:57 -0400 Subject: [PATCH 01/43] Fix `ReplaceAxisSymbol` and keep it to Taskslets -> `ReplaceAxisSymbolInTasklet` Move re-usable functions into `tree_common_op` --- .../dace/stree/optimizations/axis_merge.py | 59 ++++++------------- .../stree/optimizations/tree_common_op.py | 14 ++++- 2 files changed, 32 insertions(+), 41 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 0f97468c..94920954 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +import itertools import dace from dace.properties import CodeBlock @@ -14,6 +15,8 @@ detect_cycle, list_index, swap_node_position_in_tree, + is_axis_map, + is_axis_for ) from ndsl.logging import ndsl_log @@ -22,16 +25,6 @@ PUSH_IFSCOPE_DOWNWARD = False # Crashing the overall stree - bad algorithmics -def _is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool: - """Returns true if node is a map over the given axis.""" - map_parameter = node.node.map.params - return len(map_parameter) == 1 and map_parameter[0].startswith(axis.as_str()) - - -def _is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool: - return node.loop.loop_variable.startswith(axis.as_str()) - - def _both_same_single_axis_maps( first: tn.MapScope, second: tn.MapScope, axis: AxisIterator ) -> bool: @@ -39,8 +32,8 @@ def _both_same_single_axis_maps( ( len(first.node.map.params) == 1 and len(second.node.map.params) == 1 ) # Single axis - and _is_axis_map(first, axis) # Correct axis in first map - and _is_axis_map(second, axis) # Correct axis in second map + and is_axis_map(first, axis) # Correct axis in first map + and is_axis_map(second, axis) # Correct axis in second map ) @@ -109,26 +102,10 @@ def _last_node(nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode) -> b return list_index(nodes, node) >= len(nodes) - 1 -class ReplaceAxisSymbol(tn.ScheduleNodeVisitor): +class ReplaceAxisSymbolInTasklet(tn.ScheduleNodeVisitor): def __init__(self, axis: AxisIterator) -> None: self._axis = axis - def visit_MapScope( - self, - map_scope: tn.MapScope, - axis_replacements: dict[str, str] | None = None, - ) -> None: - if axis_replacements is None: - axis_replacements = {} - - for index, param in enumerate(map_scope.node.params): - if param in axis_replacements: - map_scope.node.params[index] = axis_replacements[param] - - # visit children - for child in map_scope.children: - self.visit(child, axis_replacements=axis_replacements) - def visit_TaskletNode( self, node: tn.TaskletNode, @@ -138,11 +115,13 @@ def visit_TaskletNode( # Noop if there are no replacements to do. return - for memlets in node.in_memlets.values(): - memlets.replace(axis_replacements) - for memlets in node.out_memlets.values(): - memlets.replace(axis_replacements) - + # Dev NOTE: We directly replace the memlet.subset because the `memlet.replace` + # function sometimes doesn't work + for memlet in itertools.chain(node.in_memlets.values(), node.out_memlets.values()): + if memlet.subset is not None: + memlet.subset.replace(axis_replacements) + if memlet.other_subset is not None: + memlet.other_subset.replace(axis_replacements) class CartesianAxisMerge(tn.ScheduleNodeTransformer): """Merge a cartesian axis if they are contiguous in code-flow. @@ -197,7 +176,7 @@ def _merge_node( def _for_merge(self, the_for_scope: tn.ForScope) -> int: merged = 0 - if _is_axis_for(the_for_scope, self.axis): + if is_axis_for(the_for_scope, self.axis): # TODO: if the for scope is on a cartesian axis it can be # merged with other for scope going in the same direction pass @@ -206,7 +185,7 @@ def _for_merge(self, the_for_scope: tn.ForScope) -> int: if ( len(the_for_scope.children) == 1 and isinstance(the_for_scope.children[0], tn.MapScope) - and _is_axis_map(the_for_scope.children[0], self.axis) + and is_axis_map(the_for_scope.children[0], self.axis) ): swap_node_position_in_tree(the_for_scope, the_for_scope.children[0]) merged += 1 @@ -327,7 +306,7 @@ def _map_overcompute_merge( # End of nodes OR # Not the right axis # --> recurse - if _last_node(nodes, the_map) or not _is_axis_map(the_map, self.axis): + if _last_node(nodes, the_map) or not is_axis_map(the_map, self.axis): merged = 0 for child in the_map.children: merged += self._merge_node(child, the_map.children) @@ -384,9 +363,9 @@ def _map_overcompute_merge( # K-maps use unique iterators (i.e. every k-map iterates over `k__[0-9]*`). # After merge, we need to replace the axis symbols of the second map's children # with the axis symbol of the first map. - if next_node.node.map.params[0] != the_map.node.map.params[0]: - replacements = {next_node.node.map.params[0]: the_map.node.map.params[0]} - ReplaceAxisSymbol(self.axis).visit( + if second_map.node.map.params[0] != first_map.node.map.params[0]: + replacements = {second_map.node.map.params[0]: first_map.node.map.params[0]} + ReplaceAxisSymbolInTasklet(self.axis).visit( first_map, axis_replacements=replacements ) diff --git a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py b/ndsl/dsl/dace/stree/optimizations/tree_common_op.py index 1253ba81..e243ede3 100644 --- a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py +++ b/ndsl/dsl/dace/stree/optimizations/tree_common_op.py @@ -1,7 +1,7 @@ from typing import Collection import dace.sdfg.analysis.schedule_tree.treenodes as tn - +from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator def swap_node_position_in_tree( top_node: tn.ScheduleTreeScope, child_node: tn.ScheduleTreeScope @@ -51,3 +51,15 @@ def list_index( """Check if node is in list with "is" operator.""" # compare with "is" to get memory comparison. ".index()" uses value comparison return next(index for index, element in enumerate(collection) if element is node) + + +def is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool: + """Returns true if node is a Map over the given axis.""" + map_parameter = node.node.map.params + return len(map_parameter) == 1 and map_parameter[0].startswith(axis.as_str()) + + +def is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool: + """Returns true if node is a For over the given axis.""" + return node.loop.loop_variable.startswith(axis.as_str()) + From 2c8f74e6d17ff634f1f7ffc2329a3c66d866749f Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 14 May 2026 12:02:56 -0400 Subject: [PATCH 02/43] Add `TreeOptimizationStatistics` to capture the results of the opt at a glance --- .../dace/stree/optimizations/axis_merge.py | 9 +- .../dace/stree/optimizations/statistics.py | 94 +++++++++++++++++++ .../stree/optimizations/tree_common_op.py | 3 +- ndsl/dsl/dace/stree/pipeline.py | 8 ++ 4 files changed, 110 insertions(+), 4 deletions(-) create mode 100644 ndsl/dsl/dace/stree/optimizations/statistics.py diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 94920954..abacd0b4 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -13,10 +13,10 @@ ) from ndsl.dsl.dace.stree.optimizations.tree_common_op import ( detect_cycle, + is_axis_for, + is_axis_map, list_index, swap_node_position_in_tree, - is_axis_map, - is_axis_for ) from ndsl.logging import ndsl_log @@ -117,12 +117,15 @@ def visit_TaskletNode( # Dev NOTE: We directly replace the memlet.subset because the `memlet.replace` # function sometimes doesn't work - for memlet in itertools.chain(node.in_memlets.values(), node.out_memlets.values()): + for memlet in itertools.chain( + node.in_memlets.values(), node.out_memlets.values() + ): if memlet.subset is not None: memlet.subset.replace(axis_replacements) if memlet.other_subset is not None: memlet.other_subset.replace(axis_replacements) + class CartesianAxisMerge(tn.ScheduleNodeTransformer): """Merge a cartesian axis if they are contiguous in code-flow. diff --git a/ndsl/dsl/dace/stree/optimizations/statistics.py b/ndsl/dsl/dace/stree/optimizations/statistics.py new file mode 100644 index 00000000..1a3ce593 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/statistics.py @@ -0,0 +1,94 @@ +import dataclasses + +import dace +import dace.sdfg.analysis.schedule_tree.treenodes as stree + +from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator +from ndsl.dsl.dace.stree.optimizations.tree_common_op import is_axis_for, is_axis_map + + +class CountCartesianLoops(stree.ScheduleNodeVisitor): + def __init__(self) -> None: + super().__init__() + self._maps = [0, 0, 0] + self._fors = [0, 0, 0] + + def visit_MapScope(self, node: stree.MapScope) -> None: + for axis in AxisIterator: + if is_axis_map(node, axis): + self._maps[axis.as_cartesian_index()] += 1 + + self.visit(node.children) + + def visit_ForScope(self, node: stree.ForScope) -> None: + for axis in AxisIterator: + if is_axis_for(node, axis): + self._fors[axis.as_cartesian_index()] += 1 + + self.visit(node.children) + + +class CountTransient(stree.ScheduleNodeVisitor): + def __init__(self) -> None: + super().__init__() + self._counts = [0, 0, 0, 0] + + def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + for data in node.containers.values(): + non_atomic_dims_count = sum(1 for x in data.shape if x != 1) + if isinstance(data, dace.data.Array) and data.transient: + if non_atomic_dims_count == 1: + self._counts[0] += 1 + elif non_atomic_dims_count == 2: + self._counts[1] += 1 + elif non_atomic_dims_count == 3: + self._counts[2] += 1 + else: + self._counts[3] += 1 + + +class TreeOptimizationStatistics: + """Capture basic statistics on the schedule tree optimization actions""" + + @dataclasses.dataclass + class Record: + """Private record of a state of a tree""" + + cartesian_maps: list[int] = dataclasses.field(default_factory=lambda: [0, 0, 0]) + cartesian_fors: list[int] = dataclasses.field(default_factory=lambda: [0, 0, 0]) + transients: list[int] = dataclasses.field(default_factory=lambda: [0, 0, 0, 0]) + + def __init__(self) -> None: + self._original_record = TreeOptimizationStatistics.Record() + self._optimized_record = TreeOptimizationStatistics.Record() + + def _record( + self, + record: Record, + tree_root: stree.ScheduleTreeRoot, + ) -> None: + """Record the state of a tree""" + c = CountCartesianLoops() + c.visit(tree_root) + record.cartesian_fors = c._fors + record.cartesian_maps = c._maps + + c = CountTransient() + c.visit(tree_root) + record.transients = c._counts + + def original(self, tree_root: stree.ScheduleTreeRoot) -> None: + """Record the original state of the tree, before optimization""" + self._record(self._original_record, tree_root) + + def optimized(self, tree_root: stree.ScheduleTreeRoot) -> None: + """Record the state of the tree after optimization""" + self._record(self._optimized_record, tree_root) + + def report(self) -> str: + """Craft a concize string reporting on the statistics""" + msg = "Tree optimization:\n" + msg += f" Cartesian maps [I, J, K]: {self._original_record.cartesian_maps} -> {self._optimized_record.cartesian_maps}\n" + msg += f" Cartesian fors [I, J, K]: {self._original_record.cartesian_fors} -> {self._optimized_record.cartesian_fors}\n" + msg += f" Transients [1D, 2D, 3D, 4D+]: {self._original_record.transients} -> {self._optimized_record.transients}\n" + return msg diff --git a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py b/ndsl/dsl/dace/stree/optimizations/tree_common_op.py index e243ede3..4748a901 100644 --- a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py +++ b/ndsl/dsl/dace/stree/optimizations/tree_common_op.py @@ -1,8 +1,10 @@ from typing import Collection import dace.sdfg.analysis.schedule_tree.treenodes as tn + from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator + def swap_node_position_in_tree( top_node: tn.ScheduleTreeScope, child_node: tn.ScheduleTreeScope ) -> None: @@ -62,4 +64,3 @@ def is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool: def is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool: """Returns true if node is a For over the given axis.""" return node.loop.loop_variable.startswith(axis.as_str()) - diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index e4307ddf..44e933bc 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -3,6 +3,7 @@ import dace.sdfg.analysis.schedule_tree.treenodes as stree from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge +from ndsl.dsl.dace.stree.optimizations.statistics import TreeOptimizationStatistics from ndsl.logging import ndsl_log_on_rank_0 @@ -30,6 +31,9 @@ def run( stree: stree.ScheduleTreeRoot, verbose: bool = False, ) -> stree.ScheduleTreeRoot: + tree_stats = TreeOptimizationStatistics() + tree_stats.original(stree) + for i, p in enumerate(self.passes): if verbose: path = self.cache_directory / f"pass{i}_{p}.txt" @@ -41,6 +45,10 @@ def run( with open(path, "w+") as f: f.write(stree.as_string()) + tree_stats.optimized(stree) + if verbose: + ndsl_log_on_rank_0.info(tree_stats.report()) + return stree From 8b49e3b550ceed51c4e7063932803df1e0e91d96 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 14 May 2026 15:28:09 -0400 Subject: [PATCH 03/43] Add a master `CartesianMerge` bringing everything axis merge, refactor around --- ndsl/dsl/dace/orchestration.py | 50 +++----------- ndsl/dsl/dace/stree/optimizations/__init__.py | 3 +- .../dace/stree/optimizations/axis_merge.py | 65 +++++++++++-------- .../replace_symbol_in_tasklet.py | 29 +++++++++ 4 files changed, 79 insertions(+), 68 deletions(-) create mode 100644 ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index a81b3744..933a379b 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -24,7 +24,6 @@ import ndsl.dsl.dace.replacements # noqa # We load in the DaCe replacements from ndsl.comm.mpi import MPI -from ndsl.config import BackendLoopOrder from ndsl.dsl.dace.build import get_sdfg_path, write_build_info from ndsl.dsl.dace.dace_config import ( DEACTIVATE_DISTRIBUTED_DACE_COMPILE, @@ -40,8 +39,7 @@ ) from ndsl.dsl.dace.stree import CPUPipeline from ndsl.dsl.dace.stree.optimizations import ( - AxisIterator, - CartesianAxisMerge, + CartesianMerge, CartesianRefineTransients, CleanUpScheduleTree, ) @@ -198,44 +196,14 @@ def _build_sdfg( f.write(stree.as_string()) with DaCeProgress(config, "Schedule Tree: optimization"): - passes = [] - if backend_name.loop_order == BackendLoopOrder.IJK: - passes.extend( - [ - CleanUpScheduleTree(), - CartesianAxisMerge(AxisIterator._I), - CartesianAxisMerge(AxisIterator._J), - CartesianAxisMerge(AxisIterator._K), - CartesianRefineTransients(backend_name), - ] - ) - elif backend_name.loop_order == BackendLoopOrder.KJI: - passes.extend( - [ - CleanUpScheduleTree(), - CartesianAxisMerge(AxisIterator._K), - CartesianAxisMerge(AxisIterator._J), - CartesianAxisMerge(AxisIterator._I), - CartesianRefineTransients(backend_name), - ] - ) - elif backend_name.loop_order == BackendLoopOrder.KIJ: - passes.extend( - [ - CleanUpScheduleTree(), - CartesianAxisMerge(AxisIterator._K), - CartesianAxisMerge(AxisIterator._I), - CartesianAxisMerge(AxisIterator._J), - CartesianRefineTransients(backend_name), - ] - ) - else: - raise NotImplementedError( - f"Loop order {backend_name.loop_order} has no schedule tree pipeline" - ) - CPUPipeline(passes=passes, cache_directory=Path(sdfg.build_folder)).run( - stree, verbose=config.verbose_schedule_tree_optimizations - ) + CPUPipeline( + passes=[ + CleanUpScheduleTree(), + CartesianMerge(backend_name), + CartesianRefineTransients(backend_name), + ], + cache_directory=Path(sdfg.build_folder), + ).run(stree, verbose=config.verbose_schedule_tree_optimizations) if config.verbose_orchestration: with open( os.path.abspath(f"{sdfg.build_folder}/03-post_opt.stree.txt"), diff --git a/ndsl/dsl/dace/stree/optimizations/__init__.py b/ndsl/dsl/dace/stree/optimizations/__init__.py index 73497f93..b08d6839 100644 --- a/ndsl/dsl/dace/stree/optimizations/__init__.py +++ b/ndsl/dsl/dace/stree/optimizations/__init__.py @@ -1,4 +1,4 @@ -from .axis_merge import AxisIterator, CartesianAxisMerge +from .axis_merge import AxisIterator, CartesianAxisMerge, CartesianMerge from .clean_tree import CleanUpScheduleTree from .refine_transients import CartesianRefineTransients @@ -6,6 +6,7 @@ __all__ = [ "AxisIterator", "CartesianAxisMerge", + "CartesianMerge", "CartesianRefineTransients", "CleanUpScheduleTree", ] diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index abacd0b4..2cbf8995 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -1,16 +1,19 @@ from __future__ import annotations import copy -import itertools import dace from dace.properties import CodeBlock from dace.sdfg.analysis.schedule_tree import treenodes as tn +from ndsl.config import Backend, BackendLoopOrder from ndsl.dsl.dace.stree.optimizations.memlet_helpers import ( AxisIterator, no_data_dependencies_on_cartesian_axis, ) +from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( + ReplaceAxisSymbolInTasklet, +) from ndsl.dsl.dace.stree.optimizations.tree_common_op import ( detect_cycle, is_axis_for, @@ -102,30 +105,6 @@ def _last_node(nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode) -> b return list_index(nodes, node) >= len(nodes) - 1 -class ReplaceAxisSymbolInTasklet(tn.ScheduleNodeVisitor): - def __init__(self, axis: AxisIterator) -> None: - self._axis = axis - - def visit_TaskletNode( - self, - node: tn.TaskletNode, - axis_replacements: dict[str, str] | None = None, - ) -> None: - if not axis_replacements: - # Noop if there are no replacements to do. - return - - # Dev NOTE: We directly replace the memlet.subset because the `memlet.replace` - # function sometimes doesn't work - for memlet in itertools.chain( - node.in_memlets.values(), node.out_memlets.values() - ): - if memlet.subset is not None: - memlet.subset.replace(axis_replacements) - if memlet.other_subset is not None: - memlet.other_subset.replace(axis_replacements) - - class CartesianAxisMerge(tn.ScheduleNodeTransformer): """Merge a cartesian axis if they are contiguous in code-flow. @@ -368,7 +347,7 @@ def _map_overcompute_merge( # with the axis symbol of the first map. if second_map.node.map.params[0] != first_map.node.map.params[0]: replacements = {second_map.node.map.params[0]: first_map.node.map.params[0]} - ReplaceAxisSymbolInTasklet(self.axis).visit( + ReplaceAxisSymbolInTasklet().visit( first_map, axis_replacements=replacements ) @@ -447,3 +426,37 @@ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: ndsl_log.debug( f"🚀 {self}: {overall_merged} maps merged in {passes_apply} passes" ) + + +class CartesianMerge(tn.ScheduleNodeTransformer): + """Merge Cartesian axis loops""" + + def __init__(self, backend: Backend, *, eager: bool = True) -> None: + self._backend = backend + self.eager = eager + + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: + if self._backend.loop_order == BackendLoopOrder.IJK: + CartesianAxisMerge(AxisIterator._I).visit(node) + CartesianAxisMerge(AxisIterator._J).visit(node) + CartesianAxisMerge(AxisIterator._K).visit(node) + elif self._backend.loop_order == BackendLoopOrder.IKJ: + CartesianAxisMerge(AxisIterator._I).visit(node) + CartesianAxisMerge(AxisIterator._K).visit(node) + CartesianAxisMerge(AxisIterator._J).visit(node) + elif self._backend.loop_order == BackendLoopOrder.JIK: + CartesianAxisMerge(AxisIterator._J).visit(node) + CartesianAxisMerge(AxisIterator._I).visit(node) + CartesianAxisMerge(AxisIterator._K).visit(node) + elif self._backend.loop_order == BackendLoopOrder.JKI: + CartesianAxisMerge(AxisIterator._J).visit(node) + CartesianAxisMerge(AxisIterator._K).visit(node) + CartesianAxisMerge(AxisIterator._I).visit(node) + elif self._backend.loop_order == BackendLoopOrder.KIJ: + CartesianAxisMerge(AxisIterator._K).visit(node) + CartesianAxisMerge(AxisIterator._I).visit(node) + CartesianAxisMerge(AxisIterator._J).visit(node) + elif self._backend.loop_order == BackendLoopOrder.KJI: + CartesianAxisMerge(AxisIterator._K).visit(node) + CartesianAxisMerge(AxisIterator._J).visit(node) + CartesianAxisMerge(AxisIterator._I).visit(node) diff --git a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py new file mode 100644 index 00000000..41c03ada --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import itertools + +from dace.sdfg.analysis.schedule_tree import treenodes as tn + + +class ReplaceAxisSymbolInTasklet(tn.ScheduleNodeVisitor): + def __init__(self) -> None: + pass + + def visit_TaskletNode( + self, + node: tn.TaskletNode, + axis_replacements: dict[str, str] | None = None, + ) -> None: + if not axis_replacements: + # Noop if there are no replacements to do. + return + + # Dev NOTE: We directly replace the memlet.subset because the `memlet.replace` + # function sometimes doesn't work + for memlet in itertools.chain( + node.in_memlets.values(), node.out_memlets.values() + ): + if memlet.subset is not None: + memlet.subset.replace(axis_replacements) + if memlet.other_subset is not None: + memlet.other_subset.replace(axis_replacements) From c8d05af66d6ae94c5a31a0e235de3d07f36a3c4a Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 15 May 2026 08:55:36 -0400 Subject: [PATCH 04/43] Move helpers into `common` and break them by type Move pipeline defaults inside the Pipeline itself and have orchestration call default Mockup of passes required for merging to behave --- ndsl/dsl/dace/orchestration.py | 11 +-- ndsl/dsl/dace/stree/optimizations/__init__.py | 13 ++- .../dace/stree/optimizations/axis_merge.py | 28 +++---- .../stree/optimizations/cartesian_merge.py | 52 ++++++++++++ .../stree/optimizations/common/__init__.py | 22 ++++++ .../dace/stree/optimizations/common/loops.py | 14 ++++ .../{memlet_helpers.py => common/memlet.py} | 0 .../{tree_common_op.py => common/topology.py} | 17 ++-- .../optimizations/offgrid_conditionals.py | 79 +++++++++++++++++++ .../stree/optimizations/refine_transients.py | 2 +- .../dace/stree/optimizations/remove_loops.py | 22 ++++++ .../dace/stree/optimizations/statistics.py | 7 +- ndsl/dsl/dace/stree/pipeline.py | 20 ++++- 13 files changed, 241 insertions(+), 46 deletions(-) create mode 100644 ndsl/dsl/dace/stree/optimizations/cartesian_merge.py create mode 100644 ndsl/dsl/dace/stree/optimizations/common/__init__.py create mode 100644 ndsl/dsl/dace/stree/optimizations/common/loops.py rename ndsl/dsl/dace/stree/optimizations/{memlet_helpers.py => common/memlet.py} (100%) rename ndsl/dsl/dace/stree/optimizations/{tree_common_op.py => common/topology.py} (77%) create mode 100644 ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py create mode 100644 ndsl/dsl/dace/stree/optimizations/remove_loops.py diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 933a379b..e31298d7 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -38,11 +38,6 @@ sdfg_nan_checker, ) from ndsl.dsl.dace.stree import CPUPipeline -from ndsl.dsl.dace.stree.optimizations import ( - CartesianMerge, - CartesianRefineTransients, - CleanUpScheduleTree, -) from ndsl.dsl.dace.utils import ( DaCeProgress, memory_static_analysis, @@ -197,11 +192,7 @@ def _build_sdfg( with DaCeProgress(config, "Schedule Tree: optimization"): CPUPipeline( - passes=[ - CleanUpScheduleTree(), - CartesianMerge(backend_name), - CartesianRefineTransients(backend_name), - ], + backend=backend_name, cache_directory=Path(sdfg.build_folder), ).run(stree, verbose=config.verbose_schedule_tree_optimizations) if config.verbose_orchestration: diff --git a/ndsl/dsl/dace/stree/optimizations/__init__.py b/ndsl/dsl/dace/stree/optimizations/__init__.py index b08d6839..21dcaa72 100644 --- a/ndsl/dsl/dace/stree/optimizations/__init__.py +++ b/ndsl/dsl/dace/stree/optimizations/__init__.py @@ -1,6 +1,13 @@ -from .axis_merge import AxisIterator, CartesianAxisMerge, CartesianMerge +from .axis_merge import AxisIterator, CartesianAxisMerge +from .cartesian_merge import CartesianMerge from .clean_tree import CleanUpScheduleTree +from .offgrid_conditionals import ( + ExtractOffgridConditionals, + InlineOffgridConditionals, + MergeConditionals, +) from .refine_transients import CartesianRefineTransients +from .remove_loops import InlineVertical2DWrite __all__ = [ @@ -9,4 +16,8 @@ "CartesianMerge", "CartesianRefineTransients", "CleanUpScheduleTree", + "InlineVertical2DWrite", + "ExtractOffgridConditionals", + "InlineOffgridConditionals", + "MergeConditionals", ] diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 2cbf8995..9ac339e3 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -7,20 +7,20 @@ from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl.config import Backend, BackendLoopOrder -from ndsl.dsl.dace.stree.optimizations.memlet_helpers import ( +from ndsl.dsl.dace.stree.optimizations.common import ( AxisIterator, - no_data_dependencies_on_cartesian_axis, -) -from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( - ReplaceAxisSymbolInTasklet, -) -from ndsl.dsl.dace.stree.optimizations.tree_common_op import ( detect_cycle, + get_next_node, is_axis_for, is_axis_map, + last_node, list_index, + no_data_dependencies_on_cartesian_axis, swap_node_position_in_tree, ) +from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( + ReplaceAxisSymbolInTasklet, +) from ndsl.logging import ndsl_log @@ -95,16 +95,6 @@ def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope: return node -def _get_next_node( - nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode -) -> tn.ScheduleTreeNode: - return nodes[list_index(nodes, node) + 1] - - -def _last_node(nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode) -> bool: - return list_index(nodes, node) >= len(nodes) - 1 - - class CartesianAxisMerge(tn.ScheduleNodeTransformer): """Merge a cartesian axis if they are contiguous in code-flow. @@ -288,13 +278,13 @@ def _map_overcompute_merge( # End of nodes OR # Not the right axis # --> recurse - if _last_node(nodes, the_map) or not is_axis_map(the_map, self.axis): + if last_node(nodes, the_map) or not is_axis_map(the_map, self.axis): merged = 0 for child in the_map.children: merged += self._merge_node(child, the_map.children) return merged - next_node = _get_next_node(nodes, the_map) + next_node = get_next_node(nodes, the_map) # Next node is not a MapScope - no merge if not isinstance(next_node, tn.MapScope): diff --git a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py new file mode 100644 index 00000000..779d2900 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from dace.sdfg.analysis.schedule_tree import treenodes as tn + +from ndsl.config import Backend, BackendLoopOrder +from ndsl.dsl.dace.stree.optimizations import ( + CartesianAxisMerge, + ExtractOffgridConditionals, + InlineOffgridConditionals, + MergeConditionals, +) +from ndsl.dsl.dace.stree.optimizations.common import AxisIterator + + +class CartesianMerge(tn.ScheduleNodeTransformer): + """Merge Cartesian computation blocks""" + + def __init__(self, backend: Backend, *, eager: bool = True) -> None: + self._backend = backend + self.eager = eager + + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: + InlineOffgridConditionals().visit(node) + MergeConditionals().visit(node) + + if self._backend.loop_order == BackendLoopOrder.IJK: + CartesianAxisMerge(AxisIterator._I).visit(node) + CartesianAxisMerge(AxisIterator._J).visit(node) + CartesianAxisMerge(AxisIterator._K).visit(node) + elif self._backend.loop_order == BackendLoopOrder.IKJ: + CartesianAxisMerge(AxisIterator._I).visit(node) + CartesianAxisMerge(AxisIterator._K).visit(node) + CartesianAxisMerge(AxisIterator._J).visit(node) + elif self._backend.loop_order == BackendLoopOrder.JIK: + CartesianAxisMerge(AxisIterator._J).visit(node) + CartesianAxisMerge(AxisIterator._I).visit(node) + CartesianAxisMerge(AxisIterator._K).visit(node) + elif self._backend.loop_order == BackendLoopOrder.JKI: + CartesianAxisMerge(AxisIterator._J).visit(node) + CartesianAxisMerge(AxisIterator._K).visit(node) + CartesianAxisMerge(AxisIterator._I).visit(node) + elif self._backend.loop_order == BackendLoopOrder.KIJ: + CartesianAxisMerge(AxisIterator._K).visit(node) + CartesianAxisMerge(AxisIterator._I).visit(node) + CartesianAxisMerge(AxisIterator._J).visit(node) + elif self._backend.loop_order == BackendLoopOrder.KJI: + CartesianAxisMerge(AxisIterator._K).visit(node) + CartesianAxisMerge(AxisIterator._J).visit(node) + CartesianAxisMerge(AxisIterator._I).visit(node) + + ExtractOffgridConditionals().visit(node) + MergeConditionals().visit(node) diff --git a/ndsl/dsl/dace/stree/optimizations/common/__init__.py b/ndsl/dsl/dace/stree/optimizations/common/__init__.py new file mode 100644 index 00000000..a4a64bc4 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/common/__init__.py @@ -0,0 +1,22 @@ +from .memlet import AxisIterator, no_data_dependencies_on_cartesian_axis # isort: skip +from .loops import is_axis_for, is_axis_map +from .topology import ( + detect_cycle, + get_next_node, + last_node, + list_index, + swap_node_position_in_tree, +) + + +__all__ = [ + "AxisIterator", + "no_data_dependencies_on_cartesian_axis", + "is_axis_map", + "is_axis_for", + "get_next_node", + "last_node", + "swap_node_position_in_tree", + "detect_cycle", + "list_index", +] diff --git a/ndsl/dsl/dace/stree/optimizations/common/loops.py b/ndsl/dsl/dace/stree/optimizations/common/loops.py new file mode 100644 index 00000000..83a91280 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/common/loops.py @@ -0,0 +1,14 @@ +import dace.sdfg.analysis.schedule_tree.treenodes as tn + +from ndsl.dsl.dace.stree.optimizations.common import AxisIterator + + +def is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool: + """Returns true if node is a Map over the given axis.""" + map_parameter = node.node.map.params + return len(map_parameter) == 1 and map_parameter[0].startswith(axis.as_str()) + + +def is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool: + """Returns true if node is a For over the given axis.""" + return node.loop.loop_variable.startswith(axis.as_str()) diff --git a/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py b/ndsl/dsl/dace/stree/optimizations/common/memlet.py similarity index 100% rename from ndsl/dsl/dace/stree/optimizations/memlet_helpers.py rename to ndsl/dsl/dace/stree/optimizations/common/memlet.py diff --git a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py b/ndsl/dsl/dace/stree/optimizations/common/topology.py similarity index 77% rename from ndsl/dsl/dace/stree/optimizations/tree_common_op.py rename to ndsl/dsl/dace/stree/optimizations/common/topology.py index 4748a901..27edf8fe 100644 --- a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py +++ b/ndsl/dsl/dace/stree/optimizations/common/topology.py @@ -2,8 +2,6 @@ import dace.sdfg.analysis.schedule_tree.treenodes as tn -from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator - def swap_node_position_in_tree( top_node: tn.ScheduleTreeScope, child_node: tn.ScheduleTreeScope @@ -55,12 +53,13 @@ def list_index( return next(index for index, element in enumerate(collection) if element is node) -def is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool: - """Returns true if node is a Map over the given axis.""" - map_parameter = node.node.map.params - return len(map_parameter) == 1 and map_parameter[0].startswith(axis.as_str()) +def get_next_node( + nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode +) -> tn.ScheduleTreeNode: + """Get next node in the children from given node""" + return nodes[list_index(nodes, node) + 1] -def is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool: - """Returns true if node is a For over the given axis.""" - return node.loop.loop_variable.startswith(axis.as_str()) +def last_node(nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode) -> bool: + """Test for last node of list""" + return list_index(nodes, node) >= len(nodes) - 1 diff --git a/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py b/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py new file mode 100644 index 00000000..93d6ab20 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from dace.sdfg.analysis.schedule_tree import treenodes as tn + + +class InlineOffgridConditionals(tn.ScheduleNodeTransformer): + """Push offgrid conditional inside their cartesian block, + duplicating the conditional if needed + + Turning: + ``` + if a_flag == 0 + map i, j, k + [ops...] + map i, j, k + [ops...] + ``` + into + ``` + map i,j, k + if a_flag == 0 + [ops...] + map i,j, k + if a_flag == 0 + [ops...] + ``` + """ + + def __init__(self) -> None: + pass + + def __str__(self) -> str: + return "InlineOffgridConditionals" + + +class ExtractOffgridConditionals(tn.ScheduleNodeTransformer): + """Push offgrid conditional outside of their cartesian block + + Reverse transform from InlineOffgridConditionals + """ + + def __init__(self) -> None: + pass + + def __str__(self) -> str: + return "ExtractOffgridConditionals" + + +class MergeConditionals(tn.ScheduleNodeTransformer): + """Merge consecutive and equal conditionals + + Turning: + ``` + if a_flag == 0 + map i, j, k + [ops...] + if a_flag == 0 + map i, j, k + [ops...] + ``` + into + ``` + if a_flag == 0 + map i, j, k + [ops...] + map i, j, k + [ops...] + ``` + + Outside of user code, vombination of ExtractOffgridConditionals, + InlineOffgridConditionals and CartesianMapMerge can lead to this + pattern. + """ + + def __init__(self) -> None: + pass + + def __str__(self) -> str: + return "MergeConditionals" diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index 7b788e32..bb066d4c 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -4,7 +4,7 @@ import dace.sdfg.analysis.schedule_tree.treenodes as stree from ndsl.config import Backend, BackendFramework -from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator +from ndsl.dsl.dace.stree.optimizations.common import AxisIterator from ndsl.logging import ndsl_log diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py new file mode 100644 index 00000000..f3ba21a3 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -0,0 +1,22 @@ +from dace.sdfg.analysis.schedule_tree import treenodes as tn + + +class InlineVertical2DWrite(tn.ScheduleNodeTransformer): + """Inline K index value for 2D write vertical while removing for loop. + + Transforming: + ``` + for __k = 0; __k < 1; __k = __k + 1: + map __j, __i: + field[__i, __j] = tasklet(field_in[__i, __j, __k]) + ``` + + Into + ``` + map __j, __i: + field[__i, __j] = tasklet(field_in[__i, __j, 0]) + ``` + """ + + def __init__(self) -> None: + super().__init__() diff --git a/ndsl/dsl/dace/stree/optimizations/statistics.py b/ndsl/dsl/dace/stree/optimizations/statistics.py index 1a3ce593..ebef36fe 100644 --- a/ndsl/dsl/dace/stree/optimizations/statistics.py +++ b/ndsl/dsl/dace/stree/optimizations/statistics.py @@ -3,8 +3,11 @@ import dace import dace.sdfg.analysis.schedule_tree.treenodes as stree -from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator -from ndsl.dsl.dace.stree.optimizations.tree_common_op import is_axis_for, is_axis_map +from ndsl.dsl.dace.stree.optimizations.common import ( + AxisIterator, + is_axis_for, + is_axis_map, +) class CountCartesianLoops(stree.ScheduleNodeVisitor): diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index 44e933bc..ad52ca9c 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -2,7 +2,13 @@ import dace.sdfg.analysis.schedule_tree.treenodes as stree -from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge +from ndsl import Backend +from ndsl.dsl.dace.stree.optimizations import ( + CartesianMerge, + CartesianRefineTransients, + CleanUpScheduleTree, + InlineVertical2DWrite, +) from ndsl.dsl.dace.stree.optimizations.statistics import TreeOptimizationStatistics from ndsl.logging import ndsl_log_on_rank_0 @@ -55,14 +61,20 @@ def run( class CPUPipeline(StreePipeline): def __init__( self, + backend: Backend, *, passes: list[stree.ScheduleNodeTransformer] | None = None, cache_directory: Path | None = None, ) -> None: + if passes is None: + passes = [ + CleanUpScheduleTree(), + InlineVertical2DWrite(), + CartesianMerge(backend), + CartesianRefineTransients(backend), + ] super().__init__( - passes=( - passes if passes is not None else [CartesianAxisMerge(AxisIterator._K)] - ), + passes=passes, cache_directory=cache_directory, ) From 20665a8aedc97c5cabc034d97e5a6fdb4ebbeb35 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 15 May 2026 09:04:22 -0400 Subject: [PATCH 05/43] Fix imports --- ndsl/dsl/dace/stree/optimizations/cartesian_merge.py | 6 +++--- ndsl/dsl/dace/stree/optimizations/remove_loops.py | 3 +++ .../dace/stree/optimizations/replace_symbol_in_tasklet.py | 3 +++ ndsl/dsl/dace/stree/optimizations/specialize_maps.py | 3 +++ 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py index 779d2900..d8ab5043 100644 --- a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py @@ -3,13 +3,13 @@ from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl.config import Backend, BackendLoopOrder -from ndsl.dsl.dace.stree.optimizations import ( - CartesianAxisMerge, +from ndsl.dsl.dace.stree.optimizations.axis_merge import CartesianAxisMerge +from ndsl.dsl.dace.stree.optimizations.common import AxisIterator +from ndsl.dsl.dace.stree.optimizations.offgrid_conditionals import ( ExtractOffgridConditionals, InlineOffgridConditionals, MergeConditionals, ) -from ndsl.dsl.dace.stree.optimizations.common import AxisIterator class CartesianMerge(tn.ScheduleNodeTransformer): diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py index f3ba21a3..0bffe7ab 100644 --- a/ndsl/dsl/dace/stree/optimizations/remove_loops.py +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -20,3 +20,6 @@ class InlineVertical2DWrite(tn.ScheduleNodeTransformer): def __init__(self) -> None: super().__init__() + + def __str__(self) -> str: + return "InlineVertical2DWrite" diff --git a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py index 41c03ada..afb150e0 100644 --- a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py +++ b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py @@ -27,3 +27,6 @@ def visit_TaskletNode( memlet.subset.replace(axis_replacements) if memlet.other_subset is not None: memlet.other_subset.replace(axis_replacements) + + def __str__(self) -> str: + return "ReplaceAxisSymbolInTasklet" diff --git a/ndsl/dsl/dace/stree/optimizations/specialize_maps.py b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py index 2583ec2d..e2409e1a 100644 --- a/ndsl/dsl/dace/stree/optimizations/specialize_maps.py +++ b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py @@ -19,3 +19,6 @@ def visit_MapScope(self, node: stree.MapScope) -> None: node.node.map.range = sbs.Range(dims) self.visit(node.children) + + def __str__(self) -> str: + return "SpecializeCartesianMaps" From c8a225e2492427634a9742df593f6a3ae65a1f0b Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 15 May 2026 15:23:44 -0400 Subject: [PATCH 06/43] `InlineVertical2DWrite` + utest --- ndsl/dsl/dace/stree/optimizations/__init__.py | 3 +- .../stree/optimizations/cartesian_merge.py | 3 + .../stree/optimizations/common/__init__.py | 2 + .../dace/stree/optimizations/common/memlet.py | 3 + .../stree/optimizations/common/topology.py | 16 +++ .../stree/optimizations/refine_transients.py | 6 +- .../dace/stree/optimizations/remove_loops.py | 50 +++++++ .../stree/optimizations/test_remove_loops.py | 132 ++++++++++++++++++ 8 files changed, 208 insertions(+), 7 deletions(-) create mode 100644 tests/dsl/dace/stree/optimizations/test_remove_loops.py diff --git a/ndsl/dsl/dace/stree/optimizations/__init__.py b/ndsl/dsl/dace/stree/optimizations/__init__.py index 21dcaa72..8cd77f55 100644 --- a/ndsl/dsl/dace/stree/optimizations/__init__.py +++ b/ndsl/dsl/dace/stree/optimizations/__init__.py @@ -1,4 +1,4 @@ -from .axis_merge import AxisIterator, CartesianAxisMerge +from .axis_merge import CartesianAxisMerge from .cartesian_merge import CartesianMerge from .clean_tree import CleanUpScheduleTree from .offgrid_conditionals import ( @@ -11,7 +11,6 @@ __all__ = [ - "AxisIterator", "CartesianAxisMerge", "CartesianMerge", "CartesianRefineTransients", diff --git a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py index d8ab5043..398a2103 100644 --- a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py @@ -19,6 +19,9 @@ def __init__(self, backend: Backend, *, eager: bool = True) -> None: self._backend = backend self.eager = eager + def __str__(self) -> str: + return "CartesianMerge" + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: InlineOffgridConditionals().visit(node) MergeConditionals().visit(node) diff --git a/ndsl/dsl/dace/stree/optimizations/common/__init__.py b/ndsl/dsl/dace/stree/optimizations/common/__init__.py index a4a64bc4..c76887fb 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/__init__.py +++ b/ndsl/dsl/dace/stree/optimizations/common/__init__.py @@ -5,6 +5,7 @@ get_next_node, last_node, list_index, + reparent_scope_node, swap_node_position_in_tree, ) @@ -19,4 +20,5 @@ "swap_node_position_in_tree", "detect_cycle", "list_index", + "reparent_scope_node", ] diff --git a/ndsl/dsl/dace/stree/optimizations/common/memlet.py b/ndsl/dsl/dace/stree/optimizations/common/memlet.py index 75f68143..b1540c98 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/memlet.py +++ b/ndsl/dsl/dace/stree/optimizations/common/memlet.py @@ -17,6 +17,9 @@ def as_str(self) -> str: def as_cartesian_index(self) -> int: return self.value[1] + def is_equal(self, other: str) -> bool: + return other.startswith(self.as_str()) + def no_data_dependencies_on_cartesian_axis( first: stree.MapScope, diff --git a/ndsl/dsl/dace/stree/optimizations/common/topology.py b/ndsl/dsl/dace/stree/optimizations/common/topology.py index 27edf8fe..fe878522 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/topology.py +++ b/ndsl/dsl/dace/stree/optimizations/common/topology.py @@ -3,6 +3,22 @@ import dace.sdfg.analysis.schedule_tree.treenodes as tn +def reparent_scope_node( + original_parent: tn.ScheduleTreeScope, + new_parent: tn.ScheduleTreeNode, + prepend: bool = True, +) -> None: + """Re-parent children between two scope nodes""" + + for child in original_parent.children: + child.parent = new_parent + + if prepend: + new_parent.children = [*original_parent.children, *new_parent.children] + else: + new_parent.children = [*new_parent.children, *original_parent.children] + + def swap_node_position_in_tree( top_node: tn.ScheduleTreeScope, child_node: tn.ScheduleTreeScope ) -> None: diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index bb066d4c..8da15f8f 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -38,11 +38,7 @@ def _reduce_cartesian_axis_size_to_1( # Assume 3D cartesian! if len(transient_data.shape) < 3: - warnings.warn( - f"Potential non-3D array: {transient_data}, skipping.", - UserWarning, - stacklevel=2, - ) + ndsl_log.debug(f"Potential non-3D array: {transient_data}, skipping.") return False read_write_range: dace.subsets.Range = dace.subsets.union( diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py index 0bffe7ab..58b20d23 100644 --- a/ndsl/dsl/dace/stree/optimizations/remove_loops.py +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -1,5 +1,13 @@ +import ast + from dace.sdfg.analysis.schedule_tree import treenodes as tn +from ndsl import ndsl_log +from ndsl.dsl.dace.stree.optimizations.common import AxisIterator, reparent_scope_node +from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( + ReplaceAxisSymbolInTasklet, +) + class InlineVertical2DWrite(tn.ScheduleNodeTransformer): """Inline K index value for 2D write vertical while removing for loop. @@ -20,6 +28,48 @@ class InlineVertical2DWrite(tn.ScheduleNodeTransformer): def __init__(self) -> None: super().__init__() + self._for_scope_removed = 0 def __str__(self) -> str: return "InlineVertical2DWrite" + + def visit_ForScope(self, the_for: tn.ForScope) -> tn.ForScope | tn.ScheduleTreeNode: + if ( + AxisIterator._K.is_equal(the_for.loop.loop_variable) + and the_for.loop.executions == 1 + and the_for.parent + ): + # Retrieve init value by executing the code and replace usage of it + # If the code cannot be executed (no-literal variable part of the op, etc.) + # we will _not_ inline + try: + exec(ast.unparse(the_for.loop.init_statement.code[0])) + except Exception as _: + return the_for + init_value = locals()[the_for.loop.loop_variable] + ReplaceAxisSymbolInTasklet().visit( + the_for, axis_replacements={the_for.loop.loop_variable: str(init_value)} + ) + + # Prepend children of the ForScope to parent + # the_for.parent.children = [*the_for.children, *the_for.parent.children] + reparent_scope_node(the_for, the_for.parent) + + # Remove ForScope + the_for.parent.children.remove(the_for) + self._for_scope_removed += 1 + assert len(the_for.children) > 0 + return the_for.parent.children[0] + + return the_for + + def visit_ScheduleTreeRoot( + self, the_root: tn.ScheduleTreeRoot + ) -> tn.ScheduleTreeRoot: + + for child in the_root.children: + self.visit(child) + + ndsl_log.debug(f"🚀 {self}: {self._for_scope_removed} inlined") + + return the_root diff --git a/tests/dsl/dace/stree/optimizations/test_remove_loops.py b/tests/dsl/dace/stree/optimizations/test_remove_loops.py new file mode 100644 index 00000000..07611767 --- /dev/null +++ b/tests/dsl/dace/stree/optimizations/test_remove_loops.py @@ -0,0 +1,132 @@ +from typing import TypeAlias + +import pytest +from dace import nodes +from dace.sdfg.state import LoopRegion + +from ndsl import QuantityFactory, StencilFactory, orchestrate +from ndsl.boilerplate import get_factories_single_tile +from ndsl.config import Backend +from ndsl.constants import I_DIM, J_DIM, K_DIM, Float +from ndsl.dsl.gt4py import FORWARD, computation, interval +from ndsl.dsl.typing import FloatField, FloatFieldIJ +from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge + + +def stencil_simple_2D_write(in_field: FloatField, out_fieldIJ: FloatFieldIJ) -> None: + with computation(FORWARD), interval(0, 1): + out_fieldIJ = in_field + + +def stencil_2D_write_at_K(in_field: FloatField, out_fieldIJ: FloatFieldIJ) -> None: + with computation(FORWARD), interval(-1, None): + out_fieldIJ = in_field + + +class OrchestratedCode: + def __init__( + self, + stencil_factory: StencilFactory, + quantity_factory: QuantityFactory, + ) -> None: + orchestratable_methods = [ + "write_at_0", + "write_at_top", + ] + for method in orchestratable_methods: + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + method_to_orchestrate=method, + ) + + self.stencil_simple_2D_write = stencil_factory.from_dims_halo( + func=stencil_simple_2D_write, + compute_dims=[I_DIM, J_DIM, K_DIM], + ) + self.stencil_2D_write_at_K = stencil_factory.from_dims_halo( + func=stencil_2D_write_at_K, + compute_dims=[I_DIM, J_DIM, K_DIM], + ) + + def write_at_0( + self, + in_field: FloatField, + out_field: FloatFieldIJ, + ) -> None: + self.stencil_simple_2D_write(in_field, out_field) + + def write_at_top( + self, + in_field: FloatField, + out_field: FloatFieldIJ, + ) -> None: + self.stencil_2D_write_at_K(in_field, out_field) + + +Factories: TypeAlias = tuple[StencilFactory, QuantityFactory] + + +class TestStree2DWriteInline: + @pytest.fixture(params=[Backend("orch:dace:cpu:IJK"), Backend("orch:dace:cpu:KJI")]) + def factories(self, request) -> Factories: + domain = (3, 3, 4) + return get_factories_single_tile( + domain[0], domain[1], domain[2], 0, backend=request.param + ) + + @pytest.fixture + def code(self, factories: Factories) -> OrchestratedCode: + return OrchestratedCode(*factories) + + def test_common_2D_write( + self, code: OrchestratedCode, factories: Factories + ) -> None: + stencil_factory, quantity_factory = factories + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM], "") + in_qty.field[:, :, 0] = Float(32.0) + + with StreeOptimization(): + code.write_at_0(in_qty, out_qty) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + all_loop_region = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, LoopRegion) + ] + + assert len(all_maps) == 2 + assert len(all_loop_region) == 0 + assert (out_qty.field[:] == Float(32.0)).all() + + def test_2D_write_K_top(self, code: OrchestratedCode, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM], "") + in_qty.field[:, :, -1] = Float(32.0) + + with StreeOptimization(): + code.write_at_top(in_qty, out_qty) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + all_loop_region = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, LoopRegion) + ] + + assert len(all_maps) == 2 + assert len(all_loop_region) == 0 + assert (out_qty.field[:] == Float(32.0)).all() From 73f5609d355b9adf78eb175d1002decee2df97d2 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 15 May 2026 15:59:20 -0400 Subject: [PATCH 07/43] Fix InlineVertical2DWrite --- .../dace/stree/optimizations/remove_loops.py | 16 ++++--- .../stree/optimizations/test_remove_loops.py | 45 +++++++++++++++++-- 2 files changed, 50 insertions(+), 11 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py index 58b20d23..8eff3e51 100644 --- a/ndsl/dsl/dace/stree/optimizations/remove_loops.py +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -34,19 +34,21 @@ def __str__(self) -> str: return "InlineVertical2DWrite" def visit_ForScope(self, the_for: tn.ForScope) -> tn.ForScope | tn.ScheduleTreeNode: - if ( - AxisIterator._K.is_equal(the_for.loop.loop_variable) - and the_for.loop.executions == 1 - and the_for.parent - ): - # Retrieve init value by executing the code and replace usage of it + if AxisIterator._K.is_equal(the_for.loop.loop_variable) and the_for.parent: + # Retrieve init/bound value by executing the code and replace usage of it # If the code cannot be executed (no-literal variable part of the op, etc.) # we will _not_ inline try: exec(ast.unparse(the_for.loop.init_statement.code[0])) + init_value = locals()[the_for.loop.loop_variable] + bound_value = eval( + ast.unparse(the_for.loop.loop_condition.code[0].value.comparators) + ) except Exception as _: return the_for - init_value = locals()[the_for.loop.loop_variable] + if abs(bound_value - init_value) != 1: + return the_for + ReplaceAxisSymbolInTasklet().visit( the_for, axis_replacements={the_for.loop.loop_variable: str(init_value)} ) diff --git a/tests/dsl/dace/stree/optimizations/test_remove_loops.py b/tests/dsl/dace/stree/optimizations/test_remove_loops.py index 07611767..012e88b9 100644 --- a/tests/dsl/dace/stree/optimizations/test_remove_loops.py +++ b/tests/dsl/dace/stree/optimizations/test_remove_loops.py @@ -23,16 +23,18 @@ def stencil_2D_write_at_K(in_field: FloatField, out_fieldIJ: FloatFieldIJ) -> No out_fieldIJ = in_field +def stencil_forward_at_K(in_field: FloatField, out_field: FloatField) -> None: + with computation(FORWARD), interval(...): + out_field = in_field + + class OrchestratedCode: def __init__( self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory, ) -> None: - orchestratable_methods = [ - "write_at_0", - "write_at_top", - ] + orchestratable_methods = ["write_at_0", "write_at_top", "do_not_inline"] for method in orchestratable_methods: orchestrate( obj=self, @@ -48,6 +50,10 @@ def __init__( func=stencil_2D_write_at_K, compute_dims=[I_DIM, J_DIM, K_DIM], ) + self.stencil_do_not_inline = stencil_factory.from_dims_halo( + func=stencil_forward_at_K, + compute_dims=[I_DIM, J_DIM, K_DIM], + ) def write_at_0( self, @@ -63,6 +69,13 @@ def write_at_top( ) -> None: self.stencil_2D_write_at_K(in_field, out_field) + def do_not_inline( + self, + in_field: FloatField, + out_field: FloatField, + ) -> None: + self.stencil_do_not_inline(in_field, out_field) + Factories: TypeAlias = tuple[StencilFactory, QuantityFactory] @@ -130,3 +143,27 @@ def test_2D_write_K_top(self, code: OrchestratedCode, factories: Factories) -> N assert len(all_maps) == 2 assert len(all_loop_region) == 0 assert (out_qty.field[:] == Float(32.0)).all() + + def test_do_not_inline(self, code: OrchestratedCode, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + with StreeOptimization(): + code.do_not_inline(in_qty, out_qty) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + all_loop_region = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, LoopRegion) + ] + + assert len(all_maps) == 2 + assert len(all_loop_region) == 1 + assert (out_qty.field[:] == Float(1)).all() From fc1ecb10838849d4179b387e06849b04d303df08 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 18 May 2026 11:42:44 +0200 Subject: [PATCH 08/43] cleanup --- ndsl/dsl/dace/stree/optimizations/axis_merge.py | 4 +--- .../dsl/dace/stree/optimizations/cartesian_merge.py | 2 -- ndsl/dsl/dace/stree/optimizations/clean_tree.py | 5 ++--- .../stree/optimizations/offgrid_conditionals.py | 13 +------------ .../dace/stree/optimizations/refine_transients.py | 4 ++-- .../optimizations/replace_symbol_in_tasklet.py | 5 ----- 6 files changed, 6 insertions(+), 27 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 9ac339e3..0c924123 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -1,11 +1,10 @@ -from __future__ import annotations - import copy import dace from dace.properties import CodeBlock from dace.sdfg.analysis.schedule_tree import treenodes as tn +from ndsl import ndsl_log from ndsl.config import Backend, BackendLoopOrder from ndsl.dsl.dace.stree.optimizations.common import ( AxisIterator, @@ -21,7 +20,6 @@ from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( ReplaceAxisSymbolInTasklet, ) -from ndsl.logging import ndsl_log # Buggy passes that should work diff --git a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py index 398a2103..0521309f 100644 --- a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl.config import Backend, BackendLoopOrder diff --git a/ndsl/dsl/dace/stree/optimizations/clean_tree.py b/ndsl/dsl/dace/stree/optimizations/clean_tree.py index 0da456de..8f04882a 100644 --- a/ndsl/dsl/dace/stree/optimizations/clean_tree.py +++ b/ndsl/dsl/dace/stree/optimizations/clean_tree.py @@ -1,14 +1,13 @@ -from __future__ import annotations - from dace.sdfg.analysis.schedule_tree import treenodes as tn -from ndsl.logging import ndsl_log +from ndsl import ndsl_log class CleanUpScheduleTree(tn.ScheduleNodeTransformer): """Remove `StateBoundary` nodes from children of ScheduleTreeScopes.""" def __init__(self) -> None: + super().__init__() self._removed_state_boundaries = 0 def __str__(self) -> str: diff --git a/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py b/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py index 93d6ab20..d677c203 100644 --- a/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py +++ b/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from dace.sdfg.analysis.schedule_tree import treenodes as tn @@ -26,9 +24,6 @@ class InlineOffgridConditionals(tn.ScheduleNodeTransformer): ``` """ - def __init__(self) -> None: - pass - def __str__(self) -> str: return "InlineOffgridConditionals" @@ -39,9 +34,6 @@ class ExtractOffgridConditionals(tn.ScheduleNodeTransformer): Reverse transform from InlineOffgridConditionals """ - def __init__(self) -> None: - pass - def __str__(self) -> str: return "ExtractOffgridConditionals" @@ -67,13 +59,10 @@ class MergeConditionals(tn.ScheduleNodeTransformer): [ops...] ``` - Outside of user code, vombination of ExtractOffgridConditionals, + Outside of user code, combination of ExtractOffgridConditionals, InlineOffgridConditionals and CartesianMapMerge can lead to this pattern. """ - def __init__(self) -> None: - pass - def __str__(self) -> str: return "MergeConditionals" diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index 8da15f8f..cd8e2703 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -34,7 +34,7 @@ def _reduce_cartesian_axis_size_to_1( are atomic""" # Dev Note: Better dataflow analysis would look at exactly - # what's goin on here! + # what's going on here! # Assume 3D cartesian! if len(transient_data.shape) < 3: @@ -206,7 +206,7 @@ class CartesianRefineTransients(stree.ScheduleNodeTransformer): cartesian axis) it will reduce that axis to 1 if all access are atomic (exactly _one_ element of the array is ever worked on in a single loop) - It will refuse to merge if the transient is used in multiple loops of for - a given axis - irrigardless of it's access pattern (e.g. even if it could be + a given axis - regardless of it's access pattern (e.g. even if it could be refine because it's always written first.) It should but cannot do/will bug if: diff --git a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py index afb150e0..398ce203 100644 --- a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py +++ b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py @@ -1,14 +1,9 @@ -from __future__ import annotations - import itertools from dace.sdfg.analysis.schedule_tree import treenodes as tn class ReplaceAxisSymbolInTasklet(tn.ScheduleNodeVisitor): - def __init__(self) -> None: - pass - def visit_TaskletNode( self, node: tn.TaskletNode, From d7e40aa81096c87efdfc04998bd3136a55c5c031 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 18 May 2026 11:43:35 +0200 Subject: [PATCH 09/43] fix symbol replacement Use symbols in the replacement directory. Update DaCe to a version that doesn't re-initialize the symbols. And fix the test failure in python 3.13. --- external/dace | 2 +- ndsl/dsl/dace/stree/optimizations/axis_merge.py | 6 +++++- .../dace/stree/optimizations/remove_loops.py | 17 ++++++++++++++--- .../optimizations/replace_symbol_in_tasklet.py | 5 +---- .../stree/optimizations/test_remove_loops.py | 5 +++-- 5 files changed, 24 insertions(+), 11 deletions(-) diff --git a/external/dace b/external/dace index d5fbadb6..f271b30b 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit d5fbadb626389e425fac5ed93d2a880811eca41f +Subproject commit f271b30bb983559306342ce2ff98c69e6662bb32 diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 0c924123..227f5614 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -334,7 +334,11 @@ def _map_overcompute_merge( # After merge, we need to replace the axis symbols of the second map's children # with the axis symbol of the first map. if second_map.node.map.params[0] != first_map.node.map.params[0]: - replacements = {second_map.node.map.params[0]: first_map.node.map.params[0]} + replacements = { + dace.symbol(second_map.node.map.params[0]): dace.symbol( + first_map.node.map.params[0] + ) + } ReplaceAxisSymbolInTasklet().visit( first_map, axis_replacements=replacements ) diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py index 8eff3e51..43e9c15e 100644 --- a/ndsl/dsl/dace/stree/optimizations/remove_loops.py +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -1,5 +1,7 @@ import ast +from typing import Any +import dace from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl import ndsl_log @@ -39,8 +41,14 @@ def visit_ForScope(self, the_for: tn.ForScope) -> tn.ForScope | tn.ScheduleTreeN # If the code cannot be executed (no-literal variable part of the op, etc.) # we will _not_ inline try: - exec(ast.unparse(the_for.loop.init_statement.code[0])) - init_value = locals()[the_for.loop.loop_variable] + exec_locals: dict[str, Any] = {} + exec_globals: dict[str, Any] = {} + exec( + ast.unparse(the_for.loop.init_statement.code[0]), + exec_globals, + exec_locals, + ) + init_value = exec_locals[the_for.loop.loop_variable] bound_value = eval( ast.unparse(the_for.loop.loop_condition.code[0].value.comparators) ) @@ -50,7 +58,10 @@ def visit_ForScope(self, the_for: tn.ForScope) -> tn.ForScope | tn.ScheduleTreeN return the_for ReplaceAxisSymbolInTasklet().visit( - the_for, axis_replacements={the_for.loop.loop_variable: str(init_value)} + the_for, + axis_replacements={ + dace.symbol(the_for.loop.loop_variable): str(init_value) + }, ) # Prepend children of the ForScope to parent diff --git a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py index 398ce203..1020affe 100644 --- a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py +++ b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py @@ -18,10 +18,7 @@ def visit_TaskletNode( for memlet in itertools.chain( node.in_memlets.values(), node.out_memlets.values() ): - if memlet.subset is not None: - memlet.subset.replace(axis_replacements) - if memlet.other_subset is not None: - memlet.other_subset.replace(axis_replacements) + memlet.replace(axis_replacements) def __str__(self) -> str: return "ReplaceAxisSymbolInTasklet" diff --git a/tests/dsl/dace/stree/optimizations/test_remove_loops.py b/tests/dsl/dace/stree/optimizations/test_remove_loops.py index 012e88b9..06cbe9fe 100644 --- a/tests/dsl/dace/stree/optimizations/test_remove_loops.py +++ b/tests/dsl/dace/stree/optimizations/test_remove_loops.py @@ -81,11 +81,12 @@ def do_not_inline( class TestStree2DWriteInline: - @pytest.fixture(params=[Backend("orch:dace:cpu:IJK"), Backend("orch:dace:cpu:KJI")]) + @pytest.fixture(params=["orch:dace:cpu:IJK", "orch:dace:cpu:KJI"]) def factories(self, request) -> Factories: + domain = (3, 3, 4) return get_factories_single_tile( - domain[0], domain[1], domain[2], 0, backend=request.param + domain[0], domain[1], domain[2], 0, backend=Backend(request.param) ) @pytest.fixture From 55ad8fa8db6b8acc050b3d1c82f862e96918fd27 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 18 May 2026 15:11:43 +0200 Subject: [PATCH 10/43] update gt4py (log10 and friends) --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index c7d162cc..9fbba0a0 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit c7d162cccc35cb2d1aaa79f5ad12222f617803ac +Subproject commit 9fbba0a07232cd8765123bdd226ea3c26cf768a8 From 1c9bb5f567831e824d86c778c7392d2cfaff8b95 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 18 May 2026 15:07:44 +0200 Subject: [PATCH 11/43] more cleanup (all minor nothing fancy) --- .../dace/stree/optimizations/axis_merge.py | 43 ++----------------- .../stree/optimizations/cartesian_merge.py | 4 +- .../stree/optimizations/common/topology.py | 1 + .../replace_symbol_in_tasklet.py | 2 - 4 files changed, 6 insertions(+), 44 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 227f5614..e699dddc 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -5,7 +5,6 @@ from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl import ndsl_log -from ndsl.config import Backend, BackendLoopOrder from ndsl.dsl.dace.stree.optimizations.common import ( AxisIterator, detect_cycle, @@ -98,22 +97,20 @@ class CartesianAxisMerge(tn.ScheduleNodeTransformer): Can do: - merge a given axis with the next maps at the same recursion level - - can overcompute (eager) to allow for more merging at the cost of an if + - does overcompute to allow for more merging at the cost of an if It expects: - All Maps and ForLoop are on a single axis - but doesn't check for it. Args: axis: AxisIterator to be merged - eager: overcompute with a conditional guard """ - def __init__(self, axis: AxisIterator, *, eager: bool = True) -> None: + def __init__(self, axis: AxisIterator) -> None: self.axis = axis - self.eager = eager def __str__(self) -> str: - return f"CartesianAxisMerge_{self.axis.name}_{'eager' if self.eager else ''}" + return f"CartesianAxisMerge_{self.axis.name}" def _merge_node( self, node: tn.ScheduleTreeNode, nodes: list[tn.ScheduleTreeNode] @@ -418,37 +415,3 @@ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: ndsl_log.debug( f"🚀 {self}: {overall_merged} maps merged in {passes_apply} passes" ) - - -class CartesianMerge(tn.ScheduleNodeTransformer): - """Merge Cartesian axis loops""" - - def __init__(self, backend: Backend, *, eager: bool = True) -> None: - self._backend = backend - self.eager = eager - - def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: - if self._backend.loop_order == BackendLoopOrder.IJK: - CartesianAxisMerge(AxisIterator._I).visit(node) - CartesianAxisMerge(AxisIterator._J).visit(node) - CartesianAxisMerge(AxisIterator._K).visit(node) - elif self._backend.loop_order == BackendLoopOrder.IKJ: - CartesianAxisMerge(AxisIterator._I).visit(node) - CartesianAxisMerge(AxisIterator._K).visit(node) - CartesianAxisMerge(AxisIterator._J).visit(node) - elif self._backend.loop_order == BackendLoopOrder.JIK: - CartesianAxisMerge(AxisIterator._J).visit(node) - CartesianAxisMerge(AxisIterator._I).visit(node) - CartesianAxisMerge(AxisIterator._K).visit(node) - elif self._backend.loop_order == BackendLoopOrder.JKI: - CartesianAxisMerge(AxisIterator._J).visit(node) - CartesianAxisMerge(AxisIterator._K).visit(node) - CartesianAxisMerge(AxisIterator._I).visit(node) - elif self._backend.loop_order == BackendLoopOrder.KIJ: - CartesianAxisMerge(AxisIterator._K).visit(node) - CartesianAxisMerge(AxisIterator._I).visit(node) - CartesianAxisMerge(AxisIterator._J).visit(node) - elif self._backend.loop_order == BackendLoopOrder.KJI: - CartesianAxisMerge(AxisIterator._K).visit(node) - CartesianAxisMerge(AxisIterator._J).visit(node) - CartesianAxisMerge(AxisIterator._I).visit(node) diff --git a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py index 0521309f..d52403f7 100644 --- a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py @@ -13,9 +13,9 @@ class CartesianMerge(tn.ScheduleNodeTransformer): """Merge Cartesian computation blocks""" - def __init__(self, backend: Backend, *, eager: bool = True) -> None: + def __init__(self, backend: Backend) -> None: + super().__init__() self._backend = backend - self.eager = eager def __str__(self) -> str: return "CartesianMerge" diff --git a/ndsl/dsl/dace/stree/optimizations/common/topology.py b/ndsl/dsl/dace/stree/optimizations/common/topology.py index fe878522..e81df22a 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/topology.py +++ b/ndsl/dsl/dace/stree/optimizations/common/topology.py @@ -6,6 +6,7 @@ def reparent_scope_node( original_parent: tn.ScheduleTreeScope, new_parent: tn.ScheduleTreeNode, + *, prepend: bool = True, ) -> None: """Re-parent children between two scope nodes""" diff --git a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py index 1020affe..7dcb7bae 100644 --- a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py +++ b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py @@ -13,8 +13,6 @@ def visit_TaskletNode( # Noop if there are no replacements to do. return - # Dev NOTE: We directly replace the memlet.subset because the `memlet.replace` - # function sometimes doesn't work for memlet in itertools.chain( node.in_memlets.values(), node.out_memlets.values() ): From 36204b00eebab3ee9ca539dbb01d2c2034a0d2ad Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 20 May 2026 15:35:28 +0200 Subject: [PATCH 12/43] Add support for InlineOffgridConditionals --- external/dace | 2 +- .../dace/stree/optimizations/axis_merge.py | 24 +-- .../stree/optimizations/cartesian_merge.py | 48 +++--- .../dace/stree/optimizations/clean_tree.py | 4 +- .../optimizations/offgrid_conditionals.py | 119 +++++++++---- ndsl/dsl/dace/stree/pipeline.py | 1 + .../test_offgrid_conditionals.py | 158 ++++++++++++++++++ 7 files changed, 288 insertions(+), 68 deletions(-) create mode 100644 tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py diff --git a/external/dace b/external/dace index f271b30b..ec81b1a0 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit f271b30bb983559306342ce2ff98c69e6662bb32 +Subproject commit ec81b1a0c2a872da8dd315378ff6a9ac67d5458b diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index e699dddc..3ad5b377 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -75,20 +75,20 @@ def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope: all_children_are_maps = all( [isinstance(child, tn.MapScope) for child in node.children] ) - if not all_children_are_maps: - if self._merged_range != self._original_range: - if_scope = tn.IfScope( - condition=self._execution_condition(), - children=node.children, - parent=node, - ) - # Re-parent to IF - for child in node.children: - child.parent = if_scope - node.children = [if_scope] + if all_children_are_maps: + node.children = self.visit(node.children) return node - node.children = self.visit(node.children) + if self._merged_range != self._original_range: + if_scope = tn.IfScope( + condition=self._execution_condition(), + children=node.children, + parent=node, + ) + # Re-parent to IF + for child in node.children: + child.parent = if_scope + node.children = [if_scope] return node diff --git a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py index d52403f7..1dd64458 100644 --- a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py @@ -21,33 +21,31 @@ def __str__(self) -> str: return "CartesianMerge" def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: - InlineOffgridConditionals().visit(node) + for axis in self._backend_order(): + InlineOffgridConditionals(axis).visit(node) MergeConditionals().visit(node) - if self._backend.loop_order == BackendLoopOrder.IJK: - CartesianAxisMerge(AxisIterator._I).visit(node) - CartesianAxisMerge(AxisIterator._J).visit(node) - CartesianAxisMerge(AxisIterator._K).visit(node) - elif self._backend.loop_order == BackendLoopOrder.IKJ: - CartesianAxisMerge(AxisIterator._I).visit(node) - CartesianAxisMerge(AxisIterator._K).visit(node) - CartesianAxisMerge(AxisIterator._J).visit(node) - elif self._backend.loop_order == BackendLoopOrder.JIK: - CartesianAxisMerge(AxisIterator._J).visit(node) - CartesianAxisMerge(AxisIterator._I).visit(node) - CartesianAxisMerge(AxisIterator._K).visit(node) - elif self._backend.loop_order == BackendLoopOrder.JKI: - CartesianAxisMerge(AxisIterator._J).visit(node) - CartesianAxisMerge(AxisIterator._K).visit(node) - CartesianAxisMerge(AxisIterator._I).visit(node) - elif self._backend.loop_order == BackendLoopOrder.KIJ: - CartesianAxisMerge(AxisIterator._K).visit(node) - CartesianAxisMerge(AxisIterator._I).visit(node) - CartesianAxisMerge(AxisIterator._J).visit(node) - elif self._backend.loop_order == BackendLoopOrder.KJI: - CartesianAxisMerge(AxisIterator._K).visit(node) - CartesianAxisMerge(AxisIterator._J).visit(node) - CartesianAxisMerge(AxisIterator._I).visit(node) + for axis in self._backend_order(): + CartesianAxisMerge(axis).visit(node) ExtractOffgridConditionals().visit(node) MergeConditionals().visit(node) + + def _backend_order(self) -> tuple[AxisIterator, AxisIterator, AxisIterator]: + if self._backend.loop_order == BackendLoopOrder.IJK: + return (AxisIterator._I, AxisIterator._J, AxisIterator._K) + + if self._backend.loop_order == BackendLoopOrder.IKJ: + return (AxisIterator._I, AxisIterator._K, AxisIterator._J) + + if self._backend.loop_order == BackendLoopOrder.JIK: + return (AxisIterator._J, AxisIterator._I, AxisIterator._K) + + if self._backend.loop_order == BackendLoopOrder.JKI: + return (AxisIterator._J, AxisIterator._K, AxisIterator._I) + + if self._backend.loop_order == BackendLoopOrder.KIJ: + return (AxisIterator._K, AxisIterator._I, AxisIterator._J) + + assert self._backend.loop_order == BackendLoopOrder.KJI + return (AxisIterator._K, AxisIterator._J, AxisIterator._I) diff --git a/ndsl/dsl/dace/stree/optimizations/clean_tree.py b/ndsl/dsl/dace/stree/optimizations/clean_tree.py index 8f04882a..7d3b5558 100644 --- a/ndsl/dsl/dace/stree/optimizations/clean_tree.py +++ b/ndsl/dsl/dace/stree/optimizations/clean_tree.py @@ -49,12 +49,13 @@ def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope: def visit_IfScope(self, node: tn.IfScope) -> tn.IfScope: self._remove_state_boundaries_from_children(node) + for child in node.children: self.visit(child) return node - def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> tn.ScheduleTreeRoot: self._removed_state_boundaries = 0 self._remove_state_boundaries_from_children(node) @@ -63,3 +64,4 @@ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: self.visit(child) ndsl_log.debug(f"{self}: removed {self._removed_state_boundaries} nodes") + return node diff --git a/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py b/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py index d677c203..de4c21a2 100644 --- a/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py +++ b/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py @@ -1,37 +1,97 @@ from dace.sdfg.analysis.schedule_tree import treenodes as tn +from ndsl import ndsl_log +from ndsl.dsl.dace.stree.optimizations.common import ( + AxisIterator, + get_next_node, + is_axis_map, + last_node, + list_index, +) -class InlineOffgridConditionals(tn.ScheduleNodeTransformer): - """Push offgrid conditional inside their cartesian block, - duplicating the conditional if needed + +class InlineOffgridConditionals(tn.ScheduleNodeVisitor): + """ + Push offgrid conditional inside their cartesian block, duplicating the + conditional if needed. Turning: ``` - if a_flag == 0 - map i, j, k - [ops...] - map i, j, k - [ops...] + if a_flag == 0: + map i, j, k: + ... + map i, j, k: + ... ``` into ``` - map i,j, k - if a_flag == 0 - [ops...] - map i,j, k - if a_flag == 0 - [ops...] + map i, j, k: + if a_flag == 0: + ... + map i, j, k: + if a_flag == 0: + ... ``` """ + _axis: AxisIterator + + def __init__(self, axis: AxisIterator) -> None: + super().__init__() + self._axis = axis + def __str__(self) -> str: - return "InlineOffgridConditionals" + return f"InlineOffgridConditionals_{self._axis}" + + def visit_IfScope(self, node: tn.IfScope) -> None: + assert node.parent is not None # just to keep pyright happy + + # For now, skip in case there's an `elif` or `else` following. + if not last_node(node.parent.children, node): + next_node = get_next_node(node.parent.children, node) + if isinstance(next_node, (tn.ElifScope, tn.ElseScope)): + ndsl_log.debug( + "Can't handle conditionals with `elif` and `else` blocks yet :(" + ) + return + + if not all( + [ + isinstance(child, tn.MapScope) and is_axis_map(child, self._axis) + for child in node.children + ] + ): + return + + # If all children are maps over the correct axis, move the if inside. + new_nodes: list[tn.MapScope] = [] + + for child in node.children: + assert isinstance( + child, tn.MapScope + ) # otherwise the condition above is wrong + + if_scope = tn.IfScope( + condition=node.condition, children=child.children, parent=child + ) + + for map_child in child.children: + map_child.parent = if_scope # re-parent to new if_scope + + child.children = [if_scope] + child.parent = node.parent # re-parent to parent of old if_scope + new_nodes.append(child) + + insert_at = list_index(node.parent.children, node) + node.parent.children[insert_at:insert_at] = new_nodes + node.parent.children.remove(node) class ExtractOffgridConditionals(tn.ScheduleNodeTransformer): - """Push offgrid conditional outside of their cartesian block + """ + Push offgrid conditional outside of their cartesian block. - Reverse transform from InlineOffgridConditionals + This is the inverse transform of InlineOffgridConditionals. """ def __str__(self) -> str: @@ -39,24 +99,25 @@ def __str__(self) -> str: class MergeConditionals(tn.ScheduleNodeTransformer): - """Merge consecutive and equal conditionals + """ + Merge consecutive and equal conditionals. Turning: ``` - if a_flag == 0 - map i, j, k - [ops...] - if a_flag == 0 - map i, j, k - [ops...] + if a_flag == 0: + map i, j, k: + ... + if a_flag == 0: + map i, j, k: + ... ``` into ``` - if a_flag == 0 - map i, j, k - [ops...] - map i, j, k - [ops...] + if a_flag == 0: + map i, j, k: + ... + map i, j, k: + ... ``` Outside of user code, combination of ExtractOffgridConditionals, diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index ad52ca9c..13e30974 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -52,6 +52,7 @@ def run( f.write(stree.as_string()) tree_stats.optimized(stree) + if verbose: ndsl_log_on_rank_0.info(tree_stats.report()) diff --git a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py new file mode 100644 index 00000000..7e4ceb4f --- /dev/null +++ b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py @@ -0,0 +1,158 @@ +from typing import TypeAlias + +import pytest +from dace import nodes + +from ndsl import ( + Backend, + NDSLRuntime, + QuantityFactory, + StencilFactory, + orchestrate, + stencils, +) +from ndsl.boilerplate import get_factories_single_tile +from ndsl.constants import I_DIM, J_DIM, K_DIM +from ndsl.dsl.typing import FloatField +from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge + + +class OrchestratedCode(NDSLRuntime): + def __init__(self, stencil_factory: StencilFactory) -> None: + super().__init__(stencil_factory) + + methods_to_orchestrate = [ + "happy_case", + "happy_case_2", + "blocked_by_else", + "blocked_by_other_nodes", + ] + + for method in methods_to_orchestrate: + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + method_to_orchestrate=method, + ) + + self._copy_stencil = stencil_factory.from_dims_halo( + func=stencils.copy, compute_dims=[I_DIM, J_DIM, K_DIM] + ) + + def happy_case(self, in_field: FloatField, out_field: FloatField) -> None: + if in_field[0, 0, 0] > 0: + self._copy_stencil(in_field, out_field) + self._copy_stencil(in_field, out_field) + + def happy_case_2(self, in_field: FloatField, out_field: FloatField) -> None: + if not in_field[0, 0, 0] > 0: + self._copy_stencil(in_field, out_field) + self._copy_stencil(in_field, out_field) + + def blocked_by_else(self, in_field: FloatField, out_field: FloatField) -> None: + self._copy_stencil(in_field, out_field) + + if in_field[0, 0, 0] > 0: + self._copy_stencil(in_field, out_field) + else: + self._copy_stencil(out_field, in_field) + + def blocked_by_other_nodes( + self, in_field: FloatField, out_field: FloatField + ) -> None: + if in_field[0, 0, 0] > 0: + in_field[:] = 42.0 + self._copy_stencil(in_field, out_field) + self._copy_stencil(in_field, out_field) + + +Factories: TypeAlias = tuple[StencilFactory, QuantityFactory] + + +class TestStreeInlineOffgridConditionals: + @pytest.fixture(params=["orch:dace:cpu:IJK", "orch:dace:cpu:KJI"]) + def factories(self, request) -> Factories: + domain = (3, 3, 4) + return get_factories_single_tile( + domain[0], domain[1], domain[2], 0, backend=Backend(request.param) + ) + + def test_happy_case(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + + code = OrchestratedCode(stencil_factory) + in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + with StreeOptimization(): + code.happy_case(in_quantity, out_quantity) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + assert precompiled_sdfg.sdfg + + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + assert len(all_maps) == 3 + + def test_happy_case_2(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + + code = OrchestratedCode(stencil_factory) + in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + with StreeOptimization(): + code.happy_case_2(in_quantity, out_quantity) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + assert precompiled_sdfg.sdfg + + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + assert len(all_maps) == 3 + + def test_blocked_by_else(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + + code = OrchestratedCode(stencil_factory) + in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + with StreeOptimization(): + code.blocked_by_else(in_quantity, out_quantity) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + assert precompiled_sdfg.sdfg + + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + assert len(all_maps) == 9 + + def test_blocked_by_other_nodes(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + + code = OrchestratedCode(stencil_factory) + in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + with StreeOptimization(): + code.blocked_by_other_nodes(in_quantity, out_quantity) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + assert precompiled_sdfg.sdfg + + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + assert len(all_maps) == 6 From 689ab89a4b046af4860f386cbb892d959d30a9eb Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 20 May 2026 15:51:21 +0200 Subject: [PATCH 13/43] fixup: temp fix for test of InlineOffgridConditionals --- .../stree/optimizations/test_offgrid_conditionals.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py index 7e4ceb4f..77ba3591 100644 --- a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py +++ b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py @@ -155,4 +155,13 @@ def test_blocked_by_other_nodes(self, factories: Factories) -> None: for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 6 + + # ⚠️ Dev note: + # This should be just `assert len(all_maps) == 6`, but currently, the K-loops + # can't merge because the K-iterators are different. To be fixed (and simplified + # here) with a subsequent commit. + assert ( + len(all_maps) == 6 + if stencil_factory.backend == Backend("orch:dace:cpu:IJK") + else 9 + ) From c263116d813bf5c7eaa95b08194e11faa46f08f9 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 20 May 2026 15:54:39 +0200 Subject: [PATCH 14/43] cleanup: remove old "push if down" codepath This has been replaced with `InlineOffgridConditionals` pass --- .../dace/stree/optimizations/axis_merge.py | 80 ------------------- 1 file changed, 80 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 3ad5b377..44f6577f 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -21,10 +21,6 @@ ) -# Buggy passes that should work -PUSH_IFSCOPE_DOWNWARD = False # Crashing the overall stree - bad algorithmics - - def _both_same_single_axis_maps( first: tn.MapScope, second: tn.MapScope, axis: AxisIterator ) -> bool: @@ -125,9 +121,6 @@ def _merge_node( if isinstance(node, tn.MapScope): return self._map_overcompute_merge(node, nodes) - if PUSH_IFSCOPE_DOWNWARD and isinstance(node, tn.IfScope): - return self._push_ifelse_down(node, nodes) - if isinstance(node, tn.ForScope): return self._for_merge(node) @@ -194,79 +187,6 @@ def _push_tasklet_down( return merged - def _push_ifelse_down( - self, the_if: tn.IfScope, nodes: list[tn.ScheduleTreeNode] - ) -> int: - merged = 0 - - # Recurse down if/else/elif - if_index = list_index(nodes, the_if) - if len(the_if.children) != 0: - merged += self._merge_node(the_if.children[0], the_if.children) - for else_index in range(if_index + 1, len(nodes)): - else_node = nodes[else_index] - if else_index < len(nodes) and ( - isinstance(else_node, tn.ElseScope) - or isinstance(else_node, tn.ElifScope) - ): - merged += self._merge_node(else_node, else_node.children) - else: - break - - # Look at swapping if/else/elif first map w/ control flow - - # Gather all first maps - if they do not exists, get out - all_maps = [] - if isinstance(the_if.children[0], tn.MapScope): - all_maps.append(the_if.children[0]) - else: - return merged - for else_index in range(if_index + 1, len(nodes)): - else_node = nodes[else_index] - if else_index < len(nodes) and ( - isinstance(else_node, tn.ElseScope) - or isinstance(else_node, tn.ElifScope) - ): - if isinstance(else_node.children[0], tn.MapScope): - all_maps.append(else_node.children[0]) - else: - return merged - - else: - break - - # Check for mergeability - if len(all_maps) > 1: - the_map = all_maps[0] - for _map in all_maps[1:]: - if not _can_merge_axis_maps(the_map, _map, self.axis): - return merged - - # We are good to go - swap it all - inner_if_map = the_if.children[0] - - # Swap IF & maps - if_index = list_index(nodes, the_if) - swap_node_position_in_tree(the_if, inner_if_map) - - # Swap ELIF/ELSE & maps - for else_index in range(if_index + 1, len(nodes)): - if else_index < len(nodes) and ( - isinstance(nodes[else_index], tn.ElseScope) - or isinstance(nodes[else_index], tn.ElifScope) - ): - swap_node_position_in_tree( - nodes[else_index], nodes[else_index].children[0] - ) - else: - break - - # Merge the Maps - assert isinstance(nodes[if_index], tn.MapScope) - merged += self._map_overcompute_merge(nodes[if_index], nodes) - - return merged - def _map_overcompute_merge( self, the_map: tn.MapScope, nodes: list[tn.ScheduleTreeNode] ) -> int: From 7d6ecc1756a4ca00417ee146430b05fd30e1191a Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 20 May 2026 12:27:44 -0400 Subject: [PATCH 15/43] Normalize cartesian index during data depedancy check --- .../dace/stree/optimizations/common/memlet.py | 26 ++++++++++++--- tests/dsl/dace/stree/common/__init__.py | 0 tests/dsl/dace/stree/common/test_memlet.py | 32 +++++++++++++++++++ 3 files changed, 53 insertions(+), 5 deletions(-) create mode 100644 tests/dsl/dace/stree/common/__init__.py create mode 100644 tests/dsl/dace/stree/common/test_memlet.py diff --git a/ndsl/dsl/dace/stree/optimizations/common/memlet.py b/ndsl/dsl/dace/stree/optimizations/common/memlet.py index b1540c98..97d99b68 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/memlet.py +++ b/ndsl/dsl/dace/stree/optimizations/common/memlet.py @@ -2,6 +2,7 @@ import dace.sdfg.analysis.schedule_tree.treenodes as stree from dace.memlet import Memlet +from dace.symbolic import symbol from ndsl.logging import ndsl_log @@ -21,6 +22,15 @@ def is_equal(self, other: str) -> bool: return other.startswith(self.as_str()) +def normalize_cartesian_indexation(index: symbol, axis: AxisIterator) -> symbol: + """Return a normalize indexation symbol for cartesian indexation.""" + rename_maps = {} + for symb in index.free_symbols: + if symb.name.startswith(axis.as_str()): + rename_maps[symb] = symbol(axis.as_str()) + return index.subs(rename_maps) + + def no_data_dependencies_on_cartesian_axis( first: stree.MapScope, second: stree.MapScope, @@ -36,20 +46,26 @@ def no_data_dependencies_on_cartesian_axis( for write in write_collector.out_memlets: # TODO: this can be optimized to allow non-overlapping intervals and such in the future - if write.subset.dims() <= axis.as_cartesian_index(): + axis_index = axis.as_cartesian_index() + + if write.subset.dims() <= axis_index: # Dimension does not exist continue - previous_axis_index = write.subset[axis.as_cartesian_index()][0] + previous_axis_index = normalize_cartesian_indexation( + write.subset[axis_index][0], axis + ) for read in read_collector.in_memlets: if write.data == read.data: - if previous_axis_index != read.subset[axis.as_cartesian_index()][0]: + if previous_axis_index != normalize_cartesian_indexation( + read.subset[axis_index][0], axis + ): ndsl_log.debug( f"[{axis.name} Merge] Found read after write conflict " f"for {write.data} " f"w/ different offset to {axis.name} (" - f"write at {write.subset[axis.as_cartesian_index()][0]}, " - f"read at {read.subset[axis.as_cartesian_index()][0]})" + f"write at {write.subset[axis_index][0]}, " + f"read at {read.subset[axis_index][0]})" ) return False return True diff --git a/tests/dsl/dace/stree/common/__init__.py b/tests/dsl/dace/stree/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/dsl/dace/stree/common/test_memlet.py b/tests/dsl/dace/stree/common/test_memlet.py new file mode 100644 index 00000000..44fe15e1 --- /dev/null +++ b/tests/dsl/dace/stree/common/test_memlet.py @@ -0,0 +1,32 @@ +from dace.symbolic import symbol + +from ndsl.dsl.dace.stree.optimizations.common import AxisIterator +from ndsl.dsl.dace.stree.optimizations.common.memlet import ( + normalize_cartesian_indexation, +) + + +def test_normalize_cartesian_index(): + # Case of __k_id(node) - original case + original_symbol = symbol("__k_12345678789") + norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) + + assert norm_symbol == symbol("__k") + + # Case of offset + original_symbol = 1 + symbol("__k_12345678789") + norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) + + assert norm_symbol == symbol("__k") + 1 + + # Case of no-op (with offset) + original_symbol = 1 + symbol("__k") + norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) + + assert norm_symbol == symbol("__k") + 1 + + # Case of index named with _k - so not a cartesian axis + original_symbol = 1 + symbol("_kindex") + norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) + + assert norm_symbol == symbol("_kindex") + 1 From de03d3480038d990f2bcbee2a5d62ec72efa12e0 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 20 May 2026 16:18:25 -0400 Subject: [PATCH 16/43] Update tests --- tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py index 77ba3591..5699407e 100644 --- a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py +++ b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py @@ -161,7 +161,7 @@ def test_blocked_by_other_nodes(self, factories: Factories) -> None: # can't merge because the K-iterators are different. To be fixed (and simplified # here) with a subsequent commit. assert ( - len(all_maps) == 6 + len(all_maps) == 5 if stencil_factory.backend == Backend("orch:dace:cpu:IJK") else 9 ) From ff5722770caf6a919e09c969e3f085fa32263f59 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 20 May 2026 17:12:39 -0400 Subject: [PATCH 17/43] ReplaceAxisSymbolInTasklet -> ReplaceAxisSymbol + debug of it's usage --- .../dace/stree/optimizations/axis_merge.py | 9 +++---- .../dace/stree/optimizations/remove_loops.py | 11 +++----- .../replace_symbol_in_tasklet.py | 26 ++++++++++++++----- 3 files changed, 26 insertions(+), 20 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 44f6577f..3f10f122 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -17,7 +17,7 @@ swap_node_position_in_tree, ) from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( - ReplaceAxisSymbolInTasklet, + ReplaceAxisSymbol, ) @@ -211,7 +211,6 @@ def _map_overcompute_merge( # Over compute to merge: # - force-merge by expanding the ranges - # - then, guard children to only run in their respective range first_range = the_map.node.map.range second_range = next_node.node.map.range merged_range = dace.subsets.Range( @@ -224,7 +223,7 @@ def _map_overcompute_merge( ] ) - # push IfScope down if children are just maps + # - then, guard children to only run in their respective range axis_as_str = the_map.node.params[0] first_map = InsertOvercomputationGuard( axis_as_str, merged_range=merged_range, original_range=first_range @@ -256,9 +255,7 @@ def _map_overcompute_merge( first_map.node.map.params[0] ) } - ReplaceAxisSymbolInTasklet().visit( - first_map, axis_replacements=replacements - ) + ReplaceAxisSymbol(replacements).visit(first_map) # delete now-merged second_map del nodes[list_index(nodes, next_node)] diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py index 43e9c15e..41c7aee9 100644 --- a/ndsl/dsl/dace/stree/optimizations/remove_loops.py +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -7,7 +7,7 @@ from ndsl import ndsl_log from ndsl.dsl.dace.stree.optimizations.common import AxisIterator, reparent_scope_node from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( - ReplaceAxisSymbolInTasklet, + ReplaceAxisSymbol, ) @@ -57,12 +57,9 @@ def visit_ForScope(self, the_for: tn.ForScope) -> tn.ForScope | tn.ScheduleTreeN if abs(bound_value - init_value) != 1: return the_for - ReplaceAxisSymbolInTasklet().visit( - the_for, - axis_replacements={ - dace.symbol(the_for.loop.loop_variable): str(init_value) - }, - ) + ReplaceAxisSymbol( + {dace.symbol(the_for.loop.loop_variable): str(init_value)} + ).visit(the_for) # Prepend children of the ForScope to parent # the_for.parent.children = [*the_for.children, *the_for.parent.children] diff --git a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py index 7dcb7bae..a64c6e67 100644 --- a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py +++ b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py @@ -1,22 +1,34 @@ import itertools from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.symbolic import symbol -class ReplaceAxisSymbolInTasklet(tn.ScheduleNodeVisitor): +class ReplaceAxisSymbol(tn.ScheduleNodeVisitor): + def __init__(self, axis_replacements: dict[str | symbol, str | symbol]) -> None: + self._axis_replacements = axis_replacements + def visit_TaskletNode( self, node: tn.TaskletNode, - axis_replacements: dict[str, str] | None = None, ) -> None: - if not axis_replacements: - # Noop if there are no replacements to do. - return - for memlet in itertools.chain( node.in_memlets.values(), node.out_memlets.values() ): - memlet.replace(axis_replacements) + memlet.replace(self._axis_replacements) + + def visit_IfScope( + self, + node: tn.IfScope, + ) -> None: + if self._axis_replacements: + for old, new in self._axis_replacements.items(): + node.condition.as_string = node.condition.as_string.replace( + str(old), str(new) + ) + + for child in node.children: + self.visit(child) def __str__(self) -> str: return "ReplaceAxisSymbolInTasklet" From 94b2e9928d5e89b5e8804cdae044783cecb5d9bd Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 21 May 2026 09:59:03 +0200 Subject: [PATCH 18/43] fix unit test by hardinging detection of "our" loops --- .../dace/stree/optimizations/common/loops.py | 13 +++- .../dace/stree/optimizations/remove_loops.py | 7 +- .../replace_symbol_in_tasklet.py | 21 ++---- tests/dsl/dace/stree/common/test_loops.py | 69 +++++++++++++++++++ .../test_offgrid_conditionals.py | 6 +- 5 files changed, 92 insertions(+), 24 deletions(-) create mode 100644 tests/dsl/dace/stree/common/test_loops.py diff --git a/ndsl/dsl/dace/stree/optimizations/common/loops.py b/ndsl/dsl/dace/stree/optimizations/common/loops.py index 83a91280..35e33b8c 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/loops.py +++ b/ndsl/dsl/dace/stree/optimizations/common/loops.py @@ -6,9 +6,18 @@ def is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool: """Returns true if node is a Map over the given axis.""" map_parameter = node.node.map.params - return len(map_parameter) == 1 and map_parameter[0].startswith(axis.as_str()) + if len(map_parameter) != 1: + return False + + if axis == AxisIterator._K: + return map_parameter[0].startswith(axis.as_str()) + + return map_parameter[0] == axis.as_str() def is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool: """Returns true if node is a For over the given axis.""" - return node.loop.loop_variable.startswith(axis.as_str()) + if axis == AxisIterator._K: + return node.loop.loop_variable.startswith(axis.as_str()) + + return node.loop.loop_variable == axis.as_str() diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py index 41c7aee9..54ac6d5d 100644 --- a/ndsl/dsl/dace/stree/optimizations/remove_loops.py +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -30,7 +30,7 @@ class InlineVertical2DWrite(tn.ScheduleNodeTransformer): def __init__(self) -> None: super().__init__() - self._for_scope_removed = 0 + self._for_scopes_removed = 0 def __str__(self) -> str: return "InlineVertical2DWrite" @@ -67,7 +67,7 @@ def visit_ForScope(self, the_for: tn.ForScope) -> tn.ForScope | tn.ScheduleTreeN # Remove ForScope the_for.parent.children.remove(the_for) - self._for_scope_removed += 1 + self._for_scopes_removed += 1 assert len(the_for.children) > 0 return the_for.parent.children[0] @@ -76,10 +76,11 @@ def visit_ForScope(self, the_for: tn.ForScope) -> tn.ForScope | tn.ScheduleTreeN def visit_ScheduleTreeRoot( self, the_root: tn.ScheduleTreeRoot ) -> tn.ScheduleTreeRoot: + self._for_scopes_removed = 0 for child in the_root.children: self.visit(child) - ndsl_log.debug(f"🚀 {self}: {self._for_scope_removed} inlined") + ndsl_log.debug(f"🚀 {self}: {self._for_scopes_removed} inlined") return the_root diff --git a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py index a64c6e67..dbc26eb0 100644 --- a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py +++ b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py @@ -8,27 +8,20 @@ class ReplaceAxisSymbol(tn.ScheduleNodeVisitor): def __init__(self, axis_replacements: dict[str | symbol, str | symbol]) -> None: self._axis_replacements = axis_replacements - def visit_TaskletNode( - self, - node: tn.TaskletNode, - ) -> None: + def visit_TaskletNode(self, node: tn.TaskletNode) -> None: for memlet in itertools.chain( node.in_memlets.values(), node.out_memlets.values() ): memlet.replace(self._axis_replacements) - def visit_IfScope( - self, - node: tn.IfScope, - ) -> None: - if self._axis_replacements: - for old, new in self._axis_replacements.items(): - node.condition.as_string = node.condition.as_string.replace( - str(old), str(new) - ) + def visit_IfScope(self, node: tn.IfScope) -> None: + for old, new in self._axis_replacements.items(): + node.condition.as_string = node.condition.as_string.replace( + str(old), str(new) + ) for child in node.children: self.visit(child) def __str__(self) -> str: - return "ReplaceAxisSymbolInTasklet" + return "ReplaceAxisSymbol" diff --git a/tests/dsl/dace/stree/common/test_loops.py b/tests/dsl/dace/stree/common/test_loops.py new file mode 100644 index 00000000..3020f551 --- /dev/null +++ b/tests/dsl/dace/stree/common/test_loops.py @@ -0,0 +1,69 @@ +from dace.sdfg import nodes +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.sdfg.state import LoopRegion + +from ndsl.dsl.dace.stree.optimizations.common import ( + AxisIterator, + is_axis_for, + is_axis_map, +) + + +def test_is_axis_map_multiple_params() -> None: + node = tn.MapScope( + node=nodes.MapEntry( + nodes.Map("map_ij", ["__i", "__j"], [(0, 3, 1), (0, 4, 1)]) + ), + children=[], + ) + assert not is_axis_map(node, AxisIterator._I) + assert not is_axis_map(node, AxisIterator._J) + + +def test_is_axis_map_I() -> None: + node = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_i", ["__i"], [(0, 3, 1)])), children=[] + ) + assert is_axis_map(node, AxisIterator._I) + + +def test_is_axis_map_not_I() -> None: + node = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_other_i", ["__i0"], [(0, 3, 1)])), + children=[], + ) + assert not is_axis_map(node, AxisIterator._I) + + +def test_is_axis_map_K() -> None: + node = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_k", ["__k_1234"], [(0, 3, 1)])), children=[] + ) + assert is_axis_map(node, AxisIterator._K) + + +def test_is_axis_map_wrong_iterator() -> None: + node = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_i", ["__i"], [(0, 3, 1)])), children=[] + ) + assert not is_axis_map(node, AxisIterator._J) + + +def test_is_axis_for_k() -> None: + node = tn.ForScope(loop=LoopRegion("for_k", loop_var="__k_1234"), children=[]) + assert is_axis_for(node, AxisIterator._K) + + +def test_is_axis_for_wrong_iterator() -> None: + node = tn.ForScope(loop=LoopRegion("for_k", loop_var="__k_1234"), children=[]) + assert not is_axis_for(node, AxisIterator._I) + + +def test_is_axis_for_i() -> None: + node = tn.ForScope(loop=LoopRegion("for_i", loop_var="__i"), children=[]) + assert is_axis_for(node, AxisIterator._I) + + +def test_is_axis_for_not_i() -> None: + node = tn.ForScope(loop=LoopRegion("for_i", loop_var="__i0"), children=[]) + assert not is_axis_for(node, AxisIterator._I) diff --git a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py index 5699407e..f897173f 100644 --- a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py +++ b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py @@ -160,8 +160,4 @@ def test_blocked_by_other_nodes(self, factories: Factories) -> None: # This should be just `assert len(all_maps) == 6`, but currently, the K-loops # can't merge because the K-iterators are different. To be fixed (and simplified # here) with a subsequent commit. - assert ( - len(all_maps) == 5 - if stencil_factory.backend == Backend("orch:dace:cpu:IJK") - else 9 - ) + assert len(all_maps) == 9 From 3a5057719d69780e20671cc380ea72324056419e Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 21 May 2026 11:56:27 +0200 Subject: [PATCH 19/43] unrelated cleanup: fix/assert type issues --- ndsl/dsl/dace/stree/optimizations/common/topology.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ndsl/dsl/dace/stree/optimizations/common/topology.py b/ndsl/dsl/dace/stree/optimizations/common/topology.py index e81df22a..fa06f3db 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/topology.py +++ b/ndsl/dsl/dace/stree/optimizations/common/topology.py @@ -5,7 +5,7 @@ def reparent_scope_node( original_parent: tn.ScheduleTreeScope, - new_parent: tn.ScheduleTreeNode, + new_parent: tn.ScheduleTreeScope, *, prepend: bool = True, ) -> None: @@ -26,6 +26,7 @@ def swap_node_position_in_tree( """Top node becomes child, child becomes top node.""" # Ensue parent/children relationship is valid tn.validate_children_and_parents_align(top_node) + assert top_node.parent is not None # Take refs before swap top_children = top_node.parent.children From d6824f3c3f63c722cab4e876d25b7fff7912809b Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 21 May 2026 15:36:09 +0200 Subject: [PATCH 20/43] Changes to `InlineVertical2DWrite` --- .../dace/stree/optimizations/common/loops.py | 10 +- .../dace/stree/optimizations/common/memlet.py | 5 +- .../dace/stree/optimizations/remove_loops.py | 93 +++++++------ .../stree/optimizations/test_remove_loops.py | 129 +++++++++++++++--- 4 files changed, 166 insertions(+), 71 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/common/loops.py b/ndsl/dsl/dace/stree/optimizations/common/loops.py index 35e33b8c..5d414915 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/loops.py +++ b/ndsl/dsl/dace/stree/optimizations/common/loops.py @@ -9,15 +9,9 @@ def is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool: if len(map_parameter) != 1: return False - if axis == AxisIterator._K: - return map_parameter[0].startswith(axis.as_str()) - - return map_parameter[0] == axis.as_str() + return axis.is_equal(map_parameter[0]) def is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool: """Returns true if node is a For over the given axis.""" - if axis == AxisIterator._K: - return node.loop.loop_variable.startswith(axis.as_str()) - - return node.loop.loop_variable == axis.as_str() + return axis.is_equal(node.loop.loop_variable) diff --git a/ndsl/dsl/dace/stree/optimizations/common/memlet.py b/ndsl/dsl/dace/stree/optimizations/common/memlet.py index 97d99b68..61f64225 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/memlet.py +++ b/ndsl/dsl/dace/stree/optimizations/common/memlet.py @@ -19,7 +19,10 @@ def as_cartesian_index(self) -> int: return self.value[1] def is_equal(self, other: str) -> bool: - return other.startswith(self.as_str()) + if self == AxisIterator._K: + return other.startswith(self.as_str()) + + return other == self.as_str() def normalize_cartesian_indexation(index: symbol, axis: AxisIterator) -> symbol: diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py index 54ac6d5d..c02f8af5 100644 --- a/ndsl/dsl/dace/stree/optimizations/remove_loops.py +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -5,13 +5,17 @@ from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl import ndsl_log -from ndsl.dsl.dace.stree.optimizations.common import AxisIterator, reparent_scope_node +from ndsl.dsl.dace.stree.optimizations.common import ( + AxisIterator, + is_axis_for, + list_index, +) from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( ReplaceAxisSymbol, ) -class InlineVertical2DWrite(tn.ScheduleNodeTransformer): +class InlineVertical2DWrite(tn.ScheduleNodeVisitor): """Inline K index value for 2D write vertical while removing for loop. Transforming: @@ -35,52 +39,51 @@ def __init__(self) -> None: def __str__(self) -> str: return "InlineVertical2DWrite" - def visit_ForScope(self, the_for: tn.ForScope) -> tn.ForScope | tn.ScheduleTreeNode: - if AxisIterator._K.is_equal(the_for.loop.loop_variable) and the_for.parent: - # Retrieve init/bound value by executing the code and replace usage of it - # If the code cannot be executed (no-literal variable part of the op, etc.) - # we will _not_ inline - try: - exec_locals: dict[str, Any] = {} - exec_globals: dict[str, Any] = {} - exec( - ast.unparse(the_for.loop.init_statement.code[0]), - exec_globals, - exec_locals, - ) - init_value = exec_locals[the_for.loop.loop_variable] - bound_value = eval( - ast.unparse(the_for.loop.loop_condition.code[0].value.comparators) - ) - except Exception as _: - return the_for - if abs(bound_value - init_value) != 1: - return the_for - - ReplaceAxisSymbol( - {dace.symbol(the_for.loop.loop_variable): str(init_value)} - ).visit(the_for) - - # Prepend children of the ForScope to parent - # the_for.parent.children = [*the_for.children, *the_for.parent.children] - reparent_scope_node(the_for, the_for.parent) - - # Remove ForScope - the_for.parent.children.remove(the_for) - self._for_scopes_removed += 1 - assert len(the_for.children) > 0 - return the_for.parent.children[0] - - return the_for - - def visit_ScheduleTreeRoot( - self, the_root: tn.ScheduleTreeRoot - ) -> tn.ScheduleTreeRoot: + def visit_ForScope(self, the_for: tn.ForScope) -> None: + if not is_axis_for(the_for, AxisIterator._K): + return + + assert the_for.parent is not None # just to keep pyright happy + + # Retrieve init/bound value by executing the code and replace usage of it + # If the code cannot be executed (no-literal variable part of the op, etc.) + # we will _not_ inline + try: + exec_locals: dict[str, Any] = {} + exec_globals: dict[str, Any] = {} + exec( + ast.unparse(the_for.loop.init_statement.code[0]), + exec_globals, + exec_locals, + ) + init_value = exec_locals[the_for.loop.loop_variable] + bound_value = eval( + ast.unparse(the_for.loop.loop_condition.code[0].value.comparators) + ) + except Exception as _: + return + if abs(bound_value - init_value) != 1: + return + + ReplaceAxisSymbol( + {dace.symbol(the_for.loop.loop_variable): str(init_value)} + ).visit(the_for) + + # Insert children of the ForScope to parent + insert_at = list_index(the_for.parent.children, the_for) + for child in the_for.children: + child.parent = the_for.parent + the_for.parent.children[insert_at:insert_at] = the_for.children + + # Remove ForScope + the_for.parent.children.remove(the_for) + self._for_scopes_removed += 1 + assert len(the_for.children) > 0 + + def visit_ScheduleTreeRoot(self, the_root: tn.ScheduleTreeRoot) -> None: self._for_scopes_removed = 0 for child in the_root.children: self.visit(child) ndsl_log.debug(f"🚀 {self}: {self._for_scopes_removed} inlined") - - return the_root diff --git a/tests/dsl/dace/stree/optimizations/test_remove_loops.py b/tests/dsl/dace/stree/optimizations/test_remove_loops.py index 06cbe9fe..4af0474a 100644 --- a/tests/dsl/dace/stree/optimizations/test_remove_loops.py +++ b/tests/dsl/dace/stree/optimizations/test_remove_loops.py @@ -6,10 +6,11 @@ from ndsl import QuantityFactory, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile -from ndsl.config import Backend +from ndsl.config import Backend, BackendLoopOrder from ndsl.constants import I_DIM, J_DIM, K_DIM, Float from ndsl.dsl.gt4py import FORWARD, computation, interval from ndsl.dsl.typing import FloatField, FloatFieldIJ +from ndsl.stencils import copy from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge @@ -18,6 +19,14 @@ def stencil_simple_2D_write(in_field: FloatField, out_fieldIJ: FloatFieldIJ) -> out_fieldIJ = in_field +def stencil_multiple_2D_write( + in_field: FloatField, out_fieldIJ: FloatFieldIJ, out_fieldIJ_2: FloatFieldIJ +) -> None: + with computation(FORWARD), interval(0, 1): + out_fieldIJ = in_field + out_fieldIJ_2 = in_field + 1.0 + + def stencil_2D_write_at_K(in_field: FloatField, out_fieldIJ: FloatFieldIJ) -> None: with computation(FORWARD), interval(-1, None): out_fieldIJ = in_field @@ -29,13 +38,15 @@ def stencil_forward_at_K(in_field: FloatField, out_field: FloatField) -> None: class OrchestratedCode: - def __init__( - self, - stencil_factory: StencilFactory, - quantity_factory: QuantityFactory, - ) -> None: - orchestratable_methods = ["write_at_0", "write_at_top", "do_not_inline"] - for method in orchestratable_methods: + def __init__(self, stencil_factory: StencilFactory) -> None: + methods_to_orchestrate = [ + "write_at_0", + "write_at_top", + "do_not_inline", + "combined_stencils", + "multiple_statements", + ] + for method in methods_to_orchestrate: orchestrate( obj=self, config=stencil_factory.config.dace_config, @@ -54,6 +65,14 @@ def __init__( func=stencil_forward_at_K, compute_dims=[I_DIM, J_DIM, K_DIM], ) + self.stencil_copy = stencil_factory.from_dims_halo( + func=copy, + compute_dims=[I_DIM, J_DIM, K_DIM], + ) + self.stencil_multiple_2D_write = stencil_factory.from_dims_halo( + func=stencil_multiple_2D_write, + compute_dims=[I_DIM, J_DIM, K_DIM], + ) def write_at_0( self, @@ -76,6 +95,18 @@ def do_not_inline( ) -> None: self.stencil_do_not_inline(in_field, out_field) + def combined_stencils( + self, field: FloatField, field2: FloatField, fieldIJ: FloatFieldIJ + ) -> None: + self.stencil_copy(field, field2) + self.stencil_simple_2D_write(field2, fieldIJ) + + def multiple_statements( + self, in_field: FloatField, out_field: FloatFieldIJ, out_field2: FloatFieldIJ + ) -> None: + self.stencil_copy(in_field, in_field) + self.stencil_multiple_2D_write(in_field, out_field, out_field2) + Factories: TypeAlias = tuple[StencilFactory, QuantityFactory] @@ -89,14 +120,10 @@ def factories(self, request) -> Factories: domain[0], domain[1], domain[2], 0, backend=Backend(request.param) ) - @pytest.fixture - def code(self, factories: Factories) -> OrchestratedCode: - return OrchestratedCode(*factories) - - def test_common_2D_write( - self, code: OrchestratedCode, factories: Factories - ) -> None: + def test_common_2D_write(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM], "") in_qty.field[:, :, 0] = Float(32.0) @@ -120,8 +147,10 @@ def test_common_2D_write( assert len(all_loop_region) == 0 assert (out_qty.field[:] == Float(32.0)).all() - def test_2D_write_K_top(self, code: OrchestratedCode, factories: Factories) -> None: + def test_2D_write_K_top(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM], "") in_qty.field[:, :, -1] = Float(32.0) @@ -145,8 +174,10 @@ def test_2D_write_K_top(self, code: OrchestratedCode, factories: Factories) -> N assert len(all_loop_region) == 0 assert (out_qty.field[:] == Float(32.0)).all() - def test_do_not_inline(self, code: OrchestratedCode, factories: Factories) -> None: + def test_do_not_inline(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") @@ -168,3 +199,67 @@ def test_do_not_inline(self, code: OrchestratedCode, factories: Factories) -> No assert len(all_maps) == 2 assert len(all_loop_region) == 1 assert (out_qty.field[:] == Float(1)).all() + + def test_combined_stencils(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + field = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + field_2 = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + field_IJ = quantity_factory.zeros([I_DIM, J_DIM], "") + + with StreeOptimization(): + code.combined_stencils(field, field_2, field_IJ) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + all_loop_region = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, LoopRegion) + ] + + assert ( + len(all_maps) == 3 + if stencil_factory.backend.loop_order == BackendLoopOrder.IJK + else 5 + ) + assert len(all_loop_region) == 0 + assert (field_IJ.field[:] == Float(1)).all() + + def test_multiple_statements(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + field = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + field_IJ = quantity_factory.zeros([I_DIM, J_DIM], "") + field_IJ_2 = quantity_factory.zeros([I_DIM, J_DIM], "") + + field.field[:, :, 0] = Float(42.0) + with StreeOptimization(): + code.multiple_statements(field, field_IJ, field_IJ_2) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + all_loop_region = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, LoopRegion) + ] + + assert ( + len(all_maps) == 3 + if stencil_factory.backend.loop_order == BackendLoopOrder.IJK + else 5 + ) + assert len(all_loop_region) == 0 + assert (field_IJ.field[:] == Float(42.0)).all() + assert (field_IJ_2.field[:] == Float(43.0)).all() From 454fb44867fa21a95018344be8e0136ef4c5419e Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 22 May 2026 14:13:11 +0200 Subject: [PATCH 21/43] dace update: connect source/sink nodes with empty memlets --- external/dace | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/dace b/external/dace index ec81b1a0..44657753 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit ec81b1a0c2a872da8dd315378ff6a9ac67d5458b +Subproject commit 44657753cef3c0ce3ef9deef9d0c81e0e7314b1e From 9ba2664b62c1cdfc23790cd175bb35260ec54dc9 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 22 May 2026 16:55:01 +0200 Subject: [PATCH 22/43] dace update: support for self-assigning copy nodes --- external/dace | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/dace b/external/dace index 44657753..99a9360d 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit 44657753cef3c0ce3ef9deef9d0c81e0e7314b1e +Subproject commit 99a9360d35f458c328b204860d59a365522484ab From f8798a08a469de196643f06992859fc464b270d3 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 22 May 2026 14:20:40 -0400 Subject: [PATCH 23/43] GPU tree orchestration pipeline - Local are no longer transient on GPU - RefineTransients is deactivated --- ndsl/dsl/dace/orchestration.py | 35 +++++++++++++++++++++++++-------- ndsl/dsl/dace/stree/pipeline.py | 21 +++++++++++++++++--- ndsl/quantity/local.py | 3 ++- 3 files changed, 47 insertions(+), 12 deletions(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index e31298d7..5b1de8e8 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -6,19 +6,20 @@ from pathlib import Path from typing import Any -from dace import SDFG, CompiledSDFG +from dace import SDFG, CompiledSDFG, DeviceType from dace import compiletime as DaceCompiletime from dace import dtypes from dace import method as dace_method from dace import nodes from dace import program as dace_program from dace.dtypes import DeviceType as DaceDeviceType +from dace.dtypes import ScheduleType from dace.dtypes import StorageType as DaceStorageType from dace.frontend.python.common import SDFGConvertible from dace.frontend.python.parser import DaceProgram from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.transformation.auto.auto_optimize import make_transients_persistent -from dace.transformation.dataflow import MapExpansion +from dace.transformation.dataflow import MapCollapse, MapExpansion from dace.transformation.helpers import get_parent_map from gt4py import storage as gt_storage @@ -37,7 +38,7 @@ negative_qtracers_checker, sdfg_nan_checker, ) -from ndsl.dsl.dace.stree import CPUPipeline +from ndsl.dsl.dace.stree import CPUPipeline, GPUPipeline from ndsl.dsl.dace.utils import ( DaCeProgress, memory_static_analysis, @@ -181,7 +182,18 @@ def _build_sdfg( # Here be 🐉 - but tests exists in test_optimization.py with DaCeProgress(config, "Schedule Tree: generate from SDFG"): # Break all loops into uni-dimensional loops to simplify optimizations - sdfg.apply_transformations_repeated(MapExpansion, validate=True) + sdfg.apply_transformations_repeated( + MapExpansion, + options={ + "inner_schedule": ( + ScheduleType.GPU_Device + if device_type is DeviceType.GPU + else ScheduleType.Default + ) + }, + validate=True, + print_report=True, + ) stree = sdfg.as_schedule_tree() if config.verbose_orchestration: with open( @@ -191,10 +203,16 @@ def _build_sdfg( f.write(stree.as_string()) with DaCeProgress(config, "Schedule Tree: optimization"): - CPUPipeline( - backend=backend_name, - cache_directory=Path(sdfg.build_folder), - ).run(stree, verbose=config.verbose_schedule_tree_optimizations) + if device_type == device_type.CPU: + CPUPipeline( + backend=backend_name, + cache_directory=Path(sdfg.build_folder), + ).run(stree, verbose=config.verbose_schedule_tree_optimizations) + elif device_type == DeviceType.GPU: + GPUPipeline( + backend=backend_name, + cache_directory=Path(sdfg.build_folder), + ).run(stree, verbose=config.verbose_schedule_tree_optimizations) if config.verbose_orchestration: with open( os.path.abspath(f"{sdfg.build_folder}/03-post_opt.stree.txt"), @@ -209,6 +227,7 @@ def _build_sdfg( os.path.abspath(f"{sdfg.build_folder}/04-from_stree.sdfgz"), compress=True, ) + sdfg.apply_transformations_repeated(MapCollapse) # Make the transients array persistents if config.is_gpu_backend(): diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index 13e30974..9833b2eb 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -17,7 +17,7 @@ class StreePipeline: def __init__( self, *, - passes: list[stree.ScheduleNodeTransformer], + passes: list[stree.ScheduleNodeVisitor], cache_directory: Path | None = None, ) -> None: if cache_directory is None: @@ -64,7 +64,7 @@ def __init__( self, backend: Backend, *, - passes: list[stree.ScheduleNodeTransformer] | None = None, + passes: list[stree.ScheduleNodeVisitor] | None = None, cache_directory: Path | None = None, ) -> None: if passes is None: @@ -83,9 +83,24 @@ def __init__( class GPUPipeline(StreePipeline): def __init__( self, - passes: list[stree.ScheduleNodeTransformer] | None = None, + backend: Backend, + *, + passes: list[stree.ScheduleNodeVisitor] | None = None, cache_directory: Path | None = None, ) -> None: + if passes is None: + passes = [ + CleanUpScheduleTree(), + InlineVertical2DWrite(), + CartesianMerge(backend), + # 🐞 Transient refine can't be used + # because of bugs transients showing in code generation + # CartesianRefineTransients(backend), + ] + super().__init__( + passes=passes, + cache_directory=cache_directory, + ) super().__init__( passes=passes if passes is not None else [], cache_directory=cache_directory, diff --git a/ndsl/quantity/local.py b/ndsl/quantity/local.py index f69480a7..37aee3eb 100644 --- a/ndsl/quantity/local.py +++ b/ndsl/quantity/local.py @@ -31,6 +31,7 @@ def __init__( # Initialize memory to obviously wrong value - Local should _not_ be expected # to be zero'ed. data[:] = 123456789 + self._on_gpu = backend.is_gpu_backend() super().__init__( data, @@ -45,5 +46,5 @@ def __init__( def __descriptor__(self) -> Any: """Locals uses `Quantity.__descriptor__` and flag itself as transient.""" data = dace.data.create_datadescriptor(self._data) - data.transient = True + data.transient = True if not self._on_gpu else False return data From 89294d2b5e520771a7f8af1d708ac36c181208f7 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 22 May 2026 14:23:32 -0400 Subject: [PATCH 24/43] Add scalarized array to tree statistics --- .../dsl/dace/stree/optimizations/statistics.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/statistics.py b/ndsl/dsl/dace/stree/optimizations/statistics.py index ebef36fe..6e5fe3af 100644 --- a/ndsl/dsl/dace/stree/optimizations/statistics.py +++ b/ndsl/dsl/dace/stree/optimizations/statistics.py @@ -34,20 +34,22 @@ def visit_ForScope(self, node: stree.ForScope) -> None: class CountTransient(stree.ScheduleNodeVisitor): def __init__(self) -> None: super().__init__() - self._counts = [0, 0, 0, 0] + self._counts = [0, 0, 0, 0, 0] def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: for data in node.containers.values(): non_atomic_dims_count = sum(1 for x in data.shape if x != 1) if isinstance(data, dace.data.Array) and data.transient: - if non_atomic_dims_count == 1: + if non_atomic_dims_count == 0: self._counts[0] += 1 - elif non_atomic_dims_count == 2: + elif non_atomic_dims_count == 1: self._counts[1] += 1 - elif non_atomic_dims_count == 3: + elif non_atomic_dims_count == 2: self._counts[2] += 1 - else: + elif non_atomic_dims_count == 3: self._counts[3] += 1 + else: + self._counts[4] += 1 class TreeOptimizationStatistics: @@ -59,7 +61,9 @@ class Record: cartesian_maps: list[int] = dataclasses.field(default_factory=lambda: [0, 0, 0]) cartesian_fors: list[int] = dataclasses.field(default_factory=lambda: [0, 0, 0]) - transients: list[int] = dataclasses.field(default_factory=lambda: [0, 0, 0, 0]) + transients: list[int] = dataclasses.field( + default_factory=lambda: [0, 0, 0, 0, 0] + ) def __init__(self) -> None: self._original_record = TreeOptimizationStatistics.Record() @@ -93,5 +97,5 @@ def report(self) -> str: msg = "Tree optimization:\n" msg += f" Cartesian maps [I, J, K]: {self._original_record.cartesian_maps} -> {self._optimized_record.cartesian_maps}\n" msg += f" Cartesian fors [I, J, K]: {self._original_record.cartesian_fors} -> {self._optimized_record.cartesian_fors}\n" - msg += f" Transients [1D, 2D, 3D, 4D+]: {self._original_record.transients} -> {self._optimized_record.transients}\n" + msg += f" Transients [Scalarized Array, 1D, 2D, 3D, 4D+]: {self._original_record.transients} -> {self._optimized_record.transients}\n" return msg From 456b5fbda02c22ea5d8735ee00ee630dbd460cb5 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 22 May 2026 17:49:35 -0400 Subject: [PATCH 25/43] Replace `AxisSymbol` in "masklet as well + rename file --- ndsl/dsl/dace/stree/optimizations/axis_merge.py | 2 +- ndsl/dsl/dace/stree/optimizations/remove_loops.py | 2 +- ...replace_symbol_in_tasklet.py => replace_axis_symbol.py} | 7 +++++++ 3 files changed, 9 insertions(+), 2 deletions(-) rename ndsl/dsl/dace/stree/optimizations/{replace_symbol_in_tasklet.py => replace_axis_symbol.py} (77%) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 3f10f122..e0867ede 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -16,7 +16,7 @@ no_data_dependencies_on_cartesian_axis, swap_node_position_in_tree, ) -from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( +from ndsl.dsl.dace.stree.optimizations.replace_axis_symbol import ( ReplaceAxisSymbol, ) diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py index c02f8af5..76b6dc54 100644 --- a/ndsl/dsl/dace/stree/optimizations/remove_loops.py +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -10,7 +10,7 @@ is_axis_for, list_index, ) -from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( +from ndsl.dsl.dace.stree.optimizations.replace_axis_symbol import ( ReplaceAxisSymbol, ) diff --git a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py b/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py similarity index 77% rename from ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py rename to ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py index dbc26eb0..b7f3d548 100644 --- a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py +++ b/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py @@ -14,6 +14,13 @@ def visit_TaskletNode(self, node: tn.TaskletNode) -> None: ): memlet.replace(self._axis_replacements) + if node.node.label.startswith("masklet"): + for old, new in self._axis_replacements.items(): + node.node.code.as_string = node.node.code.as_string.replace( + str(old), str(new) + ) + + def visit_IfScope(self, node: tn.IfScope) -> None: for old, new in self._axis_replacements.items(): node.condition.as_string = node.condition.as_string.replace( From 0aaa78d69070dbe601f1fd79e531671f8674b818 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 22 May 2026 17:49:53 -0400 Subject: [PATCH 26/43] Deactivate `InlineVertical2DWrite` for now --- ndsl/dsl/dace/stree/pipeline.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index 9833b2eb..ed3d4d77 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -70,7 +70,8 @@ def __init__( if passes is None: passes = [ CleanUpScheduleTree(), - InlineVertical2DWrite(), + # TODO: Is it safe? Deactivate for now + # InlineVertical2DWrite(), CartesianMerge(backend), CartesianRefineTransients(backend), ] @@ -91,7 +92,8 @@ def __init__( if passes is None: passes = [ CleanUpScheduleTree(), - InlineVertical2DWrite(), + # TODO: Is it safe? Deactivate for now + # InlineVertical2DWrite(), CartesianMerge(backend), # 🐞 Transient refine can't be used # because of bugs transients showing in code generation From 634a097197072d6ba4e654337b733be4abda92ae Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 22 May 2026 17:50:57 -0400 Subject: [PATCH 27/43] Lint --- ndsl/dsl/dace/stree/optimizations/axis_merge.py | 4 +--- ndsl/dsl/dace/stree/optimizations/remove_loops.py | 4 +--- ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py | 1 - ndsl/dsl/dace/stree/pipeline.py | 1 - 4 files changed, 2 insertions(+), 8 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index e0867ede..d875082e 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -16,9 +16,7 @@ no_data_dependencies_on_cartesian_axis, swap_node_position_in_tree, ) -from ndsl.dsl.dace.stree.optimizations.replace_axis_symbol import ( - ReplaceAxisSymbol, -) +from ndsl.dsl.dace.stree.optimizations.replace_axis_symbol import ReplaceAxisSymbol def _both_same_single_axis_maps( diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py index 76b6dc54..89716404 100644 --- a/ndsl/dsl/dace/stree/optimizations/remove_loops.py +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -10,9 +10,7 @@ is_axis_for, list_index, ) -from ndsl.dsl.dace.stree.optimizations.replace_axis_symbol import ( - ReplaceAxisSymbol, -) +from ndsl.dsl.dace.stree.optimizations.replace_axis_symbol import ReplaceAxisSymbol class InlineVertical2DWrite(tn.ScheduleNodeVisitor): diff --git a/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py b/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py index b7f3d548..c04c2fc5 100644 --- a/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py +++ b/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py @@ -19,7 +19,6 @@ def visit_TaskletNode(self, node: tn.TaskletNode) -> None: node.node.code.as_string = node.node.code.as_string.replace( str(old), str(new) ) - def visit_IfScope(self, node: tn.IfScope) -> None: for old, new in self._axis_replacements.items(): diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index ed3d4d77..cb3a2ec5 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -7,7 +7,6 @@ CartesianMerge, CartesianRefineTransients, CleanUpScheduleTree, - InlineVertical2DWrite, ) from ndsl.dsl.dace.stree.optimizations.statistics import TreeOptimizationStatistics from ndsl.logging import ndsl_log_on_rank_0 From 02102af6ac7c3fa421ebf5cfc81a5902c0d11ee2 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 26 May 2026 10:23:40 +0200 Subject: [PATCH 28/43] Fix tests after collapsing maps / fix non-cartesian loop inline Also adds infrastructure to override the orchestration pipeline in tests (used to allow testing `InlineVertical2Dwrite`). --- ndsl/dsl/dace/orchestration.py | 39 ++++++++---- .../dace/stree/optimizations/axis_merge.py | 2 +- ndsl/dsl/dace/stree/pipeline.py | 12 ++-- .../dace/stree/optimizations/test_merge.py | 20 +++---- .../test_offgrid_conditionals.py | 14 ++--- .../stree/optimizations/test_remove_loops.py | 60 +++++++++++++++---- tests/dsl/dace/stree/sdfg_stree_tools.py | 8 +++ 7 files changed, 107 insertions(+), 48 deletions(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 5b1de8e8..c88212c2 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -24,6 +24,7 @@ from gt4py import storage as gt_storage import ndsl.dsl.dace.replacements # noqa # We load in the DaCe replacements +from ndsl import Backend from ndsl.comm.mpi import MPI from ndsl.dsl.dace.build import get_sdfg_path, write_build_info from ndsl.dsl.dace.dace_config import ( @@ -39,6 +40,7 @@ sdfg_nan_checker, ) from ndsl.dsl.dace.stree import CPUPipeline, GPUPipeline +from ndsl.dsl.dace.stree.pipeline import StreePipeline from ndsl.dsl.dace.utils import ( DaCeProgress, memory_static_analysis, @@ -54,6 +56,8 @@ ) """INTERNAL: Developer flag to turn the untested schedule tree roundtrip optimizer.""" +_INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES: list[tn.ScheduleNodeVisitor] | None = None + def dace_inhibitor(func: Callable) -> Callable: """Triggers callback generation wrapping `func` while doing DaCe parsing.""" @@ -143,6 +147,24 @@ def _tree_as_sdfg(stree: tn.ScheduleTreeRoot) -> SDFG: return stree.as_sdfg(skip={"ScalarToSymbolPromotion", "ControlFlowRaising"}) +def _optimization_pipeline( + device_type: DeviceType, + backend: Backend, + *, + passes: list[tn.ScheduleNodeVisitor] | None = None, + cache_directory: Path | None = None, +) -> StreePipeline: + if device_type == device_type.CPU: + return CPUPipeline(backend, passes=passes, cache_directory=cache_directory) + + if device_type == DeviceType.GPU: + return GPUPipeline(backend, passes=passes, cache_directory=cache_directory) + + raise ValueError( + f"Unknown device type `{device_type}`, expected {DeviceType.CPU} or {DeviceType.GPU}." + ) + + def _build_sdfg( dace_program: DaceProgram, sdfg: SDFG, config: DaceConfig, args: Any, kwargs: Any ) -> None: @@ -203,16 +225,13 @@ def _build_sdfg( f.write(stree.as_string()) with DaCeProgress(config, "Schedule Tree: optimization"): - if device_type == device_type.CPU: - CPUPipeline( - backend=backend_name, - cache_directory=Path(sdfg.build_folder), - ).run(stree, verbose=config.verbose_schedule_tree_optimizations) - elif device_type == DeviceType.GPU: - GPUPipeline( - backend=backend_name, - cache_directory=Path(sdfg.build_folder), - ).run(stree, verbose=config.verbose_schedule_tree_optimizations) + pipeline = _optimization_pipeline( + device_type, + backend_name, + cache_directory=Path(sdfg.build_folder), + passes=_INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES, + ) + pipeline.run(stree, verbose=config.verbose_schedule_tree_optimizations) if config.verbose_orchestration: with open( os.path.abspath(f"{sdfg.build_folder}/03-post_opt.stree.txt"), diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index d875082e..0c5a476f 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -134,7 +134,7 @@ def _merge_node( def _for_merge(self, the_for_scope: tn.ForScope) -> int: merged = 0 - if is_axis_for(the_for_scope, self.axis): + if is_axis_for(the_for_scope, AxisIterator._K): # TODO: if the for scope is on a cartesian axis it can be # merged with other for scope going in the same direction pass diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index cb3a2ec5..f44a399e 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -1,6 +1,6 @@ from pathlib import Path -import dace.sdfg.analysis.schedule_tree.treenodes as stree +from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl import Backend from ndsl.dsl.dace.stree.optimizations import ( @@ -16,7 +16,7 @@ class StreePipeline: def __init__( self, *, - passes: list[stree.ScheduleNodeVisitor], + passes: list[tn.ScheduleNodeVisitor], cache_directory: Path | None = None, ) -> None: if cache_directory is None: @@ -33,9 +33,9 @@ def __repr__(self) -> str: def run( self, - stree: stree.ScheduleTreeRoot, + stree: tn.ScheduleTreeRoot, verbose: bool = False, - ) -> stree.ScheduleTreeRoot: + ) -> tn.ScheduleTreeRoot: tree_stats = TreeOptimizationStatistics() tree_stats.original(stree) @@ -63,7 +63,7 @@ def __init__( self, backend: Backend, *, - passes: list[stree.ScheduleNodeVisitor] | None = None, + passes: list[tn.ScheduleNodeVisitor] | None = None, cache_directory: Path | None = None, ) -> None: if passes is None: @@ -85,7 +85,7 @@ def __init__( self, backend: Backend, *, - passes: list[stree.ScheduleNodeVisitor] | None = None, + passes: list[tn.ScheduleNodeVisitor] | None = None, cache_directory: Path | None = None, ) -> None: if passes is None: diff --git a/tests/dsl/dace/stree/optimizations/test_merge.py b/tests/dsl/dace/stree/optimizations/test_merge.py index d57e758a..2e76029c 100644 --- a/tests/dsl/dace/stree/optimizations/test_merge.py +++ b/tests/dsl/dace/stree/optimizations/test_merge.py @@ -160,7 +160,7 @@ def test_trivial_merge(self, code: OrchestratedCode, factories: Factories) -> No if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 + assert len(all_maps) == 1 # all merged and collapsed assert (out_qty.field[:] == 2).all() def test_missing_merge_of_forscope_and_map( @@ -179,7 +179,7 @@ def test_missing_merge_of_forscope_and_map( for map_entry, _ in sdfg.all_nodes_recursive() if isinstance(map_entry, nodes.MapEntry) ] - assert len(all_maps) == 4 # 2 IJ + 2 Ks + assert len(all_maps) == 3 # 1 IJ + 2 Ks all_loops = [ loop for loop, _ in sdfg.all_nodes_recursive() @@ -203,7 +203,7 @@ def test_overcompute_merge( for me, state in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 # All maps merged + assert len(all_maps) == 1 # All maps merged and collapsed def test_block_merge_when_dependencies_are_found( self, code: OrchestratedCode, factories: Factories @@ -222,7 +222,7 @@ def test_block_merge_when_dependencies_are_found( for me, state in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 4 # 2 IJ + 2 Ks (un-merged) + assert len(all_maps) == 3 # 1 IJ + 2 Ks (un-merged) def test_push_non_cartesian_for( self, code: OrchestratedCode, factories: Factories @@ -242,7 +242,7 @@ def test_push_non_cartesian_for( for me, state in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 # All merged + assert len(all_maps) == 1 # All merged & collapsed for_loops = [ node for node, _ in sdfg.all_nodes_recursive() @@ -278,7 +278,7 @@ def test_trivial_merge(self, code: OrchestratedCode, factories: Factories) -> No if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 + assert len(all_maps) == 1 # all maps merged and collapsed assert (out_qty.field[:] == 2).all() def test_missing_merge_of_forscope_and_map( @@ -298,7 +298,7 @@ def test_missing_merge_of_forscope_and_map( for map_entry, _ in sdfg.all_nodes_recursive() if isinstance(map_entry, nodes.MapEntry) ] - assert len(all_maps) == 8 # 2 KJI (all maps) + 1 for scope + assert len(all_maps) == 3 # 2 KJI (all maps) + 1 JI all_loops = [ loop for loop, _ in sdfg.all_nodes_recursive() @@ -323,7 +323,7 @@ def test_overcompute_merge( for me, state in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 # All maps merged + assert len(all_maps) == 1 # All maps merged & collapsed def test_block_merge_when_dependencies_are_found( self, code: OrchestratedCode, factories: Factories @@ -342,7 +342,7 @@ def test_block_merge_when_dependencies_are_found( for me, state in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 6 # 2 * KJI + assert len(all_maps) == 2 # 2 * KJI def test_push_non_cartesian_for( self, code: OrchestratedCode, factories: Factories @@ -362,7 +362,7 @@ def test_push_non_cartesian_for( for me, state in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 # All merged + assert len(all_maps) == 1 # All merged and collapsed for_loops = [ node for node, _ in sdfg.all_nodes_recursive() diff --git a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py index f897173f..f58c90f7 100644 --- a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py +++ b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py @@ -88,14 +88,13 @@ def test_happy_case(self, factories: Factories) -> None: code.happy_case(in_quantity, out_quantity) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) - assert precompiled_sdfg.sdfg all_maps = [ (me, state) for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 + assert len(all_maps) == 1 # all merged and collapsed def test_happy_case_2(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories @@ -108,14 +107,13 @@ def test_happy_case_2(self, factories: Factories) -> None: code.happy_case_2(in_quantity, out_quantity) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) - assert precompiled_sdfg.sdfg all_maps = [ (me, state) for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 + assert len(all_maps) == 1 # all merged and collapsed def test_blocked_by_else(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories @@ -128,14 +126,13 @@ def test_blocked_by_else(self, factories: Factories) -> None: code.blocked_by_else(in_quantity, out_quantity) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) - assert precompiled_sdfg.sdfg all_maps = [ (me, state) for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 9 + assert len(all_maps) == 3 # 3 * IJK/KJI def test_blocked_by_other_nodes(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories @@ -148,7 +145,6 @@ def test_blocked_by_other_nodes(self, factories: Factories) -> None: code.blocked_by_other_nodes(in_quantity, out_quantity) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) - assert precompiled_sdfg.sdfg all_maps = [ (me, state) @@ -157,7 +153,7 @@ def test_blocked_by_other_nodes(self, factories: Factories) -> None: ] # ⚠️ Dev note: - # This should be just `assert len(all_maps) == 6`, but currently, the K-loops + # This should be just `assert len(all_maps) == 2`, but currently, the K-loops # can't merge because the K-iterators are different. To be fixed (and simplified # here) with a subsequent commit. - assert len(all_maps) == 9 + assert len(all_maps) == 3 diff --git a/tests/dsl/dace/stree/optimizations/test_remove_loops.py b/tests/dsl/dace/stree/optimizations/test_remove_loops.py index 4af0474a..da38e890 100644 --- a/tests/dsl/dace/stree/optimizations/test_remove_loops.py +++ b/tests/dsl/dace/stree/optimizations/test_remove_loops.py @@ -8,6 +8,12 @@ from ndsl.boilerplate import get_factories_single_tile from ndsl.config import Backend, BackendLoopOrder from ndsl.constants import I_DIM, J_DIM, K_DIM, Float +from ndsl.dsl.dace.stree.optimizations import InlineVertical2DWrite +from ndsl.dsl.dace.stree.pipeline import ( + CartesianMerge, + CartesianRefineTransients, + CleanUpScheduleTree, +) from ndsl.dsl.gt4py import FORWARD, computation, interval from ndsl.dsl.typing import FloatField, FloatFieldIJ from ndsl.stencils import copy @@ -123,12 +129,18 @@ def factories(self, request) -> Factories: def test_common_2D_write(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories code = OrchestratedCode(stencil_factory) + pipeline = [ + CleanUpScheduleTree(), + InlineVertical2DWrite(), + CartesianMerge(stencil_factory.backend), + CartesianRefineTransients(stencil_factory.backend), + ] in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM], "") in_qty.field[:, :, 0] = Float(32.0) - with StreeOptimization(): + with StreeOptimization(passes=pipeline): code.write_at_0(in_qty, out_qty) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -143,19 +155,25 @@ def test_common_2D_write(self, factories: Factories) -> None: if isinstance(me, LoopRegion) ] - assert len(all_maps) == 2 + assert len(all_maps) == 1 # IJ/JI collapsed assert len(all_loop_region) == 0 assert (out_qty.field[:] == Float(32.0)).all() def test_2D_write_K_top(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories code = OrchestratedCode(stencil_factory) + pipeline = [ + CleanUpScheduleTree(), + InlineVertical2DWrite(), + CartesianMerge(stencil_factory.backend), + CartesianRefineTransients(stencil_factory.backend), + ] in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM], "") in_qty.field[:, :, -1] = Float(32.0) - with StreeOptimization(): + with StreeOptimization(passes=pipeline): code.write_at_top(in_qty, out_qty) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -170,18 +188,24 @@ def test_2D_write_K_top(self, factories: Factories) -> None: if isinstance(me, LoopRegion) ] - assert len(all_maps) == 2 + assert len(all_maps) == 1 # IJ/JI collapsed assert len(all_loop_region) == 0 assert (out_qty.field[:] == Float(32.0)).all() def test_do_not_inline(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories code = OrchestratedCode(stencil_factory) + pipeline = [ + CleanUpScheduleTree(), + InlineVertical2DWrite(), + CartesianMerge(stencil_factory.backend), + CartesianRefineTransients(stencil_factory.backend), + ] in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): + with StreeOptimization(passes=pipeline): code.do_not_inline(in_qty, out_qty) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -196,19 +220,25 @@ def test_do_not_inline(self, factories: Factories) -> None: if isinstance(me, LoopRegion) ] - assert len(all_maps) == 2 + assert len(all_maps) == 1 # IJ/JI collapsed assert len(all_loop_region) == 1 assert (out_qty.field[:] == Float(1)).all() def test_combined_stencils(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories code = OrchestratedCode(stencil_factory) + pipeline = [ + CleanUpScheduleTree(), + InlineVertical2DWrite(), + CartesianMerge(stencil_factory.backend), + CartesianRefineTransients(stencil_factory.backend), + ] field = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") field_2 = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") field_IJ = quantity_factory.zeros([I_DIM, J_DIM], "") - with StreeOptimization(): + with StreeOptimization(passes=pipeline): code.combined_stencils(field, field_2, field_IJ) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -224,9 +254,9 @@ def test_combined_stencils(self, factories: Factories) -> None: ] assert ( - len(all_maps) == 3 + len(all_maps) == 2 # IJ + K if stencil_factory.backend.loop_order == BackendLoopOrder.IJK - else 5 + else 2 # KJI + JI ) assert len(all_loop_region) == 0 assert (field_IJ.field[:] == Float(1)).all() @@ -234,13 +264,19 @@ def test_combined_stencils(self, factories: Factories) -> None: def test_multiple_statements(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories code = OrchestratedCode(stencil_factory) + pipeline = [ + CleanUpScheduleTree(), + InlineVertical2DWrite(), + CartesianMerge(stencil_factory.backend), + CartesianRefineTransients(stencil_factory.backend), + ] field = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") field_IJ = quantity_factory.zeros([I_DIM, J_DIM], "") field_IJ_2 = quantity_factory.zeros([I_DIM, J_DIM], "") field.field[:, :, 0] = Float(42.0) - with StreeOptimization(): + with StreeOptimization(passes=pipeline): code.multiple_statements(field, field_IJ, field_IJ_2) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -256,9 +292,9 @@ def test_multiple_statements(self, factories: Factories) -> None: ] assert ( - len(all_maps) == 3 + len(all_maps) == 2 # IJ + K if stencil_factory.backend.loop_order == BackendLoopOrder.IJK - else 5 + else 2 # KJI + JI ) assert len(all_loop_region) == 0 assert (field_IJ.field[:] == Float(42.0)).all() diff --git a/tests/dsl/dace/stree/sdfg_stree_tools.py b/tests/dsl/dace/stree/sdfg_stree_tools.py index 6c664205..b913a134 100644 --- a/tests/dsl/dace/stree/sdfg_stree_tools.py +++ b/tests/dsl/dace/stree/sdfg_stree_tools.py @@ -1,6 +1,7 @@ from types import TracebackType import dace +from dace.sdfg.analysis.schedule_tree import treenodes as tn import ndsl.dsl.dace.orchestration as orch from ndsl import StencilFactory @@ -21,8 +22,14 @@ def get_SDFG_and_purge(stencil_factory: StencilFactory) -> dace.CompiledSDFG: class StreeOptimization: + def __init__(self, *, passes: list[tn.ScheduleNodeVisitor] | None = None) -> None: + self.passes = passes + def __enter__(self) -> None: + self.original_passes = orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES = self.passes def __exit__( self, @@ -31,3 +38,4 @@ def __exit__( exc_tb: TracebackType | None, ) -> None: orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES = self.original_passes From 43674afc9507194ba07bf925974731f9d331e3bd Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 28 May 2026 17:43:21 +0200 Subject: [PATCH 29/43] fixes to run GFLD_1M with orch:dace:cpu:KJI backend --- external/gt4py | 2 +- ndsl/dsl/stencil.py | 14 +- .../dsl/dace/stree/optimizations/__init__.py | 6 + .../dace/stree/optimizations/test_merge.py | 6 +- .../test_offgrid_conditionals.py | 17 +- .../stree/optimizations/test_remove_loops.py | 10 +- tests/dsl/orchestration/test_boundaries_k.py | 196 ++++++++++++++++++ 7 files changed, 219 insertions(+), 32 deletions(-) create mode 100644 tests/dsl/orchestration/test_boundaries_k.py diff --git a/external/gt4py b/external/gt4py index 7ba05d5d..08100b85 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 7ba05d5dc03c3d140c9074bc2b5f8e8027832842 +Subproject commit 08100b8505a6ce655a8b71043da514f3a6b8634a diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index a00adebb..26f3dce4 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -881,6 +881,8 @@ def _origin_from_dims(self, dims: Iterable[str]) -> list[int]: return_origin.append(self.origin[1]) elif dim in K_DIMS: return_origin.append(self.origin[2]) + else: + raise ValueError(f"Unknown dimension '{dim}'.") return return_origin def _domain_from_dims(self, dimensions: Iterable[str]) -> list[int]: @@ -888,16 +890,18 @@ def _domain_from_dims(self, dimensions: Iterable[str]) -> list[int]: for dimension in dimensions: if dimension == I_DIM: result.append(self.domain[0]) - if dimension == I_INTERFACE_DIM: + elif dimension == I_INTERFACE_DIM: result.append(self.domain[0] + 1) - if dimension == J_DIM: + elif dimension == J_DIM: result.append(self.domain[1]) - if dimension == J_INTERFACE_DIM: + elif dimension == J_INTERFACE_DIM: result.append(self.domain[1] + 1) - if dimension == K_DIM: + elif dimension == K_DIM: result.append(self.domain[2]) - if dimension == K_INTERFACE_DIM: + elif dimension == K_INTERFACE_DIM: result.append(self.domain[2] + 1) + else: + raise ValueError(f"Unknown dimension '{dimension}'.") return result def get_shape( diff --git a/tests/dsl/dace/stree/optimizations/__init__.py b/tests/dsl/dace/stree/optimizations/__init__.py index e69de29b..e0e56d60 100644 --- a/tests/dsl/dace/stree/optimizations/__init__.py +++ b/tests/dsl/dace/stree/optimizations/__init__.py @@ -0,0 +1,6 @@ +from typing import TypeAlias + +from ndsl import QuantityFactory, StencilFactory + + +Factories: TypeAlias = tuple[StencilFactory, QuantityFactory] diff --git a/tests/dsl/dace/stree/optimizations/test_merge.py b/tests/dsl/dace/stree/optimizations/test_merge.py index 2e76029c..1a9ed508 100644 --- a/tests/dsl/dace/stree/optimizations/test_merge.py +++ b/tests/dsl/dace/stree/optimizations/test_merge.py @@ -1,5 +1,3 @@ -from typing import TypeAlias - import dace import pytest from dace import nodes @@ -13,6 +11,7 @@ from ndsl.dsl.gt4py import FORWARD, PARALLEL, K, computation, interval from ndsl.dsl.typing import FloatField from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge +from tests.dsl.dace.stree.optimizations import Factories def stencil(in_field: FloatField, out_field: FloatField) -> None: @@ -130,9 +129,6 @@ def push_non_cartesian_for( self.stencil(in_field, out_field) -Factories: TypeAlias = tuple[StencilFactory, QuantityFactory] - - class TestStreeMergeMapsIJK: @pytest.fixture def factories(self) -> Factories: diff --git a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py index f58c90f7..fcfe33bc 100644 --- a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py +++ b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py @@ -1,20 +1,12 @@ -from typing import TypeAlias - import pytest from dace import nodes -from ndsl import ( - Backend, - NDSLRuntime, - QuantityFactory, - StencilFactory, - orchestrate, - stencils, -) +from ndsl import Backend, NDSLRuntime, StencilFactory, orchestrate, stencils from ndsl.boilerplate import get_factories_single_tile from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.typing import FloatField from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge +from tests.dsl.dace.stree.optimizations import Factories class OrchestratedCode(NDSLRuntime): @@ -66,12 +58,9 @@ def blocked_by_other_nodes( self._copy_stencil(in_field, out_field) -Factories: TypeAlias = tuple[StencilFactory, QuantityFactory] - - class TestStreeInlineOffgridConditionals: @pytest.fixture(params=["orch:dace:cpu:IJK", "orch:dace:cpu:KJI"]) - def factories(self, request) -> Factories: + def factories(self, request: pytest.FixtureRequest) -> Factories: domain = (3, 3, 4) return get_factories_single_tile( domain[0], domain[1], domain[2], 0, backend=Backend(request.param) diff --git a/tests/dsl/dace/stree/optimizations/test_remove_loops.py b/tests/dsl/dace/stree/optimizations/test_remove_loops.py index da38e890..9469f204 100644 --- a/tests/dsl/dace/stree/optimizations/test_remove_loops.py +++ b/tests/dsl/dace/stree/optimizations/test_remove_loops.py @@ -1,10 +1,8 @@ -from typing import TypeAlias - import pytest from dace import nodes from dace.sdfg.state import LoopRegion -from ndsl import QuantityFactory, StencilFactory, orchestrate +from ndsl import StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile from ndsl.config import Backend, BackendLoopOrder from ndsl.constants import I_DIM, J_DIM, K_DIM, Float @@ -18,6 +16,7 @@ from ndsl.dsl.typing import FloatField, FloatFieldIJ from ndsl.stencils import copy from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge +from tests.dsl.dace.stree.optimizations import Factories def stencil_simple_2D_write(in_field: FloatField, out_fieldIJ: FloatFieldIJ) -> None: @@ -114,12 +113,9 @@ def multiple_statements( self.stencil_multiple_2D_write(in_field, out_field, out_field2) -Factories: TypeAlias = tuple[StencilFactory, QuantityFactory] - - class TestStree2DWriteInline: @pytest.fixture(params=["orch:dace:cpu:IJK", "orch:dace:cpu:KJI"]) - def factories(self, request) -> Factories: + def factories(self, request: pytest.FixtureRequest) -> Factories: domain = (3, 3, 4) return get_factories_single_tile( diff --git a/tests/dsl/orchestration/test_boundaries_k.py b/tests/dsl/orchestration/test_boundaries_k.py new file mode 100644 index 00000000..80ea4a84 --- /dev/null +++ b/tests/dsl/orchestration/test_boundaries_k.py @@ -0,0 +1,196 @@ +import numpy as np +import pytest + +from ndsl import Backend, NDSLRuntime, StencilFactory, orchestrate +from ndsl.boilerplate import get_factories_single_tile +from ndsl.constants import I_DIM, J_DIM, K_DIM, K_INTERFACE_DIM +from ndsl.dsl.gt4py import BACKWARD, FORWARD, computation, interval +from ndsl.dsl.typing import FloatField +from tests.dsl.dace.stree.optimizations import Factories + + +def accumulate_down(in_field: FloatField, out_field: FloatField) -> None: # type: ignore + with computation(BACKWARD): + # handle top layer separately + with interval(-1, None): + out_field = in_field + + # accumulate "downwards" + with interval(0, -1): + out_field = out_field[0, 0, 1] + in_field + + +def accumulate_down_from_interface_field(interface_field: FloatField, out_field: FloatField) -> None: # type: ignore + with computation(BACKWARD): + # handle top layer separately + with interval(-1, None): + out_field = interface_field + interface_field[0, 0, 1] + + # accumulate "downwards" + with interval(0, -1): + out_field = out_field[0, 0, 1] + interface_field + + +def accumulate_on_interface(interface_field: FloatField, out_field: FloatField) -> None: # type: ignore + with computation(BACKWARD): + # handle top layer separately + with interval(-2, -1): + out_field = interface_field + interface_field[0, 0, 1] + + # accumulate "downwards" + with interval(0, -2): + out_field = out_field[0, 0, 1] + interface_field + + +def accumulate_up(in_field: FloatField, out_field: FloatField) -> None: # type: ignore + with computation(FORWARD): + # handle bottom layer separately + with interval(0, 1): + out_field = in_field + + # accumulate "upwards" + with interval(1, None): + out_field = out_field[0, 0, -1] + in_field + + +def accumulate_up_interface(in_field: FloatField, interface_field: FloatField) -> None: # type: ignore + with computation(FORWARD): + # handle bottom layer separately + with interval(0, 1): + interface_field = in_field + + # accumulate "upwards" + with interval(1, None): + interface_field = interface_field[0, 0, -1] + in_field[0, 0, -1] + + +class OrchestratedCode(NDSLRuntime): + def __init__(self, stencil_factory: StencilFactory) -> None: + super().__init__(stencil_factory) + + methods_to_orchestrate = [ + "accumulate_down", + "accumulate_down_from_interface_field", + "accumulate_on_interface", + "accumulate_up", + "accumulate_up_interface", + ] + + for method in methods_to_orchestrate: + orchestrate( + obj=self, + method_to_orchestrate=method, + config=stencil_factory.config.dace_config, + ) + + self._accumulate_down = stencil_factory.from_dims_halo( + func=accumulate_down, compute_dims=(I_DIM, J_DIM, K_DIM) + ) + + self._accumulate_down_from_interface_field = stencil_factory.from_dims_halo( + func=accumulate_down_from_interface_field, + compute_dims=(I_DIM, J_DIM, K_DIM), + ) + + self._accumulate_on_interface = stencil_factory.from_dims_halo( + func=accumulate_on_interface, compute_dims=(I_DIM, J_DIM, K_INTERFACE_DIM) + ) + + self._accumulate_up = stencil_factory.from_dims_halo( + func=accumulate_up, compute_dims=(I_DIM, J_DIM, K_DIM) + ) + + self._accumulate_up_interface = stencil_factory.from_dims_halo( + func=accumulate_up_interface, compute_dims=(I_DIM, J_DIM, K_INTERFACE_DIM) + ) + + def accumulate_down(self, in_field: FloatField, out_field: FloatField) -> None: # type: ignore + self._accumulate_down(in_field, out_field) + + def accumulate_down_from_interface_field(self, interface_field: FloatField, out_field: FloatField) -> None: # type: ignore + self._accumulate_down_from_interface_field(interface_field, out_field) + + def accumulate_on_interface(self, interface_field: FloatField, out_field: FloatField) -> None: # type: ignore + self._accumulate_on_interface(interface_field, out_field) + + def accumulate_up(self, in_field: FloatField, out_field: FloatField) -> None: # type: ignore + self._accumulate_up(in_field, out_field) + + def accumulate_up_interface(self, in_field: FloatField, interface_field: FloatField) -> None: # type: ignore + self._accumulate_up_interface(in_field, interface_field) + + +class TestBoundariesK: + @pytest.fixture( + params=[ + "orch:dace:cpu:IJK", + "orch:dace:cpu:KJI", + "st:dace:cpu:IJK", + "st:dace:cpu:KJI", + ] + ) + def factories(self, request: pytest.FixtureRequest) -> Factories: + domain = (3, 4, 5) + return get_factories_single_tile( + nx=domain[0], + ny=domain[1], + nz=domain[2], + nhalo=0, + backend=Backend(request.param), + ) + + def test_accumulate_down(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + in_field = quantity_factory.ones((I_DIM, J_DIM, K_DIM), units="") + out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), units="") + + code.accumulate_down(in_field, out_field) + assert np.array_equal(out_field.field[0, 0, :], [5, 4, 3, 2, 1]) + + def test_accumulate_interface_field(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + interface_field = quantity_factory.ones( + (I_DIM, J_DIM, K_INTERFACE_DIM), units="" + ) + out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), units="") + + code.accumulate_down_from_interface_field(interface_field, out_field) + assert np.array_equal(out_field.field[0, 0, :], [6, 5, 4, 3, 2]) + + def test_accumulate_interface_domain(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + interface_field = quantity_factory.ones( + (I_DIM, J_DIM, K_INTERFACE_DIM), units="" + ) + out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), units="") + + code.accumulate_on_interface(interface_field, out_field) + assert np.array_equal(out_field.field[0, 0, :], [6, 5, 4, 3, 2]) + + def test_accumulate_up(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + in_field = quantity_factory.ones((I_DIM, J_DIM, K_DIM), units="") + out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), units="") + + code.accumulate_up(in_field, out_field) + assert np.array_equal(out_field.field[0, 0, :], [1, 2, 3, 4, 5]) + + def test_accumulate_up_interface(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + in_field = quantity_factory.ones((I_DIM, J_DIM, K_DIM), units="") + interface_field = quantity_factory.zeros( + (I_DIM, J_DIM, K_INTERFACE_DIM), units="" + ) + + code.accumulate_up_interface(in_field, interface_field) + assert np.array_equal(interface_field.field[0, 0, :], [1, 2, 3, 4, 5, 6]) From 1ea1314dc0a956cd66609fefcae48cea864abb1e Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 29 May 2026 09:05:02 +0200 Subject: [PATCH 30/43] ci: gt4py update (restore temp dace working branch) --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index 08100b85..d19bf894 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 08100b8505a6ce655a8b71043da514f3a6b8634a +Subproject commit d19bf894f2361c26e5030facd2d06a19ea2af157 From fa60dccfe6d33a15ba2f9a60012358481c0f01ab Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 29 May 2026 10:27:36 +0200 Subject: [PATCH 31/43] remove extra `f` in result report header --- ndsl/stencils/testing/test_translate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/stencils/testing/test_translate.py b/ndsl/stencils/testing/test_translate.py index af9f8ae9..266e811a 100644 --- a/ndsl/stencils/testing/test_translate.py +++ b/ndsl/stencils/testing/test_translate.py @@ -469,7 +469,7 @@ def _report_results( os.makedirs(detail_dir, exist_ok=True) # Summary - header = f"{savepoint_name} w/ f{backend.as_humanly_readable()}" + header = f"{savepoint_name} w/ {backend.as_humanly_readable()}" lines = [] for varname, metric in results.items(): lines.append(f"{varname}: {metric.one_line_report()}") From 160923f14056f20016afdc3044c175c3ac27f83c Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 1 Jun 2026 16:24:50 +0200 Subject: [PATCH 32/43] unrelated dace/gt4py update: just test fixes and a typo --- external/dace | 2 +- external/gt4py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/external/dace b/external/dace index 99a9360d..4da9d096 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit 99a9360d35f458c328b204860d59a365522484ab +Subproject commit 4da9d096ed3454ffa6dcb7b5233c281dc90696c2 diff --git a/external/gt4py b/external/gt4py index d19bf894..210fcbd8 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit d19bf894f2361c26e5030facd2d06a19ea2af157 +Subproject commit 210fcbd8c78800bf26421fac3c49c5b22e59d4e4 From c2bd78d5fd7763d6a047882e0248210f78f222fc Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 2 Jun 2026 11:34:31 -0400 Subject: [PATCH 33/43] Expose `gpu:IJK` backends to NDSL --- ndsl/config/backend.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index 2807cf6a..605b86d7 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -52,6 +52,8 @@ class BackendLoopOrder(Enum): "orch:dace:cpu:KJI": "dace:cpu_KJI", "st:dace:gpu:KJI": "dace:gpu", "orch:dace:gpu:KJI": "dace:gpu", + "st:dace:gpu:IJK": "dace:gpu_IJK", + "orch:dace:gpu:IJK": "dace:gpu_IJK", } """Internal: match the NDSL backend names with the GT4Py names""" From eaaa0cc0a579d1546ed8ed57099cf88857292e0b Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 3 Jun 2026 11:48:07 +0200 Subject: [PATCH 34/43] Disable DaceConfig.from_dict() as it is incomplete While the functions creates an inconstent DaceConfig by creating a config first and then tempering with some properites without re-evaluating computed properties. In particular `code_path`, `do_compile` and distributed caches are potentially out of sync with the layout information. --- ndsl/dsl/dace/dace_config.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 62b679a5..c53ebc85 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -166,8 +166,8 @@ def __init__( Args: communicator: used for setting the distributed caches backend: string for the backend - tile_nx: x/y domain size for a single time - tile_nz: z domain size for a single time + tile_nx: x/y domain size for a single tile + tile_nz: z domain size for a single tile orchestration: orchestration mode from DaCeOrchestration time: trigger performance collection, available to user with `performance_collector` @@ -412,4 +412,11 @@ def from_dict(cls, data: dict) -> Self: config.rank_size = data["rank_size"] config.layout = data["layout"] config.tile_resolution = data["tile_resolution"] - return config + # TODO + # Computed properties like `self.code_path` and `self.do_compile` + # aren't updated. + # We also don't `set_distributed_caches()` based on that updated + # information. + raise NotImplementedError( + "Implementation of `DaceConfig.from_dict()` is incomplete." + ) From 749f8a3c5387dd934e6c42ac301bf274e5d5f639 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 3 Jun 2026 11:54:44 +0200 Subject: [PATCH 35/43] readability of cache location code --- ndsl/dsl/caches/cache_location.py | 64 ++++++++++++++++--------------- ndsl/dsl/caches/codepath.py | 2 + 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/ndsl/dsl/caches/cache_location.py b/ndsl/dsl/caches/cache_location.py index 87d608dd..d4313815 100644 --- a/ndsl/dsl/caches/cache_location.py +++ b/ndsl/dsl/caches/cache_location.py @@ -7,46 +7,48 @@ def identify_code_path( partitioner: Partitioner, single_code_path: bool, ) -> FV3CodePath: - """Determine which code path your rank will hit. + """ + Determine which code path your rank will hit. - If single_code_path is True, single_code_path is True, - only one code path exists (case of doubly periodic grid). + If single_code_path is True, only one code path exists, + e.g. in case of a doubly periodic grid. If single_code_path is False, we are in the case of the - cube-sphere and we will look at our position on the tile.""" + cube-sphere and we will look at our position on the tile. + """ # Doubly-periodic or single tile grid - if single_code_path: + if single_code_path or partitioner.layout == (1, 1): return FV3CodePath.All # Cube-sphere - if partitioner.layout == (1, 1): - return FV3CodePath.All - elif partitioner.layout[0] == 1 or partitioner.layout[1] == 1: + if partitioner.layout[0] <= 1 or partitioner.layout[1] <= 1: raise NotImplementedError( - f"Build for layout {partitioner.layout} is not handled" + f"Build for layout {partitioner.layout} is not handled." ) - else: - if partitioner.tile.on_tile_bottom(rank): - if partitioner.tile.on_tile_left(rank): - return FV3CodePath.BottomLeft - if partitioner.tile.on_tile_right(rank): - return FV3CodePath.BottomRight - else: - return FV3CodePath.Bottom - if partitioner.tile.on_tile_top(rank): - if partitioner.tile.on_tile_left(rank): - return FV3CodePath.TopLeft - if partitioner.tile.on_tile_right(rank): - return FV3CodePath.TopRight - else: - return FV3CodePath.Top - else: - if partitioner.tile.on_tile_left(rank): - return FV3CodePath.Left - if partitioner.tile.on_tile_right(rank): - return FV3CodePath.Right - else: - return FV3CodePath.Center + + # Bottom row + if partitioner.tile.on_tile_bottom(rank): + if partitioner.tile.on_tile_left(rank): + return FV3CodePath.BottomLeft + if partitioner.tile.on_tile_right(rank): + return FV3CodePath.BottomRight + return FV3CodePath.Bottom + + # Top row + if partitioner.tile.on_tile_top(rank): + if partitioner.tile.on_tile_left(rank): + return FV3CodePath.TopLeft + if partitioner.tile.on_tile_right(rank): + return FV3CodePath.TopRight + return FV3CodePath.Top + + # Left & right column with corners already handled + if partitioner.tile.on_tile_left(rank): + return FV3CodePath.Left + if partitioner.tile.on_tile_right(rank): + return FV3CodePath.Right + + return FV3CodePath.Center def get_cache_fullpath(code_path: FV3CodePath) -> str: diff --git a/ndsl/dsl/caches/codepath.py b/ndsl/dsl/caches/codepath.py index 61591ccf..3d90a9e2 100644 --- a/ndsl/dsl/caches/codepath.py +++ b/ndsl/dsl/caches/codepath.py @@ -3,10 +3,12 @@ class FV3CodePath(enum.Enum): """Enum listing all possible code paths on a cube sphere. + For any layout the cube sphere has up to 9 different code paths depending on the positioning of the rank on the tile and which of the edge/corner cases it has to handle, as well as the possibility for all boundary computations in the 1x1 layout case. + Since the framework inlines code to optimize, we _cannot_ pre-suppose which code being kept and/or ejected. This enum serves as the ground truth to map rank to the proper generated code. From 0d860bf4a4f0c65d36ab893dd6b6117184b71ba1 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 4 Jun 2026 14:46:32 +0200 Subject: [PATCH 36/43] translate tests: fix crash in reporting when comparing scalars UW translate test compares a scalar value (`dotransport`) as part of the translate test. Doing so trips reporting and this change makes it work again \o/ --- ndsl/testing/comparison.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/testing/comparison.py b/ndsl/testing/comparison.py index 3acfd723..e7fc93eb 100644 --- a/ndsl/testing/comparison.py +++ b/ndsl/testing/comparison.py @@ -339,7 +339,7 @@ def one_line_report(self) -> str: return f"❌ Numerical failures: {failed_indices}/{all_indices} failed - metric: {metric_thresholds}" def report(self, file_path: str | None = None) -> list[str]: - failed_indices = np.logical_not(self.success).nonzero() + failed_indices = np.atleast_1d(np.logical_not(self.success)).nonzero() # List all errors to terminal and file bad_indices_count = len(failed_indices[0]) if self.changing_column_map is not None: From 5ee3bb9217f29f7cd23d1ae95dc14109867c8ba4 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 5 Jun 2026 12:20:10 -0400 Subject: [PATCH 37/43] Weaken the cube-sphere communicator hard ranks limit. We need "at least" not "exactly" the rank number --- ndsl/comm/communicator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index 65d72018..abb70ec8 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -786,7 +786,7 @@ def __init__( "Communicator needs to be instantiated with communication subsystem" f" derived from `comm_abc.Comm`, got {type(comm)}." ) - if comm.Get_size() != partitioner.total_ranks: + if comm.Get_size() < partitioner.total_ranks: raise ValueError( f"was given a partitioner for {partitioner.total_ranks} ranks but a " f"comm object with only {comm.Get_size()} ranks, are we running " From da82edefd73728cea16c6288297422e7f6bd7792 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 5 Jun 2026 14:03:22 -0400 Subject: [PATCH 38/43] Adjust `cflags` format read for orchestrated compile Protect `performance_timer` for `time==False` and add external setup --- ndsl/dsl/dace/dace_config.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index c53ebc85..579ddcb0 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -181,16 +181,10 @@ def __init__( # ToDo: DaceConfig becomes a bit more than a read-only config # with this. Should be refactored into a DaceExecutor carrying a config self.loaded_dace_executables: DaceExecutables = {} - self.performance_collector = ( - PerformanceCollector( - "InternalOrchestrationTimer", - comm=( - LocalComm(0, 6, {}) if communicator is None else communicator.comm - ), - ) - if time - else NullPerformanceCollector() - ) + if not time: + self.performance_collector = NullPerformanceCollector() + else: + self.set_timer(communicator.comm if communicator else None) # Temporary. This is a bit too out of the ordinary for the common user. # We should refactor the architecture to allow for a `gtc:orchestrated:dace:X` @@ -264,11 +258,12 @@ def __init__( march_option = "-mcpu=native" if is_arm_neoverse else "-march=native" # Removed --fast-math gpu_config = gpu_configuration(GT4PY_COMPILE_OPT_LEVEL) + gpu_cflags = " ".join(gpu_config.gpu_compile_flags).strip() dace.config.Config.set( "compiler", "cuda", "args", - value=f"-std=c++14 -Xcompiler -fPIC -O{optimization_level} -Xcompiler {march_option} {gpu_config.gpu_compile_flags}", + value=f"-std=c++14 -Xcompiler -fPIC -O{optimization_level} -Xcompiler {march_option} {gpu_cflags}", ) cuda_sm = cp.cuda.Device(0).compute_capability if cp else 60 @@ -420,3 +415,12 @@ def from_dict(cls, data: dict) -> Self: raise NotImplementedError( "Implementation of `DaceConfig.from_dict()` is incomplete." ) + + def set_timer(self, comm): + """Set timer on configuration externally""" + # TODO: this absolutely should not be a on a Configuration object + # and even less setup outside. Madness, we have lost our ways... + self.performance_collector = PerformanceCollector( + "InternalOrchestrationTimer", + comm=(LocalComm(0, 6, {}) if comm is None else comm), + ) From 64dd47c5022663105923d9345bd7529bbd79dd34 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 5 Jun 2026 14:05:44 -0400 Subject: [PATCH 39/43] Lint --- ndsl/dsl/dace/dace_config.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 579ddcb0..0059ff39 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -10,6 +10,7 @@ from gt4py.cartesian.utils.compiler import cxx_compiler_defaults, gpu_configuration from ndsl import LocalComm +from ndsl.comm import Comm from ndsl.comm.communicator import Communicator from ndsl.comm.partitioner import Partitioner from ndsl.config import Backend @@ -17,7 +18,11 @@ from ndsl.dsl.caches.cache_location import identify_code_path from ndsl.dsl.caches.codepath import FV3CodePath from ndsl.optional_imports import cupy as cp -from ndsl.performance.collector import NullPerformanceCollector, PerformanceCollector +from ndsl.performance.collector import ( + AbstractPerformanceCollector, + NullPerformanceCollector, + PerformanceCollector, +) if TYPE_CHECKING: @@ -182,7 +187,9 @@ def __init__( # with this. Should be refactored into a DaceExecutor carrying a config self.loaded_dace_executables: DaceExecutables = {} if not time: - self.performance_collector = NullPerformanceCollector() + self.performance_collector: AbstractPerformanceCollector = ( + NullPerformanceCollector() + ) else: self.set_timer(communicator.comm if communicator else None) @@ -416,7 +423,7 @@ def from_dict(cls, data: dict) -> Self: "Implementation of `DaceConfig.from_dict()` is incomplete." ) - def set_timer(self, comm): + def set_timer(self, comm: Comm | None) -> None: """Set timer on configuration externally""" # TODO: this absolutely should not be a on a Configuration object # and even less setup outside. Madness, we have lost our ways... From 26fb0ef2f60168371fc401520ecdd15c1a1842ae Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sat, 6 Jun 2026 17:13:31 -0400 Subject: [PATCH 40/43] Introduce hardware configuration good defaults --- ndsl/dsl/dace/dace_config.py | 22 ++++-- ndsl/dsl/dace/hardware_config.py | 112 +++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 7 deletions(-) create mode 100644 ndsl/dsl/dace/hardware_config.py diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 0059ff39..f73bc249 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -17,6 +17,7 @@ from ndsl.dsl import NDSL_GLOBAL_PRECISION from ndsl.dsl.caches.cache_location import identify_code_path from ndsl.dsl.caches.codepath import FV3CodePath +from ndsl.dsl.dace.hardware_config import get_gpu_hardware_defaults from ndsl.optional_imports import cupy as cp from ndsl.performance.collector import ( AbstractPerformanceCollector, @@ -273,14 +274,21 @@ def __init__( value=f"-std=c++14 -Xcompiler -fPIC -O{optimization_level} -Xcompiler {march_option} {gpu_cflags}", ) - cuda_sm = cp.cuda.Device(0).compute_capability if cp else 60 - dace.config.Config.set("compiler", "cuda", "cuda_arch", value=f"{cuda_sm}") - # Block size/thread count is defaulted to an average value for recent - # hardware (Pascal and upward). The problem of setting an optimized - # block/thread is both hardware and problem dependant. Fine tuners - # available in DaCe should be relied on for further tuning of this value. + # Target compilation for hardware micro-code capacities + gpu_defaults = get_gpu_hardware_defaults() dace.config.Config.set( - "compiler", "cuda", "default_block_size", value="64,8,1" + "compiler", + "cuda", + "cuda_arch", + value=f"{gpu_defaults.compute_capability}", + ) + + # Default block size for kernels launch + dace.config.Config.set( + "compiler", + "cuda", + "default_block_size", + value=str(gpu_defaults.block_size)[1:-1], ) # Potentially buggy - deactivate dace.config.Config.set( diff --git a/ndsl/dsl/dace/hardware_config.py b/ndsl/dsl/dace/hardware_config.py new file mode 100644 index 00000000..ca28ac3b --- /dev/null +++ b/ndsl/dsl/dace/hardware_config.py @@ -0,0 +1,112 @@ +import dataclasses +import os +import sys + +from ndsl import ndsl_log +from ndsl.optional_imports import cupy as cp + + +# Taken straight out of https://pcisig.com/membership/member-companies +_VENDOR_PCI_SIGNAURES = { + 0x10DE: "Nvidia", + 0x1002: "AMD", + 0x8086: "Intel", + 0x0: "Unknown", +} + +# Cached copy of the hardware default +_GPU_HARDWARE_DEFAULTS = None + + +def _get_vendor() -> str: + """Retrieve vendor using the current device PCI id to query the PCI vendor + from the kernel logs + + ⚠️ Only works on Linux - kicks back to "Unknwon" in other cases + """ + if not sys.platform.startswith("linux"): + return _VENDOR_PCI_SIGNAURES[0x0] + + pci_device_id = cp.cuda.runtime.deviceGetPCIBusId(0) + dev_path = f"/sys/bus/pci/devices/{pci_device_id}" + if not os.path.exists(dev_path): + return "Unknown" + + with open(os.path.join(dev_path, "vendor"), "r") as f: + vendor_str = f.read().strip().replace("0x", "") + vendor_id = int(vendor_str, 16) + + if vendor_id not in _VENDOR_PCI_SIGNAURES: + ndsl_log.error(f"Unknown GPU vendor with PCI-SIG ID of {vendor_id:#X}") + return "Unknown" + return _VENDOR_PCI_SIGNAURES[int(vendor_str, 16)] + + +@dataclasses.dataclass +class GPUHardwareDefaults: + """Compute defaults for common GPUs""" + + vendor: str + block_size: list[int] = dataclasses.field(default_factory=list) + compute_capability: int = -1 # Nvidia specific + + +def get_gpu_hardware_defaults() -> GPUHardwareDefaults: + """Retrieve default values for GPU computation configuration""" + global _GPU_HARDWARE_DEFAULTS + if _GPU_HARDWARE_DEFAULTS is not None: + return _GPU_HARDWARE_DEFAULTS # type: ignore[unreachable] + + if not cp: + raise ModuleNotFoundError("Cupy must be installed to read hardware defaults") + if not cp.cuda.is_available(): + raise RuntimeError("No device available for hardware defaults read") + + # Who goes there + vendor = _get_vendor() + if vendor == "Nvidia": + compute_capability = int(cp.cuda.Device(0).compute_capability) + # Default block size based on compute capability + if compute_capability > 80: + # Covers: + # - Blackwell (100+) + # - Hopper (90-100) + # - Ampere (80-90) + block_sizes = [128, 1, 1] + elif compute_capability > 60: + # Covers: + # - Volta (70-80) + # - Pascal (60-70) + block_sizes = [64, 8, 1] + else: + # For older hardware - we default to the safe warp-size since + # the dawn of GPGPU on Nvidia hardware + block_sizes = [32, 1, 1] + + _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults( + vendor=vendor, + block_size=block_sizes, + compute_capability=compute_capability, + ) + elif vendor == "AMD": + _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults( + vendor=vendor, block_size=[64, 1, 1] # Default RDNA architectue is Wave64 + ) + elif vendor == "Intel": + _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults( + vendor=vendor, + block_size=[32, 1, 1], # Intel can run 8, 16 or 32 - but SIMD betters in 32 + ) + else: + _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults( + vendor=vendor, + block_size=[ + 8, + 1, + 1, + ], # Smaller common denominator of massively parallel hardware + ) + + ndsl_log.info(f"GPU vendor detected: {_GPU_HARDWARE_DEFAULTS.vendor}") + + return _GPU_HARDWARE_DEFAULTS From 7bdd3fa39f1317955824be4016ab9bf1b7139e1f Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sat, 6 Jun 2026 17:15:06 -0400 Subject: [PATCH 41/43] Fix double load for compiling rank Split Simplify2 pass into a GPU centric with block_size on maps & apply_gpu_xform Remove useless code - legacy code bleed Verbose the steps better --- ndsl/dsl/dace/orchestration.py | 112 +++++++++++++++++++-------------- 1 file changed, 65 insertions(+), 47 deletions(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index c88212c2..cf64fa8a 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -33,6 +33,7 @@ DaCeOrchestration, ) from ndsl.dsl.dace.dace_executable import DaceExecutable +from ndsl.dsl.dace.hardware_config import get_gpu_hardware_defaults from ndsl.dsl.dace.labeler import set_label from ndsl.dsl.dace.sdfg_debug_passes import ( negative_delp_checker, @@ -248,43 +249,52 @@ def _build_sdfg( ) sdfg.apply_transformations_repeated(MapCollapse) - # Make the transients array persistents - if config.is_gpu_backend(): - # TODO - # The following should happen on the stree level - _to_gpu(sdfg) - - sdfg.apply_gpu_transformations() + with DaCeProgress(config, "Make transient persistents"): + # Make the transients array persistents + if config.is_gpu_backend(): + # TODO + # The following should happen on the stree level + _to_gpu(sdfg) + make_transients_persistent(sdfg=sdfg, device=device_type) - make_transients_persistent(sdfg=sdfg, device=device_type) + # Upload args to device + _upload_to_device(list(args) + list(kwargs.values())) + else: + # TODO + # The following should happen on the stree level + for _sd, _aname, arr in sdfg.arrays_recursive(): + if arr.shape == (1,): + arr.storage = DaceStorageType.Register + make_transients_persistent(sdfg=sdfg, device=device_type) - # Upload args to device - _upload_to_device(list(args) + list(kwargs.values())) + if config.is_gpu_backend(): + with DaCeProgress(config, "Apply GPU transformations"): + # Set block size on GPU maps + gpu_defaults = get_gpu_hardware_defaults() + for me, _state in sdfg.all_nodes_recursive(): + if ( + isinstance(me, nodes.MapEntry) + and me.map.schedule == ScheduleType.GPU_Device + ): + if me.map.gpu_block_size is None: + me.map.gpu_block_size = gpu_defaults.block_size + # Apply common GPU transforms (includes a simplify) + sdfg.apply_gpu_transformations() + if config.verbose_orchestration: + sdfg.save( + os.path.abspath( + f"{sdfg.build_folder}/05-apply_gpu_xforms.sdfgz" + ), + compress=True, + ) else: - # TODO - # The following should happen on the stree level - for _sd, _aname, arr in sdfg.arrays_recursive(): - if arr.shape == (1,): - arr.storage = DaceStorageType.Register - make_transients_persistent(sdfg=sdfg, device=device_type) - - # Build non-constants & non-transients from the sdfg_kwargs - sdfg_kwargs = dace_program._create_sdfg_args(sdfg, args, kwargs) - for k in dace_program.constant_args: - if k in sdfg_kwargs: - del sdfg_kwargs[k] - sdfg_kwargs = {k: v for k, v in sdfg_kwargs.items() if v is not None} - for k, tup in dace_program.resolver.closure_arrays.items(): - if k in sdfg_kwargs and tup[1].transient: - del sdfg_kwargs[k] - - with DaCeProgress(config, "Simplify (2)"): - _simplify(sdfg) - if config.verbose_orchestration: - sdfg.save( - os.path.abspath(f"{sdfg.build_folder}/05-simplify_2.sdfgz"), - compress=True, - ) + with DaCeProgress(config, "Simplify (2)"): + _simplify(sdfg) + if config.verbose_orchestration: + sdfg.save( + os.path.abspath(f"{sdfg.build_folder}/05-simplify_2.sdfgz"), + compress=True, + ) # Move all memory that can be into a pool to lower memory pressure for GPU # We skip this memory optimization for CPU because we don't have a memory # pool available yet (DaCe v1) @@ -313,7 +323,12 @@ def _build_sdfg( # Compile with DaCeProgress(config, "Codegen & compile"): - sdfg.compile() + compiled_sdfg = sdfg.compile() + config.loaded_dace_executables[dace_program] = DaceExecutable( + compiled_sdfg=compiled_sdfg, + arguments={}, + arguments_hash=0, + ) # Printing analysis of the compiled SDFG with DaCeProgress(config, "Build finished. Running memory static analysis"): @@ -352,18 +367,21 @@ def _build_sdfg( ) MPI.COMM_WORLD.Barrier() - with DaCeProgress(config, "Loading"): - sdfg_path = get_sdfg_path(dace_program.name, config, override_run_only=True) - if sdfg_path is None: - raise ValueError("Couldn't load SDFG post build") - compiledSDFG, _ = dace_program.load_precompiled_sdfg( - sdfg_path, *args, **kwargs - ) - config.loaded_dace_executables[dace_program] = DaceExecutable( - compiled_sdfg=compiledSDFG, - arguments={}, - arguments_hash=0, - ) + if not is_compiling: + with DaCeProgress(config, "Loading"): + sdfg_path = get_sdfg_path( + dace_program.name, config, override_run_only=True + ) + if sdfg_path is None: + raise ValueError("Couldn't load SDFG post build") + compiledSDFG, _ = dace_program.load_precompiled_sdfg( + sdfg_path, *args, **kwargs + ) + config.loaded_dace_executables[dace_program] = DaceExecutable( + compiled_sdfg=compiledSDFG, + arguments={}, + arguments_hash=0, + ) def _call_sdfg( From 0fcd9bd0e5894c8995cfbe170b9489a057eb91f7 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sat, 6 Jun 2026 17:46:03 -0400 Subject: [PATCH 42/43] Hardware default: gives back default when no `cp` instead of raising --- ndsl/dsl/dace/hardware_config.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/ndsl/dsl/dace/hardware_config.py b/ndsl/dsl/dace/hardware_config.py index ca28ac3b..ebcdbfee 100644 --- a/ndsl/dsl/dace/hardware_config.py +++ b/ndsl/dsl/dace/hardware_config.py @@ -57,10 +57,17 @@ def get_gpu_hardware_defaults() -> GPUHardwareDefaults: if _GPU_HARDWARE_DEFAULTS is not None: return _GPU_HARDWARE_DEFAULTS # type: ignore[unreachable] - if not cp: - raise ModuleNotFoundError("Cupy must be installed to read hardware defaults") - if not cp.cuda.is_available(): - raise RuntimeError("No device available for hardware defaults read") + if not cp or not cp.cuda.is_available(): + ndsl_log.warning("No cupy - defaulting for GPU hardware") + _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults( + vendor="Unknown", + block_size=[ + 8, + 1, + 1, + ], # Smaller common denominator of massively parallel hardware + ) + return _GPU_HARDWARE_DEFAULTS # Who goes there vendor = _get_vendor() From d843a2cd7785a3a4584c42c82720ebcc11fce365 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sun, 7 Jun 2026 16:49:24 -0400 Subject: [PATCH 43/43] Orch: always collapse maps to maximize the kernel parallel basis --- ndsl/dsl/dace/orchestration.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index cf64fa8a..07cdeeb5 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -247,7 +247,11 @@ def _build_sdfg( os.path.abspath(f"{sdfg.build_folder}/04-from_stree.sdfgz"), compress=True, ) - sdfg.apply_transformations_repeated(MapCollapse) + + # We want all maps properly collapse to make sure the codegen will see nD parallel + # axis as a single kernelizable map + with DaCeProgress(config, "Collapse maps"): + sdfg.apply_transformations_repeated(MapCollapse) with DaCeProgress(config, "Make transient persistents"): # Make the transients array persistents