Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions tilelang/language/overrides/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tvm.tir import BufferLoad, Var

from tvm.script.parser.tir import parser as tvm_tir_parser
from tilelang.language.tir.ir import SerialStepSpec


def _get_node_span(node: doc.AST) -> tuple[int, int, int, int]:
Expand Down Expand Up @@ -142,3 +143,65 @@ def tilelang_visit_ann_assign(self, node: doc.AnnAssign) -> None: # pylint: dis
frame = T.LetStmt(rhs, var=ann_var)
frame.add_callback(partial(frame.__exit__, None, None, None))
frame.__enter__()


# Override For to support stepped serial: T.serial(start, end, step)
@dispatch.register(token="tir", type_name="For")
def tilelang_visit_for(self, node: doc.For) -> None: # pylint: disable=unused-argument
"""Override `For` to add support for T.serial(start, end, step).

When the iterable is a SerialStepSpec, lower it to a unit-step loop over
t in [0, floor_div(|end-start|, step)] and bind the loop variable using a
Let to `start + t*step` (inclusive semantics).
"""
iter_val = self.eval_expr(node.iter)

# Fast path: fall back to TVM default behavior when not a SerialStepSpec
if not isinstance(iter_val, SerialStepSpec):
if not isinstance(iter_val, T.frame.ForFrame):
self.report_error(
node.iter,
"Expect the for loop to be one of the following: "
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding",
)
with self.var_table.with_frame():
with iter_val as iters:
self.eval_assign(target=node.target, source=iters, bind_value=tvm_tir_parser.bind_for_value)
self.visit_body(node.body)
return

# Stepped inclusive serial: require positive integer step
start = iter_val.start
end = iter_val.stop
step = iter_val.step
annotations = iter_val.annotations

# Normalize step to Python int if possible, otherwise expect IntImm-like
if isinstance(step, int):
step_val = step
else:
step_val = getattr(step, "value", None)
if step_val is None:
self.report_error(node.iter, "T.serial step must be an integer or IntImm")
return

if step_val <= 0:
self.report_error(node.iter, "T.serial step must be a positive integer")
return

# Use tvm.tir.floordiv via builder ops from tilelang.tir.ir if available
# Avoid importing op wrappers; compute using arithmetic to keep it simple.
# We construct: T.ceildiv((end - start), step)
extent = T.ceildiv(end - start, step_val) # type: ignore[operator]

for_frame = T.serial(0, extent, annotations=annotations)
with self.var_table.with_frame():
with for_frame as t:
# Bind loop target as Let var: i = start + t * step
stepped_index = start + t * step_val # type: ignore[operator]
self.eval_assign(
target=node.target,
source=stepped_index,
bind_value=tvm_tir_parser.bind_assign_value,
)
self.visit_body(node.body)
Comment on lines +148 to +207

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Inclusive stepped T.serial is off‑by‑one when (end - start) is divisible by step (and when start == end)

The lowering in tilelang_visit_for mostly matches the documented inclusive semantics, but the extent calculation is short by one iteration whenever (end - start) is an exact multiple of step, including the start == end case.

  • For example, T.serial(0, 8, 2) is documented to produce i = 0, 2, 4, 6, 8, but:
    • end - start = 8
    • extent = T.ceildiv(8, 2) = 4
    • t runs 0..3, so i = 0, 2, 4, 6 and the i = 8 iteration is lost.
  • Similarly, for T.serial(5, 5, 1):
    • end - start = 0
    • extent = 0 → no iterations, instead of a single i = 5 iteration.

To implement the stated semantics i = start, start+step, …, <= end, the iteration count should be floor((end - start) / step) + 1, which can be expressed as ceildiv((end - start) + 1, step) for positive step.

Consider updating the extent computation and the accompanying comment as follows:

