diff --git a/dace/sdfg/performance_evaluation/assumptions.py b/dace/sdfg/performance_evaluation/assumptions.py index 1b1d37348b..9e140e0fd9 100644 --- a/dace/sdfg/performance_evaluation/assumptions.py +++ b/dace/sdfg/performance_evaluation/assumptions.py @@ -272,7 +272,7 @@ def parse_assumptions(assumptions, array_symbols): for sym, assum in condensed_assumptions.items(): i = 0 for g in assum.greater: - replacement_symbol = sp.Symbol(f'_p_{sym}', positive=True, integer=True) + replacement_symbol = sp.Symbol(f'_p_{sym}', nonnegative=True, integer=True) all_subs[i][0].update({sp.Symbol(sym): replacement_symbol + g}) all_subs[i][1].update({replacement_symbol: sp.Symbol(sym) - g}) i += 1 diff --git a/dace/sdfg/performance_evaluation/helpers.py b/dace/sdfg/performance_evaluation/helpers.py index ba7bfb84f2..5077579094 100644 --- a/dace/sdfg/performance_evaluation/helpers.py +++ b/dace/sdfg/performance_evaluation/helpers.py @@ -5,6 +5,10 @@ from collections import deque from typing import List, Dict, Set, Tuple, Optional, Union import networkx as nx +import re +import sympy as sp +from dace.sdfg.state import ControlFlowRegion +from dace.symbolic import pystr_to_symbolic NodeT = str EdgeT = Tuple[NodeT, NodeT] @@ -335,3 +339,174 @@ def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): # remove artificial end node sdfg_nx.remove_node(artificial_end_node) return nodes_oNodes_exits + +def get_legacy_loop_body(cfr, guard, tail, exits): + """ + Get all nodes in a legacy loop body. + A node is in the loop body if: + - It's reachable from guard + - It can reach tail + - It's not an exit node + """ + # Forward reachability from guard + forward_reachable = set() + queue = deque([guard]) + while queue: + node = queue.popleft() + if node in forward_reachable: + continue + forward_reachable.add(node) + for edge in cfr.out_edges(node): + queue.append(edge.dst) + + # Backward reachability to tail + backward_reachable = set() + queue = deque([tail]) + while queue: + node = queue.popleft() + if node in backward_reachable: + continue + backward_reachable.add(node) + for edge in cfr.in_edges(node): + queue.append(edge.src) + + # Loop body = (forward AND backward) - exits + loop_body = (forward_reachable & backward_reachable) - set(exits) + + return loop_body + +def get_legacy_loop_ranges(cfr: ControlFlowRegion) -> Dict[SDFGState, Tuple[sp.Expr, sp.Expr, sp.Expr, sp.Symbol]]: + """ + Builds a map from loop guard states to their loop variable and iteration + range, harvesting the annotations set by propagate_states / + _annotate_loop_ranges. + + Must be called AFTER propagate_states has been run on the SDFG. + + :param cfr: The ControlFlowRegion to inspect (only its direct nodes are + checked, not descendants, since control_flow_region_work_depth + is called recursively anyway). + :return: A dict mapping each legacy loop guard SDFGState to a tuple + (loop_var, start, stop, stride) + """ + #propagate_states(cfr) + result: Dict[SDFGState, Tuple[sp.Symbol, sp.Expr, sp.Expr, sp.Expr]] = {} + + for node in cfr.nodes(): + if not getattr(node, 'is_loop_guard', False): + continue + + itvar_str: str = node.itvar + loop_var: sp.Symbol = sp.Symbol(itvar_str) + + # guard.ranges[itvar] is a subsets.Range with one entry: [(start, stop, stride)] + rng = node.ranges[itvar_str][0] # -> (start, stop, stride) + start = sp.sympify(rng[0]) + stop = sp.sympify(rng[1]) + stride = sp.sympify(rng[2]) + + result[node] = (loop_var, start, stop, stride) + + return result + +def subs_till_fixed_point(expr:sp.Expr, symbol_map:Dict[sp.Expr, sp.Expr]): + """ + Takes a sympy expression and a symbol mapping and applies the mapping to the expression until a fixed point is reached + Needs the guarantee that the symbol mapping does not have cyclic dependencies. + + :param expr: Description + :param symbol_map: Description + :return: Description + """ + if not isinstance(expr, sp.Expr): + return expr + prev = None + curr = expr + while prev != curr: + prev = curr + curr = curr.subs(symbol_map) + return curr + +def get_static_symbols(sdfg: SDFG): + """ + Returns a mapping of symbols that are assigned exactly at one point in the sdfg. + + :param sdfg: The sdfg for which we want to find the static symbols and their corresponding assignment + :return: The mapping of the symbols to higher levels (iterated to a fixed point) + """ + + + patterns = [ + "dace.complex128", + "dace.float64", + "dace.float32", + "dace.int64", + "dace.int32", + "dace.int16", + "dace.uint32", + "dace.uint16", + "dace.uint8", + "float", + "int" + ] + + type_regex = re.compile("|".join(map(re.escape, patterns))) + static_symbol_mapping:Dict[sp.Symbol, sp.Expr] = {sp.Symbol(a): sp.Symbol(a) for a in sdfg.arg_names} + non_static_symbols = set() + for node, containing_state in sdfg.all_nodes_recursive(): + if isinstance(node, nodes.AccessNode): + + if containing_state.in_degree(node) == 1: + edge = containing_state.in_edges(node)[0] + source = edge.src + + if edge.data.volume == 1: + if isinstance(source, nodes.Tasklet): + tasklet = source + in_map = {} + out_map = {} + # Incoming edges: symbols feeding the tasklet + for e in containing_state.in_edges(tasklet): + if not isinstance(e.src, nodes.AccessNode): + continue + sym = str(e.src.data) + in_map[e.dst_conn] = sym + # Outgoing edges: symbols written by the tasklet + # Out edges should only be one, but for safety we iterate + for e in containing_state.out_edges(tasklet): + if not isinstance(e.dst, nodes.AccessNode): + continue + sym = sp.Symbol(e.dst.data) + out_map[e.src_conn] = sym + code = tasklet.code.as_string.strip() + # Expect a single assignment + lines = [l.strip() for l in code.splitlines() if l.strip()] + lhs, rhs = lines[0].split('=',1) + lhs = lhs.strip() + rhs = rhs.strip() + rhs = type_regex.sub("", rhs) + # Parse RHS using SymPy, with tasklet inputs substituted + lhs_sympy = pystr_to_symbolic(lhs) + lhs_sympy = lhs_sympy.subs(out_map) + + if not lhs_sympy in static_symbol_mapping.keys(): + try: + rhs_sympy = pystr_to_symbolic(rhs) + rhs_sympy = rhs_sympy.subs(in_map) + static_symbol_mapping[lhs_sympy] = rhs_sympy + except: + non_static_symbols.add(lhs_sympy) + else: + non_static_symbols.add(lhs_sympy) + + elif isinstance(source, nodes.AccessNode): + data_sym = sp.Symbol(source.data) + nd_sym = sp.Symbol(node.data) + if not data_sym in static_symbol_mapping.keys(): + static_symbol_mapping[data_sym] = nd_sym + else: + non_static_symbols.add(data_sym) + + static_symbol_mapping = {k: v for (k, v) in static_symbol_mapping.items() if k not in non_static_symbols} + static_symbol_mapping = {str(k): subs_till_fixed_point(v, static_symbol_mapping) for k,v in static_symbol_mapping.items()} + return static_symbol_mapping diff --git a/dace/sdfg/performance_evaluation/op_in_helpers.py b/dace/sdfg/performance_evaluation/op_in_helpers.py index b1e430f676..abe0e2d5a9 100644 --- a/dace/sdfg/performance_evaluation/op_in_helpers.py +++ b/dace/sdfg/performance_evaluation/op_in_helpers.py @@ -2,7 +2,7 @@ """ Contains class CacheLineTracker which keeps track of all arrays of an SDFG and their cache line position and class AccessStack which which corresponds to the stack used to compute the stack distance. Further, provides a curve fitting method and plotting function. """ - +from __future__ import annotations import warnings from dace.data import Array import sympy as sp @@ -11,7 +11,6 @@ import numpy as np from dace import symbol - class CacheLineTracker: """ A CacheLineTracker maps data container accesses to the corresponding accessed cache line. """ @@ -129,6 +128,12 @@ def copy(self): curr.next = Node(x) curr = curr.next return new_stack + + def replace_self(self, other:AccessStack): + self.top = other.top + self.num_calls = other.num_calls + self.lengh = other.length + self.C = other.C def plot(x, work_map, cache_misses, op_in_map, symbol_name, C, L, sympy_f, element, name): diff --git a/dace/sdfg/performance_evaluation/operational_intensity.py b/dace/sdfg/performance_evaluation/operational_intensity.py index ee9286a7c5..3996dffbd4 100644 --- a/dace/sdfg/performance_evaluation/operational_intensity.py +++ b/dace/sdfg/performance_evaluation/operational_intensity.py @@ -5,14 +5,16 @@ import argparse from collections import deque from dace.sdfg import nodes as nd -from dace import SDFG, SDFGState, dtypes +from dace import dtypes, SDFG +from dace.sdfg.state import SDFGState, ControlFlowRegion, LoopRegion, FunctionCallRegion, ConditionalBlock, ReturnBlock, ContinueBlock, BreakBlock from typing import Tuple, Dict import os import sympy as sp from copy import deepcopy from dace.symbolic import pystr_to_symbolic, SymExpr +import re -from dace.sdfg.performance_evaluation.helpers import get_uuid +from dace.sdfg.performance_evaluation.helpers import get_uuid, get_static_symbols, subs_till_fixed_point from dace.transformation.passes.symbol_ssa import StrictSymbolSSA from dace.transformation.pass_pipeline import FixedPointPipeline @@ -20,7 +22,9 @@ from dace.sdfg.performance_evaluation.op_in_helpers import CacheLineTracker, AccessStack, fit_curve, plot, compute_mape from dace.sdfg.performance_evaluation.work_depth import analyze_sdfg, get_tasklet_work +from dace.transformation.passes.analysis import loop_analysis +import traceback class SymbolRange(): """ Used to describe an SDFG symbol associated with a range (start, stop, step) of values. """ @@ -60,7 +64,7 @@ def update_map(op_in_map, uuid, new_misses, average=True): def calculate_op_in(op_in_map, work_map, stringify=False, assumptions={}): """ Calculates the operational intensity for each SDFG element from work and bytes loaded. """ for uuid in op_in_map: - work = work_map[uuid][0].subs(assumptions) + work = work_map[uuid].subs(assumptions) if work == 0 and op_in_map[uuid] == 0: op_in_map[uuid] = 0 elif work != 0 and op_in_map[uuid] == 0: @@ -136,8 +140,55 @@ def update_mapping(mapping, e): update[k] = pystr_to_symbolic(v).subs(mapping) mapping.update(update) - -def update_map_iterators(map, mapping): +def assignment_misses(edge, mapping, stack, clt, C, symbols, array_names): + # regex pattern to detect buffer name and index if applicable + pattern = re.compile(r""" + ^\s* + (?P[a-zA-Z_]\w*) # variable name + (?:\[ + (?P[^\[\]]+) # anything inside brackets (no nested []) + \])? + \s*$ +""", re.VERBOSE) + + misses = 0 + for lhs, rhs in edge.data.assignments.items(): + m_lhs = pattern.match(lhs) + m_rhs = pattern.match(rhs) + try: + lhs_name = m_lhs.group("name") + lhs_index = m_lhs.group("index") + if lhs_index and not lhs_index.isdigit(): + lhs_index = sp.Symbol(m_lhs.group("index")) + elif lhs_index and lhs_index.isdigit(): + lhs_index = sp.Expr(int(lhs_index)) + + rhs_name = m_rhs.group("name") + rhs_index = m_rhs.group("index") + if rhs_index and not rhs_index.isdigit(): + rhs_index = sp.Symbol(m_rhs.group("index")) + elif rhs_index and rhs_index.isdigit(): + lhs_index = sp.Expr(int(rhs_index)) + + if lhs_name in clt.array_info or (lhs_name in array_names and array_names[lhs_name] in clt.array_info): + line_id = clt.cache_line_id(lhs_name if lhs_name not in array_names else array_names[lhs_name], + ([lhs_index.subs(mapping)] if isinstance(lhs_index, sp.Expr) else []), mapping) + line_id = int(line_id.subs(symbols).subs(mapping) if isinstance(line_id, sp.Expr) else line_id) + dist = stack.touch(line_id) + misses += 1 if dist >= C or dist == -1 else 0 + + if rhs_name in clt.array_info or (rhs_name in array_names and array_names[rhs_name] in clt.array_info): + line_id = clt.cache_line_id(rhs_name if rhs_name not in array_names else array_names[rhs_name], + ([rhs_index.subs(mapping)] if isinstance(rhs_index, sp.Expr) else []), mapping) + line_id = int(line_id.subs(symbols).subs(mapping) if isinstance(line_id, sp.Expr) else line_id) + dist = stack.touch(line_id) + misses += 1 if dist >= C or dist == -1 else 0 + except Exception as e: + traceback.print_exc() # full stack trace + return misses + + +def update_map_iterators(map, mapping, symbols): # update the map params and return False # if all iterations exhausted, return True # always increase the last one. If it is exhausted, increase the next one and so forth @@ -145,23 +196,23 @@ def update_map_iterators(map, mapping): for p, range in zip(map.params[::-1], map.range[::-1]): # reversed order curr_value = mapping[p] if not isinstance(range[1], SymExpr): - if curr_value.subs(mapping) + range[2].subs(mapping) <= range[1].subs(mapping): + if curr_value.subs(symbols).subs(mapping) + range[2].subs(symbols).subs(mapping) <= range[1].subs(symbols).subs(mapping): # update this value and then we are done - mapping[p] = curr_value.subs(mapping) + range[2].subs(mapping) + mapping[p] = curr_value.subs(symbols).subs(mapping) + range[2].subs(symbols).subs(mapping) map_exhausted = False break else: # set current param to start again and continue - mapping[p] = range[0].subs(mapping) + mapping[p] = range[0].subs(symbols).subs(mapping) else: - if curr_value.subs(mapping) + range[2].subs(mapping) <= range[1].expr.subs(mapping): + if curr_value.subs(symbols).subs(mapping) + range[2].subs(symbols).subs(mapping) <= range[1].expr.subs(symbols).subs(mapping): # update this value and we done - mapping[p] = curr_value.subs(mapping) + range[2].subs(mapping) + mapping[p] = curr_value.subs(symbols).subs(mapping) + range[2].subs(symbols).subs(mapping) map_exhausted = False break else: # set current param to start again and continue - mapping[p] = range[0].subs(mapping) + mapping[p] = range[0].subs(symbols).subs(mapping) return map_exhausted @@ -174,15 +225,15 @@ def map_op_in(state: SDFGState, op_in_map: Dict[str, sp.Expr], entry, mapping, s map_misses = 0 while True: # do analysis of map contents - map_misses += scope_op_in(state, op_in_map, mapping, stack, clt, C, symbols, array_names, decided_branches, + map_misses += scope_misses(state, op_in_map, mapping, stack, clt, C, symbols, array_names, decided_branches, ask_user, entry) - if update_map_iterators(entry.map, mapping): + if update_map_iterators(entry.map, mapping,symbols): break return map_misses -def scope_op_in(state: SDFGState, +def scope_misses(state: SDFGState, op_in_map: Dict[str, sp.Expr], mapping, stack: AccessStack, @@ -225,14 +276,35 @@ def scope_op_in(state: SDFGState, elif isinstance(node, nd.Tasklet): tasklet_misses = 0 # analyze the memory accesses of this tasklet and whether they hit in cache or not - for e in state.in_edges(node) + state.out_edges(node): + for e in state.in_edges(node): + # Check if source node is just a transient node to map to correct cache line + data_node = e.src + data_node_in_edges = state.in_edges(data_node) + if len(data_node_in_edges) == 1 and isinstance(data_node_in_edges[0].src, nd.AccessNode): + e = data_node_in_edges[0] + if e.data.data in clt.array_info or (e.data.data in array_names and array_names[e.data.data] in clt.array_info): line_id = clt.cache_line_id( e.data.data if e.data.data not in array_names else array_names[e.data.data], [x[0].subs(mapping) for x in e.data.subset.ranges], mapping) + line_id = int(line_id.subs(symbols).subs(mapping)) + dist = stack.touch(line_id) + tasklet_misses += 1 if dist >= C or dist == -1 else 0 + + for e in state.out_edges(node): + # Check if destination node is just a transient node to map to correct cache line + data_node = e.dst + data_node_out_edges = state.out_edges(data_node) + if len(data_node_out_edges) == 1 and isinstance(data_node_out_edges[0].src, nd.AccessNode): + e = data_node_out_edges[0] - line_id = int(line_id.subs(mapping)) + if e.data.data in clt.array_info or (e.data.data in array_names + and array_names[e.data.data] in clt.array_info): + line_id = clt.cache_line_id( + e.data.data if e.data.data not in array_names else array_names[e.data.data], + [x[0].subs(mapping) for x in e.data.subset.ranges], mapping) + line_id = int(line_id.subs(symbols).subs(mapping)) dist = stack.touch(line_id) tasklet_misses += 1 if dist >= C or dist == -1 else 0 @@ -259,7 +331,7 @@ def scope_op_in(state: SDFGState, for e in state.out_edges(node): nested_array_names[e.src_conn] = e.data.data # Nested SDFGs are recursively analyzed first. - nsdfg_misses = sdfg_op_in(node.sdfg, op_in_map, mapping, stack, clt, C, nested_syms, nested_array_names, + nsdfg_misses = cfg_misses(node.sdfg, op_in_map, mapping, stack, clt, C, nested_syms, nested_array_names, decided_branches, ask_user) scope_misses += nsdfg_misses @@ -280,8 +352,109 @@ def scope_op_in(state: SDFGState, update_map(op_in_map, get_uuid(state), scope_misses, average=False) return scope_misses - -def sdfg_op_in(sdfg: SDFG, +def cfr_misses(cfr:ControlFlowRegion, + op_in_map: Dict[str, Tuple[sp.Expr, sp.Expr]], + mapping, + stack: AccessStack, + clt: CacheLineTracker, + C, + symbols, + array_names, + decided_branches, + ask_user, + start=None): + region_misses = 0 + if isinstance(cfr, SDFGState): + region_misses = scope_misses(cfr, op_in_map, mapping, stack, clt, C, symbols, array_names, decided_branches, + ask_user, None) + + elif isinstance(cfr, LoopRegion): + loop_var = cfr.loop_variable + loop_condition = pystr_to_symbolic(cfr.loop_condition.as_string) + start = loop_analysis.get_init_assignment(cfr).subs(mapping) + step = sp.sympify(loop_analysis.get_loop_stride(cfr)) + mapping[loop_var] = start.subs(mapping) + region_misses = 0 + while (loop_condition.subs(mapping) == True): + iter_misses = cfg_misses(cfr, op_in_map, mapping, stack, clt, C, symbols, array_names, decided_branches, ask_user, start=cfr.start_block,end=None) + mapping[loop_var] = mapping[loop_var] + step + region_misses += iter_misses + elif isinstance(cfr, ConditionalBlock): + true_branches = [] + possible_branches = [] + else_branch = None + + for cond, branch in cfr.branches: + if cond is None: + else_branch = branch + continue + + sym_cond = pystr_to_symbolic(cond.as_string) + res = sym_cond.subs(mapping) + + if res == True: + true_branches.append(branch) + elif res == False: + continue + else: + possible_branches.append(branch) + + ### if the branch is not decided by a true condition we + # 1- ask the userif he hasn't decided yet + # 2- take the one we took last time if he has decided + # 3- take the worst case if he opted not to decide + possibilities = true_branches + possible_branches + [else_branch] + if not true_branches and len(possible_branches)>0 and ask_user and (cfr not in decided_branches or decided_branches[cfr] not in possibilities): + possibilities = true_branches + possible_branches + ['else_branch'] + if len(possibilities)>1: + print(f'\n\nWhich branch to take at {cfr.name}') + for i in range(len(possibilities)): + print(f'({i}) for branch {possibilities[i] if possibilities[i] else "else_branch"}') + chosen = int(input('Choose an option from above: ')) + # if the user chooses one, we check only that branch + branches = [possibilities[chosen]] + if possibilities[chosen]: + # only store the decided branch if it is not the implicit else branch + decided_branches[cfr] = possibilities[chosen] + else: + branches = possibilities + elif true_branches: + # if we have true branches we take the first one + branches = [true_branches[0]] + else: + # else we check all possibilities and take the max + branches = possibilities + + + max_branch_misses = 0 + mapping_after_cond, stack_after_cond, decided_branches_after_cond = mapping, stack, decided_branches + for branch in branches: + if not branch: + # the implicit else branch has no misses + continue + # copy all data that must not be shared between branches + mapping_copy = deepcopy(mapping) + stack_copy = deepcopy(stack) + symbols_copy = deepcopy(symbols) + decided_branches_copy = deepcopy(decided_branches) + branch_misses = cfg_misses(branch, op_in_map, mapping_copy, stack_copy, clt, C, symbols_copy, array_names, decided_branches_copy, ask_user, branch.start_block, None) + + if branch_misses > max_branch_misses: + max_branch_misses = branch_misses + mapping_after_cond, stack_after_cond, decided_branches_after_cond = mapping_copy, stack_copy, decided_branches_copy + + mapping.update(mapping_after_cond) + stack.replace_self(stack_after_cond) + decided_branches.update(decided_branches_after_cond) + region_misses = max_branch_misses + elif isinstance(cfr, FunctionCallRegion): + region_misses = cfg_misses(cfr, op_in_map,mapping,stack,clt,C,symbols,array_names,decided_branches,ask_user,start=cfr.start_block,end=None) + elif isinstance(cfr, (ReturnBlock, ContinueBlock, BreakBlock)): + region_misses = 0 + + return region_misses + +def cfg_misses(cfg: ControlFlowRegion, op_in_map: Dict[str, Tuple[sp.Expr, sp.Expr]], mapping, stack: AccessStack, @@ -311,35 +484,38 @@ def sdfg_op_in(sdfg: SDFG, :param end: The end state of the SDFG traversal. If None, the whole SDFG is traversed. """ - if start is None: + if isinstance(cfg, SDFG) and start is None: # add this SDFG's arrays to the cache line tracker - for name, arr in sdfg.arrays.items(): + for name, arr in cfg.arrays.items(): if isinstance(arr, Array): if name in array_names: name = array_names[name] clt.add_array(name, arr, mapping) # start traversal at SDFG's start state - curr_state = sdfg.start_state + curr_state = cfg.start_block else: curr_state = start total_misses = 0 - # traverse this SDFG's states + # traverse this SDFG's ControlFlowRegions while True: - total_misses += scope_op_in(curr_state, op_in_map, mapping, stack, clt, C, symbols, array_names, - decided_branches, ask_user) - - if len(sdfg.out_edges(curr_state)) == 0: + + region_misses = cfr_misses(curr_state, op_in_map, mapping, stack, clt, C, symbols, array_names, + decided_branches, ask_user) + + total_misses += region_misses + if len(cfg.out_edges(curr_state)) == 0: # we reached an end state --> stop break else: # take first edge with True condition found = False - for e in sdfg.out_edges(curr_state): + for e in cfg.out_edges(curr_state): if e.data.is_unconditional() or e.data.condition_sympy().subs(mapping) == True: # save e's assignments in mapping and update curr_state # replace values first with mapping, then update mapping try: + total_misses += assignment_misses(e, mapping, stack, clt, C, symbols, array_names) update_mapping(mapping, e) except: print('\nWARNING: Uncommon assignment detected on InterstateEdge (e.g. bitwise operators).' @@ -351,7 +527,7 @@ def sdfg_op_in(sdfg: SDFG, if not found: # We need to check if we are in an implicit end state (i.e. all outgoing edge conditions evaluate to False) all_false = True - for e in sdfg.out_edges(curr_state): + for e in cfg.out_edges(curr_state): if e.data.condition_sympy().subs(mapping) != False: all_false = False if all_false: @@ -360,27 +536,28 @@ def sdfg_op_in(sdfg: SDFG, if curr_state in decided_branches: # if the user already decided this branch in a previous iteration, take the same branch again. e = decided_branches[curr_state] - + total_misses += assignment_misses(e, mapping, stack, clt, C, symbols, array_names) update_mapping(mapping, e) curr_state = e.dst else: # we cannot determine which branch to take --> check if both contain work - merge_state = find_merge_state(sdfg, curr_state) + merge_state = find_merge_state(cfg, curr_state) next_edge_candidates = [] - for e in sdfg.out_edges(curr_state): - states = find_states_between(sdfg, e.dst, merge_state) + for e in cfg.out_edges(curr_state): + states = find_states_between(cfg, e.dst, merge_state) curr_work = mem_accesses_on_path(states) if sp.sympify(curr_work).subs(mapping) > 0: next_edge_candidates.append(e) if len(next_edge_candidates) == 1: e = next_edge_candidates[0] + total_misses += assignment_misses(e, mapping, stack, clt, C, symbols, array_names) update_mapping(mapping, e) decided_branches[curr_state] = e curr_state = e.dst else: if ask_user: - edges = sdfg.out_edges(curr_state) + edges = cfg.out_edges(curr_state) print(f'\n\nWhich branch to take at {curr_state.name}') for i in range(len(edges)): print(f'({i}) for edge to state {edges[i].dst.name}') @@ -388,16 +565,19 @@ def sdfg_op_in(sdfg: SDFG, print('merge state is named ', merge_state) chosen = int(input('Choose an option from above: ')) e = edges[chosen] + total_misses += assignment_misses(e, mapping, stack, clt, C, symbols, array_names) update_mapping(mapping, e) decided_branches[curr_state] = e curr_state = e.dst print(2 * '\n') else: final_e = next_edge_candidates.pop() + final_misses = 0 for e in next_edge_candidates: # copy the state of the analysis curr_mapping = dict(mapping) + curr_misses += assignment_misses(e, mapping, stack, clt, C, symbols, array_names) update_mapping(curr_mapping, e) curr_stack = stack.copy() curr_clt = clt.copy() @@ -406,17 +586,18 @@ def sdfg_op_in(sdfg: SDFG, curr_state = e.dst # walk down this branch until merge_state - sdfg_op_in(sdfg, op_in_map, curr_mapping, curr_stack, curr_clt, C, curr_symbols, + cfg_misses(cfg, op_in_map, curr_mapping, curr_stack, curr_clt, C, curr_symbols, curr_array_names, decided_branches, ask_user, curr_state, merge_state) update_mapping(mapping, final_e) curr_state = final_e.dst if curr_state == end: break - + if end is None: # only update if we were actually analyzing a whole sdfg (not just start to end state) - update_map(op_in_map, get_uuid(sdfg), total_misses, average=False) + update_map(op_in_map, get_uuid(cfg), total_misses, average=False) + return total_misses @@ -427,7 +608,7 @@ def analyze_sdfg_op_in(sdfg: SDFG, assumptions, generate_plots=False, stringify=False, - test_set_size=3, + test_set_size=1, ask_user=False): """ Computes the operational intensity of the input SDFG. @@ -448,7 +629,7 @@ def analyze_sdfg_op_in(sdfg: SDFG, # from now on we take C as the number of lines that fit into cache C = C // L - + sdfg = deepcopy(sdfg) # apply SSA pass pipeline = FixedPointPipeline([StrictSymbolSSA()]) @@ -463,9 +644,10 @@ def analyze_sdfg_op_in(sdfg: SDFG, elif isinstance(assumptions[sym], str): range_symbol[sym] = SymbolRange(int(x) for x in assumptions[sym].split(',')) del assumptions[sym] - work_map = {} + assumptions_list = [f'{x}=={y}' for x, y in assumptions.items()] + analyze_sdfg(sdfg, work_map, get_tasklet_work, assumptions_list) if len(undefined_symbols) > 0: @@ -480,12 +662,17 @@ def analyze_sdfg_op_in(sdfg: SDFG, # all symbols are concretized --> run normal op_in analysis with concretized symbols sdfg.specialize(assumptions) mapping = {} + # add the static symbols to the map to allow for better analysis + static_symbols = get_static_symbols(sdfg) + mapping.update(static_symbols) + mapping.update(assumptions) + mapping = {k: subs_till_fixed_point(v, mapping) for k,v in mapping.items()} stack = AccessStack(C) clt = CacheLineTracker(L) - sdfg_op_in(sdfg, op_in_map, mapping, stack, clt, C, {}, {}, {}, ask_user) + total_misses = cfg_misses(sdfg, op_in_map, mapping, stack, clt, C, {}, {}, {}, ask_user) # compute bytes for k, v in op_in_map.items(): op_in_map[k] = v[0] / v[1] * L @@ -500,11 +687,12 @@ def analyze_sdfg_op_in(sdfg: SDFG, while True: new_val = False for sym, r in range_symbol.items(): + val = r.next() if val > -1: new_val = True assumptions[sym] = val - elif t < 3: + elif t < test_set_size: # now we sample test set t += 1 assumptions[sym] = r.max_value() + t * 3 @@ -512,12 +700,19 @@ def analyze_sdfg_op_in(sdfg: SDFG, if not new_val: break + r_sdfg = deepcopy(sdfg) + curr_op_in_map = {} mapping = {} + # add the static symbols to the map to allow for better analysis + static_symbols = get_static_symbols(r_sdfg) + mapping.update(static_symbols) mapping.update(assumptions) + mapping = {k: subs_till_fixed_point(v, mapping) for k,v in mapping.items()} + stack = AccessStack(C) clt = CacheLineTracker(L) - sdfg_op_in(sdfg, curr_op_in_map, mapping, stack, clt, C, {}, {}, {}, ask_user) + cfg_misses(r_sdfg, curr_op_in_map, mapping, stack, clt, C, {}, {}, {}, ask_user) # compute average cache misses for k, v in curr_op_in_map.items(): @@ -526,7 +721,7 @@ def analyze_sdfg_op_in(sdfg: SDFG, # save cache misses curr_cache_misses = dict(curr_op_in_map) - work_measurements.append(work_map[get_uuid(sdfg)][0].subs(assumptions)) + work_measurements.append(work_map[get_uuid(sdfg)].subs(assumptions)) # put curr values in cache_miss_measurements for k, v in curr_cache_misses.items(): if k in cache_miss_measurements: @@ -540,6 +735,7 @@ def analyze_sdfg_op_in(sdfg: SDFG, sympy_fs = {} for k, v in cache_miss_measurements.items(): + final_f, sympy_f, r_s = fit_curve(x_values[:-test_set_size], v[:-test_set_size], symbol_name) op_in_map[k] = sp.simplify(sympy_f * L) sympy_fs[k] = sympy_f @@ -562,6 +758,7 @@ def analyze_sdfg_op_in(sdfg: SDFG, if stringify: for k, v in op_in_map.items(): op_in_map[k] = str(v) + return op_in_map[get_uuid(sdfg)] ################################################################################ @@ -599,7 +796,6 @@ def main() -> None: assumptions[a] = int(b) else: assumptions[a] = b - print(assumptions) analyze_sdfg_op_in(sdfg, op_in_map, int(args.C), int(args.L), assumptions) result_whole_sdfg = op_in_map[get_uuid(sdfg)] diff --git a/dace/sdfg/performance_evaluation/work_depth.py b/dace/sdfg/performance_evaluation/work_depth.py index ecb447ba9d..288ee187e5 100644 --- a/dace/sdfg/performance_evaluation/work_depth.py +++ b/dace/sdfg/performance_evaluation/work_depth.py @@ -11,19 +11,23 @@ import os import sympy as sp from copy import deepcopy -from dace.libraries.blas import MatMul +from dace.libraries.blas import MatMul, Dot, Gemm, Gemv from dace.libraries.standard import Reduce, Transpose from dace.symbolic import pystr_to_symbolic import ast import astunparse import warnings +import re -from dace.sdfg.performance_evaluation.helpers import LoopExtractionError, get_uuid, find_loop_guards_tails_exits +from dace.sdfg.performance_evaluation.helpers import LoopExtractionError, get_uuid, find_loop_guards_tails_exits, get_legacy_loop_body, get_legacy_loop_ranges, get_static_symbols, subs_till_fixed_point from dace.sdfg.performance_evaluation.assumptions import parse_assumptions from dace.transformation.passes.symbol_ssa import StrictSymbolSSA from dace.transformation.pass_pipeline import FixedPointPipeline +from dace.transformation.passes.analysis import loop_analysis +from dace.sdfg.state import AbstractControlFlowRegion, ControlFlowRegion, LoopRegion, ConditionalBlock, ReturnBlock, ContinueBlock, BreakBlock +math_funcs = set() def get_array_size_symbols(sdfg): """ Returns all symbols that appear isolated in shapes of the SDFG's arrays. @@ -41,7 +45,6 @@ def get_array_size_symbols(sdfg): symbols.add(s) return symbols - def symeval(val, symbols): """ Takes a sympy expression and substitutes its symbols according to a dict { old_symbol: new_symbol}. @@ -53,8 +56,14 @@ def symeval(val, symbols): second_replacement = {pystr_to_symbolic('__REPLSYM_' + k): v for k, v in symbols.items()} return sp.simplify(val.subs(first_replacement).subs(second_replacement)) - def evaluate_symbols(base, new): + """Takes a base symbol mapping and a new one and adapts the new one to match the base one for symbols contained in it + + :param base: The base mapping + :param new: The mapping that gets adjusted + :return result: A new mapping that contains all mappings from new, but adjusted to transitively match to the mapping of base + """ + result = {} for k, v in new.items(): result[k] = symeval(pystr_to_symbolic(v), base) @@ -82,7 +91,7 @@ def count_depth_matmul(node, symbols, state): # optimal depth of a matrix multiplication is O(log(size of shared dimension)): A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') size_shared_dimension = symeval(A_memlet.data.subset.size()[-1], symbols) - return sp.log(size_shared_dimension) + return sp.Max(1, sp.log(sp.Max(1, size_shared_dimension), 2)) def count_work_reduce(node, symbols, state): @@ -99,22 +108,180 @@ def count_work_reduce(node, symbols, state): result = 0 return sp.sympify(result) - def count_depth_reduce(node, symbols, state): # optimal depth of reduction is log of the work - return sp.log(count_work_reduce(node, symbols, state)) + return sp.log(sp.Max(1, count_work_reduce(node, symbols, state)), 2) + +def count_work_dot(node, symbols, state): + X_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_x') + Y_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_y') + RES_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_result') + result = 2*symeval(X_memlet.data.subset.size()[-1], symbols)-1 + return sp.sympify(result) + +def count_depth_dot(node, symbols, state): + X_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_x') + Y_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_y') + RES_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_result') + # optimal depth for dot product is 1 for multiplications and logarithmic for additions + result = 1+sp.log(sp.Max(1, symeval(X_memlet.data.subset.size()[-1], symbols)), 2) + return sp.sympify(result) + +def count_work_gemm(node, symbols, state): + """ + Count work for GEMM operation: C = alpha * A @ B + beta * C + Work includes: + - Matrix multiplication: 2*M*N*K (multiply + add per element) + - Alpha scaling: M*N (if alpha != 1) + - Beta scaling + addition: 2*M*N (if beta != 0) + """ + A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') + B_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_b') + C_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_c') + + # Get dimensions + # Handle batch dimension if present + if len(C_memlet.data.subset) == 3: + batch = symeval(C_memlet.data.subset.size()[0], symbols) + M = symeval(C_memlet.data.subset.size()[1], symbols) + N = symeval(C_memlet.data.subset.size()[2], symbols) + else: + batch = 1 + M = symeval(C_memlet.data.subset.size()[-2], symbols) if len(C_memlet.data.subset.size()) >= 2 else 1 + N = symeval(C_memlet.data.subset.size()[-1], symbols) + + K = symeval(A_memlet.data.subset.size()[-1], symbols) + + # Core matrix multiplication: 2*M*N*K (multiply + add) + result = 2 * batch * M * N * K + + # Add work for alpha scaling if alpha != 1 + alpha = getattr(node, 'alpha', 1) + if alpha != 1: + result += batch * M * N # M*N multiplications by alpha + + # Add work for beta * C if beta != 0 + beta = getattr(node, 'beta', 0) + if beta != 0: + result += batch * M * N # M*N multiplications by beta + result += batch * M * N # M*N additions + + return sp.sympify(result) + + +def count_depth_gemm(node, symbols, state): + """ + Optimal depth for GEMM: log(K) for the reduction + constant for scaling/addition + """ + A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') + K = symeval(A_memlet.data.subset.size()[-1], symbols) + + # Depth is dominated by the reduction over K dimension + depth = sp.log(sp.Max(1, K), 2) + + # Add constant depth for alpha and beta operations + alpha = getattr(node, 'alpha', 1) + beta = getattr(node, 'beta', 0) + + if alpha != 1: + depth += 1 # One multiplication layer + if beta != 0: + depth += 2 # One multiplication + one addition layer + + return depth + + +def count_work_gemv(node, symbols, state): + """ + Count work for GEMV operation: y = alpha * A @ x + beta * y + Two variants: + - GEMV: y = alpha * A @ x + beta * y (A is MxN, x is N, y is M) + - GEMVT: y = alpha * A^T @ x + beta * y (A is MxN, x is M, y is N) + + Work includes: + - Matrix-vector multiplication: 2*M*N (multiply + add per element) + - Alpha scaling: M (if alpha != 1) + - Beta scaling + addition: 2*M (if beta != 0) + """ + A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_A') + x_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_x') + y_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_y') + + # Get dimensions from A matrix + A_shape = A_memlet.data.subset.size() + M = symeval(A_shape[-2], symbols) + N = symeval(A_shape[-1], symbols) + + # Check if transpose (GEMVT) + trans = getattr(node, 'transA', False) + + # Output size + output_size = N if trans else M + + # Core matrix-vector multiplication: 2*M*N (each output element needs N multiplies and N-1 adds) + result = 2 * M * N + + # Add work for alpha scaling if alpha != 1 + alpha = getattr(node, 'alpha', 1) + if alpha != 1: + result += output_size # output_size multiplications by alpha + + # Add work for beta * y if beta != 0 + beta = getattr(node, 'beta', 0) + if beta != 0: + result += output_size # output_size multiplications by beta + result += output_size # output_size additions + + return sp.sympify(result) + + +def count_depth_gemv(node, symbols, state): + """ + Optimal depth for GEMV: log(N) for the reduction + constant for scaling/addition + where N is the reduction dimension + """ + A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_A') + A_shape = A_memlet.data.subset.size() + M = symeval(A_shape[-2], symbols) + N = symeval(A_shape[-1], symbols) + + # Check if transpose + trans = getattr(node, 'transA', False) + + # Reduction dimension + reduction_dim = M if trans else N + + # Depth is dominated by the reduction + depth = sp.log(sp.Max(1, reduction_dim), 2) + + # Add constant depth for alpha and beta operations + alpha = getattr(node, 'alpha', 1) + beta = getattr(node, 'beta', 0) + + if alpha != 1: + depth += 1 # One multiplication layer + if beta != 0: + depth += 2 # One multiplication + one addition layer + + return sp.Max(1, depth) LIBNODES_TO_WORK = { MatMul: count_work_matmul, + Gemm: count_work_gemm, + Gemv: count_work_gemv, Transpose: lambda *args: 0, Reduce: count_work_reduce, + Dot: count_work_dot, } LIBNODES_TO_DEPTH = { MatMul: count_depth_matmul, + Gemm: count_depth_gemm, + Gemv: count_depth_gemv, Transpose: lambda *args: 0, Reduce: count_depth_reduce, + Dot: count_depth_dot, } PYFUNC_TO_ARITHMETICS = { @@ -437,19 +604,19 @@ def do_initial_subs(w, d, eq, subs1): return result -def sdfg_work_depth(sdfg: SDFG, - w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], - analyze_tasklet: Callable[[nd.Tasklet, SDFGState], Tuple[sp.Expr, sp.Expr]], - symbols: Dict[str, str], - equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], - subs1: Dict[str, sp.Expr], - detailed_analysis: bool = False) -> Tuple[sp.Expr, sp.Expr]: +def control_flow_region_work_depth(cfr: ControlFlowRegion, + w_d_map: Dict[str, Tuple[sp.Expr|List[Tuple[sp.Expr, sp.Expr]], sp.Expr|List[Tuple[sp.Expr, sp.Expr]]]], + analyze_tasklet: Callable[[nd.Tasklet, SDFGState], Tuple[sp.Expr, sp.Expr]], + symbols: Dict[str, str], + equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], + subs1: Dict[str, sp.Expr], + detailed_analysis: bool = False) -> Tuple[sp.Expr|List[Tuple[sp.Expr, sp.Expr]], sp.Expr|List[Tuple[sp.Expr, sp.Expr]]]: """ - Analyze the work and depth of a given SDFG. + Analyze the work and depth of a given (structured) ControlFLowRegion. First we determine the work and depth of each state. Then we break loops in the state machine, such that we get a DAG. Lastly, we compute the path with most work and the path with the most depth in order to get the total work depth. - :param sdfg: The SDFG to analyze. + :param cfr: The ControlFLowRegion to analyze. :param w_d_map: Dictionary which will save the result. :param analyze_tasklet: Function used to analyze tasklet nodes. :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. @@ -460,29 +627,155 @@ def sdfg_work_depth(sdfg: SDFG, :param subs1: First substitution dict for greater/lesser assumptions. :return: A tuple containing the work and depth of the SDFG. """ + # First determine the work and depth of each ControlFlowRegion individually. + # Keep track of the work and depth for each state in a dictionary + region_depths: Dict[AbstractControlFlowRegion, sp.Expr] = {} + region_works: Dict[AbstractControlFlowRegion, sp.Expr] = {} + for region in cfr.nodes(): + if isinstance(region, SDFGState): + #rename variable to make code more readable + state = region + state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, + detailed_analysis) + # Substitutions for state_work and state_depth already performed, but state.executions needs to be subs'd now. + state_work = sp.simplify(state_work.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + state_depth = sp.simplify(state_depth.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + + region_works[state], region_depths[state] = state_work, state_depth + w_d_map[get_uuid(state)] = (region_works[state], region_depths[state]) + elif isinstance(region, LoopRegion): + #rename variable to make code more readable + loop = region + fallback = False + + # We try to get a closed form solution for the work and depth of the loop. If this is not possible (e.g. because there are no static loop bounds), + # we fall back to just multiplying the work and depth of one loop iteration with the number of loop iterations. This can lead to incorrect results, + # since it does not take into account that work and depth of one loop iteration might be dependent on the loop variable. + loop_var = sp.Symbol(loop.loop_variable) + lower_bound = loop_analysis.get_init_assignment(loop) + upper_bound = loop_analysis.get_loop_end(loop) + step = sp.sympify(loop_analysis.get_loop_stride(loop)) + if any(v is None for v in (loop_var, lower_bound, upper_bound)): + print("Because the loop does not provide a loop variable and static bounds, we fell back to just using its number of iterations. Mind that this can affect the correctness of the expression ") + fallback = True + executions = loop.start_block.executions + executions = executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) + + # Recursively get the work and depth of the loop body + loop_work, loop_depth = control_flow_region_work_depth(loop, w_d_map, analyze_tasklet, symbols, + equality_subs, subs1, detailed_analysis) + + if not fallback: + # If static loop bounds are available, we can write the work and depth of the loop as a summation over the loop variable from the lower to the upper bound. + + # to ensure that the summation works properly, we need to make sure that the symbol that is used as loop varaible + # is the same as the ones used in the inner expression + for var in loop_work.free_symbols: + if var.name == loop_var.name: + loop_var = var + + #TEMPORARY FIX: with library nodes it can happen that we get two symbols (with the same name) + for var in loop_work.free_symbols: + if var.name == loop_var.name and not var == loop_var: + loop_work = loop_work.subs({var: loop_var}) + loop_depth = loop_depth.subs({var: loop_var}) + + # prepare loop bounds such that we can write the work as a nice summation from 0 to an upper bound + loop_var = loop_var.subs(subs1) + shifted_hi = (upper_bound-lower_bound)//step + shifted_hi = shifted_hi.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) + lower_bound = lower_bound.subs(subs1) if step.evalf()>0 else upper_bound.subs(subs1) + shifted_lo = sp.sympify(0) + step:sp.Expr = sp.Abs(step) + + loop_work = loop_work.subs({loop_var: (step*loop_var+lower_bound)}) + loop_depth = loop_depth.subs({loop_var: (step*loop_var+lower_bound)}) + + # write the work and depth of the loop as a sum of the work of one iteration over the number of loop iterations + # (we have cannot use a simple multiplication as work and depth of one loop iteration might be dependent on the loop variable) + loop_work = sp.Sum(loop_work, (loop_var, shifted_lo, shifted_hi)).doit() + loop_depth = sp.Sum(loop_depth, (loop_var, shifted_lo, shifted_hi)).doit() + + # Do equality subs + loop_work = sp.simplify(loop_work.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + loop_depth = sp.simplify(loop_depth.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + else: + loop_work = sp.simplify(loop_work.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + loop_depth = sp.simplify(loop_depth.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) - # First determine the work and depth of each state individually. - # Keep track of the work and depth for each state in a dictionary, where work and depth are multiplied by the number - # of times the state will be executed. - state_depths: Dict[SDFGState, sp.Expr] = {} - state_works: Dict[SDFGState, sp.Expr] = {} - for state in sdfg.nodes(): - state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, - detailed_analysis) - - # Substitutions for state_work and state_depth already performed, but state.executions needs to be subs'd now. - state_work = sp.simplify( - state_work.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) * - state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) - state_depth = sp.simplify( - state_depth.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) * - state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) - - state_works[state], state_depths[state] = state_work, state_depth - w_d_map[get_uuid(state)] = (state_works[state], state_depths[state]) + if executions != 0: + loop_work = loop_work*executions + loop_depth = loop_depth*executions + else: + exec_symbol = sp.Symbol(f'num_execs_{region.sdfg.cfg_id}_{region.sdfg.node_id(region)}', nonnegative=True) + loop_work = loop_work*exec_symbol + loop_depth = loop_depth*exec_symbol + + region_works[loop], region_depths[loop] = loop_work, loop_depth + w_d_map[get_uuid(loop)] = (region_works[loop], region_depths[loop]) + + elif isinstance(region, ConditionalBlock): + branch_conditions = {} + branch_works = {} + branch_depths = {} + for (condition, branch) in region.branches: + branch_conditions[branch] = ( + pystr_to_symbolic(condition.as_string) + if condition is not None else sp.sympify(True) + ) + branch_works[branch], branch_depths[branch] = control_flow_region_work_depth( + branch, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, detailed_analysis + ) + + if analyze_tasklet == get_tasklet_avg_par: + # For avg_par we want the branch minimising W/D (worst case parallelism). + # Build a Piecewise that selects work and depth together based on which + # branch has the smallest W/D ratio, so the pair stays consistent. + branches = list(branch_works.keys()) + + def avg_par_expr_val(w: sp.Expr, d: sp.Expr) -> sp.Expr: + """Return W/D, treating D=0 as infinite parallelism (never worst-case).""" + return sp.Piecewise((w / d, d > 0), (sp.oo, True)) + + # Start with the first branch as the running minimum + best_work = branch_works[branches[0]] + best_depth = branch_depths[branches[0]] + + for b in branches[1:]: + w_b = branch_works[b] + d_b = branch_depths[b] + ap_best = avg_par_expr_val(best_work, best_depth) + ap_b = avg_par_expr_val(w_b, d_b) + is_worse = sp.simplify(ap_b < ap_best) # lower ratio = worse parallelism + best_work = sp.Piecewise((w_b, is_worse), (best_work, True)) + best_depth = sp.Piecewise((d_b, is_worse), (best_depth, True)) + + region_works[region] = best_work + region_depths[region] = best_depth + elif not detailed_analysis: + region_works[region] = sp.Max(*branch_works.values()) + region_depths[region] = sp.Max(*branch_depths.values()) + else: + work_condition = list(zip(branch_works.values(), branch_conditions.values())) + depth_condition = list(zip(branch_depths.values(), branch_conditions.values())) + region_works[region] = sp.Piecewise(*work_condition) + region_depths[region] = sp.Piecewise(*depth_condition) + w_d_map[get_uuid(region)] = (region_works[region], region_depths[region]) + + elif isinstance(region, (ReturnBlock, ContinueBlock, BreakBlock)): + region_works[region], region_depths[region] = (sp.sympify(0), sp.sympify(0)) + w_d_map[get_uuid(region)] = (sp.sympify(0), sp.sympify(0)) + else: + function_work, function_depth = control_flow_region_work_depth(region, w_d_map, analyze_tasklet, symbols, + equality_subs, subs1, detailed_analysis) + function_work = sp.simplify(function_work.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + function_depth = sp.simplify(function_depth.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + + region_works[region], region_depths[region] = function_work, function_depth + w_d_map[get_uuid(region)] = (region_works[region], region_depths[region]) edge_w_d_map: Dict[Tuple[str, str], Tuple[sp.Expr, sp.Expr]] = {} - for isedge in sdfg.edges(): + for isedge in cfr.edges(): edge_work, edge_depth = sp.sympify(0), sp.sympify(0) if isedge.data.assignments: for v in isedge.data.assignments.values(): @@ -494,16 +787,18 @@ def sdfg_work_depth(sdfg: SDFG, # the guard, and instead places an edge between the last loop state and the exit state. # This transforms the state machine into a DAG. Hence, we can find the "heaviest" and "deepest" paths in linear time. # Additionally, construct a dummy exit state and connect every state that has no outgoing edges to it. - - # identify all loops in the SDFG + + + # ================================= Handling of Legacy Loops ================================= + # identify legacy loops (i.e. loops that aren't LoopRegions) in the CFR try: - nodes_oNodes_exits = find_loop_guards_tails_exits(sdfg._nx) + nodes_oNodes_exits = find_loop_guards_tails_exits(cfr._nx) except LoopExtractionError: # If loop detection fails, we cannot make proper propagation. print('Analysis failed since not all loops got detected. It may help to use more structured loop constructs.' + ' The analysis per state remains correct, but no SDFG-wide analysis can be performed.') sdfg_result = (sp.oo, sp.oo) - w_d_map[get_uuid(sdfg)] = sdfg_result + w_d_map[get_uuid(cfr)] = sdfg_result for k, (v_w, v_d) in w_d_map.items(): # The symeval replaces nested SDFG symbols with their global counterparts. @@ -511,6 +806,69 @@ def sdfg_work_depth(sdfg: SDFG, v_d = symeval(v_d, symbols) w_d_map[k] = (v_w, v_d) return sdfg_result + + legacy_loop_ranges = get_legacy_loop_ranges(cfr) + + for guard, tail, exits in nodes_oNodes_exits: + + if guard not in legacy_loop_ranges: + # propagate_states didn't recognise this as an annotated for-loop + # (e.g. while-loop or dynamic unbounded). Fall back to the + # execution-count symbol approach as before. + iter_sym = sp.Symbol( + f'num_execs_{guard.parent.cfg_id}_{guard.parent.node_id(guard)}', + nonnegative=True + ) + loop_body_nodes = get_legacy_loop_body(cfr, guard, tail, exits) + for node in loop_body_nodes: + if node is guard: + continue + node_uuid = get_uuid(node) + if node_uuid not in w_d_map: + continue + node_work, node_depth = w_d_map[node_uuid] + region_works[node] = sp.simplify(node_work * iter_sym) + region_depths[node] = sp.simplify(node_depth * iter_sym) + w_d_map[node_uuid] = (region_works[node], region_depths[node]) + continue + + print(f'Handling legacy loop with guard {guard} and tail {tail} and loops bounds {legacy_loop_ranges[guard]}') + loop_var, start, stop, stride = legacy_loop_ranges[guard] + loop_body_nodes = get_legacy_loop_body(cfr, guard, tail, exits) + + loop_var = loop_var.subs(subs1) + # Shift loop variable to start from 0 with step 1 + shifted_hi = (stop - start) // stride + shifted_hi = shifted_hi.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) + + lower_bound = start.subs(subs1) if stride.evalf() > 0 else stop.subs(subs1) + abs_stride = sp.Abs(stride) + + for node in loop_body_nodes: + if node is guard: + continue + node_uuid = get_uuid(node) + if node_uuid not in w_d_map: + continue + node_work = w_d_map[node_uuid][0] + node_depth = w_d_map[node_uuid][1] + # Substitute loop_var -> (abs_stride * loop_var + lower_bound), matching + # the LoopRegion branch exactly. + node_work = node_work.subs( {loop_var: abs_stride * loop_var + lower_bound}) + node_depth = node_depth.subs({loop_var: abs_stride * loop_var + lower_bound}) + + node_work = sp.Sum(node_work, (loop_var, sp.sympify(0), shifted_hi)).doit() + node_depth = sp.Sum(node_depth, (loop_var, sp.sympify(0), shifted_hi)).doit() + + node_work = sp.simplify(node_work.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + node_depth = sp.simplify(node_depth.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + + # Write the summed result back into ONLY the guard, so the downstream + # BFS traversal picks up one single node representing the whole loop. + region_works[node] = node_work + region_depths[node] = node_depth + w_d_map[get_uuid(node)] = (node_work, node_depth) + # Now we need to go over each triple (node, oNode, exits). For each triple, we # - remove edge (oNode, node), i.e. the backward edge @@ -519,39 +877,42 @@ def sdfg_work_depth(sdfg: SDFG, # - This ensures that every node with > 1 outgoing edge is a branch guard # - useful for detailed anaylsis. for node, oNode, exits in nodes_oNodes_exits: - sdfg.remove_edge(sdfg.edges_between(oNode, node)[0]) + cfr.remove_edge(cfr.edges_between(oNode, node)[0]) for e in exits: - if len(sdfg.edges_between(oNode, e)) == 0: + if len(cfr.edges_between(oNode, e)) == 0: # no edge there yet - sdfg.add_edge(oNode, e, InterstateEdge()) - if len(sdfg.edges_between(node, e)) > 0: + cfr.add_edge(oNode, e, InterstateEdge()) + if len(cfr.edges_between(node, e)) > 0: # edge present --> remove it - sdfg.remove_edge(sdfg.edges_between(node, e)[0]) + cfr.remove_edge(cfr.edges_between(node, e)[0]) + # ================================= End of Legacy Loop Handling ================================= + # add a dummy exit to the SDFG, such that each path ends there. - dummy_exit = sdfg.add_state('dummy_exit') - for state in sdfg.nodes(): - if len(sdfg.out_edges(state)) == 0 and state is not dummy_exit: - sdfg.add_edge(state, dummy_exit, InterstateEdge()) + dummy_exit = cfr.add_state('dummy_exit') + for region in cfr.nodes(): + if len(cfr.out_edges(region)) == 0 and region is not dummy_exit: + cfr.add_edge(region, dummy_exit, InterstateEdge()) # These two dicts save the current length of the "heaviest", resp. "deepest", paths at each state. - work_map: Dict[SDFGState, sp.Expr] = {} - depth_map: Dict[SDFGState, sp.Expr] = {} + work_map: Dict[AbstractControlFlowRegion, sp.Expr] = {} + depth_map: Dict[AbstractControlFlowRegion, sp.Expr] = {} # Keeps track of assignments done on InterstateEdges. - state_value_map: Dict[SDFGState, Dict[sp.Symbol, sp.Symbol]] = {} + region_value_map: Dict[AbstractControlFlowRegion, Dict[sp.Symbol, sp.Symbol]] = {} # The dummy state has 0 work and depth. - state_depths[dummy_exit] = sp.sympify(0) - state_works[dummy_exit] = sp.sympify(0) + region_depths[dummy_exit] = sp.sympify(0) + region_works[dummy_exit] = sp.sympify(0) # Perform a BFS traversal of the state machine and calculate the maximum work / depth at each state. Only advance to # the next state in the BFS if all incoming edges have been visited, to ensure the maximum work / depth expressions # have been calculated. traversal_q = deque() - traversal_q.append((sdfg.start_state, sp.sympify(0), sp.sympify(0), None, [], [], {})) + traversal_q.append((cfr.start_block, sp.sympify(0), sp.sympify(0), None, [], [], {})) visited = set() - + c = 0 while traversal_q: - state, depth, work, ie, condition_stack, common_subexpr_stack, value_map = traversal_q.popleft() + c += 1 + region, depth, work, ie, condition_stack, common_subexpr_stack, value_map = traversal_q.popleft() if ie is not None: visited.add(ie) @@ -560,43 +921,43 @@ def sdfg_work_depth(sdfg: SDFG, work += edge_w_d_map[edge_uid][0] depth += edge_w_d_map[edge_uid][1] - if state in state_value_map: + if region in region_value_map: # update value map: - update_value_map(state_value_map[state], value_map) + update_value_map(region_value_map[region], value_map) else: - state_value_map[state] = value_map + region_value_map[region] = value_map - value_map = {pystr_to_symbolic(k): pystr_to_symbolic(v) for k, v in state_value_map[state].items()} - n_depth = sp.simplify((depth + state_depths[state]).subs(value_map)) - n_work = sp.simplify((work + state_works[state]).subs(value_map)) + value_map = {pystr_to_symbolic(k): pystr_to_symbolic(v) for k, v in region_value_map[region].items()} + n_depth = sp.simplify((depth + region_depths[region]).subs(value_map)) + n_work = sp.simplify((work + region_works[region]).subs(value_map)) # If we are analysing average parallelism, we don't search "heaviest" and "deepest" paths separately, but we want one # single path with the least average parallelsim (of all paths with more than 0 work). if analyze_tasklet == get_tasklet_avg_par: - if state in depth_map: # this means we have already visited this state before + if region in depth_map: # this means we have already visited this region before cse = common_subexpr_stack.pop() # if current path has 0 depth (--> 0 work as well), we don't do anything. if n_depth != 0: # check if we need to update the work and depth of the current state # we update if avg parallelism of new incoming path is less than current avg parallelism - if depth_map[state] == 0: + if depth_map[region] == 0: # old value was divided by zero --> we take new value anyway - depth_map[state] = cse[1] + n_depth - work_map[state] = cse[0] + n_work + depth_map[region] = cse[1] + n_depth + work_map[region] = cse[0] + n_work else: - old_avg_par = (cse[0] + work_map[state]) / (cse[1] + depth_map[state]) + old_avg_par = (cse[0] + work_map[region]) / (cse[1] + depth_map[region]) new_avg_par = (cse[0] + n_work) / (cse[1] + n_depth) # we take either old work/depth or new work/depth (or both if we cannot determine which one is greater) - depth_map[state] = cse[1] + sp.Piecewise((n_depth, sp.simplify(new_avg_par < old_avg_par)), - (depth_map[state], True)) - work_map[state] = cse[0] + sp.Piecewise((n_work, sp.simplify(new_avg_par < old_avg_par)), - (work_map[state], True)) + depth_map[region] = cse[1] + sp.Piecewise((n_depth, sp.simplify(new_avg_par < old_avg_par)), + (depth_map[region], True)) + work_map[region] = cse[0] + sp.Piecewise((n_work, sp.simplify(new_avg_par < old_avg_par)), + (work_map[region], True)) else: - depth_map[state] = n_depth - work_map[state] = n_work + depth_map[region] = n_depth + work_map[region] = n_work else: # search heaviest and deepest path separately - if state in depth_map: # and consequently also in work_map + if region in depth_map: # and consequently also in work_map # This cse value would appear in both arguments of the Max. Hence, for performance reasons, # we pull it out of the Max expression. # Example: We do cse + Max(a, b) instead of Max(cse + a, cse + b). @@ -606,18 +967,18 @@ def sdfg_work_depth(sdfg: SDFG, if detailed_analysis: # This MAX should be covered in the more detailed analysis cond = condition_stack.pop() - work_map[state] = cse[0] + sp.Piecewise((work_map[state], sp.Not(cond)), (n_work, cond)) - depth_map[state] = cse[1] + sp.Piecewise((depth_map[state], sp.Not(cond)), (n_depth, cond)) + work_map[region] = cse[0] + sp.Piecewise((work_map[region], sp.Not(cond)), (n_work, cond)) + depth_map[region] = cse[1] + sp.Piecewise((depth_map[region], sp.Not(cond)), (n_depth, cond)) else: - work_map[state] = cse[0] + sp.Max(work_map[state], n_work) - depth_map[state] = cse[1] + sp.Max(depth_map[state], n_depth) + work_map[region] = cse[0] + sp.Max(work_map[region], n_work) + depth_map[region] = cse[1] + sp.Max(depth_map[region], n_depth) else: - depth_map[state] = n_depth - work_map[state] = n_work + depth_map[region] = n_depth + work_map[region] = n_work - out_edges = sdfg.out_edges(state) + out_edges = cfr.out_edges(region) # only advance after all incoming edges were visited (meaning that current work depth values of state are final). - if any(iedge not in visited for iedge in sdfg.in_edges(state)): + if any(iedge not in visited for iedge in cfr.in_edges(region)): pass else: for oedge in out_edges: @@ -628,9 +989,9 @@ def sdfg_work_depth(sdfg: SDFG, new_cond_stack.append(oedge.data.condition_sympy()) # same for common_subexr_stack new_cse_stack = list(common_subexpr_stack) - new_cse_stack.append((work_map[state], depth_map[state])) + new_cse_stack.append((work_map[region], depth_map[region])) # same for value_map - new_value_map = dict(state_value_map[state]) + new_value_map = dict(region_value_map[region]) new_value_map.update({ pystr_to_symbolic(k): pystr_to_symbolic(v).subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) @@ -644,7 +1005,7 @@ def sdfg_work_depth(sdfg: SDFG, pystr_to_symbolic(v).subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1) for k, v in oedge.data.assignments.items() }) - traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge, condition_stack, + traversal_q.append((oedge.dst, depth_map[region], work_map[region], oedge, condition_stack, common_subexpr_stack, value_map)) try: @@ -653,18 +1014,18 @@ def sdfg_work_depth(sdfg: SDFG, except KeyError: # If we get a KeyError above, this means that the traversal never reached the dummy_exit state. # This happens if the loops were not properly detected and broken. - raise LoopExtractionError( - 'Analysis failed, since not all loops got detected. It may help to use more structured loop constructs.') + raise RuntimeError("Analysis failed! The dummy exit state was never reached") - sdfg_result = (max_work, max_depth) - w_d_map[get_uuid(sdfg)] = sdfg_result + cfr_result = (max_work.simplify(), max_depth.simplify()) + w_d_map[get_uuid(cfr)] = cfr_result for k, (v_w, v_d) in w_d_map.items(): # The symeval replaces nested SDFG symbols with their global counterparts. v_w = symeval(v_w, symbols) v_d = symeval(v_d, symbols) w_d_map[k] = (v_w, v_d) - return sdfg_result + + return cfr_result def scope_work_depth( @@ -736,10 +1097,13 @@ def scope_work_depth( nested_syms.update(symbols) nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping)) # Nested SDFGs are recursively analyzed first. - nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms, equality_subs, - subs1, detailed_analysis) + nsdfg_work, nsdfg_depth = control_flow_region_work_depth(node.sdfg, w_d_map, analyze_tasklet, {}, + equality_subs, {}, detailed_analysis) + + nsdfg_work, nsdfg_depth = nsdfg_work.subs(nested_syms), nsdfg_depth.subs(nested_syms) # We cannot use assumptions for nested sdfg analysis. It interfers with the global assumptions. We thus substitute afterwards nsdfg_work, nsdfg_depth = do_initial_subs(nsdfg_work, nsdfg_depth, equality_subs, subs1) + # add up work for whole state, but also save work for this nested SDFG in w_d_map work += nsdfg_work w_d_map[get_uuid(node, state)] = (nsdfg_work, nsdfg_depth) @@ -766,7 +1130,10 @@ def scope_work_depth( lib_node_depth = LIBNODES_TO_DEPTH[type(node)](node, symbols, state) except KeyError: top_level_sdfg = state.parent - top_level_sdfg.add_symbol(f'{node.name}_depth', dtypes.int64) + try: + top_level_sdfg.add_symbol(f'{node.name}_depth', dtypes.int64) + except FileExistsError: + pass lib_node_depth = sp.Symbol(f'{node.name}_depth', positive=True) lib_node_work, lib_node_depth = do_initial_subs(lib_node_work, lib_node_depth, equality_subs, subs1) work += lib_node_work @@ -895,7 +1262,7 @@ def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet, assumptions: List[str], - detailed_analysis: bool = False) -> None: + detailed_analysis: bool = False): """ Analyze a given SDFG. We can either analyze work, work and depth or average parallelism. @@ -909,18 +1276,17 @@ def analyze_sdfg(sdfg: SDFG, and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). """ - # deepcopy such that original sdfg not changed sdfg = deepcopy(sdfg) # apply SSA pass pipeline = FixedPointPipeline([StrictSymbolSSA()]) pipeline.apply_pass(sdfg, {}) - + static_symbol_mapping = get_static_symbols(sdfg) array_symbols = get_array_size_symbols(sdfg) # parse assumptions equality_subs, all_subs = parse_assumptions(assumptions if assumptions is not None else [], array_symbols) - + # Run state propagation for all SDFGs recursively. This is necessary to determine the number of times each state # will be executed, or to determine upper bounds for that number (such as in the case of branching) for sd in sdfg.all_sdfgs_recursive(): @@ -928,9 +1294,9 @@ def analyze_sdfg(sdfg: SDFG, # Analyze the work and depth of the SDFG. symbols = {} - sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols, equality_subs, all_subs[0][0] if len(all_subs) > 0 else {}, - detailed_analysis) - + control_flow_region_work_depth(sdfg, w_d_map, analyze_tasklet, symbols, equality_subs, + all_subs[0][0] if len(all_subs) > 0 else {}, detailed_analysis) + for k, (v_w, v_d) in w_d_map.items(): # The symeval replaces nested SDFG symbols with their global counterparts. v_w, v_d = do_subs(v_w, v_d, all_subs) @@ -938,6 +1304,23 @@ def analyze_sdfg(sdfg: SDFG, v_d = symeval(v_d, symbols) w_d_map[k] = (v_w, v_d) + for k, v, in w_d_map.items(): + w_d_map[k] = ((v[0].subs(static_symbol_mapping).subs(equality_subs[1])),(v[1].subs(static_symbol_mapping).subs(equality_subs[1]))) + + if analyze_tasklet == get_tasklet_work_depth: + for k, v, in w_d_map.items(): + w_d_map[k] = ((sp.simplify(v[0])), (sp.simplify(v[1]))) + elif analyze_tasklet == get_tasklet_work: + for k, v, in w_d_map.items(): + w_d_map[k] = (sp.simplify(v[0])) + elif analyze_tasklet == get_tasklet_avg_par: + for k, v, in w_d_map.items(): + w_d_map[k] = (sp.simplify(v[0] / v[1]) if (v[1]) != 0 else 0) # work / depth = avg par + + result_whole_sdfg = w_d_map[get_uuid(sdfg)] + + return result_whole_sdfg + def do_subs(work, depth, all_subs): """ @@ -994,16 +1377,6 @@ def main() -> None: work_depth_map = {} analyze_sdfg(sdfg, work_depth_map, analyze_tasklet, args.assume, args.detailed) - if args.analyze == 'workDepth': - for k, v, in work_depth_map.items(): - work_depth_map[k] = (str(sp.simplify(v[0])), str(sp.simplify(v[1]))) - elif args.analyze == 'work': - for k, v, in work_depth_map.items(): - work_depth_map[k] = str(sp.simplify(v[0])) - elif args.analyze == 'avgPar': - for k, v, in work_depth_map.items(): - work_depth_map[k] = str(sp.simplify(v[0] / v[1]) if str(v[1]) != '0' else 0) # work / depth = avg par - result_whole_sdfg = work_depth_map[get_uuid(sdfg)] print(80 * '-') @@ -1018,4 +1391,4 @@ def main() -> None: if __name__ == '__main__': - main() + main() \ No newline at end of file diff --git a/tests/sdfg/operational_intensity_test.py b/tests/sdfg/operational_intensity_test.py index 36be0455db..71e610f386 100644 --- a/tests/sdfg/operational_intensity_test.py +++ b/tests/sdfg/operational_intensity_test.py @@ -49,14 +49,14 @@ def if_else(x: dc.int64[100], sum: dc.int64[1]): if x[0] > 3: for i in range(100): sum += x[i] - # no else --> simply analyze the ifs. if cache big enough, everything is reused + # no else --> simply analyze the ifs. if cache big enough, everything is reused; @dc.program def unaligned_for_loop(x: dc.float32[100], sum: dc.int64[1]): for i in range(17, 53): sum += x[i] - + # 36 = 144byte array elemets accessed 1 = 4byte scalar accessed 36 ops -> 64byte line size=> 3 lines + 1 line scalar => op in = 9/64 @dc.program def sequential_maps(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): @@ -109,6 +109,16 @@ def reduction_library_node(x: dc.float64[N]): 'single_map16_even': (single_map16, 64 * 64, 64, { 'N': 512 }, 1 / 6), + 'single_for_loop': (single_for_loop, 64 * 64, 64, { + 'N': 512 + }, 1/16), + 'if_else': (if_else, 64 * 64, 64, { + 'N': 512 + }, 200/(14*64)) + # 14 cache misses, because DaCe introduces intermediate variable + , + 'unaligned_for_loop': (unaligned_for_loop, 64 * 64, 64, { + }, 9/64), # now num_elements_on_single_cache_line does not divie N anymore # -->513 work, 520 elements loaded --> 513 / (520*8*3) 'single_map64_uneven': (single_map64, 64 * 64, 64, { diff --git a/tests/sdfg/work_depth_test.py b/tests/sdfg/work_depth_test.py index d2323aeddd..d5ab798c83 100644 --- a/tests/sdfg/work_depth_test.py +++ b/tests/sdfg/work_depth_test.py @@ -1,6 +1,7 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Contains test cases for the work depth analysis. """ from typing import Dict, List, Tuple +from unittest import result import pytest import dace as dc @@ -180,6 +181,10 @@ def gemm_library_node(x: dc.float64[456, 200], y: dc.float64[200, 111], z: dc.fl def gemm_library_node_symbolic(x: dc.float64[M, K], y: dc.float64[K, N], z: dc.float64[M, N]): z[:] = x @ y +@dc.program +def loop_var_dependent_work(x: dc.float64[N], y: dc.float64[N], z:dc.float64[N]): + for i in range(1, N+1): + z[i-1] = np.dot(x[:i], y[:i]) #(sdfg, (expected_work, expected_depth)) work_depth_test_cases: Dict[str, Tuple[DaceProgram, Tuple[symbolic.SymbolicType, symbolic.SymbolicType]]] = { @@ -193,18 +198,25 @@ def gemm_library_node_symbolic(x: dc.float64[M, K], y: dc.float64[K, N], z: dc.f 'nested_if_else': (nested_if_else, (sp.Max(K, 3 * N, M + N), sp.Max(3, K, M + 1))), 'max_of_positive_symbols': (max_of_positive_symbol, (3 * N**2, 3 * N)), 'multiple_array_sizes': (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), - 'unbounded_while_do': (unbounded_while_do, (sp.Symbol('num_execs_0_5') * N, sp.Symbol('num_execs_0_5'))), + 'unbounded_while_do': (unbounded_while_do, (sp.Symbol('num_execs_0_0') * N, sp.Symbol('num_execs_0_0'))), # We get this Max(1, num_execs), since it is a do-while loop, but the num_execs symbol does not capture this. - 'unbounded_nonnegify': (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_8') * N, 2 * sp.Symbol('num_execs_0_8'))), + 'unbounded_nonnegify': (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_0') * N, 2 * sp.Symbol('num_execs_0_0'))), 'break_for_loop': (break_for_loop, (N**2, N)), - 'break_while_loop': (break_while_loop, (sp.Symbol('num_execs_0_7') * N, sp.Symbol('num_execs_0_7'))), + 'break_while_loop': (break_while_loop, (sp.Symbol('num_execs_0_0') * N, sp.Symbol('num_execs_0_0'))), 'sequential_ifs': (sequntial_ifs, (sp.Max(N + 1, M) + sp.Max(N + 1, M + 1), sp.Max(1, M) + 1)), - 'reduction_library_node': (reduction_library_node, (456, sp.log(456))), - 'reduction_library_node_symbolic': (reduction_library_node_symbolic, (N, sp.log(N))), - 'gemm_library_node': (gemm_library_node, (2 * 456 * 200 * 111, sp.log(200))), - 'gemm_library_node_symbolic': (gemm_library_node_symbolic, (2 * M * K * N, sp.log(K))) + 'reduction_library_node': (reduction_library_node, (456, sp.log(456)/sp.log(2))), + 'reduction_library_node_symbolic': (reduction_library_node_symbolic, (N, sp.log(sp.Max(1, N))/sp.log(2))), + 'gemm_library_node': (gemm_library_node, (2 * 456 * 200 * 111, sp.log(200)/sp.log(2))), + 'gemm_library_node_symbolic': (gemm_library_node_symbolic, (2 * M * K * N, sp.Max(1, sp.log(sp.Max(1, K))/sp.log(2)))), + 'loop_var_dependent_work': (loop_var_dependent_work, (N**2, N+sp.Sum(sp.log(sp.Symbol("_p_i", nonnegative=True) + 1), (sp.Symbol("_p_i", nonnegative=True), 0, N - 1))/sp.log(2))) } +def standardize(expr): + new_expr = expr.replace( + lambda x: isinstance(x, sp.Symbol), + lambda x: sp.Symbol(x.name) + ) + return new_expr @pytest.mark.parametrize('test_name', list(work_depth_test_cases.keys())) def test_work_depth(test_name): @@ -220,21 +232,51 @@ def test_work_depth(test_name): sdfg.apply_transformations(MapExpansion) # NOTE: Until the W/D Analysis is changed to make use of the new blocks, inline control flow for the analysis. + # inline_control_flow_regions(sdfg) + # This is now not necessary anymore. Work-depth analysis still works for inlined SDFGs, but this imposes other names which fails test cases 'unbounded_while_do', 'unbounded_nonnegify', 'break_while_loop' + + #fg.all_sdfgs_recursive(): + # sd.using_explicit_control_flow = False + + analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth, [], False) + res = w_d_map[get_uuid(sdfg)] + # substitue each symbol without assumptions. + # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. + res = (standardize(res[0]), standardize(res[1])) + correct = (standardize(sp.sympify(correct[0])), standardize(sp.sympify(correct[1]))) + # check result + assert res[0].expand() == correct[0].expand() + assert res[1].expand() == correct[1].expand() + +@pytest.mark.parametrize('test_name', list(work_depth_test_cases.keys())) +def test_work_depth_inlined(test_name): + if test_name in ['unbounded_while_do', 'unbounded_nonnegify', 'break_while_loop']: + pytest.skip('Different state naming when ControlFLowRegios are inlined') + + test, correct = work_depth_test_cases[test_name] + w_d_map: Dict[str, sp.Expr] = {} + sdfg = test.to_sdfg() + if 'nested_sdfg' in test.name: + sdfg.apply_transformations(NestSDFG) + if 'nested_maps' in test.name: + sdfg.apply_transformations(MapExpansion) + + # test inline_control_flow_regions(sdfg) + sdfg.save(test_name+".sdfg") for sd in sdfg.all_sdfgs_recursive(): sd.using_explicit_control_flow = False + analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth, [], False) res = w_d_map[get_uuid(sdfg)] # substitue each symbol without assumptions. # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. - reps = {s: sp.Symbol(s.name) for s in (res[0].free_symbols | res[1].free_symbols)} - res = (res[0].subs(reps), res[1].subs(reps)) - reps = {s: sp.Symbol(s.name) for s in (sp.sympify(correct[0]).free_symbols | sp.sympify(correct[1]).free_symbols)} - correct = (sp.sympify(correct[0]).subs(reps), sp.sympify(correct[1]).subs(reps)) + res = (standardize(res[0]), standardize(res[1])) + correct = (standardize(sp.sympify(correct[0])), standardize(sp.sympify(correct[1]))) # check result - assert correct == res - + assert res[0].expand() == correct[0].expand() + assert res[1].expand() == correct[1].expand() #(sdfg, expected_avg_par) tests_cases_avg_par = { @@ -249,13 +291,12 @@ def test_work_depth(test_name): 'unbounded_nonnegify': (unbounded_nonnegify, N), 'break_for_loop': (break_for_loop, N), 'break_while_loop': (break_while_loop, N), - 'reduction_library_node': (reduction_library_node, 456 / sp.log(456)), - 'reduction_library_node_symbolic': (reduction_library_node_symbolic, N / sp.log(N)), - 'gemm_library_node': (gemm_library_node, 2 * 456 * 200 * 111 / sp.log(200)), - 'gemm_library_node_symbolic': (gemm_library_node_symbolic, 2 * M * K * N / sp.log(K)), + 'reduction_library_node': (reduction_library_node, 456 / (sp.log(456)/sp.log(2))), + 'reduction_library_node_symbolic': (reduction_library_node_symbolic, N*sp.log(2)/sp.log(sp.Max(1, N))), + 'gemm_library_node': (gemm_library_node, 2 * 456 * 200 * 111 / (sp.log(200)/sp.log(2))), + 'gemm_library_node_symbolic': (gemm_library_node_symbolic, 2*K*M*N/sp.Max(1, sp.log(sp.Max(1, K))/sp.log(2))), } - @pytest.mark.parametrize('test_name', list(tests_cases_avg_par.keys())) def test_avg_par(test_name: str): if (dc.Config.get_bool('optimizer', 'automatic_simplification') == False @@ -271,12 +312,41 @@ def test_avg_par(test_name: str): sdfg.apply_transformations(MapExpansion) # NOTE: Until the W/D Analysis is changed to make use of the new blocks, inline control flow for the analysis. + # inline_control_flow_regions(sdfg) + # This is now not necessary anymore. Work-depth analysis still works for inlined SDFGs, but this imposes other names which fails test cases 'unbounded_while_do', 'unbounded_nonnegify', 'break_while_loop' + #for sd in sdfg.all_sdfgs_recursive(): + # sd.using_explicit_control_flow = False + analyze_sdfg(sdfg, w_d_map, get_tasklet_avg_par, [], False) + res = w_d_map[get_uuid(sdfg)] + # substitue each symbol without assumptions. + # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. + reps = {s: sp.Symbol(s.name) for s in res.free_symbols} + res = res.subs(reps) + reps = {s: sp.Symbol(s.name) for s in sp.sympify(correct).free_symbols} + correct = sp.sympify(correct).subs(reps) + # check result + assert res.expand() == correct.expand() + + +@pytest.mark.parametrize('test_name', list(tests_cases_avg_par.keys())) +def test_avg_par_inlined(test_name: str): + if test_name in ['unbounded_while_do', 'unbounded_nonnegify', 'break_while_loop']: + pytest.skip('Different state naming when ControlFLowRegios are inlined') + + test, correct = tests_cases_avg_par[test_name] + w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]] = {} + sdfg = test.to_sdfg() + if 'nested_sdfg' in test_name: + sdfg.apply_transformations(NestSDFG) + if 'nested_maps' in test_name: + sdfg.apply_transformations(MapExpansion) + inline_control_flow_regions(sdfg) + for sd in sdfg.all_sdfgs_recursive(): sd.using_explicit_control_flow = False - analyze_sdfg(sdfg, w_d_map, get_tasklet_avg_par, [], False) - res = w_d_map[get_uuid(sdfg)][0] / w_d_map[get_uuid(sdfg)][1] + res = w_d_map[get_uuid(sdfg)] # substitue each symbol without assumptions. # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. reps = {s: sp.Symbol(s.name) for s in res.free_symbols} @@ -284,7 +354,7 @@ def test_avg_par(test_name: str): reps = {s: sp.Symbol(s.name) for s in sp.sympify(correct).free_symbols} correct = sp.sympify(correct).subs(reps) # check result - assert correct == res + assert res.expand() == correct.expand() x, y, z, a = sp.symbols('x y z a') @@ -326,12 +396,10 @@ def test_assumption_system_contradictions(assumptions): with raises(ContradictingAssumptions): parse_assumptions(assumptions, set()) - def test_depth_counter_vs_work_counter(): """ Test that the DepthCounter correctly computes depth (longest chain of dependent operations) which can differ from work (total number of operations). - Depth measures the critical path through the expression tree, while work measures the total number of operations. """