Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions arctic_training/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,21 @@ def validate_peft_config_type(cls, value: Optional[Dict]) -> Optional[Dict]:

@field_validator("attn_implementation", mode="after")
def validate_attn_implementation(cls, value: str) -> str:
if value in ["flash_attention_2", "flash_attention_3"]:
if value == "flash_attention_2":
try:
import flash_attn # noqa: F401
except (ImportError, ModuleNotFoundError):
raise ValueError(
f"{value} requires the flash_attn package. Install with"
"flash_attention_2 requires the flash_attn package. Install with"
" `pip install flash_attn`. Please refer to documentation at"
" https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2."
)
elif value == "flash_attention_3":
try:
import flash_attn_3 # noqa: F401
except (ImportError, ModuleNotFoundError):
raise ValueError(
"flash_attention_3 requires the flash_attn_3 package."
" For FA3 build from the github source: git clone https://github.com/Dao-AILab/flash-attention;"
" cd flash-attention/hopper; pip install . --no-build-isolation --no-clean"
)
Expand Down
Loading