Skip to content

Commit 3b9efee

Browse files
committed
MAK v0.6.5 updates (#390)
* add TruncationUnion implementation * add projection mooncake rules * Ensure `blocktype` is correctly inferred for CuArray * mark tests as no longer broken * type stability improvements type stability improvements
1 parent 66187bb commit 3b9efee

8 files changed

Lines changed: 97 additions & 22 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ GPUArrays = "11.3.1"
4646
JET = "0.9, 0.10, 0.11"
4747
LRUCache = "1.0.2"
4848
LinearAlgebra = "1"
49-
MatrixAlgebraKit = "0.6.4"
49+
MatrixAlgebraKit = "0.6.5"
5050
Mooncake = "0.5"
5151
OhMyThreads = "0.8.0"
5252
Printf = "1"

ext/TensorKitMooncakeExt/linalg.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,48 @@ function Mooncake.rrule!!(::CoDual{typeof(inv)}, A_ΔA::CoDual{<:AbstractTensorM
8686

8787
return Ainv_ΔAinv, inv_pullback
8888
end
89+
90+
# single-output projections: project_hermitian!, project_antihermitian!
91+
for (f!, f, adj) in (
92+
(:project_hermitian!, :project_hermitian, :project_hermitian_adjoint),
93+
(:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint),
94+
)
95+
@eval begin
96+
function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual{<:AbstractTensorMap}, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
97+
A, dA = arrayify(A_dA)
98+
arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg)
99+
100+
# don't need to copy/restore A since projections don't mutate input
101+
argc = copy(arg)
102+
arg = $f!(A, arg, Mooncake.primal(alg_dalg))
103+
104+
function $adj(::NoRData)
105+
$f!(darg)
106+
if dA !== darg
107+
add!(dA, darg)
108+
MatrixAlgebraKit.zero!(darg)
109+
end
110+
copy!(arg, argc)
111+
return ntuple(Returns(NoRData()), 4)
112+
end
113+
114+
return arg_darg, $adj
115+
end
116+
117+
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractTensorMap}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
118+
A, dA = arrayify(A_dA)
119+
output = $f(A, Mooncake.primal(alg_dalg))
120+
output_doutput = Mooncake.zero_fcodual(output)
121+
122+
doutput = last(arrayify(output_doutput))
123+
function $adj(::NoRData)
124+
# TODO: need accumulating projection to avoid intermediate here
125+
add!(dA, $f(doutput))
126+
MatrixAlgebraKit.zero!(doutput)
127+
return ntuple(Returns(NoRData()), 3)
128+
end
129+
130+
return output_doutput, $adj
131+
end
132+
end
133+
end

src/factorizations/factorizations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ using TensorOperations: Index2Tuple
1919
using MatrixAlgebraKit
2020
import MatrixAlgebraKit as MAK
2121
using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, DiagonalAlgorithm
22-
using MatrixAlgebraKit: TruncationStrategy, NoTruncation, TruncationByValue,
23-
TruncationByError, TruncationIntersection, TruncationByFilter, TruncationByOrder
22+
using MatrixAlgebraKit: TruncationStrategy, NoTruncation, TruncationByValue, TruncationByError,
23+
TruncationIntersection, TruncationUnion, TruncationByFilter, TruncationByOrder
2424
using MatrixAlgebraKit: diagview
2525

2626
include("utility.jl")

src/factorizations/truncation.jl

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -265,21 +265,45 @@ function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationSpace)
265265
return SectorDict(c => MAK.findtruncated_svd(d, blockstrategy(c)) for (c, d) in pairs(values))
266266
end
267267

268+
# The implementations below assume that the `SectorDict` always contains an entry for every block sector
269+
# for example, if a block gets fully truncated, inds[c] = Int[].
270+
# This is always the case in the implementations above.
271+
268272
function MAK.findtruncated(values::SectorVector, strategy::TruncationIntersection)
269273
inds = map(Base.Fix1(MAK.findtruncated, values), strategy.components)
270-
return SectorDict(
271-
c => mapreduce(
272-
Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds
273-
) for c in intersect(map(keys, inds)...)
274-
)
274+
@assert TensorKit._allequal(keys, inds) "missing blocks are not supported right now"
275+
sectors = keys(first(inds))
276+
vals = map(keys(first(inds))) do c
277+
mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds)
278+
end
279+
return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals)
275280
end
276281
function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationIntersection)
277282
inds = map(Base.Fix1(MAK.findtruncated_svd, values), strategy.components)
278-
return SectorDict(
279-
c => mapreduce(
280-
Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds
281-
) for c in intersect(map(keys, inds)...)
282-
)
283+
@assert TensorKit._allequal(keys, inds) "missing blocks are not supported right now"
284+
sectors = keys(first(inds))
285+
vals = map(keys(first(inds))) do c
286+
mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds)
287+
end
288+
return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals)
289+
end
290+
function MAK.findtruncated(values::SectorVector, strategy::TruncationUnion)
291+
inds = map(Base.Fix1(MAK.findtruncated, values), strategy.components)
292+
@assert TensorKit._allequal(keys, inds) "missing blocks are not supported right now"
293+
sectors = keys(first(inds))
294+
vals = map(keys(first(inds))) do c
295+
mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_union, inds)
296+
end
297+
return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals)
298+
end
299+
function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationUnion)
300+
inds = map(Base.Fix1(MAK.findtruncated_svd, values), strategy.components)
301+
@assert TensorKit._allequal(keys, inds) "missing blocks are not supported right now"
302+
sectors = keys(first(inds))
303+
vals = map(keys(first(inds))) do c
304+
mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_union, inds)
305+
end
306+
return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals)
283307
end
284308

