Skip to content

Commit e75ce11

Browse files
committed
more shared code in the setup
1 parent 2295e2c commit e75ce11

14 files changed

Lines changed: 154 additions & 389 deletions

File tree

test/chainrules/factorizations.jl

Lines changed: 4 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -11,113 +11,6 @@ using Zygote
1111
using MatrixAlgebraKit
1212
using MatrixAlgebraKit: LAPACK_HouseholderQR, LAPACK_HouseholderLQ, diagview
1313

14-
# Test utility
15-
# -------------
16-
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::AbstractTensorMap)
17-
return randn!(similar(x))
18-
end
19-
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::DiagonalTensorMap)
20-
V = x.domain
21-
return DiagonalTensorMap(randn(eltype(x), reduceddim(V)), V)
22-
end
23-
ChainRulesTestUtils.rand_tangent(::AbstractRNG, ::VectorSpace) = NoTangent()
24-
function ChainRulesTestUtils.test_approx(
25-
actual::AbstractTensorMap, expected::AbstractTensorMap, msg = ""; kwargs...
26-
)
27-
for (c, b) in blocks(actual)
28-
ChainRulesTestUtils.@test_msg msg isapprox(b, block(expected, c); kwargs...)
29-
end
30-
return nothing
31-
end
32-
33-
# Float32 and finite differences don't mix well
34-
precision(::Type{<:Union{Float32, Complex{Float32}}}) = 1.0e-2
35-
precision(::Type{<:Union{Float64, Complex{Float64}}}) = 1.0e-5
36-
37-
function test_ad_rrule(f, args...; check_inferred = false, kwargs...)
38-
test_rrule(
39-
Zygote.ZygoteRuleConfig(), f, args...;
40-
rrule_f = rrule_via_ad, check_inferred, kwargs...
41-
)
42-
return nothing
43-
end
44-
45-
# project_hermitian is non-differentiable for now
46-
_project_hermitian(x) = (x + x') / 2
47-
48-
# Gauge fixing tangents
49-
# ---------------------
50-
function remove_qrgauge_dependence!(ΔQ, t, Q)
51-
for (c, b) in blocks(ΔQ)
52-
m, n = size(block(t, c))
53-
minmn = min(m, n)
54-
Qc = block(Q, c)
55-
Q1 = view(Qc, 1:m, 1:minmn)
56-
ΔQ2 = view(b, :, (minmn + 1):m)
57-
mul!(ΔQ2, Q1, Q1' * ΔQ2)
58-
end
59-
return ΔQ
60-
end
61-
function remove_lqgauge_dependence!(ΔQ, t, Q)
62-
for (c, b) in blocks(ΔQ)
63-
m, n = size(block(t, c))
64-
minmn = min(m, n)
65-
Qc = block(Q, c)
66-
Q1 = view(Qc, 1:minmn, 1:n)
67-
ΔQ2 = view(b, (minmn + 1):n, :)
68-
mul!(ΔQ2, ΔQ2 * Q1', Q1)
69-
end
70-
return ΔQ
71-
end
72-
function remove_eiggauge_dependence!(
73-
ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D)
74-
)
75-
gaugepart = V' * ΔV
76-
for (c, b) in blocks(gaugepart)
77-
Dc = diagview(block(D, c))
78-
# for some reason this fails only on tests, and I cannot reproduce it in an
79-
# interactive session.
80-
# b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0
81-
for j in axes(b, 2), i in axes(b, 1)
82-
abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0)
83-
end
84-
end
85-
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
86-
return ΔV
87-
end
88-
function remove_eighgauge_dependence!(
89-
ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D)
90-
)
91-
gaugepart = project_antihermitian!(V' * ΔV)
92-
for (c, b) in blocks(gaugepart)
93-
Dc = diagview(block(D, c))
94-
# for some reason this fails only on tests, and I cannot reproduce it in an
95-
# interactive session.
96-
# b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0
97-
for j in axes(b, 2), i in axes(b, 1)
98-
abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0)
99-
end
100-
end
101-
mul!(ΔV, V, gaugepart, -1, 1)
102-
return ΔV
103-
end
104-
function remove_svdgauge_dependence!(
105-
ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(S)
106-
)
107-
gaugepart = project_antihermitian!(U' * ΔU + Vᴴ * ΔVᴴ')
108-
for (c, b) in blocks(gaugepart)
109-
Sd = diagview(block(S, c))
110-
# for some reason this fails only on tests, and I cannot reproduce it in an
111-
# interactive session.
112-
# b[abs.(transpose(diagview(Sc)) .- diagview(Sc)) .>= degeneracy_atol] .= 0
113-
for j in axes(b, 2), i in axes(b, 1)
114-
abs(Sd[i] - Sd[j]) >= degeneracy_atol && (b[i, j] = 0)
115-
end
116-
end
117-
mul!(ΔU, U, gaugepart, -1, 1)
118-
return ΔU, ΔVᴴ
119-
end
120-
12114
# Tests
12215
# -----
12316

