diff --git a/src/accelerate/checkpointing.py b/src/accelerate/checkpointing.py index 2b753e6e206..d1c4a84a73d 100644 --- a/src/accelerate/checkpointing.py +++ b/src/accelerate/checkpointing.py @@ -303,11 +303,11 @@ def load_accelerator_state( torch.sdaa.set_rng_state_all(states["torch_sdaa_manual_seed"]) elif is_musa_available(): torch.musa.set_rng_state_all(states["torch_musa_manual_seed"]) - elif is_hpu_available(): + if is_hpu_available(): torch.hpu.set_rng_state_all(states["torch_hpu_manual_seed"]) - elif is_neuron_available(): + if is_neuron_available(): torch.neuron.set_rng_state_all(states["torch_neuron_manual_seed"]) - else: + if is_cuda_available(): torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"]) if is_torch_xla_available(): xm.set_rng_state(states["xm_seed"])