Skip to content

[Feat] Support stride in T.serial#41

Merged
Rachmanino merged 2 commits into
mainfrom
yu/serial
Nov 20, 2025
Merged

[Feat] Support stride in T.serial#41
Rachmanino merged 2 commits into
mainfrom
yu/serial

Conversation

@Rachmanino

@Rachmanino Rachmanino commented Nov 20, 2025

Copy link
Copy Markdown
Collaborator

Summary by CodeRabbit

Release Notes

  • New Features
    • Added support for stepped serial loop iteration, allowing loops to increment by custom step values during execution.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai

coderabbitai Bot commented Nov 20, 2025

Copy link
Copy Markdown

Walkthrough

Added support for stepped serial loops in tilelang. A new SerialStepSpec class represents stepped loop specifications. The serial() function now accepts an optional step parameter and returns either a ForFrame or SerialStepSpec accordingly. A new tilelang_visit_for() handler processes stepped loops during parsing with validation for positive integer steps.

Changes

Cohort / File(s) Summary
Stepped loop specification and API
tilelang/language/tir/ir.py
Added SerialStepSpec class for inclusive stepped loop representation. Extended serial() function signature to accept optional step parameter (PrimExpr | int | None), returning ForFrame when step is None or SerialStepSpec when step is provided. Includes validation requiring stop parameter when step is specified.
Stepped loop visitor handler
tilelang/language/overrides/parser.py
Added tilelang_visit_for() handler to process stepped serial loops. Validates positive integer steps, computes iteration extent, binds loop index to stepped value, and processes loop body. Falls back to default TVM behavior for non-SerialStepSpec cases.

Sequence Diagram

sequenceDiagram
    participant User
    participant serial as serial()
    participant Parser
    participant Handler as tilelang_visit_for()
    
    User->>serial: serial(start, stop, step)
    alt step provided
        serial->>serial: Validate step > 0
        serial-->>User: SerialStepSpec
    else step not provided
        serial-->>User: ForFrame (upstream)
    end
    
    Note over Parser: Parse loop with<br/>SerialStepSpec iterand
    Parser->>Handler: Visit For node
    Handler->>Handler: Compute iteration extent<br/>from (stop - start) / step
    Handler->>Handler: Bind loop index<br/>to stepped values
    Handler->>Handler: Process loop body
    Handler-->>Parser: Complete
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

  • tilelang/language/tir/ir.py: New SerialStepSpec class definition and conditional logic in serial() function for step handling; straightforward parameter validation and return type branching.
  • tilelang/language/overrides/parser.py: New visitor method with loop binding logic; review should focus on step validation, extent computation correctness, and error handling paths.

Poem

🐰 In loops that step with grace so neat,
Each bound and stride will now complete,
SerialStepSpec hops along,
With handler's visitor, steady and strong!
From start to stop, a rabbit's delight,

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[Feat] Support stride in T.serial' directly corresponds to the main change: adding a new step parameter to the serial function for stepped loop iteration.
Docstring Coverage ✅ Passed Docstring coverage is 83.33% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yu/serial

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tilelang/language/tir/ir.py (1)

25-66: Align serial() docs and typing with new stepped behavior

The behavioral changes in serial() look consistent with the parser override, but there are a couple of small mismatches worth tightening up:

  • The Returns section still documents only frame.ForFrame, while the function now sometimes returns SerialStepSpec when step is provided. This can surprise users and tools relying on the docstring.
  • The signature and docstring allow step: PrimExpr | int | None, but the parser currently only accepts a positive Python int or IntImm‑like object (via .value) and rejects other PrimExpr kinds. That’s fine as an initial restriction, but it’s good to make it explicit.

Consider updating the docstring along these lines:

-    Returns
-    -------
-    res : frame.ForFrame
-        The ForFrame.
+    Returns
+    -------
+    res : frame.ForFrame or SerialStepSpec
+        When `step` is None, returns a `frame.ForFrame` (upstream behavior).
+        When `step` is provided, returns a `SerialStepSpec` that the parser
+        override lowers into an inclusive stepped loop. The current parser
+        requires `step` to be a positive integer or IntImm-like constant.

(Optionally, you could also narrow the type of step in the signature and in SerialStepSpec to reflect the “constant step only” constraint, or extend the parser to handle general PrimExpr steps.)

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a5ce988 and 1de65fd.

📒 Files selected for processing (2)
  • tilelang/language/overrides/parser.py (2 hunks)
  • tilelang/language/tir/ir.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/language/overrides/parser.py (2)
tilelang/language/tir/ir.py (2)
  • SerialStepSpec (10-22)
  • serial (25-66)
tilelang/language/parser/parser.py (2)
  • bind_for_value (78-111)
  • bind_assign_value (114-160)
🪛 Ruff (0.14.5)
tilelang/language/overrides/parser.py

149-149: Possible hardcoded password assigned to argument: "token"

(S106)

tilelang/language/tir/ir.py

29-29: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


64-64: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (2)
tilelang/language/overrides/parser.py (1)

