Skip to content

Commit

Permalink
fix duplicated quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
kaixih committed Jul 19, 2024
1 parent 7e6f0dd commit cf12a06
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 29 deletions.
11 changes: 11 additions & 0 deletions praxis/layers/base_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ def __call__(self, equation: str, *args: JTensor) -> JTensor:
return jnp.einsum(equation, *args)


class EinsumGatedOp(base_layer.BaseLayer):
"""Wrapper around two jnp.einsum for gated FFN."""

def __call__(self, equation: str, *args: JTensor) -> JTensor:
assert len(args) == 3
x, k, k_gated = args
y = jnp.einsum(equation, x, k)
y_gated = jnp.einsum(equation, x, k_gated)
return y, y_gated


class ArrayLookup(base_layer.BaseLayer):
"""Wrapper around array indexing as used in embedding lookup."""

Expand Down
3 changes: 3 additions & 0 deletions praxis/layers/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def GrokStackedTransformerHParams(
fp8_ops.Fp8EinsumOp
)
p.moe_layer_tpl.einsum_tpl = pax_fiddle.Config(fp8_ops.Fp8EinsumOp)
p.moe_layer_tpl.einsum_gated_tpl = pax_fiddle.Config(
fp8_ops.Fp8EinsumGatedOp
)
return p


Expand Down
129 changes: 107 additions & 22 deletions praxis/layers/injection/fp8_nvidia_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,35 +41,39 @@
from praxis import pax_fiddle
from praxis import pytypes

def _get_fp8_args():
OVERWRITE_WITH_GRADIENT = (
base_layer.WeightHParamsCollection.OVERWRITE_WITH_GRADIENT
)
DISALLOW_BFLOAT16_CONVERSION = (
base_layer.WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION
)
scale_args = {
'shape': [1],
'init': base_layer.WeightInit.Constant(1.0),
'dtype': jnp.float32,
'mesh_shape': self.mesh_shape,
'tensor_split_dims_mapping': None,
'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION],
}
amax_history_args = {
'shape': [self.amax_history_length],
'init': base_layer.WeightInit.Constant(0.0),
'dtype': jnp.float32,
'mesh_shape': self.mesh_shape,
'tensor_split_dims_mapping': None,
'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION],
}
return scale_args, amax_history_args

class Fp8EinsumOp(base_layer.BaseLayer):
"""Wrapper around jnp.einsum used in standard Pax layers."""

amax_history_length: int = 1024

def setup(self) -> None:
OVERWRITE_WITH_GRADIENT = (
base_layer.WeightHParamsCollection.OVERWRITE_WITH_GRADIENT
)
DISALLOW_BFLOAT16_CONVERSION = (
base_layer.WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION
)
scale_args = {
'shape': [1],
'init': base_layer.WeightInit.Constant(1.0),
'dtype': jnp.float32,
'mesh_shape': self.mesh_shape,
'tensor_split_dims_mapping': None,
'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION],
}
amax_history_args = {
'shape': [self.amax_history_length],
'init': base_layer.WeightInit.Constant(0.0),
'dtype': jnp.float32,
'mesh_shape': self.mesh_shape,
'tensor_split_dims_mapping': None,
'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION],
}
scale_args, amax_history_args = _get_fp8_args()

self.create_variable(
'input_amax_history', base_layer.WeightHParams(**amax_history_args)
)
Expand Down Expand Up @@ -126,3 +130,84 @@ def __call__(self, equation: str, *args: pytypes.JTensor) -> pytypes.JTensor:
)

return y

class Fp8EinsumGatedOp(Fp8EinsumOp):
"""Wrapper around two jnp.einsum for gated FFN."""

def setup(self) -> None:
super().setup()

scale_args, amax_history_args = _get_fp8_args()

self.create_variable(
'kernel_amax_history_gated',
base_layer.WeightHParams(**amax_history_args)
)
self.create_variable(
'output_grad_amax_history_gated',
base_layer.WeightHParams(**amax_history_args),
)

