diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index de1d7fca..de09ade5 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -5,110 +5,57 @@ on: branches: - main env: - OTP_VERSION: "25.0" - ELIXIR_VERSION: "1.14.0" MIX_ENV: test XLA_CACHE_DIR: ${{ github.workspace }}/cache/xla LIBTORCH_DIR: ${{ github.workspace }}/cache/torch jobs: + codespell: + runs-on: ubuntu-latest + steps: + - uses: codespell-project/actions-codespell@v2 + with: + skip: deps + ignore_words_list: whn,ehr main: runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - otp: "26.1.1" + elixir: "1.15.6" + lint: true + - otp: "25.3.2.6" + elixir: "1.14.5" + - otp: "25.3.2.6" + elixir: "1.14.5" + test_command_prepend: "USE_EXLA=true" + - otp: "25.3.2.6" + elixir: "1.14.5" + test_command_prepend: "USE_TORCHX=true" steps: - uses: actions/checkout@v3 - name: Install Erlang & Elixir uses: erlef/setup-beam@v1 with: - otp-version: "${{ env.OTP_VERSION }}" - elixir-version: "${{ env.ELIXIR_VERSION }}" + otp-version: "${{ matrix.otp }}" + elixir-version: "${{ matrix.elixir }}" - uses: actions/cache@v3 with: path: | deps _build cache - key: ${{ runner.os }}-mix-${{ matrix.pair.elixir }}-${{ matrix.pair.otp }}-${{ hashFiles('**/mix.lock') }} + key: ${{ runner.os }}-mix-${{ matrix.elixir }}-${{ matrix.otp }}-${{ matrix.test_command_prepend }}-${{ hashFiles('**/mix.lock') }} restore-keys: | ${{ runner.os }}-mix- - name: Install mix dependencies run: mix deps.get - name: Check formatting + if: ${{ matrix.lint }} run: mix format --check-formatted - - name: Compile without optional deps - run: mix compile --skip-optional-deps --warnings-as-errors + - name: Check unused deps + if: ${{ matrix.lint }} + run: mix deps.unlock --check-unused - name: Run tests - run: mix test - exla_check: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Install Erlang & Elixir - uses: erlef/setup-beam@v1 - with: - otp-version: "${{ env.OTP_VERSION }}" - elixir-version: "${{ env.ELIXIR_VERSION }}" - - uses: actions/cache@v3 - with: - path: | - deps - _build - cache - key: ${{ runner.os }}-mix-${{ matrix.pair.elixir }}-${{ matrix.pair.otp }}-${{ hashFiles('**/mix.lock') }} - restore-keys: | - ${{ runner.os }}-mix- - - name: Install mix dependencies - run: mix deps.get - - name: Run tests against EXLA - run: USE_EXLA=true mix do compile --warnings-as-errors, test - torchx_check: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Install Erlang & Elixir - uses: erlef/setup-elixir@v1 - with: - otp-version: "${{ env.OTP_VERSION }}" - elixir-version: "${{ env.ELIXIR_VERSION }}" - - uses: actions/cache@v3 - with: - path: | - deps - _build - cache - key: ${{ runner.os }}-mix-${{ matrix.pair.elixir }}-${{ matrix.pair.otp }}-${{ hashFiles('**/mix.lock') }} - restore-keys: | - ${{ runner.os }}-mix- - - name: Install mix dependencies - run: mix deps.get - - name: Run tests against Torchx - run: USE_TORCHX=true mix do compile --warnings-as-errors, test - onnx_check: - runs-on: ubuntu-latest - steps: - - name: Install Erlang & Elixir - uses: erlef/setup-beam@v1 - with: - otp-version: "${{ env.OTP_VERSION }}" - elixir-version: "${{ env.ELIXIR_VERSION }}" - - name: Install Python - uses: actions/setup-python@v2 - with: - python-version: "3.8" - - name: Install ONNX - run: pip install numpy onnx onnxruntime - - name: Install transformers - run: pip install git+https://github.com/huggingface/transformers.git sentencepiece pillow torch tensorflow - - name: Install Protoc - uses: arduino/setup-protoc@v1 - with: - version: "3.x" - - name: Checkout AxonOnnx - uses: actions/checkout@v3 - with: - repository: elixir-nx/axon_onnx - ref: refs/heads/master - - uses: actions/checkout@v3 - with: - path: tmp/axon - - name: Run ONNX tests - run: | - AXON_PATH="tmp/axon" mix do deps.get, compile --warnings-as-errors, test + run: ${{ matrix.test_command_prepend }} mix do compile --skip-optional-deps --warnings-as-errors, test diff --git a/.gitignore b/.gitignore index 426c4881..e4cae0d0 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,7 @@ axon-*.tar # Downloaded fixtures examples/vision/horses/ examples/vision/humans/ +examples/structured/creditcard.csv # Temporary files for e.g. tests /tmp/ diff --git a/CHANGELOG.md b/CHANGELOG.md index b2689487..38788229 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,52 @@ # Changelog +## v0.5.1 (2023-02-17) + +### Bug Fixes + +* Fixed incorrect results from group normalization + +## v0.5.0 (2023-02-16) + +### Enhancements + +* Bump Nx dependency +* Update documentation to account for channels last default +* Improve error message in compilation/build errors for models +* Remove deprecated `transform` + +### Deprecations + +* Deprecate `Axon.Loop.handle/4` + +## v0.4.1 (2023-01-21) + +### Bug Fixes + +* Fixed a shape mismatch when training with certain optimizers + +## v0.4.0 (2023-01-19) + +### Enhancements + +* Add `Axon.pop_nodes/2` for popping nodes off of a graph +* Update `Axon.freeze/2` and `Axon.unfreeze/2` for manipulating frozen portions of Axon graph +* Add `Axon.Loop.monitor/5` for firing events based on loop state criteria +* Add `Axon.Loop.kino_vega_lite_plot/4` for producing Kino plots during training +* Add `Axon.Schedules.linear_decay/1` +* Performance boosts to `Axon.Loop` which prevent compilation cache misses in most Axon training and evaluation loops +* Add global event counts for more correct filtering during Axon loops +* Use layer state to manage dropout keys, making training more deterministic when starting from the same key +* Make building Axon models fully deterministic +* Add a bidirectional combinator + +### Bug Fixes + +* Fix issue with namespaced stateful models not updating correctly during training +* Fix bug in `Axon.Loop.early_stop/3` which incorrectly tracked progress and would not early stop loop +* Fix bug in `Axon.Loop.reduce_lr_on_plateau/3` which incorrectly tracked progress and would not reduce learning rate +* Fix bug in `Axon.Layers.conv_transpose/4` when using channels last + ## v0.3.1 (2022-12-07) ### Enhancements @@ -60,4 +107,4 @@ ## v0.1.0 (2022-06-16) -First release. \ No newline at end of file +First release. diff --git a/README.md b/README.md index 0a6078d0..c6e96731 100644 --- a/README.md +++ b/README.md @@ -8,10 +8,9 @@ Axon consists of the following components: * Functional API – A low-level API of numerical definitions (defn) of which all other APIs build on. * Model Creation API – A high-level model creation API which manages model initialization and application. - * Optimization API – An API for creating and using first-order optimization techniques based on the [Optax](https://github.com/deepmind/optax) library. * Training API – An API for quickly training models, inspired by [PyTorch Ignite](https://pytorch.org/ignite/index.html). -Axon provides abstractions that enable easy integration while maintaining a level of separation between each component. You should be able to use any of the APIs without dependencies on others. By decoupling the APIs, Axon gives you full control over each aspect of creating and training a neural network. +Axon provides abstractions that enable easy integration while maintaining a level of separation between each component. You should be able to use any of the APIs without dependencies on others. By decoupling the APIs, Axon gives you full control over each aspect of creating and training a neural network. Axon uses [Polaris](https://github.com/elixir-nx/polaris) for its optimization API. ## Overview @@ -103,27 +102,13 @@ model = model_state = model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adamw(0.005)) + |> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adamw(0.005)) |> Axon.Loop.metric(:accuracy) |> Axon.Loop.handle(:iteration_completed, &log_metrics/1, every: 50) |> Axon.Loop.run(data, %{}, epochs: 10, compiler: EXLA) ``` -The step expects an optimizer as argument. The following are currently supported: - -* Adabelief -* Adagrad -* Adam -* Adamw -* Fromage -* Lamb -* Noisy SGD -* Radam -* RMSProp -* SGD -* Yogi - -It’s important to note that optimization API does not directly depend on Axon models. You can use the API to optimize any differentiable objective function. +Axon uses [Polaris](https://github.com/elixir-nx/polaris) for its optimization API. It’s important to note that optimization API does not directly depend on Axon models. You can use the API to optimize any differentiable objective function. In the future we plan to support distributed training loops. We are also seeking ways to improve the performance of our training loops by running them entirely on native accelerators. @@ -140,7 +125,7 @@ Then add Axon to your dependencies: ```elixir def deps do [ - {:axon, "~> 0.2.0"} + {:axon, "~> 0.6"} ] end ``` @@ -150,13 +135,16 @@ You'll also likely want to include an `Nx` compiler such as `EXLA` for any pract ```elixir def deps do [ - {:axon, "~> 0.2.0"}, - {:exla, "~> 0.3.0"}, - {:nx, "~> 0.3.0"} + {:axon, "~> 0.6"}, + {:exla, "~> 0.6"}, ] end ``` +## Integration with other platforms + +See [Ortex](https://github.com/elixir-nx/ortex) which provides full-blown compatibility for running ONNX models via ONNX Runtime bindings. Alternatively, see [AxonONNX](https://github.com/elixir-nx/axon_onnx) to convert ONNX models to Axon models whenever possible to achieve better integration with Nx. + ## Sponsors DockYard diff --git a/examples/basics/multi_input_example.exs b/examples/basics/multi_input_example.exs index 0ed5d8b7..ed990cbd 100644 --- a/examples/basics/multi_input_example.exs +++ b/examples/basics/multi_input_example.exs @@ -1,7 +1,7 @@ Mix.install([ - {:axon, "~> 0.3.0"}, - {:exla, "~> 0.4.1"}, - {:nx, "~> 0.4.1"} + {:axon, "~> 0.5"}, + {:exla, "~> 0.5"}, + {:nx, "~> 0.5"} ]) defmodule XOR do @@ -27,7 +27,7 @@ defmodule XOR do defp train_model(model, data, epochs) do model |> Axon.Loop.trainer(:binary_cross_entropy, :sgd) - |> Axon.Loop.run(data, %{}, epochs: epochs, iterations: 1000) + |> Axon.Loop.run(data, %{}, epochs: epochs, iterations: 1000, compiler: EXLA) end def run do diff --git a/examples/basics/multi_output_example.exs b/examples/basics/multi_output_example.exs index c20e3abe..45ae9d8e 100644 --- a/examples/basics/multi_output_example.exs +++ b/examples/basics/multi_output_example.exs @@ -1,12 +1,9 @@ Mix.install([ - {:axon, "~> 0.3.0"}, - # {:exla, "~> 0.2.2"}, - {:nx, "~> 0.4.1"} + {:axon, "~> 0.5"}, + {:exla, "~> 0.5"}, + {:nx, "~> 0.5"} ]) -# Specify EXLA as the default defn compiler -# EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host]) - defmodule Power do require Axon @@ -34,11 +31,14 @@ defmodule Power do # model input and y is the target. Because we have multiple targets, we represent # y as a tuple. In the future, Axon will support any Nx container as an output data = - Stream.repeatedly(fn -> - # Batch size of 32 - x = Nx.random_uniform({32, 1}, -10, 10, type: {:f, 32}) - {x, {Nx.power(x, 2), Nx.power(x, 3)}} - end) + Stream.unfold( + Nx.Random.key(:erlang.system_time()), + fn key -> + # Batch size of 32 + {x, next_key} = Nx.Random.uniform(key, -10, 10, shape: {32, 1}, type: {:f, 32}) + {{x, {Nx.pow(x, 2), Nx.pow(x, 3)}}, next_key} + end + ) # Create the training loop, notice we specify 2 MSE objectives, 1 for the first # output and 1 for the second output. This will create a loss function which is @@ -62,7 +62,7 @@ defmodule Power do params = model |> Axon.Loop.trainer([mean_squared_error: 0.5, mean_squared_error: 0.5], :adam) - |> Axon.Loop.run(data, %{}, iterations: 250, epochs: 5) + |> Axon.Loop.run(data, %{}, iterations: 250, epochs: 5, compiler: EXLA) IO.inspect(Axon.predict(model, params, Nx.tensor([[3]]))) end diff --git a/examples/generative/fashionmnist_autoencoder.exs b/examples/generative/fashionmnist_autoencoder.exs index d6db4cab..5644604c 100644 --- a/examples/generative/fashionmnist_autoencoder.exs +++ b/examples/generative/fashionmnist_autoencoder.exs @@ -1,16 +1,12 @@ Mix.install([ - {:axon, "~> 0.1.0"}, - {:exla, "~> 0.2.2"}, - {:nx, "~> 0.2.1"}, - {:scidata, "~> 0.1.6"} + {:axon, "~> 0.5"}, + {:exla, "~> 0.5"}, + {:nx, "~> 0.5"}, + {:scidata, "~> 0.1"} ]) -# Configure default platform with accelerator precedence as tpu > cuda > rocm > host -EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host]) - -defmodule Fashionmist do +defmodule FashionMNIST do require Axon - alias Axon.Loop.State defmodule Autoencoder do defp encoder(x, latent_dim) do @@ -22,7 +18,7 @@ defmodule Fashionmist do defp decoder(x) do x |> Axon.dense(784, activation: :sigmoid) - |> Axon.reshape({1, 28, 28}) + |> Axon.reshape({:batch, 1, 28, 28}) end def build_model(input_shape, latent_dim) do @@ -37,7 +33,7 @@ defmodule Fashionmist do |> Nx.from_binary(type) |> Nx.reshape({elem(shape, 0), 1, 28, 28}) |> Nx.divide(255.0) - |> Nx.to_batched_list(32) + |> Nx.to_batched(32) end defp train_model(model, train_images, epochs) do @@ -58,7 +54,7 @@ defmodule Fashionmist do sample_image = train_images - |> hd() + |> Enum.fetch!(0) |> Nx.slice_along_axis(0, 1) |> Nx.reshape({1, 1, 28, 28}) @@ -71,4 +67,4 @@ defmodule Fashionmist do end end -Fashionmist.run() +FashionMNIST.run() diff --git a/examples/generative/mnist_gan.exs b/examples/generative/mnist_gan.exs index eac6bdaa..c1725a1e 100644 --- a/examples/generative/mnist_gan.exs +++ b/examples/generative/mnist_gan.exs @@ -1,8 +1,9 @@ Mix.install([ - {:axon, "~> 0.1.0"}, - {:exla, "~> 0.2.2"}, - {:nx, "~> 0.2.1"}, - {:scidata, "~> 0.1.6"} + {:axon, "~> 0.5"}, + {:polaris, "~> 0.1"}, + {:exla, "~> 0.5"}, + {:nx, "~> 0.5"}, + {:scidata, "~> 0.1"} ]) EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host]) @@ -17,7 +18,7 @@ defmodule MNISTGAN do |> Nx.from_binary(type) |> Nx.reshape({elem(shape, 0), 1, 28, 28}) |> Nx.divide(255.0) - |> Nx.to_batched_list(32) + |> Nx.to_batched(32) end defp build_generator(z_dim) do @@ -33,7 +34,7 @@ defmodule MNISTGAN do |> Axon.batch_norm() |> Axon.dense(784) |> Axon.tanh() - |> Axon.reshape({1, 28, 28}) + |> Axon.reshape({:batch, 28, 28, 1}) end defp build_discriminator(input_shape) do @@ -53,9 +54,13 @@ defmodule MNISTGAN do |> Nx.divide(Nx.add(i, 1)) end - defn init(d_params, g_params, init_optim_d, init_optim_g) do + defn init(template, init_d, init_g, init_optim_d, init_optim_g) do + d_params = init_d.(template, %{}) + g_params = init_g.(Nx.broadcast(0.0, {1, 100}), %{}) + %{ iteration: Nx.tensor(0), + random_key: Nx.Random.key(9999), discriminator: %{ model_state: d_params, optimizer_state: init_optim_d.(d_params), @@ -77,7 +82,7 @@ defmodule MNISTGAN do # Update D fake_labels = Nx.iota({32, 2}, axis: 1) real_labels = Nx.reverse(fake_labels) - noise = Nx.random_normal({32, 100}) + {noise, random_next_key} = Nx.Random.normal(state[:random_key], shape: {32, 100}) {d_loss, d_grads} = value_and_grad(d_params, fn params -> @@ -96,7 +101,7 @@ defmodule MNISTGAN do d_optimizer_state = state[:discriminator][:optimizer_state] {d_updates, d_optimizer_state} = optim_d.(d_grads, d_optimizer_state, d_params) - d_params = Axon.Updates.apply_updates(d_params, d_updates) + d_params = Polaris.Updates.apply_updates(d_params, d_updates) # Update G {g_loss, g_grads} = @@ -111,10 +116,11 @@ defmodule MNISTGAN do g_optimizer_state = state[:generator][:optimizer_state] {g_updates, g_optimizer_state} = optim_g.(g_grads, g_optimizer_state, g_params) - g_params = Axon.Updates.apply_updates(g_params, g_updates) + g_params = Polaris.Updates.apply_updates(g_params, g_updates) %{ iteration: iter + 1, + random_key: random_next_key, discriminator: %{ model_state: d_params, optimizer_state: d_optimizer_state, @@ -129,14 +135,14 @@ defmodule MNISTGAN do end defp train_loop(d_model, g_model) do - {init_optim_d, optim_d} = Axon.Optimizers.adam(2.0e-3, b1: 0.5) - {init_optim_g, optim_g} = Axon.Optimizers.adam(2.0e-3, b1: 0.5) + {init_optim_d, optim_d} = Polaris.Optimizers.adam(learning_rate: 2.0e-3, b1: 0.5) + {init_optim_g, optim_g} = Polaris.Optimizers.adam(learning_rate: 2.0e-3, b1: 0.5) - {d_init_params, d_model} = Axon.compile(d_model, mode: :train) - {g_init_params, g_model} = Axon.compile(g_model, mode: :train) + {init_d, d_model} = Axon.build(d_model, mode: :train) + {init_g, g_model} = Axon.build(g_model, mode: :train) step = &batch_step(d_model, g_model, optim_d, optim_g, &1, &2) - init = fn %{} -> init(d_init_params, g_init_params, init_optim_d, init_optim_g) end + init = fn template, _state -> init(template, init_d, init_g, init_optim_d, init_optim_g) end Axon.Loop.loop(step, init) end @@ -152,7 +158,7 @@ defmodule MNISTGAN do defp view_generated_images(model, batch_size, state) do %State{step_state: pstate} = state - noise = Nx.random_normal({batch_size, 100}) + {noise, random_next_key} = Nx.Random.normal(pstate[:random_key], shape: {batch_size, 100}) preds = Axon.predict(model, pstate[:generator][:model_state], noise) preds @@ -160,7 +166,7 @@ defmodule MNISTGAN do |> Nx.to_heatmap() |> IO.inspect() - {:continue, state} + {:continue, put_in(state.step_state.random_key, random_next_key)} end def run() do @@ -168,16 +174,18 @@ defmodule MNISTGAN do train_images = transform_images(images) generator = build_generator(100) - discriminator = build_discriminator({nil, 1, 28, 28}) + discriminator = build_discriminator({nil, 28, 28, 1}) discriminator |> train_loop(generator) - |> Axon.Loop.log(:iteration_completed, &log_iteration/1, :stdio, every: 50) - |> Axon.Loop.handle(:epoch_completed, &view_generated_images(generator, 3, &1)) + |> Axon.Loop.log(&log_iteration/1, + event: :iteration_completed, + device: :stdio, + filter: [every: 50] + ) + |> Axon.Loop.handle_event(:epoch_completed, &view_generated_images(generator, 3, &1)) |> Axon.Loop.run(train_images, %{}, epochs: 10, compiler: EXLA) end end -EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host]) - MNISTGAN.run() diff --git a/examples/generative/text_generator.exs b/examples/generative/text_generator.exs index 97a2f099..a34c9a1c 100644 --- a/examples/generative/text_generator.exs +++ b/examples/generative/text_generator.exs @@ -1,13 +1,11 @@ # Based on https://machinelearningmastery.com/text-generation-lstm-recurrent-neural-networks-python-keras/ Mix.install([ - {:axon, github: "elixir-nx/axon"}, - {:nx, "~> 0.2.1"}, - {:exla, "~> 0.2.2"}, - {:req, "~> 0.3.0"} + {:axon, "~> 0.5"}, + {:nx, "~> 0.5"}, + {:exla, "~> 0.5"}, + {:req, "~> 0.3.3"} ]) -EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host]) - defmodule TextGenerator do require Axon @@ -18,7 +16,7 @@ defmodule TextGenerator do def build_model(characters_count) do Axon.input("input_chars", shape: {nil, @sequence_length, 1}) |> Axon.lstm(256) - |> then(fn {_, out} -> out end) + |> then(fn {out, _} -> out end) |> Axon.nx(fn t -> t[[0..-1//1, -1]] end) |> Axon.dropout(rate: 0.2) |> Axon.dense(characters_count, activation: :softmax) @@ -59,7 +57,7 @@ defmodule TextGenerator do |> Nx.tensor() |> Nx.divide(characters_count) |> Nx.reshape({:auto, @sequence_length, 1}) - |> Nx.to_batched_list(@batch_size) + |> Nx.to_batched(@batch_size) train_labels = text @@ -68,7 +66,7 @@ defmodule TextGenerator do |> Nx.tensor() |> Nx.reshape({:auto, 1}) |> Nx.equal(Nx.iota({characters_count})) - |> Nx.to_batched_list(@batch_size) + |> Nx.to_batched(@batch_size) {train_data, train_labels} end @@ -97,7 +95,7 @@ defmodule TextGenerator do params = model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(0.001)) + |> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adam(learning_rate: 0.001)) |> Axon.Loop.run(Stream.zip(train_data, train_labels), %{}, epochs: 20, compiler: EXLA) init_sequence = """ diff --git a/examples/structured/credit_card_fraud.exs b/examples/structured/credit_card_fraud.exs index 2cb73397..f3628a80 100644 --- a/examples/structured/credit_card_fraud.exs +++ b/examples/structured/credit_card_fraud.exs @@ -1,14 +1,14 @@ Mix.install([ - {:axon, "~> 0.1.0"}, - {:exla, "~> 0.2.2"}, - {:nx, "~> 0.2.1"}, - {:explorer, "~> 0.2.0"} + {:axon, "~> 0.5"}, + {:polaris, "~> 0.1"}, + {:exla, "~> 0.5"}, + {:nx, "~> 0.5"}, + {:explorer, "~> 0.6"} ]) -EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host]) - defmodule CreditCardFraud do alias Axon.Loop.State + require Explorer.DataFrame # Download data with a Kaggle account: https://www.kaggle.com/mlg-ulb/creditcardfraud/ @file_name "examples/structured/creditcard.csv" @@ -47,32 +47,18 @@ defmodule CreditCardFraud do end defp split_features_targets(df) do - features = Explorer.DataFrame.select(df, &(&1 == "Class"), :drop) - targets = Explorer.DataFrame.select(df, &(&1 == "Class"), :keep) + features = Explorer.DataFrame.discard(df, ["Class"]) + targets = Explorer.DataFrame.select(df, ["Class"]) {features, targets} end - defp normalize(name), - do: fn df -> - Explorer.Series.divide( - df[name], - Explorer.Series.max( - Explorer.Series.transform(df[name], fn x -> - if x >= 0 do - x - else - -x - end - end) - ) - ) - end - defp normalize_data(df) do df - |> Explorer.DataFrame.names() - |> Map.new(&{&1, normalize(&1)}) - |> then(&Explorer.DataFrame.mutate(df, &1)) + |> Explorer.DataFrame.mutate( + for col <- across() do + {col.name, col / max(abs(col))} + end + ) end defp df_to_tensor(df) do @@ -127,7 +113,7 @@ defmodule CreditCardFraud do model |> Axon.Loop.evaluator() |> metrics() - |> Axon.Loop.handle(:epoch_completed, &summarize/1) + |> Axon.Loop.handle_event(:epoch_completed, &summarize/1) |> Axon.Loop.run(test_data, model_state, compiler: EXLA) end @@ -145,12 +131,12 @@ defmodule CreditCardFraud do fraud = Nx.sum(train_targets) |> Nx.to_number() legit = Nx.size(train_targets) - fraud - batched_train_inputs = Nx.to_batched_list(train_inputs, 2048) - batched_train_targets = Nx.to_batched_list(train_targets, 2048) + batched_train_inputs = Nx.to_batched(train_inputs, 2048) + batched_train_targets = Nx.to_batched(train_targets, 2048) batched_train = Stream.zip(batched_train_inputs, batched_train_targets) - batched_test_inputs = Nx.to_batched_list(test_inputs, 2048) - batched_test_targets = Nx.to_batched_list(test_targets, 2048) + batched_test_inputs = Nx.to_batched(test_inputs, 2048) + batched_test_targets = Nx.to_batched(test_targets, 2048) batched_test = Stream.zip(batched_test_inputs, batched_test_targets) IO.puts("# of legit transactions (train): #{legit}") @@ -169,7 +155,7 @@ defmodule CreditCardFraud do reduction: :mean ) - optimizer = Axon.Optimizers.adam(1.0e-2) + optimizer = Polaris.Optimizers.adam(learning_rate: 1.0e-2) model |> train_model(loss, optimizer, batched_train) diff --git a/examples/vision/cifar10.exs b/examples/vision/cifar10.exs index 52e01dc9..c4906b51 100644 --- a/examples/vision/cifar10.exs +++ b/examples/vision/cifar10.exs @@ -1,22 +1,21 @@ Mix.install([ - {:axon, "~> 0.1.0"}, - {:exla, "~> 0.2.2"}, - {:nx, "~> 0.2.1"}, - {:scidata, "~> 0.1.3"} + {:axon, "~> 0.5"}, + {:exla, "~> 0.5"}, + {:nx, "~> 0.5"}, + {:scidata, "~> 0.1"} ]) -# Configure default platform with accelerator precedence as tpu > cuda > rocm > host -EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host]) - defmodule Cifar do require Axon defp transform_images({bin, type, shape}) do bin |> Nx.from_binary(type) - |> Nx.reshape({elem(shape, 0), 3, 32, 32}) + |> Nx.reshape(shape, names: [:count, :channels, :width, :height]) + # Move channels to last position to match what conv layer expects + |> Nx.transpose(axes: [:count, :width, :height, :channels]) |> Nx.divide(255.0) - |> Nx.to_batched_list(32) + |> Nx.to_batched(32) |> Enum.split(1500) end @@ -25,7 +24,7 @@ defmodule Cifar do |> Nx.from_binary(type) |> Nx.new_axis(-1) |> Nx.equal(Nx.tensor(Enum.to_list(0..9))) - |> Nx.to_batched_list(32) + |> Nx.to_batched(32) |> Enum.split(1500) end @@ -58,12 +57,16 @@ defmodule Cifar do end def run do - {images, labels} = Scidata.CIFAR10.download() + {{_, _, {_, channels, width, height}} = images, labels} = Scidata.CIFAR10.download() {train_images, test_images} = transform_images(images) {train_labels, test_labels} = transform_labels(labels) - model = build_model({nil, 3, 32, 32}) |> IO.inspect() + model = + # Move channels to last position to match what conv layer expects + {nil, width, height, channels} + |> build_model() + |> IO.inspect() IO.write("\n\n Training Model \n\n") @@ -79,4 +82,5 @@ defmodule Cifar do end end +Nx.default_backend(EXLA.Backend) Cifar.run() diff --git a/examples/vision/cnn_image_denoising.exs b/examples/vision/cnn_image_denoising.exs index 928bcf2d..bb8e31f3 100644 --- a/examples/vision/cnn_image_denoising.exs +++ b/examples/vision/cnn_image_denoising.exs @@ -1,12 +1,10 @@ Mix.install([ - {:axon, "~> 0.1.0"}, - {:exla, "~> 0.2.2"}, - {:nx, "~> 0.2.1"}, - {:scidata, "~> 0.1.6"} + {:axon, "~> 0.5"}, + {:exla, "~> 0.5"}, + {:nx, "~> 0.5"}, + {:scidata, "~> 0.1"} ]) -EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host]) - defmodule MnistDenoising do require Axon import Nx.Defn @@ -48,7 +46,7 @@ defmodule MnistDenoising do defp transform_images({bin, type, shape}) do bin |> Nx.from_binary(type) - |> Nx.reshape({elem(shape, 0), 1, 28, 28}) + |> Nx.reshape({elem(shape, 0), 28, 28, 1}) |> Nx.divide(255.0) |> Nx.to_batched_list(@batch_size) # Test split @@ -65,7 +63,7 @@ defmodule MnistDenoising do defp display_image(images) do images |> Nx.slice_along_axis(0, 1) - |> Nx.reshape({1, 28, 28}) + |> Nx.reshape({28, 28, 1}) |> Nx.to_heatmap() |> IO.inspect() end @@ -77,8 +75,7 @@ defmodule MnistDenoising do end defp encoder(input_shape) do - input_shape - |> Axon.input("input") + Axon.input("input", shape: input_shape) |> Axon.conv(32, kernel_size: {3, 3}, padding: :same, activation: :relu) |> Axon.max_pool(kernel_size: {2, 2}, padding: :same) |> Axon.conv(32, kernel_size: {3, 3}, padding: :same, activation: :relu) diff --git a/examples/vision/horses_or_humans.exs b/examples/vision/horses_or_humans.exs index 5364db29..c564f8c7 100644 --- a/examples/vision/horses_or_humans.exs +++ b/examples/vision/horses_or_humans.exs @@ -1,20 +1,17 @@ Mix.install([ {:stb_image, "~> 0.5.2"}, - {:axon, "~> 0.1.0"}, - {:exla, "~> 0.2.2"}, - {:nx, "~> 0.2.1"} + {:axon, "~> 0.5"}, + {:polaris, "~> 0.1"}, + {:exla, "~> 0.5"}, + {:nx, "~> 0.5"} ]) -EXLA.set_as_nx_default( - [:tpu, :cuda, :rocm, :host], - run_options: [keep_on_device: true] -) - defmodule HorsesOrHumans do alias Axon.Loop.State import Nx.Defn - # Download and extract from https://laurencemoroney.com/datasets.html + # Download and extract from + # https://www.kaggle.com/datasets/sanikamal/horses-or-humans-dataset # or you can use Req to download and extract the zip file and iterate # over the resulting data @directories "examples/vision/{horses,humans}/*" @@ -48,14 +45,13 @@ defmodule HorsesOrHumans do do: Nx.tensor([1, 0], type: {:u, 8}), else: Nx.tensor([0, 1], type: {:u, 8}) - {:ok, binary, shape, :u8, :rgba} = StbImage.from_file(filename) + {:ok, img} = StbImage.read_file(filename) {StbImage.to_nx(img), class} end - defp build_model(input_shape, transpose_shape) do + defp build_model(input_shape) do Axon.input("input", shape: input_shape) - |> Axon.transpose(transpose_shape) |> Axon.conv(16, kernel_size: {3, 3}, activation: :relu) |> Axon.max_pool(kernel_size: {2, 2}) |> Axon.conv(32, kernel_size: {3, 3}, activation: :relu) @@ -78,13 +74,13 @@ defmodule HorsesOrHumans do model |> Axon.Loop.trainer(:binary_cross_entropy, optimizer, log: 1) |> Axon.Loop.metric(:accuracy) - |> Axon.Loop.run(data, %{}, epochs: epochs, iterations: 100) + |> Axon.Loop.run(data, %{}, epochs: epochs, iterations: 100, compiler: EXLA) end def run() do - model = build_model({nil, 300, 300, 4}, [2, 0, 1]) |> IO.inspect() - optimizer = Axon.Optimizers.adam(1.0e-4) - centralized_optimizer = Axon.Updates.compose(Axon.Updates.centralize(), optimizer) + model = build_model({nil, 300, 300, 4}) |> IO.inspect() + optimizer = Polaris.Optimizers.adam(learning_rate: 1.0e-4) + centralized_optimizer = Polaris.Updates.compose(Polaris.Updates.centralize(), optimizer) data = data() diff --git a/examples/vision/mnist.exs b/examples/vision/mnist.exs index fbec6139..a643368c 100644 --- a/examples/vision/mnist.exs +++ b/examples/vision/mnist.exs @@ -1,13 +1,11 @@ Mix.install([ - {:axon, "~> 0.1.0"}, - {:exla, "~> 0.2.2"}, - {:nx, "~> 0.2.1"}, - {:scidata, "~> 0.1.6"} + {:axon, "~> 0.5"}, + {:polaris, "~> 0.1"}, + {:exla, "~> 0.5"}, + {:nx, "~> 0.5"}, + {:scidata, "~> 0.1"} ]) -# Configure default platform with accelerator precedence as tpu > cuda > rocm > host -EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host]) - defmodule Mnist do require Axon @@ -16,7 +14,7 @@ defmodule Mnist do |> Nx.from_binary(type) |> Nx.reshape({elem(shape, 0), 784}) |> Nx.divide(255.0) - |> Nx.to_batched_list(32) + |> Nx.to_batched(32) # Test split |> Enum.split(1750) end @@ -26,7 +24,7 @@ defmodule Mnist do |> Nx.from_binary(type) |> Nx.new_axis(-1) |> Nx.equal(Nx.tensor(Enum.to_list(0..9))) - |> Nx.to_batched_list(32) + |> Nx.to_batched(32) # Test split |> Enum.split(1750) end @@ -40,16 +38,16 @@ defmodule Mnist do defp train_model(model, train_images, train_labels, epochs) do model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adamw(0.005)) + |> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adamw(learning_rate: 0.005)) |> Axon.Loop.metric(:accuracy, "Accuracy") - |> Axon.Loop.run(Stream.zip(train_images, train_labels), %{}, epochs: epochs) + |> Axon.Loop.run(Stream.zip(train_images, train_labels), %{}, epochs: epochs, compiler: EXLA) end defp test_model(model, model_state, test_images, test_labels) do model |> Axon.Loop.evaluator() |> Axon.Loop.metric(:accuracy, "Accuracy") - |> Axon.Loop.run(Stream.zip(test_images, test_labels), model_state) + |> Axon.Loop.run(Stream.zip(test_images, test_labels), model_state, compiler: EXLA) end def run do diff --git a/guides/guides.md b/guides/guides.md index d8538be1..4f3a32ea 100644 --- a/guides/guides.md +++ b/guides/guides.md @@ -4,29 +4,29 @@ Axon is a library for creating and training neural networks in Elixir. The Axon ## Model Creation -* [Your first Axon model](your_first_axon_model.html) -* [Sequential models](sequential_models.html) -* [Complex models](complex_models.html) -* [Multi-input / multi-output models](multi_input_multi_output_models.html) -* [Custom layers](custom_layers.html) -* [Model hooks](model_hooks.html) +* [Your first Axon model](model_creation/your_first_axon_model.livemd) +* [Sequential models](model_creation/sequential_models.livemd) +* [Complex models](model_creation/complex_models.livemd) +* [Multi-input / multi-output models](model_creation/multi_input_multi_output_models.livemd) +* [Custom layers](model_creation/custom_layers.livemd) +* [Model hooks](model_creation/model_hooks.livemd) ## Model Execution -* [Accelerating Axon](accelerating_axon.html) -* [Training and inference mode](training_and_inference_mode.html) +* [Accelerating Axon](model_execution/accelerating_axon.livemd) +* [Training and inference mode](model_execution/training_and_inference_mode.livemd) ## Training and Evaluation -* [Your first training loop](your_first_training_loop.html) -* [Instrumenting loops with metrics](instrumenting_loops_with_metrics.html) -* [Your first evalutaion loop](your_first_evaluation_loop.html) -* [Using loop event handlers](using_loop_event_handlers.html) -* [Custom models, loss functions, and optimizers](custom_models_loss_optimizers.html) -* [Writing custom metrics](writing_custom_metrics.html) -* [Writing custom event handlers](writing_custom_event_handlers.html) +* [Your first training loop](training_and_evaluation/your_first_training_loop.livemd) +* [Instrumenting loops with metrics](training_and_evaluation/instrumenting_loops_with_metrics.livemd) +* [Your first evaluation loop](training_and_evaluation/your_first_evaluation_loop.livemd) +* [Using loop event handlers](training_and_evaluation/using_loop_event_handlers.livemd) +* [Custom models, loss functions, and optimizers](training_and_evaluation/custom_models_loss_optimizers.livemd) +* [Writing custom metrics](training_and_evaluation/writing_custom_metrics.livemd) +* [Writing custom event handlers](training_and_evaluation/writing_custom_event_handlers.livemd) ## Serialization -* [Converting ONNX models to Axon](onnx_to_axon.html) +* [Converting ONNX models to Axon](serialization/onnx_to_axon.livemd) diff --git a/guides/model_creation/complex_models.livemd b/guides/model_creation/complex_models.livemd index 89d2f64c..0dc6cd2c 100644 --- a/guides/model_creation/complex_models.livemd +++ b/guides/model_creation/complex_models.livemd @@ -1,12 +1,9 @@ - - # Complex models ```elixir Mix.install([ - {:axon, github: "elixir-nx/axon"}, - {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true}, - {:kino, "~> 0.7.0"} + {:axon, ">= 0.5.0"}, + {:kino, ">= 0.9.0"} ]) ``` @@ -55,19 +52,19 @@ Axon.Display.as_graph(out, template) ```mermaid graph TD; 3[/"data (:input) {2, 8}"/]; -6["dense_0 (:dense) {2, 32}"]; -9["dense_1 (:dense) {2, 64}"]; -10["relu_0 (:relu) {2, 64}"]; -13["dense_2 (:dense) {2, 32}"]; -14["container_0 (:container) {{2, 32}, {2, 32}}"]; -15["add_0 (:add) {2, 32}"]; -14 --> 15; -13 --> 14; -6 --> 14; -10 --> 13; -9 --> 10; -3 --> 9; -3 --> 6; +4["dense_0 (:dense) {2, 32}"]; +5["dense_1 (:dense) {2, 64}"]; +6["relu_0 (:relu) {2, 64}"]; +7["dense_2 (:dense) {2, 32}"]; +8["container_0 (:container) {{2, 32}, {2, 32}}"]; +9["add_0 (:add) {2, 32}"]; +8 --> 9; +7 --> 8; +4 --> 8; +6 --> 7; +5 --> 6; +3 --> 5; +3 --> 4; ``` And you can use `Axon.build/2` on `out` as you would any other Axon model: @@ -79,8 +76,8 @@ And you can use `Axon.build/2` on `out` as you would any other Axon model: ``` -{#Function<135.51955502/2 in Nx.Defn.Compiler.fun/2>, - #Function<135.51955502/2 in Nx.Defn.Compiler.fun/2>} +{#Function<135.109794929/2 in Nx.Defn.Compiler.fun/2>, + #Function<135.109794929/2 in Nx.Defn.Compiler.fun/2>} ``` ```elixir @@ -94,8 +91,8 @@ predict_fn.(params, Nx.iota({2, 8}, type: :f32)) #Nx.Tensor< f32[2][32] [ - [-3.4256787300109863, -0.866683840751648, -0.2629307508468628, 3.2555718421936035, 2.2740533351898193, 3.0403499603271484, -2.7904915809631348, 3.4799132347106934, -4.16396951675415, -4.545778274536133, 3.146249532699585, -3.0786540508270264, 3.4500746726989746, 1.1419837474822998, -0.7993628978729248, 2.3798861503601074, 4.787802696228027, 1.290929913520813, 1.8274409770965576, -1.5016275644302368, 3.441028118133545, -1.8077948093414307, 0.25549376010894775, -2.555987596511841, -4.643674850463867, 2.164360523223877, -0.30402517318725586, -2.54134464263916, -2.699089527130127, 4.074007511138916, -0.7711544036865234, -3.988246202468872], - [-11.235082626342773, -1.5991168022155762, -4.076810836791992, 11.091293334960938, 4.669280052185059, 12.756690979003906, -1.4954360723495483, 4.8143310546875, -14.211947441101074, -11.360504150390625, 6.239661693572998, -0.9994411468505859, 8.645132064819336, -0.5422897338867188, -1.4019453525543213, 9.633858680725098, 10.077424049377441, -0.3623824119567871, ...] + [-4.283246040344238, 1.8983498811721802, 3.697357654571533, -4.720174789428711, 4.1636152267456055, 1.001131534576416, -0.7027540802955627, -3.7821826934814453, 0.027841567993164062, 9.267499923706055, 3.33616304397583, -1.5465859174728394, 8.983413696289062, 3.7445120811462402, 2.2405576705932617, -3.61336350440979, -1.7320983409881592, 0.5740477442741394, -0.22006472945213318, -0.1806044578552246, 1.1092393398284912, -0.29313594102859497, -0.41948509216308594, 3.526411533355713, -0.9127179384231567, 1.8373844623565674, 1.1746022701263428, -0.6885149478912354, -1.4326229095458984, -1.3498257398605347, -5.803186416625977, 1.5204020738601685], + [-15.615742683410645, 6.555544853210449, 7.033155918121338, -12.33556842803955, 14.105436325073242, -4.230871200561523, 5.985136032104492, -8.445676803588867, 5.383096694946289, 23.413570404052734, 0.8907639980316162, -1.400709629058838, 19.19326400756836, 13.784171104431152, 9.641424179077148, -8.407038688659668, -5.688483238220215, 4.383636474609375, ...] ] > ``` @@ -160,28 +157,28 @@ Axon.Display.as_graph(model, template) ```mermaid graph TD; -16[/"data (:input) {1, 28, 28, 3}"/]; -19["conv_0 (:conv) {1, 28, 28, 3}"]; -20["mish_0 (:mish) {1, 28, 28, 3}"]; -21["container_0 (:container) {{1, 28, 28, 3}, {1, 28, 28, 3}}"]; -22["add_0 (:add) {1, 28, 28, 3}"]; -23["max_pool_0 (:max_pool) {1, 14, 14, 3}"]; -24["flatten_0 (:flatten) {1, 588}"]; -27["dense_0 (:dense) {1, 32}"]; -28["relu_0 (:relu) {1, 32}"]; -31["dense_1 (:dense) {1, 32}"]; -32["relu_1 (:relu) {1, 32}"]; -35["dense_2 (:dense) {1, 1}"]; -32 --> 35; -31 --> 32; -28 --> 31; -27 --> 28; -24 --> 27; -23 --> 24; -22 --> 23; -21 --> 22; -16 --> 21; +10[/"data (:input) {1, 28, 28, 3}"/]; +11["conv_0 (:conv) {1, 28, 28, 3}"]; +12["mish_0 (:mish) {1, 28, 28, 3}"]; +13["container_0 (:container) {{1, 28, 28, 3}, {1, 28, 28, 3}}"]; +14["add_0 (:add) {1, 28, 28, 3}"]; +15["max_pool_0 (:max_pool) {1, 14, 14, 3}"]; +16["flatten_0 (:flatten) {1, 588}"]; +17["dense_0 (:dense) {1, 32}"]; +18["relu_0 (:relu) {1, 32}"]; +19["dense_1 (:dense) {1, 32}"]; +20["relu_1 (:relu) {1, 32}"]; +21["dense_2 (:dense) {1, 1}"]; 20 --> 21; 19 --> 20; -16 --> 19; +18 --> 19; +17 --> 18; +16 --> 17; +15 --> 16; +14 --> 15; +13 --> 14; +10 --> 13; +12 --> 13; +11 --> 12; +10 --> 11; ``` diff --git a/guides/model_creation/custom_layers.livemd b/guides/model_creation/custom_layers.livemd index 839d6e31..f2eb82b8 100644 --- a/guides/model_creation/custom_layers.livemd +++ b/guides/model_creation/custom_layers.livemd @@ -1,12 +1,9 @@ - - # Custom layers ```elixir Mix.install([ - {:axon, github: "elixir-nx/axon"}, - {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true}, - {:kino, "~> 0.7.0"} + {:axon, ">= 0.5.0"}, + {:kino, ">= 0.9.0"} ]) ``` @@ -28,7 +25,7 @@ To Axon, layers are really just `defn` implementations with special Axon inputs. The `defn` implementation looks like any other `defn` you'd write; however, it must always account for additional `opts` as an argument: ```elixir -defmodule CustomLayers do +defmodule CustomLayers0 do import Nx.Defn defn my_layer(input, opts \\ []) do @@ -44,7 +41,7 @@ end ``` -{:module, CustomLayers, <<70, 79, 82, 49, 0, 0, 11, ...>>, {:my_layer, 2}} +{:module, CustomLayers0, <<70, 79, 82, 49, 0, 0, 10, ...>>, true} ``` Regardless of the options you configure your layer to accept, the `defn` implementation will always receive a `:mode` option indicating whether or not the model is running in training or inference mode. You can customize the behavior of your layer depending on the mode. @@ -54,7 +51,7 @@ With an implementation defined, you need only to call `Axon.layer/3` to apply ou ```elixir input = Axon.input("data") -out = Axon.layer(&CustomLayers.my_layer/2, [input]) +out = Axon.layer(&CustomLayers0.my_layer/2, [input]) ``` @@ -86,7 +83,7 @@ graph TD; Notice that by default custom layers render with a default operation marked as `:custom`. This can make it difficult to determine which layer is which during inspection. You can control the rendering by passing `:op_name` to `Axon.layer/3`: ```elixir -out = Axon.layer(&CustomLayers.my_layer/2, [input], op_name: :my_layer) +out = Axon.layer(&CustomLayers0.my_layer/2, [input], op_name: :my_layer) Axon.Display.as_graph(out, template) ``` @@ -104,7 +101,7 @@ You can also control the name of your layer via the `:name` option. All other op ```elixir out = - Axon.layer(&CustomLayers.my_layer/2, [input], + Axon.layer(&CustomLayers0.my_layer/2, [input], name: "layer", op_name: :my_layer, alpha: 2.0 @@ -152,7 +149,7 @@ predict_fn.(params, Nx.iota({2, 8}, type: :f32)) Notice that this model does not have any trainable parameters because none of the layers have trainable parameters. You can introduce trainable parameters by passing inputs created with `Axon.param/3` to `Axon.layer/3`. For example, you can modify your original custom layer to take an additional trainable parameter: ```elixir -defmodule CustomLayers do +defmodule CustomLayers1 do import Nx.Defn defn my_layer(input, alpha, _opts \\ []) do @@ -166,7 +163,7 @@ end ``` -{:module, CustomLayers, <<70, 79, 82, 49, 0, 0, 11, ...>>, {:my_layer, 3}} +{:module, CustomLayers1, <<70, 79, 82, 49, 0, 0, 10, ...>>, true} ``` And then construct the layer with a regular Axon input and a trainable parameter: @@ -174,7 +171,7 @@ And then construct the layer with a regular Axon input and a trainable parameter ```elixir alpha = Axon.param("alpha", fn _ -> {} end) -out = Axon.layer(&CustomLayers.my_layer/3, [input, alpha], op_name: :my_layer) +out = Axon.layer(&CustomLayers1.my_layer/3, [input, alpha], op_name: :my_layer) ``` @@ -199,7 +196,7 @@ params = init_fn.(template, %{}) "my_layer_0" => %{ "alpha" => #Nx.Tensor< f32 - 1.194254994392395 + -1.2601861953735352 > } } @@ -212,7 +209,7 @@ Notice how your model now initializes with a trainable parameter `"alpha"` for y If you plan on re-using custom layers in many locations, it's recommended that you wrap them in an Elixir function as an interface: ```elixir -defmodule CustomLayers do +defmodule CustomLayers2 do import Nx.Defn def my_layer(%Axon{} = input, opts \\ []) do @@ -233,14 +230,14 @@ end ``` -{:module, CustomLayers, <<70, 79, 82, 49, 0, 0, 13, ...>>, {:my_layer_impl, 3}} +{:module, CustomLayers2, <<70, 79, 82, 49, 0, 0, 12, ...>>, true} ``` ```elixir out = input - |> CustomLayers.my_layer() - |> CustomLayers.my_layer() + |> CustomLayers2.my_layer() + |> CustomLayers2.my_layer() |> Axon.dense(1) ``` @@ -263,10 +260,10 @@ Axon.Display.as_graph(out, template) ```mermaid graph TD; 3[/"data (:input) {2, 8}"/]; -10["my_layer_0 (:my_layer) {2, 8}"]; -12["my_layer_1 (:my_layer) {2, 8}"]; -15["dense_0 (:dense) {2, 1}"]; -12 --> 15; -10 --> 12; -3 --> 10; +8["my_layer_0 (:my_layer) {2, 8}"]; +9["my_layer_1 (:my_layer) {2, 8}"]; +10["dense_0 (:dense) {2, 1}"]; +9 --> 10; +8 --> 9; +3 --> 8; ``` diff --git a/guides/model_creation/model_hooks.livemd b/guides/model_creation/model_hooks.livemd index 7d50b5d3..f2ffc10a 100644 --- a/guides/model_creation/model_hooks.livemd +++ b/guides/model_creation/model_hooks.livemd @@ -1,11 +1,8 @@ - - # Model hooks ```elixir Mix.install([ - {:axon, github: "elixir-nx/axon"}, - {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true} + {:axon, ">= 0.5.0"} ]) ``` @@ -49,10 +46,10 @@ dense_init: %{ "kernel" => #Nx.Tensor< f32[4][8] [ - [-0.40611347556114197, -0.1551232784986496, 0.08485020697116852, -0.6748610734939575, 0.04797258973121643, -0.059523195028305054, 0.4092640280723572, 0.1300794780254364], - [-0.3551754057407379, 0.3159058094024658, 0.25394684076309204, 0.22510826587677002, 0.2613920271396637, -0.15213526785373688, -0.15744848549365997, -0.46065202355384827], - [-0.5224899649620056, 0.3639957010746002, -0.19676287472248077, 0.5423932075500488, -0.4722306430339813, 0.26447463035583496, 0.18534891307353973, -0.6442952752113342], - [-0.5629043579101562, 0.6370815634727478, -0.43325361609458923, 0.5084872245788574, -0.1424017995595932, 0.4865548312664032, -0.5839526057243347, 0.09811079502105713] + [0.6067318320274353, 0.5483129620552063, -0.05663269758224487, -0.48249542713165283, -0.18357598781585693, 0.6496620774269104, 0.4919115900993347, -0.08380156755447388], + [-0.19745409488677979, 0.10483592748641968, -0.43387970328330994, -0.1041460633277893, -0.4129607081413269, -0.6482449769973755, 0.6696910262107849, 0.4690167307853699], + [-0.18194729089736938, -0.4856645464897156, 0.39400774240493774, -0.28496378660202026, 0.32120805978775024, -0.41854584217071533, 0.5671316981315613, -0.21937215328216553], + [0.4516749978065491, -0.23585206270217896, -0.6682141423225403, 0.4286096692085266, -0.14930623769760132, -0.3825327157974243, 0.2700549364089966, -0.3888852596282959] ] > } @@ -70,10 +67,10 @@ dense_init: %{ "kernel" => #Nx.Tensor< f32[4][8] [ - [-0.40611347556114197, -0.1551232784986496, 0.08485020697116852, -0.6748610734939575, 0.04797258973121643, -0.059523195028305054, 0.4092640280723572, 0.1300794780254364], - [-0.3551754057407379, 0.3159058094024658, 0.25394684076309204, 0.22510826587677002, 0.2613920271396637, -0.15213526785373688, -0.15744848549365997, -0.46065202355384827], - [-0.5224899649620056, 0.3639957010746002, -0.19676287472248077, 0.5423932075500488, -0.4722306430339813, 0.26447463035583496, 0.18534891307353973, -0.6442952752113342], - [-0.5629043579101562, 0.6370815634727478, -0.43325361609458923, 0.5084872245788574, -0.1424017995595932, 0.4865548312664032, -0.5839526057243347, 0.09811079502105713] + [0.6067318320274353, 0.5483129620552063, -0.05663269758224487, -0.48249542713165283, -0.18357598781585693, 0.6496620774269104, 0.4919115900993347, -0.08380156755447388], + [-0.19745409488677979, 0.10483592748641968, -0.43387970328330994, -0.1041460633277893, -0.4129607081413269, -0.6482449769973755, 0.6696910262107849, 0.4690167307853699], + [-0.18194729089736938, -0.4856645464897156, 0.39400774240493774, -0.28496378660202026, 0.32120805978775024, -0.41854584217071533, 0.5671316981315613, -0.21937215328216553], + [0.4516749978065491, -0.23585206270217896, -0.6682141423225403, 0.4286096692085266, -0.14930623769760132, -0.3825327157974243, 0.2700549364089966, -0.3888852596282959] ] > } @@ -89,18 +86,11 @@ predict_fn.(params, input) ``` -dense_forward: #Nx.Tensor< - f32[2][8] - [ - [-3.0888683795928955, 2.955142021179199, -1.4393397569656372, 2.8353562355041504, -1.1102746725082397, 1.8364784717559814, -1.538608431816101, -1.454910159111023], - [-10.475601196289062, 7.602581024169922, -2.604217529296875, 5.239866733551025, -2.331346035003662, 3.993962526321411, -2.125761032104492, -4.961938381195068] - ] -> relu: #Nx.Tensor< f32[2][8] [ - [0.0, 2.955142021179199, 0.0, 2.8353562355041504, 0.0, 1.8364784717559814, 0.0, 0.0], - [0.0, 7.602581024169922, 0.0, 5.239866733551025, 0.0, 3.993962526321411, 0.0, 0.0] + [0.7936763167381287, 0.0, 0.0, 0.61175537109375, 0.0, 0.0, 2.614119291305542, 0.0], + [3.5096981525421143, 0.0, 0.0, 0.0, 0.0, 0.0, 10.609275817871094, 0.0] ] > ``` @@ -111,8 +101,8 @@ relu: #Nx.Tensor< #Nx.Tensor< f32[2][8] [ - [0.0, 2.955142021179199, 0.0, 2.8353562355041504, 0.0, 1.8364784717559814, 0.0, 0.0], - [0.0, 7.602581024169922, 0.0, 5.239866733551025, 0.0, 3.993962526321411, 0.0, 0.0] + [0.7936763167381287, 0.0, 0.0, 0.61175537109375, 0.0, 0.0, 2.614119291305542, 0.0], + [3.5096981525421143, 0.0, 0.0, 0.0, 0.0, 0.0, 10.609275817871094, 0.0] ] > ``` @@ -136,18 +126,11 @@ predict_fn.(params, input) ``` -hook1: #Nx.Tensor< - f32[2][8] - [ - [1.3320910930633545, 1.712153673171997, -2.0420351028442383, 2.2541849613189697, -3.1382551193237305, -1.2241677045822144, -1.5477651357650757, -0.2126261293888092], - [2.1975531578063965, 3.722827911376953, -1.6301460266113281, 5.891226768493652, -10.79372787475586, -2.9982359409332275, -6.589874267578125, 1.5387766361236572] - ] -> hook2: #Nx.Tensor< f32[2][8] [ - [1.3320910930633545, 1.712153673171997, -2.0420351028442383, 2.2541849613189697, -3.1382551193237305, -1.2241677045822144, -1.5477651357650757, -0.2126261293888092], - [2.1975531578063965, 3.722827911376953, -1.6301460266113281, 5.891226768493652, -10.79372787475586, -2.9982359409332275, -6.589874267578125, 1.5387766361236572] + [-0.6567458510398865, 2.2303993701934814, -1.540865421295166, -1.873536229133606, -2.386439085006714, -1.248870849609375, -2.9092607498168945, -0.1976098120212555], + [2.4088101387023926, 5.939034461975098, -2.024522066116333, -7.58249568939209, -10.193460464477539, 0.33839887380599976, -10.836882591247559, 1.8173918724060059] ] > ``` @@ -158,8 +141,8 @@ hook2: #Nx.Tensor< #Nx.Tensor< f32[2][8] [ - [1.3320910930633545, 1.712153673171997, 0.0, 2.2541849613189697, 0.0, 0.0, 0.0, 0.0], - [2.1975531578063965, 3.722827911376953, 0.0, 5.891226768493652, 0.0, 0.0, 0.0, 1.5387766361236572] + [0.0, 2.2303993701934814, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [2.4088101387023926, 5.939034461975098, 0.0, 0.0, 0.0, 0.33839887380599976, 0.0, 1.8173918724060059] ] > ``` @@ -182,8 +165,8 @@ model = ``` -{#Function<136.40088443/2 in Nx.Defn.wrap_arity/2>, - #Function<136.40088443/2 in Nx.Defn.wrap_arity/2>} +{#Function<135.109794929/2 in Nx.Defn.Compiler.fun/2>, + #Function<135.109794929/2 in Nx.Defn.Compiler.fun/2>} ``` On initialization: @@ -203,10 +186,10 @@ params = init_fn.(input, %{}) "kernel" => #Nx.Tensor< f32[4][8] [ - [0.6784419417381287, 0.175045907497406, 0.010701040737330914, -0.5537784695625305, -0.010694148950278759, 0.7021086812019348, -0.3290281891822815, -0.6818609237670898], - [-0.6378231644630432, -0.5675055384635925, 0.031453751027584076, 0.4705190360546112, -0.002226108219474554, 0.48611924052238464, 0.5700677037239075, 0.6729928851127625], - [0.4596043527126312, -0.6557875871658325, -0.07168347388505936, -0.37926459312438965, -0.20766735076904297, 0.11274437606334686, -0.5166378617286682, -0.5115087032318115], - [-0.30842259526252747, -0.3418923616409302, 0.3374936282634735, 0.6272460222244263, 0.6156857013702393, 0.6739501357078552, -0.09081890434026718, 0.706954836845398] + [0.2199305295944214, -0.05434012413024902, -0.07989239692687988, -0.4456246793270111, -0.2792319655418396, -0.1601254940032959, -0.6115692853927612, 0.37740427255630493], + [-0.3606935739517212, 0.6091846823692322, -0.3203054368495941, -0.6252920031547546, -0.41500264406204224, -0.20729252696037292, -0.6763507127761841, -0.6776859164237976], + [0.659041702747345, -0.615885317325592, -0.45865312218666077, 0.18774819374084473, 0.31994110345840454, -0.3055777847766876, -0.3537192642688751, 0.4297131896018982], + [0.06112170219421387, 0.13321959972381592, 0.5566524863243103, -0.1115691065788269, -0.3557875156402588, -0.03118818998336792, -0.5788122415542603, -0.6988758444786072] ] > } @@ -224,10 +207,10 @@ params = init_fn.(input, %{}) "kernel" => #Nx.Tensor< f32[4][8] [ - [0.6784419417381287, 0.175045907497406, 0.010701040737330914, -0.5537784695625305, -0.010694148950278759, 0.7021086812019348, -0.3290281891822815, -0.6818609237670898], - [-0.6378231644630432, -0.5675055384635925, 0.031453751027584076, 0.4705190360546112, -0.002226108219474554, 0.48611924052238464, 0.5700677037239075, 0.6729928851127625], - [0.4596043527126312, -0.6557875871658325, -0.07168347388505936, -0.37926459312438965, -0.20766735076904297, 0.11274437606334686, -0.5166378617286682, -0.5115087032318115], - [-0.30842259526252747, -0.3418923616409302, 0.3374936282634735, 0.6272460222244263, 0.6156857013702393, 0.6739501357078552, -0.09081890434026718, 0.706954836845398] + [0.2199305295944214, -0.05434012413024902, -0.07989239692687988, -0.4456246793270111, -0.2792319655418396, -0.1601254940032959, -0.6115692853927612, 0.37740427255630493], + [-0.3606935739517212, 0.6091846823692322, -0.3203054368495941, -0.6252920031547546, -0.41500264406204224, -0.20729252696037292, -0.6763507127761841, -0.6776859164237976], + [0.659041702747345, -0.615885317325592, -0.45865312218666077, 0.18774819374084473, 0.31994110345840454, -0.3055777847766876, -0.3537192642688751, 0.4297131896018982], + [0.06112170219421387, 0.13321959972381592, 0.5566524863243103, -0.1115691065788269, -0.3557875156402588, -0.03118818998336792, -0.5788122415542603, -0.6988758444786072] ] > }, @@ -239,14 +222,14 @@ params = init_fn.(input, %{}) "kernel" => #Nx.Tensor< f32[8][1] [ - [-0.7136709690093994], - [-0.16328231990337372], - [0.08359552919864655], - [0.07646285742521286], - [0.7133787274360657], - [-0.00617210753262043], - [0.2241944670677185], - [-0.055933959782123566] + [0.3259686231613159], + [0.4874255657196045], + [0.6338149309158325], + [0.4437469244003296], + [-0.22870665788650513], + [0.8108665943145752], + [7.919073104858398e-4], + [0.4469025135040283] ] > } @@ -272,15 +255,15 @@ predict_fn.(params, input) #Nx.Tensor< f32[2][8] [ - [-0.6438822746276855, -2.9047577381134033, 0.9005677103996277, 1.593727946281433, 1.4294962882995605, 2.7334585189819336, -0.7356647253036499, 1.7708399295806885], - [0.12331989407539368, -8.465315818786621, 2.132427453994751, 2.2526159286499023, 3.0098886489868164, 10.633148193359375, -2.20133376121521, 2.5171523094177246] + [1.1407549381256104, -0.22292715311050415, 0.43234577775001526, -0.5845029354095459, -0.8424829840660095, -0.9120126962661743, -3.1202259063720703, -1.9148870706558228], + [3.4583563804626465, 0.06578820943832397, -0.776448130607605, -4.563453197479248, -3.7628071308135986, -3.7287485599517822, -12.002032279968262, -4.19266414642334] ] > #Nx.Tensor< f32[2][8] [ - [-0.6438822746276855, -2.9047577381134033, 0.9005677103996277, 1.593727946281433, 1.4294962882995605, 2.7334585189819336, -0.7356647253036499, 1.7708399295806885], - [0.12331989407539368, -8.465315818786621, 2.132427453994751, 2.2526159286499023, 3.0098886489868164, 10.633148193359375, -2.20133376121521, 2.5171523094177246] + [1.1407549381256104, -0.22292715311050415, 0.43234577775001526, -0.5845029354095459, -0.8424829840660095, -0.9120126962661743, -3.1202259063720703, -1.9148870706558228], + [3.4583563804626465, 0.06578820943832397, -0.776448130607605, -4.563453197479248, -3.7628071308135986, -3.7287485599517822, -12.002032279968262, -4.19266414642334] ] > ``` @@ -291,8 +274,8 @@ predict_fn.(params, input) #Nx.Tensor< f32[2][1] [ - [1.100995421409607], - [2.2032604217529297] + [0.6458775401115417], + [1.1593825817108154] ] > ``` @@ -316,15 +299,15 @@ Nx.Defn.grad(fn params -> predict_fn.(params, input) end).(params) #Nx.Tensor< f32[2][8] [ - [-0.6438822746276855, -2.9047577381134033, 0.9005677103996277, 1.593727946281433, 1.4294962882995605, 2.7334585189819336, -0.7356647253036499, 1.7708399295806885], - [0.12331989407539368, -8.465315818786621, 2.132427453994751, 2.2526159286499023, 3.0098886489868164, 10.633148193359375, -2.20133376121521, 2.5171523094177246] + [1.1407549381256104, -0.22292715311050415, 0.43234577775001526, -0.5845029354095459, -0.8424829840660095, -0.9120126962661743, -3.1202259063720703, -1.9148870706558228], + [3.4583563804626465, 0.06578820943832397, -0.776448130607605, -4.563453197479248, -3.7628071308135986, -3.7287485599517822, -12.002032279968262, -4.19266414642334] ] > #Nx.Tensor< f32[2][8] [ - [-0.6438822746276855, -2.9047577381134033, 0.9005677103996277, 1.593727946281433, 1.4294962882995605, 2.7334585189819336, -0.7356647253036499, 1.7708399295806885], - [0.12331989407539368, -8.465315818786621, 2.132427453994751, 2.2526159286499023, 3.0098886489868164, 10.633148193359375, -2.20133376121521, 2.5171523094177246] + [1.1407549381256104, -0.22292715311050415, 0.43234577775001526, -0.5845029354095459, -0.8424829840660095, -0.9120126962661743, -3.1202259063720703, -1.9148870706558228], + [3.4583563804626465, 0.06578820943832397, -0.776448130607605, -4.563453197479248, -3.7628071308135986, -3.7287485599517822, -12.002032279968262, -4.19266414642334] ] > ``` @@ -336,15 +319,15 @@ Nx.Defn.grad(fn params -> predict_fn.(params, input) end).(params) "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] - [-0.7136709690093994, 0.0, 0.1671910583972931, 0.15292571485042572, 1.4267574548721313, -0.01234421506524086, 0.0, -0.11186791956424713] + [0.6519372463226318, 0.4874255657196045, 0.6338149309158325, 0.0, 0.0, 0.0, 0.0, 0.0] >, "kernel" => #Nx.Tensor< f32[4][8] [ - [-2.8546838760375977, 0.0, 0.3343821167945862, 0.30585142970085144, 2.8535149097442627, -0.02468843013048172, 0.0, -0.22373583912849426], - [-3.568354845046997, 0.0, 0.5015732049942017, 0.45877712965011597, 4.280272483825684, -0.03703264519572258, 0.0, -0.3356037735939026], - [-4.2820258140563965, 0.0, 0.6687642335891724, 0.6117028594017029, 5.707029819488525, -0.04937686026096344, 0.0, -0.4474716782569885], - [-4.995697021484375, 0.0, 0.8359552621841431, 0.7646285891532898, 7.133787155151367, -0.0617210753262043, 0.0, -0.5593395829200745] + [1.3038744926452637, 1.949702262878418, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.9558117389678955, 2.4371278285980225, 0.6338149309158325, 0.0, 0.0, 0.0, 0.0, 0.0], + [2.6077489852905273, 2.924553394317627, 1.267629861831665, 0.0, 0.0, 0.0, 0.0, 0.0], + [3.259686231613159, 3.4119789600372314, 1.9014447927474976, 0.0, 0.0, 0.0, 0.0, 0.0] ] > }, @@ -356,14 +339,14 @@ Nx.Defn.grad(fn params -> predict_fn.(params, input) end).(params) "kernel" => #Nx.Tensor< f32[8][1] [ - [0.12331989407539368], + [4.599111557006836], + [0.06578820943832397], + [0.43234577775001526], + [0.0], + [0.0], [0.0], - [3.0329952239990234], - [3.846343994140625], - [4.439384937286377], - [13.366606712341309], [0.0], - [4.287992477416992] + [0.0] ] > } @@ -395,10 +378,10 @@ params = init_fn.(input, %{}) "kernel" => #Nx.Tensor< f32[4][8] [ - [0.13930729031562805, 0.6213980913162231, 0.5555388331413269, -0.18602639436721802, 0.37516212463378906, 0.025288991630077362, 0.5311357378959656, 0.2825106978416443], - [-0.14007511734962463, -0.1472432166337967, -0.011716545559465885, 0.06804006546735764, 0.4615606963634491, -0.024897094815969467, -0.2336975485086441, 0.10019711405038834], - [-0.29539188742637634, -0.5487134456634521, 0.41018739342689514, -0.49597275257110596, 0.2970600426197052, 0.4304136335849762, 0.13961079716682434, -0.4316418170928955], - [0.5435506105422974, -0.056049738079309464, 0.5059406161308289, 0.29488587379455566, 0.5656863451004028, 0.43807661533355713, -0.5058187246322632, -0.6963644623756409] + [-0.13241732120513916, 0.6946331858634949, -0.6328000426292419, -0.684409499168396, -0.39569517970085144, -0.10005003213882446, 0.2501150965690613, 0.14561182260513306], + [-0.5495109558105469, 0.459137499332428, -0.4059434235095978, -0.4489462077617645, -0.6331832408905029, 0.05011630058288574, -0.35836488008499146, -0.2661571800708771], + [0.29260867834091187, 0.42186349630355835, 0.32596689462661743, -0.12340176105499268, 0.6767188906669617, 0.2658537030220032, 0.5745270848274231, 6.475448608398438e-4], + [0.16781508922576904, 0.23747843503952026, -0.5311254858970642, 0.22617805004119873, -0.5153165459632874, 0.19729173183441162, -0.5706893801689148, -0.5531126260757446] ] > } @@ -417,8 +400,8 @@ predict_fn.(params, input) #Nx.Tensor< f32[2][8] [ - [0.8997929096221924, -1.412819266319275, 2.3264801502227783, -0.039247818291187286, 2.752739906311035, 2.150160074234009, -1.4719321727752686, -2.852180004119873], - [1.8893564939498901, -1.9352525472640991, 8.166281700134277, -1.3155406713485718, 9.550616264343262, 5.625688552856445, -1.7470110654830933, -5.833373546600342] + [0.539151668548584, 2.0152997970581055, -1.347386121749878, -0.017215579748153687, -0.8256950974464417, 1.173698902130127, -0.9213788509368896, -1.9241999387741089], + [-0.3468663692474365, 9.267749786376953, -6.322994232177734, -4.139533042907715, -4.295599460601807, 2.8265457153320312, -1.3390271663665771, -4.616241931915283] ] > ``` @@ -430,8 +413,8 @@ predict_fn.(params, input) prediction: #Nx.Tensor< f32[2][8] [ - [0.8997929096221924, 0.0, 2.3264801502227783, 0.0, 2.752739906311035, 2.150160074234009, 0.0, 0.0], - [1.8893564939498901, 0.0, 8.166281700134277, 0.0, 9.550616264343262, 5.625688552856445, 0.0, 0.0] + [0.539151668548584, 2.0152997970581055, 0.0, 0.0, 0.0, 1.173698902130127, 0.0, 0.0], + [0.0, 9.267749786376953, 0.0, 0.0, 0.0, 2.8265457153320312, 0.0, 0.0] ] >, state: %{} @@ -455,10 +438,10 @@ params = init_fn.(input, %{}) "kernel" => #Nx.Tensor< f32[4][8] [ - [0.4261569678783417, -0.6842133402824402, -0.13853907585144043, 0.6665098667144775, 0.6171062588691711, 0.25513389706611633, -0.4866299033164978, -0.5819953680038452], - [-0.36037471890449524, -0.21852241456508636, -0.6355746388435364, -0.5705516934394836, -0.35449153184890747, -0.1527744084596634, -0.5036700367927551, -0.4164859354496002], - [0.6485253572463989, 0.30033791065216064, 0.35249730944633484, -0.31768497824668884, 0.020564774051308632, 0.147691547870636, 0.6939279437065125, 0.6060985922813416], - [0.006978582590818405, 0.5333927869796753, 0.30155065655708313, -0.09574121236801147, 0.3447912037372589, -0.11081335693597794, 0.5808792114257812, 0.04360806941986084] + [0.02683490514755249, -0.28041765093803406, 0.15839070081710815, 0.16674137115478516, -0.5444575548171997, -0.34951671957969666, 0.08247309923171997, 0.6700448393821716], + [0.6001952290534973, -0.26907777786254883, 0.4580194354057312, -0.060002803802490234, -0.5385662317276001, -0.46773862838745117, 0.25804388523101807, -0.6824946999549866], + [0.13328874111175537, -0.46421635150909424, -0.5192649960517883, -0.0429919958114624, 0.0771912932395935, -0.447194904088974, 0.30910569429397583, -0.6105270981788635], + [0.5253992676734924, 0.41786473989486694, 0.6903378367424011, 0.6038702130317688, 0.06673228740692139, 0.4242702126502991, -0.6737087368965149, -0.6956207156181335] ] > } @@ -477,8 +460,8 @@ predict_fn.(params, input) #Nx.Tensor< f32[2][8] [ - [0.9576117396354675, 1.9823317527770996, 0.9740719795227051, 0.0, 0.7210116386413574, 0.0, 2.6268234252929688, 0.9265354871749878], - [3.842756509780884, 1.706311583518982, 0.49380895495414734, 0.0, 3.2328944206237793, 0.36711934208869934, 3.764852285385132, 0.0] + [2.4429705142974854, 0.056083738803863525, 1.490502953529358, 1.6656239032745361, 0.0, 0.0, 0.0, 0.0], + [7.585843086242676, 0.0, 4.640434741973877, 4.336091041564941, 0.0, 0.0, 0.0, 0.0] ] > ``` diff --git a/guides/model_creation/multi_input_multi_output_models.livemd b/guides/model_creation/multi_input_multi_output_models.livemd index 686b22ea..0a69d5ca 100644 --- a/guides/model_creation/multi_input_multi_output_models.livemd +++ b/guides/model_creation/multi_input_multi_output_models.livemd @@ -1,12 +1,9 @@ - - # Multi-input / multi-output models ```elixir Mix.install([ - {:axon, github: "elixir-nx/axon"}, - {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true}, - {:kino, "~> 0.7.0"} + {:axon, ">= 0.5.0"}, + {:kino, ">= 0.9.0"} ]) ``` @@ -144,17 +141,17 @@ Axon.Display.as_graph(out, template) ```mermaid graph TD; 7[/"data (:input) {2, 8}"/]; -10["dense_0 (:dense) {2, 32}"]; -11["relu_0 (:relu) {2, 32}"]; -14["dense_1 (:dense) {2, 64}"]; -15["relu_1 (:relu) {2, 64}"]; -16["container_0 (:container) {{2, 32}, {2, 64}}"]; -15 --> 16; -11 --> 16; -14 --> 15; -7 --> 14; +8["dense_0 (:dense) {2, 32}"]; +9["relu_0 (:relu) {2, 32}"]; +10["dense_1 (:dense) {2, 64}"]; +11["relu_1 (:relu) {2, 64}"]; +12["container_0 (:container) {{2, 32}, {2, 64}}"]; +11 --> 12; +9 --> 12; 10 --> 11; 7 --> 10; +8 --> 9; +7 --> 8; ``` When executed, containers will return a data structure which matches their input structure: @@ -171,14 +168,14 @@ predict_fn.(params, Nx.iota({2, 8}, type: :f32)) {#Nx.Tensor< f32[2][32] [ - [0.0, 0.0, 3.111135482788086, 0.48920655250549316, 0.0, 0.5125713348388672, 0.0, 0.0, 1.482532262802124, 0.0, 0.0, 0.0, 0.0, 3.103637933731079, 0.46897295117378235, 2.6465413570404053, 2.837477445602417, 0.6159781217575073, 1.3220927715301514, 0.0, 0.24302834272384644, 3.4662821292877197, 0.40560781955718994, 0.0, 0.0, 0.2682836055755615, 3.5352964401245117, 0.0, 0.6591103672981262, 2.5643503665924072, 0.0, 0.0], - [0.0, 0.0, 4.642599105834961, 0.0, 0.0, 1.8978865146636963, 2.2522430419921875, 0.0, 1.2110804319381714, 2.5524141788482666, 0.0, 0.742849588394165, 0.0, 8.30776596069336, 5.09386682510376, 4.69991397857666, 5.195588111877441, ...] + [0.4453479051589966, 1.7394963502883911, 0.8509911298751831, 0.35142624378204346, 0.0, 0.0, 0.0, 3.942654609680176, 0.0, 0.0, 0.0, 0.6140655279159546, 0.0, 5.719906330108643, 1.1410939693450928, 0.0, 2.6871578693389893, 3.373258352279663, 0.0, 0.0, 0.0, 0.3058185875415802, 0.0, 0.0, 1.3737146854400635, 2.2648088932037354, 1.3570061922073364, 0.0, 0.05746358633041382, 0.0, 2.046199321746826, 4.884631156921387], + [0.0, 2.0598671436309814, 2.4343056678771973, 3.2341041564941406, 0.0, 1.905256748199463, 0.0, 12.712749481201172, 0.0, 0.0, 0.0, 4.559232711791992, 0.0, 12.027459144592285, 0.8423471450805664, 0.0, 8.888325691223145, ...] ] >, #Nx.Tensor< f32[2][64] [ - [0.0, 0.0, 0.7948622107505798, 0.0, 0.0, 0.0, 0.0, 0.0, 2.3980231285095215, 5.2512712478637695, 1.5820361375808716, 0.0, 2.6624603271484375, 0.0, 0.0, 0.0, 1.6954007148742676, 0.017102837562561035, 0.7754535675048828, 0.0, 1.891753911972046, 0.0, 2.7824556827545166, 0.0, 0.5906356573104858, 0.0, 0.0, 1.288651466369629, 0.6939071416854858, 0.8427785038948059, 1.5664646625518799, 0.38097164034843445, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3193289637565613, 0.0, 0.0, 0.35316526889801025, 0.0, 1.2567038536071777, 0.7732977867126465, 0.16440902650356293, 0.0, 1.9872947931289673, ...], + [2.211906909942627, 0.937014639377594, 0.017132893204689026, 0.0, 3.617021083831787, 1.3125507831573486, 1.1870051622390747, 0.0, 0.0, 1.245000958442688, 1.5268664360046387, 0.0, 2.16796612739563, 0.8091188669204712, 0.45314761996269226, 0.0, 0.05176612734794617, 0.0, 5.982738018035889, 1.58057701587677, 0.0, 0.0, 1.2986125946044922, 0.8577098250389099, 0.0, 1.1064631938934326, 1.1242716312408447, 1.8777625560760498, 3.4422712326049805, 0.13321448862552643, 2.753225088119507, 0.0, 0.45021766424179077, 0.5664225816726685, 0.0, 0.0, 0.0, 1.5448659658432007, 0.0, 0.7237715721130371, 0.1693495213985443, 0.0, 0.719341516494751, 0.0, 0.0, 4.644839763641357, 0.0, 3.597681760787964, ...], ... ] >} @@ -213,14 +210,14 @@ predict_fn.(params, Nx.iota({2, 8}, type: :f32)) x1: #Nx.Tensor< f32[2][32] [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.8718442916870117, 0.0, 1.813383936882019, 0.0, 0.0, 0.0, 0.0, 3.0636630058288574, 0.0, 1.1350113153457642, 1.7888737916946411, 0.0658932775259018, 0.0, 0.4498137831687927, 1.1311852931976318, 3.2784717082977295, 0.0, 2.4505443572998047, 3.346879005432129, 0.0, 0.0, 2.614570140838623, 0.0, 0.0, 0.8967163562774658, 0.0], - [0.0, 0.0, 0.0, 1.9045438766479492, 0.0, 0.0, 7.110898971557617, 0.09859625995159149, 8.149545669555664, 0.0, 0.0, 0.0, 0.0, 4.178244113922119, 0.0, 3.8360297679901123, 6.177351474761963, ...] + [1.4180752038955688, 1.8710994720458984, 0.0, 1.1198676824569702, 1.1357430219650269, 0.0, 0.0, 0.0, 2.907017469406128, 0.0, 0.3814663589000702, 0.0, 0.6225995421409607, 1.1952786445617676, 0.0, 3.6701409816741943, 3.581918716430664, 1.4750021696090698, 0.910987377166748, 0.0, 0.0, 0.0, 2.317782402038574, 0.8362345695495605, 0.0, 1.9256348609924316, 0.0, 0.0, 0.0, 1.8028252124786377, 1.448373556137085, 1.743951678276062], + [3.7401936054229736, 2.494429349899292, 0.0, 0.9745509624481201, 8.416919708251953, 0.0, 0.6044515371322632, 0.0, 2.5829238891601562, 0.0, 3.592892646789551, 0.0, 0.0, 4.004939079284668, 0.0, 9.755555152893066, 5.3506879806518555, ...] ] >, x2: #Nx.Tensor< f32[2][64] [ - [0.41670602560043335, 0.0, 0.0, 0.0, 1.338260531425476, 0.0, 0.5181264877319336, 1.1024510860443115, 0.0, 0.0, 1.485485553741455, 0.0, 0.0, 1.9365136623382568, 0.0, 0.0, 0.0, 0.0, 2.6925604343414307, 0.6202171444892883, 0.0, 0.08886899054050446, 0.0, 1.3045244216918945, 0.0, 0.0545249879360199, 0.0, 1.2294358015060425, 0.0, 0.0, 0.670710563659668, 0.0, 4.161868572235107, 1.880513072013855, 2.6189277172088623, 0.5702207684516907, 0.0, 1.953904151916504, 0.0, 0.0, 1.370330572128296, 0.17245425283908844, 1.9922431707382202, 2.6845364570617676, 0.3711611032485962, 0.7940037250518799, 0.0, 2.12975811958313, ...], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.5240116119384766, 0.0, 1.6478428840637207, 0.0, 0.0, 0.0, 0.0, 2.1685361862182617, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.5010783672332764, 0.36673399806022644, 0.0, 0.0, 0.5610344409942627, 1.9324723482131958, 0.39768826961517334, 0.0, 0.0, 0.0, 0.0, 0.0, 0.054594263434410095, 0.6123883128166199, 0.15942004323005676, 0.7058550715446472, 0.0, 1.860019326210022, 0.2499483972787857, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03381317853927612, ...], ... ] > @@ -256,14 +253,14 @@ predict_fn.(params, Nx.iota({2, 8}, type: :f32)) x1: {#Nx.Tensor< f32[2][32] [ - [3.9104199409484863, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.051666498184204, 0.0, 1.086042881011963, 0.6107193827629089, 0.5136545896530151, 2.7927842140197754, 0.0, 0.0, 0.0, 0.0, 2.472961902618408, 0.13712915778160095, 0.49807000160217285, 1.7868735790252686, 5.796293258666992, 0.0, 0.0, 4.727283477783203, 0.0, 0.0, 2.129516363143921, 0.0, 0.0], - [11.746908187866211, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.840534687042236, 0.0, 0.0, 4.103122711181641, 1.0597835779190063, 8.971627235412598, ...] + [1.7373675107955933, 0.0, 5.150482177734375, 0.544252336025238, 0.275376558303833, 0.0, 0.0, 0.0, 0.0, 1.7849855422973633, 0.7857151031494141, 0.2273893654346466, 0.2701767086982727, 2.321484327316284, 2.685051441192627, 0.0, 2.547382116317749, 0.0, 0.0, 0.0, 0.722919225692749, 2.3600289821624756, 1.4695687294006348, 0.0, 0.0, 0.0, 1.0015852451324463, 1.2762010097503662, 0.0, 0.07927703857421875, 0.0, 0.6216219663619995], + [4.996878623962402, 0.0, 14.212154388427734, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.517582356929779, 0.0, 2.036062479019165, 2.907236337661743, 8.515787124633789, 7.998186111450195, ...] ] >, #Nx.Tensor< f32[2][64] [ - [0.951026439666748, 0.0, 0.6895619034767151, 0.12973949313163757, 3.0561492443084717, 0.0, 0.21812109649181366, 0.0, 0.6377829313278198, 0.0, 0.0, 0.0, 0.0, 1.6837494373321533, 0.0, 0.0, 0.0, 1.3907173871994019, 0.0, 0.0, 0.21352148056030273, 0.0, 1.2145031690597534, 0.0, 3.080430507659912, 0.0, 3.9572620391845703, 2.3347463607788086, 0.5280991196632385, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.616438627243042, 0.0, 1.1335082054138184, 2.228783369064331, 0.0, 1.0927692651748657, 0.0, 0.0, 0.0, 0.0, 2.7719650268554688, ...], + [1.2057430744171143, 0.0, 0.0, 0.8717040419578552, 1.7653638124465942, 0.0, 0.0, 0.0, 0.0, 0.9921279549598694, 0.0, 1.0860291719436646, 2.3648557662963867, 0.0, 0.0, 2.0518181324005127, 1.6323723793029785, 0.9113610982894897, 1.6805293560028076, 0.8101096749305725, 0.0, 0.0, 0.0, 2.2150073051452637, 0.0, 0.0, 0.0, 0.0, 0.0, 2.2320713996887207, 0.0, 2.553570508956909, 0.28632092475891113, 0.0, 0.0, 0.020383253693580627, 0.0, 0.2926883101463318, 1.3561311960220337, 0.8884503245353699, 3.1455295085906982, 0.0, 0.0, 1.237722635269165, 0.0, 2.149625539779663, ...], ... ] >}, @@ -271,14 +268,14 @@ predict_fn.(params, Nx.iota({2, 8}, type: :f32)) x1: #Nx.Tensor< f32[2][32] [ - [3.9104199409484863, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.051666498184204, 0.0, 1.086042881011963, 0.6107193827629089, 0.5136545896530151, 2.7927842140197754, 0.0, 0.0, 0.0, 0.0, 2.472961902618408, 0.13712915778160095, 0.49807000160217285, 1.7868735790252686, 5.796293258666992, 0.0, 0.0, 4.727283477783203, 0.0, 0.0, 2.129516363143921, 0.0, 0.0], - [11.746908187866211, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.840534687042236, 0.0, 0.0, 4.103122711181641, 1.0597835779190063, ...] + [1.7373675107955933, 0.0, 5.150482177734375, 0.544252336025238, 0.275376558303833, 0.0, 0.0, 0.0, 0.0, 1.7849855422973633, 0.7857151031494141, 0.2273893654346466, 0.2701767086982727, 2.321484327316284, 2.685051441192627, 0.0, 2.547382116317749, 0.0, 0.0, 0.0, 0.722919225692749, 2.3600289821624756, 1.4695687294006348, 0.0, 0.0, 0.0, 1.0015852451324463, 1.2762010097503662, 0.0, 0.07927703857421875, 0.0, 0.6216219663619995], + [4.996878623962402, 0.0, 14.212154388427734, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.517582356929779, 0.0, 2.036062479019165, 2.907236337661743, 8.515787124633789, ...] ] >, x2: {#Nx.Tensor< f32[2][64] [ - [0.951026439666748, 0.0, 0.6895619034767151, 0.12973949313163757, 3.0561492443084717, 0.0, 0.21812109649181366, 0.0, 0.6377829313278198, 0.0, 0.0, 0.0, 0.0, 1.6837494373321533, 0.0, 0.0, 0.0, 1.3907173871994019, 0.0, 0.0, 0.21352148056030273, 0.0, 1.2145031690597534, 0.0, 3.080430507659912, 0.0, 3.9572620391845703, 2.3347463607788086, 0.5280991196632385, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.616438627243042, 0.0, 1.1335082054138184, 2.228783369064331, 0.0, 1.0927692651748657, 0.0, 0.0, 0.0, ...], + [1.2057430744171143, 0.0, 0.0, 0.8717040419578552, 1.7653638124465942, 0.0, 0.0, 0.0, 0.0, 0.9921279549598694, 0.0, 1.0860291719436646, 2.3648557662963867, 0.0, 0.0, 2.0518181324005127, 1.6323723793029785, 0.9113610982894897, 1.6805293560028076, 0.8101096749305725, 0.0, 0.0, 0.0, 2.2150073051452637, 0.0, 0.0, 0.0, 0.0, 0.0, 2.2320713996887207, 0.0, 2.553570508956909, 0.28632092475891113, 0.0, 0.0, 0.020383253693580627, 0.0, 0.2926883101463318, 1.3561311960220337, 0.8884503245353699, 3.1455295085906982, 0.0, 0.0, 1.237722635269165, ...], ... ] >} diff --git a/guides/model_creation/sequential_models.livemd b/guides/model_creation/sequential_models.livemd index e73505ca..c2931fa1 100644 --- a/guides/model_creation/sequential_models.livemd +++ b/guides/model_creation/sequential_models.livemd @@ -1,12 +1,9 @@ - - # Sequential models ```elixir Mix.install([ - {:axon, github: "elixir-nx/axon"}, - {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true}, - {:kino, "~> 0.7.0"} + {:axon, ">= 0.5.0"}, + {:kino, ">= 0.9.0"} ]) ``` @@ -67,16 +64,16 @@ Axon.Display.as_graph(model, template) ```mermaid graph TD; 3[/"data (:input) {2, 16}"/]; -6["dense_0 (:dense) {2, 32}"]; -7["relu_0 (:relu) {2, 32}"]; -8["dropout_0 (:dropout) {2, 32}"]; -11["dense_1 (:dense) {2, 1}"]; -12["softmax_0 (:softmax) {2, 1}"]; -11 --> 12; -8 --> 11; +4["dense_0 (:dense) {2, 32}"]; +5["relu_0 (:relu) {2, 32}"]; +6["dropout_0 (:dropout) {2, 32}"]; +7["dense_1 (:dense) {2, 1}"]; +8["softmax_0 (:softmax) {2, 1}"]; 7 --> 8; 6 --> 7; -3 --> 6; +5 --> 6; +4 --> 5; +3 --> 4; ``` Your model is more involved and as a result so is the execution graph! Now, using the same constructs from the last section, you can build and run your model: @@ -88,8 +85,8 @@ Your model is more involved and as a result so is the execution graph! Now, usin ``` -{#Function<137.55749718/2 in Nx.Defn.wrap_arity/2>, - #Function<137.55749718/2 in Nx.Defn.wrap_arity/2>} +{#Function<135.109794929/2 in Nx.Defn.Compiler.fun/2>, + #Function<135.109794929/2 in Nx.Defn.Compiler.fun/2>} ``` ```elixir @@ -108,8 +105,8 @@ params = init_fn.(template, %{}) "kernel" => #Nx.Tensor< f32[16][32] [ - [-0.25727564096450806, -0.31299564242362976, -0.1557893306016922, -0.3321501314640045, 0.34875044226646423, 0.15635445713996887, 0.25805917382240295, 0.316285640001297, 0.29047688841819763, -0.09108144044876099, 0.2781231701374054, 0.21326711773872375, -0.29581472277641296, -0.3105146288871765, -0.11265464127063751, 0.054490894079208374, -0.22294805943965912, 0.23276928067207336, 0.06426036357879639, 0.12059605121612549, -0.24530324339866638, 0.061366915702819824, 0.17463091015815735, -0.2774006724357605, 0.2621242105960846, 0.19262376427650452, -0.10884760320186615, -0.3156566321849823, 0.104307621717453, -0.22591334581375122, -0.09672778844833374, -0.18450938165187836], - [-0.32328563928604126, -0.3434811234474182, -0.3464450538158417, 0.14756330847740173, 0.010595977306365967, 0.32808688282966614, -0.3048470616340637, 0.011142522096633911, 0.10394474864006042, 0.04501914978027344, -0.26296690106391907, -0.1051199734210968, -0.0060880184173583984, 0.22103646397590637, -0.3040429651737213, ...], + [0.21433714032173157, -0.04525795578956604, 0.32405969500541687, -0.06933712959289551, -0.24735209345817566, 0.1957167088985443, -0.2714379131793976, -0.34026962518692017, 0.03781759738922119, -0.16317953169345856, -0.1272507756948471, -0.08459293842315674, 0.20401403307914734, 0.26613888144493103, -0.3234696388244629, 0.295791357755661, 0.29850414395332336, -0.22220905125141144, -0.33034151792526245, 0.32582345604896545, -0.19104702770709991, -0.3434463143348694, 0.031930625438690186, 0.32875487208366394, 0.17335721850395203, -0.0336279571056366, -0.02203202247619629, -0.30805233120918274, 0.01472097635269165, 0.293319970369339, 0.17995354533195496, 0.09916016459465027], + [-0.33202630281448364, -0.09507006406784058, -0.12178492546081543, -0.005500674247741699, -0.24997547268867493, 0.31693217158317566, 0.31857630610466003, 0.13662374019622803, 0.11216515302658081, -0.2711845338344574, -0.18932600319385529, -0.10278302431106567, -0.1910824328660965, -0.15239068865776062, 0.2373746931552887, ...], ... ] > @@ -122,38 +119,38 @@ params = init_fn.(template, %{}) "kernel" => #Nx.Tensor< f32[32][1] [ - [-0.379288911819458], - [-0.05532142519950867], - [-0.07836392521858215], - [0.41381680965423584], - [0.33221137523651123], - [0.23515504598617554], - [-0.40667685866355896], - [-0.3503745198249817], - [0.2631032466888428], - [-0.13176566362380981], - [-0.3811171054840088], - [0.24656128883361816], - [0.17257028818130493], - [0.3528350591659546], - [0.4112042784690857], - [0.056196123361587524], - [0.138421893119812], - [-0.38378745317459106], - [-0.044070273637771606], - [0.11507803201675415], - [-0.3125251233577728], - [-0.11389034986495972], - [-0.27444711327552795], - [-0.30974721908569336], - [-0.3695589303970337], - [0.3146793246269226], - [0.005854517221450806], - [-0.03735968470573425], - [0.02763468027114868], - [-0.10707724094390869], - [0.10824829339981079], - [0.29013824462890625] + [-0.22355356812477112], + [0.09599864482879639], + [0.06676572561264038], + [-0.06866732239723206], + [0.1822824478149414], + [0.1860904097557068], + [-0.3795042335987091], + [-0.18182222545146942], + [0.4170041084289551], + [0.1812545657157898], + [0.18777817487716675], + [-0.15454193949699402], + [0.16937363147735596], + [-0.007449895143508911], + [0.421792209148407], + [-0.3314356803894043], + [-0.29834187030792236], + [0.3285354971885681], + [0.034806013107299805], + [0.1091541051864624], + [-0.385672390460968], + [0.004853636026382446], + [0.3387643098831177], + [0.03320261836051941], + [0.3905656933784485], + [-0.3835979700088501], + [-0.06302008032798767], + [0.03648516535758972], + [0.24170255661010742], + [0.01687285304069519], + [-0.017035305500030518], + [-0.2674438953399658] ] > } diff --git a/guides/model_creation/your_first_axon_model.livemd b/guides/model_creation/your_first_axon_model.livemd index 88c83def..ea340843 100644 --- a/guides/model_creation/your_first_axon_model.livemd +++ b/guides/model_creation/your_first_axon_model.livemd @@ -1,12 +1,9 @@ - - # Your first Axon model ```elixir Mix.install([ - {:axon, github: "elixir-nx/axon"}, - {:nx, github: "elixir-nx/nx", sparse: "nx", override: true}, - {:kino, "~> 0.7.0"} + {:axon, ">= 0.5.0"}, + {:kino, ">= 0.9.0"} ]) ``` @@ -64,8 +61,8 @@ You can see this in action by actually executing your model. You can build the ` ``` -{#Function<137.55749718/2 in Nx.Defn.wrap_arity/2>, - #Function<137.55749718/2 in Nx.Defn.wrap_arity/2>} +{#Function<135.109794929/2 in Nx.Defn.Compiler.fun/2>, + #Function<135.109794929/2 in Nx.Defn.Compiler.fun/2>} ``` Notice that `Axon.build/2` returns a tuple of `{init_fn, predict_fn}`. `init_fn` has the signature: diff --git a/guides/model_execution/accelerating_axon.livemd b/guides/model_execution/accelerating_axon.livemd index 5a13acbd..c0f3f040 100644 --- a/guides/model_execution/accelerating_axon.livemd +++ b/guides/model_execution/accelerating_axon.livemd @@ -1,16 +1,12 @@ - - # Accelerating Axon ```elixir Mix.install([ - {:axon, github: "elixir-nx/axon"}, - {:exla, "~> 0.3.0", github: "elixir-nx/nx", sparse: "exla"}, - {:torchx, github: "elixir-nx/nx", sparse: "torchx"}, - {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true}, - {:benchee, github: "akoutmos/benchee", branch: :adding_table_support}, - {:kino_benchee, github: "livebook-dev/kino_benchee"}, - {:kino, "~> 0.7.0", override: true} + {:axon, ">= 0.5.0"}, + {:exla, ">= 0.5.0"}, + {:torchx, ">= 0.5.0"}, + {:benchee, "~> 1.1"}, + {:kino, ">= 0.9.0", override: true} ]) ``` @@ -20,16 +16,9 @@ Mix.install([ :ok ``` -## Using Nx Compilers in Axon - -Axon is built entirely on top of Nx's numerical definitions `defn`. Functions declared with `defn` tell Nx to use *just-in-time compilation* to compile and execute the given numerical definition with an available Nx compiler. Numerical definitions enable acceleration on CPU/GPU/TPU via pluggable compilers. At the time of this writing, Nx has 2 officially supported compiler/backends on top of the default `BinaryBackend`: - -1. EXLA - Acceleration via Google's [XLA project](https://www.tensorflow.org/xla) -2. TorchX - Bindings to [LibTorch](https://pytorch.org/cppdocs/) - -By default, Nx and Axon run all computations using the `BinaryBackend` which is a pure Elixir implementation of various numerical routines. The `BinaryBackend` is guaranteed to run wherever an Elixir installation runs; however, it is **very** slow. Due to the computational expense of neural networks, you should basically never use the `BinaryBackend` and instead opt for one of the available accelerated libraries. +## Using Nx Backends in Axon -There are several ways to make use of Nx compilers from within Axon. First, create a simple model for benchmarking purposes: +Nx provides two mechanisms for accelerating your neural networks: backends and compilers. Before we learn how to effectively use them, first let's create a simple model for benchmarking purposes: ```elixir model = @@ -50,25 +39,32 @@ model = > ``` -By default, Axon will respect the default `defn` compilation options. You can set compilation options globally or per-process: +Backends are where your tensors (your neural network inputs and parameters) are located. By default, Nx and Axon run all computations using the `Nx.BinaryBackend` which is a pure Elixir implementation of various numerical routines. The `Nx.BinaryBackend` is guaranteed to run wherever an Elixir installation runs; however, it is **very** slow. Due to the computational expense of neural networks, you should basically never use the `Nx.BinaryBackend` and instead opt for one of the available accelerated libraries. At the time of writing, Nx officially supports two of them: -```elixir -# Sets the global compilation options -Nx.Defn.global_default_options(compiler: EXLA) -# Sets the process-level compilation options -Nx.Defn.default_options(compiler: EXLA) -``` +1. EXLA - Acceleration via Google's [XLA project](https://www.tensorflow.org/xla) +2. TorchX - Bindings to [LibTorch](https://pytorch.org/cppdocs/) - +Axon will respect the global and process-level Nx backend configuration. Compilers are covered more in-depth in the second half of this example. You can set the default backend using the following APIs: -``` -[compiler: EXLA] +```elixir +# Sets the global compilation options (for all Elixir processes) +Nx.global_default_backend(Torchx.Backend) +# OR +Nx.global_default_backend(EXLA.Backend) + +# Sets the process-level compilation options (current process only) +Nx.default_backend(Torchx.Backend) +# OR +Nx.default_backend(EXLA.Backend) ``` -When you call `Axon.build/2`, Axon automatically marks your initialization and forward functions as JIT compiled functions. When you invoke them, they will compile a specialized version of the function using your default compiler options: +Now all tensors and operations on them will run on the configured backend: ```elixir -inputs = Nx.random_uniform({2, 128}) +{inputs, _next_key} = + Nx.Random.key(9999) + |> Nx.Random.uniform(shape: {2, 128}) + {init_fn, predict_fn} = Axon.build(model) params = init_fn.(inputs, %{}) predict_fn.(params, inputs) @@ -76,20 +72,10 @@ predict_fn.(params, inputs) -``` - -10:34:02.503 [info] XLA service 0x7fbd5468c170 initialized for platform Host (this does not guarantee that XLA will be used). Devices: - -10:34:02.785 [info] StreamExecutor device (0): Host, Default Version - -``` - - - ``` #Nx.Tensor< + EXLA.Backend f32[2][1] - EXLA.Backend [ [1.0], [1.0] @@ -97,25 +83,27 @@ predict_fn.(params, inputs) > ``` -Notice that the inspected tensor indicates the computation has been dispatched to EXLA and the tensor's data points to an EXLA buffer. +As you swap backends above, you will get tensors allocated on different backends as results. You should be careful using multiple backends in the same project as attempting to mix tensors between backends may result in strange performance bugs or errors, as Nx will require you to explicitly convert between backends. - +With most larger models, using a compiler will bring more performance benefits in addition to the backend. -If you feel like setting the global or process-level compilation options is too intrusive, you can opt for more explicit behavior in a few ways. First, you can specify the JIT compiler when you build the model: +## Using Nx Compilers in Axon -```elixir -# Set back to defaults -Nx.Defn.global_default_options([]) -Nx.Defn.default_options([]) -``` +Axon is built entirely on top of Nx's numerical definitions `defn`. Functions declared with `defn` tell Nx to use *just-in-time compilation* to compile and execute the given numerical definition with an available Nx compiler. Numerical definitions enable acceleration on CPU/GPU/TPU via pluggable compilers. At the time of this writing, only EXLA supports a compiler in addition to its backend. - +When you call `Axon.build/2`, Axon can automatically mark your initialization and forward functions as JIT compiled functions. First let's make sure we are using the EXLA backend: +```elixir +Nx.default_backend(EXLA.Backend) ``` -[compiler: EXLA] -``` + +And now let's build another model, this time passing the EXLA compiler as an option: ```elixir +{inputs, _next_key} = + Nx.Random.key(9999) + |> Nx.Random.uniform(shape: {2, 128}) + {init_fn, predict_fn} = Axon.build(model, compiler: EXLA) params = init_fn.(inputs, %{}) predict_fn.(params, inputs) @@ -123,10 +111,28 @@ predict_fn.(params, inputs) +``` + +15:39:26.463 [info] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero + +15:39:26.473 [info] XLA service 0x7f3488329030 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices: + +15:39:26.473 [info] StreamExecutor device (0): NVIDIA GeForce RTX 3050 Ti Laptop GPU, Compute Capability 8.6 + +15:39:26.473 [info] Using BFC allocator. + +15:39:26.473 [info] XLA backend allocating 3605004288 bytes on device 0 for BFCAllocator. + +15:39:28.272 [info] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once. + +``` + + + ``` #Nx.Tensor< f32[2][1] - EXLA.Backend + EXLA.Backend [ [1.0], [1.0] @@ -147,7 +153,7 @@ exla_predict_fn = EXLA.jit(predict_fn) ``` -#Function<136.40088443/2 in Nx.Defn.wrap_arity/2> +#Function<135.109794929/2 in Nx.Defn.Compiler.fun/2> ``` ```elixir @@ -272,88 +278,12 @@ elixir predict 91.09 KB - 8.32x memory usage +80.14 KB **All measurements for memory usage were the same** ``` -Notice how calls to EXLA variants are significantly faster than their Elixir counterparts. These speedups become more pronounced with more complex models and workflows. +Notice how calls to EXLA variants are significantly faster. These speedups become more pronounced with more complex models and workflows. It's important to note that in order to use a given library as an Nx compiler, it must implement the Nx compilation behaviour. For example, you cannot invoke Torchx as an Nx compiler because it does not support JIT compilation at this time. -## Using Nx Backends in Axon - -In addition to JIT-compilation, Axon also supports the usage of Nx backends. Nx backends are slightly different than Nx compilers in the sense that they do not fuse calls within numerical definitions. Backends are more eager, sacrificing a bit of performance for convenience. Torchx and EXLA both support running via backends. - -Again, Axon will respect the global and process-level Nx backend configuration options. You can set the default backend using: - -```elixir -# Global default backend -Nx.global_default_backend(Torchx.Backend) -# Process default backend -Nx.default_backend(Torchx.Backend) -``` - - - -``` -{Nx.BinaryBackend, []} -``` - -Now when you invoke model functions, it will run them with the given backend: - -```elixir -{init_fn, predict_fn} = Axon.build(model) -params = init_fn.(inputs, %{}) -predict_fn.(params, inputs) -``` - - - -``` -#Nx.Tensor< - f32[2][1] - Torchx.Backend(cpu) - [ - [1.0], - [1.0] - ] -> -``` - -```elixir -# Global default backend -Nx.global_default_backend(EXLA.Backend) -# Process default backend -Nx.default_backend(EXLA.Backend) -``` - - - -``` -{Torchx.Backend, []} -``` - -```elixir -{init_fn, predict_fn} = Axon.build(model) -params = init_fn.(inputs, %{}) -predict_fn.(params, inputs) -``` - - - -``` -#Nx.Tensor< - f32[2][1] - EXLA.Backend - [ - [1.0], - [1.0] - ] -> -``` - -Unlike with JIT-compilation, you must set the backend at the top-level in order to invoke it. You should be careful using multiple backends in the same project as attempting to mix tensors between backends may result in strange performance bugs or errors. - -With most larger models, using a JIT compiler will be more performant than using a backend. - ## A Note on CPUs/GPUs/TPUs While Nx mostly tries to standardize behavior across compilers and backends, some behaviors are backend-specific. For example, the API for choosing an acceleration platform (e.g. CUDA/ROCm/TPU) is backend-specific. You should refer to your chosen compiler or backend's documentation for information on targeting various accelerators. Typically, you only need to change a few configuration options and your code will run as-is on a chosen accelerator. diff --git a/guides/model_execution/training_and_inference_mode.livemd b/guides/model_execution/training_and_inference_mode.livemd index b5190506..74e6de49 100644 --- a/guides/model_execution/training_and_inference_mode.livemd +++ b/guides/model_execution/training_and_inference_mode.livemd @@ -1,11 +1,8 @@ - - # Training and inference mode ```elixir Mix.install([ - {:axon, github: "elixir-nx/axon"}, - {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true} + {:axon, ">= 0.5.0"} ]) ``` @@ -40,8 +37,8 @@ predict_fn.(params, inputs) #Nx.Tensor< f32[2][1] [ - [-0.6138466000556946], - [-0.8409845232963562] + [0.6900148391723633], + [1.1159517765045166] ] > ``` @@ -60,8 +57,8 @@ predict_fn.(params, inputs) #Nx.Tensor< f32[2][1] [ - [0.7551136016845703], - [0.448221355676651] + [-1.1250841617584229], + [-1.161189317703247] ] > ``` @@ -89,7 +86,14 @@ predict_fn.(params, inputs) [0.0] ] >, - state: %{} + state: %{ + "dropout_0" => %{ + "key" => #Nx.Tensor< + u32[2] + [309162766, 2699730300] + > + } + } } ``` @@ -115,19 +119,19 @@ predict_fn.(params, inputs) prediction: #Nx.Tensor< f32[2][1] [ - [0.03675001487135887], - [-0.03674999624490738] + [0.4891311526298523], + [-0.4891311228275299] ] >, state: %{ "batch_norm_0" => %{ "mean" => #Nx.Tensor< f32[4] - [0.8784151673316956, 0.7386987209320068, 0.663623571395874, 0.8947045803070068] + [0.525083601474762, 0.8689039349555969, 0.03931800276041031, 0.0021854371298104525] >, "var" => #Nx.Tensor< f32[4] - [0.10050597041845322, 0.11294332146644592, 0.16061438620090485, 0.10003116726875305] + [0.13831248879432678, 0.10107331722974777, 0.10170891880989075, 0.10000484436750412] > } } diff --git a/guides/serialization/onnx_to_axon.livemd b/guides/serialization/onnx_to_axon.livemd index 64c7a51b..6c18af36 100644 --- a/guides/serialization/onnx_to_axon.livemd +++ b/guides/serialization/onnx_to_axon.livemd @@ -3,31 +3,34 @@ ```elixir Mix.install( [ - {:nx, "~> 0.3"}, - {:axon, "~> 0.2"}, - {:exla, "~> 0.3"}, - {:axon_onnx, "~> 0.2"}, - {:stb_image, "~> 0.5"}, - {:kino, "~> 0.7.0"} - ], - # change to "cuda111" for Nvidia GPU - system_env: %{"XLA_TARGET" => xla_target} + {:axon, ">= 0.5.0"}, + {:exla, ">= 0.5.0"}, + {:axon_onnx, ">= 0.4.0"}, + {:stb_image, ">= 0.6.0"}, + {:kino, ">= 0.9.0"}, + {:req, ">= 0.3.8"} + ] + # for Nvidia GPU change to "cuda111" for CUDA 11.1+ or "cuda118" for CUDA 11.8 + # CUDA 12.x not supported by XLA + # or you can put this value in ENV variables in Livebook settings + # XLA_TARGET=cuda111 + # system_env: %{"XLA_TARGET" => xla_target} ) ``` ## Converting an ONNX model into Axon -Axon is a new machine learning capability, specific to Elixir. We would like to take +Axon is a new machine learning capability, specific to Elixir. We would like to take advantage of a large amount of models that have been written in other languages and -machine learning frameworks. Let's take a look at how we could use a model developed +machine learning frameworks. Let's take a look at how we could use a model developed in another language. Converting models developed by data scientists into a production capable implementation is a -challenge for all languages and frameworks. [ONNX](https://onnx.ai/) is an interchange +challenge for all languages and frameworks. [ONNX](https://onnx.ai/) is an interchange format that allows models written in one language or framework to be converted into another language and framework. -The source model must use constructs mapped into ONNX. Also, the destination framework must +The source model must use constructs mapped into ONNX. Also, the destination framework must support the model's ONNX constructs. From an Elixir focus, we are interested in ONNX models that [axon_onnx](https://github.com/elixir-nx/axon_onnx) can convert into Axon models. @@ -38,7 +41,7 @@ that [axon_onnx](https://github.com/elixir-nx/axon_onnx) can convert into Axon m Elixir can get access to thousands of public models and your organization may have private models -written in other languages and frameworks. Axon will be hard pressed to quickly repeat the +written in other languages and frameworks. Axon will be hard pressed to quickly repeat the countless person-hours spent on developing models in other languages like Tensorflow and PyTorch. However, if the model can be converted into ONNX and then into Axon, we can directly run the model in Elixir. @@ -50,22 +53,22 @@ in Elixir. Axon runs on top of [Nx (Numerical Elixir)](https://hexdocs.pm/nx). Nx has backends for -both Google's XLA (via EXLA) and PyTorch (via Torchx). In this guide, we will use EXLA. -We'll also convert from an ONNX model into an Axon model using +both Google's XLA (via EXLA) and PyTorch (via Torchx). In this guide, we will use EXLA. +We'll also convert from an ONNX model into an Axon model using [`axon_onnx`](https://github.com/elixir-nx/axon_onnx). You can find all dependencies in the installation cell at the top of the notebook. -In there, you will also find the `XLA_TARGET` environment variable whick you can set -to "cuda111" or "rocm" if you have any of those GPUs available. Let's also configure +In there, you will also find the `XLA_TARGET` environment variable which you can set +to "cuda111" or "rocm" if you have any of those GPUs available. Let's also configure Nx to store tensors in EXLA by default: ```elixir -Nx.default_backend(EXLA.Backend) +# Nx.default_backend(EXLA.Backend) ``` -We'll also need local access to ONNX files. For this notebook, the models/onnx folder -contains the ONNX model file. This notebook assumes the output file location will be -in models axon. Copy your ONNX model files into the models/onnx folder. +We'll also need local access to ONNX files. For this notebook, the models/onnx folder +contains the ONNX model file. This notebook assumes the output file location will be +in models axon. Copy your ONNX model files into the models/onnx folder. This opinionated module presents a simple API for loading in an ONNX file and saving the converted Axon model in the provided directory. This API will allow us to @@ -82,7 +85,7 @@ defmodule OnnxToAxon do ## Examples - iex> OnnxToAxon.onnx_axon(path_to_onnx_file, path_to_axon_dir) + OnnxToAxon.onnx_axon(path_to_onnx_file, path_to_axon_dir) """ def onnx_axon(path_to_onnx_file, path_to_axon_dir) do @@ -112,12 +115,12 @@ The ONNX models were trained in Fast.ai (PyTorch) using the following notebooks: * https://github.com/meanderingstream/fastai_course22/blob/main/saving-a-basic-fastai-model-in-onnx.ipynb * https://github.com/meanderingstream/fastai_course22/blob/main/saving-cat-dog-breed-fastai-model-in-onnx.ipynb -To repeat this notebook, the onnx files for this notebook can be found on huggingface hub. Download the onnx models from: +To repeat this notebook, the onnx files for this notebook can be found on huggingface hub. Download the onnx models from: * https://huggingface.co/ScottMueller/Cats_v_Dogs.ONNX * https://huggingface.co/ScottMueller/Cat_Dog_Breeds.ONNX -Download the files and place them in a directory of your choice. By default, we will assume you downloaded them to the same directory as the notebook: +Download the files and place them in a directory of your choice. By default, we will assume you downloaded them to the same directory as the notebook: ```elixir File.cd!(__DIR__) @@ -126,46 +129,46 @@ File.cd!(__DIR__) Now let's convert an ONNX model into Axon ```elixir -path_to_onnx_file = "models/onnx/cats_v_dogs.onnx" -path_to_axon_dir = "models/axon" +path_to_onnx_file = "cats_v_dogs.onnx" +path_to_axon_dir = "." OnnxToAxon.onnx_axon(path_to_onnx_file, path_to_axon_dir) ``` ```elixir -path_to_onnx_file = "models/onnx/cat_dog_breeds.onnx" -path_to_axon_dir = "models/axon" +path_to_onnx_file = "cat_dog_breeds.onnx" +path_to_axon_dir = "." OnnxToAxon.onnx_axon(path_to_onnx_file, path_to_axon_dir) ``` ## Inference on ONNX derived models -To run inference on the model, you'll need 10 images focused on cats or dogs. You can download the images used in training the model at: +To run inference on the model, you'll need 10 images focused on cats or dogs. You can download the images used in training the model at: "https://s3.amazonaws.com/fast-ai-imageclas/oxford-iiit-pet.tgz" -Or you can find or use your own images. In this notebook, we are going to use the local copies of the Oxford Pets dataset that was used in training the model. +Or you can find or use your own images. In this notebook, we are going to use the local copies of the Oxford Pets dataset that was used in training the model. Let's load the Axon model. ```elixir -cats_v_dogs = File.read!("models/axon/cats_v_dogs.axon") +cats_v_dogs = File.read!("cats_v_dogs.axon") {cats_v_dogs_model, cats_v_dogs_params} = Axon.deserialize(cats_v_dogs) ``` -We need a tensor representation of an image. Let's start by looking at samples of +We need a tensor representation of an image. Let's start by looking at samples of our data. ```elixir -File.read!("data/oxford-iiit-pet/images/havanese_71.jpg") +File.read!("oxford-iiit-pet/images/havanese_71.jpg") |> Kino.Image.new(:jpeg) ``` To manipulate the images, we will use the `StbImage` library: ```elixir -{:ok, img} = StbImage.read_file("data/oxford-iiit-pet/images/havanese_71.jpg") +{:ok, img} = StbImage.read_file("oxford-iiit-pet/images/havanese_71.jpg") %StbImage{data: binary, shape: shape, type: type} = StbImage.resize(img, 224, 224) ``` @@ -191,7 +194,7 @@ Next we resize the images: ```elixir resized_images = Enum.map(file_names, fn file_name -> - ("data/oxford-iiit-pet/images/" <> file_name) + ("oxford-iiit-pet/images/" <> file_name) |> IO.inspect(label: file_name) |> StbImage.read_file!() |> StbImage.resize(224, 224) @@ -219,14 +222,14 @@ defmodule Predictions do ## Examples - iex> Predictions.sindle_label_prediction(path_to_onnx_file, path_to_axon_dir) - ["dog", "cat", "dog"] + # iex> Predictions.sindle_label_prediction(path_to_onnx_file, path_to_axon_dir) + # ["dog", "cat", "dog"] """ def single_label_classification(predictions_batch, vocabulary) do IO.inspect(Nx.shape(predictions_batch), label: "predictions batch shape") - for prediction_tensor <- Nx.to_batched(predictions_batch) do + for prediction_tensor <- Nx.to_batched(predictions_batch, 1) do {_prediction_value, prediction_label} = prediction_tensor |> Nx.to_flat_list() @@ -308,7 +311,7 @@ cat_dog_vocabulary = [ ``` ```elixir -cat_dog_breeds = File.read!("models/axon/cat_dog_breeds.axon") +cat_dog_breeds = File.read!("cat_dog_breeds.axon") {cat_dog_breeds_model, cat_dog_breeds_params} = Axon.deserialize(cat_dog_breeds) ``` diff --git a/guides/training_and_evaluation/custom_models_loss_optimizers.livemd b/guides/training_and_evaluation/custom_models_loss_optimizers.livemd index f834be0e..385c56ad 100644 --- a/guides/training_and_evaluation/custom_models_loss_optimizers.livemd +++ b/guides/training_and_evaluation/custom_models_loss_optimizers.livemd @@ -343,11 +343,11 @@ Epoch: 0, Batch: 1000, loss: 0.1098235 ## Using custom optimizers in training loops -As you might expect, it's also possible to customize the optimizer passed to `Axon.Loop.trainer/3`. If you read the `Axon.Updates` documentation, you'll learn that optimizers are actually represented as the tuple `{init_fn, update_fn}` where `init_fn` initializes optimizer state from model state and `update_fn` scales gradients from optimizer state, gradients, and model state. +As you might expect, it's also possible to customize the optimizer passed to `Axon.Loop.trainer/3`. If you read the `Polaris.Updates` documentation, you'll learn that optimizers are actually represented as the tuple `{init_fn, update_fn}` where `init_fn` initializes optimizer state from model state and `update_fn` scales gradients from optimizer state, gradients, and model state. You likely won't have to implement a custom optimizer; however, you should know how to construct optimizers with different hyperparameters and how to apply different modifiers to different optimizers to customize the optimization process. -When you specify an optimizer as an atom in `Axon.Loop.trainer/3`, it maps directly to an optimizer declared in `Axon.Optimizers`. You can instead opt to declare your optimizer directly. This is most useful for controlling things like the learning rate and various optimizer hyperparameters: +When you specify an optimizer as an atom in `Axon.Loop.trainer/3`, it maps directly to an optimizer declared in `Polaris.Optimizers`. You can instead opt to declare your optimizer directly. This is most useful for controlling things like the learning rate and various optimizer hyperparameters: ```elixir train_data = @@ -365,7 +365,7 @@ model = |> Axon.relu() |> Axon.dense(1) -optimizer = {_init_optimizer_fn, _update_fn} = Axon.Optimizers.sgd(1.0e-3) +optimizer = {_init_optimizer_fn, _update_fn} = Polaris.Optimizers.sgd(learning_rate: 1.0e-3) model |> Axon.Loop.trainer(:mean_squared_error, optimizer) diff --git a/guides/training_and_evaluation/instrumenting_loops_with_metrics.livemd b/guides/training_and_evaluation/instrumenting_loops_with_metrics.livemd index 5f3db39f..a9a74cde 100644 --- a/guides/training_and_evaluation/instrumenting_loops_with_metrics.livemd +++ b/guides/training_and_evaluation/instrumenting_loops_with_metrics.livemd @@ -1,11 +1,8 @@ - - # Instrumenting loops with metrics ```elixir Mix.install([ - {:axon, github: "elixir-nx/axon"}, - {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true} + {:axon, ">= 0.5.0"} ]) ``` @@ -38,28 +35,28 @@ loop = ``` #Axon.Loop< + metrics: %{ + "loss" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>, + #Function<9.37390314/2 in Axon.Loop.build_loss_fn/1>}, + "mean_absolute_error" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>, + :mean_absolute_error} + }, handlers: %{ completed: [], epoch_completed: [ - {#Function<23.20267452/1 in Axon.Loop.log/5>, - #Function<5.20267452/1 in Axon.Loop.build_filter_fn/1>} + {#Function<27.37390314/1 in Axon.Loop.log/3>, + #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>} ], epoch_halted: [], epoch_started: [], halted: [], iteration_completed: [ - {#Function<23.20267452/1 in Axon.Loop.log/5>, - #Function<3.20267452/1 in Axon.Loop.build_filter_fn/1>} + {#Function<27.37390314/1 in Axon.Loop.log/3>, + #Function<64.37390314/2 in Axon.Loop.build_filter_fn/1>} ], iteration_started: [], started: [] }, - metrics: %{ - "loss" => {#Function<12.6031754/3 in Axon.Metrics.running_average/1>, - #Function<6.20267452/2 in Axon.Loop.build_loss_fn/1>}, - "mean_absolute_error" => {#Function<12.6031754/3 in Axon.Metrics.running_average/1>, - :mean_absolute_error} - }, ... > ``` @@ -71,7 +68,11 @@ When you run a loop with metrics, Axon will aggregate that metric over the cours ```elixir train_data = Stream.repeatedly(fn -> - xs = Nx.random_normal({8, 1}) + {xs, _next_key} = + :random.uniform(9999) + |> Nx.Random.key() + |> Nx.Random.normal(shape: {8, 1}) + ys = Nx.sin(xs) {xs, ys} end) @@ -82,7 +83,7 @@ Axon.Loop.run(loop, train_data, %{}, iterations: 1000) ``` -Epoch: 0, Batch: 1000, loss: 0.0646209 mean_absolute_error: 0.1720028 +Epoch: 0, Batch: 950, loss: 0.0590630 mean_absolute_error: 0.1463431 ``` @@ -92,46 +93,46 @@ Epoch: 0, Batch: 1000, loss: 0.0646209 mean_absolute_error: 0.1720028 "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] - [-0.2462722808122635, 0.18984302878379822, 0.0016971784643828869, 0.19568635523319244, 0.33571094274520874, 0.07703055441379547, 0.29576605558395386, 0.14511419832706451] + [-0.015203186310827732, 0.1997198462486267, 0.09740892797708511, -0.007404750678688288, 0.11397464573383331, 0.3608400523662567, 0.07219560444355011, -0.06638865917921066] >, "kernel" => #Nx.Tensor< f32[1][8] [ - [-0.7807592749595642, -0.17303702235221863, 0.43004679679870605, -0.46043306589126587, -0.6577866077423096, 0.7490359544754028, -0.5164405703544617, -0.77418452501297] + [0.07889414578676224, 0.30445051193237305, 0.1377921849489212, 0.015571207739412785, 0.7115736603736877, -0.6404237151145935, 0.25553327798843384, 0.057831913232803345] ] > }, "dense_1" => %{ "bias" => #Nx.Tensor< f32[4] - [0.027583779767155647, 0.4279942214488983, -0.10632428526878357, -0.05149337649345398] + [0.10809992998838425, 0.0, 0.47775307297706604, -0.1641010195016861] >, "kernel" => #Nx.Tensor< f32[8][4] [ - [-0.5688502192497253, -0.49978527426719666, 0.0660838857293129, 0.30804139375686646], - [0.21578946709632874, 0.4183472990989685, 0.530754566192627, 0.1742597073316574], - [-0.17872463166713715, -0.08955764025449753, -0.7048909664154053, 0.053243234753608704], - [-0.41064000129699707, 0.3491946756839752, 0.3753710091114044, 0.6630277037620544], - [-0.1781950145959854, 0.5766432881355286, 0.5829672813415527, -0.34879636764526367], - [-0.026939965784549713, -0.44429031014442444, -0.12619371712207794, 0.0030224998481571674], - [0.411702424287796, 0.3330642879009247, -0.5062007308006287, -0.0731467455625534], - [-0.41474586725234985, 0.23881299793720245, 0.3847745358943939, -0.5769480466842651] + [-0.040330830961465836, -0.36995524168014526, 0.001599793671630323, 0.6012424826622009], + [0.21044284105300903, -0.39482879638671875, -0.5866784453392029, 0.15573620796203613], + [-0.09234675765037537, 0.27758270502090454, -0.6663768291473389, 0.6017312407493591], + [-0.4454570412635803, 0.1304328441619873, -0.31381309032440186, 0.1906844824552536], + [0.3460652530193329, -0.3017694056034088, -0.1680794507265091, -0.47811293601989746], + [0.28633055090904236, -0.34003201127052307, 0.6202688813209534, 0.18027405440807343], + [0.5729941129684448, 0.32222074270248413, 0.20647864043712616, 0.02462891861796379], + [-0.13146185874938965, -0.06700503826141357, 0.6600251793861389, -0.06442582607269287] ] > }, "dense_2" => %{ "bias" => #Nx.Tensor< f32[1] - [0.8004998564720154] + [0.4863035976886749] >, "kernel" => #Nx.Tensor< f32[4][1] [ - [-0.40993982553482056], - [-1.0208697319030762], - [0.18116380274295807], - [-0.8320646286010742] + [0.41491562128067017], + [-0.948100209236145], + [-1.2559744119644165], + [1.0097774267196655] ] > } @@ -150,7 +151,7 @@ model ``` -Epoch: 0, Batch: 1000, loss: 0.0559179 model error: 0.1430965 +Epoch: 0, Batch: 950, loss: 0.0607362 model error: 0.1516546 ``` @@ -160,46 +161,46 @@ Epoch: 0, Batch: 1000, loss: 0.0559179 model error: 0.1430965 "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] - [-0.2884136438369751, -0.016403740271925926, 0.30548375844955444, 0.2799474000930786, -0.017874717712402344, 0.3168976306915283, -0.10385002940893173, -0.18653006851673126] + [0.2577069401741028, 0.16761353611946106, 0.11587327718734741, 0.28539595007896423, -0.2071152776479721, -0.02039412036538124, -0.11152249574661255, 0.2389308214187622] >, "kernel" => #Nx.Tensor< f32[1][8] [ - [-0.44000443816185, 0.6495574712753296, -0.5427255034446716, -0.795007050037384, -0.0035864184610545635, -0.5102121233940125, 0.10152970999479294, -0.3913733959197998] + [-0.1265750676393509, 0.6902633309364319, -0.10233660787343979, -0.2544037103652954, -0.26677289605140686, -0.31035077571868896, 0.3845033347606659, -0.33032187819480896] ] > }, "dense_1" => %{ "bias" => #Nx.Tensor< f32[4] - [-0.24588409066200256, -0.05674195662140846, -0.08545850962400436, 0.27886852622032166] + [0.0, 0.16427761316299438, 0.02123815007507801, 0.22260485589504242] >, "kernel" => #Nx.Tensor< f32[8][4] [ - [0.6334101557731628, -0.44550418853759766, 0.34385600686073303, 0.24886265397071838], - [-0.5474148988723755, 0.09881290793418884, 0.14616712927818298, 0.8087677359580994], - [-0.15381869673728943, 0.5322079658508301, -0.6275551915168762, -0.4207017421722412], - [0.4673740863800049, 0.5706797242164612, 0.44344833493232727, -0.5382705926895142], - [0.6662552356719971, -0.3875215947628021, -0.5359503626823425, -0.6198058724403381], - [-0.2842515707015991, 0.2379448264837265, 0.581102728843689, -0.5942302346229553], - [0.039275627583265305, 0.6341984272003174, -0.10589496046304703, -0.3522306978702545], - [0.4015151560306549, -0.15162920951843262, -0.3449919819831848, 0.21970798075199127] + [-0.3859425485134125, 0.49959924817085266, -0.34108400344848633, 0.6222119331359863], + [-0.43326857686042786, -0.42272067070007324, 0.04245679825544357, -0.4357914626598358], + [-0.3065953850746155, 0.587925374507904, 0.2960704267024994, -0.31594154238700867], + [-0.35595524311065674, 0.6649497747421265, 0.4832736849784851, 0.3025558590888977], + [0.048333823680877686, -0.17023107409477234, 0.09139639884233475, -0.6511918902397156], + [-0.12099027633666992, -0.02014642395079136, 0.025831595063209534, -0.09945832937955856], + [0.3415437340736389, 0.41694650053977966, 0.24677544832229614, 0.06690020114183426], + [-0.1977071762084961, 0.39345067739486694, 0.26068705320358276, 0.35502269864082336] ] > }, "dense_2" => %{ "bias" => #Nx.Tensor< f32[1] - [0.26691529154777527] + [0.8329466581344604] >, "kernel" => #Nx.Tensor< f32[4][1] [ - [0.7088357210159302], - [-0.9271859526634216], - [-0.1610293984413147], - [0.6011591553688049] + [-0.23763614892959595], + [-1.031561255455017], + [0.1092313677072525], + [-0.7191486358642578] ] > } @@ -218,7 +219,7 @@ model ``` -Epoch: 0, Batch: 1000, loss: 0.0645265 total error: 158.5873566 +Epoch: 0, Batch: 950, loss: 0.0688004 total error: 151.4876404 ``` @@ -228,46 +229,46 @@ Epoch: 0, Batch: 1000, loss: 0.0645265 total error: 158.5873566 "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] - [0.013307658955454826, 0.08766761422157288, -0.0048030223697423935, -0.07024712860584259, 0.261692613363266, 0.0028863451443612576, -0.12552864849567413, 0.10552618652582169] + [0.34921368956565857, 0.2217460423707962, 0.274880051612854, 0.016405446454882622, -0.11720903217792511, -0.20693546533584595, 0.14232252538204193, -0.07956698536872864] >, "kernel" => #Nx.Tensor< f32[1][8] [ - [-0.1647171825170517, -0.4144238233566284, -0.09969457238912582, -0.6063833832740784, 0.7182243466377258, -0.3485015034675598, -0.29005324840545654, -0.5282242298126221] + [-0.37851807475090027, -0.17135880887508392, -0.3878959119319916, 0.19248774647712708, 0.12453905493021011, -0.2750281095504761, 0.5614567995071411, 0.6186240315437317] ] > }, "dense_1" => %{ "bias" => #Nx.Tensor< f32[4] - [0.021465059369802475, -0.16003911197185516, 0.6696521043777466, -0.15482725203037262] + [-0.28566694259643555, 0.27262070775032043, -0.2875851094722748, 0.0] >, "kernel" => #Nx.Tensor< f32[8][4] [ - [0.3359515964984894, -0.21561087667942047, -0.48400720953941345, -0.3186679184436798], - [-0.08509980887174606, -0.031951334327459335, -0.6084564924240112, -0.39506790041923523], - [0.003889488521963358, -0.12886928021907806, 0.5679722428321838, 0.22699925303459167], - [-0.315458744764328, 0.5626247525215149, -0.4241454303264618, -0.11212264746427536], - [0.6759291291236877, -0.6508319973945618, 0.3511318564414978, 0.17946019768714905], - [-0.7148906588554382, 0.45404312014579773, 0.4150676727294922, 0.33603984117507935], - [0.398037314414978, 0.5080180764198303, 0.6770725250244141, -0.5274750590324402], - [0.5072763562202454, -0.7351003289222717, -0.583225429058075, -0.2974703013896942] + [0.23161421716213226, 0.8222984671592712, 0.09437259286642075, -0.4825701117515564], + [-0.38828352093696594, 0.6247998476028442, 0.5035035610198975, 0.0026152729988098145], + [0.5202338099479675, 0.7906754612922668, 0.08624745905399323, -0.5285568833351135], + [0.47950035333633423, -0.07571044564247131, 0.32921522855758667, -0.7011756896972656], + [-0.3601212203502655, 0.44817543029785156, 0.13981425762176514, -0.01014477014541626], + [-0.3157005310058594, -0.6309216618537903, 0.5622371435165405, 0.27447545528411865], + [-0.5749425292015076, -0.5073797702789307, -0.3527824282646179, 0.08027392625808716], + [-0.5331286191940308, 0.15432128310203552, -0.015716910362243652, -0.5225256681442261] ] > }, "dense_2" => %{ "bias" => #Nx.Tensor< f32[1] - [-0.8310347199440002] + [0.8275660872459412] >, "kernel" => #Nx.Tensor< f32[4][1] [ - [0.28011587262153625], - [0.542819082736969], - [1.2814348936080933], - [-0.5193246603012085] + [0.45810666680336], + [-1.0092405080795288], + [0.5322748422622681], + [-0.5989866852760315] ] > } diff --git a/guides/training_and_evaluation/using_loop_event_handlers.livemd b/guides/training_and_evaluation/using_loop_event_handlers.livemd index 84ba5f39..b96a0333 100644 --- a/guides/training_and_evaluation/using_loop_event_handlers.livemd +++ b/guides/training_and_evaluation/using_loop_event_handlers.livemd @@ -1,11 +1,8 @@ - - # Using loop event handlers ```elixir Mix.install([ - {:axon, github: "elixir-nx/axon"}, - {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true} + {:axon, ">= 0.5.0"} ]) ``` @@ -59,28 +56,28 @@ loop = ``` #Axon.Loop< + metrics: %{ + "loss" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>, + #Function<9.37390314/2 in Axon.Loop.build_loss_fn/1>} + }, handlers: %{ completed: [], epoch_completed: [ - {#Function<14.20267452/1 in Axon.Loop.checkpoint/2>, - #Function<5.20267452/1 in Axon.Loop.build_filter_fn/1>}, - {#Function<23.20267452/1 in Axon.Loop.log/5>, - #Function<5.20267452/1 in Axon.Loop.build_filter_fn/1>} + {#Function<17.37390314/1 in Axon.Loop.checkpoint/2>, + #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>}, + {#Function<27.37390314/1 in Axon.Loop.log/3>, + #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>} ], epoch_halted: [], epoch_started: [], halted: [], iteration_completed: [ - {#Function<23.20267452/1 in Axon.Loop.log/5>, - #Function<3.20267452/1 in Axon.Loop.build_filter_fn/1>} + {#Function<27.37390314/1 in Axon.Loop.log/3>, + #Function<64.37390314/2 in Axon.Loop.build_filter_fn/1>} ], iteration_started: [], started: [] }, - metrics: %{ - "loss" => {#Function<12.6031754/3 in Axon.Metrics.running_average/1>, - #Function<6.20267452/2 in Axon.Loop.build_loss_fn/1>} - }, ... > ``` @@ -90,7 +87,11 @@ Now when you execute your loop, it will save a checkpoint at the end of every ep ```elixir train_data = Stream.repeatedly(fn -> - xs = Nx.random_normal({8, 1}) + {xs, _next_key} = + :random.uniform(9999) + |> Nx.Random.key() + |> Nx.Random.normal(shape: {8, 1}) + ys = Nx.sin(xs) {xs, ys} end) @@ -101,11 +102,11 @@ Axon.Loop.run(loop, train_data, %{}, epochs: 5, iterations: 100) ``` -Epoch: 0, Batch: 100, loss: 0.2462310 -Epoch: 1, Batch: 100, loss: 0.1804814 -Epoch: 2, Batch: 100, loss: 0.1452925 -Epoch: 3, Batch: 100, loss: 0.1177117 -Epoch: 4, Batch: 100, loss: 0.1008184 +Epoch: 0, Batch: 50, loss: 0.5345965 +Epoch: 1, Batch: 50, loss: 0.4578816 +Epoch: 2, Batch: 50, loss: 0.4527244 +Epoch: 3, Batch: 50, loss: 0.4466343 +Epoch: 4, Batch: 50, loss: 0.4401709 ``` @@ -115,46 +116,46 @@ Epoch: 4, Batch: 100, loss: 0.1008184 "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] - [0.36853691935539246, 0.24528849124908447, 0.13193830847740173, 0.03188902884721756, -0.06358373910188675, 0.044517479836940765, -0.1203451156616211, -6.352089694701135e-4] + [-0.1074252650141716, -0.0033432210329920053, -0.08044778555631638, 0.0016452680574730039, -0.01557128969579935, -0.061440952122211456, 0.061030879616737366, 0.012781506404280663] >, "kernel" => #Nx.Tensor< f32[1][8] [ - [0.49448737502098083, 0.5250089764595032, 0.7132464051246643, 0.47473379969596863, -0.043285828083753586, -0.14137212932109833, -0.07576408237218857, -0.48898136615753174] + [-0.3504936695098877, 0.6722151041030884, -0.5550820231437683, 0.05254736915230751, 0.7404129505157471, -0.24307608604431152, -0.7073894739151001, 0.6447222828865051] ] > }, "dense_1" => %{ "bias" => #Nx.Tensor< f32[4] - [0.30324652791023254, 0.0385407879948616, -0.16782516241073608, 0.1984063982963562] + [-0.19830459356307983, 0.0, 0.0, -0.04925372824072838] >, "kernel" => #Nx.Tensor< f32[8][4] [ - [0.2536502778530121, 0.375381737947464, 0.7119463086128235, -0.14521682262420654], - [0.20504063367843628, -0.11605211347341537, 0.49423739314079285, -0.03246872499585152], - [-0.13834621012210846, -0.2579476833343506, 0.34836748242378235, -0.4670639634132385], - [-0.11925031989812851, -0.6655324697494507, 0.5057039856910706, 0.496115118265152], - [0.15856991708278656, -0.2239169478416443, 0.5550385117530823, -0.3774339258670807], - [-0.326529860496521, -0.10192928463220596, 0.2961374819278717, 0.580808699131012], - [0.46179524064064026, -0.4794206917285919, 0.47078272700309753, -0.5654175877571106], - [-0.501025915145874, -0.38049301505088806, 0.3792027235031128, 0.685397207736969] + [0.4873020648956299, -0.3363800644874573, -0.6058675050735474, -0.47888076305389404], + [-0.18936580419540405, -0.5579301714897156, -0.49217337369918823, 0.04828363656997681], + [0.3202762305736542, -0.033479928970336914, 0.11928367614746094, -0.5225698351860046], + [0.3883931040763855, 0.07413274049758911, 0.548823893070221, -0.03494540974497795], + [-0.2598196268081665, -0.4546756446361542, 0.5866180062294006, 0.2946240305900574], + [0.2722054719924927, -0.5802338123321533, 0.4854300618171692, -0.5049118399620056], + [-0.415179044008255, -0.5426293611526489, -0.1631108522415161, -0.6544353365898132], + [-0.3079695403575897, 0.09391731023788452, -0.40262123942375183, -0.27837851643562317] ] > }, "dense_2" => %{ "bias" => #Nx.Tensor< f32[1] - [-0.4034360647201538] + [0.016238097101449966] >, "kernel" => #Nx.Tensor< f32[4][1] [ - [0.8062413334846497], - [0.6867087483406067], - [0.5137255787849426], - [-0.5783006548881531] + [0.3102125823497772], + [-1.078292727470398], + [0.7910841703414917], + [0.014510140754282475] ] > } @@ -166,22 +167,22 @@ You can also use event handlers for things as simple as implementing custom logg ```elixir model |> Axon.Loop.trainer(:mean_squared_error, :sgd) -|> Axon.Loop.log(:epoch_completed, fn _state -> "epoch is over\n" end, :stdio) +|> Axon.Loop.log(fn _state -> "epoch is over\n" end, event: :epoch_completed, device: :stdio) |> Axon.Loop.run(train_data, %{}, epochs: 5, iterations: 100) ``` ``` -Epoch: 0, Batch: 100, loss: 0.2134880 +Epoch: 0, Batch: 50, loss: 0.3220241 epoch is over -Epoch: 1, Batch: 100, loss: 0.1604774 +Epoch: 1, Batch: 50, loss: 0.2309804 epoch is over -Epoch: 2, Batch: 100, loss: 0.1294429 +Epoch: 2, Batch: 50, loss: 0.1759415 epoch is over -Epoch: 3, Batch: 100, loss: 0.1087099 +Epoch: 3, Batch: 50, loss: 0.1457551 epoch is over -Epoch: 4, Batch: 100, loss: 0.0940388 +Epoch: 4, Batch: 50, loss: 0.1247821 epoch is over ``` @@ -192,46 +193,46 @@ epoch is over "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] - [0.1741544008255005, -0.013307991437613964, 0.0873112753033638, -0.04722493514418602, -0.12966567277908325, 0.04596322402358055, 0.3969370722770691, -0.04508184269070625] + [0.01846296526491642, -0.0016654117498546839, 0.39859917759895325, 0.21187178790569305, 0.08815062046051025, -0.11071830987930298, 0.06280634552240372, -0.11682439595460892] >, "kernel" => #Nx.Tensor< f32[1][8] [ - [0.31960299611091614, -0.5328841805458069, -0.24278149008750916, -0.47772416472435, 0.21538947522640228, -0.2799384295940399, 0.5947694778442383, 0.0497460775077343] + [0.08840499818325043, 0.44253841042518616, -0.6063749194145203, -0.1487167924642563, 0.24857401847839355, 0.1697462797164917, -0.5370600819587708, 0.1658734828233719] ] > }, "dense_1" => %{ "bias" => #Nx.Tensor< f32[4] - [0.25857725739479065, -0.07283111661672592, -0.10656370222568512, -0.08234459906816483] + [-0.08111556619405746, 0.32310858368873596, -0.059386227279901505, -0.09515857696533203] >, "kernel" => #Nx.Tensor< f32[8][4] [ - [0.3983175754547119, -0.5524351596832275, 0.36650899052619934, -0.23933114111423492], - [0.06517457216978073, 0.2564122974872589, 0.6227137446403503, -0.5661884546279907], - [-0.7012182474136353, 0.054501600563526154, -0.6726318597793579, 0.4774037301540375], - [-0.11393500864505768, 0.1726256012916565, -0.6723376512527466, 0.6044175028800964], - [-0.30502673983573914, 0.7011693120002747, 0.40034061670303345, -0.5748327374458313], - [-0.07724377512931824, -0.251364529132843, -0.6626797914505005, -0.20940908789634705], - [0.7290927767753601, 0.08563250303268433, -0.047927819192409515, -0.04336162284016609], - [-0.34993213415145874, 0.281339168548584, -0.49343380331993103, -0.2481663078069687] + [0.6057762503623962, -0.2633209824562073, 0.23028653860092163, -0.2710704505443573], + [0.03961030766367912, -0.335278183221817, 0.16016681492328644, 0.10653878003358841], + [0.36239713430404663, 0.8330743312835693, 0.4745633602142334, -0.29585230350494385], + [-0.04394621402025223, 0.45401355624198914, 0.5953336954116821, -0.6513576507568359], + [-0.6447072625160217, -0.6225455403327942, -0.4814218580722809, 0.6882413625717163], + [-0.44460421800613403, -0.04251839220523834, 0.4619944095611572, 0.24515877664089203], + [-0.49396005272865295, -0.08895684778690338, 0.5212237238883972, 0.24301064014434814], + [0.3074108958244324, 0.2640342712402344, 0.4197620749473572, -0.05698487162590027] ] > }, "dense_2" => %{ "bias" => #Nx.Tensor< f32[1] - [-0.6856028437614441] + [0.6520459651947021] >, "kernel" => #Nx.Tensor< f32[4][1] [ - [1.1966136693954468], - [-0.00546963419765234], - [-0.9349364042282104], - [0.9214714765548706] + [0.45083022117614746], + [-0.8733288049697876], + [-0.1894296556711197], + [0.030911535024642944] ] > } @@ -250,11 +251,11 @@ model ``` -Epoch: 0, Batch: 100, loss: 0.1791917 -Epoch: 1, Batch: 100, loss: 0.1373887 -Epoch: 2, Batch: 100, loss: 0.1156979 -Epoch: 3, Batch: 100, loss: 0.0965481 -Epoch: 4, Batch: 100, loss: 0.0865761 +Epoch: 0, Batch: 50, loss: 0.3180207 +Epoch: 1, Batch: 50, loss: 0.1975918 +Epoch: 2, Batch: 50, loss: 0.1353940 +Epoch: 3, Batch: 50, loss: 0.1055405 +Epoch: 4, Batch: 50, loss: 0.0890203 ``` @@ -264,46 +265,46 @@ Epoch: 4, Batch: 100, loss: 0.0865761 "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] - [0.00938357226550579, 0.16315333545207977, 0.2767408788204193, -0.22733710706233978, 0.2830233573913574, -0.10280115902423859, -0.07500249892473221, 0.2947545647621155] + [0.047411054372787476, 0.1582564115524292, -0.027924394235014915, 0.1774083375930786, 0.09764095395803452, 0.1040089949965477, 0.006841400172561407, -0.11682236939668655] >, "kernel" => #Nx.Tensor< f32[1][8] [ - [0.522411048412323, 0.15686289966106415, 0.30727216601371765, 0.3295647203922272, 0.38795727491378784, 0.17159366607666016, 0.7608513236045837, 0.4526905119419098] + [0.20366023480892181, 0.7318703532218933, -0.028611917048692703, -0.5324040055274963, -0.6856501698493958, 0.21694214642047882, 0.3281741738319397, -0.13051153719425201] ] > }, "dense_1" => %{ "bias" => #Nx.Tensor< f32[4] - [-0.024011338129639626, 0.0, -0.00135718728415668, -0.0015321056125685573] + [0.1859581470489502, 0.3360026180744171, 0.24061667919158936, -0.016354668885469437] >, "kernel" => #Nx.Tensor< f32[8][4] [ - [0.606391966342926, -0.08385708928108215, 0.06838012486696243, -0.08704598248004913], - [0.5944894552230835, -0.17639528214931488, 0.26653605699539185, 0.35148826241493225], - [-0.06138936057686806, -0.024123376235365868, 0.29706713557243347, 0.5498997569084167], - [0.26888611912727356, 0.024979088455438614, -0.653775155544281, -0.4111217260360718], - [-0.5042538046836853, -0.6867390871047974, 0.13647332787513733, 0.7193269729614258], - [-0.052732646465301514, 0.099549300968647, -0.6970457434654236, 0.3078557252883911], - [-0.261769562959671, 0.17121906578540802, -0.08267408609390259, -0.2213396430015564], - [-0.09766292572021484, -0.5843542218208313, 0.369784414768219, 0.48434120416641235] + [0.07366377860307693, -0.3261552155017853, -0.6951385140419006, -0.4232194125652313], + [0.7334840893745422, -0.17827139794826508, -0.6411628127098083, -0.41898131370544434], + [0.4770638346672058, -0.4738321304321289, 0.5755389332771301, 0.30976954102516174], + [-0.498087614774704, 0.10546410828828812, 0.690037190914154, -0.5016340613365173], + [0.17509347200393677, 0.4518563449382782, -0.10358063131570816, 0.2223401516675949], + [0.6422480344772339, 0.19363932311534882, 0.2870054543018341, -0.1483648419380188], + [-0.10362248122692108, -0.7047968506813049, 0.02847556211054325, -0.18464618921279907], + [-0.6756409406661987, -0.42686882615089417, -0.5484509468078613, 0.596512496471405] ] > }, "dense_2" => %{ "bias" => #Nx.Tensor< f32[1] - [-0.6914201378822327] + [0.23296000063419342] >, "kernel" => #Nx.Tensor< f32[4][1] [ - [0.96906977891922], - [-0.5032458901405334], - [0.9275273680686951], - [0.8574270606040955] + [0.48827823996543884], + [-0.7908728122711182], + [-0.5326805114746094], + [0.3789232671260834] ] > } diff --git a/guides/training_and_evaluation/writing_custom_event_handlers.livemd b/guides/training_and_evaluation/writing_custom_event_handlers.livemd index b6ddf39d..17f42305 100644 --- a/guides/training_and_evaluation/writing_custom_event_handlers.livemd +++ b/guides/training_and_evaluation/writing_custom_event_handlers.livemd @@ -1,11 +1,8 @@ - - # Writing custom event handlers ```elixir Mix.install([ - {:axon, github: "elixir-nx/axon"}, - {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true} + {:axon, ">= 0.5.0"} ]) ``` @@ -22,7 +19,7 @@ If you require functionality not offered by any of Axon's built-in event handler All event handlers must accept an `%Axon.Loop.State{}` struct and return a tuple of `{control_term, state}` where `control_term` is one of `:continue`, `:halt_epoch`, or `:halt_loop` and `state` is the updated loop state: ```elixir -defmodule CustomEventHandler do +defmodule CustomEventHandler0 do alias Axon.Loop.State def my_weird_handler(%State{} = state) do @@ -35,7 +32,7 @@ end ``` -{:module, CustomEventHandler, <<70, 79, 82, 49, 0, 0, 6, ...>>, {:my_weird_handler, 1}} +{:module, CustomEventHandler0, <<70, 79, 82, 49, 0, 0, 6, ...>>, {:my_weird_handler, 1}} ``` To register event handlers, you use `Axon.Loop.handle/4`: @@ -52,35 +49,35 @@ model = loop = model |> Axon.Loop.trainer(:mean_squared_error, :sgd) - |> Axon.Loop.handle(:epoch_completed, &CustomEventHandler.my_weird_handler/1) + |> Axon.Loop.handle_event(:epoch_completed, &CustomEventHandler0.my_weird_handler/1) ``` ``` #Axon.Loop< + metrics: %{ + "loss" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>, + #Function<9.37390314/2 in Axon.Loop.build_loss_fn/1>} + }, handlers: %{ completed: [], epoch_completed: [ - {&CustomEventHandler.my_weird_handler/1, - #Function<5.33119226/1 in Axon.Loop.build_filter_fn/1>}, - {#Function<23.33119226/1 in Axon.Loop.log/5>, - #Function<5.33119226/1 in Axon.Loop.build_filter_fn/1>} + {&CustomEventHandler0.my_weird_handler/1, + #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>}, + {#Function<27.37390314/1 in Axon.Loop.log/3>, + #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>} ], epoch_halted: [], epoch_started: [], halted: [], iteration_completed: [ - {#Function<23.33119226/1 in Axon.Loop.log/5>, - #Function<3.33119226/1 in Axon.Loop.build_filter_fn/1>} + {#Function<27.37390314/1 in Axon.Loop.log/3>, + #Function<64.37390314/2 in Axon.Loop.build_filter_fn/1>} ], iteration_started: [], started: [] }, - metrics: %{ - "loss" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>, - #Function<6.33119226/2 in Axon.Loop.build_loss_fn/1>} - }, ... > ``` @@ -90,7 +87,11 @@ Axon will trigger your custom handler to run on the attached event: ```elixir train_data = Stream.repeatedly(fn -> - xs = Nx.random_normal({8, 1}) + {xs, _next_key} = + :random.uniform(9999) + |> Nx.Random.key() + |> Nx.Random.normal(shape: {8, 1}) + ys = Nx.sin(xs) {xs, ys} end) @@ -101,15 +102,15 @@ Axon.Loop.run(loop, train_data, %{}, epochs: 5, iterations: 100) ``` -Epoch: 0, Batch: 100, loss: 0.1905403 +Epoch: 0, Batch: 50, loss: 0.0990703 My weird handler: fired -Epoch: 1, Batch: 100, loss: 0.1478554 +Epoch: 1, Batch: 50, loss: 0.0567622 My weird handler: fired -Epoch: 2, Batch: 100, loss: 0.1184390 +Epoch: 2, Batch: 50, loss: 0.0492784 My weird handler: fired -Epoch: 3, Batch: 100, loss: 0.0983292 +Epoch: 3, Batch: 50, loss: 0.0462587 My weird handler: fired -Epoch: 4, Batch: 100, loss: 0.0845697 +Epoch: 4, Batch: 50, loss: 0.0452806 My weird handler: fired ``` @@ -120,46 +121,46 @@ My weird handler: fired "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] - [0.014659373089671135, 0.08941870182752609, -0.09661660343408585, 0.2650177478790283, -0.06400775164365768, -0.07953602075576782, 0.22094617784023285, -0.014790073968470097] + [0.10819189250469208, 0.008151392452418804, -0.0318693183362484, 0.010302421636879444, 0.15788722038269043, 0.05119801685214043, 0.14268818497657776, -0.11528034508228302] >, "kernel" => #Nx.Tensor< f32[1][8] [ - [0.3581556975841522, 0.38828182220458984, -0.3311854302883148, -0.4059808552265167, 0.6334917545318604, 0.17008493840694427, -0.5630434155464172, 0.3790667653083801] + [-0.4275593161582947, 0.40442031621932983, 0.7287659645080566, -0.7832129597663879, 0.3329123258590698, -0.5598123073577881, 0.8389336466789246, 0.3197469413280487] ] > }, "dense_1" => %{ "bias" => #Nx.Tensor< f32[4] - [0.3047839403152466, -0.025677276775240898, 0.18113580346107483, 0.19019420444965363] + [0.0671013742685318, 0.13561469316482544, 0.06218714639544487, 0.2104845941066742] >, "kernel" => #Nx.Tensor< f32[8][4] [ - [-0.25477269291877747, 0.28833284974098206, -0.25498083233833313, 0.40912926197052], - [-0.387851357460022, 0.009837300516664982, -0.48930269479751587, -0.6119663715362549], - [0.49769237637519836, -0.45746952295303345, -0.3886529505252838, -0.49895355105400085], - [0.6451961994171143, 0.16054697334766388, 0.27802371978759766, -0.15226426720619202], - [0.17125651240348816, -0.048851024359464645, 0.19429178535938263, 0.24933232367038727], - [0.5465306043624878, -0.15836869180202484, 0.39782997965812683, -0.3635501563549042], - [-0.36660289764404297, -0.011948992498219013, 0.48680511116981506, 0.5263928174972534], - [-0.6284276843070984, -0.5880372524261475, 0.004470183979719877, -0.4550755023956299] + [0.4444102942943573, 0.4518184959888458, 0.45315614342689514, 0.35392478108406067], + [0.008407601155340672, -0.6081852912902832, -0.05863206833600998, 0.14386630058288574], + [-0.010219200514256954, -0.5528244376182556, 0.3754919469356537, -0.6242967247962952], + [0.3531058132648468, -0.18348301947116852, -0.0019897441379725933, 0.41002658009529114], + [0.676723062992096, -0.09349705278873444, 0.1101854145526886, 0.06494166702032089], + [0.1534113883972168, 0.6402403116226196, 0.23490086197853088, -0.2196572870016098], + [0.5835862755775452, -0.6581316590309143, -0.3047991394996643, -0.07485166192054749], + [-0.6115342378616333, 0.3316897749900818, -0.3606548309326172, 0.3397740423679352] ] > }, "dense_2" => %{ "bias" => #Nx.Tensor< f32[1] - [0.7117368578910828] + [0.10111129283905029] >, "kernel" => #Nx.Tensor< f32[4][1] [ - [-0.7743457555770874], - [0.3977936804294586], - [-1.0638943910598755], - [-0.6494196653366089] + [0.7433153390884399], + [-0.8213723301887512], + [-0.44361063838005066], + [-1.049617052078247] ] > } @@ -169,7 +170,7 @@ My weird handler: fired You can use event handlers to early-stop a loop or loop epoch by returning a `:halt_*` control term. Halt control terms can be one of `:halt_epoch` or `:halt_loop`. `:halt_epoch` halts the current epoch and continues to the next. `:halt_loop` halts the loop altogether. ```elixir -defmodule CustomEventHandler do +defmodule CustomEventHandler1 do alias Axon.Loop.State def always_halts(%State{} = state) do @@ -182,7 +183,7 @@ end ``` -{:module, CustomEventHandler, <<70, 79, 82, 49, 0, 0, 6, ...>>, {:always_halts, 1}} +{:module, CustomEventHandler1, <<70, 79, 82, 49, 0, 0, 6, ...>>, {:always_halts, 1}} ``` The loop will immediately stop executing and return the current state at the time it was halted: @@ -190,14 +191,14 @@ The loop will immediately stop executing and return the current state at the tim ```elixir model |> Axon.Loop.trainer(:mean_squared_error, :sgd) -|> Axon.Loop.handle(:epoch_completed, &CustomEventHandler.always_halts/1) +|> Axon.Loop.handle_event(:epoch_completed, &CustomEventHandler1.always_halts/1) |> Axon.Loop.run(train_data, %{}, epochs: 5, iterations: 100) ``` ``` -Epoch: 0, Batch: 100, loss: 0.1967763 +Epoch: 0, Batch: 50, loss: 0.2201974 stopping loop ``` @@ -208,46 +209,46 @@ stopping loop "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] - [-0.05958094820380211, 0.08930676430463791, -0.006259916350245476, 0.05067025125026703, 0.10981185734272003, -0.011248357594013214, -0.007601946126669645, 0.036958880722522736] + [0.07676638662815094, -0.18689222633838654, 0.10066182911396027, -0.021994125097990036, 0.12006694823503494, -0.014219668693840504, 0.13600556552410126, -0.017512166872620583] >, "kernel" => #Nx.Tensor< f32[1][8] [ - [0.050393108278512955, -0.5486620664596558, 0.6901980042457581, 0.42280837893486023, 0.6446300745010376, 0.25207778811454773, -0.13566234707832336, 0.26625606417655945] + [-0.5354958772659302, -0.216745987534523, -0.5694359540939331, 0.023495405912399292, 0.17701618373394012, 0.011712944135069847, 0.5289720892906189, 0.07360327988862991] ] > }, "dense_1" => %{ "bias" => #Nx.Tensor< f32[4] - [-0.06729397922754288, 0.14259757101535797, -0.0020351663697510958, 0.16679106652736664] + [0.0012482400052249432, 0.09300543367862701, 0.08570009469985962, -0.018982920795679092] >, "kernel" => #Nx.Tensor< f32[8][4] [ - [-0.5964004397392273, -0.5631846785545349, 0.15613533556461334, 0.1943722516298294], - [0.19513694941997528, -0.24765732884407043, -0.06751974672079086, 0.6707308292388916], - [-0.6826592087745667, -0.006577506195753813, -0.6097249984741211, -0.5801466703414917], - [-0.30076032876968384, 0.34819719195365906, -0.5906499028205872, -0.37741175293922424], - [0.16266342997550964, 0.7666646838188171, 0.6456886529922485, -0.4589986801147461], - [-0.2686948776245117, -0.06113003194332123, 0.22663049399852753, -0.12092678993940353], - [-0.5785921216011047, -0.641874372959137, -0.24317769706249237, -0.2897084951400757], - [0.14917287230491638, 0.24462535977363586, -0.64858478307724, -0.5138146877288818] + [0.3016211688518524, 0.31998082995414734, -0.3300730884075165, 0.24982869625091553], + [0.03864569962024689, -0.44071364402770996, 0.6553062200546265, -0.5294798612594604], + [0.25020459294319153, 0.7249991297721863, 0.15611837804317474, -0.5045580863952637], + [-0.5500670075416565, 0.15677094459533691, -0.6531851291656494, -0.09289993345737457], + [0.1618722379207611, 0.4479053020477295, 0.705923318862915, -0.3853490352630615], + [-0.6752215623855591, 0.577272891998291, -0.1268012821674347, 0.6133111715316772], + [0.5361366271972656, -0.2996085286140442, 0.28480708599090576, 0.47739118337631226], + [-0.6443014144897461, -0.2866927981376648, 0.023463081568479538, -0.1491370052099228] ] > }, "dense_2" => %{ "bias" => #Nx.Tensor< f32[1] - [-0.11649220436811447] + [0.0047520860098302364] >, "kernel" => #Nx.Tensor< f32[4][1] [ - [0.7849427461624146], - [0.5966104865074158], - [-0.5520159602165222], - [-0.4974740147590637] + [0.3796459138393402], + [-0.9757304191589355], + [0.9530885815620422], + [-0.05134368687868118] ] > } @@ -257,7 +258,7 @@ stopping loop Note that halting an epoch will fire a different event than completing an epoch. So if you implement a custom handler to halt the loop when an epoch completes, it will never fire if the epoch always halts prematurely: ```elixir -defmodule CustomEventHandler do +defmodule CustomEventHandler2 do alias Axon.Loop.State def always_halts_epoch(%State{} = state) do @@ -275,7 +276,7 @@ end ``` -{:module, CustomEventHandler, <<70, 79, 82, 49, 0, 0, 7, ...>>, {:always_halts_loop, 1}} +{:module, CustomEventHandler2, <<70, 79, 82, 49, 0, 0, 8, ...>>, {:always_halts_loop, 1}} ``` If you run these handlers in conjunction, the loop will not terminate prematurely: @@ -283,8 +284,8 @@ If you run these handlers in conjunction, the loop will not terminate prematurel ```elixir model |> Axon.Loop.trainer(:mean_squared_error, :sgd) -|> Axon.Loop.handle(:iteration_completed, &CustomEventHandler.always_halts_epoch/1) -|> Axon.Loop.handle(:epoch_completed, &CustomEventHandler.always_halts_loop/1) +|> Axon.Loop.handle_event(:iteration_completed, &CustomEventHandler2.always_halts_epoch/1) +|> Axon.Loop.handle_event(:epoch_completed, &CustomEventHandler2.always_halts_loop/1) |> Axon.Loop.run(train_data, %{}, epochs: 5, iterations: 100) ``` @@ -293,13 +294,13 @@ model ``` Epoch: 0, Batch: 0, loss: 0.0000000 stopping epoch -Epoch: 0, Batch: 0, loss: 0.7256396 + stopping epoch -Epoch: 0, Batch: 0, loss: 0.4574284 + stopping epoch -Epoch: 0, Batch: 0, loss: 0.4981923 + stopping epoch -Epoch: 0, Batch: 0, loss: 0.4377063 + stopping epoch ``` @@ -310,46 +311,46 @@ stopping epoch "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] - [9.248655405826867e-4, -0.0038722341414541006, -0.0015197680331766605, -0.001993122510612011, -0.0015419051051139832, -0.004070846363902092, 0.001461982261389494, 0.0043989671394228935] + [0.009215549565851688, -0.005282022058963776, -0.0023747326340526342, 0.002623362001031637, 0.003890525083988905, 6.010813522152603e-4, -0.0024882694706320763, 0.0029246946796774864] >, "kernel" => #Nx.Tensor< f32[1][8] [ - [-0.6537156701087952, 0.2857331335544586, -0.339731365442276, 0.46841081976890564, -0.5864744782447815, -0.364472359418869, -0.5385616421699524, -0.694677472114563] + [-0.3484582304954529, -0.39938971400260925, 0.03963512182235718, -0.3549930155277252, 0.09539157152175903, 0.5987873077392578, -0.23635399341583252, 0.01850329153239727] ] > }, "dense_1" => %{ "bias" => #Nx.Tensor< f32[4] - [0.0, -0.017093738541007042, 0.00152371556032449, -0.0019599769730120897] + [-0.00194685033056885, 0.007812315598130226, 0.01710106059908867, 0.0080711729824543] >, "kernel" => #Nx.Tensor< f32[8][4] [ - [-0.21336764097213745, -0.6211493611335754, 0.676548957824707, 0.3768426477909088], - [-0.24921125173568726, 0.217195525765419, 0.23704318702220917, 0.1597728431224823], - [-0.12178827077150345, -0.4966273307800293, -0.283501535654068, 0.00888047181069851], - [-0.19504092633724213, 0.18697738647460938, 0.14705461263656616, 0.39286476373672485], - [-0.5945789813995361, -0.5958647727966309, -0.3320448100566864, -0.02747068926692009], - [-0.2157520055770874, -0.2990635335445404, -0.16008871793746948, 0.4921063184738159], - [-0.529068648815155, -0.383655846118927, -0.07292155921459198, -0.2834954559803009], - [-0.3056498169898987, -0.28507867455482483, 0.554026186466217, -0.24665579199790955] + [-0.6497661471366882, -0.3379145562648773, 0.3343344032764435, 0.4334254860877991], + [-0.37884217500686646, -0.41724908351898193, -0.19513007998466492, -0.22494879364967346], + [-0.42438197135925293, -0.40400123596191406, 0.5355109572410583, 0.4295356869697571], + [0.15086597204208374, 0.30529624223709106, 0.002222923096269369, 0.32834741473197937], + [-0.09336567670106888, 0.471781849861145, -0.06567475199699402, -0.4361487627029419], + [0.23664812743663788, 0.13572633266448975, -0.13837064802646637, -0.09471122920513153], + [0.6461064219474792, -0.2435072958469391, -0.04861235246062279, -0.1969985067844391], + [0.17856749892234802, 0.41614532470703125, -0.06008348613977432, -0.3271574079990387] ] > }, "dense_2" => %{ "bias" => #Nx.Tensor< f32[1] - [-0.010511377826333046] + [-0.005317525006830692] >, "kernel" => #Nx.Tensor< f32[4][1] [ - [0.9865502119064331], - [-0.686279296875], - [-0.15436960756778717], - [0.18355509638786316] + [-0.07891849428415298], + [0.32653072476387024], + [-0.5885495543479919], + [-0.2781771719455719] ] > } diff --git a/guides/training_and_evaluation/writing_custom_metrics.livemd b/guides/training_and_evaluation/writing_custom_metrics.livemd index 979776b4..dd9d5eda 100644 --- a/guides/training_and_evaluation/writing_custom_metrics.livemd +++ b/guides/training_and_evaluation/writing_custom_metrics.livemd @@ -1,11 +1,8 @@ - - # Writing custom metrics ```elixir Mix.install([ - {:axon, github: "elixir-nx/axon"}, - {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true} + {:axon, ">= 0.5.0"} ]) ``` @@ -32,7 +29,7 @@ end ``` -{:module, CustomMetric, <<70, 79, 82, 49, 0, 0, 8, ...>>, {:my_weird_metric, 2}} +{:module, CustomMetric, <<70, 79, 82, 49, 0, 0, 8, ...>>, true} ``` Then you can pass that directly to `Axon.Loop.metric/5`. You must provide a name for your custom metric: @@ -56,28 +53,28 @@ loop = ``` #Axon.Loop< + metrics: %{ + "loss" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>, + #Function<9.37390314/2 in Axon.Loop.build_loss_fn/1>}, + "my weird metric" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>, + &CustomMetric.my_weird_metric/2} + }, handlers: %{ completed: [], epoch_completed: [ - {#Function<23.77614421/1 in Axon.Loop.log/5>, - #Function<5.77614421/1 in Axon.Loop.build_filter_fn/1>} + {#Function<27.37390314/1 in Axon.Loop.log/3>, + #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>} ], epoch_halted: [], epoch_started: [], halted: [], iteration_completed: [ - {#Function<23.77614421/1 in Axon.Loop.log/5>, - #Function<3.77614421/1 in Axon.Loop.build_filter_fn/1>} + {#Function<27.37390314/1 in Axon.Loop.log/3>, + #Function<64.37390314/2 in Axon.Loop.build_filter_fn/1>} ], iteration_started: [], started: [] }, - metrics: %{ - "loss" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>, - #Function<6.77614421/2 in Axon.Loop.build_loss_fn/1>}, - "my weird metric" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>, - &CustomMetric.my_weird_metric/2} - }, ... > ``` @@ -87,7 +84,11 @@ Then when running, Axon will invoke your custom metric function and accumulate i ```elixir train_data = Stream.repeatedly(fn -> - xs = Nx.random_normal({8, 1}) + {xs, _next_key} = + :random.uniform(9999) + |> Nx.Random.key() + |> Nx.Random.normal(shape: {8, 1}) + ys = Nx.sin(xs) {xs, ys} end) @@ -98,7 +99,7 @@ Axon.Loop.run(loop, train_data, %{}, iterations: 1000) ``` -Epoch: 0, Batch: 1000, loss: 0.0468431 my weird metric: -5.7462921 +Epoch: 0, Batch: 950, loss: 0.0681635 my weird metric: -5.2842808 ``` @@ -108,46 +109,46 @@ Epoch: 0, Batch: 1000, loss: 0.0468431 my weird metric: -5.7462921 "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] - [0.011475208215415478, 0.23035769164562225, 0.01538881566375494, 0.08167446404695511, 0.23642019927501678, 0.10298296064138412, 0.20279639959335327, -0.18916435539722443] + [0.0866982489824295, 0.4234408140182495, 0.18205422163009644, 0.34029239416122437, -0.25770726799964905, -0.07117943465709686, 0.11470477283000946, -0.027526771649718285] >, "kernel" => #Nx.Tensor< f32[1][8] [ - [0.7426201105117798, 0.734136700630188, -0.5648708343505859, -0.5230435132980347, 0.3056533932685852, 0.3383721709251404, -0.3518844544887543, -0.19460521638393402] + [-0.7088809013366699, 0.4486531913280487, 0.4666421115398407, 0.4163222312927246, 0.5076444149017334, 0.10119977593421936, 0.6628422141075134, -0.024421442300081253] ] > }, "dense_1" => %{ "bias" => #Nx.Tensor< f32[4] - [0.2185358852148056, 0.23043134808540344, 0.0, 0.2650437355041504] + [0.2924745976924896, 0.0065560233779251575, 0.0, -0.21106423437595367] >, "kernel" => #Nx.Tensor< f32[8][4] [ - [0.19164204597473145, -0.26440876722335815, 0.060297321528196335, 0.004777891095727682], - [0.019263261929154396, -0.6267783045768738, -0.33454063534736633, 0.33268266916275024], - [-0.18489953875541687, 0.4653063714504242, -0.6056118607521057, -0.046012550592422485], - [0.5975558161735535, -0.237883061170578, -0.6522921919822693, 0.019332828000187874], - [-0.7424253225326538, 0.593705952167511, 0.2551117241382599, 0.26270362734794617], - [0.018434584140777588, 0.15290242433547974, 0.08793036639690399, 0.1839984804391861], - [0.6048195958137512, -0.20294713973999023, -0.694927990436554, -0.45577046275138855], - [-0.628790020942688, 0.21741150319576263, -0.08936657756567001, 0.6170362234115601] + [-0.3407173752784729, -0.6905813217163086, -0.5984221696853638, -0.23955762386322021], + [0.42608022689819336, 0.5949274301528931, -0.24687853455543518, -0.4948572516441345], + [0.27617380023002625, -0.44326621294021606, -0.5848686099052429, 0.31592807173728943], + [0.5401414632797241, -0.1041281446814537, -0.4072037935256958, 0.4387882947921753], + [-0.5410752892494202, 0.4544697403907776, -0.6238576173782349, -0.2077195793390274], + [-0.41753143072128296, -0.11599045991897583, -0.22447934746742249, -0.5805748701095581], + [0.1651047021150589, -0.526184618473053, 0.34729963541030884, 0.3307822048664093], + [0.6879482865333557, 0.27184563875198364, -0.4907835125923157, -0.3555335998535156] ] > }, "dense_2" => %{ "bias" => #Nx.Tensor< f32[1] - [-0.03722470998764038] + [-0.8146252036094666] >, "kernel" => #Nx.Tensor< f32[4][1] [ - [-0.7919473648071289], - [-0.4341854751110077], - [-0.39114490151405334], - [0.9605273008346558] + [1.2187021970748901], + [0.13001228868961334], + [0.2703772783279419], + [-0.3591017723083496] ] > } @@ -180,30 +181,30 @@ loop = ``` #Axon.Loop< + metrics: %{ + "dense_0_kernel_mean" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>, + &Nx.mean/1}, + "dense_0_kernel_var" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>, + &Nx.variance/1}, + "loss" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>, + #Function<9.37390314/2 in Axon.Loop.build_loss_fn/1>} + }, handlers: %{ completed: [], epoch_completed: [ - {#Function<23.77614421/1 in Axon.Loop.log/5>, - #Function<5.77614421/1 in Axon.Loop.build_filter_fn/1>} + {#Function<27.37390314/1 in Axon.Loop.log/3>, + #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>} ], epoch_halted: [], epoch_started: [], halted: [], iteration_completed: [ - {#Function<23.77614421/1 in Axon.Loop.log/5>, - #Function<3.77614421/1 in Axon.Loop.build_filter_fn/1>} + {#Function<27.37390314/1 in Axon.Loop.log/3>, + #Function<64.37390314/2 in Axon.Loop.build_filter_fn/1>} ], iteration_started: [], started: [] }, - metrics: %{ - "dense_0_kernel_mean" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>, - &Nx.mean/1}, - "dense_0_kernel_var" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>, - &Nx.variance/1}, - "loss" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>, - #Function<6.77614421/2 in Axon.Loop.build_loss_fn/1>} - }, ... > ``` @@ -213,7 +214,11 @@ Axon will apply your custom output transform to the loop's step state and forwar ```elixir train_data = Stream.repeatedly(fn -> - xs = Nx.random_normal({8, 1}) + {xs, _next_key} = + :random.uniform(9999) + |> Nx.Random.key() + |> Nx.Random.normal(shape: {8, 1}) + ys = Nx.sin(xs) {xs, ys} end) @@ -224,7 +229,7 @@ Axon.Loop.run(loop, train_data, %{}, iterations: 1000) ``` -Epoch: 0, Batch: 1000, dense_0_kernel_mean: 0.0807205 dense_0_kernel_var: 0.1448047 loss: 0.0626600 +Epoch: 0, Batch: 950, dense_0_kernel_mean: -0.1978206 dense_0_kernel_var: 0.2699870 loss: 0.0605523 ``` @@ -234,46 +239,46 @@ Epoch: 0, Batch: 1000, dense_0_kernel_mean: 0.0807205 dense_0_kernel_var: 0.1448 "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] - [-0.14429236948490143, 0.3176318109035492, 0.0036036474630236626, 0.01434470433741808, 0.21225003898143768, -0.1406097412109375, 0.32469284534454346, -0.18893203139305115] + [0.371105819940567, 0.26451945304870605, -0.048297226428985596, 0.14616385102272034, -0.19356133043766022, -0.2924956679344177, 0.08295489847660065, 0.25213995575904846] >, "kernel" => #Nx.Tensor< f32[1][8] [ - [0.2918722331523895, -0.44978663325309753, -0.28219935297966003, -0.10681337863206863, 0.5192054510116577, 0.312747985124588, -0.15127503871917725, 0.5638187527656555] + [-0.3888320028781891, -0.39463144540786743, 0.5427617430686951, -0.776488721370697, -0.2402891218662262, -0.6489362716674805, 0.772796094417572, -0.3739306926727295] ] > }, "dense_1" => %{ "bias" => #Nx.Tensor< f32[4] - [0.0, -0.003864143043756485, 0.5194356441497803, 0.028363214805722237] + [0.0, -0.006653765682131052, 0.0, 0.3086839020252228] >, "kernel" => #Nx.Tensor< f32[8][4] [ - [-0.6123268008232117, 0.22753892838954926, 0.12077417969703674, 0.4875330626964569], - [-0.5840837359428406, 0.2259720116853714, 0.4917944371700287, 0.22638437151908875], - [-0.22699439525604248, -0.6744257807731628, -0.2907045781612396, 0.35300591588020325], - [-0.16367988288402557, -0.5971682071685791, -0.39346548914909363, 0.5823913812637329], - [-0.5512545704841614, -0.6812713742256165, -0.5777145624160767, -0.653957188129425], - [-0.23620283603668213, -0.47966212034225464, -0.273225873708725, 0.3827615976333618], - [-0.5591338276863098, -0.1730434000492096, 0.25726518034935, 0.7179149389266968], - [0.3902169167995453, 0.6351881623268127, -0.602277398109436, 0.40137141942977905] + [-0.5556576251983643, 0.5547546148300171, -0.2708005905151367, 0.7341570258140564], + [-0.01800161600112915, 0.19749529659748077, -0.09523773193359375, 0.4989740252494812], + [-0.19737857580184937, -0.2741832435131073, -0.3699955344200134, 0.21036939322948456], + [-0.09787613153457642, -0.5631319284439087, 0.007957160472869873, 0.23681949079036713], + [-0.469108909368515, 0.24062377214431763, -0.012939095497131348, -0.5055088400840759], + [0.11229842901229858, -0.5476430058479309, 0.013744592666625977, -0.631401538848877], + [-0.5834296941757202, -0.42305096983909607, 0.1393480896949768, -0.4647532105445862], + [-0.3684111535549164, -0.5147689580917358, -0.3725535273551941, 0.46682292222976685] ] > }, "dense_2" => %{ "bias" => #Nx.Tensor< f32[1] - [0.824558675289154] + [0.8305950164794922] >, "kernel" => #Nx.Tensor< f32[4][1] [ - [0.9618374109268188], - [-0.028266794979572296], - [-1.1059081554412842], - [-0.7398673892021179] + [0.7111979722976685], + [-0.49341335892677307], + [-0.32701319456100464], + [-1.0638068914413452] ] > } @@ -296,7 +301,7 @@ end ``` -{:module, CustomAccumulator, <<70, 79, 82, 49, 0, 0, 11, ...>>, {:running_ema, 4}} +{:module, CustomAccumulator, <<70, 79, 82, 49, 0, 0, 11, ...>>, true} ``` Your accumulator must be an arity-3 function which accepts the current accumulated value, the current observation, and the current iteration and returns the aggregated metric. You can pass a function direct as an accumulator in your metric: @@ -329,28 +334,28 @@ loop = ``` #Axon.Loop< + metrics: %{ + "dense_0_kernel_ema_mean" => {#Function<15.37390314/3 in Axon.Loop.build_metric_fn/3>, + &Nx.mean/1}, + "loss" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>, + #Function<9.37390314/2 in Axon.Loop.build_loss_fn/1>} + }, handlers: %{ completed: [], epoch_completed: [ - {#Function<23.77614421/1 in Axon.Loop.log/5>, - #Function<5.77614421/1 in Axon.Loop.build_filter_fn/1>} + {#Function<27.37390314/1 in Axon.Loop.log/3>, + #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>} ], epoch_halted: [], epoch_started: [], halted: [], iteration_completed: [ - {#Function<23.77614421/1 in Axon.Loop.log/5>, - #Function<3.77614421/1 in Axon.Loop.build_filter_fn/1>} + {#Function<27.37390314/1 in Axon.Loop.log/3>, + #Function<64.37390314/2 in Axon.Loop.build_filter_fn/1>} ], iteration_started: [], started: [] }, - metrics: %{ - "dense_0_kernel_ema_mean" => {#Function<12.77614421/3 in Axon.Loop.build_metric_fn/3>, - &Nx.mean/1}, - "loss" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>, - #Function<6.77614421/2 in Axon.Loop.build_loss_fn/1>} - }, ... > ``` @@ -360,7 +365,11 @@ Then when you run the loop, Axon will use your custom accumulator: ```elixir train_data = Stream.repeatedly(fn -> - xs = Nx.random_normal({8, 1}) + {xs, _next_key} = + :random.uniform(9999) + |> Nx.Random.key() + |> Nx.Random.normal(shape: {8, 1}) + ys = Nx.sin(xs) {xs, ys} end) @@ -371,7 +380,7 @@ Axon.Loop.run(loop, train_data, %{}, iterations: 1000) ``` -Epoch: 0, Batch: 1000, dense_0_kernel_ema_mean: 0.2137861 loss: 0.0709054 +Epoch: 0, Batch: 950, dense_0_kernel_ema_mean: -0.0139760 loss: 0.0682910 ``` @@ -381,46 +390,46 @@ Epoch: 0, Batch: 1000, dense_0_kernel_ema_mean: 0.2137861 loss: 0.0709054 "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] - [0.08160790055990219, -0.21322371065616608, -0.1431925743818283, 0.2848915755748749, -0.007875560782849789, 0.3923396170139313, -0.04444991424679756, 0.23083189129829407] + [-0.3344854414463043, -0.14519920945167542, 0.1061621680855751, 0.36911827325820923, 0.014146199449896812, 0.46089673042297363, -0.1707312911748886, -0.054649338126182556] >, "kernel" => #Nx.Tensor< f32[1][8] [ - [-0.6269387006759644, 0.3289071023464203, 0.19450749456882477, 0.7400281429290771, 0.23878233134746552, 0.36140456795692444, 0.10503113269805908, 0.3685782253742218] + [0.6524605751037598, -0.3795280158519745, -0.2069108486175537, 0.6815686821937561, -0.5734748840332031, 0.5515486001968384, -0.13509605824947357, -0.711794912815094] ] > }, "dense_1" => %{ "bias" => #Nx.Tensor< f32[4] - [0.2350393682718277, 0.06712433695793152, -0.03675961494445801, -0.06366443634033203] + [0.3078235387802124, -0.24773009121418, -0.027328377589583397, 0.0769796073436737] >, "kernel" => #Nx.Tensor< f32[8][4] [ - [-0.35826751589775085, -0.10699580609798431, -0.3681609034538269, 0.08517063409090042], - [-0.7694831490516663, 0.13644370436668396, -0.2390032261610031, 0.6069303154945374], - [-0.6424086689949036, 0.13374455273151398, -0.35404452681541443, 0.6343701481819153], - [-0.09528166800737381, 0.7048070430755615, 0.13699916005134583, 0.6482889652252197], - [-0.08044164627790451, 0.010588583536446095, 0.11140558868646622, 0.33911004662513733], - [0.7361723780632019, 0.757600724697113, -0.0011848200811073184, 0.2799053192138672], - [0.3472788631916046, -0.5225644111633301, 0.04859891161322594, -0.4931156039237976], - [0.09371320903301239, 0.5478940606117249, 0.5831385254859924, -0.21019525825977325] + [-0.785156786441803, 0.07306647300720215, 0.339533269405365, -0.2188076674938202], + [0.29139244556427, 0.15977036952972412, 0.6193944215774536, -0.4305708408355713], + [-0.21063144505023956, -0.3738138973712921, -0.27965712547302246, 0.051842525601387024], + [0.7297297716140747, -0.08164620399475098, 0.07651054859161377, -0.43577027320861816], + [0.07917583733797073, -0.27750709652900696, 0.21028375625610352, -0.6430750489234924], + [0.7177602648735046, -0.2743614912033081, -0.5894488096237183, 0.634209156036377], + [0.4251592457294464, 0.6134526133537292, -0.35339266061782837, 0.4966743588447571], + [-0.49672019481658936, 0.46769094467163086, -0.44432300329208374, -0.3249942660331726] ] > }, "dense_2" => %{ "bias" => #Nx.Tensor< f32[1] - [-0.835706889629364] + [-0.8245151042938232] >, "kernel" => #Nx.Tensor< f32[4][1] [ - [1.0109968185424805], - [0.574639618396759], - [-0.01302765030413866], - [-0.008134203962981701] + [0.9500011205673218], + [0.9115968942642212], + [0.39282673597335815], + [0.19936752319335938] ] > } diff --git a/guides/training_and_evaluation/your_first_evaluation_loop.livemd b/guides/training_and_evaluation/your_first_evaluation_loop.livemd index 8941a369..47b1543f 100644 --- a/guides/training_and_evaluation/your_first_evaluation_loop.livemd +++ b/guides/training_and_evaluation/your_first_evaluation_loop.livemd @@ -1,11 +1,8 @@ - - # Your first evaluation loop ```elixir Mix.install([ - {:axon, github: "elixir-nx/axon"}, - {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true} + {:axon, ">= 0.5.0"} ]) ``` @@ -34,7 +31,11 @@ train_loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd) data = Stream.repeatedly(fn -> - xs = Nx.random_normal({8, 1}) + {xs, _next_key} = + :random.uniform(9999) + |> Nx.Random.key() + |> Nx.Random.normal(shape: {8, 1}) + ys = Nx.sin(xs) {xs, ys} end) @@ -45,7 +46,7 @@ trained_model_state = Axon.Loop.run(train_loop, data, %{}, iterations: 1000) ``` -Epoch: 0, Batch: 1000, loss: 0.0348526 +Epoch: 0, Batch: 950, loss: 0.1285532 ``` @@ -55,46 +56,46 @@ Epoch: 0, Batch: 1000, loss: 0.0348526 "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] - [0.12334823608398438, 0.23830991983413696, 0.07463178038597107, -0.18479900062084198, -0.2544017434120178, -0.1100262850522995, 0.04137010499835014, 0.22781872749328613] + [-0.06848274916410446, 0.037988610565662384, -0.199247345328331, 0.18008524179458618, 0.10976515710353851, -0.10479626059532166, 0.562850832939148, -0.030415315181016922] >, "kernel" => #Nx.Tensor< f32[1][8] [ - [-0.7397015690803528, 0.8709579110145569, -0.33129510283470154, -0.4521639347076416, -0.5752679109573364, 0.5516160726547241, -0.1265108585357666, -0.5665484666824341] + [-0.2839881181716919, 0.11133058369159698, -0.5213645100593567, -0.14406965672969818, 0.37532612681388855, -0.28965434432029724, -0.9048429131507874, -5.540614947676659e-4] ] > }, "dense_1" => %{ "bias" => #Nx.Tensor< f32[4] - [7.311657827813178e-5, -0.027584673836827278, 0.20344746112823486, 0.1330498605966568] + [-0.2961483597755432, 0.3721822202205658, -0.1726730614900589, -0.20648165047168732] >, "kernel" => #Nx.Tensor< f32[8][4] [ - [-0.19199007749557495, 0.15660767257213593, 0.5446576476097107, 0.07457015663385391], - [0.034533075988292694, -0.10262273252010345, 0.05103863775730133, 0.5708968639373779], - [-0.4212855398654938, -0.47742989659309387, 0.18940746784210205, -0.40659299492836], - [0.2127801775932312, -0.07477620989084244, -0.11274989694356918, 0.4552466869354248], - [-0.13839538395404816, 0.09832656383514404, -0.16157560050487518, 0.7074514627456665], - [-0.6366024017333984, 0.3754875361919403, -0.6808919906616211, -0.209626242518425], - [0.595952033996582, 0.6973875164985657, 0.4453340172767639, 0.6247327327728271], - [-0.6312451958656311, 0.33275362849235535, 0.5079866051673889, -0.2508215010166168] + [0.602420449256897, 0.46551579236984253, 0.3295630216598511, 0.484800785779953], + [0.05755739286541939, -0.2412092238664627, 0.27874955534935, 0.13457047939300537], + [-0.26997247338294983, -0.4479314386844635, 0.4976465106010437, -0.05715075880289078], + [-0.7245721220970154, 0.1187945082783699, 0.14330074191093445, 0.3257679343223572], + [-0.032964885234832764, -0.625235915184021, -0.05669135972857475, -0.7016372680664062], + [-0.08433973789215088, -0.07334757596254349, 0.08273869007825851, 0.46893611550331116], + [0.4123252332210541, 0.9876810312271118, -0.3525731563568115, 0.030163511633872986], + [0.6962482333183289, 0.5394620299339294, 0.6907036304473877, -0.5448697209358215] ] > }, "dense_2" => %{ "bias" => #Nx.Tensor< f32[1] - [0.17476916313171387] + [0.7519291043281555] >, "kernel" => #Nx.Tensor< f32[4][1] [ - [0.8893225193023682], - [-0.4548797905445099], - [-0.8288624286651611], - [0.8321414589881897] + [0.7839917540550232], + [-0.8586246967315674], + [0.8599083423614502], + [0.29766184091567993] ] > } @@ -111,6 +112,7 @@ test_loop = Axon.Loop.evaluator(model) ``` #Axon.Loop< + metrics: %{}, handlers: %{ completed: [], epoch_completed: [], @@ -118,13 +120,12 @@ test_loop = Axon.Loop.evaluator(model) epoch_started: [], halted: [], iteration_completed: [ - {#Function<23.20267452/1 in Axon.Loop.log/5>, - #Function<5.20267452/1 in Axon.Loop.build_filter_fn/1>} + {#Function<27.37390314/1 in Axon.Loop.log/3>, + #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>} ], iteration_started: [], started: [] }, - metrics: %{}, ... > ``` @@ -139,6 +140,10 @@ test_loop = test_loop |> Axon.Loop.metric(:mean_absolute_error) ``` #Axon.Loop< + metrics: %{ + "mean_absolute_error" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>, + :mean_absolute_error} + }, handlers: %{ completed: [], epoch_completed: [], @@ -146,16 +151,12 @@ test_loop = test_loop |> Axon.Loop.metric(:mean_absolute_error) epoch_started: [], halted: [], iteration_completed: [ - {#Function<23.20267452/1 in Axon.Loop.log/5>, - #Function<5.20267452/1 in Axon.Loop.build_filter_fn/1>} + {#Function<27.37390314/1 in Axon.Loop.log/3>, + #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>} ], iteration_started: [], started: [] }, - metrics: %{ - "mean_absolute_error" => {#Function<12.6031754/3 in Axon.Metrics.running_average/1>, - :mean_absolute_error} - }, ... > ``` @@ -169,7 +170,7 @@ Axon.Loop.run(test_loop, data, trained_model_state, iterations: 1000) ``` -Batch: 1000, mean_absolute_error: 0.0955574 +Batch: 999, mean_absolute_error: 0.0856894 ``` @@ -179,7 +180,7 @@ Batch: 1000, mean_absolute_error: 0.0955574 0 => %{ "mean_absolute_error" => #Nx.Tensor< f32 - 0.09555738419294357 + 0.08568935841321945 > } } diff --git a/guides/training_and_evaluation/your_first_training_loop.livemd b/guides/training_and_evaluation/your_first_training_loop.livemd index a8129ae3..15c1301f 100644 --- a/guides/training_and_evaluation/your_first_training_loop.livemd +++ b/guides/training_and_evaluation/your_first_training_loop.livemd @@ -1,11 +1,8 @@ - - # Your first training loop ```elixir Mix.install([ - {:axon, github: "elixir-nx/axon"}, - {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true} + {:axon, ">= 0.5.0"} ]) ``` @@ -30,7 +27,11 @@ Each entry is a batch of input data with a corresponding batch of labels. You ca ```elixir train_data = Stream.repeatedly(fn -> - xs = Nx.random_normal({8, 1}) + {xs, _next_key} = + :random.uniform(9999) + |> Nx.Random.key() + |> Nx.Random.normal(shape: {8, 1}) + ys = Nx.sin(xs) {xs, ys} end) @@ -39,7 +40,7 @@ train_data = ``` -#Function<50.127921642/2 in Stream.repeatedly/1> +#Function<51.6935098/2 in Stream.repeatedly/1> ``` The most basic supervised training loop in Axon requires 3 things: @@ -80,26 +81,26 @@ loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd) ``` #Axon.Loop< + metrics: %{ + "loss" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>, + #Function<9.37390314/2 in Axon.Loop.build_loss_fn/1>} + }, handlers: %{ completed: [], epoch_completed: [ - {#Function<23.20267452/1 in Axon.Loop.log/5>, - #Function<5.20267452/1 in Axon.Loop.build_filter_fn/1>} + {#Function<27.37390314/1 in Axon.Loop.log/3>, + #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>} ], epoch_halted: [], epoch_started: [], halted: [], iteration_completed: [ - {#Function<23.20267452/1 in Axon.Loop.log/5>, - #Function<3.20267452/1 in Axon.Loop.build_filter_fn/1>} + {#Function<27.37390314/1 in Axon.Loop.log/3>, + #Function<64.37390314/2 in Axon.Loop.build_filter_fn/1>} ], iteration_started: [], started: [] }, - metrics: %{ - "loss" => {#Function<12.17233431/3 in Axon.Metrics.running_average/1>, - #Function<6.20267452/2 in Axon.Loop.build_loss_fn/1>} - }, ... > ``` @@ -113,7 +114,7 @@ Axon.Loop.run(loop, train_data, %{}, iterations: 1000) ``` -Epoch: 0, Batch: 1000, loss: 0.0421094 +Epoch: 0, Batch: 950, loss: 0.0563023 ``` @@ -123,46 +124,46 @@ Epoch: 0, Batch: 1000, loss: 0.0421094 "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] - [0.18567155301570892, -0.24138866364955902, 0.13732704520225525, 0.2081741988658905, 0.013805730268359184, 0.18336650729179382, 0.07754829525947571, -0.12579604983329773] + [-0.038592107594013214, 0.19925688207149506, -0.08018972724676132, -0.11267539858818054, 0.35166260600090027, -0.0794963389635086, 0.20298318564891815, 0.3049686849117279] >, "kernel" => #Nx.Tensor< f32[1][8] [ - [0.06517036259174347, -0.7166120409965515, 0.649202823638916, -0.3636767566204071, 0.33472830057144165, -0.6622008681297302, -0.6205887198448181, -0.1951046586036682] + [-0.06691190600395203, -0.32860732078552246, 0.22386932373046875, 0.16137443482875824, 0.23626506328582764, 0.2438151240348816, 0.2662005126476288, 0.32266947627067566] ] > }, "dense_1" => %{ "bias" => #Nx.Tensor< f32[4] - [0.2652607262134552, 0.1563350260257721, -0.12963515520095825, -0.15289783477783203] + [0.03138260543346405, 0.2621246576309204, 0.021843062713742256, -0.07498764991760254] >, "kernel" => #Nx.Tensor< f32[8][4] [ - [0.5483533143997192, 0.16270962357521057, -0.29001912474632263, 0.16584330797195435], - [-0.3257339596748352, 0.6900827884674072, 0.17480286955833435, -0.5176011323928833], - [-0.5791758298873901, 0.7136418223381042, 0.2863248288631439, 0.2406335324048996], - [0.5999854803085327, -0.09972921013832092, 0.16846133768558502, 0.21690420806407928], - [0.10213596373796463, 0.01878557913005352, 0.03252492845058441, -0.25937923789024353], - [0.4094444811344147, -0.48399242758750916, 0.18455447256565094, 0.40939682722091675], - [0.2809498906135559, 0.7121831178665161, 0.42944926023483276, -0.4959437847137451], - [-0.21076196432113647, -0.3021833896636963, -0.46126121282577515, -0.5571116805076599] + [0.541576087474823, 0.4923045039176941, 0.5933979749679565, -0.5083895921707153], + [0.5120893120765686, -0.6925638318061829, 0.36635661125183105, -0.05748361349105835], + [0.26158788800239563, -0.1788359135389328, -0.14064575731754303, -0.08323567360639572], + [0.6685130596160889, -0.4880330264568329, 0.5104460120201111, -0.3399733006954193], + [-0.6356683969497681, 0.770803689956665, -0.3876360058784485, -0.5178110599517822], + [0.4476216733455658, -0.21042484045028687, -0.4300518333911896, -0.2693784534931183], + [0.08789066225290298, 0.47043612599372864, 0.02871485985815525, 0.6908602714538574], + [0.45776790380477905, 0.6735268235206604, 0.40828803181648254, 0.19558420777320862] ] > }, "dense_2" => %{ "bias" => #Nx.Tensor< f32[1] - [0.3293934762477875] + [-0.748963475227356] >, "kernel" => #Nx.Tensor< f32[4][1] [ - [-1.041453242301941], - [0.6521084308624268], - [-0.5688052773475647], - [-0.5789349675178528] + [-0.22219088673591614], + [1.1391150951385498], + [-0.13221295177936554], + [-0.27904900908470154] ] > } @@ -180,9 +181,9 @@ Axon.Loop.run(loop, train_data, %{}, epochs: 3, iterations: 500) ``` -Epoch: 0, Batch: 500, loss: 0.0376754 -Epoch: 1, Batch: 500, loss: 0.0300909 -Epoch: 2, Batch: 500, loss: 0.0260511 +Epoch: 0, Batch: 450, loss: 0.0935063 +Epoch: 1, Batch: 450, loss: 0.0576384 +Epoch: 2, Batch: 450, loss: 0.0428323 ``` @@ -192,46 +193,46 @@ Epoch: 2, Batch: 500, loss: 0.0260511 "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] - [-0.09743800014257431, 0.36350908875465393, 0.23338767886161804, 0.21299506723880768, -0.04753172770142555, -0.03144805133342743, 0.0230794008821249, -0.17029045522212982] + [-0.035534460097551346, 0.2604885697364807, -0.10573504120111465, -0.16461455821990967, 0.3610309064388275, -0.10921606421470642, 0.2061888873577118, 0.3162775933742523] >, "kernel" => #Nx.Tensor< f32[1][8] [ - [-0.14422392845153809, -0.3840259611606598, 0.7611677050590515, 0.1216919794678688, -0.4270862638950348, 0.43146076798439026, -0.3569082021713257, 0.4051334857940674] + [-0.05344606190919876, -0.3463115096092224, 0.23782028257846832, 0.20592278242111206, 0.2195105254650116, 0.2618684470653534, 0.2559347450733185, 0.3006669282913208] ] > }, "dense_1" => %{ "bias" => #Nx.Tensor< f32[4] - [0.21392156183719635, 0.02405611053109169, 0.2970339059829712, 0.02390623465180397] + [0.03086121939122677, 0.28601887822151184, 0.02634759061038494, -0.08197703212499619] >, "kernel" => #Nx.Tensor< f32[8][4] [ - [-0.12441369146108627, 0.44625332951545715, -0.2095455527305603, -0.28127536177635193], - [0.6052687764167786, 0.1358352154493332, -0.24579593539237976, 0.6278529167175293], - [-0.5855410695075989, 0.014370989985764027, 0.4479483664035797, -0.07460466772317886], - [0.5286814570426941, -0.6323351263999939, 0.4167028069496155, -0.4724753797054291], - [-0.3705250918865204, 0.41602230072021484, -0.626926600933075, -0.03850430250167847], - [0.22140666842460632, -0.6492624878883362, 0.09525017440319061, 0.3179352283477783], - [-0.27787405252456665, 0.43634578585624695, 0.2430884689092636, 0.18133315443992615], - [0.4248749911785126, -0.059922583401203156, -0.09462974965572357, 0.57406085729599] + [0.5404174327850342, 0.49248307943344116, 0.5927202701568604, -0.5083895921707153], + [0.5133915543556213, -0.7197086811065674, 0.3669036030769348, -0.057483553886413574], + [0.26609811186790466, -0.20234307646751404, -0.14102067053318024, -0.08141336590051651], + [0.673393964767456, -0.512398362159729, 0.5106634497642517, -0.3384905159473419], + [-0.6347945928573608, 0.7695014476776123, -0.3877493143081665, -0.5186421275138855], + [0.45236992835998535, -0.2351287305355072, -0.4305106997489929, -0.2674770951271057], + [0.08871842920780182, 0.46521952748298645, 0.02729635499417782, 0.691332221031189], + [0.4584391117095947, 0.6687410473823547, 0.4068295657634735, 0.19576647877693176] ] > }, "dense_2" => %{ "bias" => #Nx.Tensor< f32[1] - [0.015223611146211624] + [-0.7425869703292847] >, "kernel" => #Nx.Tensor< f32[4][1] [ - [-0.6736029386520386], - [-0.019722800701856613], - [0.932664692401886], - [-0.9208926558494568] + [-0.24965399503707886], + [1.1746525764465332], + [-0.12984804809093475], + [-0.2796761095523834] ] > } @@ -249,7 +250,7 @@ model ``` -Epoch: 0, Batch: 1000, loss: 0.0700251 +Epoch: 0, Batch: 900, loss: 0.1492715 ``` @@ -259,46 +260,46 @@ Epoch: 0, Batch: 1000, loss: 0.0700251 "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] - [-0.10562735795974731, 0.3525764048099518, -0.0731351301074028, 0.3316117525100708, -0.08621923625469208, 0.15377338230609894, 0.02795499749481678, 0.19813594222068787] + [0.09267199039459229, 0.5775123834609985, -0.07691138982772827, 0.04283804073929787, -0.015639742836356163, -0.0725373700261116, -0.10598818212747574, 0.021243896335363388] >, "kernel" => #Nx.Tensor< f32[1][8] [ - [0.46547073125839233, -0.3838779926300049, 0.06413891166448593, 0.6604263186454773, 0.09603694081306458, -0.3142688274383545, -0.0673874095082283, -0.1551232486963272] + [0.07886508852243423, 0.826379120349884, 0.1022031158208847, -0.5164816975593567, 0.390212744474411, 0.2709604799747467, -0.05409134551882744, -0.6204537749290466] ] > }, "dense_1" => %{ "bias" => #Nx.Tensor< f32[4] - [0.16770508885383606, -0.11785938590765, -0.08730955421924591, 0.18854482471942902] + [-0.09577611088752747, 0.3303026556968689, -0.25102874636650085, -0.3312375247478485] >, "kernel" => #Nx.Tensor< f32[8][4] [ - [-0.32443270087242126, 0.33927711844444275, 0.5110990405082703, -0.34353166818618774], - [0.6843343377113342, -0.09189904481172562, 0.4550926983356476, -0.27025723457336426], - [0.029612643644213676, 0.3680649697780609, 0.5105444192886353, -0.1120513379573822], - [-0.12359219789505005, -0.2177252620458603, -0.2753210961818695, 0.7462171912193298], - [0.2723115086555481, 0.39580288529396057, -0.41799622774124146, 0.003858723910525441], - [0.21861012279987335, -0.37737029790878296, -0.5444738268852234, -0.12978340685367584], - [0.12569139897823334, 0.09505560994148254, 0.13603702187538147, 0.20154744386672974], - [0.4721740484237671, 0.27258655428886414, -0.6905713677406311, 0.09732398390769958] + [0.5508446097373962, -0.03904113546013832, 0.382876992225647, -0.6273598670959473], + [0.13289013504981995, 0.947068452835083, -0.27359727025032043, 0.4073275923728943], + [-0.10011858493089676, -0.32976964116096497, -0.3160743713378906, -0.3586210012435913], + [-0.628970205783844, -0.19567319750785828, -0.07241304218769073, -0.43270331621170044], + [-0.6155693531036377, -0.020595157518982887, -0.3254905045032501, 0.18614870309829712], + [-0.07561944425106049, -0.34477049112319946, -0.30149057507514954, -0.6603768467903137], + [-0.17559891939163208, -0.2768605649471283, 0.5830116868019104, 0.11386138200759888], + [-0.6376093626022339, -0.31125709414482117, 0.2749727964401245, -0.6777774691581726] ] > }, "dense_2" => %{ "bias" => #Nx.Tensor< f32[1] - [0.2536466121673584] + [-0.767456591129303] >, "kernel" => #Nx.Tensor< f32[4][1] [ - [-0.9850672483444214], - [-0.5319440960884094], - [-0.8099393844604492], - [0.6502916216850281] + [-0.3530634641647339], + [0.9497018456459045], + [0.31334763765335083], + [-0.624195396900177] ] > } diff --git a/lib/axon.ex b/lib/axon.ex index 45845a60..98a9de97 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -175,7 +175,7 @@ defmodule Axon do {init_fn, predict_fn} = Axon.build(model) - init_fn.(Nx.template({1, 1}, {:f, 32}), %{}) + params = init_fn.(Nx.template({1, 1}, {:f, 32}), %{}) predict_fn.(params, inputs) You may either set the default JIT compiler or backend globally, or @@ -185,7 +185,7 @@ defmodule Axon do {init_fn, predict_fn} = Axon.build(model, compiler: EXLA, mode: :train) - init_fn.(Nx.template({1, 1}, {:f, 32}), %{}) + params = init_fn.(Nx.template({1, 1}, {:f, 32}), %{}) predict_fn.(params, inputs) `predict_fn` by default runs in inference mode, which performs certain @@ -209,11 +209,72 @@ defmodule Axon do model_state = model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adamw(0.005)) + |> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adamw(learning_rate: 0.005)) |> Axon.Loop.run(train_data, epochs: 10, compiler: EXLA) - See `Axon.Updates` and `Axon.Loop` for a more in-depth treatment of + See `Polaris.Updates` and `Axon.Loop` for a more in-depth treatment of model optimization and model training. + + ## Using with `Nx.Serving` + + When deploying an `Axon` model to production, you usually want to batch + multiple prediction requests and run the inference for all of them at + once. Conveniently, `Nx` already has an abstraction for this task in the + form of `Nx.Serving`. Here's how you could define a serving for an `Axon` + model: + + def build_serving() do + # Configuration + batch_size = 4 + defn_options = [compiler: EXLA] + + Nx.Serving.new( + # This function runs on the serving startup + fn -> + # Build the Axon model and load params (usually from file) + model = build_model() + params = load_params() + + # Build the prediction defn function + {_init_fun, predict_fun} = Axon.build(model) + + inputs_template = %{"pixel_values" => Nx.template({batch_size, 224, 224, 3}, :f32)} + template_args = [Nx.to_template(params), inputs_template] + + # Compile the prediction function upfront for the configured batch_size + predict_fun = Nx.Defn.compile(predict_fun, template_args, defn_options) + + # The returned function is called for every accumulated batch + fn inputs -> + inputs = Nx.Batch.pad(inputs, batch_size - inputs.size) + predict_fun.(params, inputs) + end + end, + batch_size: batch_size + ) + end + + Then you would start the serving server as part of your application's + supervision tree: + + children = [ + ..., + {Nx.Serving, serving: build_serving(), name: MyApp.Serving, batch_timeout: 100} + ] + + With that in place, you can now ask serving for predictions all across + your application (controllers, live views, async jobs, etc.). Having a + tensor input you would do: + + inputs = %{"pixel_values" => ...} + batch = Nx.Batch.concatenate([inputs]) + result = Nx.Serving.batched_run(MyApp.Serving, batch) + + Usually you also want to do pre/post-processing of the model input/output. + You could make those preparations directly before/after `Nx.Serving.batched_run/2`, + however you can also make use of `Nx.Serving.client_preprocessing/2` and + `Nx.Serving.client_postprocessing/2` to encapsulate that logic as part of + the serving. """ alias __MODULE__, as: Axon alias Axon.Parameter @@ -381,21 +442,6 @@ defmodule Axon do layer(:input, [], name: name, shape: output_shape, op_name: :input, optional: optional) end - # TODO: remove on Axon v0.3 - - def input(input_shape, name) when is_binary(name) do - IO.warn( - "Passing shape as an argument to Axon.input/2 is deprecated, pass it as an option instead" - ) - - input(name, [{:shape, input_shape}]) - end - - @deprecated "Pass the shape as an option to Axon.input/2" - def input(input_shape, name, opts) when is_binary(name) do - input(name, [{:shape, input_shape} | opts]) - end - @doc """ Wraps an Axon model in an optional node. @@ -629,6 +675,69 @@ defmodule Axon do layer(:namespace, [axon], name: name) end + @doc """ + Returns a function which represents a self-contained re-usable block + of operations in a neural network. All parameters in the block are + shared between every usage of the block. + + This returns an arity-1 function which accepts a list of inputs which + are forwarded to `fun`. This is most often used in situations where + you wish to re-use parameters in a block: + + reused_dense = Axon.block(&Axon.dense(&1, 32)) + + Everytime `reused_dense` is invoked, it re-uses the same parameters: + + input = Axon.input("features") + # unique parameters + x1 = Axon.dense(input, 32) + # unique parameters + x2 = reused_dense.(x1) + # parameters shared + x3 = reused_dense.(x2) + + Subgraphs in blocks can be arbitrarily complex: + + reused_block = Axon.block(fn x -> + x + |> Axon.dense(32) + |> Axon.dense(64) + |> Axon.dense(32) + end) + + Blocks can also have multiple inputs, you can invoke a block with multiple + inputs by passing a list of arguments: + + reused_block = Axon.block(fn x, y, z -> + x = Axon.dense(x, 32) + y = Axon.dense(y, 32) + z = Axon.dense(z, 32) + + Axon.add([x, y, z]) + end) + + # invoke with a list + reused_block.([x, y, z]) + + Blocks prefix subgraph parameters with their name and a dot. As with other + Axon layers, if a name is not explicitly provided, one will be dynamically + generated. + """ + @doc type: :special + def block(fun, opts \\ []) when is_function(fun) do + opts = Keyword.validate!(opts, [:name]) + block_id = System.unique_integer([:positive, :monotonic]) + + fn inputs -> + layer(:block, List.wrap(inputs), + op_name: :block, + name: opts[:name], + block_fun: fun, + block_id: block_id + ) + end + end + @doc """ Adds a dense layer to the network. @@ -1521,6 +1630,39 @@ defmodule Axon do layer(pool, [x], opts) end + @doc """ + Adds a blur pooling layer to the network. + + See `Axon.Layers.blur_pool/2` for more details. + + ## Options + + * `:name` - layer name. + + * `:strides` - stride during convolution. Defaults to `1`. + + * `:channels` - channels location. One of `:first` or `:last`. + Defaults to `:last`. + """ + def blur_pool(%Axon{} = x, opts \\ []) do + opts = + Keyword.validate!(opts, [ + :name, + channels: :last + ]) + + channels = opts[:channels] + name = opts[:name] + + opts = [ + name: name, + channels: channels, + op_name: :blur_pool + ] + + layer(:blur_pool, [x], opts) + end + ## Adaptive Pooling @adaptive_pooling_layers [ @@ -1667,7 +1809,7 @@ defmodule Axon do * `:channel_index` - input feature index used for calculating mean and variance. Defaults to `-1`. - * `:epsilon` - numerical stability term. + * `:epsilon` - numerical stability term. Defaults to `1.0e-5`. """ @doc type: :normalization @@ -2182,6 +2324,30 @@ defmodule Axon do List.to_tuple(splits) end + @doc """ + Computes a sequence mask according to the given EOS token. + + Masks can be propagated to recurrent layers or custom layers to + indicate that a given token should be ignored in processing. This + is useful when you have sequences of variable length. + + Most commonly, `eos_token` is `0`. + + ## Options + + * `:name` - layer name. + """ + @doc type: :recurrent + def mask(%Axon{} = input, eos_token, opts \\ []) when is_integer(eos_token) do + opts = Keyword.validate!(opts, [:name]) + + fun = fn x, opts -> + Nx.equal(Nx.as_type(x, :s64), opts[:eos_token]) + end + + layer(fun, [input], eos_token: eos_token, op_name: :mask, name: opts[:name]) + end + @doc """ Applies the given forward function bidirectionally and merges the results with the given merge function. @@ -2292,16 +2458,17 @@ defmodule Axon do unroll: :dynamic, use_bias: true, kernel_initializer: :glorot_uniform, - bias_initializer: :zeros + bias_initializer: :zeros, + mask: Axon.constant(0) ]) activation = opts[:activation] gate = opts[:gate] unroll = opts[:unroll] - input_kernel_shape = fn inp, _ -> Axon.Shape.rnn_input_kernel(inp, units, :lstm) end - hidden_kernel_shape = fn inp, _ -> Axon.Shape.rnn_hidden_kernel(inp, units, :lstm) end - bias_shape = fn inp, _ -> Axon.Shape.rnn_bias(inp, units, :lstm) end + input_kernel_shape = fn inp, _, _ -> Axon.Shape.rnn_input_kernel(inp, units, :lstm) end + hidden_kernel_shape = fn inp, _, _ -> Axon.Shape.rnn_hidden_kernel(inp, units, :lstm) end + bias_shape = fn inp, _, _ -> Axon.Shape.rnn_bias(inp, units, :lstm) end kernel_initializer = opts[:kernel_initializer] @@ -2336,9 +2503,9 @@ defmodule Axon do bias = param("bias", {:tuple, List.duplicate(bias_shape, 4)}, initializer: bias_initializer) - {[x, hidden_state, input_kernel, hidden_kernel, bias], :lstm} + {[x, hidden_state, opts[:mask], input_kernel, hidden_kernel, bias], :lstm} else - {[x, hidden_state, input_kernel, hidden_kernel], &Axon.Layers.lstm/5} + {[x, hidden_state, opts[:mask], input_kernel, hidden_kernel], &Axon.Layers.lstm/6} end output = @@ -2484,6 +2651,7 @@ defmodule Axon do opts = Keyword.validate!(opts, [ :name, + mask: Axon.constant(0), activation: :tanh, gate: :sigmoid, unroll: :dynamic, @@ -2496,9 +2664,9 @@ defmodule Axon do gate = opts[:gate] unroll = opts[:unroll] - input_kernel_shape = fn inp, _ -> Axon.Shape.rnn_input_kernel(inp, units, :gru) end - hidden_kernel_shape = fn inp, _ -> Axon.Shape.rnn_hidden_kernel(inp, units, :gru) end - bias_shape = fn inp, _ -> Axon.Shape.rnn_bias(inp, units, :gru) end + input_kernel_shape = fn inp, _, _ -> Axon.Shape.rnn_input_kernel(inp, units, :gru) end + hidden_kernel_shape = fn inp, _, _ -> Axon.Shape.rnn_hidden_kernel(inp, units, :gru) end + bias_shape = fn inp, _, _ -> Axon.Shape.rnn_bias(inp, units, :gru) end kernel_initializer = opts[:kernel_initializer] @@ -2532,9 +2700,9 @@ defmodule Axon do bias = param("bias", {:tuple, List.duplicate(bias_shape, 4)}, initializer: bias_initializer) - [x, hidden_state, input_kernel, hidden_kernel, bias] + [x, hidden_state, opts[:mask], input_kernel, hidden_kernel, bias] else - [x, hidden_state, input_kernel, hidden_kernel] + [x, hidden_state, opts[:mask], input_kernel, hidden_kernel] end output = @@ -2666,6 +2834,7 @@ defmodule Axon do opts = Keyword.validate!(opts, [ :name, + mask: Axon.constant(0), padding: :same, kernel_size: 1, strides: 1, @@ -2681,17 +2850,17 @@ defmodule Axon do unroll = opts[:unroll] kernel_initializer = opts[:kernel_initializer] - hidden_kernel_shape = fn _, {inp, _} -> + hidden_kernel_shape = fn _, {inp, _}, _ -> shape = Tuple.delete_at(inp, 1) Axon.Shape.conv_kernel(shape, 4 * units, kernel_size, :first, 1) end - input_kernel_shape = fn inp, _ -> + input_kernel_shape = fn inp, _, _ -> shape = Tuple.delete_at(inp, 1) Axon.Shape.conv_kernel(shape, 4 * units, kernel_size, :first, 1) end - bias_shape = fn inp, _ -> + bias_shape = fn inp, _, _ -> shape = Tuple.delete_at(inp, 1) Axon.Shape.conv_bias(shape, 4 * units, kernel_size, :first, 1) end @@ -2716,9 +2885,9 @@ defmodule Axon do if opts[:use_bias] do bias_initializer = opts[:bias_initializer] b = param("bias", {:tuple, [bias_shape]}, initializer: bias_initializer) - {[x, hidden_state, wi, wh, b], :conv_lstm} + {[x, hidden_state, opts[:mask], wi, wh, b], :conv_lstm} else - {[x, hidden_state, wi, wh], :conv_lstm} + {[x, hidden_state, opts[:mask], wi, wh], :conv_lstm} end output = @@ -2938,7 +3107,7 @@ defmodule Axon do |> Axon.dense(1000, activation: :softmax) model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(0.005)) + |> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adam(learning_rate: 0.005)) |> Axon.Loop.run(data, epochs: 10) When compiled, frozen parameters are wrapped in `Nx.Defn.Kernel.stop_grad/1`, @@ -3010,7 +3179,7 @@ defmodule Axon do |> Axon.unfreeze(up: 25) model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(0.0005)) + |> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adam(learning_rate: 0.0005)) |> Axon.Loop.run(data, epochs: 10) When compiled, frozen parameters are wrapped in `Nx.Defn.Kernel.stop_grad/1`, @@ -3071,19 +3240,25 @@ defmodule Axon do """ @doc type: :debug - def attach_hook(%Axon{output: id, nodes: nodes} = axon, fun, opts \\ []) do - opts = Keyword.validate!(opts, on: :forward, mode: :both) - on_event = opts[:on] - mode = opts[:mode] + def attach_hook(x, fun, opts \\ []) + def attach_hook(%Axon{output: id, nodes: nodes} = axon, fun, opts) do updated_nodes = Map.update!(nodes, id, fn axon_node -> - %{axon_node | hooks: [{on_event, mode, fun}]} + attach_hook(axon_node, fun, opts) end) %{axon | nodes: updated_nodes} end + def attach_hook(%Axon.Node{hooks: hooks} = axon_node, fun, opts) do + opts = Keyword.validate!(opts, on: :forward, mode: :both) + on_event = opts[:on] + mode = opts[:mode] + + %{axon_node | hooks: [{on_event, mode, fun} | hooks]} + end + ## Graph Manipulation and Utilities # TODO: Revisit later with new decoupled structs @@ -3247,12 +3422,12 @@ defmodule Axon do you can use this function to visualize intermediate activations of all convolutional layers in a model: - instrumented_model = Axon. (model, fn - %Axon{op: :conv} = graph -> - Axon.attach_hook(graph, &visualize_activations/1) + instrumented_model = Axon.map_nodes(model, fn + %Axon.Node{op: :conv} = axon_node -> + Axon.attach_hook(axon_node, &visualize_activations/1) - graph -> - graph + axon_node -> + axon_node end) Another use case is to replace entire classes of layers @@ -3396,7 +3571,10 @@ defmodule Axon do @doc """ Builds the given model to `{init_fn, predict_fn}`. - Once built, a model can be passed as argument to `Nx.Defn`. + The given functions can be either given as arguments to `Nx.Defn` + functions or be invoked directly, to perform just-in-time compilation + and execution. If you want to compile the model (instead of just-in-time) + based on a predefined initialization shape, see `compile/4`. ## `init_fn` @@ -3416,19 +3594,29 @@ defmodule Axon do ## Options - * `:mode` - one of `:inference` or `:train`. Forwarded to layers - to control differences in compilation at training or inference time. - Defaults to `:inference` + * `:compiler` - the underlying `Nx.Defn` compiler to perform + JIT compilation when the functions are invoked. If none is + passed, it uses the default compiler configured in `Nx.Defn`; * `:debug` - if `true`, will log graph traversal and generation metrics. Also forwarded to JIT if debug mode is available for your chosen compiler or backend. Defaults to `false` - All other options are forwarded to the default JIT compiler - or backend. + * `:mode` - one of `:inference` or `:train`. Forwarded to layers + to control differences in compilation at training or inference time. + Defaults to `:inference` + + All other options are forwarded to the underlying JIT compiler. """ @doc type: :model def build(model, opts \\ []) when is_list(opts) do + if opts[:backend] do + IO.warn( + "the :backend option has no effect on Axon.build/2. " <> + "Use Nx.default_backend/1 to set a backend instead" + ) + end + {init_fn, predict_fn} = Axon.Compiler.build(model, opts) opts = [on_conflict: :reuse] ++ opts {Nx.Defn.jit(init_fn, opts), Nx.Defn.jit(predict_fn, opts)} @@ -3445,8 +3633,11 @@ defmodule Axon do This function makes use of the built-in `Nx.Defn.compile/3`. Note that passing inputs which differ in shape or type from the templates - provided to this function will result in potentially expensive - recompilation. + provided to this function will result in a crash. + + ## Options + + It accepts the same options as `build/2`. """ @doc type: :model def compile(model, template, init_params \\ %{}, opts \\ []) when is_list(opts) do @@ -3554,8 +3745,10 @@ defmodule Axon do end @doc """ - Compiles and runs the given Axon model with `params` on - `input` with the given compiler options. + Builds and runs the given Axon `model` with `params` and `input`. + + This is equivalent to calling `build/2` and then invoking the + predict function. ## Options diff --git a/lib/axon/activations.ex b/lib/axon/activations.ex index 722ba0c8..440d9cd4 100644 --- a/lib/axon/activations.ex +++ b/lib/axon/activations.ex @@ -85,18 +85,16 @@ defmodule Axon.Activations do """ defn celu(x, opts \\ []) do opts = keyword!(opts, alpha: 1.0) - - transform( - opts[:alpha], - fn x -> - if Elixir.Kernel.==(x, 0), - do: raise(ArgumentError, ":alpha must be non-zero in CELU activation") - end - ) + validate_celu_alpha!(opts[:alpha]) Nx.select(Nx.greater(x, 0), x, opts[:alpha] * Nx.expm1(x / opts[:alpha])) end + deftransformp validate_celu_alpha!(alpha) do + if alpha == 0, + do: raise(ArgumentError, ":alpha must be non-zero in CELU activation") + end + @doc ~S""" Exponential linear unit activation. @@ -361,22 +359,22 @@ defmodule Axon.Activations do iex> Axon.Activations.log_sumexp(Nx.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], names: [:data])) #Nx.Tensor< f32[data: 1] - [0.45776283740997314] + [3.4577627182006836] > iex> Axon.Activations.log_sumexp(Nx.tensor([[-1.0, -2.0, -3.0], [1.0, 2.0, 3.0]], type: {:bf, 16}, names: [:batch, :data])) #Nx.Tensor< bf16[batch: 2][data: 1] [ - [0.404296875], - [0.404296875] + [-0.59375], + [3.390625] ] > """ defn log_sumexp(x, opts \\ []) do opts = keyword!(opts, axis: -1) - axes = transform(opts[:axis], &List.wrap/1) + axes = wrap(opts[:axis]) # This is a scaling term designed to prevent over/under flow when x is very # large. Consider cases where the intermediate value e^x with large positive @@ -392,7 +390,8 @@ defmodule Axon.Activations do # We are essentially treating the max value as a constant term, C. Thus there # is no need to differentiate through the max. See also: https://github.com/google/jax/pull/2260 # for a note on performance. - max_val = stop_grad(Nx.reduce_max(x, axes: axes, keep_axes: true)) + max_val = Nx.reduce_max(x, axes: axes, keep_axes: true) + max_val = stop_grad(Nx.select(Nx.is_infinity(max_val), 0, max_val)) stable_exp = x @@ -403,6 +402,7 @@ defmodule Axon.Activations do stable_exp |> Nx.sum(axes: axes, keep_axes: true) |> Nx.log() + |> Nx.add(max_val) res end @@ -457,12 +457,6 @@ defmodule Axon.Activations do defn log_softmax(x, opts \\ []) do opts = keyword!(opts, axis: -1) - transform({x, opts}, fn {x, opts} -> - if Elixir.Kernel.<=(Nx.rank(x), opts[:axis]) do - raise ArgumentError, "log_softmax axis must be within rank of tensor" - end - end) - shifted = x - stop_grad(Nx.reduce_max(x, axes: [opts[:axis]], keep_axes: true)) shifted @@ -525,7 +519,8 @@ defmodule Axon.Activations do defn relu(x) do custom_grad( Nx.max(x, 0), - fn _ans, g -> [{x, Nx.select(Nx.greater(x, 0), g, Nx.broadcast(0, g))}] end + [x], + fn g -> [Nx.select(Nx.greater(x, 0), g, Nx.broadcast(0, g))] end ) end @@ -593,7 +588,7 @@ defmodule Axon.Activations do defn sigmoid(x) do # Cache logits so they are available in certain calculations, # e.g. binary_cross_entropy and categorical_cross_entropy - transform(Nx.sigmoid(x), &Nx.Defn.Expr.metadata(&1, %{logits: x})) + cache_logits(x, Nx.sigmoid(x)) end @doc ~S""" @@ -707,13 +702,7 @@ defmodule Axon.Activations do """ defn softmax(x, opts \\ []) do opts = keyword!(opts, axis: -1) - axes = transform(opts[:axis], &List.wrap/1) - - transform({x, axes}, fn {x, axes} -> - Enum.each(axes, fn axis -> - Nx.Shape.normalize_axis(Nx.shape(x), axis, Nx.names(x)) - end) - end) + axes = wrap(opts[:axis]) # This is a scaling term designed to prevent over/under flow when x is very # large. Consider cases where the intermediate value e^x with large positive @@ -744,7 +733,7 @@ defmodule Axon.Activations do # Cache logits so they are available in certain calculations, # e.g. binary_cross_entropy and categorical_cross_entropy - transform(res, &Nx.Defn.Expr.metadata(&1, %{logits: x})) + cache_logits(x, res) end @doc ~S""" @@ -836,4 +825,12 @@ defmodule Axon.Activations do """ defn tanh(x), do: Nx.tanh(x) + + ## Helpers + + deftransformp cache_logits(input, output) do + Nx.Defn.Expr.metadata(output, %{logits: input}) + end + + deftransformp wrap(axis), do: List.wrap(axis) end diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index 9999dc4a..4da5c7cf 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -6,13 +6,29 @@ defmodule Axon.CompileError do formatted_mfa = Exception.format_mfa(module, fun, arity) formatted_msg = Exception.format(:error, exception.exception, exception.compile_stacktrace) + layer_info = + if exception.layer_stacktrace != [] do + """ + + The layer was defined at: + + #{Exception.format_stacktrace(exception.layer_stacktrace)} + """ + else + """ + + (pass debug: true to build/compile see where the layer was defined) + + """ + end + """ exception found when compiling layer #{formatted_mfa} named #{exception.name}: - #{indent(formatted_msg)} - The layer was defined at: + #{indent(formatted_msg)}\ + + #{layer_info}\ - #{Exception.format_stacktrace(exception.layer_stacktrace)} Compiling of the model was initiated at: """ end @@ -36,9 +52,9 @@ defmodule Axon.Compiler do seed = Keyword.get_lazy(opts, :seed, fn -> :erlang.system_time() end) config = %{mode: mode, debug?: debug?} - {time, {root_id, {cache, _op_counts}}} = + {time, {root_id, {cache, _op_counts, _block_cache}}} = :timer.tc(fn -> - to_model_funs(id, nodes, {%{}, %{}}, config) + to_model_funs(id, nodes, {%{}, %{}, %{}}, config) end) if debug? do @@ -186,8 +202,8 @@ defmodule Axon.Compiler do %{^key => %{} = nested} when not is_struct(nested) -> %{params | key => merge_params!(nested, value)} - %{^key => _} -> - %{params | key => value} + %{^key => template} -> + %{params | key => merge_type(key, template, value)} _ -> Logger.warning("found unexpected key in the initial parameters map: #{inspect(key)}") @@ -196,6 +212,19 @@ defmodule Axon.Compiler do end) end + defp merge_type(key, template, value) do + if Nx.type(template) != Nx.type(value) do + Logger.warning( + "initial type for parameter #{key} does not match policy," <> + " consider using Axon.MixedPrecision.cast before passing" <> + " initial state to model initialization function to avoid" <> + " type casts" + ) + end + + Nx.as_type(value, Nx.type(template)) + end + def compile(graph, _opts) do raise ArgumentError, "attempting to compile model functions from" <> @@ -204,17 +233,17 @@ defmodule Axon.Compiler do " output, use `Axon.container`" end - defp to_model_funs(id, nodes, {cache, op_counts}, config) do + defp to_model_funs(id, nodes, {cache, op_counts, block_cache}, config) do case cache do %{^id => {int_id, _}} -> - {int_id, {cache, op_counts}} + {int_id, {cache, op_counts, block_cache}} %{} -> - {id, model_funs, cache, op_counts} = - recur_model_funs(nodes[id], nodes, {cache, op_counts}, config) + {id, model_funs, cache, op_counts, block_cache} = + recur_model_funs(nodes[id], nodes, {cache, op_counts, block_cache}, config) int_id = map_size(cache) - {int_id, {Map.put(cache, id, {int_id, model_funs}), op_counts}} + {int_id, {Map.put(cache, id, {int_id, model_funs}), op_counts, block_cache}} end end @@ -256,11 +285,12 @@ defmodule Axon.Compiler do defp recur_model_funs( %Axon.Node{id: id, mode: node_mode, parent: [parent | _]}, nodes, - {cache, op_counts}, + {cache, op_counts, block_cache}, config ) when node_mode != :both and node_mode != config.mode do - {parent_id, {cache, op_counts}} = to_model_funs(parent, nodes, {cache, op_counts}, config) + {parent_id, {cache, op_counts, block_cache}} = + to_model_funs(parent, nodes, {cache, op_counts, block_cache}, config) predict_fun = fn params, inputs, state, cache, result_cache, fn_stacktrace -> call_predict_cache(parent_id, params, inputs, state, cache, result_cache, fn_stacktrace) @@ -271,13 +301,13 @@ defmodule Axon.Compiler do end model_funs = %{predict: predict_fun, init: init_fun} - {id, model_funs, cache, op_counts} + {id, model_funs, cache, op_counts, block_cache} end defp recur_model_funs( %Axon.Node{id: id, op: :constant, opts: [value: tensor], policy: %{output: output}}, _nodes, - {cache, op_counts}, + {cache, op_counts, block_cache}, _ ) do op_counts = Map.update(op_counts, :constant, 1, fn x -> x + 1 end) @@ -293,7 +323,7 @@ defmodule Axon.Compiler do end model_funs = %{predict: predict_fun, init: init_fun} - {id, model_funs, cache, op_counts} + {id, model_funs, cache, op_counts, block_cache} end defp recur_model_funs( @@ -305,7 +335,7 @@ defmodule Axon.Compiler do opts: [shape: _input_shape, optional: optional?] }, _nodes, - {cache, op_counts}, + {cache, op_counts, block_cache}, %{mode: mode} ) do name = name_fn.(:input, op_counts) @@ -331,16 +361,17 @@ defmodule Axon.Compiler do end model_funs = %{predict: predict_fun, init: init_fun} - {id, model_funs, cache, op_counts} + {id, model_funs, cache, op_counts, block_cache} end defp recur_model_funs( %Axon.Node{id: id, op: :optional, parent: [parent]}, nodes, - {cache, op_counts}, + {cache, op_counts, block_cache}, config ) do - {parent_id, {cache, op_counts}} = to_model_funs(parent, nodes, {cache, op_counts}, config) + {parent_id, {cache, op_counts, block_cache}} = + to_model_funs(parent, nodes, {cache, op_counts, block_cache}, config) predict_fun = fn params, inputs, state, cache, result_cache, fn_stacktrace -> {out, {state, result_cache}} = @@ -361,7 +392,7 @@ defmodule Axon.Compiler do end model_funs = %{predict: predict_fun, init: init_fun} - {id, model_funs, cache, op_counts} + {id, model_funs, cache, op_counts, block_cache} end defp recur_model_funs( @@ -370,7 +401,7 @@ defmodule Axon.Compiler do cache_and_counts, config ) do - {parent_ids, {cache, op_counts}} = + {parent_ids, {cache, op_counts, block_cache}} = deep_map_reduce(parents, cache_and_counts, &to_model_funs(&1, nodes, &2, config)) op_counts = Map.update(op_counts, :container, 1, fn x -> x + 1 end) @@ -427,13 +458,138 @@ defmodule Axon.Compiler do end model_funs = %{predict: predict_fun, init: init_fun} - {id, model_funs, cache, op_counts} + {id, model_funs, cache, op_counts, block_cache} + end + + defp recur_model_funs( + %Axon.Node{ + id: id, + op: :block, + parent: [parent], + opts: [block_fun: block_fun, block_id: block_id], + name: name_fn + }, + nodes, + cache_and_counts, + config + ) do + {[parent_id], {cache, op_counts, block_cache}} = + Enum.map_reduce( + [parent], + cache_and_counts, + &to_model_funs(&1, nodes, &2, config) + ) + + {{block_init_fun, block_predict_fun}, block_name, block_cache, op_counts} = + case block_cache do + %{^block_id => {funs, name}} = block_cache -> + {funs, name, block_cache, op_counts} + + %{} -> + funs = build(block_fun.(Axon.input("subgraph")), debug?: config.debug?) + name = name_fn.(:block, op_counts) + op_counts = Map.update(op_counts, :block, 1, fn x -> x + 1 end) + {funs, name, Map.put(block_cache, block_id, {funs, name}), op_counts} + end + + predict_fun = fn params, inputs, state, cache, result_cache, fn_stacktrace -> + # Recurse graph inputs and invoke cache to get parent results, + # state, and result_cache and then apply dtype policy and hooks + # to each input + {[layer_input], {state, result_cache, none?}} = + Enum.map_reduce( + [parent_id], + {state, result_cache, false}, + fn parent_id, {state, result_cache, none?} -> + {layer_input, {state, result_cache}} = + call_predict_cache( + parent_id, + params, + inputs, + state, + cache, + result_cache, + fn_stacktrace + ) + + none? = none? or propagating_none?(layer_input) + + {layer_input, {state, result_cache, none?}} + end + ) + + if none? do + {%Axon.None{}, {state, result_cache}} + else + block_params = params[block_name] || %{} + result = apply(block_predict_fun, [block_params, layer_input]) + + {out_result, out_state} = + case result do + # Make sure the none is non-propagating + %Axon.None{} -> %Axon.None{} + %{prediction: pred_expr, state: state_expr} -> {pred_expr, state_expr} + result -> {result, %{}} + end + + state = + if map_size(out_state) == 0 do + state + else + Map.put(state, block_name, out_state) + end + + {out_result, {state, result_cache}} + end + end + + init_fun = fn template, cache, result_cache, fn_stacktrace, keys -> + {[parent_shape], {parent_params, result_cache, none?}} = + Enum.map_reduce([parent_id], {%{}, result_cache, false}, fn + parent_id, {params, result_cache, none?} -> + {parent_shape, {params, result_cache}} = + call_init_cache( + parent_id, + template, + params, + cache, + result_cache, + fn_stacktrace, + keys + ) + + none? = none? or propagating_none?(parent_shape) + {parent_shape, {params, result_cache, none?}} + end) + + if none? do + {%Axon.None{}, {parent_params, result_cache}} + else + template = Nx.broadcast(0.0, parent_shape) + block_params = apply(block_init_fun, [template, %{}]) + + params = + if block_params == %{} do + %{} + else + Map.put(parent_params, block_name, block_params) + end + + {pred_expr, {_, result_cache}} = + predict_fun.(params, template, %{}, cache, result_cache, fn_stacktrace) + + {safe_shape(pred_expr), {params, result_cache}} + end + end + + model_funs = %{predict: predict_fun, init: init_fun} + {id, model_funs, cache, op_counts, block_cache} end defp recur_model_funs( %Axon.Node{id: id, op: :namespace, name: name_fn, parent: [parent]}, nodes, - {cache, op_counts}, + {cache, op_counts, block_cache}, config ) do name = name_fn.(:namespace, op_counts) @@ -446,8 +602,8 @@ defmodule Axon.Compiler do # All of the children of this namespace belong to it, so # we forward this name to the namespace, but everything after # it belongs to whatever namespace we're currently in - {parent_id, {cache, namespace_op_counts}} = - to_model_funs(parent, nodes, {cache, namespace_op_counts}, config) + {parent_id, {cache, namespace_op_counts, block_cache}} = + to_model_funs(parent, nodes, {cache, namespace_op_counts, block_cache}, config) # Update the global op_count of input layers, since they # are a global operation regardless of where they are @@ -503,7 +659,7 @@ defmodule Axon.Compiler do end model_funs = %{predict: predict_fun, init: init_fun} - {id, model_funs, cache, op_counts} + {id, model_funs, cache, op_counts, block_cache} end defp recur_model_funs( @@ -529,7 +685,7 @@ defmodule Axon.Compiler do # application within the function. We work only with # functions and IDs to avoid leaking entire graphs into # the closure - {parent_ids, {cache, op_counts}} = + {parent_ids, {cache, op_counts, block_cache}} = Enum.map_reduce( inputs, cache_and_counts, @@ -553,6 +709,7 @@ defmodule Axon.Compiler do &5, &6, op, + op_name, parent_ids, name, args, @@ -581,7 +738,7 @@ defmodule Axon.Compiler do ) model_funs = %{predict: predict_fun, init: init_fun} - {id, model_funs, cache, op_counts} + {id, model_funs, cache, op_counts, block_cache} end defp get_input(inputs, name, optional?) do @@ -632,6 +789,7 @@ defmodule Axon.Compiler do result_cache, fn_stacktrace, op, + op_name, parent_ids, name, args, @@ -718,7 +876,7 @@ defmodule Axon.Compiler do # in Axon.Layers. The implication of this is that every function which # can be invoked as a layer must have a definition in Axon.Layers even # if there is a distinction (e.g. with activations) - result = apply_layer(name, op, args, layer_stacktrace, fn_stacktrace) + result = apply_layer(name, op, args, layer_stacktrace, fn_stacktrace, op_name) result = case result do @@ -756,14 +914,30 @@ defmodule Axon.Compiler do end end - defp apply_layer(name, op, args, layer_stacktrace, fn_stacktrace) do + defp apply_layer(name, op, args, layer_stacktrace, fn_stacktrace, op_name) do try do - case op do - op when is_function(op) -> - apply(op, args) + result = + case op do + op when is_function(op) -> + apply(op, args) + + op when is_atom(op) -> + apply(Axon.Layers, op, args) + end + + case result do + out when is_tuple(out) -> + out + + %Axon.None{} = out -> + out + + %Axon.StatefulOutput{output: out} = stateful -> + out = Nx.Defn.Expr.metadata(Nx.Defn.Expr.tensor(out), %{axon_layer: op_name}) + %{stateful | output: out} - op when is_atom(op) -> - apply(Axon.Layers, op, args) + out -> + Nx.Defn.Expr.metadata(Nx.Defn.Expr.tensor(out), %{axon_layer: op_name}) end rescue exception -> @@ -889,9 +1063,9 @@ defmodule Axon.Compiler do if event? and mode? do if on_event == :backward do - Nx.Defn.Kernel.custom_grad(expr, fn _ans, g -> + Nx.Defn.Kernel.custom_grad(expr, [expr], fn g -> hooked_g = Nx.Defn.Kernel.hook(g, hook_fn) - [{expr, hooked_g}] + [hooked_g] end) else Nx.Defn.Kernel.hook(expr, hook_fn) diff --git a/lib/axon/defn.ex b/lib/axon/defn.ex index e1e7cb40..e7313ee4 100644 --- a/lib/axon/defn.ex +++ b/lib/axon/defn.ex @@ -19,4 +19,7 @@ defmodule Axon.Defn do @impl true def __compile__(_, _, _, _), do: raise("not implemented") + + @impl true + def __partitions_options__(_), do: raise("not implemented") end diff --git a/lib/axon/display.ex b/lib/axon/display.ex index cbd4fc4f..1e95e9c6 100644 --- a/lib/axon/display.ex +++ b/lib/axon/display.ex @@ -30,7 +30,7 @@ defmodule Axon.Display do Axon.Display.as_table(model, input) """ def as_table(%Axon{output: id, nodes: nodes}, input_templates) do - assert_table_rex!("ax_table/2") + assert_table_rex!("as_table/2") title = "Model" header = ["Layer", "Input Shape", "Output Shape", "Options", "Parameters"] diff --git a/lib/axon/initializers.ex b/lib/axon/initializers.ex index eeb2176a..a947f7a1 100644 --- a/lib/axon/initializers.ex +++ b/lib/axon/initializers.ex @@ -174,20 +174,16 @@ defmodule Axon.Initializers do """ def uniform(opts \\ []) do + opts = Keyword.validate!(opts, scale: 1.0e-2) + scale = Keyword.fetch!(opts, :scale) + fn shape, type, key -> - scale = opts[:scale] || 1.0e-2 - uniform_impl(key, shape: shape, type: type, scale: scale) + uniform_impl(key, scale, shape: shape, type: type) end end - defnp uniform_impl(key, opts \\ []) do - opts = keyword!(opts, [:shape, type: {:f, 32}, scale: 1.0e-2]) - shape = Nx.shape(opts[:shape]) - - Nx.Random.uniform_split(key, Nx.negate(opts[:scale]), opts[:scale], - type: opts[:type], - shape: shape - ) + defnp uniform_impl(key, scale, opts) do + Nx.Random.uniform_split(key, Nx.negate(scale), scale, opts) end @doc """ @@ -216,18 +212,15 @@ defmodule Axon.Initializers do """ def normal(opts \\ []) do + opts = Keyword.validate!(opts, scale: 1.0e-2, mean: 0.0) + scale = Keyword.fetch!(opts, :scale) + mean = Keyword.fetch!(opts, :mean) + fn shape, type, key -> - scale = opts[:scale] || 1.0e-2 - mean = opts[:mean] || 0.0 - normal_impl(key, shape: shape, type: type, scale: scale, mean: mean) + Nx.Random.normal_split(key, mean, scale, type: type, shape: shape) end end - defnp normal_impl(key, opts \\ []) do - opts = keyword!(opts, [:shape, type: {:f, 32}, scale: 1.0e-2, mean: 0.0]) - Nx.Random.normal_split(key, opts[:mean], opts[:scale], shape: opts[:shape], type: opts[:type]) - end - @doc """ Initializes parameters with the Lecun uniform initializer. @@ -261,25 +254,21 @@ defmodule Axon.Initializers do """ def lecun_uniform(opts \\ []) do + opts = Keyword.validate!(opts, scale: 1.0) + scale = Keyword.fetch!(opts, :scale) + fn shape, type, key -> - scale = opts[:scale] || 1.0 - lecun_uniform_impl(key, shape: shape, type: type, scale: scale) + variance_scaling_impl( + key, + scale, + shape: shape, + type: type, + mode: :fan_in, + distribution: :uniform + ) end end - defnp lecun_uniform_impl(key, opts \\ []) do - opts = keyword!(opts, [:shape, type: {:f, 32}, scale: 1.0]) - - variance_scaling_impl( - key, - shape: opts[:shape], - type: opts[:type], - scale: opts[:scale], - mode: :fan_in, - distribution: :uniform - ) - end - @doc """ Initializes parameters with the Lecun normal initializer. @@ -313,25 +302,21 @@ defmodule Axon.Initializers do """ def lecun_normal(opts \\ []) do + opts = Keyword.validate!(opts, scale: 1.0) + scale = Keyword.fetch!(opts, :scale) + fn shape, type, key -> - scale = opts[:scale] || 1.0 - lecun_normal_impl(key, shape: shape, type: type, scale: scale) + variance_scaling_impl( + key, + scale, + shape: shape, + type: type, + mode: :fan_in, + distribution: :truncated_normal + ) end end - defnp lecun_normal_impl(key, opts \\ []) do - opts = keyword!(opts, [:shape, type: {:f, 32}, scale: 1.0]) - - variance_scaling_impl( - key, - shape: opts[:shape], - type: opts[:type], - scale: opts[:scale], - mode: :fan_in, - distribution: :truncated_normal - ) - end - @doc """ Initializes parameters with the Glorot uniform initializer. @@ -368,25 +353,21 @@ defmodule Axon.Initializers do """ def glorot_uniform(opts \\ []) do + opts = Keyword.validate!(opts, scale: 1.0) + scale = Keyword.fetch!(opts, :scale) + fn shape, type, key -> - scale = opts[:scale] || 1.0 - glorot_uniform_impl(key, shape: shape, type: type, scale: scale) + variance_scaling_impl( + key, + scale, + shape: shape, + type: type, + mode: :fan_avg, + distribution: :uniform + ) end end - defnp glorot_uniform_impl(key, opts \\ []) do - opts = keyword!(opts, [:shape, type: {:f, 32}, scale: 1.0]) - - variance_scaling_impl( - key, - shape: opts[:shape], - type: opts[:type], - scale: opts[:scale], - mode: :fan_avg, - distribution: :uniform - ) - end - @doc """ Initializes parameters with the Glorot normal initializer. @@ -423,25 +404,21 @@ defmodule Axon.Initializers do """ def glorot_normal(opts \\ []) do + opts = Keyword.validate!(opts, scale: 1.0) + scale = Keyword.fetch!(opts, :scale) + fn shape, type, key -> - scale = opts[:scale] || 1.0 - glorot_normal_impl(key, shape: shape, type: type, scale: scale) + variance_scaling_impl( + key, + scale, + shape: shape, + type: type, + mode: :fan_avg, + distribution: :truncated_normal + ) end end - defnp glorot_normal_impl(key, opts \\ []) do - opts = keyword!(opts, [:shape, type: {:f, 32}, scale: 1.0]) - - variance_scaling_impl( - key, - shape: opts[:shape], - type: opts[:type], - scale: opts[:scale], - mode: :fan_avg, - distribution: :truncated_normal - ) - end - @doc """ Initializes parameters with the He uniform initializer. @@ -475,25 +452,21 @@ defmodule Axon.Initializers do """ def he_uniform(opts \\ []) do + opts = Keyword.validate!(opts, scale: 2.0) + scale = Keyword.fetch!(opts, :scale) + fn shape, type, key -> - scale = opts[:scale] || 2.0 - he_uniform_impl(key, shape: shape, type: type, scale: scale) + variance_scaling_impl( + key, + scale, + shape: shape, + type: type, + mode: :fan_in, + distribution: :uniform + ) end end - defnp he_uniform_impl(key, opts \\ []) do - opts = keyword!(opts, [:shape, type: {:f, 32}, scale: 2.0]) - - variance_scaling_impl( - key, - shape: opts[:shape], - type: opts[:type], - scale: opts[:scale], - mode: :fan_in, - distribution: :uniform - ) - end - @doc """ Initializes parameters with the He normal initializer. @@ -527,25 +500,21 @@ defmodule Axon.Initializers do """ def he_normal(opts \\ []) do + opts = Keyword.validate!(opts, scale: 2.0) + scale = Keyword.fetch!(opts, :scale) + fn shape, type, key -> - scale = opts[:scale] || 2.0 - he_normal_impl(key, shape: shape, type: type, scale: scale) + variance_scaling_impl( + key, + scale, + shape: shape, + type: type, + mode: :fan_in, + distribution: :truncated_normal + ) end end - defnp he_normal_impl(key, opts \\ []) do - opts = keyword!(opts, [:shape, type: {:f, 32}, scale: 2.0]) - - variance_scaling_impl( - key, - shape: opts[:shape], - type: opts[:type], - scale: opts[:scale], - mode: :fan_in, - distribution: :truncated_normal - ) - end - @doc """ Initializes parameters with variance scaling according to the given distribution and mode. @@ -586,30 +555,29 @@ defmodule Axon.Initializers do """ def variance_scaling(opts \\ []) do - fn shape, type, key -> - scale = opts[:scale] || 1.0 - mode = opts[:mode] || :fan_in - distribution = opts[:distribution] || :normal + opts = Keyword.validate!(opts, scale: 1.0, mode: :fan_in, distribution: :normal) + scale = Keyword.fetch!(opts, :scale) + mode = Keyword.fetch!(opts, :mode) + distribution = Keyword.fetch!(opts, :distribution) + fn shape, type, key -> variance_scaling_impl( key, + scale, shape: shape, type: type, - scale: scale, mode: mode, distribution: distribution ) end end - defnp variance_scaling_impl(key, opts \\ []) do - opts = - keyword!(opts, [:shape, type: {:f, 32}, scale: 1.0, mode: :fan_in, distribution: :normal]) + defnp variance_scaling_impl(key, scale, opts \\ []) do + opts = keyword!(opts, [:shape, type: {:f, 32}, mode: :fan_in, distribution: :normal]) fans = compute_fans(opts[:shape]) denominator = compute_denominator(fans, opts[:mode]) - - variance = Nx.divide(Nx.tensor(opts[:scale], type: opts[:type]), Nx.max(denominator, 1.0)) + variance = Nx.as_type(scale, opts[:type]) / Nx.max(denominator, 1.0) apply_distribution(key, opts[:distribution], variance, shape: opts[:shape], type: opts[:type]) end @@ -716,33 +684,20 @@ defmodule Axon.Initializers do assert_min_rank!("Axon.Initializers.orthogonal", "input_shape", shape, 2) - {{m, n}, random_seed} = - transform({key, shape, distribution, type}, fn {key, shape, distribution, type} -> - flat_shape = - if tuple_size(shape) > 2 do - tuple_list = shape |> Tuple.to_list() |> Enum.reverse() - n = hd(tuple_list) - m = Enum.reduce(tl(tuple_list), 1, &(&1 * &2)) - {m, n} - else - shape - end - - out = - case distribution do - :uniform -> - Nx.Random.uniform_split(key, 0.0, 1.0, shape: flat_shape, type: type) - - :normal -> - Nx.Random.normal_split(key, 0.0, 1.0, shape: flat_shape, type: type) - - dist -> - raise ArgumentError, - "invalid distribution #{inspect(dist)} passed to orthogonal/1" - end - - {flat_shape, out} - end) + {m, n} = get_flat_shape(shape) + + random_seed = + case distribution do + :uniform -> + Nx.Random.uniform_split(key, 0.0, 1.0, shape: {m, n}, type: type) + + :normal -> + Nx.Random.normal_split(key, 0.0, 1.0, shape: {m, n}, type: type) + + dist -> + raise ArgumentError, + "invalid distribution #{inspect(dist)} passed to orthogonal/1" + end {q, _r} = Nx.LinAlg.qr(random_seed, mode: :complete) @@ -754,6 +709,17 @@ defmodule Axon.Initializers do rand end + deftransformp get_flat_shape(shape) do + if tuple_size(shape) > 2 do + tuple_list = shape |> Tuple.to_list() |> Enum.reverse() + n = hd(tuple_list) + m = Enum.reduce(tl(tuple_list), 1, &(&1 * &2)) + {m, n} + else + shape + end + end + # Variance scaling branches defnp var_normal(key, variance, opts \\ []) do diff --git a/lib/axon/layers.ex b/lib/axon/layers.ex index 3336642d..4b64776d 100644 --- a/lib/axon/layers.ex +++ b/lib/axon/layers.ex @@ -188,8 +188,8 @@ defmodule Axon.Layers do assert_equal_rank!("Axon.Layers.bilinear", "input1", input1, "input2", input2) assert_rank!("Axon.Layers.bilinear", "kernel", kernel, 3) - inp1_axes = transform(Nx.rank(input1), fn rank -> [rank - 1] end) - inp2_axes = transform(Nx.rank(input2), fn rank -> [rank - 1] end) + inp1_axes = input1 |> last_axis() |> list_wrap() + inp2_axes = input2 |> last_axis() |> list_wrap() input1 |> Nx.dot(inp1_axes, [], kernel, [1], []) @@ -197,6 +197,8 @@ defmodule Axon.Layers do |> Nx.add(bias) end + deftransformp last_axis(input), do: Nx.rank(input) - 1 + ## Convolutional @doc """ @@ -239,7 +241,7 @@ defmodule Axon.Layers do Defaults to `1` or no dilation. * `:channels ` - channel configuration. One of `:first` or `:last`. - Defaults to `:first`. + Defaults to `:last`. ## Examples @@ -355,29 +357,8 @@ defmodule Axon.Layers do mode: :inference ) - bias_reshape = - transform( - {Nx.shape(bias), Nx.rank(input) - 2, opts[:channels]}, - fn {bias_shape, rank, channels} -> - Axon.Shape.conv_bias_reshape(bias_shape, rank, channels) - end - ) - - {permutations, kernel_permutation} = - transform({Nx.rank(input), opts[:channels]}, fn - {rank, :first} -> - perm = Enum.to_list(0..(rank - 1)) - {perm, perm} - - {rank, :last} -> - spatial = Enum.to_list(1..(rank - 2)//1) - perm = [0, rank - 1 | spatial] - kernel_perm = [rank - 1, rank - 2] ++ Enum.to_list(0..(rank - 3)//1) - {perm, kernel_perm} - - {_rank, invalid} -> - raise ArgumentError, "invalid channel configuration, #{inspect(invalid)}" - end) + bias_reshape = Axon.Shape.conv_bias_reshape(input, bias, opts[:channels]) + {permutations, kernel_permutation} = Axon.Shape.conv_permutations(input, opts[:channels]) input |> Nx.conv(kernel, @@ -431,7 +412,7 @@ defmodule Axon.Layers do Defaults to `1` or no dilation. * `:channels ` - channel configuration. One of `:first` or `:last`. - Defaults to `:first`. + Defaults to `:last`. ## Examples @@ -491,24 +472,18 @@ defmodule Axon.Layers do mode: :inference ) - strides = - transform( - {Nx.rank(input), opts[:strides]}, - fn - {_, [_ | _] = strides} -> strides - {rank, strides} -> List.duplicate(strides, rank - 2) - end - ) + strides = Axon.Shape.conv_transpose_strides(input, opts[:strides]) padding = - transform( - {Nx.shape(kernel), opts[:kernel_dilation], strides, opts[:padding], opts[:channels]}, - fn {shape, k_dilation, strides, padding, channels} -> - Axon.Shape.conv_transpose_padding(shape, k_dilation, strides, padding, channels) - end + Axon.Shape.conv_transpose_padding( + kernel, + opts[:kernel_dilation], + strides, + opts[:padding], + opts[:channels] ) - ones = transform(Nx.rank(input), &List.duplicate(1, &1 - 2)) + ones = list_duplicate(1, Nx.rank(input) - 2) conv(input, kernel, bias, strides: ones, @@ -560,7 +535,7 @@ defmodule Axon.Layers do Defaults to `1` or no dilation. * `:channels ` - channel configuration. One of `:first` or `:last`. - Defaults to `:first`. + Defaults to `:last`. """ @doc type: :convolutional @@ -597,14 +572,8 @@ defmodule Axon.Layers do mode: :inference ) - num_groups = - transform({Nx.shape(input), opts[:channels]}, fn - {shape, :first} -> - elem(shape, 1) - - {shape, :last} -> - elem(shape, tuple_size(shape) - 1) - end) + channel_index = channel_index_transform(input, opts[:channels]) + num_groups = Nx.axis_size(input, channel_index) conv(input, kernel, bias, strides: opts[:strides], @@ -616,6 +585,9 @@ defmodule Axon.Layers do ) end + deftransformp channel_index_transform(_input, :first), do: 1 + deftransformp channel_index_transform(input, :last), do: Nx.rank(input) - 1 + @doc """ Functional implementation of a 2-dimensional separable depthwise convolution. @@ -655,7 +627,7 @@ defmodule Axon.Layers do Defaults to `1` or no dilation. * `:channels ` - channel configuration. One of `:first` or `:last`. - Defaults to `:first`. + Defaults to `:last`. ## References @@ -722,7 +694,7 @@ defmodule Axon.Layers do Defaults to `1` or no dilation. * `:channels ` - channel configuration. One of `:first` or `:last`. - Defaults to `:first`. + Defaults to `:last`. ## References @@ -779,7 +751,7 @@ defmodule Axon.Layers do dilation. * `:channels ` - channel configuration. One of `:first` or `:last`. - Defaults to `:first`. + Defaults to `:last`. ## Examples @@ -816,47 +788,13 @@ defmodule Axon.Layers do ] ) - window_dimensions = - transform( - {Nx.rank(input), opts[:kernel_size], opts[:channels]}, - fn {rank, kernel_size, channels} -> - Axon.Shape.pool_window_size(kernel_size, rank - 2, channels) - end - ) + window_dimensions = Axon.Shape.pool_window_size(input, opts[:kernel_size], opts[:channels]) strides = - transform( - {Nx.rank(input), opts[:strides], window_dimensions, opts[:channels]}, - fn - {_, nil, dims, _} -> Tuple.to_list(dims) - {_, [_ | _] = strides, _, :first} -> [1, 1 | strides] - {_, [_ | _] = strides, _, :last} -> [1 | strides] ++ [1] - {rank, strides, _, :first} -> [1, 1 | List.duplicate(strides, rank - 2)] - {rank, strides, _, :last} -> [1 | List.duplicate(strides, rank - 2)] ++ [1] - end - ) + Axon.Shape.pool_window_strides(input, opts[:strides], window_dimensions, opts[:channels]) - dilations = - transform( - {Nx.rank(input), opts[:window_dilations], opts[:channels]}, - fn - {_, [_ | _] = dilations, :first} -> [1, 1 | dilations] - {rank, dilations, :first} -> [1, 1 | List.duplicate(dilations, rank - 2)] - {_, [_ | _] = dilations, :last} -> [1 | dilations] ++ [1] - {rank, dilations, :last} -> [1 | List.duplicate(dilations, rank - 2)] ++ [1] - end - ) - - padding = - transform( - {opts[:padding], opts[:channels]}, - fn - {:same, _} -> :same - {:valid, _} -> :valid - {padding, :first} -> [{0, 0}, {0, 0} | padding] - {padding, :last} -> [{0, 0} | padding] ++ [{0, 0}] - end - ) + dilations = Axon.Shape.pool_window_dilations(input, opts[:window_dilations], opts[:channels]) + padding = Axon.Shape.pool_window_padding(opts[:padding], opts[:channels]) input |> Nx.window_max(window_dimensions, @@ -896,7 +834,7 @@ defmodule Axon.Layers do dilation. * `:channels ` - channel configuration. One of `:first` or `:last`. - Defaults to `:first`. + Defaults to `:last`. """ @doc type: :pooling defn avg_pool(input, opts \\ []) do @@ -915,51 +853,13 @@ defmodule Axon.Layers do ] ) - window_dimensions = - transform( - {Nx.rank(input), opts[:kernel_size], opts[:channels]}, - fn {rank, kernel_size, channels} -> - Axon.Shape.pool_window_size(kernel_size, rank - 2, channels) - end - ) + window_dimensions = Axon.Shape.pool_window_size(input, opts[:kernel_size], opts[:channels]) strides = - transform( - {Nx.rank(input), opts[:strides], window_dimensions, opts[:channels]}, - fn - {_, nil, dims, _} -> Tuple.to_list(dims) - {_, [_ | _] = strides, _, :first} -> [1, 1 | strides] - {_, [_ | _] = strides, _, :last} -> [1 | strides] ++ [1] - {rank, strides, _, :first} -> [1, 1 | List.duplicate(strides, rank - 2)] - {rank, strides, _, :last} -> [1 | List.duplicate(strides, rank - 2)] ++ [1] - end - ) + Axon.Shape.pool_window_strides(input, opts[:strides], window_dimensions, opts[:channels]) - dilations = - transform( - {Nx.rank(input), opts[:window_dilations], opts[:channels]}, - fn - {_, [_ | _] = dilations, :first} -> [1, 1 | dilations] - {rank, dilations, :first} -> [1, 1 | List.duplicate(dilations, rank - 2)] - {_, [_ | _] = dilations, :last} -> [1 | dilations] ++ [1] - {rank, dilations, :last} -> [1 | List.duplicate(dilations, rank - 2)] ++ [1] - end - ) - - padding = - transform( - opts[:padding], - fn - :same -> - :same - - :valid -> - :valid - - padding -> - [{0, 0}, {0, 0} | padding] - end - ) + dilations = Axon.Shape.pool_window_dilations(input, opts[:window_dilations], opts[:channels]) + padding = Axon.Shape.pool_window_padding(opts[:padding], opts[:channels]) input |> Nx.window_mean(window_dimensions, @@ -1006,7 +906,7 @@ defmodule Axon.Layers do dilation. * `:channels ` - channel configuration. One of `:first` or `:last`. - Defaults to `:first`. + Defaults to `:last`. ## Examples @@ -1041,62 +941,95 @@ defmodule Axon.Layers do ] ) - window_dimensions = - transform( - {Nx.rank(input), opts[:kernel_size], opts[:channels]}, - fn {rank, kernel_size, channels} -> - Axon.Shape.pool_window_size(kernel_size, rank - 2, channels) - end - ) + window_dimensions = Axon.Shape.pool_window_size(input, opts[:kernel_size], opts[:channels]) strides = - transform( - {Nx.rank(input), opts[:strides], window_dimensions, opts[:channels]}, - fn - {_, nil, dims, _} -> Tuple.to_list(dims) - {_, [_ | _] = strides, _, :first} -> [1, 1 | strides] - {_, [_ | _] = strides, _, :last} -> [1 | strides] ++ [1] - {rank, strides, _, :first} -> [1, 1 | List.duplicate(strides, rank - 2)] - {rank, strides, _, :last} -> [1 | List.duplicate(strides, rank - 2)] ++ [1] - end - ) - - dilations = - transform( - {Nx.rank(input), opts[:window_dilations], opts[:channels]}, - fn - {_, [_ | _] = dilations, :first} -> [1, 1 | dilations] - {rank, dilations, :first} -> [1, 1 | List.duplicate(dilations, rank - 2)] - {_, [_ | _] = dilations, :last} -> [1 | dilations] ++ [1] - {rank, dilations, :last} -> [1 | List.duplicate(dilations, rank - 2)] ++ [1] - end - ) - - padding = - transform( - opts[:padding], - fn - :same -> - :same - - :valid -> - :valid + Axon.Shape.pool_window_strides(input, opts[:strides], window_dimensions, opts[:channels]) - padding -> - [{0, 0}, {0, 0} | padding] - end - ) + dilations = Axon.Shape.pool_window_dilations(input, opts[:window_dilations], opts[:channels]) + padding = Axon.Shape.pool_window_padding(opts[:padding], opts[:channels]) norm = opts[:norm] input - |> Nx.power(norm) + |> Nx.pow(norm) |> Nx.window_sum(window_dimensions, strides: strides, padding: padding, window_dilations: dilations ) - |> Nx.power(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm)) + |> Nx.pow(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm)) + end + + @doc """ + Functional implementation of a 2-dimensional blur pooling layer. + + Blur pooling applies a spatial low-pass filter to the input. It is + often applied before pooling and convolutional layers as a way to + increase model accuracy without much additional computation cost. + + The blur pooling implementation follows from [MosaicML](https://github.com/mosaicml/composer/blob/dev/composer/algorithms/blurpool/blurpool_layers.py). + """ + @doc type: :pooling + defn blur_pool(input, opts \\ []) do + assert_rank!("blur_pool", "input", input, 4) + opts = keyword!(opts, channels: :last, mode: :train) + + filter = + Nx.tensor([ + [ + [ + [1, 2, 1], + [2, 4, 2], + [1, 2, 1] + ] + ] + ]) * 1 / 16.0 + + output_channels = + case opts[:channels] do + :last -> + Nx.axis_size(input, 3) + + :first -> + Nx.axis_size(input, 1) + end + + filter = compute_filter(filter, opts[:channels], output_channels) + + conv(input, filter, + padding: padding_for_filter(filter), + feature_group_size: output_channels, + channels: opts[:channels] + ) + end + + deftransformp compute_filter(filter, :first, out_channels) do + filter_shape = put_elem(Nx.shape(filter), 0, out_channels) + Nx.broadcast(filter, filter_shape) + end + + deftransformp compute_filter(filter, :last, out_channels) do + filter_shape = put_elem(Nx.shape(filter), 0, out_channels) + filter_permutation = [3, 2, 0, 1] + filter |> Nx.broadcast(filter_shape) |> Nx.transpose(axes: filter_permutation) + end + + deftransformp padding_for_filter(filter) do + {_, _, h, w} = Nx.shape(filter) + + cond do + rem(h, 2) == 0 -> + raise ArgumentError, "filter height must be odd" + + rem(w, 2) == 0 -> + raise ArgumentError, "filter width must be odd" + + true -> + :ok + end + + [{div(h, 2), div(h, 2)}, {div(w, 2), div(w, 2)}] end @doc """ @@ -1121,7 +1054,7 @@ defmodule Axon.Layers do Required. * `:channels ` - channel configuration. One of `:first` or `:last`. - Defaults to `:first`. + Defaults to `:last`. """ @doc type: :pooling defn adaptive_avg_pool(input, opts \\ []) do @@ -1129,26 +1062,11 @@ defmodule Axon.Layers do opts = keyword!(opts, [:output_size, channels: :last, mode: :inference]) - output_size = - transform({Nx.shape(input), opts[:output_size], opts[:channels]}, fn {shape, size, channels} -> - Axon.Shape.adaptive_pool_window_size(shape, size, channels) - end) - - window_strides = - transform( - {Nx.shape(input), Nx.rank(input), output_size, opts[:channels]}, - fn {shape, rank, output_size, channels} -> - Axon.Shape.adaptive_pool_window_strides(shape, output_size, rank - 2, channels) - end - ) + output_size = Axon.Shape.adaptive_pool_output_size(input, opts[:output_size], opts[:channels]) + window_strides = Axon.Shape.adaptive_pool_window_strides(input, output_size, opts[:channels]) window_dimensions = - transform( - {Nx.shape(input), Nx.rank(input), window_strides, output_size, opts[:channels]}, - fn {shape, rank, strides, output_size, channels} -> - Axon.Shape.adaptive_pool_window_size(shape, strides, output_size, rank - 2, channels) - end - ) + Axon.Shape.adaptive_pool_window_size(input, window_strides, output_size, opts[:channels]) Nx.window_mean(input, window_dimensions, padding: :valid, strides: window_strides) end @@ -1180,26 +1098,11 @@ defmodule Axon.Layers do opts = keyword!(opts, [:output_size, channels: :last, mode: :inference]) - output_size = - transform({Nx.shape(input), opts[:output_size], opts[:channels]}, fn {shape, size, channels} -> - Axon.Shape.adaptive_pool_window_size(shape, size, channels) - end) - - window_strides = - transform( - {Nx.shape(input), Nx.rank(input), output_size, opts[:channels]}, - fn {shape, rank, output_size, channels} -> - Axon.Shape.adaptive_pool_window_strides(shape, output_size, rank - 2, channels) - end - ) + output_size = Axon.Shape.adaptive_pool_output_size(input, opts[:output_size], opts[:channels]) + window_strides = Axon.Shape.adaptive_pool_window_strides(input, output_size, opts[:channels]) window_dimensions = - transform( - {Nx.shape(input), Nx.rank(input), window_strides, output_size, opts[:channels]}, - fn {shape, rank, strides, output_size, channels} -> - Axon.Shape.adaptive_pool_window_size(shape, strides, output_size, rank - 2, channels) - end - ) + Axon.Shape.adaptive_pool_window_size(input, window_strides, output_size, opts[:channels]) Nx.window_max(input, window_dimensions, padding: :valid, strides: window_strides) end @@ -1239,31 +1142,16 @@ defmodule Axon.Layers do norm = opts[:norm] - output_size = - transform({Nx.shape(input), opts[:output_size], opts[:channels]}, fn {shape, size, channels} -> - Axon.Shape.adaptive_pool_window_size(shape, size, channels) - end) - - window_strides = - transform( - {Nx.shape(input), Nx.rank(input), output_size, opts[:channels]}, - fn {shape, rank, output_size, channels} -> - Axon.Shape.adaptive_pool_window_strides(shape, output_size, rank - 2, channels) - end - ) + output_size = Axon.Shape.adaptive_pool_output_size(input, opts[:output_size], opts[:channels]) + window_strides = Axon.Shape.adaptive_pool_window_strides(input, output_size, opts[:channels]) window_dimensions = - transform( - {Nx.shape(input), Nx.rank(input), window_strides, output_size, opts[:channels]}, - fn {shape, rank, strides, output_size, channels} -> - Axon.Shape.adaptive_pool_window_size(shape, strides, output_size, rank - 2, channels) - end - ) + Axon.Shape.adaptive_pool_window_size(input, window_strides, output_size, opts[:channels]) input - |> Nx.power(norm) + |> Nx.pow(norm) |> Nx.window_sum(window_dimensions, padding: :valid, strides: window_strides) - |> Nx.power(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm)) + |> Nx.pow(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm)) end ## Normalization @@ -1292,7 +1180,7 @@ defmodule Axon.Layers do * `:momentum` - momentum to use for EMA update. - * `:training?` - if true, uses training mode batch norm. Defaults to false. + * `:mode` - if `:train`, uses training mode batch norm. Defaults to `:inference`. ## References @@ -1302,56 +1190,22 @@ defmodule Axon.Layers do defn batch_norm(input, gamma, beta, ra_mean, ra_var, opts \\ []) do opts = keyword!(opts, epsilon: 1.0e-5, channel_index: -1, momentum: 0.1, mode: :inference) - training? = - transform(opts[:mode], fn - :inference -> false - :train -> true - end) + axes = Axon.Shape.batch_norm_axes(input, opts[:channel_index]) - {axes, channel_index} = - transform({input, opts[:channel_index]}, fn {input, channel} -> - axes = Nx.axes(input) - axis = Nx.Shape.normalize_axis(Nx.shape(input), channel, Nx.names(input)) - {Axon.Shape.batch_norm_axes(axes, axis), axis} - end) + num_channels = Nx.axis_size(input, opts[:channel_index]) - num_channels = - transform({input, channel_index}, fn {inp, channel_idx} -> - elem(Nx.shape(inp), channel_idx) - end) + parameter_shape = norm_parameter_reshape(input, num_channels, opts[:channel_index]) - {gamma, beta, ra_mean, ra_var} = - transform( - {gamma, beta, ra_mean, ra_var, Nx.rank(input), num_channels, channel_index}, - fn {g, b, m, v, rank, num_channels, channel_idx} -> - new_shape = - 1 - |> List.duplicate(rank) - |> List.to_tuple() - |> put_elem(channel_idx, num_channels) - - {Nx.reshape(g, new_shape), Nx.reshape(b, new_shape), Nx.reshape(m, new_shape), - Nx.reshape(v, new_shape)} - end - ) + gamma = Nx.reshape(gamma, parameter_shape) + beta = Nx.reshape(beta, parameter_shape) + ra_mean = Nx.reshape(ra_mean, parameter_shape) + ra_var = Nx.reshape(ra_var, parameter_shape) - transform( - {input, gamma, beta, ra_mean, ra_var, axes, opts[:epsilon], opts[:momentum], training?}, - fn - {x, g, b, m, v, axes, eps, alpha, true} -> - {new_mean, new_var} = mean_and_variance(x, axes: axes) - out = normalize(x, new_mean, new_var, g, b, epsilon: eps) - ra_mean = update_ema(new_mean, m, alpha) - ra_var = update_ema(new_var, v, alpha) - - %Axon.StatefulOutput{ - output: out, - state: %{"mean" => ra_mean, "var" => ra_var} - } - - {x, g, b, m, v, _, eps, _, _} -> - normalize(x, m, v, g, b, epsilon: eps) - end + stateful_normalization_mode_transform(input, gamma, beta, ra_mean, ra_var, + axes: axes, + epsilon: opts[:epsilon], + momentum: opts[:momentum], + mode: opts[:mode] ) end @@ -1383,32 +1237,12 @@ defmodule Axon.Layers do opts = keyword!(opts, epsilon: 1.0e-5, channel_index: -1, mode: :inference) axes = opts[:channel_index] - channel_index = opts[:channel_index] + num_channels = Nx.axis_size(input, opts[:channel_index]) - num_channels = - transform({input, channel_index}, fn {inp, channel_idx} -> - names = List.duplicate(nil, Nx.rank(inp)) - axis = Nx.Shape.normalize_axis(Nx.shape(inp), channel_idx, names) - elem(Nx.shape(inp), axis) - end) - - {gamma, beta} = - transform({gamma, beta, input, Nx.rank(input), num_channels, channel_index}, fn {g, b, - input, - rank, - num_channels, - channel_idx} -> - names = List.duplicate(nil, rank) - axis = Nx.Shape.normalize_axis(Nx.shape(input), channel_idx, names) + parameter_shape = norm_parameter_reshape(input, num_channels, opts[:channel_index]) - new_shape = - 1 - |> List.duplicate(rank) - |> List.to_tuple() - |> put_elem(axis, num_channels) - - {Nx.reshape(g, new_shape), Nx.reshape(b, new_shape)} - end) + gamma = Nx.reshape(gamma, parameter_shape) + beta = Nx.reshape(beta, parameter_shape) {mean, var} = mean_and_variance(input, axes: [axes]) normalize(input, mean, var, gamma, beta, epsilon: opts[:epsilon]) @@ -1444,49 +1278,18 @@ defmodule Axon.Layers do defn group_norm(input, gamma, beta, opts \\ []) do opts = keyword!(opts, [:num_groups, epsilon: 1.0e-5, channel_index: -1, mode: :inference]) - channel_axis = - transform({Nx.shape(input), opts[:channel_index]}, fn - {shape, channel_index} -> - names = List.duplicate(nil, Nx.rank(shape)) - Nx.Shape.normalize_axis(shape, channel_index, names) - end) - - group_shape = - transform({Nx.shape(input), opts[:num_groups], channel_axis}, fn - {shape, groups, channel_axis} -> - Axon.Shape.group_norm_shape(shape, groups, channel_axis) - end) - - channel_index = opts[:channel_index] - - num_channels = - transform({input, channel_index}, fn {inp, channel_idx} -> - names = List.duplicate(nil, Nx.rank(inp)) - axis = Nx.Shape.normalize_axis(Nx.shape(inp), channel_idx, names) - elem(Nx.shape(inp), axis) - end) - - {gamma, beta} = - transform({gamma, beta, input, Nx.rank(input), num_channels, channel_index}, fn - {g, b, inp, rank, num_channels, channel_idx} -> - names = List.duplicate(nil, Nx.rank(inp)) - axis = Nx.Shape.normalize_axis(Nx.shape(inp), channel_idx, names) + channel_axis = normalize_group_norm_channel_axis(input, opts[:channel_index]) - new_shape = - 1 - |> List.duplicate(rank) - |> List.to_tuple() - |> put_elem(axis, num_channels) + group_shape = Axon.Shape.group_norm_shape(input, opts[:num_groups], opts[:channel_index]) + num_channels = Nx.axis_size(input, opts[:channel_index]) - {Nx.reshape(g, new_shape), Nx.reshape(b, new_shape)} - end) + parameter_shape = norm_parameter_reshape(input, num_channels, opts[:channel_index]) + gamma = Nx.reshape(gamma, parameter_shape) + beta = Nx.reshape(beta, parameter_shape) x = Nx.reshape(input, group_shape) - axes = - transform({x, channel_axis}, fn {x, channel_axis} -> - Axon.Shape.group_norm_axes(Nx.rank(x), channel_axis) - end) + axes = Axon.Shape.group_norm_axes(x, channel_axis) {mean, var} = mean_and_variance(x, axes: axes) x = (x - mean) * Nx.rsqrt(var + opts[:epsilon]) @@ -1494,6 +1297,10 @@ defmodule Axon.Layers do x * gamma + beta end + deftransformp normalize_group_norm_channel_axis(input, channel_index) do + Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.shape(input)) + end + @doc ~S""" Functional implementation of instance normalization. @@ -1527,59 +1334,63 @@ defmodule Axon.Layers do defn instance_norm(input, gamma, beta, ra_mean, ra_var, opts \\ []) do opts = keyword!(opts, epsilon: 1.0e-5, channel_index: -1, momentum: 0.1, mode: :inference) - training? = - transform(opts[:mode], fn - :inference -> false - :train -> true - end) + axes = Axon.Shape.instance_norm_axes(input, opts[:channel_index]) + num_channels = Nx.axis_size(input, opts[:channel_index]) - {axes, channel_index} = - transform({input, opts[:channel_index]}, fn {input, channel} -> - axes = Nx.axes(input) - axis = Nx.Shape.normalize_axis(Nx.shape(input), channel, Nx.names(input)) - {Axon.Shape.instance_norm_axes(axes, axis), axis} - end) + parameter_shape = norm_parameter_reshape(input, num_channels, opts[:channel_index]) - num_channels = - transform({input, channel_index}, fn {inp, channel_idx} -> - elem(Nx.shape(inp), channel_idx) - end) + gamma = Nx.reshape(gamma, parameter_shape) + beta = Nx.reshape(beta, parameter_shape) + ra_mean = Nx.reshape(ra_mean, parameter_shape) + ra_var = Nx.reshape(ra_var, parameter_shape) - {gamma, beta, ra_mean, ra_var} = - transform( - {gamma, beta, ra_mean, ra_var, Nx.rank(input), num_channels, channel_index}, - fn {g, b, m, v, rank, num_channels, channel_idx} -> - new_shape = - 1 - |> List.duplicate(rank) - |> List.to_tuple() - |> put_elem(channel_idx, num_channels) - - {Nx.reshape(g, new_shape), Nx.reshape(b, new_shape), Nx.reshape(m, new_shape), - Nx.reshape(v, new_shape)} - end - ) + stateful_normalization_mode_transform(input, gamma, beta, ra_mean, ra_var, + axes: axes, + epsilon: opts[:epsilon], + momentum: opts[:momentum], + mode: opts[:mode] + ) + end - transform( - {input, gamma, beta, ra_mean, ra_var, axes, opts[:epsilon], opts[:momentum], training?}, - fn - {x, g, b, m, v, axes, eps, alpha, true} -> - {new_mean, new_var} = mean_and_variance(x, axes: axes) - out = normalize(x, new_mean, new_var, g, b, epsilon: eps) - ra_mean = update_ema(new_mean, m, alpha) - ra_var = update_ema(new_var, v, alpha) - - %Axon.StatefulOutput{ - output: out, - state: %{"mean" => ra_mean, "var" => ra_var} - } - - {x, g, b, m, v, _, eps, _, _} -> - normalize(x, m, v, g, b, epsilon: eps) - end + deftransformp norm_parameter_reshape(input, num_channels, channel_index) do + 1 + |> List.duplicate(Nx.rank(input)) + |> List.to_tuple() + |> put_elem( + Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.names(input)), + num_channels ) end + deftransformp stateful_normalization_mode_transform( + input, + gamma, + beta, + ra_mean, + ra_var, + opts \\ [] + ) do + eps = opts[:epsilon] + alpha = opts[:momentum] + axes = opts[:axes] + + case opts[:mode] do + :train -> + {new_mean, new_var} = mean_and_variance(input, axes: axes) + out = normalize(input, new_mean, new_var, gamma, beta, epsilon: eps) + ra_mean = update_ema(new_mean, ra_mean, alpha) + ra_var = update_ema(new_var, ra_var, alpha) + + %Axon.StatefulOutput{ + output: out, + state: %{"mean" => ra_mean, "var" => ra_var} + } + + :inference -> + normalize(input, ra_mean, ra_var, gamma, beta, epsilon: eps) + end + end + ## Stochastic @doc ~S""" @@ -1609,7 +1420,7 @@ defmodule Axon.Layers do @doc type: :dropout defn dropout(input, key, opts \\ []) do opts = keyword!(opts, [:rate, noise_shape: Nx.shape(input), mode: :inference]) - keep_prob = Nx.tensor(1, type: Nx.type(input)) - Nx.tensor(opts[:rate], type: Nx.type(input)) + keep_prob = Nx.tensor(1, type: Nx.type(input)) - Nx.as_type(opts[:rate], Nx.type(input)) {rand, new_key} = Nx.Random.uniform(key, 0, 1, shape: opts[:noise_shape], type: Nx.type(input)) @@ -1660,10 +1471,7 @@ defmodule Axon.Layers do opts = keyword!(opts, rate: 0.5, channels: :last, mode: :inference) - noise_shape = - transform({Nx.shape(input), opts[:channels]}, fn {shape, channels} -> - Axon.Shape.spatial_dropout_noise_shape(shape, channels) - end) + noise_shape = Axon.Shape.spatial_dropout_noise_shape(input, opts[:channels]) dropout(input, key, rate: opts[:rate], @@ -1706,7 +1514,7 @@ defmodule Axon.Layers do mask = Nx.less(rand, keep_prob) - a = Nx.rsqrt(keep_prob * Nx.power(Nx.tensor(1, type: Nx.type(input)) * alpha_p, 2)) + a = Nx.rsqrt(keep_prob * Nx.pow(Nx.tensor(1, type: Nx.type(input)) * alpha_p, 2)) b = -a * alpha_p * rate x = Nx.select(mask, input, alpha_p) @@ -1738,10 +1546,7 @@ defmodule Axon.Layers do opts = keyword!(opts, rate: 0.5, channels: :last, mode: :inference) - noise_shape = - transform({Nx.shape(input), opts[:channels]}, fn {shape, channels} -> - Axon.Shape.spatial_dropout_noise_shape(shape, channels) - end) + noise_shape = Axon.Shape.spatial_dropout_noise_shape(input, opts[:channels]) keep_prob = 1 - opts[:rate] @@ -1808,14 +1613,7 @@ defmodule Axon.Layers do opts = keyword!(opts, channels: :last, keep_axes: false, mode: :inference) - all_but_batch_and_feature = - transform({Nx.rank(input), opts[:channels]}, fn - {rank, :first} -> - for i <- 2..(rank - 1), do: i - - {rank, :last} -> - for i <- 1..(rank - 2), do: i - end) + all_but_batch_and_feature = Axon.Shape.global_pool_axes(input, opts[:channels]) Nx.mean(input, axes: all_but_batch_and_feature, keep_axes: opts[:keep_axes]) end @@ -1872,14 +1670,7 @@ defmodule Axon.Layers do opts = keyword!(opts, keep_axes: false, channels: :last, mode: :inference) - all_but_batch_and_feature = - transform({Nx.rank(input), opts[:channels]}, fn - {rank, :first} -> - for i <- 2..(rank - 1), do: i - - {rank, :last} -> - for i <- 1..(rank - 2), do: i - end) + all_but_batch_and_feature = Axon.Shape.global_pool_axes(input, opts[:channels]) Nx.reduce_max(input, axes: all_but_batch_and_feature, keep_axes: opts[:keep_axes]) end @@ -1943,19 +1734,12 @@ defmodule Axon.Layers do norm = opts[:norm] - all_but_batch_and_feature = - transform({Nx.rank(input), opts[:channels]}, fn - {rank, :first} -> - for i <- 2..(rank - 1), do: i - - {rank, :last} -> - for i <- 1..(rank - 2), do: i - end) + all_but_batch_and_feature = Axon.Shape.global_pool_axes(input, opts[:channels]) input - |> Nx.power(norm) + |> Nx.pow(norm) |> Nx.sum(axes: all_but_batch_and_feature, keep_axes: opts[:keep_axes]) - |> Nx.power(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm)) + |> Nx.pow(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm)) end ## Sparse @@ -2031,32 +1815,29 @@ defmodule Axon.Layers do > """ @doc type: :shape - defn flatten(x, _opts \\ []) do - new_shape = transform(Nx.shape(x), &Axon.Shape.flatten/1) + deftransform flatten(input, _opts \\ []) do + shape = Nx.shape(input) + out_units = Nx.size(Tuple.delete_at(shape, 0)) + out_shape = {elem(shape, 0), out_units} - Nx.reshape(x, new_shape) + Nx.reshape(input, out_shape) end @doc false # Internal version of Nx.reshape for constructing reshape layers # without worrying about a batch dimension - defn reshape(x, opts \\ []) do - opts = keyword!(opts, [:shape, mode: :inference]) - - transform({opts[:shape], x}, fn {shape, x} -> - batch_size = Nx.axis_size(x, 0) - - new_shape = - shape - |> Tuple.to_list() - |> Enum.map(fn - :batch -> batch_size - val -> val - end) - |> List.to_tuple() - - Nx.reshape(x, new_shape) + deftransform reshape(x, opts \\ []) do + opts = Keyword.validate!(opts, [:shape, mode: :inference]) + batch_size = Nx.axis_size(x, 0) + + opts[:shape] + |> Tuple.to_list() + |> Enum.map(fn + :batch -> batch_size + val -> val end) + |> List.to_tuple() + |> then(&Nx.reshape(x, &1)) end @doc false @@ -2064,34 +1845,27 @@ defmodule Axon.Layers do # worrying about batch or channel dimensions defn pad(x, opts \\ []) do opts = keyword!(opts, [:padding_config, :value, :channels, mode: :inference]) + config = padding_config_transform(opts[:padding_config], opts[:channels]) - config = - transform({opts[:padding_config], opts[:channels]}, fn - {config, :first} -> - [{0, 0, 0}, {0, 0, 0} | Enum.map(config, fn {x, y} -> {x, y, 0} end)] + Nx.pad(x, Nx.as_type(opts[:value], Nx.type(x)), config) + end - {config, :last} -> - [{0, 0, 0} | Enum.map(config, fn {x, y} -> {x, y, 0} end)] ++ [{0, 0, 0}] - end) + deftransform padding_config_transform(config, channels) do + case channels do + :first -> + [{0, 0, 0}, {0, 0, 0} | Enum.map(config, fn {x, y} -> {x, y, 0} end)] - Nx.pad(x, Nx.as_type(opts[:value], Nx.type(x)), config) + :last -> + [{0, 0, 0} | Enum.map(config, fn {x, y} -> {x, y, 0} end)] ++ [{0, 0, 0}] + end end @doc false # Internal version of Nx.transpose for constructing a transpose layer # without worrying about a batch dimension - defn transpose(x, opts \\ []) do - opts = keyword!(opts, [:axes, mode: :inference]) - - axes = - transform({Nx.shape(x), opts[:axes]}, fn - {shape, nil} -> - Nx.axes(shape) |> Enum.reverse() - - {_, axes} -> - axes - end) - + deftransform transpose(x, opts \\ []) do + opts = Keyword.validate!(opts, [:axes, mode: :inference]) + axes = opts[:axes] || Enum.reverse(Nx.axes(x)) Nx.transpose(x, axes: axes) end @@ -2102,20 +1876,7 @@ defmodule Axon.Layers do opts = keyword!(opts, [:cond, mode: :inference]) cond_expr = opts[:cond].(cond_input_expr) - transform(cond_expr, fn cond_expr -> - cond_rank = Nx.rank(cond_expr) - cond_type = Nx.type(cond_expr) - - unless Elixir.Kernel.and( - Elixir.Kernel.==(cond_rank, 0), - Elixir.Kernel.==(cond_type, {:u, 8}) - ) do - raise ArgumentError, - "cond_fn must return a scalar-boolean tensor" <> - " got result with rank #{inspect(cond_rank)} and" <> - " type #{inspect(cond_type)}" - end - end) + validate_conv_predicate!(cond_expr) if cond_expr do on_true_expr @@ -2124,6 +1885,21 @@ defmodule Axon.Layers do end end + deftransformp validate_conv_predicate!(cond_expr) do + cond_rank = Nx.rank(cond_expr) + cond_type = Nx.type(cond_expr) + + unless Elixir.Kernel.and( + Elixir.Kernel.==(cond_rank, 0), + Elixir.Kernel.==(cond_type, {:u, 8}) + ) do + raise ArgumentError, + "cond_fn must return a scalar-boolean tensor" <> + " got result with rank #{inspect(cond_rank)} and" <> + " type #{inspect(cond_type)}" + end + end + @doc false # Internal helper for constructing bias layers without defn bias(input, bias, _opts \\ []) do @@ -2134,7 +1910,7 @@ defmodule Axon.Layers do Resizes a batch of tensors to the given shape using one of a number of sampling methods. - Requires input option `:to` which should be a tuple specifying + Requires input option `:size` which should be a tuple specifying the resized spatial dimensions of the input tensor. Input tensor must be at least rank 3, with fixed `batch` and `channel` dimensions. Resizing will upsample or downsample using the given resize method. @@ -2167,67 +1943,70 @@ defmodule Axon.Layers do ** (ArgumentError) expected :method to be either of :nearest, :bilinear, :bicubic, :lanczos3, :lanczos5, got: :foo """ @doc type: :shape - defn resize(input, opts \\ []) do + deftransform resize(input, opts \\ []) do assert_rank!("Axon.Layers.resize", "input", input, 4) opts = - keyword!(opts, [ + Keyword.validate!(opts, [ :size, method: :nearest, channels: :last, mode: :inference ]) - transform({input, opts}, fn {input, opts} -> - {spatial_axes, out_shape} = - input - |> spatial_axes_with_sizes(opts) - |> Enum.reject(fn {_axis, size, out_size} -> Elixir.Kernel.==(size, out_size) end) - |> Enum.map_reduce(Nx.shape(input), fn {axis, _size, out_size}, out_shape -> - {axis, put_elem(out_shape, axis, out_size)} - end) + {spatial_axes, out_shape} = + input + |> spatial_axes_with_sizes(opts) + |> Enum.reject(fn {_axis, size, out_size} -> Elixir.Kernel.==(size, out_size) end) + |> Enum.map_reduce(Nx.shape(input), fn {axis, _size, out_size}, out_shape -> + {axis, put_elem(out_shape, axis, out_size)} + end) - resized_input = - case opts[:method] do - :nearest -> - resize_nearest(input, out_shape, spatial_axes) + resized_input = + case opts[:method] do + :nearest -> + resize_nearest(input, out_shape, spatial_axes) - :bilinear -> - resize_with_kernel(input, out_shape, spatial_axes, &fill_linear_kernel/1) + :bilinear -> + resize_with_kernel(input, out_shape, spatial_axes, &fill_linear_kernel/1) - :bicubic -> - resize_with_kernel(input, out_shape, spatial_axes, &fill_cubic_kernel/1) + :bicubic -> + resize_with_kernel(input, out_shape, spatial_axes, &fill_cubic_kernel/1) - :lanczos3 -> - resize_with_kernel(input, out_shape, spatial_axes, &fill_lanczos_kernel(3, &1)) + :lanczos3 -> + resize_with_kernel(input, out_shape, spatial_axes, &fill_lanczos_kernel(3, &1)) - :lanczos5 -> - resize_with_kernel(input, out_shape, spatial_axes, &fill_lanczos_kernel(5, &1)) + :lanczos5 -> + resize_with_kernel(input, out_shape, spatial_axes, &fill_lanczos_kernel(5, &1)) - method -> - raise ArgumentError, - "expected :method to be either of :nearest, :bilinear, :bicubic, " <> - ":lanczos3, :lanczos5, got: #{inspect(method)}" - end + method -> + raise ArgumentError, + "expected :method to be either of :nearest, :bilinear, :bicubic, " <> + ":lanczos3, :lanczos5, got: #{inspect(method)}" + end - cast_to(resized_input, input) - end) + cast_to(resized_input, input) + end + + deftransformp spatial_axes_with_sizes(input, opts \\ []) do + {height_axis, width_axis} = spatial_axes(input, channels: opts[:channels]) + {height, width} = size(input, channels: opts[:channels]) + {out_height, out_width} = opts[:size] + [{height_axis, height, out_height}, {width_axis, width, out_width}] end - defnp spatial_axes(input, opts \\ []) do + deftransformp spatial_axes(input, opts \\ []) do channels = opts[:channels] - transform({input, channels}, fn {input, channels} -> - axes = - case channels do - :first -> [-2, -1] - :last -> [-3, -2] - end + axes = + case channels do + :first -> [-2, -1] + :last -> [-3, -2] + end - axes - |> Enum.map(&Nx.axis_index(input, &1)) - |> List.to_tuple() - end) + axes + |> Enum.map(&Nx.axis_index(input, &1)) + |> List.to_tuple() end defnp cast_to(left, right) do @@ -2236,56 +2015,58 @@ defmodule Axon.Layers do |> Nx.reshape(left, names: Nx.names(right)) end - defnp resize_nearest(input, out_shape, spatial_axes) do - transform({input, out_shape, spatial_axes}, fn {input, out_shape, spatial_axes} -> - singular_shape = List.duplicate(1, Nx.rank(input)) |> List.to_tuple() + deftransformp resize_nearest(input, out_shape, spatial_axes) do + singular_shape = List.duplicate(1, Nx.rank(input)) |> List.to_tuple() - for axis <- spatial_axes, reduce: input do - input -> - input_shape = Nx.shape(input) - input_size = elem(input_shape, axis) - output_size = elem(out_shape, axis) - inv_scale = input_size / output_size - offset = (Nx.iota({output_size}) + 0.5) * inv_scale - offset = offset |> Nx.floor() |> Nx.as_type({:s, 32}) + for axis <- spatial_axes, reduce: input do + input -> + input_shape = Nx.shape(input) + input_size = elem(input_shape, axis) + output_size = elem(out_shape, axis) + inv_scale = input_size / output_size + offset = Nx.iota({output_size}) |> Nx.add(0.5) |> Nx.multiply(inv_scale) + offset = offset |> Nx.floor() |> Nx.as_type({:s, 32}) - offset = - offset - |> Nx.reshape(put_elem(singular_shape, axis, output_size)) - |> Nx.broadcast(put_elem(input_shape, axis, output_size)) + offset = + offset + |> Nx.reshape(put_elem(singular_shape, axis, output_size)) + |> Nx.broadcast(put_elem(input_shape, axis, output_size)) - Nx.take_along_axis(input, offset, axis: axis) - end - end) + Nx.take_along_axis(input, offset, axis: axis) + end end @f32_eps :math.pow(2, -23) - defnp resize_with_kernel(input, out_shape, spatial_axes, kernel_fun) do - transform({input, out_shape, spatial_axes}, fn {input, out_shape, spatial_axes} -> - for axis <- spatial_axes, reduce: input do - input -> - input_shape = Nx.shape(input) - input_size = elem(input_shape, axis) - output_size = elem(out_shape, axis) + deftransformp resize_with_kernel(input, out_shape, spatial_axes, kernel_fun) do + for axis <- spatial_axes, reduce: input do + input -> + input_shape = Nx.shape(input) + input_size = elem(input_shape, axis) + output_size = elem(out_shape, axis) - inv_scale = input_size / output_size - kernel_scale = Nx.max(1, inv_scale) + inv_scale = input_size / output_size + kernel_scale = Nx.max(1, inv_scale) - sample_f = (Nx.iota({1, output_size}) + 0.5) * inv_scale - 0.5 - x = Nx.abs(sample_f - Nx.iota({input_size, 1})) / kernel_scale - weights = kernel_fun.(x) + sample_f = + Nx.add(Nx.iota({1, output_size}), 0.5) |> Nx.multiply(Nx.subtract(inv_scale, 0.5)) - weights_sum = Nx.sum(weights, axes: [0], keep_axes: true) + x = Nx.abs(Nx.subtract(sample_f, Nx.iota({input_size, 1}))) |> Nx.divide(kernel_scale) + weights = kernel_fun.(x) - weights = - Nx.select(Nx.abs(weights) > 1000 * @f32_eps, safe_divide(weights, weights_sum), 0) + weights_sum = Nx.sum(weights, axes: [0], keep_axes: true) - input = Nx.dot(input, [axis], weights, [0]) - # The transformed axis is moved to the end, so we transpose back - reorder_axis(input, -1, axis) - end - end) + weights = + Nx.select( + Nx.greater(Nx.abs(weights), 1000 * @f32_eps), + safe_divide(weights, weights_sum), + 0 + ) + + input = Nx.dot(input, [axis], weights, [0]) + # The transformed axis is moved to the end, so we transpose back + reorder_axis(input, -1, axis) + end end defnp fill_linear_kernel(x) do @@ -2311,20 +2092,11 @@ defmodule Axon.Layers do x / Nx.select(y != 0, y, 1) end - defnp reorder_axis(tensor, axis, target_axis) do - transform({tensor, axis, target_axis}, fn {tensor, axis, target_axis} -> - axes = Nx.axes(tensor) - {source_axis, axes} = List.pop_at(axes, axis) - axes = List.insert_at(axes, target_axis, source_axis) - Nx.transpose(tensor, axes: axes) - end) - end - - defnp spatial_axes_with_sizes(input, opts \\ []) do - {height_axis, width_axis} = spatial_axes(input, channels: opts[:channels]) - {height, width} = size(input, channels: opts[:channels]) - {out_height, out_width} = opts[:size] - [{height_axis, height, out_height}, {width_axis, width, out_width}] + deftransformp reorder_axis(tensor, axis, target_axis) do + axes = Nx.axes(tensor) + {source_axis, axes} = List.pop_at(axes, axis) + axes = List.insert_at(axes, target_axis, source_axis) + Nx.transpose(tensor, axes: axes) end defnp size(input, opts \\ []) do @@ -2341,23 +2113,16 @@ defmodule Axon.Layers do for activation <- @activation_layers do @doc false - defn unquote(activation)(input, _opts \\ []) do - transform(input, fn inp -> - Elixir.Kernel.apply(Axon.Activations, unquote(activation), [inp]) - end) + deftransform unquote(activation)(input, _opts \\ []) do + apply(Axon.Activations, unquote(activation), [input]) end end @activation_layers_with_opts [:celu, :elu, :hard_sigmoid, :hard_silu, :leaky_relu] ++ [:log_sumexp, :log_softmax, :selu, :softmax] for activation <- @activation_layers_with_opts do - defn unquote(activation)(input, opts \\ []) do - transform(input, fn inp -> - Elixir.Kernel.apply(Axon.Activations, unquote(activation), [ - inp, - Keyword.delete(opts, :mode) - ]) - end) + deftransform unquote(activation)(input, opts \\ []) do + apply(Axon.Activations, unquote(activation), [input, Keyword.delete(opts, :mode)]) end end @@ -2367,26 +2132,22 @@ defmodule Axon.Layers do @element_wise_layers [:add, :subtract, :multiply] for op <- @element_wise_layers do - defn unquote(op)(inputs, _opts \\ []) do - transform(inputs, fn inputs -> - [first | rest] = Tuple.to_list(inputs) + deftransform unquote(op)(inputs, _opts \\ []) do + [first | rest] = Tuple.to_list(inputs) - Enum.reduce(rest, first, fn next, acc -> - apply(Nx, unquote(op), [acc, next]) - end) + Enum.reduce(rest, first, fn next, acc -> + apply(Nx, unquote(op), [acc, next]) end) end end @doc false - defn concatenate(inputs, opts \\ []) do - opts = keyword!(opts, axis: -1, mode: :inference) + deftransform concatenate(inputs, opts \\ []) do + opts = Keyword.validate!(opts, axis: -1, mode: :inference) - transform(inputs, fn inputs -> - inputs - |> Tuple.to_list() - |> Nx.concatenate(axis: opts[:axis]) - end) + inputs + |> Tuple.to_list() + |> Nx.concatenate(axis: opts[:axis]) end ## Recurrent @@ -2404,6 +2165,7 @@ defmodule Axon.Layers do defn gru_cell( input, carry, + mask, input_kernel, hidden_kernel, bias, @@ -2419,7 +2181,10 @@ defmodule Axon.Layers do z = gate_fn.(dense(input, wiz, bz) + dense(hidden, whz, 0)) n = activation_fn.(dense(input, win, bin) + r * dense(hidden, whn, bhn)) + mask = Nx.broadcast(mask, hidden) + new_h = (1.0 - z) * n + z * hidden + new_h = Nx.select(Nx.as_type(mask, :u8), hidden, new_h) {new_h, {new_h}} end @@ -2437,6 +2202,7 @@ defmodule Axon.Layers do defn lstm_cell( input, carry, + mask, input_kernel, hidden_kernel, bias, @@ -2457,6 +2223,10 @@ defmodule Axon.Layers do new_c = f * cell + i * g new_h = o * activation_fn.(new_c) + mask = Nx.broadcast(mask, hidden) + + new_h = Nx.select(Nx.as_type(mask, :u8), hidden, new_h) + new_c = Nx.select(Nx.as_type(mask, :u8), cell, new_c) {new_h, {new_c, new_h}} end @@ -2476,14 +2246,14 @@ defmodule Axon.Layers do * [Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting](https://arxiv.org/abs/1506.04214) """ - defn conv_lstm_cell(input, carry, input_kernel, hidden_kernel, bias, opts \\ []) do + defn conv_lstm_cell(input, carry, _mask, input_kernel, hidden_kernel, bias, opts \\ []) do opts = keyword!(opts, strides: 1, padding: :same) {ih} = input_kernel {hh} = hidden_kernel {bi} = bias - {input, {cell, hidden}} = rank_down({input, carry}) + {cell, hidden} = rank_down(carry) gates = Nx.add( @@ -2497,49 +2267,43 @@ defmodule Axon.Layers do new_c = f * cell + Axon.Activations.sigmoid(i) * Axon.Activations.tanh(g) new_h = Axon.Activations.sigmoid(o) * Axon.Activations.tanh(new_c) - rank_up({new_h, {new_c, new_h}}) + {new_h, rank_up({new_c, new_h})} end - defnp split_gates(gates) do - transform(gates, fn gates -> - channels = elem(Nx.shape(gates), 1) - split_every = div(channels, 4) + deftransformp split_gates(gates) do + channels = elem(Nx.shape(gates), 1) + split_every = div(channels, 4) - split_dims = - for i <- 0..3 do - {i * split_every, split_every} - end + split_dims = + for i <- 0..3 do + {i * split_every, split_every} + end - split_dims - |> Enum.map(fn {start, len} -> Nx.slice_along_axis(gates, start, len, axis: 1) end) - |> List.to_tuple() - end) + split_dims + |> Enum.map(fn {start, len} -> Nx.slice_along_axis(gates, start, len, axis: 1) end) + |> List.to_tuple() end - defnp rank_down(rnn_data) do - transform(rnn_data, fn {input, {cell, hidden}} -> - [cell, hidden, input] = - for tensor <- [cell, hidden, input] do - Nx.squeeze(tensor, axes: [1]) - end + deftransformp rank_down({cell, hidden}) do + [cell, hidden] = + for tensor <- [cell, hidden] do + Nx.squeeze(tensor, axes: [1]) + end - {input, {cell, hidden}} - end) + {cell, hidden} end - defnp rank_up(rnn_data) do - transform(rnn_data, fn {input, {cell, hidden}} -> - [cell, hidden, input] = - for tensor <- [cell, hidden, input] do - new_shape = - Nx.shape(tensor) - |> Tuple.insert_at(1, 1) + deftransformp rank_up({cell, hidden}) do + [cell, hidden] = + for tensor <- [cell, hidden] do + new_shape = + Nx.shape(tensor) + |> Tuple.insert_at(1, 1) - Nx.reshape(tensor, new_shape) - end + Nx.reshape(tensor, new_shape) + end - {input, {cell, hidden}} - end) + {cell, hidden} end @doc """ @@ -2553,35 +2317,68 @@ defmodule Axon.Layers do This function will make use of an `defn` while-loop such and thus may be more efficient for long sequences. """ - defn dynamic_unroll(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias) do - time_steps = transform(Nx.shape(input_sequence), &elem(&1, 1)) - - feature_dims = transform(Nx.rank(input_sequence), &List.duplicate(0, &1 - 2)) + defn dynamic_unroll(cell_fn, input_sequence, carry, mask, input_kernel, recurrent_kernel, bias) do + time_steps = Nx.axis_size(input_sequence, 1) + feature_dims = list_duplicate(0, Nx.rank(input_sequence) - 2) + mask = get_mask(mask, input_sequence) initial_shape = - transform({cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias}, fn - {cell_fn, inp, carry, inp_kernel, hid_kernel, bias} -> - seq = Nx.slice_along_axis(inp, 0, 1, axis: 1) - {seq, _} = cell_fn.(seq, carry, inp_kernel, hid_kernel, bias) - put_elem(Nx.shape(seq), 1, elem(Nx.shape(inp), 1)) - end) + unroll_initial_shape_transform( + cell_fn, + input_sequence, + carry, + mask, + input_kernel, + recurrent_kernel, + bias + ) init_sequence = Nx.broadcast(0.0, initial_shape) i = Nx.tensor(0) - {_, carry, output, _, _, _, _} = - while {i, carry, init_sequence, input_sequence, input_kernel, recurrent_kernel, bias}, + {_, carry, output, _, _, _, _, _} = + while {i, carry, init_sequence, input_sequence, mask, input_kernel, recurrent_kernel, bias}, Nx.less(i, time_steps) do sequence = Nx.slice_along_axis(input_sequence, i, 1, axis: 1) - indices = transform({feature_dims, i}, fn {feature_dims, i} -> [0, i] ++ feature_dims end) - {output, carry} = cell_fn.(sequence, carry, input_kernel, recurrent_kernel, bias) + sequence = Nx.squeeze(sequence, axes: [1]) + mask_token = Nx.slice_along_axis(mask, i, 1, axis: 1) + mask_token = Nx.reshape(mask_token, {Nx.axis_size(sequence, 0), 1}) + indices = compute_indices(i, feature_dims) + + {output, carry} = + cell_fn.(sequence, carry, mask_token, input_kernel, recurrent_kernel, bias) + + output = Nx.new_axis(output, 1) update_sequence = Nx.put_slice(init_sequence, indices, output) - {i + 1, carry, update_sequence, input_sequence, input_kernel, recurrent_kernel, bias} + + {i + 1, carry, update_sequence, input_sequence, mask, input_kernel, recurrent_kernel, + bias} end {output, carry} end + deftransformp compute_indices(i, feature_dims) do + [0, i] ++ feature_dims + end + + deftransformp unroll_initial_shape_transform( + cell_fn, + inp, + carry, + mask, + inp_kernel, + hid_kernel, + bias + ) do + seq = Nx.slice_along_axis(inp, 0, 1, axis: 1) + seq = Nx.squeeze(seq, axes: [1]) + mask_token = Nx.slice_along_axis(mask, 0, 1, axis: 1) + mask_token = Nx.reshape(mask_token, {Nx.axis_size(seq, 0), 1}) + {seq, _} = cell_fn.(seq, carry, mask_token, inp_kernel, hid_kernel, bias) + Tuple.insert_at(Nx.shape(seq), 1, elem(Nx.shape(inp), 1)) + end + @doc """ Statically unrolls an RNN. @@ -2594,29 +2391,47 @@ defmodule Axon.Layers do the entire operation appears as a part of the compilation graph. This makes it suitable for shorter sequences. """ - defn static_unroll(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias) do - static_unroll_loop(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias) + defn static_unroll(cell_fn, input_sequence, carry, mask, input_kernel, recurrent_kernel, bias) do + static_unroll_loop(cell_fn, input_sequence, carry, mask, input_kernel, recurrent_kernel, bias) end deftransformp static_unroll_loop( cell_fn, input_sequence, carry, + mask, input_kernel, recurrent_kernel, bias ) do time_steps = elem(Nx.shape(input_sequence), 1) + mask = get_mask(mask, input_sequence) {carry, outputs} = for t <- 0..(time_steps - 1), reduce: {carry, []} do {carry, outputs} -> input = Nx.slice_along_axis(input_sequence, t, 1, axis: 1) - {output, carry} = cell_fn.(input, carry, input_kernel, recurrent_kernel, bias) + input = Nx.squeeze(input, axes: [1]) + mask_token = Nx.slice_along_axis(mask, t, 1, axis: 1) + mask_token = Nx.reshape(mask_token, {Nx.axis_size(input, 0), 1}) + + {output, carry} = + cell_fn.(input, carry, mask_token, input_kernel, recurrent_kernel, bias) + {carry, [output | outputs]} end - {Nx.concatenate(Enum.reverse(outputs), axis: 1), carry} + {Nx.stack(Enum.reverse(outputs), axis: 1), carry} + end + + deftransformp get_mask(mask, sequence) do + case Nx.shape(mask) do + {} -> + Nx.broadcast(mask, {Nx.axis_size(sequence, 0), 1}) + + _ -> + mask + end end @recurrent_layers [lstm: {0, 0, 0, 0}, gru: {0, 0, 0, 0}, conv_lstm: {0}] @@ -2625,6 +2440,7 @@ defmodule Axon.Layers do deftransform unquote(rnn_op)( input, hidden_state, + mask, input_kernel, hidden_kernel, bias \\ [], @@ -2641,8 +2457,8 @@ defmodule Axon.Layers do Keyword.validate!(opts, mode: :inference, unroll: :static, - activation: :sigmoid, - gate: :tanh, + activation: :tanh, + gate: :sigmoid, conv_opts: [] ) @@ -2654,6 +2470,7 @@ defmodule Axon.Layers do cell_fn, input, hidden_state, + mask, input_kernel, hidden_kernel, bias @@ -2664,6 +2481,7 @@ defmodule Axon.Layers do cell_fn, input, hidden_state, + mask, input_kernel, hidden_kernel, bias @@ -2675,17 +2493,17 @@ defmodule Axon.Layers do defp get_cell_fn(:lstm, activation, gate, _) do gate_fn = &apply(Axon.Activations, gate, [&1]) act_fn = &apply(Axon.Activations, activation, [&1]) - &lstm_cell(&1, &2, &3, &4, &5, gate_fn, act_fn) + &lstm_cell(&1, &2, &3, &4, &5, &6, gate_fn, act_fn) end defp get_cell_fn(:gru, activation, gate, _) do gate_fn = &apply(Axon.Activations, gate, [&1]) act_fn = &apply(Axon.Activations, activation, [&1]) - &gru_cell(&1, &2, &3, &4, &5, gate_fn, act_fn) + &gru_cell(&1, &2, &3, &4, &5, &6, gate_fn, act_fn) end defp get_cell_fn(:conv_lstm, _, _, conv_opts) do - &conv_lstm_cell(&1, &2, &3, &4, &5, conv_opts) + &conv_lstm_cell(&1, &2, &3, &4, &5, &6, conv_opts) end @doc false @@ -2693,17 +2511,7 @@ defmodule Axon.Layers do assert_min_rank!("Axon.Layers.split", "input", input, 2) opts = keyword!(opts, [:index, :splits, axis: -1, mode: :train]) - shape = Nx.shape(input) - - {offset, size} = - transform( - {shape, opts[:index], opts[:splits], opts[:axis]}, - fn {shape, idx, splits, axis} -> - slice_size = Axon.Shape.split(shape, splits, axis) - offset = idx * slice_size - {offset, slice_size} - end - ) + {offset, size} = Axon.Shape.split(input, opts[:index], opts[:splits], opts[:axis]) Nx.slice_along_axis(input, offset, size, axis: opts[:axis]) end @@ -2716,10 +2524,14 @@ defmodule Axon.Layers do end deftransformp stack_columns_transform(container, ignore) do - container - |> Map.from_struct() - |> Enum.reject(fn {k, _} -> k in ignore end) - |> Enum.reduce([], fn {_, v}, acc -> [v | acc] end) + container.__struct__().__info__(:struct) + |> Enum.reduce([], fn %{field: k}, acc -> + if k in ignore do + acc + else + [Map.fetch!(container, k) | acc] + end + end) |> Enum.reverse() |> Nx.stack(axis: -1) end diff --git a/lib/axon/loop.ex b/lib/axon/loop.ex index 500e59b1..277ad4f3 100644 --- a/lib/axon/loop.ex +++ b/lib/axon/loop.ex @@ -5,7 +5,7 @@ defmodule Axon.Loop do Inspired heavily by [PyTorch Ignite](https://pytorch.org/ignite/index.html). - The main abstraction is the `%Loop{}` struct, which controls a nested + The main abstraction is the `%Axon.Loop{}` struct, which controls a nested reduction of the form: Enum.reduce(1..max_epochs, state, fn epoch, state -> @@ -14,7 +14,7 @@ defmodule Axon.Loop do `data` is assumed to be an `Enumerable` or `Stream` of input data which is handled by a processing function, `batch_step`. The purpose of the loop - abstraction is to take away much of the boilerplate used in solving machine + abstraction is to take away much of the boilerplate code used in solving machine learning tasks. Tasks such as normalizing a dataset, hyperparameter optimization, or training machine learning models boil down to writing one function: @@ -44,10 +44,10 @@ defmodule Axon.Loop do dataset for `N` epochs before finally returning the trained model state. By defining 1 function, we've created a training loop that works for most machine learning models. - In actuality, the loop abstraction accumulates a struct, `Axon.Loop.State`, which looks + In actuality, the loop abstraction accumulates a struct, `%Axon.Loop.State{}`, which looks like (assuming `container` is a generic Elixir container of tensors, e.g. map, tuple, etc.): - %State{ + %Axon.Loop.State{ epoch: integer(), max_epoch: integer(), iteration: integer(), @@ -90,6 +90,15 @@ defmodule Axon.Loop do new_state end + Note that any optimization and training anonymous functions that need to be used in the + `batch_step` function can be passed as extra arguments. For example: + + step_with_training_arguments = fn data, state, optimizer_update_fn, state_update_fn -> + # ...do something... + end + + step = &(step_with_training_arguments.(&1, &2, actual_optimizer_update_fn, actual_state_update_fn)) + ## Metrics Often times you want to compute metrics associated with your training iterations. @@ -143,14 +152,12 @@ defmodule Axon.Loop do :iteration_completed, # On iteration complete :epoch_completed, # On epoch complete :epoch_halted, # On epoch halt, if early halted - :halted, # On loop halt, if early halted - :completed # On loop completion ] - You can attach event handlers to events using `Axon.Loop.handle/4`: + You can attach event handlers to events using `Axon.Loop.handle_event/4`: loop - |> Axon.Loop.handle(:iteration_completed, &log_metrics/1, every: 100) + |> Axon.Loop.handle_event(:iteration_completed, &log_metrics/1, every: 100) |> Axon.Loop.run(data) The above will trigger `log_metrics/1` every 100 times the `:iteration_completed` event @@ -168,10 +175,10 @@ defmodule Axon.Loop do to the loop. If you have two handlers on the same event, they will trigger in order: loop - |> Axon.Loop.handle(:epoch_completed, &normalize_state/1) # Runs first - |> Axon.Loop.handle(:epoch_completed, &log_state/1) # Runs second + |> Axon.Loop.handle_event(:epoch_completed, &normalize_state/1) # Runs first + |> Axon.Loop.handle_event(:epoch_completed, &log_state/1) # Runs second - You may provide filters to filter when event handlers trigger. See `Axon.Loop.handle/4` + You may provide filters to filter when event handlers trigger. See `Axon.Loop.handle_event/4` for more details on valid filters. ## Factories @@ -191,8 +198,7 @@ defmodule Axon.Loop do In order to execute a loop, you should use `Axon.Loop.run/3`: - loop - |> Axon.Loop.run(data, epochs: 10) + Axon.Loop.run(loop, data, epochs: 10) ## Resuming loops @@ -203,7 +209,7 @@ defmodule Axon.Loop do |> Axon.Loop.from_state(state) |> Axon.Loop.run(data) """ - require Axon.Updates + require Polaris.Updates require Logger alias __MODULE__, as: Loop @@ -221,9 +227,7 @@ defmodule Axon.Loop do :iteration_started, :iteration_completed, :epoch_completed, - :epoch_halted, - :halted, - :completed + :epoch_halted ] @default_handlers %{ @@ -250,7 +254,7 @@ defmodule Axon.Loop do :soft_margin ] - @valid_axon_optimizers [ + @valid_polaris_optimizers [ :adabelief, :adagrad, :adam, @@ -302,12 +306,14 @@ defmodule Axon.Loop do for multi-output models, or an arity-2 function representing a custom loss function. - `optimizer` must be an atom matching the name of a valid optimizer in `Axon.Optimizers`, + `optimizer` must be an atom matching the name of a valid optimizer in `Polaris.Optimizers`, or a `{init_fn, update_fn}` tuple where `init_fn` is an arity-1 function which - initializes the optimizer state from attached parameters and `update_fn` is an - arity-3 function which scales gradient updates with respect to input parameters, - optimizer state, and gradients. See `Axon.Updates` for more information on building - optimizers. + initializes the optimizer state from the model parameters and `update_fn` is an + arity-3 function that receives `(gradient, optimizer_state, model_parameters)` and + scales gradient updates with respect to input parameters, optimizer state, and gradients. + The `update_fn` returns `{scaled_updates, optimizer_state}`, which can then be applied to + the model through `model_parameters = Axon.Update.apply_updates(model_parameters, scaled_updates)`. + See `Polaris.Updates` for more information on building optimizers. ## Options @@ -351,7 +357,7 @@ defmodule Axon.Loop do loss: Nx.tensor(0.0), gradient_step: Nx.tensor(0), model_state: model_state, - gradient_state: zeros_like(model_state), + gradient_state: zeros_like(model_state, type: :f32), optimizer_state: optimizer_state, loss_scale_state: loss_scale_state } @@ -458,7 +464,6 @@ defmodule Axon.Loop do opts = keyword!(opts, [:steps]) steps = opts[:steps] - # TODO: this explodes the graph {_, new_model_state, _, new_optimizer_state, new_gradient_state, new_gradient_step, _} = while {gradients, model_state, new_state, optimizer_state, gradient_state, gradient_step, flag = Nx.tensor(1)}, @@ -468,7 +473,7 @@ defmodule Axon.Loop do update_optimizer_fn.(gradients, optimizer_state, model_state) new_gradient_state = zeros_like(model_state) - new_model_state = Axon.Updates.apply_updates(model_state, updates, new_state) + new_model_state = Polaris.Updates.apply_updates(model_state, updates, new_state) {gradients, new_model_state, new_state, new_optimizer_state, new_gradient_state, 0, Nx.tensor(0)} @@ -616,11 +621,11 @@ defmodule Axon.Loop do for multi-output models, or an arity-2 function representing a custom loss function. - `optimizer` must be an atom matching the name of a valid optimizer in `Axon.Optimizers`, + `optimizer` must be an atom matching the name of a valid optimizer in `Polaris.Optimizers`, or a `{init_fn, update_fn}` tuple where `init_fn` is an arity-1 function which initializes the optimizer state from attached parameters and `update_fn` is an arity-3 function which scales gradient updates with respect to input parameters, - optimizer state, and gradients. See `Axon.Updates` for more information on building + optimizer state, and gradients. See `Polaris.Updates` for more information on building optimizers. This function creates a step function which outputs a map consisting of the following @@ -649,7 +654,7 @@ defmodule Axon.Loop do ### Customizing Optimizer model - |> Axon.Loop.trainer(:binary_cross_entropy, Axon.Optimizers.adam(0.05)) + |> Axon.Loop.trainer(:binary_cross_entropy, Polaris.Optimizers.adam(learning_rate: 0.05)) |> Axon.Loop.run(data) ### Custom loss @@ -657,7 +662,7 @@ defmodule Axon.Loop do loss_fn = fn y_true, y_pred -> Nx.cos(y_true, y_pred) end model - |> Axon.Loop.trainer(loss_fn, Axon.Optimizers.rmsprop(0.01)) + |> Axon.Loop.trainer(loss_fn, Polaris.Optimizers.rmsprop(learning_rate: 0.01)) |> Axon.Loop.run(data) ### Multiple objectives with multi-output model @@ -691,7 +696,7 @@ defmodule Axon.Loop do # Build loss now so we can use it as a metric loss_fn = build_loss_fn(loss) - step_opts = Keyword.take(opts, [:gradient_accumulation_steps, :loss_cale, :seed]) + step_opts = Keyword.take(opts, [:gradient_accumulation_steps, :loss_scale, :seed]) {init_fn, step_fn} = train_step(model, loss_fn, optimizer, step_opts) log_interval = opts[:log] || 50 @@ -754,7 +759,7 @@ defmodule Axon.Loop do end @doc """ - Creates a supervised evaluator from a model and model state. + Creates a supervised evaluator from a model. An evaluator can be used for things such as testing and validation of models after or during training. It assumes `model` is an Axon struct, container of @@ -775,8 +780,17 @@ defmodule Axon.Loop do |> Axon.Loop.evaluator() |> Axon.Loop.metric("Accuracy", :accuracy) - Applies an output transform which returns the map of metrics accumulated over - the given loop. + You must pass a compatible trained model state to `Axon.Loop.run/4` when using + supervised evaluation loops. For example, if you've binded the result of a training + run to `trained_model_state`, you can run the trained model through an evaluation + run like this: + + model + |> Axon.Loop.evaluator() + |> Axon.Loop.run(data, trained_model_state, compiler: EXLA) + + This function applies an output transform which returns the map of metrics accumulated + over the given loop. """ def evaluator(model) do {init_fn, step_fn} = eval_step(model) @@ -878,8 +892,6 @@ defmodule Axon.Loop do :iteration_completed, # On iteration complete :epoch_completed, # On epoch complete :epoch_halted, # On epoch halt, if early halted - :halted, # On loop halt, if early halted - :completed # On loop completion ] Generally, event handlers are side-effecting operations which provide some @@ -889,8 +901,8 @@ defmodule Axon.Loop do loop: loop - |> Axon.Loop.handle(:epoch_started, &normalize_step_state/1) # executes first - |> Axon.Loop.handle(:epoch_started, &log_step_state/1) # executes second + |> Axon.Loop.handle_event(:epoch_started, &normalize_step_state/1) # executes first + |> Axon.Loop.handle_event(:epoch_started, &log_step_state/1) # executes second Thus, if you have separate handlers which alter or depend on loop state, you need to ensure they are ordered correctly, or combined into a single @@ -923,11 +935,10 @@ defmodule Axon.Loop do only: N # Trigger on `N` event **Warning: If you modify the step state in an event handler, it will trigger - potentially excessive recompilation and result in significant additinal overhead + potentially excessive recompilation and result in significant additional overhead during loop execution.** """ - # TODO(seanmor5): Custom events - def handle(%Loop{handlers: handle_fns} = loop, event, handler, filter \\ :always) do + def handle_event(%Loop{handlers: handle_fns} = loop, event, handler, filter \\ :always) do filter = build_filter_fn(filter) handle_fns = @@ -945,6 +956,12 @@ defmodule Axon.Loop do %Loop{loop | handlers: handle_fns} end + @doc false + @deprecated "handle/4 is deprecated, use handle_event/4 instead" + def handle(%Loop{} = loop, event, handler, filter \\ :always) do + handle_event(loop, event, handler, filter) + end + @doc """ Adds a handler function which logs the given message produced by `message_fn` to the given IO device every `event` satisfying @@ -983,7 +1000,7 @@ defmodule Axon.Loop do end end - handle(loop, event, log_fn, filter) + handle_event(loop, event, log_fn, filter) end @doc """ @@ -1043,7 +1060,6 @@ defmodule Axon.Loop do metrics = Enum.reduce(metric_fns, evaluator, fn {k, {_, v}}, loop -> metric(loop, v, k) end) - |> log(fn _ -> "\n" end, event: :completed) |> run(validation_data, model_state) |> Access.get(0) |> Map.new(fn {k, v} -> @@ -1054,7 +1070,7 @@ defmodule Axon.Loop do {:continue, %{state | metrics: metrics}} end - handle(loop, event, validation_loop, filter) + handle_event(loop, event, validation_loop, filter) end @doc """ @@ -1102,7 +1118,7 @@ defmodule Axon.Loop do mode = opts[:mode] || :min patience = opts[:patience] || 3 - handle(loop, event, &monitor_impl(&1, metric, fun, name, mode, patience), filter) + handle_event(loop, event, &monitor_impl(&1, metric, fun, name, mode, patience), filter) end defp monitor_impl( @@ -1239,12 +1255,22 @@ defmodule Axon.Loop do `checkpoint_\#{epoch}_\#{iteration}.ckpt`. """ def checkpoint(%Loop{} = loop, opts \\ []) do - {event, opts} = Keyword.pop(opts, :event, :epoch_completed) - {filter, opts} = Keyword.pop(opts, :filter, :always) - {path, opts} = Keyword.pop(opts, :path, "checkpoint") - {file_pattern, opts} = Keyword.pop(opts, :file_pattern, &default_checkpoint_file/1) + opts = + Keyword.validate!(opts, [ + :criteria, + event: :epoch_completed, + filter: :always, + path: "checkpoint", + file_pattern: &default_checkpoint_file/1, + mode: :min + ]) + {criteria, opts} = Keyword.pop(opts, :criteria) - {mode, serialize_opts} = Keyword.pop(opts, :mode, :min) + {event, opts} = Keyword.pop!(opts, :event) + {filter, opts} = Keyword.pop!(opts, :filter) + {path, opts} = Keyword.pop!(opts, :path) + {file_pattern, opts} = Keyword.pop!(opts, :file_pattern) + {mode, serialize_opts} = Keyword.pop!(opts, :mode) checkpoint_fun = &checkpoint_impl(&1, path, file_pattern, serialize_opts) @@ -1255,7 +1281,7 @@ defmodule Axon.Loop do filter: filter ) else - handle(loop, event, checkpoint_fun, filter) + handle_event(loop, event, checkpoint_fun, filter) end end @@ -1419,40 +1445,45 @@ defmodule Axon.Loop do opts = Keyword.validate!(opts, event: :iteration_completed, filter: :always) - handle( + handle_event( loop, opts[:event], fn %{ metrics: metrics, - handler_metadata: handler_meta + handler_metadata: handler_metadata } = state -> unless Map.has_key?(metrics, metric) do raise ArgumentError, "invalid metric to plot, key #{inspect(metric)} not present in metrics" end - {iteration, handler_meta} = absolute_iteration(handler_meta) + plot_metadata_key = "plot_#{metric}" + plot_metadata = Map.get(handler_metadata, plot_metadata_key, %{}) + + {iteration, plot_metadata} = absolute_iteration(plot_metadata) Kino.VegaLite.push(plot, %{ "step" => iteration, metric => Nx.to_number(metrics[metric]) }) - {:continue, %{state | handler_metadata: handler_meta}} + next_handler_metadata = Map.put(handler_metadata, plot_metadata_key, plot_metadata) + + {:continue, %{state | handler_metadata: next_handler_metadata}} end, opts[:filter] ) end - defp absolute_iteration( - %{"plot" => %{"absolute_iteration" => absolute_iteration}} = handler_meta - ), - do: - {absolute_iteration, - put_in(handler_meta, ["plot", "absolute_iteration"], absolute_iteration + 1)} + defp absolute_iteration(plot_metadata) do + case plot_metadata do + %{"absolute_iteration" => iteration} -> + {iteration, Map.put(plot_metadata, "absolute_iteration", iteration + 1)} - defp absolute_iteration(handler_meta), - do: {0, Map.put(handler_meta, "plot", %{"absolute_iteration" => 1})} + %{} -> + {0, %{"absolute_iteration" => 1}} + end + end defp assert_kino_vega_lite!(fn_name) do unless Code.ensure_loaded?(Kino.VegaLite) do @@ -1713,7 +1744,8 @@ defmodule Axon.Loop do &Map.put(&2, &1, zero_metrics) ) - {_, state} = fire_event(status, handler_fns, %{state | metrics: final_metrics_map}, debug?) + state = %State{state | metrics: final_metrics, status: status} + {_, state} = fire_event(status, handler_fns, state, debug?) output_transform.(state) end @@ -1794,9 +1826,6 @@ defmodule Axon.Loop do Logger.debug("Axon.Loop finished batch step execution in #{us_to_ms(time)}ms") end - # Force a garbage collection so any device or copied data is deallocated. - :erlang.garbage_collect() - batch_fn = {:compiled, batch_fn} state = %{state | step_state: new_step_state, metrics: new_metrics} @@ -2005,12 +2034,12 @@ defmodule Axon.Loop do # Builds optimizer init and update functions either from an atom # or a tuple of init / update functions. The init and update functions - # match the signatures of those defined in Axon.Updates. If the + # match the signatures of those defined in Polaris.Updates. If the # optimizer is an atom, it must match the name of a function in - # Axon.Optimizers. + # Polaris.Optimizers. defp build_optimizer_fns(optimizer) - when is_atom(optimizer) and optimizer in @valid_axon_optimizers do - apply(Axon.Optimizers, optimizer, []) + when is_atom(optimizer) and optimizer in @valid_polaris_optimizers do + apply(Polaris.Optimizers, optimizer, []) end defp build_optimizer_fns({init_optimizer_fn, update_optimizer_fn}) @@ -2021,8 +2050,8 @@ defmodule Axon.Loop do defp build_optimizer_fns(invalid) do raise ArgumentError, "Invalid optimizer #{inspect(invalid)}, a valid optimizer" <> - " is an atom matching the name of an optimizer in Axon.Optimizers" <> - " or a tuple of {init_fn, update_fn}. See Axon.Updates for more" <> + " is an atom matching the name of an optimizer in Polaris.Optimizers" <> + " or a tuple of {init_fn, update_fn}. See Polaris.Updates for more" <> " information on building optimizers using the low-level API" end diff --git a/lib/axon/loop/state.ex b/lib/axon/loop/state.ex index 09782bb0..eaccef22 100644 --- a/lib/axon/loop/state.ex +++ b/lib/axon/loop/state.ex @@ -42,10 +42,14 @@ defmodule Axon.Loop.State do `event_counts` is a metadata field which stores information about the number of times each event has been fired. This is useful when creating custom filters. + + `status` refers to the loop state status after the loop has executed. You can + use this to determine if the loop ran to completion or if it was halted early. """ @enforce_keys [:step_state] defstruct [ :step_state, + :status, handler_metadata: %{}, epoch: 0, max_epoch: 1, diff --git a/lib/axon/loss_scale.ex b/lib/axon/loss_scale.ex index 116d1292..2a071c0a 100644 --- a/lib/axon/loss_scale.ex +++ b/lib/axon/loss_scale.ex @@ -7,7 +7,7 @@ defmodule Axon.LossScale do precision during the model training process. Each loss-scale implementation here returns a 3-tuple of the functions: - {init_fn, scale_fn, unscale_fn, adjust_fn} = Axon.LossScale.static(Nx.power(2, 15)) + {init_fn, scale_fn, unscale_fn, adjust_fn} = Axon.LossScale.static(Nx.pow(2, 15)) You can use these to scale/unscale loss and gradients as well as adjust the loss scale state. @@ -25,7 +25,7 @@ defmodule Axon.LossScale do @doc """ Implements identity loss-scale. """ - def identity() do + def identity(_opts \\ []) do scale_unscale_fun = fn x, _state -> x end adjust_fun = fn x, state -> {x, state} end {fn -> %{} end, scale_unscale_fun, adjust_fun} @@ -34,8 +34,9 @@ defmodule Axon.LossScale do @doc """ Implements static loss-scale. """ - def static(loss_scale \\ @default_loss_scale) do - loss_scale = Nx.backend_copy(loss_scale, Nx.BinaryBackend) + def static(opts \\ []) do + opts = Keyword.validate!(opts, init_scale: @default_loss_scale) + loss_scale = Nx.backend_copy(opts[:init_scale], Nx.BinaryBackend) {fn -> init_static(loss_scale) end, &scale_static/2, &unscale_static/2} end @@ -44,26 +45,28 @@ defmodule Axon.LossScale do end defnp scale_static(value, %{loss_scale: loss_scale}) do - transform({value, loss_scale}, fn {value, loss_scale} -> - deep_new(value, fn x -> x * loss_scale end) - end) + deep_new(value, fn x -> x * loss_scale end) end defnp unscale_static(value, %{loss_scale: loss_scale} = state) do inv_loss_scale = 1 / loss_scale - - unscaled = - transform({value, inv_loss_scale}, fn {value, inv_loss_scale} -> - deep_new(value, fn x -> x * inv_loss_scale end) - end) - + unscaled = deep_new(value, fn x -> x * inv_loss_scale end) {unscaled, state} end @doc """ Implements dynamic loss-scale. """ - def dynamic(loss_scale \\ @default_loss_scale, opts \\ []) do + def dynamic(opts \\ []) do + opts = + Keyword.validate!(opts, + init_scale: @default_loss_scale, + period: 2_000, + factor: 2, + min_loss_scale: 1 + ) + + {loss_scale, opts} = Keyword.pop(opts, :init_scale, @default_loss_scale) loss_scale = Nx.backend_copy(loss_scale, Nx.BinaryBackend) { @@ -81,19 +84,12 @@ defmodule Axon.LossScale do end defnp scale_dynamic(value, %{loss_scale: loss_scale}) do - transform({value, loss_scale}, fn {value, loss_scale} -> - deep_new(value, fn x -> x * loss_scale end) - end) + deep_new(value, fn x -> x * loss_scale end) end defnp unscale_dynamic(value, %{loss_scale: loss_scale} = state, opts \\ []) do inv_loss_scale = 1 / loss_scale - - unscaled = - transform({value, inv_loss_scale}, fn {value, inv_loss_scale} -> - deep_new(value, fn x -> x * inv_loss_scale end) - end) - + unscaled = deep_new(value, fn x -> x * inv_loss_scale end) {unscaled, adjust_dynamic(value, state, opts)} end @@ -101,24 +97,22 @@ defmodule Axon.LossScale do opts = keyword!(opts, period: 2_000, factor: 2, min_loss_scale: 1) grads_are_finite = - transform(grads, fn grads -> - deep_reduce(grads, Nx.tensor(1), fn x, acc -> - x - |> is_finite() - |> Nx.logical_and(acc) - end) + deep_reduce(grads, Nx.tensor(1), fn x, acc -> + x + |> is_finite() + |> Nx.logical_and(acc) end) new_loss_scale = - if grads_are_finite do - if counter == opts[:period] - 1 do - first_finite(loss_scale * opts[:factor], loss_scale) - else + Nx.select( + grads_are_finite, + Nx.select( + Nx.equal(counter, opts[:period] - 1), + first_finite(loss_scale * opts[:factor], loss_scale), loss_scale - end - else + ), Nx.max(opts[:min_loss_scale], loss_scale / opts[:factor]) - end + ) new_counter = Nx.remainder(counter + 1, opts[:period]) * grads_are_finite diff --git a/lib/axon/losses.ex b/lib/axon/losses.ex index 34cdcaf1..b5068fa5 100644 --- a/lib/axon/losses.ex +++ b/lib/axon/losses.ex @@ -134,20 +134,7 @@ defmodule Axon.Losses do # altogether if necessary. If either of them is set, then we need to set # both and perform this whole thing. If neither is set, we set this to # nil and then avoid the weighted avg later on. - weights = - transform({y_true, opts[:positive_weight], opts[:negative_weight]}, fn - {_, nil, nil} -> - nil - - {y_true, pos, nil} -> - Nx.take(Nx.tensor([1.0, pos], backend: Nx.Defn.Expr), y_true) - - {y_true, nil, neg} -> - Nx.take(Nx.tensor([neg, 1.0], backend: Nx.Defn.Expr), y_true) - - {y_true, pos, neg} -> - Nx.take(Nx.tensor([neg, pos], backend: Nx.Defn.Expr), y_true) - end) + weights = get_weights(y_true, opts[:positive_weight], opts[:negative_weight]) # Merge types before computing loss to prevent under/overflow. This # can especially happen when targets are encoded as u8 tensors. We @@ -207,6 +194,22 @@ defmodule Axon.Losses do reduction(possibly_weighted_avg_loss, opts[:reduction]) end + deftransformp get_weights(y_true, pos, neg) do + case {y_true, pos, neg} do + {_, nil, nil} -> + nil + + {y_true, pos, nil} -> + Nx.take(Nx.tensor([1.0, pos], backend: Nx.Defn.Expr), y_true) + + {y_true, nil, neg} -> + Nx.take(Nx.tensor([neg, 1.0], backend: Nx.Defn.Expr), y_true) + + {y_true, pos, neg} -> + Nx.take(Nx.tensor([neg, pos], backend: Nx.Defn.Expr), y_true) + end + end + defnp sigmoid_cross_entropy_from_logits(y_true, y_pred) do log_p = Axon.Activations.log_sigmoid(y_pred) log_not_p = Axon.Activations.log_sigmoid(-y_pred) @@ -218,7 +221,7 @@ defmodule Axon.Losses do $$l_i = -\sum_i^C \hat{y_i} \cdot \log(y_i)$$ - Categorical cross-entropy is typically used for multi-class classifcation problems. + Categorical cross-entropy is typically used for multi-class classification problems. By default, it expects `y_pred` to encode a probability distribution along the last axis. You can specify `from_logits: true` to indicate `y_pred` is a logits tensor. @@ -850,7 +853,7 @@ defmodule Axon.Losses do loss = y_true |> Nx.subtract(y_pred) - |> Nx.power(2) + |> Nx.pow(2) |> Nx.mean(axes: [-1]) reduction(loss, opts[:reduction]) @@ -895,14 +898,7 @@ defmodule Axon.Losses do n12 = Nx.max(w1 * w2, eps) loss = w12 / n12 - transform( - {opts[:reduction], loss}, - fn - {:mean, loss} -> Nx.mean(loss) - {:sum, loss} -> Nx.sum(loss) - {:none, loss} -> loss - end - ) + reduction(loss, opts[:reduction]) end @doc ~S""" @@ -963,6 +959,56 @@ defmodule Axon.Losses do reduction(loss, opts[:reduction]) end + @doc """ + Huber loss. + + ## Argument Shapes + + * `y_true` - $(d_0, d_1, ..., d_n)$ + * `y_pred` - $(d_0, d_1, ..., d_n)$ + + ## Options + + * `:reduction` - reduction mode. One of `:mean`, `:sum`, or `:none`. + Defaults to `:none`. + + * `:delta` - the point where the Huber loss function changes from a quadratic to linear. + Defaults to `1.0`. + + ## Examples + + iex> y_true = Nx.tensor([[1], [1.5], [2.0]]) + iex> y_pred = Nx.tensor([[0.8], [1.8], [2.1]]) + iex> Axon.Losses.huber(y_true, y_pred) + #Nx.Tensor< + f32[3][1] + [ + [0.019999997690320015], + [0.04499998688697815], + [0.004999990575015545] + ] + > + + iex> y_true = Nx.tensor([[1], [1.5], [2.0]]) + iex> y_pred = Nx.tensor([[0.8], [1.8], [2.1]]) + iex> Axon.Losses.huber(y_true, y_pred, reduction: :mean) + #Nx.Tensor< + f32 + 0.02333332598209381 + > + """ + defn huber(y_true, y_pred, opts \\ []) do + opts = keyword!(opts, reduction: :none, delta: 1.0) + + delta = opts[:delta] + + abs_diff = Nx.abs(y_pred - y_true) + + (abs_diff <= delta) + |> Nx.select(0.5 * abs_diff ** 2, delta * abs_diff - 0.5 * delta ** 2) + |> reduction(opts[:reduction]) + end + @doc """ Connectionist Temporal Classification loss. @@ -1019,14 +1065,11 @@ defmodule Axon.Losses do {Nx.put_slice(loss, [b], Nx.reshape(loss_b, {1})), b + 1, y_true, s_true, y_pred} end - transform( - {opts[:reduction], loss}, - fn - {:mean, loss} -> Nx.divide(loss, l_true) |> Nx.mean() - {:sum, loss} -> Nx.sum(loss) - {:none, loss} -> loss - end - ) + case opts[:reduction] do + :mean -> Nx.divide(loss, l_true) |> Nx.mean() + :sum -> Nx.sum(loss) + :none -> loss + end end defnp get_limits(y_true, s_max, t_max) do @@ -1134,6 +1177,53 @@ defmodule Axon.Losses do t0_prob end + ## Modifiers + + @doc """ + Modifies the given loss function to smooth labels prior + to calculating loss. + + See `apply_label_smoothing/2` for details. + + ## Options + + * `:smoothing` - smoothing factor. Defaults to 0.1 + """ + def label_smoothing(loss_fun, opts \\ []) when is_function(loss_fun, 2) do + opts = Keyword.validate!(opts, smoothing: 0.1) + + fn y_true, y_pred -> + smoothed = apply_label_smoothing(y_true, y_pred, smoothing: opts[:smoothing]) + loss_fun.(smoothed, y_pred) + end + end + + @doc """ + Applies label smoothing to the given labels. + + Label smoothing is a regularization technique which shrink targets + towards a uniform distribution. Label smoothing can improve model + generalization. + + ## Options + + * `:smoothing` - smoothing factor. Defaults to 0.1 + + ## References + + * [Rethinking the Inception Architecture for Computer Vision](https://arxiv.org/abs/1512.00567) + """ + defn apply_label_smoothing(y_true, y_pred, opts \\ []) do + assert_min_rank!("apply_label_smoothing", "y_true", y_true, 2) + assert_min_rank!("apply_label_smoothing", "y_pred", y_pred, 2) + + opts = keyword!(opts, smoothing: 0.1) + n_classes = Nx.axis_size(y_pred, 1) + y_true * (1 - opts[:smoothing]) + opts[:smoothing] / n_classes + end + + ## Helpers + defnp reduction(loss, reduction \\ :none) do case reduction do :mean -> Nx.mean(loss) diff --git a/lib/axon/metrics.ex b/lib/axon/metrics.ex index 3df31355..cf1753e1 100644 --- a/lib/axon/metrics.ex +++ b/lib/axon/metrics.ex @@ -416,14 +416,7 @@ defmodule Axon.Metrics do defn top_k_categorical_accuracy(y_true, y_pred, opts \\ []) do opts = keyword!(opts, k: 5, sparse: false) - y_true = - transform(y_true, fn y_true -> - if opts[:sparse] do - y_true - else - top_k_index_transform(y_true) - end - end) + y_true = if opts[:sparse], do: y_true, else: top_k_index_transform(y_true) cond do Nx.rank(y_pred) == 2 -> @@ -449,7 +442,7 @@ defmodule Axon.Metrics do end end - defnp(top_k_index_transform(y_true), do: Nx.argmax(y_true, axis: -1, keep_axis: true)) + defnp top_k_index_transform(y_true), do: Nx.argmax(y_true, axis: -1, keep_axis: true) # Combinators diff --git a/lib/axon/mixed_precision.ex b/lib/axon/mixed_precision.ex index ad320781..fd948844 100644 --- a/lib/axon/mixed_precision.ex +++ b/lib/axon/mixed_precision.ex @@ -41,6 +41,7 @@ defmodule Axon.MixedPrecision do """ alias Axon.MixedPrecision.Policy + import Axon.Shared @doc """ Creates a mixed precision policy with the given options. @@ -54,10 +55,10 @@ defmodule Axon.MixedPrecision do ## Examples iex> Axon.MixedPrecision.create_policy(params: {:f, 16}, output: {:f, 16}) - %Policy{params: {:f, 16}, compute: {:f, 32}, output: {:f, 16}} + #Axon.MixedPrecision.Policy iex> Axon.MixedPrecision.create_policy(compute: {:bf, 16}) - %Policy{params: {:f, 32}, compute: {:bf, 16}, output: {:f, 32}} + #Axon.MixedPrecision.Policy """ def create_policy(opts \\ []) do params = opts[:params] || {:f, 32} @@ -121,4 +122,34 @@ defmodule Axon.MixedPrecision do def apply_policy(%Axon{} = axon, %Policy{} = policy) do apply_policy(%Axon{} = axon, %Policy{} = policy, & &1) end + + @doc """ + Casts the given container according to the given policy + and type. + + ## Examples + + iex> policy = Axon.MixedPrecision.create_policy(params: {:f, 16}) + iex> params = %{"dense" => %{"kernel" => Nx.tensor([1.0, 2.0, 3.0])}} + iex> params = Axon.MixedPrecision.cast(policy, params, :params) + iex> Nx.type(params["dense"]["kernel"]) + {:f, 16} + + iex> policy = Axon.MixedPrecision.create_policy(compute: {:bf, 16}) + iex> value = Nx.tensor([1.0, 2.0, 3.0]) + iex> value = Axon.MixedPrecision.cast(policy, value, :compute) + iex> Nx.type(value) + {:bf, 16} + + iex> policy = Axon.MixedPrecision.create_policy(output: {:bf, 16}) + iex> value = Nx.tensor([1.0, 2.0, 3.0]) + iex> value = Axon.MixedPrecision.cast(policy, value, :output) + iex> Nx.type(value) + {:bf, 16} + """ + def cast(%Policy{} = policy, tensor_or_container, variable_type) + when variable_type in [:compute, :params, :output] do + type = get_in(policy, [Access.key!(variable_type)]) + deep_new(tensor_or_container, fn x -> Nx.as_type(x, type) end) + end end diff --git a/lib/axon/mixed_precision/policy.ex b/lib/axon/mixed_precision/policy.ex index f34d2774..30f73256 100644 --- a/lib/axon/mixed_precision/policy.ex +++ b/lib/axon/mixed_precision/policy.ex @@ -10,9 +10,11 @@ defmodule Axon.MixedPrecision.Policy do def inspect(policy, _opts) do force_unfit( concat([ + "#Axon.MixedPrecision.Policy<", "p=#{Nx.Type.to_string(policy.params)} ", "c=#{Nx.Type.to_string(policy.compute)} ", - "o=#{Nx.Type.to_string(policy.output)}" + "o=#{Nx.Type.to_string(policy.output)}", + ">" ]) ) end diff --git a/lib/axon/optimizers.ex b/lib/axon/optimizers.ex index 9f3b12cc..8cadf3d6 100644 --- a/lib/axon/optimizers.ex +++ b/lib/axon/optimizers.ex @@ -1,59 +1,6 @@ defmodule Axon.Optimizers do - @moduledoc """ - Implementations of common gradient-based optimization algorithms. - - All of the methods in this module are written in terms of - the update methods defined in `Axon.Updates`. Axon treats - optimizers as the tuple: - - {init_fn, update_fn} - - where `init_fn` returns an initial optimizer state and `update_fn` - scales input gradients. `init_fn` accepts a model's parameters - and attaches state to each parameter. `update_fn` accepts - gradients, optimizer state, and current model parameters and - returns updated optimizer state and gradients. - - Custom optimizers are often created via the `Axon.Updates` API. - - ## Example - - Consider the following usage of the Adam optimizer in a basic - update function (assuming `objective` and the `dataset` are - defined elsewhere): - - defmodule Learning do - - import Nx.Defn - - defn init(params, init_fn) do - init_fn.(params) - end - - defn update(params, optimizer_state, inputs, targets, update_fn) do - {loss, gradient} = value_and_grad(params, &objective(&1, inputs, targets)) - {scaled_updates, new_optimizer_state} = update_fn.(gradient, optimizer_state, params) - {Axon.Updates.apply_updates(params, scaled_updates), new_optimizer_state, loss} - end - end - - model_params = Nx.random_uniform({784, 10}) - {init_fn, update_fn} = Axon.Optimizers.adam(0.005) - - optimizer_state = - Learning.init(params, init_fn) - - {new_params, new_optimizer_state, loss} = - Learning.update(params, optimizer_state, inputs, targets, update_fn) - - For a simpler approach, you can also use optimizers with the training API: - - model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(0.005)) - |> Axon.Loop.run(data, epochs: 10, compiler: EXLA) - - """ - alias Axon.Updates + @moduledoc false + alias Polaris.Updates @doc """ Adabelief optimizer. @@ -69,6 +16,7 @@ defmodule Axon.Optimizers do * [AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients](https://arxiv.org/abs/2010.07468) """ + @deprecated "Use Polaris.Optimizers.adabelief/1 instead" def adabelief(learning_rate \\ 1.0e-3, opts \\ []) do Updates.scale_by_belief(opts) |> scale_by_learning_rate(learning_rate) @@ -85,6 +33,7 @@ defmodule Axon.Optimizers do * [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) """ + @deprecated "Use Polaris.Optimizers.adagrad/1 instead" def adagrad(learning_rate \\ 1.0e-3, opts \\ []) do Updates.scale_by_rss(opts) |> scale_by_learning_rate(learning_rate) @@ -104,6 +53,7 @@ defmodule Axon.Optimizers do * [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980) """ + @deprecated "Use Polaris.Optimizers.adam/1 instead" def adam(learning_rate \\ 1.0e-3, opts \\ []) do Updates.scale_by_adam(opts) |> scale_by_learning_rate(learning_rate) @@ -120,6 +70,7 @@ defmodule Axon.Optimizers do * `:eps_root` - numerical stability term. Defaults to `0.0` * `:decay` - weight decay. Defaults to `0.0` """ + @deprecated "Use Polaris.Optimizers.adamw/1 instead" def adamw(learning_rate \\ 1.0e-3, opts \\ []) do {decay, opts} = Keyword.pop(opts, :decay, 0.0) @@ -144,6 +95,7 @@ defmodule Axon.Optimizers do * [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962) """ + @deprecated "Use Polaris.Optimizers.lamb/1 instead" def lamb(learning_rate \\ 1.0e-2, opts \\ []) do {decay, opts} = Keyword.pop(opts, :decay, 0.0) {min_norm, opts} = Keyword.pop(opts, :min_norm, 0.0) @@ -162,6 +114,7 @@ defmodule Axon.Optimizers do * `:eta` - used to compute variance of noise distribution. Defaults to `0.1` * `:gamma` - used to compute variance of noise distribution. Defaults to `0.55` """ + @deprecated "Use Polaris.Optimizers.noisy_sgd/1 instead" def noisy_sgd(learning_rate \\ 1.0e-2, opts \\ []) do scale_by_learning_rate(learning_rate) |> Updates.add_noise(opts) @@ -182,6 +135,7 @@ defmodule Axon.Optimizers do * [On the Variance of Adaptive Learning Rate and Beyond](https://arxiv.org/pdf/1908.03265.pdf) """ + @deprecated "Use Polaris.Optimizers.radam/1 instead" def radam(learning_rate \\ 1.0e-3, opts \\ []) do Updates.scale_by_radam(opts) |> scale_by_learning_rate(learning_rate) @@ -200,6 +154,7 @@ defmodule Axon.Optimizers do * `:decay` - EMA decay rate. Defaults to `0.9` * `:eps` - numerical stability term. Defaults to `1.0e-8` """ + @deprecated "Use Polaris.Optimizers.rmsprop/1 instead" def rmsprop(learning_rate \\ 1.0e-2, opts \\ []) do {centered, opts} = Keyword.pop(opts, :centered, false) {nesterov?, opts} = Keyword.pop(opts, :nesterov, false) @@ -227,6 +182,7 @@ defmodule Axon.Optimizers do to value of this term. * `:nesterov` - whether or not to use nesterov momentum. Defaults to `false` """ + @deprecated "Use Polaris.Optimizers.sgd/1 instead" def sgd(learning_rate \\ 1.0e-2, opts \\ []) do momentum = opts[:momentum] nesterov? = opts[:nesterov] || false @@ -254,6 +210,7 @@ defmodule Axon.Optimizers do * [Adaptive Methods for Nonconvex Optimization](https://papers.nips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf) """ + @deprecated "Use Polaris.Optimizers.yogi/1 instead" def yogi(learning_rate \\ 1.0e-2, opts \\ []) do Updates.scale_by_yogi(opts) |> scale_by_learning_rate(learning_rate) diff --git a/lib/axon/recurrent.ex b/lib/axon/recurrent.ex deleted file mode 100644 index 7a044466..00000000 --- a/lib/axon/recurrent.ex +++ /dev/null @@ -1,233 +0,0 @@ -defmodule Axon.Recurrent do - @moduledoc false - - import Nx.Defn - import Axon.Layers - - @doc """ - GRU Cell. - - When combined with `Axon.Recurrent.*_unroll`, implements a - GRU-based RNN. More memory efficient than traditional LSTM. - - ## References - - * [Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling](https://arxiv.org/pdf/1412.3555v1.pdf) - """ - @deprecated "Use Axon.Layers.gru_cell/7 instead" - defn gru_cell( - input, - carry, - input_kernel, - hidden_kernel, - bias, - gate_fn \\ &sigmoid/1, - activation_fn \\ &tanh/1 - ) do - {hidden} = carry - {wir, wiz, win} = input_kernel - {whr, whz, whn} = hidden_kernel - {br, bz, bin, bhn} = bias - - r = gate_fn.(dense(input, wir, br) + dense(hidden, whr, 0)) - z = gate_fn.(dense(input, wiz, bz) + dense(hidden, whz, 0)) - n = activation_fn.(dense(input, win, bin) + r * dense(hidden, whn, bhn)) - - new_h = (1.0 - z) * n + z * hidden - - {{new_h}, new_h} - end - - @doc """ - LSTM Cell. - - When combined with `Axon.Recurrent.*_unroll`, implements a - LSTM-based RNN. More memory efficient than traditional LSTM. - - ## References - - * [Long Short-Term Memory](http://www.bioinf.jku.at/publications/older/2604.pdf) - """ - @deprecated "Use Axon.Layers.lstm_cell/7 instead" - defn lstm_cell( - input, - carry, - input_kernel, - hidden_kernel, - bias, - gate_fn \\ &sigmoid/1, - activation_fn \\ &tanh/1 - ) do - {cell, hidden} = carry - {wii, wif, wig, wio} = input_kernel - {whi, whf, whg, who} = hidden_kernel - - {bi, bf, bg, bo} = bias - - i = gate_fn.(dense(input, wii, bi) + dense(hidden, whi, 0)) - f = gate_fn.(dense(input, wif, bf) + dense(hidden, whf, 0)) - g = activation_fn.(dense(input, wig, bg) + dense(hidden, whg, 0)) - o = gate_fn.(dense(input, wio, bo) + dense(hidden, who, 0)) - - new_c = f * cell + i * g - new_h = o * activation_fn.(new_c) - - {{new_c, new_h}, new_h} - end - - @doc """ - ConvLSTM Cell. - - When combined with `Axon.Recurrent.*_unroll`, implements a - ConvLSTM-based RNN. More memory efficient than traditional LSTM. - - ## Options - - * `:strides` - convolution strides. Defaults to `1`. - - * `:padding` - convolution padding. Defaults to `:same`. - - ## References - - * [Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting](https://arxiv.org/abs/1506.04214) - """ - @deprecated "Use Axon.Layers.conv_lstm_cell/6 instead" - defn conv_lstm_cell(input, carry, input_kernel, hidden_kernel, bias, opts \\ []) do - opts = keyword!(opts, strides: 1, padding: :same) - - {ih} = input_kernel - {hh} = hidden_kernel - {bi} = bias - - {{cell, hidden}, input} = rank_down({carry, input}) - - gates = - Nx.add( - conv(input, ih, bi, strides: opts[:strides], padding: opts[:padding]), - conv(hidden, hh, 0, strides: opts[:strides], padding: opts[:padding]) - ) - - {i, g, f, o} = split_gates(gates) - - f = sigmoid(f + 1) - new_c = f * cell + sigmoid(i) * tanh(g) - new_h = sigmoid(o) * tanh(new_c) - - rank_up({{new_c, new_h}, new_h}) - end - - defnp split_gates(gates) do - transform(gates, fn gates -> - channels = elem(Nx.shape(gates), 1) - split_every = div(channels, 4) - - split_dims = - for i <- 0..3 do - {i * split_every, split_every} - end - - split_dims - |> Enum.map(fn {start, len} -> Nx.slice_along_axis(gates, start, len, axis: 1) end) - |> List.to_tuple() - end) - end - - defnp rank_down(rnn_data) do - transform(rnn_data, fn {{cell, hidden}, input} -> - [cell, hidden, input] = - for tensor <- [cell, hidden, input] do - Nx.squeeze(tensor, axes: [1]) - end - - {{cell, hidden}, input} - end) - end - - defnp rank_up(rnn_data) do - transform(rnn_data, fn {{cell, hidden}, input} -> - [cell, hidden, input] = - for tensor <- [cell, hidden, input] do - new_shape = - Nx.shape(tensor) - |> Tuple.insert_at(1, 1) - - Nx.reshape(tensor, new_shape) - end - - {{cell, hidden}, input} - end) - end - - @doc """ - Dynamically unrolls an RNN. - - Unrolls implement a `scan` operation which applies a - transformation on the leading axis of `input_sequence` carrying - some state. In this instance `cell_fn` is an RNN cell function - such as `lstm_cell` or `gru_cell`. - - This function will make use of an `defn` while-loop such and thus - may be more efficient for long sequences. - """ - @deprecated "Use Axon.Layers.dynamic_unroll/6 instead" - defn dynamic_unroll(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias) do - time_steps = transform(Nx.shape(input_sequence), &elem(&1, 1)) - - feature_dims = transform(Nx.rank(input_sequence), &List.duplicate(0, &1 - 2)) - - initial_shape = - transform({cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias}, fn - {cell_fn, inp, carry, inp_kernel, hid_kernel, bias} -> - seq = Nx.slice_along_axis(inp, 0, 1, axis: 1) - {_, seq} = cell_fn.(seq, carry, inp_kernel, hid_kernel, bias) - put_elem(Nx.shape(seq), 1, elem(Nx.shape(inp), 1)) - end) - - init_sequence = Nx.broadcast(0.0, initial_shape) - i = Nx.tensor(0) - - {_, carry, output, _, _, _, _} = - while {i, carry, init_sequence, input_sequence, input_kernel, recurrent_kernel, bias}, - Nx.less(i, time_steps) do - sequence = Nx.slice_along_axis(input_sequence, i, 1, axis: 1) - indices = transform({feature_dims, i}, fn {feature_dims, i} -> [0, i] ++ feature_dims end) - {carry, output} = cell_fn.(sequence, carry, input_kernel, recurrent_kernel, bias) - update_sequence = Nx.put_slice(init_sequence, indices, output) - {i + 1, carry, update_sequence, input_sequence, input_kernel, recurrent_kernel, bias} - end - - {carry, output} - end - - @doc """ - Statically unrolls an RNN. - - Unrolls implement a `scan` operation which applies a - transformation on the leading axis of `input_sequence` carrying - some state. In this instance `cell_fn` is an RNN cell function - such as `lstm_cell` or `gru_cell`. - - This function inlines the unrolling of the sequence such that - the entire operation appears as a part of the compilation graph. - This makes it suitable for shorter sequences. - """ - @deprecated "Use Axon.Layers.static_unroll/6 instead" - defn static_unroll(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias) do - transform( - {cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias}, - fn {cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias} -> - time_steps = elem(Nx.shape(input_sequence), 1) - - {carry, outputs} = - for t <- 0..(time_steps - 1), reduce: {carry, []} do - {carry, outputs} -> - input = Nx.slice_along_axis(input_sequence, t, 1, axis: 1) - {carry, output} = cell_fn.(input, carry, input_kernel, recurrent_kernel, bias) - {carry, [output | outputs]} - end - - {carry, Nx.concatenate(Enum.reverse(outputs), axis: 1)} - end - ) - end -end diff --git a/lib/axon/schedules.ex b/lib/axon/schedules.ex index 892ed3b7..62f07e07 100644 --- a/lib/axon/schedules.ex +++ b/lib/axon/schedules.ex @@ -1,28 +1,6 @@ defmodule Axon.Schedules do - @moduledoc """ - Parameter Schedules. - - Parameter schedules are often used to anneal hyperparameters - such as the learning rate during the training process. Schedules - provide a mapping from the current time step to a learning rate - or another hyperparameter. - - Choosing a good learning rate and consequently a good learning - rate schedule is typically a process of trial and error. Learning - rates should be relatively small such that the learning curve - does not oscillate violently during the training process, but - not so small that learning proceeds too slowly. Using a - schedule slowly decreases oscillations during the training - process such that, as the model converges, training also - becomes more stable. - - All of the functions in this module are implemented as - numerical functions and can be JIT or AOT compiled with - any supported `Nx` compiler. - """ - + @moduledoc false import Nx.Defn - import Axon.Shared @doc """ Linear decay schedule. @@ -33,6 +11,7 @@ defmodule Axon.Schedules do * `:steps` - total number of decay steps. Defaults to `1000` """ + @deprecated "Use Polaris.Schedules.linear_decay/2 instead" def linear_decay(init_value, opts \\ []) do &apply_linear_decay(&1, [{:init_value, init_value} | opts]) end @@ -70,6 +49,7 @@ defmodule Axon.Schedules do * `:staircase` - discretize outputs. Defaults to `false` """ + @deprecated "Use Polaris.Schedules.exponential_decay/2 instead" def exponential_decay(init_value, opts \\ []) do &apply_exponential_decay(&1, [{:init_value, init_value} | opts]) end @@ -86,7 +66,7 @@ defmodule Axon.Schedules do init_value = opts[:init_value] rate = opts[:decay_rate] - staircase? = to_predicate(opts[:staircase]) + staircase? = opts[:staircase] k = opts[:transition_steps] start = opts[:transition_begin] @@ -104,7 +84,7 @@ defmodule Axon.Schedules do decayed_value = rate - |> Nx.power(p) + |> Nx.pow(p) |> Nx.multiply(init_value) Nx.select( @@ -132,6 +112,7 @@ defmodule Axon.Schedules do * [SGDR: Stochastic Gradient Descent with Warm Restarts](https://openreview.net/forum?id=Skq89Scxx¬eId=Skq89Scxx) """ + @deprecated "Use Polaris.Schedules.cosine_decay/2 instead" def cosine_decay(init_value, opts \\ []) do &apply_cosine_decay(&1, [{:init_value, init_value} | opts]) end @@ -160,13 +141,14 @@ defmodule Axon.Schedules do $$\gamma(t) = \gamma_0$$ """ + @deprecated "Use Polaris.Schedules.constant/2 instead" def constant(init_value, opts \\ []) do &apply_constant(&1, [{:init_value, init_value} | opts]) end defnp apply_constant(_step, opts \\ []) do opts = keyword!(opts, init_value: 0.01) - Nx.tensor(opts[:init_value]) + opts[:init_value] end @doc ~S""" @@ -185,6 +167,7 @@ defmodule Axon.Schedules do $k$ in above formulation. Defaults to `10` """ + @deprecated "Use Polaris.Schedules.polynomial_decay/2 instead" def polynomial_decay(init_value, opts \\ []) do &apply_polynomial_decay(&1, [{:init_value, init_value} | opts]) end @@ -211,7 +194,7 @@ defmodule Axon.Schedules do |> Nx.divide(k) |> Nx.negate() |> Nx.add(1) - |> Nx.power(p) + |> Nx.pow(p) |> Nx.multiply(Nx.subtract(init_value, end_value)) |> Nx.add(end_value) end diff --git a/lib/axon/shape.ex b/lib/axon/shape.ex index 5e7314b9..d06471f6 100644 --- a/lib/axon/shape.ex +++ b/lib/axon/shape.ex @@ -1,6 +1,8 @@ defmodule Axon.Shape do @moduledoc false + import Nx.Defn + # Collection of shape calculations for calculating the # output and trainable parameter shapes for high-level # layers. @@ -319,8 +321,11 @@ defmodule Axon.Shape do the input bias shape is a vector, otherwise we'll just attempt to let it broadcast itself. """ - def conv_bias_reshape(input_shape, spatial_rank, channels) do - case input_shape do + deftransform conv_bias_reshape(input, bias, channels) do + bias_shape = Nx.shape(bias) + spatial_rank = Nx.rank(input) - 2 + + case bias_shape do {} -> {} @@ -338,11 +343,51 @@ defmodule Axon.Shape do end end + @doc """ + Calculates the permutation options to pass to convolution + based on channels configuration. + + It returns both the input/output permutation and the kernel + permutation. + """ + deftransform conv_permutations(input, channels) do + rank = Nx.rank(input) + + case channels do + :first -> + perm = Enum.to_list(0..(rank - 1)) + {perm, perm} + + :last -> + spatial = Enum.to_list(1..(rank - 2)//1) + perm = [0, rank - 1 | spatial] + kernel_perm = [rank - 1, rank - 2] ++ Enum.to_list(0..(rank - 3)//1) + {perm, kernel_perm} + + invalid -> + raise ArgumentError, "invalid channel configuration, #{inspect(invalid)}" + end + end + + @doc """ + Calculates strides for transposed convolution. + """ + deftransform conv_transpose_strides(input, strides) do + rank = Nx.rank(input) - 2 + + case strides do + [_ | _] = strides -> strides + strides -> List.duplicate(strides, rank) + end + end + @doc """ Calculates the padding needed for a transposed convolution. """ - def conv_transpose_padding(kernel_shape, kernel_dilation, strides, padding, channels) - when padding in [:valid, :same] do + deftransform conv_transpose_padding(kernel, kernel_dilation, strides, padding, channels) + when padding in [:valid, :same] do + kernel_shape = Nx.shape(kernel) + kernel_spatial_dims = case channels do :first -> @@ -395,7 +440,7 @@ defmodule Axon.Shape do end end - def conv_transpose_padding(_, _, _, padding, _), do: padding + deftransform conv_transpose_padding(_, _, _, padding, _), do: padding @doc """ Calculates the shape of a depthwise convolution kernel given the @@ -632,7 +677,9 @@ defmodule Axon.Shape do across batch or channel dimensions, so we just specify a size of `1` for each of those. """ - def pool_window_size(window, spatial_rank, channels) do + deftransform pool_window_size(input, window, channels) do + spatial_rank = Nx.rank(input) - 2 + spatial_dims = case window do x when is_integer(x) -> @@ -655,20 +702,70 @@ defmodule Axon.Shape do end @doc """ - Computes the window size from the given parent shape. + Calculates the window strides of a pooling operation. """ - def adaptive_pool_window_size(parent_shape, nil, channels) do + deftransform pool_window_strides(input, strides, window_dimensions, channels) do + rank = Nx.rank(input) + + case {strides, channels} do + {nil, _} -> Tuple.to_list(window_dimensions) + {[_ | _] = strides, :first} -> [1, 1 | strides] + {[_ | _] = strides, :last} -> [1 | strides] ++ [1] + {strides, :first} -> [1, 1 | List.duplicate(strides, rank - 2)] + {strides, :last} -> [1 | List.duplicate(strides, rank - 2)] ++ [1] + end + end + + @doc """ + Calculates window dilations of a pooling operation. + """ + deftransform pool_window_dilations(input, window_dilations, channels) do + rank = Nx.rank(input) + + case {window_dilations, channels} do + {nil, _} -> List.duplicate(1, rank) + {[_ | _] = dilations, :first} -> [1, 1 | dilations] + {[_ | _] = dilations, :last} -> [1 | dilations] ++ [1] + {dilations, :first} -> [1, 1 | List.duplicate(dilations, rank - 2)] + {dilations, :last} -> [1 | List.duplicate(dilations, rank - 2)] ++ [1] + end + end + + @doc """ + Calculates padding of a pooling operation based on input padding + and channels configuration. + """ + deftransform pool_window_padding(padding, channels) do + case {padding, channels} do + {:same, _} -> :same + {:valid, _} -> :valid + {padding, :first} -> [{0, 0}, {0, 0} | padding] + {padding, :last} -> [{0, 0} | padding] ++ [{0, 0}] + end + end + + @doc """ + Computes the adaptive pooling output size from the given parent + shape, output shape and channels configuration. + """ + deftransform adaptive_pool_output_size(input, nil, channels) do + parent_shape = Nx.shape(input) + case channels do :first -> - parent_shape |> Tuple.delete_at(0) |> Tuple.delete_at(0) + parent_shape + |> Tuple.delete_at(0) + |> Tuple.delete_at(0) :last -> - parent_shape |> Tuple.delete_at(tuple_size(parent_shape) - 1) |> Tuple.delete_at(0) + parent_shape + |> Tuple.delete_at(tuple_size(parent_shape) - 1) + |> Tuple.delete_at(0) end end - def adaptive_pool_window_size(parent_shape, output_size, _channels) do - inner_rank = Nx.rank(parent_shape) - 2 + deftransform adaptive_pool_output_size(input, output_size, _channels) do + inner_rank = Nx.rank(input) - 2 tuple_or_duplicate(:output_size, output_size, inner_rank) end @@ -684,7 +781,10 @@ defmodule Axon.Shape do This preserves the size of the channel/batch dimension. """ - def adaptive_pool_window_strides(input_shape, output_spatial, spatial_rank, channels) do + deftransform adaptive_pool_window_strides(input, output_spatial, channels) do + input_shape = Nx.shape(input) + spatial_rank = Nx.rank(input) - 2 + idx = if channels == :first do 1 @@ -733,13 +833,15 @@ defmodule Axon.Shape do This preserves the size of the channel/batch dimension. """ - def adaptive_pool_window_size( - input_shape, - stride, - output_spatial, - spatial_rank, - channels - ) do + deftransform adaptive_pool_window_size( + input, + stride, + output_spatial, + channels + ) do + input_shape = Nx.shape(input) + spatial_rank = Nx.rank(input) - 2 + strides = case channels do :first -> @@ -813,16 +915,22 @@ defmodule Axon.Shape do @doc """ Calculates the reduction axes for batch normalization. """ - def batch_norm_axes(axes, channel_index) do - axes - |> Enum.filter(&(&1 != channel_index)) + deftransform batch_norm_axes(input, channel_index) do + axis = Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.names(input)) + + input + |> Nx.axes() + |> Enum.filter(&(&1 != axis)) end @doc """ Calculates the reduction axes for instance normalization. """ - def instance_norm_axes(axes, channel_index) do - reduction_axes = axes -- [0, channel_index] + deftransform instance_norm_axes(input, channel_index) do + axis = Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.names(input)) + axes = Nx.axes(input) + + reduction_axes = axes -- [0, axis] if reduction_axes == [] do raise ArgumentError, "rank of input shape must be at least 3" @@ -834,14 +942,17 @@ defmodule Axon.Shape do @doc """ Calculates the reduction axes for group normalization. """ - def group_norm_axes(rank, channel_index) do - Enum.to_list(1..(rank - 1)) -- [channel_index] + deftransform group_norm_axes(x, channel_index) do + Enum.to_list(1..(Nx.rank(x) - 1)) -- [channel_index] end @doc """ Calculates the reshape for group normalization. """ - def group_norm_shape(shape, num_groups, channel_index) do + deftransform group_norm_shape(input, num_groups, channel_index) do + shape = Nx.shape(input) + channel_index = Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.names(input)) + channels = elem(shape, channel_index) group_size = div(channels, num_groups) @@ -850,42 +961,25 @@ defmodule Axon.Shape do |> Tuple.insert_at(channel_index + 1, group_size) end - @doc """ - Calculates the shape after a flatten layer, which - flattens the non-minibatch dimensions into a single - dimension. - - ## Examples - - iex> Axon.Shape.flatten({nil, 1, 28, 28}) - {nil, 784} - - iex> Axon.Shape.flatten({32, 128}) - {32, 128} - - iex> Axon.Shape.flatten({nil, 10, 10}) - {nil, 100} - """ - def flatten(shape) do - out_units = Nx.size(Tuple.delete_at(shape, 0)) - - {elem(shape, 0), out_units} - end - @doc """ Computes split sizes for the given splits. """ - def split(shape, n, axis) do + deftransform split(input, index, splits, axis) do + shape = Nx.shape(input) + nil_names = List.duplicate(nil, Nx.rank(shape)) axis = Nx.Shape.normalize_axis(shape, axis, nil_names) - unless rem(elem(shape, axis), n) == 0 do + unless rem(elem(shape, axis), splits) == 0 do raise ArgumentError, - "unable to create #{n} even splits along axis #{axis}" <> + "unable to create #{splits} even splits along axis #{axis}" <> " of size #{elem(shape, axis)}" end - div(elem(shape, axis), n) + slice_size = div(elem(shape, axis), splits) + + offset = index * slice_size + {offset, slice_size} end @doc """ @@ -898,13 +992,15 @@ defmodule Axon.Shape do ## Examples - iex> Axon.Shape.spatial_dropout_noise_shape({nil, 3, 28, 28}, :first) - {nil, 1, 28, 28} + iex> Axon.Shape.spatial_dropout_noise_shape({1, 3, 28, 28}, :first) + {1, 1, 28, 28} - iex> Axon.Shape.spatial_dropout_noise_shape({nil, 28, 28, 3}, :last) - {nil, 28, 28, 1} + iex> Axon.Shape.spatial_dropout_noise_shape({1, 28, 28, 3}, :last) + {1, 28, 28, 1} """ - def spatial_dropout_noise_shape(input_shape, channels) do + deftransform spatial_dropout_noise_shape(input, channels) do + input_shape = Nx.shape(input) + if channels == :first do :erlang.setelement(2, input_shape, 1) else @@ -969,7 +1065,23 @@ defmodule Axon.Shape do " got #{inspect(shape)}" end - {elem(shape, 0), 1, units} + {elem(shape, 0), units} + end + + @doc """ + Returns the reduction axes for a global pooling operation + based on the input rank and channels configuration. + """ + deftransform global_pool_axes(input, channels) do + rank = Nx.rank(input) + + case channels do + :last -> + Enum.to_list(1..(rank - 2)) + + :first -> + Enum.to_list(2..(rank - 1)) + end end defp tuple_or_duplicate(key, tuple_or_integer, rank) do diff --git a/lib/axon/shared.ex b/lib/axon/shared.ex index 2c5e7421..6279488a 100644 --- a/lib/axon/shared.ex +++ b/lib/axon/shared.ex @@ -11,145 +11,114 @@ defmodule Axon.Shared do @doc """ Asserts `lhs` has same shape as `rhs`. """ - defn assert_shape!(caller, lhs_name, lhs, rhs_name, rhs) do - transform( - {lhs, rhs}, - fn {lhs, rhs} -> - lhs = Nx.shape(lhs) - rhs = Nx.shape(rhs) - - unless Elixir.Kernel.==(lhs, rhs) do - raise ArgumentError, - "#{caller}: expected input shapes #{lhs_name} and #{rhs_name}" <> - " to be equal, got #{inspect(lhs)} != #{inspect(rhs)}" - end - end - ) + deftransform assert_shape!(caller, lhs_name, lhs, rhs_name, rhs) do + lhs = Nx.shape(lhs) + rhs = Nx.shape(rhs) + + unless lhs == rhs do + raise ArgumentError, + "#{caller}: expected input shapes #{lhs_name} and #{rhs_name}" <> + " to be equal, got #{inspect(lhs)} != #{inspect(rhs)}" + end end @doc """ Asserts all shapes are equal. """ - defn assert_shape!(caller, shape_names, shapes) do - transform(shapes, fn [shape | shapes] -> - equal? = - Enum.all?(shapes, fn cur_shape -> - Elixir.Kernel.==(Nx.shape(cur_shape), Nx.shape(shape)) - end) - - unless equal? do - raise ArgumentError, - "#{caller}: expected all input shapes #{inspect(shape_names)}" <> - " to be equal, got #{inspect(shapes)}" - end - end) + deftransform assert_shape!(caller, shape_names, [shape | shapes]) do + equal? = + Enum.all?(shapes, fn cur_shape -> + Nx.shape(cur_shape) == Nx.shape(shape) + end) + + unless equal? do + raise ArgumentError, + "#{caller}: expected all input shapes #{inspect(shape_names)}" <> + " to be equal, got #{inspect(shapes)}" + end end @doc """ Asserts `inp` has explicit rank `rank`. """ - defn assert_rank!(caller, inp_name, inp, rank) do - transform( - {inp, rank}, - fn {x, y} -> - x = Nx.rank(x) - - unless Elixir.Kernel.==(x, y) do - raise ArgumentError, - "#{caller}: expected #{inp_name} to have rank equal to #{y}," <> - " got #{x} != #{y}" - end - end - ) + deftransform assert_rank!(caller, inp_name, inp, rank) do + x = Nx.rank(inp) + + unless x == rank do + raise ArgumentError, + "#{caller}: expected #{inp_name} to have rank equal to #{rank}," <> + " got #{x} != #{rank}" + end end @doc """ Asserts `lhs` has same rank as `rhs`. """ - defn assert_equal_rank!(caller, lhs_name, lhs, rhs_name, rhs) do - transform( - {lhs, rhs}, - fn {x, y} -> - x = if is_integer(x), do: x, else: Nx.rank(x) - y = if is_integer(y), do: y, else: Nx.rank(y) - - unless Elixir.Kernel.>=(x, y) do - raise ArgumentError, - "#{caller}: expected #{lhs_name} and #{rhs_name} ranks to be equal" <> - " got #{x} != #{y}" - end - end - ) + deftransform assert_equal_rank!(caller, lhs_name, lhs, rhs_name, rhs) do + x = if is_integer(lhs), do: lhs, else: Nx.rank(lhs) + y = if is_integer(rhs), do: rhs, else: Nx.rank(rhs) + + unless x >= y do + raise ArgumentError, + "#{caller}: expected #{lhs_name} and #{rhs_name} ranks to be equal" <> + " got #{x} != #{y}" + end end @doc """ Asserts all ranks are equal. """ - defn assert_equal_rank!(caller, rank_names, ranks) do - transform(ranks, fn [rank | ranks] -> - equal? = - Enum.all?(ranks, fn cur_rank -> - Elixir.Kernel.==(Nx.rank(cur_rank), Nx.rank(rank)) - end) - - unless equal? do - raise ArgumentError, - "#{caller}: expected all input ranks #{inspect(rank_names)}" <> - " to be equal, got #{inspect(ranks)}" - end - end) + deftransform assert_equal_rank!(caller, rank_names, [rank | ranks]) do + equal? = + Enum.all?(ranks, fn cur_rank -> + Nx.rank(cur_rank) == Nx.rank(rank) + end) + + unless equal? do + raise ArgumentError, + "#{caller}: expected all input ranks #{inspect(rank_names)}" <> + " to be equal, got #{inspect(ranks)}" + end end @doc """ Asserts `lhs` has at least rank `rhs`. """ - defn assert_min_rank!(caller, name, lhs, rhs) do - transform( - {lhs, rhs}, - fn {x, y} -> - x = if is_integer(x), do: x, else: Nx.rank(x) - y = if is_integer(y), do: y, else: Nx.rank(y) - - unless Elixir.Kernel.>=(x, y) do - raise ArgumentError, - "#{caller}: expected #{name} shape to have at least rank #{y}, got rank #{x}" - end - end - ) - end + deftransform assert_min_rank!(caller, name, lhs, rhs) do + x = if is_integer(lhs), do: lhs, else: Nx.rank(lhs) + y = if is_integer(rhs), do: rhs, else: Nx.rank(rhs) - @doc """ - Transforms the given Elixir value into a scalar predicate. - """ - defn to_predicate(term) do - transform(term, fn term -> if term, do: 1, else: 0 end) + unless x >= y do + raise ArgumentError, + "#{caller}: expected #{name} shape to have at least rank #{y}, got rank #{x}" + end end @doc """ Creates a zeros-like structure which matches the structure of the input. """ - defn zeros_like(params) do - transform( - params, - &deep_new(&1, fn x -> - fun = Axon.Initializers.zeros() - fun.(Nx.shape(x), Nx.type(x)) - end) - ) + deftransform zeros_like(params, opts \\ []) do + opts = Keyword.validate!(opts, [:type]) + fun = Axon.Initializers.zeros() + + deep_new(params, fn x -> + type = opts[:type] || Nx.type(x) + fun.(Nx.shape(x), type) + end) end @doc """ Creates a fulls-like tuple of inputs. """ - defn fulls_like(params, value) do - transform( - params, - &deep_new(&1, fn x -> - fun = Axon.Initializers.full(value) - fun.(Nx.shape(x), Nx.type(x)) - end) - ) + deftransform fulls_like(params, value, opts \\ []) do + opts = Keyword.validate!(opts, [:type]) + fun = Axon.Initializers.full(value) + + deep_new(params, fn x -> + type = opts[:type] || Nx.type(x) + fun.(Nx.shape(x), type) + end) end @doc """ @@ -259,18 +228,17 @@ defmodule Axon.Shared do end end - ## Numerical Helpers + ## List transforms in defn - # TODO: These should be contained somewhere else, like another library + deftransform list_duplicate(value, size) do + List.duplicate(value, size) + end - defn logsumexp(x, opts \\ []) do - opts = keyword!(opts, axes: [], keep_axes: false) + deftransform list_wrap(value), do: List.wrap(value) - x - |> Nx.exp() - |> Nx.sum(opts) - |> Nx.log() - end + ## Numerical Helpers + + # TODO: These should be contained somewhere else, like another library defn xlogy(x, y) do x_ok = Nx.not_equal(x, 0.0) @@ -282,25 +250,20 @@ defmodule Axon.Shared do defn reciprocal(x), do: Nx.divide(1, x) defn normalize(input, mean, variance, gamma, bias, opts \\ []) do - opts = keyword!(opts, epsilon: 1.0e-6) + [epsilon: epsilon] = keyword!(opts, epsilon: 1.0e-6) + # The select is so that we improve numerical stability by clipping + # both insignificant values of variance and NaNs to epsilon. scale = - variance - |> Nx.add(opts[:epsilon]) - |> Nx.rsqrt() - |> Nx.multiply(gamma) - - input - |> Nx.subtract(mean) - |> Nx.multiply(scale) - |> Nx.add(bias) + gamma * Nx.select(variance >= epsilon, Nx.rsqrt(variance + epsilon), Nx.rsqrt(epsilon)) + + scale * (input - mean) + bias end defn mean_and_variance(input, opts \\ []) do opts = keyword!(opts, [:axes]) mean = Nx.mean(input, axes: opts[:axes], keep_axes: true) - mean_of_squares = Nx.mean(Nx.multiply(input, input), axes: opts[:axes], keep_axes: true) - square_of_mean = Nx.multiply(mean, mean) - {mean, mean_of_squares - square_of_mean} + var = Nx.variance(input, axes: opts[:axes], keep_axes: true) + {mean, var} end end diff --git a/lib/axon/updates.ex b/lib/axon/updates.ex index bd004404..06e0b144 100644 --- a/lib/axon/updates.ex +++ b/lib/axon/updates.ex @@ -1,89 +1,6 @@ defmodule Axon.Updates do - @moduledoc ~S""" - Parameter update methods. - - Update methods transform the input tensor in some way, - usually by scaling or shifting the input with respect - to some input state. Update methods are composed - to create more advanced optimization methods such as AdaGrad - or Adam. Each update returns a tuple: - - {init_fn, update_fn} - - Which represent a state initialization and state update - function respectively. While each method in the Updates - API is a regular Elixir function, the two methods they - return are implemented as `defn`, so they can be accelerated - using any Nx backend or compiler. - - Update methods are just combinators that can be arbitrarily - composed to create complex optimizers. For example, the Adam - optimizer in Axon.Optimizers is implemented as: - - def adam(learning_rate, opts \\ []) do - Updates.scale_by_adam(opts) - |> Updates.scale(-learning_rate) - end - - Updates are maps of updates, often associated with parameters of - the same names. Using `Axon.Updates.apply_updates/3` will merge updates - and parameters by adding associated parameters and updates, and - ensuring any given model state is preserved. - - ## Custom combinators - - You can create your own combinators using the `stateless/2` and - `stateful/3` primitives. Every update method in this module is - implemented in terms of one of these two primitives. - - `stateless/2` represents a stateless update: - - def scale(combinator \\ Axon.Updates.identity(), step_size) do - stateless(combinator, &apply_scale(&1, &2, step_size)) - end - - defnp apply_scale(x, _params, step) do - transform( - {x, step}, - fn {updates, step} -> - deep_new(updates, fn x -> Nx.multiply(x, step) end) - end - ) - end - - Notice how the function given to `stateless/2` is defined within `defn`. - This is what allows the anonymous functions returned by `Axon.Updates` - to be used inside `defn`. - - `stateful/3` represents a stateful update and follows the same pattern: - - def my_stateful_update(updates) do - Axon.Updates.stateful(updates, &init_my_update/1, &apply_my_update/2) - end - - defnp init_my_update(params) do - state = zeros_like(params) - %{state: state} - end + @moduledoc false - defnp apply_my_update(updates, state) do - new_state = deep_new(state, fn v -> Nx.add(v, 0.01) end) - updates = transform({updates, new_state}, fn {updates, state} -> - deep_merge(updates, state, fn g, z -> Nx.multiply(g, z) end) - end) - {updates, %{state: new_state}} - end - - State associated with individual parameters should have keys that match the - keys of the parameter. For example, if you have parameters `%{kernel: kernel}` - with associated states `mu` and `nu` representing the first and second moments, - your state should look something like: - - %{ - mu: %{kernel: kernel_mu} - nu: %{kernel: kernel_nu} - } - """ import Nx.Defn import Axon.Shared @@ -92,6 +9,7 @@ defmodule Axon.Updates do $$f(x_i) = \alpha x_i$$ """ + @deprecated "Use Polaris.Updates.scale/2 instead" def scale(combinator \\ identity(), step_size) do stateless(combinator, &apply_scale(&1, &2, step_size)) end @@ -106,6 +24,7 @@ defmodule Axon.Updates do $$f(x_i) = \alpha x_i$$ """ + @deprecated "Use Polaris.Updates.scale_by_state/1 instead" def scale_by_state(combinator_or_step) def scale_by_state(step) when is_number(step) do @@ -144,6 +63,7 @@ defmodule Axon.Updates do * [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980) """ + @deprecated "Use Polaris.Updates.scale_by_adam/1 instead" def scale_by_adam(combinator_or_opts \\ []) def scale_by_adam(opts) when is_list(opts) do @@ -165,8 +85,8 @@ defmodule Axon.Updates do end defnp init_scale_by_adam(params) do - mus = zeros_like(params) - nus = zeros_like(params) + mus = zeros_like(params, type: :f32) + nus = zeros_like(params, type: :f32) count = Nx.tensor(0) %{mu: mus, nu: nus, count: count} end @@ -196,6 +116,7 @@ defmodule Axon.Updates do * `:eps` - numerical stability term. Defaults to `1.0e-7` """ + @deprecated "Use Polaris.Updates.scale_by_rss/1 instead" def scale_by_rss(combinator_or_opts \\ []) def scale_by_rss(opts) when is_list(opts) do @@ -219,7 +140,7 @@ defmodule Axon.Updates do end defnp init_scale_by_rss(params, value) do - sum_of_squares = fulls_like(params, value) + sum_of_squares = fulls_like(params, value, type: :f32) %{sum_of_squares: sum_of_squares} end @@ -227,7 +148,7 @@ defmodule Axon.Updates do opts = keyword!(opts, eps: 1.0e-7) eps = opts[:eps] - sum_of_squares = deep_merge(x, sum_of_squares, fn g, z -> Nx.power(g, 2) + z end) + sum_of_squares = deep_merge(x, sum_of_squares, fn g, z -> Nx.pow(g, 2) + z end) inv_sqrt_squares = deep_new(sum_of_squares, fn z -> Nx.rsqrt(z + eps) end) @@ -255,6 +176,7 @@ defmodule Axon.Updates do * [Overview of mini-batch gradient descent](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) """ + @deprecated "Use Polaris.Updates.scale_by_rms/1 instead" def scale_by_rms(combinator_or_opts \\ []) def scale_by_rms(opts) when is_list(opts) do @@ -278,7 +200,7 @@ defmodule Axon.Updates do end defnp init_scale_by_rms(params, scale) do - nu = fulls_like(params, scale) + nu = fulls_like(params, scale, type: :f32) %{nu: nu} end @@ -312,6 +234,7 @@ defmodule Axon.Updates do * [AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients](https://arxiv.org/abs/2010.07468) """ + @deprecated "Use Polaris.Updates.scale_by_belief/1 instead" def scale_by_belief(combinator_or_opts \\ []) def scale_by_belief(opts) when is_list(opts) do @@ -333,8 +256,8 @@ defmodule Axon.Updates do end defnp init_scale_by_belief(params) do - mus = zeros_like(params) - nus = zeros_like(params) + mus = zeros_like(params, type: :f32) + nus = zeros_like(params, type: :f32) count = Nx.tensor(0) %{mu: mus, nu: nus, count: count} end @@ -371,6 +294,7 @@ defmodule Axon.Updates do * [Overview of mini-batch gradient descent](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) """ + @deprecated "Use Polaris.Updates.scale_by_stddev/1 instead" def scale_by_stddev(combinator_or_opts \\ []) def scale_by_stddev(opts) when is_list(opts) do @@ -394,8 +318,8 @@ defmodule Axon.Updates do end defnp init_scale_by_stddev(params, value) do - mu = zeros_like(params) - nu = fulls_like(params, value) + mu = zeros_like(params, type: :f32) + nu = fulls_like(params, value, type: :f32) %{mu: mu, nu: nu} end @@ -409,7 +333,7 @@ defmodule Axon.Updates do mu_nu = deep_merge(mu, nu, fn m, n -> - Nx.rsqrt(-Nx.power(m, 2) + n + eps) + Nx.rsqrt(-Nx.pow(m, 2) + n + eps) end) x = deep_merge(x, mu_nu, fn g, mn -> g * mn end) @@ -425,6 +349,7 @@ defmodule Axon.Updates do counter. You might need to update the schedule to operate on per-batch schedule rather than per-epoch. """ + @deprecated "Use Polaris.Updates.scale_by_schedule/2 instead" def scale_by_schedule(combinator \\ identity(), schedule_fn) when is_function(schedule_fn, 1) do stateful( combinator, @@ -465,6 +390,7 @@ defmodule Axon.Updates do * [On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265) """ + @deprecated "Use Polaris.Updates.scale_by_radam/1 instead" def scale_by_radam(combinator_or_opts \\ []) def scale_by_radam(opts) when is_list(opts) do @@ -486,8 +412,8 @@ defmodule Axon.Updates do end defnp init_scale_by_radam(params) do - mu = zeros_like(params) - nu = zeros_like(params) + mu = zeros_like(params, type: :f32) + nu = zeros_like(params, type: :f32) count = Nx.tensor(0) %{mu: mu, nu: nu, count: count} end @@ -506,7 +432,7 @@ defmodule Axon.Updates do nu = update_moment(x, nu, b2, 2) count_inc = count + 1 - b2t = Nx.power(b2, count_inc) + b2t = Nx.pow(b2, count_inc) ro = ro_inf - 2 * count_inc * b2t / (1 - b2t) mu_hat = bias_correction(mu, b1, count + 1) @@ -525,10 +451,8 @@ defmodule Axon.Updates do defnp radam_update(ro, ro_inf, mu, nu, eps_root, eps) do r = Nx.sqrt((ro - 4) * (ro - 2) * ro_inf / ((ro_inf - 4) * (ro_inf - 2) * ro)) - transform({r, mu, nu, eps_root, eps}, fn {r, mu, nu, eps_root, eps} -> - deep_merge(mu, nu, fn m, v -> - r * m / (Nx.sqrt(v + eps_root) + eps) - end) + deep_merge(mu, nu, fn m, v -> + r * m / (Nx.sqrt(v + eps_root) + eps) end) end @@ -543,6 +467,7 @@ defmodule Axon.Updates do to `false` """ + @deprecated "Use Polaris.Updates.trace/1 instead" def trace(combinator_or_opts \\ []) def trace(opts) when is_list(opts) do @@ -564,7 +489,7 @@ defmodule Axon.Updates do end defnp init_trace(params) do - trace = zeros_like(params) + trace = zeros_like(params, type: :f32) %{trace: trace} end @@ -592,6 +517,7 @@ defmodule Axon.Updates do * `:delta` - maximum absolute value of the input. Defaults to `2.0` """ + @deprecated "Use Polaris.Updates.clip/1 instead" def clip(combinator_or_opts \\ []) def clip(opts) when is_list(opts) do @@ -623,6 +549,7 @@ defmodule Axon.Updates do * `:max_norm` - maximum norm value of input. Defaults to `1.0` """ + @deprecated "Use Polaris.Updates.clip_by_global_norm/1 instead" def clip_by_global_norm(combinator_or_opts \\ []) def clip_by_global_norm(opts) when is_list(opts) do @@ -646,7 +573,7 @@ defmodule Axon.Updates do sum_gs = deep_reduce(x, Nx.tensor(0.0), fn leaf, acc -> leaf - |> Nx.power(2) + |> Nx.pow(2) |> Nx.sum() |> Nx.add(acc) end) @@ -661,6 +588,7 @@ defmodule Axon.Updates do @doc """ Centralizes input by shifting updates by their mean. """ + @deprecated "Use Polaris.Updates.centralize/1 instead" def centralize(combinator_or_opts \\ []) def centralize(opts) when is_list(opts) do @@ -678,16 +606,16 @@ defmodule Axon.Updates do end defnp apply_centralize(x, _params, _opts \\ []) do - transform(x, fn x -> - deep_new(x, fn z -> - if Elixir.Kernel.>(Nx.rank(z), 1) do - axes = tl(Nx.axes(z)) - z - Nx.mean(z, axes: axes, keep_axes: true) - else - z - end - end) - end) + deep_new(x, ¢ralize_for_rank/1) + end + + deftransformp centralize_for_rank(input) do + if Nx.rank(input) > 1 do + input + |> Nx.subtract(Nx.mean(input, axes: tl(Nx.axes(input)), keep_axes: true)) + else + input + end end @doc """ @@ -699,6 +627,7 @@ defmodule Axon.Updates do * `:decay` - Rate of decay. Defaults to `0.0`. """ + @deprecated "Use Polaris.Updates.add_decayed_weights/1 instead" def add_decayed_weights(combinator_or_opts \\ []) def add_decayed_weights(opts) when is_list(opts) do @@ -737,6 +666,7 @@ defmodule Axon.Updates do * `:eps` - Numerical stability term. Defaults to `0.0`. """ + @deprecated "Use Polaris.Updates.scale_by_trust_ratio/1 instead" def scale_by_trust_ratio(combinator_or_opts \\ []) def scale_by_trust_ratio(opts) when is_list(opts) do @@ -781,12 +711,16 @@ defmodule Axon.Updates do ## Options + * `:seed` - Random seed to use. Defaults to the + current system time. + * `:eta` - Controls amount of noise to add. Defaults to `0.01`. * `:gamma` - Controls amount of noise to add. Defaults to `0.55`. """ + @deprecated "Use Polaris.Updates.add_noise/1 instead" def add_noise(combinator_or_opts \\ []) def add_noise(opts) when is_list(opts) do @@ -800,22 +734,26 @@ defmodule Axon.Updates do def add_noise({init_fn, apply_fn} = combinator, opts) when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do - stateful(combinator, &init_add_noise/1, &apply_add_noise(&1, &2, &3, opts)) + {seed, opts} = Keyword.pop_lazy(opts, :seed, fn -> :erlang.system_time() end) + stateful(combinator, &init_add_noise(&1, seed: seed), &apply_add_noise(&1, &2, &3, opts)) end - defnp init_add_noise(_params) do - %{count: Nx.tensor(0)} + defnp init_add_noise(_params, opts \\ []) do + %{count: Nx.tensor(0), key: Nx.Random.key(opts[:seed])} end - defnp apply_add_noise(x, %{count: count}, _params, opts \\ []) do + defnp apply_add_noise(x, %{count: count, key: key}, _params, opts \\ []) do opts = keyword!(opts, eta: 0.01, gamma: 0.55) - var = opts[:eta] / Nx.power(count + 1, opts[:gamma]) + var = opts[:eta] / Nx.pow(count + 1, opts[:gamma]) - noise = deep_new(x, fn z -> Nx.random_normal(z) end) + {noise, key} = + deep_map_reduce(x, key, fn z, key -> + Nx.Random.normal(key, shape: Nx.shape(z), type: Nx.type(z)) + end) updates = deep_merge(x, noise, fn g, n -> g + var * n end) - {updates, %{count: count + 1}} + {updates, %{count: count + 1, key: key}} end @doc """ @@ -837,6 +775,7 @@ defmodule Axon.Updates do * [Adaptive Methods for Nonconvex Optimization](https://proceedings.neurips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf) """ + @deprecated "Use Polaris.Updates.scale_by_yogi/1 instead" def scale_by_yogi(combinator_or_opts \\ []) def scale_by_yogi(opts) when is_list(opts) do @@ -860,7 +799,7 @@ defmodule Axon.Updates do end defnp init_scale_by_yogi(params, value) do - value = fulls_like(params, value) + value = fulls_like(params, value, type: :f32) mu = value nu = value count = Nx.tensor(0) @@ -878,7 +817,7 @@ defmodule Axon.Updates do nu = deep_merge(x, nu, fn g, v -> - v - (1 - b2) * Nx.sign(v - Nx.power(g, 2)) * Nx.power(g, 2) + v - (1 - b2) * Nx.sign(v - Nx.pow(g, 2)) * Nx.pow(g, 2) end) mu_hat = bias_correction(mu, b1, count + 1) @@ -895,6 +834,7 @@ defmodule Axon.Updates do Stateless updates do not depend on an update state and thus only require an implementation of an update function. """ + @deprecated "Use Polaris.Updates.stateless/2 instead" def stateless({parent_init_fn, parent_apply_fn} \\ identity(), apply_fn) do apply_fn = fn updates, state, params -> {updates, state} = parent_apply_fn.(updates, state, params) @@ -909,6 +849,7 @@ defmodule Axon.Updates do This is often as the initial update in many functions in this module. """ + @deprecated "Use Polaris.Updates.identity/1 instead" def identity() do {fn _params -> {} end, fn updates, state, _params -> {updates, state} end} end @@ -931,6 +872,7 @@ defmodule Axon.Updates do Axon.Updates.centralize() |> Axon.Updates.scale_by_rms() """ + @deprecated "Use Polaris.Updates.compose/2 instead" def compose({init_fn1, apply_fn1}, {init_fn2, apply_fn2}) do init_fn = fn params -> state = init_fn1.(params) @@ -956,6 +898,7 @@ defmodule Axon.Updates do implement some initialization function as well as an update function. """ + @deprecated "Use Polaris.Updates.stateful/3 instead" def stateful({parent_init_fn, parent_apply_fn} \\ identity(), init_fn, apply_fn) do init_fn = fn params -> state = parent_init_fn.(params) @@ -1007,11 +950,11 @@ defmodule Axon.Updates do ## Helpers defnp update_moment(x, moment, decay, order) do - deep_merge(x, moment, fn g, z -> (1 - decay) * Nx.power(g, order) + decay * z end) + deep_merge(x, moment, fn g, z -> (1 - decay) * Nx.pow(g, order) + decay * z end) end defnp bias_correction(moment, decay, count) do - deep_new(moment, fn z -> z / (1 - Nx.power(decay, count)) end) + deep_new(moment, fn z -> z / (1 - Nx.pow(decay, count)) end) end defnp safe_norm(g, min_norm) do diff --git a/mix.exs b/mix.exs index ef73d2b6..4154f74d 100644 --- a/mix.exs +++ b/mix.exs @@ -2,7 +2,7 @@ defmodule Axon.MixProject do use Mix.Project @source_url "https://github.com/elixir-nx/axon" - @version "0.3.0" + @version "0.6.0" def project do [ @@ -35,13 +35,14 @@ defmodule Axon.MixProject do # Run "mix help deps" to learn about dependencies. defp deps do [ - {:exla, "~> 0.4.0", [only: :test] ++ exla_opts()}, - {:torchx, "~> 0.4.0", [only: :test] ++ torchx_opts()}, - {:nx, "~> 0.4.0", nx_opts()}, + {:exla, "~> 0.6.0", [only: :test] ++ exla_opts()}, + {:torchx, "~> 0.6.0", [only: :test] ++ torchx_opts()}, + {:nx, "~> 0.6.0", nx_opts()}, {:ex_doc, "~> 0.23", only: :docs}, {:table_rex, "~> 3.1.1", optional: true}, {:kino, "~> 0.7", optional: true}, - {:kino_vega_lite, "~> 0.1.7", optional: true} + {:kino_vega_lite, "~> 0.1.7", optional: true}, + {:polaris, "~> 0.1"} ] end @@ -115,7 +116,7 @@ defmodule Axon.MixProject do groups_for_extras: [ "Guides: Model Creation": Path.wildcard("guides/model_creation/*.livemd"), "Guides: Model Execution": Path.wildcard("guides/model_execution/*.livemd"), - "Guides: Training and Evalutaion": + "Guides: Training and Evaluation": Path.wildcard("guides/training_and_evaluation/*.livemd"), "Guides: Serialization": Path.wildcard("guides/serialization/*.livemd"), "Examples: Basics": Path.wildcard("notebooks/basics/*.livemd"), @@ -155,7 +156,7 @@ defmodule Axon.MixProject do Axon.MixedPrecision, Axon.None, Axon.StatefulOutput, - Axon.Initalizers + Axon.Initializers ], Summary: [ Axon.Display @@ -169,11 +170,6 @@ defmodule Axon.MixProject do Axon.Recurrent, Axon.LossScale ], - Optimization: [ - Axon.Optimizers, - Axon.Updates, - Axon.Schedules - ], Loop: [ Axon.Loop, Axon.Loop.State diff --git a/mix.lock b/mix.lock index c2a7ff8d..a4010e04 100644 --- a/mix.lock +++ b/mix.lock @@ -1,21 +1,23 @@ %{ - "complex": {:hex, :complex, "0.4.2", "923e5db0be13dbb3ea00cf8459d9f75f3afdd9ff5a82742ded21064330d28273", [:mix], [], "hexpm", "069a085ef820ce675a2619fd125b963ff4514af2102c7f7d7965128e5ec0a429"}, - "dll_loader_helper": {:hex, :dll_loader_helper, "0.1.8", "1621409a3cb06c750fe845bf954785cffa5fe8f2fca41006008b891877603bf7", [:make, :mix, :rebar3], [], "hexpm", "cd373dc6a028f3e37eca26b073e3a75249513db2f9b0e42520423886801fa7d7"}, - "earmark_parser": {:hex, :earmark_parser, "1.4.29", "149d50dcb3a93d9f3d6f3ecf18c918fb5a2d3c001b5d3305c926cddfbd33355b", [:mix], [], "hexpm", "4902af1b3eb139016aed210888748db8070b8125c2342ce3dcae4f38dcc63503"}, - "elixir_make": {:hex, :elixir_make, "0.7.1", "314f2a5450254db0446ba94cc1ba12a25b83b457f24aa9cc21c128cead5d03aa", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "0f1ad4787b4d7489563351cbf85c9221a852f5441364a2cb3ffd36f2fda7f7fb"}, - "ex_doc": {:hex, :ex_doc, "0.29.1", "b1c652fa5f92ee9cf15c75271168027f92039b3877094290a75abcaac82a9f77", [:mix], [{:earmark_parser, "~> 1.4.19", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "b7745fa6374a36daf484e2a2012274950e084815b936b1319aeebcf7809574f6"}, - "exla": {:hex, :exla, "0.4.1", "409a3294720e31bbcd03c3eacd654686feb0ed7ba3e42314a269eeaa7cfd3c76", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.4.1", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.4.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "0410a0c38b94d2be4b713753c86f33ab8400b5cc8b59100266a0a6c58c17871d"}, - "kino": {:hex, :kino, "0.8.0", "07603a32c111959ed48f08ac3808a0dda05433d28f8d2f06d65b25b255966649", [:mix], [{:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "736568d4de9eb56d8903bae6fe08b7c06db44efe37bb883165e755e623881c51"}, - "kino_vega_lite": {:hex, :kino_vega_lite, "0.1.7", "c93fdfe6e35c4c5a4f8afd51a89786b2187e5a7da4595b13ea02a4329d9f0976", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.4", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "59ee442f0532266749d15dc9af4e2875bec61ccfa1b07636bc396ee63dfde8e7"}, + "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, + "dll_loader_helper": {:hex, :dll_loader_helper, "1.1.0", "e7d015e980942a0d67e306827ec907e7e853a21186bd92bb968d986698591a0f", [:mix], [{:dll_loader_helper_beam, "~> 1.1", [hex: :dll_loader_helper_beam, repo: "hexpm", optional: false]}], "hexpm", "2b6c11ee7bb48f6a132ce8f872202f9e828c019988da1e2d40ad41496195df0c"}, + "dll_loader_helper_beam": {:hex, :dll_loader_helper_beam, "1.1.0", "d51232663985dbc998c59b5d080feecd5398d5b75a9f0293a9855db774c2684d", [:rebar3], [], "hexpm", "aa85d0d0e9398916a80b2fd751885877934ae3ea008288f99ff829c0b8ef1f55"}, + "earmark_parser": {:hex, :earmark_parser, "1.4.31", "a93921cdc6b9b869f519213d5bc79d9e218ba768d7270d46fdcf1c01bacff9e2", [:mix], [], "hexpm", "317d367ee0335ef037a87e46c91a2269fef6306413f731e8ec11fc45a7efd059"}, + "elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"}, + "ex_doc": {:hex, :ex_doc, "0.29.3", "f07444bcafb302db86e4f02d8bbcd82f2e881a0dcf4f3e4740e4b8128b9353f7", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "3dc6787d7b08801ec3b51e9bd26be5e8826fbf1a17e92d1ebc252e1a1c75bfe1"}, + "exla": {:hex, :exla, "0.6.0", "af63e45ce41ad25630967923147d14292a0cc48e507b8a3cf3bf3d5483099a28", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.6.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.5.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "5f6a4a105ea9ab207b9aa4de5a294730e2bfe9639f4b8d37a7c00da131090d7a"}, + "kino": {:hex, :kino, "0.9.0", "9d023e66ed29123ba414e978012a6e9958b09fbf5dddb5e0f4814e04df8223b7", [:mix], [{:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "46767bdbbdacc1c801d43b2dc6d2fe7fdf936bd74f4accdc5779f647f5eeda66"}, + "kino_vega_lite": {:hex, :kino_vega_lite, "0.1.8", "ec7e97778d6b774591e4cbf7fd27850abf7c0f5e9133a3d13e069aadfa04b5e3", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.4", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "0bc3135a77550ea5c5bd7bfb1fb215416ebddbbc8b1e280e6de39366cd17a2f8"}, "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.0", "f8c570a0d33f8039513fbccaf7108c5d750f47d8defd44088371191b76492b0b", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "28b2cbdc13960a46ae9a8858c4bebdec3c9a6d7b4b9e7f4ed1502f8159f338e7"}, "makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"}, "nimble_parsec": {:hex, :nimble_parsec, "1.2.3", "244836e6e3f1200c7f30cb56733fd808744eca61fd182f731eac4af635cc6d0b", [:mix], [], "hexpm", "c8d789e39b9131acf7b99291e93dae60ab48ef14a7ee9d58c6964f59efb570b0"}, - "nx": {:hex, :nx, "0.4.1", "3cc8e420d0835ab7cac94f253950dee3ff927c68798ae88a3e4ff184a825b042", [:mix], [{:complex, "~> 0.4.2", [hex: :complex, repo: "hexpm", optional: false]}], "hexpm", "0b33fccaf76ebc6e79d53fe1149a70f99838e6505e9e7092e5a0a57b131b27c6"}, + "nx": {:hex, :nx, "0.6.0", "37c86eae824125a7e298dd1ee896953d9d671ce3630dcff74c77db17d734a85f", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "e1ad3cc70a5828a1aedb156b71e90863d9623a2dc9b35a5588f8627a07ee6cb4"}, + "polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"}, "table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"}, "table_rex": {:hex, :table_rex, "3.1.1", "0c67164d1714b5e806d5067c1e96ff098ba7ae79413cc075973e17c38a587caa", [:mix], [], "hexpm", "678a23aba4d670419c23c17790f9dcd635a4a89022040df7d5d772cb21012490"}, - "telemetry": {:hex, :telemetry, "1.1.0", "a589817034a27eab11144ad24d5c0f9fab1f58173274b1e9bae7074af9cbee51", [:rebar3], [], "hexpm", "b727b2a1f75614774cff2d7565b64d0dfa5bd52ba517f16543e6fc7efcc0df48"}, - "torchx": {:hex, :torchx, "0.4.1", "5aa7f93d7aff85c9f5fbae4c534affa9d16a9ffe9bbbb261cf2dca8ead2f6ab8", [:make, :mix], [{:dll_loader_helper, "~> 0.1.0", [hex: :dll_loader_helper, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.4.1", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "3d1ee6b2588cadf2e70e4ea33449bda8e27f315b653f12d1f681070b83854b0e"}, + "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, + "torchx": {:hex, :torchx, "0.6.0", "e4a5f545e245c15aceeafcf9f22ac2ae0a87720c4a6b2f132e9909635f434e93", [:make, :mix], [{:dll_loader_helper, "~> 0.1 or ~> 1.0", [hex: :dll_loader_helper, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.6.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "35365dc51ee28dc86ca87c150dd3869bc83b207b2574bb2310c1be39e3867550"}, "vega_lite": {:hex, :vega_lite, "0.1.6", "145ab4908bc890b02cef3526e890e9b899528eaa7aa9d6fa642b52a8a2c682c6", [:mix], [{:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "078c0d8cd9a8eca4ae8f9527c45c01d69cefb6b2235fd5179a227ac2f031d7ac"}, - "xla": {:hex, :xla, "0.4.1", "c14a8214928f1aee68745b70c4f817c90e98740ceb69ad921071eb41792f9ecf", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "fe8323ceeebf114f183fcd3a09ab08d76a71e9fd9b1154109078a8355aa56366"}, + "xla": {:hex, :xla, "0.5.0", "fb8a02c02e5a4f4531fbf18a90c325e471037f983f0115d23f510e7dd9a6aa65", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "571ac797a4244b8ba8552ed0295a54397bd896708be51e4da6cbb784f6678061"}, } diff --git a/notebooks/generative/fashionmnist_autoencoder.livemd b/notebooks/generative/fashionmnist_autoencoder.livemd index 036776ff..43f10cf7 100644 --- a/notebooks/generative/fashionmnist_autoencoder.livemd +++ b/notebooks/generative/fashionmnist_autoencoder.livemd @@ -117,7 +117,7 @@ mean_square_error = fn y_pred, y -> |> Nx.mean() end -mean_absolute_erorr = fn y_pred, y -> +mean_absolute_error = fn y_pred, y -> y_pred |> Nx.subtract(y) |> Nx.abs() @@ -139,7 +139,7 @@ For the same image both errors should be 0, because when we have two exact copie ```elixir { mean_square_error.(shoe_image, shoe_image), - mean_absolute_erorr.(shoe_image, shoe_image) + mean_absolute_error.(shoe_image, shoe_image) } ``` @@ -148,7 +148,7 @@ Now the noised image: ```elixir { mean_square_error.(shoe_image, noised_shoe_image), - mean_absolute_erorr.(shoe_image, noised_shoe_image) + mean_absolute_error.(shoe_image, noised_shoe_image) } ``` @@ -157,7 +157,7 @@ And a different image: ```elixir { mean_square_error.(shoe_image, other_image), - mean_absolute_erorr.(shoe_image, other_image) + mean_absolute_error.(shoe_image, other_image) } ``` diff --git a/notebooks/generative/fashionmnist_vae.livemd b/notebooks/generative/fashionmnist_vae.livemd index 352357be..a4b4e968 100644 --- a/notebooks/generative/fashionmnist_vae.livemd +++ b/notebooks/generative/fashionmnist_vae.livemd @@ -251,7 +251,7 @@ end params = model - |> Axon.Loop.trainer(:mean_squared_error, Axon.Optimizers.adamw(0.001)) + |> Axon.Loop.trainer(:mean_squared_error, Polaris.Optimizers.adamw(learning_rate: 0.001)) |> KinoAxon.kino_early_stop() |> Axon.Loop.handle(:iteration_completed, render_example_handler, every: 450) |> Axon.Loop.validate(model, test_data) @@ -265,7 +265,7 @@ params = ## Splitting up the model -Cool! We now have the parameters for a trained, simple autoencoder. Our next step is to split up the model so we can use the encoder and decoder separately. By doing that, we'll be able to take an image and *encode* it to get the model's compressed image representation (the latent vector). We can then manipulate the latent vector and run the manipulated latent vector through the *decoder* to get a new image. +Cool! We now have the parameters for a trained, simple autoencoder. Our next step is to split up the model so we can use the encoder and decoder separately. By doing that, we'll be able to take an image and _encode_ it to get the model's compressed image representation (the latent vector). We can then manipulate the latent vector and run the manipulated latent vector through the _decoder_ to get a new image. Let's start by defining the encoder and decoder separately as two different models. @@ -311,7 +311,7 @@ So all we need to do is create a new Map that plucks out the right layers from o Fortunately, since we gave each of the layers names, this requires no work at all - we can use the Map as it is since the layer names match up! Axon will ignore any extra keys so those won't be a problem. -Note that naming the layers wasn't *required*, if the layers didn't have names we would have some renaming to do to get the names to match between the models. But giving them names made it very convenient :) +Note that naming the layers wasn't _required_, if the layers didn't have names we would have some renaming to do to get the names to match between the models. But giving them names made it very convenient :) Let's try encoding an image, printing the latent and then decoding the latent using our split up model to make sure it's working. @@ -474,7 +474,7 @@ end params = model - |> Axon.Loop.trainer(&CustomLoss.loss/2, Axon.Optimizers.adam(0.001)) + |> Axon.Loop.trainer(&CustomLoss.loss/2, Polaris.Optimizers.adam(learning_rate: 0.001)) |> KinoAxon.kino_early_stop() |> Axon.Loop.handle(:epoch_completed, render_example_handler) |> Axon.Loop.validate(model, test_data) diff --git a/notebooks/generative/mnist_autoencoder_using_kino.livemd b/notebooks/generative/mnist_autoencoder_using_kino.livemd index b56ed8a7..e251922e 100644 --- a/notebooks/generative/mnist_autoencoder_using_kino.livemd +++ b/notebooks/generative/mnist_autoencoder_using_kino.livemd @@ -75,13 +75,13 @@ test_images[[images: 0..2]] |> Nx.to_heatmap() An autoencoder is a a network that has the same sized input as output, with a "bottleneck" layer in the middle with far fewer parameters than the input. Its goal is to force the output to reconstruct the input. The bottleneck layer forces the network to learn a compressed representation of the input space. -A *denoising* autoencoder is a small tweak on an autoencoder that takes a corrupted input (often corrupted by adding noise or zeroing out pixels) and reconstructs the original input, removing the noise in the process. +A _denoising_ autoencoder is a small tweak on an autoencoder that takes a corrupted input (often corrupted by adding noise or zeroing out pixels) and reconstructs the original input, removing the noise in the process. -The part of the autoencoder that takes the input and compresses it into the bottleneck layer is called the *encoder* and the part that takes the compressed representation and reconstructs the input is called the *decoder*. Usually the decoder mirrors the encoder. +The part of the autoencoder that takes the input and compresses it into the bottleneck layer is called the _encoder_ and the part that takes the compressed representation and reconstructs the input is called the _decoder_. Usually the decoder mirrors the encoder. MNIST is a pretty easy dataset, so we're going to try a fairly small autoencoder. -The input image has size 784 (28 rows * 28 cols * 1 pixel). We'll set up the encoder to turn that into 256 features, then 128, 64, and then 10 features for the bottleneck layer. The decoder will do the reverse, take the 10 features and go to 64, 128, 256 and 784. I'll use fully-connected (dense) layers. +The input image has size 784 (28 rows _ 28 cols _ 1 pixel). We'll set up the encoder to turn that into 256 features, then 128, 64, and then 10 features for the bottleneck layer. The decoder will do the reverse, take the 10 features and go to 64, 128, 256 and 784. I'll use fully-connected (dense) layers. @@ -197,14 +197,14 @@ Looks right (and tricky). Let's see how the model does. ```elixir params = model - |> Axon.Loop.trainer(:mean_squared_error, Axon.Optimizers.adamw(0.001)) + |> Axon.Loop.trainer(:mean_squared_error, Polaris.Optimizers.adamw(learning_rate: 0.001)) |> Axon.Loop.validate(model, test_data) |> Axon.Loop.run(train_data, %{}, epochs: 20, compiler: EXLA) :ok ``` -Now that we have a model that theoretically has learned *something*, we'll see what it's learned by running it on some images from the test set. We'll use Kino to allow us to select the image from the test set to run the model against. To avoid losing the params that took a while to train, we'll create another branch so we can experiment with the params and stop execution when needed without having to retrain. +Now that we have a model that theoretically has learned _something_, we'll see what it's learned by running it on some images from the test set. We'll use Kino to allow us to select the image from the test set to run the model against. To avoid losing the params that took a while to train, we'll create another branch so we can experiment with the params and stop execution when needed without having to retrain. @@ -257,7 +257,7 @@ Note we used `Kino.animate/2` which runs asynchronously so we don't block execut ## A better training loop -*Note that we branch from the "Building a model" section since we only need the model definition for this section and not the previously trained model.* +_Note that we branch from the "Building a model" section since we only need the model definition for this section and not the previously trained model._ @@ -312,7 +312,7 @@ end params = model - |> Axon.Loop.trainer(:mean_squared_error, Axon.Optimizers.adamw(0.001)) + |> Axon.Loop.trainer(:mean_squared_error, Polaris.Optimizers.adamw(learning_rate: 0.001)) |> Axon.Loop.handle(:iteration_completed, render_example_handler, every: 450) |> Axon.Loop.validate(model, test_data) |> Axon.Loop.run(train_data, %{}, epochs: 20, compiler: EXLA) diff --git a/notebooks/structured/credit_card_fraud.livemd b/notebooks/structured/credit_card_fraud.livemd index e407723f..ac937424 100644 --- a/notebooks/structured/credit_card_fraud.livemd +++ b/notebooks/structured/credit_card_fraud.livemd @@ -17,7 +17,7 @@ alias Explorer.{DataFrame, Series} ## Introduction -This time we will examine the Credit Card Fraud Dataset. Due to confidentiality, the original data were preprocessed by principal component analysis (PCA), and then 31 principal components were selected for the final data set. The dataset is highly imbalanced. The positive class (frauds) account for 0.172% of all transactions. Eventually, we will create a classifier which has not only great accuracy but, what is even more important, a high *recall* and *precision* - two metrics that are much more indicative of performance with imbalanced classification problems. +This time we will examine the Credit Card Fraud Dataset. Due to confidentiality, the original data were preprocessed by principal component analysis (PCA), and then 31 principal components were selected for the final data set. The dataset is highly imbalanced. The positive class (frauds) account for 0.172% of all transactions. Eventually, we will create a classifier which has not only great accuracy but, what is even more important, a high _recall_ and _precision_ - two metrics that are much more indicative of performance with imbalanced classification problems. ## Data processing @@ -139,7 +139,7 @@ IO.puts("# of fraudulent transactions (train): #{fraud}") IO.puts("% fraudlent transactions (train): #{100 * (fraud / (legit + fraud))}%") ``` -As always, we define our train loop. We are using *binary cross-entropy* as our loss function and Adam as the optimizer with a learning rate of 0.01. Then we immediately start the training passing our train portion of the dataset. +As always, we define our train loop. We are using _binary cross-entropy_ as our loss function and Adam as the optimizer with a learning rate of 0.01. Then we immediately start the training passing our train portion of the dataset. ```elixir loss = @@ -151,7 +151,7 @@ loss = reduction: :mean ) -optimizer = Axon.Optimizers.adam(1.0e-2) +optimizer = Polaris.Optimizers.adam(learning_rate: 1.0e-2) params = model diff --git a/notebooks/text/lstm_generation.livemd b/notebooks/text/lstm_generation.livemd index 677188f7..7aabec7c 100644 --- a/notebooks/text/lstm_generation.livemd +++ b/notebooks/text/lstm_generation.livemd @@ -158,7 +158,7 @@ model = To train the network, we will use Axon's Loop API. It is pretty straightforward. -For the loss function we can use *categorical cross-entropy* since we are dealing with categories (each character) in our output. For the optimizer we can use *Adam*. +For the loss function we can use _categorical cross-entropy_ since we are dealing with categories (each character) in our output. For the optimizer we can use _Adam_. We will train our network for 20 epochs. Note that we are working with a fair amount data, so it may take a long time unless you run it on a GPU. @@ -171,7 +171,7 @@ IO.puts("Total batches: #{Enum.count(train_batches)}") params = model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(0.001)) + |> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adam(learning_rate: 0.001)) |> Axon.Loop.run(Stream.zip(train_batches, result_batches), %{}, epochs: 20, compiler: EXLA) :ok @@ -250,7 +250,7 @@ IO.puts("Total batches: #{Enum.count(train_batches)}") new_params = new_model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(0.001)) + |> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adam(learning_rate: 0.001)) |> Axon.Loop.run(Stream.zip(train_batches, result_batches), %{}, epochs: 50, compiler: EXLA) :ok diff --git a/notebooks/vision/horses_or_humans.livemd b/notebooks/vision/horses_or_humans.livemd index 225e2db3..ce24a7b9 100644 --- a/notebooks/vision/horses_or_humans.livemd +++ b/notebooks/vision/horses_or_humans.livemd @@ -143,7 +143,7 @@ The next step is creating our model. In this notebook, we choose the classic Con -| ![](https://miroslawmamczur.pl/wp-content/uploads/2021/03/06.gif) | +| ![](https://miroslawmamczur.pl/wp-content/uploads/2021/03/06.gif) | | :-------------------------------------------------------------------------------------: | | Figure 1: A step-by-step visualization of a convolution layer for `kernel_size: {3, 3}` | @@ -155,7 +155,7 @@ The next step is creating our model. In this notebook, we choose the classic Con | ![](https://production-media.paperswithcode.com/methods/MaxpoolSample2.png) | | :-------------------------------------------------------------------------: | -| Figure 2: Max pooling operation for `kernel_size: {2, 2}` | +| Figure 2: Max pooling operation for `kernel_size: {2, 2}` | @@ -163,7 +163,7 @@ The next step is creating our model. In this notebook, we choose the classic Con -| ![](https://miro.medium.com/max/1400/1*KkqxjvXTIV_b365B41ltfg.png) | +| ![](https://miro.medium.com/max/1400/1*KkqxjvXTIV_b365B41ltfg.png) | | :-------------------------------------------------------------------: | | Figure 3: The difference between standard dropout and spatial dropout | @@ -199,7 +199,7 @@ It's time to train our model. We specify the loss, optimizer and choose accuracy ```elixir data = HorsesHumans.DataProcessing.data_stream(files, batch_size) -optimizer = Axon.Optimizers.adam(1.0e-4) +optimizer = Polaris.Optimizers.adam(learning_rate: 1.0e-4) params = model @@ -215,7 +215,7 @@ params = We can improve the training by applying gradient centralization. It is a technique with a similar purpose to batch normalization. For each loss gradient, we subtract a mean value to have a gradient with mean equal to zero. This process prevents gradients from exploding. ```elixir -centralized_optimizer = Axon.Updates.compose(Axon.Updates.centralize(), optimizer) +centralized_optimizer = Polaris.Updates.compose(Polaris.Updates.centralize(), optimizer) model |> Axon.Loop.trainer(:categorical_cross_entropy, centralized_optimizer, :identity, log: 1) @@ -242,7 +242,7 @@ input = Axon.predict(model, params, input) ``` -*Note: the model output refers to the probability that the image presents a horse and a human respectively.* +_Note: the model output refers to the probability that the image presents a horse and a human respectively._ diff --git a/test/axon/activations_test.exs b/test/axon/activations_test.exs index 47e0cc7c..2f6634cb 100644 --- a/test/axon/activations_test.exs +++ b/test/axon/activations_test.exs @@ -700,7 +700,7 @@ defmodule Axon.ActivationsTest do describe "log_softmax" do test "raises on bad axis" do - assert_raise ArgumentError, ~r/log_softmax axis must be within rank of tensor/, fn -> + assert_raise ArgumentError, "given axis (2) invalid for shape with rank 2", fn -> Axon.Activations.log_softmax(Nx.iota({1, 3}), axis: 2) end end @@ -1143,6 +1143,22 @@ defmodule Axon.ActivationsTest do actual = apply(jit(fn x -> grad(x, &Nx.sum(Axon.Activations.sigmoid(&1))) end), [a]) assert_all_close(expected, actual) end + + defn cache_test_sigmoid(x) do + x + |> Axon.Activations.sigmoid() + |> get_cached() + end + + deftransformp get_cached(res) do + %{data: %{args: [_, %{logits: inp}]}} = res + inp + end + + test "caches input logits" do + {a, _key} = Nx.Random.uniform(Nx.Random.key(42), shape: {10, 10}) + assert_all_close(cache_test_sigmoid(a), a) + end end describe "silu" do @@ -1348,6 +1364,17 @@ defmodule Axon.ActivationsTest do actual = apply(jit(fn x -> grad(x, &Nx.sum(Axon.Activations.softmax(&1))) end), [a]) assert_all_close(expected, actual, atol: 1.0e-7) end + + defn cache_test_softmax(x) do + x + |> Axon.Activations.softmax() + |> get_cached() + end + + test "caches input logits" do + {a, _key} = Nx.Random.uniform(Nx.Random.key(42), shape: {10, 10}) + assert_all_close(cache_test_softmax(a), a) + end end describe "softplus" do diff --git a/test/axon/compiler_test.exs b/test/axon/compiler_test.exs index a05da1e8..5db0d7ab 100644 --- a/test/axon/compiler_test.exs +++ b/test/axon/compiler_test.exs @@ -7,7 +7,7 @@ defmodule CompilerTest do describe "input" do test "single input, single output" do model = Axon.input("input_0", shape: {nil, 1}) - input = Nx.random_uniform({1, 1}, type: {:f, 32}) + input = random({1, 1}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -19,8 +19,8 @@ defmodule CompilerTest do {Axon.input("input_0", shape: {nil, 1}), Axon.input("input_1", shape: {nil, 1})} |> Axon.container() - input1 = Nx.random_uniform({1, 1}) - input2 = Nx.random_uniform({1, 1}) + input1 = random({1, 1}) + input2 = random({1, 1}) input = %{"input_0" => input1, "input_1" => input2} assert {init_fn, predict_fn} = Axon.build(model1) @@ -35,7 +35,7 @@ defmodule CompilerTest do test "output map" do model = %{foo: Axon.input("input_0", shape: {nil, 1})} |> Axon.container() - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert {init_fn, predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -48,8 +48,8 @@ defmodule CompilerTest do model1 = {input1, {input1, {input2, {}}, input2, %{foo: input1}}} |> Axon.container() - inp1 = Nx.random_uniform({1, 1}) - inp2 = Nx.random_uniform({1, 2}) + inp1 = random({1, 1}) + inp2 = random({1, 2}) input = %{"input_0" => inp1, "input_1" => inp2} assert {init_fn, predict_fn} = Axon.build(model1) @@ -67,9 +67,9 @@ defmodule CompilerTest do z = Axon.input("z", shape: {nil, 1}) model = {z, x, y} |> Axon.container() - x_val = Nx.random_uniform({1, 1}) - y_val = Nx.random_uniform({1, 1}) - z_val = Nx.random_uniform({1, 1}) + x_val = random({1, 1}) + y_val = random({1, 1}) + z_val = random({1, 1}) input = %{"x" => x_val, "y" => y_val, "z" => z_val} assert {init_fn, predict_fn} = Axon.build(model) @@ -99,7 +99,7 @@ defmodule CompilerTest do test "raises if input not found, no default value" do model = Axon.input("input_0", shape: {nil, 32}) - input = Nx.random_uniform({1, 16}) + input = random({1, 16}) assert {_, predict_fn} = Axon.build(model) exception = assert_raise ArgumentError, fn -> predict_fn.(%{}, %{foo: input}) end @@ -262,7 +262,7 @@ defmodule CompilerTest do test "initializes with no params" do for activation <- @activation_layers do model = Axon.input("input_0", shape: {nil, 32}) |> Axon.activation(activation) - input = Nx.random_uniform({1, 32}) + input = random({1, 32}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -272,7 +272,7 @@ defmodule CompilerTest do test "computes forward pass with default options" do for activation <- @activation_layers do model = Axon.input("input_0", shape: {nil, 1}) |> Axon.activation(activation) - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert {_init_fn, predict_fn} = Axon.build(model) assert_equal(predict_fn.(%{}, input), apply(Axon.Activations, activation, [input])) @@ -282,7 +282,7 @@ defmodule CompilerTest do test "computes forward pass with custom options" do for activation <- [:celu, :elu, :leaky_relu] do model = Axon.input("input_0", shape: {nil, 32}) |> Axon.activation(activation, alpha: 0.8) - input = Nx.random_uniform({1, 32}, type: {:f, 32}) + input = random({1, 32}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model) @@ -299,10 +299,10 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert {init_fn, predict_fn} = Axon.build(mp_model) - assert Nx.type(predict_fn.(init_fn.(input, %{}), Nx.random_uniform({1, 1}))) == {:bf, 16} + assert Nx.type(predict_fn.(init_fn.(input, %{}), random({1, 1}))) == {:bf, 16} end end end @@ -311,7 +311,7 @@ defmodule CompilerTest do test "initializes in default case" do model = Axon.input("input_0", shape: {nil, 1}) |> Axon.bias(name: "bias") - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{"bias" => %{"bias" => bias}} = init_fn.(input, %{}) @@ -324,7 +324,7 @@ defmodule CompilerTest do test "initializes in default case" do model = Axon.input("input_0", shape: {nil, 1}) |> Axon.dense(1, name: "dense") - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{"dense" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(input, %{}) @@ -335,7 +335,7 @@ defmodule CompilerTest do end test "initializes with custom initializers" do - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) model1 = Axon.input("input_0", shape: {nil, 1}) @@ -382,7 +382,7 @@ defmodule CompilerTest do Nx.Defn.grad(params, &Nx.mean(predict_fn.(&1, input))) end - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert %{"dense" => %{"kernel" => kernel_grad, "bias" => bias_grad}} = apply(Nx.Defn.jit(backward), [init_fn.(input, %{}), input]) @@ -396,7 +396,7 @@ defmodule CompilerTest do policy = AMP.create_policy(params: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) assert {init_fn, _} = Axon.build(mp_model) assert %{"dense" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(input, %{}) @@ -409,7 +409,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -419,7 +419,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 2}) |> Axon.dense(1, name: "dense", use_bias: false) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) assert {init_fn, _} = Axon.build(model) assert %{"dense" => %{"kernel" => _} = dense_params} = init_fn.(input, %{}) @@ -430,7 +430,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 2}) |> Axon.dense(1, name: "dense", use_bias: false) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) assert {init_fn, predict_fn} = Axon.build(model) assert %{"dense" => %{"kernel" => k}} = params = init_fn.(input, %{}) @@ -445,7 +445,7 @@ defmodule CompilerTest do input2 = Axon.input("input_1", shape: {nil, 2}) model = Axon.bilinear(input1, input2, 1, name: "bilinear") - inputs = %{"input_0" => Nx.random_uniform({1, 1}), "input_1" => Nx.random_uniform({1, 2})} + inputs = %{"input_0" => random({1, 1}), "input_1" => random({1, 2})} assert {init_fn, _predict_fn} = Axon.build(model) assert %{"bilinear" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(inputs, %{}) @@ -460,11 +460,11 @@ defmodule CompilerTest do input2 = Axon.input("input_1", shape: {nil, 2}) model1 = Axon.bilinear(input1, input2, 1, name: "bilinear", kernel_initializer: :zeros) - inputs = %{"input_0" => Nx.random_uniform({1, 1}), "input_1" => Nx.random_uniform({1, 2})} + inputs = %{"input_0" => random({1, 1}), "input_1" => random({1, 2})} assert {init_fn, _predict_fn} = Axon.build(model1) assert %{"bilinear" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(inputs, %{}) - assert_equal(kernel, zeros({1, 1, 2})) + assert_equal(kernel, zeros({1, 2})) assert Nx.shape(bias) == {1} assert Nx.type(bias) == {:f, 32} @@ -522,7 +522,7 @@ defmodule CompilerTest do input2 = Axon.input("input_1", shape: {nil, 2}) model = Axon.bilinear(input1, input2, 1, name: "bilinear") |> Axon.freeze() - input = %{"input_0" => Nx.random_uniform({1, 1}), "input_1" => Nx.random_uniform({1, 2})} + input = %{"input_0" => random({1, 1}), "input_1" => random({1, 2})} assert {init_fn, predict_fn} = Axon.build(model) @@ -544,7 +544,7 @@ defmodule CompilerTest do policy = AMP.create_policy(params: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = %{"input_0" => Nx.random_uniform({1, 1}), "input_1" => Nx.random_uniform({1, 2})} + input = %{"input_0" => random({1, 1}), "input_1" => random({1, 2})} assert {init_fn, _} = Axon.build(mp_model) assert %{"bilinear" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(input, %{}) @@ -559,7 +559,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = %{"input_0" => Nx.random_uniform({1, 1}), "input_1" => Nx.random_uniform({1, 2})} + input = %{"input_0" => random({1, 1}), "input_1" => random({1, 2})} assert {init_fn, predict_fn} = Axon.build(mp_model) @@ -571,7 +571,7 @@ defmodule CompilerTest do input2 = Axon.input("input_1", shape: {nil, 2}) model = Axon.bilinear(input1, input2, 1, name: "bilinear", use_bias: false) - input = %{"input_0" => Nx.random_uniform({1, 1}), "input_1" => Nx.random_uniform({1, 2})} + input = %{"input_0" => random({1, 1}), "input_1" => random({1, 2})} assert {init_fn, _} = Axon.build(model) assert %{"bilinear" => %{"kernel" => _} = bilinear_params} = init_fn.(input, %{}) @@ -583,8 +583,8 @@ defmodule CompilerTest do input2 = Axon.input("input_1", shape: {nil, 2}) model = Axon.bilinear(input1, input2, 1, name: "bilinear", use_bias: false) - inp1 = Nx.random_uniform({1, 1}) - inp2 = Nx.random_uniform({1, 2}) + inp1 = random({1, 1}) + inp2 = random({1, 2}) input = %{"input_0" => inp1, "input_1" => inp2} @@ -602,7 +602,7 @@ defmodule CompilerTest do test "initializes in default case" do model = Axon.input("input_0", shape: {nil, 1}) |> Axon.embedding(1, 1, name: "embedding") - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{"embedding" => %{"kernel" => kernel}} = init_fn.(input, %{}) @@ -615,7 +615,7 @@ defmodule CompilerTest do Axon.input("input_0", shape: {nil, 1}) |> Axon.embedding(1, 1, name: "embedding", kernel_initializer: :zeros) - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert {init_fn, _predict_fn} = Axon.build(model1) assert %{"embedding" => %{"kernel" => kernel}} = init_fn.(input, %{}) @@ -685,7 +685,7 @@ defmodule CompilerTest do for pool <- @pooling_layers do model = apply(Axon, pool, [Axon.input("input", shape: {nil, 32, 1})]) - input = Nx.random_uniform({1, 32, 1}) + input = random({1, 32, 1}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -697,7 +697,7 @@ defmodule CompilerTest do for pool <- @pooling_layers do model1 = apply(Axon, pool, [Axon.input("input", shape: {nil, 32, 1})]) - input1 = Nx.random_uniform({1, 32, 1}, type: {:f, 32}) + input1 = random({1, 32, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model1) @@ -707,7 +707,7 @@ defmodule CompilerTest do ) model2 = apply(Axon, pool, [Axon.input("input", shape: {nil, 8, 4, 1})]) - input2 = Nx.random_uniform({1, 8, 4, 1}, type: {:f, 32}) + input2 = random({1, 8, 4, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model2) @@ -717,7 +717,7 @@ defmodule CompilerTest do ) model3 = apply(Axon, pool, [Axon.input("input", shape: {nil, 8, 4, 2, 1})]) - input3 = Nx.random_uniform({1, 8, 4, 2, 1}, type: {:f, 32}) + input3 = random({1, 8, 4, 2, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model3) @@ -732,21 +732,21 @@ defmodule CompilerTest do for pool <- @pooling_layers do opts1 = [kernel_size: 6] model1 = apply(Axon, pool, [Axon.input("input", shape: {nil, 32, 1}), opts1]) - input1 = Nx.random_uniform({1, 32, 1}, type: {:f, 32}) + input1 = random({1, 32, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model1) assert_equal(predict_fn.(%{}, input1), apply(Axon.Layers, pool, [input1, opts1])) opts2 = [kernel_size: 2, strides: 2, padding: :same] model2 = apply(Axon, pool, [Axon.input("input", shape: {nil, 8, 4, 1}), opts2]) - input2 = Nx.random_uniform({1, 8, 4, 1}, type: {:f, 32}) + input2 = random({1, 8, 4, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model2) assert_equal(predict_fn.(%{}, input2), apply(Axon.Layers, pool, [input2, opts2])) opts3 = [kernel_size: {2, 1, 2}, strides: [1, 2, 1], padding: [{0, 1}, {1, 1}, {0, 2}]] model3 = apply(Axon, pool, [Axon.input("input", shape: {nil, 8, 4, 2, 1}), opts3]) - input3 = Nx.random_uniform({1, 8, 4, 2, 1}, type: {:f, 32}) + input3 = random({1, 8, 4, 2, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model3) assert_equal(predict_fn.(%{}, input3), apply(Axon.Layers, pool, [input3, opts3])) @@ -755,7 +755,7 @@ defmodule CompilerTest do test "lp_pool computes forward pass with custom norm" do model = Axon.input("input", shape: {nil, 32, 1}) |> Axon.lp_pool(norm: 3) - input = Nx.random_uniform({1, 32, 1}, type: {:f, 32}) + input = random({1, 32, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model) assert_equal(predict_fn.(%{}, input), Axon.Layers.lp_pool(input, kernel_size: {1}, norm: 3)) @@ -767,11 +767,11 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 32, 1}) + input = random({1, 32, 1}) assert {init_fn, predict_fn} = Axon.build(mp_model) - assert Nx.type(predict_fn.(init_fn.(input, %{}), Nx.random_uniform({1, 32, 1}))) == + assert Nx.type(predict_fn.(init_fn.(input, %{}), random({1, 32, 1}))) == {:bf, 16} end end @@ -784,7 +784,7 @@ defmodule CompilerTest do [channels: :last, kernel_size: {2}] ]) - inp = Nx.random_uniform({1, 32, 1}) + inp = random({1, 32, 1}) assert {_, predict_fn} = Axon.build(model) @@ -813,6 +813,59 @@ defmodule CompilerTest do # end end + describe "blur_pool" do + test "initializes with no params" do + model = apply(Axon, :blur_pool, [Axon.input("input", shape: {nil, 32, 32, 1})]) + + input = random({1, 32, 32, 1}) + + assert {init_fn, _predict_fn} = Axon.build(model) + assert %{} = init_fn.(input, %{}) + end + + test "computes forward pass with default options" do + model2 = apply(Axon, :blur_pool, [Axon.input("input", shape: {nil, 8, 4, 1})]) + input2 = random({1, 8, 4, 1}, type: {:f, 32}) + + assert {_, predict_fn} = Axon.build(model2) + + assert_equal( + predict_fn.(%{}, input2), + apply(Axon.Layers, :blur_pool, [input2]) + ) + end + + test "computes forward pass with output policy" do + model = apply(Axon, :blur_pool, [Axon.input("input", shape: {nil, 32, 32, 1})]) + policy = AMP.create_policy(output: {:bf, 16}) + mp_model = AMP.apply_policy(model, policy) + + input = random({1, 32, 32, 1}) + + assert {init_fn, predict_fn} = Axon.build(mp_model) + + assert Nx.type(predict_fn.(init_fn.(input, %{}), random({1, 32, 32, 1}))) == + {:bf, 16} + end + + test "computes forward pass with channels last" do + model = + apply(Axon, :blur_pool, [ + Axon.input("input", shape: {nil, 32, 32, 1}), + [channels: :last] + ]) + + inp = random({1, 32, 32, 1}) + + assert {_, predict_fn} = Axon.build(model) + + assert_equal( + predict_fn.(%{}, inp), + apply(Axon.Layers, :blur_pool, [inp, [channels: :last]]) + ) + end + end + @adaptive_pooling_layers [:adaptive_avg_pool, :adaptive_max_pool, :adaptive_lp_pool] describe "adaptive pooling" do @@ -820,7 +873,7 @@ defmodule CompilerTest do for pool <- @adaptive_pooling_layers do model = apply(Axon, pool, [Axon.input("input", shape: {nil, 32, 1})]) - input = Nx.random_uniform({1, 32, 1}) + input = random({1, 32, 1}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -830,7 +883,7 @@ defmodule CompilerTest do test "computes forward pass with default options" do for pool <- @adaptive_pooling_layers do model1 = apply(Axon, pool, [Axon.input("input", shape: {nil, 32, 1})]) - input1 = Nx.random_uniform({1, 32, 1}, type: {:f, 32}) + input1 = random({1, 32, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model1) @@ -840,7 +893,7 @@ defmodule CompilerTest do ) model2 = apply(Axon, pool, [Axon.input("input", shape: {nil, 8, 4, 1})]) - input2 = Nx.random_uniform({1, 8, 4, 1}, type: {:f, 32}) + input2 = random({1, 8, 4, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model2) @@ -850,7 +903,7 @@ defmodule CompilerTest do ) model3 = apply(Axon, pool, [Axon.input("input", shape: {nil, 8, 4, 2, 1})]) - input3 = Nx.random_uniform({1, 8, 4, 2, 1}, type: {:f, 32}) + input3 = random({1, 8, 4, 2, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model3) @@ -865,21 +918,21 @@ defmodule CompilerTest do for pool <- @adaptive_pooling_layers do opts1 = [output_size: 27] model1 = apply(Axon, pool, [Axon.input("input", shape: {nil, 32, 1}), opts1]) - input1 = Nx.random_uniform({1, 32, 1}, type: {:f, 32}) + input1 = random({1, 32, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model1) assert_equal(predict_fn.(%{}, input1), apply(Axon.Layers, pool, [input1, opts1])) opts2 = [output_size: {2, 3}] model2 = apply(Axon, pool, [Axon.input("input", shape: {nil, 8, 4, 1}), opts2]) - input2 = Nx.random_uniform({1, 8, 4, 1}, type: {:f, 32}) + input2 = random({1, 8, 4, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model2) assert_equal(predict_fn.(%{}, input2), apply(Axon.Layers, pool, [input2, opts2])) opts3 = [output_size: {4, 3, 1}] model3 = apply(Axon, pool, [Axon.input("input", shape: {nil, 8, 4, 2, 1}), opts3]) - input3 = Nx.random_uniform({1, 8, 4, 2, 1}, type: {:f, 32}) + input3 = random({1, 8, 4, 2, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model3) assert_equal(predict_fn.(%{}, input3), apply(Axon.Layers, pool, [input3, opts3])) @@ -892,7 +945,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 32, 1}) + input = random({1, 32, 1}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -907,7 +960,7 @@ defmodule CompilerTest do [channels: :last, output_size: {27}] ]) - inp = Nx.random_uniform({1, 32, 1}) + inp = random({1, 32, 1}) assert {_, predict_fn} = Axon.build(model) @@ -926,7 +979,7 @@ defmodule CompilerTest do for pool <- @global_pooling_layers do model = apply(Axon, pool, [Axon.input("input", shape: {nil, 1, 32})]) - input = Nx.random_uniform({1, 1, 32}) + input = random({1, 1, 32}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -936,19 +989,19 @@ defmodule CompilerTest do test "computes forward pass with default options" do for pool <- @global_pooling_layers do model1 = apply(Axon, pool, [Axon.input("input", shape: {nil, 1, 4})]) - input1 = Nx.random_uniform({1, 1, 4}, type: {:f, 32}) + input1 = random({1, 1, 4}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model1) assert_equal(predict_fn.(%{}, input1), apply(Axon.Layers, pool, [input1])) model2 = apply(Axon, pool, [Axon.input("input", shape: {nil, 1, 2, 2})]) - input2 = Nx.random_uniform({1, 1, 2, 2}, type: {:f, 32}) + input2 = random({1, 1, 2, 2}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model2) assert_equal(predict_fn.(%{}, input2), apply(Axon.Layers, pool, [input2])) model3 = apply(Axon, pool, [Axon.input("input", shape: {nil, 1, 2, 2, 1})]) - input3 = Nx.random_uniform({1, 1, 2, 2, 1}, type: {:f, 32}) + input3 = random({1, 1, 2, 2, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model3) assert_equal(predict_fn.(%{}, input3), apply(Axon.Layers, pool, [input3])) @@ -959,7 +1012,7 @@ defmodule CompilerTest do for pool <- @global_pooling_layers do opts1 = [keep_axes: true] model1 = apply(Axon, pool, [Axon.input("input", shape: {nil, 1, 2}), opts1]) - input1 = Nx.random_uniform({1, 1, 2}, type: {:f, 32}) + input1 = random({1, 1, 2}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model1) assert_equal(predict_fn.(%{}, input1), apply(Axon.Layers, pool, [input1, opts1])) @@ -972,7 +1025,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -993,7 +1046,7 @@ defmodule CompilerTest do [channels: :last, keep_axes: false] ]) - inp = Nx.random_uniform({1, 32, 1}) + inp = random({1, 32, 1}) assert {_, predict_fn} = Axon.build(model1) @@ -1023,7 +1076,7 @@ defmodule CompilerTest do [name: "dropout", seed: 0] ]) - input = Nx.random_uniform({1, 1, 32}) + input = random({1, 1, 32}) assert {init_fn, _predict_fn} = Axon.build(model, mode: :train) assert %{"dropout" => %{"key" => key}} = init_fn.(input, %{}) @@ -1039,7 +1092,7 @@ defmodule CompilerTest do [name: "dropout", seed: 0] ]) - input = Nx.random_uniform({1, 1, 32}) + input = random({1, 1, 32}) assert {init_fn, predict_fn} = Axon.build(model, mode: :train) @@ -1059,7 +1112,7 @@ defmodule CompilerTest do [rate: 0.5, name: "dropout", seed: 0] ]) - input = Nx.random_uniform({1, 16, 32}) + input = random({1, 16, 32}) assert {init_fn, predict_fn} = Axon.build(model, mode: :train) @@ -1074,7 +1127,7 @@ defmodule CompilerTest do test "computes forward pass with default options" do for dropout <- @dropout_layers do model1 = apply(Axon, dropout, [Axon.input("input", shape: {nil, 32, 32})]) - input1 = Nx.random_uniform({1, 32, 32}, type: {:f, 32}) + input1 = random({1, 32, 32}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1, mode: :train) %{prediction: result1} = predict_fn.(init_fn.(input1, %{}), input1) @@ -1084,7 +1137,7 @@ defmodule CompilerTest do assert_not_equal(result1, input1) model2 = apply(Axon, dropout, [Axon.input("input", shape: {nil, 1, 8, 4})]) - input2 = Nx.random_uniform({1, 1, 8, 4}, type: {:f, 32}) + input2 = random({1, 1, 8, 4}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model2, mode: :train) %{prediction: result2} = predict_fn.(init_fn.(input2, %{}), input2) @@ -1094,7 +1147,7 @@ defmodule CompilerTest do assert_not_equal(result2, input2) model3 = apply(Axon, dropout, [Axon.input("input", shape: {nil, 1, 8, 4, 2})]) - input3 = Nx.random_uniform({1, 1, 8, 4, 2}, type: {:f, 32}) + input3 = random({1, 1, 8, 4, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model3, mode: :train) %{prediction: result3} = predict_fn.(init_fn.(input3, %{}), input3) @@ -1109,7 +1162,7 @@ defmodule CompilerTest do for dropout <- @dropout_layers do opts1 = [rate: 0.5] model1 = apply(Axon, dropout, [Axon.input("input", shape: {nil, 32, 128}), opts1]) - input1 = Nx.random_uniform({1, 32, 128}, type: {:f, 32}) + input1 = random({1, 32, 128}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1, mode: :train) @@ -1127,7 +1180,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 32}) + input = random({1, 1, 32}) assert {init_fn, predict_fn} = Axon.build(mp_model, mode: :train) assert Nx.type(predict_fn.(init_fn.(input, %{}), input).prediction) == {:bf, 16} @@ -1137,7 +1190,7 @@ defmodule CompilerTest do test "not present in inference mode" do for dropout <- @dropout_layers do model = apply(Axon, dropout, [Axon.input("input", shape: {nil, 1, 32})]) - input = Nx.random_uniform({1, 1, 32}) + input = random({1, 1, 32}) {init_fn, predict_fn} = Axon.build(model) assert_equal(predict_fn.(init_fn.(input, %{}), input), input) @@ -1148,7 +1201,7 @@ defmodule CompilerTest do for dropout <- @dropout_layers do input = Axon.input("input", shape: {nil, 1, 32}) model = Axon.add([input, apply(Axon, dropout, [input])]) - input = Nx.random_uniform({1, 1, 32}) + input = random({1, 1, 32}) {init_fn, _predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -1196,21 +1249,21 @@ defmodule CompilerTest do test "computes forward pass with default options" do model1 = Axon.input("input", shape: {nil, 1, 2}) |> Axon.conv(2, name: "conv") - input1 = Nx.random_uniform({1, 1, 2}, type: {:f, 32}) + input1 = random({1, 1, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input1, %{}) assert_equal(predict_fn.(params, input1), Axon.Layers.conv(input1, kernel, bias)) model2 = Axon.input("input", shape: {nil, 1, 2, 2}) |> Axon.conv(3, name: "conv") - input2 = Nx.random_uniform({1, 1, 2, 2}, type: {:f, 32}) + input2 = random({1, 1, 2, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model2) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input2, %{}) assert_equal(predict_fn.(params, input2), Axon.Layers.conv(input2, kernel, bias)) model3 = Axon.input("input", shape: {nil, 1, 2, 2, 2}) |> Axon.conv(4, name: "conv") - input3 = Nx.random_uniform({1, 1, 2, 2, 2}, type: {:f, 32}) + input3 = random({1, 1, 2, 2, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model3) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input3, %{}) @@ -1224,7 +1277,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 2, 1}) |> Axon.conv(2, [name: "conv", kernel_size: 2] ++ opts1) - input1 = Nx.random_uniform({1, 2, 1}) + input1 = random({1, 2, 1}) assert {init_fn, predict_fn} = Axon.build(model1) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input1, %{}) @@ -1236,7 +1289,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 4, 4, 1}) |> Axon.conv(2, [name: "conv", kernel_size: 2] ++ opts2) - input2 = Nx.random_uniform({1, 4, 4, 1}) + input2 = random({1, 4, 4, 1}) assert {init_fn, predict_fn} = Axon.build(model2) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input2, %{}) @@ -1248,7 +1301,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 2, 2, 2, 1}) |> Axon.conv(4, [name: "conv", kernel_size: {2, 1, 1}] ++ opts3) - input3 = Nx.random_uniform({1, 2, 2, 2, 1}) + input3 = random({1, 2, 2, 2, 1}) assert {init_fn, predict_fn} = Axon.build(model3) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input3, %{}) @@ -1263,7 +1316,7 @@ defmodule CompilerTest do assert {init_fn, predict_fn} = Axon.build(model) - input = Nx.random_uniform({1, 1, 32}) + input = random({1, 1, 32}) backward = fn params, input -> Nx.Defn.grad(params, &Nx.mean(predict_fn.(&1, input))) @@ -1294,7 +1347,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 32}) + input = random({1, 1, 32}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -1304,7 +1357,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 1, 2}) |> Axon.conv(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, _} = Axon.build(model) assert %{"conv" => %{"kernel" => _} = conv_params} = init_fn.(input, %{}) @@ -1315,7 +1368,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 1, 2}) |> Axon.conv(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(model) assert %{"conv" => %{"kernel" => k}} = params = init_fn.(input, %{}) @@ -1326,7 +1379,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 3, 3, 6}) |> Axon.conv(2, name: "conv", channels: :last) - input = Nx.random_uniform({1, 3, 3, 6}) + input = random({1, 3, 3, 6}) assert {init_fn, predict_fn} = Axon.build(model) assert %{"conv" => %{"kernel" => k, "bias" => b}} = params = init_fn.(input, %{}) @@ -1367,7 +1420,7 @@ defmodule CompilerTest do test "initializes in default case" do model = Axon.input("input", shape: {nil, 2, 2, 3}) |> Axon.depthwise_conv(3, name: "conv") - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _} = Axon.build(model) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(input, %{}) @@ -1382,7 +1435,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 2, 2, 3}) |> Axon.depthwise_conv(3, name: "conv", kernel_initializer: :zeros) - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _predict_fn} = Axon.build(model1) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(input, %{}) @@ -1403,14 +1456,14 @@ defmodule CompilerTest do test "computes forward pass with default options" do model1 = Axon.input("input", shape: {nil, 1, 8}) |> Axon.depthwise_conv(3, name: "conv") - input1 = Nx.random_uniform({1, 1, 8}, type: {:f, 32}) + input1 = random({1, 1, 8}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input1, %{}) assert_equal(predict_fn.(params, input1), Axon.Layers.depthwise_conv(input1, kernel, bias)) model2 = Axon.input("input", shape: {nil, 1, 2, 2}) |> Axon.depthwise_conv(4, name: "conv") - input2 = Nx.random_uniform({1, 1, 2, 2}, type: {:f, 32}) + input2 = random({1, 1, 2, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model2) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input2, %{}) @@ -1419,7 +1472,7 @@ defmodule CompilerTest do model3 = Axon.input("input", shape: {nil, 1, 2, 2, 2}) |> Axon.depthwise_conv(5, name: "conv") - input3 = Nx.random_uniform({1, 1, 2, 2, 2}, type: {:f, 32}) + input3 = random({1, 1, 2, 2, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model3) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input3, %{}) @@ -1433,7 +1486,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 8, 1}) |> Axon.depthwise_conv(1, [name: "conv", kernel_size: 2] ++ opts1) - input1 = Nx.random_uniform({1, 8, 1}) + input1 = random({1, 8, 1}) assert {init_fn, predict_fn} = Axon.build(model1) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input1, %{}) @@ -1449,7 +1502,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 4, 4, 1}) |> Axon.depthwise_conv(8, [name: "conv", kernel_size: 2] ++ opts2) - input2 = Nx.random_uniform({1, 4, 4, 1}) + input2 = random({1, 4, 4, 1}) assert {init_fn, predict_fn} = Axon.build(model2) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input2, %{}) @@ -1465,7 +1518,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3, 2, 2, 1}) |> Axon.depthwise_conv(2, [name: "conv", kernel_size: {2, 1, 1}] ++ opts3) - input3 = Nx.random_uniform({1, 3, 2, 2, 1}) + input3 = random({1, 3, 2, 2, 1}) assert {init_fn, predict_fn} = Axon.build(model3) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input3, %{}) @@ -1482,7 +1535,7 @@ defmodule CompilerTest do |> Axon.depthwise_conv(1, name: "conv") |> Axon.freeze() - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -1502,7 +1555,7 @@ defmodule CompilerTest do policy = AMP.create_policy(params: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, _} = Axon.build(mp_model) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(input, %{}) @@ -1515,7 +1568,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -1526,7 +1579,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 1, 2}) |> Axon.depthwise_conv(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, _} = Axon.build(model) assert %{"conv" => %{"kernel" => _} = conv_params} = init_fn.(input, %{}) @@ -1538,7 +1591,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 1, 2}) |> Axon.depthwise_conv(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(model) assert %{"conv" => %{"kernel" => k}} = params = init_fn.(input, %{}) @@ -1550,7 +1603,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3, 3, 6}) |> Axon.depthwise_conv(2, name: "conv", channels: :last) - input = Nx.random_uniform({1, 3, 3, 6}) + input = random({1, 3, 3, 6}) assert {init_fn, predict_fn} = Axon.build(model) assert %{"conv" => %{"kernel" => k, "bias" => b}} = params = init_fn.(input, %{}) @@ -1589,7 +1642,7 @@ defmodule CompilerTest do test "initializes in default case" do model = Axon.input("input", shape: {nil, 2, 2, 3}) |> Axon.conv_transpose(32, name: "conv") - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _} = Axon.build(model) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(input, %{}) @@ -1604,7 +1657,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 2, 2, 3}) |> Axon.conv_transpose(32, name: "conv", kernel_initializer: :zeros) - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _predict_fn} = Axon.build(model1) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(input, %{}) @@ -1625,14 +1678,14 @@ defmodule CompilerTest do test "computes forward pass with default options" do model1 = Axon.input("input", shape: {nil, 1, 4}) |> Axon.conv_transpose(3, name: "conv") - input1 = Nx.random_uniform({1, 1, 4}, type: {:f, 32}) + input1 = random({1, 1, 4}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input1, %{}) assert_equal(predict_fn.(params, input1), Axon.Layers.conv_transpose(input1, kernel, bias)) model2 = Axon.input("input", shape: {nil, 1, 4, 4}) |> Axon.conv_transpose(4, name: "conv") - input2 = Nx.random_uniform({1, 1, 4, 4}, type: {:f, 32}) + input2 = random({1, 1, 4, 4}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model2) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input2, %{}) @@ -1641,7 +1694,7 @@ defmodule CompilerTest do model3 = Axon.input("input", shape: {nil, 1, 2, 2, 2}) |> Axon.conv_transpose(5, name: "conv") - input3 = Nx.random_uniform({1, 1, 2, 2, 2}, type: {:f, 32}) + input3 = random({1, 1, 2, 2, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model3) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input3, %{}) @@ -1655,7 +1708,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 4, 1}) |> Axon.conv_transpose(1, [name: "conv", kernel_size: 2] ++ opts1) - input1 = Nx.random_uniform({1, 4, 1}) + input1 = random({1, 4, 1}) assert {init_fn, predict_fn} = Axon.build(model1) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input1, %{}) @@ -1671,7 +1724,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 4, 4, 1}) |> Axon.conv_transpose(8, [name: "conv", kernel_size: 2] ++ opts2) - input2 = Nx.random_uniform({1, 4, 4, 1}) + input2 = random({1, 4, 4, 1}) assert {init_fn, predict_fn} = Axon.build(model2) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input2, %{}) @@ -1687,7 +1740,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 2, 2, 2, 1}) |> Axon.conv_transpose(2, [name: "conv", kernel_size: {2, 1, 1}] ++ opts3) - input3 = Nx.random_uniform({1, 2, 2, 2, 1}) + input3 = random({1, 2, 2, 2, 1}) assert {init_fn, predict_fn} = Axon.build(model3) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input3, %{}) @@ -1704,7 +1757,7 @@ defmodule CompilerTest do |> Axon.conv_transpose(1, name: "conv") |> Axon.freeze() - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -1724,7 +1777,7 @@ defmodule CompilerTest do policy = AMP.create_policy(params: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, _} = Axon.build(mp_model) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(input, %{}) @@ -1737,7 +1790,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -1748,7 +1801,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 1, 2}) |> Axon.conv_transpose(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, _} = Axon.build(model) assert %{"conv" => %{"kernel" => _} = conv_params} = init_fn.(input, %{}) @@ -1760,7 +1813,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 1, 2}) |> Axon.conv_transpose(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(model) assert %{"conv" => %{"kernel" => k}} = params = init_fn.(input, %{}) @@ -1772,7 +1825,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3, 3, 6}) |> Axon.conv_transpose(2, name: "conv", channels: :last) - input = Nx.random_uniform({1, 3, 3, 6}) + input = random({1, 3, 3, 6}) assert {init_fn, predict_fn} = Axon.build(model) assert %{"conv" => %{"kernel" => k, "bias" => b}} = params = init_fn.(input, %{}) @@ -1788,7 +1841,7 @@ defmodule CompilerTest do test "initializes in default case" do model = Axon.input("input", shape: {nil, 2, 2, 3}) |> Axon.separable_conv2d(3, name: "conv") - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _} = Axon.build(model) @@ -1816,7 +1869,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 2, 2, 3}) |> Axon.separable_conv2d(3, name: "conv", kernel_initializer: :zeros) - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _} = Axon.build(model1) @@ -1861,7 +1914,7 @@ defmodule CompilerTest do test "computes forward pass with default options" do model = Axon.input("input", shape: {nil, 3, 2, 2}) |> Axon.separable_conv2d(3, name: "conv") - input = Nx.random_uniform({1, 3, 2, 2}) + input = random({1, 3, 2, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -1887,7 +1940,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3, 3, 2}) |> Axon.separable_conv2d(3, [name: "conv", kernel_size: {2, 2}] ++ opts) - input = Nx.random_uniform({1, 3, 3, 2}) + input = random({1, 3, 3, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -1912,7 +1965,7 @@ defmodule CompilerTest do |> Axon.separable_conv2d(1, name: "conv") |> Axon.freeze() - input = Nx.random_uniform({1, 1, 3, 2}) + input = random({1, 1, 3, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -1940,7 +1993,7 @@ defmodule CompilerTest do policy = AMP.create_policy(params: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 3, 2}) + input = random({1, 1, 3, 2}) assert {init_fn, _} = Axon.build(mp_model) @@ -1964,7 +2017,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 3, 2}) + input = random({1, 1, 3, 2}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -1975,7 +2028,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 1, 2, 2}) |> Axon.separable_conv2d(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 2, 2}) + input = random({1, 1, 2, 2}) assert {init_fn, _} = Axon.build(model) assert %{"conv" => %{"kernel_1" => _, "kernel_2" => _} = conv_params} = init_fn.(input, %{}) @@ -1988,7 +2041,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 1, 2, 2}) |> Axon.separable_conv2d(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 2, 2}) + input = random({1, 1, 2, 2}) assert {init_fn, predict_fn} = Axon.build(model) assert %{"conv" => %{"kernel_1" => k1, "kernel_2" => k2}} = params = init_fn.(input, %{}) @@ -2004,7 +2057,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3, 3, 6}) |> Axon.separable_conv2d(2, name: "conv", channels: :last) - input = Nx.random_uniform({1, 3, 3, 6}) + input = random({1, 3, 3, 6}) assert {init_fn, predict_fn} = Axon.build(model) @@ -2046,7 +2099,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 3, 2, 2, 3}) |> Axon.separable_conv3d(3, name: "conv") - input = Nx.random_uniform({1, 3, 2, 2, 3}) + input = random({1, 3, 2, 2, 3}) assert {init_fn, _} = Axon.build(model) @@ -2080,7 +2133,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3, 3, 2, 2}) |> Axon.separable_conv3d(3, name: "conv", kernel_initializer: :zeros) - input = Nx.random_uniform({1, 3, 2, 2, 3}) + input = random({1, 3, 2, 2, 3}) assert {init_fn, _} = Axon.build(model1) @@ -2141,7 +2194,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 3, 2, 2, 2}) |> Axon.separable_conv3d(3, name: "conv") - input = Nx.random_uniform({1, 3, 2, 2, 2}) + input = random({1, 3, 2, 2, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -2169,7 +2222,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3, 2, 3, 3}) |> Axon.separable_conv3d(3, [name: "conv", kernel_size: {2, 2, 1}] ++ opts) - input = Nx.random_uniform({1, 3, 2, 3, 3}) + input = random({1, 3, 2, 3, 3}) assert {init_fn, predict_fn} = Axon.build(model) @@ -2196,7 +2249,7 @@ defmodule CompilerTest do |> Axon.separable_conv3d(1, name: "conv") |> Axon.freeze() - input = Nx.random_uniform({1, 1, 3, 2, 2}) + input = random({1, 1, 3, 2, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -2230,7 +2283,7 @@ defmodule CompilerTest do policy = AMP.create_policy(params: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 3, 2, 2}) + input = random({1, 1, 3, 2, 2}) assert {init_fn, _} = Axon.build(mp_model) @@ -2260,7 +2313,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 3, 2, 2}) + input = random({1, 1, 3, 2, 2}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -2271,7 +2324,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 1, 3, 2, 2}) |> Axon.separable_conv3d(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 3, 2, 2}) + input = random({1, 1, 3, 2, 2}) assert {init_fn, _} = Axon.build(model) @@ -2288,7 +2341,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 1, 3, 2, 2}) |> Axon.separable_conv3d(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 3, 2, 2}) + input = random({1, 1, 3, 2, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -2314,7 +2367,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3, 3, 3, 6}) |> Axon.separable_conv3d(2, name: "conv", channels: :last) - input = Nx.random_uniform({1, 3, 3, 3, 6}) + input = random({1, 3, 3, 3, 6}) assert {init_fn, predict_fn} = Axon.build(model) @@ -2368,7 +2421,7 @@ defmodule CompilerTest do if norm != :instance_norm do model1 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2}), [name: "norm"]]) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) assert {init_fn, _predict_fn} = Axon.build(model1) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2380,7 +2433,7 @@ defmodule CompilerTest do model2 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2, 2, 3}), [name: "norm"]]) - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _predict_fn} = Axon.build(model2) @@ -2407,7 +2460,7 @@ defmodule CompilerTest do [name: "norm", gamma_initializer: :zeros] ]) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) assert {init_fn, _predict_fn} = Axon.build(model1) @@ -2429,7 +2482,7 @@ defmodule CompilerTest do [name: "norm", beta_initializer: :zeros] ]) - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _predict_fn} = Axon.build(model2) @@ -2450,7 +2503,7 @@ defmodule CompilerTest do for norm <- @normalization_with_stats_layers do if norm != :instance_norm do model1 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2}), [name: "norm"]]) - input1 = Nx.random_uniform({1, 2}, type: {:f, 32}) + input1 = random({1, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1) @@ -2464,7 +2517,7 @@ defmodule CompilerTest do end model2 = apply(Axon, norm, [Axon.input("input", shape: {nil, 3, 2, 2}), [name: "norm"]]) - input2 = Nx.random_uniform({1, 3, 2, 2}, type: {:f, 32}) + input2 = random({1, 3, 2, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model2) @@ -2486,7 +2539,7 @@ defmodule CompilerTest do model1 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2}), [name: "norm"] ++ opts1]) - input1 = Nx.random_uniform({1, 2}, type: {:f, 32}) + input1 = random({1, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1) @@ -2504,7 +2557,7 @@ defmodule CompilerTest do model2 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2, 2, 3}), [name: "norm"] ++ opts2]) - input2 = Nx.random_uniform({1, 2, 2, 3}, type: {:f, 32}) + input2 = random({1, 2, 2, 3}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model2) @@ -2524,7 +2577,7 @@ defmodule CompilerTest do apply(Axon, norm, [Axon.input("input", shape: {nil, 1, 2}), [name: "norm"]]) |> Axon.freeze() - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -2546,7 +2599,7 @@ defmodule CompilerTest do policy = AMP.create_policy(params: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, _} = Axon.build(mp_model) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2561,7 +2614,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -2577,7 +2630,7 @@ defmodule CompilerTest do if norm != :instance_norm do model1 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2}), [name: "norm"]]) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) assert {init_fn, _predict_fn} = Axon.build(model1) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2589,7 +2642,7 @@ defmodule CompilerTest do model2 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2, 2, 3}), [name: "norm"]]) - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _predict_fn} = Axon.build(model2) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2609,7 +2662,7 @@ defmodule CompilerTest do [name: "norm", gamma_initializer: :zeros] ]) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) assert {init_fn, _predict_fn} = Axon.build(model1) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2624,7 +2677,7 @@ defmodule CompilerTest do [name: "norm", beta_initializer: :zeros] ]) - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _predict_fn} = Axon.build(model2) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2638,7 +2691,7 @@ defmodule CompilerTest do for norm <- @normalization_layers do if norm != :instance_norm do model1 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2}), [name: "norm"]]) - input1 = Nx.random_uniform({1, 2}, type: {:f, 32}) + input1 = random({1, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = params = init_fn.(input1, %{}) @@ -2650,7 +2703,7 @@ defmodule CompilerTest do end model2 = apply(Axon, norm, [Axon.input("input", shape: {nil, 3, 2, 2}), [name: "norm"]]) - input2 = Nx.random_uniform({1, 3, 2, 2}, type: {:f, 32}) + input2 = random({1, 3, 2, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model2) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = params = init_fn.(input2, %{}) @@ -2666,7 +2719,7 @@ defmodule CompilerTest do model1 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2}), [name: "norm"] ++ opts1]) - input1 = Nx.random_uniform({1, 2}, type: {:f, 32}) + input1 = random({1, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = params = init_fn.(input1, %{}) @@ -2682,7 +2735,7 @@ defmodule CompilerTest do model2 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2, 2, 3}), [name: "norm"] ++ opts2]) - input2 = Nx.random_uniform({1, 2, 2, 3}, type: {:f, 32}) + input2 = random({1, 2, 2, 3}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model2) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = params = init_fn.(input2, %{}) @@ -2702,7 +2755,7 @@ defmodule CompilerTest do assert {init_fn, predict_fn} = Axon.build(model) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) backward = fn params, input -> Nx.Defn.grad(params, &Nx.mean(predict_fn.(&1, input))) @@ -2722,7 +2775,7 @@ defmodule CompilerTest do policy = AMP.create_policy(params: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, _} = Axon.build(mp_model) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2737,7 +2790,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -2749,7 +2802,7 @@ defmodule CompilerTest do test "initializes in default case" do model = Axon.input("input", shape: {nil, 3}) |> Axon.group_norm(3, name: "norm") - input = Nx.random_uniform({1, 3}) + input = random({1, 3}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2764,7 +2817,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3}) |> Axon.group_norm(3, name: "norm", gamma_initializer: :zeros) - input = Nx.random_uniform({1, 3}) + input = random({1, 3}) assert {init_fn, _predict_fn} = Axon.build(model1) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2776,7 +2829,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3, 3}) |> Axon.group_norm(3, name: "norm", beta_initializer: :zeros) - input = Nx.random_uniform({1, 3, 3}) + input = random({1, 3, 3}) assert {init_fn, _predict_fn} = Axon.build(model2) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2787,7 +2840,7 @@ defmodule CompilerTest do test "computes forward pass with default options" do model1 = Axon.input("input", shape: {nil, 2}) |> Axon.group_norm(2, name: "norm") - input1 = Nx.random_uniform({1, 2}) + input1 = random({1, 2}) assert {init_fn, predict_fn} = Axon.build(model1) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = params = init_fn.(input1, %{}) @@ -2798,7 +2851,7 @@ defmodule CompilerTest do ) model2 = Axon.input("input", shape: {nil, 2, 2, 3}) |> Axon.group_norm(3, name: "norm") - input2 = Nx.random_uniform({1, 2, 2, 3}) + input2 = random({1, 2, 2, 3}) assert {init_fn, predict_fn} = Axon.build(model2) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = params = init_fn.(input2, %{}) @@ -2815,7 +2868,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 2, 2, 3}) |> Axon.group_norm(3, [name: "norm"] ++ opts) - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, predict_fn} = Axon.build(model) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = params = init_fn.(input, %{}) @@ -2834,7 +2887,7 @@ defmodule CompilerTest do assert {init_fn, predict_fn} = Axon.build(model) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) backward = fn params, input -> Nx.Defn.grad(params, &Nx.mean(predict_fn.(&1, input))) @@ -2852,7 +2905,7 @@ defmodule CompilerTest do policy = AMP.create_policy(params: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) assert {init_fn, _} = Axon.build(mp_model) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2865,7 +2918,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 3}) + input = random({1, 3}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -2876,7 +2929,7 @@ defmodule CompilerTest do test "initializes with no params" do model = Axon.input("input_0", shape: {nil, 32}) |> Axon.flatten() - input = Nx.random_uniform({1, 32}) + input = random({1, 32}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -2884,13 +2937,13 @@ defmodule CompilerTest do test "computes forward pass with default options" do model1 = Axon.input("input_0", shape: {nil, 32}) |> Axon.flatten() - input1 = Nx.random_uniform({1, 32}) + input1 = random({1, 32}) assert {_, predict_fn} = Axon.build(model1) assert_equal(predict_fn.(%{}, input1), Axon.Layers.flatten(input1)) model2 = Axon.input("input", shape: {nil, 3, 32, 32}) |> Axon.flatten() - input2 = Nx.random_uniform({1, 3, 32, 32}) + input2 = random({1, 3, 32, 32}) assert {_, predict_fn} = Axon.build(model2) assert_equal(predict_fn.(%{}, input2), Axon.Layers.flatten(input2)) @@ -2901,7 +2954,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 3}) + input = random({1, 3}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -2912,7 +2965,7 @@ defmodule CompilerTest do test "initializes with no params" do model = Axon.input("input", shape: {nil, 3, 32}) |> Axon.transpose() - input = Nx.random_uniform({1, 3, 32}) + input = random({1, 3, 32}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -2920,13 +2973,13 @@ defmodule CompilerTest do test "computes forward pass with default options" do model1 = Axon.input("input_0", shape: {nil, 32}) |> Axon.transpose([0, 1]) - input1 = Nx.random_uniform({1, 32}) + input1 = random({1, 32}) assert {_, predict_fn} = Axon.build(model1) assert_equal(predict_fn.(%{}, input1), Nx.transpose(input1, axes: [0, 1])) model2 = Axon.input("input", shape: {nil, 3, 32, 32}) |> Axon.transpose([0, 2, 1, 3]) - input2 = Nx.random_uniform({1, 3, 32, 32}) + input2 = random({1, 3, 32, 32}) assert {_, predict_fn} = Axon.build(model2) assert_equal(predict_fn.(%{}, input2), Nx.transpose(input2, axes: [0, 2, 1, 3])) @@ -2944,7 +2997,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 32}) + input = random({1, 32}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -2955,7 +3008,7 @@ defmodule CompilerTest do test "initializes with no params" do model = Axon.input("input", shape: {nil, 1, 32}) |> Axon.reshape({16, 2}) - input = Nx.random_uniform({1, 1, 32}) + input = random({1, 1, 32}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -2963,13 +3016,13 @@ defmodule CompilerTest do test "computes forward pass with default options" do model1 = Axon.input("input_0", shape: {nil, 32}) |> Axon.reshape({16, 2}) - input1 = Nx.random_uniform({1, 32}) + input1 = random({1, 32}) assert {_, predict_fn} = Axon.build(model1) assert_equal(predict_fn.(%{}, input1), Nx.reshape(input1, {1, 16, 2})) model2 = Axon.input("input", shape: {nil, 3, 32, 32}) |> Axon.reshape({3, 16, 2, 32}) - input2 = Nx.random_uniform({1, 3, 32, 32}) + input2 = random({1, 3, 32, 32}) assert {_, predict_fn} = Axon.build(model2) assert_equal(predict_fn.(%{}, input2), Nx.reshape(input2, {1, 3, 16, 2, 32})) @@ -2987,7 +3040,7 @@ defmodule CompilerTest do assert {_, predict_fn} = Axon.build(model) - input = Nx.random_uniform({2, 4, 6}) + input = random({2, 4, 6}) assert_equal(predict_fn.(%{}, input), Nx.reshape(input, {2, 3, 8})) end @@ -2996,7 +3049,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 32}) + input = random({1, 32}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -3007,7 +3060,7 @@ defmodule CompilerTest do test "initializes with no params" do model = Axon.input("input", shape: {nil, 1, 3, 3}) |> Axon.resize({4, 4}) - input = Nx.random_uniform({1, 1, 3, 3}) + input = random({1, 1, 3, 3}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -3015,7 +3068,7 @@ defmodule CompilerTest do test "computes forward pass with default options" do model1 = Axon.input("input", shape: {nil, 1, 3, 3}) |> Axon.resize({4, 4}) - input1 = Nx.random_uniform({1, 1, 3, 3}) + input1 = random({1, 1, 3, 3}) assert {_, predict_fn} = Axon.build(model1) assert_equal(predict_fn.(%{}, input1), Axon.Layers.resize(input1, size: {4, 4})) @@ -3026,7 +3079,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 3, 3}) + input = random({1, 1, 3, 3}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -3040,7 +3093,7 @@ defmodule CompilerTest do |> Axon.lstm(64, name: "lstm") |> Axon.container() - input = Nx.random_uniform({1, 32, 10}) + input = random({1, 32, 10}) assert {init_fn, _predict_fn} = Axon.build(model) @@ -3089,7 +3142,7 @@ defmodule CompilerTest do |> Axon.lstm(64, name: "lstm", kernel_initializer: :zeros) |> Axon.container() - input = Nx.random_uniform({1, 32, 10}) + input = random({1, 32, 10}) assert {init_fn, _predict_fn} = Axon.build(model1) @@ -3171,9 +3224,9 @@ defmodule CompilerTest do |> Axon.lstm(2, name: "lstm", recurrent_initializer: :zeros) |> Axon.container() - input = Nx.random_uniform({1, 8, 2}, type: {:f, 32}) + input = random({1, 8, 2}, type: {:f, 32}) - init_carry = {zeros({1, 1, 2}), zeros({1, 1, 2})} + init_carry = {zeros({1, 2}), zeros({1, 2})} assert {init_fn, predict_fn} = Axon.build(model) @@ -3192,9 +3245,10 @@ defmodule CompilerTest do assert_equal( predict_fn.(params, input), Axon.Layers.dynamic_unroll( - &Axon.Layers.lstm_cell/5, + &Axon.Layers.lstm_cell/6, input, init_carry, + Nx.tensor(0), k, h, b @@ -3213,14 +3267,15 @@ defmodule CompilerTest do ) |> Axon.container() - input1 = Nx.random_uniform({1, 8, 2}, type: {:f, 32}) + input1 = random({1, 8, 2}, type: {:f, 32}) - init_carry1 = {zeros({1, 1, 2}), zeros({1, 1, 2})} + init_carry1 = {zeros({1, 2}), zeros({1, 2})} - cell_fn1 = fn i, c, k, h, b -> + cell_fn1 = fn i, c, mask, k, h, b -> Axon.Layers.lstm_cell( i, c, + mask, k, h, b, @@ -3245,7 +3300,7 @@ defmodule CompilerTest do assert_all_close( predict_fn.(params, input1), - Axon.Layers.dynamic_unroll(cell_fn1, input1, init_carry1, k, h, b) + Axon.Layers.dynamic_unroll(cell_fn1, input1, init_carry1, Nx.tensor(0), k, h, b) ) model2 = @@ -3253,11 +3308,11 @@ defmodule CompilerTest do |> Axon.lstm(2, name: "lstm", unroll: :static, recurrent_initializer: :zeros) |> Axon.container() - input2 = Nx.random_uniform({1, 8, 2}, type: {:f, 32}) + input2 = random({1, 8, 2}, type: {:f, 32}) - init_carry2 = {zeros({1, 1, 2}), zeros({1, 1, 2})} + init_carry2 = {zeros({1, 2}), zeros({1, 2})} - cell_fn2 = &Axon.Layers.lstm_cell/5 + cell_fn2 = &Axon.Layers.lstm_cell/6 assert {init_fn, predict_fn} = Axon.build(model2) @@ -3275,7 +3330,7 @@ defmodule CompilerTest do assert_all_close( predict_fn.(params, input2), - Axon.Layers.static_unroll(cell_fn2, input2, init_carry2, k, h, b) + Axon.Layers.static_unroll(cell_fn2, input2, init_carry2, Nx.tensor(0), k, h, b) ) end @@ -3283,7 +3338,7 @@ defmodule CompilerTest do seq = Axon.input("input", shape: {nil, 8, 2}) {_, carry} = seq |> Axon.lstm(2, name: "encode", recurrent_initializer: :zeros) model = Axon.lstm(seq, carry, 2, name: "decode") |> Axon.container() - input = Nx.random_uniform({1, 8, 2}) + input = random({1, 8, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -3291,12 +3346,20 @@ defmodule CompilerTest do {ei, eh, eb} = enc {di, dh, db} = dec - init_carry = {zeros({1, 1, 2}), zeros({1, 1, 2})} + init_carry = {zeros({1, 2}), zeros({1, 2})} {_, carry} = - Axon.Layers.dynamic_unroll(&Axon.Layers.lstm_cell/5, inp, init_carry, ei, eh, eb) + Axon.Layers.dynamic_unroll( + &Axon.Layers.lstm_cell/6, + inp, + init_carry, + Nx.tensor(0), + ei, + eh, + eb + ) - Axon.Layers.dynamic_unroll(&Axon.Layers.lstm_cell/5, inp, carry, di, dh, db) + Axon.Layers.dynamic_unroll(&Axon.Layers.lstm_cell/6, inp, carry, Nx.tensor(0), di, dh, db) end assert %{ @@ -3328,7 +3391,7 @@ defmodule CompilerTest do |> Axon.lstm(2, name: "lstm", use_bias: false) |> Axon.container() - input = Nx.random_uniform({1, 2, 1}) + input = random({1, 2, 1}) assert {init_fn, _} = Axon.build(model) @@ -3349,7 +3412,7 @@ defmodule CompilerTest do |> Axon.lstm(2, name: "lstm", use_bias: false, recurrent_initializer: :zeros) |> Axon.container() - input = Nx.random_uniform({1, 2, 1}) + input = random({1, 2, 1}) assert {init_fn, predict_fn} = Axon.build(model) @@ -3361,14 +3424,37 @@ defmodule CompilerTest do } = params = init_fn.(input, %{}) b = {Nx.tensor(0), Nx.tensor(0), Nx.tensor(0), Nx.tensor(0)} - c = {zeros({1, 1, 2}), zeros({1, 1, 2})} + c = {zeros({1, 2}), zeros({1, 2})} assert_equal( predict_fn.(params, input), - Axon.Layers.dynamic_unroll(&Axon.Layers.lstm_cell/5, input, c, k, h, b) + Axon.Layers.dynamic_unroll(&Axon.Layers.lstm_cell/6, input, c, Nx.tensor(0), k, h, b) ) end + test "mask actually works" do + sequence = Axon.input("review") + mask = Axon.mask(sequence, 0) + embedded = sequence |> Axon.embedding(2048, 64) + {rnn_sequence, _state} = Axon.lstm(embedded, 64, mask: mask) + + {init_fn, predict_fn} = Axon.build(rnn_sequence) + params = init_fn.(Nx.template({64, 64}, :s64), %{}) + + input = Nx.tensor([[1, 2, 3, 4]]) + padded = Nx.pad(input, 0, [{0, 0, 0}, {0, 60, 0}]) + out = predict_fn.(params, padded) + + last_token = out[[.., 3, ..]] + + for i <- 4..63 do + # all eos tokens will be ignored so we just propagate the value + # to the next token and thus these should all be the same as the + # last non eos token + assert_equal(last_token, out[[.., i, ..]]) + end + end + # TODO(seanmor5): https://github.com/elixir-nx/axon/issues/90 # test "initializes with parameter policy" do # end @@ -3394,7 +3480,7 @@ defmodule CompilerTest do |> Axon.conv_lstm(out_channel_n, name: "convlstm") |> Axon.container() - input = Nx.random_uniform({1, 10, 3, 6, 6}) + input = random({1, 10, 3, 6, 6}) assert {init_fn, _predict_fn} = Axon.build(model) @@ -3428,7 +3514,7 @@ defmodule CompilerTest do _heigth = 6 } - input = Nx.random_uniform({1, 10, 3, 6, 6}) + input = random({1, 10, 3, 6, 6}) out_channel_n = 4 @@ -3502,7 +3588,7 @@ defmodule CompilerTest do input = input_shape |> put_elem(0, batch_real) - |> Nx.random_uniform(type: {:f, 32}) + |> random(type: {:f, 32}) init_carry = {zeros(hidden_shape_real), zeros(hidden_shape_real)} @@ -3523,9 +3609,10 @@ defmodule CompilerTest do assert_equal( predict_fn.(params, input), Axon.Layers.dynamic_unroll( - &Axon.Layers.conv_lstm_cell/5, + &Axon.Layers.conv_lstm_cell/6, input, init_carry, + Nx.tensor(0), k, h, b @@ -3558,7 +3645,7 @@ defmodule CompilerTest do input = input_shape |> put_elem(0, batch_real) - |> Nx.random_uniform(type: {:f, 32}) + |> random(type: {:f, 32}) init_carry = {zeros(hidden_shape_real), zeros(hidden_shape_real)} @@ -3579,9 +3666,10 @@ defmodule CompilerTest do assert_equal( predict_fn.(params, input), Axon.Layers.static_unroll( - &Axon.Layers.conv_lstm_cell/5, + &Axon.Layers.conv_lstm_cell/6, input, init_carry, + Nx.tensor(0), k, h, b @@ -3612,14 +3700,15 @@ defmodule CompilerTest do input1 = input_shape |> put_elem(0, batch_real) - |> Nx.random_uniform(type: {:f, 32}) + |> random(type: {:f, 32}) init_carry1 = {zeros(hidden_shape_real), zeros(hidden_shape_real)} - cell_fn1 = fn i, c, k, h, b -> + cell_fn1 = fn i, c, mask, k, h, b -> Axon.Layers.conv_lstm_cell( i, c, + mask, k, h, b @@ -3642,7 +3731,7 @@ defmodule CompilerTest do assert_equal( predict_fn.(params, input1), - Axon.Layers.dynamic_unroll(cell_fn1, input1, init_carry1, k, h, b) + Axon.Layers.dynamic_unroll(cell_fn1, input1, init_carry1, Nx.tensor(0), k, h, b) ) model2 = @@ -3657,11 +3746,11 @@ defmodule CompilerTest do input2 = input_shape |> put_elem(0, batch_real) - |> Nx.random_uniform(type: {:f, 32}) + |> random(type: {:f, 32}) init_carry2 = {zeros(hidden_shape_real), zeros(hidden_shape_real)} - cell_fn2 = &Axon.Layers.conv_lstm_cell/5 + cell_fn2 = &Axon.Layers.conv_lstm_cell/6 assert {init_fn, predict_fn} = Axon.build(model2) @@ -3679,7 +3768,7 @@ defmodule CompilerTest do assert_equal( predict_fn.(params, input2), - Axon.Layers.static_unroll(cell_fn2, input2, init_carry2, k, h, b) + Axon.Layers.static_unroll(cell_fn2, input2, init_carry2, Nx.tensor(0), k, h, b) ) end @@ -3708,7 +3797,7 @@ defmodule CompilerTest do input = input_shape |> put_elem(0, batch_real) - |> Nx.random_uniform(type: {:f, 32}) + |> random(type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model) @@ -3720,15 +3809,24 @@ defmodule CompilerTest do {_, carry} = Axon.Layers.dynamic_unroll( - &Axon.Layers.conv_lstm_cell/5, + &Axon.Layers.conv_lstm_cell/6, inp, init_carry, + Nx.tensor(0), ei, eh, eb ) - Axon.Layers.dynamic_unroll(&Axon.Layers.conv_lstm_cell/5, inp, carry, di, dh, db) + Axon.Layers.dynamic_unroll( + &Axon.Layers.conv_lstm_cell/6, + inp, + carry, + Nx.tensor(0), + di, + dh, + db + ) end assert %{ @@ -3779,7 +3877,7 @@ defmodule CompilerTest do input = input_shape |> put_elem(0, batch_real) - |> Nx.random_uniform(type: {:f, 32}) + |> random(type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model) @@ -3796,7 +3894,7 @@ defmodule CompilerTest do assert_equal( predict_fn.(params, input), - Axon.Layers.dynamic_unroll(&Axon.Layers.conv_lstm_cell/5, input, c, k, h, b) + Axon.Layers.dynamic_unroll(&Axon.Layers.conv_lstm_cell/6, input, c, Nx.tensor(0), k, h, b) ) end end @@ -3806,7 +3904,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 32, 10}) |> Axon.gru(64, name: "gru") |> Axon.container() - input = Nx.random_uniform({1, 32, 10}) + input = random({1, 32, 10}) assert {init_fn, _} = Axon.build(model) @@ -3841,7 +3939,7 @@ defmodule CompilerTest do end test "initializes with custom initializers" do - input = Nx.random_uniform({1, 32, 10}) + input = random({1, 32, 10}) model1 = Axon.input("input", shape: {nil, 32, 10}) @@ -3912,8 +4010,8 @@ defmodule CompilerTest do |> Axon.gru(2, name: "gru", recurrent_initializer: :zeros) |> Axon.container() - input = Nx.random_uniform({1, 8, 2}) - carry = {zeros({1, 1, 2})} + input = random({1, 8, 2}) + carry = {zeros({1, 2})} assert {init_fn, predict_fn} = Axon.build(model) @@ -3927,7 +4025,7 @@ defmodule CompilerTest do assert_equal( predict_fn.(params, input), - Axon.Layers.dynamic_unroll(&Axon.Layers.gru_cell/5, input, carry, k, h, b) + Axon.Layers.dynamic_unroll(&Axon.Layers.gru_cell/6, input, carry, Nx.tensor(0), k, h, b) ) end @@ -3942,13 +4040,14 @@ defmodule CompilerTest do ) |> Axon.container() - input1 = Nx.random_uniform({1, 8, 2}) - carry1 = {zeros({1, 1, 2})} + input1 = random({1, 8, 2}) + carry1 = {zeros({1, 2})} - cell_fn1 = fn i, c, k, h, b -> + cell_fn1 = fn i, c, mask, k, h, b -> Axon.Layers.gru_cell( i, c, + mask, k, h, b, @@ -3971,9 +4070,9 @@ defmodule CompilerTest do h = {whr, whz, whn} b = {br, bz, bin, bhn} - assert_equal( + assert_all_close( predict_fn.(params, input1), - Axon.Layers.dynamic_unroll(cell_fn1, input1, carry1, k, h, b) + Axon.Layers.dynamic_unroll(cell_fn1, input1, carry1, Nx.tensor(0), k, h, b) ) model2 = @@ -3981,8 +4080,8 @@ defmodule CompilerTest do |> Axon.gru(2, name: "gru", recurrent_initializer: :zeros, unroll: :static) |> Axon.container() - input2 = Nx.random_uniform({1, 8, 2}) - carry2 = {zeros({1, 1, 2})} + input2 = random({1, 8, 2}) + carry2 = {zeros({1, 2})} assert {init_fn, predict_fn} = Axon.build(model2) @@ -3998,9 +4097,9 @@ defmodule CompilerTest do h = {whr, whz, whn} b = {br, bz, bin, bhn} - assert_equal( + assert_all_close( predict_fn.(params, input2), - Axon.Layers.static_unroll(&Axon.Layers.gru_cell/5, input2, carry2, k, h, b) + Axon.Layers.static_unroll(&Axon.Layers.gru_cell/6, input2, carry2, Nx.tensor(0), k, h, b) ) end @@ -4008,16 +4107,26 @@ defmodule CompilerTest do seq = Axon.input("input", shape: {nil, 8, 2}) {_, carry} = Axon.gru(seq, 2, name: "encode", recurrent_initializer: :zeros) model = Axon.gru(seq, carry, 2, name: "decode") |> Axon.container() - input = Nx.random_uniform({1, 8, 2}) - carry = {zeros({1, 1, 2})} + + input = random({1, 8, 2}) + carry = {zeros({1, 2})} equiv_fn = fn inp, enc, dec -> {ei, eh, eb} = enc {di, dh, db} = dec - {_, carry} = Axon.Layers.dynamic_unroll(&Axon.Layers.gru_cell/5, inp, carry, ei, eh, eb) + {_, carry} = + Axon.Layers.dynamic_unroll( + &Axon.Layers.gru_cell/6, + inp, + carry, + Nx.tensor(0), + ei, + eh, + eb + ) - Axon.Layers.dynamic_unroll(&Axon.Layers.gru_cell/5, inp, carry, di, dh, db) + Axon.Layers.dynamic_unroll(&Axon.Layers.gru_cell/6, inp, carry, Nx.tensor(0), di, dh, db) end assert {init_fn, predict_fn} = Axon.build(model) @@ -4047,7 +4156,7 @@ defmodule CompilerTest do |> Axon.gru(2, name: "gru", use_bias: false) |> Axon.container() - input = Nx.random_uniform({1, 2, 1}) + input = random({1, 2, 1}) assert {init_fn, _} = Axon.build(model) @@ -4068,7 +4177,7 @@ defmodule CompilerTest do |> Axon.gru(2, name: "gru", use_bias: false, recurrent_initializer: :zeros) |> Axon.container() - input = Nx.random_uniform({1, 2, 1}) + input = random({1, 2, 1}) assert {init_fn, predict_fn} = Axon.build(model) assert %{ @@ -4079,14 +4188,37 @@ defmodule CompilerTest do } = params = init_fn.(input, %{}) b = {Nx.tensor(0), Nx.tensor(0), Nx.tensor(0), Nx.tensor(0)} - c = {zeros({1, 1, 2})} + c = {zeros({1, 2})} assert_all_close( predict_fn.(params, input), - Axon.Layers.dynamic_unroll(&Axon.Layers.gru_cell/5, input, c, k, h, b) + Axon.Layers.dynamic_unroll(&Axon.Layers.gru_cell/6, input, c, Nx.tensor(0), k, h, b) ) end + test "mask actually works" do + sequence = Axon.input("review") + mask = Axon.mask(sequence, 0) + embedded = sequence |> Axon.embedding(2048, 64) + {rnn_sequence, _state} = Axon.gru(embedded, 64, mask: mask) + + {init_fn, predict_fn} = Axon.build(rnn_sequence) + params = init_fn.(Nx.template({64, 64}, :s64), %{}) + + input = Nx.tensor([[1, 2, 3, 4]]) + padded = Nx.pad(input, 0, [{0, 0, 0}, {0, 60, 0}]) + out = predict_fn.(params, padded) + + last_token = out[[.., 3, ..]] + + for i <- 4..63 do + # all eos tokens will be ignored so we just propagate the value + # to the next token and thus these should all be the same as the + # last non eos token + assert_equal(last_token, out[[.., i, ..]]) + end + end + # TODO(seanmor5): https://github.com/elixir-nx/axon/issues/90 # test "" # TODO(seanmor5): https://github.com/elixir-nx/axon/issues/90 @@ -4110,8 +4242,8 @@ defmodule CompilerTest do ]) input = %{ - "input_0" => Nx.random_uniform({1, 32}), - "input_1" => Nx.random_uniform({1, 32}) + "input_0" => random({1, 32}), + "input_1" => random({1, 32}) } assert {init_fn, _} = Axon.build(model) @@ -4127,8 +4259,8 @@ defmodule CompilerTest do Axon.input("input_1", shape: {nil, 32}) ]) - input1_1 = Nx.random_uniform({1, 32}) - input1_2 = Nx.random_uniform({1, 32}) + input1_1 = random({1, 32}) + input1_2 = random({1, 32}) assert {_, predict_fn} = Axon.build(model1) assert_all_close( @@ -4145,9 +4277,9 @@ defmodule CompilerTest do ] ]) - input2_1 = Nx.random_uniform({1, 32}) - input2_2 = Nx.random_uniform({1, 32}) - input2_3 = Nx.random_uniform({1, 32}) + input2_1 = random({1, 32}) + input2_2 = random({1, 32}) + input2_3 = random({1, 32}) assert {_, predict_fn} = Axon.build(model2) assert_all_close( @@ -4169,8 +4301,8 @@ defmodule CompilerTest do mp_model = AMP.apply_policy(model, policy) input = %{ - "input_0" => Nx.random_uniform({1, 32}), - "input_1" => Nx.random_uniform({1, 32}) + "input_0" => random({1, 32}), + "input_1" => random({1, 32}) } assert {_, predict_fn} = Axon.build(mp_model) @@ -4179,8 +4311,8 @@ defmodule CompilerTest do end test "computes forward pass with broadcasting" do - inp1 = Nx.random_uniform({1, 1}) - inp2 = Nx.random_uniform({1, 2}) + inp1 = random({1, 1}) + inp2 = random({1, 2}) for op <- @binary_layers do model = @@ -4201,8 +4333,8 @@ defmodule CompilerTest do test "raises on bad shapes" do for op <- @binary_layers do assert_raise Axon.CompileError, ~r/cannot broadcast tensor/, fn -> - inp1 = Nx.random_uniform({1, 32}) - inp2 = Nx.random_uniform({1, 64}) + inp1 = random({1, 32}) + inp2 = random({1, 64}) model = apply(Axon, op, [ @@ -4223,7 +4355,7 @@ defmodule CompilerTest do Axon.input("input_1", shape: {nil, 32}) ) - input = %{"input_0" => Nx.random_uniform({1, 32}), "input_1" => Nx.random_uniform({1, 32})} + input = %{"input_0" => random({1, 32}), "input_1" => random({1, 32})} assert {init_fn, _} = Axon.build(model) assert %{} == init_fn.(input, %{}) @@ -4236,8 +4368,8 @@ defmodule CompilerTest do Axon.input("input_1", shape: {nil, 32}) ) - input1_1 = Nx.random_uniform({1, 32}) - input1_2 = Nx.random_uniform({1, 32}) + input1_1 = random({1, 32}) + input1_2 = random({1, 32}) assert {_, predict_fn} = Axon.build(model1) @@ -4253,9 +4385,9 @@ defmodule CompilerTest do Axon.input("input_2", shape: {nil, 32}) ]) - input2_1 = Nx.random_uniform({1, 32}) - input2_2 = Nx.random_uniform({1, 32}) - input2_3 = Nx.random_uniform({1, 32}) + input2_1 = random({1, 32}) + input2_2 = random({1, 32}) + input2_3 = random({1, 32}) assert {_, predict_fn} = Axon.build(model2) @@ -4273,8 +4405,8 @@ defmodule CompilerTest do axis: 1 ) - input1_1 = Nx.random_uniform({1, 1, 32}) - input1_2 = Nx.random_uniform({1, 1, 32}) + input1_1 = random({1, 1, 32}) + input1_2 = random({1, 1, 32}) assert {_, predict_fn} = Axon.build(model1) @@ -4294,8 +4426,8 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model1, policy) - input1_1 = Nx.random_uniform({1, 1, 32}) - input1_2 = Nx.random_uniform({1, 1, 32}) + input1_1 = random({1, 1, 32}) + input1_2 = random({1, 1, 32}) assert {_, predict_fn} = Axon.build(mp_model) @@ -4307,7 +4439,7 @@ defmodule CompilerTest do describe "pad" do test "initializes with no params" do model = Axon.input("input", shape: {nil, 3, 3}) |> Axon.pad([{1, 0}]) - input = Nx.random_uniform({1, 3, 3}) + input = random({1, 3, 3}) assert {init_fn, _} = Axon.build(model) assert %{} == init_fn.(input, %{}) @@ -4315,7 +4447,7 @@ defmodule CompilerTest do test "computes forward pass with default options" do model1 = Axon.input("input", shape: {nil, 3, 3}) |> Axon.pad([{1, 0}]) - input1 = Nx.random_uniform({1, 3, 3}) + input1 = random({1, 3, 3}) assert {_, predict_fn} = Axon.build(model1) @@ -4325,7 +4457,7 @@ defmodule CompilerTest do ) model2 = Axon.input("input", shape: {nil, 3, 3, 3}) |> Axon.pad([{0, 1}, {0, 1}]) - input2 = Nx.random_uniform({1, 3, 3, 3}) + input2 = random({1, 3, 3, 3}) assert {_, predict_fn} = Axon.build(model2) @@ -4335,7 +4467,7 @@ defmodule CompilerTest do ) model3 = Axon.input("input", shape: {nil, 3, 3, 3, 3}) |> Axon.pad([{0, 1}, {0, 1}, {1, 0}]) - input3 = Nx.random_uniform({1, 3, 3, 3, 3}) + input3 = random({1, 3, 3, 3, 3}) assert {_, predict_fn} = Axon.build(model3) @@ -4347,7 +4479,7 @@ defmodule CompilerTest do test "computes forward pass with custom options" do model = Axon.input("input", shape: {nil, 3, 3}) |> Axon.pad([{1, 0}], 2) - input = Nx.random_uniform({1, 3, 3}) + input = random({1, 3, 3}) assert {_, predict_fn} = Axon.build(model) @@ -4361,7 +4493,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 3, 3}) |> Axon.pad([{1, 0}]) policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 3, 3}) + input = random({1, 3, 3}) assert {_, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(%{}, input)) == {:bf, 16} @@ -4371,7 +4503,7 @@ defmodule CompilerTest do describe "nx" do test "computes special nx functions" do model = Axon.input("input", shape: {nil, 10}) |> Axon.nx(&Nx.sin/1) - input = Nx.random_uniform({1, 10}) + input = random({1, 10}) assert {_, predict_fn} = Axon.build(model) assert_all_close(predict_fn.(%{}, input), Nx.sin(input)) @@ -4417,7 +4549,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model1, policy) - input1_1 = Nx.random_uniform({1, 1, 32}) + input1_1 = random({1, 1, 32}) assert {_, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(%{}, input1_1)) == {:bf, 16} @@ -4433,7 +4565,7 @@ defmodule CompilerTest do assert_raise Axon.CompileError, ~r/cond_fn must return a scalar/, fn -> {_, predict_fn} = Axon.build(model) - predict_fn.(%{}, Nx.random_uniform({1, 1, 10})) + predict_fn.(%{}, random({1, 1, 10})) end end end @@ -4442,7 +4574,7 @@ defmodule CompilerTest do test "initializes with no parameters" do model = Axon.input("input", shape: {nil, 10}) |> Axon.split(5) |> Axon.container() - input = Nx.random_uniform({1, 10}) + input = random({1, 10}) assert {init_fn, _} = Axon.build(model) assert init_fn.(input, %{}) == %{} @@ -4521,7 +4653,7 @@ defmodule CompilerTest do {init_fn, predict_fn} = Axon.build(model) - inp = Nx.random_uniform({1, 1}) + inp = random({1, 1}) params = init_fn.(inp, %{}) axon_loss = fn inp, params -> Nx.sum(predict_fn.(params, inp)) end @@ -4565,7 +4697,7 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - inp = Nx.random_uniform({1, 2}) + inp = random({1, 2}) assert %{"dense_0" => dense_0_params, "dense_1" => dense_1_params} = init_fn.(inp, %{}) @@ -4632,7 +4764,7 @@ defmodule CompilerTest do describe "custom layers" do test "initializes with no parameters" do model = Axon.layer(fn x, _opts -> x end, [Axon.input("input_0", shape: {nil, 1})]) - inp = Nx.random_uniform({1, 1}) + inp = random({1, 1}) {init_fn, _} = Axon.build(model) assert Enum.empty?(init_fn.(inp, %{})) @@ -4650,7 +4782,7 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - inp = Nx.random_uniform({1, 1}) + inp = random({1, 1}) assert %{"layer_0" => %{"kernel" => kernel}} = init_fn.(inp, %{}) @@ -4670,7 +4802,7 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - inp = Nx.random_uniform({1, 1}) + inp = random({1, 1}) assert %{"custom_0" => %{"kernel" => _}, "custom_1" => %{"kernel" => _}} = init_fn.(inp, %{}) @@ -4687,7 +4819,7 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert %{"layer_0" => %{"kernel" => kernel}} = params = init_fn.(input, %{}) @@ -4695,19 +4827,17 @@ defmodule CompilerTest do end defn layer_with_options(x, kernel, opts \\ []) do - transform({x, kernel, opts}, fn {x, kernel, opts} -> - if opts[:add] do - Nx.add(x, kernel) - else - Nx.multiply(x, kernel) - end - end) + if opts[:add] do + Nx.add(x, kernel) + else + Nx.multiply(x, kernel) + end end test "computes forward pass with options" do kernel_param = Axon.param("kernel", fn shape -> shape end) - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) model1 = Axon.layer(&layer_with_options/3, [Axon.input("input_0", shape: {nil, 1}), kernel_param], @@ -4741,7 +4871,7 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - inp = Nx.random_uniform({1, 1}) + inp = random({1, 1}) assert %{"model" => %{"dense_0" => %{"kernel" => k, "bias" => b}}} = init_fn.(inp, %{}) @@ -4760,7 +4890,7 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - inp = Nx.random_uniform({1, 1}) + inp = random({1, 1}) assert %{"nested" => %{"model" => %{"dense_0" => %{"kernel" => k, "bias" => b}}}} = init_fn.(inp, %{}) @@ -4771,12 +4901,12 @@ defmodule CompilerTest do assert Nx.type(b) == {:f, 32} end - test "initializes correclty with single namespace no params" do + test "initializes correctly with single namespace no params" do model = Axon.input("input_0", shape: {nil, 1}) |> Axon.namespace("model") {init_fn, _} = Axon.build(model) - inp = Nx.random_uniform({1, 1}) + inp = random({1, 1}) assert Enum.empty?(init_fn.(inp, %{})) end @@ -4789,7 +4919,7 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - inp = Nx.random_uniform({1, 1}) + inp = random({1, 1}) assert Enum.empty?(init_fn.(inp, %{})) end @@ -4798,7 +4928,7 @@ defmodule CompilerTest do x = Axon.input("input_0", shape: {nil, 1}) |> Axon.dense(2) |> Axon.namespace("x") y = Axon.input("input_1", shape: {nil, 1}) |> Axon.dense(2) |> Axon.namespace("y") - inp = %{"input_0" => Nx.random_uniform({1, 1}), "input_1" => Nx.random_uniform({1, 1})} + inp = %{"input_0" => random({1, 1}), "input_1" => random({1, 1})} model = Axon.add(x, y) @@ -4828,7 +4958,7 @@ defmodule CompilerTest do |> Axon.namespace("y") |> Axon.namespace("z") - inp = %{"input_0" => Nx.random_uniform({1, 1}), "input_1" => Nx.random_uniform({1, 1})} + inp = %{"input_0" => random({1, 1}), "input_1" => random({1, 1})} model = Axon.add(x, z) @@ -4853,7 +4983,7 @@ defmodule CompilerTest do x = Axon.input("input_0", shape: {nil, 1}) |> Axon.dense(2) |> Axon.namespace("x") y = Axon.input("input_1", shape: {nil, 1}) |> Axon.dense(2) - inp = %{"input_0" => Nx.random_uniform({1, 1}), "input_1" => Nx.random_uniform({1, 1})} + inp = %{"input_0" => random({1, 1}), "input_1" => random({1, 1})} model = Axon.add(x, y) @@ -4877,7 +5007,7 @@ defmodule CompilerTest do test "initializes correctly reusing namespace" do x = Axon.input("input_0", shape: {nil, 1}) |> Axon.dense(2) |> Axon.namespace("x") - inp = Nx.random_uniform({1, 1}) + inp = random({1, 1}) model = Axon.add(x, x) {init_fn, _} = Axon.build(model) @@ -4920,7 +5050,7 @@ defmodule CompilerTest do test "predicts correctly with single namespace" do model = Axon.input("input_0", shape: {nil, 1}) |> Axon.dense(2) |> Axon.namespace("model") - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) {init_fn, _} = Axon.build(model) @@ -4932,7 +5062,7 @@ defmodule CompilerTest do test "predicts correctly with single namespace no parameters" do model = Axon.input("input_0", shape: {nil, 1}) |> Axon.namespace("model") - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert_equal(Axon.predict(model, %{}, input), input) end @@ -4946,7 +5076,7 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert %{"nested" => %{"model" => %{"dense_0" => %{"kernel" => k, "bias" => b}}}} = params = init_fn.(input, %{}) @@ -4960,7 +5090,7 @@ defmodule CompilerTest do |> Axon.namespace("model") |> Axon.namespace("nested") - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert_equal(Axon.predict(model, %{}, input), input) end @@ -4973,8 +5103,8 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - input_0 = Nx.random_uniform({1, 1}) - input_1 = Nx.random_uniform({1, 1}) + input_0 = random({1, 1}) + input_1 = random({1, 1}) inputs = %{"input_0" => input_0, "input_1" => input_1} assert %{ @@ -4999,8 +5129,8 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - input_0 = Nx.random_uniform({1, 1}) - input_1 = Nx.random_uniform({1, 1}) + input_0 = random({1, 1}) + input_1 = random({1, 1}) inputs = %{"input_0" => input_0, "input_1" => input_1} assert %{ @@ -5020,8 +5150,8 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - input_0 = Nx.random_uniform({1, 1}) - input_1 = Nx.random_uniform({1, 1}) + input_0 = random({1, 1}) + input_1 = random({1, 1}) inputs = %{"input_0" => input_0, "input_1" => input_1} assert %{ @@ -5038,7 +5168,7 @@ defmodule CompilerTest do model = Axon.add(x, x) {init_fn, _} = Axon.build(model) - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert %{"x" => %{"dense_0" => %{"kernel" => k, "bias" => b}}} = params = init_fn.(input, %{}) @@ -5054,7 +5184,7 @@ defmodule CompilerTest do # model = Axon.add(inner, x) - # input = Nx.random_uniform({1, 1}) + # input = random({1, 1}) # assert %{"x" => %{"dense_0" => %{"kernel" => k, "bias" => b}}} = params = Axon.init(model) # expected = Nx.add(Axon.Layers.dense(input, k, b), Axon.Layers.dense(input, k, b)) @@ -5062,6 +5192,353 @@ defmodule CompilerTest do # end end + describe "block" do + test "initializes correctly with single dense layer, used once" do + block = Axon.block(&Axon.dense(&1, 32)) + model = block.(Axon.input("features")) + + {init_fn, _} = Axon.build(model) + + assert %{"block_0" => %{"dense_0" => %{"kernel" => k, "bias" => b}}} = + init_fn.(Nx.template({1, 1}, :f32), %{}) + + assert Nx.shape(k) == {1, 32} + assert Nx.shape(b) == {32} + assert Nx.type(k) == {:f, 32} + assert Nx.type(b) == {:f, 32} + end + + test "initializes correctly with single dense layer, used twice" do + block = Axon.block(&Axon.dense(&1, 1)) + + model = + Axon.input("features") + |> block.() + |> block.() + + {init_fn, _} = Axon.build(model) + + assert %{"block_0" => %{"dense_0" => %{"kernel" => k, "bias" => b}} = block_params} = + params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + assert Nx.shape(k) == {1, 1} + assert Nx.shape(b) == {1} + assert Nx.type(k) == {:f, 32} + assert Nx.type(b) == {:f, 32} + + # no additional dense layers in block + assert map_size(block_params) == 1 + # no additional blocks + assert map_size(params) == 1 + end + + test "initializes correctly with multiple dense layer, used once" do + block = + Axon.block(fn x -> + x + |> Axon.dense(32, activation: :relu) + |> Axon.dense(32, activation: :relu) + end) + + model = block.(Axon.input("features")) + {init_fn, _} = Axon.build(model) + + assert %{ + "block_0" => + %{ + "dense_0" => %{"kernel" => k1, "bias" => b1}, + "dense_1" => %{"kernel" => k2, "bias" => b2} + } = block_params + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + assert Nx.shape(k1) == {1, 32} + assert Nx.shape(b1) == {32} + assert Nx.shape(k2) == {32, 32} + assert Nx.shape(b2) == {32} + assert Nx.type(k1) == {:f, 32} + assert Nx.type(b1) == {:f, 32} + assert Nx.type(k2) == {:f, 32} + assert Nx.type(b2) == {:f, 32} + + # no additional dense layers in block + assert map_size(block_params) == 2 + # no additional blocks + assert map_size(params) == 1 + end + + test "initializes correctly with multiple dense layer, used multiple times" do + block = + Axon.block(fn x -> + x + |> Axon.dense(32, activation: :relu) + |> Axon.dense(1, activation: :relu) + end) + + model = Enum.reduce(0..9, Axon.input("features"), fn _, x -> block.(x) end) + + {init_fn, _} = Axon.build(model) + + assert %{ + "block_0" => + %{ + "dense_0" => %{"kernel" => k1, "bias" => b1}, + "dense_1" => %{"kernel" => k2, "bias" => b2} + } = block_params + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + assert Nx.shape(k1) == {1, 32} + assert Nx.shape(b1) == {32} + assert Nx.shape(k2) == {32, 1} + assert Nx.shape(b2) == {1} + assert Nx.type(k1) == {:f, 32} + assert Nx.type(b1) == {:f, 32} + assert Nx.type(k2) == {:f, 32} + assert Nx.type(b2) == {:f, 32} + + # no additional dense layers in block + assert map_size(block_params) == 2 + # no additional blocks + assert map_size(params) == 1 + end + + test "initializes correctly with multiple blocks in network" do + block1 = Axon.block(&Axon.dense(&1, 32)) + block2 = Axon.block(&Axon.dense(&1, 32)) + + model = + Axon.input("features") + |> block1.() + |> block2.() + + {init_fn, _} = Axon.build(model) + + assert %{ + "block_0" => + %{ + "dense_0" => %{"kernel" => k1, "bias" => b1} + } = block_0_params, + "block_1" => + %{ + "dense_0" => %{"kernel" => k2, "bias" => b2} + } = block_1_params + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + assert Nx.shape(k1) == {1, 32} + assert Nx.shape(b1) == {32} + assert Nx.shape(k2) == {32, 32} + assert Nx.shape(b2) == {32} + assert Nx.type(k1) == {:f, 32} + assert Nx.type(b1) == {:f, 32} + assert Nx.type(k2) == {:f, 32} + assert Nx.type(b2) == {:f, 32} + + # no additional dense layers in block + assert map_size(block_0_params) == 1 + assert map_size(block_1_params) == 1 + # no additional blocks + assert map_size(params) == 2 + end + + test "initializes correctly with block inside of a block" do + block = + Axon.block(fn x -> + inner_block = Axon.block(&Axon.dense(&1, 1)) + + x |> inner_block.() |> inner_block.() + end) + + model = + Axon.input("features") + |> block.() + |> block.() + + {init_fn, _} = Axon.build(model) + + assert %{ + "block_0" => + %{ + "block_0" => %{"dense_0" => %{"kernel" => k, "bias" => b}} = inner_block_params + } = block_params + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + assert Nx.shape(k) == {1, 1} + assert Nx.shape(b) == {1} + assert Nx.type(k) == {:f, 32} + assert Nx.type(b) == {:f, 32} + + assert map_size(inner_block_params) == 1 + assert map_size(block_params) == 1 + assert map_size(params) == 1 + end + + test "predicts correctly with single dense, used once" do + block = Axon.block(&Axon.dense(&1, 32)) + model = block.(Axon.input("features")) + + {init_fn, predict_fn} = Axon.build(model) + + assert %{"block_0" => %{"dense_0" => %{"kernel" => k, "bias" => b}}} = + params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + input = random({1, 1}) + + assert predict_fn.(params, input) == Axon.Layers.dense(input, k, b) + end + + test "predicts correctly with single dense, used twice" do + block = Axon.block(&Axon.dense(&1, 1)) + + model = + Axon.input("features") + |> block.() + |> block.() + + {init_fn, predict_fn} = Axon.build(model) + + assert %{"block_0" => %{"dense_0" => %{"kernel" => k, "bias" => b}}} = + params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + input = random({1, 1}) + + assert predict_fn.(params, input) == + input |> Axon.Layers.dense(k, b) |> Axon.Layers.dense(k, b) + end + + test "predicts correctly with multiple dense, used once" do + block = + Axon.block(fn x -> + x + |> Axon.dense(32, activation: :relu) + |> Axon.dense(1, activation: :relu) + end) + + model = block.(Axon.input("features")) + {init_fn, predict_fn} = Axon.build(model) + + assert %{ + "block_0" => %{ + "dense_0" => %{"kernel" => k1, "bias" => b1}, + "dense_1" => %{"kernel" => k2, "bias" => b2} + } + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + expected_predict_fn = fn x, k1, b1, k2, b2 -> + x + |> Axon.Layers.dense(k1, b1) + |> Axon.Activations.relu() + |> Axon.Layers.dense(k2, b2) + |> Axon.Layers.relu() + end + + input = random({1, 1}) + + assert predict_fn.(params, input) == expected_predict_fn.(input, k1, b1, k2, b2) + end + + test "predicts correctly with multiple dense, used twice" do + block = + Axon.block(fn x -> + x + |> Axon.dense(32, activation: :relu) + |> Axon.dense(1, activation: :relu) + end) + + model = + Axon.input("features") + |> block.() + |> block.() + + {init_fn, predict_fn} = Axon.build(model) + + assert %{ + "block_0" => %{ + "dense_0" => %{"kernel" => k1, "bias" => b1}, + "dense_1" => %{"kernel" => k2, "bias" => b2} + } + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + expected_predict_fn = fn x, k1, b1, k2, b2 -> + x + |> Axon.Layers.dense(k1, b1) + |> Axon.Activations.relu() + |> Axon.Layers.dense(k2, b2) + |> Axon.Layers.relu() + |> Axon.Layers.dense(k1, b1) + |> Axon.Activations.relu() + |> Axon.Layers.dense(k2, b2) + |> Axon.Layers.relu() + end + + input = random({1, 1}) + + assert predict_fn.(params, input) == expected_predict_fn.(input, k1, b1, k2, b2) + end + + test "predicts correctly with multiple blocks in network" do + block1 = Axon.block(&Axon.dense(&1, 32)) + block2 = Axon.block(&Axon.dense(&1, 32)) + + model = + Axon.input("features") + |> block1.() + |> block2.() + + {init_fn, predict_fn} = Axon.build(model) + + actual_predict_fn = fn x, k1, b1, k2, b2 -> + x + |> Axon.Layers.dense(k1, b1) + |> Axon.Layers.dense(k2, b2) + end + + assert %{ + "block_0" => %{ + "dense_0" => %{"kernel" => k1, "bias" => b1} + }, + "block_1" => %{ + "dense_0" => %{"kernel" => k2, "bias" => b2} + } + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + input = random({1, 1}) + + assert predict_fn.(params, input) == actual_predict_fn.(input, k1, b1, k2, b2) + end + + test "predicts correctly with block inside of a block" do + block = + Axon.block(fn x -> + inner_block = Axon.block(&Axon.dense(&1, 1)) + + x |> inner_block.() |> inner_block.() + end) + + model = + Axon.input("features") + |> block.() + |> block.() + + {init_fn, predict_fn} = Axon.build(model) + + actual_predict_fn = fn x, k, b -> + x + |> Axon.Layers.dense(k, b) + |> Axon.Layers.dense(k, b) + |> Axon.Layers.dense(k, b) + |> Axon.Layers.dense(k, b) + end + + assert %{ + "block_0" => %{ + "block_0" => %{"dense_0" => %{"kernel" => k, "bias" => b}} + } + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + input = random({1, 1}) + assert predict_fn.(params, input) == actual_predict_fn.(input, k, b) + end + end + describe "initializers" do test "work with functions" do model = @@ -5141,7 +5618,7 @@ defmodule CompilerTest do # {init_fn, _} = Axon.build(model) - # inp = Nx.random_uniform({1, 1}) + # inp = random({1, 1}) # assert_raise ArgumentError, # ~s{found unexpected key in the initial parameters map: "dense_2"}, @@ -5153,8 +5630,8 @@ defmodule CompilerTest do describe "containers" do test "allows accessors with custom layers" do - input1 = Nx.random_uniform({1, 1}) - input2 = Nx.random_uniform({1, 2}) + input1 = random({1, 1}) + input2 = random({1, 2}) inputs = %{"input_0" => input1, "input_1" => input2} inp1 = Axon.input("input_0", shape: {nil, 1}) @@ -5281,4 +5758,19 @@ defmodule CompilerTest do assert predict_fn1 == predict_fn2 end end + + describe "metadata" do + test "axon compiler attaches layer name as metadata to subgraphs" do + model = Axon.input("input", shape: {nil, 784}) |> Axon.dense(128) + + {init_fn, predict_fn} = Axon.build(model) + params = init_fn.(Nx.template({1, 784}, :f32), %{}) + input = Nx.broadcast(0.0, {1, 784}) + + expr_fn = Nx.Defn.jit(predict_fn, compiler: Axon.Defn) + expr = expr_fn.(params, input) + + assert %{data: %{op: :metadata, args: [_tensor, %{axon_layer: :dense}]}} = expr + end + end end diff --git a/test/axon/integration_test.exs b/test/axon/integration_test.exs index 19d35440..8c95ed56 100644 --- a/test/axon/integration_test.exs +++ b/test/axon/integration_test.exs @@ -26,7 +26,58 @@ defmodule Axon.IntegrationTest do ExUnit.CaptureIO.capture_io(fn -> results = model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(5.0e-3)) + |> Axon.Loop.trainer( + :categorical_cross_entropy, + Polaris.Optimizers.adam(learning_rate: 5.0e-3) + ) + # TODO: Fix default output transform + |> Map.update(:output_transform, nil, fn _ -> & &1 end) + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.validate(model, train) + |> Axon.Loop.run(train, %{}, epochs: 10) + + assert %{step_state: %{model_state: model_state}, metrics: %{9 => last_epoch_metrics}} = + results + + eval_results = + model + |> Axon.Loop.evaluator() + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.run(train, model_state) + + assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results + + assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.7) + assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"]) + assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2} + end) + end + + test "f64 input test" do + {train, _test} = get_test_data(100, 0, 10, {10}, 2, 1337) + + train = + train + |> Stream.map(fn {xs, ys} -> + {Nx.as_type(xs, :f64), one_hot(ys, num_classes: 2)} + end) + |> Enum.to_list() + + [{x_test, _}] = Enum.take(train, 1) + + model = + Axon.input("input") + |> Axon.dense(16) + |> Axon.dropout(rate: 0.1) + |> Axon.dense(2, activation: :softmax) + + ExUnit.CaptureIO.capture_io(fn -> + results = + model + |> Axon.Loop.trainer( + :categorical_cross_entropy, + Polaris.Optimizers.adam(learning_rate: 5.0e-3) + ) # TODO: Fix default output transform |> Map.update(:output_transform, nil, fn _ -> & &1 end) |> Axon.Loop.metric(:accuracy) @@ -74,7 +125,10 @@ defmodule Axon.IntegrationTest do ExUnit.CaptureIO.capture_io(fn -> results = model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(5.0e-3)) + |> Axon.Loop.trainer( + :categorical_cross_entropy, + Polaris.Optimizers.adam(learning_rate: 5.0e-3) + ) # TODO: Fix default output transform |> Map.update(:output_transform, nil, fn _ -> & &1 end) |> Axon.Loop.metric(:accuracy) @@ -121,7 +175,59 @@ defmodule Axon.IntegrationTest do ExUnit.CaptureIO.capture_io(fn -> results = model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(5.0e-3)) + |> Axon.Loop.trainer( + :categorical_cross_entropy, + Polaris.Optimizers.adam(learning_rate: 5.0e-3) + ) + # TODO: Fix default output transform + |> Map.update(:output_transform, nil, fn _ -> & &1 end) + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.validate(model, train) + |> Axon.Loop.run(train, %{}, epochs: 10) + + assert %{step_state: %{model_state: model_state}, metrics: %{9 => last_epoch_metrics}} = + results + + eval_results = + model + |> Axon.Loop.evaluator() + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.run(train, model_state) + + assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results + + assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.7) + assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"]) + assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2} + end) + end + + test "gradient accumulation test" do + {train, _test} = get_test_data(100, 0, 10, {10}, 2, 1337) + + train = + train + |> Stream.map(fn {xs, ys} -> + {xs, one_hot(ys, num_classes: 2)} + end) + |> Enum.to_list() + + [{x_test, _}] = Enum.take(train, 1) + + model = + Axon.input("input") + |> Axon.dense(16) + |> Axon.dropout(rate: 0.1) + |> Axon.dense(2, activation: :softmax) + + ExUnit.CaptureIO.capture_io(fn -> + results = + model + |> Axon.Loop.trainer( + :categorical_cross_entropy, + Polaris.Optimizers.adam(learning_rate: 5.0e-3), + gradient_accumulation_steps: 3 + ) # TODO: Fix default output transform |> Map.update(:output_transform, nil, fn _ -> & &1 end) |> Axon.Loop.metric(:accuracy) @@ -164,7 +270,11 @@ defmodule Axon.IntegrationTest do ExUnit.CaptureIO.capture_io(fn -> %{metrics: metrics1, step_state: step_state1} = model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(5.0e-3), seed: 1) + |> Axon.Loop.trainer( + :categorical_cross_entropy, + Polaris.Optimizers.adam(learning_rate: 5.0e-3), + seed: 1 + ) # TODO: Fix default output transform |> Map.update(:output_transform, nil, fn _ -> & &1 end) |> Axon.Loop.metric(:accuracy) @@ -173,7 +283,11 @@ defmodule Axon.IntegrationTest do %{metrics: metrics2, step_state: step_state2} = model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(5.0e-3), seed: 1) + |> Axon.Loop.trainer( + :categorical_cross_entropy, + Polaris.Optimizers.adam(learning_rate: 5.0e-3), + seed: 1 + ) # TODO: Fix default output transform |> Map.update(:output_transform, nil, fn _ -> & &1 end) |> Axon.Loop.metric(:accuracy) @@ -184,4 +298,250 @@ defmodule Axon.IntegrationTest do assert_equal(step_state1, step_state2) end) end + + describe "optimizer integration" do + @optimizers_and_args [ + {:adabelief, [[learning_rate: 5.0e-3]]}, + {:adagrad, [[learning_rate: 5.0e-3]]}, + {:adam, [[learning_rate: 5.0e-3]]}, + {:adamw, [[learning_rate: 5.0e-3]]}, + {:adamw, [[learning_rate: 5.0e-3, decay: 0.9]]}, + {:lamb, [[learning_rate: 5.0e-3]]}, + {:lamb, [[learning_rate: 5.0e-3, decay: 0.9]]}, + {:lamb, [[learning_rate: 5.0e-3, min_norm: 0.1]]}, + {:lamb, [[learning_rate: 5.0e-3, decay: 0.9, min_norm: 0.1]]}, + {:noisy_sgd, [[learning_rate: 5.0e-3]]}, + {:radam, [[learning_rate: 5.0e-3]]}, + {:rmsprop, [[learning_rate: 5.0e-3]]}, + {:rmsprop, [[learning_rate: 5.0e-3, centered: true]]}, + {:rmsprop, [[learning_rate: 5.0e-3, momentum: 0.9]]}, + {:rmsprop, [[learning_rate: 5.0e-3, nesterov: true, momentum: 0.9]]}, + {:rmsprop, [[learning_rate: 5.0e-3, centered: true, nesterov: true, momentum: 0.9]]}, + {:sgd, [[learning_rate: 5.0e-3]]}, + {:sgd, [[learning_rate: 5.0e-3, momentum: 0.9]]}, + {:sgd, [[learning_rate: 5.0e-3, momentum: 0.9, nesterov: true]]} + ] + + for {optimizer, [opts] = args} <- @optimizers_and_args do + lr = opts[:learning_rate] + + test "#{optimizer}, learning_rate: #{lr}, opts: #{inspect(opts)} trains simple model with dropout" do + {train, _test} = get_test_data(100, 0, 10, {10}, 2, 1337) + + train = + train + |> Stream.map(fn {xs, ys} -> + {xs, one_hot(ys, num_classes: 2)} + end) + |> Enum.to_list() + + [{x_test, _}] = Enum.take(train, 1) + + model = + Axon.input("input") + |> Axon.dense(16) + |> Axon.dropout(rate: 0.1) + |> Axon.dense(2, activation: :softmax) + + ExUnit.CaptureIO.capture_io(fn -> + results = + model + |> Axon.Loop.trainer( + :categorical_cross_entropy, + Polaris.Optimizers.unquote(optimizer)(unquote_splicing(args)) + ) + # TODO: Fix default output transform + |> Map.update(:output_transform, nil, fn _ -> & &1 end) + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.validate(model, train) + |> Axon.Loop.run(train, %{}, epochs: 10) + + assert %{step_state: %{model_state: model_state}, metrics: %{9 => last_epoch_metrics}} = + results + + eval_results = + model + |> Axon.Loop.evaluator() + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.run(train, model_state) + + assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results + + assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.7) + assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"]) + assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2} + end) + end + end + end + + describe "mixed precision training integration" do + @policies [ + {"compute f16", Axon.MixedPrecision.create_policy(compute: {:f, 16})}, + {"compute f16, params f16", + Axon.MixedPrecision.create_policy(compute: {:f, 16}, params: {:f, 16})}, + {"compute f16, params f16, output f16", + Axon.MixedPrecision.create_policy(params: {:f, 16}, compute: {:f, 16}, output: {:f, 16})} + ] + + @scales [:identity, :dynamic, :static] + + for {name, policy} <- @policies, scale <- @scales do + test "trains simple model with policy #{name}, scale #{inspect(scale)}" do + {train, _test} = get_test_data(100, 0, 10, {10}, 2, 1337) + + train = + train + |> Stream.map(fn {xs, ys} -> + {xs, one_hot(ys, num_classes: 2)} + end) + |> Enum.to_list() + + [{x_test, _}] = Enum.take(train, 1) + + model = + Axon.input("input") + |> Axon.dense(16) + |> Axon.dropout(rate: 0.1) + |> Axon.dense(2, activation: :softmax) + |> Axon.MixedPrecision.apply_policy(unquote(Macro.escape(policy))) + + ExUnit.CaptureIO.capture_io(fn -> + results = + model + |> Axon.Loop.trainer(:categorical_cross_entropy, :adam, loss_scale: unquote(scale)) + # TODO: Fix default output transform + |> Map.update(:output_transform, nil, fn _ -> & &1 end) + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.validate(model, train) + |> Axon.Loop.run(train, %{}, epochs: 10) + + assert %{step_state: %{model_state: model_state}, metrics: %{9 => last_epoch_metrics}} = + results + + eval_results = + model + |> Axon.Loop.evaluator() + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.run(train, model_state) + + assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results + + assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.60) + assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"]) + assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2} + assert Nx.type(model_state["dense_0"]["kernel"]) == unquote(Macro.escape(policy)).params + end) + end + + test "trains model with batch norm with policy #{name}, scale #{inspect(scale)}" do + {train, _test} = get_test_data(100, 0, 10, {10}, 2, 1337) + + train = + train + |> Stream.map(fn {xs, ys} -> + {xs, one_hot(ys, num_classes: 2)} + end) + |> Enum.to_list() + + [{x_test, _}] = Enum.take(train, 1) + + model = + Axon.input("input") + |> Axon.dense(16) + |> Axon.batch_norm() + |> Axon.dropout(rate: 0.1) + |> Axon.dense(2, activation: :softmax) + |> Axon.MixedPrecision.apply_policy( + unquote(Macro.escape(policy)), + except: [:batch_norm] + ) + + ExUnit.CaptureIO.capture_io(fn -> + results = + model + |> Axon.Loop.trainer(:categorical_cross_entropy, :adam, loss_scale: unquote(scale)) + # TODO: Fix default output transform + |> Map.update(:output_transform, nil, fn _ -> & &1 end) + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.validate(model, train) + |> Axon.Loop.run(train, %{}, epochs: 10) + + assert %{step_state: %{model_state: model_state}, metrics: %{9 => last_epoch_metrics}} = + results + + eval_results = + model + |> Axon.Loop.evaluator() + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.run(train, model_state) + + assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results + + assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.60) + assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"]) + assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2} + assert Nx.type(model_state["dense_0"]["kernel"]) == unquote(Macro.escape(policy)).params + end) + end + end + + test "mixed precision downcasts model when state is given to train" do + policy = + Axon.MixedPrecision.create_policy( + params: {:f, 16}, + compute: {:f, 16}, + output: {:f, 16} + ) + + {train, _test} = get_test_data(100, 0, 10, {10}, 2, 1337) + + train = + train + |> Stream.map(fn {xs, ys} -> + {xs, one_hot(ys, num_classes: 2)} + end) + |> Enum.to_list() + + [{x_test, _}] = Enum.take(train, 1) + + model = + Axon.input("input") + |> Axon.dense(16) + |> Axon.dropout(rate: 0.1) + |> Axon.dense(2, activation: :softmax) + + {init_fn, _} = Axon.build(model) + initial_state = init_fn.(Nx.template({1, 10}, :f32), %{}) + + mp_model = Axon.MixedPrecision.apply_policy(model, policy) + + ExUnit.CaptureIO.capture_io(fn -> + results = + mp_model + |> Axon.Loop.trainer(:categorical_cross_entropy, :adam, loss_scale: :dynamic) + # TODO: Fix default output transform + |> Map.update(:output_transform, nil, fn _ -> & &1 end) + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.validate(model, train) + |> Axon.Loop.run(train, initial_state, epochs: 10) + + assert %{step_state: %{model_state: model_state}, metrics: %{9 => last_epoch_metrics}} = + results + + eval_results = + model + |> Axon.Loop.evaluator() + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.run(train, model_state) + + assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results + + assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.60) + assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"]) + assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2} + assert Nx.type(model_state["dense_0"]["kernel"]) == policy.params + end) + end + end end diff --git a/test/axon/layers_test.exs b/test/axon/layers_test.exs index cc1fb9fc..37d930b9 100644 --- a/test/axon/layers_test.exs +++ b/test/axon/layers_test.exs @@ -186,9 +186,9 @@ defmodule Axon.LayersTest do describe "conv" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) - kernel = Nx.random_uniform({3, 1, 4, 4}) + kernel = random({3, 1, 4, 4}) t_kernel = Nx.transpose(kernel, axes: [2, 3, 1, 0]) bias = Nx.tensor(0.0) @@ -198,6 +198,19 @@ defmodule Axon.LayersTest do assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) end + test "channels last same as channels first with strides" do + input = random({1, 1, 28, 28}) + t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) + kernel = random({3, 1, 4, 4}) + t_kernel = Nx.transpose(kernel, axes: [2, 3, 1, 0]) + bias = Nx.tensor(0.0) + + first = Axon.Layers.conv(input, kernel, bias, channels: :first, strides: [1, 2]) + last = Axon.Layers.conv(t_input, t_kernel, bias, channels: :last, strides: [1, 2]) + + assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) + end + test "raises on input rank less than 3" do inp = Nx.iota({1, 1}) kernel = Nx.iota({2, 1, 1}) @@ -225,9 +238,9 @@ defmodule Axon.LayersTest do describe "conv_transpose" do test "channels first same as channels last" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) - kernel = Nx.random_uniform({3, 1, 4, 4}) + kernel = random({3, 1, 4, 4}) t_kernel = Nx.transpose(kernel, axes: [2, 3, 1, 0]) bias = Nx.tensor(0.0) @@ -237,6 +250,19 @@ defmodule Axon.LayersTest do assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) end + test "channels first same as channels last with strides" do + input = random({1, 1, 28, 28}) + t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) + kernel = random({3, 1, 4, 4}) + t_kernel = Nx.transpose(kernel, axes: [2, 3, 1, 0]) + bias = Nx.tensor(0.0) + + first = Axon.Layers.conv_transpose(input, kernel, bias, channels: :first, strides: [1, 2]) + last = Axon.Layers.conv_transpose(t_input, t_kernel, bias, channels: :last, strides: [1, 2]) + + assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) + end + test "correct valid padding, no strides" do inp = Nx.iota({1, 1, 4}, type: {:f, 32}) kernel = Nx.iota({3, 1, 2}, type: {:f, 32}) @@ -422,9 +448,9 @@ defmodule Axon.LayersTest do describe "depthwise conv" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 3, 28, 28}) + input = random({1, 3, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) - kernel = Nx.random_uniform({6, 1, 4, 4}) + kernel = random({6, 1, 4, 4}) t_kernel = Nx.transpose(kernel, axes: [2, 3, 1, 0]) bias = Nx.tensor(0.0) @@ -461,11 +487,11 @@ defmodule Axon.LayersTest do describe "separable_conv2d" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 3, 28, 28}) + input = random({1, 3, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) - k1 = Nx.random_uniform({6, 1, 4, 1}) + k1 = random({6, 1, 4, 1}) t_k1 = Nx.transpose(k1, axes: [2, 3, 1, 0]) - k2 = Nx.random_uniform({6, 1, 1, 4}) + k2 = random({6, 1, 1, 4}) t_k2 = Nx.transpose(k2, axes: [2, 3, 1, 0]) b1 = Nx.tensor(0.0) b2 = Nx.tensor(0.0) @@ -504,13 +530,13 @@ defmodule Axon.LayersTest do describe "separable_conv3d" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 3, 8, 8, 8}) + input = random({1, 3, 8, 8, 8}) t_input = Nx.transpose(input, axes: [0, 2, 3, 4, 1]) - k1 = Nx.random_uniform({6, 1, 4, 1, 1}) + k1 = random({6, 1, 4, 1, 1}) t_k1 = Nx.transpose(k1, axes: [2, 3, 4, 1, 0]) - k2 = Nx.random_uniform({6, 1, 1, 4, 1}) + k2 = random({6, 1, 1, 4, 1}) t_k2 = Nx.transpose(k2, axes: [2, 3, 4, 1, 0]) - k3 = Nx.random_uniform({6, 1, 1, 1, 4}) + k3 = random({6, 1, 1, 1, 4}) t_k3 = Nx.transpose(k3, axes: [2, 3, 4, 1, 0]) b1 = b2 = b3 = Nx.tensor(0.0) @@ -549,7 +575,7 @@ defmodule Axon.LayersTest do describe "max_pool" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = Axon.Layers.max_pool(input, kernel_size: {2, 2}, channels: :first) @@ -559,7 +585,7 @@ defmodule Axon.LayersTest do end test "channels last same as channels first with dilation" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = @@ -579,6 +605,27 @@ defmodule Axon.LayersTest do assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) end + test "channels last same as channels first with custom padding" do + input = random({1, 1, 28, 28}) + t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) + + first = + Axon.Layers.max_pool(input, + kernel_size: {2, 2}, + channels: :first, + padding: [{2, 2}, {1, 2}] + ) + + last = + Axon.Layers.max_pool(t_input, + kernel_size: {2, 2}, + channels: :last, + padding: [{2, 2}, {1, 2}] + ) + + assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) + end + test "raises on input rank less than 3" do inp = Nx.iota({1, 1}) @@ -592,7 +639,7 @@ defmodule Axon.LayersTest do describe "avg_pool" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = Axon.Layers.avg_pool(input, kernel_size: {2, 2}, channels: :first) @@ -602,7 +649,7 @@ defmodule Axon.LayersTest do end test "channels last same as channels first with dilation" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = @@ -622,6 +669,27 @@ defmodule Axon.LayersTest do assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) end + test "channels last same as channels first with custom padding" do + input = random({1, 1, 28, 28}) + t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) + + first = + Axon.Layers.max_pool(input, + kernel_size: {2, 2}, + channels: :first, + padding: [{2, 2}, {1, 2}] + ) + + last = + Axon.Layers.max_pool(t_input, + kernel_size: {2, 2}, + channels: :last, + padding: [{2, 2}, {1, 2}] + ) + + assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) + end + test "raises on input rank less than 3" do inp = Nx.iota({1, 1}) @@ -635,7 +703,7 @@ defmodule Axon.LayersTest do describe "lp pool" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = Axon.Layers.lp_pool(input, kernel_size: {2, 2}, channels: :first) @@ -645,7 +713,7 @@ defmodule Axon.LayersTest do end test "channels last same as channels first with dilation" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = @@ -665,6 +733,27 @@ defmodule Axon.LayersTest do assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) end + test "channels last same as channels first with custom padding" do + input = random({1, 1, 28, 28}) + t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) + + first = + Axon.Layers.max_pool(input, + kernel_size: {2, 2}, + channels: :first, + padding: [{2, 2}, {1, 2}] + ) + + last = + Axon.Layers.max_pool(t_input, + kernel_size: {2, 2}, + channels: :last, + padding: [{2, 2}, {1, 2}] + ) + + assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) + end + test "raises on input rank less than 3" do inp = Nx.iota({1, 1}) @@ -678,7 +767,7 @@ defmodule Axon.LayersTest do describe "adaptive avg pool" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = Axon.Layers.adaptive_avg_pool(input, output_size: {25, 25}, channels: :first) @@ -700,7 +789,7 @@ defmodule Axon.LayersTest do describe "adaptive max pool" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = Axon.Layers.adaptive_max_pool(input, output_size: {25, 25}, channels: :first) @@ -722,7 +811,7 @@ defmodule Axon.LayersTest do describe "adaptive lp pool" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = Axon.Layers.adaptive_lp_pool(input, output_size: {25, 25}, channels: :first) @@ -768,7 +857,7 @@ defmodule Axon.LayersTest do describe "global_max_pool" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = Axon.Layers.global_max_pool(input, channels: :first) @@ -790,7 +879,7 @@ defmodule Axon.LayersTest do describe "global_avg_pool" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = Axon.Layers.global_avg_pool(input, channels: :first) @@ -812,7 +901,7 @@ defmodule Axon.LayersTest do describe "global_lp_pool" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = Axon.Layers.global_lp_pool(input, channels: :first) @@ -857,10 +946,109 @@ defmodule Axon.LayersTest do end end + describe "lstm_cell" do + test "cell function matches results expected from pytorch" do + seq = + File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_input_seq.npy") + |> Nx.load_numpy!() + + c = + File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_input_c.npy") |> Nx.load_numpy!() + + h = + File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_input_h.npy") |> Nx.load_numpy!() + + wii = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_wii.npy") |> Nx.load_numpy!() + wif = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_wif.npy") |> Nx.load_numpy!() + wig = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_wig.npy") |> Nx.load_numpy!() + wio = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_wio.npy") |> Nx.load_numpy!() + whi = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_whi.npy") |> Nx.load_numpy!() + whf = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_whf.npy") |> Nx.load_numpy!() + whg = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_whg.npy") |> Nx.load_numpy!() + who = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_who.npy") |> Nx.load_numpy!() + bi = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_bi.npy") |> Nx.load_numpy!() + bf = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_bf.npy") |> Nx.load_numpy!() + bg = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_bg.npy") |> Nx.load_numpy!() + bo = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_bo.npy") |> Nx.load_numpy!() + + expected_c = + File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_output_c.npy") |> Nx.load_numpy!() + + expected_h = + File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_output_h.npy") |> Nx.load_numpy!() + + {_, {new_c, new_h}} = + Axon.Layers.lstm_cell( + seq, + {c, h}, + Nx.tensor(0), + {wii, wif, wig, wio}, + {whi, whf, whg, who}, + {bi, bf, bg, bo} + ) + + assert_all_close(new_c, expected_c) + assert_all_close(new_h, expected_h) + end + end + + describe "lstm" do + test "matches results expected from pytorch with dynamic unroll" do + seq = File.read!("test/fixtures/lstm_test/test_lstm_input_seq.npy") |> Nx.load_numpy!() + + c = + File.read!("test/fixtures/lstm_test/test_lstm_input_c.npy") + |> Nx.load_numpy!() + |> Nx.squeeze() + + h = + File.read!("test/fixtures/lstm_test/test_lstm_input_h.npy") + |> Nx.load_numpy!() + |> Nx.squeeze() + + wii = File.read!("test/fixtures/lstm_test/test_lstm_wii.npy") |> Nx.load_numpy!() + wif = File.read!("test/fixtures/lstm_test/test_lstm_wif.npy") |> Nx.load_numpy!() + wig = File.read!("test/fixtures/lstm_test/test_lstm_wig.npy") |> Nx.load_numpy!() + wio = File.read!("test/fixtures/lstm_test/test_lstm_wio.npy") |> Nx.load_numpy!() + whi = File.read!("test/fixtures/lstm_test/test_lstm_whi.npy") |> Nx.load_numpy!() + whf = File.read!("test/fixtures/lstm_test/test_lstm_whf.npy") |> Nx.load_numpy!() + whg = File.read!("test/fixtures/lstm_test/test_lstm_whg.npy") |> Nx.load_numpy!() + who = File.read!("test/fixtures/lstm_test/test_lstm_who.npy") |> Nx.load_numpy!() + bi = File.read!("test/fixtures/lstm_test/test_lstm_bi.npy") |> Nx.load_numpy!() + bf = File.read!("test/fixtures/lstm_test/test_lstm_bf.npy") |> Nx.load_numpy!() + bg = File.read!("test/fixtures/lstm_test/test_lstm_bg.npy") |> Nx.load_numpy!() + bo = File.read!("test/fixtures/lstm_test/test_lstm_bo.npy") |> Nx.load_numpy!() + + expected_seq = + File.read!("test/fixtures/lstm_test/test_lstm_output_seq.npy") |> Nx.load_numpy!() + + expected_c = + File.read!("test/fixtures/lstm_test/test_lstm_output_c.npy") |> Nx.load_numpy!() + + expected_h = + File.read!("test/fixtures/lstm_test/test_lstm_output_h.npy") |> Nx.load_numpy!() + + {new_seq, {new_c, new_h}} = + Axon.Layers.lstm( + seq, + {c, h}, + Nx.tensor(0), + {wii, wif, wig, wio}, + {whi, whf, whg, who}, + {bi, bf, bg, bo}, + unroll: :dynamic + ) + + assert_all_close(new_seq, expected_seq, atol: 1.0e-3) + assert_all_close(new_c, expected_c, atol: 1.0e-3) + assert_all_close(new_h, expected_h, atol: 1.0e-3) + end + end + describe "dynamic_unroll" do test "computes carry and output identical to static_unroll" do input = Nx.iota({1, 4, 2}, type: {:f, 32}) - carry = {Nx.iota({1, 1, 8}, type: {:f, 32})} + carry = {Nx.iota({1, 8}, type: {:f, 32})} input_kernel = {Nx.iota({2, 8}, type: {:f, 32}), Nx.iota({2, 8}, type: {:f, 32}), @@ -874,13 +1062,29 @@ defmodule Axon.LayersTest do {Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32})} - cell_fn = &Axon.Layers.gru_cell/5 + cell_fn = &Axon.Layers.gru_cell/6 {s_output, {s_carry}} = - Axon.Layers.static_unroll(cell_fn, input, carry, input_kernel, hidden_kernel, bias) + Axon.Layers.static_unroll( + cell_fn, + input, + carry, + Nx.tensor(0), + input_kernel, + hidden_kernel, + bias + ) {d_output, {d_carry}} = - Axon.Layers.dynamic_unroll(cell_fn, input, carry, input_kernel, hidden_kernel, bias) + Axon.Layers.dynamic_unroll( + cell_fn, + input, + carry, + Nx.tensor(0), + input_kernel, + hidden_kernel, + bias + ) assert_equal(s_carry, d_carry) assert_equal(s_output, d_output) @@ -888,7 +1092,8 @@ defmodule Axon.LayersTest do defn grad_static_hidden_output(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(hidden_kernel, fn x -> - {output, _} = Axon.Layers.static_unroll(cell_fn, input, carry, input_kernel, x, bias) + {output, _} = + Axon.Layers.static_unroll(cell_fn, input, carry, Nx.tensor(0), input_kernel, x, bias) Nx.mean(output) end) @@ -896,7 +1101,8 @@ defmodule Axon.LayersTest do defn grad_dynamic_hidden_output(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(hidden_kernel, fn x -> - {output, _} = Axon.Layers.dynamic_unroll(cell_fn, input, carry, input_kernel, x, bias) + {output, _} = + Axon.Layers.dynamic_unroll(cell_fn, input, carry, Nx.tensor(0), input_kernel, x, bias) Nx.mean(output) end) @@ -904,7 +1110,7 @@ defmodule Axon.LayersTest do test "computes gradient identical to static unroll for hidden kernel w.r.t. output" do input = Nx.iota({1, 4, 2}, type: {:f, 32}) - carry = {Nx.iota({1, 1, 8}, type: {:f, 32})} + carry = {Nx.iota({1, 8}, type: {:f, 32})} input_kernel = {Nx.iota({2, 8}, type: {:f, 32}), Nx.iota({2, 8}, type: {:f, 32}), @@ -918,7 +1124,7 @@ defmodule Axon.LayersTest do {Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32})} - cell_fn = &Axon.Layers.gru_cell/5 + cell_fn = &Axon.Layers.gru_cell/6 assert_equal( grad_static_hidden_output(input, carry, input_kernel, hidden_kernel, bias, cell_fn), @@ -935,7 +1141,16 @@ defmodule Axon.LayersTest do defn grad_static_hidden_carry(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(hidden_kernel, fn x -> - {_, {carry}} = Axon.Layers.static_unroll(cell_fn, input, carry, input_kernel, x, bias) + {_, {carry}} = + Axon.Layers.static_unroll( + cell_fn, + input, + carry, + Nx.tensor([[0, 0, 0, 1]]), + input_kernel, + x, + bias + ) Nx.mean(carry) end) @@ -943,7 +1158,16 @@ defmodule Axon.LayersTest do defn grad_dynamic_hidden_carry(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(hidden_kernel, fn x -> - {_, {carry}} = Axon.Layers.dynamic_unroll(cell_fn, input, carry, input_kernel, x, bias) + {_, {carry}} = + Axon.Layers.dynamic_unroll( + cell_fn, + input, + carry, + Nx.tensor([[0, 0, 0, 1]]), + input_kernel, + x, + bias + ) Nx.mean(carry) end) @@ -951,7 +1175,7 @@ defmodule Axon.LayersTest do test "computes gradient identical to static_unroll for hidden kernel w.r.t carry" do input = Nx.iota({1, 4, 2}, type: {:f, 32}) - carry = {Nx.iota({1, 1, 8}, type: {:f, 32})} + carry = {Nx.iota({1, 8}, type: {:f, 32})} input_kernel = {Nx.iota({2, 8}, type: {:f, 32}), Nx.iota({2, 8}, type: {:f, 32}), @@ -965,7 +1189,7 @@ defmodule Axon.LayersTest do {Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32})} - cell_fn = &Axon.Layers.gru_cell/5 + cell_fn = &Axon.Layers.gru_cell/6 assert_equal( grad_static_hidden_carry(input, carry, input_kernel, hidden_kernel, bias, cell_fn), @@ -975,7 +1199,8 @@ defmodule Axon.LayersTest do defn grad_static_input_output(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(input_kernel, fn x -> - {output, _} = Axon.Layers.static_unroll(cell_fn, input, carry, x, hidden_kernel, bias) + {output, _} = + Axon.Layers.static_unroll(cell_fn, input, carry, Nx.tensor(0), x, hidden_kernel, bias) Nx.mean(output) end) @@ -983,7 +1208,8 @@ defmodule Axon.LayersTest do defn grad_dynamic_input_output(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(input_kernel, fn x -> - {output, _} = Axon.Layers.dynamic_unroll(cell_fn, input, carry, x, hidden_kernel, bias) + {output, _} = + Axon.Layers.dynamic_unroll(cell_fn, input, carry, Nx.tensor(0), x, hidden_kernel, bias) Nx.mean(output) end) @@ -991,7 +1217,7 @@ defmodule Axon.LayersTest do test "computes gradient identical to static unroll for input kernel w.r.t. output" do input = Nx.iota({1, 4, 2}, type: {:f, 32}) - carry = {Nx.iota({1, 1, 8}, type: {:f, 32})} + carry = {Nx.iota({1, 8}, type: {:f, 32})} input_kernel = {Nx.iota({2, 8}, type: {:f, 32}), Nx.iota({2, 8}, type: {:f, 32}), @@ -1005,7 +1231,7 @@ defmodule Axon.LayersTest do {Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32})} - cell_fn = &Axon.Layers.gru_cell/5 + cell_fn = &Axon.Layers.gru_cell/6 assert_equal( grad_static_input_output(input, carry, input_kernel, hidden_kernel, bias, cell_fn), @@ -1015,7 +1241,16 @@ defmodule Axon.LayersTest do defn grad_static_input_carry(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(input_kernel, fn x -> - {_, {carry}} = Axon.Layers.static_unroll(cell_fn, input, carry, x, hidden_kernel, bias) + {_, {carry}} = + Axon.Layers.static_unroll( + cell_fn, + input, + carry, + Nx.tensor([[0, 0, 0, 1]]), + x, + hidden_kernel, + bias + ) Nx.mean(carry) end) @@ -1023,7 +1258,16 @@ defmodule Axon.LayersTest do defn grad_dynamic_input_carry(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(input_kernel, fn x -> - {_, {carry}} = Axon.Layers.dynamic_unroll(cell_fn, input, carry, x, hidden_kernel, bias) + {_, {carry}} = + Axon.Layers.dynamic_unroll( + cell_fn, + input, + carry, + Nx.tensor([[0, 0, 0, 1]]), + x, + hidden_kernel, + bias + ) Nx.mean(carry) end) @@ -1031,7 +1275,7 @@ defmodule Axon.LayersTest do test "computes gradient identical to static unroll for input kernel w.r.t. carry" do input = Nx.iota({1, 4, 2}, type: {:f, 32}) - carry = {Nx.iota({1, 1, 8}, type: {:f, 32})} + carry = {Nx.iota({1, 8}, type: {:f, 32})} input_kernel = {Nx.iota({2, 8}, type: {:f, 32}), Nx.iota({2, 8}, type: {:f, 32}), @@ -1045,7 +1289,7 @@ defmodule Axon.LayersTest do {Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32})} - cell_fn = &Axon.Layers.gru_cell/5 + cell_fn = &Axon.Layers.gru_cell/6 assert_equal( grad_static_input_carry(input, carry, input_kernel, hidden_kernel, bias, cell_fn), @@ -1056,7 +1300,15 @@ defmodule Axon.LayersTest do defn grad_static_bias_output(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(bias, fn x -> {output, _} = - Axon.Layers.static_unroll(cell_fn, input, carry, input_kernel, hidden_kernel, x) + Axon.Layers.static_unroll( + cell_fn, + input, + carry, + Nx.tensor([[0, 0, 0, 1]]), + input_kernel, + hidden_kernel, + x + ) Nx.mean(output) end) @@ -1065,7 +1317,15 @@ defmodule Axon.LayersTest do defn grad_dynamic_bias_output(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(bias, fn x -> {output, _} = - Axon.Layers.dynamic_unroll(cell_fn, input, carry, input_kernel, hidden_kernel, x) + Axon.Layers.dynamic_unroll( + cell_fn, + input, + carry, + Nx.tensor([[0, 0, 0, 1]]), + input_kernel, + hidden_kernel, + x + ) Nx.mean(output) end) @@ -1073,7 +1333,7 @@ defmodule Axon.LayersTest do test "computes gradient identical to static unroll for bias w.r.t. output" do input = Nx.iota({1, 4, 2}, type: {:f, 32}) - carry = {Nx.iota({1, 1, 8}, type: {:f, 32})} + carry = {Nx.iota({1, 8}, type: {:f, 32})} input_kernel = {Nx.iota({2, 8}, type: {:f, 32}), Nx.iota({2, 8}, type: {:f, 32}), @@ -1087,7 +1347,7 @@ defmodule Axon.LayersTest do {Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32})} - cell_fn = &Axon.Layers.gru_cell/5 + cell_fn = &Axon.Layers.gru_cell/6 assert_equal( grad_static_bias_output(input, carry, input_kernel, hidden_kernel, bias, cell_fn), @@ -1098,7 +1358,15 @@ defmodule Axon.LayersTest do defn grad_static_bias_carry(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(bias, fn x -> {_, {carry}} = - Axon.Layers.static_unroll(cell_fn, input, carry, input_kernel, hidden_kernel, x) + Axon.Layers.static_unroll( + cell_fn, + input, + carry, + Nx.tensor([[0, 0, 0, 1]]), + input_kernel, + hidden_kernel, + x + ) Nx.mean(carry) end) @@ -1107,7 +1375,15 @@ defmodule Axon.LayersTest do defn grad_dynamic_bias_carry(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(bias, fn x -> {_, {carry}} = - Axon.Layers.dynamic_unroll(cell_fn, input, carry, input_kernel, hidden_kernel, x) + Axon.Layers.dynamic_unroll( + cell_fn, + input, + carry, + Nx.tensor([[0, 0, 0, 1]]), + input_kernel, + hidden_kernel, + x + ) Nx.mean(carry) end) @@ -1115,7 +1391,7 @@ defmodule Axon.LayersTest do test "computes gradient identical to static unroll for bias w.r.t. carry" do input = Nx.iota({1, 4, 2}, type: {:f, 32}) - carry = {Nx.iota({1, 1, 8}, type: {:f, 32})} + carry = {Nx.iota({1, 8}, type: {:f, 32})} input_kernel = {Nx.iota({2, 8}, type: {:f, 32}), Nx.iota({2, 8}, type: {:f, 32}), @@ -1129,7 +1405,7 @@ defmodule Axon.LayersTest do {Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32})} - cell_fn = &Axon.Layers.gru_cell/5 + cell_fn = &Axon.Layers.gru_cell/6 assert_equal( grad_static_bias_carry(input, carry, input_kernel, hidden_kernel, bias, cell_fn), @@ -1137,4 +1413,132 @@ defmodule Axon.LayersTest do ) end end + + describe "group_norm" do + test "matches pytorch" do + a = + Nx.tensor([ + [ + 0.8423, + 1.9226, + -1.1295, + -1.3154, + 1.2963, + -0.6821, + -0.0519, + 0.6875, + -0.0313, + -0.3328, + -0.2821, + -2.3289, + -1.7641, + -1.3184, + -0.0890, + 0.0625 + ], + [ + -1.0853, + 0.8060, + -0.1397, + -0.2169, + 0.9605, + 0.3947, + 0.4760, + 0.8097, + 0.0380, + -0.6314, + 0.5761, + 1.9309, + 0.5038, + -0.1892, + 1.8476, + 0.0517 + ] + ]) + + b = + Nx.tensor([ + -0.3101, + -1.5896, + -1.4963, + 0.1278, + -1.4580, + 1.3832, + 0.5709, + 0.5531, + -0.0588, + 1.0411, + 1.3503, + -1.2166, + 0.7133, + 0.0694, + 0.3150, + -0.1306 + ]) + + c = + Nx.tensor([ + 1.6585, + 2.3515, + -1.3456, + 0.2376, + -0.1333, + 0.5068, + 0.2441, + 1.0382, + 0.6879, + -0.5402, + -1.8304, + -0.8906, + -0.5329, + -0.3390, + -0.1877, + 0.1405 + ]) + + expected = + Nx.tensor([ + [ + 1.4768, + -0.1375, + 0.4536, + 0.0623, + -1.5881, + -0.5951, + 0.1157, + 1.2847, + 0.6378, + -0.0194, + -1.0751, + 1.3407, + -1.3700, + -0.3844, + 0.0597, + 0.0149 + ], + [ + 2.2986, + 0.9877, + -0.4434, + 0.1453, + -1.7321, + 0.8146, + 0.4430, + 1.5159, + 0.7202, + -1.9153, + -1.7368, + -2.8723, + -0.5429, + -0.3954, + 0.2952, + 0.2103 + ] + ]) + + actual = Axon.Layers.group_norm(a, b, c, num_groups: 2) + + assert_all_close(expected, actual, atol: 1.0e-3) + end + end end diff --git a/test/axon/loop_test.exs b/test/axon/loop_test.exs index 0c9f4c6a..05e8f0bb 100644 --- a/test/axon/loop_test.exs +++ b/test/axon/loop_test.exs @@ -34,7 +34,7 @@ defmodule Axon.LoopTest do ] valid_axon_optimizers = - Axon.Optimizers.__info__(:functions) + Polaris.Optimizers.__info__(:functions) |> Enum.map(fn {k, _} -> k end) |> Enum.uniq() @@ -82,7 +82,7 @@ defmodule Axon.LoopTest do test "trainer/3 returns a supervised training loop with custom optimizer" do model = Axon.input("input", shape: {nil, 1}) - optimizer = Axon.Optimizers.rmsprop(1.0e-3) + optimizer = Polaris.Optimizers.rmsprop(learning_rate: 1.0e-3) assert %Loop{init: init_fn, step: update_fn, output_transform: transform} = Loop.trainer(model, :mean_squared_error, optimizer) @@ -203,7 +203,7 @@ defmodule Axon.LoopTest do end) =~ "Batch" end - test "eval_step/1 evalutes model on a single batch" do + test "eval_step/1 evaluates model on a single batch" do inp = Nx.tensor([0, 1, 0, 1, 0, 1]) |> Nx.new_axis(-1) tar = Nx.tensor([1, 0, 1, 0, 1, 0]) |> Nx.new_axis(-1) @@ -360,7 +360,7 @@ defmodule Axon.LoopTest do Axon.input("input", shape: {nil, 1}) |> Axon.dense(1) |> Loop.trainer(:binary_cross_entropy, :sgd, log: 0) - |> Loop.handle( + |> Loop.handle_event( :epoch_completed, fn %State{step_state: pstate} = state -> { @@ -376,7 +376,7 @@ defmodule Axon.LoopTest do } end ) - |> Loop.handle( + |> Loop.handle_event( :completed, fn %State{step_state: %{counter: counter}} = state -> assert 4 = counter @@ -396,7 +396,7 @@ defmodule Axon.LoopTest do Axon.input("input", shape: {nil, 1}) |> Axon.dense(1) |> Loop.trainer(:binary_cross_entropy, :sgd, log: 0) - |> Loop.handle( + |> Loop.handle_event( :epoch_completed, fn %State{step_state: pstate} = state -> { @@ -416,7 +416,7 @@ defmodule Axon.LoopTest do } end ) - |> Loop.handle( + |> Loop.handle_event( :completed, fn %State{step_state: %{counter: counter}} = state -> assert {{4}, 4} = counter @@ -477,7 +477,7 @@ defmodule Axon.LoopTest do end def send_handler(loop, event) do - Axon.Loop.handle(loop, event, fn state -> + Axon.Loop.handle_event(loop, event, fn state -> send(self(), event) {:continue, state} end) @@ -540,15 +540,6 @@ defmodule Axon.LoopTest do refute_received :iteration_completed end - test "fires correctly on :completed" do - ExUnit.CaptureIO.capture_io(fn -> - run_dummy_loop!(:completed, 5, 10) - end) - - assert_received :completed - refute_received :completed - end - test "fires correctly on :epoch_halted" do model = Axon.input("foo") @@ -562,7 +553,7 @@ defmodule Axon.LoopTest do ExUnit.CaptureIO.capture_io(fn -> model |> Axon.Loop.trainer(:binary_cross_entropy, :sgd) - |> Axon.Loop.handle(:iteration_completed, fn state -> + |> Axon.Loop.handle_event(:iteration_completed, fn state -> {:halt_epoch, state} end) |> send_handler(:epoch_halted) @@ -576,30 +567,6 @@ defmodule Axon.LoopTest do refute_received :epoch_halted end - test "fires correctly on :halted" do - model = Axon.input("foo") - - data = - Stream.repeatedly(fn -> - xs = Nx.tensor([[Enum.random(0..10)]]) - ys = Nx.greater(xs, 5) - {xs, ys} - end) - - ExUnit.CaptureIO.capture_io(fn -> - model - |> Axon.Loop.trainer(:binary_cross_entropy, :sgd) - |> Axon.Loop.handle(:iteration_completed, fn state -> - {:halt_loop, state} - end) - |> send_handler(:halted) - |> Axon.Loop.run(data, %{}, epochs: 5, iterations: 10) - end) - - assert_received :halted - refute_received :halted - end - test "events fire in order" do model = Axon.input("foo") @@ -618,7 +585,6 @@ defmodule Axon.LoopTest do |> send_handler(:iteration_started) |> send_handler(:iteration_completed) |> send_handler(:epoch_completed) - |> send_handler(:completed) |> Axon.Loop.run(data, %{}, epochs: 1, iterations: 1) end) @@ -627,7 +593,6 @@ defmodule Axon.LoopTest do assert_received :iteration_started assert_received :iteration_completed assert_received :epoch_completed - assert_received :completed refute_received _ end @@ -651,7 +616,7 @@ defmodule Axon.LoopTest do end def send_handler(loop, event, filter) do - Axon.Loop.handle( + Axon.Loop.handle_event( loop, event, fn state -> @@ -770,7 +735,7 @@ defmodule Axon.LoopTest do describe "serialization" do test "serialize_state/deserialize_state preserve loop state" do model = Axon.input("input", shape: {nil, 1}) |> Axon.dense(2) - optimizer = Axon.Optimizers.adam(1.0e-2) + optimizer = Polaris.Optimizers.adam(learning_rate: 1.0e-2) loss = :binary_cross_entropy {init_fn, _} = Axon.Loop.train_step(model, loss, optimizer) @@ -813,7 +778,7 @@ defmodule Axon.LoopTest do [loop: loop] end - test "saves a ceckpoint on each epoch", %{loop: loop} do + test "saves a checkpoint on each epoch", %{loop: loop} do loop |> Loop.checkpoint() |> Loop.run([{Nx.tensor([[1]]), Nx.tensor([[2]])}], %{}, epochs: 3) @@ -822,6 +787,28 @@ defmodule Axon.LoopTest do File.ls!("checkpoint") |> Enum.sort() end + test "saves a checkpoint on custom events", %{loop: loop} do + data = List.duplicate({Nx.iota({1, 1}), Nx.iota({1, 1})}, 5) + + assert %Axon.Loop.State{epoch: 3, iteration: 0, event_counts: %{iteration_completed: 15}} = + loop + |> Map.put(:output_transform, & &1) + |> Loop.checkpoint(event: :iteration_completed, filter: [every: 2]) + |> Loop.run(data, %{}, epochs: 3) + + assert [ + "checkpoint_0_0.ckpt", + "checkpoint_0_2.ckpt", + "checkpoint_0_4.ckpt", + "checkpoint_1_1.ckpt", + "checkpoint_1_3.ckpt", + "checkpoint_2_0.ckpt", + "checkpoint_2_2.ckpt", + "checkpoint_2_4.ckpt" + ] == + File.ls!("checkpoint") |> Enum.sort() + end + test "uses the custom file_pattern function", %{loop: loop} do loop |> Loop.checkpoint(file_pattern: &"ckp_#{&1.epoch}.ckpt") @@ -863,7 +850,7 @@ defmodule Axon.LoopTest do model |> Axon.Loop.trainer(:binary_cross_entropy, :sgd) |> Axon.Loop.from_state(state1) - |> Axon.Loop.handle(:epoch_completed, fn %{epoch: epoch} = state -> + |> Axon.Loop.handle_event(:epoch_completed, fn %{epoch: epoch} = state -> assert epoch >= 3 {:continue, state} end) @@ -888,7 +875,7 @@ defmodule Axon.LoopTest do |> Axon.Loop.trainer(:binary_cross_entropy, :sgd) |> Axon.Loop.metric(:accuracy) |> Axon.Loop.validate(model, Enum.take(data, 5)) - |> Axon.Loop.handle( + |> Axon.Loop.handle_event( :epoch_completed, fn %{metrics: metrics} = state -> assert Map.has_key?(metrics, "validation_accuracy") @@ -918,7 +905,7 @@ defmodule Axon.LoopTest do |> Axon.Loop.metric(:accuracy) |> Axon.Loop.validate(model, Enum.take(data, 5)) |> Axon.Loop.early_stop("validation_accuracy", mode: :max) - |> Axon.Loop.handle( + |> Axon.Loop.handle_event( :epoch_completed, fn %{handler_metadata: meta} = state -> assert %{early_stop: %{"validation_accuracy" => _, :since_last_improvement => _}} = @@ -1006,7 +993,7 @@ defmodule Axon.LoopTest do |> Axon.Loop.metric(:accuracy) |> Axon.Loop.validate(model, Enum.take(data, 5)) |> Axon.Loop.reduce_lr_on_plateau("validation_accuracy", mode: :max) - |> Axon.Loop.handle( + |> Axon.Loop.handle_event( :epoch_completed, fn %{handler_metadata: meta} = state -> assert %{reduce_lr: %{"validation_accuracy" => _, :since_last_improvement => _}} = @@ -1039,7 +1026,10 @@ defmodule Axon.LoopTest do ExUnit.CaptureIO.capture_io(fn -> state = model - |> Axon.Loop.trainer(:binary_cross_entropy, Axon.Optimizers.sgd(initial_lr)) + |> Axon.Loop.trainer( + :binary_cross_entropy, + Polaris.Optimizers.sgd(learning_rate: initial_lr) + ) |> Axon.Loop.metric(my_metric, "counter", :running_sum) |> Axon.Loop.reduce_lr_on_plateau("counter", factor: 0.5, mode: :min, patience: 2) # TODO: This API needs to change @@ -1072,7 +1062,10 @@ defmodule Axon.LoopTest do ExUnit.CaptureIO.capture_io(fn -> state = model - |> Axon.Loop.trainer(:binary_cross_entropy, Axon.Optimizers.sgd(initial_lr)) + |> Axon.Loop.trainer( + :binary_cross_entropy, + Polaris.Optimizers.sgd(learning_rate: initial_lr) + ) |> Axon.Loop.metric(my_metric, "counter", :running_sum) |> Axon.Loop.reduce_lr_on_plateau("counter", factor: 0.5, mode: :max, patience: 2) # TODO: This API needs to change diff --git a/test/axon/loss_scale_test.exs b/test/axon/loss_scale_test.exs new file mode 100644 index 00000000..93fb70c0 --- /dev/null +++ b/test/axon/loss_scale_test.exs @@ -0,0 +1,297 @@ +defmodule Axon.LossScaleTest do + use ExUnit.Case + import AxonTestUtil + + import Axon.LossScale + + describe "identity/1" do + test "creates a loss scale tuple" do + assert {init_fn, scale_fn, adjust_fn} = identity() + assert is_function(init_fn, 0) + assert is_function(scale_fn, 2) + assert is_function(adjust_fn, 2) + end + + test "accepts options" do + assert {init_fn, scale_fn, adjust_fn} = identity([]) + assert is_function(init_fn, 0) + assert is_function(scale_fn, 2) + assert is_function(adjust_fn, 2) + end + + test "initializes to empty state" do + assert {init_fn, _, _} = identity() + assert init_fn.() == %{} + end + + test "scale function returns identity operation on x" do + assert {init_fn, scale_fn, _} = identity() + state = init_fn.() + x = Nx.tensor([1.0, 2.0, 3.0]) + + new_x = scale_fn.(x, state) + assert new_x == x + end + + test "adjust function returns identity operation on x and state" do + assert {init_fn, _, adjust_fn} = identity() + state = init_fn.() + x = Nx.tensor([1.0, 2.0, 3.0]) + + assert {new_x, new_state} = adjust_fn.(x, state) + assert new_x == x + assert new_state == state + end + end + + describe "static/1" do + test "creates a loss scale tuple" do + assert {init_fn, scale_fn, adjust_fn} = static() + assert is_function(init_fn, 0) + assert is_function(scale_fn, 2) + assert is_function(adjust_fn, 2) + end + + test "accepts options" do + assert {init_fn, scale_fn, adjust_fn} = static([]) + assert is_function(init_fn, 0) + assert is_function(scale_fn, 2) + assert is_function(adjust_fn, 2) + end + + test "initializes state with default loss scale" do + assert {init_fn, _, _} = static() + assert %{loss_scale: loss_scale} = init_fn.() + assert_equal(loss_scale, Nx.pow(2, 15)) + end + + test "initializes state with specified loss scale" do + init_scale = Nx.pow(3, 15) + assert {init_fn, _, _} = static(init_scale: init_scale) + assert %{loss_scale: loss_scale} = init_fn.() + assert_equal(loss_scale, init_scale) + end + + test "scale function returns a tree scaled by static scale" do + assert {init_fn, scale_fn, _} = static() + state = init_fn.() + a = Nx.tensor([1.0, 2.0, 3.0]) + c = Nx.tensor([4.0, 5.0, 6.0]) + x = %{a: a, b: %{c: c}} + + assert %{a: scaled_a, b: %{c: scaled_c}} = scale_fn.(x, state) + assert_equal(scaled_a, Nx.multiply(a, Nx.pow(2, 15))) + assert_equal(scaled_c, Nx.multiply(c, Nx.pow(2, 15))) + end + + test "scale function returns a tree scaled by static scale with custom scale" do + init_scale = Nx.pow(3, 15) + assert {init_fn, scale_fn, _} = static(init_scale: init_scale) + state = init_fn.() + a = Nx.tensor([1.0, 2.0, 3.0]) + c = Nx.tensor([4.0, 5.0, 6.0]) + x = %{a: a, b: %{c: c}} + + assert %{a: scaled_a, b: %{c: scaled_c}} = scale_fn.(x, state) + assert_equal(scaled_a, Nx.multiply(a, init_scale)) + assert_equal(scaled_c, Nx.multiply(c, init_scale)) + end + + test "adjust function returns unscaled tree with static state" do + assert {init_fn, scale_fn, adjust_fn} = static() + state = init_fn.() + a = Nx.tensor([1.0, 2.0, 3.0]) + c = Nx.tensor([4.0, 5.0, 6.0]) + x = %{a: a, b: %{c: c}} + + scaled_x = scale_fn.(x, state) + assert {unscaled_x, new_state} = adjust_fn.(scaled_x, state) + assert %{a: unscaled_a, b: %{c: unscaled_c}} = unscaled_x + assert %{loss_scale: new_loss_scale} = new_state + + assert_all_close(unscaled_a, a) + assert_all_close(unscaled_c, c) + assert_equal(new_loss_scale, Nx.pow(2, 15)) + end + + test "adjust function returns unscaled tree with static state and custom scale" do + init_scale = Nx.pow(3, 15) + + assert {init_fn, scale_fn, adjust_fn} = static(init_scale: init_scale) + state = init_fn.() + a = Nx.tensor([1.0, 2.0, 3.0]) + c = Nx.tensor([4.0, 5.0, 6.0]) + x = %{a: a, b: %{c: c}} + + scaled_x = scale_fn.(x, state) + assert {unscaled_x, new_state} = adjust_fn.(scaled_x, state) + assert %{a: unscaled_a, b: %{c: unscaled_c}} = unscaled_x + assert %{loss_scale: new_loss_scale} = new_state + + assert_all_close(unscaled_a, a) + assert_all_close(unscaled_c, c) + assert_equal(new_loss_scale, init_scale) + end + end + + describe "dynamic/1" do + test "creates a loss scale tuple" do + assert {init_fn, scale_fn, adjust_fn} = dynamic() + assert is_function(init_fn, 0) + assert is_function(scale_fn, 2) + assert is_function(adjust_fn, 2) + end + + test "accepts options" do + assert {init_fn, scale_fn, adjust_fn} = dynamic([]) + assert is_function(init_fn, 0) + assert is_function(scale_fn, 2) + assert is_function(adjust_fn, 2) + end + + test "initializes state with default loss scale" do + assert {init_fn, _, _} = dynamic() + assert %{loss_scale: loss_scale, counter: counter} = init_fn.() + assert_equal(loss_scale, Nx.pow(2, 15)) + assert_equal(counter, Nx.tensor(0)) + end + + test "initializes state with specified loss scale" do + init_scale = Nx.pow(3, 15) + assert {init_fn, _, _} = dynamic(init_scale: init_scale) + assert %{loss_scale: loss_scale, counter: counter} = init_fn.() + assert_equal(counter, Nx.tensor(0)) + assert_equal(loss_scale, init_scale) + end + + test "scale function returns a tree scaled by scale" do + assert {init_fn, scale_fn, _} = dynamic() + state = init_fn.() + a = Nx.tensor([1.0, 2.0, 3.0]) + c = Nx.tensor([4.0, 5.0, 6.0]) + x = %{a: a, b: %{c: c}} + + assert %{a: scaled_a, b: %{c: scaled_c}} = scale_fn.(x, state) + assert_equal(scaled_a, Nx.multiply(a, Nx.pow(2, 15))) + assert_equal(scaled_c, Nx.multiply(c, Nx.pow(2, 15))) + end + + test "scale function returns a tree scaled by scale with custom scale" do + init_scale = Nx.pow(3, 15) + assert {init_fn, scale_fn, _} = dynamic(init_scale: init_scale) + state = init_fn.() + a = Nx.tensor([1.0, 2.0, 3.0]) + c = Nx.tensor([4.0, 5.0, 6.0]) + x = %{a: a, b: %{c: c}} + + assert %{a: scaled_a, b: %{c: scaled_c}} = scale_fn.(x, state) + assert_equal(scaled_a, Nx.multiply(a, init_scale)) + assert_equal(scaled_c, Nx.multiply(c, init_scale)) + end + + test "adjust function unscales correctly" do + init_scale = Nx.tensor(10) + assert {init_fn, scale_fn, adjust_fn} = dynamic(init_scale: init_scale) + state = init_fn.() + + a = Nx.tensor([1.0, 2.0, 3.0]) + c = Nx.tensor([4.0, 5.0, 6.0]) + x = %{a: a, b: %{c: c}} + + scaled_x = scale_fn.(x, state) + assert {unscaled_x, _new_state} = adjust_fn.(scaled_x, state) + assert %{a: unscaled_a, b: %{c: unscaled_c}} = unscaled_x + + assert_all_close(unscaled_a, a) + assert_all_close(unscaled_c, c) + end + + test "adjust function increases loss scale according to period and factor when grads are finite" do + init_scale = Nx.tensor(10) + period = 5 + assert {init_fn, _, adjust_fn} = dynamic(init_scale: init_scale, period: period) + state = init_fn.() + + finite = Nx.tensor([1.0, 1.0, 1.0]) + + final_state = + for i <- 1..(period - 1), reduce: state do + new_state -> + {_, %{loss_scale: loss_scale, counter: counter} = new_state} = + adjust_fn.(finite, new_state) + + assert_equal(loss_scale, init_scale) + assert_equal(counter, Nx.tensor(i)) + new_state + end + + assert {_, %{loss_scale: final_scale, counter: final_counter}} = + adjust_fn.(finite, final_state) + + assert_equal(final_scale, Nx.tensor(20.0)) + assert_equal(final_counter, Nx.tensor(0)) + end + + test "adjust function reduces loss scale on non finite" do + init_scale = Nx.tensor(10) + period = 5 + factor = 2 + + assert {init_fn, _, adjust_fn} = + dynamic(init_scale: init_scale, period: period, factor: factor) + + state = init_fn.() + + non_finite = Nx.tensor([:infinity, :infinity, :infinity]) + + # TODO: increase to 99 when https://github.com/elixir-nx/complex/issues/26 + # is fixed + for i <- 0..62, reduce: state do + new_state -> + {_, %{loss_scale: loss_scale, counter: counter} = new_state} = + adjust_fn.(non_finite, new_state) + + expected_new_scale = Nx.max(1, Nx.divide(init_scale, Nx.pow(factor, i + 1))) + assert_equal(counter, Nx.tensor(0)) + assert_all_close(loss_scale, expected_new_scale) + + new_state + end + end + + test "adjust function reduces loss scale to min loss scale" do + init_scale = Nx.tensor(20) + period = 5 + factor = 2 + min_loss_scale = 2 + + assert {init_fn, _, adjust_fn} = + dynamic( + init_scale: init_scale, + period: period, + factor: factor, + min_loss_scale: min_loss_scale + ) + + state = init_fn.() + + non_finite = Nx.tensor([:infinity, :infinity, :infinity]) + + # TODO: increase to 99 when https://github.com/elixir-nx/complex/issues/26 + # is fixed + for i <- 0..62, reduce: state do + new_state -> + {_, %{loss_scale: loss_scale, counter: counter} = new_state} = + adjust_fn.(non_finite, new_state) + + expected_new_scale = + Nx.max(min_loss_scale, Nx.divide(init_scale, Nx.pow(factor, i + 1))) + + assert_equal(counter, Nx.tensor(0)) + assert_all_close(loss_scale, expected_new_scale) + + new_state + end + end + end +end diff --git a/test/axon/losses_test.exs b/test/axon/losses_test.exs index 922986c3..6f396cca 100644 --- a/test/axon/losses_test.exs +++ b/test/axon/losses_test.exs @@ -284,4 +284,25 @@ defmodule Axon.LossesTest do ) end end + + describe "apply_label_smoothing" do + test "correctly smooths labels" do + y_true = Nx.tensor([[0, 1, 0, 0, 0, 0]]) + y_pred = Nx.tensor([[0.5, 0.1, 0.1, 0.0, 0.2, 0.1]]) + + assert_all_close( + Axon.Losses.apply_label_smoothing(y_true, y_pred, smoothing: 0.1), + Nx.tensor([[0.0167, 0.9167, 0.0167, 0.0167, 0.0167, 0.0167]]), + atol: 1.0e-3 + ) + end + end + + describe "label_smoothing" do + test "returns an arity-2 function from loss function" do + loss = &Axon.Losses.categorical_cross_entropy/2 + smooth_loss = Axon.Losses.label_smoothing(loss, smoothing: 0.1) + assert is_function(smooth_loss, 2) + end + end end diff --git a/test/axon/mixed_precision_test.exs b/test/axon/mixed_precision_test.exs index 9b7f3c61..d858529e 100644 --- a/test/axon/mixed_precision_test.exs +++ b/test/axon/mixed_precision_test.exs @@ -1,86 +1,4 @@ -# defmodule MixedPrecisionTest do -# use Axon.Case, async: true - -# alias Axon.MixedPrecision.Policy -# alias Axon.MixedPrecision, as: AMP -# alias Axon.Loop - -# describe "creation and application" do -# test "create policy" do -# assert %Policy{params: {:f, 32}, compute: {:bf, 16}, output: {:f, 32}} = -# AMP.create_policy(compute: {:bf, 16}) - -# assert %Policy{params: {:bf, 16}, compute: {:f, 32}, output: {:bf, 16}} = -# AMP.create_policy(params: {:bf, 16}, output: {:bf, 16}) -# end - -# test "apply_policy" do -# model = -# Axon.input("input", shape: {nil, 784}) -# |> Axon.dense(128) -# |> Axon.batch_norm() -# |> Axon.dense(10) - -# policy = AMP.create_policy(compute: {:bf, 16}) - -# assert %Axon{ -# op: :dense, -# parent: [ -# %Axon{ -# op: :batch_norm, -# parent: [%Axon{op: :dense, policy: %Policy{compute: {:bf, 16}}}], -# policy: %Policy{compute: {:f, 32}} -# } -# ], -# policy: %Policy{compute: {:bf, 16}} -# } = AMP.apply_policy(model, policy, except: [:batch_norm]) -# end -# end - -# describe "compilation" do -# # TODO(seanmor5): Now that everything else has moved, maybe this -# # belongs in a train test or elsewhere -# test "correctly maintains parameter type after train step" do -# model = -# Axon.input("input", shape: {nil, 32}) -# |> Axon.dense(2, name: "dense1") -# |> Axon.batch_norm(name: "batch_norm") -# |> Axon.dense(1, activation: :sigmoid, name: "dense2") - -# policy = AMP.create_policy(params: {:bf, 16}) - -# mp_model = AMP.apply_policy(model, policy, except: [:batch_norm]) - -# %Loop{init: init_fn, step: step_fn} = -# Axon.Loop.trainer(mp_model, :binary_cross_entropy, Axon.Optimizers.sgd(0.01)) - -# v1 = Nx.random_uniform({1, 32}) -# v2 = Nx.random_uniform({1, 1}) - -# pstate = -# apply(Nx.Defn.jit(step_fn), [ -# {v1, v2}, -# init_fn.({v1, v2}, %{}) -# ]) - -# params = pstate[:model_state] - -# assert Nx.type(params["dense1"]["kernel"]) == {:bf, 16} -# assert Nx.type(params["dense1"]["bias"]) == {:bf, 16} -# assert Nx.type(params["dense2"]["kernel"]) == {:bf, 16} -# assert Nx.type(params["dense2"]["bias"]) == {:bf, 16} -# assert Nx.type(params["batch_norm"]["gamma"]) == {:f, 32} -# assert Nx.type(params["batch_norm"]["beta"]) == {:f, 32} -# end -# end - -# describe "inspection" do -# test "works" do -# policy = AMP.create_policy() - -# assert inspect(policy) == """ -# p=f32 c=f32 o=f32\ -# """ -# end -# end -# end +defmodule Axon.MixedPrecisionTest do + use ExUnit.Case + doctest Axon.MixedPrecision +end diff --git a/test/axon_test.exs b/test/axon_test.exs index c2a24df8..835aadd3 100644 --- a/test/axon_test.exs +++ b/test/axon_test.exs @@ -888,7 +888,7 @@ defmodule AxonTest do #Axon< inputs: %{"input_0" => {nil, 32, 10}} outputs: "lstm_output_sequence" - nodes: 6 + nodes: 7 >\ """ end diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_bf.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_bf.npy new file mode 100644 index 00000000..f444dc0f Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_bf.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_bg.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_bg.npy new file mode 100644 index 00000000..8bc03a47 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_bg.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_bi.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_bi.npy new file mode 100644 index 00000000..203ae3ca Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_bi.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_bo.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_bo.npy new file mode 100644 index 00000000..ab6e84d9 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_bo.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_input_c.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_input_c.npy new file mode 100644 index 00000000..983d06ac Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_input_c.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_input_h.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_input_h.npy new file mode 100644 index 00000000..0a167d7c Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_input_h.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_input_seq.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_input_seq.npy new file mode 100644 index 00000000..5b96fe16 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_input_seq.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_output_c.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_output_c.npy new file mode 100644 index 00000000..971ba017 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_output_c.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_output_h.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_output_h.npy new file mode 100644 index 00000000..15b995ba Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_output_h.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_whf.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_whf.npy new file mode 100644 index 00000000..e3883b75 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_whf.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_whg.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_whg.npy new file mode 100644 index 00000000..1ce47cf0 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_whg.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_whi.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_whi.npy new file mode 100644 index 00000000..80e7dc49 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_whi.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_who.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_who.npy new file mode 100644 index 00000000..31ba7f5b Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_who.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_wif.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_wif.npy new file mode 100644 index 00000000..fac88b34 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_wif.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_wig.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_wig.npy new file mode 100644 index 00000000..f0f22966 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_wig.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_wii.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_wii.npy new file mode 100644 index 00000000..982a6ced Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_wii.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_wio.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_wio.npy new file mode 100644 index 00000000..49c363f4 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_wio.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_bf.npy b/test/fixtures/lstm_test/test_lstm_bf.npy new file mode 100644 index 00000000..26f96ff3 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_bf.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_bg.npy b/test/fixtures/lstm_test/test_lstm_bg.npy new file mode 100644 index 00000000..ba52553f Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_bg.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_bi.npy b/test/fixtures/lstm_test/test_lstm_bi.npy new file mode 100644 index 00000000..25b16a5b Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_bi.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_bo.npy b/test/fixtures/lstm_test/test_lstm_bo.npy new file mode 100644 index 00000000..06fd60cd Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_bo.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_input_c.npy b/test/fixtures/lstm_test/test_lstm_input_c.npy new file mode 100644 index 00000000..23f8afa7 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_input_c.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_input_h.npy b/test/fixtures/lstm_test/test_lstm_input_h.npy new file mode 100644 index 00000000..3f2c0b33 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_input_h.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_input_seq.npy b/test/fixtures/lstm_test/test_lstm_input_seq.npy new file mode 100644 index 00000000..b8f4633d Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_input_seq.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_output_c.npy b/test/fixtures/lstm_test/test_lstm_output_c.npy new file mode 100644 index 00000000..488515c3 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_output_c.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_output_h.npy b/test/fixtures/lstm_test/test_lstm_output_h.npy new file mode 100644 index 00000000..98bb49ec Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_output_h.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_output_seq.npy b/test/fixtures/lstm_test/test_lstm_output_seq.npy new file mode 100644 index 00000000..382acfa3 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_output_seq.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_whf.npy b/test/fixtures/lstm_test/test_lstm_whf.npy new file mode 100644 index 00000000..1063b5a6 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_whf.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_whg.npy b/test/fixtures/lstm_test/test_lstm_whg.npy new file mode 100644 index 00000000..470e25a2 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_whg.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_whi.npy b/test/fixtures/lstm_test/test_lstm_whi.npy new file mode 100644 index 00000000..288a45cf Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_whi.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_who.npy b/test/fixtures/lstm_test/test_lstm_who.npy new file mode 100644 index 00000000..95182e0d Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_who.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_wif.npy b/test/fixtures/lstm_test/test_lstm_wif.npy new file mode 100644 index 00000000..2fee5e56 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_wif.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_wig.npy b/test/fixtures/lstm_test/test_lstm_wig.npy new file mode 100644 index 00000000..0c485ad2 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_wig.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_wii.npy b/test/fixtures/lstm_test/test_lstm_wii.npy new file mode 100644 index 00000000..159ab30b Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_wii.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_wio.npy b/test/fixtures/lstm_test/test_lstm_wio.npy new file mode 100644 index 00000000..b441abf8 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_wio.npy differ diff --git a/test/support/axon_test_util.ex b/test/support/axon_test_util.ex index 633243fe..d5544b19 100644 --- a/test/support/axon_test_util.ex +++ b/test/support/axon_test_util.ex @@ -134,7 +134,7 @@ defmodule AxonTestUtil do {params, opt_state} = state gradients = Nx.Defn.grad(params, loss) {updates, new_state} = update_fn.(gradients, opt_state, params) - {Axon.Updates.apply_updates(updates, params), new_state} + {Polaris.Updates.apply_updates(updates, params), new_state} end {params, _} = @@ -163,6 +163,13 @@ defmodule AxonTestUtil do end end + def random(shape, opts \\ []) do + Nx.Random.uniform_split(Nx.Random.key(:erlang.system_time()), 0.0, 1.0, + shape: shape, + type: opts[:type] || :f32 + ) + end + def get_test_data( train_samples, test_samples,