diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 6cf84b71eb..157bcfd379 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -192,11 +192,17 @@ def remove_in_connector(self, connector_name: str): :param connector_name: The name of the connector to remove. :return: True if the operation was successful. """ + if not connector_name: + warnings.warn(f'Tried to remove `{connector_name}` from the in-connectors of node {str(self)}', + stacklevel=1) + return False + + if connector_name not in self.in_connectors: + return False - if connector_name in self.in_connectors: - connectors = self.in_connectors - del connectors[connector_name] - self.in_connectors = connectors + connectors = self.in_connectors + del connectors[connector_name] + self.in_connectors = connectors return True def remove_out_connector(self, connector_name: str): @@ -205,11 +211,17 @@ def remove_out_connector(self, connector_name: str): :param connector_name: The name of the connector to remove. :return: True if the operation was successful. """ + if not connector_name: + warnings.warn(f'Tried to remove `{connector_name}` from the out-connectors of node {str(self)}', + stacklevel=1) + return False + + if connector_name not in self.out_connectors: + return False - if connector_name in self.out_connectors: - connectors = self.out_connectors - del connectors[connector_name] - self.out_connectors = connectors + connectors = self.out_connectors + del connectors[connector_name] + self.out_connectors = connectors return True def _next_connector_int(self) -> int: diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 782e98d40d..9e8eeda1bc 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -857,51 +857,74 @@ def set_outer_subset(e: MultiConnectorEdge[dace.Memlet], new_subset: sbs.Subset) return consolidated -def remove_edge_and_dangling_path(state: SDFGState, edge: MultiConnectorEdge): +def remove_edge_and_dangling_path(state: SDFGState, edge: MultiConnectorEdge) -> int: """ - Removes an edge and all of its parent edges in a memlet path, cleaning - dangling connectors and isolated nodes resulting from the removal. + Removes an edge and all of its parent edges in a memlet path, including now + unused connectors. Furthermore, all nodes that become isolated are also + removed from the state. :param state: The state in which the edge exists. :param edge: The edge to remove. """ - mtree = state.memlet_tree(edge) - inwards = (isinstance(edge.src, nd.EntryNode) or isinstance(edge.dst, nd.EntryNode)) + + if edge.data.is_empty(): + state.remove_edge(edge) + if state.degree(edge.dst) == 0: + state.remove_node(edge.dst) + if state.degree(edge.src) == 0: + state.remove_node(edge.src) + return 1 # Traverse tree upwards, removing edges and connectors as necessary - curedge = mtree - while curedge is not None: - e = curedge.edge - state.remove_edge(e) - if inwards: - neighbors = [] if not e.src_conn else [ - neighbor for neighbor in state.out_edges_by_connector(e.src, e.src_conn) - ] - else: - neighbors = [] if not e.dst_conn else [ - neighbor for neighbor in state.in_edges_by_connector(e.dst, e.dst_conn) - ] - if len(neighbors) > 0: # There are still edges connected, leave as-is - break + mtree = state.memlet_tree(edge) + curr_tree = mtree + nb_removed_edges = 0 + while curr_tree is not None: + curr_edge = curr_tree.edge + assert not curr_edge.data.is_empty() + state.remove_edge(curr_edge) + nb_removed_edges += 1 + + if curr_tree.downwards: + if state.degree(curr_edge.dst) == 0: + # If target node is isolated we can remove it. + state.remove_node(curr_edge.dst) + else: + # If the node is not isolated we must look at its connectors and clean them. + if isinstance(curr_edge.dst, nd.EntryNode) and curr_edge.dst_conn.startswith("IN_"): + curr_edge.dst.remove_out_connector("OUT_" + curr_edge.dst_conn[3:]) + if curr_edge.dst_conn and len(list(state.in_edges_by_connector(curr_edge.dst, + curr_edge.dst_conn))) == 0: + curr_edge.dst.remove_in_connector(curr_edge.dst_conn) - # Remove connector and matching outer connector - if inwards: - if e.src_conn: - e.src.remove_out_connector(e.src_conn) - e.src.remove_in_connector('IN' + e.src_conn[3:]) - else: - if e.dst_conn: - e.dst.remove_in_connector(e.dst_conn) - e.dst.remove_out_connector('OUT' + e.dst_conn[2:]) + # There is a fan-out, i.e. the `curr_edge.src_conn` is still in use and we are done here. + if len(list(state.out_edges_by_connector(curr_edge.src, curr_edge.src_conn))) != 0: + return nb_removed_edges - # Continue traversing upwards - curedge = curedge.parent - else: - # Check if an isolated node have been created at the root and remove - root_edge = mtree.root().edge - root_node: nd.Node = root_edge.src if inwards else root_edge.dst - if state.degree(root_node) == 0: - state.remove_node(root_node) + else: + if state.degree(curr_edge.src) == 0: + state.remove_node(curr_edge.src) + else: + if isinstance(curr_edge.src, nd.ExitNode) and curr_edge.src_conn.startswith("OUT_"): + curr_edge.src.remove_in_connector("IN_" + curr_edge.src_conn[4:]) + if curr_edge.src_conn and len(list(state.out_edges_by_connector(curr_edge.src, + curr_edge.src_conn))) == 0: + curr_edge.src.remove_out_connector(curr_edge.src_conn) + + # The connector might be collecting. + if len(list(state.in_edges_by_connector(curr_edge.dst, curr_edge.dst_conn))) != 0: + return nb_removed_edges + + # Continue traversing tree upwards + curr_tree = curr_tree.parent + + # Check if an isolated node have been created at the root and remove + root_edge = mtree.root().edge + root_node: nd.Node = root_edge.src if mtree.downwards else root_edge.dst + if state.degree(root_node) == 0: + state.remove_node(root_node) + + return nb_removed_edges def consolidate_edges( diff --git a/tests/sdfg/remove_edge_and_dangling_path_test.py b/tests/sdfg/remove_edge_and_dangling_path_test.py new file mode 100644 index 0000000000..07ccd5ac55 --- /dev/null +++ b/tests/sdfg/remove_edge_and_dangling_path_test.py @@ -0,0 +1,131 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +import dace + + +def test_remove_edge_global_scope(): + sdfg = dace.SDFG("simple_edge_remover_test") + state = sdfg.add_state() + + sdfg.add_array("a", shape=(1, ), dtype=dace.float64, transient=False) + sdfg.add_array("b", shape=(1, ), dtype=dace.float64, transient=False) + + tlet = state.add_tasklet( + "comp", + inputs={"__in"}, + outputs={"__out"}, + code="__out = __in + 1.0", + ) + a = state.add_access("a") + b = state.add_access("b") + + up_edge = state.add_edge(a, None, tlet, "__in", dace.Memlet("a[0]")) + down_edge = state.add_edge(tlet, "__out", b, None, dace.Memlet("b[0]")) + sdfg.validate() + + # Now remove the edge using the function. Note that this makes the SDFG invalid, + # because the tasklet wants an input. We ignore that for now. + nb_removed_edges = dace.sdfg.utils.remove_edge_and_dangling_path(state, up_edge) + assert nb_removed_edges == 1 + assert a not in state.nodes() + assert tlet in state.nodes() + assert b in state.nodes() + assert sdfg.arrays.keys() == {"a", "b"} + assert len(tlet.in_connectors) == 0 + assert state.in_degree(tlet) == 0 + assert tlet.out_connectors.keys() == {"__out"} + assert state.out_degree(tlet) == 1 + + # If we now also delete the `down_edge` then the state will become empty. + nb_removed_edges = dace.sdfg.utils.remove_edge_and_dangling_path(state, down_edge) + assert nb_removed_edges == 1 + assert state.number_of_nodes() == 0 + + +def test_remove_edge_nested_scope(): + sdfg = dace.SDFG("nested_edge_remover_test") + state = sdfg.add_state() + + sdfg.add_array("a", shape=(10, 10), dtype=dace.float64, transient=False) + sdfg.add_array("b", shape=(10, 10, 2), dtype=dace.float64, transient=False) + + tlet = state.add_tasklet( + "comp", + inputs={"__in1", "__in2"}, + outputs={"__out1", "__out2"}, + code="__out1 = __in1 + 1.0\n__out2 = __in2 - 1.0", + ) + a, b = (state.add_access(name) for name in "ab") + me, mx = state.add_map("outer_map", ndrange={"__i": "0:10"}) + nme, nmx = state.add_map("nested_map", ndrange={"__j": "0:10"}) + + state.add_edge(a, None, me, "IN_a", dace.Memlet("a[0:10, 0:10]")) + me.add_scope_connectors("a") + state.add_edge(me, "OUT_a", nme, "IN_a1", dace.Memlet("a[__i, 0:10]")) + state.add_edge(me, "OUT_a", nme, "IN_a2", dace.Memlet("a[0:10, __i]")) + nme.add_scope_connectors("a1") + nme.add_scope_connectors("a2") + + up_edge1 = state.add_edge(nme, "OUT_a1", tlet, "__in1", dace.Memlet("a[__i, __j]")) + up_edge2 = state.add_edge(nme, "OUT_a2", tlet, "__in2", dace.Memlet("a[__j, __i]")) + + down_edge1 = state.add_edge(tlet, "__out1", nmx, "IN_b", dace.Memlet("b[__i, __j, 0]")) + down_edge2 = state.add_edge(tlet, "__out2", nmx, "IN_b", dace.Memlet("b[__j, __i, 1]")) + nmx.add_scope_connectors("b") + + state.add_edge(nmx, "OUT_b", mx, "IN_b", dace.Memlet("b[0:10, 0:10, 0:2]")) + mx.add_scope_connectors("b") + state.add_edge(mx, "OUT_b", b, None, dace.Memlet("b[0:10, 0:10, 0:2]")) + sdfg.validate() + + assert state.number_of_nodes() == 7 + + # Because of the fan out, the deletion will stop. + nb_rm_up_edge1 = dace.sdfg.utils.remove_edge_and_dangling_path(state, up_edge1) + assert nb_rm_up_edge1 == 2 + assert state.number_of_nodes() == 7 + assert set(tlet.in_connectors.keys()) == {"__in2"} + assert state.in_degree(tlet) == 1 + assert set(tlet.out_connectors.keys()) == {"__out1", "__out2"} + assert state.out_degree(tlet) == 2 + assert set(nme.out_connectors.keys()) == {"OUT_a2"} + assert set(nme.in_connectors.keys()) == {"IN_a2"} + assert state.out_degree(nme) == 1 + assert state.in_degree(nme) == 1 + assert state.out_degree(me) == 1 + assert state.in_degree(me) == 1 + assert set(me.out_connectors.keys()) == {"OUT_a"} + assert set(me.in_connectors.keys()) == {"IN_a"} + assert state.degree(a) == 1 + + # Now the deletion will go up and remove the map entries; which leads to a + # technical and functional invalid SDFG. + nb_rm_up_edge2 = dace.sdfg.utils.remove_edge_and_dangling_path(state, up_edge2) + assert nb_rm_up_edge2 == 3 + assert state.number_of_nodes() == 4 + assert state.number_of_edges() == 4 + assert {tlet, nmx, mx, b} == set(state.nodes()) + assert len(tlet.in_connectors) == 0 + assert set(tlet.out_connectors.keys()) == {"__out1", "__out2"} + + # The first down edge, will only delete one edge, due to the location of the fan + # in, which is a bit different compared to the location of the upper fan out. + nb_rm_down_edge1 = dace.sdfg.utils.remove_edge_and_dangling_path(state, down_edge1) + assert nb_rm_down_edge1 == 1 + assert state.number_of_nodes() == 4 + assert state.number_of_edges() == 3 + assert {tlet, nmx, mx, b} == set(state.nodes()) + assert len(tlet.in_connectors) == 0 + assert set(tlet.out_connectors.keys()) == {"__out2"} + assert len(nmx.in_connectors) == 1 + assert len(nmx.out_connectors) == 1 + assert len(mx.in_connectors) == 1 + assert len(mx.out_connectors) == 1 + + # This will remove all nodes. + nb_rm_down_edge2 = dace.sdfg.utils.remove_edge_and_dangling_path(state, down_edge2) + assert nb_rm_up_edge2 == 3 + assert state.number_of_nodes() == 0 + + +if __name__ == '__main__': + test_remove_edge_global_scope()