Skip to content

Commit

Permalink
Make ChainRulesCore and DensityInterface weak dependencies (JuliaStat…
Browse files Browse the repository at this point in the history
…s#1686)

* Make ChainRulesCore and DensityInterface weak dependencies

* Fixes

* More fixes

* Another fix
  • Loading branch information
devmotion authored Mar 3, 2023
1 parent eecfd3c commit b9d063f
Show file tree
Hide file tree
Showing 16 changed files with 201 additions and 182 deletions.
12 changes: 11 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"

[extensions]
DistributionsChainRulesCoreExt = "ChainRulesCore"
DistributionsDensityInterfaceExt = "DensityInterface"

[compat]
ChainRulesCore = "1"
DensityInterface = "0.4"
Expand All @@ -32,7 +40,9 @@ julia = "1.3"

[extras]
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -43,4 +53,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["StableRNGs", "Calculus", "ChainRulesTestUtils", "Distributed", "FiniteDifferences", "ForwardDiff", "JSON", "StaticArrays", "Test", "OffsetArrays"]
test = ["StableRNGs", "Calculus", "ChainRulesCore", "ChainRulesTestUtils", "DensityInterface", "Distributed", "FiniteDifferences", "ForwardDiff", "JSON", "StaticArrays", "Test", "OffsetArrays"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module DistributionsChainRulesCoreExt

using Distributions
using Distributions: LinearAlgebra, SpecialFunctions, StatsFuns
import ChainRulesCore

include("eachvariate.jl")
include("utils.jl")

include("univariate/continuous/uniform.jl")
include("univariate/discrete/negativebinomial.jl")
include("univariate/discrete/poissonbinomial.jl")

include("multivariate/dirichlet.jl")

end # module
10 changes: 10 additions & 0 deletions ext/DistributionsChainRulesCoreExt/eachvariate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
function ChainRulesCore.rrule(::Type{Distributions.EachVariate{V}}, x::AbstractArray{<:Real}) where {V}
y = Distributions.EachVariate{V}(x)
size_x = size(x)
function EachVariate_pullback(Δ)
# TODO: Should we also handle `Tangent{<:EachVariate}`?
Δ_out = reshape(mapreduce(vec, vcat, ChainRulesCore.unthunk(Δ)), size_x)
return (ChainRulesCore.NoTangent(), Δ_out)
end
return y, EachVariate_pullback
end
57 changes: 57 additions & 0 deletions ext/DistributionsChainRulesCoreExt/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
function ChainRulesCore.frule((_, Δalpha)::Tuple{Any,Any}, ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
d = DT(alpha; check_args=check_args)
∂alpha0 = sum(Δalpha)
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha) do Δalphai, alphai
Δalphai * (SpecialFunctions.digamma(alphai) - digamma_alpha0)
end))
Δd = ChainRulesCore.Tangent{typeof(d)}(; alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB)
return d, Δd
end

function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
d = DT(alpha; check_args=check_args)
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
function Dirichlet_pullback(_Δd)
Δd = ChainRulesCore.unthunk(_Δd)
Δalpha = Δd.alpha .+ Δd.alpha0 .+ Δd.lmnB .* (SpecialFunctions.digamma.(alpha) .- digamma_alpha0)
return ChainRulesCore.NoTangent(), Δalpha
end
return d, Dirichlet_pullback
end

function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(Distributions._logpdf), d::Dirichlet, x::AbstractVector{<:Real})
Ω = Distributions._logpdf(d, x)
∂alpha = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalphai, Δxi, alphai, xi
StatsFuns.xlogy(Δalphai, xi) + (alphai - 1) * Δxi / xi
end))
∂lmnB = -Δd.lmnB
ΔΩ = ∂alpha + ∂lmnB
if !isfinite(Ω)
ΔΩ = oftype(ΔΩ, NaN)
end
return Ω, ΔΩ
end

