Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Notes from the September probprog community call #1410

Closed
mohamed82008 opened this issue Sep 20, 2020 · 39 comments
Closed

Notes from the September probprog community call #1410

mohamed82008 opened this issue Sep 20, 2020 · 39 comments

Comments

@mohamed82008
Copy link
Member

mohamed82008 commented Sep 20, 2020

So some take away points from the great community call today (thanks @femtomc and @cscherrer for organizing this):

  1. There are plans to separate Gen's GFI to its own package at some point. CC: @alex-lew (Proposal: Separate package for abstract GFI probcomp/Gen.jl#306)
  2. As far as I can tell, the best entry point for a Turing-Gen integration is at the context level. Each of the GFI functions can be defined for a Turing model by overloading the model call with a different input context or by overloading the tilde and dot_tilde functions. I don't know how it will turn out but I think I found the right entry point.
  3. Omega.jl uses type-based dependence / independence tracking which is very analogous to the Tracker / ReverseDiff / ForwardDiff approach in AD. Omega can support inference techniques currently impossible in Turing but on a limited subset of models, i.e. the ones compatible with the Ω struct, I think. CC: @zenna
  4. The same entry point for Gen compatibility above can be used to allow a Turing model to use Omega's inference algorithms on the class of models that it supports. The main change required here will be to define the random variables to be instances of random variable type in Omega. Then the sample function can simply call the cond function or something similar in Omega iiuc. Again more details need to be figured out here but it should be doable as a summer or winter project.
  5. It was mentioned today that one limitation of the current PPLs is that we cannot condition on a "function" of the model's "observations" instead of the "observations" themselves. I am using "observations" here to refer to the the thing on the LHS of ~. I think this might be possible today for some functions, namely bijectors. I think we can define a transformed distribution using Bijectors.jl and use that to observe for example 2y instead of y or log(y), etc. A tutorial here may be all we need. CC: @torfjelde
  6. Many of the optimizations in ProbabilityModels can be implemented at the distribution level making them available to other PPLs. CC: @chriselrod. I think it may help to have a macro that can let us more easily define a "complex" distribution that returns a named tuple, together with its custom adjoint. It would be interesting if we can automatically compose and inline rules in ChainRules explicitly at this level to make a bigger chain rule for our "complex distribution". The goal of this is to minimize the time spent in Zygote's type unstable parts by going through a single primitive instead of multiple AD primitives, kind of like a function barrier but for adjoints. CC: @oxinabox.
@cscherrer
Copy link

I think it may help to have a macro that can let us more easily define a "complex" distribution that returns a named tuple, together with its custom adjoint.

How about Soss.@model? ;)

@phipsgabler
Copy link
Member

phipsgabler commented Sep 20, 2020

Ad (5): this should be easy (at least to write) in my proposed IR. Or in every system that has a separation of variable names and sampling, so that you can have normal assignement to a named variable:

{x} ~ Normal()
{log_x} = log({x})

It would still have to be supported by the PPL evaluating the model, of course. THis is also the kind of thing that could allow to constrain intermediate transformations -- something I believe @zenna mentioned at some point.

Ad (6): that could even be combinator in Measures.jl, couldn't it?

@mohamed82008
Copy link
Member Author

How about Soss.@model? ;)

Yes! But it would be nice to add the custom adjoint part to that as well and have Chris Elrod implement some heavily optimized logpdf methods for those distributions or the distribution generators.

@mohamed82008
Copy link
Member Author

Ad (6): that could even be combinator in Measures.jl, couldn't it?

Yes either Measures or Soss would be a good place for this.

@cscherrer
Copy link

For multivariate normal, there's a lot of value in representing the covariance
as a tensor product. Something like Kronecker.jl, but without the flattening.

Then

x ~ MvNormal(Σ) |> iid(n)

would reduce to

x ~ NewGaussian(I(n)  Σ)

and the correlation between the (i,a)th and (j,b)th elements of x
is

0      if i≠j
Σ[a,b] if i=j

and if Σ is k×k, rand(NewGaussian(I(n) ⊗ Σ)) would generate an n×k matrix.

@cscherrer
Copy link

Yes either Measures or Soss would be a good place for this.

I was sort of half-joking originally, but defining a distribution (or a measure, more generally) over named tuples is sort of Soss's whole deal. Also, @tpapp 's TransformVariables is really good at this, if your starting point is a distribution over ℝⁿ.

@phipsgabler
Copy link
Member

One reason I'm already a huge fan of Measures.jl is the nice "multiple parametrizations" idea. So actually, there could be

x ~ NewGaussian(diag = Σ)

falling back to

NewGaussian(cov = I(length(Σ)) ⊗ Σ)

or similarly.

@chriselrod
Copy link

chriselrod commented Sep 20, 2020

Like in Stan, I'd prefer to point people towards using Normal(::AbstractVector, ::LinearAlgebra.AbstractTriangular) when reasonable to skip the Cholesky factorization.

This is an unoptimized but simple and dependency-free (aside from standard library) implementation for calculating the density and gradients:

