Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions examples/models/llama/eval_llama_lib.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -46,9 +47,13 @@ def __init__(
use_kv_cache: bool = False,
generate_full_logits: bool = False,
enable_dynamic_shape: bool = True,
device: Optional[str] = None,
):
super().__init__(
model=model, tokenizer=tokenizer, max_seq_length=max_seq_length
model=model,
tokenizer=tokenizer,
max_seq_length=max_seq_length,
device=device,
)
self._model = model.to(self.device)
self._use_kv_cache = use_kv_cache
Expand Down Expand Up @@ -80,7 +85,47 @@ def _model_call(self, inps):
return logits

else:
return self._model(inps)
# lm-eval expects logits shaped [batch, seq, vocab].
# The KV-cache path above handles that separately. In the non-KV path,
# some exported graphs (when generate_full_logits=False) return only
# last-position logits [batch, vocab], so reconstruct per-position
# logits by running prefix calls.

seq_len = inps.shape[-1]

def pad_to_max_len(tokens: torch.Tensor) -> torch.Tensor:
if self._enable_dynamic_shape:
return tokens
token_len = tokens.shape[-1]
if token_len < self._max_seq_length:
pad_len = self._max_seq_length - token_len
pad_token = getattr(
self._tokenizer, "pad_id", self._tokenizer.eos_id
)
pad = torch.full(
(tokens.shape[0], pad_len),
pad_token,
dtype=tokens.dtype,
device=tokens.device,
)
return torch.cat((tokens, pad), dim=-1)
if token_len > self._max_seq_length:
return tokens[:, : self._max_seq_length]
return tokens

if self._generate_full_logits:
return self._model(pad_to_max_len(inps))

# Reconstruct full logits by running prefixes.
result_logits = []
for pos in range(min(seq_len, self._max_seq_length)):
prefix = pad_to_max_len(inps[:, : pos + 1])
logits = self._model(prefix)
if logits.dim() == 3:
logits = logits[:, -1, :]
result_logits.append(logits)

return torch.stack(result_logits, dim=1)
Comment on lines +116 to +128
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the non-KV-cache + generate_full_logits=False case, logits are reconstructed by running the model once per prefix position (for pos in range(...)). This makes lm-eval O(seq_len) forward passes per sample, which can become extremely slow at larger seq lengths. If possible, prefer exporting with generate_full_logits=True for evaluation/calibration, or add a fast path/guard (e.g., only reconstruct up to the required positions or raise with guidance when seq_len is large).

Copilot uses AI. Check for mistakes.

def _model_generate(self, context, max_length, eos_token_id):
raise Exception("unimplemented")
Expand Down
8 changes: 5 additions & 3 deletions examples/models/llama/evaluate/eager_eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -28,12 +29,13 @@ def __init__(
tokenizer: Union[SentencePieceTokenizer, Tiktoken, HuggingFaceTokenizer],
max_seq_length: Optional[int] = None,
use_kv_cache: bool = False,
device: Optional[str] = None,
):
device = "cuda" if torch.cuda.is_available() else "cpu"
super().__init__(device=device, pretrained="gpt2")
resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
super().__init__(device=resolved_device, pretrained="gpt2")
self._model = model
self._tokenizer = tokenizer
self._device = torch.device(device)
self._device = torch.device(resolved_device)
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
self._use_kv_cache = use_kv_cache

Expand Down
15 changes: 15 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def build_args_parser() -> argparse.ArgumentParser:
"vulkan_8w",
"tosa_8a8w",
"ethosu_8a8w",
"vgf_8a8w",
"vgf_16a8w",
],
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
)
Expand Down Expand Up @@ -456,6 +458,18 @@ def build_args_parser() -> argparse.ArgumentParser:
)
parser.add_argument("-V", "--vulkan", action="store_true")
parser.add_argument("--vulkan-force-fp16", action="store_true")
parser.add_argument("--vgf", action="store_true")
parser.add_argument(
"--vgf-compile-spec",
default="TOSA-1.0+INT",
help="VGF compile spec, e.g. TOSA-1.0+INT or TOSA-1.0+INT+int16.",
)
parser.add_argument(
"--vgf-quantize-scope",
default="full",
choices=["full", "linear"],
help="VGF quantization scope. Use 'linear' to quantize only Linear modules.",
)
parser.add_argument("--mps", action="store_true")
parser.add_argument("--coreml", action="store_true")
parser.add_argument(
Expand Down Expand Up @@ -847,6 +861,7 @@ def get_quantizer_and_quant_params(llm_config):
llm_config.backend.vgf.compile_spec,
llm_config.backend.vgf.compiler_flags,
llm_config.quantization.pt2e_quantize.value,
llm_config.backend.vgf.quantize_scope.value,
)
quantizers.append(vgf_quantizer)
if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize:
Expand Down
11 changes: 6 additions & 5 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -285,11 +286,11 @@ def get_example_inputs(self):
if self.use_kv_cache:
return self.get_example_inputs_kvcache_sdpa()
else:
return (
torch.tensor(
[[1, 2, 3]], dtype=torch.long
), # tokens, with kv cache our input token length is always just 1 token.
)
max_len = getattr(self.llm_config.export, "max_seq_length", 3)
max_len = max(3, int(max_len))
example_tokens = torch.arange(max_len, dtype=torch.int32).unsqueeze(0)
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the non-KV-cache path, the example token tensor is created with dtype int32. Because torch.export uses example inputs to specialize/guard the graph, this can lock the exported program to int32 token IDs, while most tokenization/eval codepaths in this repo use torch.long/int64 tokens. Consider generating example_tokens with dtype=torch.long (or otherwise ensuring the rest of the pipeline consistently uses int32) to avoid dtype guard failures at runtime.

