Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce composite params, fix stability issue with unrolled RNNs #550

Merged
merged 6 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 88 additions & 47 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,9 @@ defmodule Axon do
be automatically initialized and used in subsequent applications
of Axon models.

Parameters *must* be specified in order of their usage.
You may specify the parameter shape as either a static shape or
as function of the inputs to the given layer. If you specify the
parameter shape as a function, it will be given the

## Options

Expand All @@ -397,7 +399,18 @@ defmodule Axon do
"""
@doc type: :special
def param(name, shape, opts \\ [])
when is_binary(name) and (is_tuple(shape) or is_function(shape)) do

def param(name, {:map, [_ | _] = inner_params}, opts) do
maybe_warn_on_param_opts(opts)

%Axon.Parameter{
name: name,
type: :map,
children: inner_params
}
end

def param(name, shape, opts) when is_tuple(shape) or is_function(shape) do
opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32})
initializer = validate_initializer!(opts[:initializer])
type = opts[:type] || {:f, 32}
Expand All @@ -410,6 +423,14 @@ defmodule Axon do
}
end

defp maybe_warn_on_param_opts(opts) do
if :initializer in opts or :type in opts do
Logger.warning(
"Passing options to a composite parameter has no effect. Pass them to inner parameters instead"
)
end
end

@doc """
Adds an input layer to the network.

Expand Down Expand Up @@ -2465,23 +2486,25 @@ defmodule Axon do
activation = opts[:activation]
gate = opts[:gate]
unroll = opts[:unroll]
kernel_initializer = opts[:kernel_initializer]

input_kernel_shape = fn inp, _, _ -> Axon.Shape.rnn_input_kernel(inp, units, :lstm) end
hidden_kernel_shape = fn inp, _, _ -> Axon.Shape.rnn_hidden_kernel(inp, units, :lstm) end
bias_shape = fn inp, _, _ -> Axon.Shape.rnn_bias(inp, units, :lstm) end

kernel_initializer = opts[:kernel_initializer]
wii = param("wii", input_kernel_shape, initializer: kernel_initializer)
wif = param("wif", input_kernel_shape, initializer: kernel_initializer)
wig = param("wig", input_kernel_shape, initializer: kernel_initializer)
wio = param("wio", input_kernel_shape, initializer: kernel_initializer)

# Parameters
input_kernel =
param("input_kernel", {:tuple, List.duplicate(input_kernel_shape, 4)},
initializer: kernel_initializer
)
whi = param("whi", hidden_kernel_shape, initializer: kernel_initializer)
whf = param("whf", hidden_kernel_shape, initializer: kernel_initializer)
whg = param("whg", hidden_kernel_shape, initializer: kernel_initializer)
who = param("who", hidden_kernel_shape, initializer: kernel_initializer)

hidden_kernel =
param("hidden_kernel", {:tuple, List.duplicate(hidden_kernel_shape, 4)},
initializer: kernel_initializer
)
# Parameters
input_kernel = param("input_kernel", {:map, [wii, wif, wig, wio]})
hidden_kernel = param("hidden_kernel", {:map, [whi, whf, whg, who]})

hidden_state_name =
case opts[:name] do
Expand All @@ -2500,8 +2523,12 @@ defmodule Axon do
if opts[:use_bias] do
bias_initializer = opts[:bias_initializer]

bias =
param("bias", {:tuple, List.duplicate(bias_shape, 4)}, initializer: bias_initializer)
bi = param("bi", bias_shape, initializer: bias_initializer)
bf = param("bf", bias_shape, initializer: bias_initializer)
bg = param("bg", bias_shape, initializer: bias_initializer)
bo = param("bo", bias_shape, initializer: bias_initializer)

bias = param("bias", {:map, [bi, bf, bg, bo]})

{[x, hidden_state, opts[:mask], input_kernel, hidden_kernel, bias], :lstm}
else
Expand Down Expand Up @@ -2670,15 +2697,16 @@ defmodule Axon do

kernel_initializer = opts[:kernel_initializer]

input_kernel =
param("input_kernel", {:tuple, List.duplicate(input_kernel_shape, 3)},
initializer: kernel_initializer
)
wir = param("wir", input_kernel_shape, initializer: kernel_initializer)
wiz = param("wiz", input_kernel_shape, initializer: kernel_initializer)
win = param("win", input_kernel_shape, initializer: kernel_initializer)

hidden_kernel =
param("hidden_kernel", {:tuple, List.duplicate(hidden_kernel_shape, 3)},
initializer: kernel_initializer
)
whr = param("whr", hidden_kernel_shape, initializer: kernel_initializer)
whz = param("whz", hidden_kernel_shape, initializer: kernel_initializer)
whn = param("whn", hidden_kernel_shape, initializer: kernel_initializer)

