Skip to content
28 changes: 20 additions & 8 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,17 @@ def remove_in_connector(self, connector_name: str):
:param connector_name: The name of the connector to remove.
:return: True if the operation was successful.
"""
if not connector_name:
warnings.warn(f'Tried to remove `{connector_name}` from the in-connectors of node {str(self)}',
stacklevel=1)
return False

if connector_name not in self.in_connectors:
return False

if connector_name in self.in_connectors:
connectors = self.in_connectors
del connectors[connector_name]
self.in_connectors = connectors
connectors = self.in_connectors
del connectors[connector_name]
self.in_connectors = connectors
return True

def remove_out_connector(self, connector_name: str):
Expand All @@ -205,11 +211,17 @@ def remove_out_connector(self, connector_name: str):
:param connector_name: The name of the connector to remove.
:return: True if the operation was successful.
"""
if not connector_name:
warnings.warn(f'Tried to remove `{connector_name}` from the out-connectors of node {str(self)}',
stacklevel=1)
return False

if connector_name not in self.out_connectors:
return False

if connector_name in self.out_connectors:
connectors = self.out_connectors
del connectors[connector_name]
self.out_connectors = connectors
connectors = self.out_connectors
del connectors[connector_name]
self.out_connectors = connectors
return True

def _next_connector_int(self) -> int:
Expand Down
95 changes: 59 additions & 36 deletions dace/sdfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,51 +857,74 @@ def set_outer_subset(e: MultiConnectorEdge[dace.Memlet], new_subset: sbs.Subset)
return consolidated


def remove_edge_and_dangling_path(state: SDFGState, edge: MultiConnectorEdge):
def remove_edge_and_dangling_path(state: SDFGState, edge: MultiConnectorEdge) -> int:
"""
Removes an edge and all of its parent edges in a memlet path, cleaning
dangling connectors and isolated nodes resulting from the removal.
Removes an edge and all of its parent edges in a memlet path, including now
unused connectors. Furthermore, all nodes that become isolated are also
removed from the state.

:param state: The state in which the edge exists.
:param edge: The edge to remove.
"""
mtree = state.memlet_tree(edge)
inwards = (isinstance(edge.src, nd.EntryNode) or isinstance(edge.dst, nd.EntryNode))

if edge.data.is_empty():
state.remove_edge(edge)
if state.degree(edge.dst) == 0:
state.remove_node(edge.dst)
if state.degree(edge.src) == 0:
state.remove_node(edge.src)
return 1

# Traverse tree upwards, removing edges and connectors as necessary
curedge = mtree
while curedge is not None:
e = curedge.edge
state.remove_edge(e)
if inwards:
neighbors = [] if not e.src_conn else [
neighbor for neighbor in state.out_edges_by_connector(e.src, e.src_conn)
]
else:
neighbors = [] if not e.dst_conn else [
neighbor for neighbor in state.in_edges_by_connector(e.dst, e.dst_conn)
]
if len(neighbors) > 0: # There are still edges connected, leave as-is
break
mtree = state.memlet_tree(edge)
curr_tree = mtree
nb_removed_edges = 0
while curr_tree is not None:
curr_edge = curr_tree.edge
assert not curr_edge.data.is_empty()
state.remove_edge(curr_edge)
nb_removed_edges += 1

if curr_tree.downwards:
if state.degree(curr_edge.dst) == 0:
# If target node is isolated we can remove it.
state.remove_node(curr_edge.dst)
else:
# If the node is not isolated we must look at its connectors and clean them.
if isinstance(curr_edge.dst, nd.EntryNode) and curr_edge.dst_conn.startswith("IN_"):
curr_edge.dst.remove_out_connector("OUT_" + curr_edge.dst_conn[3:])
if curr_edge.dst_conn and len(list(state.in_edges_by_connector(curr_edge.dst,
curr_edge.dst_conn))) == 0:
curr_edge.dst.remove_in_connector(curr_edge.dst_conn)

# Remove connector and matching outer connector
if inwards:
if e.src_conn:
e.src.remove_out_connector(e.src_conn)
e.src.remove_in_connector('IN' + e.src_conn[3:])
else:
if e.dst_conn:
e.dst.remove_in_connector(e.dst_conn)
e.dst.remove_out_connector('OUT' + e.dst_conn[2:])
# There is a fan-out, i.e. the `curr_edge.src_conn` is still in use and we are done here.
if len(list(state.out_edges_by_connector(curr_edge.src, curr_edge.src_conn))) != 0:
return nb_removed_edges

