Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
#
# ==============================================================================

.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda qwen3_5_moe-cuda qwen3_5_moe-metal clean help
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda gemma4_31b-mlx qwen3_5_moe-cuda qwen3_5_moe-metal clean help

help:
@echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make <target>\`. Available targets:"
Expand Down Expand Up @@ -127,6 +127,7 @@ help:
@echo " gemma3-cuda - Build Gemma3 runner with CUDA backend"
@echo " gemma3-cpu - Build Gemma3 runner with CPU backend"
@echo " gemma4_31b-cuda - Build Gemma 4 31B runner with CUDA backend"
@echo " gemma4_31b-mlx - Build Gemma 4 31B runner with MLX backend"
@echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend"
@echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend"
@echo " clean - Clean build artifacts"
Expand Down Expand Up @@ -435,6 +436,15 @@ gemma4_31b-cuda:
@echo "✓ Build complete!"
@echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner"

gemma4_31b-mlx:
@echo "==> Building and installing ExecuTorch with MLX..."
cmake --workflow --preset mlx-release
@echo "==> Building Gemma 4 31B runner with MLX..."
cd examples/models/gemma4_31b && cmake --workflow --preset gemma4-31b-mlx
@echo ""
@echo "✓ Build complete!"
@echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner"

qwen3_5_moe-metal:
@echo "==> Building and installing ExecuTorch with Metal..."
cmake --workflow --preset llm-release-metal
Expand Down
15 changes: 12 additions & 3 deletions examples/models/gemma4_31b/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,17 @@ list(
extension_flat_tensor
)

# CUDA backend (the only supported backend for this example for now)
# Backend: CUDA or MLX (exactly one required)
if(EXECUTORCH_BUILD_CUDA)
find_package(CUDAToolkit REQUIRED)
list(APPEND link_libraries aoti_cuda_backend)
executorch_target_link_options_shared_lib(aoti_cuda_backend)
add_compile_definitions(EXECUTORCH_BUILD_CUDA)
elseif(TARGET mlxdelegate)
list(APPEND link_libraries mlxdelegate mlx)
executorch_target_link_options_shared_lib(mlxdelegate)
else()
message(FATAL_ERROR "Set EXECUTORCH_BUILD_CUDA=ON")
message(FATAL_ERROR "Set EXECUTORCH_BUILD_CUDA=ON or EXECUTORCH_BUILD_MLX=ON")
endif()

# Tokenizer (HuggingFace tokenizer.json)
Expand All @@ -63,5 +66,11 @@ target_link_libraries(gemma4_31b_runner PUBLIC ${link_libraries})

if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
target_link_options_gc_sections(gemma4_31b_runner)
target_link_options(gemma4_31b_runner PRIVATE "LINKER:-s")
if(NOT APPLE AND NOT MSVC)
target_link_options(gemma4_31b_runner PRIVATE "LINKER:-s")
endif()
endif()

if(TARGET mlxdelegate)
executorch_target_copy_mlx_metallib(gemma4_31b_runner)
endif()
31 changes: 31 additions & 0 deletions examples/models/gemma4_31b/CMakePresets.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@
"string": "${hostSystemName}",
"list": ["Linux", "Windows"]
}
},
{
"name": "gemma4-31b-mlx",
"displayName": "Gemma 4 31B runner (MLX)",
"inherits": ["gemma4-31b-base"],
"cacheVariables": {},
"condition": {
"type": "equals",
"lhs": "${hostSystemName}",
"rhs": "Darwin"
}
}
],
"buildPresets": [
Expand All @@ -31,6 +42,12 @@
"displayName": "Build Gemma 4 31B runner (CUDA)",
"configurePreset": "gemma4-31b-cuda",
"targets": ["gemma4_31b_runner"]
},
{
"name": "gemma4-31b-mlx",
"displayName": "Build Gemma 4 31B runner (MLX)",
"configurePreset": "gemma4-31b-mlx",
"targets": ["gemma4_31b_runner"]
}
],
"workflowPresets": [
Expand All @@ -47,6 +64,20 @@
"name": "gemma4-31b-cuda"
}
]
},
{
"name": "gemma4-31b-mlx",
"displayName": "Configure and build Gemma 4 31B runner (MLX)",
"steps": [
{
"type": "configure",
"name": "gemma4-31b-mlx"
},
{
"type": "build",
"name": "gemma4-31b-mlx"
}
]
}
]
}
22 changes: 19 additions & 3 deletions examples/models/gemma4_31b/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Gemma 4 31B-IT

Text-only export of Google's Gemma 4 31B-IT to ExecuTorch with INT4/INT8
weight quantization. Currently supports the CUDA backend.
weight quantization. Supports CUDA and MLX (Apple Silicon) backends.

For architecture and design notes see [model.md](model.md).

Expand Down Expand Up @@ -67,6 +67,8 @@ recipe. Writes `model.safetensors`, `config.json`, and `tokenizer.json` into

## Export to ExecuTorch

### CUDA

```bash
python examples/models/gemma4_31b/export.py \
--prequantized ./gemma4_31b_int4 \
Expand All @@ -75,7 +77,20 @@ python examples/models/gemma4_31b/export.py \
--backend cuda
```

Writes `model.pte` and `model.ptd` into `--output-dir`.
### MLX (Apple Silicon)

```bash
python examples/models/gemma4_31b/export.py \
--prequantized ./gemma4_31b_int4 \
--output-dir ./gemma4_31b_exports_mlx \
--max-seq-len 4096 \
--backend mlx
```

The same quantized checkpoint works for both backends. MLX exports a single
method with dynamic sequence length and host-side sampling.

Writes `model.pte` (and optionally `model.ptd`) into `--output-dir`.

## Eager inference

Expand All @@ -102,7 +117,8 @@ model produces sensible text.
## Build the runner

```bash
make gemma4_31b-cuda
make gemma4_31b-cuda # Linux — CUDA backend
make gemma4_31b-mlx # macOS — MLX backend (Apple Silicon)
```

The binary lands at `cmake-out/examples/models/gemma4_31b/gemma4_31b_runner`.
Expand Down
139 changes: 135 additions & 4 deletions examples/models/gemma4_31b/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

Backends:
--backend cuda (default) CUDA via tinygemm INT4 + CudaPartitioner.
--backend mlx Apple Silicon via MLXPartitioner (single method,
dynamic seq_len, host-side sampling).
"""

import argparse
Expand Down Expand Up @@ -98,21 +100,36 @@ def load_and_quantize(
# Backend dispatch helpers


_SUPPORTED_BACKENDS = ("cuda", "mlx")


def _get_packers(backend: str) -> dict:
if backend == "cuda":
from executorch.examples.models.gemma4_31b.quant import DEFAULT_CUDA_PACKERS

return DEFAULT_CUDA_PACKERS
raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.")
if backend == "mlx":
from executorch.examples.models.gemma4_31b.quant import DEFAULT_MLX_PACKERS

return DEFAULT_MLX_PACKERS
raise ValueError(
f"Unsupported backend: {backend!r}. Supported: {_SUPPORTED_BACKENDS}."
)


def _pack_for_backend(model: nn.Module, path: str, backend: str) -> None:
if backend == "cuda":
from executorch.examples.models.gemma4_31b.quant import load_and_pack_for_cuda

load_and_pack_for_cuda(path, model)
elif backend == "mlx":
from executorch.examples.models.gemma4_31b.quant import load_and_pack_for_mlx

load_and_pack_for_mlx(path, model)
else:
raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.")
raise ValueError(
f"Unsupported backend: {backend!r}. Supported: {_SUPPORTED_BACKENDS}."
)


# ---------------------------------------------------------------------------
Expand All @@ -128,8 +145,12 @@ def export_and_lower(
"""Export and lower the model to ExecuTorch for the given backend."""
if backend == "cuda":
_export_cuda(model, config, output_dir)
elif backend == "mlx":
_export_mlx(model, config, output_dir)
else:
raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.")
raise ValueError(
f"Unsupported backend: {backend!r}. Supported: {_SUPPORTED_BACKENDS}."
)


def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> None:
Expand Down Expand Up @@ -258,6 +279,116 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -
print("Done.")


def _strip_sampler_from_forward(model: Gemma4_31B) -> None:
"""Replace forward with a ``(tokens, input_pos) → logits`` variant.

MLX samples on the host, so the on-device Gumbel-max sampler and its
temperature input are dead code. Stripping them produces a cleaner
exported graph.
"""
import types

def _clean_forward(self, tokens, input_pos):
x = self.embed_tokens(tokens) * self.embed_normalizer
sliding_mask, full_mask = self._build_masks(input_pos)
for layer in self.layers:
x = layer(x, input_pos, sliding_mask, full_mask)
x = self.norm(x)
logits = self.lm_head(x).float()
cap = self.logit_softcap.float()
return torch.tanh(logits / cap) * cap

model.forward = types.MethodType(_clean_forward, model)


def _export_mlx(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -> None:
"""Export to .pte via torch.export + MLX backend.

Unlike CUDA (which exports separate decode/prefill methods with an
Int4Tensor dispatch override), MLX uses a single method with dynamic
sequence length. No int4_dispatch import — IntxUnpackedToInt8Tensor's
default dispatch produces the ``dequantize_affine → linear`` pattern
that MLX's QuantizedLinearHandler matches.
"""
import gc

from executorch.backends.mlx import MLXPartitioner
from executorch.backends.mlx.passes import get_default_passes
from executorch.exir import (
EdgeCompileConfig,
ExecutorchBackendConfig,
to_edge_transform_and_lower,
)
from executorch.exir.passes import MemoryPlanningPass
from torch.export import Dim, export

_strip_sampler_from_forward(model)
materialize_runtime_buffers(model, dtype=torch.bfloat16)

max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2)
seq_dim = Dim("seq_len", min=1, max=max_prefill)

print(f"Exporting (T in [1, {max_prefill}])...")
with torch.no_grad():
exported = export(
model,
(
torch.tensor([[0, 1]], dtype=torch.long),
torch.tensor([0, 1], dtype=torch.long),
),
dynamic_shapes=({1: seq_dim}, {0: seq_dim}),
strict=True,
)

del model
gc.collect()

print("Lowering to ExecuTorch with MLX backend...")
et_prog = to_edge_transform_and_lower(
exported,
transform_passes=get_default_passes(),
partitioner=[MLXPartitioner()],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
),
constant_methods={
"get_max_seq_len": config.max_seq_len,
"get_vocab_size": config.vocab_size,
"get_n_layers": config.num_hidden_layers,
"get_max_prefill_chunk": max_prefill,
"use_kv_cache": True,
"use_sdpa_with_kv_cache": False,
"enable_dynamic_shape": True,
},
)

del exported
gc.collect()

et_program = et_prog.to_executorch(
config=ExecutorchBackendConfig(
extract_delegate_segments=True,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
),
)

del et_prog
gc.collect()

os.makedirs(output_dir, exist_ok=True)
pte_path = os.path.join(output_dir, "model.pte")
print(f"Saving to {pte_path}...")
with open(pte_path, "wb") as f:
et_program.write_to_file(f)
print(f" {os.path.getsize(pte_path) / 1024**2:.1f} MB")

if et_program._tensor_data:
et_program.write_tensor_data_to_file(output_dir)
print(f" Saved tensor data (.ptd) to {output_dir}/")
print("Done.")


# ---------------------------------------------------------------------------
# CLI

Expand Down Expand Up @@ -302,7 +433,7 @@ def main() -> None:
parser.add_argument(
"--backend",
default="cuda",
choices=["cuda"],
choices=list(_SUPPORTED_BACKENDS),
help="Target backend for export.",
)
args = parser.parse_args()
Expand Down
Loading
Loading