Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
runs-on: "cirun-aws-gpu--${{ github.run_id }}"
strategy:
matrix:
extras: ["torch,cupy-cuda12", "torch", "cupy-cuda12"]
extras: ["torch,cupy-cuda12", "torch", "cupy-cuda12", "jax-cuda12"]
# Setting a timeout of 30 minutes, as the AWS costs money
# At time of writing, a typical run takes about 5 minutes
timeout-minutes: 30
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ with ad.settings.override(remove_unused_categories=False):
batch_size=4096,
chunk_size=32,
preload_nchunks=256,
to_torch=True
to="torch"
)
# `use_collection` automatically uses the on-disk `X` and full `obs` in the `Loader`
# but the `load_adata` arg can override this behavior
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@
" preload_nchunks=32, # Number of chunks to preload + shuffle - default settings should work well\n",
" # If True, preloaded chunks are moved to GPU memory via `cupy`, which can put more pressure on GPU memory but will accelerate loading ~20%\n",
" preload_to_gpu=False,\n",
" to_torch=True,\n",
" to=\"torch\",\n",
")\n",
"\n",
"# Add in the shuffled data that should be used for training.\n",
Expand Down
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ optional-dependencies.doc = [
"sphinxcontrib-bibtex>=1",
"sphinxext-opengraph",
]
optional-dependencies.jax = [ "jax" ]
optional-dependencies.jax-cuda12 = [
"jax[cuda12]",
]
optional-dependencies.jax-cuda13 = [
"jax[cuda13]",
]
optional-dependencies.test = [
"annbatch[zarrs]",
"coverage",
Expand All @@ -88,7 +95,7 @@ envs.docs.scripts.clean = "git clean -fdX -- {args:docs}"
envs.hatch-test.features = [ "test" ]
envs.hatch-test.python = "3.14"
envs.hatch-test.matrix = [
{ deps = [ "min-low", "pre", "torch", "min-high" ] },
{ deps = [ "min-low", "pre", "torch", "min-high", "jax" ] },
]
# If the matrix variable `deps` is set to "pre",
# set the environment variable `UV_PRERELEASE` to "allow".
Expand All @@ -106,6 +113,7 @@ envs.hatch-test.overrides.matrix.deps.python = [
]
envs.hatch-test.overrides.matrix.deps.features = [
{ if = [ "torch" ], value = "torch" },
{ if = [ "jax" ], value = "jax" },
]

[tool.ruff]
Expand Down
7 changes: 7 additions & 0 deletions src/annbatch/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,10 @@
from cupyx.scipy.sparse import csr_matrix as CupyCSRMatrix
else:
CupyCSRMatrix = type("csr_matrix", (), {"__module__": "cupyx.scipy.sparse"})

if TYPE_CHECKING or find_spec("jax"):
from jax import Array as JaxArray
from jax.experimental.sparse import CSR as JAXCsrMatrix
else:
JAXCsrMatrix = type("CSR", (), {"__module__": "jax.experimental.sparse"})
JaxArray = type("Array", (), {"__module__": "jax"})
Loading
Loading