-
Notifications
You must be signed in to change notification settings - Fork 157
Fixed remove_edge_and_dangling_path()
#2307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
daf8d08
b80ef26
ebc4fbd
2f7c88d
c97d11c
3c1b6f3
085b101
668bd6e
34760d1
aad715e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -857,51 +857,74 @@ def set_outer_subset(e: MultiConnectorEdge[dace.Memlet], new_subset: sbs.Subset) | |
| return consolidated | ||
|
|
||
|
|
||
| def remove_edge_and_dangling_path(state: SDFGState, edge: MultiConnectorEdge): | ||
| def remove_edge_and_dangling_path(state: SDFGState, edge: MultiConnectorEdge) -> int: | ||
| """ | ||
| Removes an edge and all of its parent edges in a memlet path, cleaning | ||
| dangling connectors and isolated nodes resulting from the removal. | ||
| Removes an edge and all of its parent edges in a memlet path, including now | ||
| unused connectors. Furthermore, all nodes that become isolated are also | ||
| removed from the state. | ||
|
|
||
| :param state: The state in which the edge exists. | ||
| :param edge: The edge to remove. | ||
| """ | ||
| mtree = state.memlet_tree(edge) | ||
| inwards = (isinstance(edge.src, nd.EntryNode) or isinstance(edge.dst, nd.EntryNode)) | ||
|
|
||
| if edge.data.is_empty(): | ||
| state.remove_edge(edge) | ||
| if state.degree(edge.dst) == 0: | ||
| state.remove_node(edge.dst) | ||
| if state.degree(edge.src) == 0: | ||
| state.remove_node(edge.src) | ||
| return 1 | ||
|
|
||
| # Traverse tree upwards, removing edges and connectors as necessary | ||
| curedge = mtree | ||
| while curedge is not None: | ||
| e = curedge.edge | ||
| state.remove_edge(e) | ||
| if inwards: | ||
| neighbors = [] if not e.src_conn else [ | ||
| neighbor for neighbor in state.out_edges_by_connector(e.src, e.src_conn) | ||
| ] | ||
| else: | ||
| neighbors = [] if not e.dst_conn else [ | ||
| neighbor for neighbor in state.in_edges_by_connector(e.dst, e.dst_conn) | ||
| ] | ||
| if len(neighbors) > 0: # There are still edges connected, leave as-is | ||
| break | ||
| mtree = state.memlet_tree(edge) | ||
| curr_tree = mtree | ||
| nb_removed_edges = 0 | ||
| while curr_tree is not None: | ||
| curr_edge = curr_tree.edge | ||
| assert not curr_edge.data.is_empty() | ||
| state.remove_edge(curr_edge) | ||
| nb_removed_edges += 1 | ||
|
|
||
| if curr_tree.downwards: | ||
| if state.degree(curr_edge.dst) == 0: | ||
| # If the edge is isolated we can remove it. | ||
|
philip-paul-mueller marked this conversation as resolved.
Outdated
|
||
| state.remove_node(curr_edge.dst) | ||
| else: | ||
| # If the node is not isolated we must look at its connectors and clean them. | ||
| if isinstance(curr_edge.dst, nd.EntryNode) and curr_edge.dst_conn.startswith("IN_"): | ||
| curr_edge.dst.remove_out_connector("OUT_" + curr_edge.dst_conn[3:]) | ||
| if curr_edge.dst_conn and len(list(state.in_edges_by_connector(curr_edge.dst, | ||
| curr_edge.dst_conn))) == 0: | ||
| curr_edge.dst.remove_in_connector(curr_edge.dst_conn) | ||
|
|
||
| # Remove connector and matching outer connector | ||
| if inwards: | ||
| if e.src_conn: | ||
| e.src.remove_out_connector(e.src_conn) | ||
| e.src.remove_in_connector('IN' + e.src_conn[3:]) | ||
| else: | ||
| if e.dst_conn: | ||
| e.dst.remove_in_connector(e.dst_conn) | ||
| e.dst.remove_out_connector('OUT' + e.dst_conn[2:]) | ||
| # There is a fan-out, i.e. the `curr_edge.src_conn` is still in use and we are done here. | ||
| if len(list(state.out_edges_by_connector(curr_edge.src, curr_edge.src_conn))) != 0: | ||
| return nb_removed_edges | ||
|
|
||
| # Continue traversing upwards | ||
| curedge = curedge.parent | ||
| else: | ||
| # Check if an isolated node have been created at the root and remove | ||
| root_edge = mtree.root().edge | ||
| root_node: nd.Node = root_edge.src if inwards else root_edge.dst | ||
| if state.degree(root_node) == 0: | ||
| state.remove_node(root_node) | ||
| else: | ||
| if state.degree(curr_edge.src) == 0: | ||
| state.remove_node(curr_edge.src) | ||
| else: | ||
| if isinstance(curr_edge.src, nd.ExitNode) and curr_edge.src_conn.startswith("OUT_"): | ||
| curr_edge.src.remove_in_connector("IN_" + curr_edge.src_conn[4:]) | ||
| if curr_edge.src_conn and len(list(state.out_edges_by_connector(curr_edge.src, | ||
| curr_edge.src_conn))) == 0: | ||
| curr_edge.src.remove_out_connector(curr_edge.src_conn) | ||
|
|
||
| # The connector might be collecting. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what case is this? I think it is generally disallowed in SDFG syntax
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually wish that it would be disallowed but it is permitted and as far as I can tell parts of the toolchain (I suspect The case essentially represents the following: So However, it is also allowed that they go to the same (in-)connector of This is actually currently happening in GT4Py and it also blocks some optimization.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, there is an exception for passthrough ( |
||
| if len(list(state.in_edges_by_connector(curr_edge.dst, curr_edge.dst_conn))) != 0: | ||
| return nb_removed_edges | ||
|
|
||
| # Continue traversing tree upwards | ||
| curr_tree = curr_tree.parent | ||
|
|
||
| # Check if an isolated node have been created at the root and remove | ||
| root_edge = mtree.root().edge | ||
| root_node: nd.Node = root_edge.src if mtree.downwards else root_edge.dst | ||
| if state.degree(root_node) == 0: | ||
| state.remove_node(root_node) | ||
|
|
||
| return nb_removed_edges | ||
|
|
||
|
|
||
| def consolidate_edges( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| # Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. | ||
| import dace | ||
|
|
||
|
|
||
| def test_remove_edge_global_scope(): | ||
| sdfg = dace.SDFG("simple_edge_remover_test") | ||
| state = sdfg.add_state() | ||
|
|
||
| sdfg.add_array("a", shape=(1, ), dtype=dace.float64, transient=False) | ||
| sdfg.add_array("b", shape=(1, ), dtype=dace.float64, transient=False) | ||
|
|
||
| tlet = state.add_tasklet( | ||
| "comp", | ||
| inputs={"__in"}, | ||
| outputs={"__out"}, | ||
| code="__out = __in + 1.0", | ||
| ) | ||
| a = state.add_access("a") | ||
| b = state.add_access("b") | ||
|
|
||
| up_edge = state.add_edge(a, None, tlet, "__in", dace.Memlet("a[0]")) | ||
| down_edge = state.add_edge(tlet, "__out", b, None, dace.Memlet("b[0]")) | ||
| sdfg.validate() | ||
|
|
||
| # Now remove the edge using the function. Note that this makes the SDFG invalid, | ||
| # because the tasklet wants an input. We ignore that for now. | ||
| nb_removed_edges = dace.sdfg.utils.remove_edge_and_dangling_path(state, up_edge) | ||
| assert nb_removed_edges == 1 | ||
| assert a not in state.nodes() | ||
| assert tlet in state.nodes() | ||
| assert b in state.nodes() | ||
| assert sdfg.arrays.keys() == {"a", "b"} | ||
| assert len(tlet.in_connectors) == 0 | ||
| assert state.in_degree(tlet) == 0 | ||
| assert tlet.out_connectors.keys() == {"__out"} | ||
| assert state.out_degree(tlet) == 1 | ||
|
|
||
| # If we now also delete the `down_edge` then the state will become empty. | ||
| nb_removed_edges = dace.sdfg.utils.remove_edge_and_dangling_path(state, down_edge) | ||
| assert nb_removed_edges == 1 | ||
| assert state.number_of_nodes() == 0 | ||
|
|
||
|
|
||
| def test_remove_edge_nested_scope(): | ||
| sdfg = dace.SDFG("nested_edge_remover_test") | ||
| state = sdfg.add_state() | ||
|
|
||
| sdfg.add_array("a", shape=(10, 10), dtype=dace.float64, transient=False) | ||
| sdfg.add_array("b", shape=(10, 10, 2), dtype=dace.float64, transient=False) | ||
|
|
||
| tlet = state.add_tasklet( | ||
| "comp", | ||
| inputs={"__in1", "__in2"}, | ||
| outputs={"__out1", "__out2"}, | ||
| code="__out1 = __in1 + 1.0\n__out2 = __in2 - 1.0", | ||
| ) | ||
| a, b = (state.add_access(name) for name in "ab") | ||
| me, mx = state.add_map("outer_map", ndrange={"__i": "0:10"}) | ||
| nme, nmx = state.add_map("nested_map", ndrange={"__j": "0:10"}) | ||
|
|
||
| state.add_edge(a, None, me, "IN_a", dace.Memlet("a[0:10, 0:10]")) | ||
| me.add_scope_connectors("a") | ||
| state.add_edge(me, "OUT_a", nme, "IN_a1", dace.Memlet("a[__i, 0:10]")) | ||
| state.add_edge(me, "OUT_a", nme, "IN_a2", dace.Memlet("a[0:10, __i]")) | ||
| nme.add_scope_connectors("a1") | ||
| nme.add_scope_connectors("a2") | ||
|
|
||
| up_edge1 = state.add_edge(nme, "OUT_a1", tlet, "__in1", dace.Memlet("a[__i, __j]")) | ||
| up_edge2 = state.add_edge(nme, "OUT_a2", tlet, "__in2", dace.Memlet("a[__j, __i]")) | ||
|
|
||
| down_edge1 = state.add_edge(tlet, "__out1", nmx, "IN_b", dace.Memlet("b[__i, __j, 0]")) | ||
| down_edge2 = state.add_edge(tlet, "__out2", nmx, "IN_b", dace.Memlet("b[__j, __i, 1]")) | ||
| nmx.add_scope_connectors("b") | ||
|
|
||
| state.add_edge(nmx, "OUT_b", mx, "IN_b", dace.Memlet("b[0:10, 0:10, 0:2]")) | ||
| mx.add_scope_connectors("b") | ||
| state.add_edge(mx, "OUT_b", b, None, dace.Memlet("b[0:10, 0:10, 0:2]")) | ||
| sdfg.validate() | ||
|
|
||
| assert state.number_of_nodes() == 7 | ||
|
|
||
| # Because of the fan out, the deletion will stop. | ||
| nb_rm_up_edge1 = dace.sdfg.utils.remove_edge_and_dangling_path(state, up_edge1) | ||
| assert nb_rm_up_edge1 == 2 | ||
| assert state.number_of_nodes() == 7 | ||
| assert set(tlet.in_connectors.keys()) == {"__in2"} | ||
| assert state.in_degree(tlet) == 1 | ||
| assert set(tlet.out_connectors.keys()) == {"__out1", "__out2"} | ||
| assert state.out_degree(tlet) == 2 | ||
| assert set(nme.out_connectors.keys()) == {"OUT_a2"} | ||
| assert set(nme.in_connectors.keys()) == {"IN_a2"} | ||
| assert state.out_degree(nme) == 1 | ||
| assert state.in_degree(nme) == 1 | ||
| assert state.out_degree(me) == 1 | ||
| assert state.in_degree(me) == 1 | ||
| assert set(me.out_connectors.keys()) == {"OUT_a"} | ||
| assert set(me.in_connectors.keys()) == {"IN_a"} | ||
| assert state.degree(a) == 1 | ||
|
|
||
| # Now the deletion will go up and remove the map entries; which leads to a | ||
| # technical and functional invalid SDFG. | ||
| nb_rm_up_edge2 = dace.sdfg.utils.remove_edge_and_dangling_path(state, up_edge2) | ||
| assert nb_rm_up_edge2 == 3 | ||
| assert state.number_of_nodes() == 4 | ||
| assert state.number_of_edges() == 4 | ||
| assert {tlet, nmx, mx, b} == set(state.nodes()) | ||
| assert len(tlet.in_connectors) == 0 | ||
| assert set(tlet.out_connectors.keys()) == {"__out1", "__out2"} | ||
|
|
||
| # The first down edge, will only delete one edge, due to the location of the fan | ||
| # in, which is a bit different compared to the location of the upper fan out. | ||
| nb_rm_down_edge1 = dace.sdfg.utils.remove_edge_and_dangling_path(state, down_edge1) | ||
| assert nb_rm_down_edge1 == 1 | ||
| assert state.number_of_nodes() == 4 | ||
| assert state.number_of_edges() == 3 | ||
| assert {tlet, nmx, mx, b} == set(state.nodes()) | ||
| assert len(tlet.in_connectors) == 0 | ||
| assert set(tlet.out_connectors.keys()) == {"__out2"} | ||
| assert len(nmx.in_connectors) == 1 | ||
| assert len(nmx.out_connectors) == 1 | ||
| assert len(mx.in_connectors) == 1 | ||
| assert len(mx.out_connectors) == 1 | ||
|
|
||
| # This will remove all nodes. | ||
| nb_rm_down_edge2 = dace.sdfg.utils.remove_edge_and_dangling_path(state, down_edge2) | ||
| assert nb_rm_up_edge2 == 3 | ||
| assert state.number_of_nodes() == 0 | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| test_remove_edge_global_scope() |
Uh oh!
There was an error while loading. Please reload this page.