Skip to content

Commit

Permalink
add EmbeddingLocation.MTIA to tbe module
Browse files Browse the repository at this point in the history
Summary:
For MTIA, tbe weights have the following characteristics
1. fp32 as qparam instead of fp16,
2. row alignment 1,
3. no cacheline alignment.

Differential Revision: D52860637
  • Loading branch information
842974287 authored and facebook-github-bot committed Jan 19, 2024
1 parent 9a3c5b2 commit bfea529
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class EmbeddingLocation(enum.IntEnum):
MANAGED = 1
MANAGED_CACHING = 2
HOST = 3
MTIA = 4


class CacheAlgorithm(enum.Enum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def nbit_construct_split_state(
placements.append(EmbeddingLocation.HOST)
offsets.append(host_size)
host_size += state_size
elif location == EmbeddingLocation.DEVICE:
elif location == EmbeddingLocation.DEVICE or location == EmbeddingLocation.MTIA:
placements.append(EmbeddingLocation.DEVICE)
offsets.append(dev_size)
dev_size += state_size
Expand Down Expand Up @@ -844,26 +844,36 @@ def initialize_physical_weights_placements_and_offsets(
def reset_weights_placements_and_offsets(
self, device: torch.device, location: int
) -> None:
# Overwrite location in embedding_specs with new location
# Use map since can't script enum call (ie. EmbeddingLocation(value))
INT_TO_EMBEDDING_LOCATION = {
EmbeddingLocation.DEVICE.value: EmbeddingLocation.DEVICE,
EmbeddingLocation.MANAGED.value: EmbeddingLocation.MANAGED,
EmbeddingLocation.MANAGED_CACHING.value: EmbeddingLocation.MANAGED_CACHING,
EmbeddingLocation.HOST.value: EmbeddingLocation.HOST,
EmbeddingLocation.MTIA.value: EmbeddingLocation.MTIA,
}
# Reset device/location denoted in embedding specs
self.reset_embedding_spec_location(device, location)
target_location = INT_TO_EMBEDDING_LOCATION[location]
if target_location == EmbeddingLocation.MTIA:
self.scale_bias_size_in_bytes = 8
self.reset_embedding_spec_location(device, target_location)
# Initialize all physical/logical weights placements and offsets without initializing large dev weights tensor
self.initialize_physical_weights_placements_and_offsets()
self.initialize_physical_weights_placements_and_offsets(
cacheline_alignment=target_location != EmbeddingLocation.MTIA
)
self.initialize_logical_weights_placements_and_offsets()

def reset_embedding_spec_location(
self, device: torch.device, location: int
self, device: torch.device, target_location: EmbeddingLocation
) -> None:
# Overwrite location in embedding_specs with new location
# Use map since can't script enum call (ie. EmbeddingLocation(value))
INT_TO_EMBEDDING_LOCATION = {
0: EmbeddingLocation.DEVICE,
1: EmbeddingLocation.MANAGED,
2: EmbeddingLocation.MANAGED_CACHING,
3: EmbeddingLocation.HOST,
}
target_location = INT_TO_EMBEDDING_LOCATION[location]
self.current_device = device
self.row_alignment = 1 if target_location == EmbeddingLocation.HOST else 16
self.row_alignment = (
1
if target_location == EmbeddingLocation.HOST
or target_location == EmbeddingLocation.MTIA
else 16
)
self.embedding_specs = [
(spec[0], spec[1], spec[2], spec[3], target_location)
for spec in self.embedding_specs
Expand Down

0 comments on commit bfea529

Please sign in to comment.