diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index ae449a99..07c73b5a 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -501,8 +501,8 @@ defmodule Axon.Compiler do res = value - |> apply_hooks(:forward, mode, hooks) - |> apply_hooks(:backward, mode, hooks) + |> apply_hooks(name, :forward, mode, hooks) + |> apply_hooks(name, :backward, mode, hooks) |> maybe_print_values(name, print_values) {res, {state, result_cache}} @@ -975,7 +975,7 @@ defmodule Axon.Compiler do layer_input = layer_input |> safe_policy_cast(policy, :compute) - |> apply_hooks(:pre_forward, mode, hooks) + |> apply_hooks(name, :pre_forward, mode, hooks) {layer_input, {state, result_cache, none?}} end @@ -1051,8 +1051,8 @@ defmodule Axon.Compiler do %StatefulOutput{output: out, state: out_state} -> new_out = out - |> apply_hooks(:forward, mode, hooks) - |> apply_hooks(:backward, mode, hooks) + |> apply_hooks(name, :forward, mode, hooks) + |> apply_hooks(name, :backward, mode, hooks) |> safe_policy_cast(policy, :output) new_state = Map.put(state, name, out_state) @@ -1061,8 +1061,8 @@ defmodule Axon.Compiler do out -> new_out = out - |> apply_hooks(:forward, mode, hooks) - |> apply_hooks(:backward, mode, hooks) + |> apply_hooks(name, :forward, mode, hooks) + |> apply_hooks(name, :backward, mode, hooks) |> safe_policy_cast(policy, :output) {new_out, state} @@ -1169,7 +1169,7 @@ defmodule Axon.Compiler do init_param(layer_id, param, layer_params, parent_templates, dtype, keys) end) - layer_params = apply_hooks(layer_params, :initialize, nil, hooks) + layer_params = apply_hooks(layer_params, name, :initialize, nil, hooks) params = if layer_params == %{} do @@ -1228,7 +1228,7 @@ defmodule Axon.Compiler do defp maybe_print_values(value, _, _), do: value - defp apply_hooks(res, event, mode, hooks) do + defp apply_hooks(res, layer_name, event, mode, hooks) do hooks |> Enum.reverse() |> Enum.reduce(res, fn {on_event, on_mode, hook_fn}, expr -> @@ -1238,11 +1238,11 @@ defmodule Axon.Compiler do if event? and mode? do if on_event == :backward do Nx.Defn.Kernel.custom_grad(expr, [expr], fn g -> - hooked_g = Nx.Defn.Kernel.hook(g, hook_fn) + hooked_g = Nx.Defn.Kernel.hook(g, String.to_atom(layer_name), hook_fn) [hooked_g] end) else - Nx.Defn.Kernel.hook(expr, hook_fn) + Nx.Defn.Kernel.hook(expr, String.to_atom(layer_name), hook_fn) end else expr diff --git a/test/axon/compiler_test.exs b/test/axon/compiler_test.exs index 3cad83dd..f955690f 100644 --- a/test/axon/compiler_test.exs +++ b/test/axon/compiler_test.exs @@ -4726,6 +4726,26 @@ defmodule CompilerTest do assert_receive {%Nx.Tensor{}, :from_relu} assert_receive {%Nx.Tensor{}, :from_sigmoid} end + + test "can be overriden at jit-time with layer name", config do + model = + Axon.input("input_0", shape: {nil, 1}) + |> Axon.attach_hook(fn x -> send(config.test, {x, :from_input}) end, on: :forward) + |> Axon.relu() + + inp = Nx.tensor([[1.0]]) + {_, predict_fn} = Axon.build(model) + + hook = fn val -> send(config.test, {val, :overridden}) end + + fun = Nx.Defn.jit(predict_fn, hooks: %{input_0: hook}) + apply(fun, [ModelState.empty(), inp]) + + assert_receive {from_inp, :overridden} + refute_receive {_, :from_input} + + assert_equal(from_inp, inp) + end end describe "integrated models" do