Skip to content

Commit

Permalink
Force garbage collection option (#551)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 authored Dec 11, 2023
1 parent 67b48c7 commit 52be5dc
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1590,6 +1590,9 @@ defmodule Axon.Loop do
functions. JIT compilation must be used for gradient computations. Defaults
to true.
* `:force_garbage_collect?` - whether or not to force garbage collection after
each loop iteration. This may prevent OOMs, but it will slow down training.
* `:strict?` - whether or not to compile step functions strictly. If this flag
is set, the loop will raise on any cache miss during the training loop. Defaults
to true.
Expand All @@ -1605,6 +1608,7 @@ defmodule Axon.Loop do
{max_epochs, opts} = Keyword.pop(opts, :epochs, 1)
{max_iterations, opts} = Keyword.pop(opts, :iterations, -1)
{jit_compile?, opts} = Keyword.pop(opts, :jit_compile?, true)
{force_garbage_collection?, opts} = Keyword.pop(opts, :force_garbage_collection?, false)
{strict?, jit_opts} = Keyword.pop(opts, :strict?, true)
debug? = Keyword.get(jit_opts, :debug, false)

Expand Down Expand Up @@ -1693,7 +1697,7 @@ defmodule Axon.Loop do
end

{time, status_batch_fn_and_state} =
:timer.tc(&run_epoch/5, [batch_fn, handler_fns, state, data, debug?])
:timer.tc(&run_epoch/6, [batch_fn, handler_fns, state, data, debug?, force_garbage_collection?])

if debug? do
Logger.debug("Axon.Loop finished running epoch in #{us_to_ms(time)} ms")
Expand Down Expand Up @@ -1780,7 +1784,7 @@ defmodule Axon.Loop do
end
end

defp run_epoch(batch_fn, handler_fns, loop_state, data, debug?) do
defp run_epoch(batch_fn, handler_fns, loop_state, data, debug?, force_garbage_collection?) do
Enum.reduce_while(data, {:continue, batch_fn, loop_state}, fn data, {_, batch_fn, state} ->
case fire_event(:iteration_started, handler_fns, state, debug?) do
{:halt_epoch, state} ->
Expand Down Expand Up @@ -1837,6 +1841,10 @@ defmodule Axon.Loop do
{:halt, {:halt_loop, batch_fn, state}}

{:continue, state} ->
if force_garbage_collection? do
:erlang.garbage_collect()
end

state = %{state | iteration: iters + 1}

if max_iterations_reached?(max_iters, iters) do
Expand Down

0 comments on commit 52be5dc

Please sign in to comment.