From 83d20d446e56974d0367ba6d6cb94489203e1a90 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Sat, 21 Mar 2026 14:09:40 +0200 Subject: [PATCH 1/4] Add graph capture --- src/hip/HIP.jl | 2 + src/hip/graph.jl | 136 +++++++++++++++++++++++++++++++++++++++ src/hip/module.jl | 5 +- src/memory.jl | 7 +- t.jl | 51 +++++++++++++++ test/core/graph_tests.jl | 56 ++++++++++++++++ test/runtests.jl | 2 +- 7 files changed, 254 insertions(+), 5 deletions(-) create mode 100644 src/hip/graph.jl create mode 100644 t.jl create mode 100644 test/core/graph_tests.jl diff --git a/src/hip/HIP.jl b/src/hip/HIP.jl index 85998c97a..6303e440f 100644 --- a/src/hip/HIP.jl +++ b/src/hip/HIP.jl @@ -1,5 +1,6 @@ module HIP export HIPError, devices, device_synchronize, default_stream +export HIPGraph, @captured, capture, instantiate, update, is_capturing, launch using CEnum @@ -90,6 +91,7 @@ include("stream.jl") include("event.jl") include("pool.jl") include("module.jl") +include("graph.jl") """ Blocks until all kernels on all streams have completed. diff --git a/src/hip/graph.jl b/src/hip/graph.jl new file mode 100644 index 000000000..6f8259d30 --- /dev/null +++ b/src/hip/graph.jl @@ -0,0 +1,136 @@ +function unchecked_hipStreamEndCapture(stream, pGraph) + AMDGPU.prepare_state() + @gcsafe_ccall(libhip.hipStreamEndCapture(stream::hipStream_t, pGraph::Ptr{hipGraph_t})::hipError_t) +end + +mutable struct HIPGraph + handle::hipGraph_t + + function HIPGraph(flags = hipStreamCaptureModeGlobal) + handle_ref = Ref{hipGraph_t}() + hipGraphCreate(handle_ref, flags) + + obj = new(handle_ref[]) + finalizer(obj) do obj + hipGraphDestroy(obj) + end + return obj + end + + global function capture(f::Function; flags = hipStreamCaptureModeGlobal, throw_error::Bool = true) + gc_state = GC.enable(false) + stream = AMDGPU.stream() + try + hipStreamBeginCapture(stream, flags) + f() + finally + handle_ref = Ref{hipGraph_t}() + st = unchecked_hipStreamEndCapture(stream, handle_ref) + GC.enable(gc_state) + + if st == hipErrorStreamCaptureInvalidated && !throw_error + return nothing + elseif st != hipSuccess + throw(HIPError(st)) + end + + obj = new(handle_ref[]) + finalizer(hipGraphDestroy, obj) + return obj + end + return nothing + end +end + +Base.unsafe_convert(::Type{hipGraph_t}, graph::HIPGraph) = graph.handle + +mutable struct HIPGraphExec + handle::hipGraphExec_t + + global function instantiate(graph::HIPGraph) + handle_ref = Ref{hipGraphExec_t}() + hipGraphInstantiateWithFlags(handle_ref, graph, 0) + obj = new(handle_ref[]) + + finalizer(obj) do obj + hipGraphExecDestroy(obj) + end + return obj + end +end + +Base.unsafe_convert(::Type{hipGraphExec_t}, exec::HIPGraphExec) = exec.handle + +launch(exec::HIPGraphExec, stream::HIPStream = AMDGPU.stream()) = hipGraphLaunch(exec, stream) + +function update(exec::HIPGraphExec, graph::HIPGraph; throw_error::Bool = true) + error_node = Ref{hipGraphNode_t}() + update_res_ref = Ref{hipGraphExecUpdateResult}() + hipGraphExecUpdate(exec, graph, error_node, update_res_ref) + + update_res = update_res_ref[] + if update_res != hipGraphExecUpdateSuccess + throw_error && error("Failed to update HIPGraphExec: `$(update_res)`.") + return false + end + return true +end + +function capture_status(stream::HIPStream) + status_ref = Ref{hipStreamCaptureStatus}() + id_ref = Ref{Culonglong}() + hipStreamGetCaptureInfo(stream, status_ref, id_ref) + status = status_ref[] + return (; status, id=(status == hipStreamCaptureStatusActive) ? id_ref[] : nothing) +end + +is_capturing(stream::HIPStream = AMDGPU.stream()) = + capture_status(stream).status == hipStreamCaptureStatusActive + +macro captured(ex) + @gensym exec + @eval __module__ begin + const $exec = Ref{$HIPGraphExec}() + end + quote + executed = false + GC.enable(false) + graph = try + capture(; throw_error=false) do + $(esc(ex)) + end + finally + GC.enable(true) + end + + if graph === nothing + # if the capture failed, this may have been due to JIT compilation. + # execute the body out of capture, and try capturing again. + $(esc(ex)) + + # don't tolerate capture failures now so that the user will be informed + GC.enable(false) + graph = try + capture() do + $(esc(ex)) + end + catch + rethrow() + finally + GC.enable(true) + end + executed = true + end + + # TODO updating should be done manually by users. + # Update or instantiate. + # if !isassigned($(esc(exec))) || !update($(esc(exec))[], graph; throw_error=false) + # $(esc(exec))[] = instantiate(graph) + # end + + # when allocation nodes are present on AMD ROCm — always reinstantiate for now. + $(esc(exec))[] = instantiate(graph) + executed || launch($(esc(exec))[]) + $(esc(exec))[] + end +end diff --git a/src/hip/module.jl b/src/hip/module.jl index 8bbdda00c..88d0f02b5 100644 --- a/src/hip/module.jl +++ b/src/hip/module.jl @@ -2,7 +2,10 @@ mutable struct HIPModule handle::hipModule_t function HIPModule(data) - device_synchronize() + # During stream capture no GPU work is actually executing, so syncing + # would call hipStreamQuery on a capturing stream, which returns + # hipErrorStreamCaptureUnsupported and invalidates the capture. + is_capturing() || device_synchronize() mod_ref = Ref{hipModule_t}() hipModuleLoadData(mod_ref, data) diff --git a/src/memory.jl b/src/memory.jl index 8a40d1b0e..c864b75cb 100644 --- a/src/memory.jl +++ b/src/memory.jl @@ -409,9 +409,10 @@ mutable struct Managed{M} const mem::M stream::HIPStream dirty::Bool + captured::Bool - function Managed(mem; stream=AMDGPU.stream(), dirty=true) - new{typeof(mem)}(mem, stream, dirty) + function Managed(mem; stream=AMDGPU.stream(), dirty=true, captured=false) + new{typeof(mem)}(mem, stream, dirty, captured) end end @@ -472,7 +473,7 @@ function pool_alloc(::Type{B}, bytesize) where B maybe_collect() time = Base.@elapsed begin s = AMDGPU.stream() - managed = Managed(B(bytesize; stream=s); stream=s) + managed = Managed(B(bytesize; stream=s); stream=s, captured=AMDGPU.is_capturing()) end Base.@atomic alloc_stats.alloc_count += 1 diff --git a/t.jl b/t.jl new file mode 100644 index 000000000..6167ab51d --- /dev/null +++ b/t.jl @@ -0,0 +1,51 @@ +using AMDGPU +using GPUArrays + +# Notes: +# - if function contains malloc & respective free calls -> can just relaunch graph. +# - if only malloc, but no free -> capture allocs with AllocCache first -> then capture graph itself. +# - if rand calls -> call rand before capture to init RNG. +# +# - updating graph, does not update malloc addresses, so is not supported, only instantiation. +# - TODO write cases when updating makes sense: e.g. changing `.+ 1f0` to `.+ 2f0`. + +# function f(o) +# x = AMDGPU.rand(Float32, size(o)) +# y = AMDGPU.rand(Float32, size(o)) +# z = x * y +# o .+= z .+ 1f0 +# AMDGPU.unsafe_free!(x) +# AMDGPU.unsafe_free!(y) +# AMDGPU.unsafe_free!(z) +# return +# end + +function f(o) + x = AMDGPU.rand(Float32, size(o)) + y = AMDGPU.rand(Float32, size(o)) + o .+= x * y .+ 1f0 + return +end + +function main() + cache = GPUArrays.AllocCache() + z = AMDGPU.zeros(Float32, 4, 4) + + GPUArrays.@cached cache begin + f(z) + end + + # g = AMDGPU.@captured begin + g = GPUArrays.@cached cache AMDGPU.@captured begin + f(z) + end + display(z); println() + + for i in 1:10 + AMDGPU.launch(g) + display(z); println() + end + + return +end +main() diff --git a/test/core/graph_tests.jl b/test/core/graph_tests.jl new file mode 100644 index 000000000..7c2db67f3 --- /dev/null +++ b/test/core/graph_tests.jl @@ -0,0 +1,56 @@ +using Test +using AMDGPU +using GPUArrays + +@testset "HIP Graphs" begin + @testset "+1" begin + z = AMDGPU.zeros(Int, 4, 4) + f!(o) = o .+= one(eltype(o)) + + graph = AMDGPU.@captured f!(z) + @test sum(z) == 16 + + AMDGPU.launch(graph) + @test sum(z) == 16 * 2 + AMDGPU.launch(graph) + @test sum(z) == 16 * 3 + end + + @testset "malloc/free" begin + z = AMDGPU.zeros(Int, 4, 4) + function f!(o) + x = AMDGPU.ones(eltype(o), size(o)) + o .+= x .+ one(eltype(o)) + AMDGPU.unsafe_free!(x) + end + + graph = AMDGPU.@captured f!(z) + @test sum(z) == 32 + + AMDGPU.launch(graph) + @test sum(z) == 32 * 2 + AMDGPU.launch(graph) + @test sum(z) == 32 * 3 + end + + @testset "only malloc + alloc cache" begin + z = AMDGPU.zeros(Int, 4, 4) + function f!(o) + x = AMDGPU.ones(eltype(o), size(o)) + y = AMDGPU.ones(eltype(o), size(o)) + o .+= (x * y) .+ one(eltype(o)) + end + + cache = GPUArrays.AllocCache() + # Pre-populate alloc cache, to avoid malloc calls during capture. + GPUArrays.@cached cache f!(z) + # Capture with alloc cache. + graph = GPUArrays.@cached cache AMDGPU.@captured f!(z) + @test sum(z) == length(z) * 5 * 2 + + AMDGPU.launch(graph) + @test sum(z) == length(z) * 5 * 3 + AMDGPU.launch(graph) + @test sum(z) == length(z) * 5 * 4 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index a1829c758..4ace51b73 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,7 +24,6 @@ end @info "System information:\n" InteractiveUtils.versioninfo() - AMDGPU.versioninfo() # Autodiscovered tests @@ -39,6 +38,7 @@ include(gpuarrays_testsuite) for name in keys(TestSuite.tests) testsuite["gpuarrays/$name"] = :(TestSuite.tests[$name](AMDGPU.ROCArray)) end +@info "Available tests: `$(keys(testsuite))`." args = parse_args(ARGS) From d4b3767390883999d69d92111ca700510763eb62 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Sun, 22 Mar 2026 00:36:38 +0200 Subject: [PATCH 2/4] Add tests & update docs --- docs/Project.toml | 1 + docs/make.jl | 1 + docs/src/api/graphs.md | 61 +++++++++++++++++++++++++++ docs/src/tutorials/profiling.md | 9 ++-- src/hip/HIP.jl | 2 +- src/hip/graph.jl | 75 ++++++++++++++++++++++++--------- t.jl | 51 ---------------------- test/core/graph_tests.jl | 22 ++++++++-- 8 files changed, 142 insertions(+), 80 deletions(-) create mode 100644 docs/src/api/graphs.md delete mode 100644 t.jl diff --git a/docs/Project.toml b/docs/Project.toml index 99f6fc13c..c4a9f43e0 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -2,6 +2,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +LiveServer = "16fef848-5104-11e9-1b77-fb7a48bbb589" SIMD = "fdea26ae-647d-5447-a871-4b548cad5224" [compat] diff --git a/docs/make.jl b/docs/make.jl index 886bdb456..168fb3c8b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -41,6 +41,7 @@ function main() "Devices" => "api/devices.md", "Streams" => "api/streams.md", "Kernel Programming" => "api/kernel_programming.md", + "Graphs" => "api/graphs.md", "Exceptions" => "api/exceptions.md", "Memory" => "api/memory.md", "Host-Call" => "api/hostcall.md", diff --git a/docs/src/api/graphs.md b/docs/src/api/graphs.md new file mode 100644 index 000000000..747b28b52 --- /dev/null +++ b/docs/src/api/graphs.md @@ -0,0 +1,61 @@ +# Graphs + +[Graphs](https://rocm.docs.amd.com/projects/HIP/en/latest/how-to/hip_runtime_api/hipgraph.html) +allow capturing GPU kernels and executing them as one unit, reducing host overhead. + +Simple operations can be captured as is: + +```@example graph-1 +using AMDGPU + +f!(o) = o .+= 1f0 + +z = AMDGPU.zeros(Int, 4, 4) +graph = AMDGPU.@captured f!(z) +@assert sum(z) == 16 + +AMDGPU.launch(graph) +@assert sum(z) == 16 * 2 +``` + +However, if your code contains more complex flow, it requires more preparations: +- if code contains malloc and respective frees, then it can be captured and relaunched as is. +- if code contains **only** allocations (without freeing), allocations must be cached with `GPUArrays.@cached` beforehand (see example below). +- other unsupported operations (e.g. RNG init) must be done beforehand as well. +- updating graph, does not update allocated pointers, only instantiation is supported in such cases. + +```@example graph-2 +using AMDGPU, GPUArrays + +function f(o) + x = AMDGPU.rand(Float32, size(o)) + y = AMDGPU.rand(Float32, size(o)) + o .+= sin.(x) * cos.(y) .+ 1f0 + return +end + +cache = GPUArrays.AllocCache() +z = AMDGPU.zeros(Float32, 256, 256) +N = 10 + +# Execute function normally and cache all allocations. +GPUArrays.@cached cache f(z) + +# Capture graph using AllocCache to avoid capturing malloc/free calls. +graph = GPUArrays.@cached cache AMDGPU.@captured f(z) + +# Allocations cache must be kept alive while executing graph. +for i in 1:N + AMDGPU.launch(graph) +end +AMDGPU.synchronize() +``` + +```@docs +AMDGPU.capture +AMDGPU.@captured +AMDGPU.instantiate +AMDGPU.update +AMDGPU.is_capturing +AMDGPU.launch +``` diff --git a/docs/src/tutorials/profiling.md b/docs/src/tutorials/profiling.md index fba978fbb..affe3d445 100644 --- a/docs/src/tutorials/profiling.md +++ b/docs/src/tutorials/profiling.md @@ -2,8 +2,8 @@ ## rocprof -[rocprofv2](https://github.com/ROCm/rocprofiler?tab=readme-ov-file#rocprofiler-v2) -allows profiling both HSA & HIP API calls (rocprof being deprecated). +[rocprofv3](https://rocm.docs.amd.com/projects/rocprofiler-sdk/en/latest/how-to/using-rocprofv3.html) +allows profiling both HSA & HIP API calls. Let's profile simple copying kernel saved in `profile.jl` file: ```julia @@ -39,11 +39,10 @@ main(2^24) ### Profiling problematic code ```bash -ENABLE_JITPROFILING=1 rocprofv2 --plugin perfetto --hip-trace --hsa-trace --kernel-trace -o prof julia ./profile.jl +ENABLE_JITPROFILING=1 rocprofv3 --output-directory ./profiling --output-format pftrace --hip-trace --hsa-trace --kernel-trace -- julia ./profile.jl ``` -This will produce `prof_output.pftrace` file which can be visualized -using [Perfetto UI](https://ui.perfetto.dev/). +This will produce `.pftrace` file which can be visualized using [Perfetto UI](https://ui.perfetto.dev/). ![image](../assets/profile_1.png) diff --git a/src/hip/HIP.jl b/src/hip/HIP.jl index 6303e440f..13d1b4a4e 100644 --- a/src/hip/HIP.jl +++ b/src/hip/HIP.jl @@ -1,6 +1,6 @@ module HIP export HIPError, devices, device_synchronize, default_stream -export HIPGraph, @captured, capture, instantiate, update, is_capturing, launch +export HIPGraph, HIPGraphExec, @captured, capture, instantiate, update, is_capturing, launch using CEnum diff --git a/src/hip/graph.jl b/src/hip/graph.jl index 6f8259d30..41557d578 100644 --- a/src/hip/graph.jl +++ b/src/hip/graph.jl @@ -1,3 +1,18 @@ +""" + instantiate(graph::HIPGraph)::HIPGraphExec + +Instantiate captured graph making it executable with [`launch`](@ref). +""" +instantiate + +""" + capture(f::Function; flags = hipStreamCaptureModeGlobal, throw_error::Bool = true)::Union{Nothing, HIPGraph} + +Capture fiven function `f` to a graph. +If successful, returns a captured graph that needs to be [`instantiate`](@ref)'d to obtain executable graph. +""" +capture + function unchecked_hipStreamEndCapture(stream, pGraph) AMDGPU.prepare_state() @gcsafe_ccall(libhip.hipStreamEndCapture(stream::hipStream_t, pGraph::Ptr{hipGraph_t})::hipError_t) @@ -17,7 +32,7 @@ mutable struct HIPGraph return obj end - global function capture(f::Function; flags = hipStreamCaptureModeGlobal, throw_error::Bool = true) + global function capture(f::Function; flags = hipStreamCaptureModeGlobal, throw_error::Bool = true)::Union{Nothing, HIPGraph} gc_state = GC.enable(false) stream = AMDGPU.stream() try @@ -61,9 +76,24 @@ end Base.unsafe_convert(::Type{hipGraphExec_t}, exec::HIPGraphExec) = exec.handle -launch(exec::HIPGraphExec, stream::HIPStream = AMDGPU.stream()) = hipGraphLaunch(exec, stream) +""" + launch(exec::HIPGraphExec, stream::HIPStream = AMDGPU.stream()) + +Launch executable graph on a given stream. +""" +function launch(exec::HIPGraphExec, stream::HIPStream = AMDGPU.stream()) + hipGraphLaunch(exec, stream) +end + +""" + update(exec::HIPGraphExec, graph::HIPGraph; throw_error::Bool = true)::Bool + +Given executable graph, perform update with graph. +Return `true` if successful, `false` otherwise. -function update(exec::HIPGraphExec, graph::HIPGraph; throw_error::Bool = true) +If `throw_error=false` allows avoiding throwing an exception if update was not successful. +""" +function update(exec::HIPGraphExec, graph::HIPGraph; throw_error::Bool = true)::Bool error_node = Ref{hipGraphNode_t}() update_res_ref = Ref{hipGraphExecUpdateResult}() hipGraphExecUpdate(exec, graph, error_node, update_res_ref) @@ -84,14 +114,26 @@ function capture_status(stream::HIPStream) return (; status, id=(status == hipStreamCaptureStatusActive) ? id_ref[] : nothing) end -is_capturing(stream::HIPStream = AMDGPU.stream()) = +""" + is_capturing(stream::HIPStream = AMDGPU.stream())::Bool + +For a given `stream` check if capturing for a graph is performed. +""" +function is_capturing(stream::HIPStream = AMDGPU.stream())::Bool capture_status(stream).status == hipStreamCaptureStatusActive +end -macro captured(ex) - @gensym exec - @eval __module__ begin - const $exec = Ref{$HIPGraphExec}() +""" + graph = AMDGPU.@captured begin + # code to capture in a graph. end + +Macro to capture a given expression in a graph & execute it. +Returns captured graph, that can be relaunched with [`launch`](@ref) or updated with [`update`](@ref). + +If capture fails (e.g. due to JIT), attempts recovery, compilation and re-capture. +""" +macro captured(ex) quote executed = false GC.enable(false) @@ -104,11 +146,11 @@ macro captured(ex) end if graph === nothing - # if the capture failed, this may have been due to JIT compilation. + # If the capture failed, this may have been due to JIT compilation. # execute the body out of capture, and try capturing again. $(esc(ex)) - # don't tolerate capture failures now so that the user will be informed + # Don't tolerate capture failures now so that the user will be informed. GC.enable(false) graph = try capture() do @@ -122,15 +164,8 @@ macro captured(ex) executed = true end - # TODO updating should be done manually by users. - # Update or instantiate. - # if !isassigned($(esc(exec))) || !update($(esc(exec))[], graph; throw_error=false) - # $(esc(exec))[] = instantiate(graph) - # end - - # when allocation nodes are present on AMD ROCm — always reinstantiate for now. - $(esc(exec))[] = instantiate(graph) - executed || launch($(esc(exec))[]) - $(esc(exec))[] + exec = instantiate(graph) + executed || launch(exec) + exec end end diff --git a/t.jl b/t.jl deleted file mode 100644 index 6167ab51d..000000000 --- a/t.jl +++ /dev/null @@ -1,51 +0,0 @@ -using AMDGPU -using GPUArrays - -# Notes: -# - if function contains malloc & respective free calls -> can just relaunch graph. -# - if only malloc, but no free -> capture allocs with AllocCache first -> then capture graph itself. -# - if rand calls -> call rand before capture to init RNG. -# -# - updating graph, does not update malloc addresses, so is not supported, only instantiation. -# - TODO write cases when updating makes sense: e.g. changing `.+ 1f0` to `.+ 2f0`. - -# function f(o) -# x = AMDGPU.rand(Float32, size(o)) -# y = AMDGPU.rand(Float32, size(o)) -# z = x * y -# o .+= z .+ 1f0 -# AMDGPU.unsafe_free!(x) -# AMDGPU.unsafe_free!(y) -# AMDGPU.unsafe_free!(z) -# return -# end - -function f(o) - x = AMDGPU.rand(Float32, size(o)) - y = AMDGPU.rand(Float32, size(o)) - o .+= x * y .+ 1f0 - return -end - -function main() - cache = GPUArrays.AllocCache() - z = AMDGPU.zeros(Float32, 4, 4) - - GPUArrays.@cached cache begin - f(z) - end - - # g = AMDGPU.@captured begin - g = GPUArrays.@cached cache AMDGPU.@captured begin - f(z) - end - display(z); println() - - for i in 1:10 - AMDGPU.launch(g) - display(z); println() - end - - return -end -main() diff --git a/test/core/graph_tests.jl b/test/core/graph_tests.jl index 7c2db67f3..381105d18 100644 --- a/test/core/graph_tests.jl +++ b/test/core/graph_tests.jl @@ -4,9 +4,9 @@ using GPUArrays @testset "HIP Graphs" begin @testset "+1" begin - z = AMDGPU.zeros(Int, 4, 4) f!(o) = o .+= one(eltype(o)) + z = AMDGPU.zeros(Int, 4, 4) graph = AMDGPU.@captured f!(z) @test sum(z) == 16 @@ -17,13 +17,13 @@ using GPUArrays end @testset "malloc/free" begin - z = AMDGPU.zeros(Int, 4, 4) function f!(o) x = AMDGPU.ones(eltype(o), size(o)) o .+= x .+ one(eltype(o)) AMDGPU.unsafe_free!(x) end + z = AMDGPU.zeros(Int, 4, 4) graph = AMDGPU.@captured f!(z) @test sum(z) == 32 @@ -34,13 +34,13 @@ using GPUArrays end @testset "only malloc + alloc cache" begin - z = AMDGPU.zeros(Int, 4, 4) function f!(o) x = AMDGPU.ones(eltype(o), size(o)) y = AMDGPU.ones(eltype(o), size(o)) o .+= (x * y) .+ one(eltype(o)) end + z = AMDGPU.zeros(Int, 4, 4) cache = GPUArrays.AllocCache() # Pre-populate alloc cache, to avoid malloc calls during capture. GPUArrays.@cached cache f!(z) @@ -53,4 +53,20 @@ using GPUArrays AMDGPU.launch(graph) @test sum(z) == length(z) * 5 * 4 end + + @testset "Update graph" begin + f1!(o) = o .+= 1f0 + f2!(o) = o .+= 2f0 + + z = AMDGPU.zeros(Int, 4, 4) + graph = AMDGPU.@captured f1!(z) + @test sum(z) == 16 + + g_new = AMDGPU.capture() do + f2!(z) + end + AMDGPU.update(graph, g_new) + AMDGPU.launch(graph) + @test sum(z) == 16 * 3 + end end From b68d800562eb7daff586c8e2c8477df3b33c2349 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Sun, 22 Mar 2026 01:19:30 +0200 Subject: [PATCH 3/4] update test --- test/core/graph_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/core/graph_tests.jl b/test/core/graph_tests.jl index 381105d18..e2c355cd8 100644 --- a/test/core/graph_tests.jl +++ b/test/core/graph_tests.jl @@ -65,7 +65,7 @@ using GPUArrays g_new = AMDGPU.capture() do f2!(z) end - AMDGPU.update(graph, g_new) + @test AMDGPU.update(graph, g_new) AMDGPU.launch(graph) @test sum(z) == 16 * 3 end From 34be28d9b28aede70a76c7f50bcf939c017e900f Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Mon, 23 Mar 2026 21:55:05 +0200 Subject: [PATCH 4/4] Avoid invoking hostcalls (not supported with graph capture) --- docs/src/api/graphs.md | 3 ++- test/core/graph_tests.jl | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/src/api/graphs.md b/docs/src/api/graphs.md index 747b28b52..f5914af48 100644 --- a/docs/src/api/graphs.md +++ b/docs/src/api/graphs.md @@ -8,7 +8,7 @@ Simple operations can be captured as is: ```@example graph-1 using AMDGPU -f!(o) = o .+= 1f0 +f!(o) = o .+= one(eltype(o)) z = AMDGPU.zeros(Int, 4, 4) graph = AMDGPU.@captured f!(z) @@ -19,6 +19,7 @@ AMDGPU.launch(graph) ``` However, if your code contains more complex flow, it requires more preparations: +- code **must not** result in hostcall invokation. - if code contains malloc and respective frees, then it can be captured and relaunched as is. - if code contains **only** allocations (without freeing), allocations must be cached with `GPUArrays.@cached` beforehand (see example below). - other unsupported operations (e.g. RNG init) must be done beforehand as well. diff --git a/test/core/graph_tests.jl b/test/core/graph_tests.jl index e2c355cd8..dde0fd1fd 100644 --- a/test/core/graph_tests.jl +++ b/test/core/graph_tests.jl @@ -55,8 +55,8 @@ using GPUArrays end @testset "Update graph" begin - f1!(o) = o .+= 1f0 - f2!(o) = o .+= 2f0 + f1!(o) = o .+= one(eltype(o)) + f2!(o) = o .+= eltype(o)(2) z = AMDGPU.zeros(Int, 4, 4) graph = AMDGPU.@captured f1!(z)