22from grape .dsl import DSL
33from grape .evaluator import Evaluator
44from grape .program import Function , Primitive , Program , Variable
5+ from grape .pruning .equivalence_class_manager import EquivalenceClassManager
56import grape .types as types
67
78T = TypeVar ("T" )
@@ -26,15 +27,17 @@ def __produce_all_variants__(
2627 current [i ] += 1
2728
2829
29- def get_rewrites (
30- dsl : DSL , primitive : str , swapped_indices : tuple [int , int ]
31- ) -> list [tuple [Program , Program , str ]]:
32- constraints : list [tuple [Program , Program , str ]] = []
33- stype = dsl [primitive ][0 ]
30+ def __add_rewrite__ (
31+ dsl : DSL ,
32+ primitive : str ,
33+ swapped_indices : tuple [int , int ],
34+ manager : EquivalenceClassManager ,
35+ ):
36+ stype = dsl .primitives [primitive ][0 ]
3437 args_type , _ = types .parse (stype )
3538 swapped_type = args_type [swapped_indices [0 ]]
3639
37- nargs = len (types .arguments (dsl .primitives [primitive ]))
40+ nargs = len (types .arguments (dsl .primitives [primitive ][ 0 ] ))
3841
3942 for p1 , (stype1 , _ ) in dsl .primitives .copy ().items ():
4043 args1 , rtype1 = types .parse (stype1 )
@@ -73,7 +76,7 @@ def get_rewrites(
7376 )
7477 equiv_to .arguments [swapped_indices [0 ]] = second_arg
7578 equiv_to .arguments [swapped_indices [1 ]] = first_arg
76- constraints . append (( deleted , equiv_to ) )
79+ manager . add_merge ( deleted , equiv_to )
7780 # Add additional constraint for variable type
7881
7982 second_arg = Variable (swapped_indices [0 ])
@@ -84,14 +87,11 @@ def get_rewrites(
8487 equiv_to = Function (Primitive (primitive ), [Variable (i ) for i in range (nargs )])
8588 equiv_to .arguments [swapped_indices [0 ]] = second_arg
8689 equiv_to .arguments [swapped_indices [1 ]] = first_arg
87- constraints .append ((deleted , equiv_to ))
88-
89- return constraints
90+ manager .add_merge (deleted , equiv_to )
9091
9192
9293def prune (
93- dsl : DSL ,
94- evaluator : Evaluator ,
94+ dsl : DSL , evaluator : Evaluator , manager : EquivalenceClassManager
9595) -> list [tuple [str , list [int ]]]:
9696 commutatives = []
9797 for prim , (stype , _ ) in dsl .primitives .items ():
@@ -117,8 +117,10 @@ def prune(
117117 continue
118118 variant = Function (Primitive (prim ), new_args )
119119 if evaluator .eval (variant , stype ) is not None :
120+ swapped = [i for i , x in enumerate (new_args ) if x .no != i ]
120121 commutatives .append (
121122 (prim , [i for i , x in enumerate (new_args ) if x .no != i ])
122123 )
124+ __add_rewrite__ (dsl , prim , (swapped [0 ], swapped [1 ]), manager )
123125 evaluator .clean_memoisation ()
124126 return commutatives
0 commit comments