285309
# Truncation error

src/tensors/abstracttensor.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,6 @@ end
313313
#------------------------------------------------------------
314314
InnerProductStyle(t::AbstractTensorMap) = InnerProductStyle(typeof(t))
315315

316-
blocktype(t::AbstractTensorMap) = blocktype(typeof(t))
317-
318316
numout(t::AbstractTensorMap) = numout(typeof(t))
319317
numin(t::AbstractTensorMap) = numin(typeof(t))
320318
numind(t::AbstractTensorMap) = numind(typeof(t))
@@ -441,6 +439,7 @@ See also [`blocks`](@ref), [`blocksectors`](@ref), [`blockdim`](@ref) and [`hasb
441439
442440
Return the type of the matrix blocks of a tensor.
443441
""" blocktype
442+
blocktype(t::AbstractTensorMap) = blocktype(typeof(t))
444443
function blocktype(::Type{T}) where {T <: AbstractTensorMap}
445444
return Core.Compiler.return_type(block, Tuple{T, sectortype(T)})
446445
end

src/tensors/tensor.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -455,9 +455,7 @@ block(t::TensorMap, c::Sector) = blocks(t)[c]
455455

456456
blocks(t::TensorMap) = BlockIterator(t, fusionblockstructure(t).blockstructure)
457457

458-
function blocktype(::Type{TT}) where {TT <: TensorMap}
459-
A = storagetype(TT)
460-
T = eltype(A)
458+
function blocktype(::Type{TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A <: Vector{T}}
461459
return Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}
462460
end
463461

test/cuda/tensors.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ for V in spacelist
9898
next = @constinferred Nothing iterate(bs, state)
9999
b2 = @constinferred block(t, first(blocksectors(t)))
100100
@test b1 == b2
101-
@test_broken eltype(bs) === Pair{typeof(c), typeof(b1)}
102-
@test_broken typeof(b1) === TensorKit.blocktype(t)
101+
@test eltype(bs) === Pair{typeof(c), typeof(b1)}
102+
@test typeof(b1) === TensorKit.blocktype(t)
103103
@test typeof(c) === sectortype(t)
104104
end
105105
end
@@ -162,8 +162,8 @@ for V in spacelist
162162
next = @constinferred Nothing iterate(bs, state)
163163
b2 = @constinferred block(t', first(blocksectors(t')))
164164
@test b1 == b2
165-
@test_broken eltype(bs) === Pair{typeof(c), typeof(b1)}
166-
@test_broken typeof(b1) === TensorKit.blocktype(t')
165+
@test eltype(bs) === Pair{typeof(c), typeof(b1)}
166+
@test typeof(b1) === TensorKit.blocktype(t')
167167
@test typeof(c) === sectortype(t)
168168
# linear algebra
169169
@test isa(@constinferred(norm(t)), real(T))

test/tensors/factorizations.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,15 @@ for V in spacelist
310310
@test norm(t - U5 * S5 * Vᴴ5) ϵ5 atol = eps(real(T))^(4 / 5)
311311
@test minimum(diagview(S5)) >= λ
312312
test_dim_isapprox(domain(S5), nvals)
313+
314+
trunc = truncrank(nvals) | trunctol(; atol = λ - 10eps(λ))
315+
U5, S5, Vᴴ5, ϵ5 = @constinferred svd_trunc(t; trunc)
316+
@test t * Vᴴ5' U5 * S5
317+
@test isisometric(U5)
318+
@test isisometric(Vᴴ5; side = :right)
319+
@test norm(t - U5 * S5 * Vᴴ5) ϵ5 atol = eps(real(T))^(4 / 5)
320+
@test minimum(diagview(S5)) >= λ
321+
test_dim_isapprox(domain(S5), nvals)
313322
end
314323
end
315324

0 commit comments

Comments
 (0)