From 328212f86606b67f513a500b773f52b417246e7f Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 13 Nov 2025 13:56:54 +0100 Subject: [PATCH 01/10] adapted work depth analysis to control flow regions --- .../performance_evaluation/assumptions.py | 2 +- .../sdfg/performance_evaluation/work_depth.py | 326 +++++++++++------- 2 files changed, 195 insertions(+), 133 deletions(-) diff --git a/dace/sdfg/performance_evaluation/assumptions.py b/dace/sdfg/performance_evaluation/assumptions.py index 1b1d37348b..9e140e0fd9 100644 --- a/dace/sdfg/performance_evaluation/assumptions.py +++ b/dace/sdfg/performance_evaluation/assumptions.py @@ -272,7 +272,7 @@ def parse_assumptions(assumptions, array_symbols): for sym, assum in condensed_assumptions.items(): i = 0 for g in assum.greater: - replacement_symbol = sp.Symbol(f'_p_{sym}', positive=True, integer=True) + replacement_symbol = sp.Symbol(f'_p_{sym}', nonnegative=True, integer=True) all_subs[i][0].update({sp.Symbol(sym): replacement_symbol + g}) all_subs[i][1].update({replacement_symbol: sp.Symbol(sym) - g}) i += 1 diff --git a/dace/sdfg/performance_evaluation/work_depth.py b/dace/sdfg/performance_evaluation/work_depth.py index e238437d7f..3fa45c44e3 100644 --- a/dace/sdfg/performance_evaluation/work_depth.py +++ b/dace/sdfg/performance_evaluation/work_depth.py @@ -11,19 +11,22 @@ import os import sympy as sp from copy import deepcopy -from dace.libraries.blas import MatMul +from dace.libraries.blas import MatMul, Dot from dace.libraries.standard import Reduce, Transpose from dace.symbolic import pystr_to_symbolic import ast import astunparse import warnings -from dace.sdfg.performance_evaluation.helpers import LoopExtractionError, get_uuid, find_loop_guards_tails_exits +from dace.sdfg.performance_evaluation.helpers import get_uuid from dace.sdfg.performance_evaluation.assumptions import parse_assumptions from dace.transformation.passes.symbol_ssa import StrictSymbolSSA from dace.transformation.pass_pipeline import FixedPointPipeline +from dace.transformation.passes.analysis import loop_analysis +from dace.sdfg.state import AbstractControlFlowRegion, ControlFlowRegion, LoopRegion, ConditionalBlock +math_funcs = set() def get_array_size_symbols(sdfg): """ Returns all symbols that appear isolated in shapes of the SDFG's arrays. @@ -55,6 +58,13 @@ def symeval(val, symbols): def evaluate_symbols(base, new): + """Takes a base symbol mapping and a new one and adapts the new one to match the base one for symbols contained in it + + :param base: The base mapping + :param new: The mapping that gets adjusted + :return result: A new mapping that contains all mappings from new, but adjusted to transitively match to the mapping of base + """ + result = {} for k, v in new.items(): result[k] = symeval(v, base) @@ -82,7 +92,7 @@ def count_depth_matmul(node, symbols, state): # optimal depth of a matrix multiplication is O(log(size of shared dimension)): A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') size_shared_dimension = symeval(A_memlet.data.subset.size()[-1], symbols) - return sp.log(size_shared_dimension) + return sp.log(sp.Max(1, size_shared_dimension), 2) def count_work_reduce(node, symbols, state): @@ -102,19 +112,37 @@ def count_work_reduce(node, symbols, state): def count_depth_reduce(node, symbols, state): # optimal depth of reduction is log of the work - return sp.log(count_work_reduce(node, symbols, state)) + return sp.log(sp.Max(1, count_work_reduce(node, symbols, state)), 2) + +def count_work_dot(node, symbols, state): + print("Dot product detected") + X_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_x') + Y_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_y') + RES_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_result') + print(X_memlet.data.subset.size()) + result = 2*symeval(X_memlet.data.subset.size()[-1], symbols)-1 + return sp.sympify(result) +def count_depth_dot(node, symbols, state): + X_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_x') + Y_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_y') + RES_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_result') + # optimal depth for dot product is 1 for multiplications and logarithmic for additions + result = 1+sp.log(sp.Max(1, symeval(X_memlet.data.subset.size()[-1], symbols)), 2) + return sp.sympify(result) LIBNODES_TO_WORK = { MatMul: count_work_matmul, Transpose: lambda *args: 0, Reduce: count_work_reduce, + Dot: count_work_dot, } LIBNODES_TO_DEPTH = { MatMul: count_depth_matmul, Transpose: lambda *args: 0, Reduce: count_depth_reduce, + Dot: count_depth_dot, } PYFUNC_TO_ARITHMETICS = { @@ -294,19 +322,19 @@ def do_initial_subs(w, d, eq, subs1): return result -def sdfg_work_depth(sdfg: SDFG, - w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], - analyze_tasklet, - symbols: Dict[str, str], - equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], - subs1: Dict[str, sp.Expr], - detailed_analysis: bool = False) -> Tuple[sp.Expr, sp.Expr]: +def control_flow_region_work_depth(cfr: ControlFlowRegion, + w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], + analyze_tasklet, + symbols: Dict[str, str], + equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], + subs1: Dict[str, sp.Expr], + detailed_analysis: bool = False) -> Tuple[sp.Expr, sp.Expr]: """ - Analyze the work and depth of a given SDFG. + Analyze the work and depth of a given (structured) ControlFLowRegion. First we determine the work and depth of each state. Then we break loops in the state machine, such that we get a DAG. Lastly, we compute the path with most work and the path with the most depth in order to get the total work depth. - :param sdfg: The SDFG to analyze. + :param cfr: The ControlFLowRegion to analyze. :param w_d_map: Dictionary which will save the result. :param analyze_tasklet: Function used to analyze tasklet nodes. :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. @@ -317,130 +345,156 @@ def sdfg_work_depth(sdfg: SDFG, :param subs1: First substitution dict for greater/lesser assumptions. :return: A tuple containing the work and depth of the SDFG. """ - # First determine the work and depth of each state individually. # Keep track of the work and depth for each state in a dictionary, where work and depth are multiplied by the number # of times the state will be executed. - state_depths: Dict[SDFGState, sp.Expr] = {} - state_works: Dict[SDFGState, sp.Expr] = {} - for state in sdfg.nodes(): - state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, - detailed_analysis) - - # Substitutions for state_work and state_depth already performed, but state.executions needs to be subs'd now. - state_work = sp.simplify( - state_work.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) * - state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) - state_depth = sp.simplify( - state_depth.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) * - state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) - - state_works[state], state_depths[state] = state_work, state_depth - w_d_map[get_uuid(state)] = (state_works[state], state_depths[state]) - - # Prepare the SDFG for a depth analysis by breaking loops. This removes the edge between the last loop state and - # the guard, and instead places an edge between the last loop state and the exit state. - # This transforms the state machine into a DAG. Hence, we can find the "heaviest" and "deepest" paths in linear time. - # Additionally, construct a dummy exit state and connect every state that has no outgoing edges to it. - - # identify all loops in the SDFG - try: - nodes_oNodes_exits = find_loop_guards_tails_exits(sdfg._nx) - except LoopExtractionError: - # If loop detection fails, we cannot make proper propagation. - print('Analysis failed since not all loops got detected. It may help to use more structured loop constructs.' + - ' The analysis per state remains correct, but no SDFG-wide analysis can be performed.') - sdfg_result = (sp.oo, sp.oo) - w_d_map[get_uuid(sdfg)] = sdfg_result - - for k, (v_w, v_d) in w_d_map.items(): - # The symeval replaces nested SDFG symbols with their global counterparts. - v_w = symeval(v_w, symbols) - v_d = symeval(v_d, symbols) - w_d_map[k] = (v_w, v_d) - return sdfg_result - - # Now we need to go over each triple (node, oNode, exits). For each triple, we - # - remove edge (oNode, node), i.e. the backward edge - # - for all exits e, add edge (oNode, e). This edge may already exist - # - remove edge from node to exit (if present, i.e. while-do loop) - # - This ensures that every node with > 1 outgoing edge is a branch guard - # - useful for detailed anaylsis. - for node, oNode, exits in nodes_oNodes_exits: - sdfg.remove_edge(sdfg.edges_between(oNode, node)[0]) - for e in exits: - if len(sdfg.edges_between(oNode, e)) == 0: - # no edge there yet - sdfg.add_edge(oNode, e, InterstateEdge()) - if len(sdfg.edges_between(node, e)) > 0: - # edge present --> remove it - sdfg.remove_edge(sdfg.edges_between(node, e)[0]) + region_depths: Dict[AbstractControlFlowRegion, sp.Expr] = {} + region_works: Dict[AbstractControlFlowRegion, sp.Expr] = {} + for region in cfr.nodes(): + if isinstance(region, SDFGState): + #rename variable to make code more readable + state = region + + state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, + detailed_analysis) + + # Substitutions for state_work and state_depth already performed, but state.executions needs to be subs'd now. + state_work = sp.simplify(state_work.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + state_depth = sp.simplify(state_depth.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + + region_works[state], region_depths[state] = state_work, state_depth + w_d_map[get_uuid(state)] = (region_works[state], region_depths[state]) + elif isinstance(region, LoopRegion): + #rename variable to make code more readable + loop = region + try: + loop_var = sp.sympify(loop.loop_variable) + lower_bound = loop_analysis.get_init_assignment(loop) + upper_bound = loop_analysis.get_loop_end(loop) + step = sp.sympify(loop_analysis.get_loop_stride(loop)) + except: + raise NotImplementedError("Only loops with constant step sizes and static bounds are supported") + loop_work, loop_depth = control_flow_region_work_depth(loop, w_d_map, analyze_tasklet, symbols, + equality_subs, subs1, detailed_analysis) + + # to ensure that the summation works properly, we need to make sure that the symbol that is used as loop varaible + # is the same as the ones used in the inner expression + for var in loop_work.free_symbols: + if var.name == loop_var.name: + loop_var = var + + + #TEMPORARY FIX: Because with library nodes it can happen that we get two symbols (with the same name) that correspond to the + for var in loop_work.free_symbols: + if var.name == loop_var.name and not var == loop_var: + print("scuffed variable bug detected") + loop_work = loop_work.subs({var: loop_var}) + loop_depth = loop_depth.subs({var: loop_var}) + + # prepare loop bounds to such that we can write the work as a nice summation from 0 to an upper bound + loop_var = loop_var.subs(subs1) + shifted_hi = (upper_bound-lower_bound)//step + shifted_hi = shifted_hi.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) + lower_bound = lower_bound.subs(subs1) if step.evalf()>0 else upper_bound.subs(subs1) + shifted_lo = sp.sympify(0) + step:sp.Expr = sp.Abs(step) + + # write the work and depth of the loop as a sum of the work of one iteration over the number of loop iterations + # (we have cannot use a simple multiplication as work and depth of one loop iteration might be dependent on the loop variable) + loop_work = sp.Sum(loop_work, (loop_var, shifted_lo, shifted_hi)) + loop_depth = sp.Sum(loop_depth, (loop_var, shifted_lo, shifted_hi)) + # Do equality subs + loop_work = sp.simplify(loop_work.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + loop_depth = sp.simplify(loop_depth.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + + region_works[loop], region_depths[loop] = loop_work, loop_depth + w_d_map[get_uuid(loop)] = (region_works[loop], region_depths[loop]) + elif isinstance(region, ConditionalBlock): + branch_conditions: Dict[AbstractControlFlowRegion, sp.Expr] = {} + branch_works: Dict[AbstractControlFlowRegion, sp.Expr] = {} + branch_depths: Dict[AbstractControlFlowRegion, sp.Expr] = {} + for (condition, branch) in region.branches: + branch_conditions[branch] = pystr_to_symbolic( + condition.as_string) if condition is not None else sp.sympify(True) + branch_works[branch], branch_depths[branch] = control_flow_region_work_depth( + branch, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, detailed_analysis) + if not detailed_analysis: + region_works[region] = sp.Max(*branch_works.values()) + region_depths[region] = sp.Max(*branch_depths.values()) + else: + work_condition = list(zip(branch_works.values(), branch_conditions.values())) + depth_condition = list(zip(branch_depths.values(), branch_conditions.values())) + region_works[region] = sp.Piecewise(*work_condition) + region_depths[region] = sp.Piecewise(*depth_condition) + else: + raise NotImplementedError("Work/Depth Analysis for Unstructured Control Flow is not supported (yet).") # add a dummy exit to the SDFG, such that each path ends there. - dummy_exit = sdfg.add_state('dummy_exit') - for state in sdfg.nodes(): - if len(sdfg.out_edges(state)) == 0 and state != dummy_exit: - sdfg.add_edge(state, dummy_exit, InterstateEdge()) + dummy_exit = cfr.add_state('dummy_exit') + for region in cfr.nodes(): + if len(cfr.out_edges(region)) == 0 and region != dummy_exit: + cfr.add_edge(region, dummy_exit, InterstateEdge()) # These two dicts save the current length of the "heaviest", resp. "deepest", paths at each state. - work_map: Dict[SDFGState, sp.Expr] = {} - depth_map: Dict[SDFGState, sp.Expr] = {} + work_map: Dict[AbstractControlFlowRegion, sp.Expr] = {} + depth_map: Dict[AbstractControlFlowRegion, sp.Expr] = {} # Keeps track of assignments done on InterstateEdges. - state_value_map: Dict[SDFGState, Dict[sp.Symbol, sp.Symbol]] = {} + region_value_map: Dict[AbstractControlFlowRegion, Dict[sp.Symbol, sp.Symbol]] = {} # The dummy state has 0 work and depth. - state_depths[dummy_exit] = sp.sympify(0) - state_works[dummy_exit] = sp.sympify(0) + region_depths[dummy_exit] = sp.sympify(0) + region_works[dummy_exit] = sp.sympify(0) # Perform a BFS traversal of the state machine and calculate the maximum work / depth at each state. Only advance to # the next state in the BFS if all incoming edges have been visited, to ensure the maximum work / depth expressions # have been calculated. traversal_q = deque() - traversal_q.append((sdfg.start_state, sp.sympify(0), sp.sympify(0), None, [], [], {})) + traversal_q.append((cfr.start_block, sp.sympify(0), sp.sympify(0), None, [], [], {})) visited = set() - + c = 0 while traversal_q: - state, depth, work, ie, condition_stack, common_subexpr_stack, value_map = traversal_q.popleft() + c += 1 + region, depth, work, ie, condition_stack, common_subexpr_stack, value_map = traversal_q.popleft() if ie is not None: visited.add(ie) - if state in state_value_map: + if region in region_value_map: # update value map: - update_value_map(state_value_map[state], value_map) + update_value_map(region_value_map[region], value_map) else: - state_value_map[state] = value_map + region_value_map[region] = value_map - value_map = {pystr_to_symbolic(k): pystr_to_symbolic(v) for k, v in state_value_map[state].items()} - n_depth = sp.simplify((depth + state_depths[state]).subs(value_map)) - n_work = sp.simplify((work + state_works[state]).subs(value_map)) + value_map = {pystr_to_symbolic(k): pystr_to_symbolic(v) for k, v in region_value_map[region].items()} + n_depth = sp.simplify((depth + region_depths[region]).subs(value_map)) + n_work = sp.simplify((work + region_works[region]).subs(value_map)) # If we are analysing average parallelism, we don't search "heaviest" and "deepest" paths separately, but we want one # single path with the least average parallelsim (of all paths with more than 0 work). if analyze_tasklet == get_tasklet_avg_par: - if state in depth_map: # this means we have already visited this state before + if region in depth_map: # this means we have already visited this region before cse = common_subexpr_stack.pop() # if current path has 0 depth (--> 0 work as well), we don't do anything. if n_depth != 0: # check if we need to update the work and depth of the current state # we update if avg parallelism of new incoming path is less than current avg parallelism - if depth_map[state] == 0: + if depth_map[region] == 0: # old value was divided by zero --> we take new value anyway - depth_map[state] = cse[1] + n_depth - work_map[state] = cse[0] + n_work + depth_map[region] = cse[1] + n_depth + work_map[region] = cse[0] + n_work else: - old_avg_par = (cse[0] + work_map[state]) / (cse[1] + depth_map[state]) + old_avg_par = (cse[0] + work_map[region]) / (cse[1] + depth_map[region]) new_avg_par = (cse[0] + n_work) / (cse[1] + n_depth) # we take either old work/depth or new work/depth (or both if we cannot determine which one is greater) - depth_map[state] = cse[1] + sp.Piecewise((n_depth, sp.simplify(new_avg_par < old_avg_par)), - (depth_map[state], True)) - work_map[state] = cse[0] + sp.Piecewise((n_work, sp.simplify(new_avg_par < old_avg_par)), - (work_map[state], True)) + depth_map[region] = cse[1] + sp.Piecewise((n_depth, sp.simplify(new_avg_par < old_avg_par)), + (depth_map[region], True)) + work_map[region] = cse[0] + sp.Piecewise((n_work, sp.simplify(new_avg_par < old_avg_par)), + (work_map[region], True)) else: - depth_map[state] = n_depth - work_map[state] = n_work + depth_map[region] = n_depth + work_map[region] = n_work else: # search heaviest and deepest path separately - if state in depth_map: # and consequently also in work_map + if region in depth_map: # and consequently also in work_map # This cse value would appear in both arguments of the Max. Hence, for performance reasons, # we pull it out of the Max expression. # Example: We do cse + Max(a, b) instead of Max(cse + a, cse + b). @@ -450,18 +504,18 @@ def sdfg_work_depth(sdfg: SDFG, if detailed_analysis: # This MAX should be covered in the more detailed analysis cond = condition_stack.pop() - work_map[state] = cse[0] + sp.Piecewise((work_map[state], sp.Not(cond)), (n_work, cond)) - depth_map[state] = cse[1] + sp.Piecewise((depth_map[state], sp.Not(cond)), (n_depth, cond)) + work_map[region] = cse[0] + sp.Piecewise((work_map[region], sp.Not(cond)), (n_work, cond)) + depth_map[region] = cse[1] + sp.Piecewise((depth_map[region], sp.Not(cond)), (n_depth, cond)) else: - work_map[state] = cse[0] + sp.Max(work_map[state], n_work) - depth_map[state] = cse[1] + sp.Max(depth_map[state], n_depth) + work_map[region] = cse[0] + sp.Max(work_map[region], n_work) + depth_map[region] = cse[1] + sp.Max(depth_map[region], n_depth) else: - depth_map[state] = n_depth - work_map[state] = n_work + depth_map[region] = n_depth + work_map[region] = n_work - out_edges = sdfg.out_edges(state) + out_edges = cfr.out_edges(region) # only advance after all incoming edges were visited (meaning that current work depth values of state are final). - if any(iedge not in visited for iedge in sdfg.in_edges(state)): + if any(iedge not in visited for iedge in cfr.in_edges(region)): pass else: for oedge in out_edges: @@ -472,9 +526,9 @@ def sdfg_work_depth(sdfg: SDFG, new_cond_stack.append(oedge.data.condition_sympy()) # same for common_subexr_stack new_cse_stack = list(common_subexpr_stack) - new_cse_stack.append((work_map[state], depth_map[state])) + new_cse_stack.append((work_map[region], depth_map[region])) # same for value_map - new_value_map = dict(state_value_map[state]) + new_value_map = dict(region_value_map[state]) new_value_map.update({ pystr_to_symbolic(k): pystr_to_symbolic(v).subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) @@ -488,7 +542,7 @@ def sdfg_work_depth(sdfg: SDFG, pystr_to_symbolic(v).subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) for k, v in oedge.data.assignments.items() }) - traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge, condition_stack, + traversal_q.append((oedge.dst, depth_map[region], work_map[region], oedge, condition_stack, common_subexpr_stack, value_map)) try: @@ -497,18 +551,18 @@ def sdfg_work_depth(sdfg: SDFG, except KeyError: # If we get a KeyError above, this means that the traversal never reached the dummy_exit state. # This happens if the loops were not properly detected and broken. - raise LoopExtractionError( - 'Analysis failed, since not all loops got detected. It may help to use more structured loop constructs.') + raise RuntimeError("Analysis failed! The dummy exit state was never reached") - sdfg_result = (max_work, max_depth) - w_d_map[get_uuid(sdfg)] = sdfg_result + cfr_result = (max_work, max_depth) + w_d_map[get_uuid(cfr)] = cfr_result for k, (v_w, v_d) in w_d_map.items(): # The symeval replaces nested SDFG symbols with their global counterparts. v_w = symeval(v_w, symbols) v_d = symeval(v_d, symbols) w_d_map[k] = (v_w, v_d) - return sdfg_result + + return cfr_result def scope_work_depth( @@ -580,8 +634,8 @@ def scope_work_depth( nested_syms.update(symbols) nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping)) # Nested SDFGs are recursively analyzed first. - nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms, equality_subs, - subs1, detailed_analysis) + nsdfg_work, nsdfg_depth = control_flow_region_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms, + equality_subs, subs1, detailed_analysis) nsdfg_work, nsdfg_depth = do_initial_subs(nsdfg_work, nsdfg_depth, equality_subs, subs1) # add up work for whole state, but also save work for this nested SDFG in w_d_map @@ -610,7 +664,10 @@ def scope_work_depth( lib_node_depth = LIBNODES_TO_DEPTH[type(node)](node, symbols, state) except KeyError: top_level_sdfg = state.parent - top_level_sdfg.add_symbol(f'{node.name}_depth', dtypes.int64) + try: + top_level_sdfg.add_symbol(f'{node.name}_depth', dtypes.int64) + except FileExistsError: + pass lib_node_depth = sp.Symbol(f'{node.name}_depth', positive=True) lib_node_work, lib_node_depth = do_initial_subs(lib_node_work, lib_node_depth, equality_subs, subs1) work += lib_node_work @@ -706,6 +763,7 @@ def scope_work_depth( # summarise work / depth of the whole scope in the dictionary scope_result = (work, max_depth) w_d_map[get_uuid(state)] = scope_result + return scope_result @@ -739,7 +797,7 @@ def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet, assumptions: List[str], - detailed_analysis: bool = False) -> None: + detailed_analysis: bool = False): """ Analyze a given SDFG. We can either analyze work, work and depth or average parallelism. @@ -764,7 +822,7 @@ def analyze_sdfg(sdfg: SDFG, array_symbols = get_array_size_symbols(sdfg) # parse assumptions equality_subs, all_subs = parse_assumptions(assumptions if assumptions is not None else [], array_symbols) - + # Run state propagation for all SDFGs recursively. This is necessary to determine the number of times each state # will be executed, or to determine upper bounds for that number (such as in the case of branching) for sd in sdfg.all_sdfgs_recursive(): @@ -772,9 +830,9 @@ def analyze_sdfg(sdfg: SDFG, # Analyze the work and depth of the SDFG. symbols = {} - sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols, equality_subs, all_subs[0][0] if len(all_subs) > 0 else {}, - detailed_analysis) - + control_flow_region_work_depth(sdfg, w_d_map, analyze_tasklet, symbols, equality_subs, + all_subs[0][0] if len(all_subs) > 0 else {}, detailed_analysis) + for k, (v_w, v_d) in w_d_map.items(): # The symeval replaces nested SDFG symbols with their global counterparts. v_w, v_d = do_subs(v_w, v_d, all_subs) @@ -782,6 +840,20 @@ def analyze_sdfg(sdfg: SDFG, v_d = symeval(v_d, symbols) w_d_map[k] = (v_w, v_d) + if analyze_tasklet == get_tasklet_work_depth: + for k, v, in w_d_map.items(): + w_d_map[k] = ((sp.simplify(v[0])), (sp.simplify(v[1]))) + elif analyze_tasklet == get_tasklet_work: + for k, v, in w_d_map.items(): + w_d_map[k] = (sp.simplify(v[0])) + elif analyze_tasklet == get_tasklet_avg_par: + for k, v, in w_d_map.items(): + w_d_map[k] = (sp.simplify(v[0] / v[1]) if (v[1]) != 0 else 0) # work / depth = avg par + + result_whole_sdfg = w_d_map[get_uuid(sdfg)] + + return result_whole_sdfg + def do_subs(work, depth, all_subs): """ @@ -838,16 +910,6 @@ def main() -> None: work_depth_map = {} analyze_sdfg(sdfg, work_depth_map, analyze_tasklet, args.assume, args.detailed) - if args.analyze == 'workDepth': - for k, v, in work_depth_map.items(): - work_depth_map[k] = (str(sp.simplify(v[0])), str(sp.simplify(v[1]))) - elif args.analyze == 'work': - for k, v, in work_depth_map.items(): - work_depth_map[k] = str(sp.simplify(v[0])) - elif args.analyze == 'avgPar': - for k, v, in work_depth_map.items(): - work_depth_map[k] = str(sp.simplify(v[0] / v[1]) if str(v[1]) != '0' else 0) # work / depth = avg par - result_whole_sdfg = work_depth_map[get_uuid(sdfg)] print(80 * '-') From 2be952ab5e55401c493bd1936df26e0ebf9a7dca Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 16 Feb 2026 17:30:32 +0100 Subject: [PATCH 02/10] use work depth from other branch --- .../sdfg/performance_evaluation/work_depth.py | 382 +++++++++++++++--- 1 file changed, 333 insertions(+), 49 deletions(-) diff --git a/dace/sdfg/performance_evaluation/work_depth.py b/dace/sdfg/performance_evaluation/work_depth.py index 3fa45c44e3..5c77e3444e 100644 --- a/dace/sdfg/performance_evaluation/work_depth.py +++ b/dace/sdfg/performance_evaluation/work_depth.py @@ -11,12 +11,13 @@ import os import sympy as sp from copy import deepcopy -from dace.libraries.blas import MatMul, Dot +from dace.libraries.blas import MatMul, Dot, Gemm, Gemv from dace.libraries.standard import Reduce, Transpose from dace.symbolic import pystr_to_symbolic import ast import astunparse import warnings +import re from dace.sdfg.performance_evaluation.helpers import get_uuid from dace.sdfg.performance_evaluation.assumptions import parse_assumptions @@ -24,7 +25,7 @@ from dace.transformation.pass_pipeline import FixedPointPipeline from dace.transformation.passes.analysis import loop_analysis -from dace.sdfg.state import AbstractControlFlowRegion, ControlFlowRegion, LoopRegion, ConditionalBlock +from dace.sdfg.state import AbstractControlFlowRegion, ControlFlowRegion, LoopRegion, ConditionalBlock, ReturnBlock, ContinueBlock, BreakBlock math_funcs = set() def get_array_size_symbols(sdfg): @@ -44,6 +45,110 @@ def get_array_size_symbols(sdfg): symbols.add(s) return symbols +def subs_till_fixed_point(expr:sp.Expr, symbol_map:Dict[sp.Expr, sp.Expr]): + """ + Takes a sympy expression and a symbol mapping and applies the mapping to the expression until a fixed point is reached + Needs the guarantee that the symbol mapping does not have cyclic dependencies. + + :param expr: Description + :param symbol_map: Description + :return: Description + """ + prev = None + curr = expr + while prev != curr: + prev = curr + curr = curr.subs(symbol_map) + return curr + +def get_static_symbols(sdfg: SDFG): + """ + Returns a mapping of symbols that are assigned exactly at one point in the sdfg. + + :param sdfg: The sdfg for which we want to find the static symbols and their corresponding assignment + :return: The mapping of the symbols to higher levels (iterated to a fixed point) + """ + + + patterns = [ + "dace.complex128", + "dace.float64", + "dace.float32", + "dace.int64", + "dace.int32", + "dace.int16", + "dace.uint32", + "dace.uint16", + "dace.uint8", + "float", + "int" + ] + + type_regex = re.compile("|".join(map(re.escape, patterns))) + static_symbol_mapping:Dict[sp.Symbol, sp.Expr] = {sp.Symbol(a): sp.Symbol(a) for a in sdfg.arg_names} + non_static_symbols = set() + for node, containing_state in sdfg.all_nodes_recursive(): + if isinstance(node, nd.AccessNode): + + if containing_state.in_degree(node) == 1: + edge = containing_state.in_edges(node)[0] + source = edge.src + + if edge.data.volume == 1: + if isinstance(source, nd.Tasklet): + tasklet = source + in_map = {} + out_map = {} + # Incoming edges: symbols feeding the tasklet + for e in containing_state.in_edges(tasklet): + if not isinstance(e.src, nd.AccessNode): + continue + sym = str(e.src.data) + in_map[e.dst_conn] = sym + # Outgoing edges: symbols written by the tasklet + # Out edges should only be one, but for safety we iterate + for e in containing_state.out_edges(tasklet): + if not isinstance(e.dst, nd.AccessNode): + continue + sym = sp.Symbol(e.dst.data) + out_map[e.src_conn] = sym + code = tasklet.code.as_string.strip() + # Expect a single assignment + lines = [l.strip() for l in code.splitlines() if l.strip()] + try: + lhs, rhs = lines[0].split('=',1) + except: + # Skip mapping for overly complex tasklet code + non_static_symbols.add(sp.Symbol(node.data)) + lhs = lhs.strip() + rhs = rhs.strip() + rhs = type_regex.sub("", rhs) + # Parse RHS using SymPy, with tasklet inputs substituted + lhs_sympy = pystr_to_symbolic(lhs) + lhs_sympy = lhs_sympy.subs(out_map) + + if not lhs_sympy in static_symbol_mapping.keys(): + try: + rhs_sympy = pystr_to_symbolic(rhs) + rhs_sympy = rhs_sympy.subs(in_map) + static_symbol_mapping[lhs_sympy] = rhs_sympy + except: + non_static_symbols.add(lhs_sympy) + else: + non_static_symbols.add(lhs_sympy) + + elif isinstance(source, nd.AccessNode): + data_sym = sp.Symbol(source.data) + nd_sym = sp.Symbol(node.data) + if not data_sym in static_symbol_mapping.keys(): + static_symbol_mapping[data_sym] = nd_sym + else: + non_static_symbols.add(data_sym) + + static_symbol_mapping = {k: v for (k, v) in static_symbol_mapping.items() if k not in non_static_symbols} + static_symbol_mapping = {k: subs_till_fixed_point(v, static_symbol_mapping) for k,v in static_symbol_mapping.items()} + return static_symbol_mapping + def symeval(val, symbols): """ @@ -92,7 +197,7 @@ def count_depth_matmul(node, symbols, state): # optimal depth of a matrix multiplication is O(log(size of shared dimension)): A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') size_shared_dimension = symeval(A_memlet.data.subset.size()[-1], symbols) - return sp.log(sp.Max(1, size_shared_dimension), 2) + return sp.Max(1, sp.log(sp.Max(1, size_shared_dimension), 2)) def count_work_reduce(node, symbols, state): @@ -115,11 +220,9 @@ def count_depth_reduce(node, symbols, state): return sp.log(sp.Max(1, count_work_reduce(node, symbols, state)), 2) def count_work_dot(node, symbols, state): - print("Dot product detected") X_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_x') Y_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_y') RES_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_result') - print(X_memlet.data.subset.size()) result = 2*symeval(X_memlet.data.subset.size()[-1], symbols)-1 return sp.sympify(result) @@ -131,8 +234,149 @@ def count_depth_dot(node, symbols, state): result = 1+sp.log(sp.Max(1, symeval(X_memlet.data.subset.size()[-1], symbols)), 2) return sp.sympify(result) +def count_work_gemm(node, symbols, state): + """ + Count work for GEMM operation: C = alpha * A @ B + beta * C + Work includes: + - Matrix multiplication: 2*M*N*K (multiply + add per element) + - Alpha scaling: M*N (if alpha != 1) + - Beta scaling + addition: 2*M*N (if beta != 0) + """ + A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') + B_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_b') + C_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_c') + + # Get dimensions + # Handle batch dimension if present + if len(C_memlet.data.subset) == 3: + batch = symeval(C_memlet.data.subset.size()[0], symbols) + M = symeval(C_memlet.data.subset.size()[1], symbols) + N = symeval(C_memlet.data.subset.size()[2], symbols) + else: + batch = 1 + M = symeval(C_memlet.data.subset.size()[-2], symbols) if len(C_memlet.data.subset.size()) >= 2 else 1 + N = symeval(C_memlet.data.subset.size()[-1], symbols) + + K = symeval(A_memlet.data.subset.size()[-1], symbols) + + # Core matrix multiplication: 2*M*N*K (multiply + add) + result = 2 * batch * M * N * K + + # Add work for alpha scaling if alpha != 1 + alpha = getattr(node, 'alpha', 1) + if alpha != 1: + result += batch * M * N # M*N multiplications by alpha + + # Add work for beta * C if beta != 0 + beta = getattr(node, 'beta', 0) + if beta != 0: + result += batch * M * N # M*N multiplications by beta + result += batch * M * N # M*N additions + + return sp.sympify(result) + + +def count_depth_gemm(node, symbols, state): + """ + Optimal depth for GEMM: log(K) for the reduction + constant for scaling/addition + """ + A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') + K = symeval(A_memlet.data.subset.size()[-1], symbols) + + # Depth is dominated by the reduction over K dimension + depth = sp.log(sp.Max(1, K), 2) + + # Add constant depth for alpha and beta operations + alpha = getattr(node, 'alpha', 1) + beta = getattr(node, 'beta', 0) + + if alpha != 1: + depth += 1 # One multiplication layer + if beta != 0: + depth += 2 # One multiplication + one addition layer + + return sp.Max(1, depth) + + +def count_work_gemv(node, symbols, state): + """ + Count work for GEMV operation: y = alpha * A @ x + beta * y + Two variants: + - GEMV: y = alpha * A @ x + beta * y (A is MxN, x is N, y is M) + - GEMVT: y = alpha * A^T @ x + beta * y (A is MxN, x is M, y is N) + + Work includes: + - Matrix-vector multiplication: 2*M*N (multiply + add per element) + - Alpha scaling: M (if alpha != 1) + - Beta scaling + addition: 2*M (if beta != 0) + """ + A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_A') + x_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_x') + y_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_y') + + # Get dimensions from A matrix + A_shape = A_memlet.data.subset.size() + M = symeval(A_shape[-2], symbols) + N = symeval(A_shape[-1], symbols) + + # Check if transpose (GEMVT) + trans = getattr(node, 'transA', False) + + # Output size + output_size = N if trans else M + + # Core matrix-vector multiplication: 2*M*N (each output element needs N multiplies and N-1 adds) + result = 2 * M * N + + # Add work for alpha scaling if alpha != 1 + alpha = getattr(node, 'alpha', 1) + if alpha != 1: + result += output_size # output_size multiplications by alpha + + # Add work for beta * y if beta != 0 + beta = getattr(node, 'beta', 0) + if beta != 0: + result += output_size # output_size multiplications by beta + result += output_size # output_size additions + + return sp.sympify(result) + + +def count_depth_gemv(node, symbols, state): + """ + Optimal depth for GEMV: log(N) for the reduction + constant for scaling/addition + where N is the reduction dimension + """ + A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_A') + A_shape = A_memlet.data.subset.size() + M = symeval(A_shape[-2], symbols) + N = symeval(A_shape[-1], symbols) + + # Check if transpose + trans = getattr(node, 'transA', False) + + # Reduction dimension + reduction_dim = M if trans else N + + # Depth is dominated by the reduction + depth = sp.log(sp.Max(1, reduction_dim), 2) + + # Add constant depth for alpha and beta operations + alpha = getattr(node, 'alpha', 1) + beta = getattr(node, 'beta', 0) + + if alpha != 1: + depth += 1 # One multiplication layer + if beta != 0: + depth += 2 # One multiplication + one addition layer + + return sp.Max(1, depth) + + LIBNODES_TO_WORK = { MatMul: count_work_matmul, + Gemm: count_work_gemm, + Gemv: count_work_gemv, Transpose: lambda *args: 0, Reduce: count_work_reduce, Dot: count_work_dot, @@ -140,6 +384,8 @@ def count_depth_dot(node, symbols, state): LIBNODES_TO_DEPTH = { MatMul: count_depth_matmul, + Gemm: count_depth_gemm, + Gemv: count_depth_gemv, Transpose: lambda *args: 0, Reduce: count_depth_reduce, Dot: count_depth_dot, @@ -180,7 +426,11 @@ def visit_BinOp(self, node): return self.generic_visit(node) def visit_UnaryOp(self, node): - self.count += 1 + if isinstance(node.op, (ast.USub, ast.UAdd)): + # Unary + / - are sign or no-op → don't add work + pass + else: + self.count += 1 return self.generic_visit(node) def visit_Call(self, node): @@ -367,48 +617,66 @@ def control_flow_region_work_depth(cfr: ControlFlowRegion, elif isinstance(region, LoopRegion): #rename variable to make code more readable loop = region - try: - loop_var = sp.sympify(loop.loop_variable) - lower_bound = loop_analysis.get_init_assignment(loop) - upper_bound = loop_analysis.get_loop_end(loop) - step = sp.sympify(loop_analysis.get_loop_stride(loop)) - except: - raise NotImplementedError("Only loops with constant step sizes and static bounds are supported") + fallback = False + loop_var = sp.Symbol(loop.loop_variable) + lower_bound = loop_analysis.get_init_assignment(loop) + upper_bound = loop_analysis.get_loop_end(loop) + step = sp.sympify(loop_analysis.get_loop_stride(loop)) + if any(v is None for v in (loop_var, lower_bound, upper_bound)): + print("Because the loop does not provide a loop variable and static bounds, we fell back to just using its number of iterations. Mind that this can affect the correctness of the expression ") + fallback = True + executions = loop.start_block.executions + executions = executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) + loop_work, loop_depth = control_flow_region_work_depth(loop, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, detailed_analysis) - - # to ensure that the summation works properly, we need to make sure that the symbol that is used as loop varaible - # is the same as the ones used in the inner expression - for var in loop_work.free_symbols: - if var.name == loop_var.name: - loop_var = var - - #TEMPORARY FIX: Because with library nodes it can happen that we get two symbols (with the same name) that correspond to the - for var in loop_work.free_symbols: - if var.name == loop_var.name and not var == loop_var: - print("scuffed variable bug detected") - loop_work = loop_work.subs({var: loop_var}) - loop_depth = loop_depth.subs({var: loop_var}) - - # prepare loop bounds to such that we can write the work as a nice summation from 0 to an upper bound - loop_var = loop_var.subs(subs1) - shifted_hi = (upper_bound-lower_bound)//step - shifted_hi = shifted_hi.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) - lower_bound = lower_bound.subs(subs1) if step.evalf()>0 else upper_bound.subs(subs1) - shifted_lo = sp.sympify(0) - step:sp.Expr = sp.Abs(step) - - # write the work and depth of the loop as a sum of the work of one iteration over the number of loop iterations - # (we have cannot use a simple multiplication as work and depth of one loop iteration might be dependent on the loop variable) - loop_work = sp.Sum(loop_work, (loop_var, shifted_lo, shifted_hi)) - loop_depth = sp.Sum(loop_depth, (loop_var, shifted_lo, shifted_hi)) - # Do equality subs - loop_work = sp.simplify(loop_work.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) - loop_depth = sp.simplify(loop_depth.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + if not fallback: + # to ensure that the summation works properly, we need to make sure that the symbol that is used as loop varaible + # is the same as the ones used in the inner expression + for var in loop_work.free_symbols: + if var.name == loop_var.name: + loop_var = var + + + #TEMPORARY FIX: with library nodes it can happen that we get two symbols (with the same name) + for var in loop_work.free_symbols: + if var.name == loop_var.name and not var == loop_var: + loop_work = loop_work.subs({var: loop_var}) + loop_depth = loop_depth.subs({var: loop_var}) + + # prepare loop bounds to such that we can write the work as a nice summation from 0 to an upper bound + loop_var = loop_var.subs(subs1) + shifted_hi = (upper_bound-lower_bound)//step + shifted_hi = shifted_hi.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) + lower_bound = lower_bound.subs(subs1) if step.evalf()>0 else upper_bound.subs(subs1) + shifted_lo = sp.sympify(0) + step:sp.Expr = sp.Abs(step) + + loop_work = loop_work.subs({loop_var: (step*loop_var+lower_bound)}) + loop_depth = loop_depth.subs({loop_var: (step*loop_var+lower_bound)}) + # write the work and depth of the loop as a sum of the work of one iteration over the number of loop iterations + # (we have cannot use a simple multiplication as work and depth of one loop iteration might be dependent on the loop variable) + loop_work = sp.Sum(loop_work, (loop_var, shifted_lo, shifted_hi)) + loop_depth = sp.Sum(loop_depth, (loop_var, shifted_lo, shifted_hi)) + # Do equality subs + loop_work = sp.simplify(loop_work.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + loop_depth = sp.simplify(loop_depth.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + else: + loop_work = sp.simplify(loop_work.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + loop_depth = sp.simplify(loop_depth.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + + if executions != 0: + loop_work = loop_work*executions + loop_depth = loop_depth*executions + else: + exec_symbol = sp.Symbol(f'num_executions_{loop.name}') + loop_work = loop_work*exec_symbol + loop_depth = loop_depth*exec_symbol region_works[loop], region_depths[loop] = loop_work, loop_depth w_d_map[get_uuid(loop)] = (region_works[loop], region_depths[loop]) + elif isinstance(region, ConditionalBlock): branch_conditions: Dict[AbstractControlFlowRegion, sp.Expr] = {} branch_works: Dict[AbstractControlFlowRegion, sp.Expr] = {} @@ -426,8 +694,19 @@ def control_flow_region_work_depth(cfr: ControlFlowRegion, depth_condition = list(zip(branch_depths.values(), branch_conditions.values())) region_works[region] = sp.Piecewise(*work_condition) region_depths[region] = sp.Piecewise(*depth_condition) + w_d_map[get_uuid(region)] = (region_works[region], region_depths[region]) + + elif isinstance(region, (ReturnBlock, ContinueBlock, BreakBlock)): + region_works[region], region_depths[region] = (sp.sympify(0), sp.sympify(0)) + w_d_map[get_uuid(region)] = (sp.sympify(0), sp.sympify(0)) else: - raise NotImplementedError("Work/Depth Analysis for Unstructured Control Flow is not supported (yet).") + function_work, function_depth = control_flow_region_work_depth(region, w_d_map, analyze_tasklet, symbols, + equality_subs, subs1, detailed_analysis) + function_work = sp.simplify(function_work.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + function_depth = sp.simplify(function_depth.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + + region_works[region], region_depths[region] = function_work, function_depth + w_d_map[get_uuid(region)] = (region_works[region], region_depths[region]) # add a dummy exit to the SDFG, such that each path ends there. dummy_exit = cfr.add_state('dummy_exit') @@ -528,7 +807,7 @@ def control_flow_region_work_depth(cfr: ControlFlowRegion, new_cse_stack = list(common_subexpr_stack) new_cse_stack.append((work_map[region], depth_map[region])) # same for value_map - new_value_map = dict(region_value_map[state]) + new_value_map = dict(region_value_map[region]) new_value_map.update({ pystr_to_symbolic(k): pystr_to_symbolic(v).subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) @@ -634,10 +913,13 @@ def scope_work_depth( nested_syms.update(symbols) nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping)) # Nested SDFGs are recursively analyzed first. - nsdfg_work, nsdfg_depth = control_flow_region_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms, - equality_subs, subs1, detailed_analysis) + nsdfg_work, nsdfg_depth = control_flow_region_work_depth(node.sdfg, w_d_map, analyze_tasklet, {}, + equality_subs, {}, detailed_analysis) + + nsdfg_work, nsdfg_depth = nsdfg_work.subs(nested_syms), nsdfg_depth.subs(nested_syms) # We cannot use assumptions for nested sdfg analysis. It interfers with the global assumptions. We thus substitute afterwards nsdfg_work, nsdfg_depth = do_initial_subs(nsdfg_work, nsdfg_depth, equality_subs, subs1) + # add up work for whole state, but also save work for this nested SDFG in w_d_map work += nsdfg_work w_d_map[get_uuid(node, state)] = (nsdfg_work, nsdfg_depth) @@ -811,14 +1093,13 @@ def analyze_sdfg(sdfg: SDFG, and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). """ - # deepcopy such that original sdfg not changed sdfg = deepcopy(sdfg) # apply SSA pass pipeline = FixedPointPipeline([StrictSymbolSSA()]) pipeline.apply_pass(sdfg, {}) - + static_symbol_mapping = get_static_symbols(sdfg) array_symbols = get_array_size_symbols(sdfg) # parse assumptions equality_subs, all_subs = parse_assumptions(assumptions if assumptions is not None else [], array_symbols) @@ -840,6 +1121,9 @@ def analyze_sdfg(sdfg: SDFG, v_d = symeval(v_d, symbols) w_d_map[k] = (v_w, v_d) + for k, v, in w_d_map.items(): + w_d_map[k] = ((v[0].subs(static_symbol_mapping).subs(equality_subs[1])),(v[1].subs(static_symbol_mapping).subs(equality_subs[1]))) + if analyze_tasklet == get_tasklet_work_depth: for k, v, in w_d_map.items(): w_d_map[k] = ((sp.simplify(v[0])), (sp.simplify(v[1]))) @@ -924,4 +1208,4 @@ def main() -> None: if __name__ == '__main__': - main() + main() \ No newline at end of file From fc5151edf217dfe816b53deea859c0606189d72e Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 17 Mar 2026 21:10:01 +0100 Subject: [PATCH 03/10] fix all errors --- dace/sdfg/performance_evaluation/helpers.py | 72 +++ .../sdfg/performance_evaluation/work_depth.py | 419 +++++++++++++++--- tests/sdfg/work_depth_test.py | 65 +-- 3 files changed, 463 insertions(+), 93 deletions(-) diff --git a/dace/sdfg/performance_evaluation/helpers.py b/dace/sdfg/performance_evaluation/helpers.py index ba7bfb84f2..d730a84d88 100644 --- a/dace/sdfg/performance_evaluation/helpers.py +++ b/dace/sdfg/performance_evaluation/helpers.py @@ -5,6 +5,9 @@ from collections import deque from typing import List, Dict, Set, Tuple, Optional, Union import networkx as nx +import sympy as sp +from dace.sdfg.state import ControlFlowRegion +from dace.sdfg.propagation import propagate_states NodeT = str EdgeT = Tuple[NodeT, NodeT] @@ -335,3 +338,72 @@ def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): # remove artificial end node sdfg_nx.remove_node(artificial_end_node) return nodes_oNodes_exits + +def get_legacy_loop_body(cfr, guard, tail, exits): + """ + Get all nodes in a legacy loop body. + A node is in the loop body if: + - It's reachable from guard + - It can reach tail + - It's not an exit node + """ + # Forward reachability from guard + forward_reachable = set() + queue = deque([guard]) + while queue: + node = queue.popleft() + if node in forward_reachable: + continue + forward_reachable.add(node) + for edge in cfr.out_edges(node): + queue.append(edge.dst) + + # Backward reachability to tail + backward_reachable = set() + queue = deque([tail]) + while queue: + node = queue.popleft() + if node in backward_reachable: + continue + backward_reachable.add(node) + for edge in cfr.in_edges(node): + queue.append(edge.src) + + # Loop body = (forward AND backward) - exits + loop_body = (forward_reachable & backward_reachable) - set(exits) + + return loop_body + +def get_legacy_loop_ranges(cfr: ControlFlowRegion) -> Dict[SDFGState, Tuple[sp.Expr, sp.Expr, sp.Expr, sp.Symbol]]: + """ + Builds a map from loop guard states to their loop variable and iteration + range, harvesting the annotations set by propagate_states / + _annotate_loop_ranges. + + Must be called AFTER propagate_states has been run on the SDFG. + + :param cfr: The ControlFlowRegion to inspect (only its direct nodes are + checked, not descendants, since control_flow_region_work_depth + is called recursively anyway). + :return: A dict mapping each legacy loop guard SDFGState to a tuple + (loop_var, start, stop, stride) + """ + #propagate_states(cfr) + result: Dict[SDFGState, Tuple[sp.Symbol, sp.Expr, sp.Expr, sp.Expr]] = {} + + for node in cfr.nodes(): + if not getattr(node, 'is_loop_guard', False): + continue + + itvar_str: str = node.itvar + loop_var: sp.Symbol = sp.Symbol(itvar_str) + + # guard.ranges[itvar] is a subsets.Range with one entry: [(start, stop, stride)] + rng = node.ranges[itvar_str][0] # -> (start, stop, stride) + start = sp.sympify(rng[0]) + stop = sp.sympify(rng[1]) + stride = sp.sympify(rng[2]) + + result[node] = (loop_var, start, stop, stride) + + return result \ No newline at end of file diff --git a/dace/sdfg/performance_evaluation/work_depth.py b/dace/sdfg/performance_evaluation/work_depth.py index 5c77e3444e..bb3ffc9090 100644 --- a/dace/sdfg/performance_evaluation/work_depth.py +++ b/dace/sdfg/performance_evaluation/work_depth.py @@ -7,7 +7,7 @@ from dace.sdfg import nodes as nd, propagation, InterstateEdge from dace import SDFG, SDFGState, dtypes from dace.subsets import Range -from typing import List, Tuple, Dict +from typing import List, Tuple, Dict, Callable, Sequence, Union import os import sympy as sp from copy import deepcopy @@ -19,7 +19,7 @@ import warnings import re -from dace.sdfg.performance_evaluation.helpers import get_uuid +from dace.sdfg.performance_evaluation.helpers import LoopExtractionError, get_uuid, find_loop_guards_tails_exits, get_legacy_loop_body, get_legacy_loop_ranges from dace.sdfg.performance_evaluation.assumptions import parse_assumptions from dace.transformation.passes.symbol_ssa import StrictSymbolSSA from dace.transformation.pass_pipeline import FixedPointPipeline @@ -68,7 +68,6 @@ def get_static_symbols(sdfg: SDFG): :param sdfg: The sdfg for which we want to find the static symbols and their corresponding assignment :return: The mapping of the symbols to higher levels (iterated to a fixed point) """ - patterns = [ "dace.complex128", @@ -147,8 +146,7 @@ def get_static_symbols(sdfg: SDFG): static_symbol_mapping = {k: v for (k, v) in static_symbol_mapping.items() if k not in non_static_symbols} static_symbol_mapping = {k: subs_till_fixed_point(v, static_symbol_mapping) for k,v in static_symbol_mapping.items()} - return static_symbol_mapping - + return static_symbol_mapping def symeval(val, symbols): """ @@ -161,7 +159,6 @@ def symeval(val, symbols): second_replacement = {pystr_to_symbolic('__REPLSYM_' + k): v for k, v in symbols.items()} return sp.simplify(val.subs(first_replacement).subs(second_replacement)) - def evaluate_symbols(base, new): """Takes a base symbol mapping and a new one and adapts the new one to match the base one for symbols contained in it @@ -172,7 +169,7 @@ def evaluate_symbols(base, new): result = {} for k, v in new.items(): - result[k] = symeval(v, base) + result[k] = symeval(pystr_to_symbolic(v), base) return result @@ -214,7 +211,6 @@ def count_work_reduce(node, symbols, state): result = 0 return sp.sympify(result) - def count_depth_reduce(node, symbols, state): # optimal depth of reduction is log of the work return sp.log(sp.Max(1, count_work_reduce(node, symbols, state)), 2) @@ -295,7 +291,7 @@ def count_depth_gemm(node, symbols, state): if beta != 0: depth += 2 # One multiplication + one addition layer - return sp.Max(1, depth) + return depth def count_work_gemv(node, symbols, state): @@ -426,11 +422,7 @@ def visit_BinOp(self, node): return self.generic_visit(node) def visit_UnaryOp(self, node): - if isinstance(node.op, (ast.USub, ast.UAdd)): - # Unary + / - are sign or no-op → don't add work - pass - else: - self.count += 1 + self.count += 1 return self.generic_visit(node) def visit_Call(self, node): @@ -453,7 +445,7 @@ def visit_While(self, node): raise NotImplementedError -def count_arithmetic_ops_code(code): +def count_arithmetic_ops_code(code: Union[Sequence[ast.AST], str, ast.AST]) -> int: ctr = ArithmeticCounter() if isinstance(code, (tuple, list)): for stmt in code: @@ -466,20 +458,37 @@ def count_arithmetic_ops_code(code): class DepthCounter(ast.NodeVisitor): - # so far this is identical to the ArithmeticCounter above. + """ + Computes the depth (longest chain of dependent operations) of a Python AST expression. + + Unlike ArithmeticCounter which sums all operations, this computes the critical path + through the expression tree, tracking data dependencies across statements. + + For example: + - `a + b` has depth 1 + - `(a + b) * c` has depth 2 (add, then multiply) + - `(a + b) * (c + d)` has depth 2 (both adds can be parallel, then multiply) + - `a = x + y; b = a + z` has depth 2 (a has depth 1, b depends on a so depth 2) + - `a = x + y; b = z + w` has depth 1 (independent statements, parallel) + """ + def __init__(self): - self.count = 0 + # Track the depth at which each variable was last assigned + self.var_depths: Dict[str, int] = {} def visit_BinOp(self, node): if isinstance(node.op, ast.MatMult): raise NotImplementedError('MatMult op count requires shape ' 'inference') - self.count += 1 - return self.generic_visit(node) + # Depth is 1 (for this operation) + max depth of the two operands + left_depth = self.visit(node.left) + right_depth = self.visit(node.right) + return 1 + max(left_depth, right_depth) def visit_UnaryOp(self, node): - self.count += 1 - return self.generic_visit(node) + # Depth is 1 (for this operation) + depth of operand + operand_depth = self.visit(node.operand) + return 1 + operand_depth def visit_Call(self, node): fname = astunparse.unparse(node.func)[:-1] @@ -487,12 +496,54 @@ def visit_Call(self, node): print( 'WARNING: Unrecognized python function "%s". If this is a type conversion, like "dace.float64", then this is fine.' % fname) - return self.generic_visit(node) - self.count += PYFUNC_TO_ARITHMETICS[fname] - return self.generic_visit(node) + # Still need to visit arguments to get their depth + arg_depths = [self.visit(arg) for arg in node.args] + return max(arg_depths) if arg_depths else 0 + op_cost = PYFUNC_TO_ARITHMETICS[fname] + # Get the maximum depth among all arguments + arg_depths = [self.visit(arg) for arg in node.args] + max_arg_depth = max(arg_depths) if arg_depths else 0 + return op_cost + max_arg_depth def visit_AugAssign(self, node): - return self.visit_BinOp(node) + # e.g., x += expr is equivalent to x = x + expr + # Get the target variable name + target_name = None + if isinstance(node.target, ast.Name): + target_name = node.target.id + + # Get the current depth of the target variable (it's being read) + target_depth = self.visit(node.target) + # Get the depth of the value expression + value_depth = self.visit(node.value) + # The operation depth is 1 + max of target and value depths + result_depth = 1 + max(target_depth, value_depth) + + # Update the variable's depth + if target_name: + self.var_depths[target_name] = result_depth + + return result_depth + + def visit_Assign(self, node): + # Compute the depth of the value expression + value_depth = self.visit(node.value) + + # Update the depth of all target variables + for target in node.targets: + if isinstance(target, ast.Name): + self.var_depths[target.id] = value_depth + elif isinstance(target, ast.Tuple) or isinstance(target, ast.List): + # Handle tuple/list unpacking: a, b = ... + for elt in target.elts: + if isinstance(elt, ast.Name): + self.var_depths[elt.id] = value_depth + + return value_depth + + def visit_Expr(self, node): + # Expression statement, just propagate + return self.visit(node.value) def visit_For(self, node): raise NotImplementedError @@ -500,24 +551,109 @@ def visit_For(self, node): def visit_While(self, node): raise NotImplementedError + def visit_Name(self, node): + # Variable reference: return the tracked depth if known, else 0 + return self.var_depths.get(node.id, 0) + + def visit_Constant(self, node): + # Constants have no computational depth + return 0 + + def visit_Num(self, node): + # For older Python AST compatibility - numbers have no depth + return 0 + + def visit_Subscript(self, node): + # Array access - both the array and index computation contribute to depth + array_depth = self.visit(node.value) + index_depth = self.visit(node.slice) + return max(array_depth, index_depth) + + def visit_Index(self, node): + # For older Python AST compatibility + return self.visit(node.value) + + def visit_Tuple(self, node): + # Tuple elements can be computed in parallel + if node.elts: + return max(self.visit(e) for e in node.elts) + return 0 + + def visit_List(self, node): + # List elements can be computed in parallel + if node.elts: + return max(self.visit(e) for e in node.elts) + return 0 + + def visit_Compare(self, node): + # Comparisons: don't count as arithmetic depth (matching ArithmeticCounter) + # but still traverse to find any arithmetic in operands + depths = [self.visit(node.left)] + depths.extend(self.visit(c) for c in node.comparators) + return max(depths) + + def visit_BoolOp(self, node): + # Boolean operations (and, or): don't count as arithmetic depth + # (matching ArithmeticCounter), but traverse values + return max(self.visit(v) for v in node.values) + + def visit_IfExp(self, node): + # Ternary expression: test must be computed first, then one of body/orelse + test_depth = self.visit(node.test) + body_depth = self.visit(node.body) + orelse_depth = self.visit(node.orelse) + return test_depth + max(body_depth, orelse_depth) + + def visit_Slice(self, node): + # Slice: max depth of lower, upper, step + depths = [] + if node.lower: + depths.append(self.visit(node.lower)) + if node.upper: + depths.append(self.visit(node.upper)) + if node.step: + depths.append(self.visit(node.step)) + return max(depths) if depths else 0 + + def generic_visit(self, node): + # For unhandled nodes, try to get max depth from children + max_depth = 0 + for child in ast.iter_child_nodes(node): + child_depth = self.visit(child) + if isinstance(child_depth, int): + max_depth = max(max_depth, child_depth) + return max_depth + + +def count_depth_code(code: Union[Sequence[ast.AST], str, ast.AST]) -> int: + """ + Compute the depth (longest chain of dependent operations) of Python code. -def count_depth_code(code): - ctr = ArithmeticCounter() + Tracks data dependencies across statements via variable assignments. + For example: + - `a = x + y; b = a + z` has depth 2 (b depends on a) + - `a = x + y; b = z + w` has depth 1 (independent) + """ + ctr = DepthCounter() if isinstance(code, (tuple, list)): + if not code: + return 0 + # Process statements sequentially to track dependencies + max_depth = 0 for stmt in code: - ctr.visit(stmt) + stmt_depth = ctr.visit(stmt) + max_depth = max(max_depth, stmt_depth) + return max_depth elif isinstance(code, str): - ctr.visit(ast.parse(code)) + return ctr.visit(ast.parse(code)) else: - ctr.visit(code) - return ctr.count + return ctr.visit(code) -def tasklet_work(tasklet_node, state): +def tasklet_work(tasklet_node: nd.Tasklet, state: SDFGState): if tasklet_node.code.language == dtypes.Language.CPP: - # simplified work analysis for CPP tasklets. - for oedge in state.out_edges(tasklet_node): - return oedge.data.num_accesses + warnings.warn('Work of CPP tasklets cannot be exactly determined.') + return 1 elif tasklet_node.code.language == dtypes.Language.Python: return count_arithmetic_ops_code(tasklet_node.code.code) else: @@ -527,11 +663,10 @@ def tasklet_work(tasklet_node, state): return 1 -def tasklet_depth(tasklet_node, state): +def tasklet_depth(tasklet_node: nd.Tasklet, state: SDFGState): if tasklet_node.code.language == dtypes.Language.CPP: - # Depth == work for CPP tasklets. - for oedge in state.out_edges(tasklet_node): - return oedge.data.num_accesses + warnings.warn('Depth of CPP tasklets cannot be exactly determined.') + return 1 if tasklet_node.code.language == dtypes.Language.Python: return count_depth_code(tasklet_node.code.code) else: @@ -541,15 +676,15 @@ def tasklet_depth(tasklet_node, state): return 1 -def get_tasklet_work(node, state): +def get_tasklet_work(node: nd.Tasklet, state: SDFGState): return sp.sympify(tasklet_work(node, state)), sp.sympify(-1) -def get_tasklet_work_depth(node, state): +def get_tasklet_work_depth(node: nd.Tasklet, state: SDFGState): return sp.sympify(tasklet_work(node, state)), sp.sympify(tasklet_depth(node, state)) -def get_tasklet_avg_par(node, state): +def get_tasklet_avg_par(node: nd.Tasklet, state: SDFGState): return sp.sympify(tasklet_work(node, state)), sp.sympify(tasklet_depth(node, state)) @@ -573,12 +708,12 @@ def do_initial_subs(w, d, eq, subs1): def control_flow_region_work_depth(cfr: ControlFlowRegion, - w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], - analyze_tasklet, + w_d_map: Dict[str, Tuple[sp.Expr|List[Tuple[sp.Expr, sp.Expr]], sp.Expr|List[Tuple[sp.Expr, sp.Expr]]]], + analyze_tasklet: Callable[[nd.Tasklet, SDFGState], Tuple[sp.Expr, sp.Expr]], symbols: Dict[str, str], equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], subs1: Dict[str, sp.Expr], - detailed_analysis: bool = False) -> Tuple[sp.Expr, sp.Expr]: + detailed_analysis: bool = False) -> Tuple[sp.Expr|List[Tuple[sp.Expr, sp.Expr]], sp.Expr|List[Tuple[sp.Expr, sp.Expr]]]: """ Analyze the work and depth of a given (structured) ControlFLowRegion. First we determine the work and depth of each state. Then we break loops in the state machine, such that we get a DAG. @@ -595,19 +730,16 @@ def control_flow_region_work_depth(cfr: ControlFlowRegion, :param subs1: First substitution dict for greater/lesser assumptions. :return: A tuple containing the work and depth of the SDFG. """ - # First determine the work and depth of each state individually. - # Keep track of the work and depth for each state in a dictionary, where work and depth are multiplied by the number - # of times the state will be executed. + # First determine the work and depth of each ControlFlowRegion individually. + # Keep track of the work and depth for each state in a dictionary region_depths: Dict[AbstractControlFlowRegion, sp.Expr] = {} region_works: Dict[AbstractControlFlowRegion, sp.Expr] = {} for region in cfr.nodes(): if isinstance(region, SDFGState): #rename variable to make code more readable state = region - state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, detailed_analysis) - # Substitutions for state_work and state_depth already performed, but state.executions needs to be subs'd now. state_work = sp.simplify(state_work.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) state_depth = sp.simplify(state_depth.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) @@ -618,6 +750,10 @@ def control_flow_region_work_depth(cfr: ControlFlowRegion, #rename variable to make code more readable loop = region fallback = False + + # We try to get a closed form solution for the work and depth of the loop. If this is not possible (e.g. because there are no static loop bounds), + # we fall back to just multiplying the work and depth of one loop iteration with the number of loop iterations. This can lead to incorrect results, + # since it does not take into account that work and depth of one loop iteration might be dependent on the loop variable. loop_var = sp.Symbol(loop.loop_variable) lower_bound = loop_analysis.get_init_assignment(loop) upper_bound = loop_analysis.get_loop_end(loop) @@ -628,24 +764,26 @@ def control_flow_region_work_depth(cfr: ControlFlowRegion, executions = loop.start_block.executions executions = executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) + # Recursively get the work and depth of the loop body loop_work, loop_depth = control_flow_region_work_depth(loop, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, detailed_analysis) if not fallback: + # If static loop bounds are available, we can write the work and depth of the loop as a summation over the loop variable from the lower to the upper bound. + # to ensure that the summation works properly, we need to make sure that the symbol that is used as loop varaible # is the same as the ones used in the inner expression for var in loop_work.free_symbols: if var.name == loop_var.name: loop_var = var - #TEMPORARY FIX: with library nodes it can happen that we get two symbols (with the same name) for var in loop_work.free_symbols: if var.name == loop_var.name and not var == loop_var: loop_work = loop_work.subs({var: loop_var}) loop_depth = loop_depth.subs({var: loop_var}) - # prepare loop bounds to such that we can write the work as a nice summation from 0 to an upper bound + # prepare loop bounds such that we can write the work as a nice summation from 0 to an upper bound loop_var = loop_var.subs(subs1) shifted_hi = (upper_bound-lower_bound)//step shifted_hi = shifted_hi.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) @@ -655,10 +793,13 @@ def control_flow_region_work_depth(cfr: ControlFlowRegion, loop_work = loop_work.subs({loop_var: (step*loop_var+lower_bound)}) loop_depth = loop_depth.subs({loop_var: (step*loop_var+lower_bound)}) + # write the work and depth of the loop as a sum of the work of one iteration over the number of loop iterations # (we have cannot use a simple multiplication as work and depth of one loop iteration might be dependent on the loop variable) - loop_work = sp.Sum(loop_work, (loop_var, shifted_lo, shifted_hi)) - loop_depth = sp.Sum(loop_depth, (loop_var, shifted_lo, shifted_hi)) + loop_work = sp.Sum(loop_work, (loop_var, shifted_lo, shifted_hi)).doit() + loop_depth = sp.Sum(loop_depth, (loop_var, shifted_lo, shifted_hi)).doit() + + print("Loop work uneval: ", sp.Sum(loop_work, (loop_var, shifted_lo, shifted_hi))) # Do equality subs loop_work = sp.simplify(loop_work.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) loop_depth = sp.simplify(loop_depth.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) @@ -670,7 +811,7 @@ def control_flow_region_work_depth(cfr: ControlFlowRegion, loop_work = loop_work*executions loop_depth = loop_depth*executions else: - exec_symbol = sp.Symbol(f'num_executions_{loop.name}') + exec_symbol = sp.Symbol(f'num_execs_{region.sdfg.cfg_id}_{region.sdfg.node_id(region)}', nonnegative=True) loop_work = loop_work*exec_symbol loop_depth = loop_depth*exec_symbol @@ -678,16 +819,45 @@ def control_flow_region_work_depth(cfr: ControlFlowRegion, w_d_map[get_uuid(loop)] = (region_works[loop], region_depths[loop]) elif isinstance(region, ConditionalBlock): - branch_conditions: Dict[AbstractControlFlowRegion, sp.Expr] = {} - branch_works: Dict[AbstractControlFlowRegion, sp.Expr] = {} - branch_depths: Dict[AbstractControlFlowRegion, sp.Expr] = {} + branch_conditions = {} + branch_works = {} + branch_depths = {} for (condition, branch) in region.branches: - branch_conditions[branch] = pystr_to_symbolic( - condition.as_string) if condition is not None else sp.sympify(True) + branch_conditions[branch] = ( + pystr_to_symbolic(condition.as_string) + if condition is not None else sp.sympify(True) + ) branch_works[branch], branch_depths[branch] = control_flow_region_work_depth( - branch, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, detailed_analysis) - if not detailed_analysis: - region_works[region] = sp.Max(*branch_works.values()) + branch, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, detailed_analysis + ) + + if analyze_tasklet == get_tasklet_avg_par: + # For avg_par we want the branch minimising W/D (worst case parallelism). + # Build a Piecewise that selects work and depth together based on which + # branch has the smallest W/D ratio, so the pair stays consistent. + branches = list(branch_works.keys()) + + def avg_par_expr_val(w: sp.Expr, d: sp.Expr) -> sp.Expr: + """Return W/D, treating D=0 as infinite parallelism (never worst-case).""" + return sp.Piecewise((w / d, d > 0), (sp.oo, True)) + + # Start with the first branch as the running minimum + best_work = branch_works[branches[0]] + best_depth = branch_depths[branches[0]] + + for b in branches[1:]: + w_b = branch_works[b] + d_b = branch_depths[b] + ap_best = avg_par_expr_val(best_work, best_depth) + ap_b = avg_par_expr_val(w_b, d_b) + is_worse = sp.simplify(ap_b < ap_best) # lower ratio = worse parallelism + best_work = sp.Piecewise((w_b, is_worse), (best_work, True)) + best_depth = sp.Piecewise((d_b, is_worse), (best_depth, True)) + + region_works[region] = best_work + region_depths[region] = best_depth + elif not detailed_analysis: + region_works[region] = sp.Max(*branch_works.values()) region_depths[region] = sp.Max(*branch_depths.values()) else: work_condition = list(zip(branch_works.values(), branch_conditions.values())) @@ -708,10 +878,124 @@ def control_flow_region_work_depth(cfr: ControlFlowRegion, region_works[region], region_depths[region] = function_work, function_depth w_d_map[get_uuid(region)] = (region_works[region], region_depths[region]) + edge_w_d_map: Dict[Tuple[str, str], Tuple[sp.Expr, sp.Expr]] = {} + for isedge in cfr.edges(): + edge_work, edge_depth = sp.sympify(0), sp.sympify(0) + if isedge.data.assignments: + for v in isedge.data.assignments.values(): + edge_depth = sp.Max(edge_depth, count_depth_code(v)) + edge_work += count_arithmetic_ops_code(v) + edge_w_d_map[(get_uuid(isedge.src), get_uuid(isedge.dst))] = (edge_work, edge_depth) + + # Prepare the SDFG for a depth analysis by breaking loops. This removes the edge between the last loop state and + # the guard, and instead places an edge between the last loop state and the exit state. + # This transforms the state machine into a DAG. Hence, we can find the "heaviest" and "deepest" paths in linear time. + # Additionally, construct a dummy exit state and connect every state that has no outgoing edges to it. + + + # ================================= Handling of Legacy Loops ================================= + # identify legacy loops (i.e. loops that aren't LoopRegions) in the CFR + try: + nodes_oNodes_exits = find_loop_guards_tails_exits(cfr._nx) + except LoopExtractionError: + # If loop detection fails, we cannot make proper propagation. + print('Analysis failed since not all loops got detected. It may help to use more structured loop constructs.' + + ' The analysis per state remains correct, but no SDFG-wide analysis can be performed.') + sdfg_result = (sp.oo, sp.oo) + w_d_map[get_uuid(cfr)] = sdfg_result + + for k, (v_w, v_d) in w_d_map.items(): + # The symeval replaces nested SDFG symbols with their global counterparts. + v_w = symeval(v_w, symbols) + v_d = symeval(v_d, symbols) + w_d_map[k] = (v_w, v_d) + return sdfg_result + + legacy_loop_ranges = get_legacy_loop_ranges(cfr) + + for guard, tail, exits in nodes_oNodes_exits: + + if guard not in legacy_loop_ranges: + # propagate_states didn't recognise this as an annotated for-loop + # (e.g. while-loop or dynamic unbounded). Fall back to the + # execution-count symbol approach as before. + iter_sym = sp.Symbol( + f'num_execs_{guard.parent.cfg_id}_{guard.parent.node_id(guard)}', + nonnegative=True + ) + loop_body_nodes = get_legacy_loop_body(cfr, guard, tail, exits) + for node in loop_body_nodes: + if node is guard: + continue + node_uuid = get_uuid(node) + if node_uuid not in w_d_map: + continue + node_work, node_depth = w_d_map[node_uuid] + region_works[node] = sp.simplify(node_work * iter_sym) + region_depths[node] = sp.simplify(node_depth * iter_sym) + w_d_map[node_uuid] = (region_works[node], region_depths[node]) + continue + + print(f'Handling legacy loop with guard {guard} and tail {tail} and loops bounds {legacy_loop_ranges[guard]}') + loop_var, start, stop, stride = legacy_loop_ranges[guard] + loop_body_nodes = get_legacy_loop_body(cfr, guard, tail, exits) + + loop_var = loop_var.subs(subs1) + # Shift loop variable to start from 0 with step 1 + shifted_hi = (stop - start) // stride + shifted_hi = shifted_hi.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) + + lower_bound = start.subs(subs1) if stride.evalf() > 0 else stop.subs(subs1) + abs_stride = sp.Abs(stride) + + for node in loop_body_nodes: + if node is guard: + continue + node_uuid = get_uuid(node) + if node_uuid not in w_d_map: + continue + node_work = w_d_map[node_uuid][0] + node_depth = w_d_map[node_uuid][1] + # Substitute loop_var -> (abs_stride * loop_var + lower_bound), matching + # the LoopRegion branch exactly. + node_work = node_work.subs( {loop_var: abs_stride * loop_var + lower_bound}) + node_depth = node_depth.subs({loop_var: abs_stride * loop_var + lower_bound}) + + node_work = sp.Sum(node_work, (loop_var, sp.sympify(0), shifted_hi)).doit() + node_depth = sp.Sum(node_depth, (loop_var, sp.sympify(0), shifted_hi)).doit() + + node_work = sp.simplify(node_work.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + node_depth = sp.simplify(node_depth.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + + # Write the summed result back into ONLY the guard, so the downstream + # BFS traversal picks up one single node representing the whole loop. + region_works[node] = node_work + region_depths[node] = node_depth + w_d_map[get_uuid(node)] = (node_work, node_depth) + + + # Now we need to go over each triple (node, oNode, exits). For each triple, we + # - remove edge (oNode, node), i.e. the backward edge + # - for all exits e, add edge (oNode, e). This edge may already exist + # - remove edge from node to exit (if present, i.e. while-do loop) + # - This ensures that every node with > 1 outgoing edge is a branch guard + # - useful for detailed anaylsis. + for node, oNode, exits in nodes_oNodes_exits: + cfr.remove_edge(cfr.edges_between(oNode, node)[0]) + for e in exits: + if len(cfr.edges_between(oNode, e)) == 0: + # no edge there yet + cfr.add_edge(oNode, e, InterstateEdge()) + if len(cfr.edges_between(node, e)) > 0: + # edge present --> remove it + cfr.remove_edge(cfr.edges_between(node, e)[0]) + + # ================================= End of Legacy Loop Handling ================================= + # add a dummy exit to the SDFG, such that each path ends there. dummy_exit = cfr.add_state('dummy_exit') for region in cfr.nodes(): - if len(cfr.out_edges(region)) == 0 and region != dummy_exit: + if len(cfr.out_edges(region)) == 0 and region is not dummy_exit: cfr.add_edge(region, dummy_exit, InterstateEdge()) # These two dicts save the current length of the "heaviest", resp. "deepest", paths at each state. @@ -736,6 +1020,10 @@ def control_flow_region_work_depth(cfr: ControlFlowRegion, if ie is not None: visited.add(ie) + edge_uid = (get_uuid(ie.src), get_uuid(ie.dst)) + if edge_uid in edge_w_d_map: # Skip new edges we added to `dummy_exit` + work += edge_w_d_map[edge_uid][0] + depth += edge_w_d_map[edge_uid][1] if region in region_value_map: # update value map: @@ -832,7 +1120,7 @@ def control_flow_region_work_depth(cfr: ControlFlowRegion, # This happens if the loops were not properly detected and broken. raise RuntimeError("Analysis failed! The dummy exit state was never reached") - cfr_result = (max_work, max_depth) + cfr_result = (max_work.simplify(), max_depth.simplify()) w_d_map[get_uuid(cfr)] = cfr_result for k, (v_w, v_d) in w_d_map.items(): @@ -1045,7 +1333,6 @@ def scope_work_depth( # summarise work / depth of the whole scope in the dictionary scope_result = (work, max_depth) w_d_map[get_uuid(state)] = scope_result - return scope_result diff --git a/tests/sdfg/work_depth_test.py b/tests/sdfg/work_depth_test.py index 38843cc9fe..3d795dce0a 100644 --- a/tests/sdfg/work_depth_test.py +++ b/tests/sdfg/work_depth_test.py @@ -1,6 +1,7 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Contains test cases for the work depth analysis. """ from typing import Dict, List, Tuple +from unittest import result import pytest import dace as dc @@ -180,6 +181,10 @@ def gemm_library_node(x: dc.float64[456, 200], y: dc.float64[200, 111], z: dc.fl def gemm_library_node_symbolic(x: dc.float64[M, K], y: dc.float64[K, N], z: dc.float64[M, N]): z[:] = x @ y +@dc.program +def loop_var_dependent_work(x: dc.float64[N], y: dc.float64[N], z:dc.float64[N]): + for i in range(1, N+1): + z[i-1] = np.dot(x[:i], y[:i]) #(sdfg, (expected_work, expected_depth)) work_depth_test_cases: Dict[str, Tuple[DaceProgram, Tuple[symbolic.SymbolicType, symbolic.SymbolicType]]] = { @@ -193,18 +198,25 @@ def gemm_library_node_symbolic(x: dc.float64[M, K], y: dc.float64[K, N], z: dc.f 'nested_if_else': (nested_if_else, (sp.Max(K, 3 * N, M + N), sp.Max(3, K, M + 1))), 'max_of_positive_symbols': (max_of_positive_symbol, (3 * N**2, 3 * N)), 'multiple_array_sizes': (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), - 'unbounded_while_do': (unbounded_while_do, (sp.Symbol('num_execs_0_5') * N, sp.Symbol('num_execs_0_5'))), + 'unbounded_while_do': (unbounded_while_do, (sp.Symbol('num_execs_0_0') * N, sp.Symbol('num_execs_0_0'))), # We get this Max(1, num_execs), since it is a do-while loop, but the num_execs symbol does not capture this. - 'unbounded_nonnegify': (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_8') * N, 2 * sp.Symbol('num_execs_0_8'))), + 'unbounded_nonnegify': (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_0') * N, 2 * sp.Symbol('num_execs_0_0'))), 'break_for_loop': (break_for_loop, (N**2, N)), - 'break_while_loop': (break_while_loop, (sp.Symbol('num_execs_0_7') * N, sp.Symbol('num_execs_0_7'))), + 'break_while_loop': (break_while_loop, (sp.Symbol('num_execs_0_0') * N, sp.Symbol('num_execs_0_0'))), 'sequential_ifs': (sequntial_ifs, (sp.Max(N + 1, M) + sp.Max(N + 1, M + 1), sp.Max(1, M) + 1)), - 'reduction_library_node': (reduction_library_node, (456, sp.log(456))), - 'reduction_library_node_symbolic': (reduction_library_node_symbolic, (N, sp.log(N))), - 'gemm_library_node': (gemm_library_node, (2 * 456 * 200 * 111, sp.log(200))), - 'gemm_library_node_symbolic': (gemm_library_node_symbolic, (2 * M * K * N, sp.log(K))) + 'reduction_library_node': (reduction_library_node, (456, sp.log(456)/sp.log(2))), + 'reduction_library_node_symbolic': (reduction_library_node_symbolic, (N, sp.log(sp.Max(1, N))/sp.log(2))), + 'gemm_library_node': (gemm_library_node, (2 * 456 * 200 * 111, sp.log(200)/sp.log(2))), + 'gemm_library_node_symbolic': (gemm_library_node_symbolic, (2 * M * K * N, sp.Max(1, sp.log(sp.Max(1, K))/sp.log(2)))), + 'loop_var_dependent_work': (loop_var_dependent_work, (N**2, N+sp.Sum(sp.log(sp.Symbol("_p_i", nonnegative=True) + 1), (sp.Symbol("_p_i", nonnegative=True), 0, N - 1))/sp.log(2))) } +def standardize(expr): + new_expr = expr.replace( + lambda x: isinstance(x, sp.Symbol), + lambda x: sp.Symbol(x.name) + ) + return new_expr @pytest.mark.parametrize('test_name', list(work_depth_test_cases.keys())) def test_work_depth(test_name): @@ -220,21 +232,21 @@ def test_work_depth(test_name): sdfg.apply_transformations(MapExpansion) # NOTE: Until the W/D Analysis is changed to make use of the new blocks, inline control flow for the analysis. - inline_control_flow_regions(sdfg) - for sd in sdfg.all_sdfgs_recursive(): - sd.using_explicit_control_flow = False + # inline_control_flow_regions(sdfg) + # This is now not necessary anymore. Work-depth analysis still works for inlined SDFGs, but this imposes other names which fails test cases 'unbounded_while_do', 'unbounded_nonnegify', 'break_while_loop' + + #fg.all_sdfgs_recursive(): + # sd.using_explicit_control_flow = False analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth, [], False) res = w_d_map[get_uuid(sdfg)] # substitue each symbol without assumptions. # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. - reps = {s: sp.Symbol(s.name) for s in (res[0].free_symbols | res[1].free_symbols)} - res = (res[0].subs(reps), res[1].subs(reps)) - reps = {s: sp.Symbol(s.name) for s in (sp.sympify(correct[0]).free_symbols | sp.sympify(correct[1]).free_symbols)} - correct = (sp.sympify(correct[0]).subs(reps), sp.sympify(correct[1]).subs(reps)) + res = (standardize(res[0]), standardize(res[1])) + correct = (standardize(sp.sympify(correct[0])), standardize(sp.sympify(correct[1]))) # check result - assert correct == res - + assert res[0].expand() == correct[0].expand() + assert res[1].expand() == correct[1].expand() #(sdfg, expected_avg_par) tests_cases_avg_par = { @@ -249,13 +261,12 @@ def test_work_depth(test_name): 'unbounded_nonnegify': (unbounded_nonnegify, N), 'break_for_loop': (break_for_loop, N), 'break_while_loop': (break_while_loop, N), - 'reduction_library_node': (reduction_library_node, 456 / sp.log(456)), - 'reduction_library_node_symbolic': (reduction_library_node_symbolic, N / sp.log(N)), - 'gemm_library_node': (gemm_library_node, 2 * 456 * 200 * 111 / sp.log(200)), - 'gemm_library_node_symbolic': (gemm_library_node_symbolic, 2 * M * K * N / sp.log(K)), + 'reduction_library_node': (reduction_library_node, 456 / (sp.log(456)/sp.log(2))), + 'reduction_library_node_symbolic': (reduction_library_node_symbolic, N*sp.log(2)/sp.log(sp.Max(1, N))), + 'gemm_library_node': (gemm_library_node, 2 * 456 * 200 * 111 / (sp.log(200)/sp.log(2))), + 'gemm_library_node_symbolic': (gemm_library_node_symbolic, 2*K*M*N/sp.Max(1, sp.log(sp.Max(1, K))/sp.log(2))), } - @pytest.mark.parametrize('test_name', list(tests_cases_avg_par.keys())) def test_avg_par(test_name: str): if (dc.Config.get_bool('optimizer', 'automatic_simplification') == False @@ -271,12 +282,12 @@ def test_avg_par(test_name: str): sdfg.apply_transformations(MapExpansion) # NOTE: Until the W/D Analysis is changed to make use of the new blocks, inline control flow for the analysis. - inline_control_flow_regions(sdfg) - for sd in sdfg.all_sdfgs_recursive(): - sd.using_explicit_control_flow = False - + # inline_control_flow_regions(sdfg) + # This is now not necessary anymore. Work-depth analysis still works for inlined SDFGs, but this imposes other names which fails test cases 'unbounded_while_do', 'unbounded_nonnegify', 'break_while_loop' + #for sd in sdfg.all_sdfgs_recursive(): + # sd.using_explicit_control_flow = False analyze_sdfg(sdfg, w_d_map, get_tasklet_avg_par, [], False) - res = w_d_map[get_uuid(sdfg)][0] / w_d_map[get_uuid(sdfg)][1] + res = w_d_map[get_uuid(sdfg)] # substitue each symbol without assumptions. # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. reps = {s: sp.Symbol(s.name) for s in res.free_symbols} @@ -284,7 +295,7 @@ def test_avg_par(test_name: str): reps = {s: sp.Symbol(s.name) for s in sp.sympify(correct).free_symbols} correct = sp.sympify(correct).subs(reps) # check result - assert correct == res + assert res.expand() == correct.expand() x, y, z, a = sp.symbols('x y z a') From 942a9578b13baf7221195637a1bddd657eae9de5 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 18 Mar 2026 10:24:20 +0100 Subject: [PATCH 04/10] reapply changes after merge --- .../sdfg/performance_evaluation/work_depth.py | 31 ++--- tests/sdfg/work_depth_test.py | 124 +----------------- 2 files changed, 12 insertions(+), 143 deletions(-) diff --git a/dace/sdfg/performance_evaluation/work_depth.py b/dace/sdfg/performance_evaluation/work_depth.py index 54bdce0523..bb3ffc9090 100644 --- a/dace/sdfg/performance_evaluation/work_depth.py +++ b/dace/sdfg/performance_evaluation/work_depth.py @@ -707,13 +707,13 @@ def do_initial_subs(w, d, eq, subs1): return result -def sdfg_work_depth(sdfg: SDFG, - w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], - analyze_tasklet, - symbols: Dict[str, str], - equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], - subs1: Dict[str, sp.Expr], - detailed_analysis: bool = False) -> Tuple[sp.Expr, sp.Expr]: +def control_flow_region_work_depth(cfr: ControlFlowRegion, + w_d_map: Dict[str, Tuple[sp.Expr|List[Tuple[sp.Expr, sp.Expr]], sp.Expr|List[Tuple[sp.Expr, sp.Expr]]]], + analyze_tasklet: Callable[[nd.Tasklet, SDFGState], Tuple[sp.Expr, sp.Expr]], + symbols: Dict[str, str], + equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], + subs1: Dict[str, sp.Expr], + detailed_analysis: bool = False) -> Tuple[sp.Expr|List[Tuple[sp.Expr, sp.Expr]], sp.Expr|List[Tuple[sp.Expr, sp.Expr]]]: """ Analyze the work and depth of a given (structured) ControlFLowRegion. First we determine the work and depth of each state. Then we break loops in the state machine, such that we get a DAG. @@ -887,15 +887,6 @@ def avg_par_expr_val(w: sp.Expr, d: sp.Expr) -> sp.Expr: edge_work += count_arithmetic_ops_code(v) edge_w_d_map[(get_uuid(isedge.src), get_uuid(isedge.dst))] = (edge_work, edge_depth) - edge_w_d_map: Dict[Tuple[str, str], Tuple[sp.Expr, sp.Expr]] = {} - for isedge in sdfg.edges(): - edge_work, edge_depth = sp.sympify(0), sp.sympify(0) - if isedge.data.assignments: - for v in isedge.data.assignments.values(): - edge_depth = sp.Max(edge_depth, count_depth_code(v)) - edge_work += count_arithmetic_ops_code(v) - edge_w_d_map[(get_uuid(isedge.src), get_uuid(isedge.dst))] = (edge_work, edge_depth) - # Prepare the SDFG for a depth analysis by breaking loops. This removes the edge between the last loop state and # the guard, and instead places an edge between the last loop state and the exit state. # This transforms the state machine into a DAG. Hence, we can find the "heaviest" and "deepest" paths in linear time. @@ -1002,10 +993,10 @@ def avg_par_expr_val(w: sp.Expr, d: sp.Expr) -> sp.Expr: # ================================= End of Legacy Loop Handling ================================= # add a dummy exit to the SDFG, such that each path ends there. - dummy_exit = sdfg.add_state('dummy_exit') - for state in sdfg.nodes(): - if len(sdfg.out_edges(state)) == 0 and state != dummy_exit: - sdfg.add_edge(state, dummy_exit, InterstateEdge()) + dummy_exit = cfr.add_state('dummy_exit') + for region in cfr.nodes(): + if len(cfr.out_edges(region)) == 0 and region is not dummy_exit: + cfr.add_edge(region, dummy_exit, InterstateEdge()) # These two dicts save the current length of the "heaviest", resp. "deepest", paths at each state. work_map: Dict[AbstractControlFlowRegion, sp.Expr] = {} diff --git a/tests/sdfg/work_depth_test.py b/tests/sdfg/work_depth_test.py index f10dd3e0cb..3d795dce0a 100644 --- a/tests/sdfg/work_depth_test.py +++ b/tests/sdfg/work_depth_test.py @@ -8,7 +8,7 @@ from dace import symbolic from dace.frontend.python.parser import DaceProgram from dace.sdfg.performance_evaluation.work_depth import (analyze_sdfg, get_tasklet_work_depth, get_tasklet_avg_par, - parse_assumptions, count_arithmetic_ops_code, count_depth_code) + parse_assumptions) from dace.sdfg.performance_evaluation.helpers import get_uuid from dace.sdfg.performance_evaluation.assumptions import ContradictingAssumptions import sympy as sp @@ -338,126 +338,6 @@ def test_assumption_system_contradictions(assumptions): parse_assumptions(assumptions, set()) -def test_depth_counter_vs_work_counter(): - """ - Test that the DepthCounter correctly computes depth (longest chain of dependent operations) - which can differ from work (total number of operations). - - Depth measures the critical path through the expression tree, while work measures - the total number of operations. - """ - # Test case 1: (a + b) * (c + d) - # Work = 3 (two additions + one multiplication) - # Depth = 2 (additions can be parallel, then multiplication) - code1 = "(a + b) * (c + d)" - assert count_arithmetic_ops_code(code1) == 3, "Work should be 3" - assert count_depth_code(code1) == 2, "Depth should be 2" - - # Test case 2: a + b + c + d (left-associative: ((a + b) + c) + d) - # Work = 3 (three additions) - # Depth = 3 (sequential chain of additions) - code2 = "a + b + c + d" - assert count_arithmetic_ops_code(code2) == 3, "Work should be 3" - assert count_depth_code(code2) == 3, "Depth should be 3" - - # Test case 3: (a + b) * (c + d) + (e + f) * (g + h) - # Work = 7 (4 additions + 2 multiplications + 1 addition) - # Depth = 3 (parallel adds, then parallel mults, then final add) - code3 = "(a + b) * (c + d) + (e + f) * (g + h)" - assert count_arithmetic_ops_code(code3) == 7, "Work should be 7" - assert count_depth_code(code3) == 3, "Depth should be 3" - - # Test case 4: Simple single operation - # Work = 1, Depth = 1 - code4 = "a + b" - assert count_arithmetic_ops_code(code4) == 1, "Work should be 1" - assert count_depth_code(code4) == 1, "Depth should be 1" - - # Test case 5: Unary operation with binary operation - # -a + b: Work = 2, Depth = 2 (unary then add, but unary on a, so depth is 1+1=2) - code5 = "-a + b" - assert count_arithmetic_ops_code(code5) == 2, "Work should be 2" - assert count_depth_code(code5) == 2, "Depth should be 2" - - # Test case 6: Function call with independent arguments - # max(a + b, c + d): Work = 2 (two adds, max is 0), Depth = 1 (parallel adds, max is 0) - code6 = "max(a + b, c + d)" - assert count_arithmetic_ops_code(code6) == 2, "Work should be 2" - assert count_depth_code(code6) == 1, "Depth should be 1" - - # Test case 7: Nested function calls - # sqrt(a + b): Work = 2 (add + sqrt), Depth = 2 (add then sqrt) - code7 = "sqrt(a + b)" - assert count_arithmetic_ops_code(code7) == 2, "Work should be 2" - assert count_depth_code(code7) == 2, "Depth should be 2" - - # Test case 8: AugAssign with parallel sub-expressions - # x += a * b + c * d: Work = 4 (2 mults + 1 add + 1 augassign), Depth = 3 - code8 = "x += a * b + c * d" - assert count_arithmetic_ops_code(code8) == 4, "Work should be 4" - assert count_depth_code(code8) == 3, "Depth should be 3" - - # Test case 9: Multiple independent statements (no data dependency) - # a = x + y; b = z + w --> Work = 2, Depth = 1 (parallel, no dependency) - code9 = """ -a = x + y -b = z + w -""" - assert count_arithmetic_ops_code(code9) == 2, "Work should be 2" - assert count_depth_code(code9) == 1, "Depth should be 1 (independent statements)" - - # Test case 10: Multiple statements WITH data dependency - # a = x + y; b = a + z --> Work = 2, Depth = 2 (b depends on a) - code10 = """ -a = x + y -b = a + z -""" - assert count_arithmetic_ops_code(code10) == 2, "Work should be 2" - assert count_depth_code(code10) == 2, "Depth should be 2 (b depends on a)" - - # Test case 11: Chain of 3 dependent statements - # a = x + y; b = a * 2; c = b + z --> Work = 3, Depth = 3 - code11 = """ -a = x + y -b = a * 2 -c = b + z -""" - assert count_arithmetic_ops_code(code11) == 3, "Work should be 3" - assert count_depth_code(code11) == 3, "Depth should be 3 (chain: a -> b -> c)" - - # Test case 12: Diamond dependency pattern - # a = x + y; b = a + 1; c = a + 2; d = b + c --> Work = 4, Depth = 3 - # a has depth 1, b and c both have depth 2 (depend on a), d has depth 3 - code12 = """ -a = x + y -b = a + 1 -c = a + 2 -d = b + c -""" - assert count_arithmetic_ops_code(code12) == 4, "Work should be 4" - assert count_depth_code(code12) == 3, "Depth should be 3 (diamond: a -> b,c -> d)" - - # Test case 13: AugAssign chain - # x += 1; x += 2; x += 3 --> Work = 3, Depth = 3 (each depends on previous x) - code13 = """ -x += 1 -x += 2 -x += 3 -""" - assert count_arithmetic_ops_code(code13) == 3, "Work should be 3" - assert count_depth_code(code13) == 3, "Depth should be 3 (augassign chain)" - - # Test case 14: Single complex statement with tree structure - # result = (a+b)*(c+d) + (e+f)*(g+h) + (i+j)*(k+l) - # The AST is left-associative: ((prod1 + prod2) + prod3) - # prod1 has depth 2, prod2 has depth 2, prod3 has depth 2 - # (prod1 + prod2) = max(2,2) + 1 = 3 - # ((prod1 + prod2) + prod3) = max(3, 2) + 1 = 4 - code14 = "(a+b)*(c+d) + (e+f)*(g+h) + (i+j)*(k+l)" - assert count_arithmetic_ops_code(code14) == 11, "Work should be 11" - assert count_depth_code(code14) == 4, "Depth should be 4" - - if __name__ == '__main__': for test_name in work_depth_test_cases.keys(): test_work_depth(test_name) @@ -470,5 +350,3 @@ def test_depth_counter_vs_work_counter(): for assumptions in tests_for_exception: test_assumption_system_contradictions(assumptions) - - test_depth_counter_vs_work_counter() From 79b5dc51e076fe00ddfa638e16348ca9f29700cd Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 18 Mar 2026 10:55:09 +0100 Subject: [PATCH 05/10] fix operational intensity simulation --- .../performance_evaluation/op_in_helpers.py | 9 +- .../operational_intensity.py | 312 +++++++++++++++--- 2 files changed, 277 insertions(+), 44 deletions(-) diff --git a/dace/sdfg/performance_evaluation/op_in_helpers.py b/dace/sdfg/performance_evaluation/op_in_helpers.py index b1e430f676..abe0e2d5a9 100644 --- a/dace/sdfg/performance_evaluation/op_in_helpers.py +++ b/dace/sdfg/performance_evaluation/op_in_helpers.py @@ -2,7 +2,7 @@ """ Contains class CacheLineTracker which keeps track of all arrays of an SDFG and their cache line position and class AccessStack which which corresponds to the stack used to compute the stack distance. Further, provides a curve fitting method and plotting function. """ - +from __future__ import annotations import warnings from dace.data import Array import sympy as sp @@ -11,7 +11,6 @@ import numpy as np from dace import symbol - class CacheLineTracker: """ A CacheLineTracker maps data container accesses to the corresponding accessed cache line. """ @@ -129,6 +128,12 @@ def copy(self): curr.next = Node(x) curr = curr.next return new_stack + + def replace_self(self, other:AccessStack): + self.top = other.top + self.num_calls = other.num_calls + self.lengh = other.length + self.C = other.C def plot(x, work_map, cache_misses, op_in_map, symbol_name, C, L, sympy_f, element, name): diff --git a/dace/sdfg/performance_evaluation/operational_intensity.py b/dace/sdfg/performance_evaluation/operational_intensity.py index ee9286a7c5..a713ab69eb 100644 --- a/dace/sdfg/performance_evaluation/operational_intensity.py +++ b/dace/sdfg/performance_evaluation/operational_intensity.py @@ -5,12 +5,14 @@ import argparse from collections import deque from dace.sdfg import nodes as nd -from dace import SDFG, SDFGState, dtypes +from dace import dtypes, SDFG +from dace.sdfg.state import SDFGState, ControlFlowRegion, LoopRegion, FunctionCallRegion, ConditionalBlock, ReturnBlock, ContinueBlock, BreakBlock from typing import Tuple, Dict import os import sympy as sp from copy import deepcopy from dace.symbolic import pystr_to_symbolic, SymExpr +import re from dace.sdfg.performance_evaluation.helpers import get_uuid from dace.transformation.passes.symbol_ssa import StrictSymbolSSA @@ -20,6 +22,110 @@ from dace.sdfg.performance_evaluation.op_in_helpers import CacheLineTracker, AccessStack, fit_curve, plot, compute_mape from dace.sdfg.performance_evaluation.work_depth import analyze_sdfg, get_tasklet_work +from dace.transformation.passes.analysis import loop_analysis + +def subs_till_fixed_point(expr:sp.Expr, symbol_map:Dict[sp.Expr, sp.Expr]): + """ + Takes a sympy expression and a symbol mapping and applies the mapping to the expression until a fixed point is reached + Needs the guarantee that the symbol mapping does not have cyclic dependencies. + + :param expr: Description + :param symbol_map: Description + :return: Description + """ + if not isinstance(expr, sp.Expr): + return expr + prev = None + curr = expr + while prev != curr: + prev = curr + curr = curr.subs(symbol_map) + return curr + +def get_static_symbols(sdfg: SDFG): + """ + Returns a mapping of symbols that are assigned exactly at one point in the sdfg. + + :param sdfg: The sdfg for which we want to find the static symbols and their corresponding assignment + :return: The mapping of the symbols to higher levels (iterated to a fixed point) + """ + + + patterns = [ + "dace.complex128", + "dace.float64", + "dace.float32", + "dace.int64", + "dace.int32", + "dace.int16", + "dace.uint32", + "dace.uint16", + "dace.uint8", + "float", + "int" + ] + + type_regex = re.compile("|".join(map(re.escape, patterns))) + static_symbol_mapping:Dict[sp.Symbol, sp.Expr] = {sp.Symbol(a): sp.Symbol(a) for a in sdfg.arg_names} + non_static_symbols = set() + for node, containing_state in sdfg.all_nodes_recursive(): + if isinstance(node, nd.AccessNode): + + if containing_state.in_degree(node) == 1: + edge = containing_state.in_edges(node)[0] + source = edge.src + + if edge.data.volume == 1: + if isinstance(source, nd.Tasklet): + tasklet = source + in_map = {} + out_map = {} + # Incoming edges: symbols feeding the tasklet + for e in containing_state.in_edges(tasklet): + if not isinstance(e.src, nd.AccessNode): + continue + sym = str(e.src.data) + in_map[e.dst_conn] = sym + # Outgoing edges: symbols written by the tasklet + # Out edges should only be one, but for safety we iterate + for e in containing_state.out_edges(tasklet): + if not isinstance(e.dst, nd.AccessNode): + continue + sym = sp.Symbol(e.dst.data) + out_map[e.src_conn] = sym + code = tasklet.code.as_string.strip() + # Expect a single assignment + lines = [l.strip() for l in code.splitlines() if l.strip()] + lhs, rhs = lines[0].split('=',1) + lhs = lhs.strip() + rhs = rhs.strip() + rhs = type_regex.sub("", rhs) + # Parse RHS using SymPy, with tasklet inputs substituted + lhs_sympy = pystr_to_symbolic(lhs) + lhs_sympy = lhs_sympy.subs(out_map) + + if not lhs_sympy in static_symbol_mapping.keys(): + try: + rhs_sympy = pystr_to_symbolic(rhs) + rhs_sympy = rhs_sympy.subs(in_map) + static_symbol_mapping[lhs_sympy] = rhs_sympy + except: + non_static_symbols.add(lhs_sympy) + else: + non_static_symbols.add(lhs_sympy) + + elif isinstance(source, nd.AccessNode): + data_sym = sp.Symbol(source.data) + nd_sym = sp.Symbol(node.data) + if not data_sym in static_symbol_mapping.keys(): + static_symbol_mapping[data_sym] = nd_sym + else: + non_static_symbols.add(data_sym) + + static_symbol_mapping = {k: v for (k, v) in static_symbol_mapping.items() if k not in non_static_symbols} + static_symbol_mapping = {str(k): subs_till_fixed_point(v, static_symbol_mapping) for k,v in static_symbol_mapping.items()} + return static_symbol_mapping + class SymbolRange(): """ Used to describe an SDFG symbol associated with a range (start, stop, step) of values. """ @@ -60,7 +166,7 @@ def update_map(op_in_map, uuid, new_misses, average=True): def calculate_op_in(op_in_map, work_map, stringify=False, assumptions={}): """ Calculates the operational intensity for each SDFG element from work and bytes loaded. """ for uuid in op_in_map: - work = work_map[uuid][0].subs(assumptions) + work = work_map[uuid].subs(assumptions) if work == 0 and op_in_map[uuid] == 0: op_in_map[uuid] = 0 elif work != 0 and op_in_map[uuid] == 0: @@ -124,8 +230,10 @@ def symeval(val, symbols): def evaluate_symbols(base, new): result = {} + print("Evaluate symbols called with base:", base, "new:", new) for k, v in new.items(): result[k] = symeval(v, base) + print("Result:", result) return result @@ -137,7 +245,7 @@ def update_mapping(mapping, e): mapping.update(update) -def update_map_iterators(map, mapping): +def update_map_iterators(map, mapping, symbols): # update the map params and return False # if all iterations exhausted, return True # always increase the last one. If it is exhausted, increase the next one and so forth @@ -145,23 +253,23 @@ def update_map_iterators(map, mapping): for p, range in zip(map.params[::-1], map.range[::-1]): # reversed order curr_value = mapping[p] if not isinstance(range[1], SymExpr): - if curr_value.subs(mapping) + range[2].subs(mapping) <= range[1].subs(mapping): + if curr_value.subs(symbols).subs(mapping) + range[2].subs(symbols).subs(mapping) <= range[1].subs(symbols).subs(mapping): # update this value and then we are done - mapping[p] = curr_value.subs(mapping) + range[2].subs(mapping) + mapping[p] = curr_value.subs(symbols).subs(mapping) + range[2].subs(symbols).subs(mapping) map_exhausted = False break else: # set current param to start again and continue - mapping[p] = range[0].subs(mapping) + mapping[p] = range[0].subs(symbols).subs(mapping) else: - if curr_value.subs(mapping) + range[2].subs(mapping) <= range[1].expr.subs(mapping): + if curr_value.subs(symbols).subs(mapping) + range[2].subs(symbols).subs(mapping) <= range[1].expr.subs(symbols).subs(mapping): # update this value and we done - mapping[p] = curr_value.subs(mapping) + range[2].subs(mapping) + mapping[p] = curr_value.subs(symbols).subs(mapping) + range[2].subs(symbols).subs(mapping) map_exhausted = False break else: # set current param to start again and continue - mapping[p] = range[0].subs(mapping) + mapping[p] = range[0].subs(symbols).subs(mapping) return map_exhausted @@ -174,15 +282,15 @@ def map_op_in(state: SDFGState, op_in_map: Dict[str, sp.Expr], entry, mapping, s map_misses = 0 while True: # do analysis of map contents - map_misses += scope_op_in(state, op_in_map, mapping, stack, clt, C, symbols, array_names, decided_branches, + map_misses += scope_misses(state, op_in_map, mapping, stack, clt, C, symbols, array_names, decided_branches, ask_user, entry) - if update_map_iterators(entry.map, mapping): + if update_map_iterators(entry.map, mapping,symbols): break return map_misses -def scope_op_in(state: SDFGState, +def scope_misses(state: SDFGState, op_in_map: Dict[str, sp.Expr], mapping, stack: AccessStack, @@ -231,8 +339,7 @@ def scope_op_in(state: SDFGState, line_id = clt.cache_line_id( e.data.data if e.data.data not in array_names else array_names[e.data.data], [x[0].subs(mapping) for x in e.data.subset.ranges], mapping) - - line_id = int(line_id.subs(mapping)) + line_id = int(line_id.subs(symbols).subs(mapping)) dist = stack.touch(line_id) tasklet_misses += 1 if dist >= C or dist == -1 else 0 @@ -259,7 +366,7 @@ def scope_op_in(state: SDFGState, for e in state.out_edges(node): nested_array_names[e.src_conn] = e.data.data # Nested SDFGs are recursively analyzed first. - nsdfg_misses = sdfg_op_in(node.sdfg, op_in_map, mapping, stack, clt, C, nested_syms, nested_array_names, + nsdfg_misses = cfg_misses(node.sdfg, op_in_map, mapping, stack, clt, C, nested_syms, nested_array_names, decided_branches, ask_user) scope_misses += nsdfg_misses @@ -280,8 +387,111 @@ def scope_op_in(state: SDFGState, update_map(op_in_map, get_uuid(state), scope_misses, average=False) return scope_misses - -def sdfg_op_in(sdfg: SDFG, +def cfr_misses(cfr:ControlFlowRegion, + op_in_map: Dict[str, Tuple[sp.Expr, sp.Expr]], + mapping, + stack: AccessStack, + clt: CacheLineTracker, + C, + symbols, + array_names, + decided_branches, + ask_user, + start=None): + region_misses = 0 + if isinstance(cfr, SDFGState): + region_misses = scope_misses(cfr, op_in_map, mapping, stack, clt, C, symbols, array_names, decided_branches, + ask_user, None) + + elif isinstance(cfr, LoopRegion): + loop_var = cfr.loop_variable + loop_condition = pystr_to_symbolic(cfr.loop_condition.as_string) + start = loop_analysis.get_init_assignment(cfr).subs(mapping) + step = sp.sympify(loop_analysis.get_loop_stride(cfr)) + mapping[loop_var] = start.subs(mapping) + region_misses = 0 + + while (loop_condition.subs(mapping) == True): + iter_misses = cfg_misses(cfr, op_in_map, mapping, stack, clt, C, symbols, array_names, decided_branches, ask_user, start=cfr.start_block,end=None) + mapping[loop_var] = mapping[loop_var] + step + region_misses += iter_misses + + elif isinstance(cfr, ConditionalBlock): + true_branches = [] + possible_branches = [] + else_branch = None + + for cond, branch in cfr.branches: + if cond is None: + else_branch = branch + continue + + sym_cond = pystr_to_symbolic(cond.as_string) + res = sym_cond.subs(mapping) + + if res == True: + true_branches.append(branch) + elif res == False: + continue + else: + possible_branches.append(branch) + + ### if the branch is not decided by a true condition we + # 1- ask the userif he hasn't decided yet + # 2- take the one we took last time if he has decided + # 3- take the worst case if he opted not to decide + possibilities = true_branches + possible_branches + [else_branch] + if not true_branches and len(possible_branches)>0 and ask_user and (cfr not in decided_branches or decided_branches[cfr] not in possibilities): + possibilities = true_branches + possible_branches + ['else_branch'] + if len(possibilities)>1: + print(f'\n\nWhich branch to take at {cfr.name}') + for i in range(len(possibilities)): + print(f'({i}) for branch {possibilities[i] if possibilities[i] else "else_branch"}') + chosen = int(input('Choose an option from above: ')) + # if the user chooses one, we check only that branch + branches = [possibilities[chosen]] + if possibilities[chosen]: + # only store the decided branch if it is not the implicit else branch + decided_branches[cfr] = possibilities[chosen] + else: + branches = possibilities + elif true_branches: + # if we have true branches we take the first one + branches = [true_branches[0]] + else: + # else we check all possibilities and take the max + branches = possibilities + + + max_branch_misses = 0 + mapping_after_cond, stack_after_cond, decided_branches_after_cond = mapping, stack, decided_branches + for branch in branches: + if not branch: + # the implicit else branch has no misses + continue + # copy all data that must not be shared between branches + mapping_copy = deepcopy(mapping) + stack_copy = deepcopy(stack) + symbols_copy = deepcopy(symbols) + decided_branches_copy = deepcopy(decided_branches) + branch_misses = cfg_misses(branch, op_in_map, mapping_copy, stack_copy, clt, C, symbols_copy, array_names, decided_branches_copy, ask_user, branch.start_block, None) + + if branch_misses > max_branch_misses: + max_branch_misses = branch_misses + mapping_after_cond, stack_after_cond, decided_branches_after_cond = mapping_copy, stack_copy, decided_branches_copy + + mapping.update(mapping_after_cond) + stack.replace_self(stack_after_cond) + decided_branches.update(decided_branches_after_cond) + region_misses = max_branch_misses + elif isinstance(cfr, FunctionCallRegion): + region_misses = cfg_misses(cfr, op_in_map,mapping,stack,clt,C,symbols,array_names,decided_branches,ask_user,start=cfr.start_block,end=None) + elif isinstance(cfr, (ReturnBlock, ContinueBlock, BreakBlock)): + region_misses = 0 + + return region_misses + +def cfg_misses(cfg: ControlFlowRegion, op_in_map: Dict[str, Tuple[sp.Expr, sp.Expr]], mapping, stack: AccessStack, @@ -311,31 +521,33 @@ def sdfg_op_in(sdfg: SDFG, :param end: The end state of the SDFG traversal. If None, the whole SDFG is traversed. """ - if start is None: + if isinstance(cfg, SDFG) and start is None: # add this SDFG's arrays to the cache line tracker - for name, arr in sdfg.arrays.items(): + for name, arr in cfg.arrays.items(): if isinstance(arr, Array): if name in array_names: name = array_names[name] clt.add_array(name, arr, mapping) # start traversal at SDFG's start state - curr_state = sdfg.start_state + curr_state = cfg.start_block else: curr_state = start total_misses = 0 - # traverse this SDFG's states + # traverse this SDFG's ControlFlowRegions while True: - total_misses += scope_op_in(curr_state, op_in_map, mapping, stack, clt, C, symbols, array_names, - decided_branches, ask_user) - - if len(sdfg.out_edges(curr_state)) == 0: + + region_misses = cfr_misses(curr_state, op_in_map, mapping, stack, clt, C, symbols, array_names, + decided_branches, ask_user) + + total_misses += region_misses + if len(cfg.out_edges(curr_state)) == 0: # we reached an end state --> stop break else: # take first edge with True condition found = False - for e in sdfg.out_edges(curr_state): + for e in cfg.out_edges(curr_state): if e.data.is_unconditional() or e.data.condition_sympy().subs(mapping) == True: # save e's assignments in mapping and update curr_state # replace values first with mapping, then update mapping @@ -351,7 +563,7 @@ def sdfg_op_in(sdfg: SDFG, if not found: # We need to check if we are in an implicit end state (i.e. all outgoing edge conditions evaluate to False) all_false = True - for e in sdfg.out_edges(curr_state): + for e in cfg.out_edges(curr_state): if e.data.condition_sympy().subs(mapping) != False: all_false = False if all_false: @@ -365,10 +577,10 @@ def sdfg_op_in(sdfg: SDFG, curr_state = e.dst else: # we cannot determine which branch to take --> check if both contain work - merge_state = find_merge_state(sdfg, curr_state) + merge_state = find_merge_state(cfg, curr_state) next_edge_candidates = [] - for e in sdfg.out_edges(curr_state): - states = find_states_between(sdfg, e.dst, merge_state) + for e in cfg.out_edges(curr_state): + states = find_states_between(cfg, e.dst, merge_state) curr_work = mem_accesses_on_path(states) if sp.sympify(curr_work).subs(mapping) > 0: next_edge_candidates.append(e) @@ -380,7 +592,7 @@ def sdfg_op_in(sdfg: SDFG, curr_state = e.dst else: if ask_user: - edges = sdfg.out_edges(curr_state) + edges = cfg.out_edges(curr_state) print(f'\n\nWhich branch to take at {curr_state.name}') for i in range(len(edges)): print(f'({i}) for edge to state {edges[i].dst.name}') @@ -406,17 +618,18 @@ def sdfg_op_in(sdfg: SDFG, curr_state = e.dst # walk down this branch until merge_state - sdfg_op_in(sdfg, op_in_map, curr_mapping, curr_stack, curr_clt, C, curr_symbols, + cfg_misses(cfg, op_in_map, curr_mapping, curr_stack, curr_clt, C, curr_symbols, curr_array_names, decided_branches, ask_user, curr_state, merge_state) update_mapping(mapping, final_e) curr_state = final_e.dst if curr_state == end: break - + if end is None: # only update if we were actually analyzing a whole sdfg (not just start to end state) - update_map(op_in_map, get_uuid(sdfg), total_misses, average=False) + update_map(op_in_map, get_uuid(cfg), total_misses, average=False) + return total_misses @@ -427,7 +640,7 @@ def analyze_sdfg_op_in(sdfg: SDFG, assumptions, generate_plots=False, stringify=False, - test_set_size=3, + test_set_size=1, ask_user=False): """ Computes the operational intensity of the input SDFG. @@ -448,7 +661,7 @@ def analyze_sdfg_op_in(sdfg: SDFG, # from now on we take C as the number of lines that fit into cache C = C // L - + sdfg = deepcopy(sdfg) # apply SSA pass pipeline = FixedPointPipeline([StrictSymbolSSA()]) @@ -463,9 +676,10 @@ def analyze_sdfg_op_in(sdfg: SDFG, elif isinstance(assumptions[sym], str): range_symbol[sym] = SymbolRange(int(x) for x in assumptions[sym].split(',')) del assumptions[sym] - work_map = {} + assumptions_list = [f'{x}=={y}' for x, y in assumptions.items()] + analyze_sdfg(sdfg, work_map, get_tasklet_work, assumptions_list) if len(undefined_symbols) > 0: @@ -480,12 +694,17 @@ def analyze_sdfg_op_in(sdfg: SDFG, # all symbols are concretized --> run normal op_in analysis with concretized symbols sdfg.specialize(assumptions) mapping = {} + # add the static symbols to the map to allow for better analysis + static_symbols = get_static_symbols(sdfg) + mapping.update(static_symbols) + mapping.update(assumptions) + mapping = {k: subs_till_fixed_point(v, mapping) for k,v in mapping.items()} stack = AccessStack(C) clt = CacheLineTracker(L) - sdfg_op_in(sdfg, op_in_map, mapping, stack, clt, C, {}, {}, {}, ask_user) + total_misses = cfg_misses(sdfg, op_in_map, mapping, stack, clt, C, {}, {}, {}, ask_user) # compute bytes for k, v in op_in_map.items(): op_in_map[k] = v[0] / v[1] * L @@ -500,11 +719,12 @@ def analyze_sdfg_op_in(sdfg: SDFG, while True: new_val = False for sym, r in range_symbol.items(): + val = r.next() if val > -1: new_val = True assumptions[sym] = val - elif t < 3: + elif t < test_set_size: # now we sample test set t += 1 assumptions[sym] = r.max_value() + t * 3 @@ -512,12 +732,19 @@ def analyze_sdfg_op_in(sdfg: SDFG, if not new_val: break + r_sdfg = deepcopy(sdfg) + curr_op_in_map = {} mapping = {} + # add the static symbols to the map to allow for better analysis + static_symbols = get_static_symbols(r_sdfg) + mapping.update(static_symbols) mapping.update(assumptions) + mapping = {k: subs_till_fixed_point(v, mapping) for k,v in mapping.items()} + stack = AccessStack(C) clt = CacheLineTracker(L) - sdfg_op_in(sdfg, curr_op_in_map, mapping, stack, clt, C, {}, {}, {}, ask_user) + cfg_misses(r_sdfg, curr_op_in_map, mapping, stack, clt, C, {}, {}, {}, ask_user) # compute average cache misses for k, v in curr_op_in_map.items(): @@ -526,7 +753,7 @@ def analyze_sdfg_op_in(sdfg: SDFG, # save cache misses curr_cache_misses = dict(curr_op_in_map) - work_measurements.append(work_map[get_uuid(sdfg)][0].subs(assumptions)) + work_measurements.append(work_map[get_uuid(sdfg)].subs(assumptions)) # put curr values in cache_miss_measurements for k, v in curr_cache_misses.items(): if k in cache_miss_measurements: @@ -540,6 +767,7 @@ def analyze_sdfg_op_in(sdfg: SDFG, sympy_fs = {} for k, v in cache_miss_measurements.items(): + final_f, sympy_f, r_s = fit_curve(x_values[:-test_set_size], v[:-test_set_size], symbol_name) op_in_map[k] = sp.simplify(sympy_f * L) sympy_fs[k] = sympy_f @@ -562,6 +790,7 @@ def analyze_sdfg_op_in(sdfg: SDFG, if stringify: for k, v in op_in_map.items(): op_in_map[k] = str(v) + return op_in_map[get_uuid(sdfg)] ################################################################################ @@ -599,7 +828,6 @@ def main() -> None: assumptions[a] = int(b) else: assumptions[a] = b - print(assumptions) analyze_sdfg_op_in(sdfg, op_in_map, int(args.C), int(args.L), assumptions) result_whole_sdfg = op_in_map[get_uuid(sdfg)] From e906254b7e4ca37d26174f3ce3dd720e17d635d1 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 27 Apr 2026 20:01:22 +0200 Subject: [PATCH 06/10] Add tests for inlined control flow regions --- tests/sdfg/work_depth_test.py | 59 +++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tests/sdfg/work_depth_test.py b/tests/sdfg/work_depth_test.py index 3d795dce0a..bf188c3962 100644 --- a/tests/sdfg/work_depth_test.py +++ b/tests/sdfg/work_depth_test.py @@ -248,6 +248,36 @@ def test_work_depth(test_name): assert res[0].expand() == correct[0].expand() assert res[1].expand() == correct[1].expand() +@pytest.mark.parametrize('test_name', list(work_depth_test_cases.keys())) +def test_work_depth_inlined(test_name): + if test_name in ['unbounded_while_do', 'unbounded_nonnegify', 'break_while_loop']: + pytest.skip('Different state naming when ControlFLowRegios are inlined') + + test, correct = work_depth_test_cases[test_name] + w_d_map: Dict[str, sp.Expr] = {} + sdfg = test.to_sdfg() + if 'nested_sdfg' in test.name: + sdfg.apply_transformations(NestSDFG) + if 'nested_maps' in test.name: + sdfg.apply_transformations(MapExpansion) + + # test + inline_control_flow_regions(sdfg) + sdfg.save(test_name+".sdfg") + for sd in sdfg.all_sdfgs_recursive(): + sd.using_explicit_control_flow = False + + + analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth, [], False) + res = w_d_map[get_uuid(sdfg)] + # substitue each symbol without assumptions. + # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. + res = (standardize(res[0]), standardize(res[1])) + correct = (standardize(sp.sympify(correct[0])), standardize(sp.sympify(correct[1]))) + # check result + assert res[0].expand() == correct[0].expand() + assert res[1].expand() == correct[1].expand() + #(sdfg, expected_avg_par) tests_cases_avg_par = { 'single_map': (single_map, N), @@ -298,6 +328,35 @@ def test_avg_par(test_name: str): assert res.expand() == correct.expand() +@pytest.mark.parametrize('test_name', list(tests_cases_avg_par.keys())) +def test_avg_par_inlined(test_name: str): + if test_name in ['unbounded_while_do', 'unbounded_nonnegify', 'break_while_loop']: + pytest.skip('Different state naming when ControlFLowRegios are inlined') + + test, correct = tests_cases_avg_par[test_name] + w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]] = {} + sdfg = test.to_sdfg() + if 'nested_sdfg' in test_name: + sdfg.apply_transformations(NestSDFG) + if 'nested_maps' in test_name: + sdfg.apply_transformations(MapExpansion) + + inline_control_flow_regions(sdfg) + + for sd in sdfg.all_sdfgs_recursive(): + sd.using_explicit_control_flow = False + analyze_sdfg(sdfg, w_d_map, get_tasklet_avg_par, [], False) + res = w_d_map[get_uuid(sdfg)] + # substitue each symbol without assumptions. + # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. + reps = {s: sp.Symbol(s.name) for s in res.free_symbols} + res = res.subs(reps) + reps = {s: sp.Symbol(s.name) for s in sp.sympify(correct).free_symbols} + correct = sp.sympify(correct).subs(reps) + # check result + assert res.expand() == correct.expand() + + x, y, z, a = sp.symbols('x y z a') # (expr, assumptions, result) From 84af34c47934d301373c8bf6ae82574f842ae8f8 Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 28 Apr 2026 18:29:34 +0200 Subject: [PATCH 07/10] add back accidentally removed test --- tests/sdfg/work_depth_test.py | 120 +++++++++++++++++++++++++++++++++- 1 file changed, 119 insertions(+), 1 deletion(-) diff --git a/tests/sdfg/work_depth_test.py b/tests/sdfg/work_depth_test.py index bf188c3962..5920174c18 100644 --- a/tests/sdfg/work_depth_test.py +++ b/tests/sdfg/work_depth_test.py @@ -8,7 +8,7 @@ from dace import symbolic from dace.frontend.python.parser import DaceProgram from dace.sdfg.performance_evaluation.work_depth import (analyze_sdfg, get_tasklet_work_depth, get_tasklet_avg_par, - parse_assumptions) + parse_assumptions, count_arithmetic_ops_code, count_depth_code) from dace.sdfg.performance_evaluation.helpers import get_uuid from dace.sdfg.performance_evaluation.assumptions import ContradictingAssumptions import sympy as sp @@ -396,6 +396,124 @@ def test_assumption_system_contradictions(assumptions): with raises(ContradictingAssumptions): parse_assumptions(assumptions, set()) +def test_depth_counter_vs_work_counter(): + """ + Test that the DepthCounter correctly computes depth (longest chain of dependent operations) + which can differ from work (total number of operations). + Depth measures the critical path through the expression tree, while work measures + the total number of operations. + """ + # Test case 1: (a + b) * (c + d) + # Work = 3 (two additions + one multiplication) + # Depth = 2 (additions can be parallel, then multiplication) + code1 = "(a + b) * (c + d)" + assert count_arithmetic_ops_code(code1) == 3, "Work should be 3" + assert count_depth_code(code1) == 2, "Depth should be 2" + + # Test case 2: a + b + c + d (left-associative: ((a + b) + c) + d) + # Work = 3 (three additions) + # Depth = 3 (sequential chain of additions) + code2 = "a + b + c + d" + assert count_arithmetic_ops_code(code2) == 3, "Work should be 3" + assert count_depth_code(code2) == 3, "Depth should be 3" + + # Test case 3: (a + b) * (c + d) + (e + f) * (g + h) + # Work = 7 (4 additions + 2 multiplications + 1 addition) + # Depth = 3 (parallel adds, then parallel mults, then final add) + code3 = "(a + b) * (c + d) + (e + f) * (g + h)" + assert count_arithmetic_ops_code(code3) == 7, "Work should be 7" + assert count_depth_code(code3) == 3, "Depth should be 3" + + # Test case 4: Simple single operation + # Work = 1, Depth = 1 + code4 = "a + b" + assert count_arithmetic_ops_code(code4) == 1, "Work should be 1" + assert count_depth_code(code4) == 1, "Depth should be 1" + + # Test case 5: Unary operation with binary operation + # -a + b: Work = 2, Depth = 2 (unary then add, but unary on a, so depth is 1+1=2) + code5 = "-a + b" + assert count_arithmetic_ops_code(code5) == 2, "Work should be 2" + assert count_depth_code(code5) == 2, "Depth should be 2" + + # Test case 6: Function call with independent arguments + # max(a + b, c + d): Work = 2 (two adds, max is 0), Depth = 1 (parallel adds, max is 0) + code6 = "max(a + b, c + d)" + assert count_arithmetic_ops_code(code6) == 2, "Work should be 2" + assert count_depth_code(code6) == 1, "Depth should be 1" + + # Test case 7: Nested function calls + # sqrt(a + b): Work = 2 (add + sqrt), Depth = 2 (add then sqrt) + code7 = "sqrt(a + b)" + assert count_arithmetic_ops_code(code7) == 2, "Work should be 2" + assert count_depth_code(code7) == 2, "Depth should be 2" + + # Test case 8: AugAssign with parallel sub-expressions + # x += a * b + c * d: Work = 4 (2 mults + 1 add + 1 augassign), Depth = 3 + code8 = "x += a * b + c * d" + assert count_arithmetic_ops_code(code8) == 4, "Work should be 4" + assert count_depth_code(code8) == 3, "Depth should be 3" + + # Test case 9: Multiple independent statements (no data dependency) + # a = x + y; b = z + w --> Work = 2, Depth = 1 (parallel, no dependency) + code9 = """ +a = x + y +b = z + w +""" + assert count_arithmetic_ops_code(code9) == 2, "Work should be 2" + assert count_depth_code(code9) == 1, "Depth should be 1 (independent statements)" + + # Test case 10: Multiple statements WITH data dependency + # a = x + y; b = a + z --> Work = 2, Depth = 2 (b depends on a) + code10 = """ +a = x + y +b = a + z +""" + assert count_arithmetic_ops_code(code10) == 2, "Work should be 2" + assert count_depth_code(code10) == 2, "Depth should be 2 (b depends on a)" + + # Test case 11: Chain of 3 dependent statements + # a = x + y; b = a * 2; c = b + z --> Work = 3, Depth = 3 + code11 = """ +a = x + y +b = a * 2 +c = b + z +""" + assert count_arithmetic_ops_code(code11) == 3, "Work should be 3" + assert count_depth_code(code11) == 3, "Depth should be 3 (chain: a -> b -> c)" + + # Test case 12: Diamond dependency pattern + # a = x + y; b = a + 1; c = a + 2; d = b + c --> Work = 4, Depth = 3 + # a has depth 1, b and c both have depth 2 (depend on a), d has depth 3 + code12 = """ +a = x + y +b = a + 1 +c = a + 2 +d = b + c +""" + assert count_arithmetic_ops_code(code12) == 4, "Work should be 4" + assert count_depth_code(code12) == 3, "Depth should be 3 (diamond: a -> b,c -> d)" + + # Test case 13: AugAssign chain + # x += 1; x += 2; x += 3 --> Work = 3, Depth = 3 (each depends on previous x) + code13 = """ +x += 1 +x += 2 +x += 3 +""" + assert count_arithmetic_ops_code(code13) == 3, "Work should be 3" + assert count_depth_code(code13) == 3, "Depth should be 3 (augassign chain)" + + # Test case 14: Single complex statement with tree structure + # result = (a+b)*(c+d) + (e+f)*(g+h) + (i+j)*(k+l) + # The AST is left-associative: ((prod1 + prod2) + prod3) + # prod1 has depth 2, prod2 has depth 2, prod3 has depth 2 + # (prod1 + prod2) = max(2,2) + 1 = 3 + # ((prod1 + prod2) + prod3) = max(3, 2) + 1 = 4 + code14 = "(a+b)*(c+d) + (e+f)*(g+h) + (i+j)*(k+l)" + assert count_arithmetic_ops_code(code14) == 11, "Work should be 11" + assert count_depth_code(code14) == 4, "Depth should be 4" + if __name__ == '__main__': for test_name in work_depth_test_cases.keys(): From 12ae75616b12cd5e4d109d3d72be32c4dd6d3e63 Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 28 Apr 2026 18:31:13 +0200 Subject: [PATCH 08/10] add back accidentally removed test --- tests/sdfg/work_depth_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/sdfg/work_depth_test.py b/tests/sdfg/work_depth_test.py index 5920174c18..d5ab798c83 100644 --- a/tests/sdfg/work_depth_test.py +++ b/tests/sdfg/work_depth_test.py @@ -527,3 +527,5 @@ def test_depth_counter_vs_work_counter(): for assumptions in tests_for_exception: test_assumption_system_contradictions(assumptions) + + test_depth_counter_vs_work_counter() From 5b9bc376b178e9e922cc49f0aea2613dd1c41ba5 Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 28 Apr 2026 21:12:04 +0200 Subject: [PATCH 09/10] move static symbol mapping to helpers --- dace/sdfg/performance_evaluation/helpers.py | 107 +++++++++++++++++- .../operational_intensity.py | 105 +---------------- .../sdfg/performance_evaluation/work_depth.py | 105 +---------------- 3 files changed, 107 insertions(+), 210 deletions(-) diff --git a/dace/sdfg/performance_evaluation/helpers.py b/dace/sdfg/performance_evaluation/helpers.py index d730a84d88..5077579094 100644 --- a/dace/sdfg/performance_evaluation/helpers.py +++ b/dace/sdfg/performance_evaluation/helpers.py @@ -5,9 +5,10 @@ from collections import deque from typing import List, Dict, Set, Tuple, Optional, Union import networkx as nx +import re import sympy as sp from dace.sdfg.state import ControlFlowRegion -from dace.sdfg.propagation import propagate_states +from dace.symbolic import pystr_to_symbolic NodeT = str EdgeT = Tuple[NodeT, NodeT] @@ -406,4 +407,106 @@ def get_legacy_loop_ranges(cfr: ControlFlowRegion) -> Dict[SDFGState, Tuple[sp.E result[node] = (loop_var, start, stop, stride) - return result \ No newline at end of file + return result + +def subs_till_fixed_point(expr:sp.Expr, symbol_map:Dict[sp.Expr, sp.Expr]): + """ + Takes a sympy expression and a symbol mapping and applies the mapping to the expression until a fixed point is reached + Needs the guarantee that the symbol mapping does not have cyclic dependencies. + + :param expr: Description + :param symbol_map: Description + :return: Description + """ + if not isinstance(expr, sp.Expr): + return expr + prev = None + curr = expr + while prev != curr: + prev = curr + curr = curr.subs(symbol_map) + return curr + +def get_static_symbols(sdfg: SDFG): + """ + Returns a mapping of symbols that are assigned exactly at one point in the sdfg. + + :param sdfg: The sdfg for which we want to find the static symbols and their corresponding assignment + :return: The mapping of the symbols to higher levels (iterated to a fixed point) + """ + + + patterns = [ + "dace.complex128", + "dace.float64", + "dace.float32", + "dace.int64", + "dace.int32", + "dace.int16", + "dace.uint32", + "dace.uint16", + "dace.uint8", + "float", + "int" + ] + + type_regex = re.compile("|".join(map(re.escape, patterns))) + static_symbol_mapping:Dict[sp.Symbol, sp.Expr] = {sp.Symbol(a): sp.Symbol(a) for a in sdfg.arg_names} + non_static_symbols = set() + for node, containing_state in sdfg.all_nodes_recursive(): + if isinstance(node, nodes.AccessNode): + + if containing_state.in_degree(node) == 1: + edge = containing_state.in_edges(node)[0] + source = edge.src + + if edge.data.volume == 1: + if isinstance(source, nodes.Tasklet): + tasklet = source + in_map = {} + out_map = {} + # Incoming edges: symbols feeding the tasklet + for e in containing_state.in_edges(tasklet): + if not isinstance(e.src, nodes.AccessNode): + continue + sym = str(e.src.data) + in_map[e.dst_conn] = sym + # Outgoing edges: symbols written by the tasklet + # Out edges should only be one, but for safety we iterate + for e in containing_state.out_edges(tasklet): + if not isinstance(e.dst, nodes.AccessNode): + continue + sym = sp.Symbol(e.dst.data) + out_map[e.src_conn] = sym + code = tasklet.code.as_string.strip() + # Expect a single assignment + lines = [l.strip() for l in code.splitlines() if l.strip()] + lhs, rhs = lines[0].split('=',1) + lhs = lhs.strip() + rhs = rhs.strip() + rhs = type_regex.sub("", rhs) + # Parse RHS using SymPy, with tasklet inputs substituted + lhs_sympy = pystr_to_symbolic(lhs) + lhs_sympy = lhs_sympy.subs(out_map) + + if not lhs_sympy in static_symbol_mapping.keys(): + try: + rhs_sympy = pystr_to_symbolic(rhs) + rhs_sympy = rhs_sympy.subs(in_map) + static_symbol_mapping[lhs_sympy] = rhs_sympy + except: + non_static_symbols.add(lhs_sympy) + else: + non_static_symbols.add(lhs_sympy) + + elif isinstance(source, nodes.AccessNode): + data_sym = sp.Symbol(source.data) + nd_sym = sp.Symbol(node.data) + if not data_sym in static_symbol_mapping.keys(): + static_symbol_mapping[data_sym] = nd_sym + else: + non_static_symbols.add(data_sym) + + static_symbol_mapping = {k: v for (k, v) in static_symbol_mapping.items() if k not in non_static_symbols} + static_symbol_mapping = {str(k): subs_till_fixed_point(v, static_symbol_mapping) for k,v in static_symbol_mapping.items()} + return static_symbol_mapping diff --git a/dace/sdfg/performance_evaluation/operational_intensity.py b/dace/sdfg/performance_evaluation/operational_intensity.py index a713ab69eb..a25b361c9f 100644 --- a/dace/sdfg/performance_evaluation/operational_intensity.py +++ b/dace/sdfg/performance_evaluation/operational_intensity.py @@ -14,7 +14,7 @@ from dace.symbolic import pystr_to_symbolic, SymExpr import re -from dace.sdfg.performance_evaluation.helpers import get_uuid +from dace.sdfg.performance_evaluation.helpers import get_uuid, get_static_symbols, subs_till_fixed_point from dace.transformation.passes.symbol_ssa import StrictSymbolSSA from dace.transformation.pass_pipeline import FixedPointPipeline @@ -24,109 +24,6 @@ from dace.transformation.passes.analysis import loop_analysis -def subs_till_fixed_point(expr:sp.Expr, symbol_map:Dict[sp.Expr, sp.Expr]): - """ - Takes a sympy expression and a symbol mapping and applies the mapping to the expression until a fixed point is reached - Needs the guarantee that the symbol mapping does not have cyclic dependencies. - - :param expr: Description - :param symbol_map: Description - :return: Description - """ - if not isinstance(expr, sp.Expr): - return expr - prev = None - curr = expr - while prev != curr: - prev = curr - curr = curr.subs(symbol_map) - return curr - -def get_static_symbols(sdfg: SDFG): - """ - Returns a mapping of symbols that are assigned exactly at one point in the sdfg. - - :param sdfg: The sdfg for which we want to find the static symbols and their corresponding assignment - :return: The mapping of the symbols to higher levels (iterated to a fixed point) - """ - - - patterns = [ - "dace.complex128", - "dace.float64", - "dace.float32", - "dace.int64", - "dace.int32", - "dace.int16", - "dace.uint32", - "dace.uint16", - "dace.uint8", - "float", - "int" - ] - - type_regex = re.compile("|".join(map(re.escape, patterns))) - static_symbol_mapping:Dict[sp.Symbol, sp.Expr] = {sp.Symbol(a): sp.Symbol(a) for a in sdfg.arg_names} - non_static_symbols = set() - for node, containing_state in sdfg.all_nodes_recursive(): - if isinstance(node, nd.AccessNode): - - if containing_state.in_degree(node) == 1: - edge = containing_state.in_edges(node)[0] - source = edge.src - - if edge.data.volume == 1: - if isinstance(source, nd.Tasklet): - tasklet = source - in_map = {} - out_map = {} - # Incoming edges: symbols feeding the tasklet - for e in containing_state.in_edges(tasklet): - if not isinstance(e.src, nd.AccessNode): - continue - sym = str(e.src.data) - in_map[e.dst_conn] = sym - # Outgoing edges: symbols written by the tasklet - # Out edges should only be one, but for safety we iterate - for e in containing_state.out_edges(tasklet): - if not isinstance(e.dst, nd.AccessNode): - continue - sym = sp.Symbol(e.dst.data) - out_map[e.src_conn] = sym - code = tasklet.code.as_string.strip() - # Expect a single assignment - lines = [l.strip() for l in code.splitlines() if l.strip()] - lhs, rhs = lines[0].split('=',1) - lhs = lhs.strip() - rhs = rhs.strip() - rhs = type_regex.sub("", rhs) - # Parse RHS using SymPy, with tasklet inputs substituted - lhs_sympy = pystr_to_symbolic(lhs) - lhs_sympy = lhs_sympy.subs(out_map) - - if not lhs_sympy in static_symbol_mapping.keys(): - try: - rhs_sympy = pystr_to_symbolic(rhs) - rhs_sympy = rhs_sympy.subs(in_map) - static_symbol_mapping[lhs_sympy] = rhs_sympy - except: - non_static_symbols.add(lhs_sympy) - else: - non_static_symbols.add(lhs_sympy) - - elif isinstance(source, nd.AccessNode): - data_sym = sp.Symbol(source.data) - nd_sym = sp.Symbol(node.data) - if not data_sym in static_symbol_mapping.keys(): - static_symbol_mapping[data_sym] = nd_sym - else: - non_static_symbols.add(data_sym) - - static_symbol_mapping = {k: v for (k, v) in static_symbol_mapping.items() if k not in non_static_symbols} - static_symbol_mapping = {str(k): subs_till_fixed_point(v, static_symbol_mapping) for k,v in static_symbol_mapping.items()} - return static_symbol_mapping - - class SymbolRange(): """ Used to describe an SDFG symbol associated with a range (start, stop, step) of values. """ diff --git a/dace/sdfg/performance_evaluation/work_depth.py b/dace/sdfg/performance_evaluation/work_depth.py index bb3ffc9090..354f6d76a8 100644 --- a/dace/sdfg/performance_evaluation/work_depth.py +++ b/dace/sdfg/performance_evaluation/work_depth.py @@ -19,7 +19,7 @@ import warnings import re -from dace.sdfg.performance_evaluation.helpers import LoopExtractionError, get_uuid, find_loop_guards_tails_exits, get_legacy_loop_body, get_legacy_loop_ranges +from dace.sdfg.performance_evaluation.helpers import LoopExtractionError, get_uuid, find_loop_guards_tails_exits, get_legacy_loop_body, get_legacy_loop_ranges, get_static_symbols, subs_till_fixed_point from dace.sdfg.performance_evaluation.assumptions import parse_assumptions from dace.transformation.passes.symbol_ssa import StrictSymbolSSA from dace.transformation.pass_pipeline import FixedPointPipeline @@ -45,109 +45,6 @@ def get_array_size_symbols(sdfg): symbols.add(s) return symbols -def subs_till_fixed_point(expr:sp.Expr, symbol_map:Dict[sp.Expr, sp.Expr]): - """ - Takes a sympy expression and a symbol mapping and applies the mapping to the expression until a fixed point is reached - Needs the guarantee that the symbol mapping does not have cyclic dependencies. - - :param expr: Description - :param symbol_map: Description - :return: Description - """ - prev = None - curr = expr - while prev != curr: - prev = curr - curr = curr.subs(symbol_map) - return curr - -def get_static_symbols(sdfg: SDFG): - """ - Returns a mapping of symbols that are assigned exactly at one point in the sdfg. - - :param sdfg: The sdfg for which we want to find the static symbols and their corresponding assignment - :return: The mapping of the symbols to higher levels (iterated to a fixed point) - """ - - patterns = [ - "dace.complex128", - "dace.float64", - "dace.float32", - "dace.int64", - "dace.int32", - "dace.int16", - "dace.uint32", - "dace.uint16", - "dace.uint8", - "float", - "int" - ] - - type_regex = re.compile("|".join(map(re.escape, patterns))) - static_symbol_mapping:Dict[sp.Symbol, sp.Expr] = {sp.Symbol(a): sp.Symbol(a) for a in sdfg.arg_names} - non_static_symbols = set() - for node, containing_state in sdfg.all_nodes_recursive(): - if isinstance(node, nd.AccessNode): - - if containing_state.in_degree(node) == 1: - edge = containing_state.in_edges(node)[0] - source = edge.src - - if edge.data.volume == 1: - if isinstance(source, nd.Tasklet): - tasklet = source - in_map = {} - out_map = {} - # Incoming edges: symbols feeding the tasklet - for e in containing_state.in_edges(tasklet): - if not isinstance(e.src, nd.AccessNode): - continue - sym = str(e.src.data) - in_map[e.dst_conn] = sym - # Outgoing edges: symbols written by the tasklet - # Out edges should only be one, but for safety we iterate - for e in containing_state.out_edges(tasklet): - if not isinstance(e.dst, nd.AccessNode): - continue - sym = sp.Symbol(e.dst.data) - out_map[e.src_conn] = sym - code = tasklet.code.as_string.strip() - # Expect a single assignment - lines = [l.strip() for l in code.splitlines() if l.strip()] - try: - lhs, rhs = lines[0].split('=',1) - except: - # Skip mapping for overly complex tasklet code - non_static_symbols.add(sp.Symbol(node.data)) - lhs = lhs.strip() - rhs = rhs.strip() - rhs = type_regex.sub("", rhs) - # Parse RHS using SymPy, with tasklet inputs substituted - lhs_sympy = pystr_to_symbolic(lhs) - lhs_sympy = lhs_sympy.subs(out_map) - - if not lhs_sympy in static_symbol_mapping.keys(): - try: - rhs_sympy = pystr_to_symbolic(rhs) - rhs_sympy = rhs_sympy.subs(in_map) - static_symbol_mapping[lhs_sympy] = rhs_sympy - except: - non_static_symbols.add(lhs_sympy) - else: - non_static_symbols.add(lhs_sympy) - - elif isinstance(source, nd.AccessNode): - data_sym = sp.Symbol(source.data) - nd_sym = sp.Symbol(node.data) - if not data_sym in static_symbol_mapping.keys(): - static_symbol_mapping[data_sym] = nd_sym - else: - non_static_symbols.add(data_sym) - - static_symbol_mapping = {k: v for (k, v) in static_symbol_mapping.items() if k not in non_static_symbols} - static_symbol_mapping = {k: subs_till_fixed_point(v, static_symbol_mapping) for k,v in static_symbol_mapping.items()} - return static_symbol_mapping - def symeval(val, symbols): """ Takes a sympy expression and substitutes its symbols according to a dict { old_symbol: new_symbol}. From e52fdcbe03843c8f068a1f7a36535af321f44372 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 29 Apr 2026 00:14:23 +0200 Subject: [PATCH 10/10] fixes --- .../operational_intensity.py | 83 +++++++++++++++++-- .../sdfg/performance_evaluation/work_depth.py | 1 - tests/sdfg/operational_intensity_test.py | 14 +++- 3 files changed, 89 insertions(+), 9 deletions(-) diff --git a/dace/sdfg/performance_evaluation/operational_intensity.py b/dace/sdfg/performance_evaluation/operational_intensity.py index a25b361c9f..3996dffbd4 100644 --- a/dace/sdfg/performance_evaluation/operational_intensity.py +++ b/dace/sdfg/performance_evaluation/operational_intensity.py @@ -24,6 +24,7 @@ from dace.transformation.passes.analysis import loop_analysis +import traceback class SymbolRange(): """ Used to describe an SDFG symbol associated with a range (start, stop, step) of values. """ @@ -127,10 +128,8 @@ def symeval(val, symbols): def evaluate_symbols(base, new): result = {} - print("Evaluate symbols called with base:", base, "new:", new) for k, v in new.items(): result[k] = symeval(v, base) - print("Result:", result) return result @@ -141,6 +140,53 @@ def update_mapping(mapping, e): update[k] = pystr_to_symbolic(v).subs(mapping) mapping.update(update) +def assignment_misses(edge, mapping, stack, clt, C, symbols, array_names): + # regex pattern to detect buffer name and index if applicable + pattern = re.compile(r""" + ^\s* + (?P[a-zA-Z_]\w*) # variable name + (?:\[ + (?P[^\[\]]+) # anything inside brackets (no nested []) + \])? + \s*$ +""", re.VERBOSE) + + misses = 0 + for lhs, rhs in edge.data.assignments.items(): + m_lhs = pattern.match(lhs) + m_rhs = pattern.match(rhs) + try: + lhs_name = m_lhs.group("name") + lhs_index = m_lhs.group("index") + if lhs_index and not lhs_index.isdigit(): + lhs_index = sp.Symbol(m_lhs.group("index")) + elif lhs_index and lhs_index.isdigit(): + lhs_index = sp.Expr(int(lhs_index)) + + rhs_name = m_rhs.group("name") + rhs_index = m_rhs.group("index") + if rhs_index and not rhs_index.isdigit(): + rhs_index = sp.Symbol(m_rhs.group("index")) + elif rhs_index and rhs_index.isdigit(): + lhs_index = sp.Expr(int(rhs_index)) + + if lhs_name in clt.array_info or (lhs_name in array_names and array_names[lhs_name] in clt.array_info): + line_id = clt.cache_line_id(lhs_name if lhs_name not in array_names else array_names[lhs_name], + ([lhs_index.subs(mapping)] if isinstance(lhs_index, sp.Expr) else []), mapping) + line_id = int(line_id.subs(symbols).subs(mapping) if isinstance(line_id, sp.Expr) else line_id) + dist = stack.touch(line_id) + misses += 1 if dist >= C or dist == -1 else 0 + + if rhs_name in clt.array_info or (rhs_name in array_names and array_names[rhs_name] in clt.array_info): + line_id = clt.cache_line_id(rhs_name if rhs_name not in array_names else array_names[rhs_name], + ([rhs_index.subs(mapping)] if isinstance(rhs_index, sp.Expr) else []), mapping) + line_id = int(line_id.subs(symbols).subs(mapping) if isinstance(line_id, sp.Expr) else line_id) + dist = stack.touch(line_id) + misses += 1 if dist >= C or dist == -1 else 0 + except Exception as e: + traceback.print_exc() # full stack trace + return misses + def update_map_iterators(map, mapping, symbols): # update the map params and return False @@ -230,7 +276,29 @@ def scope_misses(state: SDFGState, elif isinstance(node, nd.Tasklet): tasklet_misses = 0 # analyze the memory accesses of this tasklet and whether they hit in cache or not - for e in state.in_edges(node) + state.out_edges(node): + for e in state.in_edges(node): + # Check if source node is just a transient node to map to correct cache line + data_node = e.src + data_node_in_edges = state.in_edges(data_node) + if len(data_node_in_edges) == 1 and isinstance(data_node_in_edges[0].src, nd.AccessNode): + e = data_node_in_edges[0] + + if e.data.data in clt.array_info or (e.data.data in array_names + and array_names[e.data.data] in clt.array_info): + line_id = clt.cache_line_id( + e.data.data if e.data.data not in array_names else array_names[e.data.data], + [x[0].subs(mapping) for x in e.data.subset.ranges], mapping) + line_id = int(line_id.subs(symbols).subs(mapping)) + dist = stack.touch(line_id) + tasklet_misses += 1 if dist >= C or dist == -1 else 0 + + for e in state.out_edges(node): + # Check if destination node is just a transient node to map to correct cache line + data_node = e.dst + data_node_out_edges = state.out_edges(data_node) + if len(data_node_out_edges) == 1 and isinstance(data_node_out_edges[0].src, nd.AccessNode): + e = data_node_out_edges[0] + if e.data.data in clt.array_info or (e.data.data in array_names and array_names[e.data.data] in clt.array_info): line_id = clt.cache_line_id( @@ -307,12 +375,10 @@ def cfr_misses(cfr:ControlFlowRegion, step = sp.sympify(loop_analysis.get_loop_stride(cfr)) mapping[loop_var] = start.subs(mapping) region_misses = 0 - while (loop_condition.subs(mapping) == True): iter_misses = cfg_misses(cfr, op_in_map, mapping, stack, clt, C, symbols, array_names, decided_branches, ask_user, start=cfr.start_block,end=None) mapping[loop_var] = mapping[loop_var] + step region_misses += iter_misses - elif isinstance(cfr, ConditionalBlock): true_branches = [] possible_branches = [] @@ -449,6 +515,7 @@ def cfg_misses(cfg: ControlFlowRegion, # save e's assignments in mapping and update curr_state # replace values first with mapping, then update mapping try: + total_misses += assignment_misses(e, mapping, stack, clt, C, symbols, array_names) update_mapping(mapping, e) except: print('\nWARNING: Uncommon assignment detected on InterstateEdge (e.g. bitwise operators).' @@ -469,7 +536,7 @@ def cfg_misses(cfg: ControlFlowRegion, if curr_state in decided_branches: # if the user already decided this branch in a previous iteration, take the same branch again. e = decided_branches[curr_state] - + total_misses += assignment_misses(e, mapping, stack, clt, C, symbols, array_names) update_mapping(mapping, e) curr_state = e.dst else: @@ -484,6 +551,7 @@ def cfg_misses(cfg: ControlFlowRegion, if len(next_edge_candidates) == 1: e = next_edge_candidates[0] + total_misses += assignment_misses(e, mapping, stack, clt, C, symbols, array_names) update_mapping(mapping, e) decided_branches[curr_state] = e curr_state = e.dst @@ -497,16 +565,19 @@ def cfg_misses(cfg: ControlFlowRegion, print('merge state is named ', merge_state) chosen = int(input('Choose an option from above: ')) e = edges[chosen] + total_misses += assignment_misses(e, mapping, stack, clt, C, symbols, array_names) update_mapping(mapping, e) decided_branches[curr_state] = e curr_state = e.dst print(2 * '\n') else: final_e = next_edge_candidates.pop() + final_misses = 0 for e in next_edge_candidates: # copy the state of the analysis curr_mapping = dict(mapping) + curr_misses += assignment_misses(e, mapping, stack, clt, C, symbols, array_names) update_mapping(curr_mapping, e) curr_stack = stack.copy() curr_clt = clt.copy() diff --git a/dace/sdfg/performance_evaluation/work_depth.py b/dace/sdfg/performance_evaluation/work_depth.py index 354f6d76a8..288ee187e5 100644 --- a/dace/sdfg/performance_evaluation/work_depth.py +++ b/dace/sdfg/performance_evaluation/work_depth.py @@ -696,7 +696,6 @@ def control_flow_region_work_depth(cfr: ControlFlowRegion, loop_work = sp.Sum(loop_work, (loop_var, shifted_lo, shifted_hi)).doit() loop_depth = sp.Sum(loop_depth, (loop_var, shifted_lo, shifted_hi)).doit() - print("Loop work uneval: ", sp.Sum(loop_work, (loop_var, shifted_lo, shifted_hi))) # Do equality subs loop_work = sp.simplify(loop_work.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) loop_depth = sp.simplify(loop_depth.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) diff --git a/tests/sdfg/operational_intensity_test.py b/tests/sdfg/operational_intensity_test.py index 36be0455db..71e610f386 100644 --- a/tests/sdfg/operational_intensity_test.py +++ b/tests/sdfg/operational_intensity_test.py @@ -49,14 +49,14 @@ def if_else(x: dc.int64[100], sum: dc.int64[1]): if x[0] > 3: for i in range(100): sum += x[i] - # no else --> simply analyze the ifs. if cache big enough, everything is reused + # no else --> simply analyze the ifs. if cache big enough, everything is reused; @dc.program def unaligned_for_loop(x: dc.float32[100], sum: dc.int64[1]): for i in range(17, 53): sum += x[i] - + # 36 = 144byte array elemets accessed 1 = 4byte scalar accessed 36 ops -> 64byte line size=> 3 lines + 1 line scalar => op in = 9/64 @dc.program def sequential_maps(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): @@ -109,6 +109,16 @@ def reduction_library_node(x: dc.float64[N]): 'single_map16_even': (single_map16, 64 * 64, 64, { 'N': 512 }, 1 / 6), + 'single_for_loop': (single_for_loop, 64 * 64, 64, { + 'N': 512 + }, 1/16), + 'if_else': (if_else, 64 * 64, 64, { + 'N': 512 + }, 200/(14*64)) + # 14 cache misses, because DaCe introduces intermediate variable + , + 'unaligned_for_loop': (unaligned_for_loop, 64 * 64, 64, { + }, 9/64), # now num_elements_on_single_cache_line does not divie N anymore # -->513 work, 520 elements loaded --> 513 / (520*8*3) 'single_map64_uneven': (single_map64, 64 * 64, 64, {