function ChainRulesCore.rrule(::typeof(Distributions._logpdf), d::T, x::AbstractVector{<:Real}) where {T<:Dirichlet}
Ω = Distributions._logpdf(d, x)
isfinite_Ω = isfinite(Ω)
alpha = d.alpha
function _logpdf_Dirichlet_pullback(_ΔΩ)
ΔΩ = ChainRulesCore.unthunk(_ΔΩ)
∂alpha = _logpdf_Dirichlet_∂alphai.(x, ΔΩ, isfinite_Ω)
∂lmnB = isfinite_Ω ? -float(ΔΩ) : oftype(float(ΔΩ), NaN)
Δd = ChainRulesCore.Tangent{T}(; alpha=∂alpha, lmnB=∂lmnB)
Δx = _logpdf_Dirichlet_Δxi.(ΔΩ, alpha, x, isfinite_Ω)
return ChainRulesCore.NoTangent(), Δd, Δx
end
return Ω, _logpdf_Dirichlet_pullback
end
function _logpdf_Dirichlet_∂alphai(xi, ΔΩi, isfinite::Bool)
∂alphai = StatsFuns.xlogy.(ΔΩi, xi)
return isfinite ? ∂alphai : oftype(∂alphai, NaN)
end
function _logpdf_Dirichlet_Δxi(ΔΩi, alphai, xi, isfinite::Bool)
Δxi = ΔΩi * (alphai - 1) / xi
return isfinite ? Δxi : oftype(Δxi, NaN)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
function ChainRulesCore.frule((_, Δd, _), ::typeof(logpdf), d::Uniform, x::Real)
# Compute log probability
a, b = params(d)
insupport = a <= x <= b
diff = b - a
Ω = insupport ? -log(diff) : log(zero(diff))

# Compute tangent
Δdiff = Δd.a - Δd.b
ΔΩ = (insupport ? Δdiff : zero(Δdiff)) / diff

return Ω, ΔΩ
end

function ChainRulesCore.rrule(::typeof(logpdf), d::Uniform, x::Real)
# Compute log probability
a, b = params(d)
insupport = a <= x <= b
diff = b - a
Ω = insupport ? -log(diff) : log(zero(diff))

# Define pullback
function logpdf_Uniform_pullback(Δ)
Δa = Δ / diff
Δd = if insupport
ChainRulesCore.Tangent{typeof(d)}(; a=Δa, b=-Δa)
else
ChainRulesCore.Tangent{typeof(d)}(; a=zero(Δa), b=zero(Δa))
end
return ChainRulesCore.NoTangent(), Δd, ChainRulesCore.ZeroTangent()
end

return Ω, logpdf_Uniform_pullback
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
## Callable struct to fix type inference issues caused by captured values
struct LogPDFNegativeBinomialPullback{D,T<:Real}
∂r::T
∂p::T
end

function (f::LogPDFNegativeBinomialPullback{D})(Δ) where {D}
Δr = Δ * f.∂r
Δp = Δ * f.∂p
Δd = ChainRulesCore.Tangent{D}(; r=Δr, p=Δp)
return ChainRulesCore.NoTangent(), Δd, ChainRulesCore.NoTangent()
end

function ChainRulesCore.rrule(::typeof(logpdf), d::NegativeBinomial, k::Real)
# Compute log probability (as in the definition of `logpdf(d, k)` above)
r, p = params(d)
z = StatsFuns.xlogy(r, p) + StatsFuns.xlog1py(k, -p)
if iszero(k)
Ω = z
∂r = oftype(z, log(p))
∂p = oftype(z, r/p)
elseif insupport(d, k)
Ω = z - log(k + r) - SpecialFunctions.logbeta(r, k + 1)
∂r = oftype(z, log(p) - inv(k + r) - SpecialFunctions.digamma(r) + SpecialFunctions.digamma(r + k + 1))
∂p = oftype(z, r/p - k / (1 - p))
else
Ω = oftype(z, -Inf)
∂r = oftype(z, NaN)
∂p = oftype(z, NaN)
end

# Define pullback
logpdf_NegativeBinomial_pullback = LogPDFNegativeBinomialPullback{typeof(d),typeof(z)}(∂r, ∂p)