input_kernel = param("input_kernel", {:map, [wir, wiz, win]})
hidden_kernel = param("hidden_kernel", {:map, [whr, whz, whn]})

hidden_state_name =
case opts[:name] do
Expand All @@ -2697,8 +2725,12 @@ defmodule Axon do
if opts[:use_bias] do
bias_initializer = opts[:bias_initializer]

bias =
param("bias", {:tuple, List.duplicate(bias_shape, 4)}, initializer: bias_initializer)
br = param("br", bias_shape, initializer: bias_initializer)
bz = param("bz", bias_shape, initializer: bias_initializer)
bin = param("bin", bias_shape, initializer: bias_initializer)
bhn = param("bhn", bias_shape, initializer: bias_initializer)

bias = param("bias", {:map, [br, bz, bin, bhn]})

[x, hidden_state, opts[:mask], input_kernel, hidden_kernel, bias]
else
Expand Down Expand Up @@ -2865,8 +2897,8 @@ defmodule Axon do
Axon.Shape.conv_bias(shape, 4 * units, kernel_size, :first, 1)
end

wi = param("input_kernel", {:tuple, [input_kernel_shape]}, initializer: kernel_initializer)
wh = param("hidden_kernel", {:tuple, [hidden_kernel_shape]}, initializer: kernel_initializer)
wi = param("input_kernel", input_kernel_shape, initializer: kernel_initializer)
wh = param("hidden_kernel", hidden_kernel_shape, initializer: kernel_initializer)

hidden_state_name =
case opts[:name] do
Expand All @@ -2884,7 +2916,7 @@ defmodule Axon do
{inputs, op} =
if opts[:use_bias] do
bias_initializer = opts[:bias_initializer]
b = param("bias", {:tuple, [bias_shape]}, initializer: bias_initializer)
b = param("bias", bias_shape, initializer: bias_initializer)
{[x, hidden_state, opts[:mask], wi, wh, b], :conv_lstm}
else
{[x, hidden_state, opts[:mask], wi, wh], :conv_lstm}
Expand Down Expand Up @@ -2977,31 +3009,40 @@ defmodule Axon do
"#{parent_name}_#{state_name}_hidden_state"
end

fun = fn inputs, key, _opts ->
shape = Axon.Shape.rnn_hidden_state(Nx.shape(inputs), units, rnn_type)
initializer =
if is_function(initializer) do
initializer
else
apply(Axon.Initializers, initializer, [])
end

case initializer do
fun when is_function(fun) ->
fun.(shape)
{:arity, arity} = Function.info(initializer, :arity)

fun when is_atom(fun) ->
fun = apply(Axon.Initializers, fun, [])
{:arity, arity} = Function.info(fun, :arity)
{fun, inputs} =
cond do
arity == 2 ->
fun =
fn inputs, _opts ->
shape = Axon.Shape.rnn_hidden_state(Nx.shape(inputs), units, rnn_type)
initializer.(shape, {:f, 32})
end

cond do
arity == 2 ->
fun.(shape, {:f, 32})
{fun, [x]}

arity == 3 ->
fun.(shape, {:f, 32}, key)
arity == 3 ->
fun =
fn inputs, key, _opts ->
shape = Axon.Shape.rnn_hidden_state(Nx.shape(inputs), units, rnn_type)
initializer.(shape, {:f, 32}, key)
end

true ->
raise ArgumentError, "bad arity for initializer"
end
{fun, [x, key_state]}

true ->
raise ArgumentError, "bad arity for initializer"
end
end

layer(fun, [x, key_state], name: name, op_name: :recurrent_state)
layer(fun, inputs, name: name, op_name: :recurrent_state)
end

@doc """
Expand Down Expand Up @@ -3722,7 +3763,7 @@ defmodule Axon do
"""
@doc type: :debug
def trace_backward(model, inputs, params, loss, opts \\ []) do
{_, forward_fn} = build(model, opts)
{_, forward_fn} = build(model, opts ++ [mode: :train])

backward_fn = fn params, inputs, targets ->
Nx.Defn.grad(params, fn params ->
Expand All @@ -3731,7 +3772,7 @@ defmodule Axon do
end)
end

outputs = Nx.Defn.jit(forward_fn, compiler: Axon.Defn).(params, inputs)
%{prediction: outputs} = Nx.Defn.jit(forward_fn, compiler: Axon.Defn).(params, inputs)
inputs = [params, inputs, outputs]

