Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e06220e
SSA loop iterators
ThrudPrimrose Mar 16, 2026
6e5f3cf
Refactor
ThrudPrimrose Mar 16, 2026
7720a18
Merge branch 'main' into ssa_loop_iterators
ThrudPrimrose Apr 15, 2026
8fe2dec
Rename to unique_loop_iterators; address review
ThrudPrimrose May 15, 2026
69ece11
Merge branch 'main' into ssa_loop_iterators
ThrudPrimrose May 15, 2026
33618d8
Address comments
ThrudPrimrose May 15, 2026
1c18af1
Drop _rename_symbol_by_name; replace_dict already handles it
ThrudPrimrose May 15, 2026
ab99a6e
Add large nested map/for/for/map stress test
ThrudPrimrose May 15, 2026
f7c5e49
Make docstrings pure ASCII
ThrudPrimrose May 15, 2026
b35f879
Fix Sphinx inline-literal backticks in docstrings
ThrudPrimrose May 15, 2026
5b0ec86
Drop single-backtick role; use double-backtick literal
ThrudPrimrose May 15, 2026
96e2cdc
Flank every double-backtick literal with spaces
ThrudPrimrose May 15, 2026
707f603
Reconcile backticks to repo/Sphinx convention (no forced pre-punctuat…
ThrudPrimrose May 15, 2026
0054bdb
Small refactor
ThrudPrimrose May 15, 2026
a85d03e
Small text fix2
ThrudPrimrose May 15, 2026
2f04f7b
Cleanup dead loop-var declaration when post-value epilogue is disabled
ThrudPrimrose May 20, 2026
968c873
test: port post-value epilogue test cleanup from yakup/dev
ThrudPrimrose May 20, 2026
3c32722
Fix dead-symbol cleanup gate: use ``used_symbols(False)`` not ``free_…
ThrudPrimrose May 20, 2026
82f5d9f
fix(unique-loop-iterators): skip already-unique iterators (idempotent…
ThrudPrimrose May 21, 2026
6d76968
UniqueLoopIterators: deterministic scan-seeded counter + duplicate di…
ThrudPrimrose May 22, 2026
3e4be70
refactor(unique-loop-iterators): accept assign_loop_iterator_post_val…
ThrudPrimrose May 22, 2026
aa4aab0
fix(unique-loop-iterators): seed the counter past MapEntry parameters…
ThrudPrimrose May 26, 2026
a3c7a67
UniqueLoopIterators: debloat docstrings + add map/loop name-conflict …
ThrudPrimrose May 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions dace/transformation/passes/ssa_loop_iterators.py
Comment thread
ThrudPrimrose marked this conversation as resolved.
Outdated
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved.
import dace

from dace.sdfg.state import LoopRegion
from dace.transformation import pass_pipeline as ppl
from dace.sdfg import utils as sdutil
from typing import Optional
import copy
from dace.sdfg.state import ControlFlowRegion
from dace.transformation.passes.analysis import loop_analysis
from dace.transformation.transformation import explicit_cf_compatible

import sympy as sp
from typing import Union


def replace_symbol_by_name(expr: sp.Basic, old_name: str, new: Union[str, sp.Basic]) -> sp.Basic:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have symbolic replace and subs etc.
Why the new function?

"""
Replace all symbols in `expr` whose .name matches `old_name`,
regardless of assumptions, with `new`.
"""
if isinstance(new, str):
new = sp.Symbol(new)
repl = {s: new for s in expr.free_symbols if s.name == old_name}
if not repl:
return expr
return expr.subs(repl)


@dace.properties.make_properties
@explicit_cf_compatible
class SSALoopIterators(ppl.Pass):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for this class name

loop_var_counter = 0
FOR_IT_NAME = "_it"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be a class variable or just a global constant

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this one is a global constant, I updated accordingly.


def modifies(self) -> ppl.Modifies:
return ppl.Modifies.AccessNodes | ppl.Modifies.Memlets

def should_reapply(self, modified: ppl.Modifies) -> bool:
return False

def _repl_recursive(self, cfg: ControlFlowRegion | dace.SDFG, loop_var: str, next_ssa_loop_var: str):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like dace.sdfg.replace.replace_dict would do a good job here, if applied selectively

# What about Nested SDFGs? Do we need to update symbol mapping?
cfg.replace_meta_accesses({loop_var: next_ssa_loop_var})
cfg.replace_dict({loop_var: next_ssa_loop_var})

for state in cfg.all_states():
for node in state.nodes():
# Update symbol mapping

if isinstance(node, dace.nodes.NestedSDFG):
inner_sdfg = node.sdfg
to_repl = str(loop_var) in node.symbol_mapping
if to_repl:
v = node.symbol_mapping.pop(str(loop_var))
v_symexpr = dace.symbolic.SymExpr(v)
Comment thread
ThrudPrimrose marked this conversation as resolved.
Outdated
node.symbol_mapping[str(next_ssa_loop_var)] = replace_symbol_by_name(
v_symexpr, loop_var, next_ssa_loop_var)

# Now we can replace what is inside
to_repl |= str(loop_var) in inner_sdfg.symbols
if to_repl:
self._repl_recursive(inner_sdfg, loop_var, next_ssa_loop_var)

def _apply_recursive(self, sdfg: dace.SDFG):
for cfg in sdfg.all_control_flow_regions():
if isinstance(cfg, LoopRegion):
loop_var = cfg.loop_variable
loop_end = f"({loop_analysis.get_loop_end(cfg)})" # Inclusive
next_ssa_loop_var = f"{SSALoopIterators.FOR_IT_NAME}_{SSALoopIterators.loop_var_counter}"
# Replace loop variable with next_ssa_loop_var in the loop body,
# and assign loop_var = loop_end at the end of the loop
self._repl_recursive(cfg, loop_var, next_ssa_loop_var)

# Assign to the variable after the loop end
parent_graph = cfg.parent_graph
parent_graph.add_state_after(cfg,
Comment thread
ThrudPrimrose marked this conversation as resolved.
Outdated
f"SSA_loop_var_reconstruction_{SSALoopIterators.loop_var_counter}",
assignments={loop_var: loop_end})

SSALoopIterators.loop_var_counter += 1

for state in sdfg.all_states():
for node in state.nodes():
if isinstance(node, dace.nodes.NestedSDFG):
self._apply_recursive(node.sdfg)

def apply_pass(self, sdfg: dace.SDFG, _) -> Optional[int]:
self._apply_recursive(sdfg)
164 changes: 164 additions & 0 deletions tests/passes/ssa_loop_iterators_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""
Unit tests for SSALoopIterators pass.
"""
import dace
import numpy as np
import pytest
from dace.sdfg.state import LoopRegion
from dace.transformation.passes.ssa_loop_iterators import SSALoopIterators
from dace.transformation.passes.analysis import loop_analysis


@dace.program
def foo(A: dace.float64[10, 10], idx: dace.int32[10, 10], B: dace.float64[5, 10]):
for i in range(5):
for j, k in dace.map[0:10, 0:10]:
A[j, k] = 1.1 * A[j, k] + 1.2 * B[i, idx[j, k]]


def test_nested_sdfg_symbol_mapping():
"""
The map inside the loop body becomes a nested SDFG.
The loop variable `i` must appear in the nested SDFG's symbol_mapping.
After SSALoopIterators, the symbol_mapping should reference the new
SSA name (_it_N), not the original `i`.
"""
SSALoopIterators.loop_var_counter = 0

sdfg = foo.to_sdfg(simplify=False)

# Before: confirm `i` is the loop variable and appears in a nested SDFG mapping
loops_before = [cfg for cfg in sdfg.all_control_flow_regions() if isinstance(cfg, LoopRegion)]
assert len(loops_before) == 1
assert loops_before[0].loop_variable == 'i'

found_i_in_mapping = False
for state in sdfg.all_states():
for node in state.nodes():
if isinstance(node, dace.nodes.NestedSDFG):
if 'i' in node.symbol_mapping:
found_i_in_mapping = True
assert found_i_in_mapping, "Expected 'i' in nested SDFG symbol_mapping before pass"

# Apply pass
SSALoopIterators().apply_pass(sdfg, None)
sdfg.validate()

# After: the nested SDFG symbol_mapping should have _it_0, not i
for state in sdfg.all_states():
for node in state.nodes():
if isinstance(node, dace.nodes.NestedSDFG):
assert 'i' not in node.symbol_mapping, \
f"Original loop var 'i' should not be in symbol_mapping, got {node.symbol_mapping}"
assert '_it_0' in node.symbol_mapping, \
f"SSA loop var '_it_0' should be in symbol_mapping, got {node.symbol_mapping}"

# Verify correctness
A = np.random.rand(10, 10)
idx = np.random.randint(0, 10, size=(10, 10), dtype=np.int32)
B = np.random.rand(5, 10)

A_ref = A.copy()
for i in range(5):
for j in range(10):
for k in range(10):
A_ref[j, k] = 1.1 * A_ref[j, k] + 1.2 * B[i, idx[j, k]]

csdfg = sdfg.compile()
csdfg(A=A, idx=idx, B=B)
assert np.allclose(A, A_ref), f"Max error: {np.max(np.abs(A - A_ref))}"


# ============================================================================
# Test 2: Loop variable used after the loop (reconstruction check)
# ============================================================================
@dace.program
def loop_var_used_after(A: dace.float64[10], B: dace.float64[10]):
for i in range(10):
A[i] = 2.0 * B[i]
# After the loop, i should be 9. The pass should insert
# an assignment i = loop_end - 1 so downstream usage is correct.


def test_loop_var_reconstruction():
"""
After SSALoopIterators, a reconstruction state should assign
the original loop variable to (loop_end - 1) so that any
downstream use of the variable sees the correct final value.
"""
SSALoopIterators.loop_var_counter = 0

sdfg = loop_var_used_after.to_sdfg(simplify=False)

SSALoopIterators().apply_pass(sdfg, None)
sdfg.validate()

# Check that a reconstruction state was added
reconstruction_states = [
s for s in sdfg.all_states() if hasattr(s, 'label') and 'SSA_loop_var_reconstruction' in s.label
]
assert len(reconstruction_states) == 1, f"Expected 1 reconstruction state, found {len(reconstruction_states)}"

# Check that assignment is correct
loops = [cfg for cfg in sdfg.all_control_flow_regions() if isinstance(cfg, LoopRegion)]
assert len(loops) == 1
loop = loops[0]

out_edges = loop.parent_graph.out_edges(loop)
assert len(out_edges) == 1

assignments = out_edges[0].data.assignments
assert 'i' in assignments, f"Expected assignment to 'i', got {assignments}"
assert str(
assignments['i']
) == f"({(str(loop_analysis.get_loop_end(loop)))})", f"Expected loop end assignment, got {assignments['i']}"

# Verify correctness
A = np.zeros(10)
B = np.random.rand(10)
csdfg = sdfg.compile()
csdfg(A=A, B=B)
assert np.allclose(A, 2.0 * B)


# ============================================================================
# Test 3: Nested loops — both variables should be renamed independently
# ============================================================================
@dace.program
def nested_loops(A: dace.float64[8, 8]):
for i in range(8):
for j in range(8):
A[i, j] = A[i, j] + 1.0


def test_nested_loops():
"""
Two nested LoopRegions with variables i and j.
Both should be renamed to distinct SSA names (_it_0, _it_1),
and both should get reconstruction states.
"""
SSALoopIterators.loop_var_counter = 0

sdfg = nested_loops.to_sdfg(simplify=False)

# Before: should have 2 loop regions
loops_before = [cfg for cfg in sdfg.all_control_flow_regions() if isinstance(cfg, LoopRegion)]
assert len(loops_before) == 2
loop_vars_before = {l.loop_variable for l in loops_before}
assert loop_vars_before == {'i', 'j'}

SSALoopIterators().apply_pass(sdfg, None)
sdfg.validate()

# Verify correctness
A = np.random.rand(8, 8)
A_ref = A.copy() + 1.0
csdfg = sdfg.compile()
csdfg(A=A)
assert np.allclose(A, A_ref), f"Max error: {np.max(np.abs(A - A_ref))}"


if __name__ == '__main__':
test_nested_sdfg_symbol_mapping()
test_loop_var_reconstruction()
test_nested_loops()
Loading