@@ -181,7 +74,7 @@ for V in spacelist
18174
DiagonalTensorMap(randn(T, reduceddim(V[1])), V[1]),
18275
)
18376

184-
atol = rtol = precision(T) * dim(space(t))
77+
atol = rtol = default_tol(T) * dim(space(t))
18578
fkwargs = (; positive = true) # make FiniteDifferences happy
18679

18780
test_ad_rrule(qr_compact, t; fkwargs, atol, rtol)
@@ -218,7 +111,7 @@ for V in spacelist
218111
DiagonalTensorMap(randn(T, reduceddim(V[1])), V[1]),
219112
)
220113

221-
atol = rtol = precision(T) * dim(space(t))
114+
atol = rtol = default_tol(T) * dim(space(t))
222115
fkwargs = (; positive = true) # make FiniteDifferences happy
223116

224117
test_ad_rrule(lq_compact, t; fkwargs, atol, rtol)
@@ -254,7 +147,7 @@ for V in spacelist
254147
# DiagonalTensorMap(rand(T, reduceddim(V[1])), V[1]), # broken in MatrixAlgebraKit
255148
)
256149

257-
atol = rtol = precision(T) * dim(space(t))
150+
atol = rtol = default_tol(T) * dim(space(t))
258151

259152
d, v = eig_full(t)
260153
Δv = rand_tangent(v)
@@ -290,7 +183,7 @@ for V in spacelist
290183
# TODO: fix diagonaltensormap case
291184
# DiagonalTensorMap(rand(T, reduceddim(V1)), V1))
292185

293-
atol = rtol = degeneracy_atol = precision(T) * dim(space(t))
186+
atol = rtol = degeneracy_atol = default_tol(T) * dim(space(t))
294187
USVᴴ = svd_compact(t)
295188
ΔU, ΔS, ΔVᴴ = rand_tangent.(USVᴴ)
296189
ΔS2 = randn!(similar(ΔS, space(ΔS)))

test/chainrules/linalg.jl

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,40 +10,6 @@ using LinearAlgebra
1010
using Zygote
1111
using MatrixAlgebraKit
1212