apply(Nx.Defn.jit(backward_fn, compiler: Axon.Defn), inputs)
Expand Down
130 changes: 92 additions & 38 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -147,31 +147,12 @@ defmodule Axon.Compiler do
{_, %Axon.Node{id: id, op: op, name: name_fn, parameters: params}}, {keys, op_counts} ->
name = name_fn.(op, op_counts)
op_counts = Map.update(op_counts, op, 1, &(&1 + 1))

keys =
Enum.reduce(params, keys, fn
%Axon.Parameter{name: param_name, initializer: fun}, keys ->
{:arity, arity} = Function.info(fun, :arity)

cond do
arity == 2 ->
keys

arity == 3 ->
<<data::unsigned-size(32), _rest::binary>> =
:erlang.md5(name <> "." <> param_name)

[{{id, param_name}, data} | keys]

true ->
raise ArgumentError, "bad initializer arity"
end
end)

keys = get_node_keys(id, name, params, keys)
{keys, op_counts}
end)

{ids, data} = Enum.unzip(ids_and_data)
data = List.flatten(data)

case ids do
[] ->
Expand All @@ -186,16 +167,82 @@ defmodule Axon.Compiler do
|> then(&Nx.Random.fold_in(key, &1))

{keys, _} =
Enum.reduce(ids, {%{}, 0}, fn {layer_id, param_name}, {acc, i} ->
key = keys_tensor[i]
acc = Map.update(acc, layer_id, %{param_name => key}, &Map.put(&1, param_name, key))
{acc, i + 1}
Enum.reduce(ids, {%{}, 0}, fn {layer_id, param}, {acc, i} ->
{{root_name, keys}, i} = recur_slice_keys(keys_tensor, param, i)

layer_keys =
Map.update(acc, layer_id, %{root_name => keys}, &Map.put(&1, root_name, keys))

{layer_keys, i}
end)

keys
end
end

defp get_node_keys(id, parent_name, params, keys) do
Enum.reduce(params, keys, fn param, keys ->
case get_param_data(parent_name, param) do
nil -> keys
{param_name, data} -> [{{id, param_name}, data} | keys]
end
end)
end

defp get_param_data(parent_name, param) do
case param do
%Axon.Parameter{name: param_name, type: :map, children: inner_params} ->
parent_name = parent_name <> "." <> param_name

{inner_names, inner_data} =
Enum.map(inner_params, &get_param_data(parent_name, &1))
|> Enum.reject(&(&1 == nil))
|> Enum.unzip()

case inner_data do
[] ->
nil

[_ | _] ->
{{param_name, inner_names}, inner_data}
end

%Axon.Parameter{name: param_name, initializer: fun} ->
{:arity, arity} = Function.info(fun, :arity)

cond do
arity == 2 ->
nil

arity == 3 ->
<<data::unsigned-size(32), _rest::binary>> =
:erlang.md5(parent_name <> "." <> param_name)

{param_name, [data]}

true ->
raise ArgumentError, "bad initializer arity"
end
end
end

defp recur_slice_keys(keys_tensor, param, i) do
case param do
{composite_param_name, children} ->
{subkeys, i} =
Enum.reduce(children, {%{}, i}, fn child_param, {acc, i} ->
{{root_name, keys}, i} = recur_slice_keys(keys_tensor, child_param, i)
{Map.put(acc, root_name, keys), i}
end)

{{composite_param_name, subkeys}, i}

param_name when is_binary(param_name) ->
key = keys_tensor[i]
{{param_name, key}, i + 1}
end
end

defp merge_params!(params, init_params) do
Enum.reduce(init_params, params, fn {key, value}, params ->
case params do
Expand Down Expand Up @@ -1020,21 +1067,28 @@ defmodule Axon.Compiler do
end

defp init_param(layer_id, param, layer_params, parent_shapes, dtype, keys) do
%{name: name, shape: shape, initializer: initializer} = param
%{name: name} = param

params =
case shape do
{:tuple, params} ->
params =
Enum.map(params, fn shape ->
shape = apply(shape, parent_shapes)
apply_initializer(layer_id, initializer, name, shape, dtype, keys)
end)

List.to_tuple(params)

shape ->
shape = apply(shape, parent_shapes)
case param do
%Axon.Parameter{name: parent_name, type: :map, children: children} ->
Enum.reduce(children, %{}, fn child_param, acc ->
init_param(parent_name, child_param, acc, parent_shapes, dtype, keys[layer_id])
end)

%Axon.Parameter{name: name, shape: shape, initializer: initializer} ->
shape =
case shape do
shape when is_function(shape) ->
apply(shape, parent_shapes)

shape when is_tuple(shape) ->
shape

other ->
raise "unsupported parameter shape, parameter shape should be a static tuple, a function, or a composite, got #{inspect(other)}"
end

apply_initializer(layer_id, initializer, name, shape, dtype, keys)
end

Expand Down
Loading
Loading