Skip to content
Draft
32 changes: 32 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,42 @@ jobs:
- tensors
- other
- mooncake
- enzyme/factorizations
- enzyme/linalg
- enzyme/tensoroperations
- enzyme/vectorinterface
- enzyme/indexmanipulations/add
- enzyme/indexmanipulations/flip_twist_unit
- chainrules
os:
- ubuntu-latest
- macOS-latest
- windows-latest
exclude:
- os: macos-latest
group: enzyme/factorizations
- os: macos-latest
group: enzyme/linalg
- os: macos-latest
group: enzyme/tensoroperations
- os: macos-latest
group: enzyme/vectorinterface
- os: macos-latest
group: enzyme/indexmanipulations/add
- os: macos-latest
group: enzyme/indexmanipulations/flip_twist_unit
- os: windows-latest
group: enzyme/factorizations
- os: windows-latest
group: enzyme/linalg
- os: windows-latest
group: enzyme/tensoroperations
- os: windows-latest
group: enzyme/vectorinterface
- os: windows-latest
group: enzyme/indexmanipulations/add
- os: windows-latest
group: enzyme/indexmanipulations/flip_twist_unit
uses: "QuantumKitHub/QuantumKitHubActions/.github/workflows/Tests.yml@main"
with:
group: "${{ matrix.group }}"
Expand All @@ -57,6 +88,7 @@ jobs:
- tensors
- other
- mooncake
- enzyme
- chainrules
os:
- ubuntu-latest
Expand Down
12 changes: 10 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
Expand All @@ -30,6 +32,8 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
TensorKitAdaptExt = "Adapt"
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
TensorKitChainRulesCoreExt = "ChainRulesCore"
TensorKitEnzymeExt = "Enzyme"
TensorKitEnzymeTestUtilsExt = "EnzymeTestUtils"
TensorKitFiniteDifferencesExt = "FiniteDifferences"
TensorKitMooncakeExt = "Mooncake"

Expand All @@ -43,6 +47,8 @@ ChainRulesCore = "1"
ChainRulesTestUtils = "1"
Combinatorics = "1"
Dictionaries = "0.4"
Enzyme = "0.13.134"
EnzymeTestUtils = "0.2.5"
FiniteDifferences = "0.12"
GPUArrays = "11.3.1"
JET = "0.9, 0.10, 0.11"
Expand All @@ -57,7 +63,7 @@ SafeTestsets = "0.1"
ScopedValues = "1.3.0"
Strided = "2"
TensorKitSectors = "0.3.6"
TensorOperations = "5.1"
TensorOperations = "5.5.2"
Test = "1"
TestExtras = "0.2,0.3"
TupleTools = "1.5"
Expand All @@ -75,6 +81,8 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Expand All @@ -88,4 +96,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[targets]
test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "JET", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake"]
test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "JET", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake", "Enzyme", "EnzymeTestUtils"]
21 changes: 21 additions & 0 deletions ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module TensorKitEnzymeExt

using Enzyme
using TensorKit
import TensorKit as TK
using VectorInterface
using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize
import TensorOperations as TO
using MatrixAlgebraKit
using TupleTools
using Random: AbstractRNG

include("utility.jl")
include("linalg.jl")
include("vectorinterface.jl")
include("tensoroperations.jl")
include("factorizations.jl")
include("indexmanipulations.jl")
#include("planaroperations.jl")

end
203 changes: 203 additions & 0 deletions ext/TensorKitEnzymeExt/factorizations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(MatrixAlgebraKit.copy_input)},
::Type{RT},
cache,
f::Annotation,
A::Annotation{<:AbstractTensorMap}
) where {RT}
copy_shadow = cache
if !isa(A, Const) && !isnothing(copy_shadow)
add!(A.dval, copy_shadow)
end
return (nothing, nothing)
end

# need these due to Enzyme choking on blocks
for f in (:project_hermitian, :project_antihermitian)
f! = Symbol(f, :!)
@eval begin
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f!)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
arg::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
$f!(A.val, arg.val, alg.val)
primal = EnzymeRules.needs_primal(config) ? arg.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? arg.dval : nothing
cache = nothing
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f!)},
::Type{RT},
cache,
A::Annotation{<:AbstractTensorMap},
arg::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
if !isa(A, Const)
$f!(arg.dval, arg.dval, alg.val)
if A.dval !== arg.dval
A.dval .+= arg.dval
make_zero!(arg.dval)
end
end
return (nothing, nothing, nothing)
end
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
ret = $f(A.val, alg.val)
dret = make_zero(ret)
primal = EnzymeRules.needs_primal(config) ? ret : nothing
shadow = EnzymeRules.needs_shadow(config) ? dret : nothing
cache = dret
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f)},
::Type{RT},
cache,
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
dret = cache
if !isa(A, Const)
$f!(dret, dret, alg.val)
add!(A.dval, dret)
end
make_zero!(dret)
return (nothing, nothing)
end
end
end

for (f, pb) in (
(:eig_full, :(MatrixAlgebraKit.eig_pullback!)),
(:eigh_full, :(MatrixAlgebraKit.eigh_pullback!)),
(:lq_compact, :(MatrixAlgebraKit.lq_pullback!)),
(:qr_compact, :(MatrixAlgebraKit.qr_pullback!)),
)
@eval begin
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
ret = $f(A.val, alg.val)
primal = EnzymeRules.needs_primal(config) ? ret : nothing
shadow = EnzymeRules.needs_shadow(config) ? make_zero(ret) : nothing
cache = (ret, shadow)
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f)},
::Type{RT},
cache,
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
!isa(A, Const) && $pb(A.dval, A.val, cache...)
return (nothing, nothing)
end
end
end

for f in (:svd_compact, :svd_full)
@eval begin
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
USVᴴ = $f(A.val, alg.val)
primal = EnzymeRules.needs_primal(config) ? USVᴴ : nothing
shadow = EnzymeRules.needs_shadow(config) ? make_zero(USVᴴ) : nothing
cache = (USVᴴ, shadow)
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f)},
::Type{RT},
cache,
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
!isa(A, Const) && MatrixAlgebraKit.svd_pullback!(A.dval, A.val, cache...)
return (nothing, nothing)
end
end

# mutating version is not guaranteed to actually mutate
# so we can simply use the non-mutating version instead
f! = Symbol(f, :!)
#=@eval begin
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f!)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
USVᴴ::Annotation,
alg::Const,
) where {RT}
EnzymeRules.augmented_primal(func, RT, A, alg)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof($f!)},
::Type{RT},
cache,
A::Annotation{<:AbstractTensorMap},
USVᴴ::Annotation,
alg::Const,
) where {RT}
EnzymeRules.reverse(func, RT, A, alg)
end
end=# #hmmmm
end

# TODO
#=
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(svd_trunc)},
::Type{RT},
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}

USVᴴ = svd_compact(A.val, alg.val.alg)
USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.val.trunc)
ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind)
dUSVᴴtrunc = make_zero(USVᴴtrunc)
cache = (USVᴴtrunc, dUSVᴴtrunc)
return EnzymeRules.AugmentedReturn(USVᴴtrunc, dUSVᴴtrunc, cache)
end
function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(svd_trunc)},
::Type{RT},
cache,
A::Annotation{<:AbstractTensorMap},
alg::Const,
) where {RT}
USVᴴ, dUSVᴴ = cache
MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴ, dUSVᴴ)
return (nothing, nothing)
end=#
Loading
Loading