1+ from collections import defaultdict
12import random
23from typing import Any , Callable , Generator , Optional
34from grape .dsl import DSL
@@ -22,8 +23,8 @@ def __init__(
2223 seed : int = 1 ,
2324 ):
2425 self .dsl = dsl
25- self .equiv_classes : dict [str , dict [Any , Program ]] = {}
26- self .memoization : dict [Program , dict [Any , Any ]] = {}
26+ self .equiv_classes : dict [str , dict [Any , Program ]] = defaultdict ( dict )
27+ self .memoization : dict [Program , dict [Any , Any ]] = defaultdict ( dict )
2728 self .rtypes : dict [str , str ] = {
2829 p : types .return_type (stype ) for p , (stype , _ ) in dsl .primitives .items ()
2930 }
@@ -58,14 +59,15 @@ def __gen_full_inputs__(self, type_req: str) -> None:
5859 self .full_inputs [type_req ] = list (elems )
5960
6061 def __return_type__ (self , program : Program , type_req : str ) -> str :
61- if isinstance (program , Variable ):
62- return types .arguments (type_req )[program .no ]
63- elif isinstance (program , Primitive ):
64- return self .rtypes [program .name ]
65- elif isinstance (program , Function ):
66- return self .rtypes [program .function .name ]
67- else :
68- raise ValueError
62+ match program :
63+ case Variable (no ):
64+ return types .arguments (type_req )[no ]
65+ case Primitive (name ):
66+ return self .rtypes [name ]
67+ case Function (func ):
68+ return self .rtypes [func .name ]
69+ case _:
70+ raise ValueError
6971
7072 def eval (self , program : Program , type_req : str ) -> Optional [Program ]:
7173 if program in self .memoization :
@@ -85,8 +87,6 @@ def eval(self, program: Program, type_req: str) -> Optional[Program]:
8587 # Check equivalence class
8688 rtype = self .__return_type__ (program , type_req )
8789 key = tuple (outs )
88- if rtype not in self .equiv_classes :
89- self .equiv_classes [rtype ] = {}
9090 representative = self .equiv_classes [rtype ].get (key , None )
9191 if representative is None :
9292 self .equiv_classes [rtype ][key ] = program
@@ -95,20 +95,20 @@ def eval(self, program: Program, type_req: str) -> Optional[Program]:
9595 return representative
9696
9797 def __eval__ (self , program : Program , full_input : tuple [Any , ...]) -> Any :
98- if program in self .memoization :
99- if full_input in self .memoization [program ]:
100- return self .memoization [program ][full_input ]
101- else :
102- self .memoization [program ] = {}
98+ mem = self .memoization [program ]
99+ if full_input in mem :
100+ return mem [full_input ]
103101 # Compute value
104102 out = None
105- if isinstance (program , Variable ):
106- out = full_input [program .no ]
107- elif isinstance (program , Primitive ):
108- out = self .dsl .semantic (program .name )
109- elif isinstance (program , Function ):
110- fun = self .__eval__ (program .function , full_input )
111- arg_vals = [self .__eval__ (arg , full_input ) for arg in program .arguments ]
112- out = fun (* arg_vals )
113- self .memoization [program ][full_input ] = out
103+ match program :
104+ case Variable (no ):
105+ out = full_input [no ]
106+ case Primitive (name ):
107+ out = self .dsl .semantic (name )
108+ case Function (func ):
109+ fun = self .dsl .semantic (func .name )
110+ arg_vals = [self .__eval__ (arg , full_input ) for arg in program .arguments ]
111+ out = fun (* arg_vals )
112+
113+ mem [full_input ] = out
114114 return out
0 commit comments