Add CoreML-stable RMSNorm for llama eager paths#19523
Conversation
Summary: The standard `RMSNorm` formulation `x * rsqrt(mean(x²)) * weight` is numerically unstable on CoreML/ANE because the explicit FP32 cast around the mean reduction is silently stripped from the lowered graph, leaving the squared sum to overflow in FP16. The ANE PTE then diverges from the eager reference even on checkpoints fine-tuned in BF16/FP16. This diff introduces `RMSNormCoreML` in `examples/models/llama/norm.py`. The module expresses the normalization as `x * sqrt(d) / vector_norm(x, dim=-1)` — `torch.linalg.vector_norm` keeps the reduction in a single op that survives CoreML lowering, so FP16 inference remains stable. To avoid `0 / 0 = NaN` on zero-padded positions (chunked prefill in `StaticAttentionIOManager` pads each chunk to `input_len` with zeros), the denominator is floored with `sqrt(dim * eps)`. This matches standard RMSNorm's `rsqrt(mean(x²) + eps)` semantics on a zero input and is large enough to survive fp16 — a plain `1e-6` underflows. Real (non-zero) tokens satisfy `vector_norm(x) >> sqrt(dim * eps)`, so the floor is a no-op on real positions. A new `use_coreml_norm: bool = False` field on `ModelArgs` opts into the new norm without disturbing existing models. When True, every llama-side norm site constructs `RMSNormCoreML`: - `llama_transformer.py`: `attention_norm`, `ffn_norm`, the final `self.norm` on `Transformer`. - `attention.py`: `q_norm_fn` / `k_norm_fn` in the affine QK-norm path, AND the `else` branch of `_init_qk_norms` (the scaleless / non-affine QK-norm path that the original landing missed). - `static_attention.py`: `q_norm` / `k_norm` in the scaleless path, propagated through `from_attention_mha` by detecting `rms_norm_class is RMSNormCoreML`. The QNN/HTP export path is untouched and continues to use `torch.nn.RMSNorm`. Differential Revision: D104862210
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19523
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 Awaiting Approval, 4 New Failures, 2 Unrelated FailuresAs of commit cda18f8 with merge base 99f1f0b ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@telgamal-1 has exported this pull request. If you are a Meta employee, you can view the originating Diff in D104862210. |
This PR needs a
|
| eps (float, optional): Stored for API compatibility; ignored in the math. | ||
|
|
||
| Attributes: | ||
| eps (float): Stored for API compatibility; not consumed by `_norm`. |
There was a problem hiding this comment.
Can we assert eps is 0 rather than silently drop it?
Summary:
The standard
RMSNormformulationx * rsqrt(mean(x²)) * weightis numerically unstable on CoreML/ANE because the explicit FP32 cast around the mean reduction is silently stripped from the lowered graph, leaving the squared sum to overflow in FP16. The ANE PTE then diverges from the eager reference even on checkpoints fine-tuned in BF16/FP16.This diff introduces
RMSNormCoreMLinexamples/models/llama/norm.py. The module expresses the normalization asx * sqrt(d) / vector_norm(x, dim=-1)—torch.linalg.vector_normkeeps the reduction in a single op that survives CoreML lowering, so FP16 inference remains stable.To avoid
0 / 0 = NaNon zero-padded positions (chunked prefill inStaticAttentionIOManagerpads each chunk toinput_lenwith zeros), the denominator is floored withsqrt(dim * eps). This matches standard RMSNorm'srsqrt(mean(x²) + eps)semantics on a zero input and is large enough to survive fp16 — a plain1e-6underflows. Real (non-zero) tokens satisfyvector_norm(x) >> sqrt(dim * eps), so the floor is a no-op on real positions.A new
use_coreml_norm: bool = Falsefield onModelArgsopts into the new norm without disturbing existing models. When True, every llama-side norm site constructsRMSNormCoreML:llama_transformer.py:attention_norm,ffn_norm, the finalself.normonTransformer.attention.py:q_norm_fn/k_norm_fnin the affine QK-norm path, AND theelsebranch of_init_qk_norms(the scaleless / non-affine QK-norm path that the original landing missed).static_attention.py:q_norm/k_normin the scaleless path, propagated throughfrom_attention_mhaby detectingrms_norm_class is RMSNormCoreML.The QNN/HTP export path is untouched and continues to use
torch.nn.RMSNorm.Differential Revision: D104862210