diff --git a/tilelang/language/overrides/parser.py b/tilelang/language/overrides/parser.py index 01d59b6078..6c028efc18 100644 --- a/tilelang/language/overrides/parser.py +++ b/tilelang/language/overrides/parser.py @@ -8,6 +8,7 @@ from tvm.tir import BufferLoad, Var from tvm.script.parser.tir import parser as tvm_tir_parser +from tilelang.language.tir.ir import SerialStepSpec def _get_node_span(node: doc.AST) -> tuple[int, int, int, int]: @@ -142,3 +143,65 @@ def tilelang_visit_ann_assign(self, node: doc.AnnAssign) -> None: # pylint: dis frame = T.LetStmt(rhs, var=ann_var) frame.add_callback(partial(frame.__exit__, None, None, None)) frame.__enter__() + + +# Override For to support stepped serial: T.serial(start, end, step) +@dispatch.register(token="tir", type_name="For") +def tilelang_visit_for(self, node: doc.For) -> None: # pylint: disable=unused-argument + """Override `For` to add support for T.serial(start, end, step). + + When the iterable is a SerialStepSpec, lower it to a unit-step loop over + t in [0, floor_div(|end-start|, step)] and bind the loop variable using a + Let to `start + t*step` (inclusive semantics). + """ + iter_val = self.eval_expr(node.iter) + + # Fast path: fall back to TVM default behavior when not a SerialStepSpec + if not isinstance(iter_val, SerialStepSpec): + if not isinstance(iter_val, T.frame.ForFrame): + self.report_error( + node.iter, + "Expect the for loop to be one of the following: " + "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding", + ) + with self.var_table.with_frame(): + with iter_val as iters: + self.eval_assign(target=node.target, source=iters, bind_value=tvm_tir_parser.bind_for_value) + self.visit_body(node.body) + return + + # Stepped inclusive serial: require positive integer step + start = iter_val.start + end = iter_val.stop + step = iter_val.step + annotations = iter_val.annotations + + # Normalize step to Python int if possible, otherwise expect IntImm-like + if isinstance(step, int): + step_val = step + else: + step_val = getattr(step, "value", None) + if step_val is None: + self.report_error(node.iter, "T.serial step must be an integer or IntImm") + return + + if step_val <= 0: + self.report_error(node.iter, "T.serial step must be a positive integer") + return + + # Use tvm.tir.floordiv via builder ops from tilelang.tir.ir if available + # Avoid importing op wrappers; compute using arithmetic to keep it simple. + # We construct: T.ceildiv((end - start), step) + extent = T.ceildiv(end - start, step_val) # type: ignore[operator] + + for_frame = T.serial(0, extent, annotations=annotations) + with self.var_table.with_frame(): + with for_frame as t: + # Bind loop target as Let var: i = start + t * step + stepped_index = start + t * step_val # type: ignore[operator] + self.eval_assign( + target=node.target, + source=stepped_index, + bind_value=tvm_tir_parser.bind_assign_value, + ) + self.visit_body(node.body) diff --git a/tilelang/language/tir/ir.py b/tilelang/language/tir/ir.py index 6d2487c066..977e65036c 100644 --- a/tilelang/language/tir/ir.py +++ b/tilelang/language/tir/ir.py @@ -7,10 +7,26 @@ import functools +class SerialStepSpec: + """A lightweight spec object for stepped serial loops. + + This is consumed by the TileLang TIR parser override to realize + inclusive stepped loops like T.serial(start, end, step). + """ + + def __init__(self, start: PrimExpr, stop: PrimExpr, step: PrimExpr | int, + annotations: dict[str, Any] | None): + self.start = start + self.stop = stop + self.step = step + self.annotations = annotations + + def serial(start: PrimExpr, - stop: PrimExpr = None, + stop: PrimExpr | None = None, + step: PrimExpr | int | None = None, *, - annotations: dict[str, Any] = None) -> frame.ForFrame: + annotations: dict[str, Any] = None) -> frame.ForFrame | SerialStepSpec: """The serial For statement. Parameters @@ -21,6 +37,14 @@ def serial(start: PrimExpr, stop : PrimExpr The maximum value of iteration. + step : PrimExpr | int | None + Optional step size of iteration. When provided as the third positional + argument (or keyword), the loop iterates inclusively with stride `step`: + i = start, start+step, ..., <= end. If `end-start` is not divisible by + `step`, the last value will be the largest `start + k*step` such that + it does not exceed `end` (for positive step). Negative steps are not + currently supported. + annotations : Dict[str, Any] The optional annotations of the For statement. @@ -29,7 +53,17 @@ def serial(start: PrimExpr, res : frame.ForFrame The ForFrame. """ - return _ir.serial(start=start, stop=stop, annotations=annotations) + # If no step is provided, delegate to the upstream builder (supports + # both one-arg and two-arg forms). + if step is None: + return _ir.serial(start=start, stop=stop, annotations=annotations) + + # Step provided: return a spec for the parser override to lower into an + # inclusive stepped loop. Require `stop` to be provided explicitly. + if stop is None: + raise TypeError("T.serial(start, end, step): `end` must be provided when `step` is set") + + return SerialStepSpec(start=start, stop=stop, step=step, annotations=annotations) def parallel(start: PrimExpr,