Skip to content

Commit 2359d28

Browse files
committed
More tweaks
1 parent 66187bb commit 2359d28

File tree

6 files changed

+134
-125
lines changed

6 files changed

+134
-125
lines changed

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorKit"
22
uuid = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
3-
authors = ["Jutho Haegeman, Lukas Devos"]
43
version = "0.16.3"
4+
authors = ["Jutho Haegeman, Lukas Devos"]
55

66
[deps]
77
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
@@ -53,7 +53,7 @@ Printf = "1"
5353
Random = "1"
5454
SafeTestsets = "0.1"
5555
ScopedValues = "1.3.0"
56-
Strided = "2"
56+
Strided = "2.3.4"
5757
TensorKitSectors = "0.3.5"
5858
TensorOperations = "5.1"
5959
Test = "1"
@@ -87,3 +87,6 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
8787

8888
[targets]
8989
test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake", "JET"]
90+
91+
[sources]
92+
Strided = {url = "https://github.com/QuantumKitHub/Strided.jl", rev = "ksh/copyto"}

ext/TensorKitCUDAExt/TensorKitCUDAExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using TensorKit.Factorizations
1010
using TensorKit.Strided
1111
using TensorKit.Factorizations: AbstractAlgorithm
1212
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check
13-
import TensorKit: randisometry, rand, randn
13+
import TensorKit: randisometry, rand, randn, _copyto!, _add_general_kernel_nonthreaded!, blocktype
1414

1515
using TensorKit: MatrixAlgebraKit
1616

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::Abstr
1717
return TensorKit.TensorMapWithStorage{T, A}(A(h_t.data), V)
1818
end
1919

20+
function TensorKit.blocktype(::Type{<:CuTensorMap{T, S}}) where {T, S}
21+
return CuMatrix{T, CUDA.DeviceMemory}
22+
end
23+
2024
for (fname, felt) in ((:zeros, :zero), (:ones, :one))
2125
@eval begin
2226
function CUDA.$fname(
@@ -101,18 +105,6 @@ function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S}
101105
return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)]
102106
end
103107

104-
function Base.convert(
105-
TT::Type{CuTensorMap{T, S, N₁, N₂}},
106-
t::AbstractTensorMap{<:Any, S, N₁, N₂}
107-
) where {T, S, N₁, N₂}
108-
if typeof(t) === TT
109-
return t
110-
else
111-
tnew = TT(undef, space(t))
112-
return copy!(tnew, t)
113-
end
114-
end
115-
116108
function LinearAlgebra.isposdef(t::CuTensorMap)
117109
domain(t) == codomain(t) ||
118110
throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same"))
@@ -138,10 +130,9 @@ function Base.promote_rule(
138130
return CuTensorMap{T, S, N₁, N₂}
139131
end
140132

141-
TensorKit.promote_storage_rule(::Type{CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} =
133+
TensorKit.promote_storage_rule(::Type{<:CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} =
142134
CuArray{T, N, CUDA.default_memory}
143135

144-
145136
# CuTensorMap exponentation:
146137
function TensorKit.exp!(t::CuTensorMap)
147138
domain(t) == codomain(t) ||
@@ -168,3 +159,30 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
168159
return tf
169160
end
170161
end
162+
163+
function TensorKit._add_general_kernel_nonthreaded!(
164+
tdst::CuTensorMap, tsrc::CuTensorMap, p, transformer::TensorKit.GenericTreeTransformer, α, β, backend...
165+
)
166+
# preallocate buffers
167+
buffers = TensorKit.allocate_buffers(tdst, tsrc, transformer)
168+
169+
for subtransformer in transformer.data
170+
# Special case without intermediate buffers whenever there is only a single block
171+
if length(subtransformer[1]) == 1
172+
TensorKit._add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...)
173+
else
174+
cu_subtransformer = tuple(CUDA.adapt(CuArray, subtransformer[1]), subtransformer[2:end]...)
175+
TensorKit._add_transform_multi!(tdst, tsrc, p, cu_subtransformer, buffers, α, β, backend...)
176+
end
177+
end
178+
return nothing
179+
end
180+
181+
function TensorKit.allocate_buffers(
182+
tdst::CuTensorMap, tsrc::CuTensorMap, transformer::TensorKit.GenericTreeTransformer
183+
)
184+
sz = TensorKit.buffersize(transformer)
185+
# force zeros to ensure the buffers are empty
186+
# otherwise memory re-use can fill them with garbage data
187+
return CUDA.zeros(eltype(tdst.data), sz), CUDA.zeros(eltype(tsrc.data), sz)
188+
end

src/tensors/abstracttensor.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,11 @@ storagetype(t) = storagetype(typeof(t))
5353
function storagetype(::Type{T}) where {T <: AbstractTensorMap}
5454
if T isa Union
5555
# attempt to be slightly more specific by promoting unions
56-
Ma = storagetype(T.a)
57-
Mb = storagetype(T.b)
58-
return promote_storagetype(Ma, Mb)
56+
return promote_storagetype(T.a, T.b)
57+
elseif eltype(T) isa Union
58+
# attempt to be slightly more specific by promoting unions
59+
TU = eltype(T)
60+
return promote_storagetype(TU.a, TU.b)
5961
else
6062
# fallback definition by using scalartype
6163
return similarstoragetype(scalartype(T))
@@ -103,8 +105,9 @@ similarstoragetype(X::Type, ::Type{T}) where {T <: Number} =
103105

104106
# implement on tensors
105107
similarstoragetype(::Type{TT}) where {TT <: AbstractTensorMap} = similarstoragetype(storagetype(TT))
106-
similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} =
107-
similarstoragetype(storagetype(TT), T)
108+
function similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number}
109+
return similarstoragetype(storagetype(TT), T)
110+
end
108111

109112
# implement on arrays
110113
similarstoragetype(::Type{A}) where {A <: DenseVector{<:Number}} = A

src/tensors/braidingtensor.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,15 @@ end
171171
has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false
172172
function add_transform!(
173173
tdst::AbstractTensorMap,
174-
tsrc::BraidingTensor, (p₁, p₂)::Index2Tuple,
174+
tsrc::BraidingTensor{T, S},
175+
(p₁, p₂)::Index2Tuple,
175176
fusiontreetransform,
176177
α::Number, β::Number, backend::AbstractBackend...
177-
)
178+
) where {T, S}
179+
tsrc_map = similar(tdst, storagetype(tdst), space(tsrc))
180+
copy!(tsrc_map, tsrc)
178181
return add_transform!(
179-
tdst, TensorMap(tsrc), (p₁, p₂), fusiontreetransform, α, β,
182+
tdst, tsrc_map, (p₁, p₂), fusiontreetransform, α, β,
180183
backend...
181184
)
182185
end

0 commit comments

Comments
 (0)