Skip to content

Commit

Permalink
Update to Zig doc comments now that we can wrap them
Browse files Browse the repository at this point in the history
I fixed the VSCode Rewrap plugin fixed for Zig in
stkb/Rewrap#389
  • Loading branch information
MadLittleMods committed Nov 3, 2023
1 parent ef1eee5 commit c0e42d2
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 275 deletions.
19 changes: 17 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
# Zig neural network library

This is meant to be a clear, annotated, from scratch, neural network implementation in Zig.
This is meant to be a clear, annotated, from scratch, neural network library in
Zig.

To add some buzzword details, it's a multi-layer perceptron (MLP) with backpropagation
(reverse-mode automatic differentiation) and stochastic gradient descent (SGD).

This is heavily inspired by my [first neural network
implementation](https://github.com/MadLittleMods/zig-ocr-neural-network) which was
heavily based on [Sebastian Lague's video](https://www.youtube.com/watch?v=hfMk-kjRv4c)
and now this implementation makes things a bit simpler to understand by following the
pattern from [Omar Aflak's (The Independent Code)
video](https://www.youtube.com/watch?v=pauPCy_s0Ok) where layers just have
`forward(...)`/`backward(...)` methods and the activations are just another layer in the
network (reverse-mode automatic differentiation). See the [*developer
notes*](./dev-notes.md) for more details.


## Usage

Tested with Zig 0.11.0

Compatible with the Zig package manager. Just define it as a dependency in your `build.zig.zon` file.
Compatible with the Zig package manager. Just define it as a dependency in your
`build.zig.zon` file.

TODO: Update to use `zig-neural-network` instead of `zshuffle` once we figure it out

Expand Down
25 changes: 1 addition & 24 deletions examples/xor/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,7 @@ pub fn main() !void {
}
}

// I wish we could just inline these declarations as literals in the `layers` array
// so there wasn't a chance to re-assemble in the wrong order or accidenately use
// layers multiple times. But Zig will create any literal/temporary declaration as
// `const` with no way to specify that they are `var`/mutable
// (https://ziglang.org/download/0.10.0/release-notes.html#Address-of-Temporaries-Now-Produces-Const-Pointers).
// And we would also have trouble with deinitializing them since we wouldn't have a
// handle to them.
// var dense_layer1 = neural_network.DenseLayer.init(2, 3, allocator);
// defer dense_layer1.deinit();
// var activation_layer1 = neural_network.ActivationLayer(neural_network.ActivationFunction{ .elu = {} }).init();
// var dense_layer2 = neural_network.DenseLayer.init(3, 2);
// defer dense_layer2.deinit();
// var activation_layer2 = neural_network.ActivationLayer(neural_network.ActivationFunction{ .soft_max = {} }).init();

// var layers = [_]neural_network.Layer{
// dense_layer1.layer(),
// activation_layer1.layer(),
// dense_layer2.layer(),
// activation_layer2.layer(),
// };

neural_network.NeuralNetwork.initFromLayerSizes(
try neural_network.NeuralNetwork.initFromLayerSizes(
&[_]u32{ 2, 3, 2 },
neural_network.ActivationFunction{
// .relu = .{},
Expand All @@ -59,6 +38,4 @@ pub fn main() !void {
},
allocator,
);

std.log.debug("layers {any}", .{layers});
}
331 changes: 167 additions & 164 deletions src/activation_functions.zig

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/layers/activation_layer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ const ActivationFunction = @import("../activation_functions.zig").ActivationFunc
pub fn ActivationLayer(activation_function: ActivationFunction) type {
return struct {
const Self = @This();
// Store any inputs we get during the forward pass so we can use them during the
// backward pass.
/// Store any inputs we get during the forward pass so we can use them during
/// the backward pass.
inputs: []f64 = undefined,

pub fn init() !Self {
Expand Down
25 changes: 15 additions & 10 deletions src/layers/dense_layer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,22 @@ pub const DenseLayer = struct {
const Self = @This();
num_input_nodes: usize,
num_output_nodes: usize,
// Weights for each incoming connection. Each node in this layer has a weighted
// connection to each node in the previous layer (num_input_nodes * num_output_nodes).
//
// The weights are stored in row-major order where each row is the incoming
// connection weights for a single node in this layer.
// Size: num_output_nodes * num_input_nodes
/// Weights for each incoming connection. Each node in this layer has a weighted
/// connection to each node in the previous layer (num_input_nodes *
/// num_output_nodes).
///
/// The weights are stored in row-major order where each row is the incoming
/// connection weights for a single node in this layer.
///
/// Size: num_output_nodes * num_input_nodes
weights: []f64,
// Bias for each node in the layer (num_output_nodes)
// Size: num_output_nodes
/// Bias for each node in the layer (num_output_nodes)
///
/// Size: num_output_nodes
biases: []f64,

// Store any inputs we get during the forward pass so we can use them during the
// backward pass.
/// Store any inputs we get during the forward pass so we can use them during the
/// backward pass.
inputs: []f64 = undefined,

pub fn init(
Expand Down Expand Up @@ -197,6 +200,8 @@ pub const DenseLayer = struct {
return self.weights[weight_index];
}

/// Helper to access the weight for a specific connection since
/// the weights are stored in a flat array.
pub fn getFlatWeightIndex(self: *Self, node_index: usize, node_in_index: usize) usize {
return (node_index * self.num_input_nodes) + node_in_index;
}
Expand Down
89 changes: 44 additions & 45 deletions src/loss_functions.zig
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
const std = @import("std");
const log = std.log.scoped(.zig_neural_network);

// Loss functions are also known as cost/error functions.
//
// Math equation references (loss and derivative):
// https://stats.stackexchange.com/questions/154879/a-list-of-cost-functions-used-in-neural-networks-alongside-applications/154880#154880

// TODO: In the future, we could add negative log likelihood, MeanAbsoluteError (L1 loss),
// RootMeanSquaredError, Focal Loss, etc.

// SquaredError (also known as L2 loss)
//
// Used for regression problems where the output data comes from a normal/gaussian
// distribution like predicting something based on a trend (extrapolate). Does not
// penalize misclassifications as much as it could/should for binary/multi-class
// classification problems. Although this answer says that it doesn't matter,
// https://stats.stackexchange.com/a/568253/360344. Useful when TODO: What?
/// SquaredError (also known as L2 loss)
///
/// Used for regression problems where the output data comes from a normal/gaussian
/// distribution like predicting something based on a trend (extrapolate). Does not
/// penalize misclassifications as much as it could/should for binary/multi-class
/// classification problems. Although this answer says that it doesn't matter,
/// https://stats.stackexchange.com/a/568253/360344. Useful when TODO: What?
pub const SquaredError = struct {
// Sum of Squared Errors (SSE)
pub fn vector_loss(
Expand Down Expand Up @@ -51,42 +49,42 @@ pub const SquaredError = struct {
}
};

// Cross-Entropy is also referred to as Logarithmic loss. Used for binary or multi-class
// classification problems where the output data comes from a bernoulli distribution
// which just means we have buckets/categories with expected probabilities.
//
// Also called Sigmoid Cross-Entropy loss or Binary Cross-Entropy Loss
// https://gombru.github.io/2018/05/23/cross_entropy_loss/
//
// Why use Cross-Entropy for loss?
//
// 1. The first reason is that it gives steeper gradients for regimes that matter which
// results in fewer training epochs to get to the local minimum.
// 2. The second and probably more important reason is that when you're predicting more
// than one class (e.g., not just "cat" or "not", but also a third like "cat",
// "dog", or "neither"), Cross-Entropy gives you calibrated[1] results, whereas
// SquaredError will not necessarily.
// - [1] To explain what "calibrated" means; say you want a composite model. The
// first predicts the probability that a customer will convert in one of 3
// different price buckets. The second then multiplies the probability by the size
// of the bucket ($5, $30, or $80 let's say). If there's any reason the model is
// more accurate for some classes than others (extremely common), then an error in
// the $5 bucket has very different effects on the resulting estimate of lifetime
// customer value than an error in the $80 bucket. If your probabilities are
// calibrated then you can blindly do the aforementioned multiplication and know
// that on average it's correct. Otherwise, you're prone to (drastically) under or
// over valuing a customer and making incorrect decisions as a result of that
// information.
//
// With just 2 classes, the optimal values of SquaredError and Cross-Entropy are
// identical, so only learning rate applies. With 3 or more, Cross-Entropy is
// potentially more calibrated any time there is shared information in the inputs. The
// simplest case of that is when two inputs overlap completely, but a neural network
// maps close inputs to similar outputs (caveats apply), so for vector-valued inputs
// like images (where you feed in all of the pixels) you'll see the same effect just
// from images that look close to each other.
//
// https://machinelearningmastery.com/cross-entropy-for-machine-learning/
/// Cross-Entropy is also referred to as Logarithmic loss. Used for binary or multi-class
/// classification problems where the output data comes from a bernoulli distribution
/// which just means we have buckets/categories with expected probabilities.
///
/// Also called Sigmoid Cross-Entropy loss or Binary Cross-Entropy Loss
/// https://gombru.github.io/2018/05/23/cross_entropy_loss/
///
/// Why use Cross-Entropy for loss?
///
/// 1. The first reason is that it gives steeper gradients for regimes that matter which
/// results in fewer training epochs to get to the local minimum.
/// 2. The second and probably more important reason is that when you're predicting more
/// than one class (e.g., not just "cat" or "not", but also a third like "cat",
/// "dog", or "neither"), Cross-Entropy gives you calibrated[1] results, whereas
/// SquaredError will not necessarily.
/// - [1] To explain what "calibrated" means; say you want a composite model. The
/// first predicts the probability that a customer will convert in one of 3
/// different price buckets. The second then multiplies the probability by the size
/// of the bucket ($5, $30, or $80 let's say). If there's any reason the model is
/// more accurate for some classes than others (extremely common), then an error in
/// the $5 bucket has very different effects on the resulting estimate of lifetime
/// customer value than an error in the $80 bucket. If your probabilities are
/// calibrated then you can blindly do the aforementioned multiplication and know
/// that on average it's correct. Otherwise, you're prone to (drastically) under or
/// over valuing a customer and making incorrect decisions as a result of that
/// information.
///
/// With just 2 classes, the optimal values of SquaredError and Cross-Entropy are
/// identical, so only learning rate applies. With 3 or more, Cross-Entropy is
/// potentially more calibrated any time there is shared information in the inputs. The
/// simplest case of that is when two inputs overlap completely, but a neural network
/// maps close inputs to similar outputs (caveats apply), so for vector-valued inputs
/// like images (where you feed in all of the pixels) you'll see the same effect just
/// from images that look close to each other.
///
/// https://machinelearningmastery.com/cross-entropy-for-machine-learning/
pub const CrossEntropy = struct {
pub fn vector_loss(
self: @This(),
Expand Down Expand Up @@ -285,6 +283,7 @@ test "Slope check loss functions" {
}
}

/// Loss functions are also known as cost/error functions.
pub const LossFunction = union(enum) {
squared_error: SquaredError,
cross_entropy: CrossEntropy,
Expand Down
5 changes: 0 additions & 5 deletions src/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,3 @@ pub const ActivationLayer = @import("./layers/activation_layer.zig").ActivationL

pub const ActivationFunction = @import("./activation_functions.zig").ActivationFunction;
pub const LossFunction = @import("./loss_functions.zig").LossFunction;

// pub fn add(a: i32, b: i32) i32 {
// log.debug("add() called", .{});
// return a + b;
// }
43 changes: 20 additions & 23 deletions src/neural_network.zig
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,28 @@ const LossFunction = @import("./loss_functions.zig").LossFunction;
pub const NeuralNetwork = struct {
const Self = @This();

// ex.
// ```
// var dense_layer1 = neural_network.DenseLayer.init(2, 3, allocator);
// defer dense_layer1.deinit();
// var activation_layer1 = neural_network.ActivationLayer(neural_network.ActivationFunction{ .elu = {} }).init();
// var dense_layer2 = neural_network.DenseLayer.init(3, 2);
// defer dense_layer2.deinit();
// var activation_layer2 = neural_network.ActivationLayer(neural_network.ActivationFunction{ .soft_max = {} }).init();
/// ex.
/// ```
/// var dense_layer1 = neural_network.DenseLayer.init(2, 3, allocator);
/// defer dense_layer1.deinit();
/// var activation_layer1 = neural_network.ActivationLayer(neural_network.ActivationFunction{ .elu = {} }).init();
/// var dense_layer2 = neural_network.DenseLayer.init(3, 2);
/// defer dense_layer2.deinit();
/// var activation_layer2 = neural_network.ActivationLayer(neural_network.ActivationFunction{ .soft_max = {} }).init();
//
// var layers = [_]neural_network.Layer{
// dense_layer1.layer(),
// activation_layer1.layer(),
// dense_layer2.layer(),
// activation_layer2.layer(),
// };
/// var layers = [_]neural_network.Layer{
/// dense_layer1.layer(),
/// activation_layer1.layer(),
/// dense_layer2.layer(),
/// activation_layer2.layer(),
/// };
//
// neural_network.NeuralNetwork.init(
// layers,
// neural_network.LossFunction{
// .squared_error = {},
// // .cross_entropy = {},
// },
// allocator,
// );
// ```
/// neural_network.NeuralNetwork.init(
/// layers,
/// neural_network.LossFunction{ .squared_error = {} },
/// allocator,
/// );
/// ```
pub fn initFromLayers(
layers: []const Layer,
loss_function: LossFunction,
Expand Down

0 comments on commit c0e42d2

Please sign in to comment.