diff --git a/scico/flax/train/input_pipeline.py b/scico/flax/train/input_pipeline.py index abb966233..757ac7a5e 100644 --- a/scico/flax/train/input_pipeline.py +++ b/scico/flax/train/input_pipeline.py @@ -26,7 +26,7 @@ from .typed_dict import DataSetDict DType = Any -KeyArray = Union[Array, jax.random.PRNGKeyArray] +KeyArray = Union[Array, jax.Array] class IterateData: diff --git a/scico/flax/train/state.py b/scico/flax/train/state.py index 21dab6ec0..9c6952cb4 100644 --- a/scico/flax/train/state.py +++ b/scico/flax/train/state.py @@ -21,7 +21,7 @@ from .typed_dict import ConfigDict, ModelVarDict ModuleDef = Any -KeyArray = Union[Array, jax.random.PRNGKeyArray] +KeyArray = Union[Array, jax.Array] PyTree = Any ArrayTree = optax.Params diff --git a/scico/flax/train/steps.py b/scico/flax/train/steps.py index 8b3df81ac..8901e1881 100644 --- a/scico/flax/train/steps.py +++ b/scico/flax/train/steps.py @@ -19,7 +19,7 @@ from .state import TrainState from .typed_dict import DataSetDict, MetricsDict -KeyArray = Union[Array, jax.random.PRNGKeyArray] +KeyArray = Union[Array, jax.Array] PyTree = Any diff --git a/scico/flax/train/trainer.py b/scico/flax/train/trainer.py index 8ce144214..93cf97b84 100644 --- a/scico/flax/train/trainer.py +++ b/scico/flax/train/trainer.py @@ -47,10 +47,11 @@ from .typed_dict import ConfigDict, DataSetDict, MetricsDict, ModelVarDict ModuleDef = Any -KeyArray = Union[Array, jax.random.PRNGKeyArray] +KeyArray = Union[Array, jax.Array] PyTree = Any DType = Any + # sync across replicas def sync_batch_stats(state: TrainState) -> TrainState: """Sync the batch statistics across replicas."""