self.create_variable(
'kernel_scale_gated', base_layer.WeightHParams(**scale_args)
)
self.create_variable(
'output_grad_scale_gated', base_layer.WeightHParams(**scale_args)
)

def __call__(self, equation: str, *args: pytypes.JTensor) -> pytypes.JTensor:
assert len(args) == 3
x, k, k_gated = args

comp_dtype = self.fprop_dtype
assert (
k.dtype == k_gated.dtype == comp_dtype
), f'k dtype has to be {comp_dtype}, but got {k.dtype} and {k_gated.dtype}'
x = jnp.asarray(x, comp_dtype)

theta = self.theta

x_qdq = fp8_ops.in_qdq(
comp_dtype,
jnp.float8_e4m3fn,
x,
theta.input_scale,
theta.input_amax_history,
)
k_qdq = fp8_ops.in_qdq(
comp_dtype,
jnp.float8_e4m3fn,
k,
theta.kernel_scale,
theta.kernel_amax_history,
)
k_gated_qdq = fp8_ops.in_qdq(
comp_dtype,
jnp.float8_e4m3fn,
k_gated,
theta.kernel_scale_gated,
theta.kernel_amax_history_gated,
)
y_qdq = jnp.einsum(
equation, x_qdq, k_qdq, _dot_general=fp8_ops.dot_general_with_precision
)
y_gated_qdq = jnp.einsum(
equation, x_qdq, k_gated_qdq,
_dot_general=fp8_ops.dot_general_with_precision
)
y = fp8_ops.out_qdq(
comp_dtype,
jnp.float8_e5m2,
y_qdq,
theta.output_grad_scale,
theta.output_grad_amax_history,
)
y_gated = fp8_ops.out_qdq(
comp_dtype,
jnp.float8_e5m2,
y_gated_qdq,
theta.output_grad_scale_gated,
theta.output_grad_amax_history_gated,
)

return y, y_gated
11 changes: 4 additions & 7 deletions praxis/layers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ class TransformerFeedForwardMoe(base_layer.BaseLayer):
moe_gating_embedding_level: str = 'token'
use_gated_activation: bool = False
einsum_tpl: LayerTpl = template_field(base_ops.EinsumOp)
einsum_gated_tpl: LayerTpl = template_field(base_ops.EinsumGatedOp)

# SPMD partition related params.
# M - model_dim, for both inputs and outputs
Expand Down Expand Up @@ -827,8 +828,7 @@ def setup(self) -> None:
self.create_variable('wo_0', wo_pc)
self.create_child('dispatch_einsum', self.einsum_tpl.clone())
if self._is_ffn1_gated:
self.create_child('gated_ffn1_hidden0_einsum', self.einsum_tpl.clone())
self.create_child('gated_ffn1_hidden1_einsum', self.einsum_tpl.clone())
self.create_child('gated_ffn1_einsum', self.einsum_gated_tpl.clone())
self.create_child('ffn2_einsum', self.einsum_tpl.clone())
self.create_child('combine_einsum', self.einsum_tpl.clone())

Expand Down Expand Up @@ -1050,11 +1050,8 @@ def _dispatch_and_combine_expert_outputs(self, inputs, paddings, segment_ids):
expert_inputs = self._split(expert_inputs, ap.egcm)

if self._is_ffn1_gated:
hidden0 = self.gated_ffn1_hidden0_einsum(
'egcm,emh->egch', expert_inputs, theta_wi
)
hidden1 = self.gated_ffn1_hidden1_einsum(
'egcm,emh->egch', expert_inputs, theta_wi_gated
hidden0, hidden1 = self.gated_ffn1_einsum(
'egcm,emh->egch', expert_inputs, theta_wi, theta_wi_gated
)
if self.gating_func in ['top2', 'expert_choice_v2']:
self._count_dead_neurons(hidden1, dispatch_tensor)
Expand Down

0 comments on commit cf12a06

Please sign in to comment.