Skip to content

Commit b180485

Browse files
committed
speed-up by memoization
1 parent 04c76fe commit b180485

1 file changed

Lines changed: 68 additions & 27 deletions

File tree

grape/automaton/loop_manager.py

Lines changed: 68 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import itertools
22

33
from grape import types
4-
from grape.automaton import spec_manager
54
from grape.automaton.tree_automaton import DFTA
65
from grape.dsl import DSL
76
from grape.program import Function, Primitive, Program, Variable
@@ -18,45 +17,71 @@ def __can_states_merge(
1817
reversed_rules: dict[tuple[str, tuple[str, ...]], str],
1918
original: str,
2019
candidate: str,
20+
merge_memory: dict[(str, str), bool],
21+
state_to_letter: dict[str, tuple[str, bool]],
2122
) -> bool:
22-
if __state2letter__(candidate) != __state2letter__(original) and not str(
23-
__state2letter__(candidate)
24-
).startswith("var"):
25-
return False
26-
for P1, args1 in reversed_rules[original]:
27-
has_equivalent = False
28-
for P2, args2 in reversed_rules[candidate]:
29-
if all(
30-
__can_states_merge(reversed_rules, arg1, arg2)
31-
for arg1, arg2 in zip(args1, args2)
32-
):
33-
has_equivalent = True
34-
break
35-
if not has_equivalent:
23+
res = merge_memory.get((original, candidate))
24+
if res is None:
25+
lc = state_to_letter[candidate]
26+
if lc[0] != state_to_letter[original][0] and not lc[1]:
27+
merge_memory[(original, candidate)] = False
28+
merge_memory[(candidate, original)] = False
3629
return False
37-
return True
30+
for P1, args1 in reversed_rules[original]:
31+
has_equivalent = False
32+
for P2, args2 in reversed_rules[candidate]:
33+
if all(
34+
__can_states_merge(
35+
reversed_rules, arg1, arg2, merge_memory, state_to_letter
36+
)
37+
for arg1, arg2 in zip(args1, args2)
38+
if arg1 != arg2
39+
):
40+
has_equivalent = True
41+
break
42+
if not has_equivalent:
43+
merge_memory[(original, candidate)] = False
44+
merge_memory[(candidate, original)] = False
45+
return False
46+
merge_memory[(original, candidate)] = True
47+
merge_memory[(candidate, original)] = True
48+
return True
49+
else:
50+
return res
3851

3952

4053
def __find_merge__(
41-
dfta: DFTA[str, str], P: str, args: tuple[str, ...], candidates: set[str]
54+
dfta: DFTA[str, str],
55+
P: str,
56+
args: tuple[str, ...],
57+
candidates: set[str],
58+
merge_memory: dict[(str, str), bool],
59+
state_to_letter: dict[str, tuple[str, bool]],
60+
state_to_size: dict[str, int],
4261
) -> str | None:
4362
best_candidate = None
63+
size_best = -1
4464
for candidate in candidates:
45-
if __state2letter__(candidate) != P and not str(
46-
__state2letter__(candidate)
47-
).startswith("var"):
65+
if state_to_size[candidate] <= size_best:
66+
break
67+
elif state_to_letter[candidate][0] != P and not state_to_letter[candidate][1]:
4868
continue
4969
has_equivalent = False
5070
for P2, args2 in dfta.reversed_rules[candidate]:
5171
if all(
52-
__can_states_merge(dfta.reversed_rules, arg1, arg2)
72+
__can_states_merge(
73+
dfta.reversed_rules, arg1, arg2, merge_memory, state_to_letter
74+
)
5375
for arg1, arg2 in zip(args, args2)
76+
if arg1 != arg2
5477
):
5578
has_equivalent = True
5679
break
80+
5781
if has_equivalent and (
58-
best_candidate is None or best_candidate.count(" ") < candidate.count(" ")
82+
best_candidate is None or size_best < state_to_size[candidate]
5983
):
84+
size_best = state_to_size[candidate]
6085
best_candidate = candidate
6186
return best_candidate
6287

@@ -81,9 +106,17 @@ def add_loops(
81106
else:
82107
state_to_type = dsl.get_state_types(dfta)
83108
state_to_size = {s: s.count(" ") + 1 for s in dfta.all_states}
109+
state_to_letter = {
110+
s: (__state2letter__(s), __state2letter__(s).startswith("var"))
111+
for s in dfta.all_states
112+
}
84113
max_size = max(state_to_size.values())
85114
states_by_types = {
86-
t: set(s for s, st in state_to_type.items() if st == t)
115+
t: sorted(
116+
[s for s, st in state_to_type.items() if st == t],
117+
reverse=True,
118+
key=lambda s: state_to_size[s],
119+
)
87120
for t in set(state_to_type.values())
88121
}
89122
added = True
@@ -102,13 +135,16 @@ def add_loops(
102135
virtual_vars.add(max_varno)
103136
dst = str(Variable(max_varno))
104137
new_dfta.rules[(Variable(max_varno), tuple())] = dst
105-
states_by_types[t].add(dst)
138+
states_by_types[t].append(dst)
106139
state_to_size[dst] = 1
140+
state_to_letter[dst] = (dst, True)
107141
max_varno += 1
108142
new_dfta.refresh_reversed_rules()
143+
merge_memory = {}
109144
while added:
110145
added = False
111146
for P, (Ptype, _) in dsl.primitives.items():
147+
rtype = types.return_type(dsl.get_type(P))
112148
possibles = [states_by_types[arg_t] for arg_t in types.arguments(Ptype)]
113149
for combi in itertools.product(*possibles):
114150
key = (P, combi)
@@ -120,13 +156,18 @@ def add_loops(
120156
and max(args_size) >= max_size - len(args_size) + 1
121157
):
122158
added = True
123-
rtype = types.return_type(dsl.get_type(P))
159+
124160
dst = Function(Primitive(P), list(map(Primitive, combi)))
125161
new_state = __find_merge__(
126-
new_dfta, P, combi, states_by_types[rtype]
162+
new_dfta,
163+
P,
164+
combi,
165+
states_by_types[rtype],
166+
merge_memory,
167+
state_to_letter,
168+
state_to_size,
127169
) or str(dst)
128170
new_dfta.rules[key] = new_state
129-
states_by_types[rtype].add(new_state)
130171
assert new_state in state_to_size
131172
new_dfta.refresh_reversed_rules()
132173

0 commit comments

Comments
 (0)