Skip to content

Commit

Permalink
Merge pull request #80 from kaixih:fix_fp8_einsum
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 655615204
  • Loading branch information
pax authors committed Jul 24, 2024
2 parents 5b5ed76 + 203e9f7 commit da4fe8d
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 42 deletions.
11 changes: 11 additions & 0 deletions praxis/layers/base_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ def __call__(self, equation: str, *args: JTensor) -> JTensor:
return jnp.einsum(equation, *args)


class EinsumGatedOp(base_layer.BaseLayer):
"""Wrapper around two jnp.einsum for gated FFN."""

def __call__(self, equation: str, *args: JTensor) -> tuple[JTensor, JTensor]:
assert len(args) == 3
x, k, k_gated = args
y = jnp.einsum(equation, x, k)
y_gated = jnp.einsum(equation, x, k_gated)
return y, y_gated


class ArrayLookup(base_layer.BaseLayer):
"""Wrapper around array indexing as used in embedding lookup."""

Expand Down
6 changes: 3 additions & 3 deletions praxis/layers/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ def GrokStackedTransformerHParams(
fp8_ops.Fp8EinsumOp
)
p.moe_layer_tpl.einsum_tpl = pax_fiddle.Config(fp8_ops.Fp8EinsumOp)
# p.moe_layer_tpl.einsum_tpl = (
# pax_fiddle.Config(fp8_ops.Fp8EinsumOp)
# )
p.moe_layer_tpl.einsum_gated_tpl = pax_fiddle.Config(
fp8_ops.Fp8EinsumGatedOp
)
return p


Expand Down
152 changes: 119 additions & 33 deletions praxis/layers/injection/fp8_nvidia_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"""Op wrappers to support FP8 GEMMs."""

from functools import partial
from typing import Union

from flax.linen import fp8_ops
from jax import custom_vjp
Expand All @@ -41,35 +42,45 @@
from praxis import pax_fiddle
from praxis import pytypes

JTensor = pytypes.JTensor


def _get_fp8_args(amax_history_length, mesh_shape):
OVERWRITE_WITH_GRADIENT = (
base_layer.WeightHParamsCollection.OVERWRITE_WITH_GRADIENT
)
DISALLOW_BFLOAT16_CONVERSION = (
base_layer.WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION
)
scale_args = {
'shape': [1],
'init': base_layer.WeightInit.Constant(1.0),
'dtype': jnp.float32,
'mesh_shape': mesh_shape,
'tensor_split_dims_mapping': None,
'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION],
}
amax_history_args = {
'shape': [amax_history_length],
'init': base_layer.WeightInit.Constant(0.0),
'dtype': jnp.float32,
'mesh_shape': mesh_shape,
'tensor_split_dims_mapping': None,
'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION],
}
return scale_args, amax_history_args


class Fp8EinsumOp(base_layer.BaseLayer):
"""Wrapper around jnp.einsum used in standard Pax layers."""

amax_history_length: int = 1024

def setup(self) -> None:
OVERWRITE_WITH_GRADIENT = (
base_layer.WeightHParamsCollection.OVERWRITE_WITH_GRADIENT
)
DISALLOW_BFLOAT16_CONVERSION = (
base_layer.WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION
)
scale_args = {
'shape': [1],
'init': base_layer.WeightInit.Constant(1.0),
'dtype': jnp.float32,
'mesh_shape': self.mesh_shape,
'tensor_split_dims_mapping': None,
'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION],
}
amax_history_args = {
'shape': [self.amax_history_length],
'init': base_layer.WeightInit.Constant(0.0),
'dtype': jnp.float32,
'mesh_shape': self.mesh_shape,
'tensor_split_dims_mapping': None,
'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION],
}
scale_args, amax_history_args = _get_fp8_args(
self.amax_history_length, self.mesh_shape
)

self.create_variable(
'input_amax_history', base_layer.WeightHParams(**amax_history_args)
)
Expand All @@ -87,18 +98,11 @@ def setup(self) -> None:
'output_grad_scale', base_layer.WeightHParams(**scale_args)
)

def __call__(self, equation: str, *args: pytypes.JTensor) -> pytypes.JTensor:
assert len(args) == 2
x = args[0]
k = args[1]

comp_dtype = self.fprop_dtype
assert (
k.dtype == comp_dtype
), f'k dtype has to be {comp_dtype}, but got {k.dtype}'
x = jnp.asarray(x, comp_dtype)

def quantized_einsum(
self, equation: str, x: JTensor, k: JTensor, return_quantized_x: bool
) -> JTensor | tuple[JTensor, JTensor]:
theta = self.theta
comp_dtype = self.fprop_dtype

x_qdq = fp8_ops.in_qdq(
comp_dtype,
Expand All @@ -125,4 +129,86 @@ def __call__(self, equation: str, *args: pytypes.JTensor) -> pytypes.JTensor:
theta.output_grad_amax_history,
)

if return_quantized_x:
return y, x_qdq
return y

def __call__(
self, equation: str, *args: JTensor
) -> Union[JTensor, tuple[JTensor, JTensor]]:
assert len(args) == 2
x = args[0]
k = args[1]

comp_dtype = self.fprop_dtype
assert (
k.dtype == comp_dtype
), f'k dtype has to be {comp_dtype}, but got {k.dtype}'
x = jnp.asarray(x, comp_dtype)

