11import itertools
22
33from grape import types
4- from grape .automaton import spec_manager
54from grape .automaton .tree_automaton import DFTA
65from grape .dsl import DSL
76from grape .program import Function , Primitive , Program , Variable
@@ -18,45 +17,71 @@ def __can_states_merge(
1817 reversed_rules : dict [tuple [str , tuple [str , ...]], str ],
1918 original : str ,
2019 candidate : str ,
20+ merge_memory : dict [(str , str ), bool ],
21+ state_to_letter : dict [str , tuple [str , bool ]],
2122) -> bool :
22- if __state2letter__ (candidate ) != __state2letter__ (original ) and not str (
23- __state2letter__ (candidate )
24- ).startswith ("var" ):
25- return False
26- for P1 , args1 in reversed_rules [original ]:
27- has_equivalent = False
28- for P2 , args2 in reversed_rules [candidate ]:
29- if all (
30- __can_states_merge (reversed_rules , arg1 , arg2 )
31- for arg1 , arg2 in zip (args1 , args2 )
32- ):
33- has_equivalent = True
34- break
35- if not has_equivalent :
23+ res = merge_memory .get ((original , candidate ))
24+ if res is None :
25+ lc = state_to_letter [candidate ]
26+ if lc [0 ] != state_to_letter [original ][0 ] and not lc [1 ]:
27+ merge_memory [(original , candidate )] = False
28+ merge_memory [(candidate , original )] = False
3629 return False
37- return True
30+ for P1 , args1 in reversed_rules [original ]:
31+ has_equivalent = False
32+ for P2 , args2 in reversed_rules [candidate ]:
33+ if all (
34+ __can_states_merge (
35+ reversed_rules , arg1 , arg2 , merge_memory , state_to_letter
36+ )
37+ for arg1 , arg2 in zip (args1 , args2 )
38+ if arg1 != arg2
39+ ):
40+ has_equivalent = True
41+ break
42+ if not has_equivalent :
43+ merge_memory [(original , candidate )] = False
44+ merge_memory [(candidate , original )] = False
45+ return False
46+ merge_memory [(original , candidate )] = True
47+ merge_memory [(candidate , original )] = True
48+ return True
49+ else :
50+ return res
3851
3952
4053def __find_merge__ (
41- dfta : DFTA [str , str ], P : str , args : tuple [str , ...], candidates : set [str ]
54+ dfta : DFTA [str , str ],
55+ P : str ,
56+ args : tuple [str , ...],
57+ candidates : set [str ],
58+ merge_memory : dict [(str , str ), bool ],
59+ state_to_letter : dict [str , tuple [str , bool ]],
60+ state_to_size : dict [str , int ],
4261) -> str | None :
4362 best_candidate = None
63+ size_best = - 1
4464 for candidate in candidates :
45- if __state2letter__ ( candidate ) != P and not str (
46- __state2letter__ ( candidate )
47- ). startswith ( "var" ) :
65+ if state_to_size [ candidate ] <= size_best :
66+ break
67+ elif state_to_letter [ candidate ][ 0 ] != P and not state_to_letter [ candidate ][ 1 ] :
4868 continue
4969 has_equivalent = False
5070 for P2 , args2 in dfta .reversed_rules [candidate ]:
5171 if all (
52- __can_states_merge (dfta .reversed_rules , arg1 , arg2 )
72+ __can_states_merge (
73+ dfta .reversed_rules , arg1 , arg2 , merge_memory , state_to_letter
74+ )
5375 for arg1 , arg2 in zip (args , args2 )
76+ if arg1 != arg2
5477 ):
5578 has_equivalent = True
5679 break
80+
5781 if has_equivalent and (
58- best_candidate is None or best_candidate . count ( " " ) < candidate . count ( " " )
82+ best_candidate is None or size_best < state_to_size [ candidate ]
5983 ):
84+ size_best = state_to_size [candidate ]
6085 best_candidate = candidate
6186 return best_candidate
6287
@@ -81,9 +106,17 @@ def add_loops(
81106 else :
82107 state_to_type = dsl .get_state_types (dfta )
83108 state_to_size = {s : s .count (" " ) + 1 for s in dfta .all_states }
109+ state_to_letter = {
110+ s : (__state2letter__ (s ), __state2letter__ (s ).startswith ("var" ))
111+ for s in dfta .all_states
112+ }
84113 max_size = max (state_to_size .values ())
85114 states_by_types = {
86- t : set (s for s , st in state_to_type .items () if st == t )
115+ t : sorted (
116+ [s for s , st in state_to_type .items () if st == t ],
117+ reverse = True ,
118+ key = lambda s : state_to_size [s ],
119+ )
87120 for t in set (state_to_type .values ())
88121 }
89122 added = True
@@ -102,13 +135,16 @@ def add_loops(
102135 virtual_vars .add (max_varno )
103136 dst = str (Variable (max_varno ))
104137 new_dfta .rules [(Variable (max_varno ), tuple ())] = dst
105- states_by_types [t ].add (dst )
138+ states_by_types [t ].append (dst )
106139 state_to_size [dst ] = 1
140+ state_to_letter [dst ] = (dst , True )
107141 max_varno += 1
108142 new_dfta .refresh_reversed_rules ()
143+ merge_memory = {}
109144 while added :
110145 added = False
111146 for P , (Ptype , _ ) in dsl .primitives .items ():
147+ rtype = types .return_type (dsl .get_type (P ))
112148 possibles = [states_by_types [arg_t ] for arg_t in types .arguments (Ptype )]
113149 for combi in itertools .product (* possibles ):
114150 key = (P , combi )
@@ -120,13 +156,18 @@ def add_loops(
120156 and max (args_size ) >= max_size - len (args_size ) + 1
121157 ):
122158 added = True
123- rtype = types . return_type ( dsl . get_type ( P ))
159+
124160 dst = Function (Primitive (P ), list (map (Primitive , combi )))
125161 new_state = __find_merge__ (
126- new_dfta , P , combi , states_by_types [rtype ]
162+ new_dfta ,
163+ P ,
164+ combi ,
165+ states_by_types [rtype ],
166+ merge_memory ,
167+ state_to_letter ,
168+ state_to_size ,
127169 ) or str (dst )
128170 new_dfta .rules [key ] = new_state
129- states_by_types [rtype ].add (new_state )
130171 assert new_state in state_to_size
131172 new_dfta .refresh_reversed_rules ()
132173
0 commit comments