From 7a2e9bcbf63583400e6b69d9ecf99365d5e9ce0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Tue, 12 Dec 2023 09:33:45 +0100 Subject: [PATCH] Rename garbage collect option --- lib/axon/loop.ex | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/lib/axon/loop.ex b/lib/axon/loop.ex index 79d3ec14..5b383bf6 100644 --- a/lib/axon/loop.ex +++ b/lib/axon/loop.ex @@ -1590,7 +1590,7 @@ defmodule Axon.Loop do functions. JIT compilation must be used for gradient computations. Defaults to true. - * `:force_garbage_collect?` - whether or not to force garbage collection after + * `:garbage_collect` - whether or not to garbage collect after each loop iteration. This may prevent OOMs, but it will slow down training. * `:strict?` - whether or not to compile step functions strictly. If this flag @@ -1608,7 +1608,7 @@ defmodule Axon.Loop do {max_epochs, opts} = Keyword.pop(opts, :epochs, 1) {max_iterations, opts} = Keyword.pop(opts, :iterations, -1) {jit_compile?, opts} = Keyword.pop(opts, :jit_compile?, true) - {force_garbage_collection?, opts} = Keyword.pop(opts, :force_garbage_collection?, false) + {garbage_collect, opts} = Keyword.pop(opts, :garbage_collect, false) {strict?, jit_opts} = Keyword.pop(opts, :strict?, true) debug? = Keyword.get(jit_opts, :debug, false) @@ -1680,8 +1680,8 @@ defmodule Axon.Loop do batch_fn = {:non_compiled, build_batch_fn(step_fn, metric_fns), jit_compile?, strict?, jit_opts} - epoch_start..epoch_end//1 - |> Enum.reduce_while( + Enum.reduce_while( + epoch_start..epoch_end//1, {batch_fn, final_metrics_map, state}, fn epoch, {batch_fn, final_metrics_map, loop_state} -> case fire_event(:epoch_started, handler_fns, loop_state, debug?) do @@ -1697,7 +1697,14 @@ defmodule Axon.Loop do end {time, status_batch_fn_and_state} = - :timer.tc(&run_epoch/6, [batch_fn, handler_fns, state, data, debug?, force_garbage_collection?]) + :timer.tc(&run_epoch/6, [ + batch_fn, + handler_fns, + state, + data, + debug?, + garbage_collect + ]) if debug? do Logger.debug("Axon.Loop finished running epoch in #{us_to_ms(time)} ms") @@ -1784,7 +1791,7 @@ defmodule Axon.Loop do end end - defp run_epoch(batch_fn, handler_fns, loop_state, data, debug?, force_garbage_collection?) do + defp run_epoch(batch_fn, handler_fns, loop_state, data, debug?, garbage_collect) do Enum.reduce_while(data, {:continue, batch_fn, loop_state}, fn data, {_, batch_fn, state} -> case fire_event(:iteration_started, handler_fns, state, debug?) do {:halt_epoch, state} -> @@ -1841,7 +1848,7 @@ defmodule Axon.Loop do {:halt, {:halt_loop, batch_fn, state}} {:continue, state} -> - if force_garbage_collection? do + if garbage_collect do :erlang.garbage_collect() end