@@ -17,6 +17,10 @@ function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::Abstr
1717 return TensorKit. TensorMapWithStorage {T, A} (A (h_t. data), V)
1818end
1919
20+ function TensorKit. blocktype (:: Type{<:CuTensorMap{T, S}} ) where {T, S}
21+ return CuMatrix{T, CUDA. DeviceMemory}
22+ end
23+
2024for (fname, felt) in ((:zeros , :zero ), (:ones , :one ))
2125 @eval begin
2226 function CUDA. $fname (
@@ -101,18 +105,6 @@ function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S}
101105 return isempty (inds) ? zero (scalartype (t)) : @allowscalar @inbounds t. data[only (inds)]
102106end
103107
104- function Base. convert (
105- TT:: Type{CuTensorMap{T, S, N₁, N₂}} ,
106- t:: AbstractTensorMap{<:Any, S, N₁, N₂}
107- ) where {T, S, N₁, N₂}
108- if typeof (t) === TT
109- return t
110- else
111- tnew = TT (undef, space (t))
112- return copy! (tnew, t)
113- end
114- end
115-
116108function LinearAlgebra. isposdef (t:: CuTensorMap )
117109 domain (t) == codomain (t) ||
118110 throw (SpaceMismatch (" `isposdef` requires domain and codomain to be the same" ))
@@ -138,10 +130,9 @@ function Base.promote_rule(
138130 return CuTensorMap{T, S, N₁, N₂}
139131end
140132
141- TensorKit. promote_storage_rule (:: Type{CuArray{T, N}} , :: Type{<:CuArray{T, N}} ) where {T, N} =
133+ TensorKit. promote_storage_rule (:: Type{<: CuArray{T, N}} , :: Type{<:CuArray{T, N}} ) where {T, N} =
142134 CuArray{T, N, CUDA. default_memory}
143135
144-
145136# CuTensorMap exponentation:
146137function TensorKit. exp! (t:: CuTensorMap )
147138 domain (t) == codomain (t) ||
@@ -168,3 +159,30 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
168159 return tf
169160 end
170161end
162+
163+ function TensorKit. _add_general_kernel_nonthreaded! (
164+ tdst:: CuTensorMap , tsrc:: CuTensorMap , p, transformer:: TensorKit.GenericTreeTransformer , α, β, backend...
165+ )
166+ # preallocate buffers
167+ buffers = TensorKit. allocate_buffers (tdst, tsrc, transformer)
168+
169+ for subtransformer in transformer. data
170+ # Special case without intermediate buffers whenever there is only a single block
171+ if length (subtransformer[1 ]) == 1
172+ TensorKit. _add_transform_single! (tdst, tsrc, p, subtransformer, α, β, backend... )
173+ else
174+ cu_subtransformer = tuple (CUDA. adapt (CuArray, subtransformer[1 ]), subtransformer[2 : end ]. .. )
175+ TensorKit. _add_transform_multi! (tdst, tsrc, p, cu_subtransformer, buffers, α, β, backend... )
176+ end
177+ end
178+ return nothing
179+ end
180+
181+ function TensorKit. allocate_buffers (
182+ tdst:: CuTensorMap , tsrc:: CuTensorMap , transformer:: TensorKit.GenericTreeTransformer
183+ )
184+ sz = TensorKit. buffersize (transformer)
185+ # force zeros to ensure the buffers are empty
186+ # otherwise memory re-use can fill them with garbage data
187+ return CUDA. zeros (eltype (tdst. data), sz), CUDA. zeros (eltype (tsrc. data), sz)
188+ end
0 commit comments