diff --git a/src/accelerate/utils/torch_xla.py b/src/accelerate/utils/torch_xla.py index 140133926c2..fdcba097847 100644 --- a/src/accelerate/utils/torch_xla.py +++ b/src/accelerate/utils/torch_xla.py @@ -44,8 +44,7 @@ def install_xla(upgrade: bool = False): # get the current version of torch torch_version = importlib.metadata.version("torch") torch_version_trunc = torch_version[: torch_version.rindex(".")] - xla_wheel = f"https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-{torch_version_trunc}-cp37-cp37m-linux_x86_64.whl" - xla_install_cmd = ["pip", "install", xla_wheel] + xla_install_cmd = ["pip", "install", f"torch_xla=={torch_version_trunc}"] subprocess.run(xla_install_cmd, check=True) else: raise RuntimeError("`install_xla` utility works only on google colab.")