diff --git a/praxis/layers/injection/fp8_nvidia_gpu.py b/praxis/layers/injection/fp8_nvidia_gpu.py index 9e8480d2..c93964cd 100644 --- a/praxis/layers/injection/fp8_nvidia_gpu.py +++ b/praxis/layers/injection/fp8_nvidia_gpu.py @@ -41,6 +41,8 @@ from praxis import pax_fiddle from praxis import pytypes +JTensor = pytypes.JTensor + def _get_fp8_args(amax_history_length, mesh_shape): OVERWRITE_WITH_GRADIENT = ( base_layer.WeightHParamsCollection.OVERWRITE_WITH_GRADIENT @@ -92,18 +94,11 @@ def setup(self) -> None: 'output_grad_scale', base_layer.WeightHParams(**scale_args) ) - def __call__(self, equation: str, *args: pytypes.JTensor) -> pytypes.JTensor: - assert len(args) == 2 - x = args[0] - k = args[1] - - comp_dtype = self.fprop_dtype - assert ( - k.dtype == comp_dtype - ), f'k dtype has to be {comp_dtype}, but got {k.dtype}' - x = jnp.asarray(x, comp_dtype) - + def quantized_einsum( + self, equation: str, x: JTensor, k: JTensor, return_quantized_x: bool + ) -> JTensor | tuple[JTensor, JTensor]: theta = self.theta + comp_dtype = self.fprop_dtype x_qdq = fp8_ops.in_qdq( comp_dtype, @@ -130,6 +125,23 @@ def __call__(self, equation: str, *args: pytypes.JTensor) -> pytypes.JTensor: theta.output_grad_amax_history, ) + if return_quantized_x: + return y, x_qdq + return y + + def __call__(self, equation: str, *args: JTensor) -> JTensor: + assert len(args) == 2 + x = args[0] + k = args[1] + + comp_dtype = self.fprop_dtype + assert ( + k.dtype == comp_dtype + ), f'k dtype has to be {comp_dtype}, but got {k.dtype}' + x = jnp.asarray(x, comp_dtype) + + y = self.quantized_einsum(equation, x, k, return_quantized_x=False) + return y class Fp8EinsumGatedOp(Fp8EinsumOp): @@ -156,7 +168,7 @@ def setup(self) -> None: 'output_grad_scale_gated', base_layer.WeightHParams(**scale_args) ) - def __call__(self, equation: str, *args: pytypes.JTensor) -> pytypes.JTensor: + def __call__(self, equation: str, *args: JTensor) -> tuple[JTensor, JTensor]: assert len(args) == 3 x, k, k_gated = args @@ -166,22 +178,10 @@ def __call__(self, equation: str, *args: pytypes.JTensor) -> pytypes.JTensor: ), f'k dtype has to be {comp_dtype}, but got {k.dtype} and {k_gated.dtype}' x = jnp.asarray(x, comp_dtype) + y, x_qdq = self.quantized_einsum(equation, x, k, return_quantized_x=True) + theta = self.theta - x_qdq = fp8_ops.in_qdq( - comp_dtype, - jnp.float8_e4m3fn, - x, - theta.input_scale, - theta.input_amax_history, - ) - k_qdq = fp8_ops.in_qdq( - comp_dtype, - jnp.float8_e4m3fn, - k, - theta.kernel_scale, - theta.kernel_amax_history, - ) k_gated_qdq = fp8_ops.in_qdq( comp_dtype, jnp.float8_e4m3fn, @@ -189,20 +189,10 @@ def __call__(self, equation: str, *args: pytypes.JTensor) -> pytypes.JTensor: theta.kernel_scale_gated, theta.kernel_amax_history_gated, ) - y_qdq = jnp.einsum( - equation, x_qdq, k_qdq, _dot_general=fp8_ops.dot_general_with_precision - ) y_gated_qdq = jnp.einsum( equation, x_qdq, k_gated_qdq, _dot_general=fp8_ops.dot_general_with_precision ) - y = fp8_ops.out_qdq( - comp_dtype, - jnp.float8_e5m2, - y_qdq, - theta.output_grad_scale, - theta.output_grad_amax_history, - ) y_gated = fp8_ops.out_qdq( comp_dtype, jnp.float8_e5m2,