y = self.quantized_einsum(equation, x, k, return_quantized_x=False)

return y


class Fp8EinsumGatedOp(Fp8EinsumOp):
"""Wrapper around two jnp.einsum for gated FFN."""

def setup(self) -> None:
super().setup()
scale_args, amax_history_args = _get_fp8_args(
self.amax_history_length, self.mesh_shape
)

self.create_variable(
'kernel_amax_history_gated',
base_layer.WeightHParams(**amax_history_args),
)
self.create_variable(
'output_grad_amax_history_gated',
base_layer.WeightHParams(**amax_history_args),
)

self.create_variable(
'kernel_scale_gated', base_layer.WeightHParams(**scale_args)
)
self.create_variable(
'output_grad_scale_gated', base_layer.WeightHParams(**scale_args)
)

def __call__(self, equation: str, *args: JTensor) -> tuple[JTensor, JTensor]:
assert len(args) == 3
x, k, k_gated = args

comp_dtype = self.fprop_dtype
assert (
k.dtype == k_gated.dtype == comp_dtype
), f'k dtype has to be {comp_dtype}, but got {k.dtype} and {k_gated.dtype}'
x = jnp.asarray(x, comp_dtype)

y, x_qdq = self.quantized_einsum(equation, x, k, return_quantized_x=True)

theta = self.theta

k_gated_qdq = fp8_ops.in_qdq(
comp_dtype,
jnp.float8_e4m3fn,
k_gated,
theta.kernel_scale_gated,
theta.kernel_amax_history_gated,
)
y_gated_qdq = jnp.einsum(
equation,
x_qdq,
k_gated_qdq,
_dot_general=fp8_ops.dot_general_with_precision,
)
y_gated = fp8_ops.out_qdq(
comp_dtype,
jnp.float8_e5m2,
y_gated_qdq,
theta.output_grad_scale_gated,
theta.output_grad_amax_history_gated,
)

return y, y_gated
18 changes: 12 additions & 6 deletions praxis/layers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ class TransformerFeedForwardMoe(base_layer.BaseLayer):
moe_gating_embedding_level: str = 'token'
use_gated_activation: bool = False
einsum_tpl: LayerTpl = template_field(base_ops.EinsumOp)
einsum_gated_tpl: LayerTpl = template_field(base_ops.EinsumGatedOp)

# SPMD partition related params.
# M - model_dim, for both inputs and outputs
Expand Down Expand Up @@ -825,7 +826,11 @@ def setup(self) -> None:
)
logging.debug('moe wo WeightHParams %s', wo_pc)
self.create_variable('wo_0', wo_pc)
self.create_child('einsum', self.einsum_tpl.clone())
self.create_child('dispatch_einsum', self.einsum_tpl.clone())
if self._is_ffn1_gated:
self.create_child('gated_ffn1_einsum', self.einsum_gated_tpl.clone())
self.create_child('ffn2_einsum', self.einsum_tpl.clone())
self.create_child('combine_einsum', self.einsum_tpl.clone())

def _split(self, t_in, sharding):
return base_layer.maybe_shard(t_in, sharding, self.mesh_axis_names)
Expand Down Expand Up @@ -1031,7 +1036,7 @@ def _dispatch_and_combine_expert_outputs(self, inputs, paddings, segment_ids):
if self.gating_func in ['top2', 'expert_choice_v2']:
combine_tensor = self._split(combine_tensor, ap.gsec)
dispatch_tensor = self._split(dispatch_tensor, ap.gsec)
expert_inputs = self.einsum(
expert_inputs = self.dispatch_einsum(
'gsec,gsm->egcm', dispatch_tensor, reshaped_inputs
)
elif self.gating_func == 'expert_choice':
Expand All @@ -1045,8 +1050,9 @@ def _dispatch_and_combine_expert_outputs(self, inputs, paddings, segment_ids):
expert_inputs = self._split(expert_inputs, ap.egcm)

if self._is_ffn1_gated:
hidden0 = self.einsum('egcm,emh->egch', expert_inputs, theta_wi)
hidden1 = self.einsum('egcm,emh->egch', expert_inputs, theta_wi_gated)
hidden0, hidden1 = self.gated_ffn1_einsum(
'egcm,emh->egch', expert_inputs, theta_wi, theta_wi_gated
)
if self.gating_func in ['top2', 'expert_choice_v2']:
self._count_dead_neurons(hidden1, dispatch_tensor)
hidden1 = self.activation(hidden1)
Expand All @@ -1061,13 +1067,13 @@ def _dispatch_and_combine_expert_outputs(self, inputs, paddings, segment_ids):
# Dropout.
hidden = self.relu_dropout(hidden)
# Output.
expert_output = self.einsum('egch,ehm->egcm', hidden, theta_wo)
expert_output = self.ffn2_einsum('egch,ehm->egcm', hidden, theta_wo)
expert_output = self._split(expert_output, ap.egcm)
# Now transpose and reshard.
transposed_expert_output = jnp.einsum('egcm->gecm', expert_output)
transposed_expert_output = self._split(transposed_expert_output, ap.gecm)
if self.gating_func in ['top2', 'expert_choice_v2']:
combined_output = self.einsum(
combined_output = self.combine_einsum(
'gecm,gsec->gsm', transposed_expert_output, combine_tensor
)
elif self.gating_func == 'expert_choice':
Expand Down

0 comments on commit da4fe8d

Please sign in to comment.