diff --git a/examples/notebooks/00_hello_flydsl.ipynb b/examples/notebooks/00_hello_flydsl.ipynb new file mode 100644 index 000000000..8e79384b6 --- /dev/null +++ b/examples/notebooks/00_hello_flydsl.ipynb @@ -0,0 +1,255 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a89f7d33", + "metadata": {}, + "source": [ + "\n", + "\n", + "\n", + "# Hello, FlyDSL\n", + "\n", + "**FlyDSL** is a Python DSL and MLIR compiler stack for writing high-performance\n", + "AMD GPU kernels. You write ordinary-looking Python; FlyDSL *traces* it into the\n", + "`fly` / `fly_rocdl` MLIR dialects, lowers that through ROCDL/LLVM, and emits a\n", + "HSACO binary that runs on the GPU.\n", + "\n", + "This is notebook **0 of an onboarding series** that builds up the\n", + "`flydsl.expr` foundation one idea at a time:\n", + "\n", + "| # | Notebook | Topic |\n", + "|---|----------|-------|\n", + "| 00 | *this one* | the mental model: `@kernel` / `@jit`, and how to read the IR |\n", + "| 01 | `01_numeric_types` | the scalar type system (ints, floats, bf16, fp8) |\n", + "| 02 | `02_struct` | aggregate value types with `@fx.struct` |\n", + "| 03 | `03_universal_ops` | target-agnostic `Universal*` atoms + a vector-add capstone |\n", + "\n", + "Layout algebra (`make_layout`, `logical_divide`, tiled copy, MMA) is intentionally\n", + "**not** covered yet — it gets its own series once these primitives are familiar.\n", + "\n", + "**Prerequisites:** a built/installed `flydsl`, a ROCm GPU, and `wurlitzer`\n", + "(`pip install wurlitzer`) so the notebook can show GPU `printf` output — see below." + ] + }, + { + "cell_type": "markdown", + "id": "88791ade", + "metadata": {}, + "source": [ + "## 1. Setup & sanity check\n", + "\n", + "Two imports cover almost everything:\n", + "\n", + "- `flydsl.compiler as flyc` — the compiler entry points (`@flyc.kernel`, `@flyc.jit`, `flyc.from_dlpack`).\n", + "- `flydsl.expr as fx` — the DSL surface you write *inside* a kernel (types, ops, atoms, `printf`).\n", + "\n", + "If the import below fails, FlyDSL isn't on your path yet — build it and\n", + "`pip install -e .` (see the project README), then restart the kernel." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5bf3709", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "import flydsl.compiler as flyc\n", + "import flydsl.expr as fx\n", + "from flydsl.runtime.device import get_rocm_arch\n", + "\n", + "print(\"torch sees GPU:\", torch.cuda.is_available())\n", + "print(\"ROCm arch :\", get_rocm_arch())" + ] + }, + { + "cell_type": "markdown", + "id": "dbaeeba1", + "metadata": {}, + "source": [ + "**A note on seeing GPU output.** `fx.printf` runs on the device and writes to the\n", + "process's stdout, which Jupyter does not capture on its own. The tiny helper below\n", + "runs a launcher and routes that output back into the notebook (via `wurlitzer`).\n", + "We'll use `show_gpu_output(...)` throughout the series. Outside Jupyter — running a\n", + "plain `.py` — you don't need any of this; `printf` just goes to your terminal." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd67645c", + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "from wurlitzer import pipes\n", + "\n", + "\n", + "def show_gpu_output(launcher, *args, **kwargs):\n", + " \"\"\"Run a @flyc.jit launcher and echo its GPU printf output into the notebook.\"\"\"\n", + " kwargs.setdefault(\"stream\", torch.cuda.Stream())\n", + " with pipes() as (out, _err):\n", + " launcher(*args, **kwargs)\n", + " torch.cuda.synchronize()\n", + " print(out.read(), end=\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "3d3a00c3", + "metadata": {}, + "source": [ + "## 2. Two decorators: `@flyc.kernel` and `@flyc.jit`\n", + "\n", + "FlyDSL splits a launch into two traced functions:\n", + "\n", + "- **`@flyc.kernel`** marks **device** code — the body that runs on each GPU thread.\n", + " Inside it you have intrinsics like `fx.thread_idx.x` and `fx.block_idx.x`.\n", + "- **`@flyc.jit`** marks the **host launcher**. It calls a kernel and `.launch(...)`es\n", + " it with a grid/block configuration.\n", + "\n", + "Both are *traced*, not interpreted: when first called, FlyDSL runs the Python once\n", + "to build MLIR, compiles it, and caches the result. So `block=(4, 1, 1)` is read at\n", + "**trace time**, while `fx.thread_idx.x` is a **runtime** value that differs per\n", + "thread.\n", + "\n", + "Here is the smallest possible kernel — it takes no tensors and just prints.\n", + "`fx.printf` uses `{}` placeholders (avoid a literal `%` in the format string — it\n", + "is consumed by the underlying device `printf`)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6596d590", + "metadata": {}, + "outputs": [], + "source": [ + "@flyc.kernel\n", + "def hello_kernel():\n", + " bid = fx.block_idx.x\n", + " tid = fx.thread_idx.x\n", + " fx.printf(\"hello from block {} thread {}\", bid, tid)\n", + "\n", + "\n", + "@flyc.jit\n", + "def hello(stream: fx.Stream = fx.Stream(None)):\n", + " hello_kernel().launch(grid=(1, 1, 1), block=(4, 1, 1), stream=stream)\n", + "\n", + "\n", + "show_gpu_output(hello)" + ] + }, + { + "cell_type": "markdown", + "id": "f078729c", + "metadata": {}, + "source": [ + "Four threads, four lines. The launch built a one-block grid of four threads, and\n", + "each thread reached the `printf`." + ] + }, + { + "cell_type": "markdown", + "id": "f5547c20", + "metadata": {}, + "source": [ + "## 3. Looking at the generated IR\n", + "\n", + "The fastest way to build intuition for what FlyDSL *did* is to read the MLIR it\n", + "produced. Set `FLYDSL_DUMP_IR=1` (and a dump directory) and FlyDSL writes one\n", + "`.mlir` file per compiler pass, from `00_origin.mlir` (the high-level `fly` IR)\n", + "down to the final ISA. The env var is read at **compile time**, so we set it, then\n", + "compile a fresh kernel and read back its first dump." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c233269", + "metadata": {}, + "outputs": [], + "source": [ + "import contextlib\n", + "import glob\n", + "import io\n", + "import os\n", + "import tempfile\n", + "\n", + "dump_dir = tempfile.mkdtemp(prefix=\"flydsl_ir_\")\n", + "os.environ[\"FLYDSL_DUMP_IR\"] = \"1\"\n", + "os.environ[\"FLYDSL_DUMP_DIR\"] = dump_dir\n", + "\n", + "\n", + "@flyc.kernel\n", + "def add_one_kernel(x: fx.Int32):\n", + " fx.printf(\"x + 1 = {}\", x + fx.Int32(1))\n", + "\n", + "\n", + "@flyc.jit\n", + "def add_one(x: fx.Int32, stream: fx.Stream = fx.Stream(None)):\n", + " add_one_kernel(x).launch(grid=(1, 1, 1), block=(1, 1, 1), stream=stream)\n", + "\n", + "\n", + "# Compile + run once (silence the verbose per-pass dump log).\n", + "with contextlib.redirect_stdout(io.StringIO()):\n", + " add_one(fx.Int32(41), stream=torch.cuda.Stream())\n", + " torch.cuda.synchronize()\n", + "\n", + "os.environ.pop(\"FLYDSL_DUMP_IR\", None) # stop dumping...\n", + "os.environ.pop(\"FLYDSL_DUMP_DIR\", None) # ...and clear the dump dir we set\n", + "\n", + "origin = sorted(glob.glob(os.path.join(dump_dir, \"*\", \"00_origin.mlir\")))[0]\n", + "with open(origin) as f:\n", + " print(f.read())" + ] + }, + { + "cell_type": "markdown", + "id": "5c0b1b6e", + "metadata": {}, + "source": [ + "Things to notice in that high-level `fly` IR:\n", + "\n", + "- `gpu.func @add_one_kernel_0(...) kernel { ... }` — the device kernel.\n", + "- `arith.addi %arg0, %c1_i32` — the `x + 1` you wrote, in MLIR form.\n", + "- `fly.print(...) {format = \"...\"}` — your `fx.printf`.\n", + "- `gpu.launch_func ... blocks in (...) threads in (...)` — the host-side launch.\n", + "\n", + "The numbered files after `00_origin.mlir` show each lowering step (layout\n", + "lowering, `fly`→`rocdl`, `gpu`→`llvm`, …) down to `*_final_isa.s`. Whenever a\n", + "kernel misbehaves, dumping the IR is the first move.\n", + "\n", + "---\n", + "**Next:** [`01_numeric_types`](01_numeric_types.ipynb) — the scalar type system you\n", + "just used (`fx.Int32`) in full: integers, floats, `bf16`, `fp8`, casts, and the\n", + "difference between compile-time and runtime values." + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/01_numeric_types.ipynb b/examples/notebooks/01_numeric_types.ipynb new file mode 100644 index 000000000..915fd8cf0 --- /dev/null +++ b/examples/notebooks/01_numeric_types.ipynb @@ -0,0 +1,247 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9f1ad4fb", + "metadata": {}, + "source": [ + "\n", + "\n", + "\n", + "# Numeric types\n", + "\n", + "Every scalar you compute with in a FlyDSL kernel has a **DSL numeric type**.\n", + "These types are thin Python wrappers over MLIR scalar types: `fx.Int32` is `i32`,\n", + "`fx.Float32` is `f32`, `fx.BFloat16` is `bf16`, and so on. They carry the type\n", + "information FlyDSL needs to emit the right `arith` / `math` ops and to check casts.\n", + "\n", + "This notebook covers the type families, how to construct and operate on values,\n", + "casts and promotion, and the difference between **compile-time** (`Constexpr`) and\n", + "**runtime** values." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "119e540e", + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "import flydsl.compiler as flyc\n", + "import flydsl.expr as fx\n", + "from wurlitzer import pipes\n", + "\n", + "\n", + "def show_gpu_output(launcher, *args, **kwargs):\n", + " \"\"\"Run a @flyc.jit launcher and echo its GPU printf output into the notebook.\n", + " (Device printf isn't captured by Jupyter on its own; wurlitzer routes it here.)\"\"\"\n", + " kwargs.setdefault(\"stream\", torch.cuda.Stream())\n", + " with pipes() as (out, _err):\n", + " launcher(*args, **kwargs)\n", + " torch.cuda.synchronize()\n", + " print(out.read(), end=\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "1c59421b", + "metadata": {}, + "source": [ + "## 1. The type families\n", + "\n", + "| Family | DSL types | MLIR |\n", + "|---|---|---|\n", + "| Signed int | `Int4`, `Int8`, `Int16`, `Int32`, `Int64` | `i4 … i64` |\n", + "| Unsigned int | `Uint8`, `Uint16`, `Uint32`, `Uint64` | `ui8 … ui64` |\n", + "| Float | `Float16`, `BFloat16`, `Float32`, `Float64` | `f16`, `bf16`, `f32`, `f64` |\n", + "| FP8 / FP6 / FP4 | `Float8E5M2`, `Float8E4M3FN`, …, `Float6E2M3FN`, `Float4E2M1FN` | OCP low-precision |\n", + "| Special | `Boolean` (`i1`), `Index` (loop/index type) | |\n", + "\n", + "Each type exposes its bit `.width` (no GPU needed to read it):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3c91613", + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "for t in [fx.Int8, fx.Int32, fx.Uint32, fx.Float16, fx.BFloat16,\n", + " fx.Float32, fx.Float8E4M3FN, fx.Float4E2M1FN, fx.Boolean, fx.Index]:\n", + " print(f\"{t.__name__:<14} width = {t.width}\")" + ] + }, + { + "cell_type": "markdown", + "id": "d98d8567", + "metadata": {}, + "source": [ + "## 2. Constructing and computing\n", + "\n", + "A value is created by calling the type: `fx.Int32(7)`, `fx.Float32(3.5)`.\n", + "\n", + "Arithmetic on these values **emits MLIR**, so it has to happen *inside a traced\n", + "kernel* (there is no MLIR context at the notebook top level). We surface the\n", + "results with `fx.printf`.\n", + "\n", + "Operators map to the obvious MLIR ops: `+ - * // / %` → `arith.add/sub/mul/...`,\n", + "bitwise `& | ^ ~ << >>`, and comparisons `< <= > >= == !=` (which return a\n", + "`Boolean`). Two display quirks in `printf`: avoid a literal `%` in the format\n", + "string (the device `printf` consumes it), and a true `Boolean` prints as `-1`\n", + "(all bits of the `i1` set)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90630369", + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "@flyc.kernel\n", + "def numeric_demo():\n", + " a = fx.Int32(7)\n", + " b = fx.Int32(3)\n", + " fx.printf(\"a+b={} a-b={} a*b={} a//b={} a-mod-b={}\", a + b, a - b, a * b, a // b, a % b)\n", + " fx.printf(\"a&b={} a|b={} a^b={} a<<1={}\", a & b, a | b, a ^ b, a << fx.Int32(1))\n", + " fx.printf(\"a>b -> {} (true is -1) a==b -> {}\", a > b, a == b)\n", + "\n", + "\n", + "@flyc.jit\n", + "def run_numeric(stream: fx.Stream = fx.Stream(None)):\n", + " numeric_demo().launch(grid=(1, 1, 1), block=(1, 1, 1), stream=stream)\n", + "\n", + "\n", + "show_gpu_output(run_numeric)" + ] + }, + { + "cell_type": "markdown", + "id": "ebf0dd18", + "metadata": {}, + "source": [ + "## 3. Casts and type promotion\n", + "\n", + "Convert explicitly with `.to(TargetType)`. In a mixed-type expression FlyDSL also\n", + "**promotes** operands to a common type (wider wins; float beats int; on a tie,\n", + "unsigned beats signed). The `fx.math` module provides the usual float functions\n", + "(`sqrt`, `exp`, `exp2`, `log`, `sin`, `fma`, …)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2b9c61c", + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "@flyc.kernel\n", + "def cast_demo():\n", + " i = fx.Int32(7)\n", + " f = i.to(fx.Float32) # explicit int -> float\n", + " g = fx.Float32(2.0)\n", + " fx.printf(\"f/g={} sqrt(f)={} exp2(g)={}\", f / g, fx.math.sqrt(f), fx.math.exp2(g))\n", + "\n", + " # mixed: Int32 + Float32 promotes to Float32\n", + " promoted = fx.Int32(3) + g\n", + " fx.printf(\"int(3) + float(2.0) = {} (promoted to f32)\", promoted)\n", + "\n", + "\n", + "@flyc.jit\n", + "def run_cast(stream: fx.Stream = fx.Stream(None)):\n", + " cast_demo().launch(grid=(1, 1, 1), block=(1, 1, 1), stream=stream)\n", + "\n", + "\n", + "show_gpu_output(run_cast)" + ] + }, + { + "cell_type": "markdown", + "id": "1757ec2d", + "metadata": {}, + "source": [ + "## 4. Compile-time (`Constexpr`) vs runtime values\n", + "\n", + "A kernel parameter typed as `fx.Constexpr[T]` is a **Python value at trace time** —\n", + "it can size loops, pick code paths, and is baked into the compiled kernel (it is\n", + "part of the cache key, so a new value triggers a recompile). A parameter typed as a\n", + "numeric type (e.g. `fx.Int32`) is a **runtime IR value**, materialized in the kernel." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae48217b", + "metadata": {}, + "outputs": [], + "source": [ + "@flyc.kernel\n", + "def constexpr_demo(scale: fx.Constexpr[int], x: fx.Int32):\n", + " # `scale` is a real Python int here — we can use it in plain Python.\n", + " label = \"doubling\" if scale == 2 else f\"x{scale}\"\n", + " fx.printf(\"scale is a python int at trace time: \" + label)\n", + " fx.printf(\"x * scale = {}\", x * scale)\n", + "\n", + "\n", + "@flyc.jit\n", + "def run_constexpr(x: fx.Int32, scale: fx.Constexpr[int], stream: fx.Stream = fx.Stream(None)):\n", + " constexpr_demo(scale, x).launch(grid=(1, 1, 1), block=(1, 1, 1), stream=stream)\n", + "\n", + "\n", + "show_gpu_output(run_constexpr, fx.Int32(10), 4)" + ] + }, + { + "cell_type": "markdown", + "id": "b30b757b", + "metadata": {}, + "source": [ + "## 5. Low-precision types\n", + "\n", + "AMD GPUs lean heavily on reduced precision, so the half- and sub-8-bit formats in\n", + "the table above (`BFloat16`, the `Float8*` / `Float6*` / `Float4*` families) are\n", + "first-class DSL types — you declare buffers and accumulators with them just like\n", + "`Float32`. Their narrow mantissas trade accuracy for speed and memory bandwidth.\n", + "The payoff shows up where data is *moved* and *multiplied* in low precision — the\n", + "copy atoms (next notebook) and the MMA atoms (the later layout/MMA series) — rather\n", + "than in scalar arithmetic, so we only introduce the type names here.\n", + "\n", + "---\n", + "**Next:** [`02_struct`](02_struct.ipynb) — bundling these scalars into aggregate\n", + "value types with `@fx.struct`." + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/02_struct.ipynb b/examples/notebooks/02_struct.ipynb new file mode 100644 index 000000000..d65c7d4d1 --- /dev/null +++ b/examples/notebooks/02_struct.ipynb @@ -0,0 +1,209 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a8756da0", + "metadata": {}, + "source": [ + "\n", + "\n", + "\n", + "# Structs\n", + "\n", + "Once you have scalars, you often want to **bundle** them — a pair of accumulators,\n", + "a small parameter block, a chunk of shared-memory scratch. FlyDSL provides\n", + "`@fx.struct`: a decorator that turns an annotated class into a **frozen, C-layout\n", + "value type** whose fields are DSL types.\n", + "\n", + "This notebook covers defining a struct, its memory layout, reading/replacing\n", + "fields, and the most common real use — describing a shared-memory block." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fac6b59a", + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "import flydsl.compiler as flyc\n", + "import flydsl.expr as fx\n", + "from flydsl.compiler.protocol import dsl_align_of, dsl_size_of\n", + "from wurlitzer import pipes\n", + "\n", + "\n", + "def show_gpu_output(launcher, *args, **kwargs):\n", + " \"\"\"Run a @flyc.jit launcher and echo its GPU printf output into the notebook.\"\"\"\n", + " kwargs.setdefault(\"stream\", torch.cuda.Stream())\n", + " with pipes() as (out, _err):\n", + " launcher(*args, **kwargs)\n", + " torch.cuda.synchronize()\n", + " print(out.read(), end=\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "34d40035", + "metadata": {}, + "source": [ + "## 1. Defining a struct\n", + "\n", + "Annotate fields with DSL types. Fields may be scalars, fixed-size `fx.Array`s,\n", + "nested structs, or `fx.Constexpr` (compile-time-only, contributes **no** storage)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da712ac3", + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "@fx.struct\n", + "class Pair:\n", + " a: fx.Int32\n", + " b: fx.Float32\n", + "\n", + "\n", + "@fx.struct\n", + "class WithConst:\n", + " a: fx.Int32\n", + " b: fx.Float32\n", + " n: fx.Constexpr[int] # compile-time only\n", + "\n", + "\n", + "@fx.struct\n", + "class Scratch:\n", + " buf: fx.Array[fx.Float32, 16]\n", + " count: fx.Int32" + ] + }, + { + "cell_type": "markdown", + "id": "f0d6954f", + "metadata": {}, + "source": [ + "## 2. Memory layout\n", + "\n", + "A struct lays out like a C struct: fields in declaration order, padded to each\n", + "field's natural alignment; the struct's alignment is the max of its fields'. Query\n", + "the computed size/alignment (in bytes) with `dsl_size_of` / `dsl_align_of` — no GPU\n", + "needed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6025a0b", + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "print(f\"Pair size={dsl_size_of(Pair):>3} align={dsl_align_of(Pair)}\")\n", + "print(f\"WithConst size={dsl_size_of(WithConst):>3} align={dsl_align_of(WithConst)} (constexpr field adds 0 bytes)\")\n", + "print(f\"Scratch size={dsl_size_of(Scratch):>3} align={dsl_align_of(Scratch)} (16 * f32 + i32)\")" + ] + }, + { + "cell_type": "markdown", + "id": "260e1174", + "metadata": {}, + "source": [ + "`Pair` is `4 + 4 = 8` bytes. `WithConst` is also `8` — the `Constexpr` field exists\n", + "only at compile time. `Scratch` is `16*4 + 4 = 68` bytes.\n", + "\n", + "## 3. Constructing, reading, and `.replace()`\n", + "\n", + "Struct values are **frozen**: you read fields with `.field`, and produce a modified\n", + "copy with `.replace(field=...)`. Field values are DSL values, so — like any other\n", + "DSL computation — we build and inspect them *inside a kernel* and print with\n", + "`fx.printf`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "133393e6", + "metadata": {}, + "outputs": [], + "source": [ + "@fx.struct\n", + "class Accum:\n", + " total: fx.Float32\n", + " n: fx.Int32\n", + "\n", + "\n", + "@flyc.kernel\n", + "def struct_demo(x: fx.Int32):\n", + " p = Pair(x, x.to(fx.Float32)) # construct from a runtime value\n", + " fx.printf(\"p.a={} p.b={}\", p.a, p.b)\n", + "\n", + " q = p.replace(a=x + fx.Int32(100)) # frozen update -> new value\n", + " fx.printf(\"q.a={} q.b={} (p unchanged: p.a={})\", q.a, q.b, p.a)\n", + "\n", + " acc = Accum(total=fx.Float32(0.0), n=fx.Int32(0))\n", + " acc = acc.replace(total=acc.total + fx.Float32(2.5), n=acc.n + fx.Int32(1))\n", + " fx.printf(\"acc.total={} acc.n={}\", acc.total, acc.n)\n", + "\n", + "\n", + "@flyc.jit\n", + "def run_struct(x: fx.Int32, stream: fx.Stream = fx.Stream(None)):\n", + " struct_demo(x).launch(grid=(1, 1, 1), block=(1, 1, 1), stream=stream)\n", + "\n", + "\n", + "show_gpu_output(run_struct, fx.Int32(7))" + ] + }, + { + "cell_type": "markdown", + "id": "4208ec8c", + "metadata": {}, + "source": [ + "## 4. The common real use: shared-memory scratch\n", + "\n", + "The biggest payoff for structs is describing a **shared-memory layout**. A struct\n", + "with `fx.Array` fields names the LDS scratch a kernel needs, and the size/alignment\n", + "you just saw is exactly what the shared-memory allocator reserves.\n", + "\n", + "You can see this pattern in the shipped kernels — e.g. the reduction buffers in\n", + "`kernels/softmax_kernel.py` and `kernels/layernorm_kernel.py` are declared as a\n", + "`@fx.struct SharedStorage` with `fx.Array[fx.Float32, ...]` fields. Allocating and\n", + "viewing that storage inside a kernel uses the layout/shared-memory APIs, which are\n", + "the subject of a later notebook — here the point is just that **the struct is how\n", + "you describe the block.**\n", + "\n", + "---\n", + "**Next:** [`03_universal_ops`](03_universal_ops.ipynb) — moving data with the\n", + "target-agnostic `Universal*` atoms, ending in a complete vector-add kernel." + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/03_universal_ops.ipynb b/examples/notebooks/03_universal_ops.ipynb new file mode 100644 index 000000000..3ef10e63b --- /dev/null +++ b/examples/notebooks/03_universal_ops.ipynb @@ -0,0 +1,233 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "26482a53", + "metadata": {}, + "source": [ + "\n", + "\n", + "\n", + "# Target-agnostic operations: the `Universal*` atoms\n", + "\n", + "FlyDSL expresses data movement and math through **atoms** — small, named\n", + "operations like \"copy 32 bits\" or \"multiply-accumulate\". Atoms come in two flavors:\n", + "\n", + "- **Target-agnostic** `Universal*` atoms describe *what* to do. The backend decides\n", + " *how* to realize it on the actual GPU.\n", + "- **Architecture-specific** atoms under `fx.rocdl.*` (e.g. `BufferCopy32b`, `MFMA`)\n", + " name a particular hardware instruction family (here, AMD CDNA buffer ops / MFMA).\n", + "\n", + "This notebook focuses on the universal flavor — write once, let the compiler\n", + "specialize — and ends with a complete vector-add kernel built entirely from\n", + "universal atoms." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3bbff8f1", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "import flydsl.compiler as flyc\n", + "import flydsl.expr as fx" + ] + }, + { + "cell_type": "markdown", + "id": "c4bd1386", + "metadata": {}, + "source": [ + "## 1. The universal atom family\n", + "\n", + "- **Copy:** `UniversalCopy8b`, `UniversalCopy16b`, `UniversalCopy32b`,\n", + " `UniversalCopy64b`, `UniversalCopy128b` — move *N* bits between tensors.\n", + "- **Multiply-accumulate / atomics:** `UniversalFMA`, plus `UniversalAtomicAdd`,\n", + " `UniversalAtomicMax`, `UniversalAtomicMin`, `UniversalAtomicAnd/Or`,\n", + " `UniversalAtomicInc/Dec` round out the family. Their call shapes belong to the\n", + " later layout/MMA series — on AMD, matrix-multiply lowers through `fx.rocdl.MFMA`.\n", + "\n", + "This notebook uses only the **copy** atoms, since they need nothing but a tensor\n", + "and a dtype. You build one and invoke it like this:\n", + "\n", + "```python\n", + "atom = fx.make_copy_atom(fx.UniversalCopy32b(), fx.Float32)\n", + "fx.copy_atom_call(atom, src, dst) # copy src -> dst\n", + "```\n", + "\n", + "Swapping `fx.UniversalCopy32b()` for `fx.rocdl.BufferCopy32b()` would pin this to\n", + "CDNA buffer loads/stores; the universal atom keeps it portable." + ] + }, + { + "cell_type": "markdown", + "id": "cd23c384", + "metadata": {}, + "source": [ + "## 2. Capstone: a vector add, fully universal\n", + "\n", + "Each thread loads one element of `A` and one of `B` into registers, adds them, and\n", + "stores the result to `C`. The loads and the store are all `UniversalCopy32b`.\n", + "\n", + "**Passing tensors in.** A kernel parameter is an `fx.Tensor`. You can hand `@flyc.jit`\n", + "a raw torch tensor — as we do for `B` and `C`, which it adapts via DLPack — or convert\n", + "explicitly with `flyc.from_dlpack(t)`, as we do for `A`. The explicit form lets you add\n", + "`.mark_layout_dynamic(leading_dim=0, divisibility=4)`, telling the compiler that\n", + "dimension is sized at runtime and `divisibility`-byte aligned so it can vectorize.\n", + "\n", + "**Compute on register tensors.** `rA`/`rB`/`rC` are single-element *register tensors*,\n", + "not scalars, so we pull values out with `fx.memref_load_vec`, add them with an explicit\n", + "`fx.arith.addf`, and write the result back with `fx.memref_store_vec`. The `+` operator\n", + "overloading from notebook 01 acts on `fx.Int32`/`fx.Float32` *values* — not on a\n", + "register-tensor handle.\n", + "\n", + "We do use a little layout machinery (`make_layout`, `logical_divide`, `slice`) purely to\n", + "hand each thread *its* element — that is the subject of the next series, so don't dwell\n", + "on it here. The point of this notebook is the **copy**: the exact same `copy_atom_call`\n", + "code runs on any FlyDSL target." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd234b38", + "metadata": {}, + "outputs": [], + "source": [ + "# Turn IR dumping on *before* the first compile, so we can inspect it in the next cell.\n", + "import contextlib\n", + "import glob\n", + "import io\n", + "import os\n", + "import tempfile\n", + "\n", + "dump_dir = tempfile.mkdtemp(prefix=\"flydsl_ir_\")\n", + "os.environ[\"FLYDSL_DUMP_IR\"] = \"1\"\n", + "os.environ[\"FLYDSL_DUMP_DIR\"] = dump_dir\n", + "\n", + "\n", + "@flyc.kernel\n", + "def vadd_kernel(A: fx.Tensor, B: fx.Tensor, C: fx.Tensor, block_dim: fx.Constexpr[int]):\n", + " bid = fx.block_idx.x\n", + " tid = fx.thread_idx.x\n", + "\n", + " # Hand thread (bid, tid) its single element of each vector. This little\n", + " # partitioning is layout algebra — the topic of the next series; don't dwell on it.\n", + " tA = fx.logical_divide(A, fx.make_layout(block_dim, 1))\n", + " tB = fx.logical_divide(B, fx.make_layout(block_dim, 1))\n", + " tC = fx.logical_divide(C, fx.make_layout(block_dim, 1))\n", + " tA = fx.logical_divide(fx.slice(tA, (None, bid)), fx.make_layout(1, 1))\n", + " tB = fx.logical_divide(fx.slice(tB, (None, bid)), fx.make_layout(1, 1))\n", + " tC = fx.logical_divide(fx.slice(tC, (None, bid)), fx.make_layout(1, 1))\n", + "\n", + " # the target-agnostic copy atom\n", + " copy = fx.make_copy_atom(fx.UniversalCopy32b(), fx.Float32)\n", + "\n", + " rA = fx.make_rmem_tensor(fx.make_layout(1, 1), fx.Float32)\n", + " rB = fx.make_rmem_tensor(fx.make_layout(1, 1), fx.Float32)\n", + " rC = fx.make_rmem_tensor(fx.make_layout(1, 1), fx.Float32)\n", + "\n", + " fx.copy_atom_call(copy, fx.slice(tA, (None, tid)), rA) # global -> register\n", + " fx.copy_atom_call(copy, fx.slice(tB, (None, tid)), rB)\n", + "\n", + " vC = fx.arith.addf(fx.memref_load_vec(rA), fx.memref_load_vec(rB))\n", + " fx.memref_store_vec(vC, rC)\n", + "\n", + " fx.copy_atom_call(copy, rC, fx.slice(tC, (None, tid))) # register -> global\n", + "\n", + "\n", + "@flyc.jit\n", + "def vadd(A: fx.Tensor, B: fx.Tensor, C: fx.Tensor, n: fx.Int32, stream: fx.Stream = fx.Stream(None)):\n", + " block_dim = 64\n", + " grid_x = (n + block_dim - 1) // block_dim\n", + " vadd_kernel(A, B, C, block_dim).launch(grid=(grid_x, 1, 1), block=[block_dim, 1, 1], stream=stream)\n", + "\n", + "\n", + "# Run it and validate against torch (hiding the verbose per-pass dump log).\n", + "n = 256\n", + "A = torch.randint(0, 10, (n,), dtype=torch.float32).cuda()\n", + "B = torch.randint(0, 10, (n,), dtype=torch.float32).cuda()\n", + "C = torch.zeros(n, dtype=torch.float32).cuda()\n", + "tA = flyc.from_dlpack(A).mark_layout_dynamic(leading_dim=0, divisibility=4)\n", + "\n", + "with contextlib.redirect_stdout(io.StringIO()):\n", + " vadd(tA, B, C, n, stream=torch.cuda.Stream())\n", + " torch.cuda.synchronize()\n", + "\n", + "os.environ.pop(\"FLYDSL_DUMP_IR\", None) # stop dumping...\n", + "os.environ.pop(\"FLYDSL_DUMP_DIR\", None) # ...and clear the dump dir we set\n", + "print(\"matches torch:\", torch.allclose(C, A + B))\n", + "print(\"C[:8] =\", C[:8].tolist())" + ] + }, + { + "cell_type": "markdown", + "id": "37d6bafa", + "metadata": {}, + "source": [ + "## 3. Seeing \"universal\" in the IR\n", + "\n", + "Dump the IR and look at the high-level `00_origin.mlir`: the copies appear as\n", + "generic `fly` copy-atom ops, with no architecture baked in. The specialization to\n", + "this gfx950 box happens later, in the `convert-fly-to-rocdl` pass." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b58253fe", + "metadata": {}, + "outputs": [], + "source": [ + "origin = sorted(glob.glob(os.path.join(dump_dir, \"*\", \"00_origin.mlir\")))[0]\n", + "with open(origin) as f:\n", + " atom_lines = [ln.strip() for ln in f.read().splitlines() if \"copy_atom\" in ln]\n", + "print(\"copy ops in 00_origin.mlir — note the target-agnostic `universal_copy<32>`:\\n\")\n", + "print(\"\\n\".join(atom_lines))" + ] + }, + { + "cell_type": "markdown", + "id": "9ed04817", + "metadata": {}, + "source": [ + "## Recap\n", + "\n", + "Across these four notebooks you have the `flydsl.expr` foundation:\n", + "\n", + "- **00** — the `@kernel` / `@jit` trace model and how to read dumped IR.\n", + "- **01** — the numeric type system: ints, floats, `bf16`/`fp8`, casts, `Constexpr`.\n", + "- **02** — `@fx.struct` aggregate value types and their memory layout.\n", + "- **03** — target-agnostic `Universal*` atoms and a complete vector add.\n", + "\n", + "**Next series:** layout algebra — `make_layout`, `logical_divide`, `partition`,\n", + "tiled copy, and the MMA atoms (`rocdl.MFMA`). That is where the small slicing we\n", + "glossed over above is explained from first principles." + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/README.md b/examples/notebooks/README.md new file mode 100644 index 000000000..bde53283b --- /dev/null +++ b/examples/notebooks/README.md @@ -0,0 +1,74 @@ + + + +# FlyDSL onboarding notebooks + +An interactive, bottom-up introduction to the `flydsl.expr` foundation. Work through +them in order — each builds on the last, and the series stops short of layout algebra +(`make_layout`, `logical_divide`, tiled copy, MMA), which gets its own follow-up series. + +| # | Notebook | Topic | +|---|----------|-------| +| 00 | [`00_hello_flydsl.ipynb`](00_hello_flydsl.ipynb) | the `@flyc.kernel` / `@flyc.jit` model; reading dumped IR | +| 01 | [`01_numeric_types.ipynb`](01_numeric_types.ipynb) | scalar types: ints, floats, `bf16`/`fp8`, casts, `Constexpr` | +| 02 | [`02_struct.ipynb`](02_struct.ipynb) | `@fx.struct` aggregate value types and their memory layout | +| 03 | [`03_universal_ops.ipynb`](03_universal_ops.ipynb) | target-agnostic `Universal*` atoms + a vector-add capstone | + +## API cheat-sheet + +The whole `flydsl.expr` foundation these notebooks cover, in one place — enough to +write a kernel without reading the source. Layout ops (`make_layout`, +`logical_divide`, tiled copy, MMA) are deliberately out of scope; they get their +own series. + +```python +# Kernel + launch (00) +@flyc.kernel # device kernel; the body is traced to MLIR +@flyc.jit # host launch wrapper +kernel(args).launch(grid=(gx, 1, 1), block=[bx, 1, 1], stream=stream) +flyc.from_dlpack(t) # torch tensor -> fx.Tensor view (jit also accepts a raw torch tensor) + .mark_layout_dynamic(leading_dim=0, divisibility=4) # dim sized at runtime, n-byte aligned for vectorization + +# Scalars (01) — construct at top level; arithmetic and casts run only inside a trace +fx.Int32(7) fx.Float32(2.0) fx.Boolean(True) +v.to(fx.Float16) # cast (.width works at top level; ops need an active trace) +fx.Constexpr[int] # trace-time Python value; folds into the kernel and the JIT cache key + +# Structs (02) +@fx.struct # frozen aggregate value type; v.replace(field=...) returns a copy +from flydsl.compiler.protocol import dsl_size_of, dsl_align_of # host-side; NOT attributes of fx + +# Copy atoms + register tensors (03) +atom = fx.make_copy_atom(fx.UniversalCopy32b(), fx.Float32) # target-agnostic +fx.copy_atom_call(atom, src, dst) # copy src -> dst +rt = fx.make_rmem_tensor(fx.make_layout(1, 1), fx.Float32) # per-thread register tensor +fx.memref_load_vec(rt) / fx.memref_store_vec(val, rt) # read / write a register tensor +fx.arith.addf(a, b) # explicit op on loaded values (`+` is for fx scalar values, not tensors) +``` + +Three gotchas worth front-loading: + +- `fx.printf` takes only bare `{}` (no `{:.2f}`); a literal `%` is consumed by the + device printf (write `"mod"`); a true `Boolean` prints as `-1`. +- Device `printf` is not captured by Jupyter — wrap the launch in + `with wurlitzer.pipes() as (out, _): launch(...); torch.cuda.synchronize()`, then `print(out.read())`. +- `Constexpr` `fp8`/`bf16` math is not rounded until the value is materialized as its + MLIR type; only `f16`/`f32`/`f64` fold at trace time. + +## Running + +These notebooks execute kernels, so they need a built/installed FlyDSL and a ROCm GPU, +plus a couple of notebook tools: + +```bash +pip install jupyter wurlitzer +``` + +`wurlitzer` lets the notebooks show GPU `printf` output inline — Jupyter does not +capture device stdout on its own. Then open them with Jupyter, or run headless: + +```bash +jupyter nbconvert --to notebook --execute --inplace examples/notebooks/*.ipynb +``` + +Cell outputs are committed **cleared**; run the cells to populate them.