Skip to content

Commit

Permalink
Merge pull request #6 from JuliaTrustworthyAI/energy-differential
Browse files Browse the repository at this point in the history
Energy differential
  • Loading branch information
pat-alt authored Oct 28, 2024
2 parents 383309c + 8937a89 commit 926f37c
Show file tree
Hide file tree
Showing 11 changed files with 2,448 additions and 2,390 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ jobs:
- uses: actions/checkout@v4
- uses: quarto-dev/quarto-actions/setup@v2
with:
version: 1.5.54
version: 1.5.57
- uses: julia-actions/setup-julia@latest
with:
version: '1'
version: '1.10'
- uses: julia-actions/cache@v2
- name: Cache Quarto
id: cache-quarto
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/FormatCheck.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- uses: julia-actions/setup-julia@latest
with:
version: 1
- uses: actions/checkout@v1
- uses: actions/checkout@v4
- name: Install JuliaFormatter
run: |
using Pkg
Expand Down
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Changelog

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Version [1.0.1] - 2024-10-26

### Added

- Added new methods `energy_differential` and `energy_penalty`. [#6]

4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnergySamplers"
uuid = "f446124b-5d5e-4171-a6dd-a1d99768d3ce"
authors = ["Patrick Altmeyer and contributors"]
version = "1.0.0"
version = "1.0.1"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand All @@ -21,7 +21,7 @@ MLUtils = "0.4"
Optimisers = "0.3"
StatsBase = "0.33, 0.34"
Tables = "1.12"
Test = "1.10"
Test = "1"
julia = "1.10"

[extras]
Expand Down
18 changes: 9 additions & 9 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ The package adds two new optimisers that are compatible with the [Optimisers.jl]
1. Stochastic Gradient Langevin Dynamics (SGLD) (Welling and Teh 2011) — [`SGLD`](@ref).
2. Improper SGLD (see, for example, Grathwohl et al. (2020)) — [`ImproperSGLD`](@ref).

SGLD is an efficient gradient-based Markov Chain Monte Carlo (MCMC) method that can be used in the context of EBM to draw samples from the model posterior (Murphy 2023). Formally, we can draw from $p_{\theta}(\mathbf{x})$ as follows
SGLD is an efficient gradient-based Markov Chain Monte Carlo (MCMC) method that can be used in the context of EBM to draw samples from the model posterior (Murphy 2023). Formally, we can draw from $p_{\theta}(x)$ as follows

``` math
\begin{aligned}
\mathbf{x}_{j+1} &\leftarrow \mathbf{x}_j - \frac{\epsilon_j^2}{2} \nabla_x \mathcal{E}_{\theta}(\mathbf{x}_j) + \epsilon_j \mathbf{r}_j, && j=1,...,J
x_{j+1} &\leftarrow x_j - \frac{\epsilon_j^2}{2} \nabla_x \mathcal{E}_{\theta}(x_j) + \epsilon_j r_j, && j=1,...,J
\end{aligned}
```

where $\mathbf{r}_j \sim \mathcal{N}(\mathbf{0},\mathbf{I})$ is a stochastic term and the step-size $\epsilon_j$ is typically polynomially decayed (Welling and Teh 2011). To allow for faster sampling, it is common practice to choose the step-size $\epsilon_j$ and the standard deviation of $\mathbf{r}_j$ separately. While $\mathbf{x}_J$ is only guaranteed to distribute as $p_{\theta}(\mathbf{x})$ if $\epsilon \rightarrow 0$ and $J \rightarrow \infty$, the bias introduced for a small finite $\epsilon$ is negligible in practice (Murphy 2023). We denote this form of sampling as Improper SGLD.
where $r_j \sim \mathcal{N}(0,I)$ is a stochastic term and the step-size $\epsilon_j$ is typically polynomially decayed (Welling and Teh 2011). To allow for faster sampling, it is common practice to choose the step-size $\epsilon_j$ and the standard deviation of $r_j$ separately. While $x_J$ is only guaranteed to distribute as $p_{\theta}(x)$ if $\epsilon \rightarrow 0$ and $J \rightarrow \infty$, the bias introduced for a small finite $\epsilon$ is negligible in practice (Murphy 2023). We denote this form of sampling as Improper SGLD.

### Example: Bayesian Inferecne with SGLD

Expand Down Expand Up @@ -154,7 +154,7 @@ plot(p1, p2, size=(800, 400))
│ The input will be converted, but any earlier layers may be very slow.
│ layer = Dense(3 => 1, σ) # 4 parameters
│ summary(x) = "3×7000 adjoint(::Matrix{Float64}) with eltype Float64"
└ @ Flux ~/.julia/packages/Flux/HBF2N/src/layers/stateless.jl:60
└ @ Flux ~/.julia/packages/Flux/htpCe/src/layers/stateless.jl:59
Final parameters are Float32[-2.3311744 1.1305944 -1.5102222 -4.0762844]
Test accuracy is 0.9666666666666667
Final parameters are Float32[-0.6106307 2.760134 -0.031244753 -5.8856964]
Expand All @@ -166,9 +166,9 @@ plot(p1, p2, size=(800, 400))

In the context of EBM, the optimisers can be used to sample from a model posterior. To this end, the package provides the following samples:

1. [`UnconditionalSampler`](@ref) — samples from the unconditional distribution $p_{\theta}(\mathbf{x})$ as in Grathwohl et al. (2020).
2. [`ConditionalSampler`](@ref) — samples from the conditional distribution $p_{\theta}(\mathbf{x}|y)$ as in Grathwohl et al. (2020).
3. [`JointSampler`](@ref) — samples from the joint distribution $p_{\theta}(\mathbf{x},y)$ as in Kelly, Zemel, and Grathwohl (2021).
1. [`UnconditionalSampler`](@ref) — samples from the unconditional distribution $p_{\theta}(x)$ as in Grathwohl et al. (2020).
2. [`ConditionalSampler`](@ref) — samples from the conditional distribution $p_{\theta}(x|y)$ as in Grathwohl et al. (2020).
3. [`JointSampler`](@ref) — samples from the joint distribution $p_{\theta}(x,y)$ as in Kelly, Zemel, and Grathwohl (2021).

### Example: Joint Energy-Based Model

Expand Down Expand Up @@ -221,7 +221,7 @@ end
```

[ Info: Epoch 1
Accuracy: 0.9995
Accuracy: 0.99
[ Info: Epoch 2
Accuracy: 0.9995
[ Info: Epoch 3
Expand Down Expand Up @@ -270,7 +270,7 @@ plot(plt)
│ The input will be converted, but any earlier layers may be very slow.
│ layer = Dense(2 => 2) # 6 parameters
│ summary(x) = "2-element Vector{Float64}"
└ @ Flux ~/.julia/packages/Flux/HBF2N/src/layers/stateless.jl:60
└ @ Flux ~/.julia/packages/Flux/htpCe/src/layers/stateless.jl:59

![](index_files/figure-commonmark/cell-8-output-2.svg)

Expand Down
160 changes: 80 additions & 80 deletions docs/src/index_files/figure-commonmark/cell-5-output-2.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4,584 changes: 2,294 additions & 2,290 deletions docs/src/index_files/figure-commonmark/cell-8-output-2.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions src/EnergySamplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ abstract type AbstractSamplingRule <: Optimisers.AbstractRule end
"Base type for samplers."
abstract type AbstractSampler end

const DOC_Grathwohl = "For details see Grathwohl et al. (2020) [[arXiv](https://arxiv.org/abs/1912.03263), [ICLR](https://iclr.cc/virtual_2020/poster_Hkxzx0NtDB.html)]."

export AbstractSampler, AbstractSamplingRule
export ConditionalSampler, UnconditionalSampler, JointSampler
export PMC
Expand Down
39 changes: 34 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
using Flux
using StatsBase

"""
get_logits(f::Flux.Chain, x)
Retrieves the logits (linear predictions) of a `Chain` for the input `x`.
"""
get_logits(f::Flux.Chain, x) = f[end] isa Function ? f[1:(end - 1)](x) : f(x)

@doc raw"""
energy(f, x)
_energy(f, x; agg=mean)
Computes the energy for unconditional samples $x \sim p_{\theta}(x)$: $E(x)=-\text{LogSumExp}_y f_{\theta}(x)[y]$.
Computes the energy for unconditional samples $x \sim p_{\theta}(x)$: $E(x)=-\text{LogSumExp}_y f_{\theta}(x)[y]$. $DOC_Grathwohl
"""
function _energy(f, x; agg=mean)
if f isa Flux.Chain
Expand All @@ -24,9 +29,9 @@ function _energy(f, x; agg=mean)
end

@doc raw"""
energy(f, x, y::Int; agg=mean)
_energy(f, x, y::Int; agg=mean)
Computes the energy for conditional samples $x \sim p_{\theta}(x|y)$: $E(x)=- f_{\theta}(x)[y]$.
Computes the energy for conditional samples $x \sim p_{\theta}(x|y)$: $E(x)=- f_{\theta}(x)[y]$. $DOC_Grathwohl
"""
function _energy(f, x, y::Int; agg=mean)
if f isa Flux.Chain
Expand All @@ -40,6 +45,30 @@ function _energy(f, x, y::Int; agg=mean)
E = agg(map(_y -> _E(_y, y), eachslice(ŷ; dims=ndims(ŷ))))
return E
else
return _E(_y, y)
return _E(, y)
end
end

@doc raw"""
energy_differential(f, xgen, xsampled, y::Int; agg=mean)
Computes the energy differential between a conditional sample ``x_{\text{gen}} \sim p_{\theta}(x|y)`` and an observed sample ``x_{\text{sample}} \sim p(x|y)`` as ``E(x_{\text{sample}}|y) - E(x_{\text{gen}}|y)`` with ``E(x|y) = -f_{\theta}(x)[y]``. $DOC_Grathwohl
"""
function energy_differential(f, xgen, xsampled, y::Int; agg=mean)
neg_loss = _energy(f, xgen, y; agg=agg) # negative loss associated with generated samples
pos_loss = _energy(f, xsampled, y; agg=agg) # positive loss associated with sampled samples
= pos_loss - neg_loss
return
end

@doc raw"""
energy_penalty(f, xgen, xsampled, y::Int; agg=mean)
Computes the a Ridge penalty for the overall energies of the conditional samples ``x_{\text{gen}} \sim p_{\theta}(x|y)`` and an observed sample ``x_{\text{sample}} \sim p(x|y)``. $DOC_Grathwohl
"""
function energy_penalty(f, xgen, xsampled, y::Int; agg=mean, p=1)
neg_loss = _energy(f, xgen, y; agg=agg) # negative loss associated with generated samples
pos_loss = _energy(f, xsampled, y; agg=agg) # positive loss associated with sampled samples
= neg_loss .^ 2 .+ pos_loss .^ 2
return
end
11 changes: 11 additions & 0 deletions test/other.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using EnergySamplers: energy_differential, energy_penalty

@testset "Other things" begin
@testset "Energy differential" begin
f(X::Matrix) = [prod(x) for x in eachcol(X)]
x1 = randn(10, 1)
x2 = randn(10, 1)
@test isreal(energy_differential(f, x1, x2, 1))
@test isreal(energy_penalty(f, x1, x2, 1))
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ using Test

@testset "EnergySamplers.jl" begin
include("aqua.jl")

include("samplers.jl")
include("other.jl")
end

2 comments on commit 926f37c

@pat-alt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Added

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/118194

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.0.1 -m "<description of version>" 926f37cbd150dec0def78339e646ca85d72251da
git push origin v1.0.1

Please sign in to comment.