return Ω, logpdf_NegativeBinomial_pullback
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
for f in (:poissonbinomial_pdf, :poissonbinomial_pdf_fft)
pullback = Symbol(f, :_pullback)
@eval begin
function ChainRulesCore.frule(
(_, Δp)::Tuple{<:Any,<:AbstractVector{<:Real}}, ::typeof(Distributions.$f), p::AbstractVector{<:Real}
)
y = Distributions.$f(p)
A = Distributions.poissonbinomial_pdf_partialderivatives(p)
return y, A' * Δp
end
function ChainRulesCore.rrule(::typeof(Distributions.$f), p::AbstractVector{<:Real})
y = Distributions.$f(p)
A = Distributions.poissonbinomial_pdf_partialderivatives(p)
function $pullback(Δy)
= ChainRulesCore.InplaceableThunk(
Δ -> LinearAlgebra.mul!(Δ, A, Δy, true, true),
ChainRulesCore.@thunk(A * Δy),
)
return ChainRulesCore.NoTangent(), p̄
end
return y, $pullback
end
end
end
1 change: 1 addition & 0 deletions ext/DistributionsChainRulesCoreExt/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ChainRulesCore.@non_differentiable Distributions.check_args(::Any, ::Bool)
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
module DistributionsDensityInterfaceExt

using Distributions
import DensityInterface

@inline DensityInterface.DensityKind(::Distribution) = DensityInterface.HasDensity()

for (di_func, d_func) in ((:logdensityof, :logpdf), (:densityof, :pdf))
Expand All @@ -17,3 +22,5 @@ for (di_func, d_func) in ((:logdensityof, :logpdf), (:densityof, :pdf))
end
end
end

end # module
11 changes: 5 additions & 6 deletions src/Distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ import PDMats: dim, PDMat, invquad

using SpecialFunctions

import ChainRulesCore

import DensityInterface

export
# re-export Statistics
mean, median, quantile, std, var, cov, cor,
Expand Down Expand Up @@ -310,8 +306,11 @@ include("pdfnorm.jl")
include("mixtures/mixturemodel.jl")
include("mixtures/unigmm.jl")

# Implementation of DensityInterface API
include("density_interface.jl")
# Extensions: Implementation of DensityInterface and ChainRulesCore API
if !isdefined(Base, :get_extension)
include("../ext/DistributionsChainRulesCoreExt/DistributionsChainRulesCoreExt.jl")
include("../ext/DistributionsDensityInterfaceExt.jl")
end

# Testing utilities for other packages which implement distributions.
include("test_utils.jl")
Expand Down
11 changes: 0 additions & 11 deletions src/eachvariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,6 @@ function EachVariate{V}(x::AbstractArray{<:Real,M}) where {V,M}
return EachVariate{V,typeof(x),typeof(ax),T,M-V}(x, ax)
end

function ChainRulesCore.rrule(::Type{EachVariate{V}}, x::AbstractArray{<:Real}) where {V}
y = EachVariate{V}(x)
size_x = size(x)
function EachVariate_pullback(Δ)
# TODO: Should we also handle `Tangent{<:EachVariate}`?
Δ_out = reshape(mapreduce(vec, vcat, ChainRulesCore.unthunk(Δ)), size_x)
return (ChainRulesCore.NoTangent(), Δ_out)
end
return y, EachVariate_pullback
end

Base.IteratorSize(::Type{EachVariate{V,P,A,T,N}}) where {V,P,A,T,N} = Base.HasShape{N}()

Base.axes(x::EachVariate) = x.axes
Expand Down
59 changes: 0 additions & 59 deletions src/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,62 +375,3 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64},
elogp = mean_logp(suffstats(Dirichlet, P, w))
fit_dirichlet!(elogp, α; maxiter=maxiter, tol=tol, debug=debug)
end

## Differentiation
function ChainRulesCore.frule((_, Δalpha)::Tuple{Any,Any}, ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
d = DT(alpha; check_args=check_args)
∂alpha0 = sum(Δalpha)
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha) do Δalphai, alphai
Δalphai * (SpecialFunctions.digamma(alphai) - digamma_alpha0)
end))
Δd = ChainRulesCore.Tangent{typeof(d)}(; alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB)
return d, Δd
end

