-
Notifications
You must be signed in to change notification settings - Fork 157
Remove name conflicts in loop iterators #2326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
e06220e
6e5f3cf
7720a18
8fe2dec
69ece11
33618d8
1c18af1
ab99a6e
f7c5e49
b35f879
5b0ec86
96e2cdc
707f603
0054bdb
a85d03e
2f04f7b
968c873
3c32722
82f5d9f
6d76968
3e4be70
aa4aab0
a3c7a67
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| # Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. | ||
| import dace | ||
|
|
||
| from dace.sdfg.state import LoopRegion | ||
| from dace.transformation import pass_pipeline as ppl | ||
| from dace.sdfg import utils as sdutil | ||
| from typing import Optional | ||
| import copy | ||
| from dace.sdfg.state import ControlFlowRegion | ||
| from dace.transformation.passes.analysis import loop_analysis | ||
| from dace.transformation.transformation import explicit_cf_compatible | ||
|
|
||
| import sympy as sp | ||
| from typing import Union | ||
|
|
||
|
|
||
| def replace_symbol_by_name(expr: sp.Basic, old_name: str, new: Union[str, sp.Basic]) -> sp.Basic: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we have symbolic replace and subs etc. |
||
| """ | ||
| Replace all symbols in `expr` whose .name matches `old_name`, | ||
| regardless of assumptions, with `new`. | ||
| """ | ||
| if isinstance(new, str): | ||
| new = sp.Symbol(new) | ||
| repl = {s: new for s in expr.free_symbols if s.name == old_name} | ||
| if not repl: | ||
| return expr | ||
| return expr.subs(repl) | ||
|
|
||
|
|
||
| @dace.properties.make_properties | ||
| @explicit_cf_compatible | ||
| class SSALoopIterators(ppl.Pass): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same for this class name |
||
| loop_var_counter = 0 | ||
| FOR_IT_NAME = "_it" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be a class variable or just a global constant
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this one is a global constant, I updated accordingly. |
||
|
|
||
| def modifies(self) -> ppl.Modifies: | ||
| return ppl.Modifies.AccessNodes | ppl.Modifies.Memlets | ||
|
|
||
| def should_reapply(self, modified: ppl.Modifies) -> bool: | ||
| return False | ||
|
|
||
| def _repl_recursive(self, cfg: ControlFlowRegion | dace.SDFG, loop_var: str, next_ssa_loop_var: str): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like dace.sdfg.replace.replace_dict would do a good job here, if applied selectively |
||
| # What about Nested SDFGs? Do we need to update symbol mapping? | ||
| cfg.replace_meta_accesses({loop_var: next_ssa_loop_var}) | ||
| cfg.replace_dict({loop_var: next_ssa_loop_var}) | ||
|
|
||
| for state in cfg.all_states(): | ||
| for node in state.nodes(): | ||
| # Update symbol mapping | ||
|
|
||
| if isinstance(node, dace.nodes.NestedSDFG): | ||
| inner_sdfg = node.sdfg | ||
| to_repl = str(loop_var) in node.symbol_mapping | ||
| if to_repl: | ||
| v = node.symbol_mapping.pop(str(loop_var)) | ||
| v_symexpr = dace.symbolic.SymExpr(v) | ||
|
ThrudPrimrose marked this conversation as resolved.
Outdated
|
||
| node.symbol_mapping[str(next_ssa_loop_var)] = replace_symbol_by_name( | ||
| v_symexpr, loop_var, next_ssa_loop_var) | ||
|
|
||
| # Now we can replace what is inside | ||
| to_repl |= str(loop_var) in inner_sdfg.symbols | ||
| if to_repl: | ||
| self._repl_recursive(inner_sdfg, loop_var, next_ssa_loop_var) | ||
|
|
||
| def _apply_recursive(self, sdfg: dace.SDFG): | ||
| for cfg in sdfg.all_control_flow_regions(): | ||
| if isinstance(cfg, LoopRegion): | ||
| loop_var = cfg.loop_variable | ||
| loop_end = f"({loop_analysis.get_loop_end(cfg)})" # Inclusive | ||
| next_ssa_loop_var = f"{SSALoopIterators.FOR_IT_NAME}_{SSALoopIterators.loop_var_counter}" | ||
| # Replace loop variable with next_ssa_loop_var in the loop body, | ||
| # and assign loop_var = loop_end at the end of the loop | ||
| self._repl_recursive(cfg, loop_var, next_ssa_loop_var) | ||
|
|
||
| # Assign to the variable after the loop end | ||
| parent_graph = cfg.parent_graph | ||
| parent_graph.add_state_after(cfg, | ||
|
ThrudPrimrose marked this conversation as resolved.
Outdated
|
||
| f"SSA_loop_var_reconstruction_{SSALoopIterators.loop_var_counter}", | ||
| assignments={loop_var: loop_end}) | ||
|
|
||
| SSALoopIterators.loop_var_counter += 1 | ||
|
|
||
| for state in sdfg.all_states(): | ||
| for node in state.nodes(): | ||
| if isinstance(node, dace.nodes.NestedSDFG): | ||
| self._apply_recursive(node.sdfg) | ||
|
|
||
| def apply_pass(self, sdfg: dace.SDFG, _) -> Optional[int]: | ||
| self._apply_recursive(sdfg) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| """ | ||
| Unit tests for SSALoopIterators pass. | ||
| """ | ||
| import dace | ||
| import numpy as np | ||
| import pytest | ||
| from dace.sdfg.state import LoopRegion | ||
| from dace.transformation.passes.ssa_loop_iterators import SSALoopIterators | ||
| from dace.transformation.passes.analysis import loop_analysis | ||
|
|
||
|
|
||
| @dace.program | ||
| def foo(A: dace.float64[10, 10], idx: dace.int32[10, 10], B: dace.float64[5, 10]): | ||
| for i in range(5): | ||
| for j, k in dace.map[0:10, 0:10]: | ||
| A[j, k] = 1.1 * A[j, k] + 1.2 * B[i, idx[j, k]] | ||
|
|
||
|
|
||
| def test_nested_sdfg_symbol_mapping(): | ||
| """ | ||
| The map inside the loop body becomes a nested SDFG. | ||
| The loop variable `i` must appear in the nested SDFG's symbol_mapping. | ||
| After SSALoopIterators, the symbol_mapping should reference the new | ||
| SSA name (_it_N), not the original `i`. | ||
| """ | ||
| SSALoopIterators.loop_var_counter = 0 | ||
|
|
||
| sdfg = foo.to_sdfg(simplify=False) | ||
|
|
||
| # Before: confirm `i` is the loop variable and appears in a nested SDFG mapping | ||
| loops_before = [cfg for cfg in sdfg.all_control_flow_regions() if isinstance(cfg, LoopRegion)] | ||
| assert len(loops_before) == 1 | ||
| assert loops_before[0].loop_variable == 'i' | ||
|
|
||
| found_i_in_mapping = False | ||
| for state in sdfg.all_states(): | ||
| for node in state.nodes(): | ||
| if isinstance(node, dace.nodes.NestedSDFG): | ||
| if 'i' in node.symbol_mapping: | ||
| found_i_in_mapping = True | ||
| assert found_i_in_mapping, "Expected 'i' in nested SDFG symbol_mapping before pass" | ||
|
|
||
| # Apply pass | ||
| SSALoopIterators().apply_pass(sdfg, None) | ||
| sdfg.validate() | ||
|
|
||
| # After: the nested SDFG symbol_mapping should have _it_0, not i | ||
| for state in sdfg.all_states(): | ||
| for node in state.nodes(): | ||
| if isinstance(node, dace.nodes.NestedSDFG): | ||
| assert 'i' not in node.symbol_mapping, \ | ||
| f"Original loop var 'i' should not be in symbol_mapping, got {node.symbol_mapping}" | ||
| assert '_it_0' in node.symbol_mapping, \ | ||
| f"SSA loop var '_it_0' should be in symbol_mapping, got {node.symbol_mapping}" | ||
|
|
||
| # Verify correctness | ||
| A = np.random.rand(10, 10) | ||
| idx = np.random.randint(0, 10, size=(10, 10), dtype=np.int32) | ||
| B = np.random.rand(5, 10) | ||
|
|
||
| A_ref = A.copy() | ||
| for i in range(5): | ||
| for j in range(10): | ||
| for k in range(10): | ||
| A_ref[j, k] = 1.1 * A_ref[j, k] + 1.2 * B[i, idx[j, k]] | ||
|
|
||
| csdfg = sdfg.compile() | ||
| csdfg(A=A, idx=idx, B=B) | ||
| assert np.allclose(A, A_ref), f"Max error: {np.max(np.abs(A - A_ref))}" | ||
|
|
||
|
|
||
| # ============================================================================ | ||
| # Test 2: Loop variable used after the loop (reconstruction check) | ||
| # ============================================================================ | ||
| @dace.program | ||
| def loop_var_used_after(A: dace.float64[10], B: dace.float64[10]): | ||
| for i in range(10): | ||
| A[i] = 2.0 * B[i] | ||
| # After the loop, i should be 9. The pass should insert | ||
| # an assignment i = loop_end - 1 so downstream usage is correct. | ||
|
|
||
|
|
||
| def test_loop_var_reconstruction(): | ||
| """ | ||
| After SSALoopIterators, a reconstruction state should assign | ||
| the original loop variable to (loop_end - 1) so that any | ||
| downstream use of the variable sees the correct final value. | ||
| """ | ||
| SSALoopIterators.loop_var_counter = 0 | ||
|
|
||
| sdfg = loop_var_used_after.to_sdfg(simplify=False) | ||
|
|
||
| SSALoopIterators().apply_pass(sdfg, None) | ||
| sdfg.validate() | ||
|
|
||
| # Check that a reconstruction state was added | ||
| reconstruction_states = [ | ||
| s for s in sdfg.all_states() if hasattr(s, 'label') and 'SSA_loop_var_reconstruction' in s.label | ||
| ] | ||
| assert len(reconstruction_states) == 1, f"Expected 1 reconstruction state, found {len(reconstruction_states)}" | ||
|
|
||
| # Check that assignment is correct | ||
| loops = [cfg for cfg in sdfg.all_control_flow_regions() if isinstance(cfg, LoopRegion)] | ||
| assert len(loops) == 1 | ||
| loop = loops[0] | ||
|
|
||
| out_edges = loop.parent_graph.out_edges(loop) | ||
| assert len(out_edges) == 1 | ||
|
|
||
| assignments = out_edges[0].data.assignments | ||
| assert 'i' in assignments, f"Expected assignment to 'i', got {assignments}" | ||
| assert str( | ||
| assignments['i'] | ||
| ) == f"({(str(loop_analysis.get_loop_end(loop)))})", f"Expected loop end assignment, got {assignments['i']}" | ||
|
|
||
| # Verify correctness | ||
| A = np.zeros(10) | ||
| B = np.random.rand(10) | ||
| csdfg = sdfg.compile() | ||
| csdfg(A=A, B=B) | ||
| assert np.allclose(A, 2.0 * B) | ||
|
|
||
|
|
||
| # ============================================================================ | ||
| # Test 3: Nested loops — both variables should be renamed independently | ||
| # ============================================================================ | ||
| @dace.program | ||
| def nested_loops(A: dace.float64[8, 8]): | ||
| for i in range(8): | ||
| for j in range(8): | ||
| A[i, j] = A[i, j] + 1.0 | ||
|
|
||
|
|
||
| def test_nested_loops(): | ||
| """ | ||
| Two nested LoopRegions with variables i and j. | ||
| Both should be renamed to distinct SSA names (_it_0, _it_1), | ||
| and both should get reconstruction states. | ||
| """ | ||
| SSALoopIterators.loop_var_counter = 0 | ||
|
|
||
| sdfg = nested_loops.to_sdfg(simplify=False) | ||
|
|
||
| # Before: should have 2 loop regions | ||
| loops_before = [cfg for cfg in sdfg.all_control_flow_regions() if isinstance(cfg, LoopRegion)] | ||
| assert len(loops_before) == 2 | ||
| loop_vars_before = {l.loop_variable for l in loops_before} | ||
| assert loop_vars_before == {'i', 'j'} | ||
|
|
||
| SSALoopIterators().apply_pass(sdfg, None) | ||
| sdfg.validate() | ||
|
|
||
| # Verify correctness | ||
| A = np.random.rand(8, 8) | ||
| A_ref = A.copy() + 1.0 | ||
| csdfg = sdfg.compile() | ||
| csdfg(A=A) | ||
| assert np.allclose(A, A_ref), f"Max error: {np.max(np.abs(A - A_ref))}" | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| test_nested_sdfg_symbol_mapping() | ||
| test_loop_var_reconstruction() | ||
| test_nested_loops() |
Uh oh!
There was an error while loading. Please reload this page.