-    # Use tvm.tir.floordiv via builder ops from tilelang.tir.ir if available
-    # Avoid importing op wrappers; compute using arithmetic to keep it simple.
-    # We construct: T.ceildiv((end - start), step)
-    extent = T.ceildiv(end - start, step_val) # type: ignore[operator]
+    # Compute inclusive iteration count:
+    # N = floor((end - start) / step) + 1  for step > 0,
+    # which is equivalent to ceildiv((end - start) + 1, step).
+    # This realizes i = start, start + step, ..., <= end.
+    delta = end - start  # type: ignore[operator]
+    extent = T.ceildiv(delta + 1, step_val)  # type: ignore[operator]

You may also want to:

  • Align the docstring at Lines 153‑155 with the actual formula (ceildiv((end - start) + 1, step) instead of floor_div(|end-start|, step)), to avoid confusion.
  • (Optional) Clarify that, in the current parser, step must be a positive constant or IntImm (Line 179–187), even though the type annotation allows a general PrimExpr. Either narrowing the type in SerialStepSpec/serial() or explicitly documenting this constraint would make behavior clearer.

I'll verify the review comment by investigating the codebase. Let me search for the relevant code and test cases.


---

🏁 Script executed:

```shell
#!/bin/bash
# Find and display the actual code at lines 148-207 in parser.py
if [ -f "tilelang/language/overrides/parser.py" ]; then
  echo "=== Current code at lines 148-207 ==="
  sed -n '148,207p' tilelang/language/overrides/parser.py
else
  find . -name "parser.py" -path "*/language/overrides/*" -exec cat {} \;
fi

Length of output: 2618


🏁 Script executed:

#!/bin/bash
# Search for SerialStepSpec definition
echo "=== SerialStepSpec definition ==="
rg -n "class SerialStepSpec" -A 10

Length of output: 778


🏁 Script executed:

#!/bin/bash
# Search for T.serial usage and tests
echo "=== Test files for serial ==="
fd -e py -path "*test*" | head -20

Length of output: 198


🏁 Script executed:

#!/bin/bash
# Look for documentation about T.serial semantics
echo "=== Looking for T.serial semantics documentation ==="
rg -n "T\.serial" -B 2 -A 2 --type=py -g "*.py" | head -50

Length of output: 3780


🏁 Script executed:

#!/bin/bash
# Search for existing tests of stepped serial loops
echo "=== Searching for stepped serial tests ==="
rg -n "T\.serial.*step" -C 3 --type=py

Length of output: 2842


Let me search more thoroughly for test cases and documentation about the expected semantics.


---

🏁 Script executed:

```shell
#!/bin/bash
# Find all test files in the repository
echo "=== Test files ==="
find . -name "*test*.py" -o -name "test_*.py" | head -30

Length of output: 1539


🏁 Script executed:

#!/bin/bash
# Search for tests that involve parser or For loops
echo "=== Parser-related tests ==="
rg -l "tilelang_visit_for\|SerialStepSpec" --type=py

Length of output: 28


🏁 Script executed:

#!/bin/bash
# Look for documentation about the semantics in docstrings and comments
echo "=== Documentation about inclusive semantics ==="
rg -n "inclusive" -B 3 -A 3 --type=py | head -60

Length of output: 3508


🏁 Script executed:

#!/bin/bash
# Find where T.ceildiv is imported/used
echo "=== T.ceildiv definition or import ==="
rg -n "ceildiv" -B 2 -A 2 --type=py | head -40

Length of output: 2305


🏁 Script executed:

#!/bin/bash
# Look at the serial function definition more completely
echo "=== Full serial function definition ==="
rg -n "def serial" -A 30 tilelang/language/tir/ir.py

Length of output: 1157


🏁 Script executed:

#!/by/bash
# Search for existing tests specifically
find . \( -name "*test*" -type d \) | head -10

Length of output: 100


Let me check the testing directories for any tests of stepped serial loops:


---

🏁 Script executed:

