Skip to content

Commit

Permalink
Minor refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
wenchenvincent committed Jun 29, 2024
1 parent bc083b8 commit a6f52ae
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
10 changes: 0 additions & 10 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,6 @@ def get_fp8_max(fp8_dtype, out_dtype):
jnp.float8_e4m3fnuz, jnp.float8_e5m2fnuz)
return jnp.finfo(fp8_dtype).max.astype(out_dtype)

def get_fp8_dtypes(fp8_genre):
assert fp8_genre in ('OCP', 'NANOO')
if fp8_genre == 'OCP':
e4m3_dtype = jnp.float8_e4m3fn
e5m2_dtype = jnp.float8_e5m2
else: # fp8_genre == 'NANOO'
e4m3_dtype = jnp.float8_e4m3fnuz
e5m2_dtype = jnp.float8_e5m2fnuz
return e4m3_dtype, e5m2_dtype

def quantize(x, q_dtype, scale, compute_dtype):
# Explicitly cast the max values to the compute dtype to avoid unnecessary
# casting to FP32 during the subsequent math operations."
Expand Down
18 changes: 14 additions & 4 deletions tests/linen/linen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,6 +1243,15 @@ def test_hashable(self):
self.assertNotEqual(hash(id1), hash(id1c))
self.assertNotEqual(hash(id1), hash(id1dc))

def get_fp8_dtypes(fp8_genre):
assert fp8_genre in ('OCP', 'NANOO')
if fp8_genre == 'OCP':
e4m3_dtype = jnp.float8_e4m3fn
e5m2_dtype = jnp.float8_e5m2
else: # fp8_genre == 'NANOO'
e4m3_dtype = jnp.float8_e4m3fnuz
e5m2_dtype = jnp.float8_e5m2fnuz
return e4m3_dtype, e5m2_dtype

class Fp8Test(parameterized.TestCase):
@parameterized.parameters(
Expand All @@ -1257,7 +1266,7 @@ def test_fp8_dot_general_injection(self, fp8_genre):
compute_dtype=jnp.float32,
)

e4m3_dtype, e5m2_dtype = fp8_ops.get_fp8_dtypes(fp8_genre)
e4m3_dtype, e5m2_dtype = get_fp8_dtypes(fp8_genre)

init_key, random_key = random.split(random.PRNGKey(seed=123), 2)
x = cast_to_representable(
Expand Down Expand Up @@ -1361,7 +1370,7 @@ def loss_fn(vars):
scale_x, amax_history_x = jnp.ones(()), jnp.zeros((1024,))
scale_k, amax_history_k = jnp.ones(()), jnp.zeros((1024,))
scale_g, amax_history_g = jnp.ones(()), jnp.zeros((1024,))
e4m3_dtype, e5m2_dtype = fp8_ops.get_fp8_dtypes(fp8_genre)
e4m3_dtype, e5m2_dtype = get_fp8_dtypes(fp8_genre)
e4m3_max = jnp.finfo(e4m3_dtype).max.astype(jnp.float32)
e5m2_max = jnp.finfo(e5m2_dtype).max.astype(jnp.float32)

Expand Down Expand Up @@ -1422,7 +1431,7 @@ def test_fp8_meta_dtype(self, fp8_genre, use_jit):
self.skipTest("TODO: requires newer jax that has earray")
f32 = jnp.dtype('float32')
fm32 = fp8_ops.fm32
e4m3_dtype, _ = fp8_ops.get_fp8_dtypes(fp8_genre)
e4m3_dtype, _ = get_fp8_dtypes(fp8_genre)
e4m3_max = 448 if fp8_genre == 'OCP' else 240

# Create a scan loop with reused ah_f32 and sf_f32. So, the autograd will
Expand Down Expand Up @@ -1456,4 +1465,5 @@ def body_fun(carry, _):
np.testing.assert_allclose(new_ah, [4., 2., 3.])
np.testing.assert_allclose(new_sf, [3. / e4m3_max])


if __name__ == '__main__':
absltest.main()

0 comments on commit a6f52ae

Please sign in to comment.