diff --git a/arctic_training/config/model.py b/arctic_training/config/model.py index d5c9b820..709b8125 100644 --- a/arctic_training/config/model.py +++ b/arctic_training/config/model.py @@ -84,14 +84,21 @@ def validate_peft_config_type(cls, value: Optional[Dict]) -> Optional[Dict]: @field_validator("attn_implementation", mode="after") def validate_attn_implementation(cls, value: str) -> str: - if value in ["flash_attention_2", "flash_attention_3"]: + if value == "flash_attention_2": try: import flash_attn # noqa: F401 except (ImportError, ModuleNotFoundError): raise ValueError( - f"{value} requires the flash_attn package. Install with" + "flash_attention_2 requires the flash_attn package. Install with" " `pip install flash_attn`. Please refer to documentation at" " https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2." + ) + elif value == "flash_attention_3": + try: + import flash_attn_3 # noqa: F401 + except (ImportError, ModuleNotFoundError): + raise ValueError( + "flash_attention_3 requires the flash_attn_3 package." " For FA3 build from the github source: git clone https://github.com/Dao-AILab/flash-attention;" " cd flash-attention/hopper; pip install . --no-build-isolation --no-clean" )