Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions Test/Passes/Combines/shift5_li_to_shift5i.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// RUN: veir-opt %s -p=riscv-combine | filecheck %s

"builtin.module"() ({
// riscv.sllw x (riscv.li imm) -> riscv.slliw x imm
"func.func"() <{function_type = (!riscv.reg) -> !riscv.reg}> ({
^bb(%a : !riscv.reg):
// CHECK: ^{{.*}}([[a:%.*]] : !riscv.reg):
%c = "riscv.li"() <{"value" = 1 : i64}>: () -> !riscv.reg
%r = "riscv.sllw"(%a, %c) : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK: [[r:%.*]] = "riscv.slliw"([[a]]) <{"value" = 1 : i64}> : (!riscv.reg) -> !riscv.reg
"func.return"(%r) : (!riscv.reg) -> ()
// CHECK: "func.return"([[r]]) : (!riscv.reg) -> ()
}) : () -> ()

// riscv.srlw x (riscv.li imm) -> riscv.srliw x imm
"func.func"() <{function_type = (!riscv.reg) -> !riscv.reg}> ({
^bb(%a : !riscv.reg):
// CHECK: ^{{.*}}([[a2:%.*]] : !riscv.reg):
%c = "riscv.li"() <{"value" = 7 : i64}>: () -> !riscv.reg
%r = "riscv.srlw"(%a, %c) : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK: [[r2:%.*]] = "riscv.srliw"([[a2]]) <{"value" = 7 : i64}> : (!riscv.reg) -> !riscv.reg
"func.return"(%r) : (!riscv.reg) -> ()
// CHECK: "func.return"([[r2]]) : (!riscv.reg) -> ()
}) : () -> ()

// riscv.sraw x (riscv.li imm) -> riscv.sraiw x imm (max in-range shamt)
"func.func"() <{function_type = (!riscv.reg) -> !riscv.reg}> ({
^bb(%a : !riscv.reg):
// CHECK: ^{{.*}}([[a3:%.*]] : !riscv.reg):
%c = "riscv.li"() <{"value" = 31 : i64}>: () -> !riscv.reg
%r = "riscv.sraw"(%a, %c) : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK: [[r3:%.*]] = "riscv.sraiw"([[a3]]) <{"value" = 31 : i64}> : (!riscv.reg) -> !riscv.reg
"func.return"(%r) : (!riscv.reg) -> ()
// CHECK: "func.return"([[r3]]) : (!riscv.reg) -> ()
}) : () -> ()

// riscv.rorw x (riscv.li imm) -> riscv.roriw x imm
"func.func"() <{function_type = (!riscv.reg) -> !riscv.reg}> ({
^bb(%a : !riscv.reg):
// CHECK: ^{{.*}}([[a4:%.*]] : !riscv.reg):
%c = "riscv.li"() <{"value" = 12 : i64}>: () -> !riscv.reg
%r = "riscv.rorw"(%a, %c) : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK: [[r4:%.*]] = "riscv.roriw"([[a4]]) <{"value" = 12 : i64}> : (!riscv.reg) -> !riscv.reg
"func.return"(%r) : (!riscv.reg) -> ()
// CHECK: "func.return"([[r4]]) : (!riscv.reg) -> ()
}) : () -> ()

// Non-commutative: li on the left is the shifted value, not the amount, so
// riscv.sllw (riscv.li imm) x must NOT fold.
"func.func"() <{function_type = (!riscv.reg) -> !riscv.reg}> ({
^bb(%a : !riscv.reg):
// CHECK: ^{{.*}}([[a5:%.*]] : !riscv.reg):
%c = "riscv.li"() <{"value" = 4 : i64}>: () -> !riscv.reg
%r = "riscv.sllw"(%c, %a) : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK: [[v:%.*]] = "riscv.li"() <{"value" = 4 : i64}> : () -> !riscv.reg
// CHECK: [[r5:%.*]] = "riscv.sllw"([[v]], [[a5]]) : (!riscv.reg, !riscv.reg) -> !riscv.reg
"func.return"(%r) : (!riscv.reg) -> ()
// CHECK: "func.return"([[r5]]) : (!riscv.reg) -> ()
}) : () -> ()

// Shift amount out of the unsigned 5-bit range (32): not folded.
"func.func"() <{function_type = (!riscv.reg) -> !riscv.reg}> ({
^bb(%a : !riscv.reg):
// CHECK: ^{{.*}}([[a6:%.*]] : !riscv.reg):
%c = "riscv.li"() <{"value" = 32 : i64}>: () -> !riscv.reg
%r = "riscv.srlw"(%a, %c) : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK: [[big:%.*]] = "riscv.li"() <{"value" = 32 : i64}> : () -> !riscv.reg
// CHECK: [[r6:%.*]] = "riscv.srlw"([[a6]], [[big]]) : (!riscv.reg, !riscv.reg) -> !riscv.reg
"func.return"(%r) : (!riscv.reg) -> ()
// CHECK: "func.return"([[r6]]) : (!riscv.reg) -> ()
}) : () -> ()

// Negative shift amount (-1): not folded.
"func.func"() <{function_type = (!riscv.reg) -> !riscv.reg}> ({
^bb(%a : !riscv.reg):
// CHECK: ^{{.*}}([[a7:%.*]] : !riscv.reg):
%c = "riscv.li"() <{"value" = -1 : i64}>: () -> !riscv.reg
%r = "riscv.sraw"(%a, %c) : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK: [[neg:%.*]] = "riscv.li"() <{"value" = -1 : i64}> : () -> !riscv.reg
// CHECK: [[r7:%.*]] = "riscv.sraw"([[a7]], [[neg]]) : (!riscv.reg, !riscv.reg) -> !riscv.reg
"func.return"(%r) : (!riscv.reg) -> ()
// CHECK: "func.return"([[r7]]) : (!riscv.reg) -> ()
}) : () -> ()
}) : () -> ()
116 changes: 116 additions & 0 deletions Test/Passes/Combines/shift6_li_to_shift6i.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// RUN: veir-opt %s -p=riscv-combine | filecheck %s

