Skip to content

Commit

Permalink
Rename garbage collect option
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Dec 12, 2023
1 parent 52be5dc commit 7a2e9bc
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1590,7 +1590,7 @@ defmodule Axon.Loop do
functions. JIT compilation must be used for gradient computations. Defaults
to true.
* `:force_garbage_collect?` - whether or not to force garbage collection after
* `:garbage_collect` - whether or not to garbage collect after
each loop iteration. This may prevent OOMs, but it will slow down training.
* `:strict?` - whether or not to compile step functions strictly. If this flag
Expand All @@ -1608,7 +1608,7 @@ defmodule Axon.Loop do
{max_epochs, opts} = Keyword.pop(opts, :epochs, 1)
{max_iterations, opts} = Keyword.pop(opts, :iterations, -1)
{jit_compile?, opts} = Keyword.pop(opts, :jit_compile?, true)
{force_garbage_collection?, opts} = Keyword.pop(opts, :force_garbage_collection?, false)
{garbage_collect, opts} = Keyword.pop(opts, :garbage_collect, false)
{strict?, jit_opts} = Keyword.pop(opts, :strict?, true)
debug? = Keyword.get(jit_opts, :debug, false)

Expand Down Expand Up @@ -1680,8 +1680,8 @@ defmodule Axon.Loop do
batch_fn =
{:non_compiled, build_batch_fn(step_fn, metric_fns), jit_compile?, strict?, jit_opts}

epoch_start..epoch_end//1
|> Enum.reduce_while(
Enum.reduce_while(
epoch_start..epoch_end//1,
{batch_fn, final_metrics_map, state},
fn epoch, {batch_fn, final_metrics_map, loop_state} ->
case fire_event(:epoch_started, handler_fns, loop_state, debug?) do
Expand All @@ -1697,7 +1697,14 @@ defmodule Axon.Loop do
end

{time, status_batch_fn_and_state} =
:timer.tc(&run_epoch/6, [batch_fn, handler_fns, state, data, debug?, force_garbage_collection?])
:timer.tc(&run_epoch/6, [
batch_fn,
handler_fns,
state,
data,
debug?,
garbage_collect
])

if debug? do
Logger.debug("Axon.Loop finished running epoch in #{us_to_ms(time)} ms")
Expand Down Expand Up @@ -1784,7 +1791,7 @@ defmodule Axon.Loop do
end
end

defp run_epoch(batch_fn, handler_fns, loop_state, data, debug?, force_garbage_collection?) do
defp run_epoch(batch_fn, handler_fns, loop_state, data, debug?, garbage_collect) do
Enum.reduce_while(data, {:continue, batch_fn, loop_state}, fn data, {_, batch_fn, state} ->
case fire_event(:iteration_started, handler_fns, state, debug?) do
{:halt_epoch, state} ->
Expand Down Expand Up @@ -1841,7 +1848,7 @@ defmodule Axon.Loop do
{:halt, {:halt_loop, batch_fn, state}}

{:continue, state} ->
if force_garbage_collection? do
if garbage_collect do
:erlang.garbage_collect()
end

Expand Down

0 comments on commit 7a2e9bc

Please sign in to comment.