11-11: Importing SerialStepSpec looks correct

The import cleanly wires the new stepped-serial spec into the parser override; no issues here.

tilelang/language/tir/ir.py (1)

10-22: SerialStepSpec container matches parser needs

The spec class is minimal and its fields (start, stop, step, annotations) line up with how the parser override consumes them. No issues here.

Comment on lines +148 to +207
# Override For to support stepped serial: T.serial(start, end, step)
@dispatch.register(token="tir", type_name="For")
def tilelang_visit_for(self, node: doc.For) -> None: # pylint: disable=unused-argument
"""Override `For` to add support for T.serial(start, end, step).

When the iterable is a SerialStepSpec, lower it to a unit-step loop over
t in [0, floor_div(|end-start|, step)] and bind the loop variable using a
Let to `start + t*step` (inclusive semantics).
"""
iter_val = self.eval_expr(node.iter)

# Fast path: fall back to TVM default behavior when not a SerialStepSpec
if not isinstance(iter_val, SerialStepSpec):
if not isinstance(iter_val, T.frame.ForFrame):
self.report_error(
node.iter,
"Expect the for loop to be one of the following: "
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding",
)
with self.var_table.with_frame():
with iter_val as iters:
self.eval_assign(target=node.target, source=iters, bind_value=tvm_tir_parser.bind_for_value)
self.visit_body(node.body)
return

# Stepped inclusive serial: require positive integer step
start = iter_val.start
end = iter_val.stop
step = iter_val.step
annotations = iter_val.annotations

# Normalize step to Python int if possible, otherwise expect IntImm-like
if isinstance(step, int):
step_val = step
else:
step_val = getattr(step, "value", None)
if step_val is None:
self.report_error(node.iter, "T.serial step must be an integer or IntImm")
return

if step_val <= 0:
self.report_error(node.iter, "T.serial step must be a positive integer")
return

# Use tvm.tir.floordiv via builder ops from tilelang.tir.ir if available
# Avoid importing op wrappers; compute using arithmetic to keep it simple.
# We construct: T.ceildiv((end - start), step)
extent = T.ceildiv(end - start, step_val) # type: ignore[operator]

for_frame = T.serial(0, extent, annotations=annotations)
with self.var_table.with_frame():
with for_frame as t:
# Bind loop target as Let var: i = start + t * step
stepped_index = start + t * step_val # type: ignore[operator]
self.eval_assign(
target=node.target,
source=stepped_index,
bind_value=tvm_tir_parser.bind_assign_value,
)
self.visit_body(node.body)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Inclusive stepped T.serial is off‑by‑one when (end - start) is divisible by step (and when start == end)

The lowering in tilelang_visit_for mostly matches the documented inclusive semantics, but the extent calculation is short by one iteration whenever (end - start) is an exact multiple of step, including the start == end case.

  • For example, T.serial(0, 8, 2) is documented to produce i = 0, 2, 4, 6, 8, but:
    • end - start = 8
    • extent = T.ceildiv(8, 2) = 4
    • t runs 0..3, so i = 0, 2, 4, 6 and the i = 8 iteration is lost.
  • Similarly, for T.serial(5, 5, 1):
    • end - start = 0
    • extent = 0 → no iterations, instead of a single i = 5 iteration.

To implement the stated semantics i = start, start+step, …, <= end, the iteration count should be floor((end - start) / step) + 1, which can be expressed as ceildiv((end - start) + 1, step) for positive step.

Consider updating the extent computation and the accompanying comment as follows:

-    # Use tvm.tir.floordiv via builder ops from tilelang.tir.ir if available
-    # Avoid importing op wrappers; compute using arithmetic to keep it simple.
-    # We construct: T.ceildiv((end - start), step)
-    extent = T.ceildiv(end - start, step_val) # type: ignore[operator]
+    # Compute inclusive iteration count:
+    # N = floor((end - start) / step) + 1  for step > 0,
+    # which is equivalent to ceildiv((end - start) + 1, step).
+    # This realizes i = start, start + step, ..., <= end.
+    delta = end - start  # type: ignore[operator]
+    extent = T.ceildiv(delta + 1, step_val)  # type: ignore[operator]

You may also want to:

  • Align the docstring at Lines 153‑155 with the actual formula (ceildiv((end - start) + 1, step) instead of floor_div(|end-start|, step)), to avoid confusion.
  • (Optional) Clarify that, in the current parser, step must be a positive constant or IntImm (Line 179–187), even though the type annotation allows a general PrimExpr. Either narrowing the type in SerialStepSpec/serial() or explicitly documenting this constraint would make behavior clearer.

I'll verify the review comment by investigating the codebase. Let me search for the relevant code and test cases.


---

🏁 Script executed:

