Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export AutoChainRules,
AutoFiniteDifferences,
AutoForwardDiff,
AutoGTPSA,
AutoHyperHessians,
AutoModelingToolkit,
AutoMooncake,
AutoMooncakeForward,
Expand Down
32 changes: 32 additions & 0 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,38 @@ struct AutoZygote <: AbstractADType end

mode(::AutoZygote) = ReverseMode()

"""
AutoHyperHessians{chunksize}

Struct used to select the [HyperHessians.jl](https://github.com/KristofferC/HyperHessians.jl) backend for automatic differentiation.

Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).

# Constructors

AutoHyperHessians(; chunksize=nothing)

# Type parameters

- `chunksize`: the preferred chunk size to evaluate several derivatives at once. If `nothing`, HyperHessians chooses automatically.
"""
struct AutoHyperHessians{chunksize} <: AbstractADType end

function AutoHyperHessians(; chunksize::Union{Nothing, Int} = nothing)
if chunksize isa Int
chunksize > 0 || throw(ArgumentError("chunksize must be positive, got $chunksize"))
end
return AutoHyperHessians{chunksize}()
end

mode(::AutoHyperHessians) = ForwardMode()

function Base.show(io::IO, ::AutoHyperHessians{chunksize}) where {chunksize}
print(io, AutoHyperHessians, "(")
chunksize !== nothing && print(io, "chunksize=", repr(chunksize; context = io))
return print(io, ")")
end

"""
NoAutoDiff

Expand Down
4 changes: 2 additions & 2 deletions src/symbols.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ Auto(package::Symbol, args...; kws...) = Auto(Val(package), args...; kws...)

for backend in (
:ChainRules, :Diffractor, :Enzyme, :Reactant, :FastDifferentiation,
:FiniteDiff, :FiniteDifferences, :ForwardDiff, :GTPSA, :Mooncake, :PolyesterForwardDiff,
:ReverseDiff, :Symbolics, :Tapir, :TaylorDiff, :Tracker, :Zygote,
:FiniteDiff, :FiniteDifferences, :ForwardDiff, :GTPSA, :HyperHessians, :Mooncake,
:PolyesterForwardDiff, :ReverseDiff, :Symbolics, :Tapir, :TaylorDiff, :Tracker, :Zygote,
)
@eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))(
args...; kws...
Expand Down
14 changes: 14 additions & 0 deletions test/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,20 @@ end
@test ad.descriptor == Val(:descriptor)
end

@testset "AutoHyperHessians" begin
ad = AutoHyperHessians()
@test ad isa AbstractADType
@test ad isa AutoHyperHessians{nothing}
@test mode(ad) isa ForwardMode

ad = AutoHyperHessians(; chunksize = 8)
@test ad isa AbstractADType
@test ad isa AutoHyperHessians{8}
@test mode(ad) isa ForwardMode

@test_throws ArgumentError AutoHyperHessians(; chunksize = -1)
end

@testset "AutoMooncake" begin
ad = AutoMooncake(; config = :config)
@test ad isa AbstractADType
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ function every_ad()
AutoFiniteDifferences(; fdm = :fdm),
AutoForwardDiff(),
AutoGTPSA(),
AutoHyperHessians(),
AutoPolyesterForwardDiff(),
AutoReverseDiff(),
AutoSymbolics(),
Expand All @@ -72,6 +73,8 @@ function every_ad_with_options()
AutoForwardDiff(chunksize = 3, tag = :tag),
AutoGTPSA(),
AutoGTPSA(descriptor = Val(:descriptor)),
AutoHyperHessians(),
AutoHyperHessians(chunksize = 8),
AutoMooncake(; config = :config),
AutoMooncakeForward(; config = :config),
AutoPolyesterForwardDiff(),
Expand Down
1 change: 1 addition & 0 deletions test/symbols.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using Test
@test ADTypes.Auto(:FiniteDiff) isa AutoFiniteDiff
@test ADTypes.Auto(:FiniteDifferences, 1.0) isa AutoFiniteDifferences{Float64}
@test ADTypes.Auto(:ForwardDiff) isa AutoForwardDiff
@test ADTypes.Auto(:HyperHessians) isa AutoHyperHessians
@test ADTypes.Auto(:Mooncake) isa AutoMooncake
@test ADTypes.Auto(:PolyesterForwardDiff) isa AutoPolyesterForwardDiff
@test ADTypes.Auto(:ReverseDiff) isa AutoReverseDiff
Expand Down
Loading