Skip to content

Commit 4561a1e

Browse files
committed
Incremental work on Enzyme support
1 parent cdf66fb commit 4561a1e

38 files changed

+2730
-3
lines changed

.github/workflows/CI.yml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,37 @@ jobs:
3131
- tensors
3232
- other
3333
- mooncake
34+
- enzyme/factorizations
35+
- enzyme/linalg
36+
- enzyme/tensoroperations
37+
- enzyme/vectorinterface
38+
- enzyme/indexmanipulations
3439
- chainrules
3540
os:
3641
- ubuntu-latest
3742
- macOS-latest
3843
- windows-latest
44+
exclude:
45+
- os: macos-latest
46+
group: enzyme/factorizations
47+
- os: macos-latest
48+
group: enzyme/linalg
49+
- os: macos-latest
50+
group: enzyme/tensoroperations
51+
- os: macos-latest
52+
group: enzyme/vectorinterface
53+
- os: macos-latest
54+
group: enzyme/indexmanipulations
55+
- os: windows-latest
56+
group: enzyme/factorizations
57+
- os: windows-latest
58+
group: enzyme/linalg
59+
- os: windows-latest
60+
group: enzyme/tensoroperations
61+
- os: windows-latest
62+
group: enzyme/vectorinterface
63+
- os: windows-latest
64+
group: enzyme/indexmanipulations
3965
uses: "QuantumKitHub/QuantumKitHubActions/.github/workflows/Tests.yml@main"
4066
with:
4167
group: "${{ matrix.group }}"
@@ -57,6 +83,7 @@ jobs:
5783
- tensors
5884
- other
5985
- mooncake
86+
- enzyme
6087
- chainrules
6188
os:
6289
- ubuntu-latest

Project.toml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
2222
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
2323
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2424
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
25+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
26+
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
2527
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2628
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2729
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
@@ -30,6 +32,8 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
3032
TensorKitAdaptExt = "Adapt"
3133
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
3234
TensorKitChainRulesCoreExt = "ChainRulesCore"
35+
TensorKitEnzymeExt = "Enzyme"
36+
TensorKitEnzymeTestUtilsExt = "EnzymeTestUtils"
3337
TensorKitFiniteDifferencesExt = "FiniteDifferences"
3438
TensorKitMooncakeExt = "Mooncake"
3539

@@ -43,6 +47,8 @@ ChainRulesCore = "1"
4347
ChainRulesTestUtils = "1"
4448
Combinatorics = "1"
4549
Dictionaries = "0.4"
50+
Enzyme = "0.13.134"
51+
EnzymeTestUtils = "0.2.5"
4652
FiniteDifferences = "0.12"
4753
GPUArrays = "11.3.1"
4854
JET = "0.9, 0.10, 0.11"
@@ -75,6 +81,8 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
7581
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
7682
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
7783
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
84+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
85+
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
7886
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
7987
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
8088
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
@@ -88,4 +96,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
8896
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
8997