```shell
#!/bin/bash
# Find and display the actual code at lines 148-207 in parser.py
if [ -f "tilelang/language/overrides/parser.py" ]; then
  echo "=== Current code at lines 148-207 ==="
  sed -n '148,207p' tilelang/language/overrides/parser.py
else
  find . -name "parser.py" -path "*/language/overrides/*" -exec cat {} \;
fi

Length of output: 2618


🏁 Script executed:

#!/bin/bash
# Search for SerialStepSpec definition
echo "=== SerialStepSpec definition ==="
rg -n "class SerialStepSpec" -A 10

Length of output: 778


🏁 Script executed:

#!/bin/bash
# Search for T.serial usage and tests
echo "=== Test files for serial ==="
fd -e py -path "*test*" | head -20

Length of output: 198


🏁 Script executed:

#!/bin/bash
# Look for documentation about T.serial semantics
echo "=== Looking for T.serial semantics documentation ==="
rg -n "T\.serial" -B 2 -A 2 --type=py -g "*.py" | head -50

Length of output: 3780


🏁 Script executed:

#!/bin/bash
# Search for existing tests of stepped serial loops
echo "=== Searching for stepped serial tests ==="
rg -n "T\.serial.*step" -C 3 --type=py

Length of output: 2842


Let me search more thoroughly for test cases and documentation about the expected semantics.


---

🏁 Script executed:

```shell
#!/bin/bash
# Find all test files in the repository
echo "=== Test files ==="
find . -name "*test*.py" -o -name "test_*.py" | head -30

Length of output: 1539


🏁 Script executed:

#!/bin/bash
# Search for tests that involve parser or For loops
echo "=== Parser-related tests ==="
rg -l "tilelang_visit_for\|SerialStepSpec" --type=py

Length of output: 28


🏁 Script executed:

#!/bin/bash
# Look for documentation about the semantics in docstrings and comments
echo "=== Documentation about inclusive semantics ==="
rg -n "inclusive" -B 3 -A 3 --type=py | head -60

Length of output: 3508


🏁 Script executed:

#!/bin/bash
# Find where T.ceildiv is imported/used
echo "=== T.ceildiv definition or import ==="
rg -n "ceildiv" -B 2 -A 2 --type=py | head -40

Length of output: 2305


🏁 Script executed:

#!/bin/bash
# Look at the serial function definition more completely
echo "=== Full serial function definition ==="
rg -n "def serial" -A 30 tilelang/language/tir/ir.py

Length of output: 1157


🏁 Script executed:

#!/by/bash
# Search for existing tests specifically
find . \( -name "*test*" -type d \) | head -10

Length of output: 100


Let me check the testing directories for any tests of stepped serial loops:


---

🏁 Script executed:

```shell
#!/bin/bash
# Search testing directories for stepped serial tests
echo "=== Testing directories content ==="
find tilelang/testing -name "*.py" -type f 2>/dev/null | head -20

Length of output: 64


🏁 Script executed:

#!/bin/bash
# Check if any test imports or uses SerialStepSpec or T.serial with step
echo "=== Search all Python files for stepped serial usage ==="
rg -n "T\.serial\(" --type=py -A 1 | grep -E "(step|,.*,.*\))" | head -20

Length of output: 2231


🏁 Script executed:

#!/bin/bash
# Verify the mathematical correctness by checking the exact semantics
echo "=== Check exact extent calculation semantics ==="
rg -n "For.*SerialStepSpec\|class For" -B 2 -A 8 --type=py | head -50

Length of output: 48


Fix off-by-one error in inclusive stepped T.serial iteration count

The review comment correctly identifies a critical off-by-one bug. The documented semantics specify inclusive iteration (i = start, start+step, ..., <= end per tilelang/language/tir/ir.py:42-45), but the extent calculation is short by one when (end - start) is divisible by step (including start == end).

Verification confirms the mathematical error:

  • T.serial(0, 8, 2) should produce i = 0,2,4,6,8 (5 iterations), but current code produces 4 iterations.
  • T.serial(5, 5, 1) should produce i = 5 (1 iteration), but current code produces 0.
  • The proposed formula ceildiv((end - start) + 1, step_val) correctly implements floor((end-start)/step) + 1.

Additionally, update the docstring at lines 153–155, which currently states t in [0, floor_div(|end-start|, step)] but should reflect the correct formula.

🧰 Tools
🪛 Ruff (0.14.5)

149-149: Possible hardcoded password assigned to argument: "token"

(S106)

🤖 Prompt for AI Agents
In tilelang/language/overrides/parser.py around lines 148 to 207, the
inclusive-stepped T.serial lowering computes extent incorrectly (off-by-one) and
the docstring is inaccurate; change the extent calculation to use ceildiv((end -
start) + 1, step_val) (i.e. T.ceildiv((end - start) + 1, step_val)) so the loop
yields floor((end-start)/step)+1 iterations (covers cases like start==end), and
update the docstring at lines ~153-155 to state that we lower to t in [0,
floor((end-start)/step)] (reflecting inclusive semantics). Ensure the rest of
the lowering uses the computed extent unchanged.

@Rachmanino Rachmanino merged commit fe7bbdd into main Nov 20, 2025
2 of 3 checks passed
@chengyupku chengyupku deleted the yu/serial branch February 6, 2026 08:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants