-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #33 from LuxDL/ap/common_layers
Add Common layers
- Loading branch information
Showing
18 changed files
with
621 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
name = "Boltz" | ||
uuid = "4544d5e4-abc5-4dea-817f-29e4c205d9c8" | ||
authors = ["Avik Pal <[email protected]> and contributors"] | ||
version = "0.3.5" | ||
version = "0.3.6" | ||
|
||
[deps] | ||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" | ||
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
|
@@ -13,6 +14,8 @@ JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" | |
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" | ||
Lux = "b2108857-7c20-44ae-9111-449ecde12c47" | ||
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" | ||
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" | ||
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" | ||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" | ||
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
|
@@ -21,27 +24,36 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | |
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" | ||
|
||
[weakdeps] | ||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
||
[extensions] | ||
BoltzForwardDiffExt = "ForwardDiff" | ||
BoltzMetalheadExt = "Metalhead" | ||
BoltzZygoteExt = "Zygote" | ||
|
||
[compat] | ||
ADTypes = "1.3" | ||
Aqua = "0.8.7" | ||
ArgCheck = "2.3" | ||
Artifacts = "1.10" | ||
ChainRulesCore = "1.24" | ||
ComponentArrays = "0.15.13" | ||
ConcreteStructs = "0.2.3" | ||
ExplicitImports = "1.5" | ||
ForwardDiff = "0.10.36" | ||
GPUArraysCore = "0.1.6" | ||
JLD2 = "0.4.48" | ||
LazyArtifacts = "1.10" | ||
Lux = "0.5.50" | ||
LuxAMDGPU = "0.2.3" | ||
LuxCUDA = "0.3.2" | ||
LuxCore = "0.1.15" | ||
LuxDeviceUtils = "0.1.21" | ||
LuxLib = "0.3.26" | ||
LuxTestUtils = "0.1.15" | ||
Markdown = "1.10" | ||
Metalhead = "0.9" | ||
NNlib = "0.9.17" | ||
Pkg = "1.10" | ||
|
@@ -52,11 +64,14 @@ Reexport = "1.2.2" | |
Statistics = "1.10" | ||
Test = "1.10" | ||
WeightInitializers = "0.1.7" | ||
Zygote = "0.6.70" | ||
julia = "1.10" | ||
|
||
[extras] | ||
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" | ||
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" | ||
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" | ||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" | ||
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" | ||
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" | ||
|
@@ -65,6 +80,7 @@ Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" | |
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" | ||
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
||
[targets] | ||
test = ["Aqua", "ExplicitImports", "LuxAMDGPU", "LuxCUDA", "LuxLib", "LuxTestUtils", "Metalhead", "Pkg", "ReTestItems", "Test"] | ||
test = ["Aqua", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxLib", "LuxTestUtils", "Metalhead", "Pkg", "ReTestItems", "Test", "Zygote"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
module BoltzForwardDiffExt | ||
|
||
using ADTypes: AutoForwardDiff | ||
using Boltz: Boltz, Layers | ||
using ForwardDiff: ForwardDiff | ||
|
||
@inline Boltz._is_extension_loaded(::Val{:ForwardDiff}) = true | ||
|
||
@inline Boltz._should_type_assert(::AbstractArray{<:ForwardDiff.Dual}) = false | ||
@inline Boltz._should_type_assert(::ForwardDiff.Dual) = false | ||
|
||
# Hamiltonian NN | ||
function Layers.hamiltonian_forward(::AutoForwardDiff, model, x) | ||
return ForwardDiff.gradient(sum ∘ model, x) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
module BoltzZygoteExt | ||
|
||
using ADTypes: AutoZygote | ||
using Boltz: Boltz, Layers | ||
using Zygote: Zygote | ||
|
||
@inline Boltz._is_extension_loaded(::Val{:Zygote}) = true | ||
|
||
# Hamiltonian NN | ||
function Layers.hamiltonian_forward(::AutoZygote, model, x) | ||
return only(Zygote.gradient(sum ∘ model, x)) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
module Basis | ||
|
||
using ..Boltz: _unsqueeze1 | ||
using ChainRulesCore: ChainRulesCore, NoTangent | ||
using ConcreteStructs: @concrete | ||
using Markdown: @doc_str | ||
|
||
const CRC = ChainRulesCore | ||
|
||
# The rrules in this file are hardcoded to be used exclusively with GeneralBasisFunction | ||
@concrete struct GeneralBasisFunction{name} | ||
f | ||
n::Int | ||
end | ||
|
||
function Base.show(io::IO, basis::GeneralBasisFunction{name}) where {name} | ||
print(io, "Basis.$(name)(order=$(basis.n))") | ||
end | ||
|
||
@inline function (basis::GeneralBasisFunction{name, F})(x::AbstractArray) where {name, F} | ||
return basis.f.(1:(basis.n), _unsqueeze1(x)) | ||
end | ||
|
||
@doc doc""" | ||
Chebyshev(n) | ||
Constructs a Chebyshev basis of the form $[T_{0}(x), T_{1}(x), \dots, T_{n-1}(x)]$ where | ||
$T_j(.)$ is the $j^{th}$ Chebyshev polynomial of the first kind. | ||
## Arguments | ||
- `n`: number of terms in the polynomial expansion. | ||
""" | ||
Chebyshev(n) = GeneralBasisFunction{:Chebyshev}(__chebyshev, n) | ||
|
||
@inline __chebyshev(i, x) = @fastmath cos(i * acos(x)) | ||
|
||
@doc doc""" | ||
Sin(n) | ||
Constructs a sine basis of the form $[\sin(x), \sin(2x), \dots, \sin(nx)]$. | ||
## Arguments | ||
- `n`: number of terms in the sine expansion. | ||
""" | ||
Sin(n) = GeneralBasisFunction{:Sin}(@fastmath(sin∘*), n) | ||
|
||
@doc doc""" | ||
Cos(n) | ||
Constructs a cosine basis of the form $[\cos(x), \cos(2x), \dots, \cos(nx)]$. | ||
## Arguments | ||
- `n`: number of terms in the cosine expansion. | ||
""" | ||
Cos(n) = GeneralBasisFunction{:Cos}(@fastmath(cos∘*), n) | ||
|
||
@doc doc""" | ||
Fourier(n) | ||
Constructs a Fourier basis of the form | ||
$F_j(x) = j is even ? cos((j÷2)x) : sin((j÷2)x)$ => $[F_0(x), F_1(x), \dots, F_n(x)]$. | ||
## Arguments | ||
- `n`: number of terms in the Fourier expansion. | ||
""" | ||
Fourier(n) = GeneralBasisFunction{:Fourier}(__fourier, n) | ||
|
||
@inline @fastmath function __fourier(i, x::AbstractFloat) | ||
s, c = sincos(i * x / 2) | ||
return ifelse(iseven(i), c, s) | ||
end | ||
|
||
@inline function __fourier(i, x) # No FastMath for non abstract floats | ||
s, c = sincos(i * x / 2) | ||
return ifelse(iseven(i), c, s) | ||
end | ||
|
||
@fastmath function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof(__fourier), i, x) | ||
ix_by_2 = @. i * x / 2 | ||
s = @. sin(ix_by_2) | ||
c = @. cos(ix_by_2) | ||
y = @. ifelse(iseven(i), c, s) | ||
|
||
∇fourier = let s = s, c = c, i = i | ||
Δ -> begin | ||
return (NoTangent(), NoTangent(), NoTangent(), | ||
dropdims(sum((i / 2) .* ifelse.(iseven.(i), -s, c) .* Δ; dims=1); dims=1)) | ||
end | ||
end | ||
|
||
return y, ∇fourier | ||
end | ||
|
||
@doc doc""" | ||
Legendre(n) | ||
Constructs a Legendre basis of the form $[P_{0}(x), P_{1}(x), \dots, P_{n-1}(x)]$ where | ||
$P_j(.)$ is the $j^{th}$ Legendre polynomial. | ||
## Arguments | ||
- `n`: number of terms in the polynomial expansion. | ||
""" | ||
Legendre(n) = GeneralBasisFunction{:Legendre}(__legendre_poly, n) | ||
|
||
## Source: https://github.com/ranocha/PolynomialBases.jl/blob/master/src/legendre.jl | ||
@inline function __legendre_poly(i, x) | ||
p = i - 1 | ||
a = one(x) | ||
b = x | ||
|
||
p ≤ 0 && return a | ||
p == 1 && return b | ||
|
||
for j in 2:p | ||
a, b = b, @fastmath(((2j - 1) * x * b - (j - 1) * a)/j) | ||
end | ||
|
||
return b | ||
end | ||
|
||
@doc doc""" | ||
Polynomial(n) | ||
Constructs a Polynomial basis of the form $[1, x, \dots, x^(n-1)]$. | ||
## Arguments | ||
- `n`: number of terms in the polynomial expansion. | ||
""" | ||
Polynomial(n) = GeneralBasisFunction{:Polynomial}(__polynomial, n) | ||
|
||
@inline __polynomial(i, x) = x^(i - 1) | ||
|
||
function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof(__polynomial), i, x) | ||
y_m1 = x .^ (i .- 2) | ||
y = y_m1 .* x | ||
∇polynomial = let y_m1 = y_m1, i = i | ||
Δ -> begin | ||
return (NoTangent(), NoTangent(), NoTangent(), | ||
dropdims(sum((i .- 1) .* y_m1 .* Δ; dims=1); dims=1)) | ||
end | ||
end | ||
return y, ∇polynomial | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.