Skip to content

Commit

Permalink
added check to avoid div 0 errors in cache report; added in place wei…
Browse files Browse the repository at this point in the history
…ght initial methods

Summary: as in title

Differential Revision: D44096435

fbshipit-source-id: bc71dbfd9382c17f1abd09540217b807a6887902
  • Loading branch information
Xiao Sun authored and facebook-github-bot committed Mar 15, 2023
1 parent da01a59 commit 418d607
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,10 +979,11 @@ def print_uvm_cache_stats(self) -> None:
f"N_conflict_unique_misses: {uvm_cache_stats[4]}\n"
f"N_conflict_misses: {uvm_cache_stats[5]}\n"
)
logging.info(
f"unique indices / requested indices: {uvm_cache_stats[2]/uvm_cache_stats[1]}\n"
f"unique misses / requested indices: {uvm_cache_stats[3]/uvm_cache_stats[1]}\n"
)
if uvm_cache_stats[1]:
logging.info(
f"unique indices / requested indices: {uvm_cache_stats[2]/uvm_cache_stats[1]}\n"
f"unique misses / requested indices: {uvm_cache_stats[3]/uvm_cache_stats[1]}\n"
)

def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
self.timestep += 1
Expand Down Expand Up @@ -2347,10 +2348,11 @@ def print_uvm_cache_stats(self) -> None:
f"N_conflict_unique_misses: {uvm_cache_stats[4]}\n"
f"N_conflict_misses: {uvm_cache_stats[5]}\n"
)
logging.info(
f"unique indices / requested indices: {uvm_cache_stats[2]/uvm_cache_stats[1]}\n"
f"unique misses / requested indices: {uvm_cache_stats[3]/uvm_cache_stats[1]}\n"
)
if uvm_cache_stats[1]:
logging.info(
f"unique indices / requested indices: {uvm_cache_stats[2]/uvm_cache_stats[1]}\n"
f"unique misses / requested indices: {uvm_cache_stats[3]/uvm_cache_stats[1]}\n"
)

@torch.jit.export
def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
Expand Down Expand Up @@ -3042,6 +3044,23 @@ def fill_random_weights(self) -> None:
)
)

def fill_fp16_random_weights(self) -> None:
"""
Fill the buffer with random weights, table by table
"""
self.initialize_weights()
weights = self.split_embedding_weights()
for dest_weight in weights:
dest_weight[0].view(torch.float16).normal_(0, 0.1).view(torch.uint8)

def copy_weights(self, emb_tensor: torch.Tensor) -> None:
"""
Fill the buffer with random weights, table by table
"""
self.initialize_weights()
weights = self.split_embedding_weights()
weights[0][0].copy_(emb_tensor.view(torch.uint8))

def assign_embedding_weights(
self, q_weight_list: List[Tuple[Tensor, Optional[Tensor]]]
) -> None:
Expand Down

0 comments on commit 418d607

Please sign in to comment.