-
Notifications
You must be signed in to change notification settings - Fork 228
Expand file tree
/
Copy pathrelease_portable_linux_jax_wheels.yml
More file actions
120 lines (114 loc) · 4.14 KB
/
release_portable_linux_jax_wheels.yml
File metadata and controls
120 lines (114 loc) · 4.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT
name: Release portable Linux JAX Wheels
on:
workflow_call:
inputs:
amdgpu_family:
required: true
type: string
release_type:
description: The type of release to build ("dev", "nightly", or "prerelease"). All developer-triggered jobs should use "dev"!
type: string
default: "dev"
s3_subdir:
description: S3 subdirectory, not including the GPU-family
type: string
default: "v2"
s3_staging_subdir:
description: Staging subdirectory to push the wheels for test
type: string
default: "v2-staging"
cloudfront_url:
description: CloudFront URL pointing to Python index
required: true
type: string
cloudfront_staging_url:
description: CloudFront base URL pointing to staging Python index
required: true
type: string
rocm_version:
description: ROCm version to install (e.g. "7.10.0a20251124")
type: string
tar_url:
description: "URL to TheRock tarball to build against (e.g. https://rocm.nightlies.amd.com/tarball/therock-dist-linux-gfx94X-dcgpu-7.10.0a20251124.tar.gz)"
type: string
ref:
description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow."
type: string
workflow_dispatch:
inputs:
amdgpu_family:
type: choice
options:
- gfx101X-dgpu
- gfx103X-all
- gfx110X-all
- gfx1150
- gfx1151
- gfx1152
- gfx1153
- gfx120X-all
- gfx900
- gfx906
- gfx908
- gfx90a
- gfx94X-dcgpu
- gfx950-dcgpu
default: gfx94X-dcgpu
release_type:
description: The type of release to build ("dev", "nightly", or "prerelease"). All developer-triggered jobs should use "dev"!
type: string
default: "dev"
s3_subdir:
description: S3 subdirectory, not including the GPU-family
type: string
default: "v2"
s3_staging_subdir:
description: "Staging subdirectory to push the wheels for test"
type: string
default: "v2-staging"
cloudfront_url:
description: CloudFront URL pointing to Python index
type: string
default: "https://rocm.devreleases.amd.com/v2"
cloudfront_staging_url:
description: CloudFront base URL pointing to staging Python index
type: string
default: "https://rocm.devreleases.amd.com/v2-staging"
rocm_version:
description: ROCm version to install (e.g. "7.10.0a20251124")
type: string
tar_url:
description: "URL to TheRock tarball to build (e.g. https://rocm.nightlies.amd.com/tarball/therock-dist-linux-gfx94X-dcgpu-7.10.0a20251124.tar.gz)"
type: string
ref:
description: "TheRock branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow."
type: string
default: ''
permissions:
id-token: write
contents: read
packages: read
run-name: Release portable Linux JAX Wheels (${{ inputs.amdgpu_family }}, ${{ inputs.release_type }}, ${{ inputs.rocm_version }})
jobs:
release:
name: Release | ${{ inputs.amdgpu_family }} | py ${{ matrix.python_version }} | jax ${{ matrix.jax_ref }}
strategy:
fail-fast: false
matrix:
python_version: ["3.11", "3.12", "3.13", "3.14"]
jax_ref: ["rocm-jaxlib-v0.8.0", "rocm-jaxlib-v0.8.2", "rocm-jaxlib-v0.9.0"]
uses: ./.github/workflows/build_linux_jax_wheels.yml
with:
amdgpu_family: ${{ inputs.amdgpu_family }}
python_version: ${{ matrix.python_version }}
release_type: ${{ inputs.release_type }}
s3_subdir: ${{ inputs.s3_subdir }}
s3_staging_subdir: ${{ inputs.s3_staging_subdir }}
cloudfront_url: ${{ inputs.cloudfront_url }}
cloudfront_staging_url: ${{ inputs.cloudfront_staging_url }}
rocm_version: ${{ inputs.rocm_version }}
tar_url: ${{ inputs.tar_url }}
jax_ref: ${{ matrix.jax_ref }}
ref: ${{ inputs.ref }}