"builtin.module"() ({
// riscv.sll x (riscv.li imm) -> riscv.slli x imm
"func.func"() <{function_type = (!riscv.reg) -> !riscv.reg}> ({
^bb(%a : !riscv.reg):
// CHECK: ^{{.*}}([[a:%.*]] : !riscv.reg):
%c = "riscv.li"() <{"value" = 1 : i64}>: () -> !riscv.reg
%r = "riscv.sll"(%a, %c) : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK: [[r:%.*]] = "riscv.slli"([[a]]) <{"value" = 1 : i64}> : (!riscv.reg) -> !riscv.reg
"func.return"(%r) : (!riscv.reg) -> ()
// CHECK: "func.return"([[r]]) : (!riscv.reg) -> ()
}) : () -> ()

// riscv.srl x (riscv.li imm) -> riscv.srli x imm
"func.func"() <{function_type = (!riscv.reg) -> !riscv.reg}> ({
^bb(%a : !riscv.reg):
// CHECK: ^{{.*}}([[a2:%.*]] : !riscv.reg):
%c = "riscv.li"() <{"value" = 7 : i64}>: () -> !riscv.reg
%r = "riscv.srl"(%a, %c) : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK: [[r2:%.*]] = "riscv.srli"([[a2]]) <{"value" = 7 : i64}> : (!riscv.reg) -> !riscv.reg
"func.return"(%r) : (!riscv.reg) -> ()
// CHECK: "func.return"([[r2]]) : (!riscv.reg) -> ()
}) : () -> ()

// riscv.sra x (riscv.li imm) -> riscv.srai x imm (max in-range 6-bit shamt)
"func.func"() <{function_type = (!riscv.reg) -> !riscv.reg}> ({
^bb(%a : !riscv.reg):
// CHECK: ^{{.*}}([[a3:%.*]] : !riscv.reg):
%c = "riscv.li"() <{"value" = 63 : i64}>: () -> !riscv.reg
%r = "riscv.sra"(%a, %c) : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK: [[r3:%.*]] = "riscv.srai"([[a3]]) <{"value" = 63 : i64}> : (!riscv.reg) -> !riscv.reg
"func.return"(%r) : (!riscv.reg) -> ()
// CHECK: "func.return"([[r3]]) : (!riscv.reg) -> ()
}) : () -> ()

// riscv.ror x (riscv.li imm) -> riscv.rori x imm
"func.func"() <{function_type = (!riscv.reg) -> !riscv.reg}> ({
^bb(%a : !riscv.reg):
// CHECK: ^{{.*}}([[a4:%.*]] : !riscv.reg):
%c = "riscv.li"() <{"value" = 40 : i64}>: () -> !riscv.reg
%r = "riscv.ror"(%a, %c) : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK: [[r4:%.*]] = "riscv.rori"([[a4]]) <{"value" = 40 : i64}> : (!riscv.reg) -> !riscv.reg
"func.return"(%r) : (!riscv.reg) -> ()
// CHECK: "func.return"([[r4]]) : (!riscv.reg) -> ()
}) : () -> ()

// riscv.bclr x (riscv.li imm) -> riscv.bclri x imm
"func.func"() <{function_type = (!riscv.reg) -> !riscv.reg}> ({
^bb(%a : !riscv.reg):
// CHECK: ^{{.*}}([[a5:%.*]] : !riscv.reg):
%c = "riscv.li"() <{"value" = 5 : i64}>: () -> !riscv.reg
%r = "riscv.bclr"(%a, %c) : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK: [[r5:%.*]] = "riscv.bclri"([[a5]]) <{"value" = 5 : i64}> : (!riscv.reg) -> !riscv.reg
"func.return"(%r) : (!riscv.reg) -> ()
// CHECK: "func.return"([[r5]]) : (!riscv.reg) -> ()
}) : () -> ()

// riscv.bext x (riscv.li imm) -> riscv.bexti x imm
"func.func"() <{function_type = (!riscv.reg) -> !riscv.reg}> ({
^bb(%a : !riscv.reg):
// CHECK: ^{{.*}}([[a6:%.*]] : !riscv.reg):
%c = "riscv.li"() <{"value" = 17 : i64}>: () -> !riscv.reg
%r = "riscv.bext"(%a, %c) : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK: [[r6:%.*]] = "riscv.bexti"([[a6]]) <{"value" = 17 : i64}> : (!riscv.reg) -> !riscv.reg
"func.return"(%r) : (!riscv.reg) -> ()
// CHECK: "func.return"([[r6]]) : (!riscv.reg) -> ()
}) : () -> ()

// riscv.binv x (riscv.li imm) -> riscv.binvi x imm
"func.func"() <{function_type = (!riscv.reg) -> !riscv.reg}> ({
^bb(%a : !riscv.reg):
// CHECK: ^{{.*}}([[a7:%.*]] : !riscv.reg):
%c = "riscv.li"() <{"value" = 0 : i64}>: () -> !riscv.reg
%r = "riscv.binv"(%a, %c) : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK: [[r7:%.*]] = "riscv.binvi"([[a7]]) <{"value" = 0 : i64}> : (!riscv.reg) -> !riscv.reg
"func.return"(%r) : (!riscv.reg) -> ()
// CHECK: "func.return"([[r7]]) : (!riscv.reg) -> ()
}) : () -> ()

// riscv.bset x (riscv.li imm) -> riscv.bseti x imm
"func.func"() <{function_type = (!riscv.reg) -> !riscv.reg}> ({
^bb(%a : !riscv.reg):
// CHECK: ^{{.*}}([[a8:%.*]] : !riscv.reg):
%c = "riscv.li"() <{"value" = 62 : i64}>: () -> !riscv.reg
%r = "riscv.bset"(%a, %c) : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK: [[r8:%.*]] = "riscv.bseti"([[a8]]) <{"value" = 62 : i64}> : (!riscv.reg) -> !riscv.reg
"func.return"(%r) : (!riscv.reg) -> ()
// CHECK: "func.return"([[r8]]) : (!riscv.reg) -> ()
}) : () -> ()

// Non-commutative: li on the left is the shifted value, not the amount, so
// riscv.sll (riscv.li imm) x must NOT fold.
"func.func"() <{function_type = (!riscv.reg) -> !riscv.reg}> ({
^bb(%a : !riscv.reg):
// CHECK: ^{{.*}}([[a9:%.*]] : !riscv.reg):
%c = "riscv.li"() <{"value" = 3 : i64}>: () -> !riscv.reg
%r = "riscv.sll"(%c, %a) : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK: [[v:%.*]] = "riscv.li"() <{"value" = 3 : i64}> : () -> !riscv.reg
// CHECK: [[r9:%.*]] = "riscv.sll"([[v]], [[a9]]) : (!riscv.reg, !riscv.reg) -> !riscv.reg
"func.return"(%r) : (!riscv.reg) -> ()
// CHECK: "func.return"([[r9]]) : (!riscv.reg) -> ()
}) : () -> ()

// Shift amount out of the unsigned 6-bit range (64): not folded.
"func.func"() <{function_type = (!riscv.reg) -> !riscv.reg}> ({
^bb(%a : !riscv.reg):
// CHECK: ^{{.*}}([[a10:%.*]] : !riscv.reg):
%c = "riscv.li"() <{"value" = 64 : i64}>: () -> !riscv.reg
%r = "riscv.srl"(%a, %c) : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK: [[big:%.*]] = "riscv.li"() <{"value" = 64 : i64}> : () -> !riscv.reg
// CHECK: [[r10:%.*]] = "riscv.srl"([[a10]], [[big]]) : (!riscv.reg, !riscv.reg) -> !riscv.reg
"func.return"(%r) : (!riscv.reg) -> ()
// CHECK: "func.return"([[r10]]) : (!riscv.reg) -> ()
}) : () -> ()
}) : () -> ()
24 changes: 24 additions & 0 deletions Test/Passes/Combines/zextw_slli_to_slliuw.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: veir-opt %s -p=riscv-combine | filecheck %s