9098
[targets]
91-
test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "JET", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake"]
99+
test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "JET", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake", "Enzyme", "EnzymeTestUtils"]
100+
101+
[sources]
102+
TensorOperations = {url = "https://github.com/quantumkithub/tensoroperations.jl", rev = "master"}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module TensorKitEnzymeExt
2+
3+
using Enzyme
4+
using TensorKit
5+
import TensorKit as TK
6+
using VectorInterface
7+
using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize
8+
import TensorOperations as TO
9+
using MatrixAlgebraKit
10+
using TupleTools
11+
using Random: AbstractRNG
12+
13+
include("utility.jl")
14+
include("linalg.jl")
15+
include("vectorinterface.jl")
16+
include("tensoroperations.jl")
17+
include("factorizations.jl")
18+
include("indexmanipulations.jl")
19+
#include("planaroperations.jl")
20+
21+
end
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
function EnzymeRules.reverse(
2+
config::EnzymeRules.RevConfigWidth{1},
3+
func::Const{typeof(MatrixAlgebraKit.copy_input)},
4+
::Type{RT},
5+
cache,
6+
f::Annotation,
7+
A::Annotation{<:AbstractTensorMap}
8+
) where {RT}
9+
copy_shadow = cache
10+
if !isa(A, Const) && !isnothing(copy_shadow)
11+
add!(A.dval, copy_shadow)
12+
end
13+
return (nothing, nothing)
14+
end
15+
16+
for (f, pb) in (
17+
(:eig_full, :(MatrixAlgebraKit.eig_pullback!)),
18+
(:eigh_full, :(MatrixAlgebraKit.eigh_pullback!)),
19+
(:lq_compact, :(MatrixAlgebraKit.lq_pullback!)),
20+
(:qr_compact, :(MatrixAlgebraKit.qr_pullback!)),
21+
)
22+
@eval begin
23+
function EnzymeRules.augmented_primal(
24+
config::EnzymeRules.RevConfigWidth{1},
25+
func::Const{typeof($f)},
26+
::Type{RT},
27+
A::Annotation{<:AbstractTensorMap},
28+
alg::Const,
29+
) where {RT}
30+
ret = $f(A.val, alg.val)
31+
primal = EnzymeRules.needs_primal(config) ? ret : nothing
32+
shadow = EnzymeRules.needs_shadow(config) ? make_zero(ret) : nothing
33+
cache = (ret, shadow)
34+
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
35+
end
36+
function EnzymeRules.reverse(
37+
config::EnzymeRules.RevConfigWidth{1},
38+
func::Const{typeof($f)},
39+
::Type{RT},
40+
cache,
41+
A::Annotation{<:AbstractTensorMap},
42+
alg::Const,
43+
) where {RT}
44+
!isa(A, Const) && $pb(A.dval, A.val, cache...)
45+
return (nothing, nothing)
46+
end
47+
end
48+
end
49+
50+
for f in (:svd_compact, :svd_full)
51+
@eval begin
52+
function EnzymeRules.augmented_primal(
53+
config::EnzymeRules.RevConfigWidth{1},
54+
func::Const{typeof($f)},
55+
::Type{RT},
56+
A::Annotation{<:AbstractTensorMap},
57+
alg::Const,
58+
) where {RT}
59+
USVᴴ = $f(A.val, alg.val)
60+
primal = EnzymeRules.needs_primal(config) ? USVᴴ : nothing
61+
shadow = EnzymeRules.needs_shadow(config) ? make_zero(USVᴴ) : nothing
62+
cache = (USVᴴ, shadow)
63+
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
64+
end
65+
function EnzymeRules.reverse(
66+
config::EnzymeRules.RevConfigWidth{1},
67+
func::Const{typeof($f)},
68+
::Type{RT},
69+
cache,
70+
A::Annotation{<:AbstractTensorMap},
71+
alg::Const,
72+
) where {RT}
73+
!isa(A, Const) && MatrixAlgebraKit.svd_pullback!(A.dval, A.val, cache...)
74+
return (nothing, nothing)
75+
end
76+
end
77+
78+
# mutating version is not guaranteed to actually mutate
79+
# so we can simply use the non-mutating version instead
80+
f! = Symbol(f, :!)
81+
#=@eval begin
82+
function EnzymeRules.augmented_primal(
83+
config::EnzymeRules.RevConfigWidth{1},
84+
func::Const{typeof($f!)},
85+
::Type{RT},
86+
A::Annotation{<:AbstractTensorMap},
87+
USVᴴ::Annotation,
88+
alg::Const,
89+
) where {RT}
90+
EnzymeRules.augmented_primal(func, RT, A, alg)
91+
end
92+
function EnzymeRules.reverse(
93+
config::EnzymeRules.RevConfigWidth{1},
94+
func::Const{typeof($f!)},
95+
::Type{RT},
96+
cache,
97+
A::Annotation{<:AbstractTensorMap},
98+
USVᴴ::Annotation,
99+
alg::Const,
100+
) where {RT}
101+
EnzymeRules.reverse(func, RT, A, alg)
102+
end
103+
end=# #hmmmm
104+
end
105+
106+
# TODO
107+
#=
108+
function EnzymeRules.augmented_primal(
109+
config::EnzymeRules.RevConfigWidth{1},
110+
func::Const{typeof(svd_trunc)},
111+
::Type{RT},
112+
A::Annotation{<:AbstractTensorMap},
113+
alg::Const,
114+
) where {RT}
115+
116+
USVᴴ = svd_compact(A.val, alg.val.alg)
117+
USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.val.trunc)
118+
ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind)
119+
dUSVᴴtrunc = make_zero(USVᴴtrunc)
120+
cache = (USVᴴtrunc, dUSVᴴtrunc)
121+
return EnzymeRules.AugmentedReturn(USVᴴtrunc, dUSVᴴtrunc, cache)
122+
end
123+
function EnzymeRules.reverse(
124+
config::EnzymeRules.RevConfigWidth{1},
125+
func::Const{typeof(svd_trunc)},
126+
::Type{RT},
127+
cache,
128+
A::Annotation{<:AbstractTensorMap},
129+
alg::Const,
130+
) where {RT}
131+
USVᴴ, dUSVᴴ = cache
132+
MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴ, dUSVᴴ)
133+
return (nothing, nothing)
134+
end=#

0 commit comments

Comments
 (0)