Training a transformer-decoder network, to produce shakespear-esque text, only using the Julia standard library (plus some extra packages for testing, which are only enabled in the CI run, and for GPU support). I obtained my training data from karpathy/char-rnn, who originally compiled the dataset.
The src
directory, in the project root, contains 4 files, which represent the different layers of abstraction.
- autodiff.jl, which implements a generic way to differentiate a function (represented as a graph of operations on "tensor" objects) with a single scalar output (if function has multiple outputs, there is an implicit sum operation at the end), with respect to any number of variables. This is implemented with reverse-mode auto-differentiation.
- optimizer.jl, which contains various implementations of optimizations algorithms and loss functions. (Stochastic gradient decent, Adam optimization, mean square error loss and cross entropy loss)
- transformer.jl, implements self-attention, a decoder block, and a training loop for a generative language model, that consumes the tiny-shakespeare.txt dataset.
- Shakespeare.jl, just loads the files above, and exports high-level functions and type-definitions, and depending on flags set in the environment, it will install packages for testing all the moving parts.
For a simple example of the API of autodiff.jl
, view mnist.jl.
src/autodiff.jl
has unit test in the bottom of the file, defining expressions to validate against numerical derivatives calculated with FiniteDiff.jl.
Optimizing a parameter in a linear projection (matmul), fed into a softmax function, and testing against values, and initial states, generated with PyTorch (see linear_softmax.py).
The graph shows, the mean square error between the expected parameter value from PyTorch, and actual parameter of the linear projection, on the Y-axis. And the iteration of optimization on the x-axis.
The code used to generate the graph, and do the comparison between the torch and julia parameters, is located at test.jl (test/runtests.jl didn't work with my LSP configuration for some reason, so this is an easy workaround)