Skip to content

beacon-biosignals/LegolasFlux.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

LegolasFlux

CI codecov

Note: Upgrading from LegolasFlux v0.1.x to v0.2?

The only change is an update to Legolas v0.5. Be sure to check out the guidance for updating Legolas to v0.5 along with the rest of Legolas's documentation and tour.


LegolasFlux provides some simple functionality to use Legolas.jl's extensible Arrow schemas as means to serialize Flux models similarly to using Flux's params and loadparams! (instead, we export similar functions fetch_weights and load_weights! which handle layers like BatchNorm correctly for this purpose).

The aim is to serialize only the numeric weights, not the code defining the model. This is a very different approach from e.g. BSON.jl, and hopefully much more robust. Note that in this package, we use weights to refer to the numeric arrays that are modified over the course of training a model; that includes biases as well as means and variances in e.g. BatchNorms (but not e.g. configuration settings).

With this approach, however, if you change the code such that the weights are no longer valid (e.g. add a layer), you will not be able to load back the same model.

Usage

using Flux

function make_my_model()
    return Chain(Dense(1,10), Dense(10, 10), Dense(10, 1))
end

my_model = make_my_model()
# train it? that part is optional ;)

# Now, let's save it!
using LegolasFlux

model_row = LegolasFlux.ModelV1(; weights = fetch_weights(cpu(my_model)),
                                architecture_version=1)
write_model_row("my_model.model.arrow", model_row)

# Great! Later on, we want to re-load our model weights.
fresh_model = make_my_model()

model_row = read_model_row("my_model.model.arrow")
load_weights!(fresh_model, model_row.weights)
# Now our weights have been loaded back into `fresh_model`.

We can make use of the architecture_version column to specify a version number for the architectures, in order to keep track of for which architectures the weights are valid for.

See examples/digits.jl for a larger example, which also saves out extra metadata with the model, by using a Legolas schema extension.

LegolasFlux.ModelV1

A LegolasFlux.ModelV1 is a central object of LegolasFlux. It acts as a Tables.jl-compatible row that can store the weights of a Flux model in the weights column, optionally an architecture_version (defaults to missing).

ModelV1 is not exported because downstream models likely want to define their own rows which extend the schema provided by LegolasFlux that might end up being called something similar. See the next section for more on extensibility.

Extensibility

As a Legolas.jl schema, it is meant to be extended. For example, let's say I had an MNIST classification model that I call Digits. I am very committed to reproducibility, so I store the commit_sha of my model's repo with every training run, and I also wish to save the accuracy and epoch. I might create a DigitsRow which is a schema extension of the legolas-flux.model schema:

using Legolas, LegolasFlux
using Legolas: @schema, @version
@schema "digits-model" DigitsRow
@version DigitsRowV1 > ModelV1 begin
    # re-declare this ModelV1 field as parametric for this schema as well
    weights::(<:Union{Missing,Weights})
    epoch::Union{Missing, Int}
    accuracy::Union{Missing, Float32}
    commit_sha::Union{Missing, String}
end

Now I can use a DigitsRowV1 much like LegolasFlux's ModelV1. It has the same required weights column and optional architecture_version column, as well as the additional epoch, accuracy, and commit_sha columns. As a naming convention, one might name files produced by this row as e.g. training_run.digits.model.arrow.

When writing out a DigitsRowV1, I'll pass the schema version like so

write_model_row(path, my_digits_row, DigitsRowV1SchemaVersion())

so that later, when I call read_model_row on this path, I'll get back a DigitsRowV1 instance.

Note in this example the schema is called digits.model instead of just say digits, since the package Digits might want to create other Legolas schemas as well at some point.

Check out the Legolas.jl repo to see more about how its extensible schema system works, and the example at examples/digits.jl.

About

Save Flux model weights in Legolas-powered Arrow tables

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages