diff --git a/backends/qualcomm/quantizer/quant_recipe.py b/backends/qualcomm/quantizer/quant_recipe.py index 47d8fdf0e68..b2eb41841f0 100644 --- a/backends/qualcomm/quantizer/quant_recipe.py +++ b/backends/qualcomm/quantizer/quant_recipe.py @@ -6,6 +6,7 @@ import re +import textwrap from abc import ABC, abstractmethod from enum import IntEnum, unique from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple @@ -424,3 +425,76 @@ def summary(self, max_rows: int = -1): rows.append(["..."] * len(headers)) return tabulate(rows, headers=headers, tablefmt="grid") + + def to_source(self) -> str: + """ + Serializes this QuantRecipe into a Python source string at zero indentation. + """ + + def _dtype(d: QuantDtype) -> str: + return f"QuantDtype.{d.name}" + + def _granularity(g: QuantGranularity) -> str: + return f"QuantGranularity.{g.name}" + + def _comments(note: str) -> str: + lines = note.strip().splitlines() if note.strip() else [] + return "".join(f"# {ln}\n" for ln in lines) + + indent = "\t" + + def _args(*lines: str) -> str: + return "".join(f"{indent}{ln},\n" for ln in lines) + + strategy_blocks: List[str] = [] + for strategy in self._strategies: + extra_kwargs_flag = ( + [f"extra_kwargs={strategy.extra_kwargs!r}"] + if strategy.extra_kwargs + else [] + ) + if isinstance(strategy, ByNodeTarget): + targets_repr = ", ".join( + f"torch.ops.{t._overloadpacket._qualified_op_name.replace('::', '.')}.{t._overloadname}" + for t in sorted(strategy.targets, key=lambda t: str(t)) + ) + args = _args( + f"{{{targets_repr}}}", + _dtype(strategy.quant_dtype), + str(strategy.is_qat), + "act_observer=MinMaxObserver", + f"granularity={_granularity(strategy.granularity)}", + *extra_kwargs_flag, + f"act_symmetric={strategy.act_symmetric}", + f"note={strategy.note!r}", + ) + call = f".add_node_target(\n{args})" + elif isinstance(strategy, ByNameRegex): + patterns_repr = ", ".join(f'r"{p}"' for p in sorted(strategy.patterns)) + args = _args( + f"{{{patterns_repr}}}", + _dtype(strategy.quant_dtype), + str(strategy.is_qat), + "act_observer=MinMaxObserver", + f"granularity={_granularity(strategy.granularity)}", + *extra_kwargs_flag, + f"act_symmetric={strategy.act_symmetric}", + f"note={strategy.note!r}", + ) + call = f".add_regex(\n{args})" + else: + continue + + strategy_blocks.append(_comments(strategy.note) + call) + + header_args = _args( + "self.default_quant_dtype", + str(self._default_is_qat), + "act_observer=MinMaxObserver", + f"granularity={_granularity(self._default_granularity)}", + "verbose=verbose", + ) + header = f"QuantRecipe(\n{header_args})" + chained = "\n".join(strategy_blocks) + body = header + "\n" + chained + return "(\n" + textwrap.indent(body, indent) + "\n)" diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 48f07da06e9..111f5dc9baa 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -9039,6 +9039,66 @@ def test_intermediate_debugger(self): f"CSV valid count: {csv_valid_count}. SVG valid count: {svg_valid_count}" ) + def test_analyzer_to_file_generation(self): + """ + End-to-end test for PerLayerSqnrAnalyzer → SqnrReport → file generation. + """ + from executorch.examples.qualcomm.oss_scripts.llama.mix_precision_analyzer import ( + PerLayerSqnrAnalyzer, + save_suggest_recipes, + ) + + module = SimpleModel() # noqa: F405 + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) + fp32_gm = torch.export.export(module, sample_input, strict=True).module() + qdq_gm = self.get_qdq_module( + module, sample_input, quant_dtype=QuantDtype.use_8a4w + ) + + report = PerLayerSqnrAnalyzer( + model_name="simple_conv", + num_layers=4, + fp32_gm=fp32_gm, + qdq_gm=qdq_gm, + ).analyze([sample_input], num_sharding=4) + + overrides = report.suggest_recipe_overrides(sqnr_threshold=22.0) + + with tempfile.TemporaryDirectory() as tmp_dir: + report.save_analysis_summary(output_dir=tmp_dir) + save_suggest_recipes(report, overrides, output_dir=tmp_dir) + + # --- save_analysis_summary csv file --- + with open(f"{tmp_dir}/simple_conv_quantization_error.csv") as f: + csv_content = f.read() + rows = list(csv.reader(csv_content.splitlines())) + self.assertEqual(len(rows), 5) # 1 header + 4 group rows + self.assertEqual( + rows[0], + [ + "group_name", + "avg_sqnr", + "median_sqnr", + "min_sqnr", + "max_sqnr", + "count", + ], + ) + print(f"Sensitivity analysis:\n{csv_content}") + + # --- save_suggest_recipes .py file (only written when sensitive layers exist) --- + if overrides: + with open(f"{tmp_dir}/simple_conv_suggest_recipe.py") as f: + py_content = f.read() + # generated file must be valid Python + try: + compile(py_content, "simple_conv_suggest_recipe.py", "exec") + except SyntaxError as e: + self.fail( + f"Generated recipe file has syntax error: {e}\n{py_content}" + ) + self.assertIn("HOW TO USE THESE RECIPES", py_content) + def setup_environment(): parser = setup_common_args_and_variables() diff --git a/examples/qualcomm/oss_scripts/llama/README.md b/examples/qualcomm/oss_scripts/llama/README.md index fb926e9f613..6445a1ede7d 100644 --- a/examples/qualcomm/oss_scripts/llama/README.md +++ b/examples/qualcomm/oss_scripts/llama/README.md @@ -467,6 +467,20 @@ Example: python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "I would like to learn python, could you teach me with a simple example?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --eval_methods sqnr_eval ``` +#### Quantization Guidance + +To automatically identify sensitive layers and generate a mixed-precision recipe suggestion, add the `--quant_recipe_suggestion` flag. During calibration, the analyzer compares FP32 and QDQ intermediate outputs layer-by-layer using SQNR, then writes two files to the working directory: + +- `{model_name}_quantization_error.csv` — per-group SQNR statistics sorted by sensitivity (most sensitive first) +- `{model_name}_suggest_recipe.py` — ready-to-use `StaticLLMQuantRecipe` subclasses optimized to apply higher-precision quantization to the most sensitive groups. + +Example: +```bash +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "I would like to learn python, could you teach me with a simple example?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen3-1_7b --tasks wikitext --limit 1 --quant_recipe_suggestion --compile_only +``` + +After the run, pick one of the generated classes from `qwen3-1_7b_suggest_recipe.py` as your new recipe. For a full walkthrough, see [quantization_guidance.md](quantization_guidance.md). + #### Use attention sink for multi-turn conversations Attention sink is a way to evict cache when maximum context length be reached. There are two mainly concept for attention sink: diff --git a/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py b/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py index 92c901b3990..d944246922d 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py @@ -403,7 +403,7 @@ def __init__( self.max_seq_length = pte_max_context_len def run(self, prompt): - golden_logits = INFERENCE_REGISTRY[True]( + golden_logits, _ = INFERENCE_REGISTRY[True]( get_example_inputs=self.get_example_inputs, prompt=prompt, module=self.source_model, diff --git a/examples/qualcomm/oss_scripts/llama/decoder_utils.py b/examples/qualcomm/oss_scripts/llama/decoder_utils.py index d3261e1bb68..226cb4d7d42 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_utils.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_utils.py @@ -107,6 +107,10 @@ def __init__( # noqa: C901 self.max_seq_length = max_seq_length self.use_i64_token = use_i64_token self.seq_mse_candidates = seq_mse_candidates + self._input_samples = None + + def get_input_samples(self): + return self._input_samples def _model_call(self, inps): all_logits = None @@ -115,7 +119,7 @@ def _model_call(self, inps): kwargs["ar_len"] = self.ar_len kwargs["seq_mse_candidates"] = self.seq_mse_candidates - all_logits = INFERENCE_REGISTRY[self._use_kv_cache]( + all_logits, self._input_samples = INFERENCE_REGISTRY[self._use_kv_cache]( self.get_example_inputs, inps, self._model, @@ -403,6 +407,7 @@ def _prefill_chunking( k_caches, v_caches, total_token_list, + last_input_sample=None, ): with torch.no_grad(): num_prompt_tokens = len(total_token_list) @@ -446,6 +451,13 @@ def _prefill_chunking( *k_caches, *v_caches, ) + last_input_sample = ( + tmp_token_list, + *inputs.atten_mask, + tmp_pos, + *k_caches, + *v_caches, + ) else: logits, new_k_caches, new_v_caches = module( tmp_embedding, @@ -454,6 +466,13 @@ def _prefill_chunking( *k_caches, *v_caches, ) + last_input_sample = ( + tmp_embedding, + *inputs.atten_mask, + tmp_pos, + *k_caches, + *v_caches, + ) if collect_logits: result_logits.append(logits[:, :num_tokens_in_chunk]) @@ -493,7 +512,7 @@ def _prefill_chunking( torch.argmax(logits[:, num_tokens_in_chunk - 1], dim=-1).item() ) - return pos + return pos, last_input_sample def _generate( @@ -508,6 +527,7 @@ def _generate( v_caches, total_token_list, lookahead_config, + last_input_sample=None, ): max_cache_len = max_seq_len - ar_len num_tokens = len(total_token_list) @@ -543,6 +563,13 @@ def _generate( *k_caches, *v_caches, ) + last_input_sample = ( + tmp_token_list, + *inputs.atten_mask, + tmp_pos, + *k_caches, + *v_caches, + ) else: logits, new_k_caches, new_v_caches = module( embedding, @@ -551,6 +578,13 @@ def _generate( *k_caches, *v_caches, ) + last_input_sample = ( + embedding, + *inputs.atten_mask, + tmp_pos, + *k_caches, + *v_caches, + ) pos, k_caches, v_caches = smart_mask_updater( 1, @@ -604,6 +638,15 @@ def _generate( *k_caches, *v_caches, ) + last_input_sample = ( + torch.tensor(input_tokens, dtype=inputs.input_ids_dtype).unsqueeze( + 0 + ), + *inputs.atten_mask, + pos_offsets + pos, + *k_caches, + *v_caches, + ) else: logits, new_k_caches, new_v_caches = module( tok_embedding( @@ -616,6 +659,17 @@ def _generate( *k_caches, *v_caches, ) + last_input_sample = ( + tok_embedding( + torch.tensor( + input_tokens, dtype=inputs.input_ids_dtype + ).unsqueeze(0) + ), + *inputs.atten_mask, + pos_offsets + pos, + *k_caches, + *v_caches, + ) # collect outputs output_tokens = torch.argmax(logits, dim=-1).flatten().tolist() # update ngram pool @@ -658,6 +712,7 @@ def _generate( logging.info( f"lookahead accepted / total generated: {accepted_tokens} / {generated_tokens}" ) + return last_input_sample @register_inference(use_kv_cache=True) @@ -676,6 +731,7 @@ def kv_inference( # noqa: C901 seq_mse_candidates=0, lookahead_config=None, ): + input_samples = [] # Record input sample for quantization error analysis is_multimodal = all( [ tok_embedding is not None, @@ -774,7 +830,7 @@ def kv_inference( # noqa: C901 # 4. decoder forward with torch.no_grad(): # Phase 1: Prefill the prompt in ar_len chunks. - cur_pos = _prefill_chunking( + cur_pos, prefill_input_sample = _prefill_chunking( inputs, module, ar_len, @@ -788,7 +844,7 @@ def kv_inference( # noqa: C901 # Phase 2: Generate tokens until the EOS token is generated or max_seq_len is reached. # When run on wikitext for ppl evaluation, this while-loop is not expected to run. - _generate( + generate_input_sample = _generate( inputs, cur_pos, module, @@ -801,11 +857,15 @@ def kv_inference( # noqa: C901 total_token_list, lookahead_config, ) + if generate_input_sample is not None: + input_samples.append(generate_input_sample) + else: + input_samples.append(prefill_input_sample) logging.info(f"kv inference result:\n{tokenizer.decode(total_token_list)}") if collect_logits: result_logits = torch.cat(result_logits, dim=1) - return result_logits + return result_logits, input_samples @register_inference(use_kv_cache=False) @@ -821,6 +881,7 @@ def prefill_inference( use_i64_token=False, collect_logits=False, ): + input_samples = None # Record input sample for quantization error analysis is_multimodal = all( [ tok_embedding is not None, @@ -873,8 +934,10 @@ def prefill_inference( image_token_id, ) results = module(multimodal_embedding, *atten_mask) + input_samples = (multimodal_embedding, *atten_mask) else: results = module(tmp_token_list, *atten_mask) + input_samples = (tmp_token_list, *atten_mask) if len(results) == 3: logits, _, _ = results elif len(results) == 1: @@ -886,7 +949,7 @@ def prefill_inference( pos += 1 if isinstance(prompt, str): logging.info(f"prefill inference result:\n{tokenizer.decode(token_list)}") - return result_logits + return result_logits, [input_samples] def graph_module_inference( @@ -923,7 +986,7 @@ def graph_module_inference( kwargs["ar_len"] = ar_len kwargs["lookahead_config"] = lookahead_config - INFERENCE_REGISTRY[use_kv_cache]( + _, input_samples = INFERENCE_REGISTRY[use_kv_cache]( get_example_inputs, prompt, module, @@ -937,6 +1000,7 @@ def graph_module_inference( **kwargs, ) logging.info(f"Prompt summary for {event_name}") + return input_samples else: calibration_wrapper = GraphModuleCalibrationWrapper( model=module, @@ -958,3 +1022,5 @@ def graph_module_inference( logging.info(f"Evaluation summary for {event_name}") for task, res in eval_results["results"].items(): logging.info(f"{task}: {res}") + + return calibration_wrapper.get_input_samples() diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 3aa1fa81610..c0a48de8300 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -535,6 +535,12 @@ def _build_parser(): help="Thread count for calibration forward passes. 0 = auto-tune (default).", ) + parser.add_argument( + "--quant_recipe_suggestion", + action="store_true", + help="Enable automatic quant recipe suggestion in PTQ", + ) + return parser diff --git a/examples/qualcomm/oss_scripts/llama/mix_precision_analyzer.py b/examples/qualcomm/oss_scripts/llama/mix_precision_analyzer.py new file mode 100644 index 00000000000..02f19a0b676 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama/mix_precision_analyzer.py @@ -0,0 +1,688 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import csv +import logging +import math +import re +import statistics +import textwrap +from collections import defaultdict +from dataclasses import dataclass, field +from functools import reduce +from typing import Dict, List, Optional, Tuple + +import torch +from executorch.backends.qualcomm.quantizer.quant_recipe import ( + ByNameRegex, + ByNodeTarget, + QuantGranularity, + QuantRecipe, +) +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.devtools.inspector._intermediate_output_capturer import ( + IntermediateOutputCapturer, +) +from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY +from torchao.quantization.pt2e import MinMaxObserver +from torchao.quantization.utils import compute_error + + +class PerLayerSqnrAnalyzer: + """ + Computes per-layer SQNR by comparing fp32 and QDQ intermediate outputs. + + Args: + model_name: Name of the model being analyzed. + num_layers: Total number of transformer layers in the model. + fp32_gm: fp32 exported GraphModule (before prepare_pt2e). + qdq_gm: QDQ GraphModule (after convert_pt2e). + analysis_recipe: The QuantRecipe used to produce qdq_gm. Stored in the + returned SqnrReport and used as the baseline for diff annotation + in save_suggest_recipes(). Pass None if the QDQ model was not + produced via a QuantRecipe, in that case codegen will skip + diff annotation entirely. + """ + + def __init__( + self, + model_name: str, + num_layers: int, + fp32_gm: torch.fx.GraphModule, + qdq_gm: torch.fx.GraphModule, + analysis_recipe: Optional[QuantRecipe] = None, + ): + self.model_name = model_name + self.num_layers = num_layers + self.fp32_gm = fp32_gm + self.qdq_gm = qdq_gm + self.analysis_recipe = analysis_recipe + self.targets = { + torch.ops.aten.conv2d.default, + } + self.q_ops = { + torch.ops.torchao.quantize_affine, + torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + } + self.dq_ops = { + torch.ops.torchao.dequantize_affine, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + } + + def analyze(self, samples: List[Tuple], num_sharding: int = 5) -> "SqnrReport": + """ + Evaluates both the fp32 and QDQ graphs using the provided input_samples + and computes the per-node Signal-to-Quantization-Noise Ratio (SQNR). + + Args: + input_samples: A list of tuples containing tensors corresponding to the model's inputs. + num_sharding: Number of contiguous layer groups to bucket the model into for SQNR + aggregation. Rather than flagging individual layers, layers are grouped into + ``num_sharding`` consecutive ranges (e.g. layers 0-7, 8-15, …) and the SQNR + is averaged within each group. Because upgrading isolated layers is usually ineffective: quantization error from surrounding + low-precision layers accumulates and dominates downstream behavior. + + Returns: + An ``SqnrReport`` object containing the aggregated analysis results. + """ + input_samples = [sample for sample in samples if sample is not None] + + if not input_samples: + logging.warning("No input samples provided for analysis.") + return SqnrReport( + self.model_name, defaultdict(list), [], self.analysis_recipe + ) + + self._assign_debug_handles(self.fp32_gm) + self._assign_debug_handles(self.qdq_gm) + + num_samples = len(input_samples) + logging.info(f"num samples: {num_samples}") + + # Accumulate SQNR per module path across all input samples + path_sqnr_sum = defaultdict(float) + for sample in input_samples: + fp_outputs = self._capture(self.fp32_gm, sample) + qdq_outputs = self._capture(self.qdq_gm, sample) + for path, sqnr in self._match_and_score(fp_outputs, qdq_outputs).items(): + path_sqnr_sum[path] += sqnr + + # Average the SQNRs and group them by normalized layer ranges + report = defaultdict(list) + for path, total_sqnr in path_sqnr_sum.items(): + group = self._normalize_group_name( + path, self.num_layers, num_sharding=num_sharding + ) + report[group].append(total_sqnr / num_samples) + + return SqnrReport( + self.model_name, + report, + self._collect_conv_in_channels(), + self.analysis_recipe, + ) + + def _assign_debug_handles(self, gm: torch.fx.GraphModule) -> None: + call_nodes = [] + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in self.targets: + continue + users = list(node.users.keys()) + maybe_q_op = users[0] + if maybe_q_op.target in self.q_ops: + maybe_q_op_users = list(maybe_q_op.users.keys()) + if maybe_q_op_users: + dq_node = maybe_q_op_users[0] + call_nodes.append(dq_node) + else: + call_nodes.append(node) + + for i, node in enumerate(call_nodes): + node.meta[DEBUG_HANDLE_KEY] = i + + def _collect_conv_in_channels(self) -> List[int]: + """Collects in_channels from all conv2d nodes in fp32_gm.""" + in_channels = [] + for node in self.fp32_gm.graph.nodes: + if node.op == "call_function" and node.target in self.targets: + weight_node = node.args[1] + weight = weight_node.meta.get("val", None) + if weight is not None: + in_channels.append(weight.shape[1]) + return in_channels + + def _capture( + self, gm: torch.fx.GraphModule, inputs: Tuple + ) -> Dict[int, Tuple[str, torch.Tensor]]: + """ + Executes the graph module using IntermediateOutputCapturer and returns a mapping + from debug_handle to a tuple of (module_path, output_tensor) for every captured tensor. + """ + with torch.no_grad(): + raw = IntermediateOutputCapturer(gm).run_and_capture(*inputs) + + handle_idx_to_node = { + n.meta[DEBUG_HANDLE_KEY]: n + for n in gm.graph.nodes + if DEBUG_HANDLE_KEY in n.meta + } + + outputs: Dict[int, Tuple[str, torch.Tensor]] = {} + for handle, tensor in raw.items(): + handle_idx = handle[0] + if not isinstance(tensor, torch.Tensor): + continue + if handle_idx not in handle_idx_to_node: + continue + if path := self._module_path(handle_idx_to_node[handle_idx]): + outputs[handle_idx] = (path, tensor) + return outputs + + def _module_path(self, node: torch.fx.Node) -> str: + if node.target in self.dq_ops: + args = node.args + node = args[0].args[0] + if "nn_module_stack" in node.meta: + return list(node.meta["nn_module_stack"].values())[-1][0] + return None + + def _normalize_group_name( + self, group_name: str, total_layers: int, num_sharding: int + ) -> str: + """Buckets layer indices into broader ranges for SQNR aggregation.""" + m = re.search(r"(layers)[_.](\d+)[_.]", group_name) + if not m: + return group_name + prefix = m.group(1) + layer_id = int(m.group(2)) + step = max(1, total_layers // num_sharding) + idx = min(layer_id // step, num_sharding - 1) + start = idx * step + end = total_layers - 1 if idx == num_sharding - 1 else start + step - 1 + return re.sub(r"(layers)[_.]\d+[_.]", rf"{prefix}.[{start}-{end}].", group_name) + + def _match_and_score( + self, + fp_outputs: Dict[int, Tuple[str, torch.Tensor]], + qdq_outputs: Dict[int, Tuple[str, torch.Tensor]], + ) -> Dict[str, float]: + """ + Compares corresponding fp32 and QDQ output tensors and computes their SQNR. + Returns a dictionary mapping module paths to their SQNR values. + """ + results = {} + for handle, (path, fp_tensor) in fp_outputs.items(): + if handle in qdq_outputs and fp_tensor.dtype != torch.bool: + _, qdq_tensor = qdq_outputs[handle] + sqnr = compute_error(fp_tensor, qdq_tensor) + if math.isfinite(sqnr): + results[path] = sqnr + + return results + + +@dataclass +class GroupSqnrStats: + group_name: str + avg_sqnr: float + median_sqnr: float + min_sqnr: float + max_sqnr: float + count: int + + +@dataclass +class SqnrReport: + """Aggregated SQNR results from PerLayerSqnrAnalyzer.""" + + model_name: str + results: Dict[str, List[float]] = field(default_factory=lambda: defaultdict(list)) + conv_in_channels: List[int] = field(default_factory=list) + analysis_recipe: Optional[QuantRecipe] = field(default=None) + + def _compute_blk_sizes_candidate( + self, min_size: int = 16, max_size: int = 64 + ) -> List[int]: + """ + Derives block size candidates from the GCD of all conv in_channels, + returning all divisors of that GCD in [min_size, max_size]. + Falls back to [min_size, max_size] if no channels are available. + + Empirically, block sizes in the range [16, 64] offer a good accuracy/compression + trade-off for LPBQ quantization. Smaller values (e.g. 16) preserve more accuracy + at the cost of larger model size; larger values (e.g. 64) compress more aggressively. + You can widen the search range by adjusting ``min_size`` and ``max_size``. + """ + if not self.conv_in_channels: + return [min_size, max_size] + gcd = reduce(math.gcd, self.conv_in_channels) + return sorted(d for d in range(min_size, max_size + 1) if gcd % d == 0) or [ + min_size, + max_size, + ] + + def _group_to_regex(self, group_name: str) -> str: + """ + Converts a normalized group name containing wildcards or bucketed intervals + into a regular expression pattern suitable for ``QuantRecipe.add_regex()``. + + Examples: + - ``layers.*.feed_forward.w2_conv`` -> ``r"layers\\..*\\.feed_forward\\.w2_conv"`` + - ``layers.[7-13].feed_forward`` -> ``r"layers\\.(7|8|9|10|11|12|13)\\.feed_forward"`` + """ + pattern = re.escape(group_name) + # re.escape turns ".*." into "\.\*\." — restore the wildcard + pattern = pattern.replace(r"\.\*\.", r"\..*\.") + + def expand_range(match): + start, end = int(match.group(1)), int(match.group(2)) + return "(" + "|".join(str(i) for i in range(start, end + 1)) + ")" + + return re.sub(r"\\\[(\d+)\\-(\d+)\\\]", expand_range, pattern) + + def _group_stats(self) -> List[GroupSqnrStats]: + stats = [ + GroupSqnrStats( + group_name=grp, + avg_sqnr=(sum(vals) / len(vals)).item(), + median_sqnr=statistics.median(vals).item(), + min_sqnr=min(vals).item(), + max_sqnr=max(vals).item(), + count=len(vals), + ) + for grp, vals in self.results.items() + ] + stats.sort(key=lambda s: s.median_sqnr) + return stats + + def save_analysis_summary(self, output_dir: Optional[str] = None) -> None: + """Writes per-group SQNR statistics to ``{model_name}_quantization_error.csv``.""" + stats = self._group_stats() + output_path = ( + f"{output_dir}/{self.model_name}_quantization_error.csv" + if output_dir + else f"{self.model_name}_quantization_error.csv" + ) + with open(output_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow( + [ + "group_name", + "avg_sqnr", + "median_sqnr", + "min_sqnr", + "max_sqnr", + "count", + ] + ) + for s in stats: + writer.writerow( + [ + s.group_name, + s.avg_sqnr, + s.median_sqnr, + s.min_sqnr, + s.max_sqnr, + s.count, + ] + ) + logging.info( + f"SQNR analysis summary report saved to {self.model_name}_quantization_error.csv" + ) + + def suggest_recipe_overrides( + self, + blk_sizes_candidate: Optional[List[int]] = None, + sqnr_threshold: float = 10.0, + default_precision: QuantDtype = QuantDtype.use_16a4w_block, + higher_precision: QuantDtype = QuantDtype.use_16a8w, + ) -> List[QuantRecipe]: + """ + Suggests precision upgrades on top of the recipe used during analysis. + + This function is intended to locate layer groups where the current quantization + precision (as captured in ``report.analysis_recipe``) produces insufficient SQNR, + and recommend upgrading only those groups to a higher precision. The suggested + recipes are not standalone replacements — they are refinements of the analysis + recipe, preserving its base structure while selectively elevating sensitive layers. + + Decision logic: + - Groups with ``avg_sqnr``, ``median_sqnr``, or ``min_sqnr`` falling below + ``sqnr_threshold`` are flagged as sensitive. + - Sensitive groups are upgraded to ``higher_precision`` with PER_CHANNEL granularity. + - Non-sensitive conv2d layers use ``default_precision`` with PER_BLOCK granularity, + swept across multiple block size candidates. + + Args: + blk_sizes_candidate: Block size candidates for non-sensitive layers. If ``None`` + (default), candidates are derived automatically from the GCD + of all conv in_channels, keeping only divisors in [16, 64]. + sqnr_threshold: The SQNR threshold (in dB) below which a group is considered sensitive. + Defaults to 10.0 dB. + default_precision: The base precision dtype for non-sensitive conv2d layers + (used with PER_BLOCK granularity). + higher_precision: The elevated precision dtype for sensitive layers + (used with PER_CHANNEL granularity). + + Returns: + A list of ``QuantRecipe`` objects, one per block size candidate, each representing + a refined version of the analysis recipe with sensitive layers upgraded. + Returns an empty list if no sensitive layers are detected. + """ + if blk_sizes_candidate is None: + blk_sizes_candidate = self._compute_blk_sizes_candidate() + logging.info( + f"[SqnrAnalyzer] Auto-derived blk_sizes_candidate: {blk_sizes_candidate}" + ) + + stats = self._group_stats() + sensitive = [ + s + for s in stats + if s.avg_sqnr < sqnr_threshold + or s.median_sqnr < sqnr_threshold + or s.min_sqnr < sqnr_threshold + ] + + if not sensitive: + logging.info( + "[SqnrAnalyzer] No sensitive layers detected. Keep the current configuration." + ) + return [] + + # Build keys for what the sensitive-layer pass will add, so we can skip + # exact duplicates when copying analysis_recipe (sensitive replaces original + # only when all four attributes match). + sensitive_keys: set = { + ( + frozenset({self._group_to_regex(s.group_name)}), + higher_precision, + QuantGranularity.PER_CHANNEL, + (), + ) + for s in sensitive + } + + recipes: List[QuantRecipe] = [] + for blk_size in blk_sizes_candidate: + recipe = QuantRecipe( + QuantDtype.use_16a4w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=True, + ).add_node_target( + {torch.ops.aten.conv2d.default}, + default_precision, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, blk_size, 1, 1)}, + note="We use LPBQ for base precision", + ) + + # Carry over ByNameRegex strategies from analysis_recipe, skipping only + # those that are identical to what the sensitive-layer pass will add + # (same patterns, quant_dtype, granularity, and extra_kwargs). + if self.analysis_recipe is not None: + for strategy in self.analysis_recipe._strategies: + if isinstance(strategy, ByNameRegex): + key = ( + frozenset(strategy.patterns), + strategy.quant_dtype, + strategy.granularity, + tuple(sorted(strategy.extra_kwargs.items())), + ) + if key in sensitive_keys: + continue + recipe.add_regex( + strategy.patterns, + strategy.quant_dtype, + strategy.is_qat, + act_observer=strategy.act_observer, + granularity=strategy.granularity, + act_symmetric=strategy.act_symmetric, + extra_kwargs=strategy.extra_kwargs, + note=strategy.note, + ) + + # Add sensitive layers with upgraded precision, replacing any original entry. + for s in sensitive: + pattern = self._group_to_regex(s.group_name) + note = "[SqnrAnalyzer]:\n" + if s.avg_sqnr < sqnr_threshold: + note += ( + f" - avg_sqnr={s.avg_sqnr:.2f} < threshold={sqnr_threshold}\n" + ) + if s.median_sqnr < sqnr_threshold: + note += f" - median_sqnr={s.median_sqnr:.2f} < threshold={sqnr_threshold}\n" + if s.min_sqnr < sqnr_threshold: + note += ( + f" - min_sqnr={s.min_sqnr:.2f} < threshold={sqnr_threshold}\n" + ) + recipe.add_regex( + {pattern}, + higher_precision, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + note=note, + ) + recipes.append(recipe) + + return recipes + + +def save_suggest_recipes( # noqa: C901 + report: "SqnrReport", + suggest_recipe: List[QuantRecipe], + output_dir: Optional[str] = None, +) -> None: + """ + Generates and saves a Python script containing quantization recipe classes + based on the suggested QuantRecipe objects from ``SqnrReport.suggest_recipe_overrides()``. + + The baseline recipe for diff annotation is taken from ``report.analysis_recipe`` — the + ``QuantRecipe`` used to produce the QDQ model during analysis. If ``report.analysis_recipe`` + is None (i.e. the QDQ model was not produced via a QuantRecipe), diff annotation is skipped + and all add_regex calls are emitted without [Original recipe] / [Added by SqnrAnalyzer] tags. + + Args: + report: The ``SqnrReport`` returned by ``PerLayerSqnrAnalyzer.analyze()``. + Provides ``model_name`` and ``analysis_recipe`` (the analysis-time baseline). + suggest_recipe: List of QuantRecipe objects, from + ``SqnrReport.suggest_recipe_overrides()``. + """ + if not suggest_recipe: + logging.info( + "There are no sensitive layers detected. You may keep your current configuration." + ) + return + + model_name = report.model_name + analysis_recipe = report.analysis_recipe + class_name_prefix = model_name.upper().replace("-", "_") + output_path = ( + f"{output_dir}/{model_name}_suggest_recipe.py" + if output_dir + else f"{model_name}_suggest_recipe.py" + ) + + file_header = textwrap.dedent( + """\ + # Auto-generated by save_suggest_recipes() + # + # These recipes are REFINEMENTS of the recipe used during SQNR analysis. + # They preserve the base quantization structure and selectively upgrade + # layer groups where SQNR fell below the configured threshold. + # + # Each class below corresponds to a different LPBQ block size for the base layers. + # Review the [Original recipe] / [Added by SqnrAnalyzer] annotations to understand + # what changed relative to the analysis-time recipe, then pick the variant that + # gives the best accuracy / model-size trade-off on your target device. + + import torch + from executorch.backends.qualcomm.quantizer.custom_annotation import annotate_kv_8bit + from executorch.backends.qualcomm.quantizer.quant_recipe import ( + QuantGranularity, + QuantRecipe, + ) + from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype + from torchao.quantization.pt2e import MinMaxObserver + from examples.qualcomm.oss_scripts.llama.static_llm_quant_recipe import StaticLLMQuantRecipe + """ + ) + + # Collect existing strategies from the analysis recipe for diff annotation. + # A strategy is [Original recipe] only if patterns/targets AND quant_dtype, + # granularity, extra_kwargs all match an existing strategy in analysis_recipe. + # If analysis_recipe is None, diff annotation is skipped entirely. + original_regex_keys: set = set() + original_target_keys: set = set() + if analysis_recipe is not None: + for strategy in analysis_recipe._strategies: + if isinstance(strategy, ByNameRegex): + key = ( + frozenset(strategy.patterns), + strategy.quant_dtype, + strategy.granularity, + tuple(sorted(strategy.extra_kwargs.items())), + ) + original_regex_keys.add(key) + elif isinstance(strategy, ByNodeTarget): + key = ( + frozenset(strategy.targets), + strategy.quant_dtype, + strategy.granularity, + tuple(sorted(strategy.extra_kwargs.items())), + ) + original_target_keys.add(key) + + generated_classes: List[str] = [] + recipe_classes: List[str] = [] + + for recipe in suggest_recipe: + node_target_strategies = [ + s for s in recipe._strategies if isinstance(s, ByNodeTarget) + ] + assert node_target_strategies and node_target_strategies[0].extra_kwargs.get( + "block_size" + ), "Expected at least one LPBQ node target strategy for PTQ recipes" + blk_size = node_target_strategies[0].extra_kwargs["block_size"][1] + class_name = f"{class_name_prefix}_BlockSize{blk_size}QuantRecipe" + generated_classes.append(class_name) + + # Prepend [Original recipe] / [Added by SqnrAnalyzer] tag to each strategy's + # note before codegen, then restore. Only done when analysis_recipe is set. + saved_notes = {} + if analysis_recipe is not None: + for strategy in recipe._strategies: + if isinstance(strategy, ByNameRegex): + key = ( + frozenset(strategy.patterns), + strategy.quant_dtype, + strategy.granularity, + tuple(sorted(strategy.extra_kwargs.items())), + ) + tag = ( + "[Original recipe]\n" + if key in original_regex_keys + else "[Added by SqnrAnalyzer]\n" + ) + elif isinstance(strategy, ByNodeTarget): + key = ( + frozenset(strategy.targets), + strategy.quant_dtype, + strategy.granularity, + tuple(sorted(strategy.extra_kwargs.items())), + ) + tag = ( + "[Original recipe]\n" + if key in original_target_keys + else "[Added by SqnrAnalyzer]\n" + ) + else: + continue + saved_notes[id(strategy)] = strategy.note + strategy.note = tag + strategy.note + + recipe_body = recipe.to_source() + + for strategy in recipe._strategies: + if id(strategy) in saved_notes: + strategy.note = saved_notes[id(strategy)] + + indent = "\t" + init_body = ( + "super().__init__()\n" + "\n" + "self.recipe = " + recipe_body.lstrip() + "\n" + "self.recipe.custom_quant_annotations.append(annotate_kv_8bit)" + ) + class_body = ( + f"default_quant_dtype = QuantDtype.{recipe._default_quant_dtype.name}\n" + "\n" + "def __init__(self, verbose: bool = False):\n" + + textwrap.indent(init_body, indent) + ) + recipe_class = ( + f"class {class_name}(StaticLLMQuantRecipe):\n" + + textwrap.indent(class_body, indent) + + "\n" + ) + recipe_classes.append(recipe_class) + + class_list = "\n".join(f"# {cls}" for cls in generated_classes) + usage_comments = ( + textwrap.dedent( + """\ + # + # HOW TO USE THESE RECIPES + # + # + # The classes above were generated by the SQNR analyzer. + # Each variant uses a different LPBQ block size for the base layers + # while upgrading sensitive layers to higher precision. + # + # Suggested steps: + # 1. Pick one class to try: + """ + ) + + class_list + + textwrap.dedent( + f""" + # + # 2. In your export script, replace the original recipe import, e.g.: + # # Before: + # from examples.qualcomm.oss_scripts.llama.static_llm_quant_recipe import \\ + # {class_name_prefix}QuantRecipe + # # After (example with Blk32): + # from import {class_name_prefix}Blk32QuantRecipe as {class_name_prefix}QuantRecipe + # + # 3. Run calibration + export and compare perplexity / accuracy. + # 4. If accuracy is still insufficient, try a smaller block size + # or increase the SQNR threshold and re-run the analyzer. + """ + ) + ) + + lines: List[str] = ( + file_header + "\n" + "\n".join(recipe_classes) + "\n" + usage_comments + ).splitlines() + + with open(output_path, "w") as f: + f.write("\n".join(lines) + "\n") + + logging.info(f"\n[SqnrAnalyzer] Recipe file written to: {output_path}") + logging.info("[SqnrAnalyzer] Generated classes:") + for cls in generated_classes: + logging.info(f" - {cls}") + logging.info( + "[SqnrAnalyzer] Replace the original recipe class in your export script " + "with one of the above and re-run calibration to evaluate accuracy." + ) diff --git a/examples/qualcomm/oss_scripts/llama/quantization_guidance.md b/examples/qualcomm/oss_scripts/llama/quantization_guidance.md new file mode 100644 index 00000000000..f42f5be38f0 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama/quantization_guidance.md @@ -0,0 +1,131 @@ +# LLMs Quantization Guidance + +## Mixed-Precision Quantization with SQNR Analysis + +When deploying LLMs at low precision (for example, `16a4w_block`), some layers can accumulate significantly higher quantization error and become accuracy bottlenecks. +mix_precision_analyzer.py is an analysis tool that helps you identify these quantization-sensitive layers and provides a directional starting point for mixed-precision tuning. It lets you selectively upgrade only the most quantization-sensitive layers to higher precision, while keeping the rest of the model at the aggressive baseline (e.g. 16a4w_block). +The tool does not aim to find a globally optimal quantization recipe, but it helps narrow the search space so you can iterate from a directional starting point rather than guessing. + +### Overview + +`mix_precision_analyzer.py` provides two classes and one module-level function: + +- **`PerLayerSqnrAnalyzer`** — takes the FP32 `GraphModule` (before `prepare_pt2e`), the fake quant `GraphModule` (after `convert_pt2e`), and optionally the `QuantRecipe` used to produce the fake quant model. Runs both graphs on the same calibration inputs and computes per-conv2d layer SQNR by comparing intermediate outputs. Results are grouped by module path and bucketed across layer ranges. + +- **`SqnrReport`** — holds the grouped SQNR results and exposes three methods: + - `save_analysis_summary()` — writes a CSV with per-group statistics (columns: `group_name, avg_sqnr, median_sqnr, min_sqnr, max_sqnr, count`). + - `suggest_recipe_overrides(sqnr_threshold=10.0, default_precision=use_16a4w_block, higher_precision=use_16a8w)` — flags groups whose avg, median, or min SQNR falls below `sqnr_threshold` as sensitive layer groups to builds a `QuantRecipe` that: carries over all non-sensitive strategies from the analysis recipe, and set sensitive groups to `higher_precision` per-channel. Returns a list of `QuantRecipe` objects (one per block size), or an empty list if no sensitive groups are found. + +- **`save_suggest_recipes(report, suggest_recipe, output_dir=None)`** — renders the override recipes into a ready-to-use `.py` file. Each strategy is annotated with `[Original recipe]` (carried over from the analysis recipe unchanged) or `[Added by SqnrAnalyzer]` (automatically added by the analyzer). + + +### Step-by-Step Workflow + +**1. Initial Quantization Configuration** + +When starting quantization from scratch, the recommended first step is to apply an aggressive baseline precision to the model's layers. You can configure this in `examples/qualcomm/oss_scripts/llama/static_llm_quant_recipe.py`. + +Set your target `conv2d` layers to use **LPBQ (`16a4w_block`) with a block size of 64**. + +For example, your base recipe might look like this: +```python +class Qwen3_1_7BQuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ).add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + is_qat=False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 64, 1, 1)}, + ) + ) +``` + +**2. Run SQNR Evaluation** + +Once your baseline recipe is set, run the main script (`llama.py`) with the `--quant_recipe_suggestion` flag. The SQNR analyzer runs automatically during calibration and writes the following files to the working directory: + +```bash +python examples/qualcomm/oss_scripts/llama/llama.py \ + ... \ + --quant_recipe_suggestion +``` + +Output files: + +- `{model_name}_quantization_error.csv` — per-group SQNR statistics sorted by sensitivity (most sensitive first) +- `{model_name}_suggest_recipe.py` — ready-to-use `StaticLLMQuantRecipe` subclasses optimized to apply higher-precision quantization to the most sensitive groups. + + +**3. Analyze Sensitive Layers** + +The analyzer automatically flags layer groups where SQNR falls below `sqnr_threshold` (default: `10.0` dB). A lower SQNR means higher quantization error and greater sensitivity. + +The generated CSV is sorted by median SQNR ascending, placing the most problematic groups at the top. For example, based on the Qwen3-1.7B model: + +- **`feed_forward.w2_conv`** (down-projection), **`feed_forward.w3_conv`**, and **`attention.wv_conv`** layers are consistently the most sensitive, with SQNR values below 10 dB. + +*Note: `sqnr_threshold` can be adjusted via `suggest_recipe_overrides(sqnr_threshold=...)`.* + +**4. Generated Recipe** + +The generated `{model_name}_suggest_recipe.py` contains one class per block size candidate, e.g.: + +``` +QWEN3_1_7B_BlockSize16QuantRecipe +QWEN3_1_7B_BlockSize32QuantRecipe +QWEN3_1_7B_BlockSize64QuantRecipe +``` + +Each class extends `StaticLLMQuantRecipe` and builds a `QuantRecipe` with two types of strategies, annotated inline: + +- `# [Original recipe]` — strategy was already present in the analysis-time recipe and is carried over unchanged. +- `# [Added by SqnrAnalyzer]` — strategy is new or has different precision/granularity compared to the original; added because the layer group was flagged as sensitive. + + +**5. Apply the Suggested Recipe** + +The analyzer logs the generated classes and writes `{model_name}_suggest_recipe.py`: + +``` +[SqnrAnalyzer] Recipe file written to: {model_name}_suggest_recipe.py +[SqnrAnalyzer] Generated classes: + - QWEN3_1_7B_BlockSize16QuantRecipe + - QWEN3_1_7B_BlockSize32QuantRecipe + - QWEN3_1_7B_BlockSize64QuantRecipe +[SqnrAnalyzer] Replace the original recipe class in your export script with one of the above and re-run calibration to evaluate accuracy. +``` + +Copy the classes above into `static_llm_quant_recipe.py` or replace the original recipe import in `__init__.py`: + +```python +# Before (in examples/qualcomm/oss_scripts/llama/__init__.py): +from executorch.examples.qualcomm.oss_scripts.llama.static_llm_quant_recipe import Qwen3_1_7BQuantRecipe + +# After (example with block size 64): +from executorch.examples.qualcomm.oss_scripts.llama.static_llm_quant_recipe import QWEN3_1_7B_BlockSize64QuantRecipe as Qwen3_1_7BQuantRecipe +``` + +**Iterative tuning tips:** + +- Start with `BlockSize64` as a balanced starting point. +- If accuracy is still insufficient, try `BlockSize32` then `BlockSize16`. +- You can also try `annotate_kv_8bit` as a combination to balance accuracy and performance. +- Consider enabling additional PTQ (Post-Training Quantizatio) techniques in `__init__.py`, such as `seq_mse` or `r3`, to further improve baseline accuracy. +- Once satisfied, copy the final recipe into `static_llm_quant_recipe.py` as the permanent recipe for the model. + +> **Note:** The primary purpose of this SQNR analysis is just to provide a guiding direction for mixed-precision quantization. While it identifies which layers are most sensitive and suggests a reasonable mixed-precision combination, it does not guarantee the best possible combination for every model. Please note that not every model will truly benefit from this mixed precision analysis, as the overall effectiveness of PTQ (Post-Training Quantization) can still be limited. The generated recipe classes are a starting point for exploration, not a final answer. You may need to experiment with different block sizes, different threshold values, or manually crafting overrides for specific layer groups to find the optimal accuracy/performance trade-off for your target model. If your requirement is to push the performance to the extreme limits, please try QAT (Quantization-Aware Training) instead. diff --git a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py index 2d0713b175e..36ac506c91c 100644 --- a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py +++ b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import argparse +import copy import inspect import json import logging @@ -63,6 +64,10 @@ from executorch.examples.qualcomm.oss_scripts.llama.encoder.encoder_quant_recipe import ( EncoderQuantRecipe, ) +from executorch.examples.qualcomm.oss_scripts.llama.mix_precision_analyzer import ( + PerLayerSqnrAnalyzer, + save_suggest_recipes, +) from executorch.examples.qualcomm.oss_scripts.llama.model.embedding import ( TokenEmbedding, ) @@ -415,6 +420,36 @@ def _tag_ios(self, node, fixed_point_type): return quant_io_type + def _quant_recipe_suggestion( + self, + fp32_gm: torch.fx.GraphModule, + qdq_gm: torch.fx.GraphModule, + input_sample: tuple, + recipe: StaticLLMQuantRecipe, + ): + """ + Compare fp32 vs QDQ intermediate outputs and write SQNR reports. + + fp32_gm: Fp32 exported GraphModule (before prepare_pt2e). + qdq_gm: QDQ GraphModule (after convert_pt2e). + + Output files: + ``{model_name}_quantization_error.csv``: per-group statistics + ``{model_name}_suggest_recipe.py``: Python script containing quantization recipe classes + based on the suggested quant recipe overrides. + """ + model_name = self.control_args.decoder_model + report = PerLayerSqnrAnalyzer( + model_name=model_name, + num_layers=self.meta["get_n_layers"], + fp32_gm=fp32_gm, + qdq_gm=qdq_gm, + analysis_recipe=recipe, + ).analyze(input_sample) + report.save_analysis_summary() + suggest_recipe_overrides = report.suggest_recipe_overrides() + save_suggest_recipes(report, suggest_recipe_overrides) + def _auto_tune_calibration_threads(self): """Find the optimal thread count for calibration via quick microbenchmark. @@ -505,8 +540,9 @@ def _calibrate( # Task-based calibration: Only for text-only LLMs # Multimodal models (VLMs) cannot use task-based evaluation currently. + input_samples = [] if has_task_calibration and not is_multimodal: - graph_module_inference( + input_sample = graph_module_inference( use_kv_cache=self.meta["get_use_kv_cache"], get_example_inputs=self.get_example_inputs, module=model, @@ -520,6 +556,7 @@ def _calibrate( event_name=f"{event}_tasks", seq_mse_candidates=self.config.seq_mse_candidates, ) + input_samples.extend(input_sample) # prepare lookahead config if applicable lookahead_config = ( @@ -532,7 +569,7 @@ def _calibrate( # check user's prompt which helps calibrate special token for turn in zip(intermediate_outputs, user_calibration_data): hidden_states, prompt = turn - graph_module_inference( + input_sample = graph_module_inference( use_kv_cache=self.meta["get_use_kv_cache"], get_example_inputs=self.get_example_inputs, hidden_states=hidden_states, # hidden_states for multimodal @@ -547,6 +584,8 @@ def _calibrate( event_name=f"{event}_prompt", lookahead_config=lookahead_config, ) + input_samples.extend(input_sample) + return input_samples @log_info def quantize(self, request: Request): # noqa: C901 @@ -617,6 +656,8 @@ def quantize(self, request: Request): # noqa: C901 self.decoder = torch.export.export( self.decoder, self.export_input, strict=True ).module() + if self.control_args.quant_recipe_suggestion: + graph_module = copy.deepcopy(self.decoder) # Auto-tune thread count BEFORE prepare_pt2e so the benchmark # runs on the exported model without observers — no risk of @@ -642,7 +683,7 @@ def quantize(self, request: Request): # noqa: C901 original_threads, ) try: - self._calibrate( + input_samples = self._calibrate( model=self.decoder, tokenizer=data.tokenizer, event="prepare_pt2e", @@ -659,6 +700,14 @@ def quantize(self, request: Request): # noqa: C901 self.decoder = convert_pt2e(self.decoder) + if self.control_args.quant_recipe_suggestion: + self._quant_recipe_suggestion( + graph_module, + self.decoder, + input_samples, + self.quant_recipe.recipe, + ) + # Saving Decode QDQ Model EP for SQNR evaluation if self.mode == Mode.DECODE: qdq_ep = torch.export.export(