Skip to content

Commit

Permalink
Add layer name to hook (#536)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 authored Jul 24, 2024
1 parent b93e87f commit 216fafe
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
22 changes: 11 additions & 11 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,8 @@ defmodule Axon.Compiler do

res =
value
|> apply_hooks(:forward, mode, hooks)
|> apply_hooks(:backward, mode, hooks)
|> apply_hooks(name, :forward, mode, hooks)
|> apply_hooks(name, :backward, mode, hooks)
|> maybe_print_values(name, print_values)

{res, {state, result_cache}}
Expand Down Expand Up @@ -975,7 +975,7 @@ defmodule Axon.Compiler do
layer_input =
layer_input
|> safe_policy_cast(policy, :compute)
|> apply_hooks(:pre_forward, mode, hooks)
|> apply_hooks(name, :pre_forward, mode, hooks)

{layer_input, {state, result_cache, none?}}
end
Expand Down Expand Up @@ -1051,8 +1051,8 @@ defmodule Axon.Compiler do
%StatefulOutput{output: out, state: out_state} ->
new_out =
out
|> apply_hooks(:forward, mode, hooks)
|> apply_hooks(:backward, mode, hooks)
|> apply_hooks(name, :forward, mode, hooks)
|> apply_hooks(name, :backward, mode, hooks)
|> safe_policy_cast(policy, :output)

new_state = Map.put(state, name, out_state)
Expand All @@ -1061,8 +1061,8 @@ defmodule Axon.Compiler do
out ->
new_out =
out
|> apply_hooks(:forward, mode, hooks)
|> apply_hooks(:backward, mode, hooks)
|> apply_hooks(name, :forward, mode, hooks)
|> apply_hooks(name, :backward, mode, hooks)
|> safe_policy_cast(policy, :output)

{new_out, state}
Expand Down Expand Up @@ -1169,7 +1169,7 @@ defmodule Axon.Compiler do
init_param(layer_id, param, layer_params, parent_templates, dtype, keys)
end)

layer_params = apply_hooks(layer_params, :initialize, nil, hooks)
layer_params = apply_hooks(layer_params, name, :initialize, nil, hooks)

params =
if layer_params == %{} do
Expand Down Expand Up @@ -1228,7 +1228,7 @@ defmodule Axon.Compiler do

defp maybe_print_values(value, _, _), do: value

defp apply_hooks(res, event, mode, hooks) do
defp apply_hooks(res, layer_name, event, mode, hooks) do
hooks
|> Enum.reverse()
|> Enum.reduce(res, fn {on_event, on_mode, hook_fn}, expr ->
Expand All @@ -1238,11 +1238,11 @@ defmodule Axon.Compiler do
if event? and mode? do
if on_event == :backward do
Nx.Defn.Kernel.custom_grad(expr, [expr], fn g ->
hooked_g = Nx.Defn.Kernel.hook(g, hook_fn)
hooked_g = Nx.Defn.Kernel.hook(g, String.to_atom(layer_name), hook_fn)
[hooked_g]
end)
else
Nx.Defn.Kernel.hook(expr, hook_fn)
Nx.Defn.Kernel.hook(expr, String.to_atom(layer_name), hook_fn)
end
else
expr
Expand Down
20 changes: 20 additions & 0 deletions test/axon/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4726,6 +4726,26 @@ defmodule CompilerTest do
assert_receive {%Nx.Tensor{}, :from_relu}
assert_receive {%Nx.Tensor{}, :from_sigmoid}
end

test "can be overriden at jit-time with layer name", config do
model =
Axon.input("input_0", shape: {nil, 1})
|> Axon.attach_hook(fn x -> send(config.test, {x, :from_input}) end, on: :forward)
|> Axon.relu()

inp = Nx.tensor([[1.0]])
{_, predict_fn} = Axon.build(model)

hook = fn val -> send(config.test, {val, :overridden}) end

fun = Nx.Defn.jit(predict_fn, hooks: %{input_0: hook})
apply(fun, [ModelState.empty(), inp])

assert_receive {from_inp, :overridden}
refute_receive {_, :from_input}

assert_equal(from_inp, inp)
end
end

describe "integrated models" do
Expand Down

0 comments on commit 216fafe

Please sign in to comment.