Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Energy differential #6

Merged
merged 8 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ jobs:
- uses: actions/checkout@v4
- uses: quarto-dev/quarto-actions/setup@v2
with:
version: 1.5.54
version: 1.5.57
- uses: julia-actions/setup-julia@latest
with:
version: '1'
version: '1.10'
- uses: julia-actions/cache@v2
- name: Cache Quarto
id: cache-quarto
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/FormatCheck.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- uses: julia-actions/setup-julia@latest
with:
version: 1
- uses: actions/checkout@v1
- uses: actions/checkout@v4
- name: Install JuliaFormatter
run: |
using Pkg
Expand Down
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Changelog

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Version [1.0.1] - 2024-10-26

### Added

- Added new methods `energy_differential` and `energy_penalty`. [#6]

4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnergySamplers"
uuid = "f446124b-5d5e-4171-a6dd-a1d99768d3ce"
authors = ["Patrick Altmeyer and contributors"]
version = "1.0.0"
version = "1.0.1"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand All @@ -21,7 +21,7 @@ MLUtils = "0.4"
Optimisers = "0.3"
StatsBase = "0.33, 0.34"
Tables = "1.12"
Test = "1.10"
Test = "1"
julia = "1.10"

[extras]
Expand Down
18 changes: 9 additions & 9 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ The package adds two new optimisers that are compatible with the [Optimisers.jl]
1. Stochastic Gradient Langevin Dynamics (SGLD) (Welling and Teh 2011) — [`SGLD`](@ref).
2. Improper SGLD (see, for example, Grathwohl et al. (2020)) — [`ImproperSGLD`](@ref).

SGLD is an efficient gradient-based Markov Chain Monte Carlo (MCMC) method that can be used in the context of EBM to draw samples from the model posterior (Murphy 2023). Formally, we can draw from $p_{\theta}(\mathbf{x})$ as follows
SGLD is an efficient gradient-based Markov Chain Monte Carlo (MCMC) method that can be used in the context of EBM to draw samples from the model posterior (Murphy 2023). Formally, we can draw from $p_{\theta}(x)$ as follows

``` math
\begin{aligned}
\mathbf{x}_{j+1} &\leftarrow \mathbf{x}_j - \frac{\epsilon_j^2}{2} \nabla_x \mathcal{E}_{\theta}(\mathbf{x}_j) + \epsilon_j \mathbf{r}_j, && j=1,...,J
x_{j+1} &\leftarrow x_j - \frac{\epsilon_j^2}{2} \nabla_x \mathcal{E}_{\theta}(x_j) + \epsilon_j r_j, && j=1,...,J
\end{aligned}
```

where $\mathbf{r}_j \sim \mathcal{N}(\mathbf{0},\mathbf{I})$ is a stochastic term and the step-size $\epsilon_j$ is typically polynomially decayed (Welling and Teh 2011). To allow for faster sampling, it is common practice to choose the step-size $\epsilon_j$ and the standard deviation of $\mathbf{r}_j$ separately. While $\mathbf{x}_J$ is only guaranteed to distribute as $p_{\theta}(\mathbf{x})$ if $\epsilon \rightarrow 0$ and $J \rightarrow \infty$, the bias introduced for a small finite $\epsilon$ is negligible in practice (Murphy 2023). We denote this form of sampling as Improper SGLD.
where $r_j \sim \mathcal{N}(0,I)$ is a stochastic term and the step-size $\epsilon_j$ is typically polynomially decayed (Welling and Teh 2011). To allow for faster sampling, it is common practice to choose the step-size $\epsilon_j$ and the standard deviation of $r_j$ separately. While $x_J$ is only guaranteed to distribute as $p_{\theta}(x)$ if $\epsilon \rightarrow 0$ and $J \rightarrow \infty$, the bias introduced for a small finite $\epsilon$ is negligible in practice (Murphy 2023). We denote this form of sampling as Improper SGLD.

### Example: Bayesian Inferecne with SGLD

Expand Down Expand Up @@ -154,7 +154,7 @@ plot(p1, p2, size=(800, 400))
│ The input will be converted, but any earlier layers may be very slow.
│ layer = Dense(3 => 1, σ) # 4 parameters
│ summary(x) = "3×7000 adjoint(::Matrix{Float64}) with eltype Float64"
└ @ Flux ~/.julia/packages/Flux/HBF2N/src/layers/stateless.jl:60
└ @ Flux ~/.julia/packages/Flux/htpCe/src/layers/stateless.jl:59
Final parameters are Float32[-2.3311744 1.1305944 -1.5102222 -4.0762844]
Test accuracy is 0.9666666666666667
Final parameters are Float32[-0.6106307 2.760134 -0.031244753 -5.8856964]
Expand All @@ -166,9 +166,9 @@ plot(p1, p2, size=(800, 400))

In the context of EBM, the optimisers can be used to sample from a model posterior. To this end, the package provides the following samples:

1. [`UnconditionalSampler`](@ref) — samples from the unconditional distribution $p_{\theta}(\mathbf{x})$ as in Grathwohl et al. (2020).
2. [`ConditionalSampler`](@ref) — samples from the conditional distribution $p_{\theta}(\mathbf{x}|y)$ as in Grathwohl et al. (2020).
3. [`JointSampler`](@ref) — samples from the joint distribution $p_{\theta}(\mathbf{x},y)$ as in Kelly, Zemel, and Grathwohl (2021).
1. [`UnconditionalSampler`](@ref) — samples from the unconditional distribution $p_{\theta}(x)$ as in Grathwohl et al. (2020).
2. [`ConditionalSampler`](@ref) — samples from the conditional distribution $p_{\theta}(x|y)$ as in Grathwohl et al. (2020).
3. [`JointSampler`](@ref) — samples from the joint distribution $p_{\theta}(x,y)$ as in Kelly, Zemel, and Grathwohl (2021).

### Example: Joint Energy-Based Model

Expand Down Expand Up @@ -221,7 +221,7 @@ end
```

[ Info: Epoch 1
Accuracy: 0.9995
Accuracy: 0.99
[ Info: Epoch 2
Accuracy: 0.9995
[ Info: Epoch 3
Expand Down Expand Up @@ -270,7 +270,7 @@ plot(plt)
│ The input will be converted, but any earlier layers may be very slow.
│ layer = Dense(2 => 2) # 6 parameters
│ summary(x) = "2-element Vector{Float64}"
└ @ Flux ~/.julia/packages/Flux/HBF2N/src/layers/stateless.jl:60
└ @ Flux ~/.julia/packages/Flux/htpCe/src/layers/stateless.jl:59

![](index_files/figure-commonmark/cell-8-output-2.svg)

Expand Down
160 changes: 80 additions & 80 deletions docs/src/index_files/figure-commonmark/cell-5-output-2.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4,584 changes: 2,294 additions & 2,290 deletions docs/src/index_files/figure-commonmark/cell-8-output-2.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions src/EnergySamplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ abstract type AbstractSamplingRule <: Optimisers.AbstractRule end
"Base type for samplers."
abstract type AbstractSampler end

const DOC_Grathwohl = "For details see Grathwohl et al. (2020) [[arXiv](https://arxiv.org/abs/1912.03263), [ICLR](https://iclr.cc/virtual_2020/poster_Hkxzx0NtDB.html)]."

export AbstractSampler, AbstractSamplingRule
export ConditionalSampler, UnconditionalSampler, JointSampler
export PMC
Expand Down
39 changes: 34 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
using Flux
using StatsBase

"""
get_logits(f::Flux.Chain, x)

Retrieves the logits (linear predictions) of a `Chain` for the input `x`.
"""
get_logits(f::Flux.Chain, x) = f[end] isa Function ? f[1:(end - 1)](x) : f(x)

@doc raw"""
energy(f, x)
_energy(f, x; agg=mean)

Computes the energy for unconditional samples $x \sim p_{\theta}(x)$: $E(x)=-\text{LogSumExp}_y f_{\theta}(x)[y]$.
Computes the energy for unconditional samples $x \sim p_{\theta}(x)$: $E(x)=-\text{LogSumExp}_y f_{\theta}(x)[y]$. $DOC_Grathwohl
"""
function _energy(f, x; agg=mean)
if f isa Flux.Chain
Expand All @@ -24,9 +29,9 @@ function _energy(f, x; agg=mean)
end

@doc raw"""
energy(f, x, y::Int; agg=mean)
_energy(f, x, y::Int; agg=mean)

Computes the energy for conditional samples $x \sim p_{\theta}(x|y)$: $E(x)=- f_{\theta}(x)[y]$.
Computes the energy for conditional samples $x \sim p_{\theta}(x|y)$: $E(x)=- f_{\theta}(x)[y]$. $DOC_Grathwohl
"""
function _energy(f, x, y::Int; agg=mean)
if f isa Flux.Chain
Expand All @@ -40,6 +45,30 @@ function _energy(f, x, y::Int; agg=mean)
E = agg(map(_y -> _E(_y, y), eachslice(ŷ; dims=ndims(ŷ))))
return E
else
return _E(_y, y)
return _E(, y)
end
end

@doc raw"""
energy_differential(f, xgen, xsampled, y::Int; agg=mean)

Computes the energy differential between a conditional sample ``x_{\text{gen}} \sim p_{\theta}(x|y)`` and an observed sample ``x_{\text{sample}} \sim p(x|y)`` as ``E(x_{\text{sample}}|y) - E(x_{\text{gen}}|y)`` with ``E(x|y) = -f_{\theta}(x)[y]``. $DOC_Grathwohl
"""
function energy_differential(f, xgen, xsampled, y::Int; agg=mean)
neg_loss = _energy(f, xgen, y; agg=agg) # negative loss associated with generated samples
pos_loss = _energy(f, xsampled, y; agg=agg) # positive loss associated with sampled samples
ℓ = pos_loss - neg_loss
return ℓ
end

@doc raw"""
energy_penalty(f, xgen, xsampled, y::Int; agg=mean)

Computes the a Ridge penalty for the overall energies of the conditional samples ``x_{\text{gen}} \sim p_{\theta}(x|y)`` and an observed sample ``x_{\text{sample}} \sim p(x|y)``. $DOC_Grathwohl
"""
function energy_penalty(f, xgen, xsampled, y::Int; agg=mean, p=1)
neg_loss = _energy(f, xgen, y; agg=agg) # negative loss associated with generated samples
pos_loss = _energy(f, xsampled, y; agg=agg) # positive loss associated with sampled samples
ℓ = neg_loss .^ 2 .+ pos_loss .^ 2
return ℓ
end
11 changes: 11 additions & 0 deletions test/other.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using EnergySamplers: energy_differential, energy_penalty

@testset "Other things" begin
@testset "Energy differential" begin
f(X::Matrix) = [prod(x) for x in eachcol(X)]
x1 = randn(10, 1)
x2 = randn(10, 1)
@test isreal(energy_differential(f, x1, x2, 1))
@test isreal(energy_penalty(f, x1, x2, 1))
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ using Test

@testset "EnergySamplers.jl" begin
include("aqua.jl")

include("samplers.jl")
include("other.jl")
end