|
| 1 | +import ChainRulesCore |
| 2 | +using ChainRulesCore: ProjectTo, NoTangent, @not_implemented |
| 3 | + |
| 4 | +function ChainRulesCore.rrule(c::ChebPoly{1}, x::Real) |
| 5 | + project_x = ProjectTo(x) |
| 6 | + y, ∇y = chebgradient(c, x) |
| 7 | + chebpoly_pullback(∂y) = @not_implemented("no rrule for changes in ChebPoly itself"), project_x(real(∇y' * ∂y)) |
| 8 | + y, chebpoly_pullback |
| 9 | +end |
| 10 | + |
| 11 | +function ChainRulesCore.rrule(c::ChebPoly, x::AbstractVector{<:Real}) |
| 12 | + project_x = ProjectTo(x) |
| 13 | + y, J = chebjacobian(c, x) |
| 14 | + chebpoly_pullback(Δy) = @not_implemented("no rrule for changes in ChebPoly itself"), project_x(vec(real(J' * Δy))) |
| 15 | + y, chebpoly_pullback |
| 16 | +end |
| 17 | + |
| 18 | +ChainRulesCore.frule((Δself, Δx), c::ChebPoly{1}, x::Real) = |
| 19 | + ChainRulesCore.frule((Δself, SVector{1}(Δx)), c, SVector{1}(x)) |
| 20 | + |
| 21 | +function ChainRulesCore.frule((Δself, Δx), c::ChebPoly, x::AbstractVector) |
| 22 | + y, J = chebjacobian(c, x) |
| 23 | + if Δself isa ChainRulesCore.AbstractZero # Δself == 0 |
| 24 | + Δy = J * Δx |
| 25 | + return y, y isa Number ? Δy[1] : Δy |
| 26 | + else # need derivatives with respect to changes in c |
| 27 | + # additional Δx from changes in bound: |
| 28 | + # --- recall x0 = @. (x - c.lb) * 2 / (c.ub - c.lb) - 1, |
| 29 | + # but note that J already includes 2 / (c.ub - c.lb) |
| 30 | + d2 = @. (x - c.lb) / (c.ub - c.lb) |
| 31 | + Δx′ = @. Δx + (d2 - 1) * Δself.lb - d2 * Δself.ub |
| 32 | + Δy = J * Δx′ |
| 33 | + |
| 34 | + # dependence on coefs is linear |
| 35 | + Δcoefs = typeof(c)(Δself.coefs, c.lb, c.ub) |
| 36 | + |
| 37 | + return y, (y isa Number ? Δy[1] : Δy) + Δcoefs(x) |
| 38 | + end |
| 39 | +end |
0 commit comments