Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ GPUArrays = "11.3.1"
JET = "0.9, 0.10, 0.11"
LRUCache = "1.0.2"
LinearAlgebra = "1"
MatrixAlgebraKit = "0.6.4"
MatrixAlgebraKit = "0.6.5"
Mooncake = "0.5"
OhMyThreads = "0.8.0"
Printf = "1"
Expand Down
45 changes: 45 additions & 0 deletions ext/TensorKitMooncakeExt/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,48 @@ function Mooncake.rrule!!(::CoDual{typeof(inv)}, A_ΔA::CoDual{<:AbstractTensorM

return Ainv_ΔAinv, inv_pullback
end

# single-output projections: project_hermitian!, project_antihermitian!
for (f!, f, adj) in (
(:project_hermitian!, :project_hermitian, :project_hermitian_adjoint),
(:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint),
)
@eval begin
function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual{<:AbstractTensorMap}, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
A, dA = arrayify(A_dA)
arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg)

# don't need to copy/restore A since projections don't mutate input
argc = copy(arg)
arg = $f!(A, arg, Mooncake.primal(alg_dalg))

function $adj(::NoRData)
$f!(darg)
if dA !== darg
add!(dA, darg)
MatrixAlgebraKit.zero!(darg)
end
copy!(arg, argc)
return ntuple(Returns(NoRData()), 4)
end

return arg_darg, $adj
end

function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractTensorMap}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
A, dA = arrayify(A_dA)
output = $f(A, Mooncake.primal(alg_dalg))
output_doutput = Mooncake.zero_fcodual(output)

doutput = last(arrayify(output_doutput))
function $adj(::NoRData)
# TODO: need accumulating projection to avoid intermediate here
add!(dA, $f(doutput))
MatrixAlgebraKit.zero!(doutput)
return ntuple(Returns(NoRData()), 3)
end

return output_doutput, $adj
end
end
end
4 changes: 2 additions & 2 deletions src/factorizations/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ using TensorOperations: Index2Tuple
using MatrixAlgebraKit
import MatrixAlgebraKit as MAK
using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, DiagonalAlgorithm
using MatrixAlgebraKit: TruncationStrategy, NoTruncation, TruncationByValue,
TruncationByError, TruncationIntersection, TruncationByFilter, TruncationByOrder
using MatrixAlgebraKit: TruncationStrategy, NoTruncation, TruncationByValue, TruncationByError,
TruncationIntersection, TruncationUnion, TruncationByFilter, TruncationByOrder
using MatrixAlgebraKit: diagview

include("utility.jl")
Expand Down
16 changes: 16 additions & 0 deletions src/factorizations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,22 @@ function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationInterse
) for c in intersect(map(keys, inds)...)
)
end
function MAK.findtruncated(values::SectorVector, strategy::TruncationUnion)
inds = map(Base.Fix1(MAK.findtruncated, values), strategy.components)
return SectorDict(
c => reduce(
MatrixAlgebraKit._ind_union, [ind[c] for ind in inds if haskey(ind, c)]
) for c in union(map(keys, inds)...)
)
end
function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationUnion)
inds = map(Base.Fix1(MAK.findtruncated_svd, values), strategy.components)
return SectorDict(
c => reduce(
MatrixAlgebraKit._ind_union, [ind[c] for ind in inds if haskey(ind, c)]
) for c in union(map(keys, inds)...)
)
end

# Truncation error
# ----------------
Expand Down
3 changes: 1 addition & 2 deletions src/tensors/abstracttensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,6 @@ end
#------------------------------------------------------------
InnerProductStyle(t::AbstractTensorMap) = InnerProductStyle(typeof(t))

blocktype(t::AbstractTensorMap) = blocktype(typeof(t))

numout(t::AbstractTensorMap) = numout(typeof(t))
numin(t::AbstractTensorMap) = numin(typeof(t))
numind(t::AbstractTensorMap) = numind(typeof(t))
Expand Down Expand Up @@ -441,6 +439,7 @@ See also [`blocks`](@ref), [`blocksectors`](@ref), [`blockdim`](@ref) and [`hasb

Return the type of the matrix blocks of a tensor.
""" blocktype
blocktype(t::AbstractTensorMap) = blocktype(typeof(t))
function blocktype(::Type{T}) where {T <: AbstractTensorMap}
return Core.Compiler.return_type(block, Tuple{T, sectortype(T)})
end
Expand Down
4 changes: 1 addition & 3 deletions src/tensors/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,7 @@ block(t::TensorMap, c::Sector) = blocks(t)[c]

blocks(t::TensorMap) = BlockIterator(t, fusionblockstructure(t).blockstructure)

function blocktype(::Type{TT}) where {TT <: TensorMap}
A = storagetype(TT)
T = eltype(A)
function blocktype(::Type{TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A <: Vector{T}}
return Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}
end

Expand Down
8 changes: 4 additions & 4 deletions test/cuda/tensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ for V in spacelist
next = @constinferred Nothing iterate(bs, state)
b2 = @constinferred block(t, first(blocksectors(t)))
@test b1 == b2
@test_broken eltype(bs) === Pair{typeof(c), typeof(b1)}
@test_broken typeof(b1) === TensorKit.blocktype(t)
@test eltype(bs) === Pair{typeof(c), typeof(b1)}
@test typeof(b1) === TensorKit.blocktype(t)
@test typeof(c) === sectortype(t)
end
end
Expand Down Expand Up @@ -162,8 +162,8 @@ for V in spacelist
next = @constinferred Nothing iterate(bs, state)
b2 = @constinferred block(t', first(blocksectors(t')))
@test b1 == b2
@test_broken eltype(bs) === Pair{typeof(c), typeof(b1)}
@test_broken typeof(b1) === TensorKit.blocktype(t')
@test eltype(bs) === Pair{typeof(c), typeof(b1)}
@test typeof(b1) === TensorKit.blocktype(t')
@test typeof(c) === sectortype(t)
# linear algebra
@test isa(@constinferred(norm(t)), real(T))
Expand Down
9 changes: 9 additions & 0 deletions test/tensors/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,15 @@ for V in spacelist
@test norm(t - U5 * S5 * Vᴴ5) ≈ ϵ5 atol = eps(real(T))^(4 / 5)
@test minimum(diagview(S5)) >= λ
test_dim_isapprox(domain(S5), nvals)

trunc = truncrank(nvals) | trunctol(; atol = λ - 10eps(λ))
U5, S5, Vᴴ5, ϵ5 = @constinferred svd_trunc(t; trunc)
@test t * Vᴴ5' ≈ U5 * S5
@test isisometric(U5)
@test isisometric(Vᴴ5; side = :right)
@test norm(t - U5 * S5 * Vᴴ5) ≈ ϵ5 atol = eps(real(T))^(4 / 5)
@test minimum(diagview(S5)) >= λ
test_dim_isapprox(domain(S5), nvals)
end
end

Expand Down
Loading