using LinearAlgebra
# y is an P x N matrix containing N total P-dimensional observations
# m is a vector of length P containing the means
# l is the lower triangular Cholesky factor of the covariance matrix
# dy, dm, and dl are the corresponding gradients
# b is preallocated memory with the same size and shape as y
function ldnorm!(dy, dm, dl, b, y, m, l)
    b .= m .- y
    LAPACK.trtrs!('L', 'N', 'N', l, b)
    lp = dot(b, b); dy .= b
    LAPACK.trtrs!('L', 'T', 'N', l, dy)
    #@avx should be a little faster, and let you move the 0-assignment into the loop nest
    dm .= 0
    @inbounds for n  1:size(y,2)
        @simd for p  1:size(y,1)
            dm[p] -= dy[p,n]
        end
    end
    BLAS.gemm!('N', 'T', 1.0, dy, b, 0.0, dl)
    N = size(y,2); lp *= -0.5
    @inbounds for p  1:size(dl,1) #@avx instead would be much faster
        dl[p,p] -= N * l[p,p]
        lp -= N * log(l[p,p])
    end
    lp
end

The optimized code I wrote that does only a single pass over the memory is much messier and currently broken, although it handles a variety of situations such as were the mean is defined as X * beta (fusing it into the rest of the calculations).

Unfortunately, LAPACK doesn't include a version of trtrs! for solving NxP with a PxP matrix, even though that situation is much easier to SIMD and therefore faster.
Hence why the above uses PxN instead.

I could work on rewriting and cleaning it up, but I think long term I'd rather get LoopVectorization to be able to handle the kind of optimizations and loop dependency structures needed.
One of the major goals of that in the context of a PPL would be to let it convert high level expressions like logpdfs and linear algebra into loop nests, and then let loop optimization code optimize the combined expressions.

@devmotion
Copy link
Member

One reason I'm already a huge fan of Measures.jl is the nice "multiple parametrizations" idea. So actually, there could be

x ~ NewGaussian(diag = Σ)

falling back to

NewGaussian(cov = I(length(Σ)) ⊗ Σ)

or similarly.

Isn't the same behaviour achieved by MvNormal without keyword arguments by dispatching on the type of the covariance matrix, such as AbstractVector, UniformScaling, Diagonal, etc.? IMO an annoyance with MvNormal is just that in the end the type of the covariance matrix has to be a AbstractPDMat (although also there optimizations for diagonal matrices etc exist). For the general case, it seems there is a PR to Distributions that adds keyword arguments to constructors: JuliaStats/Distributions.jl#823

@cscherrer
Copy link

Oh wait, sorry I had misread @phipsgabler 's comment. I think it's a little weird how Distributions.jl works in this case. It has MvNormal as a distribution over vectors, but then taking multiple samples magically puts it into a matrix. It just makes the whole thing awkward to reason about. I'd rather have the size as part of the distribution explicitly.

IMO an annoyance with MvNormal is just that in the end the type of the covariance matrix has to be a AbstractPDMat (although also there optimizations for diagonal matrices etc exist).

Yep, I agree. Also, it's often handy to allow positive semidefinite covariance. Tim Holy has an approach to this in PositiveFactorizations.jl that seems promising.

For the general case, it seems there is a PR to Distributions that adds keyword arguments to constructors: JuliaStats/Distributions.jl#823

This looks nice, but it seems to not be really allow for different parameterizations. Instead, it uses keyword arguments to coerce to the standard parameterization.

A big advantage of reparameterization is to make computation cheaper, but this doesn't seem to help in that regard.

@devmotion
Copy link
Member

It has MvNormal as a distribution over vectors, but then taking multiple samples magically puts it into a matrix.

I agree, there is some inconsistency in how rand works here. However, the API of Distributions also accepts AbstractArrays of AbstractVectors as samples of multivariate distributions, e.g., for evaluating the loglikelihood or in-place sampling with rand!.

@trappmartin
Copy link
Member

Thanks @mohamed82008!

@torfjelde
Copy link
Member

It was mentioned today that one limitation of the current PPLs is that we cannot condition on a "function" of the model's "observations" instead of the "observations" themselves. I am using "observations" here to refer to the the thing on the LHS of ~. I think this might be possible today for some functions, namely bijectors. I think we can define a transformed distribution using Bijectors.jl and use that to observe for example 2y instead of y or log(y), etc. A tutorial here may be all we need.

In Turing we decide what's "observed" based on whether or not it's present in args for the model *and is not missing, right? More specifically https://github.com/TuringLang/DynamicPPL.jl/blob/405546f5f034a9c78e3687e05f3713b998cdbf0c/src/compiler.jl#L6..L19

So the issue I don't think is related to whether or not the transformation is bijective, but rather that we can't handle something like

@model function demo(x)
    y = f(x)
    y ~ Likelihood()
end

because we don't know that LHS is a function of the inputs.

Also, it's worth pointing out that it's still possible to conditio on observations by using the @addlogpdf!, i.e.

@model function demo(x)
    y = f(x)
    @addlogprob! logpdf(Likelihood(), y)
end

Alternatively we can make it look nicer by doing something like

