Skip to content

Commit 04c76fe

Browse files
committed
idiomatic python => speed-up
1 parent 22ab85c commit 04c76fe

1 file changed

Lines changed: 26 additions & 26 deletions

File tree

grape/evaluator.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import defaultdict
12
import random
23
from typing import Any, Callable, Generator, Optional
34
from grape.dsl import DSL
@@ -22,8 +23,8 @@ def __init__(
2223
seed: int = 1,
2324
):
2425
self.dsl = dsl
25-
self.equiv_classes: dict[str, dict[Any, Program]] = {}
26-
self.memoization: dict[Program, dict[Any, Any]] = {}
26+
self.equiv_classes: dict[str, dict[Any, Program]] = defaultdict(dict)
27+
self.memoization: dict[Program, dict[Any, Any]] = defaultdict(dict)
2728
self.rtypes: dict[str, str] = {
2829
p: types.return_type(stype) for p, (stype, _) in dsl.primitives.items()
2930
}
@@ -58,14 +59,15 @@ def __gen_full_inputs__(self, type_req: str) -> None:
5859
self.full_inputs[type_req] = list(elems)
5960

6061
def __return_type__(self, program: Program, type_req: str) -> str:
61-
if isinstance(program, Variable):
62-
return types.arguments(type_req)[program.no]
63-
elif isinstance(program, Primitive):
64-
return self.rtypes[program.name]
65-
elif isinstance(program, Function):
66-
return self.rtypes[program.function.name]
67-
else:
68-
raise ValueError
62+
match program:
63+
case Variable(no):
64+
return types.arguments(type_req)[no]
65+
case Primitive(name):
66+
return self.rtypes[name]
67+
case Function(func):
68+
return self.rtypes[func.name]
69+
case _:
70+
raise ValueError
6971

7072
def eval(self, program: Program, type_req: str) -> Optional[Program]:
7173
if program in self.memoization:
@@ -85,8 +87,6 @@ def eval(self, program: Program, type_req: str) -> Optional[Program]:
8587
# Check equivalence class
8688
rtype = self.__return_type__(program, type_req)
8789
key = tuple(outs)
88-
if rtype not in self.equiv_classes:
89-
self.equiv_classes[rtype] = {}
9090
representative = self.equiv_classes[rtype].get(key, None)
9191
if representative is None:
9292
self.equiv_classes[rtype][key] = program
@@ -95,20 +95,20 @@ def eval(self, program: Program, type_req: str) -> Optional[Program]:
9595
return representative
9696

9797
def __eval__(self, program: Program, full_input: tuple[Any, ...]) -> Any:
98-
if program in self.memoization:
99-
if full_input in self.memoization[program]:
100-
return self.memoization[program][full_input]
101-
else:
102-
self.memoization[program] = {}
98+
mem = self.memoization[program]
99+
if full_input in mem:
100+
return mem[full_input]
103101
# Compute value
104102
out = None
105-
if isinstance(program, Variable):
106-
out = full_input[program.no]
107-
elif isinstance(program, Primitive):
108-
out = self.dsl.semantic(program.name)
109-
elif isinstance(program, Function):
110-
fun = self.__eval__(program.function, full_input)
111-
arg_vals = [self.__eval__(arg, full_input) for arg in program.arguments]
112-
out = fun(*arg_vals)
113-
self.memoization[program][full_input] = out
103+
match program:
104+
case Variable(no):
105+
out = full_input[no]
106+
case Primitive(name):
107+
out = self.dsl.semantic(name)
108+
case Function(func):
109+
fun = self.dsl.semantic(func.name)
110+
arg_vals = [self.__eval__(arg, full_input) for arg in program.arguments]
111+
out = fun(*arg_vals)
112+
113+
mem[full_input] = out
114114
return out

0 commit comments

Comments
 (0)