Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/build_linux_jax_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ jobs:
run: |
ls -lah
pushd rocm-jax
[[ "${{ inputs.jax_ref }}" != *"0.8.0"* && "${{ inputs.build_jaxlib }}" == "true" ]] && SOURCE_ARG="--jax-source-dir=${{ github.workspace }}/jax" || SOURCE_ARG=""
# --jax-source-dir is only needed when building jaxlib from source
# (JAX <= 0.9.0). For JAX >= 0.9.1, jaxlib comes from upstream PyPI.
[[ "${{ inputs.build_jaxlib }}" == "true" ]] && SOURCE_ARG="--jax-source-dir=${{ github.workspace }}/jax" || SOURCE_ARG=""
python3 build/ci_build \
--compiler=clang \
--python-versions="${{ inputs.python_version }}" \
Expand Down
4 changes: 1 addition & 3 deletions .github/workflows/release_portable_linux_jax_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,8 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11", "3.12", "3.13", "3.14"]
jax_ref: ["rocm-jaxlib-v0.8.0", "rocm-jaxlib-v0.8.2", "rocm-jaxlib-v0.9.0", "rocm-jaxlib-v0.9.1"]
jax_ref: ["rocm-jaxlib-v0.8.2", "rocm-jaxlib-v0.9.0", "rocm-jaxlib-v0.9.1"]
include:
- jax_ref: "rocm-jaxlib-v0.8.0"
build_jaxlib: true
- jax_ref: "rocm-jaxlib-v0.8.2"
build_jaxlib: true
- jax_ref: "rocm-jaxlib-v0.9.0"
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test_jax_dockerfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ on:
required: true
description: JAX plugin branch to checkout
type: string
default: "rocm-jaxlib-v0.6.0"
default: "rocm-jaxlib-v0.8.2"

workflow_call:
inputs:
Expand All @@ -40,7 +40,7 @@ on:
jax_plugin_branch:
description: JAX plugin branch to checkout to use for test scripts
type: string
default: "rocm-jaxlib-v0.8.0"
default: "rocm-jaxlib-v0.8.2"

permissions:
contents: read
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/test_linux_jax_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ on:
type: string
default: "rocm-jaxlib-v0.8.2"
jax_version:
description: "Base JAX version (e.g. 0.8.0) for installing jax from PyPI. Extracted from built wheels by write_jax_versions.py."
description: "Base JAX version (e.g. 0.8.2) for installing jax from PyPI. Extracted from built wheels by write_jax_versions.py."
required: false
type: string
jaxlib_version:
description: "jaxlib wheel version (e.g. 0.9.0+rocm7 or 0.8.0+rocm7.12.0). Extracted from built wheels by write_jax_versions.py."
description: "jaxlib wheel version (e.g. 0.9.0+rocm7 or 0.8.2+rocm7.12.0). Extracted from built wheels by write_jax_versions.py."
required: false
type: string
jax_plugin_version:
Expand Down Expand Up @@ -108,11 +108,11 @@ on:
type: string
default: "rocm-jaxlib-v0.8.2"
jax_version:
description: "Base JAX version (e.g. 0.8.0). Leave empty to auto-detect from rocm-jax requirements."
description: "Base JAX version (e.g. 0.8.2). Leave empty to auto-detect from rocm-jax requirements."
required: false
type: string
jaxlib_version:
description: "jaxlib wheel version (e.g. 0.9.0+rocm7 or 0.8.0+rocm7.12.0). Leave empty to auto-compute from rocm_version."
description: "jaxlib wheel version (e.g. 0.9.0+rocm7 or 0.8.2+rocm7.12.0). Leave empty to auto-compute from rocm_version."
required: false
type: string
jax_plugin_version:
Expand Down
6 changes: 3 additions & 3 deletions build_tools/github_actions/compute_jax_package_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
The following strings are appended to the file specified in the "GITHUB_ENV"
environment variable:

JAX_VERSION=0.8.0
JAXLIB_VERSION=0.8.0+rocm7.12.0.dev0.e1a5d395
JAX_VERSION=0.8.2
JAXLIB_VERSION=0.8.2+rocm7.12.0.dev0.e1a5d395
"""

import argparse
Expand All @@ -37,7 +37,7 @@
def extract_jax_version_from_requirements(requirements_path: str) -> str:
"""Extracts the JAX version from a requirements.txt file.

Looks for lines like 'jax==0.8.0' or 'jaxlib==0.8.0' and returns
Looks for lines like 'jax==0.8.2' or 'jaxlib==0.8.2' and returns
the version number.
"""
pattern = re.compile(r"^\s*(jax|jaxlib)\s*==\s*([^#\s]+)")
Expand Down
10 changes: 6 additions & 4 deletions external-builds/jax/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ Support for JAX is provided via stable release branches of
| JAX version | Linux | Windows |
| ----------- | --------------------------------------------------------------------------------------------------------------- | ---------------- |
| 0.9.1 | ✅ Supported via [ROCm/rocm-jax `rocm-jaxlib-v0.9.1`](https://github.com/ROCm/rocm-jax/tree/rocm-jaxlib-v0.9.1) | ❌ Not supported |
| 0.9.0 | ✅ Supported via [ROCm/rocm-jax `rocm-jaxlib-v0.9.0`](https://github.com/ROCm/rocm-jax/tree/rocm-jaxlib-v0.9.0) | ❌ Not supported |
| 0.8.2 | ✅ Supported via [ROCm/rocm-jax `rocm-jaxlib-v0.8.2`](https://github.com/ROCm/rocm-jax/tree/rocm-jaxlib-v0.8.2) | ❌ Not supported |
| 0.8.0 | ✅ Supported via [ROCm/rocm-jax `rocm-jaxlib-v0.8.0`](https://github.com/ROCm/rocm-jax/tree/rocm-jaxlib-v0.8.0) | ❌ Not supported |

See also:

Expand Down Expand Up @@ -97,7 +97,7 @@ provide it via **tarballs** with arbitrary install locations.

1. Choose your configuration:

- **JAX version**: e.g. `0.8.2` or `0.8.0`
- **JAX version**: e.g. `0.9.1`, `0.9.0`, or `0.8.2`
- **Python version**: e.g. `3.12`
- **TheRock tarball**: A tarball URL, a local tarball file path, or a
directory containing a ROCm installation. Nightly tarballs are available
Expand All @@ -118,8 +118,10 @@ provide it via **tarballs** with arbitrary install locations.
```

> [!NOTE]
> The `--jax-source-dir` flag is required for JAX 0.8.2 and points to the
> cloned `jax` repository directory. For JAX 0.8.0, this flag can be omitted.
> The `--jax-source-dir` flag is required when building jaxlib from source
> (JAX \<= 0.9.0) and points to the cloned `jax` repository directory.
> For JAX >= 0.9.1, jaxlib is installed from upstream PyPI, so this flag
> can be omitted.

1. Locate built wheels:

Expand Down
Loading