-
Notifications
You must be signed in to change notification settings - Fork 1
/
quadrature.jl
63 lines (56 loc) · 2.11 KB
/
quadrature.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
batch_quadrature(
fs::AbstractVector,
ms::AbstractVector{<:Real},
σs::AbstractVector{<:Real},
num_points::Integer,
)
Approximate the integrals
```julia
∫ fs[n](x) pdf(Normal(ms[n], σs[n]), x) dx
```
for all `n` in `eachindex(fs)` using Gauss-Hermite quadrature with `num_points`.
"""
function batch_quadrature(
fs::AbstractVector,
ms::AbstractVector{<:Real},
σs::AbstractVector{<:Real},
num_points::Integer,
)
# Check that as many bounds are provided as we have functions to integrate.
length(fs) == length(ms) || throw(error("length(fs) != length(ms)"))
length(fs) == length(σs) || throw(error("length(fs) != length(σs)"))
# Construct the quadrature points.
xs, ws = gausshermite(num_points)
# Compute the integral.
return map((f, m, σ) -> _gauss_hermite_quadrature(f, m, σ, xs, ws), fs, ms, σs)
end
Zygote.@nograd gausshermite
# Internal method. Assumes that the gradient w.r.t. xs and ws is never needed, so avoids
# computing it and returns nothing. This is potentially not what you want in general.
function _gauss_hermite_quadrature(f, m::Real, σ::Real, xs, ws)
t(x, m, σ) = m + sqrt(2) * σ * x
I = ws[1] * f(t(xs[1], m, σ))
for j in 2:length(xs)
I += ws[j] * f(t(xs[j], m, σ))
end
return I / sqrt(π)
end
function Zygote._pullback(
ctx::Zygote.AContext, ::typeof(_gauss_hermite_quadrature), f, m::Real, σ::Real, xs, ws,
)
function _gauss_hermite_quadrature_pullback(Δ::Real)
g(f, x, w, m, σ) = w * f(m + sqrt(2) * σ * x)
_, pb = Zygote._pullback(ctx, g, f, xs[1], ws[1], m, σ)
_, Δf, _, _, Δm, Δσ = pb(Δ / sqrt(π))
for j in 2:length(xs)
_, pb = Zygote._pullback(ctx, g, f, xs[j], ws[j], m, σ)
_, Δf_, _, _, Δm_, Δσ_ = pb(Δ / sqrt(π))
Δf = Zygote.accum(Δf, Δf_)
Δm = Zygote.accum(Δm, Δm_)
Δσ = Zygote.accum(Δσ, Δσ_)
end
return nothing, Δf, Δm, Δσ, nothing, nothing
end
return _gauss_hermite_quadrature(f, m, σ, xs, ws), _gauss_hermite_quadrature_pullback
end