[Feat] Support stride in T.serial#41
Conversation
WalkthroughAdded support for stepped serial loops in tilelang. A new Changes
Sequence DiagramsequenceDiagram
participant User
participant serial as serial()
participant Parser
participant Handler as tilelang_visit_for()
User->>serial: serial(start, stop, step)
alt step provided
serial->>serial: Validate step > 0
serial-->>User: SerialStepSpec
else step not provided
serial-->>User: ForFrame (upstream)
end
Note over Parser: Parse loop with<br/>SerialStepSpec iterand
Parser->>Handler: Visit For node
Handler->>Handler: Compute iteration extent<br/>from (stop - start) / step
Handler->>Handler: Bind loop index<br/>to stepped values
Handler->>Handler: Process loop body
Handler-->>Parser: Complete
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tilelang/language/tir/ir.py (1)
25-66: Alignserial()docs and typing with new stepped behaviorThe behavioral changes in
serial()look consistent with the parser override, but there are a couple of small mismatches worth tightening up:
- The Returns section still documents only
frame.ForFrame, while the function now sometimes returnsSerialStepSpecwhenstepis provided. This can surprise users and tools relying on the docstring.- The signature and docstring allow
step: PrimExpr | int | None, but the parser currently only accepts a positive Pythonintor IntImm‑like object (via.value) and rejects otherPrimExprkinds. That’s fine as an initial restriction, but it’s good to make it explicit.Consider updating the docstring along these lines:
- Returns - ------- - res : frame.ForFrame - The ForFrame. + Returns + ------- + res : frame.ForFrame or SerialStepSpec + When `step` is None, returns a `frame.ForFrame` (upstream behavior). + When `step` is provided, returns a `SerialStepSpec` that the parser + override lowers into an inclusive stepped loop. The current parser + requires `step` to be a positive integer or IntImm-like constant.(Optionally, you could also narrow the type of
stepin the signature and inSerialStepSpecto reflect the “constant step only” constraint, or extend the parser to handle generalPrimExprsteps.)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tilelang/language/overrides/parser.py(2 hunks)tilelang/language/tir/ir.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/language/overrides/parser.py (2)
tilelang/language/tir/ir.py (2)
SerialStepSpec(10-22)serial(25-66)tilelang/language/parser/parser.py (2)
bind_for_value(78-111)bind_assign_value(114-160)
🪛 Ruff (0.14.5)
tilelang/language/overrides/parser.py
149-149: Possible hardcoded password assigned to argument: "token"
(S106)
tilelang/language/tir/ir.py
29-29: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
64-64: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (2)
tilelang/language/overrides/parser.py (1)
11-11: ImportingSerialStepSpeclooks correctThe import cleanly wires the new stepped-serial spec into the parser override; no issues here.
tilelang/language/tir/ir.py (1)
10-22:SerialStepSpeccontainer matches parser needsThe spec class is minimal and its fields (
start,stop,step,annotations) line up with how the parser override consumes them. No issues here.
| # Override For to support stepped serial: T.serial(start, end, step) | ||
| @dispatch.register(token="tir", type_name="For") | ||
| def tilelang_visit_for(self, node: doc.For) -> None: # pylint: disable=unused-argument | ||
| """Override `For` to add support for T.serial(start, end, step). | ||
|
|
||
| When the iterable is a SerialStepSpec, lower it to a unit-step loop over | ||
| t in [0, floor_div(|end-start|, step)] and bind the loop variable using a | ||
| Let to `start + t*step` (inclusive semantics). | ||
| """ | ||
| iter_val = self.eval_expr(node.iter) | ||
|
|
||
| # Fast path: fall back to TVM default behavior when not a SerialStepSpec | ||
| if not isinstance(iter_val, SerialStepSpec): | ||
| if not isinstance(iter_val, T.frame.ForFrame): | ||
| self.report_error( | ||
| node.iter, | ||
| "Expect the for loop to be one of the following: " | ||
| "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding", | ||
| ) | ||
| with self.var_table.with_frame(): | ||
| with iter_val as iters: | ||
| self.eval_assign(target=node.target, source=iters, bind_value=tvm_tir_parser.bind_for_value) | ||
| self.visit_body(node.body) | ||
| return | ||
|
|
||
| # Stepped inclusive serial: require positive integer step | ||
| start = iter_val.start | ||
| end = iter_val.stop | ||
| step = iter_val.step | ||
| annotations = iter_val.annotations | ||
|
|
||
| # Normalize step to Python int if possible, otherwise expect IntImm-like | ||
| if isinstance(step, int): | ||
| step_val = step | ||
| else: | ||
| step_val = getattr(step, "value", None) | ||
| if step_val is None: | ||
| self.report_error(node.iter, "T.serial step must be an integer or IntImm") | ||
| return | ||
|
|
||
| if step_val <= 0: | ||
| self.report_error(node.iter, "T.serial step must be a positive integer") | ||
| return | ||
|
|
||
| # Use tvm.tir.floordiv via builder ops from tilelang.tir.ir if available | ||
| # Avoid importing op wrappers; compute using arithmetic to keep it simple. | ||
| # We construct: T.ceildiv((end - start), step) | ||
| extent = T.ceildiv(end - start, step_val) # type: ignore[operator] | ||
|
|
||
| for_frame = T.serial(0, extent, annotations=annotations) | ||
| with self.var_table.with_frame(): | ||
| with for_frame as t: | ||
| # Bind loop target as Let var: i = start + t * step | ||
| stepped_index = start + t * step_val # type: ignore[operator] | ||
| self.eval_assign( | ||
| target=node.target, | ||
| source=stepped_index, | ||
| bind_value=tvm_tir_parser.bind_assign_value, | ||
| ) | ||
| self.visit_body(node.body) |
There was a problem hiding this comment.
🧩 Analysis chain
Inclusive stepped T.serial is off‑by‑one when (end - start) is divisible by step (and when start == end)
The lowering in tilelang_visit_for mostly matches the documented inclusive semantics, but the extent calculation is short by one iteration whenever (end - start) is an exact multiple of step, including the start == end case.
- For example,
T.serial(0, 8, 2)is documented to producei = 0, 2, 4, 6, 8, but:end - start = 8extent = T.ceildiv(8, 2) = 4truns0..3, soi = 0, 2, 4, 6and thei = 8iteration is lost.
- Similarly, for
T.serial(5, 5, 1):end - start = 0extent = 0→ no iterations, instead of a singlei = 5iteration.
To implement the stated semantics i = start, start+step, …, <= end, the iteration count should be floor((end - start) / step) + 1, which can be expressed as ceildiv((end - start) + 1, step) for positive step.
Consider updating the extent computation and the accompanying comment as follows:
- # Use tvm.tir.floordiv via builder ops from tilelang.tir.ir if available
- # Avoid importing op wrappers; compute using arithmetic to keep it simple.
- # We construct: T.ceildiv((end - start), step)
- extent = T.ceildiv(end - start, step_val) # type: ignore[operator]
+ # Compute inclusive iteration count:
+ # N = floor((end - start) / step) + 1 for step > 0,
+ # which is equivalent to ceildiv((end - start) + 1, step).
+ # This realizes i = start, start + step, ..., <= end.
+ delta = end - start # type: ignore[operator]
+ extent = T.ceildiv(delta + 1, step_val) # type: ignore[operator]You may also want to:
- Align the docstring at Lines 153‑155 with the actual formula (
ceildiv((end - start) + 1, step)instead offloor_div(|end-start|, step)), to avoid confusion. - (Optional) Clarify that, in the current parser,
stepmust be a positive constant or IntImm (Line 179–187), even though the type annotation allows a generalPrimExpr. Either narrowing the type inSerialStepSpec/serial()or explicitly documenting this constraint would make behavior clearer.
I'll verify the review comment by investigating the codebase. Let me search for the relevant code and test cases.
---
🏁 Script executed:
```shell
#!/bin/bash
# Find and display the actual code at lines 148-207 in parser.py
if [ -f "tilelang/language/overrides/parser.py" ]; then
echo "=== Current code at lines 148-207 ==="
sed -n '148,207p' tilelang/language/overrides/parser.py
else
find . -name "parser.py" -path "*/language/overrides/*" -exec cat {} \;
fi
Length of output: 2618
🏁 Script executed:
#!/bin/bash
# Search for SerialStepSpec definition
echo "=== SerialStepSpec definition ==="
rg -n "class SerialStepSpec" -A 10Length of output: 778
🏁 Script executed:
#!/bin/bash
# Search for T.serial usage and tests
echo "=== Test files for serial ==="
fd -e py -path "*test*" | head -20Length of output: 198
🏁 Script executed:
#!/bin/bash
# Look for documentation about T.serial semantics
echo "=== Looking for T.serial semantics documentation ==="
rg -n "T\.serial" -B 2 -A 2 --type=py -g "*.py" | head -50Length of output: 3780
🏁 Script executed:
#!/bin/bash
# Search for existing tests of stepped serial loops
echo "=== Searching for stepped serial tests ==="
rg -n "T\.serial.*step" -C 3 --type=pyLength of output: 2842
Let me search more thoroughly for test cases and documentation about the expected semantics.
---
🏁 Script executed:
```shell
#!/bin/bash
# Find all test files in the repository
echo "=== Test files ==="
find . -name "*test*.py" -o -name "test_*.py" | head -30
Length of output: 1539
🏁 Script executed:
#!/bin/bash
# Search for tests that involve parser or For loops
echo "=== Parser-related tests ==="
rg -l "tilelang_visit_for\|SerialStepSpec" --type=pyLength of output: 28
🏁 Script executed:
#!/bin/bash
# Look for documentation about the semantics in docstrings and comments
echo "=== Documentation about inclusive semantics ==="
rg -n "inclusive" -B 3 -A 3 --type=py | head -60Length of output: 3508
🏁 Script executed:
#!/bin/bash
# Find where T.ceildiv is imported/used
echo "=== T.ceildiv definition or import ==="
rg -n "ceildiv" -B 2 -A 2 --type=py | head -40Length of output: 2305
🏁 Script executed:
#!/bin/bash
# Look at the serial function definition more completely
echo "=== Full serial function definition ==="
rg -n "def serial" -A 30 tilelang/language/tir/ir.pyLength of output: 1157
🏁 Script executed:
#!/by/bash
# Search for existing tests specifically
find . \( -name "*test*" -type d \) | head -10Length of output: 100
Let me check the testing directories for any tests of stepped serial loops:
---
🏁 Script executed:
```shell
#!/bin/bash
# Search testing directories for stepped serial tests
echo "=== Testing directories content ==="
find tilelang/testing -name "*.py" -type f 2>/dev/null | head -20
Length of output: 64
🏁 Script executed:
#!/bin/bash
# Check if any test imports or uses SerialStepSpec or T.serial with step
echo "=== Search all Python files for stepped serial usage ==="
rg -n "T\.serial\(" --type=py -A 1 | grep -E "(step|,.*,.*\))" | head -20Length of output: 2231
🏁 Script executed:
#!/bin/bash
# Verify the mathematical correctness by checking the exact semantics
echo "=== Check exact extent calculation semantics ==="
rg -n "For.*SerialStepSpec\|class For" -B 2 -A 8 --type=py | head -50Length of output: 48
Fix off-by-one error in inclusive stepped T.serial iteration count
The review comment correctly identifies a critical off-by-one bug. The documented semantics specify inclusive iteration (i = start, start+step, ..., <= end per tilelang/language/tir/ir.py:42-45), but the extent calculation is short by one when (end - start) is divisible by step (including start == end).
Verification confirms the mathematical error:
T.serial(0, 8, 2)should producei = 0,2,4,6,8(5 iterations), but current code produces 4 iterations.T.serial(5, 5, 1)should producei = 5(1 iteration), but current code produces 0.- The proposed formula
ceildiv((end - start) + 1, step_val)correctly implementsfloor((end-start)/step) + 1.
Additionally, update the docstring at lines 153–155, which currently states t in [0, floor_div(|end-start|, step)] but should reflect the correct formula.
🧰 Tools
🪛 Ruff (0.14.5)
149-149: Possible hardcoded password assigned to argument: "token"
(S106)
🤖 Prompt for AI Agents
In tilelang/language/overrides/parser.py around lines 148 to 207, the
inclusive-stepped T.serial lowering computes extent incorrectly (off-by-one) and
the docstring is inaccurate; change the extent calculation to use ceildiv((end -
start) + 1, step_val) (i.e. T.ceildiv((end - start) + 1, step_val)) so the loop
yields floor((end-start)/step)+1 iterations (covers cases like start==end), and
update the docstring at lines ~153-155 to state that we lower to t in [0,
floor((end-start)/step)] (reflecting inclusive semantics). Ensure the rest of
the lowering uses the computed extent unchanged.
Summary by CodeRabbit
Release Notes
✏️ Tip: You can customize this high-level summary in your review settings.