Skip to content

Commit

Permalink
test: ensure Loop.checkpoint reacts to event filter (#520)
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente authored Aug 20, 2023
1 parent 1bf0efe commit 7fc93c3
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
20 changes: 15 additions & 5 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1255,12 +1255,22 @@ defmodule Axon.Loop do
`checkpoint_\#{epoch}_\#{iteration}.ckpt`.
"""
def checkpoint(%Loop{} = loop, opts \\ []) do
{event, opts} = Keyword.pop(opts, :event, :epoch_completed)
{filter, opts} = Keyword.pop(opts, :filter, :always)
{path, opts} = Keyword.pop(opts, :path, "checkpoint")
{file_pattern, opts} = Keyword.pop(opts, :file_pattern, &default_checkpoint_file/1)
opts =
Keyword.validate!(opts, [
:criteria,
event: :epoch_completed,
filter: :always,
path: "checkpoint",
file_pattern: &default_checkpoint_file/1,
mode: :min
])

{criteria, opts} = Keyword.pop(opts, :criteria)
{mode, serialize_opts} = Keyword.pop(opts, :mode, :min)
{event, opts} = Keyword.pop!(opts, :event)
{filter, opts} = Keyword.pop!(opts, :filter)
{path, opts} = Keyword.pop!(opts, :path)
{file_pattern, opts} = Keyword.pop!(opts, :file_pattern)
{mode, serialize_opts} = Keyword.pop!(opts, :mode)

checkpoint_fun = &checkpoint_impl(&1, path, file_pattern, serialize_opts)

Expand Down
24 changes: 23 additions & 1 deletion test/axon/loop_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ defmodule Axon.LoopTest do
[loop: loop]
end

test "saves a ceckpoint on each epoch", %{loop: loop} do
test "saves a checkpoint on each epoch", %{loop: loop} do
loop
|> Loop.checkpoint()
|> Loop.run([{Nx.tensor([[1]]), Nx.tensor([[2]])}], %{}, epochs: 3)
Expand All @@ -787,6 +787,28 @@ defmodule Axon.LoopTest do
File.ls!("checkpoint") |> Enum.sort()
end

test "saves a checkpoint on custom events", %{loop: loop} do
data = List.duplicate({Nx.iota({1, 1}), Nx.iota({1, 1})}, 5)

assert %Axon.Loop.State{epoch: 3, iteration: 0, event_counts: %{iteration_completed: 15}} =
loop
|> Map.put(:output_transform, & &1)
|> Loop.checkpoint(event: :iteration_completed, filter: [every: 2])
|> Loop.run(data, %{}, epochs: 3)

assert [
"checkpoint_0_0.ckpt",
"checkpoint_0_2.ckpt",
"checkpoint_0_4.ckpt",
"checkpoint_1_1.ckpt",
"checkpoint_1_3.ckpt",
"checkpoint_2_0.ckpt",
"checkpoint_2_2.ckpt",
"checkpoint_2_4.ckpt"
] ==
File.ls!("checkpoint") |> Enum.sort()
end

test "uses the custom file_pattern function", %{loop: loop} do
loop
|> Loop.checkpoint(file_pattern: &"ckp_#{&1.epoch}.ckpt")
Expand Down

0 comments on commit 7fc93c3

Please sign in to comment.