Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/layout_system_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ row_major = fx.make_ordered_layout((M, N), order=(1, 0)) # stride order: N-firs
# Identity layout / tensor
identity = fx.make_identity_layout((M, N))
id_tensor = fx.make_identity_tensor((M, N))

# Scaled-basis strides (CuTe E<I>): fx.E(mode, *, value=1)
# fx.E(0) -> 1E0, fx.E(0, value=2) -> 2E0, fx.E(0, 1) -> 1E0E1
basis_layout = fx.make_layout((4, 8), fx.make_stride(fx.E(0), fx.E(1))) # same as make_identity_layout((4, 8))
flat_basis = fx.make_basis_stride(1, (0, 1)) # !fly.int_tuple<(1E0, 1E1)>
```

---
Expand Down
8 changes: 8 additions & 0 deletions lib/Bindings/Python/FlyExtension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ struct IntTupleAttrBuilder {
return IntTupleAttr::get(IntAttr::getStatic(ctx, cInt));
} else if (args.is_none()) {
return IntTupleAttr::getLeafNone(ctx);
} else if (nb::hasattr(args, "__fly_basis__")) {
// Scaled-basis stride leaf (e.g. fx.E(0) -> 1E0), duck-typed on __fly_basis__.
int32_t value = nb::cast<int32_t>(args.attr("value"));
SmallVector<int32_t> modes;
for (auto mode : nb::cast<nb::list>(args.attr("modes"))) {
modes.push_back(nb::cast<int32_t>(mode));
}
return IntTupleAttr::get(BasisAttr::get(IntAttr::getStatic(ctx, value), modes));

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both fixed in 7eab7d4. The branch now gates on a truthy __fly_basis__ via PyObject_IsTrue, so an object that sets it to False is no longer mistaken for a basis leaf — and a truthy non-bool marker still works (a plain nb::cast<bool> throws on anything that isn't the Py_True/Py_False singleton). Modes are read by iterating the object directly, so any sequence/iterable (tuple, list, …) is accepted rather than only nb::list.

} else {
if (!nb::hasattr(args, "_CAPIPtr")) {
throw std::invalid_argument("Expected I32, got: " +
Expand Down
68 changes: 48 additions & 20 deletions lib/Dialect/Fly/Utils/IntTupleUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,58 +138,83 @@ IntTupleAttr IntTupleBuilder<IntTupleAttr>::mul(IntTupleAttr lhs, IntTupleAttr r
}
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::div(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
return IntTupleAttr::get(lhs.getLeafAsInt() / rhs.getLeafAsInt());
// A basis divisor has no quotient mode (reachable for rank>=3 identity layouts
// via complement), so reject it rather than miscompute a stride.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rank>=3 path described here does not currently reach this named assert. A direct user-level repro is fx.logical_divide(fx.make_identity_layout((4, 8, 2)), fx.make_layout((2, 4, 2), fx.make_stride(1, 2, 4))); it SIGFPEs during LogicalDivideOp::inferReturnTypes inside compositionImpl before producing the promised div(Basis, Basis) rejection. If rank-3 identity divide is intentionally unsupported, please add a regression for this case and guard it before the algebra can hit the integer %/division-by-zero path.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — dug in, and it's a bit different from what my comment implied, so I split it (7eab7d4):

  1. The div() comment was imprecise. div(Basis, Basis) is reachable, but via complement() of a basis-strided identity layout of rank ≥ 2 — complement(make_identity_layout((4,8))) and (4,8,2) both hit the named assert. It is not reached through logical_divide. Reworded the comment to say so.

  2. The SIGFPE is a separate, pre-existing bug and isn't basis-specific. (2,4,2):(1,2,4) is a non-tiling (overlapping) divisor: its complement has a 0-extent mode, and compositionImpl then does 0 % 0. The same divisor SIGFPEs on a non-coalescible integer layout (4,8,2):(1,100,7) too — the identity case only reaches it because basis strides never coalesce to rank-1, whereas a compact integer layout like (1,4,32) coalesces away and dodges the loop. A rank-3 identity divide with a valid tiler, e.g. (2,4,2):(1,2,8), works fine.

  3. Guarded it. compositionImpl now asserts newShapeVal != 0 with a named message before the %, so a non-tiling divisor aborts with "divisor is not a tiling layout" instead of a SIGFPE. Added the rank-3 valid-divisor case to basis.mlir as a positive regression.

The recoverable op-layer diagnostic (so a bad tiler surfaces as an expected-error rather than an assert) needs a Location, which IntTupleBuilder doesn't carry — that's the #583 op-gate work, consistent with this PR's deferral of located diagnostics.

assert(lhs.isLeaf() && rhs.isLeafInt() &&
"div is undefined for a basis divisor; lower the identity layout first");
if (lhs.isLeafInt()) {
return IntTupleAttr::get(lhs.getLeafAsInt() / rhs.getLeafAsInt());
}
return IntTupleAttr::get(intSafeDiv(lhs.getLeafAsBasis(), rhs.getLeafAsInt()));
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::mod(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
assert(lhs.isLeafInt() && rhs.isLeafInt() && "mod is defined only on integer leaves");
return IntTupleAttr::get(lhs.getLeafAsInt() % rhs.getLeafAsInt());
}

IntTupleAttr IntTupleBuilder<IntTupleAttr>::logicalAnd(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
assert(lhs.isLeafInt() && rhs.isLeafInt() && "logicalAnd requires integer leaves");
return IntTupleAttr::get(lhs.getLeafAsInt() && rhs.getLeafAsInt());
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::logicalOr(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
assert(lhs.isLeafInt() && rhs.isLeafInt() && "logicalOr requires integer leaves");
return IntTupleAttr::get(lhs.getLeafAsInt() || rhs.getLeafAsInt());
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::logicalNot(IntTupleAttr val) const {
assert(val.isLeafInt());
assert(val.isLeafInt() && "logicalNot requires an integer leaf");
return IntTupleAttr::get(!val.getLeafAsInt());
}

IntTupleAttr IntTupleBuilder<IntTupleAttr>::lt(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
// Ordering is defined only on integer leaves: basis directions are free-module
// generators with no total order, and strides are ordered via getStaticValue.
assert(lhs.isLeafInt() && rhs.isLeafInt() &&
"lt is defined only on integer leaves; basis directions have no order");
return IntTupleAttr::get(lhs.getLeafAsInt() < rhs.getLeafAsInt());
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::le(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
assert(lhs.isLeafInt() && rhs.isLeafInt() &&
"le is defined only on integer leaves; basis directions have no order");
return IntTupleAttr::get(lhs.getLeafAsInt() <= rhs.getLeafAsInt());
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::gt(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
assert(lhs.isLeafInt() && rhs.isLeafInt() &&
"gt is defined only on integer leaves; basis directions have no order");
return IntTupleAttr::get(lhs.getLeafAsInt() > rhs.getLeafAsInt());
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::ge(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
assert(lhs.isLeafInt() && rhs.isLeafInt() &&
"ge is defined only on integer leaves; basis directions have no order");
return IntTupleAttr::get(lhs.getLeafAsInt() >= rhs.getLeafAsInt());
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::eq(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
return IntTupleAttr::get(lhs.getLeafAsInt() == rhs.getLeafAsInt());
assert(lhs.isLeaf() && rhs.isLeaf());
if (lhs.isLeafInt() && rhs.isLeafInt()) {
return IntTupleAttr::get(lhs.getLeafAsInt() == rhs.getLeafAsInt());
}
if (lhs.isLeafBasis() && rhs.isLeafBasis()) {
return IntTupleAttr::get(lhs.getLeafAsBasis() == rhs.getLeafAsBasis());
}
// A scalar leaf and a basis monomial never coincide.
return materializeConstantLeaf(0);

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched both the eq and ne mismatch branches to IntTupleAttr::getLeafStatic(ctx, 0/1) in 7eab7d4 — an explicit static leaf, consistent with the sibling branches and free of the builder method. It's the same attribute materializeConstantLeaf(0/1) produced, just constructed directly.

}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::ne(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
return IntTupleAttr::get(lhs.getLeafAsInt() != rhs.getLeafAsInt());
assert(lhs.isLeaf() && rhs.isLeaf());
if (lhs.isLeafInt() && rhs.isLeafInt()) {
return IntTupleAttr::get(lhs.getLeafAsInt() != rhs.getLeafAsInt());
}
if (lhs.isLeafBasis() && rhs.isLeafBasis()) {
return IntTupleAttr::get(lhs.getLeafAsBasis() != rhs.getLeafAsBasis());
}
return materializeConstantLeaf(1);
}

IntTupleAttr IntTupleBuilder<IntTupleAttr>::min(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
assert(lhs.isLeafInt() && rhs.isLeafInt() && "min is defined only on integer leaves");
return IntTupleAttr::get(intMin(lhs.getLeafAsInt(), rhs.getLeafAsInt()));
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::max(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
assert(lhs.isLeafInt() && rhs.isLeafInt() && "max is defined only on integer leaves");
return IntTupleAttr::get(intMax(lhs.getLeafAsInt(), rhs.getLeafAsInt()));
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::safeDiv(IntTupleAttr lhs, IntTupleAttr rhs) const {
Expand All @@ -209,13 +234,14 @@ IntTupleAttr IntTupleBuilder<IntTupleAttr>::ceilDiv(IntTupleAttr lhs, IntTupleAt
}
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::shapeDiv(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
assert(lhs.isLeafInt() && rhs.isLeafInt() && "shapeDiv requires integer leaves");
return IntTupleAttr::get(intShapeDiv(lhs.getLeafAsInt(), rhs.getLeafAsInt()));
}

IntTupleAttr IntTupleBuilder<IntTupleAttr>::applySwizzle(IntTupleAttr v,
SwizzleAttr swizzle) const {
assert(v.isLeafInt() && "applySwizzle only supports leafInt IntTupleAttr");
assert(v.isLeafInt() &&
"applySwizzle is undefined on a basis leaf; lower the identity layout first");
return IntTupleAttr::get(intApplySwizzle(v.getLeafAsInt(), swizzle));
}

Expand All @@ -226,8 +252,10 @@ IntTupleAttr IntTupleBuilder<IntTupleAttr>::applyCoordSwizzle(IntTupleAttr coord
}
IntTupleAttr row = selectPath(*this, coord, swizzle.getModeRow());
IntTupleAttr col = selectPath(*this, coord, swizzle.getModeCol());
assert(row.isLeafInt() && "coord swizzle row mode must select a leaf int");
assert(col.isLeafInt() && "coord swizzle col mode must select a leaf int");
assert(row.isLeafInt() &&
"coord swizzle row mode must select an integer leaf (basis unsupported)");
assert(col.isLeafInt() &&
"coord swizzle col mode must select an integer leaf (basis unsupported)");

IntTupleAttr newCol =
IntTupleAttr::get(intApplyCoordSwizzle(row.getLeafAsInt(), col.getLeafAsInt(), swizzle));
Expand Down
40 changes: 40 additions & 0 deletions python/flydsl/expr/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@
"make_int_tuple",
"make_shape",
"make_stride",
"E",
"make_basis_stride",
"make_coord",
"make_layout",
"make_layout_like",
Expand Down Expand Up @@ -356,6 +358,44 @@ def make_stride(*stride, loc=None, ip=None):
return fly.make_stride(IntTupleTy, dyncElems, loc=loc, ip=ip)


class _BasisElem:
"""A scaled-basis stride leaf (CuTe ScaledBasis / E<I>), e.g. ``fx.E(0)`` -> ``1E0``.

Duck-typed by the C++ int-tuple builder via the ``__fly_basis__`` marker. It is a
compile-time-only value (no runtime ``ir.Value`` operand), accepted anywhere a
stride entry is, e.g. ``make_stride(E(0), E(1))``.
"""

__fly_basis__ = True

def __init__(self, value, modes):
self.value = value
self.modes = modes


def E(*modes, value=1):
"""Build a scaled-basis stride leaf.

Examples:
E(0) -> 1E0
E(1) -> 1E1
E(0, value=2) -> 2E0
E(0, 1) -> 1E0E1 (a single leaf along modes 0 and 1)
"""
if not modes:
raise ValueError("E requires at least one mode")
return _BasisElem(value, list(modes))

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 7eab7d4. E now validates in Python via _as_int32: value and every mode must be integer-like (Python int or a NumPy integer scalar, via operator.index) and within int32 range, raising a clear TypeError/ValueError instead of deferring to the nanobind cast. Covered by the new test_basis_E_rejects_invalid_modes.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed bug: E accepts negative modes, but the BasisAttr assembly parser cannot round-trip them. fx.make_stride(fx.E(-1)) prints !fly.int_tuple<(1E-1)>, and feeding that IR back to fly-opt fails to parse at E-1 (expected ')'). Please validate modes here as non-negative int32 values before constructing _BasisElem (and mirror that in the binding if the duck-typed __fly_basis__ hook is meant to be public).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 7eab7d4. E now rejects negative modes (_as_int32(..., nonneg=True)), and the public __fly_basis__ binding hook mirrors it — a negative mode there throws a clear error before constructing the BasisAttr. value stays sign-unrestricted since -2E0/0E0 round-trip fine; only the mode index breaks the assembly format. Regression added in test_basis_E_rejects_invalid_modes.



def make_basis_stride(value, modes):
"""Build a flat basis stride, one ``E`` leaf per mode.

``make_basis_stride(1, (0, 1))`` -> ``(1E0, 1E1)``, the stride
``make_identity_layout`` emits for that rank.
"""
return make_stride(*[E(m, value=value) for m in modes])


@traced_op
def make_coord(*coord, loc=None, ip=None):
"""Build a coordinate used for indexing / slicing a layout.
Expand Down
37 changes: 37 additions & 0 deletions tests/mlir/LayoutAlgebra/basis.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright (c) 2025 FlyDSL Project Contributors
// RUN: %fly-opt %s | FileCheck %s

// Tests for BasisAttr (scaled-basis / CuTe E<I>) stride leaves flowing through
// the IntTupleBuilder ops: division of a basis coefficient, and a full identity
// layout (1E0, 1E1, ...) through logical_divide. The inferred result types below
// are checked by fly-opt's type inference (a wrong type fails to parse).

// -----

// div divides the scalar coefficient of each basis leaf, keeping its modes:
// (2E0, 8E1) / (2, 4) = (1E0, 2E1)
// CHECK-LABEL: @test_int_tuple_div_basis
func.func @test_int_tuple_div_basis() -> !fly.int_tuple<(1E0, 2E1)> {
%a = fly.static : !fly.int_tuple<(2E0, 8E1)>
%b = fly.static : !fly.int_tuple<(2, 4)>
// CHECK: fly.int_tuple_div(%{{.*}}, %{{.*}})
%result = fly.int_tuple_div(%a, %b) : (!fly.int_tuple<(2E0, 8E1)>, !fly.int_tuple<(2, 4)>) -> !fly.int_tuple<(1E0, 2E1)>
return %result : !fly.int_tuple<(1E0, 2E1)>
}

// -----

// logical_divide partitions a basis-strided identity layout; the algebra walks the
// basis strides (via complement/div) instead of asserting on a non-int leaf.
// CHECK-LABEL: @test_logical_divide_identity
func.func @test_logical_divide_identity() -> !fly.layout<((2, (2, 2)), 4) : ((1E0, (2E0, 1E1)), 2E1)> {
%s = fly.static : !fly.int_tuple<(4, 8)>
%id = fly.make_identity_layout(%s) : (!fly.int_tuple<(4, 8)>) -> !fly.layout<(4, 8) : (1E0, 1E1)>
%ds = fly.static : !fly.int_tuple<(2, 4)>
%dd = fly.static : !fly.int_tuple<(1, 2)>
%div = fly.make_layout(%ds, %dd) : (!fly.int_tuple<(2, 4)>, !fly.int_tuple<(1, 2)>) -> !fly.layout<(2, 4) : (1, 2)>
// CHECK: fly.logical_divide(%{{.*}}, %{{.*}})
%result = fly.logical_divide(%id, %div) : (!fly.layout<(4, 8) : (1E0, 1E1)>, !fly.layout<(2, 4) : (1, 2)>) -> !fly.layout<((2, (2, 2)), 4) : ((1E0, (2E0, 1E1)), 2E1)>
return %result : !fly.layout<((2, (2, 2)), 4) : ((1E0, (2E0, 1E1)), 2E1)>
}
40 changes: 40 additions & 0 deletions tests/mlir/Transforms/layout_lowering.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,43 @@ func.func @test_get_leaves_dynamic_only(%x: i32, %y: i64) -> (i32, i64) {
%0:2 = fly.get_leaves(%t) {dynamicOnly = true} : (!fly.int_tuple<(4, ?, ?{i64})>) -> (i32, i64)
return %0#0, %0#1 : i32, i64
}

// -----

// === EqualOp Lowering: basis (scaled-basis / E<I>) leaves ===

// equal on identical basis strides folds to true.
// CHECK-LABEL: @test_equal_basis_same
func.func @test_equal_basis_same() -> i1 {
%a = fly.make_int_tuple() : () -> !fly.int_tuple<(1E0, 1E1)>
// CHECK: %[[T:.*]] = arith.constant true
// CHECK: return %[[T]]
%r = fly.equal(%a, %a) : (!fly.int_tuple<(1E0, 1E1)>, !fly.int_tuple<(1E0, 1E1)>) -> i1
return %r : i1
}

// -----

// equal on basis leaves with different modes folds to false (E0 != E1).
// CHECK-LABEL: @test_equal_basis_diff_modes
func.func @test_equal_basis_diff_modes() -> i1 {
%a = fly.make_int_tuple() : () -> !fly.int_tuple<(1E0)>
%b = fly.make_int_tuple() : () -> !fly.int_tuple<(1E1)>
// CHECK: %[[F:.*]] = arith.constant false
// CHECK: return %[[F]]
%r = fly.equal(%a, %b) : (!fly.int_tuple<(1E0)>, !fly.int_tuple<(1E1)>) -> i1
return %r : i1
}

// -----

// a basis monomial never equals a plain integer leaf.
// CHECK-LABEL: @test_equal_int_vs_basis
func.func @test_equal_int_vs_basis() -> i1 {
%a = fly.make_int_tuple() : () -> !fly.int_tuple<(1)>
%b = fly.make_int_tuple() : () -> !fly.int_tuple<(1E0)>
// CHECK: %[[F:.*]] = arith.constant false
// CHECK: return %[[F]]
%r = fly.equal(%a, %b) : (!fly.int_tuple<(1)>, !fly.int_tuple<(1E0)>) -> i1
return %r : i1
}
35 changes: 35 additions & 0 deletions tests/unit/test_layout_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,41 @@ def build():
_build_and_verify("zipped_tiled_flat_product", build, [360])


# ==============================================================================
# Basis strides (scaled-basis / CuTe E<I>) via fx.E
# ==============================================================================


def test_basis_stride_via_E():
"""fx.E builds the same basis-strided layout as make_identity_layout."""

def build():
fx.make_layout(fx.make_shape(4, 8), fx.make_stride(fx.E(0), fx.E(1)))
fx.make_identity_layout(fx.make_shape(4, 8))
fx.make_basis_stride(1, (0, 1))
fx.make_stride(fx.E(0, 1, value=2))

def check(ir):
# fx.E(0), fx.E(1) produce a (1E0, 1E1) stride, identical to make_identity_layout.
assert "!fly.layout<(4,8):(1E0,1E1)>" in ir
# make_basis_stride(1, (0, 1)) yields the same flat basis stride.
assert "!fly.int_tuple<(1E0,1E1)>" in ir
# value and multi-mode forms: fx.E(0, 1, value=2) -> 2E0E1.
assert "!fly.int_tuple<(2E0E1)>" in ir

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 7eab7d4check() now matches against "".join(ir.split()), so the assertions no longer depend on the printer's spacing around commas/colons.


_build_and_verify_ir("basis_stride_via_E", build, check)


def test_basis_identity_size():
"""A basis-strided identity layout lowers through the pipeline; size = 4*8 = 32."""

def build():
layout = fx.make_layout(fx.make_shape(4, 8), fx.make_stride(fx.E(0), fx.E(1)))
return [fx.size(layout)]

_build_and_verify("basis_identity_size", build, [32])


# ==============================================================================
# Main
# ==============================================================================
Expand Down