Skip to content

Commit ea17a43

Browse files
committed
uniformized DFTA letter type
1 parent 6cee2e7 commit ea17a43

1 file changed

Lines changed: 14 additions & 16 deletions

File tree

grape/automaton/loop_manager.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def __all_sub_args__(
158158

159159

160160
def add_loops(
161-
dfta: DFTA[str, str],
161+
dfta: DFTA[str, Program | str],
162162
dsl: DSL,
163163
algorithm: LoopingAlgorithm = LoopingAlgorithm.OBSERVATIONAL_EQUIVALENCE,
164164
) -> DFTA[str, Program]:
@@ -200,6 +200,7 @@ def is_allowed(
200200
<= max_size
201201
)
202202

203+
dfta = dfta.map_alphabet(str)
203204
state_to_type = dsl.get_state_types(dfta)
204205
state_to_size = {s: s.count(" ") + 1 for s in dfta.all_states}
205206
state_to_letter = {
@@ -229,7 +230,7 @@ def is_allowed(
229230
if all(not state_to_letter[s][1] for s in states):
230231
virtual_vars.add(max_varno)
231232
dst = str(Variable(max_varno))
232-
new_dfta.rules[(Variable(max_varno), tuple())] = dst
233+
new_dfta.rules[(dst, tuple())] = dst
233234
states_by_types[t].append(dst)
234235
state_to_size[dst] = 1
235236
state_to_letter[dst] = (dst, True)
@@ -243,20 +244,17 @@ def is_allowed(
243244
for combi in itertools.product(*possibles):
244245
key = (P, combi)
245246
dst_size = sum(map(lambda x: state_to_size[x], combi)) + 1
246-
if dst_size > max_size:
247+
if dst_size > max_size and is_allowed(
248+
P,
249+
combi,
250+
dfta,
251+
state_to_letter,
252+
state_to_size,
253+
merge_memory,
254+
largest_merge,
255+
states_by_types,
256+
):
247257
assert key not in new_dfta.rules
248-
dst = Function(Primitive(P), list(map(Primitive, combi)))
249-
if not is_allowed(
250-
P,
251-
combi,
252-
dfta,
253-
state_to_letter,
254-
state_to_size,
255-
merge_memory,
256-
largest_merge,
257-
states_by_types,
258-
):
259-
continue
260258
new_state = __find_merge__(
261259
new_dfta,
262260
P,
@@ -270,7 +268,7 @@ def is_allowed(
270268
new_dfta.rules[key] = new_state
271269

272270
for no in virtual_vars:
273-
dst = Variable(no)
271+
dst = str(Variable(no))
274272
del new_dfta.rules[(dst, tuple())]
275273

276274
new_dfta.reduce()

0 commit comments

Comments
 (0)