Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
kaixih committed Jul 19, 2024
1 parent cf12a06 commit 9217310
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions praxis/layers/injection/fp8_nvidia_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from praxis import pax_fiddle
from praxis import pytypes

def _get_fp8_args():
def _get_fp8_args(amax_history_length, mesh_shape):
OVERWRITE_WITH_GRADIENT = (
base_layer.WeightHParamsCollection.OVERWRITE_WITH_GRADIENT
)
Expand All @@ -52,15 +52,15 @@ def _get_fp8_args():
'shape': [1],
'init': base_layer.WeightInit.Constant(1.0),
'dtype': jnp.float32,
'mesh_shape': self.mesh_shape,
'mesh_shape': mesh_shape,
'tensor_split_dims_mapping': None,
'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION],
}
amax_history_args = {
'shape': [self.amax_history_length],
'shape': [amax_history_length],
'init': base_layer.WeightInit.Constant(0.0),
'dtype': jnp.float32,
'mesh_shape': self.mesh_shape,
'mesh_shape': mesh_shape,
'tensor_split_dims_mapping': None,
'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION],
}
Expand All @@ -72,7 +72,8 @@ class Fp8EinsumOp(base_layer.BaseLayer):
amax_history_length: int = 1024

def setup(self) -> None:
scale_args, amax_history_args = _get_fp8_args()
scale_args, amax_history_args = _get_fp8_args(self.amax_history_length,
self.mesh_shape)

self.create_variable(
'input_amax_history', base_layer.WeightHParams(**amax_history_args)
Expand Down Expand Up @@ -136,8 +137,8 @@ class Fp8EinsumGatedOp(Fp8EinsumOp):

def setup(self) -> None:
super().setup()

scale_args, amax_history_args = _get_fp8_args()
scale_args, amax_history_args = _get_fp8_args(self.amax_history_length,
self.mesh_shape)

self.create_variable(
'kernel_amax_history_gated',
Expand Down

0 comments on commit 9217310

Please sign in to comment.