13-
# Test utility
14-
# -------------
15-
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::AbstractTensorMap)
16-
return randn!(similar(x))
17-
end
18-
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::DiagonalTensorMap)
19-
V = x.domain
20-
return DiagonalTensorMap(randn(eltype(x), reduceddim(V)), V)
21-
end
22-
ChainRulesTestUtils.rand_tangent(::AbstractRNG, ::VectorSpace) = NoTangent()
23-
function ChainRulesTestUtils.test_approx(
24-
actual::AbstractTensorMap, expected::AbstractTensorMap, msg = ""; kwargs...
25-
)
26-
for (c, b) in blocks(actual)
27-
ChainRulesTestUtils.@test_msg msg isapprox(b, block(expected, c); kwargs...)
28-
end
29-
return nothing
30-
end
31-
32-
# Float32 and finite differences don't mix well
33-
precision(::Type{<:Union{Float32, Complex{Float32}}}) = 1.0e-2
34-
precision(::Type{<:Union{Float64, Complex{Float64}}}) = 1.0e-5
35-
36-
function test_ad_rrule(f, args...; check_inferred = false, kwargs...)
37-
test_rrule(
38-
Zygote.ZygoteRuleConfig(), f, args...;
39-
rrule_f = rrule_via_ad, check_inferred, kwargs...
40-
)
41-
return nothing
42-
end
43-
44-
# project_hermitian is non-differentiable for now
45-
_project_hermitian(x) = (x + x') / 2
46-
4713
# Tests
4814
# -----
4915

@@ -194,8 +160,8 @@ for V in spacelist
194160
end
195161

196162
@timedtestset "Linear Algebra part II with scalartype $T" for T in eltypes
197-
atol = precision(T)
198-
rtol = precision(T)
163+
atol = default_tol(T)
164+
rtol = default_tol(T)
199165
for i in 1:3
200166
E = randn(T, (V[1:i]...) (V[1:i]...))
201167
test_rrule(LinearAlgebra.tr, E; atol, rtol)
@@ -212,8 +178,8 @@ for V in spacelist
212178
end
213179

214180
@timedtestset "Matrix functions ($T)" for T in eltypes
215-
atol = precision(T)
216-
rtol = precision(T)
181+
atol = default_tol(T)
182+
rtol = default_tol(T)
217183
for f in (sqrt, exp)
218184
check_inferred = false # !(T <: Real) # not type-stable for real functions
219185
t1 = randn(T, V[1] V[1])

test/chainrules/tensoroperations.jl

Lines changed: 2 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -10,48 +10,6 @@ using LinearAlgebra
1010
using Zygote
1111
using MatrixAlgebraKit
1212

13-
const _repartition = @static if isdefined(Base, :get_extension)
14-
Base.get_extension(TensorKit, :TensorKitChainRulesCoreExt)._repartition
15-
else
16-
TensorKit.TensorKitChainRulesCoreExt._repartition
17-
end
18-
19-
# Test utility
20-
# -------------
21-
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::AbstractTensorMap)
22-
return randn!(similar(x))
23-
end
24-
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::DiagonalTensorMap)
25-
V = x.domain
26-
return DiagonalTensorMap(randn(eltype(x), reduceddim(V)), V)
27-
end
28-
ChainRulesTestUtils.rand_tangent(::AbstractRNG, ::VectorSpace) = NoTangent()
29-
function ChainRulesTestUtils.test_approx(
30-
actual::AbstractTensorMap, expected::AbstractTensorMap, msg = ""; kwargs...
31-
)
32-
for (c, b) in blocks(actual)
33-
ChainRulesTestUtils.@test_msg msg isapprox(b, block(expected, c); kwargs...)
34-
end
35-
return nothing
36-
end
37-
38-
# Float32 and finite differences don't mix well
39-
precision(::Type{<:Union{Float32, Complex{Float32}}}) = 1.0e-2
40-
precision(::Type{<:Union{Float64, Complex{Float64}}}) = 1.0e-5
41-
42-
function randindextuple(N::Int, k::Int = rand(0:N))
43-
@assert 0 k N
44-
_p = randperm(N)
45-
return (tuple(_p[1:k]...), tuple(_p[(k + 1):end]...))
46-
end
47-
48-
function test_ad_rrule(f, args...; check_inferred = false, kwargs...)
49-
test_rrule(
50-
Zygote.ZygoteRuleConfig(), f, args...;
51-
rrule_f = rrule_via_ad, check_inferred, kwargs...
52-
)
53-
return nothing
54-
end
5513

5614
# Tests
5715
# -----
@@ -108,8 +66,8 @@ for V in spacelist
10866