"builtin.module"() ({
// riscv.slli (riscv.zextw x) shamt -> riscv.slliuw x shamt
"func.func"() <{function_type = (!riscv.reg) -> !riscv.reg}> ({
^bb(%a : !riscv.reg):
// CHECK: ^{{.*}}([[a:%.*]] : !riscv.reg):
%z = "riscv.zextw"(%a) : (!riscv.reg) -> !riscv.reg
%r = "riscv.slli"(%z) <{"value" = 3 : i64}> : (!riscv.reg) -> !riscv.reg
// CHECK: [[r:%.*]] = "riscv.slliuw"([[a]]) <{"value" = 3 : i64}> : (!riscv.reg) -> !riscv.reg
"func.return"(%r) : (!riscv.reg) -> ()
// CHECK: "func.return"([[r]]) : (!riscv.reg) -> ()
}) : () -> ()

// A plain riscv.slli (operand not a zextw) must NOT fold.
"func.func"() <{function_type = (!riscv.reg) -> !riscv.reg}> ({
^bb(%a : !riscv.reg):
// CHECK: ^{{.*}}([[a2:%.*]] : !riscv.reg):
%r = "riscv.slli"(%a) <{"value" = 3 : i64}> : (!riscv.reg) -> !riscv.reg
// CHECK: [[r2:%.*]] = "riscv.slli"([[a2]]) <{"value" = 3 : i64}> : (!riscv.reg) -> !riscv.reg
"func.return"(%r) : (!riscv.reg) -> ()
// CHECK: "func.return"([[r2]]) : (!riscv.reg) -> ()
}) : () -> ()
}) : () -> ()
102 changes: 101 additions & 1 deletion Veir/Passes/Combines/Combine.lean
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,111 @@ def fold_add_li_to_addi (rewriter: PatternRewriter OpCode) (op: OperationPtr)
#[] #[] imm (some $ .before op) sorry (by simp) (by simp) sorry
rewriter.replaceOp op addiOp sorry sorry sorry sorry sorry

