Skip to content

Commit

Permalink
Fixed test_fp8_meta_dtype.
Browse files Browse the repository at this point in the history
  • Loading branch information
wenchenvincent committed Jun 28, 2024
1 parent 541bddf commit bc083b8
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions tests/linen/linen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,12 +1411,19 @@ def loss_fn(vars):
np.testing.assert_allclose(fp8_vars['kernel_scale'][0], scale_k)
np.testing.assert_allclose(fp8_vars['output_grad_scale'][0], scale_g)

@parameterized.parameters([True, False])
def test_fp8_meta_dtype(self, use_jit):
@parameterized.parameters(
{'fp8_genre': 'OCP', 'use_jit': True},
{'fp8_genre': 'OCP', 'use_jit': False},
{'fp8_genre': 'NANOO', 'use_jit': True},
{'fp8_genre': 'NANOO', 'use_jit': False}
)
def test_fp8_meta_dtype(self, fp8_genre, use_jit):
if not use_jit and not fp8_ops.CAN_USE_EARRAY:
self.skipTest("TODO: requires newer jax that has earray")
f32 = jnp.dtype('float32')
fm32 = fp8_ops.fm32
e4m3_dtype, _ = fp8_ops.get_fp8_dtypes(fp8_genre)
e4m3_max = 448 if fp8_genre == 'OCP' else 240

# Create a scan loop with reused ah_f32 and sf_f32. So, the autograd will
# accumulate the grads of them. We expect the max op (rather than add op)
Expand All @@ -1426,7 +1433,7 @@ def outer(x, ah_f32, sf_f32):
sf_fm32 = jax.lax.convert_element_type(sf_f32, fm32)
array_x = jnp.array([x], f32)
def body_fun(carry, _):
carry = fp8_ops.in_qdq(f32, carry, sf_fm32, ah_fm32)
carry = fp8_ops.in_qdq(f32, e4m3_dtype, carry, sf_fm32, ah_fm32)
return carry, None
array_x, _ = jax.lax.scan(body_fun, array_x, None, length=3)
return array_x[0]
Expand All @@ -1443,12 +1450,10 @@ def body_fun(carry, _):
# 2nd iteration
grads, new_ah, new_sf = outer_fn(3., new_ah, new_sf)
np.testing.assert_allclose(new_ah, [3., 0., 2.])
np.testing.assert_allclose(new_sf, [2. / 448])
np.testing.assert_allclose(new_sf, [2. / e4m3_max])
# 3rd iteration
grads, new_ah, new_sf = outer_fn(4., new_ah, new_sf)
np.testing.assert_allclose(new_ah, [4., 2., 3.])
np.testing.assert_allclose(new_sf, [3. / 448])
np.testing.assert_allclose(new_sf, [3. / e4m3_max])


if __name__ == '__main__':
absltest.main()

0 comments on commit bc083b8

Please sign in to comment.