@model function demo(x)
    y = f(x)
    @observe y ~ Likelihood()
end

which simply expands to the above.

I think it may help to have a macro that can let us more easily define a "complex" distribution that returns a named tuple, together with its custom adjoint.

I really like going in this direction. Distributions for named tuples is something that keeps coming up as something that would be super-useful to have (also added support for transforming something like that in Bijectors.jl recently).

Also, @tpapp's TransformVariables is really good at this, if your starting point is a distribution over ℝⁿ.

Btw, we've also added support for transforming NamedTuple in Bijectors.jl too now, so I'd be keen to hear what you TransformVariables.jl people think of the approach. Would be nice if we could converge on a joint-effort:)

Also, I completely missed meeting! Sorry! Been in the process of moving and getting sorted for starting my studies, so completely forgot about it.

@cscherrer
Copy link

Btw, we've also added support for transforming NamedTuple in Bijectors.jl too now, so I'd be keen to hear what you TransformVariables.jl people think of the approach. Would be nice if we could converge on a joint-effort:)

THIS IS SO GREAT!!!!

Link: TuringLang/Bijectors.jl#95

Some background for others: It's often important to have the support of a variable depend stochastically on the value of another. It's not just a niche thing either, it comes up any time you want variables to be ordered in some way. And this is a big deal because exchangeability in posteriors is just a mess, causing most samplers to mode-switch all over the place, which in turn screws up most diagnostics.

So yeah, 💯

@oxinabox
Copy link

oxinabox commented Sep 30, 2020

It would be interesting if we can automatically compose and inline rules in ChainRules explicitly at this level to make a bigger chain rule for our "complex distribution". The goal of this is to minimize the time spent in Zygote's type unstable parts by going through a single primitive instead of multiple AD primitives, kind of like a function barrier but for adjoints. CC: @oxinabox.

Yes, this should be possible.
I have long intended to do basically the same idea to Flux.Chain when uses with only simple layers that we have rules for.
Right now because we don't have JuliaDiff/ChainRulesCore.jl#68
you need to know you have rules for every part; and you need to know you have rules from the start, can't back out later

@mohamed82008
Copy link
Member Author

mohamed82008 commented Sep 30, 2020

So what I am proposing is practically a symbolic AD layer on top of ChainRules to write a chain rule for a function that we really really care about and that we can define in global scope (annotated with a macro) once for our users to use over and over again. For example, the logpdf of an MvNormal which is a well-defined mathematical formula but I don't want to sit down and manually derive the reverse and forward chain rules for it. And I don't want Zygote to scratch its head trying to connect all the chain rules at runtime. I want the symbolic AD to do it for me at parse-time generating a single chain rule for the whole function that Zygote can then use as a primitive.

@mohamed82008
Copy link
Member Author

Since we are talking AD and probprog, CC: @ChrisRackauckas.

@mohamed82008
Copy link
Member Author

you need to know you have rules for every part; and you need to know you have rules from the start, can't back out later

I think even with this limitation, the proposal is still useful.

@ChrisRackauckas
Copy link
Collaborator

So what I am proposing is practically a symbolic AD layer on top of ChainRules to write a chain rule for a function that we really really care about and that we can define in global scope (annotated with a macro) once for our users to use over and over again. For example, the logpdf of an MvNormal which is a well-defined mathematical formula but I don't want to sit down and manually derive the reverse and forward chain rules for it. And I don't want Zygote to scratch its head trying to connect all the chain rules at runtime. I want the symbolic AD to do it for me at parse-time generating a single chain rule for the whole function that Zygote can then use as a primitive.

For symbolic AD here, you're just asking for ModelingToolkit?

@mohamed82008
Copy link
Member Author

you're just asking for ModelingToolkit?

If MTK can connect ChainRules like lego to make a bigger chain rule, then yes. I really should play more with MTK.

@mohamed82008
Copy link
Member Author

like lego

Or well...like a chain!

@ChrisRackauckas
Copy link
Collaborator

It's built on DiffRules right now, but with @oxinabox 's changes to ChainRules it could probably use ChainRules now.

@mohamed82008
Copy link
Member Author

Awesome! If this happens, I think it can solve like 90% of the AD needs if one is willing to work in global scope.

@mohamed82008
Copy link
Member Author

So for a bigger picture, I imagine a world where more functions are annotated with @buildchain or something from MTK to define a complex chain rule symbolically for this function. Once ChainRules can call back to AD, this complex chain rule can make calls to Zygote if no chain rule exists for one of the sub-functions in the main function.

Then the entry point for AD can be Zygote (or ChainRules really at this point). Zygote calls ChainRules making use of MTK-generated ChainRules where possible, which themselves can call back into Zygote if needed. A beautiful harmony of symbolic AD (MTK), runtime-based AD (Zygote) and manual differentiation (ChainRules) all working together to avoid Zygote's type instability!

@ChrisRackauckas
Copy link
Collaborator

That's exactly how we're using it in DiffEq, defining Jacobians, vjps, etc. to augment the AD world. The nice thing is that symbolic derivatives are more efficient than AD derivatives in the sparse case: coloring gets close, but is never as efficient as directly defining entries of the matrix. So we mainly use it to build big sparse things, but yes it handles the case the other source-to-source ADs like Zygote are bad at (scalarized stuff).

