diff --git a/docs/layout_system_guide.md b/docs/layout_system_guide.md index bd3d23396..574f93094 100644 --- a/docs/layout_system_guide.md +++ b/docs/layout_system_guide.md @@ -102,6 +102,11 @@ row_major = fx.make_ordered_layout((M, N), order=(1, 0)) # stride order: N-firs # Identity layout / tensor identity = fx.make_identity_layout((M, N)) id_tensor = fx.make_identity_tensor((M, N)) + +# Scaled-basis strides (CuTe E): fx.E(mode, *, value=1) +# fx.E(0) -> 1E0, fx.E(0, value=2) -> 2E0, fx.E(0, 1) -> 1E0E1 +basis_layout = fx.make_layout((4, 8), fx.make_stride(fx.E(0), fx.E(1))) # same as make_identity_layout((4, 8)) +flat_basis = fx.make_basis_stride(1, (0, 1)) # !fly.int_tuple<(1E0, 1E1)> ``` --- diff --git a/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h b/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h index d1148aa82..3b9315de0 100644 --- a/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h +++ b/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h @@ -737,6 +737,12 @@ std::pair compositionImpl(const IntTupleBuilder &b if (newShape.isStatic() && restShape.isStatic()) { int64_t restShapeVal = builder.getStaticValue(restShape); int64_t newShapeVal = builder.getStaticValue(newShape); + // A 0-extent mode here means the divisor does not tile the layout (e.g. an + // overlapping divisor whose complement is ill-formed). Reject it with a + // named assert rather than divide by zero (% below would SIGFPE). + assert(newShapeVal != 0 && + "composition: divisor is not a tiling layout (a complement mode has 0 " + "extent); its strides must be ordered and divisible"); assert(restShapeVal % newShapeVal == 0); } diff --git a/lib/Bindings/Python/FlyExtension.cpp b/lib/Bindings/Python/FlyExtension.cpp index aead80cfc..29c12dce9 100644 --- a/lib/Bindings/Python/FlyExtension.cpp +++ b/lib/Bindings/Python/FlyExtension.cpp @@ -71,10 +71,35 @@ struct IntTupleAttrBuilder { return IntTupleAttr::get(IntAttr::getStatic(ctx, cInt)); } else if (args.is_none()) { return IntTupleAttr::getLeafNone(ctx); + } else if (nb::hasattr(args, "__fly_basis__") && + PyObject_IsTrue(nb::object(args.attr("__fly_basis__")).ptr()) == 1) { + // Scaled-basis stride leaf (e.g. fx.E(0) -> 1E0), duck-typed on a truthy + // __fly_basis__ (Python truthiness, so a falsy marker is not mistaken for a + // basis). The marker is public, so re-validate the payload here: + // accept any sequence/iterable of modes, and reject negative modes (the + // assembly format cannot round-trip a negative E). + int32_t value = nb::cast(args.attr("value")); + SmallVector modes; + nb::object modesObj = args.attr("modes"); + for (auto mode : modesObj) { + int32_t m = nb::cast(mode); + if (m < 0) { + throw std::invalid_argument( + "basis mode must be a non-negative int (the IntTuple assembly format " + "cannot round-trip a negative E), got " + + std::to_string(m)); + } + modes.push_back(m); + } + return IntTupleAttr::get(BasisAttr::get(IntAttr::getStatic(ctx, value), modes)); } else { if (!nb::hasattr(args, "_CAPIPtr")) { - throw std::invalid_argument("Expected I32, got: " + - std::string(nb::str(nb::type_name(args)).c_str())); + // Report the instance's *type* name. nb::type_name expects a type object, + // so calling it on an arbitrary instance (e.g. an object with a falsy + // __fly_basis__ marker) reinterprets that instance as a PyTypeObject and + // segfaults instead of raising; Py_TYPE(...)->tp_name is always valid. + throw std::invalid_argument(std::string("Expected I32, got: ") + + Py_TYPE(args.ptr())->tp_name); } dyncElems.push_back(args); return IntTupleAttr::get(IntAttr::getDynamic(ctx)); diff --git a/lib/Dialect/Fly/Utils/IntTupleUtils.cpp b/lib/Dialect/Fly/Utils/IntTupleUtils.cpp index 2027f1aa8..b4e8ee31f 100644 --- a/lib/Dialect/Fly/Utils/IntTupleUtils.cpp +++ b/lib/Dialect/Fly/Utils/IntTupleUtils.cpp @@ -138,58 +138,86 @@ IntTupleAttr IntTupleBuilder::mul(IntTupleAttr lhs, IntTupleAttr r } } IntTupleAttr IntTupleBuilder::div(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); - return IntTupleAttr::get(lhs.getLeafAsInt() / rhs.getLeafAsInt()); + // A basis divisor has no quotient mode. This is reachable via complement() of a + // basis-strided (identity) layout of rank >= 2: complementImpl computes + // div(minStride, lastStride) with both leaves basis. Reject it with a named + // assert rather than miscompute a stride. + assert(lhs.isLeaf() && rhs.isLeafInt() && + "div is undefined for a basis divisor; lower the identity layout first"); + if (lhs.isLeafInt()) { + return IntTupleAttr::get(lhs.getLeafAsInt() / rhs.getLeafAsInt()); + } + return IntTupleAttr::get(intSafeDiv(lhs.getLeafAsBasis(), rhs.getLeafAsInt())); } IntTupleAttr IntTupleBuilder::mod(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + assert(lhs.isLeafInt() && rhs.isLeafInt() && "mod is defined only on integer leaves"); return IntTupleAttr::get(lhs.getLeafAsInt() % rhs.getLeafAsInt()); } IntTupleAttr IntTupleBuilder::logicalAnd(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + assert(lhs.isLeafInt() && rhs.isLeafInt() && "logicalAnd requires integer leaves"); return IntTupleAttr::get(lhs.getLeafAsInt() && rhs.getLeafAsInt()); } IntTupleAttr IntTupleBuilder::logicalOr(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + assert(lhs.isLeafInt() && rhs.isLeafInt() && "logicalOr requires integer leaves"); return IntTupleAttr::get(lhs.getLeafAsInt() || rhs.getLeafAsInt()); } IntTupleAttr IntTupleBuilder::logicalNot(IntTupleAttr val) const { - assert(val.isLeafInt()); + assert(val.isLeafInt() && "logicalNot requires an integer leaf"); return IntTupleAttr::get(!val.getLeafAsInt()); } IntTupleAttr IntTupleBuilder::lt(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + // Ordering is defined only on integer leaves: basis directions are free-module + // generators with no total order, and strides are ordered via getStaticValue. + assert(lhs.isLeafInt() && rhs.isLeafInt() && + "lt is defined only on integer leaves; basis directions have no order"); return IntTupleAttr::get(lhs.getLeafAsInt() < rhs.getLeafAsInt()); } IntTupleAttr IntTupleBuilder::le(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + assert(lhs.isLeafInt() && rhs.isLeafInt() && + "le is defined only on integer leaves; basis directions have no order"); return IntTupleAttr::get(lhs.getLeafAsInt() <= rhs.getLeafAsInt()); } IntTupleAttr IntTupleBuilder::gt(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + assert(lhs.isLeafInt() && rhs.isLeafInt() && + "gt is defined only on integer leaves; basis directions have no order"); return IntTupleAttr::get(lhs.getLeafAsInt() > rhs.getLeafAsInt()); } IntTupleAttr IntTupleBuilder::ge(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + assert(lhs.isLeafInt() && rhs.isLeafInt() && + "ge is defined only on integer leaves; basis directions have no order"); return IntTupleAttr::get(lhs.getLeafAsInt() >= rhs.getLeafAsInt()); } IntTupleAttr IntTupleBuilder::eq(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); - return IntTupleAttr::get(lhs.getLeafAsInt() == rhs.getLeafAsInt()); + assert(lhs.isLeaf() && rhs.isLeaf()); + if (lhs.isLeafInt() && rhs.isLeafInt()) { + return IntTupleAttr::get(lhs.getLeafAsInt() == rhs.getLeafAsInt()); + } + if (lhs.isLeafBasis() && rhs.isLeafBasis()) { + return IntTupleAttr::get(lhs.getLeafAsBasis() == rhs.getLeafAsBasis()); + } + // A scalar leaf and a basis monomial never coincide. + return IntTupleAttr::getLeafStatic(ctx, 0); } IntTupleAttr IntTupleBuilder::ne(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); - return IntTupleAttr::get(lhs.getLeafAsInt() != rhs.getLeafAsInt()); + assert(lhs.isLeaf() && rhs.isLeaf()); + if (lhs.isLeafInt() && rhs.isLeafInt()) { + return IntTupleAttr::get(lhs.getLeafAsInt() != rhs.getLeafAsInt()); + } + if (lhs.isLeafBasis() && rhs.isLeafBasis()) { + return IntTupleAttr::get(lhs.getLeafAsBasis() != rhs.getLeafAsBasis()); + } + // A scalar leaf and a basis monomial never coincide. + return IntTupleAttr::getLeafStatic(ctx, 1); } IntTupleAttr IntTupleBuilder::min(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + assert(lhs.isLeafInt() && rhs.isLeafInt() && "min is defined only on integer leaves"); return IntTupleAttr::get(intMin(lhs.getLeafAsInt(), rhs.getLeafAsInt())); } IntTupleAttr IntTupleBuilder::max(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + assert(lhs.isLeafInt() && rhs.isLeafInt() && "max is defined only on integer leaves"); return IntTupleAttr::get(intMax(lhs.getLeafAsInt(), rhs.getLeafAsInt())); } IntTupleAttr IntTupleBuilder::safeDiv(IntTupleAttr lhs, IntTupleAttr rhs) const { @@ -209,13 +237,14 @@ IntTupleAttr IntTupleBuilder::ceilDiv(IntTupleAttr lhs, IntTupleAt } } IntTupleAttr IntTupleBuilder::shapeDiv(IntTupleAttr lhs, IntTupleAttr rhs) const { - assert(lhs.isLeafInt() && rhs.isLeafInt()); + assert(lhs.isLeafInt() && rhs.isLeafInt() && "shapeDiv requires integer leaves"); return IntTupleAttr::get(intShapeDiv(lhs.getLeafAsInt(), rhs.getLeafAsInt())); } IntTupleAttr IntTupleBuilder::applySwizzle(IntTupleAttr v, SwizzleAttr swizzle) const { - assert(v.isLeafInt() && "applySwizzle only supports leafInt IntTupleAttr"); + assert(v.isLeafInt() && + "applySwizzle is undefined on a basis leaf; lower the identity layout first"); return IntTupleAttr::get(intApplySwizzle(v.getLeafAsInt(), swizzle)); } @@ -226,8 +255,10 @@ IntTupleAttr IntTupleBuilder::applyCoordSwizzle(IntTupleAttr coord } IntTupleAttr row = selectPath(*this, coord, swizzle.getModeRow()); IntTupleAttr col = selectPath(*this, coord, swizzle.getModeCol()); - assert(row.isLeafInt() && "coord swizzle row mode must select a leaf int"); - assert(col.isLeafInt() && "coord swizzle col mode must select a leaf int"); + assert(row.isLeafInt() && + "coord swizzle row mode must select an integer leaf (basis unsupported)"); + assert(col.isLeafInt() && + "coord swizzle col mode must select an integer leaf (basis unsupported)"); IntTupleAttr newCol = IntTupleAttr::get(intApplyCoordSwizzle(row.getLeafAsInt(), col.getLeafAsInt(), swizzle)); diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index 10f0de307..37ef5bd8c 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2025 FlyDSL Project Contributors +import operator from enum import IntEnum from typing import overload @@ -89,6 +90,8 @@ "make_int_tuple", "make_shape", "make_stride", + "E", + "make_basis_stride", "make_coord", "make_layout", "make_layout_like", @@ -356,6 +359,76 @@ def make_stride(*stride, loc=None, ip=None): return fly.make_stride(IntTupleTy, dyncElems, loc=loc, ip=ip) +class _BasisElem: + """A scaled-basis stride leaf (CuTe ScaledBasis / E), e.g. ``fx.E(0)`` -> ``1E0``. + + Duck-typed by the C++ int-tuple builder via the ``__fly_basis__`` marker. It is a + compile-time-only value (no runtime ``ir.Value`` operand), accepted anywhere a + stride entry is, e.g. ``make_stride(E(0), E(1))``. + """ + + __fly_basis__ = True + + def __init__(self, value, modes): + self.value = value + self.modes = modes + + +_INT32_MIN = -(2**31) +_INT32_MAX = 2**31 - 1 + + +def _as_int32(name, x, *, nonneg=False): + """Validate that ``x`` is an int32 (optionally non-negative) and return it as ``int``. + + Catches bad inputs in Python with an actionable error instead of deferring to + the nanobind ``int32_t`` cast on the C++ side. Accepts any integer-like value + (Python ``int``, NumPy integer scalars, ...) via ``operator.index`` while + rejecting ``float``/``str``. ``bool`` is rejected even though it subclasses + ``int`` -- ``E(True)`` is never intentional. + """ + if isinstance(x, bool): + raise TypeError(f"E {name} must be an int, got bool: {x!r}") + try: + xi = operator.index(x) + except TypeError: + raise TypeError(f"E {name} must be an int, got {type(x).__name__}: {x!r}") from None + lo = 0 if nonneg else _INT32_MIN + if not (lo <= xi <= _INT32_MAX): + kind = "a non-negative int32" if nonneg else "an int32" + raise ValueError(f"E {name} must be {kind} ({lo}..{_INT32_MAX}), got {xi}") + return xi + + +def E(*modes, value=1): + """Build a scaled-basis stride leaf. + + Examples: + E(0) -> 1E0 + E(1) -> 1E1 + E(0, value=2) -> 2E0 + E(0, 1) -> 1E0E1 (a single leaf along modes 0 and 1) + + Modes must be non-negative int32 values: the ``!fly.int_tuple`` assembly format + cannot round-trip a negative ``E`` (``1E-1`` fails to re-parse). The + coefficient ``value`` may be any int32 (``-2E0`` and ``0E0`` round-trip fine). + """ + if not modes: + raise ValueError("E requires at least one mode") + value = _as_int32("value", value) + modes = [_as_int32("mode", m, nonneg=True) for m in modes] + return _BasisElem(value, modes) + + +def make_basis_stride(value, modes): + """Build a flat basis stride, one ``E`` leaf per mode. + + ``make_basis_stride(1, (0, 1))`` -> ``(1E0, 1E1)``, the stride + ``make_identity_layout`` emits for that rank. + """ + return make_stride(*[E(m, value=value) for m in modes]) + + @traced_op def make_coord(*coord, loc=None, ip=None): """Build a coordinate used for indexing / slicing a layout. diff --git a/tests/mlir/LayoutAlgebra/basis.mlir b/tests/mlir/LayoutAlgebra/basis.mlir new file mode 100644 index 000000000..a39c88f11 --- /dev/null +++ b/tests/mlir/LayoutAlgebra/basis.mlir @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors +// RUN: %fly-opt %s | FileCheck %s + +// Tests for BasisAttr (scaled-basis / CuTe E) stride leaves flowing through +// the IntTupleBuilder ops: division of a basis coefficient, and a full identity +// layout (1E0, 1E1, ...) through logical_divide. The inferred result types below +// are checked by fly-opt's type inference (a wrong type fails to parse). + +// ----- + +// div divides the scalar coefficient of each basis leaf, keeping its modes: +// (2E0, 8E1) / (2, 4) = (1E0, 2E1) +// CHECK-LABEL: @test_int_tuple_div_basis +func.func @test_int_tuple_div_basis() -> !fly.int_tuple<(1E0, 2E1)> { + %a = fly.static : !fly.int_tuple<(2E0, 8E1)> + %b = fly.static : !fly.int_tuple<(2, 4)> + // CHECK: fly.int_tuple_div(%{{.*}}, %{{.*}}) + %result = fly.int_tuple_div(%a, %b) : (!fly.int_tuple<(2E0, 8E1)>, !fly.int_tuple<(2, 4)>) -> !fly.int_tuple<(1E0, 2E1)> + return %result : !fly.int_tuple<(1E0, 2E1)> +} + +// ----- + +// logical_divide partitions a basis-strided identity layout; the algebra walks the +// basis strides (via complement/div) instead of asserting on a non-int leaf. +// CHECK-LABEL: @test_logical_divide_identity +func.func @test_logical_divide_identity() -> !fly.layout<((2, (2, 2)), 4) : ((1E0, (2E0, 1E1)), 2E1)> { + %s = fly.static : !fly.int_tuple<(4, 8)> + %id = fly.make_identity_layout(%s) : (!fly.int_tuple<(4, 8)>) -> !fly.layout<(4, 8) : (1E0, 1E1)> + %ds = fly.static : !fly.int_tuple<(2, 4)> + %dd = fly.static : !fly.int_tuple<(1, 2)> + %div = fly.make_layout(%ds, %dd) : (!fly.int_tuple<(2, 4)>, !fly.int_tuple<(1, 2)>) -> !fly.layout<(2, 4) : (1, 2)> + // CHECK: fly.logical_divide(%{{.*}}, %{{.*}}) + %result = fly.logical_divide(%id, %div) : (!fly.layout<(4, 8) : (1E0, 1E1)>, !fly.layout<(2, 4) : (1, 2)>) -> !fly.layout<((2, (2, 2)), 4) : ((1E0, (2E0, 1E1)), 2E1)> + return %result : !fly.layout<((2, (2, 2)), 4) : ((1E0, (2E0, 1E1)), 2E1)> +} + +// ----- + +// Rank-3 identity logical_divide with a valid (tiling) divisor walks the basis +// strides through complement/composition without crashing. A non-tiling divisor +// such as (2,4,2):(1,2,4) -- whose complement has a 0-extent mode -- is rejected +// by an assert in compositionImpl instead of dividing by zero (see issue #574). +// CHECK-LABEL: @test_logical_divide_identity_rank3 +func.func @test_logical_divide_identity_rank3() + -> !fly.layout<((2, (2, 2), 2), (2, 2)) : ((1E0, (2E0, 1E1), 2E1), (4E1, 1E2))> { + %s = fly.static : !fly.int_tuple<(4, 8, 2)> + %id = fly.make_identity_layout(%s) : (!fly.int_tuple<(4, 8, 2)>) -> !fly.layout<(4, 8, 2) : (1E0, 1E1, 1E2)> + %ds = fly.static : !fly.int_tuple<(2, 4, 2)> + %dd = fly.static : !fly.int_tuple<(1, 2, 8)> + %div = fly.make_layout(%ds, %dd) : (!fly.int_tuple<(2, 4, 2)>, !fly.int_tuple<(1, 2, 8)>) -> !fly.layout<(2, 4, 2) : (1, 2, 8)> + // CHECK: fly.logical_divide(%{{.*}}, %{{.*}}) + %result = fly.logical_divide(%id, %div) : (!fly.layout<(4, 8, 2) : (1E0, 1E1, 1E2)>, !fly.layout<(2, 4, 2) : (1, 2, 8)>) -> !fly.layout<((2, (2, 2), 2), (2, 2)) : ((1E0, (2E0, 1E1), 2E1), (4E1, 1E2))> + return %result : !fly.layout<((2, (2, 2), 2), (2, 2)) : ((1E0, (2E0, 1E1), 2E1), (4E1, 1E2))> +} diff --git a/tests/mlir/Transforms/layout_lowering.mlir b/tests/mlir/Transforms/layout_lowering.mlir index f27685ece..03e14a731 100644 --- a/tests/mlir/Transforms/layout_lowering.mlir +++ b/tests/mlir/Transforms/layout_lowering.mlir @@ -384,3 +384,43 @@ func.func @test_get_leaves_dynamic_only(%x: i32, %y: i64) -> (i32, i64) { %0:2 = fly.get_leaves(%t) {dynamicOnly = true} : (!fly.int_tuple<(4, ?, ?{i64})>) -> (i32, i64) return %0#0, %0#1 : i32, i64 } + +// ----- + +// === EqualOp Lowering: basis (scaled-basis / E) leaves === + +// equal on identical basis strides folds to true. +// CHECK-LABEL: @test_equal_basis_same +func.func @test_equal_basis_same() -> i1 { + %a = fly.make_int_tuple() : () -> !fly.int_tuple<(1E0, 1E1)> + // CHECK: %[[T:.*]] = arith.constant true + // CHECK: return %[[T]] + %r = fly.equal(%a, %a) : (!fly.int_tuple<(1E0, 1E1)>, !fly.int_tuple<(1E0, 1E1)>) -> i1 + return %r : i1 +} + +// ----- + +// equal on basis leaves with different modes folds to false (E0 != E1). +// CHECK-LABEL: @test_equal_basis_diff_modes +func.func @test_equal_basis_diff_modes() -> i1 { + %a = fly.make_int_tuple() : () -> !fly.int_tuple<(1E0)> + %b = fly.make_int_tuple() : () -> !fly.int_tuple<(1E1)> + // CHECK: %[[F:.*]] = arith.constant false + // CHECK: return %[[F]] + %r = fly.equal(%a, %b) : (!fly.int_tuple<(1E0)>, !fly.int_tuple<(1E1)>) -> i1 + return %r : i1 +} + +// ----- + +// a basis monomial never equals a plain integer leaf. +// CHECK-LABEL: @test_equal_int_vs_basis +func.func @test_equal_int_vs_basis() -> i1 { + %a = fly.make_int_tuple() : () -> !fly.int_tuple<(1)> + %b = fly.make_int_tuple() : () -> !fly.int_tuple<(1E0)> + // CHECK: %[[F:.*]] = arith.constant false + // CHECK: return %[[F]] + %r = fly.equal(%a, %b) : (!fly.int_tuple<(1)>, !fly.int_tuple<(1E0)>) -> i1 + return %r : i1 +} diff --git a/tests/unit/test_layout_algebra.py b/tests/unit/test_layout_algebra.py index 4f7a55d6c..918f35785 100644 --- a/tests/unit/test_layout_algebra.py +++ b/tests/unit/test_layout_algebra.py @@ -427,6 +427,102 @@ def build(): _build_and_verify("zipped_tiled_flat_product", build, [360]) +# ============================================================================== +# Basis strides (scaled-basis / CuTe E) via fx.E +# ============================================================================== + + +def test_basis_stride_via_E(): + """fx.E builds the same basis-strided layout as make_identity_layout.""" + + def build(): + fx.make_layout(fx.make_shape(4, 8), fx.make_stride(fx.E(0), fx.E(1))) + fx.make_identity_layout(fx.make_shape(4, 8)) + fx.make_basis_stride(1, (0, 1)) + fx.make_stride(fx.E(0, 1, value=2)) + + def check(ir): + # Normalize whitespace so the checks do not depend on the MLIR printer's + # exact spacing around commas/colons. + compact = "".join(ir.split()) + # fx.E(0), fx.E(1) produce a (1E0, 1E1) stride, identical to make_identity_layout. + assert "!fly.layout<(4,8):(1E0,1E1)>" in compact + # make_basis_stride(1, (0, 1)) yields the same flat basis stride. + assert "!fly.int_tuple<(1E0,1E1)>" in compact + # value and multi-mode forms: fx.E(0, 1, value=2) -> 2E0E1. + assert "!fly.int_tuple<(2E0E1)>" in compact + + _build_and_verify_ir("basis_stride_via_E", build, check) + + +def test_basis_E_rejects_invalid_modes(): + """fx.E validates its inputs in Python instead of deferring to the C++ cast. + + Negative modes are the key case: the !fly.int_tuple assembly format cannot + round-trip a negative E (`1E-1` fails to re-parse), so they must be + rejected at construction time. + """ + with pytest.raises(ValueError): + fx.E() # at least one mode required + with pytest.raises(ValueError): + fx.E(-1) # negative mode is not round-trippable + with pytest.raises(ValueError): + fx.make_basis_stride(1, (0, -2)) # negative mode via the flat helper + with pytest.raises(TypeError): + fx.E(0.0) # non-int mode + with pytest.raises(TypeError): + fx.E(0, value="x") # non-int coefficient + with pytest.raises(ValueError): + fx.E(2**31) # mode out of int32 range + + # The coefficient may be any int32, including negative/zero (these round-trip). + assert fx.E(0, value=-2).value == -2 + assert fx.E(0, value=0).value == 0 + + # Integer-like values (NumPy integer scalars) are accepted and normalized to int. + np = pytest.importorskip("numpy") + elem = fx.E(np.int32(2), value=np.int64(3)) + assert elem.modes == [2] and elem.value == 3 + assert all(isinstance(m, int) for m in elem.modes) and isinstance(elem.value, int) + + +def test_make_stride_rejects_non_stride_object(): + """A non-stride object must raise a clean error from make_stride, not segfault. + + Covers a falsy ``__fly_basis__`` marker (which is correctly not treated as a + basis leaf and falls through to the generic path) and an arbitrary object. + The generic path reports the value's type name, which must not be computed by + reinterpreting the instance as a Python type object. + """ + + class FalsyMarker: + __fly_basis__ = False + value = 1 + modes = [0] + + with Context() as ctx: + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + module = Module.create() + with InsertionPoint(module.body): + f = func.FuncOp("reject", FunctionType.get([], [])) + with InsertionPoint(f.add_entry_block()): + with pytest.raises(ValueError): + fx.make_stride(FalsyMarker()) + with pytest.raises(ValueError): + fx.make_stride(object()) + + +def test_basis_identity_size(): + """A basis-strided identity layout lowers through the pipeline; size = 4*8 = 32.""" + + def build(): + layout = fx.make_layout(fx.make_shape(4, 8), fx.make_stride(fx.E(0), fx.E(1))) + return [fx.size(layout)] + + _build_and_verify("basis_identity_size", build, [32]) + + # ============================================================================== # Main # ==============================================================================