diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index 23d00ff8c15..f89ea0a0075 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -46,9 +47,13 @@ def __init__( use_kv_cache: bool = False, generate_full_logits: bool = False, enable_dynamic_shape: bool = True, + device: Optional[str] = None, ): super().__init__( - model=model, tokenizer=tokenizer, max_seq_length=max_seq_length + model=model, + tokenizer=tokenizer, + max_seq_length=max_seq_length, + device=device, ) self._model = model.to(self.device) self._use_kv_cache = use_kv_cache @@ -80,7 +85,47 @@ def _model_call(self, inps): return logits else: - return self._model(inps) + # lm-eval expects logits shaped [batch, seq, vocab]. + # The KV-cache path above handles that separately. In the non-KV path, + # some exported graphs (when generate_full_logits=False) return only + # last-position logits [batch, vocab], so reconstruct per-position + # logits by running prefix calls. + + seq_len = inps.shape[-1] + + def pad_to_max_len(tokens: torch.Tensor) -> torch.Tensor: + if self._enable_dynamic_shape: + return tokens + token_len = tokens.shape[-1] + if token_len < self._max_seq_length: + pad_len = self._max_seq_length - token_len + pad_token = getattr( + self._tokenizer, "pad_id", self._tokenizer.eos_id + ) + pad = torch.full( + (tokens.shape[0], pad_len), + pad_token, + dtype=tokens.dtype, + device=tokens.device, + ) + return torch.cat((tokens, pad), dim=-1) + if token_len > self._max_seq_length: + return tokens[:, : self._max_seq_length] + return tokens + + if self._generate_full_logits: + return self._model(pad_to_max_len(inps)) + + # Reconstruct full logits by running prefixes. + result_logits = [] + for pos in range(min(seq_len, self._max_seq_length)): + prefix = pad_to_max_len(inps[:, : pos + 1]) + logits = self._model(prefix) + if logits.dim() == 3: + logits = logits[:, -1, :] + result_logits.append(logits) + + return torch.stack(result_logits, dim=1) def _model_generate(self, context, max_length, eos_token_id): raise Exception("unimplemented") diff --git a/examples/models/llama/evaluate/eager_eval.py b/examples/models/llama/evaluate/eager_eval.py index 9d5d7ad447b..5c129e1c250 100644 --- a/examples/models/llama/evaluate/eager_eval.py +++ b/examples/models/llama/evaluate/eager_eval.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -28,12 +29,13 @@ def __init__( tokenizer: Union[SentencePieceTokenizer, Tiktoken, HuggingFaceTokenizer], max_seq_length: Optional[int] = None, use_kv_cache: bool = False, + device: Optional[str] = None, ): - device = "cuda" if torch.cuda.is_available() else "cpu" - super().__init__(device=device, pretrained="gpt2") + resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu") + super().__init__(device=resolved_device, pretrained="gpt2") self._model = model self._tokenizer = tokenizer - self._device = torch.device(device) + self._device = torch.device(resolved_device) self._max_seq_length = 2048 if max_seq_length is None else max_seq_length self._use_kv_cache = use_kv_cache diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 9cf1b4b4bf0..426cab8b248 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -229,6 +229,8 @@ def build_args_parser() -> argparse.ArgumentParser: "vulkan_8w", "tosa_8a8w", "ethosu_8a8w", + "vgf_8a8w", + "vgf_16a8w", ], help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.", ) @@ -456,6 +458,18 @@ def build_args_parser() -> argparse.ArgumentParser: ) parser.add_argument("-V", "--vulkan", action="store_true") parser.add_argument("--vulkan-force-fp16", action="store_true") + parser.add_argument("--vgf", action="store_true") + parser.add_argument( + "--vgf-compile-spec", + default="TOSA-1.0+INT", + help="VGF compile spec, e.g. TOSA-1.0+INT or TOSA-1.0+INT+int16.", + ) + parser.add_argument( + "--vgf-quantize-scope", + default="full", + choices=["full", "linear"], + help="VGF quantization scope. Use 'linear' to quantize only Linear modules.", + ) parser.add_argument("--mps", action="store_true") parser.add_argument("--coreml", action="store_true") parser.add_argument( @@ -847,6 +861,7 @@ def get_quantizer_and_quant_params(llm_config): llm_config.backend.vgf.compile_spec, llm_config.backend.vgf.compiler_flags, llm_config.quantization.pt2e_quantize.value, + llm_config.backend.vgf.quantize_scope.value, ) quantizers.append(vgf_quantizer) if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize: diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index f02621b66b2..3b265a074c6 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -285,11 +286,11 @@ def get_example_inputs(self): if self.use_kv_cache: return self.get_example_inputs_kvcache_sdpa() else: - return ( - torch.tensor( - [[1, 2, 3]], dtype=torch.long - ), # tokens, with kv cache our input token length is always just 1 token. - ) + max_len = getattr(self.llm_config.export, "max_seq_length", 3) + max_len = max(3, int(max_len)) + example_tokens = torch.arange(max_len, dtype=torch.int32).unsqueeze(0) + example_tokens = example_tokens % 100 + 1 + return (example_tokens,) # assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working def get_example_inputs_kvcache_sdpa(self): diff --git a/examples/models/llama/tests/test_export_llama_lib.py b/examples/models/llama/tests/test_export_llama_lib.py index 130a55f658c..f3dc403aa05 100644 --- a/examples/models/llama/tests/test_export_llama_lib.py +++ b/examples/models/llama/tests/test_export_llama_lib.py @@ -7,6 +7,8 @@ import unittest +import torch + from executorch.devtools.backend_debug import get_delegation_info try: @@ -28,7 +30,11 @@ build_args_parser, get_quantizer_and_quant_params, ) -from executorch.extension.llm.export.config.llm_config import LlmConfig, Pt2eQuantize +from executorch.extension.llm.export.config.llm_config import ( + LlmConfig, + Pt2eQuantize, + VgfQuantizeScope, +) UNWANTED_OPS = [ "aten_permute_copy_default", @@ -111,3 +117,49 @@ def test_get_quantizer_and_quant_params_returns_vgf_quantizer(self): self.assertIsNone(quant_dtype) self.assertEqual(len(quantizers), 1) self.assertIsInstance(quantizers[0], VgfQuantizer) + self.assertIsNotNone(quantizers[0].global_config) + self.assertEqual(quantizers[0].module_type_config, {}) + + @unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available") + def test_get_quantizer_and_quant_params_returns_vgf_linear_quantizer(self): + llm_config = LlmConfig() + llm_config.backend.vgf.enabled = True + llm_config.backend.vgf.compile_spec = "TOSA-1.0+INT" + llm_config.backend.vgf.quantize_scope = VgfQuantizeScope.linear + llm_config.quantization.pt2e_quantize = Pt2eQuantize.vgf_8a8w + + _pt2e_quant_params, quantizers, _quant_dtype = get_quantizer_and_quant_params( + llm_config + ) + + self.assertEqual(len(quantizers), 1) + self.assertIsInstance(quantizers[0], VgfQuantizer) + self.assertIsNone(quantizers[0].global_config) + self.assertIn(torch.nn.Linear, quantizers[0].module_type_config) + + @unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available") + def test_vgf_16a8w_requires_int16_compile_spec_extension(self): + llm_config = LlmConfig() + llm_config.backend.vgf.enabled = True + llm_config.backend.vgf.compile_spec = "TOSA-1.0+INT" + llm_config.backend.vgf.quantize_scope = VgfQuantizeScope.linear + llm_config.quantization.pt2e_quantize = Pt2eQuantize.vgf_16a8w + + with self.assertRaisesRegex(ValueError, "INT16 support"): + get_quantizer_and_quant_params(llm_config) + + @unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available") + def test_vgf_16a8w_accepts_int16_compile_spec_extension(self): + llm_config = LlmConfig() + llm_config.backend.vgf.enabled = True + llm_config.backend.vgf.compile_spec = "TOSA-1.0+INT+int16" + llm_config.backend.vgf.quantize_scope = VgfQuantizeScope.linear + llm_config.quantization.pt2e_quantize = Pt2eQuantize.vgf_16a8w + + _pt2e_quant_params, quantizers, _quant_dtype = get_quantizer_and_quant_params( + llm_config + ) + + self.assertEqual(len(quantizers), 1) + self.assertIsInstance(quantizers[0], VgfQuantizer) + self.assertIn(torch.nn.Linear, quantizers[0].module_type_config) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index c25c1190990..b4cf631b2be 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -282,23 +282,39 @@ def calibrate_template( module: torch.fx.GraphModule, tokenizer, prompts: str, max_len: int ): # TODO: change criteria & support batch inputs if necessary - pos = torch.tensor(0, dtype=torch.int64) + pos = 0 token_list = tokenizer.encode(prompts, bos=True, eos=False) + pad_token = getattr(tokenizer, "pad_id", tokenizer.eos_id) + with torch.no_grad(): while token_list[-1] != tokenizer.eos_id and pos < max_len: - logits = module( - torch.full((1, 1), token_list[pos]), - {"input_pos": torch.tensor((pos,))}, - ) + if self.use_kv_cache: + logits = module( + torch.full((1, 1), token_list[pos]), + {"input_pos": torch.tensor((pos,))}, + ) + else: + prefix_tokens = list(token_list[: pos + 1]) + if len(prefix_tokens) < max_len: + prefix_tokens.extend( + [pad_token] * (max_len - len(prefix_tokens)) + ) + else: + prefix_tokens = prefix_tokens[:max_len] + + prefix = torch.tensor( + prefix_tokens, dtype=torch.long + ).unsqueeze(0) + logits = module(prefix) + pos += 1 if pos >= len(token_list): if self.generate_full_logits: - token_list.append( - torch.argmax(logits[:, -1], dim=-1).item() - ) + next_token = torch.argmax(logits[:, -1], dim=-1).item() else: - token_list.append(torch.argmax(logits[:], dim=-1).item()) + next_token = torch.argmax(logits[:], dim=-1).item() + token_list.append(next_token) calibrate_template( module=prepared_module, @@ -307,26 +323,31 @@ def calibrate_template( max_len=calibration_seq_length, ) - eval_wrapper = GraphModuleEvalWrapper( - model=prepared_module, - tokenizer=tokenizer, - max_seq_length=calibration_seq_length, - use_kv_cache=self.use_kv_cache, - generate_full_logits=self.generate_full_logits, - enable_dynamic_shape=self.enable_dynamic_shape, - ) - - # Evaluate the model - with torch.no_grad(): - eval_results = simple_evaluate( - model=eval_wrapper, - tasks=calibration_tasks, - limit=calibration_limit, + if calibration_tasks: + eval_wrapper = GraphModuleEvalWrapper( + model=prepared_module, + tokenizer=tokenizer, + max_seq_length=calibration_seq_length, + use_kv_cache=self.use_kv_cache, + generate_full_logits=self.generate_full_logits, + enable_dynamic_shape=self.enable_dynamic_shape, + # The exported graph can contain ops like aten.full.default + # without explicit device, which default to CPU and can + # trigger device-mismatch errors when lm_eval runs on CUDA. + # Calibrate on CPU for stability. + device="cpu", ) - for task, res in eval_results["results"].items(): - print(f"{task}: {res}") - logging.info("Calibration finish...") + with torch.no_grad(): + eval_results = simple_evaluate( + model=eval_wrapper, + tasks=calibration_tasks, + limit=calibration_limit, + ) + + for task, res in eval_results["results"].items(): + print(f"{task}: {res}") + logging.info("Calibration finish...") def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager": """ @@ -360,9 +381,7 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage ) # Calibrate if ( - self.calibration_tasks is not None - and self.calibration_limit is not None - and self.calibration_seq_length is not None + self.calibration_seq_length is not None and self.calibration_data is not None and self.tokenizer_path is not None ): diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index fa22ddad7ac..798aea99906 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -377,6 +377,7 @@ class Pt2eQuantize(str, Enum): tosa_8a8w = "tosa_8a8w" ethosu_8a8w = "ethosu_8a8w" vgf_8a8w = "vgf_8a8w" + vgf_16a8w = "vgf_16a8w" class SpinQuant(str, Enum): @@ -587,6 +588,11 @@ class EthosUConfig: system_config: str = "default" +class VgfQuantizeScope(str, Enum): + full = "full" + linear = "linear" + + @dataclass class VgfConfig: """ @@ -596,6 +602,7 @@ class VgfConfig: enabled: bool = False compile_spec: Optional[str] = "TOSA-1.0+INT" compiler_flags: List[str] = field(default_factory=list) + quantize_scope: VgfQuantizeScope = VgfQuantizeScope.full @dataclass @@ -815,6 +822,16 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 if hasattr(args, "group_size") and args.group_size: llm_config.backend.openvino.nncf_compression_group_size = args.group_size + # VGF + if hasattr(args, "vgf"): + llm_config.backend.vgf.enabled = args.vgf + if hasattr(args, "vgf_compile_spec"): + llm_config.backend.vgf.compile_spec = args.vgf_compile_spec + if hasattr(args, "vgf_quantize_scope") and args.vgf_quantize_scope: + llm_config.backend.vgf.quantize_scope = VgfQuantizeScope( + args.vgf_quantize_scope + ) + # TorchAoKernels if any( hasattr(args, a) diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 0c78921e461..cd70610ee11 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -367,8 +367,10 @@ def get_vgf_quantizer( compile_spec: Optional[str], compiler_flags: Optional[List[str]], pt2e_quantize: str, + quantize_scope: str, ): from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, get_symmetric_quantization_config, VgfQuantizer, ) @@ -379,8 +381,22 @@ def get_vgf_quantizer( quantizer = VgfQuantizer(compile_spec_obj) if pt2e_quantize == "vgf_8a8w": - quantizer.set_global(get_symmetric_quantization_config()) + quantization_config = get_symmetric_quantization_config() + elif pt2e_quantize == "vgf_16a8w": + if not compile_spec_obj.tosa_spec.support_extension("int16"): + raise ValueError( + "vgf_16a8w requires a VGF compile spec with INT16 support, " + "for example TOSA-1.0+INT+int16." + ) + quantization_config = get_symmetric_a16w8_quantization_config() else: raise ValueError(f"Unsupported quantizer specification {pt2e_quantize}") + if quantize_scope == "full": + quantizer.set_global(quantization_config) + elif quantize_scope == "linear": + quantizer.set_module_type(torch.nn.Linear, quantization_config) + else: + raise ValueError(f"Unsupported VGF quantization scope {quantize_scope}") + return quantizer