Suggested change
example_tokens = torch.arange(max_len, dtype=torch.int32).unsqueeze(0)
example_tokens = torch.arange(max_len, dtype=torch.long).unsqueeze(0)

Copilot uses AI. Check for mistakes.
example_tokens = example_tokens % 100 + 1
return (example_tokens,)

# assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
def get_example_inputs_kvcache_sdpa(self):
Expand Down
54 changes: 53 additions & 1 deletion examples/models/llama/tests/test_export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import unittest

import torch

from executorch.devtools.backend_debug import get_delegation_info

try:
Expand All @@ -28,7 +30,11 @@
build_args_parser,
get_quantizer_and_quant_params,
)
from executorch.extension.llm.export.config.llm_config import LlmConfig, Pt2eQuantize
from executorch.extension.llm.export.config.llm_config import (
LlmConfig,
Pt2eQuantize,
VgfQuantizeScope,
)

UNWANTED_OPS = [
"aten_permute_copy_default",
Expand Down Expand Up @@ -111,3 +117,49 @@ def test_get_quantizer_and_quant_params_returns_vgf_quantizer(self):
self.assertIsNone(quant_dtype)
self.assertEqual(len(quantizers), 1)
self.assertIsInstance(quantizers[0], VgfQuantizer)
self.assertIsNotNone(quantizers[0].global_config)
self.assertEqual(quantizers[0].module_type_config, {})

@unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available")
def test_get_quantizer_and_quant_params_returns_vgf_linear_quantizer(self):
llm_config = LlmConfig()
llm_config.backend.vgf.enabled = True
llm_config.backend.vgf.compile_spec = "TOSA-1.0+INT"
llm_config.backend.vgf.quantize_scope = VgfQuantizeScope.linear
llm_config.quantization.pt2e_quantize = Pt2eQuantize.vgf_8a8w

_pt2e_quant_params, quantizers, _quant_dtype = get_quantizer_and_quant_params(
llm_config
)

self.assertEqual(len(quantizers), 1)
self.assertIsInstance(quantizers[0], VgfQuantizer)
self.assertIsNone(quantizers[0].global_config)
self.assertIn(torch.nn.Linear, quantizers[0].module_type_config)

@unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available")
def test_vgf_16a8w_requires_int16_compile_spec_extension(self):
llm_config = LlmConfig()
llm_config.backend.vgf.enabled = True
llm_config.backend.vgf.compile_spec = "TOSA-1.0+INT"
llm_config.backend.vgf.quantize_scope = VgfQuantizeScope.linear
llm_config.quantization.pt2e_quantize = Pt2eQuantize.vgf_16a8w

with self.assertRaisesRegex(ValueError, "INT16 support"):
get_quantizer_and_quant_params(llm_config)

@unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available")
def test_vgf_16a8w_accepts_int16_compile_spec_extension(self):
llm_config = LlmConfig()
llm_config.backend.vgf.enabled = True
llm_config.backend.vgf.compile_spec = "TOSA-1.0+INT+int16"
llm_config.backend.vgf.quantize_scope = VgfQuantizeScope.linear
llm_config.quantization.pt2e_quantize = Pt2eQuantize.vgf_16a8w

_pt2e_quant_params, quantizers, _quant_dtype = get_quantizer_and_quant_params(
llm_config
)

self.assertEqual(len(quantizers), 1)
self.assertIsInstance(quantizers[0], VgfQuantizer)
self.assertIn(torch.nn.Linear, quantizers[0].module_type_config)
2 changes: 1 addition & 1 deletion examples/models/smollm2/135M_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"n_kv_heads": 3,
"n_layers": 30,
"norm_eps": 1e-05,
"rope_theta": 10000.0,
"rope_theta": 100000.0,
"use_scaled_rope": false,
"vocab_size": 49152,
"use_hf_rope": false,
Expand Down
79 changes: 49 additions & 30 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,23 +282,39 @@ def calibrate_template(
module: torch.fx.GraphModule, tokenizer, prompts: str, max_len: int
):
# TODO: change criteria & support batch inputs if necessary
pos = torch.tensor(0, dtype=torch.int64)
pos = 0
token_list = tokenizer.encode(prompts, bos=True, eos=False)

pad_token = getattr(tokenizer, "pad_id", tokenizer.eos_id)

with torch.no_grad():
while token_list[-1] != tokenizer.eos_id and pos < max_len:
logits = module(
torch.full((1, 1), token_list[pos]),
{"input_pos": torch.tensor((pos,))},
)
if self.use_kv_cache:
logits = module(
torch.full((1, 1), token_list[pos]),
{"input_pos": torch.tensor((pos,))},
)
else:
prefix_tokens = list(token_list[: pos + 1])
if len(prefix_tokens) < max_len:
prefix_tokens.extend(
[pad_token] * (max_len - len(prefix_tokens))
)
else:
prefix_tokens = prefix_tokens[:max_len]

prefix = torch.tensor(
prefix_tokens, dtype=torch.long
).unsqueeze(0)
logits = module(prefix)

pos += 1
if pos >= len(token_list):
if self.generate_full_logits:
token_list.append(
torch.argmax(logits[:, -1], dim=-1).item()
)
next_token = torch.argmax(logits[:, -1], dim=-1).item()
else:
token_list.append(torch.argmax(logits[:], dim=-1).item())
next_token = torch.argmax(logits[:], dim=-1).item()
token_list.append(next_token)

calibrate_template(
module=prepared_module,
Expand All @@ -307,26 +323,31 @@ def calibrate_template(
max_len=calibration_seq_length,
)

eval_wrapper = GraphModuleEvalWrapper(
model=prepared_module,
tokenizer=tokenizer,
max_seq_length=calibration_seq_length,
use_kv_cache=self.use_kv_cache,
generate_full_logits=self.generate_full_logits,
enable_dynamic_shape=self.enable_dynamic_shape,
)

# Evaluate the model
with torch.no_grad():
eval_results = simple_evaluate(
model=eval_wrapper,
tasks=calibration_tasks,
limit=calibration_limit,
if calibration_tasks:
eval_wrapper = GraphModuleEvalWrapper(
model=prepared_module,
tokenizer=tokenizer,
max_seq_length=calibration_seq_length,
use_kv_cache=self.use_kv_cache,
generate_full_logits=self.generate_full_logits,
enable_dynamic_shape=self.enable_dynamic_shape,
# The exported graph can contain ops like aten.full.default
# without explicit device, which default to CPU and can
# trigger device-mismatch errors when lm_eval runs on CUDA.
# Calibrate on CPU for stability.
device="cpu",
)

for task, res in eval_results["results"].items():
print(f"{task}: {res}")
logging.info("Calibration finish...")
with torch.no_grad():
eval_results = simple_evaluate(
model=eval_wrapper,
tasks=calibration_tasks,
limit=calibration_limit,
)
Comment on lines +326 to +346
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the new if calibration_tasks: gate, simple_evaluate(..., limit=calibration_limit) will be called even when calibration_limit is None (the CLI default). In lm-eval this typically means evaluating the full dataset, which can make calibration unexpectedly long/expensive. Consider requiring calibration_limit when tasks are provided, supplying a conservative default, or emitting a clear warning when tasks are set without a limit.

Copilot uses AI. Check for mistakes.

for task, res in eval_results["results"].items():
print(f"{task}: {res}")
logging.info("Calibration finish...")

def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager":
"""
Expand Down Expand Up @@ -360,9 +381,7 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
)
# Calibrate
if (
self.calibration_tasks is not None
and self.calibration_limit is not None
and self.calibration_seq_length is not None
self.calibration_seq_length is not None
and self.calibration_data is not None
and self.tokenizer_path is not None
):
Expand Down
17 changes: 17 additions & 0 deletions extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ class Pt2eQuantize(str, Enum):
tosa_8a8w = "tosa_8a8w"
ethosu_8a8w = "ethosu_8a8w"
vgf_8a8w = "vgf_8a8w"
vgf_16a8w = "vgf_16a8w"


class SpinQuant(str, Enum):
Expand Down Expand Up @@ -587,6 +588,11 @@ class EthosUConfig:
system_config: str = "default"


class VgfQuantizeScope(str, Enum):
full = "full"
linear = "linear"


@dataclass
class VgfConfig:
"""
Expand All @@ -596,6 +602,7 @@ class VgfConfig:
enabled: bool = False
compile_spec: Optional[str] = "TOSA-1.0+INT"
compiler_flags: List[str] = field(default_factory=list)
quantize_scope: VgfQuantizeScope = VgfQuantizeScope.full


@dataclass
Expand Down Expand Up @@ -815,6 +822,16 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
if hasattr(args, "group_size") and args.group_size:
llm_config.backend.openvino.nncf_compression_group_size = args.group_size

# VGF
if hasattr(args, "vgf"):
llm_config.backend.vgf.enabled = args.vgf
if hasattr(args, "vgf_compile_spec"):
llm_config.backend.vgf.compile_spec = args.vgf_compile_spec
if hasattr(args, "vgf_quantize_scope") and args.vgf_quantize_scope:
llm_config.backend.vgf.quantize_scope = VgfQuantizeScope(
args.vgf_quantize_scope
)

# TorchAoKernels
if any(
hasattr(args, a)
Expand Down
Loading
Loading