function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
d = DT(alpha; check_args=check_args)
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
function Dirichlet_pullback(_Δd)
Δd = ChainRulesCore.unthunk(_Δd)
Δalpha = Δd.alpha .+ Δd.alpha0 .+ Δd.lmnB .* (SpecialFunctions.digamma.(alpha) .- digamma_alpha0)
return ChainRulesCore.NoTangent(), Δalpha
end
return d, Dirichlet_pullback
end

function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real})
Ω = _logpdf(d, x)
∂alpha = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalphai, Δxi, alphai, xi
xlogy(Δalphai, xi) + (alphai - 1) * Δxi / xi
end))
∂lmnB = -Δd.lmnB
ΔΩ = ∂alpha + ∂lmnB
if !isfinite(Ω)
ΔΩ = oftype(ΔΩ, NaN)
end
return Ω, ΔΩ
end

function ChainRulesCore.rrule(::typeof(_logpdf), d::T, x::AbstractVector{<:Real}) where {T<:Dirichlet}
Ω = _logpdf(d, x)
isfinite_Ω = isfinite(Ω)
alpha = d.alpha
function _logpdf_Dirichlet_pullback(_ΔΩ)
ΔΩ = ChainRulesCore.unthunk(_ΔΩ)
∂alpha = _logpdf_Dirichlet_∂alphai.(x, ΔΩ, isfinite_Ω)
∂lmnB = isfinite_Ω ? -float(ΔΩ) : oftype(float(ΔΩ), NaN)
Δd = ChainRulesCore.Tangent{T}(; alpha=∂alpha, lmnB=∂lmnB)
Δx = _logpdf_Dirichlet_Δxi.(ΔΩ, alpha, x, isfinite_Ω)
return ChainRulesCore.NoTangent(), Δd, Δx
end
return Ω, _logpdf_Dirichlet_pullback
end
function _logpdf_Dirichlet_∂alphai(xi, ΔΩi, isfinite::Bool)
∂alphai = xlogy.(ΔΩi, xi)
return isfinite ? ∂alphai : oftype(∂alphai, NaN)
end
function _logpdf_Dirichlet_Δxi(ΔΩi, alphai, xi, isfinite::Bool)
Δxi = ΔΩi * (alphai - 1) / xi
return isfinite ? Δxi : oftype(Δxi, NaN)
end
37 changes: 0 additions & 37 deletions src/univariate/continuous/uniform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,40 +165,3 @@ function fit_mle(::Type{T}, x::AbstractArray{<:Real}) where {T<:Uniform}
end
return T(extrema(x)...)
end

# ChainRules definitions

function ChainRulesCore.frule((_, Δd, _), ::typeof(logpdf), d::Uniform, x::Real)
# Compute log probability
a, b = params(d)
insupport = a <= x <= b
diff = b - a
Ω = insupport ? -log(diff) : log(zero(diff))

# Compute tangent
Δdiff = Δd.a - Δd.b
ΔΩ = (insupport ? Δdiff : zero(Δdiff)) / diff

return Ω, ΔΩ
end

function ChainRulesCore.rrule(::typeof(logpdf), d::Uniform, x::Real)
# Compute log probability
a, b = params(d)
insupport = a <= x <= b
diff = b - a
Ω = insupport ? -log(diff) : log(zero(diff))

# Define pullback
function logpdf_Uniform_pullback(Δ)
Δa = Δ / diff
Δd = if insupport
ChainRulesCore.Tangent{typeof(d)}(; a=Δa, b=-Δa)
else
ChainRulesCore.Tangent{typeof(d)}(; a=zero(Δa), b=zero(Δa))
end
return ChainRulesCore.NoTangent(), Δd, ChainRulesCore.ZeroTangent()
end

return Ω, logpdf_Uniform_pullback
end
Loading

0 comments on commit b9d063f

Please sign in to comment.