1+ from collections import defaultdict
12from enum import StrEnum
23import itertools
34from typing import Generator
@@ -107,13 +108,15 @@ def __get_largest_merges__(
107108 dfta : DFTA [str , str ],
108109 state_to_letter : dict [str , tuple [str , bool ]],
109110 state_to_size : dict [str , int ],
110- merge_memory : dict [( str , str ) , bool ],
111+ merge_memory : dict [tuple [ str , str ] , bool ],
111112 largest_merge : dict [str , str ],
112- states_by_types : dict [str , list [str ]],
113+ states_by_types_and_letter : dict [tuple [ str , str ] , list [str ]],
113114) -> list [str ]:
114115 res = largest_merge .get (state , None )
115116 if res is None :
116- candidates = [S for S in states_by_types .values () if state in S ].pop (0 )
117+ candidates = [S for S in states_by_types_and_letter .values () if state in S ].pop (
118+ 0
119+ )
117120 out = []
118121 size = - 1
119122 for candidate in candidates :
@@ -135,9 +138,9 @@ def __all_sub_args__(
135138 dfta : DFTA [str , str ],
136139 state_to_letter : dict [str , tuple [str , bool ]],
137140 state_to_size : dict [str , int ],
138- merge_memory : dict [( str , str ) , bool ],
141+ merge_memory : dict [tuple [ str , str ] , bool ],
139142 largest_merge : dict [str , str ],
140- states_by_types : dict [str , list [str ]],
143+ states_by_types_and_letter : dict [tuple [ str , str ] , list [str ]],
141144) -> Generator [str , None , None ]:
142145 possibles = list (
143146 map (
@@ -148,7 +151,7 @@ def __all_sub_args__(
148151 state_to_size ,
149152 merge_memory ,
150153 largest_merge ,
151- states_by_types ,
154+ states_by_types_and_letter ,
152155 ),
153156 combi ,
154157 )
@@ -191,9 +194,9 @@ def is_allowed(
191194 dfta : DFTA [str , str | Program ],
192195 state_to_letter : dict [str , tuple [str , bool ]],
193196 state_to_size : dict [str , int ],
194- merge_memory : dict [( str , str ) , bool ],
197+ merge_memory : dict [tuple [ str , str ] , bool ],
195198 largest_merge : dict [str , str ],
196- states_by_types : dict [str , list [str ]],
199+ states_by_types_and_letter : dict [tuple [ str , str ] , list [str ]],
197200 ) -> bool :
198201 return all (
199202 (P , sub_args ) in new_dfta
@@ -204,7 +207,7 @@ def is_allowed(
204207 state_to_size ,
205208 merge_memory ,
206209 largest_merge ,
207- states_by_types ,
210+ states_by_types_and_letter ,
208211 )
209212 if sum (map (lambda x : state_to_size [x ], sub_args )) + 1
210213 <= max_size
@@ -248,7 +251,19 @@ def is_allowed(
248251 new_dfta .refresh_reversed_rules ()
249252 merge_memory = {}
250253 largest_merge = {}
251-
254+ states_by_types_and_letter = defaultdict (list )
255+ for t , states in states_by_types .items ():
256+ later = []
257+ for s in states :
258+ if state_to_letter [s ][1 ]:
259+ later .append (s )
260+ else :
261+ key = (t , state_to_letter [s ][0 ])
262+ states_by_types_and_letter [key ].append (s )
263+ for (tt , _ ), val in states_by_types_and_letter .items ():
264+ if tt == t :
265+ for x in later :
266+ val .append (x )
252267 update = lambda : 1
253268 if use_tqdm :
254269 pbar = tqdm (
@@ -279,14 +294,14 @@ def is_allowed(
279294 state_to_size ,
280295 merge_memory ,
281296 largest_merge ,
282- states_by_types ,
297+ states_by_types_and_letter ,
283298 ):
284299 assert key not in new_dfta .rules
285300 new_state = __find_merge__ (
286301 new_dfta ,
287302 P ,
288303 combi ,
289- states_by_types [ rtype ],
304+ states_by_types_and_letter [( rtype , P ) ],
290305 merge_memory ,
291306 state_to_letter ,
292307 state_to_size ,
0 commit comments