Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ export AutoChainRules,
export AutoSparse
@public dense_ad

# DI Automatic Differentiation
export AutoDI
@public inner_ad

# Sparsity detection
export AbstractSparsityDetector
export jacobian_sparsity, hessian_sparsity
Expand Down
50 changes: 50 additions & 0 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -586,3 +586,53 @@ NoAutoDiffSelectedError() = NoAutoDiffSelectedError("Automatic differentiation c
function mode(::NoAutoDiff)
throw(NoAutoDiffSelectedError())
end

"""
AutoDI{I<:AbstractADType}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we call it AutoDifferentiationInterface? I know it is longer but it is more explicit, and coherent with other types in here. People can always define a shortcut if needed


Wraps an AD type to signify that the DifferentiationInterface wrapper should be used instead of calling the backend directly.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this formulation, because it seems to suggest that anything not wrapped in AutoDI will not use DI. However we agreed that the current behavior of packages who call, say, AutoEnzyme through DI should be preserved. Maybe add a word of caution specifying that "using an ADType directly instead of wrapping it inside AutoDI does not forbid the use of DI"?


This allows packages to distinguish between an intention to directly call a corresponding AD tool vs. the DI wrapper for said tool, enabling the ability to use, test, and validate both approaches.

# Fields

- `inner_ad::I`: the underlying AD package, subtyping [`AbstractADType`](@ref)

# Constructors

AutoDI(inner_ad)

# Example

```jldoctest
julia> using ADTypes

julia> ad = AutoDI(AutoForwardDiff())
AutoDI(AutoForwardDiff())

julia> inner_ad(ad)
AutoForwardDiff()
```
"""
struct AutoDI{I <: AbstractADType} <: AbstractADType
inner_ad::I
end

function Base.show(io::IO, backend::AutoDI)
print(io, AutoDI, "(", repr(backend.inner_ad, context = io), ")")
end

"""
inner_ad(ad::AutoDI)::AbstractADType
inner_ad(ad::AbstractADType)::AbstractADType

Return the underlying AD package for a DI AD choice, acts as the identity on a non-DI AD choice.

# See also

- [`AutoDI`](@ref)
"""
inner_ad(ad::AutoDI) = ad.inner_ad
inner_ad(ad::AbstractADType) = ad

mode(di_ad::AutoDI) = mode(inner_ad(di_ad))
65 changes: 65 additions & 0 deletions test/di.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
@testset "AutoDI" begin
@testset "Subtyping and wrapping $ad_name" for (ad_name, ad) in [
("AutoForwardDiff", AutoForwardDiff()),
("AutoZygote", AutoZygote()),
("AutoEnzyme", AutoEnzyme()),
("AutoReverseDiff", AutoReverseDiff()),
("AutoChainRules", AutoChainRules(; ruleconfig = ForwardOrReverseRuleConfig())),
]
di_ad = AutoDI(ad)
@test di_ad isa AbstractADType
@test di_ad isa AutoDI

# Test mode propagation
if mode(ad) isa ForwardMode
@test mode(di_ad) isa ForwardMode
elseif mode(ad) isa ForwardOrReverseMode
@test mode(di_ad) isa ForwardOrReverseMode
elseif mode(ad) isa ReverseMode
@test mode(di_ad) isa ReverseMode
elseif mode(ad) isa SymbolicMode
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is never hit

@test mode(di_ad) isa SymbolicMode
end

# Test inner_ad accessor
@test inner_ad(ad) == ad
@test inner_ad(di_ad) == ad
end

@testset "All AD backends" begin
for ad in every_ad()
di_ad = AutoDI(ad)
@test di_ad isa AbstractADType
@test inner_ad(di_ad) == ad
@test mode(di_ad) == mode(ad)
end
end

@testset "Nested wrapping" begin
# Test that we can wrap AutoDI with AutoSparse and vice versa
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we want to support arbitrary wrappings like that. Shouldn't there be a proper way to wrap and an improper one?

ad = AutoForwardDiff()
di_ad = AutoDI(ad)
sparse_di_ad = AutoSparse(di_ad)

@test sparse_di_ad isa AutoSparse
@test dense_ad(sparse_di_ad) isa AutoDI
@test inner_ad(dense_ad(sparse_di_ad)) == ad

# Test AutoDI wrapping AutoSparse
sparse_ad = AutoSparse(ad)
di_sparse_ad = AutoDI(sparse_ad)

@test di_sparse_ad isa AutoDI
@test inner_ad(di_sparse_ad) isa AutoSparse
@test dense_ad(inner_ad(di_sparse_ad)) == ad
end

@testset "Display" begin
ad = AutoForwardDiff(chunksize = 5)
di_ad = AutoDI(ad)

str = sprint(show, di_ad)
@test occursin("AutoDI", str)
@test occursin("AutoForwardDiff", str)
end
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ADTypes: AbstractADType,
ReverseMode,
SymbolicMode
using ADTypes: dense_ad,
inner_ad,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot fix the merge conflict

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit 540c6d5. The merge conflict was due to formatting differences - the main branch switched from JuliaFormatter to Runic.jl formatting. Updated all files to match the new Runic style (4-space indentation, removed explicit returns, adjusted spacing).

NoSparsityDetector,
KnownJacobianSparsityDetector,
KnownHessianSparsityDetector,
Expand Down Expand Up @@ -101,6 +102,9 @@ end
@testset "Sparse" begin
include("sparse.jl")
end
@testset "DI" begin
include("di.jl")
end
@testset "Symbols" begin
include("symbols.jl")
end
Expand Down
Loading