-
Notifications
You must be signed in to change notification settings - Fork 63
[Feat] Complete BasisAttr support in IntTupleBuilder (#574) #605
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
54148ed
ae79fc1
7eab7d4
7255fff
72cf4e8
b3c06cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -138,58 +138,83 @@ IntTupleAttr IntTupleBuilder<IntTupleAttr>::mul(IntTupleAttr lhs, IntTupleAttr r | |
| } | ||
| } | ||
| IntTupleAttr IntTupleBuilder<IntTupleAttr>::div(IntTupleAttr lhs, IntTupleAttr rhs) const { | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt()); | ||
| return IntTupleAttr::get(lhs.getLeafAsInt() / rhs.getLeafAsInt()); | ||
| // A basis divisor has no quotient mode (reachable for rank>=3 identity layouts | ||
| // via complement), so reject it rather than miscompute a stride. | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The rank>=3 path described here does not currently reach this named assert. A direct user-level repro is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch — dug in, and it's a bit different from what my comment implied, so I split it (7eab7d4):
The recoverable op-layer diagnostic (so a bad tiler surfaces as an |
||
| assert(lhs.isLeaf() && rhs.isLeafInt() && | ||
| "div is undefined for a basis divisor; lower the identity layout first"); | ||
| if (lhs.isLeafInt()) { | ||
| return IntTupleAttr::get(lhs.getLeafAsInt() / rhs.getLeafAsInt()); | ||
| } | ||
| return IntTupleAttr::get(intSafeDiv(lhs.getLeafAsBasis(), rhs.getLeafAsInt())); | ||
| } | ||
| IntTupleAttr IntTupleBuilder<IntTupleAttr>::mod(IntTupleAttr lhs, IntTupleAttr rhs) const { | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt()); | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt() && "mod is defined only on integer leaves"); | ||
| return IntTupleAttr::get(lhs.getLeafAsInt() % rhs.getLeafAsInt()); | ||
| } | ||
|
|
||
| IntTupleAttr IntTupleBuilder<IntTupleAttr>::logicalAnd(IntTupleAttr lhs, IntTupleAttr rhs) const { | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt()); | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt() && "logicalAnd requires integer leaves"); | ||
| return IntTupleAttr::get(lhs.getLeafAsInt() && rhs.getLeafAsInt()); | ||
| } | ||
| IntTupleAttr IntTupleBuilder<IntTupleAttr>::logicalOr(IntTupleAttr lhs, IntTupleAttr rhs) const { | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt()); | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt() && "logicalOr requires integer leaves"); | ||
| return IntTupleAttr::get(lhs.getLeafAsInt() || rhs.getLeafAsInt()); | ||
| } | ||
| IntTupleAttr IntTupleBuilder<IntTupleAttr>::logicalNot(IntTupleAttr val) const { | ||
| assert(val.isLeafInt()); | ||
| assert(val.isLeafInt() && "logicalNot requires an integer leaf"); | ||
| return IntTupleAttr::get(!val.getLeafAsInt()); | ||
| } | ||
|
|
||
| IntTupleAttr IntTupleBuilder<IntTupleAttr>::lt(IntTupleAttr lhs, IntTupleAttr rhs) const { | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt()); | ||
| // Ordering is defined only on integer leaves: basis directions are free-module | ||
| // generators with no total order, and strides are ordered via getStaticValue. | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt() && | ||
| "lt is defined only on integer leaves; basis directions have no order"); | ||
| return IntTupleAttr::get(lhs.getLeafAsInt() < rhs.getLeafAsInt()); | ||
| } | ||
| IntTupleAttr IntTupleBuilder<IntTupleAttr>::le(IntTupleAttr lhs, IntTupleAttr rhs) const { | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt()); | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt() && | ||
| "le is defined only on integer leaves; basis directions have no order"); | ||
| return IntTupleAttr::get(lhs.getLeafAsInt() <= rhs.getLeafAsInt()); | ||
| } | ||
| IntTupleAttr IntTupleBuilder<IntTupleAttr>::gt(IntTupleAttr lhs, IntTupleAttr rhs) const { | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt()); | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt() && | ||
| "gt is defined only on integer leaves; basis directions have no order"); | ||
| return IntTupleAttr::get(lhs.getLeafAsInt() > rhs.getLeafAsInt()); | ||
| } | ||
| IntTupleAttr IntTupleBuilder<IntTupleAttr>::ge(IntTupleAttr lhs, IntTupleAttr rhs) const { | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt()); | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt() && | ||
| "ge is defined only on integer leaves; basis directions have no order"); | ||
| return IntTupleAttr::get(lhs.getLeafAsInt() >= rhs.getLeafAsInt()); | ||
| } | ||
| IntTupleAttr IntTupleBuilder<IntTupleAttr>::eq(IntTupleAttr lhs, IntTupleAttr rhs) const { | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt()); | ||
| return IntTupleAttr::get(lhs.getLeafAsInt() == rhs.getLeafAsInt()); | ||
| assert(lhs.isLeaf() && rhs.isLeaf()); | ||
| if (lhs.isLeafInt() && rhs.isLeafInt()) { | ||
| return IntTupleAttr::get(lhs.getLeafAsInt() == rhs.getLeafAsInt()); | ||
| } | ||
| if (lhs.isLeafBasis() && rhs.isLeafBasis()) { | ||
| return IntTupleAttr::get(lhs.getLeafAsBasis() == rhs.getLeafAsBasis()); | ||
| } | ||
| // A scalar leaf and a basis monomial never coincide. | ||
| return materializeConstantLeaf(0); | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Switched both the |
||
| } | ||
| IntTupleAttr IntTupleBuilder<IntTupleAttr>::ne(IntTupleAttr lhs, IntTupleAttr rhs) const { | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt()); | ||
| return IntTupleAttr::get(lhs.getLeafAsInt() != rhs.getLeafAsInt()); | ||
| assert(lhs.isLeaf() && rhs.isLeaf()); | ||
| if (lhs.isLeafInt() && rhs.isLeafInt()) { | ||
| return IntTupleAttr::get(lhs.getLeafAsInt() != rhs.getLeafAsInt()); | ||
| } | ||
| if (lhs.isLeafBasis() && rhs.isLeafBasis()) { | ||
| return IntTupleAttr::get(lhs.getLeafAsBasis() != rhs.getLeafAsBasis()); | ||
| } | ||
| return materializeConstantLeaf(1); | ||
| } | ||
|
|
||
| IntTupleAttr IntTupleBuilder<IntTupleAttr>::min(IntTupleAttr lhs, IntTupleAttr rhs) const { | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt()); | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt() && "min is defined only on integer leaves"); | ||
| return IntTupleAttr::get(intMin(lhs.getLeafAsInt(), rhs.getLeafAsInt())); | ||
| } | ||
| IntTupleAttr IntTupleBuilder<IntTupleAttr>::max(IntTupleAttr lhs, IntTupleAttr rhs) const { | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt()); | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt() && "max is defined only on integer leaves"); | ||
| return IntTupleAttr::get(intMax(lhs.getLeafAsInt(), rhs.getLeafAsInt())); | ||
| } | ||
| IntTupleAttr IntTupleBuilder<IntTupleAttr>::safeDiv(IntTupleAttr lhs, IntTupleAttr rhs) const { | ||
|
|
@@ -209,13 +234,14 @@ IntTupleAttr IntTupleBuilder<IntTupleAttr>::ceilDiv(IntTupleAttr lhs, IntTupleAt | |
| } | ||
| } | ||
| IntTupleAttr IntTupleBuilder<IntTupleAttr>::shapeDiv(IntTupleAttr lhs, IntTupleAttr rhs) const { | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt()); | ||
| assert(lhs.isLeafInt() && rhs.isLeafInt() && "shapeDiv requires integer leaves"); | ||
| return IntTupleAttr::get(intShapeDiv(lhs.getLeafAsInt(), rhs.getLeafAsInt())); | ||
| } | ||
|
|
||
| IntTupleAttr IntTupleBuilder<IntTupleAttr>::applySwizzle(IntTupleAttr v, | ||
| SwizzleAttr swizzle) const { | ||
| assert(v.isLeafInt() && "applySwizzle only supports leafInt IntTupleAttr"); | ||
| assert(v.isLeafInt() && | ||
| "applySwizzle is undefined on a basis leaf; lower the identity layout first"); | ||
| return IntTupleAttr::get(intApplySwizzle(v.getLeafAsInt(), swizzle)); | ||
| } | ||
|
|
||
|
|
@@ -226,8 +252,10 @@ IntTupleAttr IntTupleBuilder<IntTupleAttr>::applyCoordSwizzle(IntTupleAttr coord | |
| } | ||
| IntTupleAttr row = selectPath(*this, coord, swizzle.getModeRow()); | ||
| IntTupleAttr col = selectPath(*this, coord, swizzle.getModeCol()); | ||
| assert(row.isLeafInt() && "coord swizzle row mode must select a leaf int"); | ||
| assert(col.isLeafInt() && "coord swizzle col mode must select a leaf int"); | ||
| assert(row.isLeafInt() && | ||
| "coord swizzle row mode must select an integer leaf (basis unsupported)"); | ||
| assert(col.isLeafInt() && | ||
| "coord swizzle col mode must select an integer leaf (basis unsupported)"); | ||
|
|
||
| IntTupleAttr newCol = | ||
| IntTupleAttr::get(intApplyCoordSwizzle(row.getLeafAsInt(), col.getLeafAsInt(), swizzle)); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -89,6 +89,8 @@ | |
| "make_int_tuple", | ||
| "make_shape", | ||
| "make_stride", | ||
| "E", | ||
| "make_basis_stride", | ||
| "make_coord", | ||
| "make_layout", | ||
| "make_layout_like", | ||
|
|
@@ -356,6 +358,44 @@ def make_stride(*stride, loc=None, ip=None): | |
| return fly.make_stride(IntTupleTy, dyncElems, loc=loc, ip=ip) | ||
|
|
||
|
|
||
| class _BasisElem: | ||
| """A scaled-basis stride leaf (CuTe ScaledBasis / E<I>), e.g. ``fx.E(0)`` -> ``1E0``. | ||
|
|
||
| Duck-typed by the C++ int-tuple builder via the ``__fly_basis__`` marker. It is a | ||
| compile-time-only value (no runtime ``ir.Value`` operand), accepted anywhere a | ||
| stride entry is, e.g. ``make_stride(E(0), E(1))``. | ||
| """ | ||
|
|
||
| __fly_basis__ = True | ||
|
|
||
| def __init__(self, value, modes): | ||
| self.value = value | ||
| self.modes = modes | ||
|
|
||
|
|
||
| def E(*modes, value=1): | ||
| """Build a scaled-basis stride leaf. | ||
|
|
||
| Examples: | ||
| E(0) -> 1E0 | ||
| E(1) -> 1E1 | ||
| E(0, value=2) -> 2E0 | ||
| E(0, 1) -> 1E0E1 (a single leaf along modes 0 and 1) | ||
| """ | ||
| if not modes: | ||
| raise ValueError("E requires at least one mode") | ||
| return _BasisElem(value, list(modes)) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in 7eab7d4.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Confirmed bug:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in 7eab7d4. |
||
|
|
||
|
|
||
| def make_basis_stride(value, modes): | ||
| """Build a flat basis stride, one ``E`` leaf per mode. | ||
|
|
||
| ``make_basis_stride(1, (0, 1))`` -> ``(1E0, 1E1)``, the stride | ||
| ``make_identity_layout`` emits for that rank. | ||
| """ | ||
| return make_stride(*[E(m, value=value) for m in modes]) | ||
|
|
||
|
|
||
| @traced_op | ||
| def make_coord(*coord, loc=None, ip=None): | ||
| """Build a coordinate used for indexing / slicing a layout. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // Copyright (c) 2025 FlyDSL Project Contributors | ||
| // RUN: %fly-opt %s | FileCheck %s | ||
|
|
||
| // Tests for BasisAttr (scaled-basis / CuTe E<I>) stride leaves flowing through | ||
| // the IntTupleBuilder ops: division of a basis coefficient, and a full identity | ||
| // layout (1E0, 1E1, ...) through logical_divide. The inferred result types below | ||
| // are checked by fly-opt's type inference (a wrong type fails to parse). | ||
|
|
||
| // ----- | ||
|
|
||
| // div divides the scalar coefficient of each basis leaf, keeping its modes: | ||
| // (2E0, 8E1) / (2, 4) = (1E0, 2E1) | ||
| // CHECK-LABEL: @test_int_tuple_div_basis | ||
| func.func @test_int_tuple_div_basis() -> !fly.int_tuple<(1E0, 2E1)> { | ||
| %a = fly.static : !fly.int_tuple<(2E0, 8E1)> | ||
| %b = fly.static : !fly.int_tuple<(2, 4)> | ||
| // CHECK: fly.int_tuple_div(%{{.*}}, %{{.*}}) | ||
| %result = fly.int_tuple_div(%a, %b) : (!fly.int_tuple<(2E0, 8E1)>, !fly.int_tuple<(2, 4)>) -> !fly.int_tuple<(1E0, 2E1)> | ||
| return %result : !fly.int_tuple<(1E0, 2E1)> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // logical_divide partitions a basis-strided identity layout; the algebra walks the | ||
| // basis strides (via complement/div) instead of asserting on a non-int leaf. | ||
| // CHECK-LABEL: @test_logical_divide_identity | ||
| func.func @test_logical_divide_identity() -> !fly.layout<((2, (2, 2)), 4) : ((1E0, (2E0, 1E1)), 2E1)> { | ||
| %s = fly.static : !fly.int_tuple<(4, 8)> | ||
| %id = fly.make_identity_layout(%s) : (!fly.int_tuple<(4, 8)>) -> !fly.layout<(4, 8) : (1E0, 1E1)> | ||
| %ds = fly.static : !fly.int_tuple<(2, 4)> | ||
| %dd = fly.static : !fly.int_tuple<(1, 2)> | ||
| %div = fly.make_layout(%ds, %dd) : (!fly.int_tuple<(2, 4)>, !fly.int_tuple<(1, 2)>) -> !fly.layout<(2, 4) : (1, 2)> | ||
| // CHECK: fly.logical_divide(%{{.*}}, %{{.*}}) | ||
| %result = fly.logical_divide(%id, %div) : (!fly.layout<(4, 8) : (1E0, 1E1)>, !fly.layout<(2, 4) : (1, 2)>) -> !fly.layout<((2, (2, 2)), 4) : ((1E0, (2E0, 1E1)), 2E1)> | ||
| return %result : !fly.layout<((2, (2, 2)), 4) : ((1E0, (2E0, 1E1)), 2E1)> | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -427,6 +427,41 @@ def build(): | |
| _build_and_verify("zipped_tiled_flat_product", build, [360]) | ||
|
|
||
|
|
||
| # ============================================================================== | ||
| # Basis strides (scaled-basis / CuTe E<I>) via fx.E | ||
| # ============================================================================== | ||
|
|
||
|
|
||
| def test_basis_stride_via_E(): | ||
| """fx.E builds the same basis-strided layout as make_identity_layout.""" | ||
|
|
||
| def build(): | ||
| fx.make_layout(fx.make_shape(4, 8), fx.make_stride(fx.E(0), fx.E(1))) | ||
| fx.make_identity_layout(fx.make_shape(4, 8)) | ||
| fx.make_basis_stride(1, (0, 1)) | ||
| fx.make_stride(fx.E(0, 1, value=2)) | ||
|
|
||
| def check(ir): | ||
| # fx.E(0), fx.E(1) produce a (1E0, 1E1) stride, identical to make_identity_layout. | ||
| assert "!fly.layout<(4,8):(1E0,1E1)>" in ir | ||
| # make_basis_stride(1, (0, 1)) yields the same flat basis stride. | ||
| assert "!fly.int_tuple<(1E0,1E1)>" in ir | ||
| # value and multi-mode forms: fx.E(0, 1, value=2) -> 2E0E1. | ||
| assert "!fly.int_tuple<(2E0E1)>" in ir | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in 7eab7d4 — |
||
|
|
||
| _build_and_verify_ir("basis_stride_via_E", build, check) | ||
|
|
||
|
|
||
| def test_basis_identity_size(): | ||
| """A basis-strided identity layout lowers through the pipeline; size = 4*8 = 32.""" | ||
|
|
||
| def build(): | ||
| layout = fx.make_layout(fx.make_shape(4, 8), fx.make_stride(fx.E(0), fx.E(1))) | ||
| return [fx.size(layout)] | ||
|
|
||
| _build_and_verify("basis_identity_size", build, [32]) | ||
|
|
||
|
|
||
| # ============================================================================== | ||
| # Main | ||
| # ============================================================================== | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both fixed in 7eab7d4. The branch now gates on a truthy
__fly_basis__viaPyObject_IsTrue, so an object that sets it toFalseis no longer mistaken for a basis leaf — and a truthy non-bool marker still works (a plainnb::cast<bool>throws on anything that isn't thePy_True/Py_Falsesingleton). Modes are read by iterating the object directly, so any sequence/iterable (tuple, list, …) is accepted rather than onlynb::list.