/--
Introduce a non-commutative immediate instruction, written as a parameterized
`LocalRewritePattern` so it can be connected to the
`LocalRewritePattern.PreservesSemantics` framework.

`riscv.src rs1 (riscv.li imm) -> riscv.dst rs1 imm`, only when the immediate
lies in `[lo, hi]`. Shifts/rotates/single-bit ops are not commutative, so the
immediate is only matched on the second operand. `dst`'s properties must be
`RISCVImmediateProperties` (proof `h`), used to carry the matched immediate over.

The `match hbinop : … with` binds the match equation so that
`matchRiscvBinop_reg_inBounds` discharges `createOp`'s operand-in-bounds
obligation — no `sorry`. The generic `RewritePattern.fromLocalRewrite` driver
performs the insertion, result replacement, and erasure.

The shift-amount/bit-index width is just the `[lo, hi]` range, so the imm5 and
imm6 families are the same rewrite at different bounds (`fold_shift5_li`,
`fold_shift6_li` below).
-/
def fold_binop_li (src dst : Riscv) (h : Riscv.propertiesOf dst = RISCVImmediateProperties)
(lo hi : Int) : LocalRewritePattern OpCode := fun ctx op =>
match hbinop : matchRiscvBinop src op ctx with
| none => some (ctx, none)
| some (reg, rhs) =>
match matchLi rhs ctx with
| none => some (ctx, none)
| some imm =>
if imm.value.value < lo || imm.value.value > hi then some (ctx, none)
else do
let (ctx, newOp) ← WfRewriter.createOp ctx (.riscv dst) #[RegisterType.mk] #[reg]
#[] #[] (cast h.symm imm) none
(by intro oper hmem
simp only [Array.mem_singleton] at hmem
subst hmem
exact matchRiscvBinop_reg_inBounds hbinop)
return (ctx, some (#[newOp], #[newOp.getResult 0]))

/-- imm5 (word shifts/rotates): `src rs1 (li imm) -> dst rs1 imm` for `imm ∈ [0,31]`.
Covers: sllw→slliw, srlw→srliw, sraw→sraiw, rorw→roriw. -/
def fold_shift5_li (src dst : Riscv) (h : Riscv.propertiesOf dst = RISCVImmediateProperties) :
LocalRewritePattern OpCode := fold_binop_li src dst h 0 31

/-- imm6 (full-width shifts/rotates and single-bit ops): `imm ∈ [0,63]`.
Covers: sll→slli, srl→srli, sra→srai, ror→rori, bclr→bclri, bext→bexti,
binv→binvi, bset→bseti. -/
def fold_shift6_li (src dst : Riscv) (h : Riscv.propertiesOf dst = RISCVImmediateProperties) :
LocalRewritePattern OpCode := fold_binop_li src dst h 0 63

def fold_sllw_li_to_slliw := fold_shift5_li .sllw .slliw rfl
def fold_srlw_li_to_srliw := fold_shift5_li .srlw .srliw rfl
def fold_sraw_li_to_sraiw := fold_shift5_li .sraw .sraiw rfl
def fold_rorw_li_to_roriw := fold_shift5_li .rorw .roriw rfl

def fold_sll_li_to_slli := fold_shift6_li .sll .slli rfl
def fold_srl_li_to_srli := fold_shift6_li .srl .srli rfl
def fold_sra_li_to_srai := fold_shift6_li .sra .srai rfl
def fold_ror_li_to_rori := fold_shift6_li .ror .rori rfl
def fold_bclr_li_to_bclri := fold_shift6_li .bclr .bclri rfl
def fold_bext_li_to_bexti := fold_shift6_li .bext .bexti rfl
def fold_binv_li_to_binvi := fold_shift6_li .binv .binvi rfl
def fold_bset_li_to_bseti := fold_shift6_li .bset .bseti rfl

/--
Contract a zero-extend-word feeding a shift-left-immediate into slli.uw:
riscv.slli (riscv.zextw x) shamt -> riscv.slliuw x shamt
Both compute `zeroExtend64(low32(x)) <<< shamt`, so this is an exact rewrite.
`slli` and `slliuw` share the same unsigned 6-bit shift-amount field, so the
immediate is carried over unchanged. Written as a `LocalRewritePattern`; the
`match hzext : … with` binds the match equation so that `matchZextw_inBounds`
discharges `createOp`'s operand obligation — no `sorry`.
-/
def fold_zextw_slli_to_slliuw : LocalRewritePattern OpCode := fun ctx op =>
match matchOp op ctx (.riscv .slli) 1 with
| none => some (ctx, none)
| some (operands, shamt) =>
match hzext : matchZextw operands[0]! ctx with
| none => some (ctx, none)
| some x => do
let (ctx, newOp) ← WfRewriter.createOp ctx (.riscv .slliuw) #[RegisterType.mk] #[x]
#[] #[] shamt none
(by intro oper hmem
simp only [Array.mem_singleton] at hmem
subst hmem
exact matchZextw_inBounds hzext)
return (ctx, some (#[newOp], #[newOp.getResult 0]))

/-! # Pass implementation -/

def Combine.impl (ctx : WfIRContext OpCode) (op : OperationPtr) (_ : op.InBounds ctx.raw) :
ExceptT String IO (WfIRContext OpCode) := do
let pattern := RewritePattern.GreedyRewritePattern #[right_identity_zero_add, fold_add_li_to_addi]
let pattern := RewritePattern.GreedyRewritePattern #[right_identity_zero_add,
fold_add_li_to_addi,
RewritePattern.fromLocalRewrite fold_sllw_li_to_slliw,
RewritePattern.fromLocalRewrite fold_srlw_li_to_srliw,
RewritePattern.fromLocalRewrite fold_sraw_li_to_sraiw,
RewritePattern.fromLocalRewrite fold_rorw_li_to_roriw,
RewritePattern.fromLocalRewrite fold_sll_li_to_slli,
RewritePattern.fromLocalRewrite fold_srl_li_to_srli,
RewritePattern.fromLocalRewrite fold_sra_li_to_srai,
RewritePattern.fromLocalRewrite fold_ror_li_to_rori,
RewritePattern.fromLocalRewrite fold_bclr_li_to_bclri,
RewritePattern.fromLocalRewrite fold_bext_li_to_bexti,
RewritePattern.fromLocalRewrite fold_binv_li_to_binvi,
RewritePattern.fromLocalRewrite fold_bset_li_to_bseti,
RewritePattern.fromLocalRewrite fold_zextw_slli_to_slliuw]
match RewritePattern.applyInContext pattern ctx with
| none => throw "Error while applying pattern rewrites"
| some ctx => pure ctx
Expand Down
Loading
Loading