Skip to content

Commit

Permalink
Minimize duplicated code
Browse files Browse the repository at this point in the history
  • Loading branch information
kaixih committed Jul 20, 2024
1 parent 9217310 commit 203e9f7
Showing 1 changed file with 26 additions and 36 deletions.
62 changes: 26 additions & 36 deletions praxis/layers/injection/fp8_nvidia_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from praxis import pax_fiddle
from praxis import pytypes

JTensor = pytypes.JTensor

def _get_fp8_args(amax_history_length, mesh_shape):
OVERWRITE_WITH_GRADIENT = (
base_layer.WeightHParamsCollection.OVERWRITE_WITH_GRADIENT
Expand Down Expand Up @@ -92,18 +94,11 @@ def setup(self) -> None:
'output_grad_scale', base_layer.WeightHParams(**scale_args)
)

def __call__(self, equation: str, *args: pytypes.JTensor) -> pytypes.JTensor:
assert len(args) == 2
x = args[0]
k = args[1]

comp_dtype = self.fprop_dtype
assert (
k.dtype == comp_dtype
), f'k dtype has to be {comp_dtype}, but got {k.dtype}'
x = jnp.asarray(x, comp_dtype)

def quantized_einsum(
self, equation: str, x: JTensor, k: JTensor, return_quantized_x: bool
) -> JTensor | tuple[JTensor, JTensor]:
theta = self.theta
comp_dtype = self.fprop_dtype

x_qdq = fp8_ops.in_qdq(
comp_dtype,
Expand All @@ -130,6 +125,23 @@ def __call__(self, equation: str, *args: pytypes.JTensor) -> pytypes.JTensor:
theta.output_grad_amax_history,
)

if return_quantized_x:
return y, x_qdq
return y

def __call__(self, equation: str, *args: JTensor) -> JTensor:
assert len(args) == 2
x = args[0]
k = args[1]

comp_dtype = self.fprop_dtype
assert (
k.dtype == comp_dtype
), f'k dtype has to be {comp_dtype}, but got {k.dtype}'
x = jnp.asarray(x, comp_dtype)

y = self.quantized_einsum(equation, x, k, return_quantized_x=False)

return y

class Fp8EinsumGatedOp(Fp8EinsumOp):
Expand All @@ -156,7 +168,7 @@ def setup(self) -> None:
'output_grad_scale_gated', base_layer.WeightHParams(**scale_args)
)

def __call__(self, equation: str, *args: pytypes.JTensor) -> pytypes.JTensor:
def __call__(self, equation: str, *args: JTensor) -> tuple[JTensor, JTensor]:
assert len(args) == 3
x, k, k_gated = args

Expand All @@ -166,43 +178,21 @@ def __call__(self, equation: str, *args: pytypes.JTensor) -> pytypes.JTensor:
), f'k dtype has to be {comp_dtype}, but got {k.dtype} and {k_gated.dtype}'
x = jnp.asarray(x, comp_dtype)

y, x_qdq = self.quantized_einsum(equation, x, k, return_quantized_x=True)

theta = self.theta

x_qdq = fp8_ops.in_qdq(
comp_dtype,
jnp.float8_e4m3fn,
x,
theta.input_scale,
theta.input_amax_history,
)
k_qdq = fp8_ops.in_qdq(
comp_dtype,
jnp.float8_e4m3fn,
k,
theta.kernel_scale,
theta.kernel_amax_history,
)
k_gated_qdq = fp8_ops.in_qdq(
comp_dtype,
jnp.float8_e4m3fn,
k_gated,
theta.kernel_scale_gated,
theta.kernel_amax_history_gated,
)
y_qdq = jnp.einsum(
equation, x_qdq, k_qdq, _dot_general=fp8_ops.dot_general_with_precision
)
y_gated_qdq = jnp.einsum(
equation, x_qdq, k_gated_qdq,
_dot_general=fp8_ops.dot_general_with_precision
)
y = fp8_ops.out_qdq(
comp_dtype,
jnp.float8_e5m2,
y_qdq,
theta.output_grad_scale,
theta.output_grad_amax_history,
)
y_gated = fp8_ops.out_qdq(
comp_dtype,
jnp.float8_e5m2,
Expand Down

0 comments on commit 203e9f7

Please sign in to comment.