diff --git a/MODULE.bazel b/MODULE.bazel index c0e851220d3f..e059073ce861 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -90,6 +90,16 @@ use_repo(remote_execution_configure, "local_config_remote_execution") # Python dependencies ############################################################## +single_version_override( + module_name = "rules_python", + patch_strip = 1, + patches = [ + # Upstreamed at https://github.com/bazel-contrib/rules_python/pull/3768 + "//third_party/py:rules_python_local_wheel.patch", + ], + version = "1.8.5", +) + single_version_override( module_name = "protobuf", patch_strip = 1, @@ -139,6 +149,9 @@ pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip") "numpy": ["numpy_headers"], }, hub_name = "jax_pypi", + local_wheels = { + "libtpu": "dist/libtpu-*.whl", # On CI, we use the downloaded nightly version of libtpu. + }, python_version = python_version, requirements_by_platform = {} if python_version in [ "3.11", diff --git a/third_party/py/BUILD.bazel b/third_party/py/BUILD.bazel new file mode 100644 index 000000000000..554f568218e5 --- /dev/null +++ b/third_party/py/BUILD.bazel @@ -0,0 +1,13 @@ +# Copyright 2026 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/third_party/py/rules_python_local_wheel.patch b/third_party/py/rules_python_local_wheel.patch new file mode 100644 index 000000000000..dd0ffb5d083b --- /dev/null +++ b/third_party/py/rules_python_local_wheel.patch @@ -0,0 +1,167 @@ +diff --git a/python/private/pypi/extension.bzl b/python/private/pypi/extension.bzl +index 3927f61c..707fef59 100644 +--- a/python/private/pypi/extension.bzl ++++ b/python/private/pypi/extension.bzl +@@ -262,6 +262,7 @@ You cannot use both the additive_build_content and additive_build_content_file a + builder.pip_parse( + module_ctx, + pip_attr = pip_attr, ++ is_root = mod.is_root, + ) + + # Keeps track of all the hub's whl repos across the different versions. +@@ -631,6 +632,9 @@ hubs can be created, and each program can use its respective hub's targets. + Targets from different hubs should not be used together. + """, + ), ++ "local_wheels": attr.string_dict( ++ doc = "Dictionary mapping package names to local wheel file paths relative to the workspace root.", ++ ), + "parallel_download": attr.bool( + doc = """\ + The flag allows to make use of parallel downloading feature in bazel 7.1 and above +diff --git a/python/private/pypi/hub_builder.bzl b/python/private/pypi/hub_builder.bzl +index 700f22e2..8e6e6840 100644 +--- a/python/private/pypi/hub_builder.bzl ++++ b/python/private/pypi/hub_builder.bzl +@@ -102,7 +102,7 @@ def _build(self): + whl_libraries = self._whl_libraries, + ) + +-def _pip_parse(self, module_ctx, pip_attr): ++def _pip_parse(self, module_ctx, pip_attr, is_root = False): + python_version = pip_attr.python_version + if python_version in self._platforms: + fail(( +@@ -152,6 +152,7 @@ def _pip_parse(self, module_ctx, pip_attr): + self, + module_ctx, + pip_attr = pip_attr, ++ is_root = is_root, + enable_pipstar = self._config.enable_pipstar or self._get_index_urls.get(pip_attr.python_version), + enable_pipstar_extract = self._config.enable_pipstar_extract or self._get_index_urls.get(pip_attr.python_version), + ) +@@ -408,6 +409,7 @@ def _create_whl_repos( + module_ctx, + *, + pip_attr, ++ is_root = False, + enable_pipstar = False, + enable_pipstar_extract = False): + """create all of the whl repositories +@@ -416,6 +418,7 @@ def _create_whl_repos( + self: the builder. + module_ctx: {type}`module_ctx`. + pip_attr: {type}`struct` - the struct that comes from the tag class iteration. ++ is_root: {type}`bool` - whether the calling module is the root workspace. + enable_pipstar: {type}`bool` - enable the pipstar or not. + enable_pipstar_extract: {type}`bool` - enable the pipstar extraction or not. + """ +@@ -464,7 +467,10 @@ def _create_whl_repos( + + interpreter = _detect_interpreter(self, pip_attr) + ++ local_wheels = _collect_local_wheels(module_ctx, pip_attr, is_root = is_root) ++ + for whl in requirements_by_platform: ++ local_wheel = local_wheels.get(whl.name) + whl_library_args = common_args | _whl_library_args( + self, + whl = whl, +@@ -481,6 +487,7 @@ def _create_whl_repos( + python_version = _major_minor_version(pip_attr.python_version), + is_multiple_versions = whl.is_multiple_versions, + interpreter = interpreter, ++ local_wheel = local_wheel, + enable_pipstar = enable_pipstar, + enable_pipstar_extract = enable_pipstar_extract, + ) +@@ -559,6 +566,7 @@ def _whl_repo( + python_version, + use_downloader, + interpreter, ++ local_wheel = None, + enable_pipstar = False, + enable_pipstar_extract = False): + args = dict(whl_library_args) +@@ -622,9 +630,17 @@ def _whl_repo( + # targets to each hub for each extra combination and solve this more cleanly as opposed to + # duplicating whl_library repositories. + target_platforms = src.target_platforms if is_multiple_versions else [] ++ repo_name = whl_repo_name(src.filename, src.sha256, *target_platforms) ++ ++ if local_wheel: ++ repo_name += "_local_override" ++ path_str = local_wheel._path if hasattr(local_wheel, "_path") else str(local_wheel) ++ args["urls"] = ["file://" + path_str] ++ args["filename"] = local_wheel.basename ++ args["sha256"] = "" + + return struct( +- repo_name = whl_repo_name(src.filename, src.sha256, *target_platforms), ++ repo_name = repo_name, + args = args, + config_setting = whl_config_setting( + version = python_version, +@@ -637,3 +653,61 @@ def _use_downloader(self, python_version, whl_name): + normalize_name(whl_name), + self._get_index_urls.get(python_version) != None, + ) ++ ++def _collect_local_wheels(module_ctx, pip_attr, is_root = False): ++ if not is_root: ++ return {} ++ ++ wheels = {} ++ explicit_wheels = getattr(pip_attr, "local_wheels", None) ++ if not explicit_wheels: ++ return wheels ++ ++ workspace_root = module_ctx.path(Label("@@//:MODULE.bazel")).dirname ++ ++ for pkg_name, wheel_path_str in explicit_wheels.items(): ++ norm_name = normalize_name(pkg_name) ++ if "*" not in wheel_path_str: ++ wheel_path = workspace_root.get_child(wheel_path_str) ++ if wheel_path.exists: ++ wheels[norm_name] = wheel_path ++ else: ++ last_slash = wheel_path_str.rfind("/") ++ if last_slash >= 0: ++ dir_part = wheel_path_str[:last_slash] ++ pattern = wheel_path_str[last_slash + 1:] ++ else: ++ dir_part = "" ++ pattern = wheel_path_str ++ ++ matched_wheel = None ++ target_dir = workspace_root.get_child(dir_part) if dir_part else workspace_root ++ if target_dir.exists: ++ candidates = target_dir.readdir() ++ else: ++ candidates = [] ++ ++ for candidate in candidates: ++ if not candidate.basename.endswith(".whl"): ++ continue ++ if _wildcard_match(candidate.basename, pattern): ++ if not matched_wheel or matched_wheel.basename < candidate.basename: ++ matched_wheel = candidate ++ ++ if matched_wheel: ++ wheels[norm_name] = matched_wheel ++ ++ return wheels ++ ++def _wildcard_match(name, pattern): ++ if pattern.startswith("*") and pattern.endswith("*"): ++ return name.find(pattern[1:-1]) >= 0 ++ elif pattern.startswith("*"): ++ return name.endswith(pattern[1:]) ++ elif pattern.endswith("*"): ++ return name.startswith(pattern[:-1]) ++ elif "*" in pattern: ++ parts = pattern.split("*", 1) ++ return name.startswith(parts[0]) and name.endswith(parts[1]) and len(name) >= len(parts[0]) + len(parts[1]) ++ else: ++ return name == pattern