From 54148edaad60991156fe24dbf3daf2e7d7c8655d Mon Sep 17 00:00:00 2001 From: Jin Pan Date: Tue, 2 Jun 2026 05:29:14 +0000 Subject: [PATCH 1/3] [Feat] Complete BasisAttr support in IntTupleBuilder (#574) Extend IntTupleBuilder to accept scaled-basis (BasisAttr / CuTe E) stride leaves where the algebra is well-defined, and add a Python construction surface for them. - div(Basis, Int): divide the coefficient, keep modes (reuse intSafeDiv). - eq / ne: compare basis leaves by (coefficient, modes); a basis monomial never equals a plain integer leaf. - Remaining ops (mod, lt/le/gt/ge, min/max, shapeDiv, logical*, swizzle) stay integer-only -- basis is either ill-posed there or structurally unreachable -- but now carry precise assert messages instead of bare leaf-int asserts. - Python: fx.E(mode, *, value=1) and fx.make_basis_stride(value, modes), wired through the int-tuple builder via a __fly_basis__ marker. div(Basis, Basis) is reachable for rank>=3 identity layouts via complement; it has no quotient mode, so it is rejected with a named assert rather than miscomputing a stride. Located op-layer diagnostics are out of scope here (IntTupleBuilder carries no Location) and belong to the sibling issue #583. Tests: tests/mlir/LayoutAlgebra/basis.mlir (div + identity logical_divide), equal folding in tests/mlir/Transforms/layout_lowering.mlir, and Python surface + pipeline tests in tests/unit/test_layout_algebra.py. Co-Authored-By: Claude Opus 4.8 (1M context) --- docs/layout_system_guide.md | 5 ++ lib/Bindings/Python/FlyExtension.cpp | 8 +++ lib/Dialect/Fly/Utils/IntTupleUtils.cpp | 68 +++++++++++++++------- python/flydsl/expr/primitive.py | 40 +++++++++++++ tests/mlir/LayoutAlgebra/basis.mlir | 37 ++++++++++++ tests/mlir/Transforms/layout_lowering.mlir | 40 +++++++++++++ tests/unit/test_layout_algebra.py | 35 +++++++++++ 7 files changed, 213 insertions(+), 20 deletions(-) create mode 100644 tests/mlir/LayoutAlgebra/basis.mlir diff --git a/docs/layout_system_guide.md b/docs/layout_system_guide.md index 5224f2b83..b622e2d85 100644 --- a/docs/layout_system_guide.md +++ b/docs/layout_system_guide.md @@ -102,6 +102,11 @@ row_major = fx.make_ordered_layout((M, N), order=(1, 0)) # stride order: N-firs # Identity layout / tensor identity = fx.make_identity_layout((M, N)) id_tensor = fx.make_identity_tensor((M, N)) + +# Scaled-basis strides (CuTe E): fx.E(mode, *, value=1) +# fx.E(0) -> 1E0, fx.E(0, value=2) -> 2E0, fx.E(0, 1) -> 1E0E1 +basis_layout = fx.make_layout((4, 8), fx.make_stride(fx.E(0), fx.E(1))) # same as make_identity_layout((4, 8)) +flat_basis = fx.make_basis_stride(1, (0, 1)) # !fly.int_tuple<(1E0, 1E1)> ``` --- diff --git a/lib/Bindings/Python/FlyExtension.cpp b/lib/Bindings/Python/FlyExtension.cpp index aead80cfc..f68947ecc 100644 --- a/lib/Bindings/Python/FlyExtension.cpp +++ b/lib/Bindings/Python/FlyExtension.cpp @@ -71,6 +71,14 @@ struct IntTupleAttrBuilder { return IntTupleAttr::get(IntAttr::getStatic(ctx, cInt)); } else if (args.is_none()) { return IntTupleAttr::getLeafNone(ctx); + } else if (nb::hasattr(args, "__fly_basis__")) { + // Scaled-basis stride leaf (e.g. fx.E(0) -> 1E0), duck-typed on __fly_basis__. + int32_t value = nb::cast(args.attr("value")); + SmallVector modes; + for (auto mode : nb::cast(args.attr("modes"))) { + modes.push_back(nb::cast(mode)); + } + return IntTupleAttr::get(BasisAttr::get(IntAttr::getStatic(ctx, value), modes)); } else { if (!nb::hasattr(args, "_CAPIPtr")) { throw std::invalid_argument("Expected I32, got: " + diff --git a/lib/Dialect/Fly/Utils/IntTupleUtils.cpp b/lib/Dialect/Fly/Utils/IntTupleUtils.cpp index 2027f1aa8..d874ead25 100644 --- a/lib/Dialect/Fly/Utils/IntTupleUtils.cpp +++ b/lib/Dialect/Fly/Utils/IntTupleUtils.cpp @@ -138,58 +138,83 @@ IntTupleAttr IntTupleBuilder::mul(IntTupleAttr lhs, IntTupleAttr r } } IntTupleAttr IntTupleBuilder::div(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); - return IntTupleAttr::get(lhs.getLeafAsInt() / rhs.getLeafAsInt()); + // A basis divisor has no quotient mode (reachable for rank>=3 identity layouts + // via complement), so reject it rather than miscompute a stride. + assert(lhs.isLeaf() && rhs.isLeafInt() && + "div is undefined for a basis divisor; lower the identity layout first"); + if (lhs.isLeafInt()) { + return IntTupleAttr::get(lhs.getLeafAsInt() / rhs.getLeafAsInt()); + } + return IntTupleAttr::get(intSafeDiv(lhs.getLeafAsBasis(), rhs.getLeafAsInt())); } IntTupleAttr IntTupleBuilder::mod(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + assert(lhs.isLeafInt() && rhs.isLeafInt() && "mod is defined only on integer leaves"); return IntTupleAttr::get(lhs.getLeafAsInt() % rhs.getLeafAsInt()); } IntTupleAttr IntTupleBuilder::logicalAnd(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + assert(lhs.isLeafInt() && rhs.isLeafInt() && "logicalAnd requires integer leaves"); return IntTupleAttr::get(lhs.getLeafAsInt() && rhs.getLeafAsInt()); } IntTupleAttr IntTupleBuilder::logicalOr(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + assert(lhs.isLeafInt() && rhs.isLeafInt() && "logicalOr requires integer leaves"); return IntTupleAttr::get(lhs.getLeafAsInt() || rhs.getLeafAsInt()); } IntTupleAttr IntTupleBuilder::logicalNot(IntTupleAttr val) const { - assert(val.isLeafInt()); + assert(val.isLeafInt() && "logicalNot requires an integer leaf"); return IntTupleAttr::get(!val.getLeafAsInt()); } IntTupleAttr IntTupleBuilder::lt(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + // Ordering is defined only on integer leaves: basis directions are free-module + // generators with no total order, and strides are ordered via getStaticValue. + assert(lhs.isLeafInt() && rhs.isLeafInt() && + "lt is defined only on integer leaves; basis directions have no order"); return IntTupleAttr::get(lhs.getLeafAsInt() < rhs.getLeafAsInt()); } IntTupleAttr IntTupleBuilder::le(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + assert(lhs.isLeafInt() && rhs.isLeafInt() && + "le is defined only on integer leaves; basis directions have no order"); return IntTupleAttr::get(lhs.getLeafAsInt() <= rhs.getLeafAsInt()); } IntTupleAttr IntTupleBuilder::gt(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + assert(lhs.isLeafInt() && rhs.isLeafInt() && + "gt is defined only on integer leaves; basis directions have no order"); return IntTupleAttr::get(lhs.getLeafAsInt() > rhs.getLeafAsInt()); } IntTupleAttr IntTupleBuilder::ge(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + assert(lhs.isLeafInt() && rhs.isLeafInt() && + "ge is defined only on integer leaves; basis directions have no order"); return IntTupleAttr::get(lhs.getLeafAsInt() >= rhs.getLeafAsInt()); } IntTupleAttr IntTupleBuilder::eq(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); - return IntTupleAttr::get(lhs.getLeafAsInt() == rhs.getLeafAsInt()); + assert(lhs.isLeaf() && rhs.isLeaf()); + if (lhs.isLeafInt() && rhs.isLeafInt()) { + return IntTupleAttr::get(lhs.getLeafAsInt() == rhs.getLeafAsInt()); + } + if (lhs.isLeafBasis() && rhs.isLeafBasis()) { + return IntTupleAttr::get(lhs.getLeafAsBasis() == rhs.getLeafAsBasis()); + } + // A scalar leaf and a basis monomial never coincide. + return materializeConstantLeaf(0); } IntTupleAttr IntTupleBuilder::ne(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); - return IntTupleAttr::get(lhs.getLeafAsInt() != rhs.getLeafAsInt()); + assert(lhs.isLeaf() && rhs.isLeaf()); + if (lhs.isLeafInt() && rhs.isLeafInt()) { + return IntTupleAttr::get(lhs.getLeafAsInt() != rhs.getLeafAsInt()); + } + if (lhs.isLeafBasis() && rhs.isLeafBasis()) { + return IntTupleAttr::get(lhs.getLeafAsBasis() != rhs.getLeafAsBasis()); + } + return materializeConstantLeaf(1); } IntTupleAttr IntTupleBuilder::min(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + assert(lhs.isLeafInt() && rhs.isLeafInt() && "min is defined only on integer leaves"); return IntTupleAttr::get(intMin(lhs.getLeafAsInt(), rhs.getLeafAsInt())); } IntTupleAttr IntTupleBuilder::max(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + assert(lhs.isLeafInt() && rhs.isLeafInt() && "max is defined only on integer leaves"); return IntTupleAttr::get(intMax(lhs.getLeafAsInt(), rhs.getLeafAsInt())); } IntTupleAttr IntTupleBuilder::safeDiv(IntTupleAttr lhs, IntTupleAttr rhs) const { @@ -209,13 +234,14 @@ IntTupleAttr IntTupleBuilder::ceilDiv(IntTupleAttr lhs, IntTupleAt } } IntTupleAttr IntTupleBuilder::shapeDiv(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + assert(lhs.isLeafInt() && rhs.isLeafInt() && "shapeDiv requires integer leaves"); return IntTupleAttr::get(intShapeDiv(lhs.getLeafAsInt(), rhs.getLeafAsInt())); } IntTupleAttr IntTupleBuilder::applySwizzle(IntTupleAttr v, SwizzleAttr swizzle) const { - assert(v.isLeafInt() && "applySwizzle only supports leafInt IntTupleAttr"); + assert(v.isLeafInt() && + "applySwizzle is undefined on a basis leaf; lower the identity layout first"); return IntTupleAttr::get(intApplySwizzle(v.getLeafAsInt(), swizzle)); } @@ -226,8 +252,10 @@ IntTupleAttr IntTupleBuilder::applyCoordSwizzle(IntTupleAttr coord } IntTupleAttr row = selectPath(*this, coord, swizzle.getModeRow()); IntTupleAttr col = selectPath(*this, coord, swizzle.getModeCol()); - assert(row.isLeafInt() && "coord swizzle row mode must select a leaf int"); - assert(col.isLeafInt() && "coord swizzle col mode must select a leaf int"); + assert(row.isLeafInt() && + "coord swizzle row mode must select an integer leaf (basis unsupported)"); + assert(col.isLeafInt() && + "coord swizzle col mode must select an integer leaf (basis unsupported)"); IntTupleAttr newCol = IntTupleAttr::get(intApplyCoordSwizzle(row.getLeafAsInt(), col.getLeafAsInt(), swizzle)); diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index 83349dcbb..41be852dd 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -89,6 +89,8 @@ "make_int_tuple", "make_shape", "make_stride", + "E", + "make_basis_stride", "make_coord", "make_layout", "make_layout_like", @@ -356,6 +358,44 @@ def make_stride(*stride, loc=None, ip=None): return fly.make_stride(IntTupleTy, dyncElems, loc=loc, ip=ip) +class _BasisElem: + """A scaled-basis stride leaf (CuTe ScaledBasis / E), e.g. ``fx.E(0)`` -> ``1E0``. + + Duck-typed by the C++ int-tuple builder via the ``__fly_basis__`` marker. It is a + compile-time-only value (no runtime ``ir.Value`` operand), accepted anywhere a + stride entry is, e.g. ``make_stride(E(0), E(1))``. + """ + + __fly_basis__ = True + + def __init__(self, value, modes): + self.value = value + self.modes = modes + + +def E(*modes, value=1): + """Build a scaled-basis stride leaf. + + Examples: + E(0) -> 1E0 + E(1) -> 1E1 + E(0, value=2) -> 2E0 + E(0, 1) -> 1E0E1 (a single leaf along modes 0 and 1) + """ + if not modes: + raise ValueError("E requires at least one mode") + return _BasisElem(value, list(modes)) + + +def make_basis_stride(value, modes): + """Build a flat basis stride, one ``E`` leaf per mode. + + ``make_basis_stride(1, (0, 1))`` -> ``(1E0, 1E1)``, the stride + ``make_identity_layout`` emits for that rank. + """ + return make_stride(*[E(m, value=value) for m in modes]) + + @traced_op def make_coord(*coord, loc=None, ip=None): """Build a coordinate used for indexing / slicing a layout. diff --git a/tests/mlir/LayoutAlgebra/basis.mlir b/tests/mlir/LayoutAlgebra/basis.mlir new file mode 100644 index 000000000..db4e48bfc --- /dev/null +++ b/tests/mlir/LayoutAlgebra/basis.mlir @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors +// RUN: %fly-opt %s | FileCheck %s + +// Tests for BasisAttr (scaled-basis / CuTe E) stride leaves flowing through +// the IntTupleBuilder ops: division of a basis coefficient, and a full identity +// layout (1E0, 1E1, ...) through logical_divide. The inferred result types below +// are checked by fly-opt's type inference (a wrong type fails to parse). + +// ----- + +// div divides the scalar coefficient of each basis leaf, keeping its modes: +// (2E0, 8E1) / (2, 4) = (1E0, 2E1) +// CHECK-LABEL: @test_int_tuple_div_basis +func.func @test_int_tuple_div_basis() -> !fly.int_tuple<(1E0, 2E1)> { + %a = fly.static : !fly.int_tuple<(2E0, 8E1)> + %b = fly.static : !fly.int_tuple<(2, 4)> + // CHECK: fly.int_tuple_div(%{{.*}}, %{{.*}}) + %result = fly.int_tuple_div(%a, %b) : (!fly.int_tuple<(2E0, 8E1)>, !fly.int_tuple<(2, 4)>) -> !fly.int_tuple<(1E0, 2E1)> + return %result : !fly.int_tuple<(1E0, 2E1)> +} + +// ----- + +// logical_divide partitions a basis-strided identity layout; the algebra walks the +// basis strides (via complement/div) instead of asserting on a non-int leaf. +// CHECK-LABEL: @test_logical_divide_identity +func.func @test_logical_divide_identity() -> !fly.layout<((2, (2, 2)), 4) : ((1E0, (2E0, 1E1)), 2E1)> { + %s = fly.static : !fly.int_tuple<(4, 8)> + %id = fly.make_identity_layout(%s) : (!fly.int_tuple<(4, 8)>) -> !fly.layout<(4, 8) : (1E0, 1E1)> + %ds = fly.static : !fly.int_tuple<(2, 4)> + %dd = fly.static : !fly.int_tuple<(1, 2)> + %div = fly.make_layout(%ds, %dd) : (!fly.int_tuple<(2, 4)>, !fly.int_tuple<(1, 2)>) -> !fly.layout<(2, 4) : (1, 2)> + // CHECK: fly.logical_divide(%{{.*}}, %{{.*}}) + %result = fly.logical_divide(%id, %div) : (!fly.layout<(4, 8) : (1E0, 1E1)>, !fly.layout<(2, 4) : (1, 2)>) -> !fly.layout<((2, (2, 2)), 4) : ((1E0, (2E0, 1E1)), 2E1)> + return %result : !fly.layout<((2, (2, 2)), 4) : ((1E0, (2E0, 1E1)), 2E1)> +} diff --git a/tests/mlir/Transforms/layout_lowering.mlir b/tests/mlir/Transforms/layout_lowering.mlir index f27685ece..03e14a731 100644 --- a/tests/mlir/Transforms/layout_lowering.mlir +++ b/tests/mlir/Transforms/layout_lowering.mlir @@ -384,3 +384,43 @@ func.func @test_get_leaves_dynamic_only(%x: i32, %y: i64) -> (i32, i64) { %0:2 = fly.get_leaves(%t) {dynamicOnly = true} : (!fly.int_tuple<(4, ?, ?{i64})>) -> (i32, i64) return %0#0, %0#1 : i32, i64 } + +// ----- + +// === EqualOp Lowering: basis (scaled-basis / E) leaves === + +// equal on identical basis strides folds to true. +// CHECK-LABEL: @test_equal_basis_same +func.func @test_equal_basis_same() -> i1 { + %a = fly.make_int_tuple() : () -> !fly.int_tuple<(1E0, 1E1)> + // CHECK: %[[T:.*]] = arith.constant true + // CHECK: return %[[T]] + %r = fly.equal(%a, %a) : (!fly.int_tuple<(1E0, 1E1)>, !fly.int_tuple<(1E0, 1E1)>) -> i1 + return %r : i1 +} + +// ----- + +// equal on basis leaves with different modes folds to false (E0 != E1). +// CHECK-LABEL: @test_equal_basis_diff_modes +func.func @test_equal_basis_diff_modes() -> i1 { + %a = fly.make_int_tuple() : () -> !fly.int_tuple<(1E0)> + %b = fly.make_int_tuple() : () -> !fly.int_tuple<(1E1)> + // CHECK: %[[F:.*]] = arith.constant false + // CHECK: return %[[F]] + %r = fly.equal(%a, %b) : (!fly.int_tuple<(1E0)>, !fly.int_tuple<(1E1)>) -> i1 + return %r : i1 +} + +// ----- + +// a basis monomial never equals a plain integer leaf. +// CHECK-LABEL: @test_equal_int_vs_basis +func.func @test_equal_int_vs_basis() -> i1 { + %a = fly.make_int_tuple() : () -> !fly.int_tuple<(1)> + %b = fly.make_int_tuple() : () -> !fly.int_tuple<(1E0)> + // CHECK: %[[F:.*]] = arith.constant false + // CHECK: return %[[F]] + %r = fly.equal(%a, %b) : (!fly.int_tuple<(1)>, !fly.int_tuple<(1E0)>) -> i1 + return %r : i1 +} diff --git a/tests/unit/test_layout_algebra.py b/tests/unit/test_layout_algebra.py index 4f7a55d6c..cf74e2d72 100644 --- a/tests/unit/test_layout_algebra.py +++ b/tests/unit/test_layout_algebra.py @@ -427,6 +427,41 @@ def build(): _build_and_verify("zipped_tiled_flat_product", build, [360]) +# ============================================================================== +# Basis strides (scaled-basis / CuTe E) via fx.E +# ============================================================================== + + +def test_basis_stride_via_E(): + """fx.E builds the same basis-strided layout as make_identity_layout.""" + + def build(): + fx.make_layout(fx.make_shape(4, 8), fx.make_stride(fx.E(0), fx.E(1))) + fx.make_identity_layout(fx.make_shape(4, 8)) + fx.make_basis_stride(1, (0, 1)) + fx.make_stride(fx.E(0, 1, value=2)) + + def check(ir): + # fx.E(0), fx.E(1) produce a (1E0, 1E1) stride, identical to make_identity_layout. + assert "!fly.layout<(4,8):(1E0,1E1)>" in ir + # make_basis_stride(1, (0, 1)) yields the same flat basis stride. + assert "!fly.int_tuple<(1E0,1E1)>" in ir + # value and multi-mode forms: fx.E(0, 1, value=2) -> 2E0E1. + assert "!fly.int_tuple<(2E0E1)>" in ir + + _build_and_verify_ir("basis_stride_via_E", build, check) + + +def test_basis_identity_size(): + """A basis-strided identity layout lowers through the pipeline; size = 4*8 = 32.""" + + def build(): + layout = fx.make_layout(fx.make_shape(4, 8), fx.make_stride(fx.E(0), fx.E(1))) + return [fx.size(layout)] + + _build_and_verify("basis_identity_size", build, [32]) + + # ============================================================================== # Main # ============================================================================== From 7eab7d4fd359be56ecccde713242dccb4752031d Mon Sep 17 00:00:00 2001 From: Jin Pan Date: Tue, 2 Jun 2026 06:42:28 +0000 Subject: [PATCH 2/3] [Feat] Address PR #605 review on BasisAttr IntTupleBuilder support (#574) Resolve the inline review comments: - E()/make_basis_stride: validate value and modes as int32 in Python (operator.index, so NumPy integer scalars are accepted) and reject negative modes -- the IntTuple assembly format cannot round-trip a negative E (1E-1 fails to re-parse). Mirror the non-negative mode check in the public __fly_basis__ binding hook. - FlyExtension binding: gate the basis branch on a *truthy* __fly_basis__ (PyObject_IsTrue, not a strict Py_True/Py_False cast that throws on non-bool) so a falsy marker is not mistaken for a basis; read modes from any iterable, not only a list. - IntTupleUtils eq/ne: return an explicit getLeafStatic(0/1) leaf instead of materializeConstantLeaf, consistent with the sibling branches. - IntTupleUtils div(): correct the comment -- a basis divisor is reached via complement() of a rank>=2 identity layout, not logical_divide. - compositionImpl: a non-tiling divisor yields a 0-extent complement mode; assert before the % so it fails with a named message instead of a SIGFPE. Pre-existing and not basis-specific (a non-coalescible integer layout with the same divisor crashes identically). Add a rank-3 identity logical_divide regression with a valid tiler. - test_layout_algebra: normalize IR whitespace before matching; add an E() input-validation regression. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../flydsl/Dialect/Fly/Utils/LayoutUtils.h | 6 +++ lib/Bindings/Python/FlyExtension.cpp | 21 ++++++++-- lib/Dialect/Fly/Utils/IntTupleUtils.cpp | 11 +++-- python/flydsl/expr/primitive.py | 35 +++++++++++++++- tests/mlir/LayoutAlgebra/basis.mlir | 19 +++++++++ tests/unit/test_layout_algebra.py | 40 +++++++++++++++++-- 6 files changed, 120 insertions(+), 12 deletions(-) diff --git a/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h b/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h index d1148aa82..3b9315de0 100644 --- a/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h +++ b/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h @@ -737,6 +737,12 @@ std::pair compositionImpl(const IntTupleBuilder &b if (newShape.isStatic() && restShape.isStatic()) { int64_t restShapeVal = builder.getStaticValue(restShape); int64_t newShapeVal = builder.getStaticValue(newShape); + // A 0-extent mode here means the divisor does not tile the layout (e.g. an + // overlapping divisor whose complement is ill-formed). Reject it with a + // named assert rather than divide by zero (% below would SIGFPE). + assert(newShapeVal != 0 && + "composition: divisor is not a tiling layout (a complement mode has 0 " + "extent); its strides must be ordered and divisible"); assert(restShapeVal % newShapeVal == 0); } diff --git a/lib/Bindings/Python/FlyExtension.cpp b/lib/Bindings/Python/FlyExtension.cpp index f68947ecc..6d9a60484 100644 --- a/lib/Bindings/Python/FlyExtension.cpp +++ b/lib/Bindings/Python/FlyExtension.cpp @@ -71,12 +71,25 @@ struct IntTupleAttrBuilder { return IntTupleAttr::get(IntAttr::getStatic(ctx, cInt)); } else if (args.is_none()) { return IntTupleAttr::getLeafNone(ctx); - } else if (nb::hasattr(args, "__fly_basis__")) { - // Scaled-basis stride leaf (e.g. fx.E(0) -> 1E0), duck-typed on __fly_basis__. + } else if (nb::hasattr(args, "__fly_basis__") && + PyObject_IsTrue(nb::object(args.attr("__fly_basis__")).ptr()) == 1) { + // Scaled-basis stride leaf (e.g. fx.E(0) -> 1E0), duck-typed on a truthy + // __fly_basis__ (Python truthiness, so a falsy marker is not mistaken for a + // basis). The marker is public, so re-validate the payload here: + // accept any sequence/iterable of modes, and reject negative modes (the + // assembly format cannot round-trip a negative E). int32_t value = nb::cast(args.attr("value")); SmallVector modes; - for (auto mode : nb::cast(args.attr("modes"))) { - modes.push_back(nb::cast(mode)); + nb::object modesObj = args.attr("modes"); + for (auto mode : modesObj) { + int32_t m = nb::cast(mode); + if (m < 0) { + throw std::invalid_argument( + "basis mode must be a non-negative int (the IntTuple assembly format " + "cannot round-trip a negative E), got " + + std::to_string(m)); + } + modes.push_back(m); } return IntTupleAttr::get(BasisAttr::get(IntAttr::getStatic(ctx, value), modes)); } else { diff --git a/lib/Dialect/Fly/Utils/IntTupleUtils.cpp b/lib/Dialect/Fly/Utils/IntTupleUtils.cpp index d874ead25..b4e8ee31f 100644 --- a/lib/Dialect/Fly/Utils/IntTupleUtils.cpp +++ b/lib/Dialect/Fly/Utils/IntTupleUtils.cpp @@ -138,8 +138,10 @@ IntTupleAttr IntTupleBuilder::mul(IntTupleAttr lhs, IntTupleAttr r } } IntTupleAttr IntTupleBuilder::div(IntTupleAttr lhs, IntTupleAttr rhs) const { - // A basis divisor has no quotient mode (reachable for rank>=3 identity layouts - // via complement), so reject it rather than miscompute a stride. + // A basis divisor has no quotient mode. This is reachable via complement() of a + // basis-strided (identity) layout of rank >= 2: complementImpl computes + // div(minStride, lastStride) with both leaves basis. Reject it with a named + // assert rather than miscompute a stride. assert(lhs.isLeaf() && rhs.isLeafInt() && "div is undefined for a basis divisor; lower the identity layout first"); if (lhs.isLeafInt()) { @@ -196,7 +198,7 @@ IntTupleAttr IntTupleBuilder::eq(IntTupleAttr lhs, IntTupleAttr rh return IntTupleAttr::get(lhs.getLeafAsBasis() == rhs.getLeafAsBasis()); } // A scalar leaf and a basis monomial never coincide. - return materializeConstantLeaf(0); + return IntTupleAttr::getLeafStatic(ctx, 0); } IntTupleAttr IntTupleBuilder::ne(IntTupleAttr lhs, IntTupleAttr rhs) const { assert(lhs.isLeaf() && rhs.isLeaf()); @@ -206,7 +208,8 @@ IntTupleAttr IntTupleBuilder::ne(IntTupleAttr lhs, IntTupleAttr rh if (lhs.isLeafBasis() && rhs.isLeafBasis()) { return IntTupleAttr::get(lhs.getLeafAsBasis() != rhs.getLeafAsBasis()); } - return materializeConstantLeaf(1); + // A scalar leaf and a basis monomial never coincide. + return IntTupleAttr::getLeafStatic(ctx, 1); } IntTupleAttr IntTupleBuilder::min(IntTupleAttr lhs, IntTupleAttr rhs) const { diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index 41be852dd..8902e8913 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2025 FlyDSL Project Contributors +import operator from enum import IntEnum from typing import overload @@ -373,6 +374,32 @@ def __init__(self, value, modes): self.modes = modes +_INT32_MIN = -(2**31) +_INT32_MAX = 2**31 - 1 + + +def _as_int32(name, x, *, nonneg=False): + """Validate that ``x`` is an int32 (optionally non-negative) and return it as ``int``. + + Catches bad inputs in Python with an actionable error instead of deferring to + the nanobind ``int32_t`` cast on the C++ side. Accepts any integer-like value + (Python ``int``, NumPy integer scalars, ...) via ``operator.index`` while + rejecting ``float``/``str``. ``bool`` is rejected even though it subclasses + ``int`` -- ``E(True)`` is never intentional. + """ + if isinstance(x, bool): + raise TypeError(f"E {name} must be an int, got bool: {x!r}") + try: + xi = operator.index(x) + except TypeError: + raise TypeError(f"E {name} must be an int, got {type(x).__name__}: {x!r}") from None + lo = 0 if nonneg else _INT32_MIN + if not (lo <= xi <= _INT32_MAX): + kind = "a non-negative int32" if nonneg else "an int32" + raise ValueError(f"E {name} must be {kind} ({lo}..{_INT32_MAX}), got {xi}") + return xi + + def E(*modes, value=1): """Build a scaled-basis stride leaf. @@ -381,10 +408,16 @@ def E(*modes, value=1): E(1) -> 1E1 E(0, value=2) -> 2E0 E(0, 1) -> 1E0E1 (a single leaf along modes 0 and 1) + + Modes must be non-negative int32 values: the ``!fly.int_tuple`` assembly format + cannot round-trip a negative ``E`` (``1E-1`` fails to re-parse). The + coefficient ``value`` may be any int32 (``-2E0`` and ``0E0`` round-trip fine). """ if not modes: raise ValueError("E requires at least one mode") - return _BasisElem(value, list(modes)) + value = _as_int32("value", value) + modes = [_as_int32("mode", m, nonneg=True) for m in modes] + return _BasisElem(value, modes) def make_basis_stride(value, modes): diff --git a/tests/mlir/LayoutAlgebra/basis.mlir b/tests/mlir/LayoutAlgebra/basis.mlir index db4e48bfc..a39c88f11 100644 --- a/tests/mlir/LayoutAlgebra/basis.mlir +++ b/tests/mlir/LayoutAlgebra/basis.mlir @@ -35,3 +35,22 @@ func.func @test_logical_divide_identity() -> !fly.layout<((2, (2, 2)), 4) : ((1E %result = fly.logical_divide(%id, %div) : (!fly.layout<(4, 8) : (1E0, 1E1)>, !fly.layout<(2, 4) : (1, 2)>) -> !fly.layout<((2, (2, 2)), 4) : ((1E0, (2E0, 1E1)), 2E1)> return %result : !fly.layout<((2, (2, 2)), 4) : ((1E0, (2E0, 1E1)), 2E1)> } + +// ----- + +// Rank-3 identity logical_divide with a valid (tiling) divisor walks the basis +// strides through complement/composition without crashing. A non-tiling divisor +// such as (2,4,2):(1,2,4) -- whose complement has a 0-extent mode -- is rejected +// by an assert in compositionImpl instead of dividing by zero (see issue #574). +// CHECK-LABEL: @test_logical_divide_identity_rank3 +func.func @test_logical_divide_identity_rank3() + -> !fly.layout<((2, (2, 2), 2), (2, 2)) : ((1E0, (2E0, 1E1), 2E1), (4E1, 1E2))> { + %s = fly.static : !fly.int_tuple<(4, 8, 2)> + %id = fly.make_identity_layout(%s) : (!fly.int_tuple<(4, 8, 2)>) -> !fly.layout<(4, 8, 2) : (1E0, 1E1, 1E2)> + %ds = fly.static : !fly.int_tuple<(2, 4, 2)> + %dd = fly.static : !fly.int_tuple<(1, 2, 8)> + %div = fly.make_layout(%ds, %dd) : (!fly.int_tuple<(2, 4, 2)>, !fly.int_tuple<(1, 2, 8)>) -> !fly.layout<(2, 4, 2) : (1, 2, 8)> + // CHECK: fly.logical_divide(%{{.*}}, %{{.*}}) + %result = fly.logical_divide(%id, %div) : (!fly.layout<(4, 8, 2) : (1E0, 1E1, 1E2)>, !fly.layout<(2, 4, 2) : (1, 2, 8)>) -> !fly.layout<((2, (2, 2), 2), (2, 2)) : ((1E0, (2E0, 1E1), 2E1), (4E1, 1E2))> + return %result : !fly.layout<((2, (2, 2), 2), (2, 2)) : ((1E0, (2E0, 1E1), 2E1), (4E1, 1E2))> +} diff --git a/tests/unit/test_layout_algebra.py b/tests/unit/test_layout_algebra.py index cf74e2d72..a4a4367e4 100644 --- a/tests/unit/test_layout_algebra.py +++ b/tests/unit/test_layout_algebra.py @@ -442,16 +442,50 @@ def build(): fx.make_stride(fx.E(0, 1, value=2)) def check(ir): + # Normalize whitespace so the checks do not depend on the MLIR printer's + # exact spacing around commas/colons. + compact = "".join(ir.split()) # fx.E(0), fx.E(1) produce a (1E0, 1E1) stride, identical to make_identity_layout. - assert "!fly.layout<(4,8):(1E0,1E1)>" in ir + assert "!fly.layout<(4,8):(1E0,1E1)>" in compact # make_basis_stride(1, (0, 1)) yields the same flat basis stride. - assert "!fly.int_tuple<(1E0,1E1)>" in ir + assert "!fly.int_tuple<(1E0,1E1)>" in compact # value and multi-mode forms: fx.E(0, 1, value=2) -> 2E0E1. - assert "!fly.int_tuple<(2E0E1)>" in ir + assert "!fly.int_tuple<(2E0E1)>" in compact _build_and_verify_ir("basis_stride_via_E", build, check) +def test_basis_E_rejects_invalid_modes(): + """fx.E validates its inputs in Python instead of deferring to the C++ cast. + + Negative modes are the key case: the !fly.int_tuple assembly format cannot + round-trip a negative E (`1E-1` fails to re-parse), so they must be + rejected at construction time. + """ + with pytest.raises(ValueError): + fx.E() # at least one mode required + with pytest.raises(ValueError): + fx.E(-1) # negative mode is not round-trippable + with pytest.raises(ValueError): + fx.make_basis_stride(1, (0, -2)) # negative mode via the flat helper + with pytest.raises(TypeError): + fx.E(0.0) # non-int mode + with pytest.raises(TypeError): + fx.E(0, value="x") # non-int coefficient + with pytest.raises(ValueError): + fx.E(2**31) # mode out of int32 range + + # The coefficient may be any int32, including negative/zero (these round-trip). + assert fx.E(0, value=-2).value == -2 + assert fx.E(0, value=0).value == 0 + + # Integer-like values (NumPy integer scalars) are accepted and normalized to int. + np = pytest.importorskip("numpy") + elem = fx.E(np.int32(2), value=np.int64(3)) + assert elem.modes == [2] and elem.value == 3 + assert all(isinstance(m, int) for m in elem.modes) and isinstance(elem.value, int) + + def test_basis_identity_size(): """A basis-strided identity layout lowers through the pipeline; size = 4*8 = 32.""" From 7255fff8b8d984dbd2e8bdc534fa68ec29d9654b Mon Sep 17 00:00:00 2001 From: Jin Pan Date: Tue, 2 Jun 2026 07:05:49 +0000 Subject: [PATCH 3/3] [Fix] make_stride: raise instead of segfault on a non-stride object (#605 review) The generic fallback in IntTupleAttrBuilder reported the rejected value's type via nb::type_name(args), but nb_type_name expects a *type* object -- passing an instance reinterprets it as a PyTypeObject and segfaults (and on Python <3.11 the __name__ lookup on the instance yields a NULL that feeds PyUnicode_FromFormat). Use Py_TYPE(args.ptr())->tp_name, which is always valid, so any non-stride object (e.g. an object exposing a falsy __fly_basis__ marker) raises a clean ValueError. Add a regression. Co-Authored-By: Claude Opus 4.8 (1M context) --- lib/Bindings/Python/FlyExtension.cpp | 8 ++++++-- tests/unit/test_layout_algebra.py | 27 +++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/lib/Bindings/Python/FlyExtension.cpp b/lib/Bindings/Python/FlyExtension.cpp index 6d9a60484..29c12dce9 100644 --- a/lib/Bindings/Python/FlyExtension.cpp +++ b/lib/Bindings/Python/FlyExtension.cpp @@ -94,8 +94,12 @@ struct IntTupleAttrBuilder { return IntTupleAttr::get(BasisAttr::get(IntAttr::getStatic(ctx, value), modes)); } else { if (!nb::hasattr(args, "_CAPIPtr")) { - throw std::invalid_argument("Expected I32, got: " + - std::string(nb::str(nb::type_name(args)).c_str())); + // Report the instance's *type* name. nb::type_name expects a type object, + // so calling it on an arbitrary instance (e.g. an object with a falsy + // __fly_basis__ marker) reinterprets that instance as a PyTypeObject and + // segfaults instead of raising; Py_TYPE(...)->tp_name is always valid. + throw std::invalid_argument(std::string("Expected I32, got: ") + + Py_TYPE(args.ptr())->tp_name); } dyncElems.push_back(args); return IntTupleAttr::get(IntAttr::getDynamic(ctx)); diff --git a/tests/unit/test_layout_algebra.py b/tests/unit/test_layout_algebra.py index a4a4367e4..918f35785 100644 --- a/tests/unit/test_layout_algebra.py +++ b/tests/unit/test_layout_algebra.py @@ -486,6 +486,33 @@ def test_basis_E_rejects_invalid_modes(): assert all(isinstance(m, int) for m in elem.modes) and isinstance(elem.value, int) +def test_make_stride_rejects_non_stride_object(): + """A non-stride object must raise a clean error from make_stride, not segfault. + + Covers a falsy ``__fly_basis__`` marker (which is correctly not treated as a + basis leaf and falls through to the generic path) and an arbitrary object. + The generic path reports the value's type name, which must not be computed by + reinterpreting the instance as a Python type object. + """ + + class FalsyMarker: + __fly_basis__ = False + value = 1 + modes = [0] + + with Context() as ctx: + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + module = Module.create() + with InsertionPoint(module.body): + f = func.FuncOp("reject", FunctionType.get([], [])) + with InsertionPoint(f.add_entry_block()): + with pytest.raises(ValueError): + fx.make_stride(FalsyMarker()) + with pytest.raises(ValueError): + fx.make_stride(object()) + + def test_basis_identity_size(): """A basis-strided identity layout lowers through the pipeline; size = 4*8 = 32."""