Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions backends/qualcomm/quantizer/quant_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


import re
import textwrap
from abc import ABC, abstractmethod
from enum import IntEnum, unique
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple
Expand Down Expand Up @@ -424,3 +425,76 @@ def summary(self, max_rows: int = -1):
rows.append(["..."] * len(headers))

return tabulate(rows, headers=headers, tablefmt="grid")

def to_source(self) -> str:
"""
Serializes this QuantRecipe into a Python source string at zero indentation.
"""

def _dtype(d: QuantDtype) -> str:
return f"QuantDtype.{d.name}"

def _granularity(g: QuantGranularity) -> str:
return f"QuantGranularity.{g.name}"

def _comments(note: str) -> str:
lines = note.strip().splitlines() if note.strip() else []
return "".join(f"# {ln}\n" for ln in lines)

indent = "\t"

def _args(*lines: str) -> str:
return "".join(f"{indent}{ln},\n" for ln in lines)

strategy_blocks: List[str] = []
for strategy in self._strategies:
extra_kwargs_flag = (
[f"extra_kwargs={strategy.extra_kwargs!r}"]
if strategy.extra_kwargs
else []
)
if isinstance(strategy, ByNodeTarget):
targets_repr = ", ".join(
f"torch.ops.{t._overloadpacket._qualified_op_name.replace('::', '.')}.{t._overloadname}"
for t in sorted(strategy.targets, key=lambda t: str(t))
)
args = _args(
f"{{{targets_repr}}}",
_dtype(strategy.quant_dtype),
str(strategy.is_qat),
"act_observer=MinMaxObserver",
f"granularity={_granularity(strategy.granularity)}",
*extra_kwargs_flag,
f"act_symmetric={strategy.act_symmetric}",
f"note={strategy.note!r}",
)
call = f".add_node_target(\n{args})"
elif isinstance(strategy, ByNameRegex):
patterns_repr = ", ".join(f'r"{p}"' for p in sorted(strategy.patterns))
args = _args(
f"{{{patterns_repr}}}",
_dtype(strategy.quant_dtype),
str(strategy.is_qat),
"act_observer=MinMaxObserver",
f"granularity={_granularity(strategy.granularity)}",
*extra_kwargs_flag,
f"act_symmetric={strategy.act_symmetric}",
f"note={strategy.note!r}",
)
call = f".add_regex(\n{args})"
else:
continue

strategy_blocks.append(_comments(strategy.note) + call)

header_args = _args(
"self.default_quant_dtype",
str(self._default_is_qat),
"act_observer=MinMaxObserver",
f"granularity={_granularity(self._default_granularity)}",
"verbose=verbose",
)
header = f"QuantRecipe(\n{header_args})"
chained = "\n".join(strategy_blocks)
body = header + "\n" + chained
return "(\n" + textwrap.indent(body, indent) + "\n)"
60 changes: 60 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9039,6 +9039,66 @@ def test_intermediate_debugger(self):
f"CSV valid count: {csv_valid_count}. SVG valid count: {svg_valid_count}"
)

def test_analyzer_to_file_generation(self):
"""
End-to-end test for PerLayerSqnrAnalyzer → SqnrReport → file generation.
"""
from executorch.examples.qualcomm.oss_scripts.llama.mix_precision_analyzer import (
PerLayerSqnrAnalyzer,
save_suggest_recipes,
)

module = SimpleModel() # noqa: F405
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
fp32_gm = torch.export.export(module, sample_input, strict=True).module()
qdq_gm = self.get_qdq_module(
module, sample_input, quant_dtype=QuantDtype.use_8a4w
)

report = PerLayerSqnrAnalyzer(
model_name="simple_conv",
num_layers=4,
fp32_gm=fp32_gm,
qdq_gm=qdq_gm,
).analyze([sample_input], num_sharding=4)

overrides = report.suggest_recipe_overrides(sqnr_threshold=22.0)

with tempfile.TemporaryDirectory() as tmp_dir:
report.save_analysis_summary(output_dir=tmp_dir)
save_suggest_recipes(report, overrides, output_dir=tmp_dir)

# --- save_analysis_summary csv file ---
with open(f"{tmp_dir}/simple_conv_quantization_error.csv") as f:
csv_content = f.read()
rows = list(csv.reader(csv_content.splitlines()))
self.assertEqual(len(rows), 5) # 1 header + 4 group rows
self.assertEqual(
rows[0],
[
"group_name",
"avg_sqnr",
"median_sqnr",
"min_sqnr",
"max_sqnr",
"count",
],
)
print(f"Sensitivity analysis:\n{csv_content}")

# --- save_suggest_recipes .py file (only written when sensitive layers exist) ---
if overrides:
with open(f"{tmp_dir}/simple_conv_suggest_recipe.py") as f:
py_content = f.read()
# generated file must be valid Python
try:
compile(py_content, "simple_conv_suggest_recipe.py", "exec")
except SyntaxError as e:
self.fail(
f"Generated recipe file has syntax error: {e}\n{py_content}"
)
self.assertIn("HOW TO USE THESE RECIPES", py_content)


def setup_environment():
parser = setup_common_args_and_variables()
Expand Down
14 changes: 14 additions & 0 deletions examples/qualcomm/oss_scripts/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,20 @@ Example:
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "I would like to learn python, could you teach me with a simple example?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen2_5-0_5b --eval_methods sqnr_eval
```

#### Quantization Guidance

To automatically identify sensitive layers and generate a mixed-precision recipe suggestion, add the `--quant_recipe_suggestion` flag. During calibration, the analyzer compares FP32 and QDQ intermediate outputs layer-by-layer using SQNR, then writes two files to the working directory:

- `{model_name}_quantization_error.csv` — per-group SQNR statistics sorted by sensitivity (most sensitive first)
- `{model_name}_suggest_recipe.py` — ready-to-use `StaticLLMQuantRecipe` subclasses optimized to apply higher-precision quantization to the most sensitive groups.

Example:
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "I would like to learn python, could you teach me with a simple example?" --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen3-1_7b --tasks wikitext --limit 1 --quant_recipe_suggestion --compile_only
```

After the run, pick one of the generated classes from `qwen3-1_7b_suggest_recipe.py` as your new recipe. For a full walkthrough, see [quantization_guidance.md](quantization_guidance.md).

#### Use attention sink for multi-turn conversations
Attention sink is a way to evict cache when maximum context length be reached.
There are two mainly concept for attention sink:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def __init__(
self.max_seq_length = pte_max_context_len

def run(self, prompt):
golden_logits = INFERENCE_REGISTRY[True](
golden_logits, _ = INFERENCE_REGISTRY[True](
get_example_inputs=self.get_example_inputs,
prompt=prompt,
module=self.source_model,
Expand Down
Loading
Loading