diff --git a/lib/axon/loop.ex b/lib/axon/loop.ex index 326f514d..79d3ec14 100644 --- a/lib/axon/loop.ex +++ b/lib/axon/loop.ex @@ -1590,6 +1590,9 @@ defmodule Axon.Loop do functions. JIT compilation must be used for gradient computations. Defaults to true. + * `:force_garbage_collect?` - whether or not to force garbage collection after + each loop iteration. This may prevent OOMs, but it will slow down training. + * `:strict?` - whether or not to compile step functions strictly. If this flag is set, the loop will raise on any cache miss during the training loop. Defaults to true. @@ -1605,6 +1608,7 @@ defmodule Axon.Loop do {max_epochs, opts} = Keyword.pop(opts, :epochs, 1) {max_iterations, opts} = Keyword.pop(opts, :iterations, -1) {jit_compile?, opts} = Keyword.pop(opts, :jit_compile?, true) + {force_garbage_collection?, opts} = Keyword.pop(opts, :force_garbage_collection?, false) {strict?, jit_opts} = Keyword.pop(opts, :strict?, true) debug? = Keyword.get(jit_opts, :debug, false) @@ -1693,7 +1697,7 @@ defmodule Axon.Loop do end {time, status_batch_fn_and_state} = - :timer.tc(&run_epoch/5, [batch_fn, handler_fns, state, data, debug?]) + :timer.tc(&run_epoch/6, [batch_fn, handler_fns, state, data, debug?, force_garbage_collection?]) if debug? do Logger.debug("Axon.Loop finished running epoch in #{us_to_ms(time)} ms") @@ -1780,7 +1784,7 @@ defmodule Axon.Loop do end end - defp run_epoch(batch_fn, handler_fns, loop_state, data, debug?) do + defp run_epoch(batch_fn, handler_fns, loop_state, data, debug?, force_garbage_collection?) do Enum.reduce_while(data, {:continue, batch_fn, loop_state}, fn data, {_, batch_fn, state} -> case fire_event(:iteration_started, handler_fns, state, debug?) do {:halt_epoch, state} -> @@ -1837,6 +1841,10 @@ defmodule Axon.Loop do {:halt, {:halt_loop, batch_fn, state}} {:continue, state} -> + if force_garbage_collection? do + :erlang.garbage_collect() + end + state = %{state | iteration: iters + 1} if max_iterations_reached?(max_iters, iters) do