diff --git a/ext/StridedGPUArraysExt.jl b/ext/StridedGPUArraysExt.jl index 608d8b5..5443e74 100644 --- a/ext/StridedGPUArraysExt.jl +++ b/ext/StridedGPUArraysExt.jl @@ -1,6 +1,6 @@ module StridedGPUArraysExt -using Strided, GPUArrays +using Strided, GPUArrays, LinearAlgebra using GPUArrays: Adapt, KernelAbstractions using GPUArrays.KernelAbstractions: @kernel, @index @@ -20,6 +20,14 @@ function Base.copy!(dst::AbstractArray{TD, ND}, src::StridedView{TS, NS, TAS, FS return dst end +function Base.copyto!(dest::StridedView{T, N, <:AnyGPUArray{T}}, bc::Base.Broadcast.Broadcasted{Strided.StridedArrayStyle{N}}) where {T <: Number, N} + dims = size(dest) + any(isequal(0), dims) && return dest + + GPUArrays._copyto!(dest, bc) + return dest +end + # lifted from GPUArrays.jl function Base.fill!(A::StridedView{T, N, TA, F}, x) where {T, N, TA <: AbstractGPUArray{T}, F <: ALL_FS} isempty(A) && return A @@ -34,7 +42,7 @@ function Base.fill!(A::StridedView{T, N, TA, F}, x) where {T, N, TA <: AbstractG return A end -function Strided.__mul!( +function LinearAlgebra.mul!( C::StridedView{TC, 2, <:AnyGPUArray{TC}}, A::StridedView{TA, 2, <:AnyGPUArray{TA}}, B::StridedView{TB, 2, <:AnyGPUArray{TB}}, diff --git a/test/amd.jl b/test/amd.jl index cc1c7a3..2152a8b 100644 --- a/test/amd.jl +++ b/test/amd.jl @@ -16,6 +16,12 @@ for T in (Float32, Float64, Complex{Float32}, Complex{Float64}) axes(f1(A1)) == axes(f2(A2)) || continue @test collect(ROCMatrix(copy!(f2(A2), f1(A1)))) == AMDGPU.Adapt.adapt(Vector{T}, copy!(B2, B1)) @test copy!(zA1, f1(A1)) == copy!(zA2, B1) + A3 = ROCArray(randn(T, (m1, m2))) + A3c = copy(A3) + B3 = f1(StridedView(A3c)) + @. B1 = 2 * B1 - B3 / 3 # test copyto! of Broadcasted + @. A1 = 2 * A1 - A3 / 3 # test copyto! of Broadcasted + @test AMDGPU.Adapt.adapt(Vector{T}, f1(A1)) == AMDGPU.Adapt.adapt(Vector{T}, B1) x = rand(T) @test f1(StridedView(AMDGPU.Adapt.adapt(Vector{T}, fill!(A1c, x)))) == AMDGPU.Adapt.adapt(Vector{T}, fill!(B1, x)) end diff --git a/test/cuda.jl b/test/cuda.jl index eb115d3..1ed05ae 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -12,6 +12,12 @@ for T in (Float32, Float64, Complex{Float32}, Complex{Float64}) axes(f1(A1)) == axes(f2(A2)) || continue @test collect(CuMatrix(copy!(f2(A2), f1(A1)))) == CUDA.Adapt.adapt(Vector{T}, copy!(B2, B1)) @test copy!(zA1, f1(A1)) == copy!(zA2, B1) + A3 = CuArray(randn(T, (m1, m2))) + A3c = copy(A3) + B3 = f1(StridedView(A3c)) + @. B1 = 2 * B1 - B3 / 3 # test copyto! of Broadcasted + @. A1 = 2 * A1 - A3 / 3 # test copyto! of Broadcasted + @test CUDA.Adapt.adapt(Vector{T}, f1(A1)) == CUDA.Adapt.adapt(Vector{T}, B1) x = rand(T) @test f1(StridedView(CUDA.Adapt.adapt(Vector{T}, fill!(A1c, x)))) == CUDA.Adapt.adapt(Vector{T}, fill!(B1, x)) end diff --git a/test/jlarrays.jl b/test/jlarrays.jl index 5aceb35..fa163f6 100644 --- a/test/jlarrays.jl +++ b/test/jlarrays.jl @@ -12,6 +12,12 @@ axes(f1(A1)) == axes(f2(A2)) || continue @test collect(Matrix(copy!(f2(A2), f1(A1)))) == JLArrays.Adapt.adapt(Vector{T}, copy!(B2, B1)) @test copy!(zA1, f1(A1)) == copy!(zA2, B1) + A3 = JLArray(randn(T, (m1, m2))) + A3c = copy(A3) + B3 = f1(StridedView(A3c)) + @. B1 = 2 * B1 - B3 / 3 # test copyto! of Broadcasted + @. A1 = 2 * A1 - A3 / 3 # test copyto! of Broadcasted + @test JLArrays.Adapt.adapt(Vector{T}, f1(A1)) == JLArrays.Adapt.adapt(Vector{T}, B1) x = rand(T) @test f1(StridedView(JLArrays.Adapt.adapt(Vector{T}, fill!(A1c, x)))) == JLArrays.Adapt.adapt(Vector{T}, fill!(B1, x)) end