# Continue traversing upwards
curedge = curedge.parent
else:
# Check if an isolated node have been created at the root and remove
root_edge = mtree.root().edge
root_node: nd.Node = root_edge.src if inwards else root_edge.dst
if state.degree(root_node) == 0:
state.remove_node(root_node)
else:
if state.degree(curr_edge.src) == 0:
state.remove_node(curr_edge.src)
else:
if isinstance(curr_edge.src, nd.ExitNode) and curr_edge.src_conn.startswith("OUT_"):
curr_edge.src.remove_in_connector("IN_" + curr_edge.src_conn[4:])
if curr_edge.src_conn and len(list(state.out_edges_by_connector(curr_edge.src,
curr_edge.src_conn))) == 0:
curr_edge.src.remove_out_connector(curr_edge.src_conn)

# The connector might be collecting.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what case is this? I think it is generally disallowed in SDFG syntax

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually wish that it would be disallowed but it is permitted and as far as I can tell parts of the toolchain (I suspect ConsolidateEdge) actually create such cases.

The case essentially represents the following:

A: np.ndarray[N, 40] = ...
for i in dace.map[0:N]:
    A[i, 0] = some_awesome_boundary_value()
    A[i, 39] = another_awesome_boundary_value()

So MapExit will have two incoming edges.
However, they do not have to go to the distinct (in-)connectors of the MapExit node.
In case they go to distinct connectors then there are also two outgoing edges one with subset [0:N, 0] and the other with [0:N, 39].

However, it is also allowed that they go to the same (in-)connector of MapExit.
In that case MapExit will only have one outgoing edge, and it will have a subset of [0:N, 0:40], which grossly overestimate the range that is written.

This is actually currently happening in GT4Py and it also blocks some optimization.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there is an exception for passthrough (IN_*/OUT_*) connectors. I was only thinking about tasklets/library nodes/nested SDFGs. Thank you for the clarification

if len(list(state.in_edges_by_connector(curr_edge.dst, curr_edge.dst_conn))) != 0:
return nb_removed_edges

# Continue traversing tree upwards
curr_tree = curr_tree.parent

# Check if an isolated node have been created at the root and remove
root_edge = mtree.root().edge
root_node: nd.Node = root_edge.src if mtree.downwards else root_edge.dst
if state.degree(root_node) == 0:
state.remove_node(root_node)

return nb_removed_edges


