From 162621b8515f0ec77a3c8f7144153bfe3c4f54cf Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 23 Apr 2026 12:38:52 -0700 Subject: [PATCH 1/3] Refactor validation for attention implementations Refactor attention implementation validation to handle specific cases for 'flash_attention_2' and 'flash_attention_3'. Updated error messages for clarity. Currently even if FA3 is installed, but not FA2, it'll fail to load. --- arctic_training/config/model.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/arctic_training/config/model.py b/arctic_training/config/model.py index d5c9b820..70f12d1e 100644 --- a/arctic_training/config/model.py +++ b/arctic_training/config/model.py @@ -84,15 +84,22 @@ def validate_peft_config_type(cls, value: Optional[Dict]) -> Optional[Dict]: @field_validator("attn_implementation", mode="after") def validate_attn_implementation(cls, value: str) -> str: - if value in ["flash_attention_2", "flash_attention_3"]: + if value == "flash_attention_2": try: import flash_attn # noqa: F401 except (ImportError, ModuleNotFoundError): raise ValueError( - f"{value} requires the flash_attn package. Install with" + "flash_attention_2 requires the flash_attn package. Install with" " `pip install flash_attn`. Please refer to documentation at" " https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2." + ) + elif value == "flash_attention_3": + try: + import flash_attn_3 # noqa: F401 + except (ImportError, ModuleNotFoundError): + raise ValueError( + f"flash_attention_3 requires the flash_attn_3 package." " For FA3 build from the github source: git clone https://github.com/Dao-AILab/flash-attention;" " cd flash-attention/hopper; pip install . --no-build-isolation --no-clean" - ) + ) return value From 6fc6aa16eb673debfbecaf9260cf1e503b7a5254 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 23 Apr 2026 12:39:47 -0700 Subject: [PATCH 2/3] Update model.py --- arctic_training/config/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arctic_training/config/model.py b/arctic_training/config/model.py index 70f12d1e..5c3a3585 100644 --- a/arctic_training/config/model.py +++ b/arctic_training/config/model.py @@ -101,5 +101,5 @@ def validate_attn_implementation(cls, value: str) -> str: f"flash_attention_3 requires the flash_attn_3 package." " For FA3 build from the github source: git clone https://github.com/Dao-AILab/flash-attention;" " cd flash-attention/hopper; pip install . --no-build-isolation --no-clean" - ) + ) return value From 152e93463158383bc15712352bd20bf13b3ea9d6 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 23 Apr 2026 19:58:15 +0000 Subject: [PATCH 3/3] style Signed-off-by: Stas Bekman --- arctic_training/config/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arctic_training/config/model.py b/arctic_training/config/model.py index 5c3a3585..709b8125 100644 --- a/arctic_training/config/model.py +++ b/arctic_training/config/model.py @@ -98,7 +98,7 @@ def validate_attn_implementation(cls, value: str) -> str: import flash_attn_3 # noqa: F401 except (ImportError, ModuleNotFoundError): raise ValueError( - f"flash_attention_3 requires the flash_attn_3 package." + "flash_attention_3 requires the flash_attn_3 package." " For FA3 build from the github source: git clone https://github.com/Dao-AILab/flash-attention;" " cd flash-attention/hopper; pip install . --no-build-isolation --no-clean" )