diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 0f5d7cf13d..6a16cb72ef 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -1427,20 +1427,17 @@ def _used_symbols_internal(self, free_syms=free_syms, used_before_assignment=used_before_assignment, with_contents=with_contents) - # Expand array-descriptor stride/shape/offset symbols into the free - # set. Without this, a ``ConditionalBlock`` guard or memlet subset - # referencing ``A[i, j]`` leaves the symbols used in ``A`` 's strides - # out of the computed free-symbol set, causing - # ``generate_nsdfg_header`` to emit a nested function signature - # missing those symbols, ceating an invalid SDFG. + # A used array needs its stride/shape/offset symbols in the free set, but a + # merely-declared one must not leak its shape symbol into the signature + # (issue #2382). ``read_and_write_sets`` already reports exactly the arrays + # that are used -- read or written, including those referenced only by a + # code-block guard/condition -- so expand the extent symbols of those alone. res_free, res_defined, res_before = result if with_contents: - for desc in self.arrays.values(): - res_free |= {str(s) for s in desc.used_symbols(all_symbols)} - # Don't drag in symbols that are genuinely defined inside this - # SDFG (e.g., LoopRegion loop variables); keep only the ones - # outside ``defined_syms``. - res_free -= res_defined + read_set, write_set = self.read_and_write_sets() + for name in (read_set | write_set) & self.arrays.keys(): + res_free |= {str(s) for s in self.arrays[name].used_symbols(all_symbols)} + res_free -= res_defined # drop symbols defined inside (e.g. loop vars) return res_free, res_defined, res_before def get_all_toplevel_symbols(self) -> Set[str]: diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index c3596e8f4f..5d7309a95f 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -921,16 +921,20 @@ def unordered_arglist(self, } if top_source_edge.src.data not in descs else {}) elif isinstance(edge.dst, nd.ExitNode) and isinstance(edge.src, (nd.AccessNode, nd.CodeNode)): - # Same case as above, but for outgoing Memlets. - # NOTE: We have to use a memlet tree here, because the data could potentially - # go to multiple sources. We have to do it this way, because if we would call - # `memlet_tree()` here, then we would just get the edge back. + # Same case as above, but for outgoing Memlets. The Memlet leaving the + # scope may be source-relative (naming the inner transient rather than + # the external array), so resolve the written array from the memlet + # tree's root -- the outermost-scope node, i.e. the destination the + # data fans out to (fall back to the Memlet's data otherwise). additional_descs = {} connector_to_look = "OUT_" + edge.dst_conn[3:] for oedge in self.graph.out_edges_by_connector(edge.dst, connector_to_look): - if ((not oedge.data.is_empty()) and (oedge.data.data not in descs) - and (oedge.data.data not in additional_descs)): - additional_descs[oedge.data.data] = sdfg.arrays[oedge.data.data] + if oedge.data.is_empty(): + continue + root_dst = self.graph.memlet_tree(oedge).root().edge.dst + dst_name = root_dst.data if isinstance(root_dst, nd.AccessNode) else oedge.data.data + if dst_name not in descs and dst_name not in additional_descs: + additional_descs[dst_name] = sdfg.arrays[dst_name] else: # Case is ignored. @@ -1643,6 +1647,20 @@ def symbols_defined_at(self, node: nd.Node) -> Dict[str, dtypes.typeclass]: for e in sdfg.edges(): symbols.update(e.data.new_symbols(sdfg, symbols)) + # Add the loop variables of the control-flow loops enclosing this state, + # outermost first. Without this only global, inter-state-edge and dataflow-scope + # (map) symbols are seen; a node inside a LoopRegion must also see the loop + # variable as defined -- e.g. so memlet propagation keeps a ``jk``-indexed + # nested-SDFG access parametric instead of widening it to the whole array. + enclosing_loops = [] + cfg = self.parent_graph + while cfg is not None: + if isinstance(cfg, LoopRegion) and cfg.loop_variable: + enclosing_loops.append(cfg) + cfg = cfg.parent_graph + for loop in reversed(enclosing_loops): + symbols.update(loop.new_symbols(symbols)) + # Find scopes this node is situated in sdict = self.scope_dict() scope_list = [] diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index 888c7e77c9..116a2412be 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -889,20 +889,27 @@ def validate_state(state: 'dace.sdfg.SDFGState', ) # Verify that source and destination subsets contain the same - # number of elements + # number of elements. Only AccessNode endpoints expose a ``.data`` + # descriptor whose ``veclen`` participates in this check; scope + # nodes (NestedSDFG, MapEntry/Exit, ConsumeEntry/Exit) route data + # through connectors and contribute ``veclen = 1`` to the count. if not e.data.allow_oob and e.data.other_subset is not None and not ( (isinstance(src_node, nd.AccessNode) and isinstance(sdfg.arrays[src_node.data], dt.Stream)) or (isinstance(dst_node, nd.AccessNode) and isinstance(sdfg.arrays[dst_node.data], dt.Stream))): - src_expr = (e.data.src_subset.num_elements() * sdfg.arrays[src_node.data].veclen) - dst_expr = (e.data.dst_subset.num_elements() * sdfg.arrays[dst_node.data].veclen) + src_veclen = sdfg.arrays[src_node.data].veclen if isinstance(src_node, nd.AccessNode) else 1 + dst_veclen = sdfg.arrays[dst_node.data].veclen if isinstance(dst_node, nd.AccessNode) else 1 + src_expr = e.data.src_subset.num_elements() * src_veclen + dst_expr = e.data.dst_subset.num_elements() * dst_veclen if symbolic.inequal_symbols(src_expr, dst_expr): error = InvalidSDFGEdgeError('Dimensionality mismatch between src/dst subsets', sdfg, state_id, eid) # NOTE: Make an exception for Views and reference sets from dace.sdfg import utils - if (isinstance(sdfg.arrays[src_node.data], dt.View) and utils.get_view_edge(state, src_node) is e): + if (isinstance(src_node, nd.AccessNode) and isinstance(sdfg.arrays[src_node.data], dt.View) + and utils.get_view_edge(state, src_node) is e): warnings.warn(error.message) continue - if (isinstance(sdfg.arrays[dst_node.data], dt.View) and utils.get_view_edge(state, dst_node) is e): + if (isinstance(dst_node, nd.AccessNode) and isinstance(sdfg.arrays[dst_node.data], dt.View) + and utils.get_view_edge(state, dst_node) is e): warnings.warn(error.message) continue if e.dst_conn == 'set': diff --git a/dace/symbolic.py b/dace/symbolic.py index 74600aa84f..f0418ccb08 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -3,6 +3,7 @@ import contextlib from collections import Counter from functools import lru_cache +import math import sympy import pickle import re @@ -564,11 +565,15 @@ def _typed_constant_suffix(dtype: dtypes.typeclass) -> str: def _format_float(value: float) -> str: - # Shortest round-trip form, keeping one fractional digit (5.0, not 5 or 5.000...). - s = f'{float(value):.15g}' + # ``repr`` for finite Python floats is the shortest decimal that + # round-trips through ``float()`` -- guaranteed idempotent under + # save->load->save and at most 17 significant digits for fp64. + f = float(value) + s = repr(f) if 'e' in s or 'E' in s: return s if '.' not in s: + # Keep one fractional digit so an integer-valued float stays floating-point. return s + '.0' int_part, frac_part = s.split('.') return f'{int_part}.{frac_part.rstrip("0") or "0"}' @@ -985,6 +990,14 @@ def sympy_numeric_fix(expr): """ Fix for printing out integers as floats with ".00000000". Converts the float constants in a given expression to integers. """ if not isinstance(expr, sympy.Basic) or isinstance(expr, sympy.Number): + # Preserve a finite float -- sympy.Float, Python float, or numpy float -- + # so an integer-valued float like 1.0 stays 1.0 and is never collapsed to + # int 1 below: that would change its type (min(x, 1) mixes double and int) + # and round-trip through SymPy as a mistyped Min/Max. 0.0 was only spared by + # the ``expr != 0`` clause; a Python float 1.0 hit the int() collapse. + # Non-finite values (+-1.8e308 -> inf) fall through to the overflow path. + if isinstance(expr, (sympy.Float, float, numpy.floating)) and math.isfinite(float(expr)): + return expr if isinstance(expr, sympy.Float) else sympy.Float(expr) try: # NOTE: If expr is ~ 1.8e308, i.e. infinity, `numpy.int64(expr)` # will throw OverflowError (which we want). @@ -1970,7 +1983,13 @@ def _print_Integer(self, expr): return super()._print_Integer(expr) def _print_Float(self, expr): - return _format_float(float(sympy_numeric_fix(expr))) + nf = sympy_numeric_fix(expr) + if not math.isfinite(float(nf)): + # The value exceeds a C double (e.g. Fortran ``HUGE``, just over the max): + # let sympy print its own shortest decimal instead of overflowing through + # ``float()`` to a spurious ``inf`` (which would then render as ``inf.0``). + return super()._print_Float(nf) + return _format_float(float(nf)) def _print_Add(self, expr): flat_args = [] @@ -2044,7 +2063,11 @@ def _serialize_symbolic_uncached(expr: Union[SymbolicType, int, float, numpy.num if isinstance(expr, int) and not isinstance(expr, bool): return str(expr) if isinstance(expr, float): - return sympy.printing.str.sstr(expr) + # Route through the shared formatter so a Python float reaches the same + # repr-based shortest-round-trip path the sympy.Float branch uses below. + # Otherwise sympy's default sstr emits a 15-sig-digit form that fails + # the SDFG save->load->save equality check. + return _format_float(expr) if isinstance(expr, sympy.Basic): return DaceSympySerializer().doprint(expr) return str(expr) @@ -2223,10 +2246,14 @@ def __init__(self, arrays, cpp_mode=False, *args, **kwargs): self._settings['full_prec'] = False def _print_Float(self, expr): + # Shortest round-tripping form, always keeping one fractional digit so an + # integer-valued float stays floating-point (``5.0``, not ``5``). nf = sympy_numeric_fix(expr) - if isinstance(nf, int) or nf != expr: - return self._print(nf) - return super()._print_Float(expr) + if not math.isfinite(float(nf)): + # Exceeds a C double (e.g. Fortran ``HUGE``): keep sympy's shortest + # decimal rather than overflowing to ``inf`` (rendered as ``inf.0``). + return super()._print_Float(nf) + return _format_float(float(nf)) def _print_TypedConstant(self, expr): value = self._print(expr.value) @@ -2287,6 +2314,44 @@ def _print_Function(self, expr): def _print_Mod(self, expr): return '((%s) %% (%s))' % (self._print(expr.args[0]), self._print(expr.args[1])) + def _print_floor(self, expr): + """sympy ``floor(...)`` printer. + + sympy's ``//`` operator on symbolic integers (e.g. ``(LEN - 1) // 8``) + simplifies to ``floor(LEN/8 - 1/8)`` where ``1/8`` becomes a + ``Rational(1, 8)``. Without this override the printer emits a literal + ``floor(LEN/8 - 1/8)`` which in C++ collapses ``1 / 8`` to ``0`` via + integer division, so the floor argument becomes ``LEN/8`` instead of + ``(LEN - 1) / 8`` -- the loop bound silently overshoots by one. + + Recombine: if the floor argument is an addition of fractions with a + common denominator, reassemble the numerator and emit a single + ``((numerator) / (denominator))`` integer division. Otherwise fall + through to the math-library ``floor(...)`` call. + """ + if not self.cpp_mode: + return super()._print_Function(expr) if hasattr(super(), "_print_Function") else super()._print_floor(expr) + arg = expr.args[0] + # Try to combine to a single ``Rational(num, den)``: when arg is + # ``a/b + c/d + ...`` sympy's ``.together()`` rewrites to a + # single fraction over a common denominator. ``as_numer_denom`` + # then splits it cleanly. + try: + arg_together = arg.together() + num, den = arg_together.as_numer_denom() + except Exception: + num, den = None, None + if num is not None and den is not None: + try: + den_int = int(den) + except (TypeError, ValueError): + den_int = None + if den_int is not None and den_int != 1 and den_int != 0: + return '((%s) / (%s))' % (self._print(num), self._print(den)) + # Fallback: pure-real floor (e.g. ``floor(sin(x))``); emit the + # math-library call. + return 'floor(%s)' % self._print(arg) + def _print_Equality(self, expr): return '((%s) == (%s))' % (self._print(expr.args[0]), self._print(expr.args[1])) diff --git a/dace/transformation/dataflow/map_fusion_vertical.py b/dace/transformation/dataflow/map_fusion_vertical.py index 89883425cb..65db2c7a3c 100644 --- a/dace/transformation/dataflow/map_fusion_vertical.py +++ b/dace/transformation/dataflow/map_fusion_vertical.py @@ -279,8 +279,109 @@ def can_be_applied( if not (exclusive_outputs or shared_outputs): return False + # NOTE: NestedSDFGs in the producer's body whose InOut connectors match + # an intermediate's data name are handled by ``_split_inout_for_intermediate`` + # below in ``apply`` -- we split the connector inside the NestedSDFG + # (rename the inner read-side accesses to a fresh array bound to a new + # input connector) so the standard rename machinery can rewire the + # output-only connector without producing the mismatched-InOut + # validation error. v1 splits only when every inner read AN of the + # InOut name has ``in_degree == 0`` (the clean one-RMW-tasklet shape) + # and refuses otherwise. + intermediate_names = { + e.dst.data + for e in (exclusive_outputs | shared_outputs) if isinstance(e.dst, nodes.AccessNode) + } + if intermediate_names: + first_scope = graph.scope_subgraph(first_map_entry, include_entry=False, include_exit=False) + for inner in first_scope.nodes(): + if not isinstance(inner, nodes.NestedSDFG): + continue + inout_conns = set(inner.in_connectors) & set(inner.out_connectors) + for name in inout_conns & intermediate_names: + if not self._inout_split_is_safe(inner, name): + return False + return True + @staticmethod + def _inout_split_is_safe(nsdfg: nodes.NestedSDFG, name: str) -> bool: + """``True`` iff every inner AccessNode of ``name`` is either a pure + read source (``in_degree == 0`` and ``out_degree > 0``) or a pure + write sink (``in_degree > 0`` and ``out_degree == 0``). Mixed-mode + accesses (read AN whose downstream is also written in the same state, + e.g. ``a -> ... -> a`` chains within one state) would require more + elaborate use-def analysis than v1 handles, so we refuse those. + """ + inner_sdfg = nsdfg.sdfg + if inner_sdfg is None: + return False + for state in inner_sdfg.all_states(): + for n in state.nodes(): + if not isinstance(n, nodes.AccessNode) or n.data != name: + continue + in_d, out_d = state.in_degree(n), state.out_degree(n) + if not ((in_d == 0 and out_d > 0) or (in_d > 0 and out_d == 0)): + return False + return True + + @staticmethod + def _split_inout_for_intermediate(graph: dace.SDFGState, sdfg: dace.SDFG, first_map_entry: nodes.MapEntry, + intermediate_names: Set[str]) -> None: + """For each NestedSDFG inside ``first_map_entry``'s scope whose InOut + connectors include any of ``intermediate_names``, split the connector: + + 1. Allocate a fresh inner array ``__map_fusion_split_`` (same + shape / dtype as the original ``name`` inside the NestedSDFG). + 2. Rename every inner read-side AccessNode of ``name`` (``in_degree == + 0``) to the fresh name; rename the memlet ``data`` on its outgoing + edges. + 3. Drop the InOut input connector ``name`` from the outer NestedSDFG + node and add the fresh ``__map_fusion_split_`` input + connector with the same dtype. + 4. Redirect the outer input edge feeding the old ``name`` input + connector to the new connector. + + After this rewrite the NestedSDFG's ``name`` connector is OUTPUT-ONLY + and the standard MapFusion rename machinery can rename it to + ``__map_fusion_`` without producing an InOut-mismatch. + """ + first_scope = graph.scope_subgraph(first_map_entry, include_entry=False, include_exit=False) + for inner in list(first_scope.nodes()): + if not isinstance(inner, nodes.NestedSDFG): + continue + inout_conns = set(inner.in_connectors) & set(inner.out_connectors) + for orig in sorted(inout_conns & intermediate_names): + inner_sdfg = inner.sdfg + if inner_sdfg is None or orig not in inner_sdfg.arrays: + continue + # 1. Fresh inner array. + new_name = f"__map_fusion_split_{orig}" + while new_name in inner_sdfg.arrays: + new_name += "_" + new_desc = copy.deepcopy(inner_sdfg.arrays[orig]) + inner_sdfg._arrays[new_name] = new_desc + # 2. Rename inner read-side accesses + their out-edge memlets. + for st in inner_sdfg.all_states(): + for n in list(st.nodes()): + if not isinstance(n, nodes.AccessNode) or n.data != orig: + continue + if st.in_degree(n) > 0: + continue + n.data = new_name + for e in st.out_edges(n): + if e.data is not None and e.data.data == orig: + e.data.data = new_name + # 3. Replace the InOut input connector on the NestedSDFG node. + in_type = inner.in_connectors.get(orig) + inner.remove_in_connector(orig) + inner.add_in_connector(new_name, in_type) + # 4. Redirect the outer edge feeding the old input connector. + for e in list(graph.in_edges(inner)): + if e.dst_conn == orig: + graph.add_edge(e.src, e.src_conn, inner, new_name, e.data) + graph.remove_edge(e) + def apply( self, graph: Union[dace.SDFGState, dace.SDFG], @@ -326,6 +427,16 @@ def apply( pure_outputs, exclusive_outputs, shared_outputs = output_partition assert (not self.require_exclusive_intermediates) or (len(shared_outputs) == 0) + # If any intermediate's data name is shared with an InOut connector of a + # NestedSDFG in the producer's body, split the connector so the standard + # rename machinery below produces a valid InOut-free shape. + intermediate_names = { + e.dst.data + for e in (exclusive_outputs | shared_outputs) if isinstance(e.dst, nodes.AccessNode) + } + if intermediate_names: + self._split_inout_for_intermediate(graph, sdfg, first_map_entry, intermediate_names) + # Now perform the actual rewiring, we handle each partition separately. if len(exclusive_outputs) != 0: self.handle_intermediate_set( diff --git a/dace/transformation/dataflow/redundant_array_copying.py b/dace/transformation/dataflow/redundant_array_copying.py index cda12f04cc..8832f78eea 100644 --- a/dace/transformation/dataflow/redundant_array_copying.py +++ b/dace/transformation/dataflow/redundant_array_copying.py @@ -1,6 +1,7 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. """ Contains redundant array removal transformations. """ +from dace import subsets from dace.sdfg import nodes from dace.sdfg import utils as sdutil from dace.sdfg.sdfg import SDFG @@ -8,9 +9,36 @@ from dace.transformation import transformation as pm +def _is_full_copy(graph: SDFGState, edge, src_desc, dst_desc) -> bool: + """Whether ``edge`` copies the whole source array onto the whole destination array. + + A side whose subset is ``None`` is treated as covering the full extent of + that side's descriptor; the edge is a full identity iff both effective + subsets equal ``Range.from_array()``. + """ + src_subset = edge.data.get_src_subset(edge, graph) + dst_subset = edge.data.get_dst_subset(edge, graph) + src_ok = src_subset is None or src_subset == subsets.Range.from_array(src_desc) + dst_ok = dst_subset is None or dst_subset == subsets.Range.from_array(dst_desc) + return src_ok and dst_ok + + +def _shapes_match(a, b) -> bool: + """Whether two shape tuples have the same rank and equal symbolic extents.""" + return len(a) == len(b) and all(x == y for x, y in zip(a, b)) + + class RedundantArrayCopyingIn(pm.SingleStateTransformation): - """ Implements the redundant array removal transformation. Removes the first and second access nodeds - in pattern A -> B -> A + """Fold an ``A -> B -> C`` chain of full identity copies into writers-of-``A`` writing straight to ``C``. + + Matches three sequential AccessNodes where ``A`` and ``B`` are transient. + ``apply`` removes ``A`` and ``B`` and redirects every writer of ``A`` onto + ``C``, renaming the memlet data so the redirected edges describe ``C``. + + The fold is only sound when ``B`` has exactly one consumer (``C``), ``A`` + and ``C`` share rank, shape and storage, and both copy edges are full + identity. A partial copy in the chain would be silently widened to a full + one by the rename and corrupt the region of ``C`` the chain never wrote. """ in_array = pm.PatternNode(nodes.AccessNode) @@ -25,23 +53,44 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): in_array = self.in_array med_array = self.med_array out_array = self.out_array + in_desc = in_array.desc(sdfg) + med_desc = med_array.desc(sdfg) + out_desc = out_array.desc(sdfg) - # Safety first (could be relaxed) - if not (graph.out_degree(in_array) == 1 and graph.in_degree(med_array) == 1 and graph.out_degree(med_array)): + # Degree gates: ``in`` and ``med`` are about to be removed, so each + # must have a single outgoing edge into the chain. ``med`` having any + # other consumer would leave it dangling without a source after the fold. + if graph.out_degree(in_array) != 1: + return False + if graph.in_degree(med_array) != 1 or graph.out_degree(med_array) != 1: return False - # Make sure that the removal candidates are transient - if not (in_array.desc(sdfg).transient and med_array.desc(sdfg).transient): + # ``in`` and ``med`` are the nodes the fold deletes; only transients + # can be deleted without losing externally-visible storage. + if not (in_desc.transient and med_desc.transient): return False - # Make sure that both arrays are using the same storage location - if in_array.desc(sdfg).storage != out_array.desc(sdfg).storage: + # The redirected writers of ``in`` keep their original subsets and end + # up writing to ``out``; this is only meaningful if ``in`` and ``out`` + # share storage location (so the writers can address it the same way) + # and the three arrays share rank and shape (so the subsets remain + # valid). ``med`` may live on a different device (this is exactly the + # CPU-GPU-CPU staging chain the pass is designed to short-circuit). + if in_desc.storage != out_desc.storage: + return False + if not (_shapes_match(in_desc.shape, med_desc.shape) and _shapes_match(in_desc.shape, out_desc.shape)): return False - # Only apply if arrays are of same shape (no need to modify memlet subset) - if len(in_array.desc(sdfg).shape) != len(out_array.desc(sdfg).shape) or any( - i != o for i, o in zip(in_array.desc(sdfg).shape, - out_array.desc(sdfg).shape)): + # The two copy edges in the chain must be full identity copies. A + # partial copy here would be silently widened to a full one when the + # writers of ``in`` are redirected onto ``out``, corrupting the region + # of ``out`` the chain never actually wrote. The unique edges are + # guaranteed by the degree gates above; fetch the first. + in_med = graph.edges_between(in_array, med_array)[0] + med_out = graph.edges_between(med_array, out_array)[0] + if not _is_full_copy(graph, in_med, in_desc, med_desc): + return False + if not _is_full_copy(graph, med_out, med_desc, out_desc): return False return True diff --git a/dace/transformation/dataflow/trivial_tasklet_elimination.py b/dace/transformation/dataflow/trivial_tasklet_elimination.py index 21eb50f796..a86f9ddf34 100644 --- a/dace/transformation/dataflow/trivial_tasklet_elimination.py +++ b/dace/transformation/dataflow/trivial_tasklet_elimination.py @@ -73,6 +73,16 @@ def apply(self, graph, sdfg): out_edge = graph.edges_between(tasklet, write)[0] graph.remove_edge(in_edge) graph.remove_edge(out_edge) - out_edge.data.other_subset = in_edge.data.subset - graph.add_edge(read, in_edge.src_conn, write, out_edge.dst_conn, out_edge.data) + if self.expr_index == 1: + # Source is a MapEntry: the surviving edge leaves the map's + # ``OUT_`` connector, so its memlet must keep the read-side + # data and subset (e.g. an offset access ``a[i + k]``) and carry + # the write subset in ``other_subset``. Reusing the write memlet + # here would strand the read offset in ``other_subset`` and drop + # it when the map is later re-lowered to a loop. + in_edge.data.other_subset = out_edge.data.subset + graph.add_edge(read, in_edge.src_conn, write, out_edge.dst_conn, in_edge.data) + else: + out_edge.data.other_subset = in_edge.data.subset + graph.add_edge(read, in_edge.src_conn, write, out_edge.dst_conn, out_edge.data) graph.remove_node(tasklet) diff --git a/dace/transformation/dataflow/wcr_conversion.py b/dace/transformation/dataflow/wcr_conversion.py index 937d51288f..f765426262 100644 --- a/dace/transformation/dataflow/wcr_conversion.py +++ b/dace/transformation/dataflow/wcr_conversion.py @@ -16,26 +16,45 @@ class AugAssignToWCR(transformation.SingleStateTransformation): """ Converts an augmented assignment ("a += b", "a = a + b") into a tasklet with a write-conflict resolution. + + A third pattern handles the *copy-wrapped* read-modify-write shape where + the accumulator slice is materialized into a scalar transient before the + combining tasklet and copied back after it + (``arr[S] -> copy_in -> tasklet -> copy_out -> arr[S]``). Those + materialization copies cannot be folded away by the redundant-array + passes because ``arr`` is both read and written in the same state; + recognising the shape directly is what lets loop-carried reductions + become WCR writes and so parallelize via ``LoopToMap``. """ input = transformation.PatternNode(nodes.AccessNode) tasklet = transformation.PatternNode(nodes.Tasklet) output = transformation.PatternNode(nodes.AccessNode) map_entry = transformation.PatternNode(nodes.MapEntry) map_exit = transformation.PatternNode(nodes.MapExit) + copy_in = transformation.PatternNode(nodes.AccessNode) + copy_out = transformation.PatternNode(nodes.AccessNode) _EXPRESSIONS = ['+', '-', '*', '^', '%'] #, '/'] _FUNCTIONS = ['min', 'max'] _EXPR_MAP = {'-': ('+', '-({expr})'), '/': ('*', '((decltype({expr}))1)/({expr})')} _PYOP_MAP = {ast.Add: '+', ast.Sub: '-', ast.Mult: '*', ast.BitXor: '^', ast.Mod: '%', ast.Div: '/'} + # Order-independent combines accepted for the copy-wrapped RMW pattern. + # Subtraction is admitted only with the accumulator on the left (checked + # at match time): ``a - b1 - b2 == a - (b1 + b2)`` is order-independent. + _RMW_BINOPS = {ast.Add: '+', ast.Sub: '-', ast.Mult: '*'} @classmethod def expressions(cls): return [ sdutil.node_path_graph(cls.input, cls.tasklet, cls.output), - sdutil.node_path_graph(cls.input, cls.map_entry, cls.tasklet, cls.map_exit, cls.output) + sdutil.node_path_graph(cls.input, cls.map_entry, cls.tasklet, cls.map_exit, cls.output), + sdutil.node_path_graph(cls.input, cls.copy_in, cls.tasklet, cls.copy_out, cls.output), ] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + if expr_index == 2: + return self._can_be_applied_rmw_copy(graph, sdfg) + inarr = self.input tasklet = self.tasklet outarr = self.output @@ -141,6 +160,9 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return False def apply(self, state: SDFGState, sdfg: SDFG): + if self.expr_index == 2: + return self._apply_rmw_copy(state, sdfg) + input: nodes.AccessNode = self.input tasklet: nodes.Tasklet = self.tasklet output: nodes.AccessNode = self.output @@ -269,6 +291,139 @@ def apply(self, state: SDFGState, sdfg: SDFG): # At this point we are leading to an access node again and can # traverse further up + def _classify_rmw_rhs(self, rhs, acc_conn, tasklet): + """Classify the combining tasklet's RHS as an order-independent reduction + of the accumulator (read on connector ``acc_conn``) with one other input. + + :param rhs: the RHS AST node of the tasklet's single assignment. + :param acc_conn: the input connector carrying the loaded accumulator. + :param tasklet: the combining tasklet (for its input-connector set). + :returns: ``(op, other_operand_ast, acc_on_left)`` where ``op`` is the + WCR operator symbol / function name, or ``(None, None, None)`` + if the RHS is not such a combine. + """ + in_conns = set(tasklet.in_connectors) + if isinstance(rhs, ast.BinOp) and type(rhs.op) in self._RMW_BINOPS: + op = self._RMW_BINOPS[type(rhs.op)] + left, right = rhs.left, rhs.right + if (isinstance(left, ast.Name) and left.id == acc_conn and isinstance(right, ast.Name) + and right.id in in_conns and right.id != acc_conn): + return op, right, True + if (isinstance(right, ast.Name) and right.id == acc_conn and isinstance(left, ast.Name) + and left.id in in_conns and left.id != acc_conn): + return op, left, False + elif (isinstance(rhs, ast.Call) and isinstance(rhs.func, ast.Name) and rhs.func.id in self._FUNCTIONS + and len(rhs.args) == 2 and all(isinstance(a, ast.Name) for a in rhs.args)): + a0, a1 = rhs.args + if a0.id == acc_conn and a1.id in in_conns and a1.id != acc_conn: + return rhs.func.id, a1, True + if a1.id == acc_conn and a0.id in in_conns and a0.id != acc_conn: + return rhs.func.id, a0, False + return None, None, None + + def _rmw_copy_edges(self, graph): + """Return the four spine edges ``(load, ine, oute, store)`` of the + copy-wrapped RMW, or ``None`` if the spine is not a clean single path.""" + load = graph.edges_between(self.input, self.copy_in) + ine = graph.edges_between(self.copy_in, self.tasklet) + oute = graph.edges_between(self.tasklet, self.copy_out) + store = graph.edges_between(self.copy_out, self.output) + if len(load) != 1 or len(ine) != 1 or len(oute) != 1 or len(store) != 1: + return None + return load[0], ine[0], oute[0], store[0] + + def _can_be_applied_rmw_copy(self, graph, sdfg): + """Match ``arr[S] -> copy_in -> tasklet -> copy_out -> arr[S]`` where the + tasklet combines the loaded accumulator with one other input via an + order-independent reduction. The copy nodes must be private single-use + transients and the load / store must hit the same accumulator slice.""" + inp, cin, tlet, cout, out = (self.input, self.copy_in, self.tasklet, self.copy_out, self.output) + if inp.data != out.data: + return False + # Only free RMWs: an enclosing map index would mean disjoint writes + # (no conflict, hence no reduction to resolve). + if graph.entry_node(tlet) is not None: + return False + # copy_in / copy_out must be private single-use transients. + for node in (cin, cout): + desc = sdfg.arrays.get(node.data) + if desc is None or not desc.transient: + return False + if graph.in_degree(cin) != 1 or graph.out_degree(cin) != 1: + return False + if graph.in_degree(cout) != 1 or graph.out_degree(cout) != 1: + return False + + edges = self._rmw_copy_edges(graph) + if edges is None: + return False + load, ine, oute, store = edges + if load.data.wcr is not None or store.data.wcr is not None: + return False + # Same accumulator slice loaded and stored. + acc_subset = store.data.get_dst_subset(store, graph) + load_subset = load.data.get_src_subset(load, graph) + if acc_subset is None or load_subset is None or acc_subset != load_subset: + return False + + # The tasklet must be a single Python assignment with exactly two data + # inputs (accumulator + increment) and one data output. + if tlet.language is not dtypes.Language.Python or len(tlet.code.code) != 1: + return False + node = tlet.code.code[0] + if (not isinstance(node, ast.Assign) or len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name) + or node.targets[0].id != oute.src_conn): + return False + data_in = [e for e in graph.in_edges(tlet) if e.data is not None and not e.data.is_empty()] + data_out = [e for e in graph.out_edges(tlet) if e.data is not None and not e.data.is_empty()] + if len(data_in) != 2 or len(data_out) != 1: + return False + + op, _, acc_on_left = self._classify_rmw_rhs(node.value, ine.dst_conn, tlet) + if op is None: + return False + if op == '-' and not acc_on_left: + return False + return True + + def _apply_rmw_copy(self, state: SDFGState, sdfg: SDFG): + """Rewrite the copy-wrapped RMW into a WCR write: drop the accumulator + load, emit only the increment from the tasklet, and write it straight + into the accumulator slice with the reduction WCR (the scalar copy-out + transient is removed).""" + inp, cin, tlet, cout, out = (self.input, self.copy_in, self.tasklet, self.copy_out, self.output) + load, ine, oute, store = self._rmw_copy_edges(state) + + node = tlet.code.code[0] + op, other_ast, _ = self._classify_rmw_rhs(node.value, ine.dst_conn, tlet) + + # The tasklet now emits only the increment (accumulator operand dropped). + tlet.code.code = [ast.copy_location(ast.Assign(targets=node.targets, value=other_ast), node)] + + # Write the increment straight into the accumulator with the WCR, + # bypassing the scalar copy-out transient. + acc_subset = store.data.get_dst_subset(store, state) + wcr = f'lambda a,b: {op}(a, b)' if op in self._FUNCTIONS else f'lambda a,b: a {op} b' + state.remove_edge(oute) + state.remove_edge(store) + state.add_edge(tlet, oute.src_conn, out, store.dst_conn, Memlet(data=out.data, subset=acc_subset, wcr=wcr)) + if state.degree(cout) == 0: + state.remove_node(cout) + + # Drop the accumulator load path (input -> copy_in -> tasklet); the + # WCR now supplies the previous accumulator value at write time. + acc_conn = ine.dst_conn + state.remove_edge(ine) + state.remove_edge(load) + if acc_conn in tlet.in_connectors: + tlet.remove_in_connector(acc_conn) + if state.degree(cin) == 0: + state.remove_node(cin) + if state.degree(inp) == 0: + state.remove_node(inp) + + propagate_memlets_state(sdfg, state) + def isolate_tasklet( self, state: SDFGState, @@ -285,7 +440,7 @@ def isolate_tasklet( for e in state.memlet_path(edge): nodes_to_move.add(e.src) orig_edges.add(e) - if isinstance(e.src, nodes.AccessNode) and isinstance(e.src.desc(sdfg), data.View): + if isinstance(e.src, nodes.AccessNode) and isinstance(e.src.desc(state.sdfg), data.View): assert state.in_degree(e.src) > 0 view_edges = sdutil.get_all_view_edges(state, e.src) for edge in view_edges: diff --git a/dace/transformation/interstate/condition_fusion.py b/dace/transformation/interstate/condition_fusion.py index 0b595dabaa..728f30ada1 100644 --- a/dace/transformation/interstate/condition_fusion.py +++ b/dace/transformation/interstate/condition_fusion.py @@ -4,7 +4,7 @@ from dace import sdfg as sd, properties from dace.properties import CodeBlock from dace.sdfg import utils as sdutil -from dace.sdfg.state import ControlFlowRegion, ConditionalBlock +from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion, ConditionalBlock from dace.transformation import transformation as xf @@ -191,10 +191,20 @@ def fuse_consecutive_conditions(self, sdfg: sd.SDFG, cblck1: ConditionalBlock, c for j, node in enumerate(cfg.nodes()): node.label = f"{node.label}_{j}" - # Fix SDFG parents + # Fix SDFG parents. ``set_nested_sdfg_parent_references`` walks every + # NestedSDFG and sets ``node.sdfg.parent_sdfg`` on the *inner* SDFGs; + # the follow-up loop repairs ``.sdfg`` on the OUTER container blocks + # (``SDFGState`` / ``ControlFlowRegion`` / ``ConditionalBlock``) whose + # ``.sdfg`` attribute names the containing SDFG. The + # ``ControlFlowBlock`` isinstance check replaces the previous + # ``hasattr(node, "sdfg")`` -- ``hasattr`` also matched ``NestedSDFG`` + # nodes, whose ``.sdfg`` is the *inner* SDFG (an + # ``SDFGReferenceProperty`` with a setter), so the assignment + # overwrote the inner-SDFG slot with the outer container and produced + # a graph cycle that infinite-recurses ``all_nodes_recursive``. sdutil.set_nested_sdfg_parent_references(sdfg) for node, parent in sdfg.all_nodes_recursive(): - if hasattr(node, "sdfg"): + if isinstance(node, ControlFlowBlock): node.sdfg = parent.sdfg def fuse_nested_conditions(self, sdfg: sd.SDFG, cblck1: ConditionalBlock): @@ -277,8 +287,12 @@ def fuse_nested_conditions(self, sdfg: sd.SDFG, cblck1: ConditionalBlock): for j, node in enumerate(cfg.nodes()): node.label = f"{node.label}_{j}" - # Fix SDFG parents + # Fix SDFG parents. The ``ControlFlowBlock`` isinstance check (not + # ``hasattr(node, "sdfg")``) is required so ``NestedSDFG`` nodes -- + # whose ``.sdfg`` is the *inner* SDFG -- are skipped. Writing the + # outer SDFG into a NestedSDFG's inner-SDFG slot creates a graph + # cycle that infinite-recurses ``all_nodes_recursive``. sdutil.set_nested_sdfg_parent_references(sdfg) for node, parent in sdfg.all_nodes_recursive(): - if hasattr(node, "sdfg"): + if isinstance(node, ControlFlowBlock): node.sdfg = parent.sdfg diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 7ee9584843..c11ab2d616 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -39,6 +39,69 @@ def _check_range(subset, a, itersym, b, step): return found +def _nested_writes_iter_indexed(nsdfg_node, conn, itersym, a, b, step) -> bool: + """Whether every write to ``conn``'s array *inside* ``nsdfg_node`` is + indexed by the (mapped) iteration variable. + + A loop body that is a ``NestedSDFG`` propagates a whole-array external + write memlet (the union over the loop), which hides a per-iteration + write. This looks past the connector: the inner write subsets are + rewritten through the node's ``symbol_mapping`` into the outer iteration + symbol and each must match the same ``a*i+b`` pattern + :func:`_check_range` enforces. Conservative: requires at least one inner + write to the array and that *all* of them pass (nested ``NestedSDFG`` s + are checked recursively, composing the symbol maps). + + Example -- ``for i: if c: b[i] = a[i] + 1`` after a + ``LoopToMap -> MapToForLoop`` round-trip (the guard forced a + ``NestedSDFG`` body):: + + for i in 0:N: # the loop being re-parallelized + state: + a ──► [ NestedSDFG loop_body ] ──► b + symbol_mapping {i: i, N: N} + │ external write connector memlet: b[0:N] + │ (correct union over the loop -- has no `i`, + │ so _check_range(b[0:N]) FAILS -> refuse) + └─ inner: if (c): + b[i] = a[i] + 1.0 ◄── real per- + iteration write + + _nested_writes_iter_indexed walks inside loop_body, finds the inner + write ``b[i]``, maps it through symbol_mapping ({i: i}) to the outer + ``b[i]``, and _check_range matches ``1*i + 0`` -> independence + proven -> LoopToMap fires (the round-trip recovers the map). + + :param nsdfg_node: The ``NestedSDFG`` node feeding the outer write. + :param conn: The output connector (== inner array name) being written. + :param itersym: The outer loop iteration symbol. + :returns: ``True`` iff every inner write to ``conn`` is iter-indexed. + """ + repl = {symbolic.symbol(k): symbolic.pystr_to_symbolic(str(v)) for k, v in nsdfg_node.symbol_mapping.items()} + found = False + for state in nsdfg_node.sdfg.all_states(): + for dn in state.data_nodes(): + if dn.data != conn or state.in_degree(dn) == 0: + continue + for e in state.in_edges(dn): + if e.data is None or e.data.wcr is not None: + return False + if isinstance(e.src, nodes.NestedSDFG): + if not _nested_writes_iter_indexed(e.src, e.src_conn, itersym, a, b, step): + return False + found = True + continue + dst_subset = e.data.get_dst_subset(e, state) + if dst_subset is None: + return False + outer = copy.deepcopy(dst_subset) + outer.replace(repl) + if not _check_range(outer, a, itersym, b, step): + return False + found = True + return found + + def _dependent_indices(itervar: str, subset: subsets.Subset) -> Set[int]: """ Finds the indices or ranges of a subset that depend on the iteration variable. Returns their index in the subset's indices/ranges list. @@ -56,6 +119,72 @@ def _sanitize_by_index(indices: Set[int], subset: subsets.Subset) -> subsets.Ran return subsets.Range([t for i, t in enumerate(subset.ndrange()) if i in indices]) +def _affine_coeffs(expr, itersym): + """ Return ``(a, b)`` with ``expr == a*itersym + b``, or ``None`` if + ``expr`` is not affine in ``itersym``. + """ + e = sp.expand(symbolic.pystr_to_symbolic(expr)) + a = e.coeff(itersym, 1) + b = e.coeff(itersym, 0) + if sp.simplify(e - (a * itersym + b)) != 0: + return None + return a, b + + +def _dim_provably_disjoint(idx1, idx2, itersym) -> bool: + """ True iff ``idx1`` at any iteration can never equal ``idx2`` at any + iteration, for any integer iterations and any loop bounds. + + Uses the linear-Diophantine solvability criterion: ``a1*i1 + b1 == + a2*i2 + b2`` has an integer solution iff ``gcd(a1, a2)`` divides + ``b2 - b1``. If it does not, the accesses never alias. + """ + f1 = _affine_coeffs(idx1, itersym) + f2 = _affine_coeffs(idx2, itersym) + if f1 is None or f2 is None: + return False + a1, b1 = f1 + a2, b2 = f2 + diff = sp.simplify(b2 - b1) + if not (a1.is_Integer and a2.is_Integer): + return False + if a1 == 0 and a2 == 0: + return diff.is_number and diff != 0 + g = sp.igcd(int(a1), int(a2)) + if g == 0: + return diff.is_number and diff != 0 + if not diff.is_number: + return False + if not diff.is_Integer: + return True + return sp.Integer(diff) % g != 0 + + +def _writes_may_overlap(m1: memlet.Memlet, m2: memlet.Memlet, itersym) -> bool: + """ Conservatively decide whether two write memlets to the same container + can address the same element on different loop iterations. Returns + ``False`` only if some subset dimension is provably disjoint (the + multidimensional element can then never coincide). + """ + nd1 = list(m1.subset.ndrange()) + nd2 = list(m2.subset.ndrange()) + if len(nd1) != len(nd2): + return True + for (b1, e1, _), (b2, e2, _) in zip(nd1, nd2): + if b1 != e1 or b2 != e2: # non-point range dimension: cannot decide here + continue + # Both writes index this dimension by the same injective function of the + # iteration variable: a collision there forces the two iterations equal, + # so the writes can only coincide within one iteration (ordered by program + # order in the map body), never across distinct iterations. + coeffs = _affine_coeffs(b1, itersym) + if coeffs is not None and coeffs[0] != 0 and sp.simplify(b1 - b2) == 0: + return False + if _dim_provably_disjoint(b1, b2, itersym): + return False + return True + + @properties.make_properties @xf.explicit_cf_compatible class LoopToMap(xf.MultiStateTransformation): @@ -71,13 +200,19 @@ def expressions(cls): return [sdutil.node_path_graph(cls.loop)] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + + def refuse(reason: str) -> bool: + """Refuse the match. The reason is dropped; transformation diagnostics live + in the upstream pipeline driver, not in transformation bodies.""" + return False + # If loop information cannot be determined, fail. start = loop_analysis.get_init_assignment(self.loop) end = loop_analysis.get_loop_end(self.loop) step = loop_analysis.get_loop_stride(self.loop) itervar = self.loop.loop_variable if start is None or end is None or step is None or itervar is None: - return False + return refuse(f"loop information incomplete - start={start}, end={end}, step={step}, itervar={itervar}") sset = {} sset.update(sdfg.symbols) @@ -86,17 +221,37 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # We may only convert something to map if the bounds are all integer-derived types. Otherwise most map schedules # except for sequential would be invalid. if not t in dtypes.INTEGER_TYPES: - return False + return refuse(f"loop bounds are not integer types - result_type={t}") # Loops containing break, continue, or returns may not be turned into a map. for blk in self.loop.all_control_flow_blocks(): if isinstance(blk, (BreakBlock, ContinueBlock, ReturnBlock)): - return False + if not permissive: + return refuse(f"loop body contains a {type(blk).__name__}") # We cannot handle symbols read from data containers unless they are scalar. for expr in (start, end, step): if symbolic.contains_sympy_functions(expr): - return False + return refuse(f"bound expression reads a non-scalar data container - expr={expr}") + + # Refuse when the loop's range (start/end/step) references a symbol + # that the loop body itself defines via an interstate-edge + # assignment. After conversion the body moves into a new + # ``loop_body`` NestedSDFG and the assignment goes with it, but the + # Map's range stays at the outer scope; the range then references a + # symbol only defined inside the new NSDFG, producing a + # ``Missing symbols on nested SDFG`` validation failure downstream. + range_syms: Set[str] = set() + for expr in (start, end, step): + try: + range_syms |= {str(s) for s in expr.free_symbols} + except AttributeError: + pass + body_assigned_syms: Set[str] = set() + for e in self.loop.all_interstate_edges(): + body_assigned_syms.update(e.data.assignments.keys()) + if range_syms & body_assigned_syms: + return refuse(f"loop range references symbol(s) {range_syms & body_assigned_syms} assigned inside the body") _, write_set = self.loop.read_and_write_sets() loop_states = set(self.loop.all_states()) @@ -105,7 +260,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Cannot have StructView in loop body for loop_state in loop_states: if [n for n in loop_state.data_nodes() if isinstance(n.desc(sdfg), dt.StructureView)]: - return False + return refuse(f"loop body contains a StructureView in state {loop_state}") # Collect symbol reads and writes from inter-state assignments in_order_loop_blocks = list( @@ -129,7 +284,8 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): if not k in fsyms: assigned_symbols.add(k) if assigned_symbols & used_before_assignment: - return False + return refuse("carried symbol dependency - " + f"{assigned_symbols & used_before_assignment} read before being assigned") symbols_that_may_be_used |= e.data.assignments.keys() @@ -166,7 +322,8 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # cannot race with another iteration's write. dst_subset = e.data.get_dst_subset(e, state) if not (dst_subset and _check_range(dst_subset, a, itersym, b, step)): - return False + return refuse(f"dynamic write to {dn.data} is not indexed by the iteration variable " + f"- dst_subset={dst_subset}") if e.data is None: continue @@ -176,12 +333,37 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # variable. The iteration variable must be used. if e.data.wcr is None: dst_subset = e.data.get_dst_subset(e, state) - if not (dst_subset and _check_range(dst_subset, a, itersym, b, step)) and not permissive: - return False + ok = bool(dst_subset) and _check_range(dst_subset, a, itersym, b, step) + # A NestedSDFG loop body propagates a whole-array + # external write memlet that hides an inner + # per-iteration write; look past the connector. + if not ok and isinstance(e.src, nodes.NestedSDFG): + ok = _nested_writes_iter_indexed(e.src, e.src_conn, itersym, a, b, step) + if not ok and not permissive: + return refuse(f"write to {dn.data} is not uniquely indexed by the iteration variable " + f"(needs an a*i+b subset) - dst_subset={dst_subset}") # End of check write_memlets[dn.data].append(e.data) + # Two writes with distinct affine subscripts into the same container can + # hit the same element on different iterations even when each is + # individually injective in the iteration variable (e.g. ``A[5*i]`` and + # ``A[3*i]`` collide at ``A[15]``). Parallelizing then reorders the + # colliding writes. Allow the pair only if some dimension is provably + # disjoint for all iterations (e.g. ``A[2*i]`` vs ``A[2*i+1]``). + for data, mmlts in write_memlets.items(): + distinct: Dict[str, memlet.Memlet] = {} + for m in mmlts: + if m.wcr is None: + distinct.setdefault(str(m.subset), m) + reps = list(distinct.values()) + for x in range(len(reps)): + for y in range(x + 1, len(reps)): + if _writes_may_overlap(reps[x], reps[y], itersym) and not permissive: + return refuse(f"writes {reps[x].subset} and {reps[y].subset} to {data} " + "may overlap across iterations") + # After looping over relevant writes, consider reads that may overlap for state in loop_states: for dn in state.data_nodes(): @@ -198,7 +380,8 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): src_subset = e.data.get_src_subset(e, state) if not self.test_read_memlet(sdfg, state, e, itersym, itervar, start, end, step, write_memlets, e.data, src_subset): - return False + return refuse(f"read-after-write conflict on {data} within the loop body " + f"- src_subset={src_subset}") # Consider reads in inter-state edges (could be in assignments or in condition) isread_set: Set[memlet.Memlet] = set() @@ -208,7 +391,8 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): if mmlt.data in write_memlets: if not self.test_read_memlet(sdfg, None, None, itersym, itervar, start, end, step, write_memlets, mmlt, mmlt.subset): - return False + return refuse(f"read-after-write conflict on {mmlt.data} via an inter-state edge " + f"- subset={mmlt.subset}") # Check that the iteration variable and other symbols are not used on other edges or blocks before they are # reassigned. @@ -218,7 +402,8 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): reassigned_symbols: Set[str] = None for oe in graph.out_edges(self.loop): if symbols_that_may_be_used & oe.data.read_symbols(): - return False + return refuse("loop-defined symbol(s) used after the loop on its outgoing edge - " + f"{symbols_that_may_be_used & oe.data.read_symbols()}") # Check for symbols that are set by all outgoing edges # TODO: Handle case of subset of out_edges if reassigned_symbols is None: @@ -238,13 +423,15 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Check state contents if symbols_that_may_be_used & block.free_symbols: - return False + return refuse(f"loop-defined symbol(s) used after the loop in block {block} - " + f"{symbols_that_may_be_used & block.free_symbols}") # Check inter-state edges reassigned_symbols = None for e in block.parent_graph.out_edges(block): if symbols_that_may_be_used & e.data.read_symbols(): - return False + return refuse("loop-defined symbol(s) used after the loop on an inter-state edge - " + f"{symbols_that_may_be_used & e.data.read_symbols()}") # Check for symbols that are set by all outgoing edges # TODO: Handle case of subset of out_edges @@ -274,7 +461,17 @@ def test_read_memlet(self, sdfg: SDFG, state: SDFGState, edge: gr.MultiConnector # If pointers are involved, give up return False if not _check_range(src_subset, a, itersym, b, step): - return False + # ``_check_range`` only accepts reads that MOVE with the iteration + # (some dimension ``a*i + b``, ``|a| >= 1``). A read that uses the + # iteration symbol but does not match that affine form is + # conservatively a conflict. But a loop-INVARIANT read (no iteration + # symbol at all) is only a conflict if it actually overlaps a write: + # ``a[0]`` is safe when the loop writes ``a[1:N]`` (the post-peel + # ``a[i] = a[0] + b[i]`` remainder), and is a real read-after-write + # only when it overlaps the write (``a[0]`` vs ``a[0:N]``). Defer + # both to the propagated-overlap check below. + if itersym in src_subset.free_symbols: + return False # Always use the source data container for the memlet test if state is not None and edge is not None: @@ -477,7 +674,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): for s, m in sdfg.parent_nsdfg_node.symbol_mapping.items(): if s not in cnode.symbol_mapping: cnode.symbol_mapping[s] = symbolic.pystr_to_symbolic(s) - nsdfg.add_symbol(s, sdfg.symbols[s]) + nsdfg.symbols[s] = sdfg.symbols[s] for name in read_set: r = body.add_read(name) body.add_edge(r, None, cnode, name, memlet.Memlet.from_array(name, sdfg.arrays[name])) diff --git a/dace/transformation/passes/symbol_propagation.py b/dace/transformation/passes/symbol_propagation.py index 5a4856707b..2eecc672f0 100644 --- a/dace/transformation/passes/symbol_propagation.py +++ b/dace/transformation/passes/symbol_propagation.py @@ -15,6 +15,58 @@ from dace.symbolic import pystr_to_symbolic, scalars +def _free_symbols(value) -> Set[str]: + """Free symbol names of an interstate-edge assignment value (RHS). + + :param value: The assignment RHS (a string), or ``None``. + :returns: The set of free symbol names; empty for ``None`` / unparseable. + """ + if value is None: + return set() + try: + return {str(s) for s in pystr_to_symbolic(value).free_symbols} + except Exception: + return set() + + +def _resolve(value, table: Dict[str, Any]): + """Substitute known symbol values from ``table`` into an assignment RHS. + + Interstate-edge assignments are simultaneous, so the RHSes of one edge are + resolved against the PRE-edge (incoming) symbol table -- a swap + ``{tx: y, ty: x}`` resolves ``tx`` to the old value of ``y`` and ``ty`` to + the old value of ``x``. Resolving here (rather than leaving raw + ``tx: 'y'`` strings to be chained later) collapses symbol-to-symbol chains + to constants/expressions up front, so a cyclic dependency (``x: tx, + tx: y, y: ty, ty: x``) never forms a substitution cycle. + + :param value: The assignment RHS (a string), or ``None``. + :param table: Known ``{symbol: value-string-or-None}`` mapping. + :returns: The resolved RHS string (or ``None`` for ``None`` input). + """ + if value is None: + return None + # Leave array-access values (``tbl[i]``) untouched: parsing them through + # sympy would turn ``tbl[i]`` into ``tbl(i)`` and lose the ``[`` that the + # downstream filter uses to drop non-propagatable nested-array accesses. + if "[" in value or "]" in value: + return value + try: + expr = pystr_to_symbolic(value) + repl = {} + for s in expr.free_symbols: + name = str(s) + known = table.get(name) + # Skip substituting array-access values for the same reason. + if known is not None and "[" not in known and "]" not in known: + repl[s] = pystr_to_symbolic(known) + if repl: + expr = expr.subs(repl) + return str(expr) + except Exception: + return value + + @dataclass(unsafe_hash=True) @properties.make_properties @transformation.explicit_cf_compatible @@ -66,10 +118,111 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[str]]: changed = True out_syms[cfg_blk] = new_out_syms - # Update symbols in the cfg_blk + # Update symbols in the cfg_blk, accumulating the symbols actually + # propagated (eliminated from a block / edge). The pipeline treats a + # non-None return as "this pass modified the SDFG" and a None return as + # "no change" (see ``Pipeline.apply_pass``); returning an honest set + # lets a FixedPointPipeline such as ``SimplifyPass`` converge instead of + # re-running this pass forever on a no-op. + propagated: Set[str] = set() for cfg_blk, parent in all_cfg_blks.items(): - self._update_syms(cfg_blk, parent, in_syms, out_syms) - return set() + propagated |= self._update_syms(cfg_blk, parent, in_syms, out_syms) + # Substitution leaves the *defining* iedge assignment (``k_plus_1 = klev + 1``) + # in place even after every consumer has been rewritten to use the resolved + # value. Sweep those dead assignments to a fixed point so the pass output + # is canonical (e.g. cloudsc's 346 bound-symbol ``+1`` assignments disappear + # once their downstream uses are gone). + eliminated = self._eliminate_dead_iedge_assignments(sdfg) + if eliminated: + propagated |= eliminated + return propagated if propagated else None + + def _eliminate_dead_iedge_assignments(self, sdfg: SDFG) -> Set[str]: + """Drop interstate-edge assignments whose LHS is no longer referenced anywhere. + + After :meth:`_update_syms` rewrites every use site in the dataflow graph, the + defining iedge assignment ``X = expr`` becomes dead -- *unless* ``X`` is still + referenced by an array descriptor's shape / strides / offset. The IR-level + ``replace_dict`` does not reach into descriptors, so those references survive + propagation and pin the iedge alive. Before deciding an iedge is dead, we + substitute ``X -> expr`` into the owning SDFG's descriptors (a semantic no-op + since the symbol IS that expression by construction), then sweep. + + Iterates to a fixed point so chained assignments (``a = klev + 1; b = a + 1``) + unravel from the leaves inward. + + :param sdfg: The SDFG to clean up. + :returns: The set of LHS names that were removed (empty if no change). + """ + removed: Set[str] = set() + while True: + this_round = self._eliminate_round(sdfg) + if not this_round: + break + removed |= this_round + return removed + + def _eliminate_round(self, sdfg: SDFG) -> Set[str]: + """One sweep of dead-iedge elimination across ``sdfg`` and its nested SDFGs. + + Substitutes propagatable iedge LHSes into the SDFG's descriptors first, since + ``SDFGState.free_symbols`` pulls array-shape symbols via the access nodes that + read those arrays (state.py:709). Without the descriptor substitution, the + symbol still reads as "used in IR" and the iedge never gets eliminated. + """ + eliminated: Set[str] = set() + for sd in sdfg.all_sdfgs_recursive(): + # Gather candidate substitutions: a symbol is propagatable if every iedge + # binding it agrees on the same RHS (no per-edge disagreement -> ambiguous) + # and the RHS does not self-reference (a self-reference like ``i = i + 1`` + # marks a loop-carried iter, which cannot be substituted out). + bindings: Dict[str, Optional[str]] = {} + for e in sd.all_interstate_edges(): + for lhs, rhs in e.data.assignments.items(): + if rhs is None or lhs in _free_symbols(rhs): + bindings[lhs] = None + continue + if lhs not in bindings: + bindings[lhs] = rhs + elif bindings[lhs] is not None and bindings[lhs] != rhs: + bindings[lhs] = None + safe_subs = {sym: rhs for sym, rhs in bindings.items() if rhs is not None} + + # Substitute every propagatable LHS into the SDFG's descriptors. Symbols + # whose value was already substituted everywhere will have no live shape + # reference after this; the dead-iedge sweep below then drops them. + if safe_subs: + sd.replace_dict(safe_subs, replace_keys=False, replace_in_graph=False) + + # Now compute the IR-level used set; this no longer includes the symbols + # that have been folded into descriptors. + used_in_ir: Set[str] = set() + for blk in sd.all_control_flow_blocks(): + used_in_ir |= {str(s) for s in blk.free_symbols} + for e in sd.all_interstate_edges(): + for rhs in e.data.assignments.values(): + used_in_ir |= _free_symbols(rhs) + if e.data.condition is not None: + try: + used_in_ir |= {str(s) for s in e.data.condition.get_free_symbols()} + except Exception: + pass + + sd_eliminated: Set[str] = set() + for e in sd.all_interstate_edges(): + for lhs in list(e.data.assignments.keys()): + if lhs not in used_in_ir: + del e.data.assignments[lhs] + sd_eliminated.add(lhs) + # Drop the now-orphaned declarations so nested-SDFG validation does not + # demand the symbol from outside. + if sd_eliminated: + still_bound = {k for ie in sd.all_interstate_edges() for k in ie.data.assignments.keys()} + for name in sd_eliminated: + if (name in sd.symbols and name not in still_bound and name not in used_in_ir): + del sd.symbols[name] + eliminated |= sd_eliminated + return eliminated # Given a cfg_blk, builds the incoming set of symbols def _get_in_syms( @@ -84,7 +237,46 @@ def _get_in_syms( new_in_syms = {} for i, edge in enumerate(parent.in_edges(cfg_blk)): sym_table = copy.deepcopy(out_syms[edge.src]) - sym_table.update(edge.data.assignments) + # Resolve this edge's RHSes against the PRE-edge table (simultaneous + # assignment semantics), then apply -- collapsing symbol chains and + # breaking cyclic dependencies instead of storing raw chained strings. + # A resolved value that references ANY symbol assigned on this SAME + # edge must stay LIVE (None) rather than be propagated: the edge's + # assignments fire simultaneously and rebind those symbols at + # runtime, so the resolved value (computed from the OLD values) + # differs from what a downstream use -- which sees the NEW values -- + # would compute. Propagating it would double-count the rebinding + # (e.g. on ``{m: t, n: t + 1}`` with ``t = m + 2``: ``m`` resolves to + # ``m + 2`` (self-ref) and ``n`` to ``m + 3`` (reads reassigned + # ``m``); both must stay live so ``B[m]`` / ``B[n]`` read the edge's + # outputs, not a re-applied expression). + edge_keys = set(edge.data.assignments.keys()) + # Resolve this edge's RHSes against the (un-invalidated) PRE-edge + # table: the assignments read their incoming values, and a carried + # value such as ``k = j + 1`` still denotes the pre-edge ``j`` here. + resolved = {} + for k, v in edge.data.assignments.items(): + rv = _resolve(v, sym_table) + if rv is not None and (_free_symbols(rv) & edge_keys): + rv = None + resolved[k] = rv + # A value CARRIED IN from the predecessor (already in ``sym_table``, + # not (re)assigned on this edge) that references a symbol this edge + # reassigns is now STALE for the downstream block: it was computed + # from that symbol's pre-edge value, but the edge rebinds the + # symbol, so the block past this edge sees the NEW value. + # Propagating the carried value would read the wrong (old) value + # (e.g. carrying ``k = j + 1`` past an edge ``j = j + 2`` would make + # ``c[k]`` read ``c[j + 1]`` against the reassigned ``j``, an + # off-by-two). Invalidate such entries (-> None / live) -- the same + # simultaneity rule the per-edge guard above applies to the edge's + # own assignments. Done AFTER resolving so the edge's own RHSes still + # see the carried pre-edge values. + for sym in list(sym_table.keys()): + val = sym_table[sym] + if sym not in edge_keys and val is not None and (_free_symbols(val) & edge_keys): + sym_table[sym] = None + sym_table.update(resolved) # Filter out symbols containing arrays accesses as they cannot be safely propagated (nested array accesses are not supported) sym_table = {k: v for k, v in sym_table.items() if v is None or ("[" not in v and "]" not in v)} @@ -111,8 +303,16 @@ def _get_in_syms( # Ignore SDFGs as nested SDFGs have symbol mappings if (parent.start_block == cfg_blk and not isinstance(parent, SDFG)) or (isinstance(parent, ConditionalBlock) and cfg_blk in parent.sub_regions()): - assert new_in_syms == {} - new_in_syms = in_syms[parent] + # A start / branch region normally has no in-edges, so the + # edge-accumulated table is empty and it inherits the parent's + # incoming symbols. On some cross-CFG shapes the block can already + # carry edge-accumulated symbols; combine conservatively + # (disagreements -> None) rather than assert emptiness, which + # crashed on those shapes. + if new_in_syms: + self._combine_syms(new_in_syms, in_syms[parent]) + else: + new_in_syms = in_syms[parent] # For LoopRegions, remove loop carried variables from the incoming symbols if isinstance(parent, LoopRegion): @@ -169,6 +369,17 @@ def _get_out_syms( self._combine_syms(new_out_syms, out_syms[n]) return new_out_syms + def _block_free_symbols(self, cfg_blk: ControlFlowBlock, parent: ControlFlowRegion) -> Set[str]: + """Names of symbols read by ``cfg_blk`` and by its outgoing edges. + + :param cfg_blk: The block to inspect. + :param parent: The block's parent region (for its out-edges). + :returns: The set of free-symbol names. + """ + free = {str(s) for s in cfg_blk.free_symbols} + free |= {str(s) for edge in parent.out_edges(cfg_blk) for s in edge.data.free_symbols} + return free + # Given a cfg_blk, updates the symbols in the cfg_blk def _update_syms( self, @@ -176,7 +387,7 @@ def _update_syms( parent: ControlFlowRegion, in_syms: Dict[ControlFlowBlock, Dict[str, Any]], out_syms: Dict[ControlFlowBlock, Dict[str, Any]], - ) -> None: + ) -> Set[str]: new_in_syms = copy.deepcopy(in_syms[cfg_blk]) new_out_syms = copy.deepcopy(out_syms[cfg_blk]) @@ -184,8 +395,26 @@ def _update_syms( new_in_syms = {sym: val for sym, val in new_in_syms.items() if val is not None} new_out_syms = {sym: val for sym, val in new_out_syms.items() if val is not None} + # Symbols this block could propagate, and the symbols it reads before + # substitution -- their set difference after substitution is what was + # actually eliminated (returned so the pipeline knows what changed). + candidates = set(new_in_syms) | set(new_out_syms) + if not candidates: + return set() + free_before = self._block_free_symbols(cfg_blk, parent) + + # Iteration cap: each pass resolves at least one more substitution + # level, so a legitimate (acyclic) substitution chain converges within + # ``#symbols`` passes. A CYCLIC value dependency (e.g. a swap + # ``x: tx, tx: y, y: ty, ty: x``) would otherwise oscillate the free- + # symbol set forever; the cap guarantees termination (leaving the + # cyclic symbols un-substituted, which is conservative and correct). + max_iters = len(new_in_syms) + len(new_out_syms) + 2 + changed = True - while changed: + iters = 0 + while changed and iters < max_iters: + iters += 1 changed = False free_sym = cfg_blk.free_symbols free_edge_sym = set([sym for edge in parent.out_edges(cfg_blk) for sym in edge.data.free_symbols]) @@ -201,15 +430,31 @@ def _update_syms( # Don't replace, as the nested CFBGs should inherit the symbols from their parent pass - # Also replace all symbols in the outgoing edges with their values + # Also replace all symbols in the outgoing edges with their values. + # Interstate-edge assignments are SIMULTANEOUS: a symbol read in an + # assignment RHS denotes its INCOMING value, not the value being + # assigned on the same edge. Substituting a propagated value that + # references a symbol which is itself a KEY on this edge would make + # the RHS read the edge's outgoing value -- a same-edge read-write + # race that validation rejects (e.g. substituting ``anext -> a + b`` + # into ``{b: a, a: anext}`` yields ``{b: a, a: a + b}``). Drop such + # colliding substitutions for that edge. for edge in parent.out_edges(cfg_blk): - edge.data.replace_dict(new_out_syms, replace_keys=False) + edge_keys = set(edge.data.assignments.keys()) + if edge_keys: + edge_subs = {s: v for s, v in new_out_syms.items() if not (_free_symbols(v) & edge_keys)} + else: + edge_subs = new_out_syms + edge.data.replace_dict(edge_subs, replace_keys=False) # Check if the symbols have changed new_free_edge_sym = set([sym for edge in parent.out_edges(cfg_blk) for sym in edge.data.free_symbols]) if free_sym != cfg_blk.free_symbols or free_edge_sym != new_free_edge_sym: changed = True + # The candidate symbols that are no longer read here were propagated. + return candidates & (free_before - self._block_free_symbols(cfg_blk, parent)) + # Combines two symbol dictionaries, setting the value to None if they don't agree. Directly modifies sym1 def _combine_syms(self, sym1: Dict[str, Any], sym2: Dict[str, Any]) -> None: for sym, val in sym2.items(): diff --git a/tests/passes/symbol_propagation_hard_test.py b/tests/passes/symbol_propagation_hard_test.py new file mode 100644 index 0000000000..38d0d57417 --- /dev/null +++ b/tests/passes/symbol_propagation_hard_test.py @@ -0,0 +1,1767 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +""" +Hard / adversarial unit tests for the :class:`SymbolPropagation` pass. + +The pass propagates symbols that were assigned to a single value forward through the +SDFG, substituting them into downstream blocks and edges to reduce the symbol count. +Its load-bearing assumption is that symbols only change on ``InterStateEdge`` +assignments. These tests stress that assumption with the patterns that tend to break +big real-world SDFGs: + +* chained inter-dependent symbols feeding array indices, +* conditional (branch-divergent) symbol values feeding indirection, +* interstate-edge conditions that themselves read a propagated symbol, +* loop-carried index symbols (in both ``LoopRegion`` loops and ``dace.map`` maps), +* double indirection / gather (a symbol read out of an array, then used as an index), +* the same symbol name reused with different values in sibling scopes. + +Every test builds a small deterministic SDFG, computes the reference result *before* +the pass, applies :meth:`SymbolPropagation.apply_pass`, asserts the SDFG is still valid, +re-runs it, and checks the results match. A test that exposes a genuine pass limitation +is marked :func:`pytest.mark.xfail` with a precise reason. +""" + +import numpy as np +import pytest + +import dace +from dace.properties import CodeBlock +from dace.sdfg.state import LoopRegion, ConditionalBlock, ControlFlowRegion +from dace.transformation.passes import SymbolPropagation + +# --------------------------------------------------------------------------- +# Python-frontend kernels (must be module-level: the frontend reads source). +# --------------------------------------------------------------------------- + + +@dace.program +def chained_index_range(B: dace.float64[64], C: dace.float64[1], idx: dace.int64): + """ + Chained inter-dependent index symbols feeding an array access (range form). + + ``idx2 = idx + 1; idx3 = idx2 + 1; C = B[idx3]`` expressed so the symbols + chain across interstate edges. + + :param B: Source array to gather from. + :param C: One-element output. + :param idx: Base index symbol. + """ + idx2 = idx + 1 + idx3 = idx2 + 1 + C[0] = B[idx3] + + +@dace.program +def chained_index_deep(B: dace.float64[64], C: dace.float64[1], idx: dace.int64): + """ + Longer chain of inter-dependent index symbols feeding an array access. + + :param B: Source array to gather from. + :param C: One-element output. + :param idx: Base index symbol. + """ + a = idx + 1 + b = a + 2 + c = b - 1 + d = c + a + C[0] = B[d] + + +@dace.program +def cond_index_diverge(A: dace.int64[1], B: dace.float64[64], C: dace.float64[1], idx: dace.int64): + """ + Conditional symbol assignment feeding indirection (the divergent case). + + ``idx3`` takes a different value on each branch, so no single value may be + propagated past the join point. + + :param A: One-element selector array (branch taken on ``A[0] > 0``). + :param B: Source array to gather from. + :param C: One-element output. + :param idx: Base index symbol. + """ + idx2 = idx + 1 + if A[0] > 0: + idx3 = idx2 + 1 + C[0] = B[idx3] + else: + idx3 = idx2 + 4 + C[0] = B[idx3] + + +@dace.program +def cond_index_diverge_join(A: dace.int64[1], B: dace.float64[64], C: dace.float64[1], idx: dace.int64): + """ + Branch-divergent symbol that is *used after* the join, not inside the branches. + + The use of ``idx3`` happens after the if/else merge, so the pass must not + have propagated either branch's value. + + :param A: One-element selector array. + :param B: Source array to gather from. + :param C: One-element output. + :param idx: Base index symbol. + """ + idx2 = idx + 1 + if A[0] > 0: + idx3 = idx2 + 1 + else: + idx3 = idx2 + 4 + C[0] = B[idx3] + + +@dace.program +def nested_cond_index(A: dace.int64[1], D: dace.int64[1], B: dace.float64[64], C: dace.float64[1], idx: dace.int64): + """ + Nested conditionals (``if A: if D: ...``) each diverging an index symbol. + + :param A: Outer selector array. + :param D: Inner selector array. + :param B: Source array to gather from. + :param C: One-element output. + :param idx: Base index symbol. + """ + idx2 = idx + 1 + if A[0] > 0: + if D[0] > 0: + idx3 = idx2 + 1 + else: + idx3 = idx2 + 2 + C[0] = B[idx3] + else: + idx3 = idx2 + 8 + C[0] = B[idx3] + + +@dace.program +def loop_carried_range(B: dace.float64[16], C: dace.float64[16], step: dace.int64): + """ + Loop-carried index symbol updated each iteration (range form). + + :param B: Source array (length >= 16). + :param C: Output array. + :param step: Constant increment applied to the carried index each iteration. + """ + idx = 0 + for i in range(16): + C[i] = B[idx % 16] + idx = idx + step + + +@dace.program +def map_chained_index(B: dace.float64[16, 64], C: dace.float64[16], idx: dace.int64): + """ + Chained inter-dependent index symbols inside a ``dace.map`` body. + + :param B: 2D source array. + :param C: Output, one element per map iteration. + :param idx: Base index symbol shared by all lanes. + """ + idx2 = idx + 1 + idx3 = idx2 + 1 + for i in dace.map[0:16]: + C[i] = B[i, idx3] + + +@dace.program +def sibling_scopes_reuse(B: dace.float64[64], C: dace.float64[2], idx: dace.int64): + """ + The same symbol name reused with different values in two sibling scopes. + + Two independent (sequential) range loops each define ``k`` from a different + expression; there must be no cross-contamination of propagated values. + + :param B: Source array. + :param C: Two-element output (one per sibling scope). + :param idx: Base index symbol. + """ + for _ in range(1): + k = idx + 1 + C[0] = B[k] + for _ in range(1): + k = idx + 5 + C[1] = B[k] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _all_assignment_values(sdfg: dace.SDFG) -> list: + """ + Collects every interstate-edge assignment value (RHS) in the SDFG, recursively. + + :param sdfg: The SDFG to scan. + :returns: A list of assignment RHS strings. + """ + vals = [] + for edge, _ in sdfg.all_edges_recursive(): + data = getattr(edge, "data", None) + if data is not None and hasattr(data, "assignments"): + vals.extend(str(v) for v in data.assignments.values()) + return vals + + +# --------------------------------------------------------------------------- +# Pattern 1: chained inter-dependent index symbols (value-preserving) +# --------------------------------------------------------------------------- + + +def test_chained_index_range_frontend(): + """Chained ``idx -> idx2 -> idx3`` indices via the Python frontend (range).""" + rng = np.random.default_rng(1) + B = rng.random(64) + sdfg = chained_index_range.to_sdfg(simplify=False) + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for base in (0, 5, 30): + expected = np.array([B[base + 2]]) + got = np.zeros(1) + sdfg(B=B.copy(), C=got, idx=base) + assert np.allclose(got, expected) + + +def test_chained_index_deep_frontend(): + """A 4-link symbol chain feeding an index; value must be unchanged by the pass.""" + rng = np.random.default_rng(2) + B = rng.random(64) + sdfg = chained_index_deep.to_sdfg(simplify=False) + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for base in (0, 3, 10): + # a=base+1, b=a+2, c=b-1, d=c+a -> d = (base+2) + (base+1) = 2*base + 3 + expected = np.array([B[2 * base + 3]]) + got = np.zeros(1) + sdfg(B=B.copy(), C=got, idx=base) + assert np.allclose(got, expected) + + +def test_chained_index_api(): + """ + SDFG-API build of a chained-index gather, comparing pre- vs post-pass results. + + ``i2 = i1 + 1; i3 = i2 + 2; out = B[i3]`` over plain states. + """ + sdfg = dace.SDFG("chained_api") + sdfg.add_array("B", [64], dace.float64) + sdfg.add_array("C", [1], dace.float64) + sdfg.add_symbol("i1", dace.int64) + sdfg.add_symbol("i2", dace.int64) + sdfg.add_symbol("i3", dace.int64) + + s0 = sdfg.add_state("s0", is_start_block=True) + s1 = sdfg.add_state("s1") + s2 = sdfg.add_state("s2") + sdfg.add_edge(s0, s1, dace.InterstateEdge(assignments={"i2": "i1 + 1"})) + sdfg.add_edge(s1, s2, dace.InterstateEdge(assignments={"i3": "i2 + 2"})) + tasklet = s2.add_tasklet("g", {"inp"}, {"out"}, "out = inp") + rd = s2.add_access("B") + wr = s2.add_access("C") + s2.add_edge(rd, None, tasklet, "inp", dace.Memlet("B[i3]")) + s2.add_edge(tasklet, "out", wr, None, dace.Memlet("C[0]")) + sdfg.validate() + + rng = np.random.default_rng(3) + B = rng.random(64) + expected = {} + for base in (0, 4, 20): + out = np.zeros(1) + sdfg(B=B.copy(), C=out, i1=base) + expected[base] = out.copy() + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for base in (0, 4, 20): + out = np.zeros(1) + sdfg(B=B.copy(), C=out, i1=base) + assert np.allclose(out, expected[base]) + + +# --------------------------------------------------------------------------- +# Pattern 2: conditional (branch-divergent) symbol feeding indirection +# --------------------------------------------------------------------------- + + +def test_cond_index_diverge_frontend(): + """Different ``idx3`` per branch, each used *inside* its branch (frontend).""" + rng = np.random.default_rng(10) + B = rng.random(64) + sdfg = cond_index_diverge.to_sdfg(simplify=False) + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for sel, base in ((1, 5), (0, 5), (1, 30), (0, 10)): + A = np.array([sel], dtype=np.int64) + # idx3 = idx2 + (1 if taken else 4), idx2 = idx + 1. + offset = 1 if sel > 0 else 4 + expected = np.array([B[base + 1 + offset]]) + got = np.zeros(1) + sdfg(A=A, B=B.copy(), C=got, idx=base) + assert np.allclose(got, expected) + + +def test_cond_index_diverge_join_frontend(): + """ + Divergent ``idx3`` *used after* the join point (frontend). + + This is the canonical failure mode: the pass must not propagate either + branch's value of ``idx3`` past the merge. + """ + rng = np.random.default_rng(11) + B = rng.random(64) + sdfg = cond_index_diverge_join.to_sdfg(simplify=False) + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for sel, base in ((1, 7), (0, 7), (1, 25), (0, 3)): + A = np.array([sel], dtype=np.int64) + offset = 1 if sel > 0 else 4 + expected = np.array([B[base + 1 + offset]]) + got = np.zeros(1) + sdfg(A=A, B=B.copy(), C=got, idx=base) + assert np.allclose(got, expected) + + +def test_nested_cond_index_frontend(): + """Nested ``if A: if D:`` each diverging the index symbol (frontend).""" + rng = np.random.default_rng(12) + B = rng.random(64) + sdfg = nested_cond_index.to_sdfg(simplify=False) + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for a_sel, d_sel, base in ((1, 1, 5), (1, 0, 5), (0, 1, 5), (0, 0, 12)): + A = np.array([a_sel], dtype=np.int64) + D = np.array([d_sel], dtype=np.int64) + if a_sel > 0: + offset = 1 if d_sel > 0 else 2 + else: + offset = 8 + expected = np.array([B[base + 1 + offset]]) + got = np.zeros(1) + sdfg(A=A, D=D, B=B.copy(), C=got, idx=base) + assert np.allclose(got, expected) + + +def test_cond_index_diverge_join_api_conditionalblock(): + """ + SDFG-API ``ConditionalBlock`` variant of the divergent-join pattern. + + ``v2 = v1 + 1`` upstream; the conditional sets ``v3`` differently on each + branch (no else needed: an implicit else means the merge is non-uniform too, + but here both branches assign ``v3`` to distinct values). ``v3`` is used in + a gather *after* the conditional, so its value must remain branch-dependent. + """ + sdfg = dace.SDFG("cond_join_api") + sdfg.add_array("B", [64], dace.float64) + sdfg.add_array("C", [1], dace.float64) + sdfg.add_symbol("v1", dace.int64) + sdfg.add_symbol("v2", dace.int64) + sdfg.add_symbol("v3", dace.int64) + sdfg.add_symbol("sel", dace.int64) + + pre = sdfg.add_state("pre", is_start_block=True) + cond = ConditionalBlock("cond", sdfg) + sdfg.add_node(cond) + # Set v2 on the edge into the conditional; the branch picks v3 from it. + sdfg.add_edge(pre, cond, dace.InterstateEdge(assignments={"v2": "v1 + 1"})) + + then_region = ControlFlowRegion("then", sdfg) + t0 = then_region.add_state("t0", is_start_block=True) + t1 = then_region.add_state("t1") + then_region.add_edge(t0, t1, dace.InterstateEdge(assignments={"v3": "v2 + 1"})) + cond.add_branch(CodeBlock("sel > 0"), then_region) + + else_region = ControlFlowRegion("else", sdfg) + e0 = else_region.add_state("e0", is_start_block=True) + e1 = else_region.add_state("e1") + else_region.add_edge(e0, e1, dace.InterstateEdge(assignments={"v3": "v2 + 4"})) + cond.add_branch(None, else_region) + + post = sdfg.add_state("post") + sdfg.add_edge(cond, post, dace.InterstateEdge()) + tasklet = post.add_tasklet("g", {"inp"}, {"out"}, "out = inp") + rd = post.add_access("B") + wr = post.add_access("C") + post.add_edge(rd, None, tasklet, "inp", dace.Memlet("B[v3]")) + post.add_edge(tasklet, "out", wr, None, dace.Memlet("C[0]")) + sdfg.validate() + + rng = np.random.default_rng(13) + B = rng.random(64) + + def run(): + results = {} + for sel, base in ((1, 6), (0, 6), (1, 20), (0, 2)): + out = np.zeros(1) + sdfg(B=B.copy(), C=out, v1=base, sel=sel) + results[(sel, base)] = out.copy() + return results + + expected = run() + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + got = run() + for key in expected: + assert np.allclose(got[key], expected[key]), key + + +# --------------------------------------------------------------------------- +# Pattern 3: interstate-edge condition that itself reads a propagated symbol +# --------------------------------------------------------------------------- + + +def test_condition_reads_propagated_symbol(): + """ + An interstate-edge condition ``idx3 > K`` where ``idx3`` is symbol-assigned upstream. + + The pass may substitute ``idx3`` into the condition; the resulting branch + selection must stay correct. + """ + K = 10 + sdfg = dace.SDFG("cond_on_symbol") + sdfg.add_array("B", [64], dace.float64) + sdfg.add_array("C", [1], dace.float64) + sdfg.add_symbol("base", dace.int64) + sdfg.add_symbol("idx2", dace.int64) + sdfg.add_symbol("idx3", dace.int64) + + s0 = sdfg.add_state("s0", is_start_block=True) + guard = sdfg.add_state("guard") + big = sdfg.add_state("big") + small = sdfg.add_state("small") + sink = sdfg.add_state("sink") + + sdfg.add_edge(s0, guard, dace.InterstateEdge(assignments={"idx2": "base + 1"})) + sdfg.add_edge(guard, big, dace.InterstateEdge(assignments={"idx3": "idx2 + 2"}, condition=f"idx2 + 2 > {K}")) + sdfg.add_edge(guard, small, dace.InterstateEdge(assignments={"idx3": "idx2 + 2"}, condition=f"idx2 + 2 <= {K}")) + + # big writes B[idx3], small writes B[0] + tb = big.add_tasklet("g", {"inp"}, {"out"}, "out = inp") + big.add_edge(big.add_access("B"), None, tb, "inp", dace.Memlet("B[idx3]")) + big.add_edge(tb, "out", big.add_access("C"), None, dace.Memlet("C[0]")) + + ts = small.add_tasklet("g", {"inp"}, {"out"}, "out = inp") + small.add_edge(small.add_access("B"), None, ts, "inp", dace.Memlet("B[0]")) + small.add_edge(ts, "out", small.add_access("C"), None, dace.Memlet("C[0]")) + + sdfg.add_edge(big, sink, dace.InterstateEdge()) + sdfg.add_edge(small, sink, dace.InterstateEdge()) + sdfg.validate() + + rng = np.random.default_rng(20) + B = rng.random(64) + + def oracle(base): + idx3 = base + 1 + 2 + return B[idx3] if idx3 > K else B[0] + + bases = (0, 6, 7, 8, 30) + expected = {b: oracle(b) for b in bases} + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for b in bases: + out = np.zeros(1) + sdfg(B=B.copy(), C=out, base=b) + assert np.allclose(out[0], expected[b]), b + + +def test_condition_reads_chained_symbol_loopregion(): + """ + A ``LoopRegion`` whose condition reads a symbol assigned outside the loop. + + ``limit = base + cnt`` is set before the loop and the loop runs ``i < limit``. + The pass may propagate ``limit`` into the loop condition; the trip count must + stay correct. + """ + sdfg = dace.SDFG("cond_loop") + sdfg.add_array("C", [32], dace.float64) + sdfg.add_symbol("base", dace.int64) + sdfg.add_symbol("cnt", dace.int64) + sdfg.add_symbol("limit", dace.int64) + + init = sdfg.add_state("init", is_start_block=True) + pre = sdfg.add_state("pre") + sdfg.add_edge(init, pre, dace.InterstateEdge(assignments={"cnt": "3"})) + loop = LoopRegion("loop", "i < limit", "i", "i = 0", "i = i + 1") + sdfg.add_node(loop) + sdfg.add_edge(pre, loop, dace.InterstateEdge(assignments={"limit": "base + cnt"})) + + body = loop.add_state("body", is_start_block=True) + tk = body.add_tasklet("w", {}, {"out"}, "out = 1.0") + body.add_edge(tk, "out", body.add_access("C"), None, dace.Memlet("C[i]")) + + end = sdfg.add_state("end") + sdfg.add_edge(loop, end, dace.InterstateEdge()) + sdfg.validate() + + def oracle(base): + limit = base + 3 + out = np.zeros(32) + out[:limit] = 1.0 + return out + + bases = (2, 5, 10) + expected = {b: oracle(b) for b in bases} + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for b in bases: + out = np.zeros(32) + sdfg(C=out, base=b) + assert np.allclose(out, expected[b]), b + + +# --------------------------------------------------------------------------- +# Pattern 4: loop-carried index symbols (LoopRegion loop AND dace.map) +# --------------------------------------------------------------------------- + + +def test_loop_carried_range_frontend(): + """Loop-carried ``idx = idx + step`` indexing an array (range / LoopRegion).""" + rng = np.random.default_rng(30) + B = rng.random(16) + sdfg = loop_carried_range.to_sdfg(simplify=False) + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for step in (1, 2, 3): + expected = np.array([B[(i * step) % 16] for i in range(16)]) + got = np.zeros(16) + sdfg(B=B.copy(), C=got, step=step) + assert np.allclose(got, expected), step + + +def test_loop_carried_api_loopregion(): + """ + SDFG-API ``LoopRegion`` with a loop-carried symbol updated on a body edge. + + ``acc`` accumulates ``acc = acc + i`` across iterations and is stored into + ``C[i]``. The carried value must not be replaced by a stale constant. + """ + sdfg = dace.SDFG("carried_api") + sdfg.add_array("C", [16], dace.int64) + sdfg.add_symbol("acc", dace.int64) + + init = sdfg.add_state("init", is_start_block=True) + loop = LoopRegion("loop", "i < 16", "i", "i = 0", "i = i + 1") + sdfg.add_node(loop) + sdfg.add_edge(init, loop, dace.InterstateEdge(assignments={"acc": "0"})) + + body = loop.add_state("body", is_start_block=True) + upd = loop.add_state("upd") + tk = body.add_tasklet("w", {}, {"out"}, "out = acc") + body.add_edge(tk, "out", body.add_access("C"), None, dace.Memlet("C[i]")) + loop.add_edge(body, upd, dace.InterstateEdge(assignments={"acc": "acc + i"})) + + end = sdfg.add_state("end") + sdfg.add_edge(loop, end, dace.InterstateEdge()) + sdfg.validate() + + expected = np.zeros(16, dtype=np.int64) + acc = 0 + for i in range(16): + expected[i] = acc + acc = acc + i + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + got = np.zeros(16, dtype=np.int64) + sdfg(C=got) + assert np.array_equal(got, expected) + + +def test_map_chained_index_frontend(): + """Chained index symbols feeding a ``dace.map`` body access.""" + rng = np.random.default_rng(31) + B = rng.random((16, 64)) + sdfg = map_chained_index.to_sdfg(simplify=False) + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for base in (0, 5, 40): + expected = B[:, base + 2].copy() + got = np.zeros(16) + sdfg(B=B.copy(), C=got, idx=base) + assert np.allclose(got, expected), base + + +def test_loop_then_map_chained_index_api(): + """ + A ``LoopRegion`` that derives an index symbol, then a map body consuming it. + + Exercises both loop and map in one SDFG: the loop computes ``picked`` (a + loop-carried symbol's final value), then a map writes ``C[i] = B[i, picked]``. + """ + sdfg = dace.SDFG("loop_then_map") + sdfg.add_array("B", [8, 32], dace.float64) + sdfg.add_array("C", [8], dace.float64) + sdfg.add_symbol("picked", dace.int64) + sdfg.add_symbol("colbase", dace.int64) + + init = sdfg.add_state("init", is_start_block=True) + loop = LoopRegion("loop", "i < 3", "i", "i = 0", "i = i + 1") + sdfg.add_node(loop) + sdfg.add_edge(init, loop, dace.InterstateEdge(assignments={"picked": "colbase"})) + lb = loop.add_state("lb", is_start_block=True) + lu = loop.add_state("lu") + loop.add_edge(lb, lu, dace.InterstateEdge(assignments={"picked": "picked + 1"})) + + consume = sdfg.add_state("consume") + sdfg.add_edge(loop, consume, dace.InterstateEdge()) + me, mx = consume.add_map("m", dict(j="0:8")) + tk = consume.add_tasklet("g", {"inp"}, {"out"}, "out = inp") + rd = consume.add_access("B") + wr = consume.add_access("C") + consume.add_memlet_path(rd, me, tk, dst_conn="inp", memlet=dace.Memlet("B[j, picked]")) + consume.add_memlet_path(tk, mx, wr, src_conn="out", memlet=dace.Memlet("C[j]")) + sdfg.validate() + + rng = np.random.default_rng(32) + B = rng.random((8, 32)) + + def oracle(colbase): + picked = colbase + 3 # loop runs i=0,1,2 -> +1 thrice + return B[:, picked].copy() + + bases = (0, 5, 20) + expected = {b: oracle(b) for b in bases} + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for b in bases: + got = np.zeros(8) + sdfg(B=B.copy(), C=got, colbase=b) + assert np.allclose(got, expected[b]), b + + +# --------------------------------------------------------------------------- +# Pattern 5: double indirection / gather (symbol read from array, used as index) +# --------------------------------------------------------------------------- + + +def test_gather_symbol_from_array_api(): + """ + ``j = tbl[i]; out = B[j]`` modeled with an interstate-edge array read. + + The pass explicitly refuses to propagate symbol values containing ``[``/``]`` + (array accesses), so ``j`` stays as a symbol and the gather must be preserved. + """ + sdfg = dace.SDFG("gather_api") + sdfg.add_array("tbl", [8], dace.int64) + sdfg.add_array("B", [64], dace.float64) + sdfg.add_array("C", [1], dace.float64) + sdfg.add_symbol("i", dace.int64) + sdfg.add_symbol("j", dace.int64) + + s0 = sdfg.add_state("s0", is_start_block=True) + s1 = sdfg.add_state("s1") + sdfg.add_edge(s0, s1, dace.InterstateEdge(assignments={"j": "tbl[i]"})) + tk = s1.add_tasklet("g", {"inp"}, {"out"}, "out = inp") + s1.add_edge(s1.add_access("B"), None, tk, "inp", dace.Memlet("B[j]")) + s1.add_edge(tk, "out", s1.add_access("C"), None, dace.Memlet("C[0]")) + sdfg.validate() + + rng = np.random.default_rng(40) + B = rng.random(64) + tbl = rng.integers(0, 64, size=8).astype(np.int64) + + expected = {} + for i in range(8): + out = np.zeros(1) + sdfg(tbl=tbl.copy(), B=B.copy(), C=out, i=i) + expected[i] = out.copy() + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + # The array-read assignment must survive (not be folded to a constant). + assert any("tbl[" in v for v in _all_assignment_values(sdfg)) + for i in range(8): + out = np.zeros(1) + sdfg(tbl=tbl.copy(), B=B.copy(), C=out, i=i) + assert np.allclose(out, expected[i]), i + + +def test_gather_chained_through_array_and_offset_api(): + """ + Double indirection with an arithmetic symbol layered on the array read. + + ``j = tbl[i]; k = j + off; out = B[k]`` — ``j`` (array-read) is not + propagated, but ``k = j + off`` is a pure symbol expression that *can* be + propagated; results must be unchanged either way. + """ + sdfg = dace.SDFG("gather_offset_api") + sdfg.add_array("tbl", [8], dace.int64) + sdfg.add_array("B", [128], dace.float64) + sdfg.add_array("C", [1], dace.float64) + sdfg.add_symbol("i", dace.int64) + sdfg.add_symbol("off", dace.int64) + sdfg.add_symbol("j", dace.int64) + sdfg.add_symbol("k", dace.int64) + + s0 = sdfg.add_state("s0", is_start_block=True) + s1 = sdfg.add_state("s1") + s2 = sdfg.add_state("s2") + sdfg.add_edge(s0, s1, dace.InterstateEdge(assignments={"j": "tbl[i]"})) + sdfg.add_edge(s1, s2, dace.InterstateEdge(assignments={"k": "j + off"})) + tk = s2.add_tasklet("g", {"inp"}, {"out"}, "out = inp") + s2.add_edge(s2.add_access("B"), None, tk, "inp", dace.Memlet("B[k]")) + s2.add_edge(tk, "out", s2.add_access("C"), None, dace.Memlet("C[0]")) + sdfg.validate() + + rng = np.random.default_rng(41) + B = rng.random(128) + tbl = rng.integers(0, 60, size=8).astype(np.int64) + + expected = {} + for i, off in ((0, 0), (3, 5), (7, 10)): + out = np.zeros(1) + sdfg(tbl=tbl.copy(), B=B.copy(), C=out, i=i, off=off) + expected[(i, off)] = out.copy() + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for i, off in ((0, 0), (3, 5), (7, 10)): + out = np.zeros(1) + sdfg(tbl=tbl.copy(), B=B.copy(), C=out, i=i, off=off) + assert np.allclose(out, expected[(i, off)]), (i, off) + + +def test_gather_per_iteration_loopregion(): + """ + Per-iteration gather inside a ``LoopRegion`` (``j = tbl[i]; C[i] = B[j]``). + + The index symbol is re-read from the table every iteration; the pass must + not hoist or freeze it. + """ + sdfg = dace.SDFG("gather_loop") + sdfg.add_array("tbl", [16], dace.int64) + sdfg.add_array("B", [64], dace.float64) + sdfg.add_array("C", [16], dace.float64) + sdfg.add_symbol("j", dace.int64) + + init = sdfg.add_state("init", is_start_block=True) + loop = LoopRegion("loop", "i < 16", "i", "i = 0", "i = i + 1") + sdfg.add_node(loop) + sdfg.add_edge(init, loop, dace.InterstateEdge()) + + read = loop.add_state("read", is_start_block=True) + use = loop.add_state("use") + loop.add_edge(read, use, dace.InterstateEdge(assignments={"j": "tbl[i]"})) + tk = use.add_tasklet("g", {"inp"}, {"out"}, "out = inp") + use.add_edge(use.add_access("B"), None, tk, "inp", dace.Memlet("B[j]")) + use.add_edge(tk, "out", use.add_access("C"), None, dace.Memlet("C[i]")) + + end = sdfg.add_state("end") + sdfg.add_edge(loop, end, dace.InterstateEdge()) + sdfg.validate() + + rng = np.random.default_rng(42) + B = rng.random(64) + tbl = rng.integers(0, 64, size=16).astype(np.int64) + expected = B[tbl] + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + got = np.zeros(16) + sdfg(tbl=tbl.copy(), B=B.copy(), C=got) + assert np.allclose(got, expected) + + +# --------------------------------------------------------------------------- +# Pattern 6: same symbol reused with different values in sibling scopes +# --------------------------------------------------------------------------- + + +def test_sibling_scopes_reuse_frontend(): + """The same symbol name with distinct values in two sibling range loops.""" + rng = np.random.default_rng(50) + B = rng.random(64) + sdfg = sibling_scopes_reuse.to_sdfg(simplify=False) + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for base in (0, 10, 30): + expected = np.array([B[base + 1], B[base + 5]]) + got = np.zeros(2) + sdfg(B=B.copy(), C=got, idx=base) + assert np.allclose(got, expected), base + + +def test_sibling_scopes_reuse_api(): + """ + SDFG-API: ``k`` defined twice with different values on disjoint sequential paths. + + Two sequential single-state scopes redefine ``k``; the second must win for + the second access and the first for the first access (no contamination). + """ + sdfg = dace.SDFG("siblings_api") + sdfg.add_array("B", [64], dace.float64) + sdfg.add_array("C", [2], dace.float64) + sdfg.add_symbol("base", dace.int64) + sdfg.add_symbol("k", dace.int64) + + s0 = sdfg.add_state("s0", is_start_block=True) + a = sdfg.add_state("a") + mid = sdfg.add_state("mid") + b = sdfg.add_state("b") + + sdfg.add_edge(s0, a, dace.InterstateEdge(assignments={"k": "base + 1"})) + ta = a.add_tasklet("g", {"inp"}, {"out"}, "out = inp") + a.add_edge(a.add_access("B"), None, ta, "inp", dace.Memlet("B[k]")) + a.add_edge(ta, "out", a.add_access("C"), None, dace.Memlet("C[0]")) + + sdfg.add_edge(a, mid, dace.InterstateEdge(assignments={"k": "base + 5"})) + sdfg.add_edge(mid, b, dace.InterstateEdge()) + tb = b.add_tasklet("g", {"inp"}, {"out"}, "out = inp") + b.add_edge(b.add_access("B"), None, tb, "inp", dace.Memlet("B[k]")) + b.add_edge(tb, "out", b.add_access("C"), None, dace.Memlet("C[1]")) + sdfg.validate() + + rng = np.random.default_rng(51) + B = rng.random(64) + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for base in (0, 8, 40): + expected = np.array([B[base + 1], B[base + 5]]) + got = np.zeros(2) + sdfg(B=B.copy(), C=got, base=base) + assert np.allclose(got, expected), base + + +# --------------------------------------------------------------------------- +# Pattern: merge of equal vs. unequal values (the join correctness boundary) +# --------------------------------------------------------------------------- + + +def test_branch_uniform_value_may_propagate_api(): + """ + Both branches assign the *same* value to a symbol used after the join. + + A correct pass is free to propagate the (uniform) value past the merge; the + point of the test is that doing so must not change the result. + """ + sdfg = dace.SDFG("uniform_join") + sdfg.add_array("B", [64], dace.float64) + sdfg.add_array("C", [1], dace.float64) + sdfg.add_symbol("base", dace.int64) + sdfg.add_symbol("v", dace.int64) + sdfg.add_symbol("sel", dace.int64) + + pre = sdfg.add_state("pre", is_start_block=True) + cond = ConditionalBlock("cond", sdfg) + sdfg.add_node(cond) + sdfg.add_edge(pre, cond, dace.InterstateEdge()) + + then_region = ControlFlowRegion("then", sdfg) + t0 = then_region.add_state("t0", is_start_block=True) + t1 = then_region.add_state("t1") + then_region.add_edge(t0, t1, dace.InterstateEdge(assignments={"v": "base + 3"})) + cond.add_branch(CodeBlock("sel > 0"), then_region) + + else_region = ControlFlowRegion("else", sdfg) + e0 = else_region.add_state("e0", is_start_block=True) + e1 = else_region.add_state("e1") + else_region.add_edge(e0, e1, dace.InterstateEdge(assignments={"v": "base + 3"})) + cond.add_branch(None, else_region) + + post = sdfg.add_state("post") + sdfg.add_edge(cond, post, dace.InterstateEdge()) + tk = post.add_tasklet("g", {"inp"}, {"out"}, "out = inp") + post.add_edge(post.add_access("B"), None, tk, "inp", dace.Memlet("B[v]")) + post.add_edge(tk, "out", post.add_access("C"), None, dace.Memlet("C[0]")) + sdfg.validate() + + rng = np.random.default_rng(60) + B = rng.random(64) + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for sel, base in ((1, 4), (0, 4), (1, 20)): + expected = np.array([B[base + 3]]) + got = np.zeros(1) + sdfg(B=B.copy(), C=got, base=base, sel=sel) + assert np.allclose(got, expected), (sel, base) + + +def test_no_else_branch_implicit_merge_api(): + """ + Conditional with only a ``then`` branch (implicit else) feeding a later use. + + The implicit-else path leaves the symbol at its incoming value, so the merged + value past the conditional is non-uniform and must not be propagated. + """ + sdfg = dace.SDFG("implicit_else") + sdfg.add_array("B", [64], dace.float64) + sdfg.add_array("C", [1], dace.float64) + sdfg.add_symbol("base", dace.int64) + sdfg.add_symbol("v", dace.int64) + sdfg.add_symbol("sel", dace.int64) + + pre = sdfg.add_state("pre", is_start_block=True) + cond = ConditionalBlock("cond", sdfg) + sdfg.add_node(cond) + # v defaults to base on the way in; the then-branch overwrites it. + sdfg.add_edge(pre, cond, dace.InterstateEdge(assignments={"v": "base"})) + + then_region = ControlFlowRegion("then", sdfg) + t0 = then_region.add_state("t0", is_start_block=True) + t1 = then_region.add_state("t1") + then_region.add_edge(t0, t1, dace.InterstateEdge(assignments={"v": "base + 7"})) + cond.add_branch(CodeBlock("sel > 0"), then_region) + + post = sdfg.add_state("post") + sdfg.add_edge(cond, post, dace.InterstateEdge()) + tk = post.add_tasklet("g", {"inp"}, {"out"}, "out = inp") + post.add_edge(post.add_access("B"), None, tk, "inp", dace.Memlet("B[v]")) + post.add_edge(tk, "out", post.add_access("C"), None, dace.Memlet("C[0]")) + sdfg.validate() + + rng = np.random.default_rng(61) + B = rng.random(64) + + def oracle(sel, base): + v = base + 7 if sel > 0 else base + return B[v] + + cases = ((1, 4), (0, 4), (1, 20), (0, 30)) + expected = {c: oracle(*c) for c in cases} + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for c in cases: + got = np.zeros(1) + sdfg(B=B.copy(), C=got, base=c[1], sel=c[0]) + assert np.allclose(got[0], expected[c]), c + + +# --------------------------------------------------------------------------- +# Pattern: mutually inter-dependent symbols updated together (loop-carried pair) +# --------------------------------------------------------------------------- + + +def test_interdependent_pair_loop_api(): + """ + Two mutually inter-dependent loop-carried symbols updated on one edge. + + ``a, b`` co-evolve (``a' = a + b``, ``b' = a``); ``a`` indexes the output. + Both must be treated as loop-carried (not propagated as constants). + + Regression for the same-edge race bug: SymbolPropagation must not + substitute ``anext -> a + b`` into the ``{b: a, a: anext}`` edge (that + would make ``a`` both read and written on one edge). Fixed by the + per-edge self-collision guard in ``_update_syms``. + """ + sdfg = dace.SDFG("pair_loop") + sdfg.add_array("C", [10], dace.int64) + sdfg.add_symbol("a", dace.int64) + sdfg.add_symbol("b", dace.int64) + sdfg.add_symbol("anext", dace.int64) + sdfg.add_symbol("bnext", dace.int64) + + init = sdfg.add_state("init", is_start_block=True) + loop = LoopRegion("loop", "i < 10", "i", "i = 0", "i = i + 1") + sdfg.add_node(loop) + sdfg.add_edge(init, loop, dace.InterstateEdge(assignments={"a": "0", "b": "1"})) + + body = loop.add_state("body", is_start_block=True) + upd = loop.add_state("upd") + tk = body.add_tasklet("w", {}, {"out"}, "out = a") + body.add_edge(tk, "out", body.add_access("C"), None, dace.Memlet("C[i]")) + # Capture BOTH new values into temps first (reads only old a, b -- not + # co-assigned), then assign a, b from the temps (reads only anext, bnext -- + # not co-assigned). Both edges are valid simultaneous assignments (no RHS + # reads a key written on the same edge). + mid = loop.add_state("mid") + loop.add_edge(body, mid, dace.InterstateEdge(assignments={"anext": "a + b", "bnext": "a"})) + loop.add_edge(mid, upd, dace.InterstateEdge(assignments={"a": "anext", "b": "bnext"})) + + end = sdfg.add_state("end") + sdfg.add_edge(loop, end, dace.InterstateEdge()) + sdfg.validate() + + expected = np.zeros(10, dtype=np.int64) + a, b = 0, 1 + for i in range(10): + expected[i] = a + anext = a + b + b = a + a = anext + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + got = np.zeros(10, dtype=np.int64) + sdfg(C=got) + assert np.array_equal(got, expected) + + +# =========================================================================== +# APPENDED: same-edge multi-assignment race / ordering hazards. +# +# These target the confirmed defect in ``SymbolPropagation._update_syms``: +# the pass substitutes propagated symbol VALUES into an outgoing interstate +# edge's assignment RHSes without checking whether that edge's own assignment +# LHS keys collide with the substitution's free symbols. When they do, a single +# variable becomes both read and written on the same (simultaneous-assignment) +# edge, which ``sdfg.validate()`` rejects as a race condition. Adjacent +# propagation hazards (ordering, self-reference, loop/conditional condition use, +# diamond merges, index chains) are exercised alongside. +# +# Every test below builds a VALID SDFG (asserted before the pass), computes a +# reference, applies the pass, re-validates, and re-checks values. A genuine bug +# surfaces as a clean validate-failure or a value mismatch; none are marked xfail +# here (the parent triages genuine-bug vs test-artifact). +# =========================================================================== + +# --------------------------------------------------------------------------- +# Module-level frontend kernels for the appended tests (unique names). +# --------------------------------------------------------------------------- + + +@dace.program +def selfref_counter_range(C: dace.int64[16], start: dace.int64, step: dace.int64): + """ + Self-referential loop-carried counter feeding the stored value (range form). + + ``cnt = cnt + step`` updates on the loop's back/update edge while ``cnt``'s + upstream value is a candidate for propagation into that same update edge. + + :param C: Output array, one element per iteration. + :param start: Initial counter value. + :param step: Per-iteration increment. + """ + cnt = start + for i in range(16): + C[i] = cnt + cnt = cnt + step + + +# --------------------------------------------------------------------------- +# Pattern A: same-edge multi-assignment race -- swap {x: y, y: expr_using_x} +# --------------------------------------------------------------------------- + + +def test_swap_pair_with_upstream_temp_api(): + """ + Edge ``{x: ty, y: tx}`` where ``tx = x`` and ``ty = y`` are assigned upstream. + + The swap edge is built VALID: each RHS is a fresh temp (``ty``/``tx``), not a + key co-assigned on the same edge, so the pre-pass SDFG validates. ``out_syms`` + at the swapping block carries ``tx -> x`` and ``ty -> y``. The pass may + substitute both into ``{x: ty, y: tx}``, yielding ``{x: y, y: x}`` -- now + ``x`` and ``y`` are each read AND written on the same simultaneous-assignment + edge: a race the input never had. A correct pass keeps it valid and + value-preserving (a clean swap). + """ + sdfg = dace.SDFG("swap_pair_api") + sdfg.add_array("C", [10], dace.int64) + sdfg.add_symbol("x", dace.int64) + sdfg.add_symbol("y", dace.int64) + sdfg.add_symbol("tx", dace.int64) + sdfg.add_symbol("ty", dace.int64) + + init = sdfg.add_state("init", is_start_block=True) + loop = LoopRegion("loop", "i < 10", "i", "i = 0", "i = i + 1") + sdfg.add_node(loop) + sdfg.add_edge(init, loop, dace.InterstateEdge(assignments={"x": "1", "y": "4"})) + + body = loop.add_state("body", is_start_block=True) + mid = loop.add_state("mid") + upd = loop.add_state("upd") + tk = body.add_tasklet("w", {}, {"out"}, "out = x") + body.add_edge(tk, "out", body.add_access("C"), None, dace.Memlet("C[i]")) + # Capture old x, y into temps, then assign crosswise (valid simultaneous swap). + loop.add_edge(body, mid, dace.InterstateEdge(assignments={"tx": "x", "ty": "y"})) + loop.add_edge(mid, upd, dace.InterstateEdge(assignments={"x": "ty", "y": "tx"})) + + end = sdfg.add_state("end") + sdfg.add_edge(loop, end, dace.InterstateEdge()) + sdfg.validate() + + expected = np.zeros(10, dtype=np.int64) + x, y = 1, 4 + for i in range(10): + expected[i] = x + x, y = y, x + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + got = np.zeros(10, dtype=np.int64) + sdfg(C=got) + assert np.array_equal(got, expected) + + +def test_two_keys_share_upstream_temp_api(): + """ + Edge ``{a: t, b: t}`` where ``t = a + 1`` upstream (reintroduces ``a`` next to its write). + + Substituting ``t -> a + 1`` into ``{a: t, b: t}`` yields ``{a: a + 1, b: a + 1}``: + ``a`` is now read AND written on the same edge -- a race. The pre-pass SDFG + uses the upstream temp legitimately and is value-preserving. + """ + sdfg = dace.SDFG("two_keys_temp_api") + sdfg.add_array("C", [8], dace.int64) + sdfg.add_symbol("a", dace.int64) + sdfg.add_symbol("b", dace.int64) + sdfg.add_symbol("t", dace.int64) + + init = sdfg.add_state("init", is_start_block=True) + loop = LoopRegion("loop", "i < 8", "i", "i = 0", "i = i + 1") + sdfg.add_node(loop) + sdfg.add_edge(init, loop, dace.InterstateEdge(assignments={"a": "0", "b": "0"})) + + body = loop.add_state("body", is_start_block=True) + mid = loop.add_state("mid") + upd = loop.add_state("upd") + tk = body.add_tasklet("w", {}, {"out"}, "out = a + b") + body.add_edge(tk, "out", body.add_access("C"), None, dace.Memlet("C[i]")) + loop.add_edge(body, mid, dace.InterstateEdge(assignments={"t": "a + 1"})) + loop.add_edge(mid, upd, dace.InterstateEdge(assignments={"a": "t", "b": "t"})) + + end = sdfg.add_state("end") + sdfg.add_edge(loop, end, dace.InterstateEdge()) + sdfg.validate() + + expected = np.zeros(8, dtype=np.int64) + a, b = 0, 0 + for i in range(8): + expected[i] = a + b + t = a + 1 + a = t + b = t + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + got = np.zeros(8, dtype=np.int64) + sdfg(C=got) + assert np.array_equal(got, expected) + + +def test_three_cycle_rotation_via_temps_api(): + """ + Rotation ``{p: tq, q: tr, r: tp}`` where ``tp=p, tq=q, tr=r`` are captured upstream. + + A 3-cycle rotation expressed VALIDLY: each RHS is a fresh capture temp, so no + key is read on the same edge that writes it (pre-pass valid). The pass carries + ``tp -> p``, ``tq -> q``, ``tr -> r`` in ``out_syms`` and may substitute them + into the rotation edge, producing ``{p: q, q: r, r: p}`` -- every key now read + and written on the same simultaneous-assignment edge: a 3-way race the input + never had. A correct pass keeps it valid and preserves the rotation. + """ + sdfg = dace.SDFG("three_cycle_api") + sdfg.add_array("C", [9], dace.int64) + sdfg.add_symbol("p", dace.int64) + sdfg.add_symbol("q", dace.int64) + sdfg.add_symbol("r", dace.int64) + sdfg.add_symbol("tp", dace.int64) + sdfg.add_symbol("tq", dace.int64) + sdfg.add_symbol("tr", dace.int64) + + init = sdfg.add_state("init", is_start_block=True) + loop = LoopRegion("loop", "i < 9", "i", "i = 0", "i = i + 1") + sdfg.add_node(loop) + sdfg.add_edge(init, loop, dace.InterstateEdge(assignments={"p": "1", "q": "2", "r": "3"})) + + body = loop.add_state("body", is_start_block=True) + cap = loop.add_state("cap") + upd = loop.add_state("upd") + tk = body.add_tasklet("w", {}, {"out"}, "out = p") + body.add_edge(tk, "out", body.add_access("C"), None, dace.Memlet("C[i]")) + loop.add_edge(body, cap, dace.InterstateEdge(assignments={"tp": "p", "tq": "q", "tr": "r"})) + loop.add_edge(cap, upd, dace.InterstateEdge(assignments={"p": "tq", "q": "tr", "r": "tp"})) + + end = sdfg.add_state("end") + sdfg.add_edge(loop, end, dace.InterstateEdge()) + sdfg.validate() + + expected = np.zeros(9, dtype=np.int64) + p, q, r = 1, 2, 3 + for i in range(9): + expected[i] = p + p, q, r = q, r, p + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + got = np.zeros(9, dtype=np.int64) + sdfg(C=got) + assert np.array_equal(got, expected) + + +def test_swap_via_temps_acyclic_api(): + """ + Acyclic swap ``{x: tx, y: ty}`` where ``tx = y`` and ``ty = x`` upstream. + + Plain-state (no loop) variant of the swap-via-temps race: the swap edge is + valid pre-pass (RHSes are capture temps, not co-assigned keys). Propagating + ``tx -> y`` and ``ty -> x`` into ``{x: tx, y: ty}`` yields ``{x: y, y: x}`` -- + a same-edge read-write race on both ``x`` and ``y``. The post-swap values + index ``B``, so any corruption is observable. Differs from the loop variant by + exercising the acyclic ``out_syms`` propagation path. + """ + sdfg = dace.SDFG("swap_two_temps_api") + sdfg.add_array("B", [128], dace.float64) + sdfg.add_array("C", [2], dace.float64) + sdfg.add_symbol("x", dace.int64) + sdfg.add_symbol("y", dace.int64) + sdfg.add_symbol("tx", dace.int64) + sdfg.add_symbol("ty", dace.int64) + + s0 = sdfg.add_state("s0", is_start_block=True) + cap = sdfg.add_state("cap") + upd = sdfg.add_state("upd") + use = sdfg.add_state("use") + sdfg.add_edge(s0, cap, dace.InterstateEdge(assignments={"x": "10", "y": "40"})) + # Capture both old values first, then assign from the captures (true swap). + sdfg.add_edge(cap, upd, dace.InterstateEdge(assignments={"tx": "y", "ty": "x"})) + sdfg.add_edge(upd, use, dace.InterstateEdge(assignments={"x": "tx", "y": "ty"})) + + t0 = use.add_tasklet("g0", {"inp"}, {"out"}, "out = inp") + use.add_edge(use.add_access("B"), None, t0, "inp", dace.Memlet("B[x]")) + use.add_edge(t0, "out", use.add_access("C"), None, dace.Memlet("C[0]")) + t1 = use.add_tasklet("g1", {"inp"}, {"out"}, "out = inp") + use.add_edge(use.add_access("B"), None, t1, "inp", dace.Memlet("B[y]")) + use.add_edge(t1, "out", use.add_access("C"), None, dace.Memlet("C[1]")) + sdfg.validate() + + rng = np.random.default_rng(72) + B = rng.random(128) + # After the swap: x == 40, y == 10. + expected = np.array([B[40], B[10]]) + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + got = np.zeros(2) + sdfg(B=B.copy(), C=got) + assert np.allclose(got, expected) + + +# --------------------------------------------------------------------------- +# Pattern B: substitution into a non-loop multi-assignment edge (ordering) +# --------------------------------------------------------------------------- + + +def test_multi_assign_temp_substitution_acyclic_api(): + """ + Acyclic ``{m: t, n: t + 1}`` edge where ``t = m + 2`` is assigned upstream. + + No loop: ``s0 -> s1`` assigns ``t = m + 2``; ``s1 -> s2`` assigns ``m`` and + ``n`` both in terms of ``t``. Substituting ``t -> m + 2`` reintroduces ``m`` + on the same edge that also writes ``m`` (race). ``s2`` uses both ``m`` and + ``n`` as indices, so any ordering corruption shows up in the gathered values. + """ + sdfg = dace.SDFG("multi_assign_acyclic_api") + sdfg.add_array("B", [128], dace.float64) + sdfg.add_array("C", [2], dace.float64) + sdfg.add_symbol("m", dace.int64) + sdfg.add_symbol("n", dace.int64) + sdfg.add_symbol("t", dace.int64) + + s0 = sdfg.add_state("s0", is_start_block=True) + s1 = sdfg.add_state("s1") + s2 = sdfg.add_state("s2") + sdfg.add_edge(s0, s1, dace.InterstateEdge(assignments={"t": "m + 2"})) + sdfg.add_edge(s1, s2, dace.InterstateEdge(assignments={"m": "t", "n": "t + 1"})) + + t0 = s2.add_tasklet("g0", {"inp"}, {"out"}, "out = inp") + s2.add_edge(s2.add_access("B"), None, t0, "inp", dace.Memlet("B[m]")) + s2.add_edge(t0, "out", s2.add_access("C"), None, dace.Memlet("C[0]")) + t1 = s2.add_tasklet("g1", {"inp"}, {"out"}, "out = inp") + s2.add_edge(s2.add_access("B"), None, t1, "inp", dace.Memlet("B[n]")) + s2.add_edge(t1, "out", s2.add_access("C"), None, dace.Memlet("C[1]")) + sdfg.validate() + + rng = np.random.default_rng(70) + B = rng.random(128) + + def oracle(m0): + t = m0 + 2 + m, n = t, t + 1 + return np.array([B[m], B[n]]) + + bases = (0, 5, 40) + expected = {b: oracle(b) for b in bases} + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for b in bases: + got = np.zeros(2) + sdfg(B=B.copy(), C=got, m=b) + assert np.allclose(got, expected[b]), b + + +def test_chained_simultaneous_feeds_index_api(): + """ + Edge ``{idx: tmp, s: base + 10}`` then ``B[idx]``, with ``tmp = base + s`` upstream. + + The racing edge is built VALID: ``idx`` is assigned from a fresh temp ``tmp`` + (not the co-assigned key ``s``), so the pre-pass SDFG validates. ``out_syms`` + carries ``tmp -> base + s`` and ``s -> base`` from upstream. If the pass + substitutes ``tmp -> base + s`` into ``idx: tmp`` it reintroduces ``s`` on the + very edge that simultaneously writes ``s`` (``s: base + 10``) -- a same-edge + read-write race. ``idx`` indexes ``B`` after, so an ordering/race corruption + is observable in the gathered value (correct ``idx == 2*base``). + """ + sdfg = dace.SDFG("chained_simul_index_api") + sdfg.add_array("B", [256], dace.float64) + sdfg.add_array("C", [1], dace.float64) + sdfg.add_symbol("base", dace.int64) + sdfg.add_symbol("s", dace.int64) + sdfg.add_symbol("tmp", dace.int64) + sdfg.add_symbol("idx", dace.int64) + + s0 = sdfg.add_state("s0", is_start_block=True) + s1 = sdfg.add_state("s1") + s2 = sdfg.add_state("s2") + s3 = sdfg.add_state("s3") + sdfg.add_edge(s0, s1, dace.InterstateEdge(assignments={"s": "base"})) + # tmp captures (base + old s); idx reads tmp while s is simultaneously bumped. + sdfg.add_edge(s1, s2, dace.InterstateEdge(assignments={"tmp": "base + s"})) + sdfg.add_edge(s2, s3, dace.InterstateEdge(assignments={"idx": "tmp", "s": "base + 10"})) + tk = s3.add_tasklet("g", {"inp"}, {"out"}, "out = inp") + s3.add_edge(s3.add_access("B"), None, tk, "inp", dace.Memlet("B[idx]")) + s3.add_edge(tk, "out", s3.add_access("C"), None, dace.Memlet("C[0]")) + sdfg.validate() + + rng = np.random.default_rng(71) + B = rng.random(256) + + def oracle(base): + s = base + tmp = base + s # base + old s == 2*base + idx = tmp + return B[idx] + + bases = (0, 5, 20, 50) + expected = {b: oracle(b) for b in bases} + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for b in bases: + got = np.zeros(1) + sdfg(B=B.copy(), C=got, base=b) + assert np.allclose(got[0], expected[b]), b + + +# --------------------------------------------------------------------------- +# Pattern C: self-referential propagation across edges +# --------------------------------------------------------------------------- + + +def test_selfref_counter_range_frontend(): + """Self-referential ``cnt = cnt + step`` loop-carried counter (frontend).""" + rng = np.random.default_rng(80) + sdfg = selfref_counter_range.to_sdfg(simplify=False) + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for start, step in ((0, 1), (3, 2), (10, 5)): + expected = np.array([start + i * step for i in range(16)], dtype=np.int64) + got = np.zeros(16, dtype=np.int64) + sdfg(C=got, start=start, step=step) + assert np.array_equal(got, expected), (start, step) + + +def test_selfref_with_upstream_alias_api(): + """ + Self-referential ``cnt = cnt + d`` where ``d = cnt`` is assigned just upstream. + + On the update edge, ``cnt`` is written. The upstream edge assigns ``d = cnt``, + so ``out_syms`` at the block before the update carries ``d -> cnt``. If the + pass substitutes ``d`` into a ``cnt = cnt + d`` style edge it would read + ``cnt`` twice while writing it. Here we keep ``d`` and ``cnt`` updates on + separate edges so the pre-pass SDFG doubles ``cnt`` each iteration validly. + """ + sdfg = dace.SDFG("selfref_alias_api") + sdfg.add_array("C", [12], dace.int64) + sdfg.add_symbol("cnt", dace.int64) + sdfg.add_symbol("d", dace.int64) + + init = sdfg.add_state("init", is_start_block=True) + loop = LoopRegion("loop", "i < 12", "i", "i = 0", "i = i + 1") + sdfg.add_node(loop) + sdfg.add_edge(init, loop, dace.InterstateEdge(assignments={"cnt": "1"})) + + body = loop.add_state("body", is_start_block=True) + cap = loop.add_state("cap") + upd = loop.add_state("upd") + tk = body.add_tasklet("w", {}, {"out"}, "out = cnt") + body.add_edge(tk, "out", body.add_access("C"), None, dace.Memlet("C[i]")) + loop.add_edge(body, cap, dace.InterstateEdge(assignments={"d": "cnt"})) + loop.add_edge(cap, upd, dace.InterstateEdge(assignments={"cnt": "cnt + d"})) + + end = sdfg.add_state("end") + sdfg.add_edge(loop, end, dace.InterstateEdge()) + sdfg.validate() + + expected = np.zeros(12, dtype=np.int64) + cnt = 1 + for i in range(12): + expected[i] = cnt + d = cnt + cnt = cnt + d # doubles each iteration + # Pre-pass validity check only requires cnt + d not collide on one edge. + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + got = np.zeros(12, dtype=np.int64) + sdfg(C=got) + assert np.array_equal(got, expected) + + +# --------------------------------------------------------------------------- +# Pattern D: propagation into a LoopRegion update edge / condition +# --------------------------------------------------------------------------- + + +def test_loop_update_reads_propagated_symbol_api(): + """ + A ``LoopRegion`` whose ``i = i + step`` update reads a propagated symbol. + + ``step = stride`` is assigned on the edge into the loop (so it is a candidate + for propagation), and the loop's update expression references ``step``. The + pass may fold ``step`` into the update; the trip pattern must stay correct. + """ + sdfg = dace.SDFG("loop_update_step_api") + sdfg.add_array("C", [32], dace.int64) + sdfg.add_symbol("stride", dace.int64) + sdfg.add_symbol("step", dace.int64) + + init = sdfg.add_state("init", is_start_block=True) + loop = LoopRegion("loop", "i < 32", "i", "i = 0", "i = i + step") + sdfg.add_node(loop) + sdfg.add_edge(init, loop, dace.InterstateEdge(assignments={"step": "stride"})) + + body = loop.add_state("body", is_start_block=True) + tk = body.add_tasklet("w", {}, {"out"}, "out = 1") + body.add_edge(tk, "out", body.add_access("C"), None, dace.Memlet("C[i]")) + + end = sdfg.add_state("end") + sdfg.add_edge(loop, end, dace.InterstateEdge()) + sdfg.validate() + + def oracle(stride): + out = np.zeros(32, dtype=np.int64) + i = 0 + while i < 32: + out[i] = 1 + i += stride + return out + + strides = (1, 2, 3, 4) + expected = {s: oracle(s) for s in strides} + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for s in strides: + got = np.zeros(32, dtype=np.int64) + sdfg(C=got, stride=s) + assert np.array_equal(got, expected[s]), s + + +def test_loop_condition_reads_simultaneously_assigned_symbol_api(): + """ + Loop condition ``i < lim`` where ``lim`` is set alongside another symbol on the in-edge. + + The edge into the loop assigns ``{lim: base + extra, off: base}`` simultaneously, + and ``extra = 4`` is assigned upstream (propagatable into ``lim``). The body + writes ``C[i] = off``. The pass may fold ``extra`` and propagate ``lim`` / + ``off`` -- the trip count and stored value must both stay correct. + """ + sdfg = dace.SDFG("loop_cond_simul_api") + sdfg.add_array("C", [40], dace.int64) + sdfg.add_symbol("base", dace.int64) + sdfg.add_symbol("extra", dace.int64) + sdfg.add_symbol("lim", dace.int64) + sdfg.add_symbol("off", dace.int64) + + init = sdfg.add_state("init", is_start_block=True) + pre = sdfg.add_state("pre") + sdfg.add_edge(init, pre, dace.InterstateEdge(assignments={"extra": "4"})) + loop = LoopRegion("loop", "i < lim", "i", "i = 0", "i = i + 1") + sdfg.add_node(loop) + sdfg.add_edge(pre, loop, dace.InterstateEdge(assignments={"lim": "base + extra", "off": "base"})) + + body = loop.add_state("body", is_start_block=True) + tk = body.add_tasklet("w", {}, {"out"}, "out = off") + body.add_edge(tk, "out", body.add_access("C"), None, dace.Memlet("C[i]")) + + end = sdfg.add_state("end") + sdfg.add_edge(loop, end, dace.InterstateEdge()) + sdfg.validate() + + def oracle(base): + lim = base + 4 + off = base + out = np.zeros(40, dtype=np.int64) + out[:lim] = off + return out + + bases = (3, 6, 10, 20) + expected = {b: oracle(b) for b in bases} + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for b in bases: + got = np.zeros(40, dtype=np.int64) + sdfg(C=got, base=b) + assert np.array_equal(got, expected[b]), b + + +# --------------------------------------------------------------------------- +# Pattern E: ConditionalBlock branch condition reading a co-assigned symbol +# --------------------------------------------------------------------------- + + +def test_branch_condition_reads_coassigned_symbol_api(): + """ + Branch condition ``pick > 0`` where ``pick`` and ``v`` are co-assigned on the in-edge. + + The edge into the ``ConditionalBlock`` assigns ``{pick: base - thr, v: base}`` + simultaneously, with ``thr = 5`` propagatable from upstream. The branch + selection depends on ``pick`` and the chosen branch indexes ``B`` with ``v`` + (then-branch) or a constant (else). Folding/propagating the co-assigned + symbols must not change which branch runs or the gathered value. + """ + sdfg = dace.SDFG("branch_cond_coassign_api") + sdfg.add_array("B", [128], dace.float64) + sdfg.add_array("C", [1], dace.float64) + sdfg.add_symbol("base", dace.int64) + sdfg.add_symbol("thr", dace.int64) + sdfg.add_symbol("pick", dace.int64) + sdfg.add_symbol("v", dace.int64) + + init = sdfg.add_state("init", is_start_block=True) + pre = sdfg.add_state("pre") + sdfg.add_edge(init, pre, dace.InterstateEdge(assignments={"thr": "5"})) + + cond = ConditionalBlock("cond", sdfg) + sdfg.add_node(cond) + sdfg.add_edge(pre, cond, dace.InterstateEdge(assignments={"pick": "base - thr", "v": "base"})) + + then_region = ControlFlowRegion("then", sdfg) + t0 = then_region.add_state("t0", is_start_block=True) + tk_t = t0.add_tasklet("g", {"inp"}, {"out"}, "out = inp") + t0.add_edge(t0.add_access("B"), None, tk_t, "inp", dace.Memlet("B[v]")) + t0.add_edge(tk_t, "out", t0.add_access("C"), None, dace.Memlet("C[0]")) + cond.add_branch(CodeBlock("pick > 0"), then_region) + + else_region = ControlFlowRegion("else", sdfg) + e0 = else_region.add_state("e0", is_start_block=True) + tk_e = e0.add_tasklet("g", {"inp"}, {"out"}, "out = inp") + e0.add_edge(e0.add_access("B"), None, tk_e, "inp", dace.Memlet("B[0]")) + e0.add_edge(tk_e, "out", e0.add_access("C"), None, dace.Memlet("C[0]")) + cond.add_branch(None, else_region) + + post = sdfg.add_state("post") + sdfg.add_edge(cond, post, dace.InterstateEdge()) + sdfg.validate() + + rng = np.random.default_rng(90) + B = rng.random(128) + + def oracle(base): + pick = base - 5 + v = base + return B[v] if pick > 0 else B[0] + + bases = (2, 5, 6, 20) + expected = {b: oracle(b) for b in bases} + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for b in bases: + got = np.zeros(1) + sdfg(B=B.copy(), C=got, base=b) + assert np.allclose(got[0], expected[b]), b + + +# --------------------------------------------------------------------------- +# Pattern F: diamond merge where both branches reduce to the same value +# --------------------------------------------------------------------------- + + +def test_diamond_merge_equal_via_propagation_api(): + """ + Two branches assign ``m`` to syntactically different but value-equal expressions. + + Upstream ``half = base`` is assigned. The then-branch sets ``m = base + base`` + and the else-branch sets ``m = half + base``; both equal ``2 * base`` once + ``half`` is propagated. A correct pass may collapse the merge to a uniform + value, but the post-pass result for ``B[m]`` must be unchanged regardless of + branch. Guards against over-propagation that picks one branch's syntactic form + and drops the other's data dependence. + """ + sdfg = dace.SDFG("diamond_equal_api") + sdfg.add_array("B", [256], dace.float64) + sdfg.add_array("C", [1], dace.float64) + sdfg.add_symbol("base", dace.int64) + sdfg.add_symbol("half", dace.int64) + sdfg.add_symbol("m", dace.int64) + sdfg.add_symbol("sel", dace.int64) + + init = sdfg.add_state("init", is_start_block=True) + pre = sdfg.add_state("pre") + sdfg.add_edge(init, pre, dace.InterstateEdge(assignments={"half": "base"})) + + cond = ConditionalBlock("cond", sdfg) + sdfg.add_node(cond) + sdfg.add_edge(pre, cond, dace.InterstateEdge()) + + then_region = ControlFlowRegion("then", sdfg) + t0 = then_region.add_state("t0", is_start_block=True) + t1 = then_region.add_state("t1") + then_region.add_edge(t0, t1, dace.InterstateEdge(assignments={"m": "base + base"})) + cond.add_branch(CodeBlock("sel > 0"), then_region) + + else_region = ControlFlowRegion("else", sdfg) + e0 = else_region.add_state("e0", is_start_block=True) + e1 = else_region.add_state("e1") + else_region.add_edge(e0, e1, dace.InterstateEdge(assignments={"m": "half + base"})) + cond.add_branch(None, else_region) + + post = sdfg.add_state("post") + sdfg.add_edge(cond, post, dace.InterstateEdge()) + tk = post.add_tasklet("g", {"inp"}, {"out"}, "out = inp") + post.add_edge(post.add_access("B"), None, tk, "inp", dace.Memlet("B[m]")) + post.add_edge(tk, "out", post.add_access("C"), None, dace.Memlet("C[0]")) + sdfg.validate() + + rng = np.random.default_rng(91) + B = rng.random(256) + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for sel, base in ((1, 7), (0, 7), (1, 50), (0, 100)): + expected = np.array([B[2 * base]]) + got = np.zeros(1) + sdfg(B=B.copy(), C=got, base=base, sel=sel) + assert np.allclose(got, expected), (sel, base) + + +def test_diamond_merge_unequal_must_not_propagate_api(): + """ + Diamond where branches assign ``m`` to genuinely different values. + + Then-branch ``m = base + 1``, else-branch ``m = base + 9``; the join is + non-uniform, so the pass must not propagate either value past the merge. The + later ``B[m]`` access must reflect the branch actually taken. Companion to + the equal-value diamond above (the join correctness boundary). + """ + sdfg = dace.SDFG("diamond_unequal_api") + sdfg.add_array("B", [128], dace.float64) + sdfg.add_array("C", [1], dace.float64) + sdfg.add_symbol("base", dace.int64) + sdfg.add_symbol("m", dace.int64) + sdfg.add_symbol("sel", dace.int64) + + pre = sdfg.add_state("pre", is_start_block=True) + cond = ConditionalBlock("cond", sdfg) + sdfg.add_node(cond) + sdfg.add_edge(pre, cond, dace.InterstateEdge()) + + then_region = ControlFlowRegion("then", sdfg) + t0 = then_region.add_state("t0", is_start_block=True) + t1 = then_region.add_state("t1") + then_region.add_edge(t0, t1, dace.InterstateEdge(assignments={"m": "base + 1"})) + cond.add_branch(CodeBlock("sel > 0"), then_region) + + else_region = ControlFlowRegion("else", sdfg) + e0 = else_region.add_state("e0", is_start_block=True) + e1 = else_region.add_state("e1") + else_region.add_edge(e0, e1, dace.InterstateEdge(assignments={"m": "base + 9"})) + cond.add_branch(None, else_region) + + post = sdfg.add_state("post") + sdfg.add_edge(cond, post, dace.InterstateEdge()) + tk = post.add_tasklet("g", {"inp"}, {"out"}, "out = inp") + post.add_edge(post.add_access("B"), None, tk, "inp", dace.Memlet("B[m]")) + post.add_edge(tk, "out", post.add_access("C"), None, dace.Memlet("C[0]")) + sdfg.validate() + + rng = np.random.default_rng(92) + B = rng.random(128) + + def oracle(sel, base): + m = base + 1 if sel > 0 else base + 9 + return B[m] + + cases = ((1, 4), (0, 4), (1, 30), (0, 30)) + expected = {c: oracle(*c) for c in cases} + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for c in cases: + got = np.zeros(1) + sdfg(B=B.copy(), C=got, base=c[1], sel=c[0]) + assert np.allclose(got[0], expected[c]), c + + +# --------------------------------------------------------------------------- +# Pattern G: chained simultaneous assignments feeding an index, then B[idx] +# --------------------------------------------------------------------------- + + +def test_simultaneous_index_pair_then_use_api(): + """ + Edge ``{lo: clo, hi: chi}`` (index swap via capture temps) with both used as indices. + + Upstream ``{lo: base, hi: base + 20}`` then capture ``{clo: hi, chi: lo}``; + the swap edge ``{lo: clo, hi: chi}`` is VALID (RHSes are capture temps, not + co-assigned keys), so the pre-pass SDFG validates. ``out_syms`` carries + ``clo -> hi`` and ``chi -> lo``. Substituting them into the swap edge yields + ``{lo: hi, hi: lo}`` -- a same-edge read-write race on both ``lo`` and ``hi``. + The consuming state reads ``B[lo]`` and ``B[hi]`` after the swap, so any + corruption from a mishandled substitution is observable. + """ + sdfg = dace.SDFG("simul_index_pair_api") + sdfg.add_array("B", [128], dace.float64) + sdfg.add_array("C", [2], dace.float64) + sdfg.add_symbol("base", dace.int64) + sdfg.add_symbol("lo", dace.int64) + sdfg.add_symbol("hi", dace.int64) + sdfg.add_symbol("clo", dace.int64) + sdfg.add_symbol("chi", dace.int64) + + s0 = sdfg.add_state("s0", is_start_block=True) + s1 = sdfg.add_state("s1") + s2 = sdfg.add_state("s2") + s3 = sdfg.add_state("s3") + sdfg.add_edge(s0, s1, dace.InterstateEdge(assignments={"lo": "base", "hi": "base + 20"})) + # Capture the crossed values, then assign from captures (valid simultaneous swap). + sdfg.add_edge(s1, s2, dace.InterstateEdge(assignments={"clo": "hi", "chi": "lo"})) + sdfg.add_edge(s2, s3, dace.InterstateEdge(assignments={"lo": "clo", "hi": "chi"})) + + t0 = s3.add_tasklet("g0", {"inp"}, {"out"}, "out = inp") + s3.add_edge(s3.add_access("B"), None, t0, "inp", dace.Memlet("B[lo]")) + s3.add_edge(t0, "out", s3.add_access("C"), None, dace.Memlet("C[0]")) + t1 = s3.add_tasklet("g1", {"inp"}, {"out"}, "out = inp") + s3.add_edge(s3.add_access("B"), None, t1, "inp", dace.Memlet("B[hi]")) + s3.add_edge(t1, "out", s3.add_access("C"), None, dace.Memlet("C[1]")) + sdfg.validate() + + rng = np.random.default_rng(93) + B = rng.random(128) + + def oracle(base): + lo, hi = base, base + 20 + lo, hi = hi, lo # simultaneous swap via captures + return np.array([B[lo], B[hi]]) + + bases = (0, 5, 30, 90) + expected = {b: oracle(b) for b in bases} + + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + for b in bases: + got = np.zeros(2) + sdfg(B=B.copy(), C=got, base=b) + assert np.allclose(got, expected[b]), b + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/passes/symbol_propagation_test.py b/tests/passes/symbol_propagation_test.py index e085b20ae5..2a445a4447 100644 --- a/tests/passes/symbol_propagation_test.py +++ b/tests/passes/symbol_propagation_test.py @@ -6,7 +6,7 @@ from dace.properties import CodeBlock from dace.sdfg.state import LoopRegion, ConditionalBlock, ControlFlowRegion from dace.transformation.interstate import LoopToMap -from dace.transformation.passes import SymbolPropagation +from dace.transformation.passes import SymbolPropagation, ScalarToSymbolPromotion def _count_loops(sdfg: dace.SDFG): @@ -235,9 +235,19 @@ def test_deeply_nested_sdfg(): SymbolPropagation().apply_pass(sdfg1, {}) sdfg1.validate() - # No assignment should have been changed - assert edge1.data.assignments["v"] == "a" - assert edge4.data.assignments["c"] == "v+1" + # The outer iedge ``v = a`` was the only binding of ``v``; with propagation reaching + # the NSDFG ``symbol_mapping`` (``{"v": "v"}`` -> ``{"v": "a"}``) the binding is + # dead and gets swept, taking the now-unused ``v`` declaration with it so the + # nested chain remains self-consistent. Same fate for the inner ``c = v+1``: its + # destination state has no readers of ``c``, so the binding + the ``c`` declaration + # both drop. + assert "v" not in edge1.data.assignments, ( + f"propagation should have substituted v->a everywhere and dropped the dead binding; " + f"got {dict(edge1.data.assignments)}") + assert "v" not in sdfg1.symbols, "declaration of v should be removed with its binding" + assert "c" not in edge4.data.assignments, ( + f"unused c=v+1 binding should be swept; got {dict(edge4.data.assignments)}") + assert "c" not in sdfg4.symbols, "declaration of c should be removed with its binding" def test_scalars(): @@ -273,6 +283,200 @@ def test_scalars(): assert A[0] == 5 +def test_cloudsc_kidia_kfdia_promote_then_propagate(): + """CloudSC subset: scalar arguments ``kidia`` / ``kfdia`` used as the + inclusive horizontal loop bound ``range(kidia, kfdia + 1)`` across several + level nests (the ``DO JK=1,KLEV; DO JL=KIDIA,KFDIA`` shape). ``simplify`` + promotes ``kfdia + 1`` to per-nest symbols ``kfdia_plus_1_N = kfdia + 1``. + + SymbolPropagation alone does NOT fold them: ``kfdia`` is a non-transient + scalar ARGUMENT, so values referencing it are (correctly) skipped by the + scalar filter -- the pass is a no-op (``apply_pass`` returns ``None``). + Promoting the scalar arguments to symbols first with + ``ScalarToSymbolPromotion(transients_only=False)`` makes ``kfdia`` a symbol, + after which SymbolPropagation folds ``kfdia_plus_1 -> (kfdia + 1)``. The + scalar-skip filter itself is unchanged (genuine scalars are still skipped -- + see ``test_scalars``). Value-preserving throughout.""" + klev, klon = dace.symbol('klev'), dace.symbol('klon') + + @dace.program + def cloudsc_kidia_kfdia(pt: dace.float64[klev, klon], ptend: dace.float64[klev, klon], kidia: dace.int32, + kfdia: dace.int32): + for jk in range(klev): + for jl in range(kidia, kfdia + 1): + ptend[jk, jl] = pt[jk, jl] * 2.0 + for jk in range(klev): + for jl in range(kidia, kfdia + 1): + ptend[jk, jl] = ptend[jk, jl] + 3.0 + + def _kfdia_plus1_syms(g): + return {k for e in g.all_interstate_edges() for k in e.data.assignments if k.startswith('kfdia_plus_1')} + + nlev, nlon = 5, 8 + rng = np.random.default_rng(0) + pt = rng.standard_normal((nlev, nlon)) + + # Reference (un-promoted) output. + ref = cloudsc_kidia_kfdia.to_sdfg(simplify=True) + ref_out = np.zeros((nlev, nlon)) + ref(pt=pt.copy(), ptend=ref_out, kidia=0, kfdia=nlon - 1, klev=nlev, klon=nlon) + + # (1) Without promotion: kfdia is a scalar argument -> symprop is a no-op. + sdfg = cloudsc_kidia_kfdia.to_sdfg(simplify=True) + assert _kfdia_plus1_syms(sdfg), 'simplify should promote kfdia + 1 to kfdia_plus_1 symbols' + assert isinstance(sdfg.arrays.get('kfdia'), dace.data.Scalar) + assert SymbolPropagation().apply_pass(sdfg, {}) is None, \ + 'symprop must skip values referencing the scalar argument kfdia (no-op)' + + # (2) Promote the scalar arguments to symbols first, then symprop folds them. + sdfg2 = cloudsc_kidia_kfdia.to_sdfg(simplify=True) + s2s = ScalarToSymbolPromotion() + s2s.transients_only = False + promoted = s2s.apply_pass(sdfg2, {}) + assert promoted and {'kidia', 'kfdia'} <= promoted, f'expected kidia/kfdia promoted, got {promoted}' + assert 'kfdia' in sdfg2.symbols and 'kfdia' not in sdfg2.arrays + + ret = SymbolPropagation().apply_pass(sdfg2, {}) + assert ret is not None and any(s.startswith('kfdia_plus_1') for s in ret), \ + f'after promotion symprop must fold kfdia_plus_1 -> (kfdia + 1); propagated={ret}' + sdfg2.validate() + + # Value-preserving (kidia/kfdia are now symbols). + out2 = np.zeros((nlev, nlon)) + sdfg2(pt=pt.copy(), ptend=out2, kidia=0, kfdia=nlon - 1, klev=nlev, klon=nlon) + assert np.allclose(out2, ref_out) + assert np.allclose(out2, pt * 2.0 + 3.0) + + +_SP_N = dace.symbol("_SP_N") + + +@dace.program +def _carried_index_kernel(a: dace.float64[_SP_N], b: dace.float64[_SP_N], c: dace.float64[_SP_N], + d: dace.float64[_SP_N]): + j = -1 + for i in range(_SP_N // 2): + k = j + 1 + a[i] = b[k] - d[i] + j = k + 1 + b[k] = a[i] + c[k] + + +def test_carried_index_symbol_not_propagated_stale(): + """Reproducer (TSVC s128): a loop-carried index ``k = j + 1`` must not be + propagated into a downstream block as ``j + 1`` once the loop has reassigned + ``j = k + 1``. There the live ``j`` is two ahead, so the stale expression is an + off-by-two on ``b[k]`` / ``c[k]``. SymbolPropagation must keep ``k`` live; this + checks the propagated SDFG still matches the un-propagated reference.""" + import copy + n = 64 + rng = np.random.default_rng(0) + base = {name: rng.random(n) for name in "abcd"} + + ref = _carried_index_kernel.to_sdfg(simplify=True) + cand = copy.deepcopy(ref) + SymbolPropagation().apply_pass(cand, {}) + cand.validate() + + ra = {name: arr.copy() for name, arr in base.items()} + ref(**ra, _SP_N=n) + ca = {name: arr.copy() for name, arr in base.items()} + cand(**ca, _SP_N=n) + for name in "abcd": + assert np.allclose(ra[name], ca[name]), f"{name}: SymbolPropagation changed the result" + + +def test_dead_iedge_assignment_eliminated_after_substitution(): + """A bound-symbol shorthand iedge assignment (``k_plus_1 = klev + 1``) survived + symbol_propagation: its uses got substituted to ``klev + 1`` but the defining + assignment was left in place. The fix sweeps such dead assignments to a fixed + point at the end of the pass; nothing references ``k_plus_1`` after the + substitution, so the iedge ends with an empty ``assignments`` dict. + """ + sdfg = dace.SDFG('dead_iedge_repro') + sdfg.add_array('out', [16], dace.float64) + sdfg.add_symbol('klev', dace.int32) + s1 = sdfg.add_state('s1', is_start_block=True) + s2 = sdfg.add_state('s2') + sdfg.add_edge(s1, s2, dace.InterstateEdge(assignments={'k_plus_1': '(klev + 1)'})) + + t = s2.add_tasklet('t', {}, {'_o'}, '_o = 1.0') + w = s2.add_write('out') + s2.add_edge(t, '_o', w, None, dace.Memlet(data='out', subset='k_plus_1')) + sdfg.validate() + + res = SymbolPropagation().apply_pass(sdfg, {}) + assert res == {'k_plus_1'}, f'expected k_plus_1 to be reported propagated; got {res}' + + surviving = [(lhs, rhs) for e in sdfg.all_interstate_edges() for lhs, rhs in e.data.assignments.items()] + assert surviving == [], f'dead k_plus_1 assignment must be eliminated; got {surviving}' + + # The substitution must reach the memlet: the write to s2's ``out`` now indexes + # ``klev + 1`` directly, not via the shorthand symbol. + seen = [] + for st in sdfg.states(): + for e in st.edges(): + if e.data is not None and e.data.data == 'out': + seen.append(str(e.data.subset)) + assert 'klev + 1' in seen, f'expected memlet subset to be substituted to klev+1; got {seen}' + + +def test_dead_iedge_chain_unravels_to_fixed_point(): + """Chained shorthands (``a = klev + 1; b = a; c = b``) must all be eliminated + once their uses are substituted -- the cleanup sweep iterates to a fixed point.""" + sdfg = dace.SDFG('chain_repro') + sdfg.add_array('out', [16], dace.float64) + sdfg.add_symbol('klev', dace.int32) + s1 = sdfg.add_state('s1', is_start_block=True) + s2 = sdfg.add_state('s2') + s3 = sdfg.add_state('s3') + s4 = sdfg.add_state('s4') + sdfg.add_edge(s1, s2, dace.InterstateEdge(assignments={'a': '(klev + 1)'})) + sdfg.add_edge(s2, s3, dace.InterstateEdge(assignments={'b': 'a'})) + sdfg.add_edge(s3, s4, dace.InterstateEdge(assignments={'c': 'b'})) + + t = s4.add_tasklet('t', {}, {'_o'}, '_o = 2.0') + w = s4.add_write('out') + s4.add_edge(t, '_o', w, None, dace.Memlet(data='out', subset='c')) + sdfg.validate() + + SymbolPropagation().apply_pass(sdfg, {}) + + surviving = [(lhs, rhs) for e in sdfg.all_interstate_edges() for lhs, rhs in e.data.assignments.items()] + assert surviving == [], f'every link of the dead chain must be eliminated; got {surviving}' + + +def test_dead_iedge_with_array_shape_substituted_into_descriptor(): + """A symbol referenced *only* by an array descriptor's shape (cloudsc's + ``[0:kfdia_plus_1, 0:klon]`` pattern) used to keep the defining iedge alive + because the IR-level ``replace_dict`` does not reach into descriptors. The + fix substitutes the symbol into descriptors as a final step before + elimination, so the array shape becomes ``kfdia + 1`` directly and the + iedge drops.""" + sdfg = dace.SDFG('array_shape_repro') + sdfg.add_symbol('klev', dace.int32) + sdfg.add_symbol('k_plus_1', dace.int32) + sdfg.add_array('out', ['k_plus_1'], dace.float64) + s1 = sdfg.add_state('s1', is_start_block=True) + s2 = sdfg.add_state('s2') + sdfg.add_edge(s1, s2, dace.InterstateEdge(assignments={'k_plus_1': '(klev + 1)'})) + + t = s2.add_tasklet('t', {}, {'_o'}, '_o = 3.0') + w = s2.add_write('out') + s2.add_edge(t, '_o', w, None, dace.Memlet(data='out', subset='0')) + sdfg.validate() + + SymbolPropagation().apply_pass(sdfg, {}) + + surviving = [(lhs, rhs) for e in sdfg.all_interstate_edges() for lhs, rhs in e.data.assignments.items()] + assert surviving == [], (f'k_plus_1 should have been substituted into the array shape and the ' + f'binding dropped; got {surviving}') + shape_str = ', '.join(str(s) for s in sdfg.arrays['out'].shape) + assert 'klev' in shape_str and 'k_plus_1' not in shape_str, ( + f'array shape must read klev + 1 directly; got {shape_str}') + assert 'k_plus_1' not in sdfg.symbols, 'declaration of k_plus_1 should be removed with its binding' + + if __name__ == "__main__": test_loop_carried_symbol() test_nested_loop_carried_symbol() @@ -281,3 +485,8 @@ def test_scalars(): test_multiple_edge_assignments() test_deeply_nested_sdfg() test_scalars() + test_cloudsc_kidia_kfdia_promote_then_propagate() + test_carried_index_symbol_not_propagated_stale() + test_dead_iedge_assignment_eliminated_after_substitution() + test_dead_iedge_chain_unravels_to_fixed_point() + test_dead_iedge_preserved_when_lhs_still_used() diff --git a/tests/sdfg/exit_arglist_test.py b/tests/sdfg/exit_arglist_test.py new file mode 100644 index 0000000000..6a1c911d6b --- /dev/null +++ b/tests/sdfg/exit_arglist_test.py @@ -0,0 +1,67 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +"""Regression tests for `DataflowGraphView.unordered_arglist`'s +`AccessNode/CodeNode -> ExitNode` branch. + +The branch collects the external arrays a map writes through its MapExit. +It iterates the matching out-edges of the MapExit and, for each one, has to +name the actual destination array. Pre-fix it trusted `oedge.data.data`; an +outgoing memlet that still names an inner transient (instead of the outer +destination) then dropped the real array -- and its shape/stride symbols -- +from the arglist, and codegen later emitted a kernel signature that +referenced an undeclared identifier. + +The fix resolves the destination from the memlet path's terminal AccessNode +(equivalently, the memlet tree's root in this branch). +""" +import dace +from dace.sdfg import nodes as nd + + +def _build_src_relative_exit_sdfg() -> dace.SDFG: + """Construct the pathological shape: outgoing memlet from MapExit to + AccessNode('D') has `data='tmp'` (the inner transient), not `data='D'`. + """ + sdfg = dace.SDFG('exit_write_src_rel') + sdfg.add_array('D', [10], dace.float64) + sdfg.add_scalar('tmp', dace.float64, transient=True) + + state = sdfg.add_state('s', is_start_block=True) + me, mx = state.add_map('m', dict(i='0:10')) + t = state.add_tasklet('w', set(), {'_o'}, '_o = 1.0') + tmp = state.add_access('tmp') + d_an = state.add_write('D') + + state.add_edge(me, None, t, None, dace.Memlet()) + state.add_edge(t, '_o', tmp, None, dace.Memlet('tmp[0]')) + # Source-relative memlet INSIDE the map: names the inner transient. + state.add_edge(tmp, None, mx, 'IN_D', dace.Memlet(data='tmp', subset='0', other_subset='i')) + # Source-relative memlet OUTSIDE the map: still names the inner transient + # (this is the pathological field -- the fix tolerates it instead of + # propagating 'tmp' into the arglist). + state.add_edge(mx, 'OUT_D', d_an, None, dace.Memlet(data='tmp', subset='0')) + mx.add_in_connector('IN_D', force=True) + mx.add_out_connector('OUT_D', force=True) + sdfg.validate() + return sdfg + + +def test_arglist_resolves_outer_destination_from_source_relative_outgoing_memlet(): + """`unordered_arglist` must surface the OUTER destination array (``D``) + even when the outgoing memlet from MapExit names an inner transient + (``tmp``). Pre-fix it returned 'tmp' in the arglist and dropped 'D' -- + a downstream codegen would then emit a kernel signature that references + 'D' without declaring it (`identifier "D" is undefined`). + """ + sdfg = _build_src_relative_exit_sdfg() + state = next(iter(sdfg.states())) + me = next(n for n in state.nodes() if isinstance(n, nd.MapEntry)) + + arglist = state.scope_subgraph(me).arglist() + assert 'D' in arglist, f"outer destination 'D' missing from arglist: {sorted(arglist.keys())}" + # The inner transient must NOT appear in the kernel arglist (it lives + # inside the scope; arglist exposes only externally-visible arrays). + assert 'tmp' not in arglist, f"inner transient 'tmp' leaked into arglist: {sorted(arglist.keys())}" + + +if __name__ == '__main__': + test_arglist_resolves_outer_destination_from_source_relative_outgoing_memlet() diff --git a/tests/sdfg/free_symbols_test.py b/tests/sdfg/free_symbols_test.py index b0a59fb3af..90cc7d63a8 100644 --- a/tests/sdfg/free_symbols_test.py +++ b/tests/sdfg/free_symbols_test.py @@ -129,6 +129,113 @@ def test_nested_sdfg_free_symbols(): assert 'k' not in inner_sdfg.free_symbols +def _build_with_optional_unused_array(create_unused_transient: bool) -> dace.SDFG: + """ + Builds the issue #2382 reproducer SDFG: two used arrays plus an optional + transient array whose shape depends on the symbol ``x_shape`` but that is + never read, written, or allocated. + + :param create_unused_transient: If True, declare the unused ``x`` array. + :returns: The constructed SDFG. + """ + sdfg = dace.SDFG('unused_transient') + state = sdfg.add_state() + sdfg.add_array('a', (10, ), dace.float64, transient=False) + sdfg.add_array('b', (10, ), dace.float64, transient=False) + sdfg.add_symbol('x_shape', dace.int32) + if create_unused_transient: + sdfg.add_array('x', ('x_shape', ), dace.float32, transient=True) + state.add_mapped_tasklet('map', {'__i': '0:10'}, {'__in': dace.Memlet('a[__i]')}, + '__out = __in + 1.90', {'__out': dace.Memlet('b[__i]')}, + external_edges=True) + return sdfg + + +def test_unused_array_does_not_leak_shape_symbol(): + """ + Regression test for issue #2382: the shape symbol of an array that is + merely declared (never read, written, or allocated) must not leak into the + SDFG signature. Declaring the unused transient ``x`` must not change the + set of arguments needed to invoke the SDFG. + """ + without = _build_with_optional_unused_array(False) + with_unused = _build_with_optional_unused_array(True) + + # The unused array's shape symbol must not be treated as a used argument. + assert 'x_shape' not in without.used_symbols(all_symbols=False) + assert 'x_shape' not in with_unused.used_symbols(all_symbols=False) + + # Declaring the unused array must not perturb the signature at all. + assert 'x_shape' not in with_unused.arglist() + assert list(without.arglist().keys()) == list(with_unused.arglist().keys()) + assert without.signature_arglist() == with_unused.signature_arglist() + assert without.init_signature() == with_unused.init_signature() + assert 'x_shape' not in with_unused.init_signature() + + +def test_used_codeblock_array_keeps_shape_symbol(): + """ + The shape/stride symbols of an array that *is* used must be preserved even + when the only reference is in a control-flow code block. Here a + ``ConditionalBlock`` guard indexes a 2D array whose stride uses the free + symbol ``S``; ``S`` must remain in the SDFG's used symbols so that codegen + declares it. + """ + from dace.properties import CodeBlock + from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, LoopRegion + + sdfg = dace.SDFG('used_codeblock_array') + sdfg.add_symbol('S', dace.int32) + sdfg.add_array('A', (10, 10), dace.int32, strides=(1, dace.symbol('S'))) + sdfg.add_scalar('acc', dace.int32, transient=True) + + loop = LoopRegion('loop', condition_expr='k < 5', loop_var='k', initialize_expr='k = 0', update_expr='k = k + 1') + sdfg.add_node(loop, is_start_block=True) + + cb = ConditionalBlock('cb') + loop.add_node(cb, is_start_block=True) + branch = ControlFlowRegion('branch', sdfg=sdfg) + cb.add_branch(CodeBlock('A[0, k] == 1'), branch) + + set_one = branch.add_state('set_one', is_start_block=True) + t1 = set_one.add_tasklet('t_set', {}, {'o'}, 'o = 1') + set_one.add_edge(t1, 'o', set_one.add_write('acc'), None, dace.Memlet('acc[0]')) + + sdfg.validate() + + # ``A`` is referenced only in the conditional guard, but it is genuinely + # used; its stride symbol ``S`` must therefore be kept. + assert 'S' in sdfg.used_symbols(all_symbols=False) + assert 'S' in sdfg.init_signature() + + +def test_used_array_keeps_symbolic_extent(): + """ + An array that is used only through a map memlet (no top-level access node + and no code-block reference) must still contribute its symbolic shape and + stride symbols to the SDFG signature. This guards against the fix for + issue #2382 being too aggressive and dropping a genuinely needed extent + symbol. + """ + n = dace.symbol('n') + s = dace.symbol('s') + + sdfg = dace.SDFG('used_via_map') + sdfg.add_array('a', (n, ), dace.float64, strides=(s, ), transient=False) + sdfg.add_array('b', (n, ), dace.float64, transient=False) + state = sdfg.add_state() + state.add_mapped_tasklet('m', {'__i': '0:n'}, {'__in': dace.Memlet('a[__i]')}, + '__out = __in + 1.0', {'__out': dace.Memlet('b[__i]')}, + external_edges=True) + sdfg.validate() + + used = sdfg.used_symbols(all_symbols=False) + assert 'n' in used + assert 's' in used + assert 'n' in sdfg.arglist() + assert 's' in sdfg.arglist() + + if __name__ == '__main__': test_single_state() test_state_subgraph() @@ -136,3 +243,6 @@ def test_nested_sdfg_free_symbols(): test_constants() test_interstate_edge_symbols() test_nested_sdfg_free_symbols() + test_unused_array_does_not_leak_shape_symbol() + test_used_codeblock_array_keeps_shape_symbol() + test_used_array_keeps_symbolic_extent() diff --git a/tests/sdfg/symbols_defined_at_test.py b/tests/sdfg/symbols_defined_at_test.py new file mode 100644 index 0000000000..91dcab9d3f --- /dev/null +++ b/tests/sdfg/symbols_defined_at_test.py @@ -0,0 +1,257 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +"""Tests for :meth:`dace.sdfg.state.SDFGState.symbols_defined_at`. + +Locks down which symbols are reported as "defined" at a given node. Memlet +propagation across ``NestedSDFG`` boundaries reads this set to decide which +symbols may appear in the propagated outer memlet -- any inner subset symbol +that is NOT in this set falls back to the array's full extent (see +``propagate_memlets_nested_sdfg`` in ``dace/sdfg/propagation.py``). Missing +symbols here therefore widen propagation and hide loop-carried dependencies +downstream (the cloudsc ``for_1133`` / ``for_430`` shape). + +The test file exercises both contributors that ``symbols_defined_at`` walks: + +* enclosing ``LoopRegion`` loop variables (the for-loop side) +* enclosing dataflow scope nodes' new symbols (the map side, including + ``MapEntry``'s ``Map`` parameters AND any non-pass-through input connector + that supplies a dynamic Map range or a scope-local parameter) +""" +import dace +import pytest + +from dace.sdfg.state import LoopRegion + +N = dace.symbol('N') + + +def test_global_sdfg_symbol_visible(): + """Global ``SDFG.symbols`` entries reach every node.""" + sdfg = dace.SDFG('g') + sdfg.add_symbol('K', dace.int32) + sdfg.add_array('A', [10], dace.float64) + s = sdfg.add_state('s') + t = s.add_tasklet('t', {}, {'o'}, 'o = K') + w = s.add_write('A') + s.add_edge(t, 'o', w, None, dace.Memlet('A[0]')) + + syms = s.symbols_defined_at(t) + assert 'K' in syms + + +def test_enclosing_loop_region_var_visible_at_state_node(): + """The loop variable of an enclosing ``LoopRegion`` is reported as defined.""" + sdfg = dace.SDFG('one_loop') + sdfg.add_array('A', [N], dace.float64) + loop = LoopRegion('L', 'i < N', 'i', 'i = 0', 'i = i + 1') + sdfg.add_node(loop) + s = loop.add_state('s', is_start_block=True) + t = s.add_tasklet('t', {}, {'o'}, 'o = i') + w = s.add_write('A') + s.add_edge(t, 'o', w, None, dace.Memlet('A[i]')) + + syms = s.symbols_defined_at(t) + assert 'i' in syms, 'enclosing LoopRegion loop variable must be visible' + + +def test_triple_nested_loop_regions_all_visible(): + """Every enclosing loop variable up the parent-region chain is visible.""" + sdfg = dace.SDFG('triple') + sdfg.add_array('A', [10], dace.float64) + L1 = LoopRegion('L1', 'a < 10', 'a', 'a = 0', 'a = a + 1') + sdfg.add_node(L1) + L2 = LoopRegion('L2', 'b < 10', 'b', 'b = 0', 'b = b + 1') + L1.add_node(L2) + L3 = LoopRegion('L3', 'c < 10', 'c', 'c = 0', 'c = c + 1') + L2.add_node(L3) + s = L3.add_state('s', is_start_block=True) + t = s.add_tasklet('t', {}, {'o'}, 'o = a + b + c') + w = s.add_write('A') + s.add_edge(t, 'o', w, None, dace.Memlet('A[a]')) + + syms = s.symbols_defined_at(t) + assert {'a', 'b', 'c'} <= set(syms), f'all enclosing loop vars must be visible; got {sorted(syms)}' + + +def test_map_iteration_var_visible_inside_scope(): + """The ``MapEntry``'s map parameter is visible at every node inside the scope.""" + sdfg = dace.SDFG('m') + sdfg.add_array('A', [10], dace.float64) + s = sdfg.add_state('s') + me, mx = s.add_map('m', {'i': '0:10'}) + me.add_in_connector('IN_A') + me.add_out_connector('OUT_A') + A_r = s.add_read('A') + A_w = s.add_write('A') + t = s.add_tasklet('t', {'a'}, {'o'}, 'o = a') + s.add_edge(A_r, None, me, 'IN_A', dace.Memlet('A[0:10]')) + s.add_edge(me, 'OUT_A', t, 'a', dace.Memlet('A[i]')) + mx.add_in_connector('IN_A') + mx.add_out_connector('OUT_A') + s.add_edge(t, 'o', mx, 'IN_A', dace.Memlet('A[i]')) + s.add_edge(mx, 'OUT_A', A_w, None, dace.Memlet('A[0:10]')) + + syms = s.symbols_defined_at(t) + assert 'i' in syms, 'Map iteration variable must be visible inside the scope' + + +def test_nested_map_iteration_vars_both_visible(): + """Both outer and inner ``Map`` parameters are visible at the innermost node.""" + sdfg = dace.SDFG('nested_m') + sdfg.add_array('A', [10, 10], dace.float64) + s = sdfg.add_state('s') + + me_o, mx_o = s.add_map('mo', {'i': '0:10'}) + me_i, mx_i = s.add_map('mi', {'j': '0:10'}) + me_o.add_in_connector('IN_A') + me_o.add_out_connector('OUT_A') + me_i.add_in_connector('IN_A') + me_i.add_out_connector('OUT_A') + A_r = s.add_read('A') + A_w = s.add_write('A') + t = s.add_tasklet('t', {'a'}, {'o'}, 'o = a') + + s.add_edge(A_r, None, me_o, 'IN_A', dace.Memlet('A[0:10, 0:10]')) + s.add_edge(me_o, 'OUT_A', me_i, 'IN_A', dace.Memlet('A[i, 0:10]')) + s.add_edge(me_i, 'OUT_A', t, 'a', dace.Memlet('A[i, j]')) + mx_o.add_in_connector('IN_A') + mx_o.add_out_connector('OUT_A') + mx_i.add_in_connector('IN_A') + mx_i.add_out_connector('OUT_A') + s.add_edge(t, 'o', mx_i, 'IN_A', dace.Memlet('A[i, j]')) + s.add_edge(mx_i, 'OUT_A', mx_o, 'IN_A', dace.Memlet('A[i, 0:10]')) + s.add_edge(mx_o, 'OUT_A', A_w, None, dace.Memlet('A[0:10, 0:10]')) + + syms = s.symbols_defined_at(t) + assert {'i', 'j'} <= set(syms), f'both enclosing Map iter vars must be visible; got {sorted(syms)}' + + +def test_dynamic_non_passthrough_map_connector_visible(): + """A ``MapEntry`` input connector that does NOT start with ``IN_`` supplies + a dynamic Map-range / scope-local parameter (e.g. a runtime upper bound). + Its name must be reported as defined for nodes inside the scope so memlets + written in terms of that parameter survive ``NSDFG``-boundary propagation + without being widened to the array extent.""" + sdfg = dace.SDFG('dyn') + sdfg.add_array('A', [100], dace.int32) + sdfg.add_array('N_arr', [1], dace.int32) + sdfg.add_array('out', [100], dace.int32) + s = sdfg.add_state('s') + + # Map range parameterized by a dynamic input connector ``N_arr_val``. + me, mx = s.add_map('m', {'i': '0:N_arr_val'}) + me.add_in_connector('N_arr_val') + me.add_in_connector('IN_A') + me.add_out_connector('OUT_A') + N_read = s.add_read('N_arr') + s.add_edge(N_read, None, me, 'N_arr_val', dace.Memlet('N_arr[0]')) + A_read = s.add_read('A') + s.add_edge(A_read, None, me, 'IN_A', dace.Memlet('A[0:100]')) + + t = s.add_tasklet('t', {'a'}, {'o'}, 'o = a') + s.add_edge(me, 'OUT_A', t, 'a', dace.Memlet('A[i]')) + mx.add_in_connector('IN_o') + mx.add_out_connector('OUT_o') + s.add_edge(t, 'o', mx, 'IN_o', dace.Memlet('out[i]')) + out_w = s.add_write('out') + s.add_edge(mx, 'OUT_o', out_w, None, dace.Memlet('out[0:100]')) + + syms = s.symbols_defined_at(t) + assert 'i' in syms, 'Map iter var must be visible' + assert 'N_arr_val' in syms, ('non-pass-through Map in-connector must be reported as defined inside the scope; ' + 'without this, memlets that reference the dynamic-range parameter would widen ' + 'to the array extent when propagated outward.') + + +def test_loop_region_and_map_combined_visible(): + """Stacked enclosing scopes (``LoopRegion`` outside, ``Map`` inside): a + node inside the Map sees BOTH the LoopRegion's loop variable and the + Map's iter variable. This is the cloudsc-style stack -- outer scan loop + over levels, inner Map over columns.""" + sdfg = dace.SDFG('stack') + sdfg.add_array('A', [10, 20], dace.float64) + loop = LoopRegion('lk', 'jk < 10', 'jk', 'jk = 0', 'jk = jk + 1') + sdfg.add_node(loop) + s = loop.add_state('s', is_start_block=True) + me, mx = s.add_map('m', {'jl': '0:20'}) + me.add_in_connector('IN_A') + me.add_out_connector('OUT_A') + t = s.add_tasklet('t', {'a'}, {'o'}, 'o = a') + A_r = s.add_read('A') + A_w = s.add_write('A') + s.add_edge(A_r, None, me, 'IN_A', dace.Memlet('A[0:10, 0:20]')) + s.add_edge(me, 'OUT_A', t, 'a', dace.Memlet('A[jk, jl]')) + mx.add_in_connector('IN_A') + mx.add_out_connector('OUT_A') + s.add_edge(t, 'o', mx, 'IN_A', dace.Memlet('A[jk, jl]')) + s.add_edge(mx, 'OUT_A', A_w, None, dace.Memlet('A[0:10, 0:20]')) + + syms = s.symbols_defined_at(t) + assert 'jk' in syms, 'enclosing LoopRegion loop variable visible across Map scope' + assert 'jl' in syms, 'Map iter variable visible inside its own scope' + + +def test_nsdfg_inside_map_inside_loop_region_propagation_endpoint(): + """The exact contract that ``propagate_memlets_nested_sdfg`` relies on: + the symbols available at the ``NestedSDFG`` node include the enclosing + ``LoopRegion``'s loop variable. Without ``jk`` reported here, any subset + of the form ``arr[jk, ...]`` inside the nested SDFG widens to the array + extent on propagation out -- the cloudsc ``for_1133`` failure mode. + """ + sdfg = dace.SDFG('cloudsc_shape') + sdfg.add_symbol('K', dace.int32) + sdfg.add_array('A', [10, 20], dace.float64) + + loop = LoopRegion('lk', 'jk < K', 'jk', 'jk = 0', 'jk = jk + 1') + sdfg.add_node(loop) + s = loop.add_state('s', is_start_block=True) + me, mx = s.add_map('m', {'jl': '0:20'}) + me.add_in_connector('IN_A') + me.add_out_connector('OUT_A') + mx.add_in_connector('IN_A') + mx.add_out_connector('OUT_A') + + inner = dace.SDFG('inner') + inner.add_symbol('jk', dace.int32) + inner.add_symbol('jl', dace.int32) + inner.add_array('A', [10, 20], dace.float64) + si = inner.add_state('si') + t_inner = si.add_tasklet('t', {'a'}, {'o'}, 'o = a') + A_ri = si.add_read('A') + A_wi = si.add_write('A') + si.add_edge(A_ri, None, t_inner, 'a', dace.Memlet('A[jk, jl]')) + si.add_edge(t_inner, 'o', A_wi, None, dace.Memlet('A[jk, jl]')) + + nsdfg = s.add_nested_sdfg(inner, inputs={'A'}, outputs={'A'}, symbol_mapping={'jk': 'jk', 'jl': 'jl'}) + A_r = s.add_read('A') + A_w = s.add_write('A') + s.add_edge(A_r, None, me, 'IN_A', dace.Memlet('A[0:10, 0:20]')) + s.add_edge(me, 'OUT_A', nsdfg, 'A', dace.Memlet('A[0:10, 0:20]')) + s.add_edge(nsdfg, 'A', mx, 'IN_A', dace.Memlet('A[0:10, 0:20]')) + s.add_edge(mx, 'OUT_A', A_w, None, dace.Memlet('A[0:10, 0:20]')) + + syms = s.symbols_defined_at(nsdfg) + assert 'jk' in syms, 'enclosing LoopRegion var jk must be defined at the NSDFG node' + assert 'jl' in syms, 'enclosing Map iter var jl must be defined at the NSDFG node' + assert 'K' in syms, 'SDFG-global symbol K must be defined' + + +def test_outside_any_scope_is_empty_of_local_scope_symbols(): + """At a state node that sits OUTSIDE any Map scope and outside any + LoopRegion, no scope-local symbols are reported (only SDFG-global ones).""" + sdfg = dace.SDFG('flat') + sdfg.add_symbol('K', dace.int32) + sdfg.add_array('A', [10], dace.float64) + s = sdfg.add_state('s') + t = s.add_tasklet('t', {}, {'o'}, 'o = K') + w = s.add_write('A') + s.add_edge(t, 'o', w, None, dace.Memlet('A[0]')) + + syms = s.symbols_defined_at(t) + assert 'K' in syms + # No scope-local names from a hypothetical outer scope. + assert 'i' not in syms and 'j' not in syms + + +if __name__ == '__main__': + import sys + sys.exit(pytest.main([__file__, '-v'])) diff --git a/tests/sdfg/validation/subset_size_test.py b/tests/sdfg/validation/subset_size_test.py index 7c63f578b0..cce8a8a867 100644 --- a/tests/sdfg/validation/subset_size_test.py +++ b/tests/sdfg/validation/subset_size_test.py @@ -84,6 +84,36 @@ def test_an_to_an_memlet_with_negative_size(): sdfg.validate() +def test_veclen_lookup_guarded_on_non_accessnode_endpoint(): + """``validate_state``'s dimensionality-mismatch check used to dereference + ``sdfg.arrays[src_node.data].veclen`` unconditionally; scope nodes + (NestedSDFG / MapEntry / MapExit / ConsumeEntry / ConsumeExit) don't expose + ``.data`` and crashed with ``AttributeError`` whenever the edge carried + both ``src_subset`` and ``dst_subset`` (``other_subset is not None``). + + Build a NestedSDFG-output -> AccessNode edge with a reshape memlet (which + sets the other-subset) and assert validation reaches a verdict instead of + blowing up. Pre-fix this raised ``AttributeError: 'NestedSDFG' object has + no attribute 'data'``.""" + sdfg = dace.SDFG('veclen_guard_nsdfg_to_an') + sdfg.add_array('Y', [10], dace.float64) + state = sdfg.add_state('main', is_start_block=True) + + inner = dace.SDFG('inner') + inner.add_array('out_inner', [5], dace.float64) + istate = inner.add_state('s', is_start_block=True) + t = istate.add_tasklet('w', set(), {'_o'}, '_o = 1.0') + iw = istate.add_write('out_inner') + istate.add_edge(t, '_o', iw, None, dace.Memlet('out_inner[0]')) + + nsdfg = state.add_nested_sdfg(inner, {}, {'out_inner'}) + y = state.add_write('Y') + state.add_edge(nsdfg, 'out_inner', y, None, dace.Memlet('Y[0:5] -> [0:5]')) + + sdfg.validate() + + if __name__ == "__main__": test_an_to_an_memlet_with_zero_size() test_an_to_an_memlet_with_negative_size() + test_veclen_lookup_guarded_on_non_accessnode_endpoint() diff --git a/tests/symbolic_print_test.py b/tests/symbolic_print_test.py new file mode 100644 index 0000000000..66c88c97e5 --- /dev/null +++ b/tests/symbolic_print_test.py @@ -0,0 +1,149 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +import pytest +import sympy +from dace.symbolic import sympy_numeric_fix, pystr_to_symbolic, symstr + + +def test_float_zero_stays_float(): + """sympy.Float(0.0) must not be demoted to int(0).""" + result = sympy_numeric_fix(sympy.Float(0.0)) + assert isinstance(result, sympy.Float), \ + f"Float(0.0) demoted to {type(result).__name__}" + assert float(result) == 0.0 + + +def test_float_one_stays_float(): + """sympy.Float(1.0) must not be demoted to int(1).""" + result = sympy_numeric_fix(sympy.Float(1.0)) + assert isinstance(result, sympy.Float) + assert float(result) == 1.0 + + +def test_float_five_stays_float(): + """sympy.Float(5.0) must not be demoted to int(5).""" + result = sympy_numeric_fix(sympy.Float(5.0)) + assert isinstance(result, sympy.Float) + assert float(result) == 5.0 + + +def test_fractional_float_preserved(): + """sympy.Float(0.7) must stay Float(0.7).""" + result = sympy_numeric_fix(sympy.Float(0.7)) + assert isinstance(result, sympy.Float) + assert abs(float(result) - 0.7) < 1e-15 + + +def test_float_prints_clean(): + """5.0 should print as '5.0', not '5.00000000000000'.""" + result = sympy_numeric_fix(sympy.Float(5.0)) + s = symstr(result) + assert s == '5.0', f"Expected '5.0', got '{s}'" + + +def test_huge_python_int_becomes_oo(): + """Python int beyond float64 range must map to sympy.oo. + Original comment: int(1.8e308) == expr is True because Python + has variable-bit integers, but numpy.float64() overflows.""" + result = sympy_numeric_fix(10**309) + assert result == sympy.oo + + +def test_huge_negative_python_int_becomes_neg_oo(): + """Negative Python int beyond float64 range must map to -sympy.oo.""" + result = sympy_numeric_fix(-(10**309)) + assert result == -sympy.oo + + +def test_max_float_literal_roundtrip(): + """Parsing 'max(a, 0.0)' and printing via symstr must preserve '0.0', + not demote to '0'. Demotion causes Max(a, 0) in C++ codegen.""" + expr = pystr_to_symbolic("max(a, 0.0)") + result = symstr(expr, cpp_mode=True) + + # Must contain 0.0 (the float literal), not bare 0 (int literal) + assert "0.0" in result, f"Float literal 0.0 was demoted to int: '{result}'. " + + +def test_max_float_literal_not_int(): + """Complement: ensure the printed string does NOT match 'Max(a, 0)' + where 0 is an integer literal (no decimal point).""" + expr = pystr_to_symbolic("max(a, 0.0)") + result = symstr(expr, cpp_mode=True) + + # Strip spaces for robust matching + clean = result.replace(" ", "") + # Should not end with ,0) + assert not clean.endswith(",0)"), f"Got integer literal in Max call: '{result}'" + + +def test_max_int_literal_stays_int(): + """Parsing 'max(a, 0)' with an explicit int should keep it as int. + This is the correct behavior when the user wrote 0, not 0.0.""" + expr = pystr_to_symbolic("max(a, 0)") + result = symstr(expr, cpp_mode=True) + + # This one SHOULD have bare 0, not 0.0 + clean = result.replace(" ", "") + assert "0.0" not in clean, f"Integer literal 0 was promoted to float: '{result}'" + + +def test_cpp_floor_of_fraction_difference_recombines_to_integer_division(): + """sympy normalises ``(LEN - 1) // 8`` on integer symbols to + ``floor(LEN/8 - 1/8)``. The C++ printer must recombine the common-denominator + sum and emit a single ``((LEN - 1) / (8))`` integer division -- not + ``floor(LEN/8 - 1/8)`` (which in C++ collapses ``1/8`` to ``0`` and + overshoots the loop bound) and not the literal ``LEN/8 - 1/8`` string.""" + from dace.symbolic import DaceSympyPrinter + LEN = sympy.Symbol('LEN', integer=True) + expr = (LEN - 1) // 8 + out = DaceSympyPrinter(arrays={}, cpp_mode=True).doprint(expr) + + clean = out.replace(' ', '') + assert 'floor' not in clean, f'C++ printer must not emit floor(...); got {out!r}' + assert '1/8' not in clean, f'literal Rational 1/8 leaked into C++ output: {out!r}' + assert 'LEN-1' in clean, f'expected combined numerator (LEN - 1); got {out!r}' + + +@pytest.mark.parametrize( + 'value', + [ + 1.0 / 21.0, # FFT ifft factor; needs 17 sig digits to round-trip + 0.1 + 0.2, + 1e-300, + 1e300, + 1.7976931348623157e308, # near DBL_MAX + 1.0, + 5.0, + 3.14, + -0.0476190476190476, + 1234567890.1234567, + ]) +def test_format_float_is_idempotent_under_parse_and_reformat(value): + """The float-to-string serializer must be idempotent: ``f -> str -> f -> str`` + yields the same string as ``f -> str``. Otherwise SDFG save -> load -> save + fails the round-trip equality check the framework runs on every serialization + (e.g. ``tests/library/fft_test.py::test_ifft[backward]`` regressed because + ``1/21`` was emitted as 17 digits in one save and 15 in the next).""" + from dace.symbolic import _format_float + s1 = _format_float(value) + s2 = _format_float(float(s1)) + assert s1 == s2, f'_format_float not idempotent: {value!r} -> {s1!r} -> {s2!r}' + assert float(s1) == float(value), (f'_format_float loses precision for {value!r}: parsed back as {float(s1)!r}') + + +@pytest.mark.parametrize('value', [1.0 / 21.0, 0.1 + 0.2, 1e-300, 1e300, 5.0, 3.14]) +def test_serialize_symbolic_float_path_is_idempotent(value): + """``serialize_symbolic`` dispatches on type: ``isinstance(expr, float)`` is a + distinct branch from ``isinstance(expr, sympy.Basic)``. Both must produce the + SAME 17-sig-digit shortest-round-trip form -- otherwise a SymbolicProperty + that was set as a Python float gets a 15-digit string on save 1 (sympy's + default sstr) and a 17-digit string on save 2 (DaceSympySerializer), breaking + the SDFG save -> load -> save equality check (e.g. FFT/IFFT + ``factor = 1/21`` regressed this way). + """ + from dace.symbolic import serialize_symbolic, deserialize_symbolic + s1 = serialize_symbolic(value) + loaded = deserialize_symbolic(s1) + s2 = serialize_symbolic(loaded) + assert s1 == s2, (f'serialize_symbolic not idempotent across the float/sympy.Basic branches: ' + f'save 1 (float)={s1!r}, save 2 (sympy.Float)={s2!r}') diff --git a/tests/symbolic_roundtrip_test.py b/tests/symbolic_roundtrip_test.py index 0458ef92ca..d5ac9a8312 100644 --- a/tests/symbolic_roundtrip_test.py +++ b/tests/symbolic_roundtrip_test.py @@ -182,6 +182,25 @@ def test_float_precision_preserved(): assert pystr_to_symbolic(huge) not in (sympy.oo, -sympy.oo) +def test_integer_valued_float_not_collapsed(): + # An integer-valued float (e.g. the ``1.0`` clamp in ``min(x, 1.0)``) keeps its + # ``.0`` and never collapses to an int: collapsing would mix int and float in a + # Min/Max and silently truncate the result after a serialization round-trip. + for src in ('0.0', '1.0', '2.0', '5.0', '100.0'): + assert _roundtrip(src) == src + # Genuine integers are left untouched. + for src in ('0', '1', '2', '42'): + assert _roundtrip(src) == src + # ``sympy_numeric_fix`` preserves a Python/numpy float as a sympy Float (the int + # collapse only ever applied to non-float inputs). + assert isinstance(symbolic.sympy_numeric_fix(1.0), sympy.Float) + assert isinstance(symbolic.sympy_numeric_fix(2.0), sympy.Float) + assert isinstance(symbolic.sympy_numeric_fix(1), int) + # Inside a Min the float clamp survives the parse -> print round-trip. + assert _idempotent('Min(x, 1.0)') + assert '1.0' in _roundtrip('Min(x, 1.0)') + + def test_infinity_roundtrip(): assert pystr_to_symbolic('inf') == sympy.oo assert pystr_to_symbolic('-inf') == -sympy.oo diff --git a/tests/transformations/interstate/loop_to_map_test.py b/tests/transformations/interstate/loop_to_map_test.py index 27f90c55c6..675902c632 100644 --- a/tests/transformations/interstate/loop_to_map_test.py +++ b/tests/transformations/interstate/loop_to_map_test.py @@ -940,6 +940,110 @@ def test_dynamic_write_slab_separated_by_iteration_var(): assert not any(isinstance(n, LoopRegion) for n, _ in sdfg.all_nodes_recursive()) +def test_loop_to_map_with_loop_invariant_if(): + """A loop whose body is guarded by a loop-invariant condition is + parallelizable: ``for i: if c: b[i]=a[i]+1`` -> one map, value-preserving + for the guard taken and not-taken. (The guard becomes a per-iteration + ``NestedSDFG``-wrapped conditional inside the map.)""" + N = dace.symbol('N') + + @dace.program + def loop_invariant_if(a: dace.float64[N], b: dace.float64[N], c: dace.int32[1]): + for i in range(N): + if c[0] > 0: + b[i] = a[i] + 1.0 + + for cv in (1, 0): + sdfg = loop_invariant_if.to_sdfg(simplify=True) + assert sdfg.apply_transformations_repeated(LoopToMap) == 1 + sdfg.validate() + assert not any(isinstance(n, LoopRegion) for n, _ in sdfg.all_nodes_recursive()) + n = 12 + a = np.random.rand(n) + b = np.zeros(n) + sdfg(a=a.copy(), b=b, c=np.array([cv], np.int32), N=n) + assert np.allclose(b, a + 1.0 if cv > 0 else 0.0), f"mismatch c={cv}" + + +def test_loop_to_map_round_trip_through_nested_sdfg_recovers_map(): + """Parallelizing a loop-invariant-guarded loop, de-parallelizing it, then + re-parallelizing recovers the map. The ``LoopToMap->MapToForLoop`` round- + trip propagates a whole-array external write memlet (``b[i]`` -> ``b[0:N]``) + on the ``NestedSDFG`` the guard forces; the write-pattern check now looks + *past* the connector at the inner per-iteration write + (:func:`_nested_writes_iter_indexed`), so independence is still proven.""" + from dace.transformation.dataflow.map_for_loop import MapToForLoop + from dace.transformation.passes.pattern_matching import PatternMatchAndApplyRepeated + N = dace.symbol('N') + + @dace.program + def loop_invariant_if(a: dace.float64[N], b: dace.float64[N], c: dace.int32[1]): + for i in range(N): + if c[0] > 0: + b[i] = a[i] + 1.0 + + sdfg = loop_invariant_if.to_sdfg(simplify=True) + assert sdfg.apply_transformations_repeated(LoopToMap) == 1, "first LoopToMap must fire" + PatternMatchAndApplyRepeated([MapToForLoop()]).apply_pass(sdfg, {}) + sdfg.validate() + assert sdfg.apply_transformations_repeated(LoopToMap) == 1, \ + "re-parallelize must recover the map after the round-trip" + + +def test_refuse_when_body_assigns_loop_range_symbol(): + """A loop whose range expression reads a symbol that the loop body + re-assigns via an interstate-edge assignment must NOT be converted to + a Map: the assignment would move into the new ``loop_body`` NestedSDFG + with the body, but the Map's range stays at the outer scope -- so the + range ends up referencing a symbol defined only inside the new NSDFG + (``Missing symbols on nested SDFG: ['KP1']`` at validation time). + + This pattern shows up in the canonicalized cloudsc SDFG (the + ``kfdia_plus_1_N = kfdia + 1`` interstate-edge assignment ends up + inside a loop whose condition reads ``kfdia_plus_1_N``). The check is + a structural-soundness gate -- a strictly additive ``return False`` + that leaves the loop as a ``LoopRegion`` (sequential codegen still + handles it cleanly). + """ + from dace.memlet import Memlet + + sdfg = dace.SDFG("refuse_body_assigns_loop_range_symbol") + sdfg.add_symbol("K", dace.int32) + sdfg.add_symbol("KP1", dace.int32) + sdfg.add_array("a", (dace.symbol("K"), ), dace.float32) + + init = sdfg.add_state("init", is_start_block=True) + # ``KP1`` defined before the loop -- the body's interstate-edge + # ``KP1 = K + 1`` is a redundant re-assignment that nonetheless + # creates the bad-pattern shape. + sdfg.add_edge(init, init, dace.InterstateEdge(assignments={})) # placeholder no-op + + loop = LoopRegion("for_j", condition_expr="j < KP1", loop_var="j", initialize_expr="j = 0", update_expr="j = j + 1") + sdfg.add_node(loop) + sdfg.add_edge(init, loop, dace.InterstateEdge(assignments={"KP1": "K + 1"})) + + # Body: a single state with a write to ``a[j]``. Critically, an + # interstate edge inside the loop reassigns ``KP1`` (loop-invariant + # value, but structurally placed inside the loop body). + body = loop.add_state("body", is_start_block=True) + after = loop.add_state("after") + loop.add_edge(body, after, dace.InterstateEdge(assignments={"KP1": "K + 1"})) + t = body.add_tasklet("t", {}, {"o"}, "o = 1.0") + w = body.add_write("a") + body.add_edge(t, "o", w, None, Memlet("a[j]")) + + sdfg.validate() + + # The loop's range reads ``KP1``; the body has an interstate edge + # assigning ``KP1``. ``LoopToMap.can_be_applied`` must refuse. + applied = sdfg.apply_transformations_repeated(LoopToMap) + assert applied == 0, "LoopToMap must refuse to convert a loop whose range reads a body-assigned symbol" + + # The loop remains as a LoopRegion; SDFG stays valid. + assert any(isinstance(c, LoopRegion) for c in sdfg.all_control_flow_regions(recursive=True)) + sdfg.validate() + + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -982,3 +1086,4 @@ def test_dynamic_write_slab_separated_by_iteration_var(): test_nested_sdfg_nested_loop() test_stride_symbol_propagated_to_nested_sdfg() test_dynamic_write_slab_separated_by_iteration_var() + test_refuse_when_body_assigns_loop_range_symbol() diff --git a/tests/transformations/loop_to_map_disjoint_writes_test.py b/tests/transformations/loop_to_map_disjoint_writes_test.py new file mode 100644 index 0000000000..a75531b451 --- /dev/null +++ b/tests/transformations/loop_to_map_disjoint_writes_test.py @@ -0,0 +1,306 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" Tests that ``LoopToMap`` parallelizes a loop only when its writes are + provably non-overlapping across iterations. + + Each iteration of a loop must write disjoint locations for the loop to be a + valid map. Two affine write subscripts ``a1*i + b1`` and ``a2*i + b2`` into + the same container collide on some pair of iterations if and only if + ``gcd(a1, a2)`` divides ``b2 - b1``; otherwise they are provably disjoint + for any iteration range. +""" +import copy + +import numpy as np +import pytest + +import dace +from dace.sdfg import nodes +from dace.transformation.interstate import LoopToMap + +N = dace.symbol('N') +M = dace.symbol('M') + + +def _has_map(sdfg: dace.SDFG) -> bool: + """ True if any state in the (recursively expanded) SDFG contains a Map. """ + return any(isinstance(n, nodes.MapEntry) for n, _ in sdfg.all_nodes_recursive()) + + +@dace.program +def overlapping_writes(A: dace.int64[5 * N]): + for i in range(N): + A[5 * i] = 1 + A[3 * i] = 2 + + +@dace.program +def injective_write(A: dace.int64[2 * N]): + for i in range(N): + A[2 * i] = i + + +@dace.program +def disjoint_stride_writes(A: dace.int64[2 * N]): + for i in range(N): + A[2 * i] = 1 + A[2 * i + 1] = 2 + + +@dace.program +def shifted_writes(A: dace.int64[N + 1]): + for i in range(N): + A[i] = 1 + A[i + 1] = 2 + + +@dace.program +def disjoint_outer_dim(B: dace.int64[2 * N, 4]): + for i in range(N): + B[2 * i, :] = 1 + B[2 * i + 1, :] = 2 + + +def _applies(program) -> int: + sdfg = program.to_sdfg(simplify=False) + return sdfg.apply_transformations_repeated(LoopToMap) + + +def test_rejects_overlapping_writes(): + """ ``A[5*i]`` and ``A[3*i]`` collide at ``A[15]`` (i=3 and i=5). """ + sdfg = overlapping_writes.to_sdfg(simplify=False) + assert sdfg.apply_transformations_repeated(LoopToMap) == 0 + + n = 64 + a = np.full(5 * n, -1, dtype=np.int64) + sdfg(A=a, N=n) + ref = np.full(5 * n, -1, dtype=np.int64) + for i in range(n): + ref[5 * i] = 1 + ref[3 * i] = 2 + assert np.array_equal(a, ref) + + +def test_accepts_injective_write(): + """ A single ``a*i + b`` write is injective in ``i`` and parallelizable. """ + sdfg = injective_write.to_sdfg(simplify=False) + assert sdfg.apply_transformations_repeated(LoopToMap) >= 1 + + n = 64 + a = np.full(2 * n, -1, dtype=np.int64) + sdfg(A=a, N=n) + ref = np.full(2 * n, -1, dtype=np.int64) + for i in range(n): + ref[2 * i] = i + assert np.array_equal(a, ref) + + +def test_accepts_disjoint_strides(): + """ ``A[2*i]`` (even) and ``A[2*i+1]`` (odd) never collide, for any range. """ + sdfg = disjoint_stride_writes.to_sdfg(simplify=False) + assert sdfg.apply_transformations_repeated(LoopToMap) >= 1 + + n = 64 + a = np.full(2 * n, -1, dtype=np.int64) + sdfg(A=a, N=n) + ref = np.full(2 * n, -1, dtype=np.int64) + for i in range(n): + ref[2 * i] = 1 + ref[2 * i + 1] = 2 + assert np.array_equal(a, ref) + + +def test_rejects_shifted_writes(): + """ ``A[i]`` and ``A[i+1]`` collide between consecutive iterations. """ + assert _applies(shifted_writes) == 0 + + +def test_accepts_disjoint_outer_dimension(): + """ A provably disjoint leading dimension makes the whole access disjoint. """ + sdfg = disjoint_outer_dim.to_sdfg(simplify=False) + assert sdfg.apply_transformations_repeated(LoopToMap) >= 1 + + n = 32 + b = np.full((2 * n, 4), -1, dtype=np.int64) + sdfg(B=b, N=n) + ref = np.full((2 * n, 4), -1, dtype=np.int64) + for i in range(n): + ref[2 * i, :] = 1 + ref[2 * i + 1, :] = 2 + assert np.array_equal(b, ref) + + +# --------------------------------------------------------------------------- +# Indirect / nonlinear subscripts must NOT be certified disjoint by the new +# affine fast path. ``idx`` is not known to be a permutation, and ``i*i`` / +# ``i % k`` fall outside the affine ``a*i + b`` model, so the loop must stay +# sequential (no Map) regardless of the gcd-disjointness reasoning. +# --------------------------------------------------------------------------- + + +@dace.program +def indirect_write_vs_affine(A: dace.int64[5 * N], idx: dace.int64[N]): + for i in range(N): + A[idx[i]] = 1 + A[5 * i] = 2 + + +@dace.program +def indirect_read_vs_affine(A: dace.int64[3 * N], idx: dace.int64[N], out: dace.int64[N]): + for i in range(N): + out[i] = A[idx[i]] + A[3 * i] = i + + +@dace.program +def nonlinear_square_write(A: dace.int64[N * N]): + for i in range(N): + A[i * i] = 1 + A[i] = 2 + + +@dace.program +def nonlinear_mod_write(A: dace.int64[N]): + for i in range(N): + A[i % 4] = 1 + A[i] = 2 + + +def test_rejects_indirect_write_vs_affine(): + """ ``A[idx[i]]`` could equal ``A[5*i]``; ``idx`` is not a known + permutation, so this is a possible cross-iteration write-write + dependence and the loop must stay sequential. """ + sdfg = indirect_write_vs_affine.to_sdfg(simplify=False) + ref_sdfg = copy.deepcopy(sdfg) + assert sdfg.apply_transformations_repeated(LoopToMap) == 0 + assert not _has_map(sdfg) + + n = 32 + # Adversarial index: many entries alias the affine ``5*i`` targets. + idx = (np.arange(n) % 5).astype(np.int64) + + a = np.full(5 * n, -1, dtype=np.int64) + a_ref = a.copy() + sdfg(A=a, idx=idx, N=n) + ref_sdfg(A=a_ref, idx=idx.copy(), N=n) + assert np.array_equal(a, a_ref) + + +def test_rejects_indirect_read_vs_affine(): + """ ``... = A[idx[i]]`` reads while ``A[3*i]`` writes the same container; + a RAW/WAR hazard the affine model cannot rule out. """ + sdfg = indirect_read_vs_affine.to_sdfg(simplify=False) + ref_sdfg = copy.deepcopy(sdfg) + assert sdfg.apply_transformations_repeated(LoopToMap) == 0 + assert not _has_map(sdfg) + + n = 24 + idx = (np.arange(n) % 3 * 3).astype(np.int64) + A = np.arange(3 * n, dtype=np.int64) + A_ref = A.copy() + out = np.full(n, -1, dtype=np.int64) + out_ref = out.copy() + sdfg(A=A, idx=idx, out=out, N=n) + ref_sdfg(A=A_ref, idx=idx.copy(), out=out_ref, N=n) + assert np.array_equal(out, out_ref) + assert np.array_equal(A, A_ref) + + +def test_rejects_nonlinear_square_write(): + """ ``A[i*i]`` is nonlinear in ``i`` and outside the affine model; with a + second write to ``A`` the loop must stay sequential. """ + sdfg = nonlinear_square_write.to_sdfg(simplify=False) + ref_sdfg = copy.deepcopy(sdfg) + assert sdfg.apply_transformations_repeated(LoopToMap) == 0 + assert not _has_map(sdfg) + + n = 16 + a = np.full(n * n, -1, dtype=np.int64) + a_ref = a.copy() + sdfg(A=a, N=n) + ref_sdfg(A=a_ref, N=n) + assert np.array_equal(a, a_ref) + + +def test_rejects_nonlinear_mod_write(): + """ ``A[i % 4]`` is nonlinear in ``i``; with a second write to ``A`` the + loop must stay sequential. """ + sdfg = nonlinear_mod_write.to_sdfg(simplify=False) + ref_sdfg = copy.deepcopy(sdfg) + assert sdfg.apply_transformations_repeated(LoopToMap) == 0 + assert not _has_map(sdfg) + + n = 16 + a = np.full(n, -1, dtype=np.int64) + a_ref = a.copy() + sdfg(A=a, N=n) + ref_sdfg(A=a_ref, N=n) + assert np.array_equal(a, a_ref) + + +# --------------------------------------------------------------------------- +# A dimension that both writes index by the same injective function of the loop +# variable pins any collision to a single iteration, so the writes are disjoint +# across iterations even when the *other* indices are opaque symbols the affine +# model cannot certify. This is the CloudSC scatter pattern +# ``zsolqa[0, imelt, i]`` / ``zsolqa[imelt, 0, i]`` (``i`` the parallel column). +# --------------------------------------------------------------------------- + + +@dace.program +def shared_iteration_dim(A: dace.int64[8, 8, N]): + for i in range(N): + A[0, M, i] = 1 + A[M, 0, i] = 2 + + +@dace.program +def shared_constant_dim_shifted(A: dace.int64[8, N + 1]): + for i in range(N): + A[M, i] = 1 + A[M, i + 1] = 2 + + +def test_accepts_shared_iteration_dimension(): + """ ``A[0, M, i]`` and ``A[M, 0, i]`` share the loop variable ``i`` in their + last dimension, so each iteration owns column ``i`` and they never + collide across iterations -- parallelizable despite the opaque ``M``. """ + sdfg = shared_iteration_dim.to_sdfg(simplify=False) + assert sdfg.apply_transformations_repeated(LoopToMap) >= 1 + assert _has_map(sdfg) + + n, m = 16, 3 + a = np.full((8, 8, n), -1, dtype=np.int64) + sdfg(A=a, N=n, M=m) + ref = np.full((8, 8, n), -1, dtype=np.int64) + for i in range(n): + ref[0, m, i] = 1 + ref[m, 0, i] = 2 + assert np.array_equal(a, ref) + + +def test_rejects_shared_constant_dimension_with_shift(): + """ Guard: the shared dimension here is the *constant* ``M`` (no dependence + on ``i``), so it does not pin iterations together; ``A[M, i]`` and + ``A[M, i+1]`` still collide between consecutive iterations. """ + assert _applies(shared_constant_dim_shifted) == 0 + + +def test_positive_control_disjoint_strides_becomes_map(): + """ Positive control: ``A[2*i]`` / ``A[2*i+1]`` are affine and + gcd-disjoint, so the new fast path SHOULD parallelize (Map present). """ + sdfg = disjoint_stride_writes.to_sdfg(simplify=False) + assert sdfg.apply_transformations_repeated(LoopToMap) >= 1 + assert _has_map(sdfg) + + n = 64 + a = np.full(2 * n, -1, dtype=np.int64) + sdfg(A=a, N=n) + ref = np.full(2 * n, -1, dtype=np.int64) + for i in range(n): + ref[2 * i] = 1 + ref[2 * i + 1] = 2 + assert np.array_equal(a, ref) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/transformations/map_fusion_vertical_test.py b/tests/transformations/map_fusion_vertical_test.py index b8776562fb..e6c4a34f64 100644 --- a/tests/transformations/map_fusion_vertical_test.py +++ b/tests/transformations/map_fusion_vertical_test.py @@ -3214,6 +3214,109 @@ def test_map_fusion_stable_label(forward_fusion: bool): assert all(np.allclose(ref[k], res[k]) for k in ref) +def test_map_fusion_inout_connector_intermediate_rename_consistency(): + """Pins the InOut-connector structural split that ``MapFusionVertical`` + performs when the fusion intermediate's data name matches an InOut + connector of a NestedSDFG inside the producer map's body. Without the + split, the rename ``inter_name -> __map_fusion_`` would + produce a NestedSDFG whose InOut connector references the original + outer data on the input edge and the renamed transient on the output + edge -- a validation error (`Inout connector X is connected to + different input ({'X'}) and output ({'__map_fusion_X'}) arrays`). + + Reproducer minimised from TSVC s221:: + + for i in range(1, N): + a[i] = a[i] + c[i] * d[i] # NestedSDFG InOut on 'a' (in-place RMW) + b[i] = b[i-1] + a[i] + d[i] # consumer of 'a', emits a Scan + apply map + + The fix in ``MapFusionVertical.apply`` runs ``_split_inout_for_intermediate`` + before the standard rename machinery: it allocates a fresh inner array + ``__map_fusion_split_`` inside the NestedSDFG, renames every inner + read-side AccessNode of ```` to the fresh name, adds a new input + connector with the fresh name, and redirects the outer input edge to it. + The NestedSDFG's ```` connector becomes output-only; the standard + rename machinery can then rename it to ``__map_fusion_`` cleanly. + + Test asserts (post-fix): fusion APPLIES, validation passes, the + NestedSDFG's InOut overlap is empty, and the numerical outputs match + the pre-fuse reference oracle bit-exact. + """ + sdfg = dace.SDFG('mf_inout_repro') + sdfg.add_array('a', [10], dace.float64, transient=False) + sdfg.add_array('b', [10], dace.float64, transient=False) + sdfg.add_array('c', [10], dace.float64, transient=False) + + state = sdfg.add_state('main') + + # Map 1's body: an InOut-on-`a` NestedSDFG performing `a[i] = a[i] + c[i]`. + inner = dace.SDFG('inner') + inner.add_array('a', [1], dace.float64, transient=False) + inner.add_array('c', [1], dace.float64, transient=False) + ist = inner.add_state('istate') + a_r = ist.add_read('a') + c_r = ist.add_read('c') + a_w = ist.add_write('a') + t = ist.add_tasklet('upd', {'_a', '_c'}, {'_o'}, '_o = _a + _c') + ist.add_edge(a_r, None, t, '_a', dace.Memlet(data='a', subset='0')) + ist.add_edge(c_r, None, t, '_c', dace.Memlet(data='c', subset='0')) + ist.add_edge(t, '_o', a_w, None, dace.Memlet(data='a', subset='0')) + + me1, mx1 = state.add_map('m1', {'i': '0:10'}) + nsdfg = state.add_nested_sdfg(inner, inputs={'a', 'c'}, outputs={'a'}) # 'a' is InOut + + a_src = state.add_read('a') + c_src = state.add_read('c') + a_inter = state.add_access('a') # producer output, consumer input + + state.add_memlet_path(a_src, me1, nsdfg, dst_conn='a', memlet=dace.Memlet(data='a', subset='i')) + state.add_memlet_path(c_src, me1, nsdfg, dst_conn='c', memlet=dace.Memlet(data='c', subset='i')) + state.add_memlet_path(nsdfg, mx1, a_inter, src_conn='a', memlet=dace.Memlet(data='a', subset='i')) + + # Map 2: simple consumer `b[i] = a[i] + 1`. + me2, mx2 = state.add_map('m2', {'i': '0:10'}) + t2 = state.add_tasklet('cons', {'_a'}, {'_b'}, '_b = _a + 1.0') + b_dst = state.add_write('b') + state.add_memlet_path(a_inter, me2, t2, dst_conn='_a', memlet=dace.Memlet(data='a', subset='i')) + state.add_memlet_path(t2, mx2, b_dst, src_conn='_b', memlet=dace.Memlet(data='b', subset='i')) + + sdfg.validate() + + # Numeric oracle: capture pre-fuse output, then optionally re-run post-fuse. + rng = np.random.default_rng(221) + a_in = rng.standard_normal(10) + c_in = rng.standard_normal(10) + ref = {'a': a_in.copy(), 'b': np.zeros(10), 'c': c_in.copy()} + res = {'a': a_in.copy(), 'b': np.zeros(10), 'c': c_in.copy()} + sdfg_ref = copy.deepcopy(sdfg) + sdfg_ref(**ref) + + # Apply MapFusionVertical (the actual production pass) and ensure validation holds. + from dace.transformation.passes.pattern_matching import PatternMatchAndApplyRepeated + res_apply = PatternMatchAndApplyRepeated([MapFusionVertical()]).apply_pass(sdfg, {}) + sdfg.validate() + + # The structural split must let the fusion APPLY (refusal was the pre-fix + # behavior; pinning the split here prevents a future regression that + # reintroduces a refusal-only path on this shape). + assert res_apply is not None and 'MapFusionVertical' in res_apply, ( + 'MapFusionVertical must apply on the InOut shape via the structural split, ' + f'not refuse; got result={res_apply!r}.') + + # After the split + fusion the NestedSDFG's InOut overlap must be empty + # (the connector that was InOut is now output-only). + nsdfgs = [n for st in sdfg.states() for n in st.nodes() if isinstance(n, nodes.NestedSDFG)] + assert nsdfgs, 'expected at least one NestedSDFG to survive the fusion' + for n in nsdfgs: + inout = set(n.in_connectors) & set(n.out_connectors) + assert not inout, f'NestedSDFG {n.label} still has InOut overlap after split: {sorted(inout)}' + + # Numerically, the fused SDFG produces the same outputs as the pre-fuse oracle. + sdfg(**res) + assert np.allclose(ref['a'], res['a']), 'a-array semantics broken by fusion (InOut split desync)' + assert np.allclose(ref['b'], res['b']), 'b-array semantics broken by fusion' + + def test_map_fusion_is_deprecated() -> None: with pytest.deprecated_call(match="MapFusion is deprecated"): MapFusion() diff --git a/tests/transformations/redundant_copy_test.py b/tests/transformations/redundant_copy_test.py index 00ca729f7d..eea356482b 100644 --- a/tests/transformations/redundant_copy_test.py +++ b/tests/transformations/redundant_copy_test.py @@ -308,6 +308,58 @@ def test_in(): assert np.allclose(A_arr, D_arr.T) +def test_in_failure_partial_copy(): + """RedundantArrayCopyingIn must refuse a chain whose copies are not full + identity copies: ``A -> B -> C -> D`` where only ``B[0:2]`` flows through. + Collapsing the writers of ``B`` straight onto ``D`` would silently widen the + partial copy to a full one and corrupt the region the chain never wrote.""" + + def build(): + sdfg = dace.SDFG('rcin_failure_partial_copy') + state = sdfg.add_state() + sdfg.add_array('A', [4], dace.float64) + sdfg.add_transient('B', [4], dace.float64) + sdfg.add_transient('C', [4], dace.float64) + sdfg.add_array('D', [4], dace.float64) + A, B, C, D = (state.add_access(x) for x in 'ABCD') + state.add_nedge(A, B, dace.Memlet('A[0:4] -> [0:4]')) # B = A (full) + state.add_nedge(B, C, dace.Memlet('B[0:2] -> [0:2]')) # C[0:2] = B[0:2] (partial) + state.add_nedge(C, D, dace.Memlet('C[0:2] -> [0:2]')) # D[0:2] = C[0:2] (partial) + sdfg.validate() + return sdfg + + sdfg = build() + applied = sdfg.apply_transformations_repeated(RedundantArrayCopyingIn) + assert applied == 0 + + A_arr = np.arange(1, 5, dtype=np.float64) + D_arr = np.zeros(4, dtype=np.float64) + sdfg(A=A_arr, D=D_arr) + assert np.allclose(D_arr, [1.0, 2.0, 0.0, 0.0]) + + +def test_in_failure_extra_consumer(): + """RedundantArrayCopyingIn must refuse when the middle array feeds a second + consumer: ``apply`` removes the middle node, which would orphan that + consumer's source.""" + sdfg = dace.SDFG('rcin_failure_extra_consumer') + state = sdfg.add_state() + sdfg.add_array('A', [4], dace.float64) + sdfg.add_transient('B', [4], dace.float64) + sdfg.add_transient('C', [4], dace.float64) + sdfg.add_array('D', [4], dace.float64) + sdfg.add_array('E', [4], dace.float64) + A, B, C, D, E = (state.add_access(x) for x in 'ABCDE') + state.add_nedge(A, B, dace.Memlet('A[0:4] -> [0:4]')) + state.add_nedge(B, C, dace.Memlet('B[0:4] -> [0:4]')) + state.add_nedge(C, D, dace.Memlet('C[0:4] -> [0:4]')) + state.add_nedge(C, E, dace.Memlet('C[0:4] -> [0:4]')) # second consumer of C + sdfg.validate() + + assert sdfg.apply_transformations_repeated(RedundantArrayCopyingIn) == 0 + assert C in state.nodes() + + def test_view_array_array(): sdfg = dace.SDFG('redarrtest') sdfg.add_view('v', [2, 10], dace.float64) diff --git a/tests/transformations/trivial_tasklet_elimination_test.py b/tests/transformations/trivial_tasklet_elimination_test.py index 6169289543..2c1a0db1c4 100644 --- a/tests/transformations/trivial_tasklet_elimination_test.py +++ b/tests/transformations/trivial_tasklet_elimination_test.py @@ -1,5 +1,6 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import dace +from dace.sdfg import nodes from dace.transformation.dataflow.trivial_tasklet_elimination import TrivialTaskletElimination N = 10 @@ -122,7 +123,49 @@ def test_trivial_tasklet_with_implicit_cast(): assert count == 0 +def test_trivial_tasklet_map_source_preserves_offset_subset(): + """When the eliminated copy tasklet's source is a ``MapEntry``, the + surviving edge must describe the read data and keep its offset subset. + + The edge leaves the map's ``OUT_a`` connector, so its memlet must have + ``data == 'a'`` and the offset subset ``a[i + 1]``. Reusing the write + memlet left ``data == 'a_idx'`` with the read offset stranded in + ``other_subset`` -- an orientation inconsistent with the connector. It + still validates and runs at this point, but a later re-lowering that reads + ``.subset`` (e.g. ``MapToForLoop``) then drops the offset (``[0]``). + """ + sym_n = dace.symbol('N') + sdfg = dace.SDFG('tte_map_offset') + sdfg.add_array('a', (sym_n, ), dace.float64) + sdfg.add_array('b', (sym_n, ), dace.float64) + sdfg.add_scalar('a_idx', dace.float64, transient=True) + st = sdfg.add_state() + a, b, aidx = st.add_access('a'), st.add_access('b'), st.add_access('a_idx') + me, mx = st.add_map('m', dict(i='0:N-1')) + copy_tasklet = st.add_tasklet('copy', {'inp'}, {'out'}, 'out = inp') + mult = st.add_tasklet('mult', {'inp'}, {'out'}, 'out = inp * 2.0') + st.add_memlet_path(a, me, copy_tasklet, dst_conn='inp', memlet=dace.Memlet('a[i + 1]')) + st.add_edge(copy_tasklet, 'out', aidx, None, dace.Memlet('a_idx[0]')) + st.add_edge(aidx, None, mult, 'inp', dace.Memlet('a_idx[0]')) + st.add_memlet_path(mult, mx, b, src_conn='out', memlet=dace.Memlet('b[i]')) + sdfg.validate() + + assert sdfg.apply_transformations_repeated(TrivialTaskletElimination) == 1 + + surviving = [ + e for st in sdfg.states() for e in st.edges() + if isinstance(e.src, nodes.MapEntry) and isinstance(e.dst, nodes.AccessNode) and e.dst.data == 'a_idx' + ] + assert len(surviving) == 1 + memlet = surviving[0].data + # The edge leaves OUT_a, so it must describe ``a`` and keep the offset. + assert memlet.data == 'a', f"surviving edge must describe the read data 'a', got {memlet.data!r}" + assert 'i' in {str(s) for s in memlet.subset.free_symbols}, \ + f"per-iteration offset lost from the subset: {memlet.subset}" + + if __name__ == '__main__': test_trivial_tasklet() test_trivial_tasklet_with_map() test_trivial_tasklet_with_implicit_cast() + test_trivial_tasklet_map_source_preserves_offset_subset() diff --git a/tests/transformations/wcr_conversion_test.py b/tests/transformations/wcr_conversion_test.py index fd37f13c3f..8465a63565 100644 --- a/tests/transformations/wcr_conversion_test.py +++ b/tests/transformations/wcr_conversion_test.py @@ -264,3 +264,120 @@ def sdfg_aug_assign_same_inconns(A: dace.float64[32]): applied = sdfg.apply_transformations_repeated(AugAssignToWCR, permissive=True) assert applied == 1 + + +def _build_copy_wrapped_rmw(op_code: str, op_wcr: str, n: int = 6): + """Build an SDFG whose loop body is the *copy-wrapped* RMW the array + frontends emit: the accumulator slice ``A[0]`` is materialized into a scalar + transient, combined with the per-iteration increment ``B[j]`` in a tasklet, + and copied back into ``A[0]`` -- ``A[0] -> a_in -> tasklet -> a_sum -> + A[0]``. ``op_code`` is the tasklet RHS (``__in1 __in2``); ``op_wcr`` is + the numpy reduction used to build the oracle.""" + sdfg = dace.SDFG(f'copy_wrapped_rmw_{op_wcr}') + sdfg.add_array('A', [2], dace.float64) + sdfg.add_array('B', [n], dace.float64) + sdfg.add_scalar('a_in', dace.float64, transient=True) + sdfg.add_scalar('b_in', dace.float64, transient=True) + sdfg.add_scalar('a_sum', dace.float64, transient=True) + + body = sdfg.add_state('body') + a_r = body.add_read('A') + a_in = body.add_access('a_in') + b_r = body.add_read('B') + b_in = body.add_access('b_in') + tasklet = body.add_tasklet('combine', {'__in1', '__in2'}, {'__out'}, f'__out = {op_code}') + a_sum = body.add_access('a_sum') + a_w = body.add_write('A') + + body.add_edge(a_r, None, a_in, None, dace.Memlet('A[0]')) # accumulator load copy + body.add_edge(a_in, None, tasklet, '__in1', dace.Memlet('a_in[0]')) + body.add_edge(b_r, None, b_in, None, dace.Memlet('B[j]')) + body.add_edge(b_in, None, tasklet, '__in2', dace.Memlet('b_in[0]')) + body.add_edge(tasklet, '__out', a_sum, None, dace.Memlet('a_sum[0]')) + body.add_edge(a_sum, None, a_w, None, dace.Memlet('A[0]')) # accumulator store copy + + before = sdfg.add_state('before', is_start_block=True) + after = sdfg.add_state('after') + sdfg.add_loop(before, body, after, 'j', '0', 'j < %d' % n, 'j + 1') + sdfg.reset_cfg_list() + return sdfg + + +def test_aug_assign_copy_wrapped_rmw_match(): + """The copy-wrapped RMW is recognised and rewritten to a WCR write: the + accumulator load is dropped, the tasklet emits only the increment, and the + write into ``A[0]`` carries the reduction WCR.""" + from dace.sdfg import nodes + sdfg = _build_copy_wrapped_rmw('__in1 + __in2', 'sum') + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + sdfg.validate() + + body = next(s for s in sdfg.all_states() if s.label == 'body') + wcr_writes = [ + e for e in body.edges() if isinstance(e.dst, nodes.AccessNode) and e.dst.data == 'A' and e.data.wcr is not None + ] + assert len(wcr_writes) == 1 + assert 'a + b' in wcr_writes[0].data.wcr + # The accumulator is no longer loaded inside the body. + assert not any(isinstance(n, nodes.AccessNode) and n.data == 'A' and body.out_degree(n) > 0 for n in body.nodes()) + + +def test_aug_assign_copy_wrapped_rmw_value_and_parallelize(): + """The rewrite is value-preserving and the now-WCR loop parallelizes via + LoopToMap (the accumulator write is no longer iteration-indexed but is + conflict-resolved).""" + import numpy as np + from dace.transformation.interstate import LoopToMap + from dace.sdfg.state import LoopRegion + + rng = np.random.default_rng(0) + n = 6 + + def run(sdfg): + A = np.array([3.0, 99.0], dtype=np.float64) + B = rng.random(n) + sdfg(A=A, B=B.copy()) + return A, B + + ref_sdfg = _build_copy_wrapped_rmw('__in1 + __in2', 'sum') + A_ref, B = run(ref_sdfg) + assert np.allclose(A_ref[0], 3.0 + B.sum(), rtol=1e-15, atol=1e-15) + + cand = _build_copy_wrapped_rmw('__in1 + __in2', 'sum') + assert cand.apply_transformations_repeated(AugAssignToWCR) == 1 + n_l2m = cand.apply_transformations_repeated(LoopToMap) + assert n_l2m == 1, 'WCR accumulator loop should parallelize' + assert not [r for r in cand.all_control_flow_regions() if isinstance(r, LoopRegion) and r.loop_variable] + + A = np.array([3.0, 99.0], dtype=np.float64) + cand(A=A, B=B.copy()) + assert np.allclose(A, A_ref, rtol=1e-12, atol=1e-12) + + +def test_aug_assign_copy_wrapped_rmw_max(): + """max-reduction copy-wrapped RMW lifts to a ``max`` WCR.""" + import numpy as np + sdfg = _build_copy_wrapped_rmw('max(__in1, __in2)', 'max') + assert sdfg.apply_transformations_repeated(AugAssignToWCR) == 1 + sdfg.validate() + A = np.array([0.5, 0.0], dtype=np.float64) + B = np.array([0.1, 0.9, 0.3, 0.2, 0.7, 0.4], dtype=np.float64) + sdfg(A=A, B=B.copy()) + assert np.allclose(A[0], max(0.5, B.max())) + + +def test_aug_assign_copy_wrapped_rmw_subtract_left_only(): + """Subtraction lifts only with the accumulator on the left (``a - b``); + ``b - a`` is not an order-independent reduction and must be refused.""" + import numpy as np + sdfg = _build_copy_wrapped_rmw('__in1 - __in2', 'sub') # acc on left -> OK + assert sdfg.apply_transformations_repeated(AugAssignToWCR) == 1 + A = np.array([10.0, 0.0], dtype=np.float64) + B = np.array([1.0, 2.0, 0.5, 1.5, 0.0, 1.0], dtype=np.float64) + sdfg(A=A, B=B.copy()) + assert np.allclose(A[0], 10.0 - B.sum()) + + refused = _build_copy_wrapped_rmw('__in2 - __in1', 'rsub') # acc on right -> refuse + assert refused.apply_transformations_repeated(AugAssignToWCR) == 0