def consolidate_edges(
Expand Down
131 changes: 131 additions & 0 deletions tests/sdfg/remove_edge_and_dangling_path_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved.
import dace


def test_remove_edge_global_scope():
sdfg = dace.SDFG("simple_edge_remover_test")
state = sdfg.add_state()

sdfg.add_array("a", shape=(1, ), dtype=dace.float64, transient=False)
sdfg.add_array("b", shape=(1, ), dtype=dace.float64, transient=False)

tlet = state.add_tasklet(
"comp",
inputs={"__in"},
outputs={"__out"},
code="__out = __in + 1.0",
)
a = state.add_access("a")
b = state.add_access("b")

up_edge = state.add_edge(a, None, tlet, "__in", dace.Memlet("a[0]"))
down_edge = state.add_edge(tlet, "__out", b, None, dace.Memlet("b[0]"))
sdfg.validate()

# Now remove the edge using the function. Note that this makes the SDFG invalid,
# because the tasklet wants an input. We ignore that for now.
nb_removed_edges = dace.sdfg.utils.remove_edge_and_dangling_path(state, up_edge)
assert nb_removed_edges == 1
assert a not in state.nodes()
assert tlet in state.nodes()
assert b in state.nodes()
assert sdfg.arrays.keys() == {"a", "b"}
assert len(tlet.in_connectors) == 0
assert state.in_degree(tlet) == 0
assert tlet.out_connectors.keys() == {"__out"}
assert state.out_degree(tlet) == 1

# If we now also delete the `down_edge` then the state will become empty.
nb_removed_edges = dace.sdfg.utils.remove_edge_and_dangling_path(state, down_edge)
assert nb_removed_edges == 1
assert state.number_of_nodes() == 0


def test_remove_edge_nested_scope():
sdfg = dace.SDFG("nested_edge_remover_test")
state = sdfg.add_state()

sdfg.add_array("a", shape=(10, 10), dtype=dace.float64, transient=False)
sdfg.add_array("b", shape=(10, 10, 2), dtype=dace.float64, transient=False)

tlet = state.add_tasklet(
"comp",
inputs={"__in1", "__in2"},
outputs={"__out1", "__out2"},
code="__out1 = __in1 + 1.0\n__out2 = __in2 - 1.0",
)
a, b = (state.add_access(name) for name in "ab")
me, mx = state.add_map("outer_map", ndrange={"__i": "0:10"})
nme, nmx = state.add_map("nested_map", ndrange={"__j": "0:10"})

state.add_edge(a, None, me, "IN_a", dace.Memlet("a[0:10, 0:10]"))
me.add_scope_connectors("a")
state.add_edge(me, "OUT_a", nme, "IN_a1", dace.Memlet("a[__i, 0:10]"))
state.add_edge(me, "OUT_a", nme, "IN_a2", dace.Memlet("a[0:10, __i]"))
nme.add_scope_connectors("a1")
nme.add_scope_connectors("a2")

up_edge1 = state.add_edge(nme, "OUT_a1", tlet, "__in1", dace.Memlet("a[__i, __j]"))
up_edge2 = state.add_edge(nme, "OUT_a2", tlet, "__in2", dace.Memlet("a[__j, __i]"))

down_edge1 = state.add_edge(tlet, "__out1", nmx, "IN_b", dace.Memlet("b[__i, __j, 0]"))
down_edge2 = state.add_edge(tlet, "__out2", nmx, "IN_b", dace.Memlet("b[__j, __i, 1]"))
nmx.add_scope_connectors("b")

state.add_edge(nmx, "OUT_b", mx, "IN_b", dace.Memlet("b[0:10, 0:10, 0:2]"))
mx.add_scope_connectors("b")
state.add_edge(mx, "OUT_b", b, None, dace.Memlet("b[0:10, 0:10, 0:2]"))
sdfg.validate()

assert state.number_of_nodes() == 7

# Because of the fan out, the deletion will stop.
nb_rm_up_edge1 = dace.sdfg.utils.remove_edge_and_dangling_path(state, up_edge1)
assert nb_rm_up_edge1 == 2
assert state.number_of_nodes() == 7
assert set(tlet.in_connectors.keys()) == {"__in2"}
assert state.in_degree(tlet) == 1
assert set(tlet.out_connectors.keys()) == {"__out1", "__out2"}
assert state.out_degree(tlet) == 2
assert set(nme.out_connectors.keys()) == {"OUT_a2"}
assert set(nme.in_connectors.keys()) == {"IN_a2"}
assert state.out_degree(nme) == 1
assert state.in_degree(nme) == 1
assert state.out_degree(me) == 1
assert state.in_degree(me) == 1
assert set(me.out_connectors.keys()) == {"OUT_a"}
assert set(me.in_connectors.keys()) == {"IN_a"}
assert state.degree(a) == 1

# Now the deletion will go up and remove the map entries; which leads to a
# technical and functional invalid SDFG.
nb_rm_up_edge2 = dace.sdfg.utils.remove_edge_and_dangling_path(state, up_edge2)
assert nb_rm_up_edge2 == 3
assert state.number_of_nodes() == 4
assert state.number_of_edges() == 4
assert {tlet, nmx, mx, b} == set(state.nodes())
assert len(tlet.in_connectors) == 0
assert set(tlet.out_connectors.keys()) == {"__out1", "__out2"}

# The first down edge, will only delete one edge, due to the location of the fan
# in, which is a bit different compared to the location of the upper fan out.
nb_rm_down_edge1 = dace.sdfg.utils.remove_edge_and_dangling_path(state, down_edge1)
assert nb_rm_down_edge1 == 1
assert state.number_of_nodes() == 4
assert state.number_of_edges() == 3
assert {tlet, nmx, mx, b} == set(state.nodes())
assert len(tlet.in_connectors) == 0
assert set(tlet.out_connectors.keys()) == {"__out2"}
assert len(nmx.in_connectors) == 1
assert len(nmx.out_connectors) == 1
assert len(mx.in_connectors) == 1
assert len(mx.out_connectors) == 1

# This will remove all nodes.
nb_rm_down_edge2 = dace.sdfg.utils.remove_edge_and_dangling_path(state, down_edge2)
assert nb_rm_up_edge2 == 3
assert state.number_of_nodes() == 0


if __name__ == '__main__':
test_remove_edge_global_scope()