diff --git a/neural_testbed/logging/base.py b/neural_testbed/logging/base.py index faf4355..773726c 100644 --- a/neural_testbed/logging/base.py +++ b/neural_testbed/logging/base.py @@ -88,10 +88,10 @@ def problem(self) -> testbed_base.TestbedProblem: def clean_results(results: Dict[str, Any]) -> Dict[str, Any]: - """Cleans the results for logging (can't log jax arrays).""" + """Cleans the results for logging.""" def clean_result(value: Any) -> Any: value = loggers.to_numpy(value) - if isinstance(value, chex.ArrayNumpy) and value.size == 1: + if isinstance(value, chex.Array) and value.size == 1: value = float(value) return value