Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BlockMatrixTensor #13

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
24 changes: 22 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,30 @@ version = "0.1.0"
[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
ProfileView = "c46f51b8-102a-5cf2-8d2c-8597cb0e0da7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ArrayLayouts = "1"
ForwardDiff = "0.10.36"
Lux = "0.5.62"
Plots = "1.40.5"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Lux", "ForwardDiff", "Test"]
69 changes: 69 additions & 0 deletions examples/convlayer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
using LinearAlgebra, MLDatasets, Plots, DualArrays, Random, FillArrays

#GOAL: Implement and differentiate a convolutional neural network layer
function convlayer(img, ker, xstride = 1, ystride = 1)
n, m = size(ker)
t = eltype(ker)

n2, m2 = size(img)
n3, m3 = div(n2-n+1,xstride), div(m2-m+1,ystride)
fmap = zeros(promote_type(eltype(img), t), n3, m3)
#Apply kernel to section of image
for i= 1:xstride:n3,j = 1:ystride:m3
ft = img[i:i+n-1,j:j+m-1] .* ker
fmap[i,j] = sum(ft)
end
fmap
end

function softmax(x)
s = sum(exp.(x))
exp.(x) / s
end

function dense_layer(W, b, x, f::Function = identity)
ret = W*x
println("Multiplication complete")
ret += b
println("Addition Complete")
f(ret)
end

function cross_entropy(x, y)
-sum(y .* log.(x))
end

function model_loss(x, y, w)
ker = reshape(w[1:9], 3, 3)
weights = reshape(w[10:6769], 10, 676)
biases = w[6770:6779]
println("Reshape Complete")
l1 = vec(DualMatrix(convlayer(x, ker)))
println("Conv layer complete")
l2 = dense_layer(weights, biases, l1, softmax)
println("Dense Layer Complete")
target = OneElement(1, y+1, 10)
loss = cross_entropy(l2, target)
println("Loss complete")
loss.value, loss.partials
end

function train_model()
p = rand(6779)
epochs = 1000
lr = 0.02
dataset = MNIST(:train)

for i = 1:epochs
train, test = dataset[i]
d = DualVector(p, I(6779))

loss, grads = model_loss(train, test, d)
println(loss)
p = p - lr * grads
end
end

train_model()


2 changes: 1 addition & 1 deletion examples/heatequation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Compare performance and results.
##

using LinearAlgebra, BandedMatrices, DifferentialEquations, Plots,
using LinearAlgebra, BandedMatrices, OrdinaryDiffEqs, Plots,
ForwardDiff

U = 500
Expand Down
45 changes: 45 additions & 0 deletions examples/neuralnet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import Lux: relu
using DualArrays, LinearAlgebra, Plots, FillArrays, Test

#GOAL: Learn exp() using simple one-layer relu neural network.

N = 20
d = 0.1

#Domain over which to learn exp
data = collect(d:d:N*d)

function model(a, b, x)
return relu.(a .* x + b)
end

function model_loss(w)
sum((model(w[1:N], w[(N+1):end], data) - exp.(data)) .^ 2)
end

function gradient_descent_sparse(n, lr = 0.01)
weights = ones(2 * N)
for i = 1:n
dw = DualVector(weights, Eye{Float64}(2*N))
grads = model_loss(dw).partials
weights -= lr * grads
end
model(weights[1:N], weights[(N + 1):end], data)
end

function gradient_descent_dense(n, lr = 0.01)
weights = ones(2 * N)
for i = 1:n
dw = DualVector(weights, Matrix(I, 2*N, 2*N))
grads = model_loss(dw).partials
weights -= lr * grads
end
model(weights[1:N], weights[(N + 1):end], data)
end

@time densesol = gradient_descent_dense(500)
@time sparsesol = gradient_descent_sparse(500)
@test densesol == sparsesol
@test sparsesol ≈ exp.(data) rtol = 1e-2


41 changes: 34 additions & 7 deletions examples/newtonpendulum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,36 @@
# via discretisation and Newton's method.
##

using LinearAlgebra, ForwardDiff, Plots, DualArrays

using LinearAlgebra, ForwardDiff, Plots, DualArrays, FillArrays

#Boundary Conditions
a = 0.1
b = 0.0

#Time step, Time period and number of x for discretisation.
ts = 0.001
ts = 0.1

Tmax = 5.0
N = Int(Tmax/ts) - 1

#LHS of ode
function f(x)
n = length(x)
D = Tridiagonal([ones(Float64, n) / ts ; 0.0], [1.0; -2ones(Float64, n) / ts; 1.0], [0.0; ones(Float64, n) / ts])
D = Tridiagonal([ones(n) / ts ; 0.0], [1.0; -2ones(n) / ts; 1.0], [0.0; ones(n) / ts])
(D * [a; x; b])[2:end-1] + sin.(x)
end


function f(u, a, b, Tmax)
h = Tmax/(length(u)-1)
[u[1] - a;
(u[1:end-2] - 2u[2:end-1] + u[3:end])/h^2 + sin.(u[2:end-1]);
u[end] - b]
end



#Newtons method using ForwardDiff.jl
function newton_method_forwarddiff(f, x0, n)
x = x0
Expand All @@ -38,18 +50,33 @@ function newton_method_dualvector(f, x0, n)
x = x0
l = length(x0)
for i = 1:n
∇f = f(DualVector(x, Matrix(I, l, l))).jacobian
∇f = f(DualVector(x, Eye(l))).jacobian
x = x - ∇f \ f(x)
end
x
end

function newton_method_dualvector2(f, x0, n)
x = x0
l = length(x0)
for i = 1:n
∇f = f(DualVector(x, Eye(l)), a, b, Tmax).jacobian
x = x - ∇f \ f(x)
end
x
end

#Initial guess
x0 = zeros(Float64, N)
x0 = zeros(N)

#Solve and plot both solution and LHS ('deviation' from system)
@time sol = newton_method_forwarddiff(f, x0, 100)
@time sol = newton_method_dualvector(f, x0, 100)
@time sol1 = newton_method_forwarddiff(f, x0, 100);
@time sol2 = newton_method_dualvector(f, x0, 100);
@test sol1 ≈ sol2
@time sol = newton_method_dualvector2(f, x0, 100);

x0 = zeros(N); x0[1] = a; x0[end] = b;

plot(0:ts:Tmax, [a; sol; b])
plot!(0:ts:Tmax, [0; f(sol); 0])

76 changes: 76 additions & 0 deletions src/BlockMatrixTensor.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#4-Tensor with data stored in a BlockMatrix
#Enables sparsity for DualMatrix jacobian.
using BlockArrays, SparseArrays

struct BlockMatrixTensor{T} <: AbstractArray{T, 4}
data::BlockMatrix{T}
end

#Construct BlockMatrixTensor from 4-tensor. Useful for testing purposes.
function BlockMatrixTensor(x::AbstractArray{T, 4}) where {T}
n, m, s, t = size(x)
data = BlockMatrix(zeros(T, n * s, m * t), fill(s, n), fill(t, m))
for i = 1:n, j = 1:m
view(data, Block(i, j)) .= x[i, j, :, :]
end
BlockMatrixTensor(data)
end

# Useful for creating a BlockMatrixTensor from blocks()
BlockMatrixTensor(x::Matrix{T}) where {T <: AbstractMatrix} = BlockMatrixTensor(mortar(x))

size(x::BlockMatrixTensor) = (blocksize(x.data)..., blocksizes(x.data)[1,1]...)

#Indexing entire blocks
getindex(x::BlockMatrixTensor, a::Int, b::Int, c, d) = blocks(x.data)[a, b][c, d]

Check warning on line 25 in src/BlockMatrixTensor.jl

View check run for this annotation

Codecov / codecov/patch

src/BlockMatrixTensor.jl#L25

Added line #L25 was not covered by tests
getindex(x::BlockMatrixTensor, a::Int, b::Int, ::Colon, ::Colon) = blocks(x.data)[a, b]
getindex(x::BlockMatrixTensor, a::Int, b, ::Colon, ::Colon) = BlockMatrixTensor(reshape(blocks(x.data)[a, b], 1, :))
getindex(x::BlockMatrixTensor, a, b::Int, ::Colon, ::Colon) = BlockMatrixTensor(reshape(blocks(x.data)[a, b], :, 1))
getindex(x::BlockMatrixTensor, a, b, ::Colon, ::Colon) = BlockMatrixTensor(blocks(x.data)[a,b])

Check warning on line 29 in src/BlockMatrixTensor.jl

View check run for this annotation

Codecov / codecov/patch

src/BlockMatrixTensor.jl#L27-L29

Added lines #L27 - L29 were not covered by tests


# For populating a BlockMatrixTensor
function setindex!(A::BlockMatrixTensor, value, a::Int, b::Int, ::Colon, ::Colon)
blocks(A.data)[a, b] = value

Check warning on line 34 in src/BlockMatrixTensor.jl

View check run for this annotation

Codecov / codecov/patch

src/BlockMatrixTensor.jl#L33-L34

Added lines #L33 - L34 were not covered by tests
end

function show(io::IO,m::MIME"text/plain", x::BlockMatrixTensor)
print("BlockMatrixTensor containing: \n")
show(io,m, x.data)

Check warning on line 39 in src/BlockMatrixTensor.jl

View check run for this annotation

Codecov / codecov/patch

src/BlockMatrixTensor.jl#L37-L39

Added lines #L37 - L39 were not covered by tests
end
show(io::IO, x::BlockMatrixTensor) = show(io, x.data)

Check warning on line 41 in src/BlockMatrixTensor.jl

View check run for this annotation

Codecov / codecov/patch

src/BlockMatrixTensor.jl#L41

Added line #L41 was not covered by tests

for op in (:*, :/)
@eval $op(x::BlockMatrixTensor, y::Number) = BlockMatrixTensor($op(x.data, y))

Check warning on line 44 in src/BlockMatrixTensor.jl

View check run for this annotation

Codecov / codecov/patch

src/BlockMatrixTensor.jl#L44

Added line #L44 was not covered by tests
@eval $op(x::Number, y::BlockMatrixTensor) = BlockMatrixTensor($op(x, y.data))
end

#Block-wise broadcast
broadcasted(f::Function, x::BlockMatrixTensor, y::AbstractMatrix) = BlockMatrixTensor(f.(blocks(x.data), y))
broadcasted(f::Function, x::BlockMatrixTensor, y::AbstractVector) = BlockMatrixTensor(f.(x, reshape(y, :, 1)))
broadcasted(f::Function, x::AbstractVecOrMat, y::BlockMatrixTensor) = f.(y, x)

function sum(x::BlockMatrixTensor; dims = Colon())
# Blockwise sum
if dims == 1:2
sum(blocks(x.data))
elseif dims == 1 || dims == 2
BlockMatrixTensor(sum(blocks(x.data); dims))
else
# For now, treat all other cases as if summing the 4-Tensor
sum(Array(x); dims = dims)

Check warning on line 61 in src/BlockMatrixTensor.jl

View check run for this annotation

Codecov / codecov/patch

src/BlockMatrixTensor.jl#L61

Added line #L61 was not covered by tests
end
end

function reshape(x::BlockMatrixTensor, dims::Vararg{Union{Colon, Int}, 4})
#Reshape block-wise
#TODO: Implement non-blockwise
if dims[3] isa Colon && dims[4] isa Colon
BlockMatrixTensor(reshape(blocks(x.data), dims[1], dims[2]))
end
end

#'Flatten': converts BlockMatrixTensor to Matrix by removing block structure
# Mimics the reshape achieved by this for general 4-Tensors
flatten(x::BlockMatrixTensor) = hcat((vcat((x[i, j, :, :] for i = 1:size(x, 1))...) for j = 1:size(x, 2))...)
flatten(x::AbstractArray{T, 4}) where {T} = reshape(x, size(x, 1) * size(x, 3), size(x, 2) * size(x, 4))

Check warning on line 76 in src/BlockMatrixTensor.jl

View check run for this annotation

Codecov / codecov/patch

src/BlockMatrixTensor.jl#L76

Added line #L76 was not covered by tests
Loading
Loading