-
Notifications
You must be signed in to change notification settings - Fork 219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Notes from the September probprog community call #1410
Comments
How about |
Ad (5): this should be easy (at least to write) in my proposed IR. Or in every system that has a separation of variable names and sampling, so that you can have normal assignement to a named variable: {x} ~ Normal()
{log_x} = log({x}) It would still have to be supported by the PPL evaluating the model, of course. THis is also the kind of thing that could allow to constrain intermediate transformations -- something I believe @zenna mentioned at some point. Ad (6): that could even be combinator in Measures.jl, couldn't it? |
Yes! But it would be nice to add the custom adjoint part to that as well and have Chris Elrod implement some heavily optimized logpdf methods for those distributions or the distribution generators. |
Yes either Measures or Soss would be a good place for this. |
For multivariate normal, there's a lot of value in representing the covariance Then x ~ MvNormal(Σ) |> iid(n) would reduce to x ~ NewGaussian(I(n) ⊗ Σ) and the correlation between the (i,a)th and (j,b)th elements of
and if Σ is k×k, |
I was sort of half-joking originally, but defining a distribution (or a measure, more generally) over named tuples is sort of Soss's whole deal. Also, @tpapp 's TransformVariables is really good at this, if your starting point is a distribution over ℝⁿ. |
One reason I'm already a huge fan of Measures.jl is the nice "multiple parametrizations" idea. So actually, there could be
falling back to
or similarly. |
Like in Stan, I'd prefer to point people towards using This is an unoptimized but simple and dependency-free (aside from standard library) implementation for calculating the density and gradients: using LinearAlgebra
# y is an P x N matrix containing N total P-dimensional observations
# m is a vector of length P containing the means
# l is the lower triangular Cholesky factor of the covariance matrix
# dy, dm, and dl are the corresponding gradients
# b is preallocated memory with the same size and shape as y
function ldnorm!(dy, dm, dl, b, y, m, l)
b .= m .- y
LAPACK.trtrs!('L', 'N', 'N', l, b)
lp = dot(b, b); dy .= b
LAPACK.trtrs!('L', 'T', 'N', l, dy)
#@avx should be a little faster, and let you move the 0-assignment into the loop nest
dm .= 0
@inbounds for n ∈ 1:size(y,2)
@simd for p ∈ 1:size(y,1)
dm[p] -= dy[p,n]
end
end
BLAS.gemm!('N', 'T', 1.0, dy, b, 0.0, dl)
N = size(y,2); lp *= -0.5
@inbounds for p ∈ 1:size(dl,1) #@avx instead would be much faster
dl[p,p] -= N * l[p,p]
lp -= N * log(l[p,p])
end
lp
end The optimized code I wrote that does only a single pass over the memory is much messier and currently broken, although it handles a variety of situations such as were the mean is defined as Unfortunately, LAPACK doesn't include a version of I could work on rewriting and cleaning it up, but I think long term I'd rather get LoopVectorization to be able to handle the kind of optimizations and loop dependency structures needed. |
Isn't the same behaviour achieved by MvNormal without keyword arguments by dispatching on the type of the covariance matrix, such as |
Oh wait, sorry I had misread @phipsgabler 's comment. I think it's a little weird how Distributions.jl works in this case. It has
Yep, I agree. Also, it's often handy to allow positive semidefinite covariance. Tim Holy has an approach to this in PositiveFactorizations.jl that seems promising.
This looks nice, but it seems to not be really allow for different parameterizations. Instead, it uses keyword arguments to coerce to the standard parameterization. A big advantage of reparameterization is to make computation cheaper, but this doesn't seem to help in that regard. |
I agree, there is some inconsistency in how |
Thanks @mohamed82008! |
In Turing we decide what's "observed" based on whether or not it's present in So the issue I don't think is related to whether or not the transformation is bijective, but rather that we can't handle something like @model function demo(x)
y = f(x)
y ~ Likelihood()
end because we don't know that LHS is a function of the inputs. Also, it's worth pointing out that it's still possible to conditio on observations by using the @model function demo(x)
y = f(x)
@addlogprob! logpdf(Likelihood(), y)
end Alternatively we can make it look nicer by doing something like @model function demo(x)
y = f(x)
@observe y ~ Likelihood()
end which simply expands to the above.
I really like going in this direction. Distributions for named tuples is something that keeps coming up as something that would be super-useful to have (also added support for transforming something like that in Bijectors.jl recently).
Btw, we've also added support for transforming Also, I completely missed meeting! Sorry! Been in the process of moving and getting sorted for starting my studies, so completely forgot about it. |
THIS IS SO GREAT!!!! Link: TuringLang/Bijectors.jl#95 Some background for others: It's often important to have the support of a variable depend stochastically on the value of another. It's not just a niche thing either, it comes up any time you want variables to be ordered in some way. And this is a big deal because exchangeability in posteriors is just a mess, causing most samplers to mode-switch all over the place, which in turn screws up most diagnostics. So yeah, 💯 |
Yes, this should be possible. |
So what I am proposing is practically a symbolic AD layer on top of ChainRules to write a chain rule for a function that we really really care about and that we can define in global scope (annotated with a macro) once for our users to use over and over again. For example, the logpdf of an |
Since we are talking AD and probprog, CC: @ChrisRackauckas. |
I think even with this limitation, the proposal is still useful. |
For symbolic AD here, you're just asking for ModelingToolkit? |
If MTK can connect ChainRules like lego to make a bigger chain rule, then yes. I really should play more with MTK. |
Or well...like a chain! |
It's built on DiffRules right now, but with @oxinabox 's changes to ChainRules it could probably use ChainRules now. |
Awesome! If this happens, I think it can solve like 90% of the AD needs if one is willing to work in global scope. |
So for a bigger picture, I imagine a world where more functions are annotated with Then the entry point for AD can be Zygote (or ChainRules really at this point). Zygote calls ChainRules making use of MTK-generated ChainRules where possible, which themselves can call back into Zygote if needed. A beautiful harmony of symbolic AD (MTK), runtime-based AD (Zygote) and manual differentiation (ChainRules) all working together to avoid Zygote's type instability! |
That's exactly how we're using it in DiffEq, defining Jacobians, vjps, etc. to augment the AD world. The nice thing is that symbolic derivatives are more efficient than AD derivatives in the sparse case: coloring gets close, but is never as efficient as directly defining entries of the matrix. So we mainly use it to build big sparse things, but yes it handles the case the other source-to-source ADs like Zygote are bad at (scalarized stuff). |
But it also needs to be taught when to give up. For example, if a function is too big. I don't know if MTK unrolls loops or tries to work with it as a single node. If it's the former, then it will be pretty taxing on the compiler and it can take a very long time to parse for long loops. |
It doesn't give up. We leave it to user choice. We have cases where 500 second compile time is a reduction in compute time by 80 hours though... so it's really a choice and I think any heuristic is wrong. Instead, a heuristic on that should exist (and be overridable) in Turing, or some high level "glue AD". |
Sure but at least it needs to be taught how to give up and we can tell it when 😅 |
For example, setting a time limit. |
Use a task and do the computation on a task and kill the task if it goes over a time limit. That's what we plan to do in AutoOptimize.jl. |
Nice! |
I haven't had the time to get it into a working state, but this much at least does work julia> using ProbabilityModels, LinearAlgebra, DistributionParameters
julia> sg = quote
σ ~ Gamma(1.0, 0.05)
L ~ LKJ(2.0)
β ~ Normal(10.0) # μ = 0
μ ~ Normal(100.0)
σL = Diagonal(σ) * L
Y₁ ~ Normal( X₁, β, μ', σL )
Y₂ ~ Normal( X₂, β[1:7,:], μ', σL )
end; A model definition API would be something like The macro would also define methods for functions dispatch on the struct to calculate the logdensity and gradient. The macro would use julia> m = ProbabilityModels.read_model(sg, Main); Generating data to create an example named tuple: julia> N, K₁, K₂, P = 100, 10, 7, 12;
julia> μ = rand(P);
julia> σU = rand(P, (3P)>>1) |> x -> cholesky(Hermitian(x * x')).U;
julia> X₁ = rand(N, K₁);
julia> X₂ = rand(N, K₂);
julia> β = randn(K₁, P);
julia> Y₁ = mul!(rand(N, P) * σU, X₁, β, 1.0, 1.0);
julia> Y₂ = mul!(rand(N, P) * σU, X₂, view(β, 1:K₂, :), 1.0, 1.0);
julia> datant = (
Y₁ = Y₁, Y₂ = Y₂, X₁ = X₁, X₂ = X₂,
μ = RealVector{12}(), β = RealMatrix{10,12}(),
σ = RealVector{12,0.0,Inf}(), L = CorrelationMatrixCholesyFactor{12}()
); A model description would be inserted into the generated function methods for logdensity and gradients. julia> ProbabilityModels.preprocess!(m, typeof(datant));
julia> ProbabilityModels.ReverseDiffExpressions.lower(m)
quote
(var"###STACK##POINTER###", var"##targetconstrained#263") = ReverseDiffExpressions.stack_pointer_call(DistributionParameters.constrain, var"###STACK##POINTER###", var"#θ#", 120, RealArray{Tuple{12}, -Inf, Inf, 0}())
var"##TARGET###0#" = ReverseDiffExpressionsBase.first(var"##targetconstrained#263")
(var"###STACK##POINTER###", var"##targetconstrained#265") = ReverseDiffExpressions.stack_pointer_call(DistributionParameters.constrain, var"###STACK##POINTER###", var"#θ#", 132, RealArray{Tuple{12}, 0.0, Inf, 0}())
σ = ReverseDiffExpressionsBase.second(var"##targetconstrained#265")
(var"###STACK##POINTER###", var"####TARGET###1##267") = ReverseDiffExpressions.stack_pointer_call(logdensity, var"###STACK##POINTER###", Gamma{(true, false, false, false)}(), σ, 1.0, 0.05, var"##ONE##")
var"##TARGET###1#" = ReverseDiffExpressions.vadd!(var"##TARGET###0#", var"####TARGET###1##267")
(var"###STACK##POINTER###", var"##targetconstrained#266") = ReverseDiffExpressions.stack_pointer_call(DistributionParameters.constrain, var"###STACK##POINTER###", var"#θ#", 144, CorrelationMatrixCholesyFactor{12}())
L = ReverseDiffExpressionsBase.second(var"##targetconstrained#266")
(var"###STACK##POINTER###", var"####TARGET###2##268") = ReverseDiffExpressions.stack_pointer_call(logdensity, var"###STACK##POINTER###", LKJ{(true, false, false)}(), L, 2.0, var"##ONE##")
var"##TARGET###2#" = ReverseDiffExpressions.vadd!(var"##TARGET###1#", var"####TARGET###2##268")
(var"###STACK##POINTER###", var"##targetconstrained#264") = ReverseDiffExpressions.stack_pointer_call(DistributionParameters.constrain, var"###STACK##POINTER###", var"#θ#", 0, RealArray{Tuple{10,12}, -Inf, Inf, 0}())
β = ReverseDiffExpressionsBase.second(var"##targetconstrained#264")
(var"###STACK##POINTER###", var"####TARGET###3##269") = ReverseDiffExpressions.stack_pointer_call(logdensity, var"###STACK##POINTER###", Normal{(true, false, false)}(), β, 10.0, var"##ONE##")
var"##TARGET###3#" = ReverseDiffExpressions.vadd!(var"##TARGET###2#", var"####TARGET###3##269")
μ = ReverseDiffExpressionsBase.second(var"##targetconstrained#263")
(var"###STACK##POINTER###", var"####TARGET###4##270") = ReverseDiffExpressions.stack_pointer_call(logdensity, var"###STACK##POINTER###", Normal{(true, false, false)}(), μ, 100.0, var"##ONE##")
var"##TARGET###4#" = ReverseDiffExpressions.vadd!(var"##TARGET###3#", var"####TARGET###4##270")
Y₁ = (var"#DATA#").Y₁
X₁ = (var"#DATA#").X₁
(var"###STACK##POINTER###", var"##LHS#259#0#") = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.adjoint, var"###STACK##POINTER###", μ)
(var"###STACK##POINTER###", var"##LHS#258") = ReverseDiffExpressions.stack_pointer_call(Diagonal, var"###STACK##POINTER###", σ)
(var"###STACK##POINTER###", σL) = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.:*, var"###STACK##POINTER###", var"##LHS#258", L)
(var"###STACK##POINTER###", var"####TARGET###5##271") = ReverseDiffExpressions.stack_pointer_call(logdensity, var"###STACK##POINTER###", Normal{(false, false, true, true, true, false)}(), Y₁, X₁, β, var"##LHS#259", σL, var"##ONE##")
var"##TARGET###5#" = ReverseDiffExpressions.vadd!(var"##TARGET###4#", var"####TARGET###5##271")
Y₂ = (var"#DATA#").Y₂
X₂ = (var"#DATA#").X₂
var"##LHS#260" = Base.view(β, StaticUnitRange{1, 7}(), :)
(var"###STACK##POINTER###", var"####TARGET###6##272") = ReverseDiffExpressions.stack_pointer_call(logdensity, var"###STACK##POINTER###", Normal{(false, false, true, true, true, false)}(), Y₂, X₂, var"##LHS#260", var"##LHS#259", σL, var"##ONE##")
var"##TARGET###6#" = ReverseDiffExpressions.vadd!(var"##TARGET###5#", var"####TARGET###6##272")
var"####TARGET###7##273" = ReverseDiffExpressionsBase.first(var"##targetconstrained#264")
var"##TARGET###7#" = ReverseDiffExpressions.vadd!(var"##TARGET###6#", var"####TARGET###7##273")
var"####TARGET###8##274" = ReverseDiffExpressionsBase.first(var"##targetconstrained#265")
var"##TARGET###8#" = ReverseDiffExpressions.vadd!(var"##TARGET###7#", var"####TARGET###8##274")
var"####TARGET###275" = ReverseDiffExpressionsBase.first(var"##targetconstrained#266")
var"##TARGET##" = ReverseDiffExpressions.vadd!(var"##TARGET###8#", var"####TARGET###275")
vsum(var"##TARGET##")
end Verbose, but a few notes:
Similarly, it can differentiate the model method: julia> dm = ProbabilityModels.ReverseDiffExpressions.differentiate(m);
julia> ProbabilityModels.ReverseDiffExpressions.lower(dm)
quote
(var"###STACK##POINTER###", var"##constrainpullbacktup#302") = ReverseDiffExpressions.stack_pointer_call(constrain_pullback!, var"###STACK##POINTER###", var"#∇#", var"#θ#", 132, RealArray{Tuple{12}, 0.0, Inf, 0}())
var"##targetconstrained#265##BAR##" = ReverseDiffExpressionsBase.second(var"##constrainpullbacktup#302")
var"σ##BAR###0#" = ReverseDiffExpressionsBase.second(var"##targetconstrained#265##BAR##")
var"##targetconstrained#265" = ReverseDiffExpressionsBase.first(var"##constrainpullbacktup#302")
σ = ReverseDiffExpressionsBase.second(var"##targetconstrained#265")
(var"###STACK##POINTER###", var"##rrule_LHS#285") = ReverseDiffExpressions.stack_pointer_call(ChainRules.rrule, var"###STACK##POINTER###", Diagonal, σ)
var"##LHS#258" = Base.first(var"##rrule_LHS#285")
(var"###STACK##POINTER###", var"##temp#289") = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.adjoint, var"###STACK##POINTER###", var"##LHS#258")
Y₁ = (var"#DATA#").Y₁
X₁ = (var"#DATA#").X₁
(var"###STACK##POINTER###", var"##constrainpullbacktup#300") = ReverseDiffExpressions.stack_pointer_call(constrain_pullback!, var"###STACK##POINTER###", var"#∇#", var"#θ#", 0, RealArray{Tuple{10,12}, -Inf, Inf, 0}())
var"##targetconstrained#264##BAR##" = ReverseDiffExpressionsBase.second(var"##constrainpullbacktup#300")
var"β##BAR###0#" = ReverseDiffExpressionsBase.second(var"##targetconstrained#264##BAR##")
var"##targetconstrained#264" = ReverseDiffExpressionsBase.first(var"##constrainpullbacktup#300")
β = ReverseDiffExpressionsBase.second(var"##targetconstrained#264")
(var"###STACK##POINTER###", var"##tup#280") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (var"β##BAR###0#", nothing, nothing), Normal{(true, false, false)}(), β, 10.0, var"##ONE##")
var"##∂tup#281" = ReverseDiffExpressionsBase.second(var"##tup#280")
var"##β##BAR###1##306" = ReverseDiffExpressionsBase.first(var"##∂tup#281")
var"β##BAR###1#" = ReverseDiffExpressions.vadd!(var"β##BAR###0#", var"##β##BAR###1##306")
(var"###STACK##POINTER###", var"##constrainpullbacktup#298") = ReverseDiffExpressions.stack_pointer_call(constrain_pullback!, var"###STACK##POINTER###", var"#∇#", var"#θ#", 120, RealArray{Tuple{12}, -Inf, Inf, 0}())
var"##targetconstrained#263##BAR##" = ReverseDiffExpressionsBase.second(var"##constrainpullbacktup#298")
var"μ##BAR###0#" = ReverseDiffExpressionsBase.second(var"##targetconstrained#263##BAR##")
var"##targetconstrained#263" = ReverseDiffExpressionsBase.first(var"##constrainpullbacktup#298")
μ = ReverseDiffExpressionsBase.second(var"##targetconstrained#263")
(var"###STACK##POINTER###", var"##rrule_LHS#291") = ReverseDiffExpressions.stack_pointer_call(ChainRules.rrule, var"###STACK##POINTER###", LoopVectorization.adjoint, μ)
var"##LHS#259" = Base.first(var"##rrule_LHS#291")
(var"###STACK##POINTER###", var"##constrainpullbacktup#304") = ReverseDiffExpressions.stack_pointer_call(constrain_pullback!, var"###STACK##POINTER###", var"#∇#", var"#θ#", 144, CorrelationMatrixCholesyFactor{12}())
var"##targetconstrained#266##BAR##" = ReverseDiffExpressionsBase.second(var"##constrainpullbacktup#304")
var"L##BAR###0#" = ReverseDiffExpressionsBase.second(var"##targetconstrained#266##BAR##")
var"##targetconstrained#266" = ReverseDiffExpressionsBase.first(var"##constrainpullbacktup#304")
L = ReverseDiffExpressionsBase.second(var"##targetconstrained#266")
(var"###STACK##POINTER###", σL) = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.:*, var"###STACK##POINTER###", var"##LHS#258", L)
(var"###STACK##POINTER###", var"##tup#294") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (nothing, nothing, var"β##BAR###1#", nothing, nothing, nothing), Normal{(false, false, true, true, true, false)}(), Y₁, X₁, β, var"##LHS#259", σL, var"##ONE##")
var"##∂tup#295" = ReverseDiffExpressionsBase.second(var"##tup#294")
var"##β##BAR###307" = ReverseDiffExpressionsBase.third(var"##∂tup#295")
var"β##BAR##" = ReverseDiffExpressions.vadd!(var"β##BAR###1#", var"##β##BAR###307")
var"##LHS#260##BAR###0#" = Base.view(var"β##BAR##", StaticUnitRange{1, 7}(), :)
Y₂ = (var"#DATA#").Y₂
X₂ = (var"#DATA#").X₂
var"##LHS#260" = Base.view(β, StaticUnitRange{1, 7}(), :)
var"##LHS#259##BAR###0#" = ReverseDiffExpressionsBase.fourth(var"##∂tup#295")
var"σL##BAR###0#" = ReverseDiffExpressionsBase.fifth(var"##∂tup#295")
(var"###STACK##POINTER###", var"##tup#296") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (nothing, nothing, var"##LHS#260##BAR###0#", var"##LHS#259##BAR###0#", var"σL##BAR###0#", nothing), Normal{(false, false, true, true, true, false)}(), Y₂, X₂, var"##LHS#260", var"##LHS#259", σL, var"##ONE##")
var"##∂tup#297" = ReverseDiffExpressionsBase.second(var"##tup#296")
var"####LHS#260##BAR###308" = ReverseDiffExpressionsBase.third(var"##∂tup#297")
var"##LHS#260##BAR##" = ReverseDiffExpressions.vadd!(var"##LHS#260##BAR###0#", var"####LHS#260##BAR###308")
(var"###STACK##POINTER###", var"##tup#294#1#") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (nothing, nothing, var"β##BAR##", var"##LHS#259##BAR###0#", var"σL##BAR###0#", nothing), Normal{(false, false, true, true, true, false)}(), Y₁, X₁, β, var"##LHS#259", σL, var"##ONE##")
var"##σL##BAR###309" = ReverseDiffExpressionsBase.fifth(var"##∂tup#297")
var"σL##BAR##" = ReverseDiffExpressions.vadd!(var"σL##BAR###0#", var"##σL##BAR###309")
(var"###STACK##POINTER###", var"##L##BAR###1##310") = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.vmul, var"###STACK##POINTER###", var"##temp#289", var"σL##BAR##")
var"L##BAR###1#" = ReverseDiffExpressions.vadd!(var"L##BAR###0#", var"##L##BAR###1##310")
(var"###STACK##POINTER###", var"##tup#278") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (var"L##BAR###1#", nothing, nothing), LKJ{(true, false, false)}(), L, 2.0, var"##ONE##")
var"##∂tup#279" = ReverseDiffExpressionsBase.second(var"##tup#278")
var"##L##BAR###311" = ReverseDiffExpressionsBase.first(var"##∂tup#279")
var"L##BAR##" = ReverseDiffExpressions.vadd!(var"L##BAR###1#", var"##L##BAR###311")
(var"###STACK##POINTER###", var"##nothing#305") = ReverseDiffExpressions.stack_pointer_call(constrain_reverse!, var"###STACK##POINTER###", var"##targetconstrained#266##BAR##", CorrelationMatrixCholesyFactor{12}())
(var"###STACK##POINTER###", var"##tup#276") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (var"σ##BAR###0#", nothing, nothing, nothing), Gamma{(true, false, false, false)}(), σ, 1.0, 0.05, var"##ONE##")
var"##∂tup#277" = ReverseDiffExpressionsBase.second(var"##tup#276")
var"##σ##BAR###1##312" = ReverseDiffExpressionsBase.first(var"##∂tup#277")
var"σ##BAR###1#" = ReverseDiffExpressions.vadd!(var"σ##BAR###0#", var"##σ##BAR###1##312")
var"##temp#286" = last(var"##rrule_LHS#285")
(var"###STACK##POINTER###", var"##temp#288") = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.adjoint, var"###STACK##POINTER###", L)
(var"###STACK##POINTER###", var"##LHS#258##BAR##") = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.vmul, var"###STACK##POINTER###", var"σL##BAR##", var"##temp#288")
(var"###STACK##POINTER###", var"##temp#287") = ReverseDiffExpressions.stack_pointer_call(callunthunk, var"###STACK##POINTER###", var"##temp#286", var"##LHS#258##BAR##")
var"##σ##BAR###313" = LoopVectorization.second(var"##temp#287")
var"σ##BAR##" = ReverseDiffExpressions.vadd!(var"σ##BAR###1#", var"##σ##BAR###313")
(var"###STACK##POINTER###", var"##nothing#303") = ReverseDiffExpressions.stack_pointer_call(constrain_reverse!, var"###STACK##POINTER###", var"##targetconstrained#265##BAR##", RealArray{Tuple{12}, 0.0, Inf, 0}())
(var"###STACK##POINTER###", var"##nothing#301") = ReverseDiffExpressions.stack_pointer_call(constrain_reverse!, var"###STACK##POINTER###", var"##targetconstrained#264##BAR##", RealArray{Tuple{10,12}, -Inf, Inf, 0}())
(var"###STACK##POINTER###", var"##tup#282") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (var"μ##BAR###0#", nothing, nothing), Normal{(true, false, false)}(), μ, 100.0, var"##ONE##")
var"##∂tup#283" = ReverseDiffExpressionsBase.second(var"##tup#282")
var"##μ##BAR###1##314" = ReverseDiffExpressionsBase.first(var"##∂tup#283")
var"μ##BAR###1#" = ReverseDiffExpressions.vadd!(var"μ##BAR###0#", var"##μ##BAR###1##314")
var"##temp#292" = last(var"##rrule_LHS#291")
var"####LHS#259##BAR###315" = ReverseDiffExpressionsBase.fourth(var"##∂tup#297")
var"##LHS#259##BAR##" = ReverseDiffExpressions.vadd!(var"##LHS#259##BAR###0#", var"####LHS#259##BAR###315")
(var"###STACK##POINTER###", var"##temp#293") = ReverseDiffExpressions.stack_pointer_call(callunthunk, var"###STACK##POINTER###", var"##temp#292", var"##LHS#259##BAR##")
var"##μ##BAR###316" = LoopVectorization.second(var"##temp#293")
var"μ##BAR##" = ReverseDiffExpressions.vadd!(var"μ##BAR###1#", var"##μ##BAR###316")
(var"###STACK##POINTER###", var"##nothing#299") = ReverseDiffExpressions.stack_pointer_call(constrain_reverse!, var"###STACK##POINTER###", var"##targetconstrained#263##BAR##", RealArray{Tuple{12}, -Inf, Inf, 0}())
var"##TARGET###0#" = ReverseDiffExpressionsBase.first(var"##tup#276")
var"####TARGET###1##317" = ReverseDiffExpressionsBase.first(var"##tup#278")
var"##TARGET###1#" = ReverseDiffExpressions.vadd!(var"##TARGET###0#", var"####TARGET###1##317")
var"####TARGET###2##318" = ReverseDiffExpressionsBase.first(var"##tup#280")
var"##TARGET###2#" = ReverseDiffExpressions.vadd!(var"##TARGET###1#", var"####TARGET###2##318")
var"####TARGET###3##319" = ReverseDiffExpressionsBase.first(var"##tup#282")
var"##TARGET###3#" = ReverseDiffExpressions.vadd!(var"##TARGET###2#", var"####TARGET###3##319")
var"####TARGET###4##320" = ReverseDiffExpressionsBase.first(var"##tup#294")
var"##TARGET###4#" = ReverseDiffExpressions.vadd!(var"##TARGET###3#", var"####TARGET###4##320")
var"####TARGET###321" = ReverseDiffExpressionsBase.first(var"##tup#296")
var"##TARGET##" = ReverseDiffExpressions.vadd!(var"##TARGET###4#", var"####TARGET###321")
vsum(var"##TARGET##")
end It checks whether gradients are considered known (and I'd implement gradients for all supported distributions), otherwise it would fall back to another library. I don't think the above code would run, because I'm pretty sure I didn't implement all the needed methods. My longer term ambition here is to be able to convert many distributions and linear algebra routines into equivalent loop-based representations, and then use LoopVectorization to reorder, fuse them, etc (still working on those optimizations). |
Cool! LoopVectorization as an optimization pass sounds very interesting. This will require the symbolic version of a chain rule though which is not stored by default in ChainRules. Transforming linear algebra expressions to loop-based expressions sounds like a task Tullio can probably do with a little meta-programming help. |
Look at the code generated in https://mtk.sciml.ai/dev/tutorials/auto_parallel/ to get a sense of what MTK is currently doing. The current push is to get non-scalar forms as well, but that's what we have so far. |
It'd be cool to add tullio / indexing notation support directly. Although, it may be better to do the conversion, because |
Yes and Tullio + a macro can help with that using a few if statements checking what the types of |
The model definition would define the likelihood/gradient functions as I think it'd be easier to delay the substitution until then. Otherwise, you'd need a lot of branches, e.g. |
A macro can do this no problem. Each loop can be annotated with |
Or maybe the entire body can be annotated to optimize the whole expression for each type scenario. But I imagine this is somewhat out of LoopVectorization's comfort zone. |
So some take away points from the great community call today (thanks @femtomc and @cscherrer for organizing this):
tilde
anddot_tilde
functions. I don't know how it will turn out but I think I found the right entry point.Omega
can support inference techniques currently impossible in Turing but on a limited subset of models, i.e. the ones compatible with theΩ
struct, I think. CC: @zennaOmega
's inference algorithms on the class of models that it supports. The main change required here will be to define the random variables to be instances of random variable type inOmega
. Then thesample
function can simply call thecond
function or something similar inOmega
iiuc. Again more details need to be figured out here but it should be doable as a summer or winter project.~
. I think this might be possible today for some functions, namely bijectors. I think we can define a transformed distribution usingBijectors.jl
and use that to observe for example2y
instead ofy
orlog(y)
, etc. A tutorial here may be all we need. CC: @torfjeldeProbabilityModels
can be implemented at the distribution level making them available to other PPLs. CC: @chriselrod. I think it may help to have a macro that can let us more easily define a "complex" distribution that returns a named tuple, together with its custom adjoint. It would be interesting if we can automatically compose and inline rules inChainRules
explicitly at this level to make a bigger chain rule for our "complex distribution". The goal of this is to minimize the time spent in Zygote's type unstable parts by going through a single primitive instead of multiple AD primitives, kind of like a function barrier but for adjoints. CC: @oxinabox.The text was updated successfully, but these errors were encountered: