@@ -11,113 +11,6 @@ using Zygote
1111using MatrixAlgebraKit
1212using MatrixAlgebraKit: LAPACK_HouseholderQR, LAPACK_HouseholderLQ, diagview
1313
14- # Test utility
15- # -------------
16- function ChainRulesTestUtils. rand_tangent (rng:: AbstractRNG , x:: AbstractTensorMap )
17- return randn! (similar (x))
18- end
19- function ChainRulesTestUtils. rand_tangent (rng:: AbstractRNG , x:: DiagonalTensorMap )
20- V = x. domain
21- return DiagonalTensorMap (randn (eltype (x), reduceddim (V)), V)
22- end
23- ChainRulesTestUtils. rand_tangent (:: AbstractRNG , :: VectorSpace ) = NoTangent ()
24- function ChainRulesTestUtils. test_approx (
25- actual:: AbstractTensorMap , expected:: AbstractTensorMap , msg = " " ; kwargs...
26- )
27- for (c, b) in blocks (actual)
28- ChainRulesTestUtils. @test_msg msg isapprox (b, block (expected, c); kwargs... )
29- end
30- return nothing
31- end
32-
33- # Float32 and finite differences don't mix well
34- precision (:: Type{<:Union{Float32, Complex{Float32}}} ) = 1.0e-2
35- precision (:: Type{<:Union{Float64, Complex{Float64}}} ) = 1.0e-5
36-
37- function test_ad_rrule (f, args... ; check_inferred = false , kwargs... )
38- test_rrule (
39- Zygote. ZygoteRuleConfig (), f, args... ;
40- rrule_f = rrule_via_ad, check_inferred, kwargs...
41- )
42- return nothing
43- end
44-
45- # project_hermitian is non-differentiable for now
46- _project_hermitian (x) = (x + x' ) / 2
47-
48- # Gauge fixing tangents
49- # ---------------------
50- function remove_qrgauge_dependence! (ΔQ, t, Q)
51- for (c, b) in blocks (ΔQ)
52- m, n = size (block (t, c))
53- minmn = min (m, n)
54- Qc = block (Q, c)
55- Q1 = view (Qc, 1 : m, 1 : minmn)
56- ΔQ2 = view (b, :, (minmn + 1 ): m)
57- mul! (ΔQ2, Q1, Q1' * ΔQ2)
58- end
59- return ΔQ
60- end
61- function remove_lqgauge_dependence! (ΔQ, t, Q)
62- for (c, b) in blocks (ΔQ)
63- m, n = size (block (t, c))
64- minmn = min (m, n)
65- Qc = block (Q, c)
66- Q1 = view (Qc, 1 : minmn, 1 : n)
67- ΔQ2 = view (b, (minmn + 1 ): n, :)
68- mul! (ΔQ2, ΔQ2 * Q1' , Q1)
69- end
70- return ΔQ
71- end
72- function remove_eiggauge_dependence! (
73- ΔV, D, V; degeneracy_atol = MatrixAlgebraKit. default_pullback_degeneracy_atol (D)
74- )
75- gaugepart = V' * ΔV
76- for (c, b) in blocks (gaugepart)
77- Dc = diagview (block (D, c))
78- # for some reason this fails only on tests, and I cannot reproduce it in an
79- # interactive session.
80- # b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0
81- for j in axes (b, 2 ), i in axes (b, 1 )
82- abs (Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0 )
83- end
84- end
85- mul! (ΔV, V / (V' * V), gaugepart, - 1 , 1 )
86- return ΔV
87- end
88- function remove_eighgauge_dependence! (
89- ΔV, D, V; degeneracy_atol = MatrixAlgebraKit. default_pullback_degeneracy_atol (D)
90- )
91- gaugepart = project_antihermitian! (V' * ΔV)
92- for (c, b) in blocks (gaugepart)
93- Dc = diagview (block (D, c))
94- # for some reason this fails only on tests, and I cannot reproduce it in an
95- # interactive session.
96- # b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0
97- for j in axes (b, 2 ), i in axes (b, 1 )
98- abs (Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0 )
99- end
100- end
101- mul! (ΔV, V, gaugepart, - 1 , 1 )
102- return ΔV
103- end
104- function remove_svdgauge_dependence! (
105- ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = MatrixAlgebraKit. default_pullback_degeneracy_atol (S)
106- )
107- gaugepart = project_antihermitian! (U' * ΔU + Vᴴ * ΔVᴴ' )
108- for (c, b) in blocks (gaugepart)
109- Sd = diagview (block (S, c))
110- # for some reason this fails only on tests, and I cannot reproduce it in an
111- # interactive session.
112- # b[abs.(transpose(diagview(Sc)) .- diagview(Sc)) .>= degeneracy_atol] .= 0
113- for j in axes (b, 2 ), i in axes (b, 1 )
114- abs (Sd[i] - Sd[j]) >= degeneracy_atol && (b[i, j] = 0 )
115- end
116- end
117- mul! (ΔU, U, gaugepart, - 1 , 1 )
118- return ΔU, ΔVᴴ
119- end
120-
12114# Tests
12215# -----
12316
@@ -181,7 +74,7 @@ for V in spacelist
18174 DiagonalTensorMap (randn (T, reduceddim (V[1 ])), V[1 ]),
18275 )
18376
184- atol = rtol = precision (T) * dim (space (t))
77+ atol = rtol = default_tol (T) * dim (space (t))
18578 fkwargs = (; positive = true ) # make FiniteDifferences happy
18679
18780 test_ad_rrule (qr_compact, t; fkwargs, atol, rtol)
@@ -218,7 +111,7 @@ for V in spacelist
218111 DiagonalTensorMap (randn (T, reduceddim (V[1 ])), V[1 ]),
219112 )
220113
221- atol = rtol = precision (T) * dim (space (t))
114+ atol = rtol = default_tol (T) * dim (space (t))
222115 fkwargs = (; positive = true ) # make FiniteDifferences happy
223116
224117 test_ad_rrule (lq_compact, t; fkwargs, atol, rtol)
@@ -254,7 +147,7 @@ for V in spacelist
254147 # DiagonalTensorMap(rand(T, reduceddim(V[1])), V[1]), # broken in MatrixAlgebraKit
255148 )
256149
257- atol = rtol = precision (T) * dim (space (t))
150+ atol = rtol = default_tol (T) * dim (space (t))
258151
259152 d, v = eig_full (t)
260153 Δv = rand_tangent (v)
@@ -290,7 +183,7 @@ for V in spacelist
290183 # TODO : fix diagonaltensormap case
291184 # DiagonalTensorMap(rand(T, reduceddim(V1)), V1))
292185
293- atol = rtol = degeneracy_atol = precision (T) * dim (space (t))
186+ atol = rtol = degeneracy_atol = default_tol (T) * dim (space (t))
294187 USVᴴ = svd_compact (t)
295188 ΔU, ΔS, ΔVᴴ = rand_tangent .(USVᴴ)
296189 ΔS2 = randn! (similar (ΔS, space (ΔS)))
0 commit comments