Skip to content

Commit bee92e8

Browse files
committed
added feature to get equivalence classes
1 parent 0e1309f commit bee92e8

4 files changed

Lines changed: 82 additions & 14 deletions

File tree

grape/cli/prune.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from grape.cli import dsl_loader
77
from grape.program import Primitive, Variable
88
from grape.evaluator import Evaluator
9+
from grape.pruning.equivalence_class_manager import EquivalenceClassManager
910
from grape.pruning.obs_equiv_pruner import prune
1011

1112

@@ -78,6 +79,12 @@ def is_python_file(file_path: str) -> bool:
7879
type=str,
7980
help="your starting automaton file",
8081
)
82+
parser.add_argument(
83+
"--classes",
84+
type=str,
85+
default=None,
86+
help="save equivalence classes ina JSON file",
87+
)
8188

8289
return parser.parse_args()
8390

@@ -90,9 +97,11 @@ def main():
9097
inputs = sample_inputs(args.samples, sample_dict, equal_dict)
9198

9299
evaluator = Evaluator(dsl, inputs, equal_dict, skip_exceptions)
100+
manager = EquivalenceClassManager()
93101
grammar, type_req = prune(
94102
dsl,
95103
evaluator,
104+
manager,
96105
args.size,
97106
target_type,
98107
args.automaton,
@@ -110,6 +119,10 @@ def main():
110119

111120
dump_automaton_to_file(grammar, args.output)
112121

122+
if args.classes is not None:
123+
with open(args.classes, "w") as fd:
124+
fd.write(manager.to_json())
125+
113126

114127
if __name__ == "__main__":
115128
main()

grape/pruning/commutativity_pruner.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from grape.dsl import DSL
33
from grape.evaluator import Evaluator
44
from grape.program import Function, Primitive, Program, Variable
5+
from grape.pruning.equivalence_class_manager import EquivalenceClassManager
56
import grape.types as types
67

78
T = TypeVar("T")
@@ -26,15 +27,17 @@ def __produce_all_variants__(
2627
current[i] += 1
2728

2829

29-
def get_rewrites(
30-
dsl: DSL, primitive: str, swapped_indices: tuple[int, int]
31-
) -> list[tuple[Program, Program, str]]:
32-
constraints: list[tuple[Program, Program, str]] = []
33-
stype = dsl[primitive][0]
30+
def __add_rewrite__(
31+
dsl: DSL,
32+
primitive: str,
33+
swapped_indices: tuple[int, int],
34+
manager: EquivalenceClassManager,
35+
):
36+
stype = dsl.primitives[primitive][0]
3437
args_type, _ = types.parse(stype)
3538
swapped_type = args_type[swapped_indices[0]]
3639

37-
nargs = len(types.arguments(dsl.primitives[primitive]))
40+
nargs = len(types.arguments(dsl.primitives[primitive][0]))
3841

3942
for p1, (stype1, _) in dsl.primitives.copy().items():
4043
args1, rtype1 = types.parse(stype1)
@@ -73,7 +76,7 @@ def get_rewrites(
7376
)
7477
equiv_to.arguments[swapped_indices[0]] = second_arg
7578
equiv_to.arguments[swapped_indices[1]] = first_arg
76-
constraints.append((deleted, equiv_to))
79+
manager.add_merge(deleted, equiv_to)
7780
# Add additional constraint for variable type
7881

7982
second_arg = Variable(swapped_indices[0])
@@ -84,14 +87,11 @@ def get_rewrites(
8487
equiv_to = Function(Primitive(primitive), [Variable(i) for i in range(nargs)])
8588
equiv_to.arguments[swapped_indices[0]] = second_arg
8689
equiv_to.arguments[swapped_indices[1]] = first_arg
87-
constraints.append((deleted, equiv_to))
88-
89-
return constraints
90+
manager.add_merge(deleted, equiv_to)
9091

9192

9293
def prune(
93-
dsl: DSL,
94-
evaluator: Evaluator,
94+
dsl: DSL, evaluator: Evaluator, manager: EquivalenceClassManager
9595
) -> list[tuple[str, list[int]]]:
9696
commutatives = []
9797
for prim, (stype, _) in dsl.primitives.items():
@@ -117,8 +117,10 @@ def prune(
117117
continue
118118
variant = Function(Primitive(prim), new_args)
119119
if evaluator.eval(variant, stype) is not None:
120+
swapped = [i for i, x in enumerate(new_args) if x.no != i]
120121
commutatives.append(
121122
(prim, [i for i, x in enumerate(new_args) if x.no != i])
122123
)
124+
__add_rewrite__(dsl, prim, (swapped[0], swapped[1]), manager)
123125
evaluator.clean_memoisation()
124126
return commutatives
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import json
2+
from grape.program import Program
3+
4+
5+
class EquivalenceClassManager:
6+
def __init__(self):
7+
self.classes: dict[Program, set[Program]] = {}
8+
9+
def new_class(self, representative: Program):
10+
"""
11+
Create a new class of equivalence.
12+
Assumes class does not already exist.
13+
"""
14+
assert representative not in self.classes
15+
self.classes[representative] = set()
16+
17+
def add_to_class(self, program: Program, representative: Program):
18+
"""
19+
Add a program to an already existing equivalence class.
20+
Assumes class already exists.
21+
"""
22+
self.classes[representative].add(program)
23+
24+
def add_merge(self, program: Program, representative: Program):
25+
"""
26+
Add a program to an already existing equivalence class.
27+
Assumes nothing so it creates a new class if it does not exist.
28+
"""
29+
if representative not in self.classes:
30+
self.new_class(representative)
31+
self.add_to_class(program, representative)
32+
33+
def to_json(self) -> str:
34+
str_classes = sorted(
35+
[
36+
{"representative": str(key), "elements": list(map(str, value))}
37+
for key, value in self.classes.items()
38+
],
39+
key=lambda x: (x["representative"], len(x["elements"])),
40+
reverse=True,
41+
)
42+
return json.dumps(str_classes)

grape/pruning/obs_equiv_pruner.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616
from grape.automaton.tree_automaton import DFTA
1717
import grape.pruning.commutativity_pruner as commutativity_pruner
18+
from grape.pruning.equivalence_class_manager import EquivalenceClassManager
1819
import grape.types as types
1920

2021
from tqdm import tqdm
@@ -59,13 +60,14 @@ def __get_base_grammar__(
5960
has_base_grammar: bool,
6061
dsl: DSL,
6162
evaluator: Evaluator,
63+
manager: EquivalenceClassManager,
6264
max_size: int,
6365
base_automaton_file: str,
6466
type_req: str,
6567
):
6668
base_grammar = grammar_by_saturation(dsl, type_req)
6769
if not has_base_grammar:
68-
commutatives = commutativity_pruner.prune(dsl, evaluator)
70+
commutatives = commutativity_pruner.prune(dsl, evaluator, manager)
6971
grammar = grammar_by_saturation(
7072
dsl,
7173
type_req,
@@ -89,6 +91,7 @@ def __get_base_grammar__(
8991
def prune(
9092
dsl: DSL,
9193
evaluator: Evaluator,
94+
manager: EquivalenceClassManager,
9295
max_size: int,
9396
rtype: str | None,
9497
base_automaton_file: str,
@@ -102,7 +105,13 @@ def prune(
102105
base_automaton_file is None or len(base_automaton_file) == 0
103106
)
104107
grammar, base_expected_trees, enum_ntrees = __get_base_grammar__(
105-
has_base_grammar, dsl, evaluator, max_size, base_automaton_file, type_req
108+
has_base_grammar,
109+
dsl,
110+
evaluator,
111+
manager,
112+
max_size,
113+
base_automaton_file,
114+
type_req,
106115
)
107116
base_ntrees = sum(base_expected_trees.values())
108117

@@ -142,6 +151,8 @@ def estimate_total(size: int) -> tuple[int, float]:
142151
program = gen.send(should_keep)
143152
representative = evaluator.eval(program, type_req)
144153
should_keep = representative is None
154+
if not should_keep:
155+
manager.add_merge(program, representative)
145156
n += 1
146157
if n & 15 == 0:
147158
pbar.update(16)

0 commit comments

Comments
 (0)