diff --git a/praxis/layers/injection/fp8_nvidia_gpu.py b/praxis/layers/injection/fp8_nvidia_gpu.py index cf8bff41..9e8480d2 100644 --- a/praxis/layers/injection/fp8_nvidia_gpu.py +++ b/praxis/layers/injection/fp8_nvidia_gpu.py @@ -41,7 +41,7 @@ from praxis import pax_fiddle from praxis import pytypes -def _get_fp8_args(): +def _get_fp8_args(amax_history_length, mesh_shape): OVERWRITE_WITH_GRADIENT = ( base_layer.WeightHParamsCollection.OVERWRITE_WITH_GRADIENT ) @@ -52,15 +52,15 @@ def _get_fp8_args(): 'shape': [1], 'init': base_layer.WeightInit.Constant(1.0), 'dtype': jnp.float32, - 'mesh_shape': self.mesh_shape, + 'mesh_shape': mesh_shape, 'tensor_split_dims_mapping': None, 'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION], } amax_history_args = { - 'shape': [self.amax_history_length], + 'shape': [amax_history_length], 'init': base_layer.WeightInit.Constant(0.0), 'dtype': jnp.float32, - 'mesh_shape': self.mesh_shape, + 'mesh_shape': mesh_shape, 'tensor_split_dims_mapping': None, 'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION], } @@ -72,7 +72,8 @@ class Fp8EinsumOp(base_layer.BaseLayer): amax_history_length: int = 1024 def setup(self) -> None: - scale_args, amax_history_args = _get_fp8_args() + scale_args, amax_history_args = _get_fp8_args(self.amax_history_length, + self.mesh_shape) self.create_variable( 'input_amax_history', base_layer.WeightHParams(**amax_history_args) @@ -136,8 +137,8 @@ class Fp8EinsumGatedOp(Fp8EinsumOp): def setup(self) -> None: super().setup() - - scale_args, amax_history_args = _get_fp8_args() + scale_args, amax_history_args = _get_fp8_args(self.amax_history_length, + self.mesh_shape) self.create_variable( 'kernel_amax_history_gated',