Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/layout_system_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ row_major = fx.make_ordered_layout((M, N), order=(1, 0)) # stride order: N-firs
# Identity layout / tensor
identity = fx.make_identity_layout((M, N))
id_tensor = fx.make_identity_tensor((M, N))

# Scaled-basis strides (CuTe E<I>): fx.E(mode, *, value=1)
# fx.E(0) -> 1E0, fx.E(0, value=2) -> 2E0, fx.E(0, 1) -> 1E0E1
basis_layout = fx.make_layout((4, 8), fx.make_stride(fx.E(0), fx.E(1))) # same as make_identity_layout((4, 8))
flat_basis = fx.make_basis_stride(1, (0, 1)) # !fly.int_tuple<(1E0, 1E1)>
```

---
Expand Down
6 changes: 6 additions & 0 deletions include/flydsl/Dialect/Fly/Utils/LayoutUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,12 @@ std::pair<IntTuple, IntTuple> compositionImpl(const IntTupleBuilder<IntTuple> &b
if (newShape.isStatic() && restShape.isStatic()) {
int64_t restShapeVal = builder.getStaticValue(restShape);
int64_t newShapeVal = builder.getStaticValue(newShape);
// A 0-extent mode here means the divisor does not tile the layout (e.g. an
// overlapping divisor whose complement is ill-formed). Reject it with a
// named assert rather than divide by zero (% below would SIGFPE).
assert(newShapeVal != 0 &&
"composition: divisor is not a tiling layout (a complement mode has 0 "
"extent); its strides must be ordered and divisible");
assert(restShapeVal % newShapeVal == 0);
}

Expand Down
29 changes: 27 additions & 2 deletions lib/Bindings/Python/FlyExtension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,35 @@ struct IntTupleAttrBuilder {
return IntTupleAttr::get(IntAttr::getStatic(ctx, cInt));
} else if (args.is_none()) {
return IntTupleAttr::getLeafNone(ctx);
} else if (nb::hasattr(args, "__fly_basis__") &&
PyObject_IsTrue(nb::object(args.attr("__fly_basis__")).ptr()) == 1) {
Comment on lines +74 to +75
// Scaled-basis stride leaf (e.g. fx.E(0) -> 1E0), duck-typed on a truthy
// __fly_basis__ (Python truthiness, so a falsy marker is not mistaken for a
// basis). The marker is public, so re-validate the payload here:
// accept any sequence/iterable of modes, and reject negative modes (the
// assembly format cannot round-trip a negative E<mode>).
int32_t value = nb::cast<int32_t>(args.attr("value"));
SmallVector<int32_t> modes;
nb::object modesObj = args.attr("modes");
for (auto mode : modesObj) {
int32_t m = nb::cast<int32_t>(mode);
if (m < 0) {
throw std::invalid_argument(
"basis mode must be a non-negative int (the IntTuple assembly format "
"cannot round-trip a negative E<mode>), got " +
std::to_string(m));
}
modes.push_back(m);
}
return IntTupleAttr::get(BasisAttr::get(IntAttr::getStatic(ctx, value), modes));
} else {
if (!nb::hasattr(args, "_CAPIPtr")) {
throw std::invalid_argument("Expected I32, got: " +
std::string(nb::str(nb::type_name(args)).c_str()));
// Report the instance's *type* name. nb::type_name expects a type object,
// so calling it on an arbitrary instance (e.g. an object with a falsy
// __fly_basis__ marker) reinterprets that instance as a PyTypeObject and
// segfaults instead of raising; Py_TYPE(...)->tp_name is always valid.
throw std::invalid_argument(std::string("Expected I32, got: ") +
Py_TYPE(args.ptr())->tp_name);
}
dyncElems.push_back(args);
return IntTupleAttr::get(IntAttr::getDynamic(ctx));
Expand Down
71 changes: 51 additions & 20 deletions lib/Dialect/Fly/Utils/IntTupleUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,58 +138,86 @@ IntTupleAttr IntTupleBuilder<IntTupleAttr>::mul(IntTupleAttr lhs, IntTupleAttr r
}
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::div(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
return IntTupleAttr::get(lhs.getLeafAsInt() / rhs.getLeafAsInt());
// A basis divisor has no quotient mode. This is reachable via complement() of a
// basis-strided (identity) layout of rank >= 2: complementImpl computes
// div(minStride, lastStride) with both leaves basis. Reject it with a named
// assert rather than miscompute a stride.
assert(lhs.isLeaf() && rhs.isLeafInt() &&
"div is undefined for a basis divisor; lower the identity layout first");
if (lhs.isLeafInt()) {
return IntTupleAttr::get(lhs.getLeafAsInt() / rhs.getLeafAsInt());
}
return IntTupleAttr::get(intSafeDiv(lhs.getLeafAsBasis(), rhs.getLeafAsInt()));
Comment on lines +145 to +150
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::mod(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
assert(lhs.isLeafInt() && rhs.isLeafInt() && "mod is defined only on integer leaves");
return IntTupleAttr::get(lhs.getLeafAsInt() % rhs.getLeafAsInt());
}

IntTupleAttr IntTupleBuilder<IntTupleAttr>::logicalAnd(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
assert(lhs.isLeafInt() && rhs.isLeafInt() && "logicalAnd requires integer leaves");
return IntTupleAttr::get(lhs.getLeafAsInt() && rhs.getLeafAsInt());
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::logicalOr(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
assert(lhs.isLeafInt() && rhs.isLeafInt() && "logicalOr requires integer leaves");
return IntTupleAttr::get(lhs.getLeafAsInt() || rhs.getLeafAsInt());
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::logicalNot(IntTupleAttr val) const {
assert(val.isLeafInt());
assert(val.isLeafInt() && "logicalNot requires an integer leaf");
return IntTupleAttr::get(!val.getLeafAsInt());
}

IntTupleAttr IntTupleBuilder<IntTupleAttr>::lt(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
// Ordering is defined only on integer leaves: basis directions are free-module
// generators with no total order, and strides are ordered via getStaticValue.
assert(lhs.isLeafInt() && rhs.isLeafInt() &&
"lt is defined only on integer leaves; basis directions have no order");
return IntTupleAttr::get(lhs.getLeafAsInt() < rhs.getLeafAsInt());
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::le(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
assert(lhs.isLeafInt() && rhs.isLeafInt() &&
"le is defined only on integer leaves; basis directions have no order");
return IntTupleAttr::get(lhs.getLeafAsInt() <= rhs.getLeafAsInt());
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::gt(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
assert(lhs.isLeafInt() && rhs.isLeafInt() &&
"gt is defined only on integer leaves; basis directions have no order");
return IntTupleAttr::get(lhs.getLeafAsInt() > rhs.getLeafAsInt());
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::ge(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
assert(lhs.isLeafInt() && rhs.isLeafInt() &&
"ge is defined only on integer leaves; basis directions have no order");
return IntTupleAttr::get(lhs.getLeafAsInt() >= rhs.getLeafAsInt());
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::eq(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
return IntTupleAttr::get(lhs.getLeafAsInt() == rhs.getLeafAsInt());
assert(lhs.isLeaf() && rhs.isLeaf());
if (lhs.isLeafInt() && rhs.isLeafInt()) {
return IntTupleAttr::get(lhs.getLeafAsInt() == rhs.getLeafAsInt());
}
if (lhs.isLeafBasis() && rhs.isLeafBasis()) {
return IntTupleAttr::get(lhs.getLeafAsBasis() == rhs.getLeafAsBasis());
}
// A scalar leaf and a basis monomial never coincide.
return IntTupleAttr::getLeafStatic(ctx, 0);
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::ne(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
return IntTupleAttr::get(lhs.getLeafAsInt() != rhs.getLeafAsInt());
assert(lhs.isLeaf() && rhs.isLeaf());
if (lhs.isLeafInt() && rhs.isLeafInt()) {
return IntTupleAttr::get(lhs.getLeafAsInt() != rhs.getLeafAsInt());
}
if (lhs.isLeafBasis() && rhs.isLeafBasis()) {
return IntTupleAttr::get(lhs.getLeafAsBasis() != rhs.getLeafAsBasis());
}
// A scalar leaf and a basis monomial never coincide.
return IntTupleAttr::getLeafStatic(ctx, 1);
}

IntTupleAttr IntTupleBuilder<IntTupleAttr>::min(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
assert(lhs.isLeafInt() && rhs.isLeafInt() && "min is defined only on integer leaves");
return IntTupleAttr::get(intMin(lhs.getLeafAsInt(), rhs.getLeafAsInt()));
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::max(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
assert(lhs.isLeafInt() && rhs.isLeafInt() && "max is defined only on integer leaves");
return IntTupleAttr::get(intMax(lhs.getLeafAsInt(), rhs.getLeafAsInt()));
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::safeDiv(IntTupleAttr lhs, IntTupleAttr rhs) const {
Expand All @@ -209,13 +237,14 @@ IntTupleAttr IntTupleBuilder<IntTupleAttr>::ceilDiv(IntTupleAttr lhs, IntTupleAt
}
}
IntTupleAttr IntTupleBuilder<IntTupleAttr>::shapeDiv(IntTupleAttr lhs, IntTupleAttr rhs) const {
assert(lhs.isLeafInt() && rhs.isLeafInt());
assert(lhs.isLeafInt() && rhs.isLeafInt() && "shapeDiv requires integer leaves");
return IntTupleAttr::get(intShapeDiv(lhs.getLeafAsInt(), rhs.getLeafAsInt()));
}

IntTupleAttr IntTupleBuilder<IntTupleAttr>::applySwizzle(IntTupleAttr v,
SwizzleAttr swizzle) const {
assert(v.isLeafInt() && "applySwizzle only supports leafInt IntTupleAttr");
assert(v.isLeafInt() &&
"applySwizzle is undefined on a basis leaf; lower the identity layout first");
return IntTupleAttr::get(intApplySwizzle(v.getLeafAsInt(), swizzle));
}

Expand All @@ -226,8 +255,10 @@ IntTupleAttr IntTupleBuilder<IntTupleAttr>::applyCoordSwizzle(IntTupleAttr coord
}
IntTupleAttr row = selectPath(*this, coord, swizzle.getModeRow());
IntTupleAttr col = selectPath(*this, coord, swizzle.getModeCol());
assert(row.isLeafInt() && "coord swizzle row mode must select a leaf int");
assert(col.isLeafInt() && "coord swizzle col mode must select a leaf int");
assert(row.isLeafInt() &&
"coord swizzle row mode must select an integer leaf (basis unsupported)");
assert(col.isLeafInt() &&
"coord swizzle col mode must select an integer leaf (basis unsupported)");

IntTupleAttr newCol =
IntTupleAttr::get(intApplyCoordSwizzle(row.getLeafAsInt(), col.getLeafAsInt(), swizzle));
Expand Down
73 changes: 73 additions & 0 deletions python/flydsl/expr/primitive.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 FlyDSL Project Contributors

import operator
from enum import IntEnum
from typing import overload

Expand Down Expand Up @@ -89,6 +90,8 @@
"make_int_tuple",
"make_shape",
"make_stride",
"E",
"make_basis_stride",
"make_coord",
"make_layout",
"make_layout_like",
Expand Down Expand Up @@ -356,6 +359,76 @@ def make_stride(*stride, loc=None, ip=None):
return fly.make_stride(IntTupleTy, dyncElems, loc=loc, ip=ip)


class _BasisElem:
"""A scaled-basis stride leaf (CuTe ScaledBasis / E<I>), e.g. ``fx.E(0)`` -> ``1E0``.

Duck-typed by the C++ int-tuple builder via the ``__fly_basis__`` marker. It is a
compile-time-only value (no runtime ``ir.Value`` operand), accepted anywhere a
stride entry is, e.g. ``make_stride(E(0), E(1))``.
"""

__fly_basis__ = True

def __init__(self, value, modes):
self.value = value
self.modes = modes


_INT32_MIN = -(2**31)
_INT32_MAX = 2**31 - 1


def _as_int32(name, x, *, nonneg=False):
"""Validate that ``x`` is an int32 (optionally non-negative) and return it as ``int``.

Catches bad inputs in Python with an actionable error instead of deferring to
the nanobind ``int32_t`` cast on the C++ side. Accepts any integer-like value
(Python ``int``, NumPy integer scalars, ...) via ``operator.index`` while
rejecting ``float``/``str``. ``bool`` is rejected even though it subclasses
``int`` -- ``E(True)`` is never intentional.
"""
if isinstance(x, bool):
raise TypeError(f"E {name} must be an int, got bool: {x!r}")
try:
xi = operator.index(x)
except TypeError:
raise TypeError(f"E {name} must be an int, got {type(x).__name__}: {x!r}") from None
lo = 0 if nonneg else _INT32_MIN
if not (lo <= xi <= _INT32_MAX):
kind = "a non-negative int32" if nonneg else "an int32"
raise ValueError(f"E {name} must be {kind} ({lo}..{_INT32_MAX}), got {xi}")
return xi


def E(*modes, value=1):
"""Build a scaled-basis stride leaf.

Examples:
E(0) -> 1E0
E(1) -> 1E1
E(0, value=2) -> 2E0
E(0, 1) -> 1E0E1 (a single leaf along modes 0 and 1)

Modes must be non-negative int32 values: the ``!fly.int_tuple`` assembly format
cannot round-trip a negative ``E<mode>`` (``1E-1`` fails to re-parse). The
coefficient ``value`` may be any int32 (``-2E0`` and ``0E0`` round-trip fine).
"""
if not modes:
raise ValueError("E requires at least one mode")
value = _as_int32("value", value)
modes = [_as_int32("mode", m, nonneg=True) for m in modes]
return _BasisElem(value, modes)


def make_basis_stride(value, modes):
"""Build a flat basis stride, one ``E`` leaf per mode.

``make_basis_stride(1, (0, 1))`` -> ``(1E0, 1E1)``, the stride
``make_identity_layout`` emits for that rank.
"""
return make_stride(*[E(m, value=value) for m in modes])


@traced_op
def make_coord(*coord, loc=None, ip=None):
"""Build a coordinate used for indexing / slicing a layout.
Expand Down
56 changes: 56 additions & 0 deletions tests/mlir/LayoutAlgebra/basis.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright (c) 2025 FlyDSL Project Contributors
// RUN: %fly-opt %s | FileCheck %s

// Tests for BasisAttr (scaled-basis / CuTe E<I>) stride leaves flowing through
// the IntTupleBuilder ops: division of a basis coefficient, and a full identity
// layout (1E0, 1E1, ...) through logical_divide. The inferred result types below
// are checked by fly-opt's type inference (a wrong type fails to parse).

// -----

// div divides the scalar coefficient of each basis leaf, keeping its modes:
// (2E0, 8E1) / (2, 4) = (1E0, 2E1)
// CHECK-LABEL: @test_int_tuple_div_basis
func.func @test_int_tuple_div_basis() -> !fly.int_tuple<(1E0, 2E1)> {
%a = fly.static : !fly.int_tuple<(2E0, 8E1)>
%b = fly.static : !fly.int_tuple<(2, 4)>
// CHECK: fly.int_tuple_div(%{{.*}}, %{{.*}})
%result = fly.int_tuple_div(%a, %b) : (!fly.int_tuple<(2E0, 8E1)>, !fly.int_tuple<(2, 4)>) -> !fly.int_tuple<(1E0, 2E1)>
return %result : !fly.int_tuple<(1E0, 2E1)>
}

// -----

// logical_divide partitions a basis-strided identity layout; the algebra walks the
// basis strides (via complement/div) instead of asserting on a non-int leaf.
// CHECK-LABEL: @test_logical_divide_identity
func.func @test_logical_divide_identity() -> !fly.layout<((2, (2, 2)), 4) : ((1E0, (2E0, 1E1)), 2E1)> {
%s = fly.static : !fly.int_tuple<(4, 8)>
%id = fly.make_identity_layout(%s) : (!fly.int_tuple<(4, 8)>) -> !fly.layout<(4, 8) : (1E0, 1E1)>
%ds = fly.static : !fly.int_tuple<(2, 4)>
%dd = fly.static : !fly.int_tuple<(1, 2)>
%div = fly.make_layout(%ds, %dd) : (!fly.int_tuple<(2, 4)>, !fly.int_tuple<(1, 2)>) -> !fly.layout<(2, 4) : (1, 2)>
// CHECK: fly.logical_divide(%{{.*}}, %{{.*}})
%result = fly.logical_divide(%id, %div) : (!fly.layout<(4, 8) : (1E0, 1E1)>, !fly.layout<(2, 4) : (1, 2)>) -> !fly.layout<((2, (2, 2)), 4) : ((1E0, (2E0, 1E1)), 2E1)>
return %result : !fly.layout<((2, (2, 2)), 4) : ((1E0, (2E0, 1E1)), 2E1)>
}

// -----

// Rank-3 identity logical_divide with a valid (tiling) divisor walks the basis
// strides through complement/composition without crashing. A non-tiling divisor
// such as (2,4,2):(1,2,4) -- whose complement has a 0-extent mode -- is rejected
// by an assert in compositionImpl instead of dividing by zero (see issue #574).
// CHECK-LABEL: @test_logical_divide_identity_rank3
func.func @test_logical_divide_identity_rank3()
-> !fly.layout<((2, (2, 2), 2), (2, 2)) : ((1E0, (2E0, 1E1), 2E1), (4E1, 1E2))> {
%s = fly.static : !fly.int_tuple<(4, 8, 2)>
%id = fly.make_identity_layout(%s) : (!fly.int_tuple<(4, 8, 2)>) -> !fly.layout<(4, 8, 2) : (1E0, 1E1, 1E2)>
%ds = fly.static : !fly.int_tuple<(2, 4, 2)>
%dd = fly.static : !fly.int_tuple<(1, 2, 8)>
%div = fly.make_layout(%ds, %dd) : (!fly.int_tuple<(2, 4, 2)>, !fly.int_tuple<(1, 2, 8)>) -> !fly.layout<(2, 4, 2) : (1, 2, 8)>
// CHECK: fly.logical_divide(%{{.*}}, %{{.*}})
%result = fly.logical_divide(%id, %div) : (!fly.layout<(4, 8, 2) : (1E0, 1E1, 1E2)>, !fly.layout<(2, 4, 2) : (1, 2, 8)>) -> !fly.layout<((2, (2, 2), 2), (2, 2)) : ((1E0, (2E0, 1E1), 2E1), (4E1, 1E2))>
return %result : !fly.layout<((2, (2, 2), 2), (2, 2)) : ((1E0, (2E0, 1E1), 2E1), (4E1, 1E2))>
}
40 changes: 40 additions & 0 deletions tests/mlir/Transforms/layout_lowering.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,43 @@ func.func @test_get_leaves_dynamic_only(%x: i32, %y: i64) -> (i32, i64) {
%0:2 = fly.get_leaves(%t) {dynamicOnly = true} : (!fly.int_tuple<(4, ?, ?{i64})>) -> (i32, i64)
return %0#0, %0#1 : i32, i64
}

// -----

// === EqualOp Lowering: basis (scaled-basis / E<I>) leaves ===

// equal on identical basis strides folds to true.
// CHECK-LABEL: @test_equal_basis_same
func.func @test_equal_basis_same() -> i1 {
%a = fly.make_int_tuple() : () -> !fly.int_tuple<(1E0, 1E1)>
// CHECK: %[[T:.*]] = arith.constant true
// CHECK: return %[[T]]
%r = fly.equal(%a, %a) : (!fly.int_tuple<(1E0, 1E1)>, !fly.int_tuple<(1E0, 1E1)>) -> i1
return %r : i1
}

// -----

// equal on basis leaves with different modes folds to false (E0 != E1).
// CHECK-LABEL: @test_equal_basis_diff_modes
func.func @test_equal_basis_diff_modes() -> i1 {
%a = fly.make_int_tuple() : () -> !fly.int_tuple<(1E0)>
%b = fly.make_int_tuple() : () -> !fly.int_tuple<(1E1)>
// CHECK: %[[F:.*]] = arith.constant false
// CHECK: return %[[F]]
%r = fly.equal(%a, %b) : (!fly.int_tuple<(1E0)>, !fly.int_tuple<(1E1)>) -> i1
return %r : i1
}

// -----

// a basis monomial never equals a plain integer leaf.
// CHECK-LABEL: @test_equal_int_vs_basis
func.func @test_equal_int_vs_basis() -> i1 {
%a = fly.make_int_tuple() : () -> !fly.int_tuple<(1)>
%b = fly.make_int_tuple() : () -> !fly.int_tuple<(1E0)>
// CHECK: %[[F:.*]] = arith.constant false
// CHECK: return %[[F]]
%r = fly.equal(%a, %b) : (!fly.int_tuple<(1)>, !fly.int_tuple<(1E0)>) -> i1
return %r : i1
}
Loading