```shell
#!/bin/bash
# Search testing directories for stepped serial tests
echo "=== Testing directories content ==="
find tilelang/testing -name "*.py" -type f 2>/dev/null | head -20

Length of output: 64


🏁 Script executed:

#!/bin/bash
# Check if any test imports or uses SerialStepSpec or T.serial with step
echo "=== Search all Python files for stepped serial usage ==="
rg -n "T\.serial\(" --type=py -A 1 | grep -E "(step|,.*,.*\))" | head -20

Length of output: 2231


🏁 Script executed:

#!/bin/bash
# Verify the mathematical correctness by checking the exact semantics
echo "=== Check exact extent calculation semantics ==="
rg -n "For.*SerialStepSpec\|class For" -B 2 -A 8 --type=py | head -50

Length of output: 48


Fix off-by-one error in inclusive stepped T.serial iteration count

The review comment correctly identifies a critical off-by-one bug. The documented semantics specify inclusive iteration (i = start, start+step, ..., <= end per tilelang/language/tir/ir.py:42-45), but the extent calculation is short by one when (end - start) is divisible by step (including start == end).

Verification confirms the mathematical error:

  • T.serial(0, 8, 2) should produce i = 0,2,4,6,8 (5 iterations), but current code produces 4 iterations.
  • T.serial(5, 5, 1) should produce i = 5 (1 iteration), but current code produces 0.
  • The proposed formula ceildiv((end - start) + 1, step_val) correctly implements floor((end-start)/step) + 1.

Additionally, update the docstring at lines 153–155, which currently states t in [0, floor_div(|end-start|, step)] but should reflect the correct formula.

🧰 Tools
🪛 Ruff (0.14.5)

149-149: Possible hardcoded password assigned to argument: "token"

(S106)

🤖 Prompt for AI Agents
In tilelang/language/overrides/parser.py around lines 148 to 207, the
inclusive-stepped T.serial lowering computes extent incorrectly (off-by-one) and
the docstring is inaccurate; change the extent calculation to use ceildiv((end -
start) + 1, step_val) (i.e. T.ceildiv((end - start) + 1, step_val)) so the loop
yields floor((end-start)/step)+1 iterations (covers cases like start==end), and
update the docstring at lines ~153-155 to state that we lower to t in [0,
floor((end-start)/step)] (reflecting inclusive semantics). Ensure the rest of
the lowering uses the computed extent unchanged.

40 changes: 37 additions & 3 deletions tilelang/language/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,26 @@
import functools


class SerialStepSpec:
"""A lightweight spec object for stepped serial loops.

This is consumed by the TileLang TIR parser override to realize
inclusive stepped loops like T.serial(start, end, step).
"""

def __init__(self, start: PrimExpr, stop: PrimExpr, step: PrimExpr | int,
annotations: dict[str, Any] | None):
self.start = start
self.stop = stop
self.step = step
self.annotations = annotations


def serial(start: PrimExpr,
stop: PrimExpr = None,
stop: PrimExpr | None = None,
step: PrimExpr | int | None = None,
*,
annotations: dict[str, Any] = None) -> frame.ForFrame:
annotations: dict[str, Any] = None) -> frame.ForFrame | SerialStepSpec:
"""The serial For statement.

Parameters
Expand All @@ -21,6 +37,14 @@ def serial(start: PrimExpr,
stop : PrimExpr
The maximum value of iteration.

step : PrimExpr | int | None
Optional step size of iteration. When provided as the third positional
argument (or keyword), the loop iterates inclusively with stride `step`:
i = start, start+step, ..., <= end. If `end-start` is not divisible by
`step`, the last value will be the largest `start + k*step` such that
it does not exceed `end` (for positive step). Negative steps are not
currently supported.

annotations : Dict[str, Any]
The optional annotations of the For statement.

Expand All @@ -29,7 +53,17 @@ def serial(start: PrimExpr,
res : frame.ForFrame
The ForFrame.
"""
return _ir.serial(start=start, stop=stop, annotations=annotations)
# If no step is provided, delegate to the upstream builder (supports
# both one-arg and two-arg forms).
if step is None:
return _ir.serial(start=start, stop=stop, annotations=annotations)

# Step provided: return a spec for the parser override to lower into an
# inclusive stepped loop. Require `stop` to be provided explicitly.
if stop is None:
raise TypeError("T.serial(start, end, step): `end` must be provided when `step` is set")

return SerialStepSpec(start=start, stop=stop, step=step, annotations=annotations)


def parallel(start: PrimExpr,
Expand Down
Loading