10967
symmetricbraiding &&
11068
@timedtestset "TensorOperations with scalartype $T" for T in eltypes
111-
atol = precision(T)
112-
rtol = precision(T)
69+
atol = default_tol(T)
70+
rtol = default_tol(T)
11371

11472
@timedtestset "tensortrace!" begin
11573
for _ in 1:5

test/factorizations/eig.jl

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,7 @@ using LinearAlgebra: LinearAlgebra
44
using MatrixAlgebraKit: diagview
55

66

7-
spacelist = if fast_tests
8-
(Vtr, Vℤ₃, VSU₂)
9-
elseif get(ENV, "CI", "false") == "true"
10-
println("Detected running on CI")
11-
if Sys.iswindows()
12-
(Vtr, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VIB_diag)
13-
elseif Sys.isapple()
14-
(Vtr, Vℤ₃, VfU₁, VfSU₂, VIB_M)
15-
else
16-
(Vtr, VU₁, VCU₁, VSU₂, VfSU₂, VIB_diag, VIB_M)
17-
end
18-
else
19-
(Vtr, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂, VIB_diag, VIB_M)
20-
end
7+
spacelist = factorization_spacelist(fast_tests)
218

229
eltypes = (Float32, ComplexF64)
2310

test/factorizations/ortho.jl

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,7 @@ using LinearAlgebra: LinearAlgebra
44
using MatrixAlgebraKit: diagview
55

66

7-
spacelist = if fast_tests
8-
(Vtr, Vℤ₃, VSU₂)
9-
elseif get(ENV, "CI", "false") == "true"
10-
println("Detected running on CI")
11-
if Sys.iswindows()
12-
(Vtr, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VIB_diag)
13-
elseif Sys.isapple()
14-
(Vtr, Vℤ₃, VfU₁, VfSU₂, VIB_M)
15-
else
16-
(Vtr, VU₁, VCU₁, VSU₂, VfSU₂, VIB_diag, VIB_M)
17-
end
18-
else
19-
(Vtr, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂, VIB_diag, VIB_M)
20-
end
7+
spacelist = factorization_spacelist(fast_tests)
218

229
eltypes = (Float32, ComplexF64)
2310

test/factorizations/projections.jl

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,7 @@ using LinearAlgebra: LinearAlgebra
44
using MatrixAlgebraKit: diagview
55

66

7-
spacelist = if fast_tests
8-
(Vtr, Vℤ₃, VSU₂)
9-
elseif get(ENV, "CI", "false") == "true"
10-
println("Detected running on CI")
11-
if Sys.iswindows()
12-
(Vtr, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VIB_diag)
13-
elseif Sys.isapple()
14-
(Vtr, Vℤ₃, VfU₁, VfSU₂, VIB_M)
15-
else
16-
(Vtr, VU₁, VCU₁, VSU₂, VfSU₂, VIB_diag, VIB_M)
17-
end
18-
else
19-
(Vtr, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂, VIB_diag, VIB_M)
20-
end
7+
spacelist = factorization_spacelist(fast_tests)
218

229
eltypes = (Float32, ComplexF64)
2310

test/factorizations/svd.jl

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,7 @@ using LinearAlgebra: LinearAlgebra
44
using MatrixAlgebraKit: diagview
55

66

7-
spacelist = if fast_tests
8-
(Vtr, Vℤ₃, VSU₂)
9-
elseif get(ENV, "CI", "false") == "true"
10-
println("Detected running on CI")
11-
if Sys.iswindows()
12-
(Vtr, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VIB_diag)
13-
elseif Sys.isapple()
14-
(Vtr, Vℤ₃, VfU₁, VfSU₂, VIB_M)
15-
else
16-
(Vtr, VU₁, VCU₁, VSU₂, VfSU₂, VIB_diag, VIB_M)
17-
end
18-
else
19-
(Vtr, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂, VIB_diag, VIB_M)
20-
end
7+
spacelist = factorization_spacelist(fast_tests)
218

229
eltypes = (Float32, ComplexF64)
2310

0 commit comments

Comments
 (0)