@mohamed82008
Copy link
Member Author

The nice thing is that symbolic derivatives are more efficient than AD derivatives in the sparse case:

But it also needs to be taught when to give up. For example, if a function is too big. I don't know if MTK unrolls loops or tries to work with it as a single node. If it's the former, then it will be pretty taxing on the compiler and it can take a very long time to parse for long loops.

@ChrisRackauckas
Copy link
Collaborator

It doesn't give up. We leave it to user choice. We have cases where 500 second compile time is a reduction in compute time by 80 hours though... so it's really a choice and I think any heuristic is wrong. Instead, a heuristic on that should exist (and be overridable) in Turing, or some high level "glue AD".

@mohamed82008
Copy link
Member Author

Sure but at least it needs to be taught how to give up and we can tell it when 😅

@mohamed82008
Copy link
Member Author

For example, setting a time limit.

@ChrisRackauckas
Copy link
Collaborator

Use a task and do the computation on a task and kill the task if it goes over a time limit. That's what we plan to do in AutoOptimize.jl.

@mohamed82008
Copy link
Member Author

Nice!

@chriselrod
Copy link

I haven't had the time to get it into a working state, but this much at least does work

julia> using ProbabilityModels, LinearAlgebra, DistributionParameters

julia> sg = quote
           σ ~ Gamma(1.0, 0.05)
           L ~ LKJ(2.0)
           β ~ Normal(10.0) # μ = 0
           μ ~ Normal(100.0)

           σL = Diagonal(σ) * L
           Y₁ ~ Normal( X₁, β,       μ',  σL )
           Y₂ ~ Normal( X₂, β[1:7,:], μ', σL )
       end;

A model definition API would be something like @model ModelName begin followed by the contents of the above quote. It would define a struct named ModelName with a field that would hold a named tuple. The named tuple holding a value would make it known data/prior (unless it is Missing), otherwise it'd be a parameter.

The macro would also define methods for functions dispatch on the struct to calculate the logdensity and gradient.

The macro would use read_model to read the model:

julia> m = ProbabilityModels.read_model(sg, Main);

Generating data to create an example named tuple:

julia> N, K₁, K₂, P = 100, 10, 7, 12;

julia> μ = rand(P);

julia> σU = rand(P, (3P)>>1) |> x -> cholesky(Hermitian(x * x')).U;

julia> X₁ = rand(N, K₁);

julia> X₂ = rand(N, K₂);

julia> β = randn(K₁, P);

julia> Y₁ = mul!(rand(N, P) * σU, X₁, β, 1.0, 1.0);

julia> Y₂ = mul!(rand(N, P) * σU, X₂, view(β, 1:K₂, :), 1.0, 1.0);

julia> datant = (
           Y₁ = Y₁, Y₂ = Y₂, X₁ = X₁, X₂ = X₂,
           μ = RealVector{12}(), β = RealMatrix{10,12}(),
           σ = RealVector{12,0.0,Inf}(), L = CorrelationMatrixCholesyFactor{12}()
       );

A model description would be inserted into the generated function methods for logdensity and gradients.
These can then use the type information from the named tuple to generate code, e.g. for the logdensity:

julia> ProbabilityModels.preprocess!(m, typeof(datant));

julia> ProbabilityModels.ReverseDiffExpressions.lower(m)
quote
    (var"###STACK##POINTER###", var"##targetconstrained#263") = ReverseDiffExpressions.stack_pointer_call(DistributionParameters.constrain, var"###STACK##POINTER###", var"#θ#", 120, RealArray{Tuple{12}, -Inf, Inf, 0}())
    var"##TARGET###0#" = ReverseDiffExpressionsBase.first(var"##targetconstrained#263")
    (var"###STACK##POINTER###", var"##targetconstrained#265") = ReverseDiffExpressions.stack_pointer_call(DistributionParameters.constrain, var"###STACK##POINTER###", var"#θ#", 132, RealArray{Tuple{12}, 0.0, Inf, 0}())
    σ = ReverseDiffExpressionsBase.second(var"##targetconstrained#265")
    (var"###STACK##POINTER###", var"####TARGET###1##267") = ReverseDiffExpressions.stack_pointer_call(logdensity, var"###STACK##POINTER###", Gamma{(true, false, false, false)}(), σ, 1.0, 0.05, var"##ONE##")
    var"##TARGET###1#" = ReverseDiffExpressions.vadd!(var"##TARGET###0#", var"####TARGET###1##267")
    (var"###STACK##POINTER###", var"##targetconstrained#266") = ReverseDiffExpressions.stack_pointer_call(DistributionParameters.constrain, var"###STACK##POINTER###", var"#θ#", 144, CorrelationMatrixCholesyFactor{12}())
    L = ReverseDiffExpressionsBase.second(var"##targetconstrained#266")
    (var"###STACK##POINTER###", var"####TARGET###2##268") = ReverseDiffExpressions.stack_pointer_call(logdensity, var"###STACK##POINTER###", LKJ{(true, false, false)}(), L, 2.0, var"##ONE##")
    var"##TARGET###2#" = ReverseDiffExpressions.vadd!(var"##TARGET###1#", var"####TARGET###2##268")
    (var"###STACK##POINTER###", var"##targetconstrained#264") = ReverseDiffExpressions.stack_pointer_call(DistributionParameters.constrain, var"###STACK##POINTER###", var"#θ#", 0, RealArray{Tuple{10,12}, -Inf, Inf, 0}())
    β = ReverseDiffExpressionsBase.second(var"##targetconstrained#264")
    (var"###STACK##POINTER###", var"####TARGET###3##269") = ReverseDiffExpressions.stack_pointer_call(logdensity, var"###STACK##POINTER###", Normal{(true, false, false)}(), β, 10.0, var"##ONE##")
    var"##TARGET###3#" = ReverseDiffExpressions.vadd!(var"##TARGET###2#", var"####TARGET###3##269")
    μ = ReverseDiffExpressionsBase.second(var"##targetconstrained#263")
    (var"###STACK##POINTER###", var"####TARGET###4##270") = ReverseDiffExpressions.stack_pointer_call(logdensity, var"###STACK##POINTER###", Normal{(true, false, false)}(), μ, 100.0, var"##ONE##")
    var"##TARGET###4#" = ReverseDiffExpressions.vadd!(var"##TARGET###3#", var"####TARGET###4##270")
    Y₁ = (var"#DATA#").Y₁
    X₁ = (var"#DATA#").X₁
    (var"###STACK##POINTER###", var"##LHS#259#0#") = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.adjoint, var"###STACK##POINTER###", μ)
    (var"###STACK##POINTER###", var"##LHS#258") = ReverseDiffExpressions.stack_pointer_call(Diagonal, var"###STACK##POINTER###", σ)
    (var"###STACK##POINTER###", σL) = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.:*, var"###STACK##POINTER###", var"##LHS#258", L)
    (var"###STACK##POINTER###", var"####TARGET###5##271") = ReverseDiffExpressions.stack_pointer_call(logdensity, var"###STACK##POINTER###", Normal{(false, false, true, true, true, false)}(), Y₁, X₁, β, var"##LHS#259", σL, var"##ONE##")
    var"##TARGET###5#" = ReverseDiffExpressions.vadd!(var"##TARGET###4#", var"####TARGET###5##271")
    Y₂ = (var"#DATA#").Y₂
    X₂ = (var"#DATA#").X₂
    var"##LHS#260" = Base.view(β, StaticUnitRange{1, 7}(), :)
    (var"###STACK##POINTER###", var"####TARGET###6##272") = ReverseDiffExpressions.stack_pointer_call(logdensity, var"###STACK##POINTER###", Normal{(false, false, true, true, true, false)}(), Y₂, X₂, var"##LHS#260", var"##LHS#259", σL, var"##ONE##")
    var"##TARGET###6#" = ReverseDiffExpressions.vadd!(var"##TARGET###5#", var"####TARGET###6##272")
    var"####TARGET###7##273" = ReverseDiffExpressionsBase.first(var"##targetconstrained#264")
    var"##TARGET###7#" = ReverseDiffExpressions.vadd!(var"##TARGET###6#", var"####TARGET###7##273")
    var"####TARGET###8##274" = ReverseDiffExpressionsBase.first(var"##targetconstrained#265")
    var"##TARGET###8#" = ReverseDiffExpressions.vadd!(var"##TARGET###7#", var"####TARGET###8##274")
    var"####TARGET###275" = ReverseDiffExpressionsBase.first(var"##targetconstrained#266")
    var"##TARGET##" = ReverseDiffExpressions.vadd!(var"##TARGET###8#", var"####TARGET###275")
    vsum(var"##TARGET##")
end

Verbose, but a few notes:

  1. Code is written to use a StackPointers.StackPointer so that functions can opt in to be non-allocating (using the stack pointer for working memory, and incremeneting it if they need to allocate). Requires escaping to be undefined behavior/not aloud.
  2. The indices for constrain, to constrain the input vector, are sorted to try and maximize the number of aligned loads and stores.
  3. The general order of lowering the expressions from the model tries to be cache friendly, clumping uses of the same memory together. It's fairly naive at this point, but could definitely be optimized.

Similarly, it can differentiate the model method:

julia> dm = ProbabilityModels.ReverseDiffExpressions.differentiate(m);

julia> ProbabilityModels.ReverseDiffExpressions.lower(dm)
quote
    (var"###STACK##POINTER###", var"##constrainpullbacktup#302") = ReverseDiffExpressions.stack_pointer_call(constrain_pullback!, var"###STACK##POINTER###", var"#∇#", var"#θ#", 132, RealArray{Tuple{12}, 0.0, Inf, 0}())
    var"##targetconstrained#265##BAR##" = ReverseDiffExpressionsBase.second(var"##constrainpullbacktup#302")
    var"σ##BAR###0#" = ReverseDiffExpressionsBase.second(var"##targetconstrained#265##BAR##")
    var"##targetconstrained#265" = ReverseDiffExpressionsBase.first(var"##constrainpullbacktup#302")
    σ = ReverseDiffExpressionsBase.second(var"##targetconstrained#265")
    (var"###STACK##POINTER###", var"##rrule_LHS#285") = ReverseDiffExpressions.stack_pointer_call(ChainRules.rrule, var"###STACK##POINTER###", Diagonal, σ)
    var"##LHS#258" = Base.first(var"##rrule_LHS#285")
    (var"###STACK##POINTER###", var"##temp#289") = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.adjoint, var"###STACK##POINTER###", var"##LHS#258")
    Y₁ = (var"#DATA#").Y₁
    X₁ = (var"#DATA#").X₁
    (var"###STACK##POINTER###", var"##constrainpullbacktup#300") = ReverseDiffExpressions.stack_pointer_call(constrain_pullback!, var"###STACK##POINTER###", var"#∇#", var"#θ#", 0, RealArray{Tuple{10,12}, -Inf, Inf, 0}())
    var"##targetconstrained#264##BAR##" = ReverseDiffExpressionsBase.second(var"##constrainpullbacktup#300")
    var"β##BAR###0#" = ReverseDiffExpressionsBase.second(var"##targetconstrained#264##BAR##")
    var"##targetconstrained#264" = ReverseDiffExpressionsBase.first(var"##constrainpullbacktup#300")
    β = ReverseDiffExpressionsBase.second(var"##targetconstrained#264")
    (var"###STACK##POINTER###", var"##tup#280") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (var"β##BAR###0#", nothing, nothing), Normal{(true, false, false)}(), β, 10.0, var"##ONE##")
    var"##∂tup#281" = ReverseDiffExpressionsBase.second(var"##tup#280")
    var"##β##BAR###1##306" = ReverseDiffExpressionsBase.first(var"##∂tup#281")
    var"β##BAR###1#" = ReverseDiffExpressions.vadd!(var"β##BAR###0#", var"##β##BAR###1##306")
    (var"###STACK##POINTER###", var"##constrainpullbacktup#298") = ReverseDiffExpressions.stack_pointer_call(constrain_pullback!, var"###STACK##POINTER###", var"#∇#", var"#θ#", 120, RealArray{Tuple{12}, -Inf, Inf, 0}())
    var"##targetconstrained#263##BAR##" = ReverseDiffExpressionsBase.second(var"##constrainpullbacktup#298")
    var"μ##BAR###0#" = ReverseDiffExpressionsBase.second(var"##targetconstrained#263##BAR##")
    var"##targetconstrained#263" = ReverseDiffExpressionsBase.first(var"##constrainpullbacktup#298")
    μ = ReverseDiffExpressionsBase.second(var"##targetconstrained#263")
    (var"###STACK##POINTER###", var"##rrule_LHS#291") = ReverseDiffExpressions.stack_pointer_call(ChainRules.rrule, var"###STACK##POINTER###", LoopVectorization.adjoint, μ)
    var"##LHS#259" = Base.first(var"##rrule_LHS#291")
    (var"###STACK##POINTER###", var"##constrainpullbacktup#304") = ReverseDiffExpressions.stack_pointer_call(constrain_pullback!, var"###STACK##POINTER###", var"#∇#", var"#θ#", 144, CorrelationMatrixCholesyFactor{12}())
    var"##targetconstrained#266##BAR##" = ReverseDiffExpressionsBase.second(var"##constrainpullbacktup#304")
    var"L##BAR###0#" = ReverseDiffExpressionsBase.second(var"##targetconstrained#266##BAR##")
    var"##targetconstrained#266" = ReverseDiffExpressionsBase.first(var"##constrainpullbacktup#304")
    L = ReverseDiffExpressionsBase.second(var"##targetconstrained#266")
    (var"###STACK##POINTER###", σL) = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.:*, var"###STACK##POINTER###", var"##LHS#258", L)
    (var"###STACK##POINTER###", var"##tup#294") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (nothing, nothing, var"β##BAR###1#", nothing, nothing, nothing), Normal{(false, false, true, true, true, false)}(), Y₁, X₁, β, var"##LHS#259", σL, var"##ONE##")
    var"##∂tup#295" = ReverseDiffExpressionsBase.second(var"##tup#294")
    var"##β##BAR###307" = ReverseDiffExpressionsBase.third(var"##∂tup#295")
    var"β##BAR##" = ReverseDiffExpressions.vadd!(var"β##BAR###1#", var"##β##BAR###307")
    var"##LHS#260##BAR###0#" = Base.view(var"β##BAR##", StaticUnitRange{1, 7}(), :)
    Y₂ = (var"#DATA#").Y₂
    X₂ = (var"#DATA#").X₂
    var"##LHS#260" = Base.view(β, StaticUnitRange{1, 7}(), :)
    var"##LHS#259##BAR###0#" = ReverseDiffExpressionsBase.fourth(var"##∂tup#295")
    var"σL##BAR###0#" = ReverseDiffExpressionsBase.fifth(var"##∂tup#295")
    (var"###STACK##POINTER###", var"##tup#296") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (nothing, nothing, var"##LHS#260##BAR###0#", var"##LHS#259##BAR###0#", var"σL##BAR###0#", nothing), Normal{(false, false, true, true, true, false)}(), Y₂, X₂, var"##LHS#260", var"##LHS#259", σL, var"##ONE##")
    var"##∂tup#297" = ReverseDiffExpressionsBase.second(var"##tup#296")
    var"####LHS#260##BAR###308" = ReverseDiffExpressionsBase.third(var"##∂tup#297")
    var"##LHS#260##BAR##" = ReverseDiffExpressions.vadd!(var"##LHS#260##BAR###0#", var"####LHS#260##BAR###308")
    (var"###STACK##POINTER###", var"##tup#294#1#") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (nothing, nothing, var"β##BAR##", var"##LHS#259##BAR###0#", var"σL##BAR###0#", nothing), Normal{(false, false, true, true, true, false)}(), Y₁, X₁, β, var"##LHS#259", σL, var"##ONE##")
    var"##σL##BAR###309" = ReverseDiffExpressionsBase.fifth(var"##∂tup#297")
    var"σL##BAR##" = ReverseDiffExpressions.vadd!(var"σL##BAR###0#", var"##σL##BAR###309")
    (var"###STACK##POINTER###", var"##L##BAR###1##310") = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.vmul, var"###STACK##POINTER###", var"##temp#289", var"σL##BAR##")
    var"L##BAR###1#" = ReverseDiffExpressions.vadd!(var"L##BAR###0#", var"##L##BAR###1##310")
    (var"###STACK##POINTER###", var"##tup#278") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (var"L##BAR###1#", nothing, nothing), LKJ{(true, false, false)}(), L, 2.0, var"##ONE##")
    var"##∂tup#279" = ReverseDiffExpressionsBase.second(var"##tup#278")
    var"##L##BAR###311" = ReverseDiffExpressionsBase.first(var"##∂tup#279")
    var"L##BAR##" = ReverseDiffExpressions.vadd!(var"L##BAR###1#", var"##L##BAR###311")
    (var"###STACK##POINTER###", var"##nothing#305") = ReverseDiffExpressions.stack_pointer_call(constrain_reverse!, var"###STACK##POINTER###", var"##targetconstrained#266##BAR##", CorrelationMatrixCholesyFactor{12}())
    (var"###STACK##POINTER###", var"##tup#276") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (var"σ##BAR###0#", nothing, nothing, nothing), Gamma{(true, false, false, false)}(), σ, 1.0, 0.05, var"##ONE##")
    var"##∂tup#277" = ReverseDiffExpressionsBase.second(var"##tup#276")
    var"##σ##BAR###1##312" = ReverseDiffExpressionsBase.first(var"##∂tup#277")
    var"σ##BAR###1#" = ReverseDiffExpressions.vadd!(var"σ##BAR###0#", var"##σ##BAR###1##312")
    var"##temp#286" = last(var"##rrule_LHS#285")
    (var"###STACK##POINTER###", var"##temp#288") = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.adjoint, var"###STACK##POINTER###", L)
    (var"###STACK##POINTER###", var"##LHS#258##BAR##") = ReverseDiffExpressions.stack_pointer_call(LoopVectorization.vmul, var"###STACK##POINTER###", var"σL##BAR##", var"##temp#288")
    (var"###STACK##POINTER###", var"##temp#287") = ReverseDiffExpressions.stack_pointer_call(callunthunk, var"###STACK##POINTER###", var"##temp#286", var"##LHS#258##BAR##")
    var"##σ##BAR###313" = LoopVectorization.second(var"##temp#287")
    var"σ##BAR##" = ReverseDiffExpressions.vadd!(var"σ##BAR###1#", var"##σ##BAR###313")
    (var"###STACK##POINTER###", var"##nothing#303") = ReverseDiffExpressions.stack_pointer_call(constrain_reverse!, var"###STACK##POINTER###", var"##targetconstrained#265##BAR##", RealArray{Tuple{12}, 0.0, Inf, 0}())
    (var"###STACK##POINTER###", var"##nothing#301") = ReverseDiffExpressions.stack_pointer_call(constrain_reverse!, var"###STACK##POINTER###", var"##targetconstrained#264##BAR##", RealArray{Tuple{10,12}, -Inf, Inf, 0}())
    (var"###STACK##POINTER###", var"##tup#282") = ReverseDiffExpressions.stack_pointer_call(∂logdensity!, var"###STACK##POINTER###", (var"μ##BAR###0#", nothing, nothing), Normal{(true, false, false)}(), μ, 100.0, var"##ONE##")
    var"##∂tup#283" = ReverseDiffExpressionsBase.second(var"##tup#282")
    var"##μ##BAR###1##314" = ReverseDiffExpressionsBase.first(var"##∂tup#283")
    var"μ##BAR###1#" = ReverseDiffExpressions.vadd!(var"μ##BAR###0#", var"##μ##BAR###1##314")
    var"##temp#292" = last(var"##rrule_LHS#291")
    var"####LHS#259##BAR###315" = ReverseDiffExpressionsBase.fourth(var"##∂tup#297")
    var"##LHS#259##BAR##" = ReverseDiffExpressions.vadd!(var"##LHS#259##BAR###0#", var"####LHS#259##BAR###315")
    (var"###STACK##POINTER###", var"##temp#293") = ReverseDiffExpressions.stack_pointer_call(callunthunk, var"###STACK##POINTER###", var"##temp#292", var"##LHS#259##BAR##")
    var"##μ##BAR###316" = LoopVectorization.second(var"##temp#293")
    var"μ##BAR##" = ReverseDiffExpressions.vadd!(var"μ##BAR###1#", var"##μ##BAR###316")
    (var"###STACK##POINTER###", var"##nothing#299") = ReverseDiffExpressions.stack_pointer_call(constrain_reverse!, var"###STACK##POINTER###", var"##targetconstrained#263##BAR##", RealArray{Tuple{12}, -Inf, Inf, 0}())
    var"##TARGET###0#" = ReverseDiffExpressionsBase.first(var"##tup#276")
    var"####TARGET###1##317" = ReverseDiffExpressionsBase.first(var"##tup#278")
    var"##TARGET###1#" = ReverseDiffExpressions.vadd!(var"##TARGET###0#", var"####TARGET###1##317")
    var"####TARGET###2##318" = ReverseDiffExpressionsBase.first(var"##tup#280")
    var"##TARGET###2#" = ReverseDiffExpressions.vadd!(var"##TARGET###1#", var"####TARGET###2##318")
    var"####TARGET###3##319" = ReverseDiffExpressionsBase.first(var"##tup#282")
    var"##TARGET###3#" = ReverseDiffExpressions.vadd!(var"##TARGET###2#", var"####TARGET###3##319")
    var"####TARGET###4##320" = ReverseDiffExpressionsBase.first(var"##tup#294")
    var"##TARGET###4#" = ReverseDiffExpressions.vadd!(var"##TARGET###3#", var"####TARGET###4##320")
    var"####TARGET###321" = ReverseDiffExpressionsBase.first(var"##tup#296")
    var"##TARGET##" = ReverseDiffExpressions.vadd!(var"##TARGET###4#", var"####TARGET###321")
    vsum(var"##TARGET##")
