diff --git a/praxis/layers/base_ops.py b/praxis/layers/base_ops.py index 63bc99c1..b3e81347 100644 --- a/praxis/layers/base_ops.py +++ b/praxis/layers/base_ops.py @@ -42,6 +42,17 @@ def __call__(self, equation: str, *args: JTensor) -> JTensor: return jnp.einsum(equation, *args) +class EinsumGatedOp(base_layer.BaseLayer): + """Wrapper around two jnp.einsum for gated FFN.""" + + def __call__(self, equation: str, *args: JTensor) -> JTensor: + assert len(args) == 3 + x, k, k_gated = args + y = jnp.einsum(equation, x, k) + y_gated = jnp.einsum(equation, x, k_gated) + return y, y_gated + + class ArrayLookup(base_layer.BaseLayer): """Wrapper around array indexing as used in embedding lookup.""" diff --git a/praxis/layers/grok.py b/praxis/layers/grok.py index 95507b11..f66d143f 100644 --- a/praxis/layers/grok.py +++ b/praxis/layers/grok.py @@ -185,6 +185,9 @@ def GrokStackedTransformerHParams( fp8_ops.Fp8EinsumOp ) p.moe_layer_tpl.einsum_tpl = pax_fiddle.Config(fp8_ops.Fp8EinsumOp) + p.moe_layer_tpl.einsum_gated_tpl = pax_fiddle.Config( + fp8_ops.Fp8EinsumGatedOp + ) return p diff --git a/praxis/layers/injection/fp8_nvidia_gpu.py b/praxis/layers/injection/fp8_nvidia_gpu.py index 8af92382..cf8bff41 100644 --- a/praxis/layers/injection/fp8_nvidia_gpu.py +++ b/praxis/layers/injection/fp8_nvidia_gpu.py @@ -41,6 +41,30 @@ from praxis import pax_fiddle from praxis import pytypes +def _get_fp8_args(): + OVERWRITE_WITH_GRADIENT = ( + base_layer.WeightHParamsCollection.OVERWRITE_WITH_GRADIENT + ) + DISALLOW_BFLOAT16_CONVERSION = ( + base_layer.WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION + ) + scale_args = { + 'shape': [1], + 'init': base_layer.WeightInit.Constant(1.0), + 'dtype': jnp.float32, + 'mesh_shape': self.mesh_shape, + 'tensor_split_dims_mapping': None, + 'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION], + } + amax_history_args = { + 'shape': [self.amax_history_length], + 'init': base_layer.WeightInit.Constant(0.0), + 'dtype': jnp.float32, + 'mesh_shape': self.mesh_shape, + 'tensor_split_dims_mapping': None, + 'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION], + } + return scale_args, amax_history_args class Fp8EinsumOp(base_layer.BaseLayer): """Wrapper around jnp.einsum used in standard Pax layers.""" @@ -48,28 +72,8 @@ class Fp8EinsumOp(base_layer.BaseLayer): amax_history_length: int = 1024 def setup(self) -> None: - OVERWRITE_WITH_GRADIENT = ( - base_layer.WeightHParamsCollection.OVERWRITE_WITH_GRADIENT - ) - DISALLOW_BFLOAT16_CONVERSION = ( - base_layer.WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION - ) - scale_args = { - 'shape': [1], - 'init': base_layer.WeightInit.Constant(1.0), - 'dtype': jnp.float32, - 'mesh_shape': self.mesh_shape, - 'tensor_split_dims_mapping': None, - 'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION], - } - amax_history_args = { - 'shape': [self.amax_history_length], - 'init': base_layer.WeightInit.Constant(0.0), - 'dtype': jnp.float32, - 'mesh_shape': self.mesh_shape, - 'tensor_split_dims_mapping': None, - 'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION], - } + scale_args, amax_history_args = _get_fp8_args() + self.create_variable( 'input_amax_history', base_layer.WeightHParams(**amax_history_args) ) @@ -126,3 +130,84 @@ def __call__(self, equation: str, *args: pytypes.JTensor) -> pytypes.JTensor: ) return y + +class Fp8EinsumGatedOp(Fp8EinsumOp): + """Wrapper around two jnp.einsum for gated FFN.""" + + def setup(self) -> None: + super().setup() + + scale_args, amax_history_args = _get_fp8_args() + + self.create_variable( + 'kernel_amax_history_gated', + base_layer.WeightHParams(**amax_history_args) + ) + self.create_variable( + 'output_grad_amax_history_gated', + base_layer.WeightHParams(**amax_history_args), + ) + + self.create_variable( + 'kernel_scale_gated', base_layer.WeightHParams(**scale_args) + ) + self.create_variable( + 'output_grad_scale_gated', base_layer.WeightHParams(**scale_args) + ) + + def __call__(self, equation: str, *args: pytypes.JTensor) -> pytypes.JTensor: + assert len(args) == 3 + x, k, k_gated = args + + comp_dtype = self.fprop_dtype + assert ( + k.dtype == k_gated.dtype == comp_dtype + ), f'k dtype has to be {comp_dtype}, but got {k.dtype} and {k_gated.dtype}' + x = jnp.asarray(x, comp_dtype) + + theta = self.theta + + x_qdq = fp8_ops.in_qdq( + comp_dtype, + jnp.float8_e4m3fn, + x, + theta.input_scale, + theta.input_amax_history, + ) + k_qdq = fp8_ops.in_qdq( + comp_dtype, + jnp.float8_e4m3fn, + k, + theta.kernel_scale, + theta.kernel_amax_history, + ) + k_gated_qdq = fp8_ops.in_qdq( + comp_dtype, + jnp.float8_e4m3fn, + k_gated, + theta.kernel_scale_gated, + theta.kernel_amax_history_gated, + ) + y_qdq = jnp.einsum( + equation, x_qdq, k_qdq, _dot_general=fp8_ops.dot_general_with_precision + ) + y_gated_qdq = jnp.einsum( + equation, x_qdq, k_gated_qdq, + _dot_general=fp8_ops.dot_general_with_precision + ) + y = fp8_ops.out_qdq( + comp_dtype, + jnp.float8_e5m2, + y_qdq, + theta.output_grad_scale, + theta.output_grad_amax_history, + ) + y_gated = fp8_ops.out_qdq( + comp_dtype, + jnp.float8_e5m2, + y_gated_qdq, + theta.output_grad_scale_gated, + theta.output_grad_amax_history_gated, + ) + + return y, y_gated diff --git a/praxis/layers/transformers.py b/praxis/layers/transformers.py index 099d4510..74ed8b8d 100644 --- a/praxis/layers/transformers.py +++ b/praxis/layers/transformers.py @@ -662,6 +662,7 @@ class TransformerFeedForwardMoe(base_layer.BaseLayer): moe_gating_embedding_level: str = 'token' use_gated_activation: bool = False einsum_tpl: LayerTpl = template_field(base_ops.EinsumOp) + einsum_gated_tpl: LayerTpl = template_field(base_ops.EinsumGatedOp) # SPMD partition related params. # M - model_dim, for both inputs and outputs @@ -827,8 +828,7 @@ def setup(self) -> None: self.create_variable('wo_0', wo_pc) self.create_child('dispatch_einsum', self.einsum_tpl.clone()) if self._is_ffn1_gated: - self.create_child('gated_ffn1_hidden0_einsum', self.einsum_tpl.clone()) - self.create_child('gated_ffn1_hidden1_einsum', self.einsum_tpl.clone()) + self.create_child('gated_ffn1_einsum', self.einsum_gated_tpl.clone()) self.create_child('ffn2_einsum', self.einsum_tpl.clone()) self.create_child('combine_einsum', self.einsum_tpl.clone()) @@ -1050,11 +1050,8 @@ def _dispatch_and_combine_expert_outputs(self, inputs, paddings, segment_ids): expert_inputs = self._split(expert_inputs, ap.egcm) if self._is_ffn1_gated: - hidden0 = self.gated_ffn1_hidden0_einsum( - 'egcm,emh->egch', expert_inputs, theta_wi - ) - hidden1 = self.gated_ffn1_hidden1_einsum( - 'egcm,emh->egch', expert_inputs, theta_wi_gated + hidden0, hidden1 = self.gated_ffn1_einsum( + 'egcm,emh->egch', expert_inputs, theta_wi, theta_wi_gated ) if self.gating_func in ['top2', 'expert_choice_v2']: self._count_dead_neurons(hidden1, dispatch_tensor)