-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathutils.jl
More file actions
28 lines (24 loc) · 915 Bytes
/
utils.jl
File metadata and controls
28 lines (24 loc) · 915 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
get_config(::AnyAutoMooncake{Nothing}) = Config()
get_config(backend::AnyAutoMooncake{<:Config}) = backend.config
@inline zero_tangent_unwrap(c::DI.Context) = zero_tangent(DI.unwrap(c))
@inline first_unwrap(c, dc) = (DI.unwrap(c), dc)
function call_and_return(f!::F, y, x, contexts...) where {F}
f!(y, x, contexts...)
return y
end
function adaptive_tangent_to_primal!!(primal, tangent)
@static if new_friendly_tangents()
# TODO: optimize performance by allocating cache during prep
return Mooncake.tangent_to_friendly!!(primal, tangent)
else
return Mooncake.tangent_to_primal!!(primal, tangent)
end
end
function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
if get_config(backend).friendly_tangents
# zero(x) but safer
return adaptive_tangent_to_primal!!(_copy_output(x), zero_tangent(x))
else
return zero_tangent(x)
end
end