end

It checks whether gradients are considered known (and I'd implement gradients for all supported distributions), otherwise it would fall back to another library.

I don't think the above code would run, because I'm pretty sure I didn't implement all the needed methods.

My longer term ambition here is to be able to convert many distributions and linear algebra routines into equivalent loop-based representations, and then use LoopVectorization to reorder, fuse them, etc (still working on those optimizations).

@mohamed82008
Copy link
Member Author

Cool! LoopVectorization as an optimization pass sounds very interesting. This will require the symbolic version of a chain rule though which is not stored by default in ChainRules. Transforming linear algebra expressions to loop-based expressions sounds like a task Tullio can probably do with a little meta-programming help.

@ChrisRackauckas
Copy link
Collaborator

Look at the code generated in https://mtk.sciml.ai/dev/tutorials/auto_parallel/ to get a sense of what MTK is currently doing. The current push is to get non-scalar forms as well, but that's what we have so far.

@chriselrod
Copy link

chriselrod commented Sep 30, 2020

It'd be cool to add tullio / indexing notation support directly.
If we just add support manually, e.g. want to say sum with 1 argument -> loop that sums the numbers, we could also have predefined LoopSets representing the loop and skip the Expr -> LoopSet conversion step.

Although, it may be better to do the conversion, because Expr -> LoopSet promises a stable API, and I do not have one for the internal representation of the LoopSet, or for constructing one. While it'd be a good idea to create or formalize one for the latter, that hasn't happened yet.

@mohamed82008
Copy link
Member Author

mohamed82008 commented Sep 30, 2020

Yes and Tullio + a macro can help with that using a few if statements checking what the types of A and b are in A * b and writing the appropriate loop expression in each branch.

@chriselrod
Copy link

chriselrod commented Sep 30, 2020

The model definition would define the likelihood/gradient functions as @generated, to delay the compilation until all compile-time info is available. So the dimensionality of A and b should be known to the DSL.
If some of the axes are of static size, it should know those sizes, too.

I think it'd be easier to delay the substitution until then. Otherwise, you'd need a lot of branches, e.g. A or b could also be a Float64 instead of some sort of AbstractArray.

@mohamed82008
Copy link
Member Author

mohamed82008 commented Sep 30, 2020

Otherwise, you'd need a lot of branches, e.g. A or b could also be a Float64 instead of some sort of AbstractArray.

A macro can do this no problem. Each loop can be annotated with @avx which can go on and use a generated function to find the sizes. Obviously dead branches will be eliminated by the Julia compiler because they are all type based.

@mohamed82008
Copy link
Member Author

mohamed82008 commented Sep 30, 2020

Or maybe the entire body can be annotated to optimize the whole expression for each type scenario. But I imagine this is somewhat out of LoopVectorization's comfort zone.

@yebai yebai closed this as completed Dec 16, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants