iex> Axon.MixedPrecision.create_policy(compute: {:bf, 16}) - %Policy{params: {:f, 32}, compute: {:bf, 16}, output: {:f, 32}} + #Axon.MixedPrecision.Policy
""" def create_policy(opts \\ []) do params = opts[:params] || {:f, 32} @@ -121,4 +122,34 @@ defmodule Axon.MixedPrecision do def apply_policy(%Axon{} = axon, %Policy{} = policy) do apply_policy(%Axon{} = axon, %Policy{} = policy, & &1) end + + @doc """ + Casts the given container according to the given policy + and type. + + ## Examples + + iex> policy = Axon.MixedPrecision.create_policy(params: {:f, 16}) + iex> params = %{"dense" => %{"kernel" => Nx.tensor([1.0, 2.0, 3.0])}} + iex> params = Axon.MixedPrecision.cast(policy, params, :params) + iex> Nx.type(params["dense"]["kernel"]) + {:f, 16} + + iex> policy = Axon.MixedPrecision.create_policy(compute: {:bf, 16}) + iex> value = Nx.tensor([1.0, 2.0, 3.0]) + iex> value = Axon.MixedPrecision.cast(policy, value, :compute) + iex> Nx.type(value) + {:bf, 16} + + iex> policy = Axon.MixedPrecision.create_policy(output: {:bf, 16}) + iex> value = Nx.tensor([1.0, 2.0, 3.0]) + iex> value = Axon.MixedPrecision.cast(policy, value, :output) + iex> Nx.type(value) + {:bf, 16} + """ + def cast(%Policy{} = policy, tensor_or_container, variable_type) + when variable_type in [:compute, :params, :output] do + type = get_in(policy, [Access.key!(variable_type)]) + deep_new(tensor_or_container, fn x -> Nx.as_type(x, type) end) + end end diff --git a/lib/axon/mixed_precision/policy.ex b/lib/axon/mixed_precision/policy.ex index f34d2774..30f73256 100644 --- a/lib/axon/mixed_precision/policy.ex +++ b/lib/axon/mixed_precision/policy.ex @@ -10,9 +10,11 @@ defmodule Axon.MixedPrecision.Policy do def inspect(policy, _opts) do force_unfit( concat([ + "#Axon.MixedPrecision.Policy<", "p=#{Nx.Type.to_string(policy.params)} ", "c=#{Nx.Type.to_string(policy.compute)} ", - "o=#{Nx.Type.to_string(policy.output)}" + "o=#{Nx.Type.to_string(policy.output)}", + ">" ]) ) end diff --git a/lib/axon/optimizers.ex b/lib/axon/optimizers.ex index 9f3b12cc..8cadf3d6 100644 --- a/lib/axon/optimizers.ex +++ b/lib/axon/optimizers.ex @@ -1,59 +1,6 @@ defmodule Axon.Optimizers do - @moduledoc """ - Implementations of common gradient-based optimization algorithms. - - All of the methods in this module are written in terms of - the update methods defined in `Axon.Updates`. Axon treats - optimizers as the tuple: - - {init_fn, update_fn} - - where `init_fn` returns an initial optimizer state and `update_fn` - scales input gradients. `init_fn` accepts a model's parameters - and attaches state to each parameter. `update_fn` accepts - gradients, optimizer state, and current model parameters and - returns updated optimizer state and gradients. - - Custom optimizers are often created via the `Axon.Updates` API. - - ## Example - - Consider the following usage of the Adam optimizer in a basic - update function (assuming `objective` and the `dataset` are - defined elsewhere): - - defmodule Learning do - - import Nx.Defn - - defn init(params, init_fn) do - init_fn.(params) - end - - defn update(params, optimizer_state, inputs, targets, update_fn) do - {loss, gradient} = value_and_grad(params, &objective(&1, inputs, targets)) - {scaled_updates, new_optimizer_state} = update_fn.(gradient, optimizer_state, params) - {Axon.Updates.apply_updates(params, scaled_updates), new_optimizer_state, loss} - end - end - - model_params = Nx.random_uniform({784, 10}) - {init_fn, update_fn} = Axon.Optimizers.adam(0.005) - - optimizer_state = - Learning.init(params, init_fn) - - {new_params, new_optimizer_state, loss} = - Learning.update(params, optimizer_state, inputs, targets, update_fn) - - For a simpler approach, you can also use optimizers with the training API: - - model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(0.005)) - |> Axon.Loop.run(data, epochs: 10, compiler: EXLA) - - """ - alias Axon.Updates + @moduledoc false + alias Polaris.Updates @doc """ Adabelief optimizer. @@ -69,6 +16,7 @@ defmodule Axon.Optimizers do * [AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients](https://arxiv.org/abs/2010.07468) """ + @deprecated "Use Polaris.Optimizers.adabelief/1 instead" def adabelief(learning_rate \\ 1.0e-3, opts \\ []) do Updates.scale_by_belief(opts) |> scale_by_learning_rate(learning_rate) @@ -85,6 +33,7 @@ defmodule Axon.Optimizers do * [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) """ + @deprecated "Use Polaris.Optimizers.adagrad/1 instead" def adagrad(learning_rate \\ 1.0e-3, opts \\ []) do Updates.scale_by_rss(opts) |> scale_by_learning_rate(learning_rate) @@ -104,6 +53,7 @@ defmodule Axon.Optimizers do * [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980) """ + @deprecated "Use Polaris.Optimizers.adam/1 instead" def adam(learning_rate \\ 1.0e-3, opts \\ []) do Updates.scale_by_adam(opts) |> scale_by_learning_rate(learning_rate) @@ -120,6 +70,7 @@ defmodule Axon.Optimizers do * `:eps_root` - numerical stability term. Defaults to `0.0` * `:decay` - weight decay. Defaults to `0.0` """ + @deprecated "Use Polaris.Optimizers.adamw/1 instead" def adamw(learning_rate \\ 1.0e-3, opts \\ []) do {decay, opts} = Keyword.pop(opts, :decay, 0.0) @@ -144,6 +95,7 @@ defmodule Axon.Optimizers do * [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962) """ + @deprecated "Use Polaris.Optimizers.lamb/1 instead" def lamb(learning_rate \\ 1.0e-2, opts \\ []) do {decay, opts} = Keyword.pop(opts, :decay, 0.0) {min_norm, opts} = Keyword.pop(opts, :min_norm, 0.0) @@ -162,6 +114,7 @@ defmodule Axon.Optimizers do * `:eta` - used to compute variance of noise distribution. Defaults to `0.1` * `:gamma` - used to compute variance of noise distribution. Defaults to `0.55` """ + @deprecated "Use Polaris.Optimizers.noisy_sgd/1 instead" def noisy_sgd(learning_rate \\ 1.0e-2, opts \\ []) do scale_by_learning_rate(learning_rate) |> Updates.add_noise(opts) @@ -182,6 +135,7 @@ defmodule Axon.Optimizers do * [On the Variance of Adaptive Learning Rate and Beyond](https://arxiv.org/pdf/1908.03265.pdf) """ + @deprecated "Use Polaris.Optimizers.radam/1 instead" def radam(learning_rate \\ 1.0e-3, opts \\ []) do Updates.scale_by_radam(opts) |> scale_by_learning_rate(learning_rate) @@ -200,6 +154,7 @@ defmodule Axon.Optimizers do * `:decay` - EMA decay rate. Defaults to `0.9` * `:eps` - numerical stability term. Defaults to `1.0e-8` """ + @deprecated "Use Polaris.Optimizers.rmsprop/1 instead" def rmsprop(learning_rate \\ 1.0e-2, opts \\ []) do {centered, opts} = Keyword.pop(opts, :centered, false) {nesterov?, opts} = Keyword.pop(opts, :nesterov, false) @@ -227,6 +182,7 @@ defmodule Axon.Optimizers do to value of this term. * `:nesterov` - whether or not to use nesterov momentum. Defaults to `false` """ + @deprecated "Use Polaris.Optimizers.sgd/1 instead" def sgd(learning_rate \\ 1.0e-2, opts \\ []) do momentum = opts[:momentum] nesterov? = opts[:nesterov] || false @@ -254,6 +210,7 @@ defmodule Axon.Optimizers do * [Adaptive Methods for Nonconvex Optimization](https://papers.nips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf) """ + @deprecated "Use Polaris.Optimizers.yogi/1 instead" def yogi(learning_rate \\ 1.0e-2, opts \\ []) do Updates.scale_by_yogi(opts) |> scale_by_learning_rate(learning_rate) diff --git a/lib/axon/recurrent.ex b/lib/axon/recurrent.ex deleted file mode 100644 index 7a044466..00000000 --- a/lib/axon/recurrent.ex +++ /dev/null @@ -1,233 +0,0 @@ -defmodule Axon.Recurrent do - @moduledoc false - - import Nx.Defn - import Axon.Layers - - @doc """ - GRU Cell. - - When combined with `Axon.Recurrent.*_unroll`, implements a - GRU-based RNN. More memory efficient than traditional LSTM. - - ## References - - * [Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling](https://arxiv.org/pdf/1412.3555v1.pdf) - """ - @deprecated "Use Axon.Layers.gru_cell/7 instead" - defn gru_cell( - input, - carry, - input_kernel, - hidden_kernel, - bias, - gate_fn \\ &sigmoid/1, - activation_fn \\ &tanh/1 - ) do - {hidden} = carry - {wir, wiz, win} = input_kernel - {whr, whz, whn} = hidden_kernel - {br, bz, bin, bhn} = bias - - r = gate_fn.(dense(input, wir, br) + dense(hidden, whr, 0)) - z = gate_fn.(dense(input, wiz, bz) + dense(hidden, whz, 0)) - n = activation_fn.(dense(input, win, bin) + r * dense(hidden, whn, bhn)) - - new_h = (1.0 - z) * n + z * hidden - - {{new_h}, new_h} - end - - @doc """ - LSTM Cell. - - When combined with `Axon.Recurrent.*_unroll`, implements a - LSTM-based RNN. More memory efficient than traditional LSTM. - - ## References - - * [Long Short-Term Memory](http://www.bioinf.jku.at/publications/older/2604.pdf) - """ - @deprecated "Use Axon.Layers.lstm_cell/7 instead" - defn lstm_cell( - input, - carry, - input_kernel, - hidden_kernel, - bias, - gate_fn \\ &sigmoid/1, - activation_fn \\ &tanh/1 - ) do - {cell, hidden} = carry - {wii, wif, wig, wio} = input_kernel - {whi, whf, whg, who} = hidden_kernel - - {bi, bf, bg, bo} = bias - - i = gate_fn.(dense(input, wii, bi) + dense(hidden, whi, 0)) - f = gate_fn.(dense(input, wif, bf) + dense(hidden, whf, 0)) - g = activation_fn.(dense(input, wig, bg) + dense(hidden, whg, 0)) - o = gate_fn.(dense(input, wio, bo) + dense(hidden, who, 0)) - - new_c = f * cell + i * g - new_h = o * activation_fn.(new_c) - - {{new_c, new_h}, new_h} - end - - @doc """ - ConvLSTM Cell. - - When combined with `Axon.Recurrent.*_unroll`, implements a - ConvLSTM-based RNN. More memory efficient than traditional LSTM. - - ## Options - - * `:strides` - convolution strides. Defaults to `1`. - - * `:padding` - convolution padding. Defaults to `:same`. - - ## References - - * [Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting](https://arxiv.org/abs/1506.04214) - """ - @deprecated "Use Axon.Layers.conv_lstm_cell/6 instead" - defn conv_lstm_cell(input, carry, input_kernel, hidden_kernel, bias, opts \\ []) do - opts = keyword!(opts, strides: 1, padding: :same) - - {ih} = input_kernel - {hh} = hidden_kernel - {bi} = bias - - {{cell, hidden}, input} = rank_down({carry, input}) - - gates = - Nx.add( - conv(input, ih, bi, strides: opts[:strides], padding: opts[:padding]), - conv(hidden, hh, 0, strides: opts[:strides], padding: opts[:padding]) - ) - - {i, g, f, o} = split_gates(gates) - - f = sigmoid(f + 1) - new_c = f * cell + sigmoid(i) * tanh(g) - new_h = sigmoid(o) * tanh(new_c) - - rank_up({{new_c, new_h}, new_h}) - end - - defnp split_gates(gates) do - transform(gates, fn gates -> - channels = elem(Nx.shape(gates), 1) - split_every = div(channels, 4) - - split_dims = - for i <- 0..3 do - {i * split_every, split_every} - end - - split_dims - |> Enum.map(fn {start, len} -> Nx.slice_along_axis(gates, start, len, axis: 1) end) - |> List.to_tuple() - end) - end - - defnp rank_down(rnn_data) do - transform(rnn_data, fn {{cell, hidden}, input} -> - [cell, hidden, input] = - for tensor <- [cell, hidden, input] do - Nx.squeeze(tensor, axes: [1]) - end - - {{cell, hidden}, input} - end) - end - - defnp rank_up(rnn_data) do - transform(rnn_data, fn {{cell, hidden}, input} -> - [cell, hidden, input] = - for tensor <- [cell, hidden, input] do - new_shape = - Nx.shape(tensor) - |> Tuple.insert_at(1, 1) - - Nx.reshape(tensor, new_shape) - end - - {{cell, hidden}, input} - end) - end - - @doc """ - Dynamically unrolls an RNN. - - Unrolls implement a `scan` operation which applies a - transformation on the leading axis of `input_sequence` carrying - some state. In this instance `cell_fn` is an RNN cell function - such as `lstm_cell` or `gru_cell`. - - This function will make use of an `defn` while-loop such and thus - may be more efficient for long sequences. - """ - @deprecated "Use Axon.Layers.dynamic_unroll/6 instead" - defn dynamic_unroll(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias) do - time_steps = transform(Nx.shape(input_sequence), &elem(&1, 1)) - - feature_dims = transform(Nx.rank(input_sequence), &List.duplicate(0, &1 - 2)) - - initial_shape = - transform({cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias}, fn - {cell_fn, inp, carry, inp_kernel, hid_kernel, bias} -> - seq = Nx.slice_along_axis(inp, 0, 1, axis: 1) - {_, seq} = cell_fn.(seq, carry, inp_kernel, hid_kernel, bias) - put_elem(Nx.shape(seq), 1, elem(Nx.shape(inp), 1)) - end) - - init_sequence = Nx.broadcast(0.0, initial_shape) - i = Nx.tensor(0) - - {_, carry, output, _, _, _, _} = - while {i, carry, init_sequence, input_sequence, input_kernel, recurrent_kernel, bias}, - Nx.less(i, time_steps) do - sequence = Nx.slice_along_axis(input_sequence, i, 1, axis: 1) - indices = transform({feature_dims, i}, fn {feature_dims, i} -> [0, i] ++ feature_dims end) - {carry, output} = cell_fn.(sequence, carry, input_kernel, recurrent_kernel, bias) - update_sequence = Nx.put_slice(init_sequence, indices, output) - {i + 1, carry, update_sequence, input_sequence, input_kernel, recurrent_kernel, bias} - end - - {carry, output} - end - - @doc """ - Statically unrolls an RNN. - - Unrolls implement a `scan` operation which applies a - transformation on the leading axis of `input_sequence` carrying - some state. In this instance `cell_fn` is an RNN cell function - such as `lstm_cell` or `gru_cell`. - - This function inlines the unrolling of the sequence such that - the entire operation appears as a part of the compilation graph. - This makes it suitable for shorter sequences. - """ - @deprecated "Use Axon.Layers.static_unroll/6 instead" - defn static_unroll(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias) do - transform( - {cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias}, - fn {cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias} -> - time_steps = elem(Nx.shape(input_sequence), 1) - - {carry, outputs} = - for t <- 0..(time_steps - 1), reduce: {carry, []} do - {carry, outputs} -> - input = Nx.slice_along_axis(input_sequence, t, 1, axis: 1) - {carry, output} = cell_fn.(input, carry, input_kernel, recurrent_kernel, bias) - {carry, [output | outputs]} - end - - {carry, Nx.concatenate(Enum.reverse(outputs), axis: 1)} - end - ) - end -end diff --git a/lib/axon/schedules.ex b/lib/axon/schedules.ex index 892ed3b7..62f07e07 100644 --- a/lib/axon/schedules.ex +++ b/lib/axon/schedules.ex @@ -1,28 +1,6 @@ defmodule Axon.Schedules do - @moduledoc """ - Parameter Schedules. - - Parameter schedules are often used to anneal hyperparameters - such as the learning rate during the training process. Schedules - provide a mapping from the current time step to a learning rate - or another hyperparameter. - - Choosing a good learning rate and consequently a good learning - rate schedule is typically a process of trial and error. Learning - rates should be relatively small such that the learning curve - does not oscillate violently during the training process, but - not so small that learning proceeds too slowly. Using a - schedule slowly decreases oscillations during the training - process such that, as the model converges, training also - becomes more stable. - - All of the functions in this module are implemented as - numerical functions and can be JIT or AOT compiled with - any supported `Nx` compiler. - """ - + @moduledoc false import Nx.Defn - import Axon.Shared @doc """ Linear decay schedule. @@ -33,6 +11,7 @@ defmodule Axon.Schedules do * `:steps` - total number of decay steps. Defaults to `1000` """ + @deprecated "Use Polaris.Schedules.linear_decay/2 instead" def linear_decay(init_value, opts \\ []) do &apply_linear_decay(&1, [{:init_value, init_value} | opts]) end @@ -70,6 +49,7 @@ defmodule Axon.Schedules do * `:staircase` - discretize outputs. Defaults to `false` """ + @deprecated "Use Polaris.Schedules.exponential_decay/2 instead" def exponential_decay(init_value, opts \\ []) do &apply_exponential_decay(&1, [{:init_value, init_value} | opts]) end @@ -86,7 +66,7 @@ defmodule Axon.Schedules do init_value = opts[:init_value] rate = opts[:decay_rate] - staircase? = to_predicate(opts[:staircase]) + staircase? = opts[:staircase] k = opts[:transition_steps] start = opts[:transition_begin] @@ -104,7 +84,7 @@ defmodule Axon.Schedules do decayed_value = rate - |> Nx.power(p) + |> Nx.pow(p) |> Nx.multiply(init_value) Nx.select( @@ -132,6 +112,7 @@ defmodule Axon.Schedules do * [SGDR: Stochastic Gradient Descent with Warm Restarts](https://openreview.net/forum?id=Skq89Scxx¬eId=Skq89Scxx) """ + @deprecated "Use Polaris.Schedules.cosine_decay/2 instead" def cosine_decay(init_value, opts \\ []) do &apply_cosine_decay(&1, [{:init_value, init_value} | opts]) end @@ -160,13 +141,14 @@ defmodule Axon.Schedules do $$\gamma(t) = \gamma_0$$ """ + @deprecated "Use Polaris.Schedules.constant/2 instead" def constant(init_value, opts \\ []) do &apply_constant(&1, [{:init_value, init_value} | opts]) end defnp apply_constant(_step, opts \\ []) do opts = keyword!(opts, init_value: 0.01) - Nx.tensor(opts[:init_value]) + opts[:init_value] end @doc ~S""" @@ -185,6 +167,7 @@ defmodule Axon.Schedules do $k$ in above formulation. Defaults to `10` """ + @deprecated "Use Polaris.Schedules.polynomial_decay/2 instead" def polynomial_decay(init_value, opts \\ []) do &apply_polynomial_decay(&1, [{:init_value, init_value} | opts]) end @@ -211,7 +194,7 @@ defmodule Axon.Schedules do |> Nx.divide(k) |> Nx.negate() |> Nx.add(1) - |> Nx.power(p) + |> Nx.pow(p) |> Nx.multiply(Nx.subtract(init_value, end_value)) |> Nx.add(end_value) end diff --git a/lib/axon/shape.ex b/lib/axon/shape.ex index 5e7314b9..d06471f6 100644 --- a/lib/axon/shape.ex +++ b/lib/axon/shape.ex @@ -1,6 +1,8 @@ defmodule Axon.Shape do @moduledoc false + import Nx.Defn + # Collection of shape calculations for calculating the # output and trainable parameter shapes for high-level # layers. @@ -319,8 +321,11 @@ defmodule Axon.Shape do the input bias shape is a vector, otherwise we'll just attempt to let it broadcast itself. """ - def conv_bias_reshape(input_shape, spatial_rank, channels) do - case input_shape do + deftransform conv_bias_reshape(input, bias, channels) do + bias_shape = Nx.shape(bias) + spatial_rank = Nx.rank(input) - 2 + + case bias_shape do {} -> {} @@ -338,11 +343,51 @@ defmodule Axon.Shape do end end + @doc """ + Calculates the permutation options to pass to convolution + based on channels configuration. + + It returns both the input/output permutation and the kernel + permutation. + """ + deftransform conv_permutations(input, channels) do + rank = Nx.rank(input) + + case channels do + :first -> + perm = Enum.to_list(0..(rank - 1)) + {perm, perm} + + :last -> + spatial = Enum.to_list(1..(rank - 2)//1) + perm = [0, rank - 1 | spatial] + kernel_perm = [rank - 1, rank - 2] ++ Enum.to_list(0..(rank - 3)//1) + {perm, kernel_perm} + + invalid -> + raise ArgumentError, "invalid channel configuration, #{inspect(invalid)}" + end + end + + @doc """ + Calculates strides for transposed convolution. + """ + deftransform conv_transpose_strides(input, strides) do + rank = Nx.rank(input) - 2 + + case strides do + [_ | _] = strides -> strides + strides -> List.duplicate(strides, rank) + end + end + @doc """ Calculates the padding needed for a transposed convolution. """ - def conv_transpose_padding(kernel_shape, kernel_dilation, strides, padding, channels) - when padding in [:valid, :same] do + deftransform conv_transpose_padding(kernel, kernel_dilation, strides, padding, channels) + when padding in [:valid, :same] do + kernel_shape = Nx.shape(kernel) + kernel_spatial_dims = case channels do :first -> @@ -395,7 +440,7 @@ defmodule Axon.Shape do end end - def conv_transpose_padding(_, _, _, padding, _), do: padding + deftransform conv_transpose_padding(_, _, _, padding, _), do: padding @doc """ Calculates the shape of a depthwise convolution kernel given the @@ -632,7 +677,9 @@ defmodule Axon.Shape do across batch or channel dimensions, so we just specify a size of `1` for each of those. """ - def pool_window_size(window, spatial_rank, channels) do + deftransform pool_window_size(input, window, channels) do + spatial_rank = Nx.rank(input) - 2 + spatial_dims = case window do x when is_integer(x) -> @@ -655,20 +702,70 @@ defmodule Axon.Shape do end @doc """ - Computes the window size from the given parent shape. + Calculates the window strides of a pooling operation. """ - def adaptive_pool_window_size(parent_shape, nil, channels) do + deftransform pool_window_strides(input, strides, window_dimensions, channels) do + rank = Nx.rank(input) + + case {strides, channels} do + {nil, _} -> Tuple.to_list(window_dimensions) + {[_ | _] = strides, :first} -> [1, 1 | strides] + {[_ | _] = strides, :last} -> [1 | strides] ++ [1] + {strides, :first} -> [1, 1 | List.duplicate(strides, rank - 2)] + {strides, :last} -> [1 | List.duplicate(strides, rank - 2)] ++ [1] + end + end + + @doc """ + Calculates window dilations of a pooling operation. + """ + deftransform pool_window_dilations(input, window_dilations, channels) do + rank = Nx.rank(input) + + case {window_dilations, channels} do + {nil, _} -> List.duplicate(1, rank) + {[_ | _] = dilations, :first} -> [1, 1 | dilations] + {[_ | _] = dilations, :last} -> [1 | dilations] ++ [1] + {dilations, :first} -> [1, 1 | List.duplicate(dilations, rank - 2)] + {dilations, :last} -> [1 | List.duplicate(dilations, rank - 2)] ++ [1] + end + end + + @doc """ + Calculates padding of a pooling operation based on input padding + and channels configuration. + """ + deftransform pool_window_padding(padding, channels) do + case {padding, channels} do + {:same, _} -> :same + {:valid, _} -> :valid + {padding, :first} -> [{0, 0}, {0, 0} | padding] + {padding, :last} -> [{0, 0} | padding] ++ [{0, 0}] + end + end + + @doc """ + Computes the adaptive pooling output size from the given parent + shape, output shape and channels configuration. + """ + deftransform adaptive_pool_output_size(input, nil, channels) do + parent_shape = Nx.shape(input) + case channels do :first -> - parent_shape |> Tuple.delete_at(0) |> Tuple.delete_at(0) + parent_shape + |> Tuple.delete_at(0) + |> Tuple.delete_at(0) :last -> - parent_shape |> Tuple.delete_at(tuple_size(parent_shape) - 1) |> Tuple.delete_at(0) + parent_shape + |> Tuple.delete_at(tuple_size(parent_shape) - 1) + |> Tuple.delete_at(0) end end - def adaptive_pool_window_size(parent_shape, output_size, _channels) do - inner_rank = Nx.rank(parent_shape) - 2 + deftransform adaptive_pool_output_size(input, output_size, _channels) do + inner_rank = Nx.rank(input) - 2 tuple_or_duplicate(:output_size, output_size, inner_rank) end @@ -684,7 +781,10 @@ defmodule Axon.Shape do This preserves the size of the channel/batch dimension. """ - def adaptive_pool_window_strides(input_shape, output_spatial, spatial_rank, channels) do + deftransform adaptive_pool_window_strides(input, output_spatial, channels) do + input_shape = Nx.shape(input) + spatial_rank = Nx.rank(input) - 2 + idx = if channels == :first do 1 @@ -733,13 +833,15 @@ defmodule Axon.Shape do This preserves the size of the channel/batch dimension. """ - def adaptive_pool_window_size( - input_shape, - stride, - output_spatial, - spatial_rank, - channels - ) do + deftransform adaptive_pool_window_size( + input, + stride, + output_spatial, + channels + ) do + input_shape = Nx.shape(input) + spatial_rank = Nx.rank(input) - 2 + strides = case channels do :first -> @@ -813,16 +915,22 @@ defmodule Axon.Shape do @doc """ Calculates the reduction axes for batch normalization. """ - def batch_norm_axes(axes, channel_index) do - axes - |> Enum.filter(&(&1 != channel_index)) + deftransform batch_norm_axes(input, channel_index) do + axis = Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.names(input)) + + input + |> Nx.axes() + |> Enum.filter(&(&1 != axis)) end @doc """ Calculates the reduction axes for instance normalization. """ - def instance_norm_axes(axes, channel_index) do - reduction_axes = axes -- [0, channel_index] + deftransform instance_norm_axes(input, channel_index) do + axis = Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.names(input)) + axes = Nx.axes(input) + + reduction_axes = axes -- [0, axis] if reduction_axes == [] do raise ArgumentError, "rank of input shape must be at least 3" @@ -834,14 +942,17 @@ defmodule Axon.Shape do @doc """ Calculates the reduction axes for group normalization. """ - def group_norm_axes(rank, channel_index) do - Enum.to_list(1..(rank - 1)) -- [channel_index] + deftransform group_norm_axes(x, channel_index) do + Enum.to_list(1..(Nx.rank(x) - 1)) -- [channel_index] end @doc """ Calculates the reshape for group normalization. """ - def group_norm_shape(shape, num_groups, channel_index) do + deftransform group_norm_shape(input, num_groups, channel_index) do + shape = Nx.shape(input) + channel_index = Nx.Shape.normalize_axis(Nx.shape(input), channel_index, Nx.names(input)) + channels = elem(shape, channel_index) group_size = div(channels, num_groups) @@ -850,42 +961,25 @@ defmodule Axon.Shape do |> Tuple.insert_at(channel_index + 1, group_size) end - @doc """ - Calculates the shape after a flatten layer, which - flattens the non-minibatch dimensions into a single - dimension. - - ## Examples - - iex> Axon.Shape.flatten({nil, 1, 28, 28}) - {nil, 784} - - iex> Axon.Shape.flatten({32, 128}) - {32, 128} - - iex> Axon.Shape.flatten({nil, 10, 10}) - {nil, 100} - """ - def flatten(shape) do - out_units = Nx.size(Tuple.delete_at(shape, 0)) - - {elem(shape, 0), out_units} - end - @doc """ Computes split sizes for the given splits. """ - def split(shape, n, axis) do + deftransform split(input, index, splits, axis) do + shape = Nx.shape(input) + nil_names = List.duplicate(nil, Nx.rank(shape)) axis = Nx.Shape.normalize_axis(shape, axis, nil_names) - unless rem(elem(shape, axis), n) == 0 do + unless rem(elem(shape, axis), splits) == 0 do raise ArgumentError, - "unable to create #{n} even splits along axis #{axis}" <> + "unable to create #{splits} even splits along axis #{axis}" <> " of size #{elem(shape, axis)}" end - div(elem(shape, axis), n) + slice_size = div(elem(shape, axis), splits) + + offset = index * slice_size + {offset, slice_size} end @doc """ @@ -898,13 +992,15 @@ defmodule Axon.Shape do ## Examples - iex> Axon.Shape.spatial_dropout_noise_shape({nil, 3, 28, 28}, :first) - {nil, 1, 28, 28} + iex> Axon.Shape.spatial_dropout_noise_shape({1, 3, 28, 28}, :first) + {1, 1, 28, 28} - iex> Axon.Shape.spatial_dropout_noise_shape({nil, 28, 28, 3}, :last) - {nil, 28, 28, 1} + iex> Axon.Shape.spatial_dropout_noise_shape({1, 28, 28, 3}, :last) + {1, 28, 28, 1} """ - def spatial_dropout_noise_shape(input_shape, channels) do + deftransform spatial_dropout_noise_shape(input, channels) do + input_shape = Nx.shape(input) + if channels == :first do :erlang.setelement(2, input_shape, 1) else @@ -969,7 +1065,23 @@ defmodule Axon.Shape do " got #{inspect(shape)}" end - {elem(shape, 0), 1, units} + {elem(shape, 0), units} + end + + @doc """ + Returns the reduction axes for a global pooling operation + based on the input rank and channels configuration. + """ + deftransform global_pool_axes(input, channels) do + rank = Nx.rank(input) + + case channels do + :last -> + Enum.to_list(1..(rank - 2)) + + :first -> + Enum.to_list(2..(rank - 1)) + end end defp tuple_or_duplicate(key, tuple_or_integer, rank) do diff --git a/lib/axon/shared.ex b/lib/axon/shared.ex index 2c5e7421..6279488a 100644 --- a/lib/axon/shared.ex +++ b/lib/axon/shared.ex @@ -11,145 +11,114 @@ defmodule Axon.Shared do @doc """ Asserts `lhs` has same shape as `rhs`. """ - defn assert_shape!(caller, lhs_name, lhs, rhs_name, rhs) do - transform( - {lhs, rhs}, - fn {lhs, rhs} -> - lhs = Nx.shape(lhs) - rhs = Nx.shape(rhs) - - unless Elixir.Kernel.==(lhs, rhs) do - raise ArgumentError, - "#{caller}: expected input shapes #{lhs_name} and #{rhs_name}" <> - " to be equal, got #{inspect(lhs)} != #{inspect(rhs)}" - end - end - ) + deftransform assert_shape!(caller, lhs_name, lhs, rhs_name, rhs) do + lhs = Nx.shape(lhs) + rhs = Nx.shape(rhs) + + unless lhs == rhs do + raise ArgumentError, + "#{caller}: expected input shapes #{lhs_name} and #{rhs_name}" <> + " to be equal, got #{inspect(lhs)} != #{inspect(rhs)}" + end end @doc """ Asserts all shapes are equal. """ - defn assert_shape!(caller, shape_names, shapes) do - transform(shapes, fn [shape | shapes] -> - equal? = - Enum.all?(shapes, fn cur_shape -> - Elixir.Kernel.==(Nx.shape(cur_shape), Nx.shape(shape)) - end) - - unless equal? do - raise ArgumentError, - "#{caller}: expected all input shapes #{inspect(shape_names)}" <> - " to be equal, got #{inspect(shapes)}" - end - end) + deftransform assert_shape!(caller, shape_names, [shape | shapes]) do + equal? = + Enum.all?(shapes, fn cur_shape -> + Nx.shape(cur_shape) == Nx.shape(shape) + end) + + unless equal? do + raise ArgumentError, + "#{caller}: expected all input shapes #{inspect(shape_names)}" <> + " to be equal, got #{inspect(shapes)}" + end end @doc """ Asserts `inp` has explicit rank `rank`. """ - defn assert_rank!(caller, inp_name, inp, rank) do - transform( - {inp, rank}, - fn {x, y} -> - x = Nx.rank(x) - - unless Elixir.Kernel.==(x, y) do - raise ArgumentError, - "#{caller}: expected #{inp_name} to have rank equal to #{y}," <> - " got #{x} != #{y}" - end - end - ) + deftransform assert_rank!(caller, inp_name, inp, rank) do + x = Nx.rank(inp) + + unless x == rank do + raise ArgumentError, + "#{caller}: expected #{inp_name} to have rank equal to #{rank}," <> + " got #{x} != #{rank}" + end end @doc """ Asserts `lhs` has same rank as `rhs`. """ - defn assert_equal_rank!(caller, lhs_name, lhs, rhs_name, rhs) do - transform( - {lhs, rhs}, - fn {x, y} -> - x = if is_integer(x), do: x, else: Nx.rank(x) - y = if is_integer(y), do: y, else: Nx.rank(y) - - unless Elixir.Kernel.>=(x, y) do - raise ArgumentError, - "#{caller}: expected #{lhs_name} and #{rhs_name} ranks to be equal" <> - " got #{x} != #{y}" - end - end - ) + deftransform assert_equal_rank!(caller, lhs_name, lhs, rhs_name, rhs) do + x = if is_integer(lhs), do: lhs, else: Nx.rank(lhs) + y = if is_integer(rhs), do: rhs, else: Nx.rank(rhs) + + unless x >= y do + raise ArgumentError, + "#{caller}: expected #{lhs_name} and #{rhs_name} ranks to be equal" <> + " got #{x} != #{y}" + end end @doc """ Asserts all ranks are equal. """ - defn assert_equal_rank!(caller, rank_names, ranks) do - transform(ranks, fn [rank | ranks] -> - equal? = - Enum.all?(ranks, fn cur_rank -> - Elixir.Kernel.==(Nx.rank(cur_rank), Nx.rank(rank)) - end) - - unless equal? do - raise ArgumentError, - "#{caller}: expected all input ranks #{inspect(rank_names)}" <> - " to be equal, got #{inspect(ranks)}" - end - end) + deftransform assert_equal_rank!(caller, rank_names, [rank | ranks]) do + equal? = + Enum.all?(ranks, fn cur_rank -> + Nx.rank(cur_rank) == Nx.rank(rank) + end) + + unless equal? do + raise ArgumentError, + "#{caller}: expected all input ranks #{inspect(rank_names)}" <> + " to be equal, got #{inspect(ranks)}" + end end @doc """ Asserts `lhs` has at least rank `rhs`. """ - defn assert_min_rank!(caller, name, lhs, rhs) do - transform( - {lhs, rhs}, - fn {x, y} -> - x = if is_integer(x), do: x, else: Nx.rank(x) - y = if is_integer(y), do: y, else: Nx.rank(y) - - unless Elixir.Kernel.>=(x, y) do - raise ArgumentError, - "#{caller}: expected #{name} shape to have at least rank #{y}, got rank #{x}" - end - end - ) - end + deftransform assert_min_rank!(caller, name, lhs, rhs) do + x = if is_integer(lhs), do: lhs, else: Nx.rank(lhs) + y = if is_integer(rhs), do: rhs, else: Nx.rank(rhs) - @doc """ - Transforms the given Elixir value into a scalar predicate. - """ - defn to_predicate(term) do - transform(term, fn term -> if term, do: 1, else: 0 end) + unless x >= y do + raise ArgumentError, + "#{caller}: expected #{name} shape to have at least rank #{y}, got rank #{x}" + end end @doc """ Creates a zeros-like structure which matches the structure of the input. """ - defn zeros_like(params) do - transform( - params, - &deep_new(&1, fn x -> - fun = Axon.Initializers.zeros() - fun.(Nx.shape(x), Nx.type(x)) - end) - ) + deftransform zeros_like(params, opts \\ []) do + opts = Keyword.validate!(opts, [:type]) + fun = Axon.Initializers.zeros() + + deep_new(params, fn x -> + type = opts[:type] || Nx.type(x) + fun.(Nx.shape(x), type) + end) end @doc """ Creates a fulls-like tuple of inputs. """ - defn fulls_like(params, value) do - transform( - params, - &deep_new(&1, fn x -> - fun = Axon.Initializers.full(value) - fun.(Nx.shape(x), Nx.type(x)) - end) - ) + deftransform fulls_like(params, value, opts \\ []) do + opts = Keyword.validate!(opts, [:type]) + fun = Axon.Initializers.full(value) + + deep_new(params, fn x -> + type = opts[:type] || Nx.type(x) + fun.(Nx.shape(x), type) + end) end @doc """ @@ -259,18 +228,17 @@ defmodule Axon.Shared do end end - ## Numerical Helpers + ## List transforms in defn - # TODO: These should be contained somewhere else, like another library + deftransform list_duplicate(value, size) do + List.duplicate(value, size) + end - defn logsumexp(x, opts \\ []) do - opts = keyword!(opts, axes: [], keep_axes: false) + deftransform list_wrap(value), do: List.wrap(value) - x - |> Nx.exp() - |> Nx.sum(opts) - |> Nx.log() - end + ## Numerical Helpers + + # TODO: These should be contained somewhere else, like another library defn xlogy(x, y) do x_ok = Nx.not_equal(x, 0.0) @@ -282,25 +250,20 @@ defmodule Axon.Shared do defn reciprocal(x), do: Nx.divide(1, x) defn normalize(input, mean, variance, gamma, bias, opts \\ []) do - opts = keyword!(opts, epsilon: 1.0e-6) + [epsilon: epsilon] = keyword!(opts, epsilon: 1.0e-6) + # The select is so that we improve numerical stability by clipping + # both insignificant values of variance and NaNs to epsilon. scale = - variance - |> Nx.add(opts[:epsilon]) - |> Nx.rsqrt() - |> Nx.multiply(gamma) - - input - |> Nx.subtract(mean) - |> Nx.multiply(scale) - |> Nx.add(bias) + gamma * Nx.select(variance >= epsilon, Nx.rsqrt(variance + epsilon), Nx.rsqrt(epsilon)) + + scale * (input - mean) + bias end defn mean_and_variance(input, opts \\ []) do opts = keyword!(opts, [:axes]) mean = Nx.mean(input, axes: opts[:axes], keep_axes: true) - mean_of_squares = Nx.mean(Nx.multiply(input, input), axes: opts[:axes], keep_axes: true) - square_of_mean = Nx.multiply(mean, mean) - {mean, mean_of_squares - square_of_mean} + var = Nx.variance(input, axes: opts[:axes], keep_axes: true) + {mean, var} end end diff --git a/lib/axon/updates.ex b/lib/axon/updates.ex index bd004404..06e0b144 100644 --- a/lib/axon/updates.ex +++ b/lib/axon/updates.ex @@ -1,89 +1,6 @@ defmodule Axon.Updates do - @moduledoc ~S""" - Parameter update methods. - - Update methods transform the input tensor in some way, - usually by scaling or shifting the input with respect - to some input state. Update methods are composed - to create more advanced optimization methods such as AdaGrad - or Adam. Each update returns a tuple: - - {init_fn, update_fn} - - Which represent a state initialization and state update - function respectively. While each method in the Updates - API is a regular Elixir function, the two methods they - return are implemented as `defn`, so they can be accelerated - using any Nx backend or compiler. - - Update methods are just combinators that can be arbitrarily - composed to create complex optimizers. For example, the Adam - optimizer in Axon.Optimizers is implemented as: - - def adam(learning_rate, opts \\ []) do - Updates.scale_by_adam(opts) - |> Updates.scale(-learning_rate) - end - - Updates are maps of updates, often associated with parameters of - the same names. Using `Axon.Updates.apply_updates/3` will merge updates - and parameters by adding associated parameters and updates, and - ensuring any given model state is preserved. - - ## Custom combinators - - You can create your own combinators using the `stateless/2` and - `stateful/3` primitives. Every update method in this module is - implemented in terms of one of these two primitives. - - `stateless/2` represents a stateless update: - - def scale(combinator \\ Axon.Updates.identity(), step_size) do - stateless(combinator, &apply_scale(&1, &2, step_size)) - end - - defnp apply_scale(x, _params, step) do - transform( - {x, step}, - fn {updates, step} -> - deep_new(updates, fn x -> Nx.multiply(x, step) end) - end - ) - end - - Notice how the function given to `stateless/2` is defined within `defn`. - This is what allows the anonymous functions returned by `Axon.Updates` - to be used inside `defn`. - - `stateful/3` represents a stateful update and follows the same pattern: - - def my_stateful_update(updates) do - Axon.Updates.stateful(updates, &init_my_update/1, &apply_my_update/2) - end - - defnp init_my_update(params) do - state = zeros_like(params) - %{state: state} - end + @moduledoc false - defnp apply_my_update(updates, state) do - new_state = deep_new(state, fn v -> Nx.add(v, 0.01) end) - updates = transform({updates, new_state}, fn {updates, state} -> - deep_merge(updates, state, fn g, z -> Nx.multiply(g, z) end) - end) - {updates, %{state: new_state}} - end - - State associated with individual parameters should have keys that match the - keys of the parameter. For example, if you have parameters `%{kernel: kernel}` - with associated states `mu` and `nu` representing the first and second moments, - your state should look something like: - - %{ - mu: %{kernel: kernel_mu} - nu: %{kernel: kernel_nu} - } - """ import Nx.Defn import Axon.Shared @@ -92,6 +9,7 @@ defmodule Axon.Updates do $$f(x_i) = \alpha x_i$$ """ + @deprecated "Use Polaris.Updates.scale/2 instead" def scale(combinator \\ identity(), step_size) do stateless(combinator, &apply_scale(&1, &2, step_size)) end @@ -106,6 +24,7 @@ defmodule Axon.Updates do $$f(x_i) = \alpha x_i$$ """ + @deprecated "Use Polaris.Updates.scale_by_state/1 instead" def scale_by_state(combinator_or_step) def scale_by_state(step) when is_number(step) do @@ -144,6 +63,7 @@ defmodule Axon.Updates do * [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980) """ + @deprecated "Use Polaris.Updates.scale_by_adam/1 instead" def scale_by_adam(combinator_or_opts \\ []) def scale_by_adam(opts) when is_list(opts) do @@ -165,8 +85,8 @@ defmodule Axon.Updates do end defnp init_scale_by_adam(params) do - mus = zeros_like(params) - nus = zeros_like(params) + mus = zeros_like(params, type: :f32) + nus = zeros_like(params, type: :f32) count = Nx.tensor(0) %{mu: mus, nu: nus, count: count} end @@ -196,6 +116,7 @@ defmodule Axon.Updates do * `:eps` - numerical stability term. Defaults to `1.0e-7` """ + @deprecated "Use Polaris.Updates.scale_by_rss/1 instead" def scale_by_rss(combinator_or_opts \\ []) def scale_by_rss(opts) when is_list(opts) do @@ -219,7 +140,7 @@ defmodule Axon.Updates do end defnp init_scale_by_rss(params, value) do - sum_of_squares = fulls_like(params, value) + sum_of_squares = fulls_like(params, value, type: :f32) %{sum_of_squares: sum_of_squares} end @@ -227,7 +148,7 @@ defmodule Axon.Updates do opts = keyword!(opts, eps: 1.0e-7) eps = opts[:eps] - sum_of_squares = deep_merge(x, sum_of_squares, fn g, z -> Nx.power(g, 2) + z end) + sum_of_squares = deep_merge(x, sum_of_squares, fn g, z -> Nx.pow(g, 2) + z end) inv_sqrt_squares = deep_new(sum_of_squares, fn z -> Nx.rsqrt(z + eps) end) @@ -255,6 +176,7 @@ defmodule Axon.Updates do * [Overview of mini-batch gradient descent](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) """ + @deprecated "Use Polaris.Updates.scale_by_rms/1 instead" def scale_by_rms(combinator_or_opts \\ []) def scale_by_rms(opts) when is_list(opts) do @@ -278,7 +200,7 @@ defmodule Axon.Updates do end defnp init_scale_by_rms(params, scale) do - nu = fulls_like(params, scale) + nu = fulls_like(params, scale, type: :f32) %{nu: nu} end @@ -312,6 +234,7 @@ defmodule Axon.Updates do * [AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients](https://arxiv.org/abs/2010.07468) """ + @deprecated "Use Polaris.Updates.scale_by_belief/1 instead" def scale_by_belief(combinator_or_opts \\ []) def scale_by_belief(opts) when is_list(opts) do @@ -333,8 +256,8 @@ defmodule Axon.Updates do end defnp init_scale_by_belief(params) do - mus = zeros_like(params) - nus = zeros_like(params) + mus = zeros_like(params, type: :f32) + nus = zeros_like(params, type: :f32) count = Nx.tensor(0) %{mu: mus, nu: nus, count: count} end @@ -371,6 +294,7 @@ defmodule Axon.Updates do * [Overview of mini-batch gradient descent](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) """ + @deprecated "Use Polaris.Updates.scale_by_stddev/1 instead" def scale_by_stddev(combinator_or_opts \\ []) def scale_by_stddev(opts) when is_list(opts) do @@ -394,8 +318,8 @@ defmodule Axon.Updates do end defnp init_scale_by_stddev(params, value) do - mu = zeros_like(params) - nu = fulls_like(params, value) + mu = zeros_like(params, type: :f32) + nu = fulls_like(params, value, type: :f32) %{mu: mu, nu: nu} end @@ -409,7 +333,7 @@ defmodule Axon.Updates do mu_nu = deep_merge(mu, nu, fn m, n -> - Nx.rsqrt(-Nx.power(m, 2) + n + eps) + Nx.rsqrt(-Nx.pow(m, 2) + n + eps) end) x = deep_merge(x, mu_nu, fn g, mn -> g * mn end) @@ -425,6 +349,7 @@ defmodule Axon.Updates do counter. You might need to update the schedule to operate on per-batch schedule rather than per-epoch. """ + @deprecated "Use Polaris.Updates.scale_by_schedule/2 instead" def scale_by_schedule(combinator \\ identity(), schedule_fn) when is_function(schedule_fn, 1) do stateful( combinator, @@ -465,6 +390,7 @@ defmodule Axon.Updates do * [On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265) """ + @deprecated "Use Polaris.Updates.scale_by_radam/1 instead" def scale_by_radam(combinator_or_opts \\ []) def scale_by_radam(opts) when is_list(opts) do @@ -486,8 +412,8 @@ defmodule Axon.Updates do end defnp init_scale_by_radam(params) do - mu = zeros_like(params) - nu = zeros_like(params) + mu = zeros_like(params, type: :f32) + nu = zeros_like(params, type: :f32) count = Nx.tensor(0) %{mu: mu, nu: nu, count: count} end @@ -506,7 +432,7 @@ defmodule Axon.Updates do nu = update_moment(x, nu, b2, 2) count_inc = count + 1 - b2t = Nx.power(b2, count_inc) + b2t = Nx.pow(b2, count_inc) ro = ro_inf - 2 * count_inc * b2t / (1 - b2t) mu_hat = bias_correction(mu, b1, count + 1) @@ -525,10 +451,8 @@ defmodule Axon.Updates do defnp radam_update(ro, ro_inf, mu, nu, eps_root, eps) do r = Nx.sqrt((ro - 4) * (ro - 2) * ro_inf / ((ro_inf - 4) * (ro_inf - 2) * ro)) - transform({r, mu, nu, eps_root, eps}, fn {r, mu, nu, eps_root, eps} -> - deep_merge(mu, nu, fn m, v -> - r * m / (Nx.sqrt(v + eps_root) + eps) - end) + deep_merge(mu, nu, fn m, v -> + r * m / (Nx.sqrt(v + eps_root) + eps) end) end @@ -543,6 +467,7 @@ defmodule Axon.Updates do to `false` """ + @deprecated "Use Polaris.Updates.trace/1 instead" def trace(combinator_or_opts \\ []) def trace(opts) when is_list(opts) do @@ -564,7 +489,7 @@ defmodule Axon.Updates do end defnp init_trace(params) do - trace = zeros_like(params) + trace = zeros_like(params, type: :f32) %{trace: trace} end @@ -592,6 +517,7 @@ defmodule Axon.Updates do * `:delta` - maximum absolute value of the input. Defaults to `2.0` """ + @deprecated "Use Polaris.Updates.clip/1 instead" def clip(combinator_or_opts \\ []) def clip(opts) when is_list(opts) do @@ -623,6 +549,7 @@ defmodule Axon.Updates do * `:max_norm` - maximum norm value of input. Defaults to `1.0` """ + @deprecated "Use Polaris.Updates.clip_by_global_norm/1 instead" def clip_by_global_norm(combinator_or_opts \\ []) def clip_by_global_norm(opts) when is_list(opts) do @@ -646,7 +573,7 @@ defmodule Axon.Updates do sum_gs = deep_reduce(x, Nx.tensor(0.0), fn leaf, acc -> leaf - |> Nx.power(2) + |> Nx.pow(2) |> Nx.sum() |> Nx.add(acc) end) @@ -661,6 +588,7 @@ defmodule Axon.Updates do @doc """ Centralizes input by shifting updates by their mean. """ + @deprecated "Use Polaris.Updates.centralize/1 instead" def centralize(combinator_or_opts \\ []) def centralize(opts) when is_list(opts) do @@ -678,16 +606,16 @@ defmodule Axon.Updates do end defnp apply_centralize(x, _params, _opts \\ []) do - transform(x, fn x -> - deep_new(x, fn z -> - if Elixir.Kernel.>(Nx.rank(z), 1) do - axes = tl(Nx.axes(z)) - z - Nx.mean(z, axes: axes, keep_axes: true) - else - z - end - end) - end) + deep_new(x, ¢ralize_for_rank/1) + end + + deftransformp centralize_for_rank(input) do + if Nx.rank(input) > 1 do + input + |> Nx.subtract(Nx.mean(input, axes: tl(Nx.axes(input)), keep_axes: true)) + else + input + end end @doc """ @@ -699,6 +627,7 @@ defmodule Axon.Updates do * `:decay` - Rate of decay. Defaults to `0.0`. """ + @deprecated "Use Polaris.Updates.add_decayed_weights/1 instead" def add_decayed_weights(combinator_or_opts \\ []) def add_decayed_weights(opts) when is_list(opts) do @@ -737,6 +666,7 @@ defmodule Axon.Updates do * `:eps` - Numerical stability term. Defaults to `0.0`. """ + @deprecated "Use Polaris.Updates.scale_by_trust_ratio/1 instead" def scale_by_trust_ratio(combinator_or_opts \\ []) def scale_by_trust_ratio(opts) when is_list(opts) do @@ -781,12 +711,16 @@ defmodule Axon.Updates do ## Options + * `:seed` - Random seed to use. Defaults to the + current system time. + * `:eta` - Controls amount of noise to add. Defaults to `0.01`. * `:gamma` - Controls amount of noise to add. Defaults to `0.55`. """ + @deprecated "Use Polaris.Updates.add_noise/1 instead" def add_noise(combinator_or_opts \\ []) def add_noise(opts) when is_list(opts) do @@ -800,22 +734,26 @@ defmodule Axon.Updates do def add_noise({init_fn, apply_fn} = combinator, opts) when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do - stateful(combinator, &init_add_noise/1, &apply_add_noise(&1, &2, &3, opts)) + {seed, opts} = Keyword.pop_lazy(opts, :seed, fn -> :erlang.system_time() end) + stateful(combinator, &init_add_noise(&1, seed: seed), &apply_add_noise(&1, &2, &3, opts)) end - defnp init_add_noise(_params) do - %{count: Nx.tensor(0)} + defnp init_add_noise(_params, opts \\ []) do + %{count: Nx.tensor(0), key: Nx.Random.key(opts[:seed])} end - defnp apply_add_noise(x, %{count: count}, _params, opts \\ []) do + defnp apply_add_noise(x, %{count: count, key: key}, _params, opts \\ []) do opts = keyword!(opts, eta: 0.01, gamma: 0.55) - var = opts[:eta] / Nx.power(count + 1, opts[:gamma]) + var = opts[:eta] / Nx.pow(count + 1, opts[:gamma]) - noise = deep_new(x, fn z -> Nx.random_normal(z) end) + {noise, key} = + deep_map_reduce(x, key, fn z, key -> + Nx.Random.normal(key, shape: Nx.shape(z), type: Nx.type(z)) + end) updates = deep_merge(x, noise, fn g, n -> g + var * n end) - {updates, %{count: count + 1}} + {updates, %{count: count + 1, key: key}} end @doc """ @@ -837,6 +775,7 @@ defmodule Axon.Updates do * [Adaptive Methods for Nonconvex Optimization](https://proceedings.neurips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf) """ + @deprecated "Use Polaris.Updates.scale_by_yogi/1 instead" def scale_by_yogi(combinator_or_opts \\ []) def scale_by_yogi(opts) when is_list(opts) do @@ -860,7 +799,7 @@ defmodule Axon.Updates do end defnp init_scale_by_yogi(params, value) do - value = fulls_like(params, value) + value = fulls_like(params, value, type: :f32) mu = value nu = value count = Nx.tensor(0) @@ -878,7 +817,7 @@ defmodule Axon.Updates do nu = deep_merge(x, nu, fn g, v -> - v - (1 - b2) * Nx.sign(v - Nx.power(g, 2)) * Nx.power(g, 2) + v - (1 - b2) * Nx.sign(v - Nx.pow(g, 2)) * Nx.pow(g, 2) end) mu_hat = bias_correction(mu, b1, count + 1) @@ -895,6 +834,7 @@ defmodule Axon.Updates do Stateless updates do not depend on an update state and thus only require an implementation of an update function. """ + @deprecated "Use Polaris.Updates.stateless/2 instead" def stateless({parent_init_fn, parent_apply_fn} \\ identity(), apply_fn) do apply_fn = fn updates, state, params -> {updates, state} = parent_apply_fn.(updates, state, params) @@ -909,6 +849,7 @@ defmodule Axon.Updates do This is often as the initial update in many functions in this module. """ + @deprecated "Use Polaris.Updates.identity/1 instead" def identity() do {fn _params -> {} end, fn updates, state, _params -> {updates, state} end} end @@ -931,6 +872,7 @@ defmodule Axon.Updates do Axon.Updates.centralize() |> Axon.Updates.scale_by_rms() """ + @deprecated "Use Polaris.Updates.compose/2 instead" def compose({init_fn1, apply_fn1}, {init_fn2, apply_fn2}) do init_fn = fn params -> state = init_fn1.(params) @@ -956,6 +898,7 @@ defmodule Axon.Updates do implement some initialization function as well as an update function. """ + @deprecated "Use Polaris.Updates.stateful/3 instead" def stateful({parent_init_fn, parent_apply_fn} \\ identity(), init_fn, apply_fn) do init_fn = fn params -> state = parent_init_fn.(params) @@ -1007,11 +950,11 @@ defmodule Axon.Updates do ## Helpers defnp update_moment(x, moment, decay, order) do - deep_merge(x, moment, fn g, z -> (1 - decay) * Nx.power(g, order) + decay * z end) + deep_merge(x, moment, fn g, z -> (1 - decay) * Nx.pow(g, order) + decay * z end) end defnp bias_correction(moment, decay, count) do - deep_new(moment, fn z -> z / (1 - Nx.power(decay, count)) end) + deep_new(moment, fn z -> z / (1 - Nx.pow(decay, count)) end) end defnp safe_norm(g, min_norm) do diff --git a/mix.exs b/mix.exs index ef73d2b6..4154f74d 100644 --- a/mix.exs +++ b/mix.exs @@ -2,7 +2,7 @@ defmodule Axon.MixProject do use Mix.Project @source_url "https://github.com/elixir-nx/axon" - @version "0.3.0" + @version "0.6.0" def project do [ @@ -35,13 +35,14 @@ defmodule Axon.MixProject do # Run "mix help deps" to learn about dependencies. defp deps do [ - {:exla, "~> 0.4.0", [only: :test] ++ exla_opts()}, - {:torchx, "~> 0.4.0", [only: :test] ++ torchx_opts()}, - {:nx, "~> 0.4.0", nx_opts()}, + {:exla, "~> 0.6.0", [only: :test] ++ exla_opts()}, + {:torchx, "~> 0.6.0", [only: :test] ++ torchx_opts()}, + {:nx, "~> 0.6.0", nx_opts()}, {:ex_doc, "~> 0.23", only: :docs}, {:table_rex, "~> 3.1.1", optional: true}, {:kino, "~> 0.7", optional: true}, - {:kino_vega_lite, "~> 0.1.7", optional: true} + {:kino_vega_lite, "~> 0.1.7", optional: true}, + {:polaris, "~> 0.1"} ] end @@ -115,7 +116,7 @@ defmodule Axon.MixProject do groups_for_extras: [ "Guides: Model Creation": Path.wildcard("guides/model_creation/*.livemd"), "Guides: Model Execution": Path.wildcard("guides/model_execution/*.livemd"), - "Guides: Training and Evalutaion": + "Guides: Training and Evaluation": Path.wildcard("guides/training_and_evaluation/*.livemd"), "Guides: Serialization": Path.wildcard("guides/serialization/*.livemd"), "Examples: Basics": Path.wildcard("notebooks/basics/*.livemd"), @@ -155,7 +156,7 @@ defmodule Axon.MixProject do Axon.MixedPrecision, Axon.None, Axon.StatefulOutput, - Axon.Initalizers + Axon.Initializers ], Summary: [ Axon.Display @@ -169,11 +170,6 @@ defmodule Axon.MixProject do Axon.Recurrent, Axon.LossScale ], - Optimization: [ - Axon.Optimizers, - Axon.Updates, - Axon.Schedules - ], Loop: [ Axon.Loop, Axon.Loop.State diff --git a/mix.lock b/mix.lock index c2a7ff8d..a4010e04 100644 --- a/mix.lock +++ b/mix.lock @@ -1,21 +1,23 @@ %{ - "complex": {:hex, :complex, "0.4.2", "923e5db0be13dbb3ea00cf8459d9f75f3afdd9ff5a82742ded21064330d28273", [:mix], [], "hexpm", "069a085ef820ce675a2619fd125b963ff4514af2102c7f7d7965128e5ec0a429"}, - "dll_loader_helper": {:hex, :dll_loader_helper, "0.1.8", "1621409a3cb06c750fe845bf954785cffa5fe8f2fca41006008b891877603bf7", [:make, :mix, :rebar3], [], "hexpm", "cd373dc6a028f3e37eca26b073e3a75249513db2f9b0e42520423886801fa7d7"}, - "earmark_parser": {:hex, :earmark_parser, "1.4.29", "149d50dcb3a93d9f3d6f3ecf18c918fb5a2d3c001b5d3305c926cddfbd33355b", [:mix], [], "hexpm", "4902af1b3eb139016aed210888748db8070b8125c2342ce3dcae4f38dcc63503"}, - "elixir_make": {:hex, :elixir_make, "0.7.1", "314f2a5450254db0446ba94cc1ba12a25b83b457f24aa9cc21c128cead5d03aa", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "0f1ad4787b4d7489563351cbf85c9221a852f5441364a2cb3ffd36f2fda7f7fb"}, - "ex_doc": {:hex, :ex_doc, "0.29.1", "b1c652fa5f92ee9cf15c75271168027f92039b3877094290a75abcaac82a9f77", [:mix], [{:earmark_parser, "~> 1.4.19", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "b7745fa6374a36daf484e2a2012274950e084815b936b1319aeebcf7809574f6"}, - "exla": {:hex, :exla, "0.4.1", "409a3294720e31bbcd03c3eacd654686feb0ed7ba3e42314a269eeaa7cfd3c76", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.4.1", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.4.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "0410a0c38b94d2be4b713753c86f33ab8400b5cc8b59100266a0a6c58c17871d"}, - "kino": {:hex, :kino, "0.8.0", "07603a32c111959ed48f08ac3808a0dda05433d28f8d2f06d65b25b255966649", [:mix], [{:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "736568d4de9eb56d8903bae6fe08b7c06db44efe37bb883165e755e623881c51"}, - "kino_vega_lite": {:hex, :kino_vega_lite, "0.1.7", "c93fdfe6e35c4c5a4f8afd51a89786b2187e5a7da4595b13ea02a4329d9f0976", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.4", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "59ee442f0532266749d15dc9af4e2875bec61ccfa1b07636bc396ee63dfde8e7"}, + "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, + "dll_loader_helper": {:hex, :dll_loader_helper, "1.1.0", "e7d015e980942a0d67e306827ec907e7e853a21186bd92bb968d986698591a0f", [:mix], [{:dll_loader_helper_beam, "~> 1.1", [hex: :dll_loader_helper_beam, repo: "hexpm", optional: false]}], "hexpm", "2b6c11ee7bb48f6a132ce8f872202f9e828c019988da1e2d40ad41496195df0c"}, + "dll_loader_helper_beam": {:hex, :dll_loader_helper_beam, "1.1.0", "d51232663985dbc998c59b5d080feecd5398d5b75a9f0293a9855db774c2684d", [:rebar3], [], "hexpm", "aa85d0d0e9398916a80b2fd751885877934ae3ea008288f99ff829c0b8ef1f55"}, + "earmark_parser": {:hex, :earmark_parser, "1.4.31", "a93921cdc6b9b869f519213d5bc79d9e218ba768d7270d46fdcf1c01bacff9e2", [:mix], [], "hexpm", "317d367ee0335ef037a87e46c91a2269fef6306413f731e8ec11fc45a7efd059"}, + "elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"}, + "ex_doc": {:hex, :ex_doc, "0.29.3", "f07444bcafb302db86e4f02d8bbcd82f2e881a0dcf4f3e4740e4b8128b9353f7", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "3dc6787d7b08801ec3b51e9bd26be5e8826fbf1a17e92d1ebc252e1a1c75bfe1"}, + "exla": {:hex, :exla, "0.6.0", "af63e45ce41ad25630967923147d14292a0cc48e507b8a3cf3bf3d5483099a28", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.6.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.5.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "5f6a4a105ea9ab207b9aa4de5a294730e2bfe9639f4b8d37a7c00da131090d7a"}, + "kino": {:hex, :kino, "0.9.0", "9d023e66ed29123ba414e978012a6e9958b09fbf5dddb5e0f4814e04df8223b7", [:mix], [{:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "46767bdbbdacc1c801d43b2dc6d2fe7fdf936bd74f4accdc5779f647f5eeda66"}, + "kino_vega_lite": {:hex, :kino_vega_lite, "0.1.8", "ec7e97778d6b774591e4cbf7fd27850abf7c0f5e9133a3d13e069aadfa04b5e3", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.4", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "0bc3135a77550ea5c5bd7bfb1fb215416ebddbbc8b1e280e6de39366cd17a2f8"}, "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.0", "f8c570a0d33f8039513fbccaf7108c5d750f47d8defd44088371191b76492b0b", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "28b2cbdc13960a46ae9a8858c4bebdec3c9a6d7b4b9e7f4ed1502f8159f338e7"}, "makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"}, "nimble_parsec": {:hex, :nimble_parsec, "1.2.3", "244836e6e3f1200c7f30cb56733fd808744eca61fd182f731eac4af635cc6d0b", [:mix], [], "hexpm", "c8d789e39b9131acf7b99291e93dae60ab48ef14a7ee9d58c6964f59efb570b0"}, - "nx": {:hex, :nx, "0.4.1", "3cc8e420d0835ab7cac94f253950dee3ff927c68798ae88a3e4ff184a825b042", [:mix], [{:complex, "~> 0.4.2", [hex: :complex, repo: "hexpm", optional: false]}], "hexpm", "0b33fccaf76ebc6e79d53fe1149a70f99838e6505e9e7092e5a0a57b131b27c6"}, + "nx": {:hex, :nx, "0.6.0", "37c86eae824125a7e298dd1ee896953d9d671ce3630dcff74c77db17d734a85f", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "e1ad3cc70a5828a1aedb156b71e90863d9623a2dc9b35a5588f8627a07ee6cb4"}, + "polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"}, "table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"}, "table_rex": {:hex, :table_rex, "3.1.1", "0c67164d1714b5e806d5067c1e96ff098ba7ae79413cc075973e17c38a587caa", [:mix], [], "hexpm", "678a23aba4d670419c23c17790f9dcd635a4a89022040df7d5d772cb21012490"}, - "telemetry": {:hex, :telemetry, "1.1.0", "a589817034a27eab11144ad24d5c0f9fab1f58173274b1e9bae7074af9cbee51", [:rebar3], [], "hexpm", "b727b2a1f75614774cff2d7565b64d0dfa5bd52ba517f16543e6fc7efcc0df48"}, - "torchx": {:hex, :torchx, "0.4.1", "5aa7f93d7aff85c9f5fbae4c534affa9d16a9ffe9bbbb261cf2dca8ead2f6ab8", [:make, :mix], [{:dll_loader_helper, "~> 0.1.0", [hex: :dll_loader_helper, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.4.1", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "3d1ee6b2588cadf2e70e4ea33449bda8e27f315b653f12d1f681070b83854b0e"}, + "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, + "torchx": {:hex, :torchx, "0.6.0", "e4a5f545e245c15aceeafcf9f22ac2ae0a87720c4a6b2f132e9909635f434e93", [:make, :mix], [{:dll_loader_helper, "~> 0.1 or ~> 1.0", [hex: :dll_loader_helper, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.6.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "35365dc51ee28dc86ca87c150dd3869bc83b207b2574bb2310c1be39e3867550"}, "vega_lite": {:hex, :vega_lite, "0.1.6", "145ab4908bc890b02cef3526e890e9b899528eaa7aa9d6fa642b52a8a2c682c6", [:mix], [{:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "078c0d8cd9a8eca4ae8f9527c45c01d69cefb6b2235fd5179a227ac2f031d7ac"}, - "xla": {:hex, :xla, "0.4.1", "c14a8214928f1aee68745b70c4f817c90e98740ceb69ad921071eb41792f9ecf", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "fe8323ceeebf114f183fcd3a09ab08d76a71e9fd9b1154109078a8355aa56366"}, + "xla": {:hex, :xla, "0.5.0", "fb8a02c02e5a4f4531fbf18a90c325e471037f983f0115d23f510e7dd9a6aa65", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "571ac797a4244b8ba8552ed0295a54397bd896708be51e4da6cbb784f6678061"}, } diff --git a/notebooks/generative/fashionmnist_autoencoder.livemd b/notebooks/generative/fashionmnist_autoencoder.livemd index 036776ff..43f10cf7 100644 --- a/notebooks/generative/fashionmnist_autoencoder.livemd +++ b/notebooks/generative/fashionmnist_autoencoder.livemd @@ -117,7 +117,7 @@ mean_square_error = fn y_pred, y -> |> Nx.mean() end -mean_absolute_erorr = fn y_pred, y -> +mean_absolute_error = fn y_pred, y -> y_pred |> Nx.subtract(y) |> Nx.abs() @@ -139,7 +139,7 @@ For the same image both errors should be 0, because when we have two exact copie ```elixir { mean_square_error.(shoe_image, shoe_image), - mean_absolute_erorr.(shoe_image, shoe_image) + mean_absolute_error.(shoe_image, shoe_image) } ``` @@ -148,7 +148,7 @@ Now the noised image: ```elixir { mean_square_error.(shoe_image, noised_shoe_image), - mean_absolute_erorr.(shoe_image, noised_shoe_image) + mean_absolute_error.(shoe_image, noised_shoe_image) } ``` @@ -157,7 +157,7 @@ And a different image: ```elixir { mean_square_error.(shoe_image, other_image), - mean_absolute_erorr.(shoe_image, other_image) + mean_absolute_error.(shoe_image, other_image) } ``` diff --git a/notebooks/generative/fashionmnist_vae.livemd b/notebooks/generative/fashionmnist_vae.livemd index 352357be..a4b4e968 100644 --- a/notebooks/generative/fashionmnist_vae.livemd +++ b/notebooks/generative/fashionmnist_vae.livemd @@ -251,7 +251,7 @@ end params = model - |> Axon.Loop.trainer(:mean_squared_error, Axon.Optimizers.adamw(0.001)) + |> Axon.Loop.trainer(:mean_squared_error, Polaris.Optimizers.adamw(learning_rate: 0.001)) |> KinoAxon.kino_early_stop() |> Axon.Loop.handle(:iteration_completed, render_example_handler, every: 450) |> Axon.Loop.validate(model, test_data) @@ -265,7 +265,7 @@ params = ## Splitting up the model -Cool! We now have the parameters for a trained, simple autoencoder. Our next step is to split up the model so we can use the encoder and decoder separately. By doing that, we'll be able to take an image and *encode* it to get the model's compressed image representation (the latent vector). We can then manipulate the latent vector and run the manipulated latent vector through the *decoder* to get a new image. +Cool! We now have the parameters for a trained, simple autoencoder. Our next step is to split up the model so we can use the encoder and decoder separately. By doing that, we'll be able to take an image and _encode_ it to get the model's compressed image representation (the latent vector). We can then manipulate the latent vector and run the manipulated latent vector through the _decoder_ to get a new image. Let's start by defining the encoder and decoder separately as two different models. @@ -311,7 +311,7 @@ So all we need to do is create a new Map that plucks out the right layers from o Fortunately, since we gave each of the layers names, this requires no work at all - we can use the Map as it is since the layer names match up! Axon will ignore any extra keys so those won't be a problem. -Note that naming the layers wasn't *required*, if the layers didn't have names we would have some renaming to do to get the names to match between the models. But giving them names made it very convenient :) +Note that naming the layers wasn't _required_, if the layers didn't have names we would have some renaming to do to get the names to match between the models. But giving them names made it very convenient :) Let's try encoding an image, printing the latent and then decoding the latent using our split up model to make sure it's working. @@ -474,7 +474,7 @@ end params = model - |> Axon.Loop.trainer(&CustomLoss.loss/2, Axon.Optimizers.adam(0.001)) + |> Axon.Loop.trainer(&CustomLoss.loss/2, Polaris.Optimizers.adam(learning_rate: 0.001)) |> KinoAxon.kino_early_stop() |> Axon.Loop.handle(:epoch_completed, render_example_handler) |> Axon.Loop.validate(model, test_data) diff --git a/notebooks/generative/mnist_autoencoder_using_kino.livemd b/notebooks/generative/mnist_autoencoder_using_kino.livemd index b56ed8a7..e251922e 100644 --- a/notebooks/generative/mnist_autoencoder_using_kino.livemd +++ b/notebooks/generative/mnist_autoencoder_using_kino.livemd @@ -75,13 +75,13 @@ test_images[[images: 0..2]] |> Nx.to_heatmap() An autoencoder is a a network that has the same sized input as output, with a "bottleneck" layer in the middle with far fewer parameters than the input. Its goal is to force the output to reconstruct the input. The bottleneck layer forces the network to learn a compressed representation of the input space. -A *denoising* autoencoder is a small tweak on an autoencoder that takes a corrupted input (often corrupted by adding noise or zeroing out pixels) and reconstructs the original input, removing the noise in the process. +A _denoising_ autoencoder is a small tweak on an autoencoder that takes a corrupted input (often corrupted by adding noise or zeroing out pixels) and reconstructs the original input, removing the noise in the process. -The part of the autoencoder that takes the input and compresses it into the bottleneck layer is called the *encoder* and the part that takes the compressed representation and reconstructs the input is called the *decoder*. Usually the decoder mirrors the encoder. +The part of the autoencoder that takes the input and compresses it into the bottleneck layer is called the _encoder_ and the part that takes the compressed representation and reconstructs the input is called the _decoder_. Usually the decoder mirrors the encoder. MNIST is a pretty easy dataset, so we're going to try a fairly small autoencoder. -The input image has size 784 (28 rows * 28 cols * 1 pixel). We'll set up the encoder to turn that into 256 features, then 128, 64, and then 10 features for the bottleneck layer. The decoder will do the reverse, take the 10 features and go to 64, 128, 256 and 784. I'll use fully-connected (dense) layers. +The input image has size 784 (28 rows _ 28 cols _ 1 pixel). We'll set up the encoder to turn that into 256 features, then 128, 64, and then 10 features for the bottleneck layer. The decoder will do the reverse, take the 10 features and go to 64, 128, 256 and 784. I'll use fully-connected (dense) layers. @@ -197,14 +197,14 @@ Looks right (and tricky). Let's see how the model does. ```elixir params = model - |> Axon.Loop.trainer(:mean_squared_error, Axon.Optimizers.adamw(0.001)) + |> Axon.Loop.trainer(:mean_squared_error, Polaris.Optimizers.adamw(learning_rate: 0.001)) |> Axon.Loop.validate(model, test_data) |> Axon.Loop.run(train_data, %{}, epochs: 20, compiler: EXLA) :ok ``` -Now that we have a model that theoretically has learned *something*, we'll see what it's learned by running it on some images from the test set. We'll use Kino to allow us to select the image from the test set to run the model against. To avoid losing the params that took a while to train, we'll create another branch so we can experiment with the params and stop execution when needed without having to retrain. +Now that we have a model that theoretically has learned _something_, we'll see what it's learned by running it on some images from the test set. We'll use Kino to allow us to select the image from the test set to run the model against. To avoid losing the params that took a while to train, we'll create another branch so we can experiment with the params and stop execution when needed without having to retrain. @@ -257,7 +257,7 @@ Note we used `Kino.animate/2` which runs asynchronously so we don't block execut ## A better training loop -*Note that we branch from the "Building a model" section since we only need the model definition for this section and not the previously trained model.* +_Note that we branch from the "Building a model" section since we only need the model definition for this section and not the previously trained model._ @@ -312,7 +312,7 @@ end params = model - |> Axon.Loop.trainer(:mean_squared_error, Axon.Optimizers.adamw(0.001)) + |> Axon.Loop.trainer(:mean_squared_error, Polaris.Optimizers.adamw(learning_rate: 0.001)) |> Axon.Loop.handle(:iteration_completed, render_example_handler, every: 450) |> Axon.Loop.validate(model, test_data) |> Axon.Loop.run(train_data, %{}, epochs: 20, compiler: EXLA) diff --git a/notebooks/structured/credit_card_fraud.livemd b/notebooks/structured/credit_card_fraud.livemd index e407723f..ac937424 100644 --- a/notebooks/structured/credit_card_fraud.livemd +++ b/notebooks/structured/credit_card_fraud.livemd @@ -17,7 +17,7 @@ alias Explorer.{DataFrame, Series} ## Introduction -This time we will examine the Credit Card Fraud Dataset. Due to confidentiality, the original data were preprocessed by principal component analysis (PCA), and then 31 principal components were selected for the final data set. The dataset is highly imbalanced. The positive class (frauds) account for 0.172% of all transactions. Eventually, we will create a classifier which has not only great accuracy but, what is even more important, a high *recall* and *precision* - two metrics that are much more indicative of performance with imbalanced classification problems. +This time we will examine the Credit Card Fraud Dataset. Due to confidentiality, the original data were preprocessed by principal component analysis (PCA), and then 31 principal components were selected for the final data set. The dataset is highly imbalanced. The positive class (frauds) account for 0.172% of all transactions. Eventually, we will create a classifier which has not only great accuracy but, what is even more important, a high _recall_ and _precision_ - two metrics that are much more indicative of performance with imbalanced classification problems. ## Data processing @@ -139,7 +139,7 @@ IO.puts("# of fraudulent transactions (train): #{fraud}") IO.puts("% fraudlent transactions (train): #{100 * (fraud / (legit + fraud))}%") ``` -As always, we define our train loop. We are using *binary cross-entropy* as our loss function and Adam as the optimizer with a learning rate of 0.01. Then we immediately start the training passing our train portion of the dataset. +As always, we define our train loop. We are using _binary cross-entropy_ as our loss function and Adam as the optimizer with a learning rate of 0.01. Then we immediately start the training passing our train portion of the dataset. ```elixir loss = @@ -151,7 +151,7 @@ loss = reduction: :mean ) -optimizer = Axon.Optimizers.adam(1.0e-2) +optimizer = Polaris.Optimizers.adam(learning_rate: 1.0e-2) params = model diff --git a/notebooks/text/lstm_generation.livemd b/notebooks/text/lstm_generation.livemd index 677188f7..7aabec7c 100644 --- a/notebooks/text/lstm_generation.livemd +++ b/notebooks/text/lstm_generation.livemd @@ -158,7 +158,7 @@ model = To train the network, we will use Axon's Loop API. It is pretty straightforward. -For the loss function we can use *categorical cross-entropy* since we are dealing with categories (each character) in our output. For the optimizer we can use *Adam*. +For the loss function we can use _categorical cross-entropy_ since we are dealing with categories (each character) in our output. For the optimizer we can use _Adam_. We will train our network for 20 epochs. Note that we are working with a fair amount data, so it may take a long time unless you run it on a GPU. @@ -171,7 +171,7 @@ IO.puts("Total batches: #{Enum.count(train_batches)}") params = model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(0.001)) + |> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adam(learning_rate: 0.001)) |> Axon.Loop.run(Stream.zip(train_batches, result_batches), %{}, epochs: 20, compiler: EXLA) :ok @@ -250,7 +250,7 @@ IO.puts("Total batches: #{Enum.count(train_batches)}") new_params = new_model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(0.001)) + |> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adam(learning_rate: 0.001)) |> Axon.Loop.run(Stream.zip(train_batches, result_batches), %{}, epochs: 50, compiler: EXLA) :ok diff --git a/notebooks/vision/horses_or_humans.livemd b/notebooks/vision/horses_or_humans.livemd index 225e2db3..ce24a7b9 100644 --- a/notebooks/vision/horses_or_humans.livemd +++ b/notebooks/vision/horses_or_humans.livemd @@ -143,7 +143,7 @@ The next step is creating our model. In this notebook, we choose the classic Con -| ![](https://miroslawmamczur.pl/wp-content/uploads/2021/03/06.gif) | +| ![](https://miroslawmamczur.pl/wp-content/uploads/2021/03/06.gif) | | :-------------------------------------------------------------------------------------: | | Figure 1: A step-by-step visualization of a convolution layer for `kernel_size: {3, 3}` | @@ -155,7 +155,7 @@ The next step is creating our model. In this notebook, we choose the classic Con | ![](https://production-media.paperswithcode.com/methods/MaxpoolSample2.png) | | :-------------------------------------------------------------------------: | -| Figure 2: Max pooling operation for `kernel_size: {2, 2}` | +| Figure 2: Max pooling operation for `kernel_size: {2, 2}` | @@ -163,7 +163,7 @@ The next step is creating our model. In this notebook, we choose the classic Con -| ![](https://miro.medium.com/max/1400/1*KkqxjvXTIV_b365B41ltfg.png) | +| ![](https://miro.medium.com/max/1400/1*KkqxjvXTIV_b365B41ltfg.png) | | :-------------------------------------------------------------------: | | Figure 3: The difference between standard dropout and spatial dropout | @@ -199,7 +199,7 @@ It's time to train our model. We specify the loss, optimizer and choose accuracy ```elixir data = HorsesHumans.DataProcessing.data_stream(files, batch_size) -optimizer = Axon.Optimizers.adam(1.0e-4) +optimizer = Polaris.Optimizers.adam(learning_rate: 1.0e-4) params = model @@ -215,7 +215,7 @@ params = We can improve the training by applying gradient centralization. It is a technique with a similar purpose to batch normalization. For each loss gradient, we subtract a mean value to have a gradient with mean equal to zero. This process prevents gradients from exploding. ```elixir -centralized_optimizer = Axon.Updates.compose(Axon.Updates.centralize(), optimizer) +centralized_optimizer = Polaris.Updates.compose(Polaris.Updates.centralize(), optimizer) model |> Axon.Loop.trainer(:categorical_cross_entropy, centralized_optimizer, :identity, log: 1) @@ -242,7 +242,7 @@ input = Axon.predict(model, params, input) ``` -*Note: the model output refers to the probability that the image presents a horse and a human respectively.* +_Note: the model output refers to the probability that the image presents a horse and a human respectively._ diff --git a/test/axon/activations_test.exs b/test/axon/activations_test.exs index 47e0cc7c..2f6634cb 100644 --- a/test/axon/activations_test.exs +++ b/test/axon/activations_test.exs @@ -700,7 +700,7 @@ defmodule Axon.ActivationsTest do describe "log_softmax" do test "raises on bad axis" do - assert_raise ArgumentError, ~r/log_softmax axis must be within rank of tensor/, fn -> + assert_raise ArgumentError, "given axis (2) invalid for shape with rank 2", fn -> Axon.Activations.log_softmax(Nx.iota({1, 3}), axis: 2) end end @@ -1143,6 +1143,22 @@ defmodule Axon.ActivationsTest do actual = apply(jit(fn x -> grad(x, &Nx.sum(Axon.Activations.sigmoid(&1))) end), [a]) assert_all_close(expected, actual) end + + defn cache_test_sigmoid(x) do + x + |> Axon.Activations.sigmoid() + |> get_cached() + end + + deftransformp get_cached(res) do + %{data: %{args: [_, %{logits: inp}]}} = res + inp + end + + test "caches input logits" do + {a, _key} = Nx.Random.uniform(Nx.Random.key(42), shape: {10, 10}) + assert_all_close(cache_test_sigmoid(a), a) + end end describe "silu" do @@ -1348,6 +1364,17 @@ defmodule Axon.ActivationsTest do actual = apply(jit(fn x -> grad(x, &Nx.sum(Axon.Activations.softmax(&1))) end), [a]) assert_all_close(expected, actual, atol: 1.0e-7) end + + defn cache_test_softmax(x) do + x + |> Axon.Activations.softmax() + |> get_cached() + end + + test "caches input logits" do + {a, _key} = Nx.Random.uniform(Nx.Random.key(42), shape: {10, 10}) + assert_all_close(cache_test_softmax(a), a) + end end describe "softplus" do diff --git a/test/axon/compiler_test.exs b/test/axon/compiler_test.exs index a05da1e8..5db0d7ab 100644 --- a/test/axon/compiler_test.exs +++ b/test/axon/compiler_test.exs @@ -7,7 +7,7 @@ defmodule CompilerTest do describe "input" do test "single input, single output" do model = Axon.input("input_0", shape: {nil, 1}) - input = Nx.random_uniform({1, 1}, type: {:f, 32}) + input = random({1, 1}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -19,8 +19,8 @@ defmodule CompilerTest do {Axon.input("input_0", shape: {nil, 1}), Axon.input("input_1", shape: {nil, 1})} |> Axon.container() - input1 = Nx.random_uniform({1, 1}) - input2 = Nx.random_uniform({1, 1}) + input1 = random({1, 1}) + input2 = random({1, 1}) input = %{"input_0" => input1, "input_1" => input2} assert {init_fn, predict_fn} = Axon.build(model1) @@ -35,7 +35,7 @@ defmodule CompilerTest do test "output map" do model = %{foo: Axon.input("input_0", shape: {nil, 1})} |> Axon.container() - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert {init_fn, predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -48,8 +48,8 @@ defmodule CompilerTest do model1 = {input1, {input1, {input2, {}}, input2, %{foo: input1}}} |> Axon.container() - inp1 = Nx.random_uniform({1, 1}) - inp2 = Nx.random_uniform({1, 2}) + inp1 = random({1, 1}) + inp2 = random({1, 2}) input = %{"input_0" => inp1, "input_1" => inp2} assert {init_fn, predict_fn} = Axon.build(model1) @@ -67,9 +67,9 @@ defmodule CompilerTest do z = Axon.input("z", shape: {nil, 1}) model = {z, x, y} |> Axon.container() - x_val = Nx.random_uniform({1, 1}) - y_val = Nx.random_uniform({1, 1}) - z_val = Nx.random_uniform({1, 1}) + x_val = random({1, 1}) + y_val = random({1, 1}) + z_val = random({1, 1}) input = %{"x" => x_val, "y" => y_val, "z" => z_val} assert {init_fn, predict_fn} = Axon.build(model) @@ -99,7 +99,7 @@ defmodule CompilerTest do test "raises if input not found, no default value" do model = Axon.input("input_0", shape: {nil, 32}) - input = Nx.random_uniform({1, 16}) + input = random({1, 16}) assert {_, predict_fn} = Axon.build(model) exception = assert_raise ArgumentError, fn -> predict_fn.(%{}, %{foo: input}) end @@ -262,7 +262,7 @@ defmodule CompilerTest do test "initializes with no params" do for activation <- @activation_layers do model = Axon.input("input_0", shape: {nil, 32}) |> Axon.activation(activation) - input = Nx.random_uniform({1, 32}) + input = random({1, 32}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -272,7 +272,7 @@ defmodule CompilerTest do test "computes forward pass with default options" do for activation <- @activation_layers do model = Axon.input("input_0", shape: {nil, 1}) |> Axon.activation(activation) - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert {_init_fn, predict_fn} = Axon.build(model) assert_equal(predict_fn.(%{}, input), apply(Axon.Activations, activation, [input])) @@ -282,7 +282,7 @@ defmodule CompilerTest do test "computes forward pass with custom options" do for activation <- [:celu, :elu, :leaky_relu] do model = Axon.input("input_0", shape: {nil, 32}) |> Axon.activation(activation, alpha: 0.8) - input = Nx.random_uniform({1, 32}, type: {:f, 32}) + input = random({1, 32}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model) @@ -299,10 +299,10 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert {init_fn, predict_fn} = Axon.build(mp_model) - assert Nx.type(predict_fn.(init_fn.(input, %{}), Nx.random_uniform({1, 1}))) == {:bf, 16} + assert Nx.type(predict_fn.(init_fn.(input, %{}), random({1, 1}))) == {:bf, 16} end end end @@ -311,7 +311,7 @@ defmodule CompilerTest do test "initializes in default case" do model = Axon.input("input_0", shape: {nil, 1}) |> Axon.bias(name: "bias") - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{"bias" => %{"bias" => bias}} = init_fn.(input, %{}) @@ -324,7 +324,7 @@ defmodule CompilerTest do test "initializes in default case" do model = Axon.input("input_0", shape: {nil, 1}) |> Axon.dense(1, name: "dense") - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{"dense" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(input, %{}) @@ -335,7 +335,7 @@ defmodule CompilerTest do end test "initializes with custom initializers" do - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) model1 = Axon.input("input_0", shape: {nil, 1}) @@ -382,7 +382,7 @@ defmodule CompilerTest do Nx.Defn.grad(params, &Nx.mean(predict_fn.(&1, input))) end - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert %{"dense" => %{"kernel" => kernel_grad, "bias" => bias_grad}} = apply(Nx.Defn.jit(backward), [init_fn.(input, %{}), input]) @@ -396,7 +396,7 @@ defmodule CompilerTest do policy = AMP.create_policy(params: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) assert {init_fn, _} = Axon.build(mp_model) assert %{"dense" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(input, %{}) @@ -409,7 +409,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -419,7 +419,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 2}) |> Axon.dense(1, name: "dense", use_bias: false) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) assert {init_fn, _} = Axon.build(model) assert %{"dense" => %{"kernel" => _} = dense_params} = init_fn.(input, %{}) @@ -430,7 +430,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 2}) |> Axon.dense(1, name: "dense", use_bias: false) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) assert {init_fn, predict_fn} = Axon.build(model) assert %{"dense" => %{"kernel" => k}} = params = init_fn.(input, %{}) @@ -445,7 +445,7 @@ defmodule CompilerTest do input2 = Axon.input("input_1", shape: {nil, 2}) model = Axon.bilinear(input1, input2, 1, name: "bilinear") - inputs = %{"input_0" => Nx.random_uniform({1, 1}), "input_1" => Nx.random_uniform({1, 2})} + inputs = %{"input_0" => random({1, 1}), "input_1" => random({1, 2})} assert {init_fn, _predict_fn} = Axon.build(model) assert %{"bilinear" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(inputs, %{}) @@ -460,11 +460,11 @@ defmodule CompilerTest do input2 = Axon.input("input_1", shape: {nil, 2}) model1 = Axon.bilinear(input1, input2, 1, name: "bilinear", kernel_initializer: :zeros) - inputs = %{"input_0" => Nx.random_uniform({1, 1}), "input_1" => Nx.random_uniform({1, 2})} + inputs = %{"input_0" => random({1, 1}), "input_1" => random({1, 2})} assert {init_fn, _predict_fn} = Axon.build(model1) assert %{"bilinear" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(inputs, %{}) - assert_equal(kernel, zeros({1, 1, 2})) + assert_equal(kernel, zeros({1, 2})) assert Nx.shape(bias) == {1} assert Nx.type(bias) == {:f, 32} @@ -522,7 +522,7 @@ defmodule CompilerTest do input2 = Axon.input("input_1", shape: {nil, 2}) model = Axon.bilinear(input1, input2, 1, name: "bilinear") |> Axon.freeze() - input = %{"input_0" => Nx.random_uniform({1, 1}), "input_1" => Nx.random_uniform({1, 2})} + input = %{"input_0" => random({1, 1}), "input_1" => random({1, 2})} assert {init_fn, predict_fn} = Axon.build(model) @@ -544,7 +544,7 @@ defmodule CompilerTest do policy = AMP.create_policy(params: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = %{"input_0" => Nx.random_uniform({1, 1}), "input_1" => Nx.random_uniform({1, 2})} + input = %{"input_0" => random({1, 1}), "input_1" => random({1, 2})} assert {init_fn, _} = Axon.build(mp_model) assert %{"bilinear" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(input, %{}) @@ -559,7 +559,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = %{"input_0" => Nx.random_uniform({1, 1}), "input_1" => Nx.random_uniform({1, 2})} + input = %{"input_0" => random({1, 1}), "input_1" => random({1, 2})} assert {init_fn, predict_fn} = Axon.build(mp_model) @@ -571,7 +571,7 @@ defmodule CompilerTest do input2 = Axon.input("input_1", shape: {nil, 2}) model = Axon.bilinear(input1, input2, 1, name: "bilinear", use_bias: false) - input = %{"input_0" => Nx.random_uniform({1, 1}), "input_1" => Nx.random_uniform({1, 2})} + input = %{"input_0" => random({1, 1}), "input_1" => random({1, 2})} assert {init_fn, _} = Axon.build(model) assert %{"bilinear" => %{"kernel" => _} = bilinear_params} = init_fn.(input, %{}) @@ -583,8 +583,8 @@ defmodule CompilerTest do input2 = Axon.input("input_1", shape: {nil, 2}) model = Axon.bilinear(input1, input2, 1, name: "bilinear", use_bias: false) - inp1 = Nx.random_uniform({1, 1}) - inp2 = Nx.random_uniform({1, 2}) + inp1 = random({1, 1}) + inp2 = random({1, 2}) input = %{"input_0" => inp1, "input_1" => inp2} @@ -602,7 +602,7 @@ defmodule CompilerTest do test "initializes in default case" do model = Axon.input("input_0", shape: {nil, 1}) |> Axon.embedding(1, 1, name: "embedding") - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{"embedding" => %{"kernel" => kernel}} = init_fn.(input, %{}) @@ -615,7 +615,7 @@ defmodule CompilerTest do Axon.input("input_0", shape: {nil, 1}) |> Axon.embedding(1, 1, name: "embedding", kernel_initializer: :zeros) - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert {init_fn, _predict_fn} = Axon.build(model1) assert %{"embedding" => %{"kernel" => kernel}} = init_fn.(input, %{}) @@ -685,7 +685,7 @@ defmodule CompilerTest do for pool <- @pooling_layers do model = apply(Axon, pool, [Axon.input("input", shape: {nil, 32, 1})]) - input = Nx.random_uniform({1, 32, 1}) + input = random({1, 32, 1}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -697,7 +697,7 @@ defmodule CompilerTest do for pool <- @pooling_layers do model1 = apply(Axon, pool, [Axon.input("input", shape: {nil, 32, 1})]) - input1 = Nx.random_uniform({1, 32, 1}, type: {:f, 32}) + input1 = random({1, 32, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model1) @@ -707,7 +707,7 @@ defmodule CompilerTest do ) model2 = apply(Axon, pool, [Axon.input("input", shape: {nil, 8, 4, 1})]) - input2 = Nx.random_uniform({1, 8, 4, 1}, type: {:f, 32}) + input2 = random({1, 8, 4, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model2) @@ -717,7 +717,7 @@ defmodule CompilerTest do ) model3 = apply(Axon, pool, [Axon.input("input", shape: {nil, 8, 4, 2, 1})]) - input3 = Nx.random_uniform({1, 8, 4, 2, 1}, type: {:f, 32}) + input3 = random({1, 8, 4, 2, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model3) @@ -732,21 +732,21 @@ defmodule CompilerTest do for pool <- @pooling_layers do opts1 = [kernel_size: 6] model1 = apply(Axon, pool, [Axon.input("input", shape: {nil, 32, 1}), opts1]) - input1 = Nx.random_uniform({1, 32, 1}, type: {:f, 32}) + input1 = random({1, 32, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model1) assert_equal(predict_fn.(%{}, input1), apply(Axon.Layers, pool, [input1, opts1])) opts2 = [kernel_size: 2, strides: 2, padding: :same] model2 = apply(Axon, pool, [Axon.input("input", shape: {nil, 8, 4, 1}), opts2]) - input2 = Nx.random_uniform({1, 8, 4, 1}, type: {:f, 32}) + input2 = random({1, 8, 4, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model2) assert_equal(predict_fn.(%{}, input2), apply(Axon.Layers, pool, [input2, opts2])) opts3 = [kernel_size: {2, 1, 2}, strides: [1, 2, 1], padding: [{0, 1}, {1, 1}, {0, 2}]] model3 = apply(Axon, pool, [Axon.input("input", shape: {nil, 8, 4, 2, 1}), opts3]) - input3 = Nx.random_uniform({1, 8, 4, 2, 1}, type: {:f, 32}) + input3 = random({1, 8, 4, 2, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model3) assert_equal(predict_fn.(%{}, input3), apply(Axon.Layers, pool, [input3, opts3])) @@ -755,7 +755,7 @@ defmodule CompilerTest do test "lp_pool computes forward pass with custom norm" do model = Axon.input("input", shape: {nil, 32, 1}) |> Axon.lp_pool(norm: 3) - input = Nx.random_uniform({1, 32, 1}, type: {:f, 32}) + input = random({1, 32, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model) assert_equal(predict_fn.(%{}, input), Axon.Layers.lp_pool(input, kernel_size: {1}, norm: 3)) @@ -767,11 +767,11 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 32, 1}) + input = random({1, 32, 1}) assert {init_fn, predict_fn} = Axon.build(mp_model) - assert Nx.type(predict_fn.(init_fn.(input, %{}), Nx.random_uniform({1, 32, 1}))) == + assert Nx.type(predict_fn.(init_fn.(input, %{}), random({1, 32, 1}))) == {:bf, 16} end end @@ -784,7 +784,7 @@ defmodule CompilerTest do [channels: :last, kernel_size: {2}] ]) - inp = Nx.random_uniform({1, 32, 1}) + inp = random({1, 32, 1}) assert {_, predict_fn} = Axon.build(model) @@ -813,6 +813,59 @@ defmodule CompilerTest do # end end + describe "blur_pool" do + test "initializes with no params" do + model = apply(Axon, :blur_pool, [Axon.input("input", shape: {nil, 32, 32, 1})]) + + input = random({1, 32, 32, 1}) + + assert {init_fn, _predict_fn} = Axon.build(model) + assert %{} = init_fn.(input, %{}) + end + + test "computes forward pass with default options" do + model2 = apply(Axon, :blur_pool, [Axon.input("input", shape: {nil, 8, 4, 1})]) + input2 = random({1, 8, 4, 1}, type: {:f, 32}) + + assert {_, predict_fn} = Axon.build(model2) + + assert_equal( + predict_fn.(%{}, input2), + apply(Axon.Layers, :blur_pool, [input2]) + ) + end + + test "computes forward pass with output policy" do + model = apply(Axon, :blur_pool, [Axon.input("input", shape: {nil, 32, 32, 1})]) + policy = AMP.create_policy(output: {:bf, 16}) + mp_model = AMP.apply_policy(model, policy) + + input = random({1, 32, 32, 1}) + + assert {init_fn, predict_fn} = Axon.build(mp_model) + + assert Nx.type(predict_fn.(init_fn.(input, %{}), random({1, 32, 32, 1}))) == + {:bf, 16} + end + + test "computes forward pass with channels last" do + model = + apply(Axon, :blur_pool, [ + Axon.input("input", shape: {nil, 32, 32, 1}), + [channels: :last] + ]) + + inp = random({1, 32, 32, 1}) + + assert {_, predict_fn} = Axon.build(model) + + assert_equal( + predict_fn.(%{}, inp), + apply(Axon.Layers, :blur_pool, [inp, [channels: :last]]) + ) + end + end + @adaptive_pooling_layers [:adaptive_avg_pool, :adaptive_max_pool, :adaptive_lp_pool] describe "adaptive pooling" do @@ -820,7 +873,7 @@ defmodule CompilerTest do for pool <- @adaptive_pooling_layers do model = apply(Axon, pool, [Axon.input("input", shape: {nil, 32, 1})]) - input = Nx.random_uniform({1, 32, 1}) + input = random({1, 32, 1}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -830,7 +883,7 @@ defmodule CompilerTest do test "computes forward pass with default options" do for pool <- @adaptive_pooling_layers do model1 = apply(Axon, pool, [Axon.input("input", shape: {nil, 32, 1})]) - input1 = Nx.random_uniform({1, 32, 1}, type: {:f, 32}) + input1 = random({1, 32, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model1) @@ -840,7 +893,7 @@ defmodule CompilerTest do ) model2 = apply(Axon, pool, [Axon.input("input", shape: {nil, 8, 4, 1})]) - input2 = Nx.random_uniform({1, 8, 4, 1}, type: {:f, 32}) + input2 = random({1, 8, 4, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model2) @@ -850,7 +903,7 @@ defmodule CompilerTest do ) model3 = apply(Axon, pool, [Axon.input("input", shape: {nil, 8, 4, 2, 1})]) - input3 = Nx.random_uniform({1, 8, 4, 2, 1}, type: {:f, 32}) + input3 = random({1, 8, 4, 2, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model3) @@ -865,21 +918,21 @@ defmodule CompilerTest do for pool <- @adaptive_pooling_layers do opts1 = [output_size: 27] model1 = apply(Axon, pool, [Axon.input("input", shape: {nil, 32, 1}), opts1]) - input1 = Nx.random_uniform({1, 32, 1}, type: {:f, 32}) + input1 = random({1, 32, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model1) assert_equal(predict_fn.(%{}, input1), apply(Axon.Layers, pool, [input1, opts1])) opts2 = [output_size: {2, 3}] model2 = apply(Axon, pool, [Axon.input("input", shape: {nil, 8, 4, 1}), opts2]) - input2 = Nx.random_uniform({1, 8, 4, 1}, type: {:f, 32}) + input2 = random({1, 8, 4, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model2) assert_equal(predict_fn.(%{}, input2), apply(Axon.Layers, pool, [input2, opts2])) opts3 = [output_size: {4, 3, 1}] model3 = apply(Axon, pool, [Axon.input("input", shape: {nil, 8, 4, 2, 1}), opts3]) - input3 = Nx.random_uniform({1, 8, 4, 2, 1}, type: {:f, 32}) + input3 = random({1, 8, 4, 2, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model3) assert_equal(predict_fn.(%{}, input3), apply(Axon.Layers, pool, [input3, opts3])) @@ -892,7 +945,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 32, 1}) + input = random({1, 32, 1}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -907,7 +960,7 @@ defmodule CompilerTest do [channels: :last, output_size: {27}] ]) - inp = Nx.random_uniform({1, 32, 1}) + inp = random({1, 32, 1}) assert {_, predict_fn} = Axon.build(model) @@ -926,7 +979,7 @@ defmodule CompilerTest do for pool <- @global_pooling_layers do model = apply(Axon, pool, [Axon.input("input", shape: {nil, 1, 32})]) - input = Nx.random_uniform({1, 1, 32}) + input = random({1, 1, 32}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -936,19 +989,19 @@ defmodule CompilerTest do test "computes forward pass with default options" do for pool <- @global_pooling_layers do model1 = apply(Axon, pool, [Axon.input("input", shape: {nil, 1, 4})]) - input1 = Nx.random_uniform({1, 1, 4}, type: {:f, 32}) + input1 = random({1, 1, 4}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model1) assert_equal(predict_fn.(%{}, input1), apply(Axon.Layers, pool, [input1])) model2 = apply(Axon, pool, [Axon.input("input", shape: {nil, 1, 2, 2})]) - input2 = Nx.random_uniform({1, 1, 2, 2}, type: {:f, 32}) + input2 = random({1, 1, 2, 2}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model2) assert_equal(predict_fn.(%{}, input2), apply(Axon.Layers, pool, [input2])) model3 = apply(Axon, pool, [Axon.input("input", shape: {nil, 1, 2, 2, 1})]) - input3 = Nx.random_uniform({1, 1, 2, 2, 1}, type: {:f, 32}) + input3 = random({1, 1, 2, 2, 1}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model3) assert_equal(predict_fn.(%{}, input3), apply(Axon.Layers, pool, [input3])) @@ -959,7 +1012,7 @@ defmodule CompilerTest do for pool <- @global_pooling_layers do opts1 = [keep_axes: true] model1 = apply(Axon, pool, [Axon.input("input", shape: {nil, 1, 2}), opts1]) - input1 = Nx.random_uniform({1, 1, 2}, type: {:f, 32}) + input1 = random({1, 1, 2}, type: {:f, 32}) assert {_, predict_fn} = Axon.build(model1) assert_equal(predict_fn.(%{}, input1), apply(Axon.Layers, pool, [input1, opts1])) @@ -972,7 +1025,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -993,7 +1046,7 @@ defmodule CompilerTest do [channels: :last, keep_axes: false] ]) - inp = Nx.random_uniform({1, 32, 1}) + inp = random({1, 32, 1}) assert {_, predict_fn} = Axon.build(model1) @@ -1023,7 +1076,7 @@ defmodule CompilerTest do [name: "dropout", seed: 0] ]) - input = Nx.random_uniform({1, 1, 32}) + input = random({1, 1, 32}) assert {init_fn, _predict_fn} = Axon.build(model, mode: :train) assert %{"dropout" => %{"key" => key}} = init_fn.(input, %{}) @@ -1039,7 +1092,7 @@ defmodule CompilerTest do [name: "dropout", seed: 0] ]) - input = Nx.random_uniform({1, 1, 32}) + input = random({1, 1, 32}) assert {init_fn, predict_fn} = Axon.build(model, mode: :train) @@ -1059,7 +1112,7 @@ defmodule CompilerTest do [rate: 0.5, name: "dropout", seed: 0] ]) - input = Nx.random_uniform({1, 16, 32}) + input = random({1, 16, 32}) assert {init_fn, predict_fn} = Axon.build(model, mode: :train) @@ -1074,7 +1127,7 @@ defmodule CompilerTest do test "computes forward pass with default options" do for dropout <- @dropout_layers do model1 = apply(Axon, dropout, [Axon.input("input", shape: {nil, 32, 32})]) - input1 = Nx.random_uniform({1, 32, 32}, type: {:f, 32}) + input1 = random({1, 32, 32}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1, mode: :train) %{prediction: result1} = predict_fn.(init_fn.(input1, %{}), input1) @@ -1084,7 +1137,7 @@ defmodule CompilerTest do assert_not_equal(result1, input1) model2 = apply(Axon, dropout, [Axon.input("input", shape: {nil, 1, 8, 4})]) - input2 = Nx.random_uniform({1, 1, 8, 4}, type: {:f, 32}) + input2 = random({1, 1, 8, 4}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model2, mode: :train) %{prediction: result2} = predict_fn.(init_fn.(input2, %{}), input2) @@ -1094,7 +1147,7 @@ defmodule CompilerTest do assert_not_equal(result2, input2) model3 = apply(Axon, dropout, [Axon.input("input", shape: {nil, 1, 8, 4, 2})]) - input3 = Nx.random_uniform({1, 1, 8, 4, 2}, type: {:f, 32}) + input3 = random({1, 1, 8, 4, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model3, mode: :train) %{prediction: result3} = predict_fn.(init_fn.(input3, %{}), input3) @@ -1109,7 +1162,7 @@ defmodule CompilerTest do for dropout <- @dropout_layers do opts1 = [rate: 0.5] model1 = apply(Axon, dropout, [Axon.input("input", shape: {nil, 32, 128}), opts1]) - input1 = Nx.random_uniform({1, 32, 128}, type: {:f, 32}) + input1 = random({1, 32, 128}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1, mode: :train) @@ -1127,7 +1180,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 32}) + input = random({1, 1, 32}) assert {init_fn, predict_fn} = Axon.build(mp_model, mode: :train) assert Nx.type(predict_fn.(init_fn.(input, %{}), input).prediction) == {:bf, 16} @@ -1137,7 +1190,7 @@ defmodule CompilerTest do test "not present in inference mode" do for dropout <- @dropout_layers do model = apply(Axon, dropout, [Axon.input("input", shape: {nil, 1, 32})]) - input = Nx.random_uniform({1, 1, 32}) + input = random({1, 1, 32}) {init_fn, predict_fn} = Axon.build(model) assert_equal(predict_fn.(init_fn.(input, %{}), input), input) @@ -1148,7 +1201,7 @@ defmodule CompilerTest do for dropout <- @dropout_layers do input = Axon.input("input", shape: {nil, 1, 32}) model = Axon.add([input, apply(Axon, dropout, [input])]) - input = Nx.random_uniform({1, 1, 32}) + input = random({1, 1, 32}) {init_fn, _predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -1196,21 +1249,21 @@ defmodule CompilerTest do test "computes forward pass with default options" do model1 = Axon.input("input", shape: {nil, 1, 2}) |> Axon.conv(2, name: "conv") - input1 = Nx.random_uniform({1, 1, 2}, type: {:f, 32}) + input1 = random({1, 1, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input1, %{}) assert_equal(predict_fn.(params, input1), Axon.Layers.conv(input1, kernel, bias)) model2 = Axon.input("input", shape: {nil, 1, 2, 2}) |> Axon.conv(3, name: "conv") - input2 = Nx.random_uniform({1, 1, 2, 2}, type: {:f, 32}) + input2 = random({1, 1, 2, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model2) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input2, %{}) assert_equal(predict_fn.(params, input2), Axon.Layers.conv(input2, kernel, bias)) model3 = Axon.input("input", shape: {nil, 1, 2, 2, 2}) |> Axon.conv(4, name: "conv") - input3 = Nx.random_uniform({1, 1, 2, 2, 2}, type: {:f, 32}) + input3 = random({1, 1, 2, 2, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model3) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input3, %{}) @@ -1224,7 +1277,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 2, 1}) |> Axon.conv(2, [name: "conv", kernel_size: 2] ++ opts1) - input1 = Nx.random_uniform({1, 2, 1}) + input1 = random({1, 2, 1}) assert {init_fn, predict_fn} = Axon.build(model1) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input1, %{}) @@ -1236,7 +1289,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 4, 4, 1}) |> Axon.conv(2, [name: "conv", kernel_size: 2] ++ opts2) - input2 = Nx.random_uniform({1, 4, 4, 1}) + input2 = random({1, 4, 4, 1}) assert {init_fn, predict_fn} = Axon.build(model2) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input2, %{}) @@ -1248,7 +1301,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 2, 2, 2, 1}) |> Axon.conv(4, [name: "conv", kernel_size: {2, 1, 1}] ++ opts3) - input3 = Nx.random_uniform({1, 2, 2, 2, 1}) + input3 = random({1, 2, 2, 2, 1}) assert {init_fn, predict_fn} = Axon.build(model3) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input3, %{}) @@ -1263,7 +1316,7 @@ defmodule CompilerTest do assert {init_fn, predict_fn} = Axon.build(model) - input = Nx.random_uniform({1, 1, 32}) + input = random({1, 1, 32}) backward = fn params, input -> Nx.Defn.grad(params, &Nx.mean(predict_fn.(&1, input))) @@ -1294,7 +1347,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 32}) + input = random({1, 1, 32}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -1304,7 +1357,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 1, 2}) |> Axon.conv(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, _} = Axon.build(model) assert %{"conv" => %{"kernel" => _} = conv_params} = init_fn.(input, %{}) @@ -1315,7 +1368,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 1, 2}) |> Axon.conv(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(model) assert %{"conv" => %{"kernel" => k}} = params = init_fn.(input, %{}) @@ -1326,7 +1379,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 3, 3, 6}) |> Axon.conv(2, name: "conv", channels: :last) - input = Nx.random_uniform({1, 3, 3, 6}) + input = random({1, 3, 3, 6}) assert {init_fn, predict_fn} = Axon.build(model) assert %{"conv" => %{"kernel" => k, "bias" => b}} = params = init_fn.(input, %{}) @@ -1367,7 +1420,7 @@ defmodule CompilerTest do test "initializes in default case" do model = Axon.input("input", shape: {nil, 2, 2, 3}) |> Axon.depthwise_conv(3, name: "conv") - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _} = Axon.build(model) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(input, %{}) @@ -1382,7 +1435,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 2, 2, 3}) |> Axon.depthwise_conv(3, name: "conv", kernel_initializer: :zeros) - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _predict_fn} = Axon.build(model1) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(input, %{}) @@ -1403,14 +1456,14 @@ defmodule CompilerTest do test "computes forward pass with default options" do model1 = Axon.input("input", shape: {nil, 1, 8}) |> Axon.depthwise_conv(3, name: "conv") - input1 = Nx.random_uniform({1, 1, 8}, type: {:f, 32}) + input1 = random({1, 1, 8}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input1, %{}) assert_equal(predict_fn.(params, input1), Axon.Layers.depthwise_conv(input1, kernel, bias)) model2 = Axon.input("input", shape: {nil, 1, 2, 2}) |> Axon.depthwise_conv(4, name: "conv") - input2 = Nx.random_uniform({1, 1, 2, 2}, type: {:f, 32}) + input2 = random({1, 1, 2, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model2) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input2, %{}) @@ -1419,7 +1472,7 @@ defmodule CompilerTest do model3 = Axon.input("input", shape: {nil, 1, 2, 2, 2}) |> Axon.depthwise_conv(5, name: "conv") - input3 = Nx.random_uniform({1, 1, 2, 2, 2}, type: {:f, 32}) + input3 = random({1, 1, 2, 2, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model3) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input3, %{}) @@ -1433,7 +1486,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 8, 1}) |> Axon.depthwise_conv(1, [name: "conv", kernel_size: 2] ++ opts1) - input1 = Nx.random_uniform({1, 8, 1}) + input1 = random({1, 8, 1}) assert {init_fn, predict_fn} = Axon.build(model1) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input1, %{}) @@ -1449,7 +1502,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 4, 4, 1}) |> Axon.depthwise_conv(8, [name: "conv", kernel_size: 2] ++ opts2) - input2 = Nx.random_uniform({1, 4, 4, 1}) + input2 = random({1, 4, 4, 1}) assert {init_fn, predict_fn} = Axon.build(model2) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input2, %{}) @@ -1465,7 +1518,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3, 2, 2, 1}) |> Axon.depthwise_conv(2, [name: "conv", kernel_size: {2, 1, 1}] ++ opts3) - input3 = Nx.random_uniform({1, 3, 2, 2, 1}) + input3 = random({1, 3, 2, 2, 1}) assert {init_fn, predict_fn} = Axon.build(model3) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input3, %{}) @@ -1482,7 +1535,7 @@ defmodule CompilerTest do |> Axon.depthwise_conv(1, name: "conv") |> Axon.freeze() - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -1502,7 +1555,7 @@ defmodule CompilerTest do policy = AMP.create_policy(params: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, _} = Axon.build(mp_model) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(input, %{}) @@ -1515,7 +1568,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -1526,7 +1579,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 1, 2}) |> Axon.depthwise_conv(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, _} = Axon.build(model) assert %{"conv" => %{"kernel" => _} = conv_params} = init_fn.(input, %{}) @@ -1538,7 +1591,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 1, 2}) |> Axon.depthwise_conv(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(model) assert %{"conv" => %{"kernel" => k}} = params = init_fn.(input, %{}) @@ -1550,7 +1603,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3, 3, 6}) |> Axon.depthwise_conv(2, name: "conv", channels: :last) - input = Nx.random_uniform({1, 3, 3, 6}) + input = random({1, 3, 3, 6}) assert {init_fn, predict_fn} = Axon.build(model) assert %{"conv" => %{"kernel" => k, "bias" => b}} = params = init_fn.(input, %{}) @@ -1589,7 +1642,7 @@ defmodule CompilerTest do test "initializes in default case" do model = Axon.input("input", shape: {nil, 2, 2, 3}) |> Axon.conv_transpose(32, name: "conv") - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _} = Axon.build(model) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(input, %{}) @@ -1604,7 +1657,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 2, 2, 3}) |> Axon.conv_transpose(32, name: "conv", kernel_initializer: :zeros) - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _predict_fn} = Axon.build(model1) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(input, %{}) @@ -1625,14 +1678,14 @@ defmodule CompilerTest do test "computes forward pass with default options" do model1 = Axon.input("input", shape: {nil, 1, 4}) |> Axon.conv_transpose(3, name: "conv") - input1 = Nx.random_uniform({1, 1, 4}, type: {:f, 32}) + input1 = random({1, 1, 4}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input1, %{}) assert_equal(predict_fn.(params, input1), Axon.Layers.conv_transpose(input1, kernel, bias)) model2 = Axon.input("input", shape: {nil, 1, 4, 4}) |> Axon.conv_transpose(4, name: "conv") - input2 = Nx.random_uniform({1, 1, 4, 4}, type: {:f, 32}) + input2 = random({1, 1, 4, 4}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model2) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input2, %{}) @@ -1641,7 +1694,7 @@ defmodule CompilerTest do model3 = Axon.input("input", shape: {nil, 1, 2, 2, 2}) |> Axon.conv_transpose(5, name: "conv") - input3 = Nx.random_uniform({1, 1, 2, 2, 2}, type: {:f, 32}) + input3 = random({1, 1, 2, 2, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model3) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input3, %{}) @@ -1655,7 +1708,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 4, 1}) |> Axon.conv_transpose(1, [name: "conv", kernel_size: 2] ++ opts1) - input1 = Nx.random_uniform({1, 4, 1}) + input1 = random({1, 4, 1}) assert {init_fn, predict_fn} = Axon.build(model1) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input1, %{}) @@ -1671,7 +1724,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 4, 4, 1}) |> Axon.conv_transpose(8, [name: "conv", kernel_size: 2] ++ opts2) - input2 = Nx.random_uniform({1, 4, 4, 1}) + input2 = random({1, 4, 4, 1}) assert {init_fn, predict_fn} = Axon.build(model2) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input2, %{}) @@ -1687,7 +1740,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 2, 2, 2, 1}) |> Axon.conv_transpose(2, [name: "conv", kernel_size: {2, 1, 1}] ++ opts3) - input3 = Nx.random_uniform({1, 2, 2, 2, 1}) + input3 = random({1, 2, 2, 2, 1}) assert {init_fn, predict_fn} = Axon.build(model3) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = params = init_fn.(input3, %{}) @@ -1704,7 +1757,7 @@ defmodule CompilerTest do |> Axon.conv_transpose(1, name: "conv") |> Axon.freeze() - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -1724,7 +1777,7 @@ defmodule CompilerTest do policy = AMP.create_policy(params: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, _} = Axon.build(mp_model) assert %{"conv" => %{"kernel" => kernel, "bias" => bias}} = init_fn.(input, %{}) @@ -1737,7 +1790,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -1748,7 +1801,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 1, 2}) |> Axon.conv_transpose(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, _} = Axon.build(model) assert %{"conv" => %{"kernel" => _} = conv_params} = init_fn.(input, %{}) @@ -1760,7 +1813,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 1, 2}) |> Axon.conv_transpose(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(model) assert %{"conv" => %{"kernel" => k}} = params = init_fn.(input, %{}) @@ -1772,7 +1825,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3, 3, 6}) |> Axon.conv_transpose(2, name: "conv", channels: :last) - input = Nx.random_uniform({1, 3, 3, 6}) + input = random({1, 3, 3, 6}) assert {init_fn, predict_fn} = Axon.build(model) assert %{"conv" => %{"kernel" => k, "bias" => b}} = params = init_fn.(input, %{}) @@ -1788,7 +1841,7 @@ defmodule CompilerTest do test "initializes in default case" do model = Axon.input("input", shape: {nil, 2, 2, 3}) |> Axon.separable_conv2d(3, name: "conv") - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _} = Axon.build(model) @@ -1816,7 +1869,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 2, 2, 3}) |> Axon.separable_conv2d(3, name: "conv", kernel_initializer: :zeros) - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _} = Axon.build(model1) @@ -1861,7 +1914,7 @@ defmodule CompilerTest do test "computes forward pass with default options" do model = Axon.input("input", shape: {nil, 3, 2, 2}) |> Axon.separable_conv2d(3, name: "conv") - input = Nx.random_uniform({1, 3, 2, 2}) + input = random({1, 3, 2, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -1887,7 +1940,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3, 3, 2}) |> Axon.separable_conv2d(3, [name: "conv", kernel_size: {2, 2}] ++ opts) - input = Nx.random_uniform({1, 3, 3, 2}) + input = random({1, 3, 3, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -1912,7 +1965,7 @@ defmodule CompilerTest do |> Axon.separable_conv2d(1, name: "conv") |> Axon.freeze() - input = Nx.random_uniform({1, 1, 3, 2}) + input = random({1, 1, 3, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -1940,7 +1993,7 @@ defmodule CompilerTest do policy = AMP.create_policy(params: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 3, 2}) + input = random({1, 1, 3, 2}) assert {init_fn, _} = Axon.build(mp_model) @@ -1964,7 +2017,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 3, 2}) + input = random({1, 1, 3, 2}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -1975,7 +2028,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 1, 2, 2}) |> Axon.separable_conv2d(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 2, 2}) + input = random({1, 1, 2, 2}) assert {init_fn, _} = Axon.build(model) assert %{"conv" => %{"kernel_1" => _, "kernel_2" => _} = conv_params} = init_fn.(input, %{}) @@ -1988,7 +2041,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 1, 2, 2}) |> Axon.separable_conv2d(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 2, 2}) + input = random({1, 1, 2, 2}) assert {init_fn, predict_fn} = Axon.build(model) assert %{"conv" => %{"kernel_1" => k1, "kernel_2" => k2}} = params = init_fn.(input, %{}) @@ -2004,7 +2057,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3, 3, 6}) |> Axon.separable_conv2d(2, name: "conv", channels: :last) - input = Nx.random_uniform({1, 3, 3, 6}) + input = random({1, 3, 3, 6}) assert {init_fn, predict_fn} = Axon.build(model) @@ -2046,7 +2099,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 3, 2, 2, 3}) |> Axon.separable_conv3d(3, name: "conv") - input = Nx.random_uniform({1, 3, 2, 2, 3}) + input = random({1, 3, 2, 2, 3}) assert {init_fn, _} = Axon.build(model) @@ -2080,7 +2133,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3, 3, 2, 2}) |> Axon.separable_conv3d(3, name: "conv", kernel_initializer: :zeros) - input = Nx.random_uniform({1, 3, 2, 2, 3}) + input = random({1, 3, 2, 2, 3}) assert {init_fn, _} = Axon.build(model1) @@ -2141,7 +2194,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 3, 2, 2, 2}) |> Axon.separable_conv3d(3, name: "conv") - input = Nx.random_uniform({1, 3, 2, 2, 2}) + input = random({1, 3, 2, 2, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -2169,7 +2222,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3, 2, 3, 3}) |> Axon.separable_conv3d(3, [name: "conv", kernel_size: {2, 2, 1}] ++ opts) - input = Nx.random_uniform({1, 3, 2, 3, 3}) + input = random({1, 3, 2, 3, 3}) assert {init_fn, predict_fn} = Axon.build(model) @@ -2196,7 +2249,7 @@ defmodule CompilerTest do |> Axon.separable_conv3d(1, name: "conv") |> Axon.freeze() - input = Nx.random_uniform({1, 1, 3, 2, 2}) + input = random({1, 1, 3, 2, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -2230,7 +2283,7 @@ defmodule CompilerTest do policy = AMP.create_policy(params: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 3, 2, 2}) + input = random({1, 1, 3, 2, 2}) assert {init_fn, _} = Axon.build(mp_model) @@ -2260,7 +2313,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 3, 2, 2}) + input = random({1, 1, 3, 2, 2}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -2271,7 +2324,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 1, 3, 2, 2}) |> Axon.separable_conv3d(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 3, 2, 2}) + input = random({1, 1, 3, 2, 2}) assert {init_fn, _} = Axon.build(model) @@ -2288,7 +2341,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 1, 3, 2, 2}) |> Axon.separable_conv3d(1, name: "conv", use_bias: false) - input = Nx.random_uniform({1, 1, 3, 2, 2}) + input = random({1, 1, 3, 2, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -2314,7 +2367,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3, 3, 3, 6}) |> Axon.separable_conv3d(2, name: "conv", channels: :last) - input = Nx.random_uniform({1, 3, 3, 3, 6}) + input = random({1, 3, 3, 3, 6}) assert {init_fn, predict_fn} = Axon.build(model) @@ -2368,7 +2421,7 @@ defmodule CompilerTest do if norm != :instance_norm do model1 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2}), [name: "norm"]]) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) assert {init_fn, _predict_fn} = Axon.build(model1) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2380,7 +2433,7 @@ defmodule CompilerTest do model2 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2, 2, 3}), [name: "norm"]]) - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _predict_fn} = Axon.build(model2) @@ -2407,7 +2460,7 @@ defmodule CompilerTest do [name: "norm", gamma_initializer: :zeros] ]) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) assert {init_fn, _predict_fn} = Axon.build(model1) @@ -2429,7 +2482,7 @@ defmodule CompilerTest do [name: "norm", beta_initializer: :zeros] ]) - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _predict_fn} = Axon.build(model2) @@ -2450,7 +2503,7 @@ defmodule CompilerTest do for norm <- @normalization_with_stats_layers do if norm != :instance_norm do model1 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2}), [name: "norm"]]) - input1 = Nx.random_uniform({1, 2}, type: {:f, 32}) + input1 = random({1, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1) @@ -2464,7 +2517,7 @@ defmodule CompilerTest do end model2 = apply(Axon, norm, [Axon.input("input", shape: {nil, 3, 2, 2}), [name: "norm"]]) - input2 = Nx.random_uniform({1, 3, 2, 2}, type: {:f, 32}) + input2 = random({1, 3, 2, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model2) @@ -2486,7 +2539,7 @@ defmodule CompilerTest do model1 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2}), [name: "norm"] ++ opts1]) - input1 = Nx.random_uniform({1, 2}, type: {:f, 32}) + input1 = random({1, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1) @@ -2504,7 +2557,7 @@ defmodule CompilerTest do model2 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2, 2, 3}), [name: "norm"] ++ opts2]) - input2 = Nx.random_uniform({1, 2, 2, 3}, type: {:f, 32}) + input2 = random({1, 2, 2, 3}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model2) @@ -2524,7 +2577,7 @@ defmodule CompilerTest do apply(Axon, norm, [Axon.input("input", shape: {nil, 1, 2}), [name: "norm"]]) |> Axon.freeze() - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -2546,7 +2599,7 @@ defmodule CompilerTest do policy = AMP.create_policy(params: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, _} = Axon.build(mp_model) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2561,7 +2614,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -2577,7 +2630,7 @@ defmodule CompilerTest do if norm != :instance_norm do model1 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2}), [name: "norm"]]) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) assert {init_fn, _predict_fn} = Axon.build(model1) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2589,7 +2642,7 @@ defmodule CompilerTest do model2 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2, 2, 3}), [name: "norm"]]) - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _predict_fn} = Axon.build(model2) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2609,7 +2662,7 @@ defmodule CompilerTest do [name: "norm", gamma_initializer: :zeros] ]) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) assert {init_fn, _predict_fn} = Axon.build(model1) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2624,7 +2677,7 @@ defmodule CompilerTest do [name: "norm", beta_initializer: :zeros] ]) - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, _predict_fn} = Axon.build(model2) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2638,7 +2691,7 @@ defmodule CompilerTest do for norm <- @normalization_layers do if norm != :instance_norm do model1 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2}), [name: "norm"]]) - input1 = Nx.random_uniform({1, 2}, type: {:f, 32}) + input1 = random({1, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = params = init_fn.(input1, %{}) @@ -2650,7 +2703,7 @@ defmodule CompilerTest do end model2 = apply(Axon, norm, [Axon.input("input", shape: {nil, 3, 2, 2}), [name: "norm"]]) - input2 = Nx.random_uniform({1, 3, 2, 2}, type: {:f, 32}) + input2 = random({1, 3, 2, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model2) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = params = init_fn.(input2, %{}) @@ -2666,7 +2719,7 @@ defmodule CompilerTest do model1 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2}), [name: "norm"] ++ opts1]) - input1 = Nx.random_uniform({1, 2}, type: {:f, 32}) + input1 = random({1, 2}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model1) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = params = init_fn.(input1, %{}) @@ -2682,7 +2735,7 @@ defmodule CompilerTest do model2 = apply(Axon, norm, [Axon.input("input", shape: {nil, 2, 2, 3}), [name: "norm"] ++ opts2]) - input2 = Nx.random_uniform({1, 2, 2, 3}, type: {:f, 32}) + input2 = random({1, 2, 2, 3}, type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model2) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = params = init_fn.(input2, %{}) @@ -2702,7 +2755,7 @@ defmodule CompilerTest do assert {init_fn, predict_fn} = Axon.build(model) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) backward = fn params, input -> Nx.Defn.grad(params, &Nx.mean(predict_fn.(&1, input))) @@ -2722,7 +2775,7 @@ defmodule CompilerTest do policy = AMP.create_policy(params: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, _} = Axon.build(mp_model) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2737,7 +2790,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 2}) + input = random({1, 1, 2}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -2749,7 +2802,7 @@ defmodule CompilerTest do test "initializes in default case" do model = Axon.input("input", shape: {nil, 3}) |> Axon.group_norm(3, name: "norm") - input = Nx.random_uniform({1, 3}) + input = random({1, 3}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2764,7 +2817,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3}) |> Axon.group_norm(3, name: "norm", gamma_initializer: :zeros) - input = Nx.random_uniform({1, 3}) + input = random({1, 3}) assert {init_fn, _predict_fn} = Axon.build(model1) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2776,7 +2829,7 @@ defmodule CompilerTest do Axon.input("input", shape: {nil, 3, 3}) |> Axon.group_norm(3, name: "norm", beta_initializer: :zeros) - input = Nx.random_uniform({1, 3, 3}) + input = random({1, 3, 3}) assert {init_fn, _predict_fn} = Axon.build(model2) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2787,7 +2840,7 @@ defmodule CompilerTest do test "computes forward pass with default options" do model1 = Axon.input("input", shape: {nil, 2}) |> Axon.group_norm(2, name: "norm") - input1 = Nx.random_uniform({1, 2}) + input1 = random({1, 2}) assert {init_fn, predict_fn} = Axon.build(model1) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = params = init_fn.(input1, %{}) @@ -2798,7 +2851,7 @@ defmodule CompilerTest do ) model2 = Axon.input("input", shape: {nil, 2, 2, 3}) |> Axon.group_norm(3, name: "norm") - input2 = Nx.random_uniform({1, 2, 2, 3}) + input2 = random({1, 2, 2, 3}) assert {init_fn, predict_fn} = Axon.build(model2) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = params = init_fn.(input2, %{}) @@ -2815,7 +2868,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 2, 2, 3}) |> Axon.group_norm(3, [name: "norm"] ++ opts) - input = Nx.random_uniform({1, 2, 2, 3}) + input = random({1, 2, 2, 3}) assert {init_fn, predict_fn} = Axon.build(model) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = params = init_fn.(input, %{}) @@ -2834,7 +2887,7 @@ defmodule CompilerTest do assert {init_fn, predict_fn} = Axon.build(model) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) backward = fn params, input -> Nx.Defn.grad(params, &Nx.mean(predict_fn.(&1, input))) @@ -2852,7 +2905,7 @@ defmodule CompilerTest do policy = AMP.create_policy(params: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 2}) + input = random({1, 2}) assert {init_fn, _} = Axon.build(mp_model) assert %{"norm" => %{"gamma" => gamma, "beta" => beta}} = init_fn.(input, %{}) @@ -2865,7 +2918,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 3}) + input = random({1, 3}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -2876,7 +2929,7 @@ defmodule CompilerTest do test "initializes with no params" do model = Axon.input("input_0", shape: {nil, 32}) |> Axon.flatten() - input = Nx.random_uniform({1, 32}) + input = random({1, 32}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -2884,13 +2937,13 @@ defmodule CompilerTest do test "computes forward pass with default options" do model1 = Axon.input("input_0", shape: {nil, 32}) |> Axon.flatten() - input1 = Nx.random_uniform({1, 32}) + input1 = random({1, 32}) assert {_, predict_fn} = Axon.build(model1) assert_equal(predict_fn.(%{}, input1), Axon.Layers.flatten(input1)) model2 = Axon.input("input", shape: {nil, 3, 32, 32}) |> Axon.flatten() - input2 = Nx.random_uniform({1, 3, 32, 32}) + input2 = random({1, 3, 32, 32}) assert {_, predict_fn} = Axon.build(model2) assert_equal(predict_fn.(%{}, input2), Axon.Layers.flatten(input2)) @@ -2901,7 +2954,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 3}) + input = random({1, 3}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -2912,7 +2965,7 @@ defmodule CompilerTest do test "initializes with no params" do model = Axon.input("input", shape: {nil, 3, 32}) |> Axon.transpose() - input = Nx.random_uniform({1, 3, 32}) + input = random({1, 3, 32}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -2920,13 +2973,13 @@ defmodule CompilerTest do test "computes forward pass with default options" do model1 = Axon.input("input_0", shape: {nil, 32}) |> Axon.transpose([0, 1]) - input1 = Nx.random_uniform({1, 32}) + input1 = random({1, 32}) assert {_, predict_fn} = Axon.build(model1) assert_equal(predict_fn.(%{}, input1), Nx.transpose(input1, axes: [0, 1])) model2 = Axon.input("input", shape: {nil, 3, 32, 32}) |> Axon.transpose([0, 2, 1, 3]) - input2 = Nx.random_uniform({1, 3, 32, 32}) + input2 = random({1, 3, 32, 32}) assert {_, predict_fn} = Axon.build(model2) assert_equal(predict_fn.(%{}, input2), Nx.transpose(input2, axes: [0, 2, 1, 3])) @@ -2944,7 +2997,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 32}) + input = random({1, 32}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -2955,7 +3008,7 @@ defmodule CompilerTest do test "initializes with no params" do model = Axon.input("input", shape: {nil, 1, 32}) |> Axon.reshape({16, 2}) - input = Nx.random_uniform({1, 1, 32}) + input = random({1, 1, 32}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -2963,13 +3016,13 @@ defmodule CompilerTest do test "computes forward pass with default options" do model1 = Axon.input("input_0", shape: {nil, 32}) |> Axon.reshape({16, 2}) - input1 = Nx.random_uniform({1, 32}) + input1 = random({1, 32}) assert {_, predict_fn} = Axon.build(model1) assert_equal(predict_fn.(%{}, input1), Nx.reshape(input1, {1, 16, 2})) model2 = Axon.input("input", shape: {nil, 3, 32, 32}) |> Axon.reshape({3, 16, 2, 32}) - input2 = Nx.random_uniform({1, 3, 32, 32}) + input2 = random({1, 3, 32, 32}) assert {_, predict_fn} = Axon.build(model2) assert_equal(predict_fn.(%{}, input2), Nx.reshape(input2, {1, 3, 16, 2, 32})) @@ -2987,7 +3040,7 @@ defmodule CompilerTest do assert {_, predict_fn} = Axon.build(model) - input = Nx.random_uniform({2, 4, 6}) + input = random({2, 4, 6}) assert_equal(predict_fn.(%{}, input), Nx.reshape(input, {2, 3, 8})) end @@ -2996,7 +3049,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 32}) + input = random({1, 32}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -3007,7 +3060,7 @@ defmodule CompilerTest do test "initializes with no params" do model = Axon.input("input", shape: {nil, 1, 3, 3}) |> Axon.resize({4, 4}) - input = Nx.random_uniform({1, 1, 3, 3}) + input = random({1, 1, 3, 3}) assert {init_fn, _predict_fn} = Axon.build(model) assert %{} = init_fn.(input, %{}) @@ -3015,7 +3068,7 @@ defmodule CompilerTest do test "computes forward pass with default options" do model1 = Axon.input("input", shape: {nil, 1, 3, 3}) |> Axon.resize({4, 4}) - input1 = Nx.random_uniform({1, 1, 3, 3}) + input1 = random({1, 1, 3, 3}) assert {_, predict_fn} = Axon.build(model1) assert_equal(predict_fn.(%{}, input1), Axon.Layers.resize(input1, size: {4, 4})) @@ -3026,7 +3079,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 1, 3, 3}) + input = random({1, 1, 3, 3}) assert {init_fn, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(init_fn.(input, %{}), input)) == {:bf, 16} @@ -3040,7 +3093,7 @@ defmodule CompilerTest do |> Axon.lstm(64, name: "lstm") |> Axon.container() - input = Nx.random_uniform({1, 32, 10}) + input = random({1, 32, 10}) assert {init_fn, _predict_fn} = Axon.build(model) @@ -3089,7 +3142,7 @@ defmodule CompilerTest do |> Axon.lstm(64, name: "lstm", kernel_initializer: :zeros) |> Axon.container() - input = Nx.random_uniform({1, 32, 10}) + input = random({1, 32, 10}) assert {init_fn, _predict_fn} = Axon.build(model1) @@ -3171,9 +3224,9 @@ defmodule CompilerTest do |> Axon.lstm(2, name: "lstm", recurrent_initializer: :zeros) |> Axon.container() - input = Nx.random_uniform({1, 8, 2}, type: {:f, 32}) + input = random({1, 8, 2}, type: {:f, 32}) - init_carry = {zeros({1, 1, 2}), zeros({1, 1, 2})} + init_carry = {zeros({1, 2}), zeros({1, 2})} assert {init_fn, predict_fn} = Axon.build(model) @@ -3192,9 +3245,10 @@ defmodule CompilerTest do assert_equal( predict_fn.(params, input), Axon.Layers.dynamic_unroll( - &Axon.Layers.lstm_cell/5, + &Axon.Layers.lstm_cell/6, input, init_carry, + Nx.tensor(0), k, h, b @@ -3213,14 +3267,15 @@ defmodule CompilerTest do ) |> Axon.container() - input1 = Nx.random_uniform({1, 8, 2}, type: {:f, 32}) + input1 = random({1, 8, 2}, type: {:f, 32}) - init_carry1 = {zeros({1, 1, 2}), zeros({1, 1, 2})} + init_carry1 = {zeros({1, 2}), zeros({1, 2})} - cell_fn1 = fn i, c, k, h, b -> + cell_fn1 = fn i, c, mask, k, h, b -> Axon.Layers.lstm_cell( i, c, + mask, k, h, b, @@ -3245,7 +3300,7 @@ defmodule CompilerTest do assert_all_close( predict_fn.(params, input1), - Axon.Layers.dynamic_unroll(cell_fn1, input1, init_carry1, k, h, b) + Axon.Layers.dynamic_unroll(cell_fn1, input1, init_carry1, Nx.tensor(0), k, h, b) ) model2 = @@ -3253,11 +3308,11 @@ defmodule CompilerTest do |> Axon.lstm(2, name: "lstm", unroll: :static, recurrent_initializer: :zeros) |> Axon.container() - input2 = Nx.random_uniform({1, 8, 2}, type: {:f, 32}) + input2 = random({1, 8, 2}, type: {:f, 32}) - init_carry2 = {zeros({1, 1, 2}), zeros({1, 1, 2})} + init_carry2 = {zeros({1, 2}), zeros({1, 2})} - cell_fn2 = &Axon.Layers.lstm_cell/5 + cell_fn2 = &Axon.Layers.lstm_cell/6 assert {init_fn, predict_fn} = Axon.build(model2) @@ -3275,7 +3330,7 @@ defmodule CompilerTest do assert_all_close( predict_fn.(params, input2), - Axon.Layers.static_unroll(cell_fn2, input2, init_carry2, k, h, b) + Axon.Layers.static_unroll(cell_fn2, input2, init_carry2, Nx.tensor(0), k, h, b) ) end @@ -3283,7 +3338,7 @@ defmodule CompilerTest do seq = Axon.input("input", shape: {nil, 8, 2}) {_, carry} = seq |> Axon.lstm(2, name: "encode", recurrent_initializer: :zeros) model = Axon.lstm(seq, carry, 2, name: "decode") |> Axon.container() - input = Nx.random_uniform({1, 8, 2}) + input = random({1, 8, 2}) assert {init_fn, predict_fn} = Axon.build(model) @@ -3291,12 +3346,20 @@ defmodule CompilerTest do {ei, eh, eb} = enc {di, dh, db} = dec - init_carry = {zeros({1, 1, 2}), zeros({1, 1, 2})} + init_carry = {zeros({1, 2}), zeros({1, 2})} {_, carry} = - Axon.Layers.dynamic_unroll(&Axon.Layers.lstm_cell/5, inp, init_carry, ei, eh, eb) + Axon.Layers.dynamic_unroll( + &Axon.Layers.lstm_cell/6, + inp, + init_carry, + Nx.tensor(0), + ei, + eh, + eb + ) - Axon.Layers.dynamic_unroll(&Axon.Layers.lstm_cell/5, inp, carry, di, dh, db) + Axon.Layers.dynamic_unroll(&Axon.Layers.lstm_cell/6, inp, carry, Nx.tensor(0), di, dh, db) end assert %{ @@ -3328,7 +3391,7 @@ defmodule CompilerTest do |> Axon.lstm(2, name: "lstm", use_bias: false) |> Axon.container() - input = Nx.random_uniform({1, 2, 1}) + input = random({1, 2, 1}) assert {init_fn, _} = Axon.build(model) @@ -3349,7 +3412,7 @@ defmodule CompilerTest do |> Axon.lstm(2, name: "lstm", use_bias: false, recurrent_initializer: :zeros) |> Axon.container() - input = Nx.random_uniform({1, 2, 1}) + input = random({1, 2, 1}) assert {init_fn, predict_fn} = Axon.build(model) @@ -3361,14 +3424,37 @@ defmodule CompilerTest do } = params = init_fn.(input, %{}) b = {Nx.tensor(0), Nx.tensor(0), Nx.tensor(0), Nx.tensor(0)} - c = {zeros({1, 1, 2}), zeros({1, 1, 2})} + c = {zeros({1, 2}), zeros({1, 2})} assert_equal( predict_fn.(params, input), - Axon.Layers.dynamic_unroll(&Axon.Layers.lstm_cell/5, input, c, k, h, b) + Axon.Layers.dynamic_unroll(&Axon.Layers.lstm_cell/6, input, c, Nx.tensor(0), k, h, b) ) end + test "mask actually works" do + sequence = Axon.input("review") + mask = Axon.mask(sequence, 0) + embedded = sequence |> Axon.embedding(2048, 64) + {rnn_sequence, _state} = Axon.lstm(embedded, 64, mask: mask) + + {init_fn, predict_fn} = Axon.build(rnn_sequence) + params = init_fn.(Nx.template({64, 64}, :s64), %{}) + + input = Nx.tensor([[1, 2, 3, 4]]) + padded = Nx.pad(input, 0, [{0, 0, 0}, {0, 60, 0}]) + out = predict_fn.(params, padded) + + last_token = out[[.., 3, ..]] + + for i <- 4..63 do + # all eos tokens will be ignored so we just propagate the value + # to the next token and thus these should all be the same as the + # last non eos token + assert_equal(last_token, out[[.., i, ..]]) + end + end + # TODO(seanmor5): https://github.com/elixir-nx/axon/issues/90 # test "initializes with parameter policy" do # end @@ -3394,7 +3480,7 @@ defmodule CompilerTest do |> Axon.conv_lstm(out_channel_n, name: "convlstm") |> Axon.container() - input = Nx.random_uniform({1, 10, 3, 6, 6}) + input = random({1, 10, 3, 6, 6}) assert {init_fn, _predict_fn} = Axon.build(model) @@ -3428,7 +3514,7 @@ defmodule CompilerTest do _heigth = 6 } - input = Nx.random_uniform({1, 10, 3, 6, 6}) + input = random({1, 10, 3, 6, 6}) out_channel_n = 4 @@ -3502,7 +3588,7 @@ defmodule CompilerTest do input = input_shape |> put_elem(0, batch_real) - |> Nx.random_uniform(type: {:f, 32}) + |> random(type: {:f, 32}) init_carry = {zeros(hidden_shape_real), zeros(hidden_shape_real)} @@ -3523,9 +3609,10 @@ defmodule CompilerTest do assert_equal( predict_fn.(params, input), Axon.Layers.dynamic_unroll( - &Axon.Layers.conv_lstm_cell/5, + &Axon.Layers.conv_lstm_cell/6, input, init_carry, + Nx.tensor(0), k, h, b @@ -3558,7 +3645,7 @@ defmodule CompilerTest do input = input_shape |> put_elem(0, batch_real) - |> Nx.random_uniform(type: {:f, 32}) + |> random(type: {:f, 32}) init_carry = {zeros(hidden_shape_real), zeros(hidden_shape_real)} @@ -3579,9 +3666,10 @@ defmodule CompilerTest do assert_equal( predict_fn.(params, input), Axon.Layers.static_unroll( - &Axon.Layers.conv_lstm_cell/5, + &Axon.Layers.conv_lstm_cell/6, input, init_carry, + Nx.tensor(0), k, h, b @@ -3612,14 +3700,15 @@ defmodule CompilerTest do input1 = input_shape |> put_elem(0, batch_real) - |> Nx.random_uniform(type: {:f, 32}) + |> random(type: {:f, 32}) init_carry1 = {zeros(hidden_shape_real), zeros(hidden_shape_real)} - cell_fn1 = fn i, c, k, h, b -> + cell_fn1 = fn i, c, mask, k, h, b -> Axon.Layers.conv_lstm_cell( i, c, + mask, k, h, b @@ -3642,7 +3731,7 @@ defmodule CompilerTest do assert_equal( predict_fn.(params, input1), - Axon.Layers.dynamic_unroll(cell_fn1, input1, init_carry1, k, h, b) + Axon.Layers.dynamic_unroll(cell_fn1, input1, init_carry1, Nx.tensor(0), k, h, b) ) model2 = @@ -3657,11 +3746,11 @@ defmodule CompilerTest do input2 = input_shape |> put_elem(0, batch_real) - |> Nx.random_uniform(type: {:f, 32}) + |> random(type: {:f, 32}) init_carry2 = {zeros(hidden_shape_real), zeros(hidden_shape_real)} - cell_fn2 = &Axon.Layers.conv_lstm_cell/5 + cell_fn2 = &Axon.Layers.conv_lstm_cell/6 assert {init_fn, predict_fn} = Axon.build(model2) @@ -3679,7 +3768,7 @@ defmodule CompilerTest do assert_equal( predict_fn.(params, input2), - Axon.Layers.static_unroll(cell_fn2, input2, init_carry2, k, h, b) + Axon.Layers.static_unroll(cell_fn2, input2, init_carry2, Nx.tensor(0), k, h, b) ) end @@ -3708,7 +3797,7 @@ defmodule CompilerTest do input = input_shape |> put_elem(0, batch_real) - |> Nx.random_uniform(type: {:f, 32}) + |> random(type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model) @@ -3720,15 +3809,24 @@ defmodule CompilerTest do {_, carry} = Axon.Layers.dynamic_unroll( - &Axon.Layers.conv_lstm_cell/5, + &Axon.Layers.conv_lstm_cell/6, inp, init_carry, + Nx.tensor(0), ei, eh, eb ) - Axon.Layers.dynamic_unroll(&Axon.Layers.conv_lstm_cell/5, inp, carry, di, dh, db) + Axon.Layers.dynamic_unroll( + &Axon.Layers.conv_lstm_cell/6, + inp, + carry, + Nx.tensor(0), + di, + dh, + db + ) end assert %{ @@ -3779,7 +3877,7 @@ defmodule CompilerTest do input = input_shape |> put_elem(0, batch_real) - |> Nx.random_uniform(type: {:f, 32}) + |> random(type: {:f, 32}) assert {init_fn, predict_fn} = Axon.build(model) @@ -3796,7 +3894,7 @@ defmodule CompilerTest do assert_equal( predict_fn.(params, input), - Axon.Layers.dynamic_unroll(&Axon.Layers.conv_lstm_cell/5, input, c, k, h, b) + Axon.Layers.dynamic_unroll(&Axon.Layers.conv_lstm_cell/6, input, c, Nx.tensor(0), k, h, b) ) end end @@ -3806,7 +3904,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 32, 10}) |> Axon.gru(64, name: "gru") |> Axon.container() - input = Nx.random_uniform({1, 32, 10}) + input = random({1, 32, 10}) assert {init_fn, _} = Axon.build(model) @@ -3841,7 +3939,7 @@ defmodule CompilerTest do end test "initializes with custom initializers" do - input = Nx.random_uniform({1, 32, 10}) + input = random({1, 32, 10}) model1 = Axon.input("input", shape: {nil, 32, 10}) @@ -3912,8 +4010,8 @@ defmodule CompilerTest do |> Axon.gru(2, name: "gru", recurrent_initializer: :zeros) |> Axon.container() - input = Nx.random_uniform({1, 8, 2}) - carry = {zeros({1, 1, 2})} + input = random({1, 8, 2}) + carry = {zeros({1, 2})} assert {init_fn, predict_fn} = Axon.build(model) @@ -3927,7 +4025,7 @@ defmodule CompilerTest do assert_equal( predict_fn.(params, input), - Axon.Layers.dynamic_unroll(&Axon.Layers.gru_cell/5, input, carry, k, h, b) + Axon.Layers.dynamic_unroll(&Axon.Layers.gru_cell/6, input, carry, Nx.tensor(0), k, h, b) ) end @@ -3942,13 +4040,14 @@ defmodule CompilerTest do ) |> Axon.container() - input1 = Nx.random_uniform({1, 8, 2}) - carry1 = {zeros({1, 1, 2})} + input1 = random({1, 8, 2}) + carry1 = {zeros({1, 2})} - cell_fn1 = fn i, c, k, h, b -> + cell_fn1 = fn i, c, mask, k, h, b -> Axon.Layers.gru_cell( i, c, + mask, k, h, b, @@ -3971,9 +4070,9 @@ defmodule CompilerTest do h = {whr, whz, whn} b = {br, bz, bin, bhn} - assert_equal( + assert_all_close( predict_fn.(params, input1), - Axon.Layers.dynamic_unroll(cell_fn1, input1, carry1, k, h, b) + Axon.Layers.dynamic_unroll(cell_fn1, input1, carry1, Nx.tensor(0), k, h, b) ) model2 = @@ -3981,8 +4080,8 @@ defmodule CompilerTest do |> Axon.gru(2, name: "gru", recurrent_initializer: :zeros, unroll: :static) |> Axon.container() - input2 = Nx.random_uniform({1, 8, 2}) - carry2 = {zeros({1, 1, 2})} + input2 = random({1, 8, 2}) + carry2 = {zeros({1, 2})} assert {init_fn, predict_fn} = Axon.build(model2) @@ -3998,9 +4097,9 @@ defmodule CompilerTest do h = {whr, whz, whn} b = {br, bz, bin, bhn} - assert_equal( + assert_all_close( predict_fn.(params, input2), - Axon.Layers.static_unroll(&Axon.Layers.gru_cell/5, input2, carry2, k, h, b) + Axon.Layers.static_unroll(&Axon.Layers.gru_cell/6, input2, carry2, Nx.tensor(0), k, h, b) ) end @@ -4008,16 +4107,26 @@ defmodule CompilerTest do seq = Axon.input("input", shape: {nil, 8, 2}) {_, carry} = Axon.gru(seq, 2, name: "encode", recurrent_initializer: :zeros) model = Axon.gru(seq, carry, 2, name: "decode") |> Axon.container() - input = Nx.random_uniform({1, 8, 2}) - carry = {zeros({1, 1, 2})} + + input = random({1, 8, 2}) + carry = {zeros({1, 2})} equiv_fn = fn inp, enc, dec -> {ei, eh, eb} = enc {di, dh, db} = dec - {_, carry} = Axon.Layers.dynamic_unroll(&Axon.Layers.gru_cell/5, inp, carry, ei, eh, eb) + {_, carry} = + Axon.Layers.dynamic_unroll( + &Axon.Layers.gru_cell/6, + inp, + carry, + Nx.tensor(0), + ei, + eh, + eb + ) - Axon.Layers.dynamic_unroll(&Axon.Layers.gru_cell/5, inp, carry, di, dh, db) + Axon.Layers.dynamic_unroll(&Axon.Layers.gru_cell/6, inp, carry, Nx.tensor(0), di, dh, db) end assert {init_fn, predict_fn} = Axon.build(model) @@ -4047,7 +4156,7 @@ defmodule CompilerTest do |> Axon.gru(2, name: "gru", use_bias: false) |> Axon.container() - input = Nx.random_uniform({1, 2, 1}) + input = random({1, 2, 1}) assert {init_fn, _} = Axon.build(model) @@ -4068,7 +4177,7 @@ defmodule CompilerTest do |> Axon.gru(2, name: "gru", use_bias: false, recurrent_initializer: :zeros) |> Axon.container() - input = Nx.random_uniform({1, 2, 1}) + input = random({1, 2, 1}) assert {init_fn, predict_fn} = Axon.build(model) assert %{ @@ -4079,14 +4188,37 @@ defmodule CompilerTest do } = params = init_fn.(input, %{}) b = {Nx.tensor(0), Nx.tensor(0), Nx.tensor(0), Nx.tensor(0)} - c = {zeros({1, 1, 2})} + c = {zeros({1, 2})} assert_all_close( predict_fn.(params, input), - Axon.Layers.dynamic_unroll(&Axon.Layers.gru_cell/5, input, c, k, h, b) + Axon.Layers.dynamic_unroll(&Axon.Layers.gru_cell/6, input, c, Nx.tensor(0), k, h, b) ) end + test "mask actually works" do + sequence = Axon.input("review") + mask = Axon.mask(sequence, 0) + embedded = sequence |> Axon.embedding(2048, 64) + {rnn_sequence, _state} = Axon.gru(embedded, 64, mask: mask) + + {init_fn, predict_fn} = Axon.build(rnn_sequence) + params = init_fn.(Nx.template({64, 64}, :s64), %{}) + + input = Nx.tensor([[1, 2, 3, 4]]) + padded = Nx.pad(input, 0, [{0, 0, 0}, {0, 60, 0}]) + out = predict_fn.(params, padded) + + last_token = out[[.., 3, ..]] + + for i <- 4..63 do + # all eos tokens will be ignored so we just propagate the value + # to the next token and thus these should all be the same as the + # last non eos token + assert_equal(last_token, out[[.., i, ..]]) + end + end + # TODO(seanmor5): https://github.com/elixir-nx/axon/issues/90 # test "" # TODO(seanmor5): https://github.com/elixir-nx/axon/issues/90 @@ -4110,8 +4242,8 @@ defmodule CompilerTest do ]) input = %{ - "input_0" => Nx.random_uniform({1, 32}), - "input_1" => Nx.random_uniform({1, 32}) + "input_0" => random({1, 32}), + "input_1" => random({1, 32}) } assert {init_fn, _} = Axon.build(model) @@ -4127,8 +4259,8 @@ defmodule CompilerTest do Axon.input("input_1", shape: {nil, 32}) ]) - input1_1 = Nx.random_uniform({1, 32}) - input1_2 = Nx.random_uniform({1, 32}) + input1_1 = random({1, 32}) + input1_2 = random({1, 32}) assert {_, predict_fn} = Axon.build(model1) assert_all_close( @@ -4145,9 +4277,9 @@ defmodule CompilerTest do ] ]) - input2_1 = Nx.random_uniform({1, 32}) - input2_2 = Nx.random_uniform({1, 32}) - input2_3 = Nx.random_uniform({1, 32}) + input2_1 = random({1, 32}) + input2_2 = random({1, 32}) + input2_3 = random({1, 32}) assert {_, predict_fn} = Axon.build(model2) assert_all_close( @@ -4169,8 +4301,8 @@ defmodule CompilerTest do mp_model = AMP.apply_policy(model, policy) input = %{ - "input_0" => Nx.random_uniform({1, 32}), - "input_1" => Nx.random_uniform({1, 32}) + "input_0" => random({1, 32}), + "input_1" => random({1, 32}) } assert {_, predict_fn} = Axon.build(mp_model) @@ -4179,8 +4311,8 @@ defmodule CompilerTest do end test "computes forward pass with broadcasting" do - inp1 = Nx.random_uniform({1, 1}) - inp2 = Nx.random_uniform({1, 2}) + inp1 = random({1, 1}) + inp2 = random({1, 2}) for op <- @binary_layers do model = @@ -4201,8 +4333,8 @@ defmodule CompilerTest do test "raises on bad shapes" do for op <- @binary_layers do assert_raise Axon.CompileError, ~r/cannot broadcast tensor/, fn -> - inp1 = Nx.random_uniform({1, 32}) - inp2 = Nx.random_uniform({1, 64}) + inp1 = random({1, 32}) + inp2 = random({1, 64}) model = apply(Axon, op, [ @@ -4223,7 +4355,7 @@ defmodule CompilerTest do Axon.input("input_1", shape: {nil, 32}) ) - input = %{"input_0" => Nx.random_uniform({1, 32}), "input_1" => Nx.random_uniform({1, 32})} + input = %{"input_0" => random({1, 32}), "input_1" => random({1, 32})} assert {init_fn, _} = Axon.build(model) assert %{} == init_fn.(input, %{}) @@ -4236,8 +4368,8 @@ defmodule CompilerTest do Axon.input("input_1", shape: {nil, 32}) ) - input1_1 = Nx.random_uniform({1, 32}) - input1_2 = Nx.random_uniform({1, 32}) + input1_1 = random({1, 32}) + input1_2 = random({1, 32}) assert {_, predict_fn} = Axon.build(model1) @@ -4253,9 +4385,9 @@ defmodule CompilerTest do Axon.input("input_2", shape: {nil, 32}) ]) - input2_1 = Nx.random_uniform({1, 32}) - input2_2 = Nx.random_uniform({1, 32}) - input2_3 = Nx.random_uniform({1, 32}) + input2_1 = random({1, 32}) + input2_2 = random({1, 32}) + input2_3 = random({1, 32}) assert {_, predict_fn} = Axon.build(model2) @@ -4273,8 +4405,8 @@ defmodule CompilerTest do axis: 1 ) - input1_1 = Nx.random_uniform({1, 1, 32}) - input1_2 = Nx.random_uniform({1, 1, 32}) + input1_1 = random({1, 1, 32}) + input1_2 = random({1, 1, 32}) assert {_, predict_fn} = Axon.build(model1) @@ -4294,8 +4426,8 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model1, policy) - input1_1 = Nx.random_uniform({1, 1, 32}) - input1_2 = Nx.random_uniform({1, 1, 32}) + input1_1 = random({1, 1, 32}) + input1_2 = random({1, 1, 32}) assert {_, predict_fn} = Axon.build(mp_model) @@ -4307,7 +4439,7 @@ defmodule CompilerTest do describe "pad" do test "initializes with no params" do model = Axon.input("input", shape: {nil, 3, 3}) |> Axon.pad([{1, 0}]) - input = Nx.random_uniform({1, 3, 3}) + input = random({1, 3, 3}) assert {init_fn, _} = Axon.build(model) assert %{} == init_fn.(input, %{}) @@ -4315,7 +4447,7 @@ defmodule CompilerTest do test "computes forward pass with default options" do model1 = Axon.input("input", shape: {nil, 3, 3}) |> Axon.pad([{1, 0}]) - input1 = Nx.random_uniform({1, 3, 3}) + input1 = random({1, 3, 3}) assert {_, predict_fn} = Axon.build(model1) @@ -4325,7 +4457,7 @@ defmodule CompilerTest do ) model2 = Axon.input("input", shape: {nil, 3, 3, 3}) |> Axon.pad([{0, 1}, {0, 1}]) - input2 = Nx.random_uniform({1, 3, 3, 3}) + input2 = random({1, 3, 3, 3}) assert {_, predict_fn} = Axon.build(model2) @@ -4335,7 +4467,7 @@ defmodule CompilerTest do ) model3 = Axon.input("input", shape: {nil, 3, 3, 3, 3}) |> Axon.pad([{0, 1}, {0, 1}, {1, 0}]) - input3 = Nx.random_uniform({1, 3, 3, 3, 3}) + input3 = random({1, 3, 3, 3, 3}) assert {_, predict_fn} = Axon.build(model3) @@ -4347,7 +4479,7 @@ defmodule CompilerTest do test "computes forward pass with custom options" do model = Axon.input("input", shape: {nil, 3, 3}) |> Axon.pad([{1, 0}], 2) - input = Nx.random_uniform({1, 3, 3}) + input = random({1, 3, 3}) assert {_, predict_fn} = Axon.build(model) @@ -4361,7 +4493,7 @@ defmodule CompilerTest do model = Axon.input("input", shape: {nil, 3, 3}) |> Axon.pad([{1, 0}]) policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model, policy) - input = Nx.random_uniform({1, 3, 3}) + input = random({1, 3, 3}) assert {_, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(%{}, input)) == {:bf, 16} @@ -4371,7 +4503,7 @@ defmodule CompilerTest do describe "nx" do test "computes special nx functions" do model = Axon.input("input", shape: {nil, 10}) |> Axon.nx(&Nx.sin/1) - input = Nx.random_uniform({1, 10}) + input = random({1, 10}) assert {_, predict_fn} = Axon.build(model) assert_all_close(predict_fn.(%{}, input), Nx.sin(input)) @@ -4417,7 +4549,7 @@ defmodule CompilerTest do policy = AMP.create_policy(output: {:bf, 16}) mp_model = AMP.apply_policy(model1, policy) - input1_1 = Nx.random_uniform({1, 1, 32}) + input1_1 = random({1, 1, 32}) assert {_, predict_fn} = Axon.build(mp_model) assert Nx.type(predict_fn.(%{}, input1_1)) == {:bf, 16} @@ -4433,7 +4565,7 @@ defmodule CompilerTest do assert_raise Axon.CompileError, ~r/cond_fn must return a scalar/, fn -> {_, predict_fn} = Axon.build(model) - predict_fn.(%{}, Nx.random_uniform({1, 1, 10})) + predict_fn.(%{}, random({1, 1, 10})) end end end @@ -4442,7 +4574,7 @@ defmodule CompilerTest do test "initializes with no parameters" do model = Axon.input("input", shape: {nil, 10}) |> Axon.split(5) |> Axon.container() - input = Nx.random_uniform({1, 10}) + input = random({1, 10}) assert {init_fn, _} = Axon.build(model) assert init_fn.(input, %{}) == %{} @@ -4521,7 +4653,7 @@ defmodule CompilerTest do {init_fn, predict_fn} = Axon.build(model) - inp = Nx.random_uniform({1, 1}) + inp = random({1, 1}) params = init_fn.(inp, %{}) axon_loss = fn inp, params -> Nx.sum(predict_fn.(params, inp)) end @@ -4565,7 +4697,7 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - inp = Nx.random_uniform({1, 2}) + inp = random({1, 2}) assert %{"dense_0" => dense_0_params, "dense_1" => dense_1_params} = init_fn.(inp, %{}) @@ -4632,7 +4764,7 @@ defmodule CompilerTest do describe "custom layers" do test "initializes with no parameters" do model = Axon.layer(fn x, _opts -> x end, [Axon.input("input_0", shape: {nil, 1})]) - inp = Nx.random_uniform({1, 1}) + inp = random({1, 1}) {init_fn, _} = Axon.build(model) assert Enum.empty?(init_fn.(inp, %{})) @@ -4650,7 +4782,7 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - inp = Nx.random_uniform({1, 1}) + inp = random({1, 1}) assert %{"layer_0" => %{"kernel" => kernel}} = init_fn.(inp, %{}) @@ -4670,7 +4802,7 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - inp = Nx.random_uniform({1, 1}) + inp = random({1, 1}) assert %{"custom_0" => %{"kernel" => _}, "custom_1" => %{"kernel" => _}} = init_fn.(inp, %{}) @@ -4687,7 +4819,7 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert %{"layer_0" => %{"kernel" => kernel}} = params = init_fn.(input, %{}) @@ -4695,19 +4827,17 @@ defmodule CompilerTest do end defn layer_with_options(x, kernel, opts \\ []) do - transform({x, kernel, opts}, fn {x, kernel, opts} -> - if opts[:add] do - Nx.add(x, kernel) - else - Nx.multiply(x, kernel) - end - end) + if opts[:add] do + Nx.add(x, kernel) + else + Nx.multiply(x, kernel) + end end test "computes forward pass with options" do kernel_param = Axon.param("kernel", fn shape -> shape end) - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) model1 = Axon.layer(&layer_with_options/3, [Axon.input("input_0", shape: {nil, 1}), kernel_param], @@ -4741,7 +4871,7 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - inp = Nx.random_uniform({1, 1}) + inp = random({1, 1}) assert %{"model" => %{"dense_0" => %{"kernel" => k, "bias" => b}}} = init_fn.(inp, %{}) @@ -4760,7 +4890,7 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - inp = Nx.random_uniform({1, 1}) + inp = random({1, 1}) assert %{"nested" => %{"model" => %{"dense_0" => %{"kernel" => k, "bias" => b}}}} = init_fn.(inp, %{}) @@ -4771,12 +4901,12 @@ defmodule CompilerTest do assert Nx.type(b) == {:f, 32} end - test "initializes correclty with single namespace no params" do + test "initializes correctly with single namespace no params" do model = Axon.input("input_0", shape: {nil, 1}) |> Axon.namespace("model") {init_fn, _} = Axon.build(model) - inp = Nx.random_uniform({1, 1}) + inp = random({1, 1}) assert Enum.empty?(init_fn.(inp, %{})) end @@ -4789,7 +4919,7 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - inp = Nx.random_uniform({1, 1}) + inp = random({1, 1}) assert Enum.empty?(init_fn.(inp, %{})) end @@ -4798,7 +4928,7 @@ defmodule CompilerTest do x = Axon.input("input_0", shape: {nil, 1}) |> Axon.dense(2) |> Axon.namespace("x") y = Axon.input("input_1", shape: {nil, 1}) |> Axon.dense(2) |> Axon.namespace("y") - inp = %{"input_0" => Nx.random_uniform({1, 1}), "input_1" => Nx.random_uniform({1, 1})} + inp = %{"input_0" => random({1, 1}), "input_1" => random({1, 1})} model = Axon.add(x, y) @@ -4828,7 +4958,7 @@ defmodule CompilerTest do |> Axon.namespace("y") |> Axon.namespace("z") - inp = %{"input_0" => Nx.random_uniform({1, 1}), "input_1" => Nx.random_uniform({1, 1})} + inp = %{"input_0" => random({1, 1}), "input_1" => random({1, 1})} model = Axon.add(x, z) @@ -4853,7 +4983,7 @@ defmodule CompilerTest do x = Axon.input("input_0", shape: {nil, 1}) |> Axon.dense(2) |> Axon.namespace("x") y = Axon.input("input_1", shape: {nil, 1}) |> Axon.dense(2) - inp = %{"input_0" => Nx.random_uniform({1, 1}), "input_1" => Nx.random_uniform({1, 1})} + inp = %{"input_0" => random({1, 1}), "input_1" => random({1, 1})} model = Axon.add(x, y) @@ -4877,7 +5007,7 @@ defmodule CompilerTest do test "initializes correctly reusing namespace" do x = Axon.input("input_0", shape: {nil, 1}) |> Axon.dense(2) |> Axon.namespace("x") - inp = Nx.random_uniform({1, 1}) + inp = random({1, 1}) model = Axon.add(x, x) {init_fn, _} = Axon.build(model) @@ -4920,7 +5050,7 @@ defmodule CompilerTest do test "predicts correctly with single namespace" do model = Axon.input("input_0", shape: {nil, 1}) |> Axon.dense(2) |> Axon.namespace("model") - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) {init_fn, _} = Axon.build(model) @@ -4932,7 +5062,7 @@ defmodule CompilerTest do test "predicts correctly with single namespace no parameters" do model = Axon.input("input_0", shape: {nil, 1}) |> Axon.namespace("model") - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert_equal(Axon.predict(model, %{}, input), input) end @@ -4946,7 +5076,7 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert %{"nested" => %{"model" => %{"dense_0" => %{"kernel" => k, "bias" => b}}}} = params = init_fn.(input, %{}) @@ -4960,7 +5090,7 @@ defmodule CompilerTest do |> Axon.namespace("model") |> Axon.namespace("nested") - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert_equal(Axon.predict(model, %{}, input), input) end @@ -4973,8 +5103,8 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - input_0 = Nx.random_uniform({1, 1}) - input_1 = Nx.random_uniform({1, 1}) + input_0 = random({1, 1}) + input_1 = random({1, 1}) inputs = %{"input_0" => input_0, "input_1" => input_1} assert %{ @@ -4999,8 +5129,8 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - input_0 = Nx.random_uniform({1, 1}) - input_1 = Nx.random_uniform({1, 1}) + input_0 = random({1, 1}) + input_1 = random({1, 1}) inputs = %{"input_0" => input_0, "input_1" => input_1} assert %{ @@ -5020,8 +5150,8 @@ defmodule CompilerTest do {init_fn, _} = Axon.build(model) - input_0 = Nx.random_uniform({1, 1}) - input_1 = Nx.random_uniform({1, 1}) + input_0 = random({1, 1}) + input_1 = random({1, 1}) inputs = %{"input_0" => input_0, "input_1" => input_1} assert %{ @@ -5038,7 +5168,7 @@ defmodule CompilerTest do model = Axon.add(x, x) {init_fn, _} = Axon.build(model) - input = Nx.random_uniform({1, 1}) + input = random({1, 1}) assert %{"x" => %{"dense_0" => %{"kernel" => k, "bias" => b}}} = params = init_fn.(input, %{}) @@ -5054,7 +5184,7 @@ defmodule CompilerTest do # model = Axon.add(inner, x) - # input = Nx.random_uniform({1, 1}) + # input = random({1, 1}) # assert %{"x" => %{"dense_0" => %{"kernel" => k, "bias" => b}}} = params = Axon.init(model) # expected = Nx.add(Axon.Layers.dense(input, k, b), Axon.Layers.dense(input, k, b)) @@ -5062,6 +5192,353 @@ defmodule CompilerTest do # end end + describe "block" do + test "initializes correctly with single dense layer, used once" do + block = Axon.block(&Axon.dense(&1, 32)) + model = block.(Axon.input("features")) + + {init_fn, _} = Axon.build(model) + + assert %{"block_0" => %{"dense_0" => %{"kernel" => k, "bias" => b}}} = + init_fn.(Nx.template({1, 1}, :f32), %{}) + + assert Nx.shape(k) == {1, 32} + assert Nx.shape(b) == {32} + assert Nx.type(k) == {:f, 32} + assert Nx.type(b) == {:f, 32} + end + + test "initializes correctly with single dense layer, used twice" do + block = Axon.block(&Axon.dense(&1, 1)) + + model = + Axon.input("features") + |> block.() + |> block.() + + {init_fn, _} = Axon.build(model) + + assert %{"block_0" => %{"dense_0" => %{"kernel" => k, "bias" => b}} = block_params} = + params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + assert Nx.shape(k) == {1, 1} + assert Nx.shape(b) == {1} + assert Nx.type(k) == {:f, 32} + assert Nx.type(b) == {:f, 32} + + # no additional dense layers in block + assert map_size(block_params) == 1 + # no additional blocks + assert map_size(params) == 1 + end + + test "initializes correctly with multiple dense layer, used once" do + block = + Axon.block(fn x -> + x + |> Axon.dense(32, activation: :relu) + |> Axon.dense(32, activation: :relu) + end) + + model = block.(Axon.input("features")) + {init_fn, _} = Axon.build(model) + + assert %{ + "block_0" => + %{ + "dense_0" => %{"kernel" => k1, "bias" => b1}, + "dense_1" => %{"kernel" => k2, "bias" => b2} + } = block_params + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + assert Nx.shape(k1) == {1, 32} + assert Nx.shape(b1) == {32} + assert Nx.shape(k2) == {32, 32} + assert Nx.shape(b2) == {32} + assert Nx.type(k1) == {:f, 32} + assert Nx.type(b1) == {:f, 32} + assert Nx.type(k2) == {:f, 32} + assert Nx.type(b2) == {:f, 32} + + # no additional dense layers in block + assert map_size(block_params) == 2 + # no additional blocks + assert map_size(params) == 1 + end + + test "initializes correctly with multiple dense layer, used multiple times" do + block = + Axon.block(fn x -> + x + |> Axon.dense(32, activation: :relu) + |> Axon.dense(1, activation: :relu) + end) + + model = Enum.reduce(0..9, Axon.input("features"), fn _, x -> block.(x) end) + + {init_fn, _} = Axon.build(model) + + assert %{ + "block_0" => + %{ + "dense_0" => %{"kernel" => k1, "bias" => b1}, + "dense_1" => %{"kernel" => k2, "bias" => b2} + } = block_params + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + assert Nx.shape(k1) == {1, 32} + assert Nx.shape(b1) == {32} + assert Nx.shape(k2) == {32, 1} + assert Nx.shape(b2) == {1} + assert Nx.type(k1) == {:f, 32} + assert Nx.type(b1) == {:f, 32} + assert Nx.type(k2) == {:f, 32} + assert Nx.type(b2) == {:f, 32} + + # no additional dense layers in block + assert map_size(block_params) == 2 + # no additional blocks + assert map_size(params) == 1 + end + + test "initializes correctly with multiple blocks in network" do + block1 = Axon.block(&Axon.dense(&1, 32)) + block2 = Axon.block(&Axon.dense(&1, 32)) + + model = + Axon.input("features") + |> block1.() + |> block2.() + + {init_fn, _} = Axon.build(model) + + assert %{ + "block_0" => + %{ + "dense_0" => %{"kernel" => k1, "bias" => b1} + } = block_0_params, + "block_1" => + %{ + "dense_0" => %{"kernel" => k2, "bias" => b2} + } = block_1_params + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + assert Nx.shape(k1) == {1, 32} + assert Nx.shape(b1) == {32} + assert Nx.shape(k2) == {32, 32} + assert Nx.shape(b2) == {32} + assert Nx.type(k1) == {:f, 32} + assert Nx.type(b1) == {:f, 32} + assert Nx.type(k2) == {:f, 32} + assert Nx.type(b2) == {:f, 32} + + # no additional dense layers in block + assert map_size(block_0_params) == 1 + assert map_size(block_1_params) == 1 + # no additional blocks + assert map_size(params) == 2 + end + + test "initializes correctly with block inside of a block" do + block = + Axon.block(fn x -> + inner_block = Axon.block(&Axon.dense(&1, 1)) + + x |> inner_block.() |> inner_block.() + end) + + model = + Axon.input("features") + |> block.() + |> block.() + + {init_fn, _} = Axon.build(model) + + assert %{ + "block_0" => + %{ + "block_0" => %{"dense_0" => %{"kernel" => k, "bias" => b}} = inner_block_params + } = block_params + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + assert Nx.shape(k) == {1, 1} + assert Nx.shape(b) == {1} + assert Nx.type(k) == {:f, 32} + assert Nx.type(b) == {:f, 32} + + assert map_size(inner_block_params) == 1 + assert map_size(block_params) == 1 + assert map_size(params) == 1 + end + + test "predicts correctly with single dense, used once" do + block = Axon.block(&Axon.dense(&1, 32)) + model = block.(Axon.input("features")) + + {init_fn, predict_fn} = Axon.build(model) + + assert %{"block_0" => %{"dense_0" => %{"kernel" => k, "bias" => b}}} = + params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + input = random({1, 1}) + + assert predict_fn.(params, input) == Axon.Layers.dense(input, k, b) + end + + test "predicts correctly with single dense, used twice" do + block = Axon.block(&Axon.dense(&1, 1)) + + model = + Axon.input("features") + |> block.() + |> block.() + + {init_fn, predict_fn} = Axon.build(model) + + assert %{"block_0" => %{"dense_0" => %{"kernel" => k, "bias" => b}}} = + params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + input = random({1, 1}) + + assert predict_fn.(params, input) == + input |> Axon.Layers.dense(k, b) |> Axon.Layers.dense(k, b) + end + + test "predicts correctly with multiple dense, used once" do + block = + Axon.block(fn x -> + x + |> Axon.dense(32, activation: :relu) + |> Axon.dense(1, activation: :relu) + end) + + model = block.(Axon.input("features")) + {init_fn, predict_fn} = Axon.build(model) + + assert %{ + "block_0" => %{ + "dense_0" => %{"kernel" => k1, "bias" => b1}, + "dense_1" => %{"kernel" => k2, "bias" => b2} + } + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + expected_predict_fn = fn x, k1, b1, k2, b2 -> + x + |> Axon.Layers.dense(k1, b1) + |> Axon.Activations.relu() + |> Axon.Layers.dense(k2, b2) + |> Axon.Layers.relu() + end + + input = random({1, 1}) + + assert predict_fn.(params, input) == expected_predict_fn.(input, k1, b1, k2, b2) + end + + test "predicts correctly with multiple dense, used twice" do + block = + Axon.block(fn x -> + x + |> Axon.dense(32, activation: :relu) + |> Axon.dense(1, activation: :relu) + end) + + model = + Axon.input("features") + |> block.() + |> block.() + + {init_fn, predict_fn} = Axon.build(model) + + assert %{ + "block_0" => %{ + "dense_0" => %{"kernel" => k1, "bias" => b1}, + "dense_1" => %{"kernel" => k2, "bias" => b2} + } + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + expected_predict_fn = fn x, k1, b1, k2, b2 -> + x + |> Axon.Layers.dense(k1, b1) + |> Axon.Activations.relu() + |> Axon.Layers.dense(k2, b2) + |> Axon.Layers.relu() + |> Axon.Layers.dense(k1, b1) + |> Axon.Activations.relu() + |> Axon.Layers.dense(k2, b2) + |> Axon.Layers.relu() + end + + input = random({1, 1}) + + assert predict_fn.(params, input) == expected_predict_fn.(input, k1, b1, k2, b2) + end + + test "predicts correctly with multiple blocks in network" do + block1 = Axon.block(&Axon.dense(&1, 32)) + block2 = Axon.block(&Axon.dense(&1, 32)) + + model = + Axon.input("features") + |> block1.() + |> block2.() + + {init_fn, predict_fn} = Axon.build(model) + + actual_predict_fn = fn x, k1, b1, k2, b2 -> + x + |> Axon.Layers.dense(k1, b1) + |> Axon.Layers.dense(k2, b2) + end + + assert %{ + "block_0" => %{ + "dense_0" => %{"kernel" => k1, "bias" => b1} + }, + "block_1" => %{ + "dense_0" => %{"kernel" => k2, "bias" => b2} + } + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + input = random({1, 1}) + + assert predict_fn.(params, input) == actual_predict_fn.(input, k1, b1, k2, b2) + end + + test "predicts correctly with block inside of a block" do + block = + Axon.block(fn x -> + inner_block = Axon.block(&Axon.dense(&1, 1)) + + x |> inner_block.() |> inner_block.() + end) + + model = + Axon.input("features") + |> block.() + |> block.() + + {init_fn, predict_fn} = Axon.build(model) + + actual_predict_fn = fn x, k, b -> + x + |> Axon.Layers.dense(k, b) + |> Axon.Layers.dense(k, b) + |> Axon.Layers.dense(k, b) + |> Axon.Layers.dense(k, b) + end + + assert %{ + "block_0" => %{ + "block_0" => %{"dense_0" => %{"kernel" => k, "bias" => b}} + } + } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) + + input = random({1, 1}) + assert predict_fn.(params, input) == actual_predict_fn.(input, k, b) + end + end + describe "initializers" do test "work with functions" do model = @@ -5141,7 +5618,7 @@ defmodule CompilerTest do # {init_fn, _} = Axon.build(model) - # inp = Nx.random_uniform({1, 1}) + # inp = random({1, 1}) # assert_raise ArgumentError, # ~s{found unexpected key in the initial parameters map: "dense_2"}, @@ -5153,8 +5630,8 @@ defmodule CompilerTest do describe "containers" do test "allows accessors with custom layers" do - input1 = Nx.random_uniform({1, 1}) - input2 = Nx.random_uniform({1, 2}) + input1 = random({1, 1}) + input2 = random({1, 2}) inputs = %{"input_0" => input1, "input_1" => input2} inp1 = Axon.input("input_0", shape: {nil, 1}) @@ -5281,4 +5758,19 @@ defmodule CompilerTest do assert predict_fn1 == predict_fn2 end end + + describe "metadata" do + test "axon compiler attaches layer name as metadata to subgraphs" do + model = Axon.input("input", shape: {nil, 784}) |> Axon.dense(128) + + {init_fn, predict_fn} = Axon.build(model) + params = init_fn.(Nx.template({1, 784}, :f32), %{}) + input = Nx.broadcast(0.0, {1, 784}) + + expr_fn = Nx.Defn.jit(predict_fn, compiler: Axon.Defn) + expr = expr_fn.(params, input) + + assert %{data: %{op: :metadata, args: [_tensor, %{axon_layer: :dense}]}} = expr + end + end end diff --git a/test/axon/integration_test.exs b/test/axon/integration_test.exs index 19d35440..8c95ed56 100644 --- a/test/axon/integration_test.exs +++ b/test/axon/integration_test.exs @@ -26,7 +26,58 @@ defmodule Axon.IntegrationTest do ExUnit.CaptureIO.capture_io(fn -> results = model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(5.0e-3)) + |> Axon.Loop.trainer( + :categorical_cross_entropy, + Polaris.Optimizers.adam(learning_rate: 5.0e-3) + ) + # TODO: Fix default output transform + |> Map.update(:output_transform, nil, fn _ -> & &1 end) + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.validate(model, train) + |> Axon.Loop.run(train, %{}, epochs: 10) + + assert %{step_state: %{model_state: model_state}, metrics: %{9 => last_epoch_metrics}} = + results + + eval_results = + model + |> Axon.Loop.evaluator() + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.run(train, model_state) + + assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results + + assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.7) + assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"]) + assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2} + end) + end + + test "f64 input test" do + {train, _test} = get_test_data(100, 0, 10, {10}, 2, 1337) + + train = + train + |> Stream.map(fn {xs, ys} -> + {Nx.as_type(xs, :f64), one_hot(ys, num_classes: 2)} + end) + |> Enum.to_list() + + [{x_test, _}] = Enum.take(train, 1) + + model = + Axon.input("input") + |> Axon.dense(16) + |> Axon.dropout(rate: 0.1) + |> Axon.dense(2, activation: :softmax) + + ExUnit.CaptureIO.capture_io(fn -> + results = + model + |> Axon.Loop.trainer( + :categorical_cross_entropy, + Polaris.Optimizers.adam(learning_rate: 5.0e-3) + ) # TODO: Fix default output transform |> Map.update(:output_transform, nil, fn _ -> & &1 end) |> Axon.Loop.metric(:accuracy) @@ -74,7 +125,10 @@ defmodule Axon.IntegrationTest do ExUnit.CaptureIO.capture_io(fn -> results = model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(5.0e-3)) + |> Axon.Loop.trainer( + :categorical_cross_entropy, + Polaris.Optimizers.adam(learning_rate: 5.0e-3) + ) # TODO: Fix default output transform |> Map.update(:output_transform, nil, fn _ -> & &1 end) |> Axon.Loop.metric(:accuracy) @@ -121,7 +175,59 @@ defmodule Axon.IntegrationTest do ExUnit.CaptureIO.capture_io(fn -> results = model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(5.0e-3)) + |> Axon.Loop.trainer( + :categorical_cross_entropy, + Polaris.Optimizers.adam(learning_rate: 5.0e-3) + ) + # TODO: Fix default output transform + |> Map.update(:output_transform, nil, fn _ -> & &1 end) + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.validate(model, train) + |> Axon.Loop.run(train, %{}, epochs: 10) + + assert %{step_state: %{model_state: model_state}, metrics: %{9 => last_epoch_metrics}} = + results + + eval_results = + model + |> Axon.Loop.evaluator() + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.run(train, model_state) + + assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results + + assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.7) + assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"]) + assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2} + end) + end + + test "gradient accumulation test" do + {train, _test} = get_test_data(100, 0, 10, {10}, 2, 1337) + + train = + train + |> Stream.map(fn {xs, ys} -> + {xs, one_hot(ys, num_classes: 2)} + end) + |> Enum.to_list() + + [{x_test, _}] = Enum.take(train, 1) + + model = + Axon.input("input") + |> Axon.dense(16) + |> Axon.dropout(rate: 0.1) + |> Axon.dense(2, activation: :softmax) + + ExUnit.CaptureIO.capture_io(fn -> + results = + model + |> Axon.Loop.trainer( + :categorical_cross_entropy, + Polaris.Optimizers.adam(learning_rate: 5.0e-3), + gradient_accumulation_steps: 3 + ) # TODO: Fix default output transform |> Map.update(:output_transform, nil, fn _ -> & &1 end) |> Axon.Loop.metric(:accuracy) @@ -164,7 +270,11 @@ defmodule Axon.IntegrationTest do ExUnit.CaptureIO.capture_io(fn -> %{metrics: metrics1, step_state: step_state1} = model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(5.0e-3), seed: 1) + |> Axon.Loop.trainer( + :categorical_cross_entropy, + Polaris.Optimizers.adam(learning_rate: 5.0e-3), + seed: 1 + ) # TODO: Fix default output transform |> Map.update(:output_transform, nil, fn _ -> & &1 end) |> Axon.Loop.metric(:accuracy) @@ -173,7 +283,11 @@ defmodule Axon.IntegrationTest do %{metrics: metrics2, step_state: step_state2} = model - |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(5.0e-3), seed: 1) + |> Axon.Loop.trainer( + :categorical_cross_entropy, + Polaris.Optimizers.adam(learning_rate: 5.0e-3), + seed: 1 + ) # TODO: Fix default output transform |> Map.update(:output_transform, nil, fn _ -> & &1 end) |> Axon.Loop.metric(:accuracy) @@ -184,4 +298,250 @@ defmodule Axon.IntegrationTest do assert_equal(step_state1, step_state2) end) end + + describe "optimizer integration" do + @optimizers_and_args [ + {:adabelief, [[learning_rate: 5.0e-3]]}, + {:adagrad, [[learning_rate: 5.0e-3]]}, + {:adam, [[learning_rate: 5.0e-3]]}, + {:adamw, [[learning_rate: 5.0e-3]]}, + {:adamw, [[learning_rate: 5.0e-3, decay: 0.9]]}, + {:lamb, [[learning_rate: 5.0e-3]]}, + {:lamb, [[learning_rate: 5.0e-3, decay: 0.9]]}, + {:lamb, [[learning_rate: 5.0e-3, min_norm: 0.1]]}, + {:lamb, [[learning_rate: 5.0e-3, decay: 0.9, min_norm: 0.1]]}, + {:noisy_sgd, [[learning_rate: 5.0e-3]]}, + {:radam, [[learning_rate: 5.0e-3]]}, + {:rmsprop, [[learning_rate: 5.0e-3]]}, + {:rmsprop, [[learning_rate: 5.0e-3, centered: true]]}, + {:rmsprop, [[learning_rate: 5.0e-3, momentum: 0.9]]}, + {:rmsprop, [[learning_rate: 5.0e-3, nesterov: true, momentum: 0.9]]}, + {:rmsprop, [[learning_rate: 5.0e-3, centered: true, nesterov: true, momentum: 0.9]]}, + {:sgd, [[learning_rate: 5.0e-3]]}, + {:sgd, [[learning_rate: 5.0e-3, momentum: 0.9]]}, + {:sgd, [[learning_rate: 5.0e-3, momentum: 0.9, nesterov: true]]} + ] + + for {optimizer, [opts] = args} <- @optimizers_and_args do + lr = opts[:learning_rate] + + test "#{optimizer}, learning_rate: #{lr}, opts: #{inspect(opts)} trains simple model with dropout" do + {train, _test} = get_test_data(100, 0, 10, {10}, 2, 1337) + + train = + train + |> Stream.map(fn {xs, ys} -> + {xs, one_hot(ys, num_classes: 2)} + end) + |> Enum.to_list() + + [{x_test, _}] = Enum.take(train, 1) + + model = + Axon.input("input") + |> Axon.dense(16) + |> Axon.dropout(rate: 0.1) + |> Axon.dense(2, activation: :softmax) + + ExUnit.CaptureIO.capture_io(fn -> + results = + model + |> Axon.Loop.trainer( + :categorical_cross_entropy, + Polaris.Optimizers.unquote(optimizer)(unquote_splicing(args)) + ) + # TODO: Fix default output transform + |> Map.update(:output_transform, nil, fn _ -> & &1 end) + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.validate(model, train) + |> Axon.Loop.run(train, %{}, epochs: 10) + + assert %{step_state: %{model_state: model_state}, metrics: %{9 => last_epoch_metrics}} = + results + + eval_results = + model + |> Axon.Loop.evaluator() + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.run(train, model_state) + + assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results + + assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.7) + assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"]) + assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2} + end) + end + end + end + + describe "mixed precision training integration" do + @policies [ + {"compute f16", Axon.MixedPrecision.create_policy(compute: {:f, 16})}, + {"compute f16, params f16", + Axon.MixedPrecision.create_policy(compute: {:f, 16}, params: {:f, 16})}, + {"compute f16, params f16, output f16", + Axon.MixedPrecision.create_policy(params: {:f, 16}, compute: {:f, 16}, output: {:f, 16})} + ] + + @scales [:identity, :dynamic, :static] + + for {name, policy} <- @policies, scale <- @scales do + test "trains simple model with policy #{name}, scale #{inspect(scale)}" do + {train, _test} = get_test_data(100, 0, 10, {10}, 2, 1337) + + train = + train + |> Stream.map(fn {xs, ys} -> + {xs, one_hot(ys, num_classes: 2)} + end) + |> Enum.to_list() + + [{x_test, _}] = Enum.take(train, 1) + + model = + Axon.input("input") + |> Axon.dense(16) + |> Axon.dropout(rate: 0.1) + |> Axon.dense(2, activation: :softmax) + |> Axon.MixedPrecision.apply_policy(unquote(Macro.escape(policy))) + + ExUnit.CaptureIO.capture_io(fn -> + results = + model + |> Axon.Loop.trainer(:categorical_cross_entropy, :adam, loss_scale: unquote(scale)) + # TODO: Fix default output transform + |> Map.update(:output_transform, nil, fn _ -> & &1 end) + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.validate(model, train) + |> Axon.Loop.run(train, %{}, epochs: 10) + + assert %{step_state: %{model_state: model_state}, metrics: %{9 => last_epoch_metrics}} = + results + + eval_results = + model + |> Axon.Loop.evaluator() + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.run(train, model_state) + + assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results + + assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.60) + assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"]) + assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2} + assert Nx.type(model_state["dense_0"]["kernel"]) == unquote(Macro.escape(policy)).params + end) + end + + test "trains model with batch norm with policy #{name}, scale #{inspect(scale)}" do + {train, _test} = get_test_data(100, 0, 10, {10}, 2, 1337) + + train = + train + |> Stream.map(fn {xs, ys} -> + {xs, one_hot(ys, num_classes: 2)} + end) + |> Enum.to_list() + + [{x_test, _}] = Enum.take(train, 1) + + model = + Axon.input("input") + |> Axon.dense(16) + |> Axon.batch_norm() + |> Axon.dropout(rate: 0.1) + |> Axon.dense(2, activation: :softmax) + |> Axon.MixedPrecision.apply_policy( + unquote(Macro.escape(policy)), + except: [:batch_norm] + ) + + ExUnit.CaptureIO.capture_io(fn -> + results = + model + |> Axon.Loop.trainer(:categorical_cross_entropy, :adam, loss_scale: unquote(scale)) + # TODO: Fix default output transform + |> Map.update(:output_transform, nil, fn _ -> & &1 end) + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.validate(model, train) + |> Axon.Loop.run(train, %{}, epochs: 10) + + assert %{step_state: %{model_state: model_state}, metrics: %{9 => last_epoch_metrics}} = + results + + eval_results = + model + |> Axon.Loop.evaluator() + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.run(train, model_state) + + assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results + + assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.60) + assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"]) + assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2} + assert Nx.type(model_state["dense_0"]["kernel"]) == unquote(Macro.escape(policy)).params + end) + end + end + + test "mixed precision downcasts model when state is given to train" do + policy = + Axon.MixedPrecision.create_policy( + params: {:f, 16}, + compute: {:f, 16}, + output: {:f, 16} + ) + + {train, _test} = get_test_data(100, 0, 10, {10}, 2, 1337) + + train = + train + |> Stream.map(fn {xs, ys} -> + {xs, one_hot(ys, num_classes: 2)} + end) + |> Enum.to_list() + + [{x_test, _}] = Enum.take(train, 1) + + model = + Axon.input("input") + |> Axon.dense(16) + |> Axon.dropout(rate: 0.1) + |> Axon.dense(2, activation: :softmax) + + {init_fn, _} = Axon.build(model) + initial_state = init_fn.(Nx.template({1, 10}, :f32), %{}) + + mp_model = Axon.MixedPrecision.apply_policy(model, policy) + + ExUnit.CaptureIO.capture_io(fn -> + results = + mp_model + |> Axon.Loop.trainer(:categorical_cross_entropy, :adam, loss_scale: :dynamic) + # TODO: Fix default output transform + |> Map.update(:output_transform, nil, fn _ -> & &1 end) + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.validate(model, train) + |> Axon.Loop.run(train, initial_state, epochs: 10) + + assert %{step_state: %{model_state: model_state}, metrics: %{9 => last_epoch_metrics}} = + results + + eval_results = + model + |> Axon.Loop.evaluator() + |> Axon.Loop.metric(:accuracy) + |> Axon.Loop.run(train, model_state) + + assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results + + assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.60) + assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"]) + assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2} + assert Nx.type(model_state["dense_0"]["kernel"]) == policy.params + end) + end + end end diff --git a/test/axon/layers_test.exs b/test/axon/layers_test.exs index cc1fb9fc..37d930b9 100644 --- a/test/axon/layers_test.exs +++ b/test/axon/layers_test.exs @@ -186,9 +186,9 @@ defmodule Axon.LayersTest do describe "conv" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) - kernel = Nx.random_uniform({3, 1, 4, 4}) + kernel = random({3, 1, 4, 4}) t_kernel = Nx.transpose(kernel, axes: [2, 3, 1, 0]) bias = Nx.tensor(0.0) @@ -198,6 +198,19 @@ defmodule Axon.LayersTest do assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) end + test "channels last same as channels first with strides" do + input = random({1, 1, 28, 28}) + t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) + kernel = random({3, 1, 4, 4}) + t_kernel = Nx.transpose(kernel, axes: [2, 3, 1, 0]) + bias = Nx.tensor(0.0) + + first = Axon.Layers.conv(input, kernel, bias, channels: :first, strides: [1, 2]) + last = Axon.Layers.conv(t_input, t_kernel, bias, channels: :last, strides: [1, 2]) + + assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) + end + test "raises on input rank less than 3" do inp = Nx.iota({1, 1}) kernel = Nx.iota({2, 1, 1}) @@ -225,9 +238,9 @@ defmodule Axon.LayersTest do describe "conv_transpose" do test "channels first same as channels last" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) - kernel = Nx.random_uniform({3, 1, 4, 4}) + kernel = random({3, 1, 4, 4}) t_kernel = Nx.transpose(kernel, axes: [2, 3, 1, 0]) bias = Nx.tensor(0.0) @@ -237,6 +250,19 @@ defmodule Axon.LayersTest do assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) end + test "channels first same as channels last with strides" do + input = random({1, 1, 28, 28}) + t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) + kernel = random({3, 1, 4, 4}) + t_kernel = Nx.transpose(kernel, axes: [2, 3, 1, 0]) + bias = Nx.tensor(0.0) + + first = Axon.Layers.conv_transpose(input, kernel, bias, channels: :first, strides: [1, 2]) + last = Axon.Layers.conv_transpose(t_input, t_kernel, bias, channels: :last, strides: [1, 2]) + + assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) + end + test "correct valid padding, no strides" do inp = Nx.iota({1, 1, 4}, type: {:f, 32}) kernel = Nx.iota({3, 1, 2}, type: {:f, 32}) @@ -422,9 +448,9 @@ defmodule Axon.LayersTest do describe "depthwise conv" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 3, 28, 28}) + input = random({1, 3, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) - kernel = Nx.random_uniform({6, 1, 4, 4}) + kernel = random({6, 1, 4, 4}) t_kernel = Nx.transpose(kernel, axes: [2, 3, 1, 0]) bias = Nx.tensor(0.0) @@ -461,11 +487,11 @@ defmodule Axon.LayersTest do describe "separable_conv2d" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 3, 28, 28}) + input = random({1, 3, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) - k1 = Nx.random_uniform({6, 1, 4, 1}) + k1 = random({6, 1, 4, 1}) t_k1 = Nx.transpose(k1, axes: [2, 3, 1, 0]) - k2 = Nx.random_uniform({6, 1, 1, 4}) + k2 = random({6, 1, 1, 4}) t_k2 = Nx.transpose(k2, axes: [2, 3, 1, 0]) b1 = Nx.tensor(0.0) b2 = Nx.tensor(0.0) @@ -504,13 +530,13 @@ defmodule Axon.LayersTest do describe "separable_conv3d" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 3, 8, 8, 8}) + input = random({1, 3, 8, 8, 8}) t_input = Nx.transpose(input, axes: [0, 2, 3, 4, 1]) - k1 = Nx.random_uniform({6, 1, 4, 1, 1}) + k1 = random({6, 1, 4, 1, 1}) t_k1 = Nx.transpose(k1, axes: [2, 3, 4, 1, 0]) - k2 = Nx.random_uniform({6, 1, 1, 4, 1}) + k2 = random({6, 1, 1, 4, 1}) t_k2 = Nx.transpose(k2, axes: [2, 3, 4, 1, 0]) - k3 = Nx.random_uniform({6, 1, 1, 1, 4}) + k3 = random({6, 1, 1, 1, 4}) t_k3 = Nx.transpose(k3, axes: [2, 3, 4, 1, 0]) b1 = b2 = b3 = Nx.tensor(0.0) @@ -549,7 +575,7 @@ defmodule Axon.LayersTest do describe "max_pool" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = Axon.Layers.max_pool(input, kernel_size: {2, 2}, channels: :first) @@ -559,7 +585,7 @@ defmodule Axon.LayersTest do end test "channels last same as channels first with dilation" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = @@ -579,6 +605,27 @@ defmodule Axon.LayersTest do assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) end + test "channels last same as channels first with custom padding" do + input = random({1, 1, 28, 28}) + t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) + + first = + Axon.Layers.max_pool(input, + kernel_size: {2, 2}, + channels: :first, + padding: [{2, 2}, {1, 2}] + ) + + last = + Axon.Layers.max_pool(t_input, + kernel_size: {2, 2}, + channels: :last, + padding: [{2, 2}, {1, 2}] + ) + + assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) + end + test "raises on input rank less than 3" do inp = Nx.iota({1, 1}) @@ -592,7 +639,7 @@ defmodule Axon.LayersTest do describe "avg_pool" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = Axon.Layers.avg_pool(input, kernel_size: {2, 2}, channels: :first) @@ -602,7 +649,7 @@ defmodule Axon.LayersTest do end test "channels last same as channels first with dilation" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = @@ -622,6 +669,27 @@ defmodule Axon.LayersTest do assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) end + test "channels last same as channels first with custom padding" do + input = random({1, 1, 28, 28}) + t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) + + first = + Axon.Layers.max_pool(input, + kernel_size: {2, 2}, + channels: :first, + padding: [{2, 2}, {1, 2}] + ) + + last = + Axon.Layers.max_pool(t_input, + kernel_size: {2, 2}, + channels: :last, + padding: [{2, 2}, {1, 2}] + ) + + assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) + end + test "raises on input rank less than 3" do inp = Nx.iota({1, 1}) @@ -635,7 +703,7 @@ defmodule Axon.LayersTest do describe "lp pool" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = Axon.Layers.lp_pool(input, kernel_size: {2, 2}, channels: :first) @@ -645,7 +713,7 @@ defmodule Axon.LayersTest do end test "channels last same as channels first with dilation" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = @@ -665,6 +733,27 @@ defmodule Axon.LayersTest do assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) end + test "channels last same as channels first with custom padding" do + input = random({1, 1, 28, 28}) + t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) + + first = + Axon.Layers.max_pool(input, + kernel_size: {2, 2}, + channels: :first, + padding: [{2, 2}, {1, 2}] + ) + + last = + Axon.Layers.max_pool(t_input, + kernel_size: {2, 2}, + channels: :last, + padding: [{2, 2}, {1, 2}] + ) + + assert_equal(first, Nx.transpose(last, axes: [0, 3, 1, 2])) + end + test "raises on input rank less than 3" do inp = Nx.iota({1, 1}) @@ -678,7 +767,7 @@ defmodule Axon.LayersTest do describe "adaptive avg pool" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = Axon.Layers.adaptive_avg_pool(input, output_size: {25, 25}, channels: :first) @@ -700,7 +789,7 @@ defmodule Axon.LayersTest do describe "adaptive max pool" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = Axon.Layers.adaptive_max_pool(input, output_size: {25, 25}, channels: :first) @@ -722,7 +811,7 @@ defmodule Axon.LayersTest do describe "adaptive lp pool" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = Axon.Layers.adaptive_lp_pool(input, output_size: {25, 25}, channels: :first) @@ -768,7 +857,7 @@ defmodule Axon.LayersTest do describe "global_max_pool" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = Axon.Layers.global_max_pool(input, channels: :first) @@ -790,7 +879,7 @@ defmodule Axon.LayersTest do describe "global_avg_pool" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = Axon.Layers.global_avg_pool(input, channels: :first) @@ -812,7 +901,7 @@ defmodule Axon.LayersTest do describe "global_lp_pool" do test "channels last same as channels first" do - input = Nx.random_uniform({1, 1, 28, 28}) + input = random({1, 1, 28, 28}) t_input = Nx.transpose(input, axes: [0, 2, 3, 1]) first = Axon.Layers.global_lp_pool(input, channels: :first) @@ -857,10 +946,109 @@ defmodule Axon.LayersTest do end end + describe "lstm_cell" do + test "cell function matches results expected from pytorch" do + seq = + File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_input_seq.npy") + |> Nx.load_numpy!() + + c = + File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_input_c.npy") |> Nx.load_numpy!() + + h = + File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_input_h.npy") |> Nx.load_numpy!() + + wii = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_wii.npy") |> Nx.load_numpy!() + wif = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_wif.npy") |> Nx.load_numpy!() + wig = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_wig.npy") |> Nx.load_numpy!() + wio = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_wio.npy") |> Nx.load_numpy!() + whi = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_whi.npy") |> Nx.load_numpy!() + whf = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_whf.npy") |> Nx.load_numpy!() + whg = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_whg.npy") |> Nx.load_numpy!() + who = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_who.npy") |> Nx.load_numpy!() + bi = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_bi.npy") |> Nx.load_numpy!() + bf = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_bf.npy") |> Nx.load_numpy!() + bg = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_bg.npy") |> Nx.load_numpy!() + bo = File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_bo.npy") |> Nx.load_numpy!() + + expected_c = + File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_output_c.npy") |> Nx.load_numpy!() + + expected_h = + File.read!("test/fixtures/lstm_cell_test/test_lstm_cell_output_h.npy") |> Nx.load_numpy!() + + {_, {new_c, new_h}} = + Axon.Layers.lstm_cell( + seq, + {c, h}, + Nx.tensor(0), + {wii, wif, wig, wio}, + {whi, whf, whg, who}, + {bi, bf, bg, bo} + ) + + assert_all_close(new_c, expected_c) + assert_all_close(new_h, expected_h) + end + end + + describe "lstm" do + test "matches results expected from pytorch with dynamic unroll" do + seq = File.read!("test/fixtures/lstm_test/test_lstm_input_seq.npy") |> Nx.load_numpy!() + + c = + File.read!("test/fixtures/lstm_test/test_lstm_input_c.npy") + |> Nx.load_numpy!() + |> Nx.squeeze() + + h = + File.read!("test/fixtures/lstm_test/test_lstm_input_h.npy") + |> Nx.load_numpy!() + |> Nx.squeeze() + + wii = File.read!("test/fixtures/lstm_test/test_lstm_wii.npy") |> Nx.load_numpy!() + wif = File.read!("test/fixtures/lstm_test/test_lstm_wif.npy") |> Nx.load_numpy!() + wig = File.read!("test/fixtures/lstm_test/test_lstm_wig.npy") |> Nx.load_numpy!() + wio = File.read!("test/fixtures/lstm_test/test_lstm_wio.npy") |> Nx.load_numpy!() + whi = File.read!("test/fixtures/lstm_test/test_lstm_whi.npy") |> Nx.load_numpy!() + whf = File.read!("test/fixtures/lstm_test/test_lstm_whf.npy") |> Nx.load_numpy!() + whg = File.read!("test/fixtures/lstm_test/test_lstm_whg.npy") |> Nx.load_numpy!() + who = File.read!("test/fixtures/lstm_test/test_lstm_who.npy") |> Nx.load_numpy!() + bi = File.read!("test/fixtures/lstm_test/test_lstm_bi.npy") |> Nx.load_numpy!() + bf = File.read!("test/fixtures/lstm_test/test_lstm_bf.npy") |> Nx.load_numpy!() + bg = File.read!("test/fixtures/lstm_test/test_lstm_bg.npy") |> Nx.load_numpy!() + bo = File.read!("test/fixtures/lstm_test/test_lstm_bo.npy") |> Nx.load_numpy!() + + expected_seq = + File.read!("test/fixtures/lstm_test/test_lstm_output_seq.npy") |> Nx.load_numpy!() + + expected_c = + File.read!("test/fixtures/lstm_test/test_lstm_output_c.npy") |> Nx.load_numpy!() + + expected_h = + File.read!("test/fixtures/lstm_test/test_lstm_output_h.npy") |> Nx.load_numpy!() + + {new_seq, {new_c, new_h}} = + Axon.Layers.lstm( + seq, + {c, h}, + Nx.tensor(0), + {wii, wif, wig, wio}, + {whi, whf, whg, who}, + {bi, bf, bg, bo}, + unroll: :dynamic + ) + + assert_all_close(new_seq, expected_seq, atol: 1.0e-3) + assert_all_close(new_c, expected_c, atol: 1.0e-3) + assert_all_close(new_h, expected_h, atol: 1.0e-3) + end + end + describe "dynamic_unroll" do test "computes carry and output identical to static_unroll" do input = Nx.iota({1, 4, 2}, type: {:f, 32}) - carry = {Nx.iota({1, 1, 8}, type: {:f, 32})} + carry = {Nx.iota({1, 8}, type: {:f, 32})} input_kernel = {Nx.iota({2, 8}, type: {:f, 32}), Nx.iota({2, 8}, type: {:f, 32}), @@ -874,13 +1062,29 @@ defmodule Axon.LayersTest do {Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32})} - cell_fn = &Axon.Layers.gru_cell/5 + cell_fn = &Axon.Layers.gru_cell/6 {s_output, {s_carry}} = - Axon.Layers.static_unroll(cell_fn, input, carry, input_kernel, hidden_kernel, bias) + Axon.Layers.static_unroll( + cell_fn, + input, + carry, + Nx.tensor(0), + input_kernel, + hidden_kernel, + bias + ) {d_output, {d_carry}} = - Axon.Layers.dynamic_unroll(cell_fn, input, carry, input_kernel, hidden_kernel, bias) + Axon.Layers.dynamic_unroll( + cell_fn, + input, + carry, + Nx.tensor(0), + input_kernel, + hidden_kernel, + bias + ) assert_equal(s_carry, d_carry) assert_equal(s_output, d_output) @@ -888,7 +1092,8 @@ defmodule Axon.LayersTest do defn grad_static_hidden_output(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(hidden_kernel, fn x -> - {output, _} = Axon.Layers.static_unroll(cell_fn, input, carry, input_kernel, x, bias) + {output, _} = + Axon.Layers.static_unroll(cell_fn, input, carry, Nx.tensor(0), input_kernel, x, bias) Nx.mean(output) end) @@ -896,7 +1101,8 @@ defmodule Axon.LayersTest do defn grad_dynamic_hidden_output(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(hidden_kernel, fn x -> - {output, _} = Axon.Layers.dynamic_unroll(cell_fn, input, carry, input_kernel, x, bias) + {output, _} = + Axon.Layers.dynamic_unroll(cell_fn, input, carry, Nx.tensor(0), input_kernel, x, bias) Nx.mean(output) end) @@ -904,7 +1110,7 @@ defmodule Axon.LayersTest do test "computes gradient identical to static unroll for hidden kernel w.r.t. output" do input = Nx.iota({1, 4, 2}, type: {:f, 32}) - carry = {Nx.iota({1, 1, 8}, type: {:f, 32})} + carry = {Nx.iota({1, 8}, type: {:f, 32})} input_kernel = {Nx.iota({2, 8}, type: {:f, 32}), Nx.iota({2, 8}, type: {:f, 32}), @@ -918,7 +1124,7 @@ defmodule Axon.LayersTest do {Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32})} - cell_fn = &Axon.Layers.gru_cell/5 + cell_fn = &Axon.Layers.gru_cell/6 assert_equal( grad_static_hidden_output(input, carry, input_kernel, hidden_kernel, bias, cell_fn), @@ -935,7 +1141,16 @@ defmodule Axon.LayersTest do defn grad_static_hidden_carry(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(hidden_kernel, fn x -> - {_, {carry}} = Axon.Layers.static_unroll(cell_fn, input, carry, input_kernel, x, bias) + {_, {carry}} = + Axon.Layers.static_unroll( + cell_fn, + input, + carry, + Nx.tensor([[0, 0, 0, 1]]), + input_kernel, + x, + bias + ) Nx.mean(carry) end) @@ -943,7 +1158,16 @@ defmodule Axon.LayersTest do defn grad_dynamic_hidden_carry(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(hidden_kernel, fn x -> - {_, {carry}} = Axon.Layers.dynamic_unroll(cell_fn, input, carry, input_kernel, x, bias) + {_, {carry}} = + Axon.Layers.dynamic_unroll( + cell_fn, + input, + carry, + Nx.tensor([[0, 0, 0, 1]]), + input_kernel, + x, + bias + ) Nx.mean(carry) end) @@ -951,7 +1175,7 @@ defmodule Axon.LayersTest do test "computes gradient identical to static_unroll for hidden kernel w.r.t carry" do input = Nx.iota({1, 4, 2}, type: {:f, 32}) - carry = {Nx.iota({1, 1, 8}, type: {:f, 32})} + carry = {Nx.iota({1, 8}, type: {:f, 32})} input_kernel = {Nx.iota({2, 8}, type: {:f, 32}), Nx.iota({2, 8}, type: {:f, 32}), @@ -965,7 +1189,7 @@ defmodule Axon.LayersTest do {Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32})} - cell_fn = &Axon.Layers.gru_cell/5 + cell_fn = &Axon.Layers.gru_cell/6 assert_equal( grad_static_hidden_carry(input, carry, input_kernel, hidden_kernel, bias, cell_fn), @@ -975,7 +1199,8 @@ defmodule Axon.LayersTest do defn grad_static_input_output(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(input_kernel, fn x -> - {output, _} = Axon.Layers.static_unroll(cell_fn, input, carry, x, hidden_kernel, bias) + {output, _} = + Axon.Layers.static_unroll(cell_fn, input, carry, Nx.tensor(0), x, hidden_kernel, bias) Nx.mean(output) end) @@ -983,7 +1208,8 @@ defmodule Axon.LayersTest do defn grad_dynamic_input_output(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(input_kernel, fn x -> - {output, _} = Axon.Layers.dynamic_unroll(cell_fn, input, carry, x, hidden_kernel, bias) + {output, _} = + Axon.Layers.dynamic_unroll(cell_fn, input, carry, Nx.tensor(0), x, hidden_kernel, bias) Nx.mean(output) end) @@ -991,7 +1217,7 @@ defmodule Axon.LayersTest do test "computes gradient identical to static unroll for input kernel w.r.t. output" do input = Nx.iota({1, 4, 2}, type: {:f, 32}) - carry = {Nx.iota({1, 1, 8}, type: {:f, 32})} + carry = {Nx.iota({1, 8}, type: {:f, 32})} input_kernel = {Nx.iota({2, 8}, type: {:f, 32}), Nx.iota({2, 8}, type: {:f, 32}), @@ -1005,7 +1231,7 @@ defmodule Axon.LayersTest do {Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32})} - cell_fn = &Axon.Layers.gru_cell/5 + cell_fn = &Axon.Layers.gru_cell/6 assert_equal( grad_static_input_output(input, carry, input_kernel, hidden_kernel, bias, cell_fn), @@ -1015,7 +1241,16 @@ defmodule Axon.LayersTest do defn grad_static_input_carry(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(input_kernel, fn x -> - {_, {carry}} = Axon.Layers.static_unroll(cell_fn, input, carry, x, hidden_kernel, bias) + {_, {carry}} = + Axon.Layers.static_unroll( + cell_fn, + input, + carry, + Nx.tensor([[0, 0, 0, 1]]), + x, + hidden_kernel, + bias + ) Nx.mean(carry) end) @@ -1023,7 +1258,16 @@ defmodule Axon.LayersTest do defn grad_dynamic_input_carry(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(input_kernel, fn x -> - {_, {carry}} = Axon.Layers.dynamic_unroll(cell_fn, input, carry, x, hidden_kernel, bias) + {_, {carry}} = + Axon.Layers.dynamic_unroll( + cell_fn, + input, + carry, + Nx.tensor([[0, 0, 0, 1]]), + x, + hidden_kernel, + bias + ) Nx.mean(carry) end) @@ -1031,7 +1275,7 @@ defmodule Axon.LayersTest do test "computes gradient identical to static unroll for input kernel w.r.t. carry" do input = Nx.iota({1, 4, 2}, type: {:f, 32}) - carry = {Nx.iota({1, 1, 8}, type: {:f, 32})} + carry = {Nx.iota({1, 8}, type: {:f, 32})} input_kernel = {Nx.iota({2, 8}, type: {:f, 32}), Nx.iota({2, 8}, type: {:f, 32}), @@ -1045,7 +1289,7 @@ defmodule Axon.LayersTest do {Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32})} - cell_fn = &Axon.Layers.gru_cell/5 + cell_fn = &Axon.Layers.gru_cell/6 assert_equal( grad_static_input_carry(input, carry, input_kernel, hidden_kernel, bias, cell_fn), @@ -1056,7 +1300,15 @@ defmodule Axon.LayersTest do defn grad_static_bias_output(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(bias, fn x -> {output, _} = - Axon.Layers.static_unroll(cell_fn, input, carry, input_kernel, hidden_kernel, x) + Axon.Layers.static_unroll( + cell_fn, + input, + carry, + Nx.tensor([[0, 0, 0, 1]]), + input_kernel, + hidden_kernel, + x + ) Nx.mean(output) end) @@ -1065,7 +1317,15 @@ defmodule Axon.LayersTest do defn grad_dynamic_bias_output(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(bias, fn x -> {output, _} = - Axon.Layers.dynamic_unroll(cell_fn, input, carry, input_kernel, hidden_kernel, x) + Axon.Layers.dynamic_unroll( + cell_fn, + input, + carry, + Nx.tensor([[0, 0, 0, 1]]), + input_kernel, + hidden_kernel, + x + ) Nx.mean(output) end) @@ -1073,7 +1333,7 @@ defmodule Axon.LayersTest do test "computes gradient identical to static unroll for bias w.r.t. output" do input = Nx.iota({1, 4, 2}, type: {:f, 32}) - carry = {Nx.iota({1, 1, 8}, type: {:f, 32})} + carry = {Nx.iota({1, 8}, type: {:f, 32})} input_kernel = {Nx.iota({2, 8}, type: {:f, 32}), Nx.iota({2, 8}, type: {:f, 32}), @@ -1087,7 +1347,7 @@ defmodule Axon.LayersTest do {Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32})} - cell_fn = &Axon.Layers.gru_cell/5 + cell_fn = &Axon.Layers.gru_cell/6 assert_equal( grad_static_bias_output(input, carry, input_kernel, hidden_kernel, bias, cell_fn), @@ -1098,7 +1358,15 @@ defmodule Axon.LayersTest do defn grad_static_bias_carry(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(bias, fn x -> {_, {carry}} = - Axon.Layers.static_unroll(cell_fn, input, carry, input_kernel, hidden_kernel, x) + Axon.Layers.static_unroll( + cell_fn, + input, + carry, + Nx.tensor([[0, 0, 0, 1]]), + input_kernel, + hidden_kernel, + x + ) Nx.mean(carry) end) @@ -1107,7 +1375,15 @@ defmodule Axon.LayersTest do defn grad_dynamic_bias_carry(input, carry, input_kernel, hidden_kernel, bias, cell_fn) do grad(bias, fn x -> {_, {carry}} = - Axon.Layers.dynamic_unroll(cell_fn, input, carry, input_kernel, hidden_kernel, x) + Axon.Layers.dynamic_unroll( + cell_fn, + input, + carry, + Nx.tensor([[0, 0, 0, 1]]), + input_kernel, + hidden_kernel, + x + ) Nx.mean(carry) end) @@ -1115,7 +1391,7 @@ defmodule Axon.LayersTest do test "computes gradient identical to static unroll for bias w.r.t. carry" do input = Nx.iota({1, 4, 2}, type: {:f, 32}) - carry = {Nx.iota({1, 1, 8}, type: {:f, 32})} + carry = {Nx.iota({1, 8}, type: {:f, 32})} input_kernel = {Nx.iota({2, 8}, type: {:f, 32}), Nx.iota({2, 8}, type: {:f, 32}), @@ -1129,7 +1405,7 @@ defmodule Axon.LayersTest do {Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32}), Nx.iota({}, type: {:f, 32})} - cell_fn = &Axon.Layers.gru_cell/5 + cell_fn = &Axon.Layers.gru_cell/6 assert_equal( grad_static_bias_carry(input, carry, input_kernel, hidden_kernel, bias, cell_fn), @@ -1137,4 +1413,132 @@ defmodule Axon.LayersTest do ) end end + + describe "group_norm" do + test "matches pytorch" do + a = + Nx.tensor([ + [ + 0.8423, + 1.9226, + -1.1295, + -1.3154, + 1.2963, + -0.6821, + -0.0519, + 0.6875, + -0.0313, + -0.3328, + -0.2821, + -2.3289, + -1.7641, + -1.3184, + -0.0890, + 0.0625 + ], + [ + -1.0853, + 0.8060, + -0.1397, + -0.2169, + 0.9605, + 0.3947, + 0.4760, + 0.8097, + 0.0380, + -0.6314, + 0.5761, + 1.9309, + 0.5038, + -0.1892, + 1.8476, + 0.0517 + ] + ]) + + b = + Nx.tensor([ + -0.3101, + -1.5896, + -1.4963, + 0.1278, + -1.4580, + 1.3832, + 0.5709, + 0.5531, + -0.0588, + 1.0411, + 1.3503, + -1.2166, + 0.7133, + 0.0694, + 0.3150, + -0.1306 + ]) + + c = + Nx.tensor([ + 1.6585, + 2.3515, + -1.3456, + 0.2376, + -0.1333, + 0.5068, + 0.2441, + 1.0382, + 0.6879, + -0.5402, + -1.8304, + -0.8906, + -0.5329, + -0.3390, + -0.1877, + 0.1405 + ]) + + expected = + Nx.tensor([ + [ + 1.4768, + -0.1375, + 0.4536, + 0.0623, + -1.5881, + -0.5951, + 0.1157, + 1.2847, + 0.6378, + -0.0194, + -1.0751, + 1.3407, + -1.3700, + -0.3844, + 0.0597, + 0.0149 + ], + [ + 2.2986, + 0.9877, + -0.4434, + 0.1453, + -1.7321, + 0.8146, + 0.4430, + 1.5159, + 0.7202, + -1.9153, + -1.7368, + -2.8723, + -0.5429, + -0.3954, + 0.2952, + 0.2103 + ] + ]) + + actual = Axon.Layers.group_norm(a, b, c, num_groups: 2) + + assert_all_close(expected, actual, atol: 1.0e-3) + end + end end diff --git a/test/axon/loop_test.exs b/test/axon/loop_test.exs index 0c9f4c6a..05e8f0bb 100644 --- a/test/axon/loop_test.exs +++ b/test/axon/loop_test.exs @@ -34,7 +34,7 @@ defmodule Axon.LoopTest do ] valid_axon_optimizers = - Axon.Optimizers.__info__(:functions) + Polaris.Optimizers.__info__(:functions) |> Enum.map(fn {k, _} -> k end) |> Enum.uniq() @@ -82,7 +82,7 @@ defmodule Axon.LoopTest do test "trainer/3 returns a supervised training loop with custom optimizer" do model = Axon.input("input", shape: {nil, 1}) - optimizer = Axon.Optimizers.rmsprop(1.0e-3) + optimizer = Polaris.Optimizers.rmsprop(learning_rate: 1.0e-3) assert %Loop{init: init_fn, step: update_fn, output_transform: transform} = Loop.trainer(model, :mean_squared_error, optimizer) @@ -203,7 +203,7 @@ defmodule Axon.LoopTest do end) =~ "Batch" end - test "eval_step/1 evalutes model on a single batch" do + test "eval_step/1 evaluates model on a single batch" do inp = Nx.tensor([0, 1, 0, 1, 0, 1]) |> Nx.new_axis(-1) tar = Nx.tensor([1, 0, 1, 0, 1, 0]) |> Nx.new_axis(-1) @@ -360,7 +360,7 @@ defmodule Axon.LoopTest do Axon.input("input", shape: {nil, 1}) |> Axon.dense(1) |> Loop.trainer(:binary_cross_entropy, :sgd, log: 0) - |> Loop.handle( + |> Loop.handle_event( :epoch_completed, fn %State{step_state: pstate} = state -> { @@ -376,7 +376,7 @@ defmodule Axon.LoopTest do } end ) - |> Loop.handle( + |> Loop.handle_event( :completed, fn %State{step_state: %{counter: counter}} = state -> assert 4 = counter @@ -396,7 +396,7 @@ defmodule Axon.LoopTest do Axon.input("input", shape: {nil, 1}) |> Axon.dense(1) |> Loop.trainer(:binary_cross_entropy, :sgd, log: 0) - |> Loop.handle( + |> Loop.handle_event( :epoch_completed, fn %State{step_state: pstate} = state -> { @@ -416,7 +416,7 @@ defmodule Axon.LoopTest do } end ) - |> Loop.handle( + |> Loop.handle_event( :completed, fn %State{step_state: %{counter: counter}} = state -> assert {{4}, 4} = counter @@ -477,7 +477,7 @@ defmodule Axon.LoopTest do end def send_handler(loop, event) do - Axon.Loop.handle(loop, event, fn state -> + Axon.Loop.handle_event(loop, event, fn state -> send(self(), event) {:continue, state} end) @@ -540,15 +540,6 @@ defmodule Axon.LoopTest do refute_received :iteration_completed end - test "fires correctly on :completed" do - ExUnit.CaptureIO.capture_io(fn -> - run_dummy_loop!(:completed, 5, 10) - end) - - assert_received :completed - refute_received :completed - end - test "fires correctly on :epoch_halted" do model = Axon.input("foo") @@ -562,7 +553,7 @@ defmodule Axon.LoopTest do ExUnit.CaptureIO.capture_io(fn -> model |> Axon.Loop.trainer(:binary_cross_entropy, :sgd) - |> Axon.Loop.handle(:iteration_completed, fn state -> + |> Axon.Loop.handle_event(:iteration_completed, fn state -> {:halt_epoch, state} end) |> send_handler(:epoch_halted) @@ -576,30 +567,6 @@ defmodule Axon.LoopTest do refute_received :epoch_halted end - test "fires correctly on :halted" do - model = Axon.input("foo") - - data = - Stream.repeatedly(fn -> - xs = Nx.tensor([[Enum.random(0..10)]]) - ys = Nx.greater(xs, 5) - {xs, ys} - end) - - ExUnit.CaptureIO.capture_io(fn -> - model - |> Axon.Loop.trainer(:binary_cross_entropy, :sgd) - |> Axon.Loop.handle(:iteration_completed, fn state -> - {:halt_loop, state} - end) - |> send_handler(:halted) - |> Axon.Loop.run(data, %{}, epochs: 5, iterations: 10) - end) - - assert_received :halted - refute_received :halted - end - test "events fire in order" do model = Axon.input("foo") @@ -618,7 +585,6 @@ defmodule Axon.LoopTest do |> send_handler(:iteration_started) |> send_handler(:iteration_completed) |> send_handler(:epoch_completed) - |> send_handler(:completed) |> Axon.Loop.run(data, %{}, epochs: 1, iterations: 1) end) @@ -627,7 +593,6 @@ defmodule Axon.LoopTest do assert_received :iteration_started assert_received :iteration_completed assert_received :epoch_completed - assert_received :completed refute_received _ end @@ -651,7 +616,7 @@ defmodule Axon.LoopTest do end def send_handler(loop, event, filter) do - Axon.Loop.handle( + Axon.Loop.handle_event( loop, event, fn state -> @@ -770,7 +735,7 @@ defmodule Axon.LoopTest do describe "serialization" do test "serialize_state/deserialize_state preserve loop state" do model = Axon.input("input", shape: {nil, 1}) |> Axon.dense(2) - optimizer = Axon.Optimizers.adam(1.0e-2) + optimizer = Polaris.Optimizers.adam(learning_rate: 1.0e-2) loss = :binary_cross_entropy {init_fn, _} = Axon.Loop.train_step(model, loss, optimizer) @@ -813,7 +778,7 @@ defmodule Axon.LoopTest do [loop: loop] end - test "saves a ceckpoint on each epoch", %{loop: loop} do + test "saves a checkpoint on each epoch", %{loop: loop} do loop |> Loop.checkpoint() |> Loop.run([{Nx.tensor([[1]]), Nx.tensor([[2]])}], %{}, epochs: 3) @@ -822,6 +787,28 @@ defmodule Axon.LoopTest do File.ls!("checkpoint") |> Enum.sort() end + test "saves a checkpoint on custom events", %{loop: loop} do + data = List.duplicate({Nx.iota({1, 1}), Nx.iota({1, 1})}, 5) + + assert %Axon.Loop.State{epoch: 3, iteration: 0, event_counts: %{iteration_completed: 15}} = + loop + |> Map.put(:output_transform, & &1) + |> Loop.checkpoint(event: :iteration_completed, filter: [every: 2]) + |> Loop.run(data, %{}, epochs: 3) + + assert [ + "checkpoint_0_0.ckpt", + "checkpoint_0_2.ckpt", + "checkpoint_0_4.ckpt", + "checkpoint_1_1.ckpt", + "checkpoint_1_3.ckpt", + "checkpoint_2_0.ckpt", + "checkpoint_2_2.ckpt", + "checkpoint_2_4.ckpt" + ] == + File.ls!("checkpoint") |> Enum.sort() + end + test "uses the custom file_pattern function", %{loop: loop} do loop |> Loop.checkpoint(file_pattern: &"ckp_#{&1.epoch}.ckpt") @@ -863,7 +850,7 @@ defmodule Axon.LoopTest do model |> Axon.Loop.trainer(:binary_cross_entropy, :sgd) |> Axon.Loop.from_state(state1) - |> Axon.Loop.handle(:epoch_completed, fn %{epoch: epoch} = state -> + |> Axon.Loop.handle_event(:epoch_completed, fn %{epoch: epoch} = state -> assert epoch >= 3 {:continue, state} end) @@ -888,7 +875,7 @@ defmodule Axon.LoopTest do |> Axon.Loop.trainer(:binary_cross_entropy, :sgd) |> Axon.Loop.metric(:accuracy) |> Axon.Loop.validate(model, Enum.take(data, 5)) - |> Axon.Loop.handle( + |> Axon.Loop.handle_event( :epoch_completed, fn %{metrics: metrics} = state -> assert Map.has_key?(metrics, "validation_accuracy") @@ -918,7 +905,7 @@ defmodule Axon.LoopTest do |> Axon.Loop.metric(:accuracy) |> Axon.Loop.validate(model, Enum.take(data, 5)) |> Axon.Loop.early_stop("validation_accuracy", mode: :max) - |> Axon.Loop.handle( + |> Axon.Loop.handle_event( :epoch_completed, fn %{handler_metadata: meta} = state -> assert %{early_stop: %{"validation_accuracy" => _, :since_last_improvement => _}} = @@ -1006,7 +993,7 @@ defmodule Axon.LoopTest do |> Axon.Loop.metric(:accuracy) |> Axon.Loop.validate(model, Enum.take(data, 5)) |> Axon.Loop.reduce_lr_on_plateau("validation_accuracy", mode: :max) - |> Axon.Loop.handle( + |> Axon.Loop.handle_event( :epoch_completed, fn %{handler_metadata: meta} = state -> assert %{reduce_lr: %{"validation_accuracy" => _, :since_last_improvement => _}} = @@ -1039,7 +1026,10 @@ defmodule Axon.LoopTest do ExUnit.CaptureIO.capture_io(fn -> state = model - |> Axon.Loop.trainer(:binary_cross_entropy, Axon.Optimizers.sgd(initial_lr)) + |> Axon.Loop.trainer( + :binary_cross_entropy, + Polaris.Optimizers.sgd(learning_rate: initial_lr) + ) |> Axon.Loop.metric(my_metric, "counter", :running_sum) |> Axon.Loop.reduce_lr_on_plateau("counter", factor: 0.5, mode: :min, patience: 2) # TODO: This API needs to change @@ -1072,7 +1062,10 @@ defmodule Axon.LoopTest do ExUnit.CaptureIO.capture_io(fn -> state = model - |> Axon.Loop.trainer(:binary_cross_entropy, Axon.Optimizers.sgd(initial_lr)) + |> Axon.Loop.trainer( + :binary_cross_entropy, + Polaris.Optimizers.sgd(learning_rate: initial_lr) + ) |> Axon.Loop.metric(my_metric, "counter", :running_sum) |> Axon.Loop.reduce_lr_on_plateau("counter", factor: 0.5, mode: :max, patience: 2) # TODO: This API needs to change diff --git a/test/axon/loss_scale_test.exs b/test/axon/loss_scale_test.exs new file mode 100644 index 00000000..93fb70c0 --- /dev/null +++ b/test/axon/loss_scale_test.exs @@ -0,0 +1,297 @@ +defmodule Axon.LossScaleTest do + use ExUnit.Case + import AxonTestUtil + + import Axon.LossScale + + describe "identity/1" do + test "creates a loss scale tuple" do + assert {init_fn, scale_fn, adjust_fn} = identity() + assert is_function(init_fn, 0) + assert is_function(scale_fn, 2) + assert is_function(adjust_fn, 2) + end + + test "accepts options" do + assert {init_fn, scale_fn, adjust_fn} = identity([]) + assert is_function(init_fn, 0) + assert is_function(scale_fn, 2) + assert is_function(adjust_fn, 2) + end + + test "initializes to empty state" do + assert {init_fn, _, _} = identity() + assert init_fn.() == %{} + end + + test "scale function returns identity operation on x" do + assert {init_fn, scale_fn, _} = identity() + state = init_fn.() + x = Nx.tensor([1.0, 2.0, 3.0]) + + new_x = scale_fn.(x, state) + assert new_x == x + end + + test "adjust function returns identity operation on x and state" do + assert {init_fn, _, adjust_fn} = identity() + state = init_fn.() + x = Nx.tensor([1.0, 2.0, 3.0]) + + assert {new_x, new_state} = adjust_fn.(x, state) + assert new_x == x + assert new_state == state + end + end + + describe "static/1" do + test "creates a loss scale tuple" do + assert {init_fn, scale_fn, adjust_fn} = static() + assert is_function(init_fn, 0) + assert is_function(scale_fn, 2) + assert is_function(adjust_fn, 2) + end + + test "accepts options" do + assert {init_fn, scale_fn, adjust_fn} = static([]) + assert is_function(init_fn, 0) + assert is_function(scale_fn, 2) + assert is_function(adjust_fn, 2) + end + + test "initializes state with default loss scale" do + assert {init_fn, _, _} = static() + assert %{loss_scale: loss_scale} = init_fn.() + assert_equal(loss_scale, Nx.pow(2, 15)) + end + + test "initializes state with specified loss scale" do + init_scale = Nx.pow(3, 15) + assert {init_fn, _, _} = static(init_scale: init_scale) + assert %{loss_scale: loss_scale} = init_fn.() + assert_equal(loss_scale, init_scale) + end + + test "scale function returns a tree scaled by static scale" do + assert {init_fn, scale_fn, _} = static() + state = init_fn.() + a = Nx.tensor([1.0, 2.0, 3.0]) + c = Nx.tensor([4.0, 5.0, 6.0]) + x = %{a: a, b: %{c: c}} + + assert %{a: scaled_a, b: %{c: scaled_c}} = scale_fn.(x, state) + assert_equal(scaled_a, Nx.multiply(a, Nx.pow(2, 15))) + assert_equal(scaled_c, Nx.multiply(c, Nx.pow(2, 15))) + end + + test "scale function returns a tree scaled by static scale with custom scale" do + init_scale = Nx.pow(3, 15) + assert {init_fn, scale_fn, _} = static(init_scale: init_scale) + state = init_fn.() + a = Nx.tensor([1.0, 2.0, 3.0]) + c = Nx.tensor([4.0, 5.0, 6.0]) + x = %{a: a, b: %{c: c}} + + assert %{a: scaled_a, b: %{c: scaled_c}} = scale_fn.(x, state) + assert_equal(scaled_a, Nx.multiply(a, init_scale)) + assert_equal(scaled_c, Nx.multiply(c, init_scale)) + end + + test "adjust function returns unscaled tree with static state" do + assert {init_fn, scale_fn, adjust_fn} = static() + state = init_fn.() + a = Nx.tensor([1.0, 2.0, 3.0]) + c = Nx.tensor([4.0, 5.0, 6.0]) + x = %{a: a, b: %{c: c}} + + scaled_x = scale_fn.(x, state) + assert {unscaled_x, new_state} = adjust_fn.(scaled_x, state) + assert %{a: unscaled_a, b: %{c: unscaled_c}} = unscaled_x + assert %{loss_scale: new_loss_scale} = new_state + + assert_all_close(unscaled_a, a) + assert_all_close(unscaled_c, c) + assert_equal(new_loss_scale, Nx.pow(2, 15)) + end + + test "adjust function returns unscaled tree with static state and custom scale" do + init_scale = Nx.pow(3, 15) + + assert {init_fn, scale_fn, adjust_fn} = static(init_scale: init_scale) + state = init_fn.() + a = Nx.tensor([1.0, 2.0, 3.0]) + c = Nx.tensor([4.0, 5.0, 6.0]) + x = %{a: a, b: %{c: c}} + + scaled_x = scale_fn.(x, state) + assert {unscaled_x, new_state} = adjust_fn.(scaled_x, state) + assert %{a: unscaled_a, b: %{c: unscaled_c}} = unscaled_x + assert %{loss_scale: new_loss_scale} = new_state + + assert_all_close(unscaled_a, a) + assert_all_close(unscaled_c, c) + assert_equal(new_loss_scale, init_scale) + end + end + + describe "dynamic/1" do + test "creates a loss scale tuple" do + assert {init_fn, scale_fn, adjust_fn} = dynamic() + assert is_function(init_fn, 0) + assert is_function(scale_fn, 2) + assert is_function(adjust_fn, 2) + end + + test "accepts options" do + assert {init_fn, scale_fn, adjust_fn} = dynamic([]) + assert is_function(init_fn, 0) + assert is_function(scale_fn, 2) + assert is_function(adjust_fn, 2) + end + + test "initializes state with default loss scale" do + assert {init_fn, _, _} = dynamic() + assert %{loss_scale: loss_scale, counter: counter} = init_fn.() + assert_equal(loss_scale, Nx.pow(2, 15)) + assert_equal(counter, Nx.tensor(0)) + end + + test "initializes state with specified loss scale" do + init_scale = Nx.pow(3, 15) + assert {init_fn, _, _} = dynamic(init_scale: init_scale) + assert %{loss_scale: loss_scale, counter: counter} = init_fn.() + assert_equal(counter, Nx.tensor(0)) + assert_equal(loss_scale, init_scale) + end + + test "scale function returns a tree scaled by scale" do + assert {init_fn, scale_fn, _} = dynamic() + state = init_fn.() + a = Nx.tensor([1.0, 2.0, 3.0]) + c = Nx.tensor([4.0, 5.0, 6.0]) + x = %{a: a, b: %{c: c}} + + assert %{a: scaled_a, b: %{c: scaled_c}} = scale_fn.(x, state) + assert_equal(scaled_a, Nx.multiply(a, Nx.pow(2, 15))) + assert_equal(scaled_c, Nx.multiply(c, Nx.pow(2, 15))) + end + + test "scale function returns a tree scaled by scale with custom scale" do + init_scale = Nx.pow(3, 15) + assert {init_fn, scale_fn, _} = dynamic(init_scale: init_scale) + state = init_fn.() + a = Nx.tensor([1.0, 2.0, 3.0]) + c = Nx.tensor([4.0, 5.0, 6.0]) + x = %{a: a, b: %{c: c}} + + assert %{a: scaled_a, b: %{c: scaled_c}} = scale_fn.(x, state) + assert_equal(scaled_a, Nx.multiply(a, init_scale)) + assert_equal(scaled_c, Nx.multiply(c, init_scale)) + end + + test "adjust function unscales correctly" do + init_scale = Nx.tensor(10) + assert {init_fn, scale_fn, adjust_fn} = dynamic(init_scale: init_scale) + state = init_fn.() + + a = Nx.tensor([1.0, 2.0, 3.0]) + c = Nx.tensor([4.0, 5.0, 6.0]) + x = %{a: a, b: %{c: c}} + + scaled_x = scale_fn.(x, state) + assert {unscaled_x, _new_state} = adjust_fn.(scaled_x, state) + assert %{a: unscaled_a, b: %{c: unscaled_c}} = unscaled_x + + assert_all_close(unscaled_a, a) + assert_all_close(unscaled_c, c) + end + + test "adjust function increases loss scale according to period and factor when grads are finite" do + init_scale = Nx.tensor(10) + period = 5 + assert {init_fn, _, adjust_fn} = dynamic(init_scale: init_scale, period: period) + state = init_fn.() + + finite = Nx.tensor([1.0, 1.0, 1.0]) + + final_state = + for i <- 1..(period - 1), reduce: state do + new_state -> + {_, %{loss_scale: loss_scale, counter: counter} = new_state} = + adjust_fn.(finite, new_state) + + assert_equal(loss_scale, init_scale) + assert_equal(counter, Nx.tensor(i)) + new_state + end + + assert {_, %{loss_scale: final_scale, counter: final_counter}} = + adjust_fn.(finite, final_state) + + assert_equal(final_scale, Nx.tensor(20.0)) + assert_equal(final_counter, Nx.tensor(0)) + end + + test "adjust function reduces loss scale on non finite" do + init_scale = Nx.tensor(10) + period = 5 + factor = 2 + + assert {init_fn, _, adjust_fn} = + dynamic(init_scale: init_scale, period: period, factor: factor) + + state = init_fn.() + + non_finite = Nx.tensor([:infinity, :infinity, :infinity]) + + # TODO: increase to 99 when https://github.com/elixir-nx/complex/issues/26 + # is fixed + for i <- 0..62, reduce: state do + new_state -> + {_, %{loss_scale: loss_scale, counter: counter} = new_state} = + adjust_fn.(non_finite, new_state) + + expected_new_scale = Nx.max(1, Nx.divide(init_scale, Nx.pow(factor, i + 1))) + assert_equal(counter, Nx.tensor(0)) + assert_all_close(loss_scale, expected_new_scale) + + new_state + end + end + + test "adjust function reduces loss scale to min loss scale" do + init_scale = Nx.tensor(20) + period = 5 + factor = 2 + min_loss_scale = 2 + + assert {init_fn, _, adjust_fn} = + dynamic( + init_scale: init_scale, + period: period, + factor: factor, + min_loss_scale: min_loss_scale + ) + + state = init_fn.() + + non_finite = Nx.tensor([:infinity, :infinity, :infinity]) + + # TODO: increase to 99 when https://github.com/elixir-nx/complex/issues/26 + # is fixed + for i <- 0..62, reduce: state do + new_state -> + {_, %{loss_scale: loss_scale, counter: counter} = new_state} = + adjust_fn.(non_finite, new_state) + + expected_new_scale = + Nx.max(min_loss_scale, Nx.divide(init_scale, Nx.pow(factor, i + 1))) + + assert_equal(counter, Nx.tensor(0)) + assert_all_close(loss_scale, expected_new_scale) + + new_state + end + end + end +end diff --git a/test/axon/losses_test.exs b/test/axon/losses_test.exs index 922986c3..6f396cca 100644 --- a/test/axon/losses_test.exs +++ b/test/axon/losses_test.exs @@ -284,4 +284,25 @@ defmodule Axon.LossesTest do ) end end + + describe "apply_label_smoothing" do + test "correctly smooths labels" do + y_true = Nx.tensor([[0, 1, 0, 0, 0, 0]]) + y_pred = Nx.tensor([[0.5, 0.1, 0.1, 0.0, 0.2, 0.1]]) + + assert_all_close( + Axon.Losses.apply_label_smoothing(y_true, y_pred, smoothing: 0.1), + Nx.tensor([[0.0167, 0.9167, 0.0167, 0.0167, 0.0167, 0.0167]]), + atol: 1.0e-3 + ) + end + end + + describe "label_smoothing" do + test "returns an arity-2 function from loss function" do + loss = &Axon.Losses.categorical_cross_entropy/2 + smooth_loss = Axon.Losses.label_smoothing(loss, smoothing: 0.1) + assert is_function(smooth_loss, 2) + end + end end diff --git a/test/axon/mixed_precision_test.exs b/test/axon/mixed_precision_test.exs index 9b7f3c61..d858529e 100644 --- a/test/axon/mixed_precision_test.exs +++ b/test/axon/mixed_precision_test.exs @@ -1,86 +1,4 @@ -# defmodule MixedPrecisionTest do -# use Axon.Case, async: true - -# alias Axon.MixedPrecision.Policy -# alias Axon.MixedPrecision, as: AMP -# alias Axon.Loop - -# describe "creation and application" do -# test "create policy" do -# assert %Policy{params: {:f, 32}, compute: {:bf, 16}, output: {:f, 32}} = -# AMP.create_policy(compute: {:bf, 16}) - -# assert %Policy{params: {:bf, 16}, compute: {:f, 32}, output: {:bf, 16}} = -# AMP.create_policy(params: {:bf, 16}, output: {:bf, 16}) -# end - -# test "apply_policy" do -# model = -# Axon.input("input", shape: {nil, 784}) -# |> Axon.dense(128) -# |> Axon.batch_norm() -# |> Axon.dense(10) - -# policy = AMP.create_policy(compute: {:bf, 16}) - -# assert %Axon{ -# op: :dense, -# parent: [ -# %Axon{ -# op: :batch_norm, -# parent: [%Axon{op: :dense, policy: %Policy{compute: {:bf, 16}}}], -# policy: %Policy{compute: {:f, 32}} -# } -# ], -# policy: %Policy{compute: {:bf, 16}} -# } = AMP.apply_policy(model, policy, except: [:batch_norm]) -# end -# end - -# describe "compilation" do -# # TODO(seanmor5): Now that everything else has moved, maybe this -# # belongs in a train test or elsewhere -# test "correctly maintains parameter type after train step" do -# model = -# Axon.input("input", shape: {nil, 32}) -# |> Axon.dense(2, name: "dense1") -# |> Axon.batch_norm(name: "batch_norm") -# |> Axon.dense(1, activation: :sigmoid, name: "dense2") - -# policy = AMP.create_policy(params: {:bf, 16}) - -# mp_model = AMP.apply_policy(model, policy, except: [:batch_norm]) - -# %Loop{init: init_fn, step: step_fn} = -# Axon.Loop.trainer(mp_model, :binary_cross_entropy, Axon.Optimizers.sgd(0.01)) - -# v1 = Nx.random_uniform({1, 32}) -# v2 = Nx.random_uniform({1, 1}) - -# pstate = -# apply(Nx.Defn.jit(step_fn), [ -# {v1, v2}, -# init_fn.({v1, v2}, %{}) -# ]) - -# params = pstate[:model_state] - -# assert Nx.type(params["dense1"]["kernel"]) == {:bf, 16} -# assert Nx.type(params["dense1"]["bias"]) == {:bf, 16} -# assert Nx.type(params["dense2"]["kernel"]) == {:bf, 16} -# assert Nx.type(params["dense2"]["bias"]) == {:bf, 16} -# assert Nx.type(params["batch_norm"]["gamma"]) == {:f, 32} -# assert Nx.type(params["batch_norm"]["beta"]) == {:f, 32} -# end -# end - -# describe "inspection" do -# test "works" do -# policy = AMP.create_policy() - -# assert inspect(policy) == """ -# p=f32 c=f32 o=f32\ -# """ -# end -# end -# end +defmodule Axon.MixedPrecisionTest do + use ExUnit.Case + doctest Axon.MixedPrecision +end diff --git a/test/axon_test.exs b/test/axon_test.exs index c2a24df8..835aadd3 100644 --- a/test/axon_test.exs +++ b/test/axon_test.exs @@ -888,7 +888,7 @@ defmodule AxonTest do #Axon< inputs: %{"input_0" => {nil, 32, 10}} outputs: "lstm_output_sequence" - nodes: 6 + nodes: 7 >\ """ end diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_bf.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_bf.npy new file mode 100644 index 00000000..f444dc0f Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_bf.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_bg.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_bg.npy new file mode 100644 index 00000000..8bc03a47 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_bg.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_bi.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_bi.npy new file mode 100644 index 00000000..203ae3ca Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_bi.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_bo.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_bo.npy new file mode 100644 index 00000000..ab6e84d9 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_bo.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_input_c.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_input_c.npy new file mode 100644 index 00000000..983d06ac Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_input_c.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_input_h.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_input_h.npy new file mode 100644 index 00000000..0a167d7c Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_input_h.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_input_seq.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_input_seq.npy new file mode 100644 index 00000000..5b96fe16 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_input_seq.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_output_c.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_output_c.npy new file mode 100644 index 00000000..971ba017 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_output_c.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_output_h.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_output_h.npy new file mode 100644 index 00000000..15b995ba Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_output_h.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_whf.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_whf.npy new file mode 100644 index 00000000..e3883b75 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_whf.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_whg.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_whg.npy new file mode 100644 index 00000000..1ce47cf0 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_whg.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_whi.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_whi.npy new file mode 100644 index 00000000..80e7dc49 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_whi.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_who.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_who.npy new file mode 100644 index 00000000..31ba7f5b Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_who.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_wif.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_wif.npy new file mode 100644 index 00000000..fac88b34 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_wif.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_wig.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_wig.npy new file mode 100644 index 00000000..f0f22966 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_wig.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_wii.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_wii.npy new file mode 100644 index 00000000..982a6ced Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_wii.npy differ diff --git a/test/fixtures/lstm_cell_test/test_lstm_cell_wio.npy b/test/fixtures/lstm_cell_test/test_lstm_cell_wio.npy new file mode 100644 index 00000000..49c363f4 Binary files /dev/null and b/test/fixtures/lstm_cell_test/test_lstm_cell_wio.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_bf.npy b/test/fixtures/lstm_test/test_lstm_bf.npy new file mode 100644 index 00000000..26f96ff3 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_bf.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_bg.npy b/test/fixtures/lstm_test/test_lstm_bg.npy new file mode 100644 index 00000000..ba52553f Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_bg.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_bi.npy b/test/fixtures/lstm_test/test_lstm_bi.npy new file mode 100644 index 00000000..25b16a5b Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_bi.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_bo.npy b/test/fixtures/lstm_test/test_lstm_bo.npy new file mode 100644 index 00000000..06fd60cd Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_bo.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_input_c.npy b/test/fixtures/lstm_test/test_lstm_input_c.npy new file mode 100644 index 00000000..23f8afa7 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_input_c.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_input_h.npy b/test/fixtures/lstm_test/test_lstm_input_h.npy new file mode 100644 index 00000000..3f2c0b33 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_input_h.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_input_seq.npy b/test/fixtures/lstm_test/test_lstm_input_seq.npy new file mode 100644 index 00000000..b8f4633d Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_input_seq.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_output_c.npy b/test/fixtures/lstm_test/test_lstm_output_c.npy new file mode 100644 index 00000000..488515c3 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_output_c.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_output_h.npy b/test/fixtures/lstm_test/test_lstm_output_h.npy new file mode 100644 index 00000000..98bb49ec Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_output_h.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_output_seq.npy b/test/fixtures/lstm_test/test_lstm_output_seq.npy new file mode 100644 index 00000000..382acfa3 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_output_seq.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_whf.npy b/test/fixtures/lstm_test/test_lstm_whf.npy new file mode 100644 index 00000000..1063b5a6 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_whf.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_whg.npy b/test/fixtures/lstm_test/test_lstm_whg.npy new file mode 100644 index 00000000..470e25a2 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_whg.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_whi.npy b/test/fixtures/lstm_test/test_lstm_whi.npy new file mode 100644 index 00000000..288a45cf Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_whi.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_who.npy b/test/fixtures/lstm_test/test_lstm_who.npy new file mode 100644 index 00000000..95182e0d Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_who.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_wif.npy b/test/fixtures/lstm_test/test_lstm_wif.npy new file mode 100644 index 00000000..2fee5e56 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_wif.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_wig.npy b/test/fixtures/lstm_test/test_lstm_wig.npy new file mode 100644 index 00000000..0c485ad2 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_wig.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_wii.npy b/test/fixtures/lstm_test/test_lstm_wii.npy new file mode 100644 index 00000000..159ab30b Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_wii.npy differ diff --git a/test/fixtures/lstm_test/test_lstm_wio.npy b/test/fixtures/lstm_test/test_lstm_wio.npy new file mode 100644 index 00000000..b441abf8 Binary files /dev/null and b/test/fixtures/lstm_test/test_lstm_wio.npy differ diff --git a/test/support/axon_test_util.ex b/test/support/axon_test_util.ex index 633243fe..d5544b19 100644 --- a/test/support/axon_test_util.ex +++ b/test/support/axon_test_util.ex @@ -134,7 +134,7 @@ defmodule AxonTestUtil do {params, opt_state} = state gradients = Nx.Defn.grad(params, loss) {updates, new_state} = update_fn.(gradients, opt_state, params) - {Axon.Updates.apply_updates(updates, params), new_state} + {Polaris.Updates.apply_updates(updates, params), new_state} end {params, _} = @@ -163,6 +163,13 @@ defmodule AxonTestUtil do end end + def random(shape, opts \\ []) do + Nx.Random.uniform_split(Nx.Random.key(:erlang.system_time()), 0.0, 1.0, + shape: shape, + type: opts[:type] || :f32 + ) + end + def get_test_data( train_samples, test_samples,