diff --git a/.github/workflows/build_linux_jax_wheels.yml b/.github/workflows/build_linux_jax_wheels.yml index ca174673e87..d88e78db95a 100644 --- a/.github/workflows/build_linux_jax_wheels.yml +++ b/.github/workflows/build_linux_jax_wheels.yml @@ -184,7 +184,9 @@ jobs: run: | ls -lah pushd rocm-jax - [[ "${{ inputs.jax_ref }}" != *"0.8.0"* && "${{ inputs.build_jaxlib }}" == "true" ]] && SOURCE_ARG="--jax-source-dir=${{ github.workspace }}/jax" || SOURCE_ARG="" + # --jax-source-dir is only needed when building jaxlib from source + # (JAX <= 0.9.0). For JAX >= 0.9.1, jaxlib comes from upstream PyPI. + [[ "${{ inputs.build_jaxlib }}" == "true" ]] && SOURCE_ARG="--jax-source-dir=${{ github.workspace }}/jax" || SOURCE_ARG="" python3 build/ci_build \ --compiler=clang \ --python-versions="${{ inputs.python_version }}" \ diff --git a/.github/workflows/release_portable_linux_jax_wheels.yml b/.github/workflows/release_portable_linux_jax_wheels.yml index abcc52368da..959770beaaf 100644 --- a/.github/workflows/release_portable_linux_jax_wheels.yml +++ b/.github/workflows/release_portable_linux_jax_wheels.yml @@ -103,10 +103,8 @@ jobs: fail-fast: false matrix: python_version: ["3.11", "3.12", "3.13", "3.14"] - jax_ref: ["rocm-jaxlib-v0.8.0", "rocm-jaxlib-v0.8.2", "rocm-jaxlib-v0.9.0", "rocm-jaxlib-v0.9.1"] + jax_ref: ["rocm-jaxlib-v0.8.2", "rocm-jaxlib-v0.9.0", "rocm-jaxlib-v0.9.1"] include: - - jax_ref: "rocm-jaxlib-v0.8.0" - build_jaxlib: true - jax_ref: "rocm-jaxlib-v0.8.2" build_jaxlib: true - jax_ref: "rocm-jaxlib-v0.9.0" diff --git a/.github/workflows/test_jax_dockerfile.yml b/.github/workflows/test_jax_dockerfile.yml index c66ce650ebb..e258582fe32 100644 --- a/.github/workflows/test_jax_dockerfile.yml +++ b/.github/workflows/test_jax_dockerfile.yml @@ -22,7 +22,7 @@ on: required: true description: JAX plugin branch to checkout type: string - default: "rocm-jaxlib-v0.6.0" + default: "rocm-jaxlib-v0.8.2" workflow_call: inputs: @@ -40,7 +40,7 @@ on: jax_plugin_branch: description: JAX plugin branch to checkout to use for test scripts type: string - default: "rocm-jaxlib-v0.8.0" + default: "rocm-jaxlib-v0.8.2" permissions: contents: read diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 485ea328521..c11670a412d 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -37,11 +37,11 @@ on: type: string default: "rocm-jaxlib-v0.8.2" jax_version: - description: "Base JAX version (e.g. 0.8.0) for installing jax from PyPI. Extracted from built wheels by write_jax_versions.py." + description: "Base JAX version (e.g. 0.8.2) for installing jax from PyPI. Extracted from built wheels by write_jax_versions.py." required: false type: string jaxlib_version: - description: "jaxlib wheel version (e.g. 0.9.0+rocm7 or 0.8.0+rocm7.12.0). Extracted from built wheels by write_jax_versions.py." + description: "jaxlib wheel version (e.g. 0.9.0+rocm7 or 0.8.2+rocm7.12.0). Extracted from built wheels by write_jax_versions.py." required: false type: string jax_plugin_version: @@ -108,11 +108,11 @@ on: type: string default: "rocm-jaxlib-v0.8.2" jax_version: - description: "Base JAX version (e.g. 0.8.0). Leave empty to auto-detect from rocm-jax requirements." + description: "Base JAX version (e.g. 0.8.2). Leave empty to auto-detect from rocm-jax requirements." required: false type: string jaxlib_version: - description: "jaxlib wheel version (e.g. 0.9.0+rocm7 or 0.8.0+rocm7.12.0). Leave empty to auto-compute from rocm_version." + description: "jaxlib wheel version (e.g. 0.9.0+rocm7 or 0.8.2+rocm7.12.0). Leave empty to auto-compute from rocm_version." required: false type: string jax_plugin_version: diff --git a/build_tools/github_actions/compute_jax_package_version.py b/build_tools/github_actions/compute_jax_package_version.py index 2667e2c504b..0dc4a14be79 100644 --- a/build_tools/github_actions/compute_jax_package_version.py +++ b/build_tools/github_actions/compute_jax_package_version.py @@ -23,8 +23,8 @@ The following strings are appended to the file specified in the "GITHUB_ENV" environment variable: - JAX_VERSION=0.8.0 - JAXLIB_VERSION=0.8.0+rocm7.12.0.dev0.e1a5d395 + JAX_VERSION=0.8.2 + JAXLIB_VERSION=0.8.2+rocm7.12.0.dev0.e1a5d395 """ import argparse @@ -37,7 +37,7 @@ def extract_jax_version_from_requirements(requirements_path: str) -> str: """Extracts the JAX version from a requirements.txt file. - Looks for lines like 'jax==0.8.0' or 'jaxlib==0.8.0' and returns + Looks for lines like 'jax==0.8.2' or 'jaxlib==0.8.2' and returns the version number. """ pattern = re.compile(r"^\s*(jax|jaxlib)\s*==\s*([^#\s]+)") diff --git a/external-builds/jax/README.md b/external-builds/jax/README.md index 78e5735e92f..9d8354d8a3b 100644 --- a/external-builds/jax/README.md +++ b/external-builds/jax/README.md @@ -36,8 +36,8 @@ Support for JAX is provided via stable release branches of | JAX version | Linux | Windows | | ----------- | --------------------------------------------------------------------------------------------------------------- | ---------------- | | 0.9.1 | ✅ Supported via [ROCm/rocm-jax `rocm-jaxlib-v0.9.1`](https://github.com/ROCm/rocm-jax/tree/rocm-jaxlib-v0.9.1) | ❌ Not supported | +| 0.9.0 | ✅ Supported via [ROCm/rocm-jax `rocm-jaxlib-v0.9.0`](https://github.com/ROCm/rocm-jax/tree/rocm-jaxlib-v0.9.0) | ❌ Not supported | | 0.8.2 | ✅ Supported via [ROCm/rocm-jax `rocm-jaxlib-v0.8.2`](https://github.com/ROCm/rocm-jax/tree/rocm-jaxlib-v0.8.2) | ❌ Not supported | -| 0.8.0 | ✅ Supported via [ROCm/rocm-jax `rocm-jaxlib-v0.8.0`](https://github.com/ROCm/rocm-jax/tree/rocm-jaxlib-v0.8.0) | ❌ Not supported | See also: @@ -97,7 +97,7 @@ provide it via **tarballs** with arbitrary install locations. 1. Choose your configuration: - - **JAX version**: e.g. `0.8.2` or `0.8.0` + - **JAX version**: e.g. `0.9.1`, `0.9.0`, or `0.8.2` - **Python version**: e.g. `3.12` - **TheRock tarball**: A tarball URL, a local tarball file path, or a directory containing a ROCm installation. Nightly tarballs are available @@ -118,8 +118,10 @@ provide it via **tarballs** with arbitrary install locations. ``` > [!NOTE] - > The `--jax-source-dir` flag is required for JAX 0.8.2 and points to the - > cloned `jax` repository directory. For JAX 0.8.0, this flag can be omitted. + > The `--jax-source-dir` flag is required when building jaxlib from source + > (JAX \<= 0.9.0) and points to the cloned `jax` repository directory. + > For JAX >= 0.9.1, jaxlib is installed from upstream PyPI, so this flag + > can be omitted. 1. Locate built wheels: