diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index bd9be9c452..af792e4c02 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -4135,7 +4135,9 @@ def test_nbit_forward_gpu_no_cache( else: weights_ty: SparseType = nbit_weights_ty mixed_weights_ty = False - output_dtype = random.choice([SparseType.FP32, SparseType.FP16]) + output_dtype = random.choice( + [SparseType.FP32, SparseType.FP16, SparseType.BF16] + ) self.execute_nbit_forward_( T, D,