Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ use_repo(remote_execution_configure, "local_config_remote_execution")
# Python dependencies
##############################################################

single_version_override(
module_name = "rules_python",
patch_strip = 1,
patches = [
# Upstreamed at https://github.com/bazel-contrib/rules_python/pull/3768
"//third_party/py:rules_python_local_wheel.patch",
],
version = "1.8.5",
)

single_version_override(
module_name = "protobuf",
patch_strip = 1,
Expand Down Expand Up @@ -139,6 +149,9 @@ pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip")
"numpy": ["numpy_headers"],
},
hub_name = "jax_pypi",
local_wheels = {
"libtpu": "dist/libtpu-*.whl", # On CI, we use the downloaded nightly version of libtpu.
},
python_version = python_version,
requirements_by_platform = {} if python_version in [
"3.11",
Expand Down
13 changes: 13 additions & 0 deletions third_party/py/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2026 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
167 changes: 167 additions & 0 deletions third_party/py/rules_python_local_wheel.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
diff --git a/python/private/pypi/extension.bzl b/python/private/pypi/extension.bzl
index 3927f61c..707fef59 100644
--- a/python/private/pypi/extension.bzl
+++ b/python/private/pypi/extension.bzl
@@ -262,6 +262,7 @@ You cannot use both the additive_build_content and additive_build_content_file a
builder.pip_parse(
module_ctx,
pip_attr = pip_attr,
+ is_root = mod.is_root,
)

# Keeps track of all the hub's whl repos across the different versions.
@@ -631,6 +632,9 @@ hubs can be created, and each program can use its respective hub's targets.
Targets from different hubs should not be used together.
""",
),
+ "local_wheels": attr.string_dict(
+ doc = "Dictionary mapping package names to local wheel file paths relative to the workspace root.",
+ ),
"parallel_download": attr.bool(
doc = """\
The flag allows to make use of parallel downloading feature in bazel 7.1 and above
diff --git a/python/private/pypi/hub_builder.bzl b/python/private/pypi/hub_builder.bzl
index 700f22e2..8e6e6840 100644
--- a/python/private/pypi/hub_builder.bzl
+++ b/python/private/pypi/hub_builder.bzl
@@ -102,7 +102,7 @@ def _build(self):
whl_libraries = self._whl_libraries,
)

-def _pip_parse(self, module_ctx, pip_attr):
+def _pip_parse(self, module_ctx, pip_attr, is_root = False):
python_version = pip_attr.python_version
if python_version in self._platforms:
fail((
@@ -152,6 +152,7 @@ def _pip_parse(self, module_ctx, pip_attr):
self,
module_ctx,
pip_attr = pip_attr,
+ is_root = is_root,
enable_pipstar = self._config.enable_pipstar or self._get_index_urls.get(pip_attr.python_version),
enable_pipstar_extract = self._config.enable_pipstar_extract or self._get_index_urls.get(pip_attr.python_version),
)
@@ -408,6 +409,7 @@ def _create_whl_repos(
module_ctx,
*,
pip_attr,
+ is_root = False,
enable_pipstar = False,
enable_pipstar_extract = False):
"""create all of the whl repositories
@@ -416,6 +418,7 @@ def _create_whl_repos(
self: the builder.
module_ctx: {type}`module_ctx`.
pip_attr: {type}`struct` - the struct that comes from the tag class iteration.
+ is_root: {type}`bool` - whether the calling module is the root workspace.
enable_pipstar: {type}`bool` - enable the pipstar or not.
enable_pipstar_extract: {type}`bool` - enable the pipstar extraction or not.
"""
@@ -464,7 +467,10 @@ def _create_whl_repos(

interpreter = _detect_interpreter(self, pip_attr)

+ local_wheels = _collect_local_wheels(module_ctx, pip_attr, is_root = is_root)
+
for whl in requirements_by_platform:
+ local_wheel = local_wheels.get(whl.name)
whl_library_args = common_args | _whl_library_args(
self,
whl = whl,
@@ -481,6 +487,7 @@ def _create_whl_repos(
python_version = _major_minor_version(pip_attr.python_version),
is_multiple_versions = whl.is_multiple_versions,
interpreter = interpreter,
+ local_wheel = local_wheel,
enable_pipstar = enable_pipstar,
enable_pipstar_extract = enable_pipstar_extract,
)
@@ -559,6 +566,7 @@ def _whl_repo(
python_version,
use_downloader,
interpreter,
+ local_wheel = None,
enable_pipstar = False,
enable_pipstar_extract = False):
args = dict(whl_library_args)
@@ -622,9 +630,17 @@ def _whl_repo(
# targets to each hub for each extra combination and solve this more cleanly as opposed to
# duplicating whl_library repositories.
target_platforms = src.target_platforms if is_multiple_versions else []
+ repo_name = whl_repo_name(src.filename, src.sha256, *target_platforms)
+
+ if local_wheel:
+ repo_name += "_local_override"
+ path_str = local_wheel._path if hasattr(local_wheel, "_path") else str(local_wheel)
+ args["urls"] = ["file://" + path_str]
+ args["filename"] = local_wheel.basename
+ args["sha256"] = ""

return struct(
- repo_name = whl_repo_name(src.filename, src.sha256, *target_platforms),
+ repo_name = repo_name,
args = args,
config_setting = whl_config_setting(
version = python_version,
@@ -637,3 +653,61 @@ def _use_downloader(self, python_version, whl_name):
normalize_name(whl_name),
self._get_index_urls.get(python_version) != None,
)
+
+def _collect_local_wheels(module_ctx, pip_attr, is_root = False):
+ if not is_root:
+ return {}
+
+ wheels = {}
+ explicit_wheels = getattr(pip_attr, "local_wheels", None)
+ if not explicit_wheels:
+ return wheels
+
+ workspace_root = module_ctx.path(Label("@@//:MODULE.bazel")).dirname
+
+ for pkg_name, wheel_path_str in explicit_wheels.items():
+ norm_name = normalize_name(pkg_name)
+ if "*" not in wheel_path_str:
+ wheel_path = workspace_root.get_child(wheel_path_str)
+ if wheel_path.exists:
+ wheels[norm_name] = wheel_path
+ else:
+ last_slash = wheel_path_str.rfind("/")
+ if last_slash >= 0:
+ dir_part = wheel_path_str[:last_slash]
+ pattern = wheel_path_str[last_slash + 1:]
+ else:
+ dir_part = ""
+ pattern = wheel_path_str
+
+ matched_wheel = None
+ target_dir = workspace_root.get_child(dir_part) if dir_part else workspace_root
+ if target_dir.exists:
+ candidates = target_dir.readdir()
+ else:
+ candidates = []
+
+ for candidate in candidates:
+ if not candidate.basename.endswith(".whl"):
+ continue
+ if _wildcard_match(candidate.basename, pattern):
+ if not matched_wheel or matched_wheel.basename < candidate.basename:
+ matched_wheel = candidate
+
+ if matched_wheel:
+ wheels[norm_name] = matched_wheel
+
+ return wheels
+
+def _wildcard_match(name, pattern):
+ if pattern.startswith("*") and pattern.endswith("*"):
+ return name.find(pattern[1:-1]) >= 0
+ elif pattern.startswith("*"):
+ return name.endswith(pattern[1:])
+ elif pattern.endswith("*"):
+ return name.startswith(pattern[:-1])
+ elif "*" in pattern:
+ parts = pattern.split("*", 1)
+ return name.startswith(parts[0]) and name.endswith(parts[1]) and len(name) >= len(parts[0]) + len(parts[1])
+ else:
+ return name == pattern
Loading