diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py index 230c693f3e..221a5ea897 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py @@ -294,7 +294,10 @@ def __init__( # noqa C901 ) hash_size_cumsum = [0] + list(accumulate(rows)) - self.total_hash_size_bits = int(log2(float(hash_size_cumsum[-1])) + 1) + if hash_size_cumsum[-1] == 0: + self.total_hash_size_bits: int = 0 + else: + self.total_hash_size_bits: int = int(log2(float(hash_size_cumsum[-1])) + 1) # The last element is to easily access # of rows of each table by # hash_size_cumsum[t + 1] - hash_size_cumsum[t] hash_size_cumsum = [hash_size_cumsum[t] for t in self.feature_table_map] + [ @@ -1362,7 +1365,10 @@ def __init__( assert self.D_offsets.numel() == T + 1 hash_size_cumsum = [0] + list(accumulate(rows)) - self.total_hash_size_bits = int(log2(float(hash_size_cumsum[-1])) + 1) + if hash_size_cumsum[-1] == 0: + self.total_hash_size_bits: int = 0 + else: + self.total_hash_size_bits: int = int(log2(float(hash_size_cumsum[-1])) + 1) # The last element is to easily access # of rows of each table by # hash_size_cumsum[t + 1] - hash_size_cumsum[t] hash_size_cumsum = [hash_size_cumsum[t] for t in feature_table_map] + [