Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the initialization of the cache when we have multi gpu #33303

Merged
merged 12 commits into from
Sep 13, 2024
40 changes: 29 additions & 11 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,9 @@ class StaticCache(Cache):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.

Example:

Expand Down Expand Up @@ -1060,6 +1063,7 @@ def __init__(
device: torch.device = None,
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
if max_batch_size is not None:
Expand Down Expand Up @@ -1088,16 +1092,20 @@ def __init__(
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for idx in range(config.num_hidden_layers):
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
if layer_device_map is not None:
layer_device = layer_device_map[idx]
else:
layer_device = device
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
# Notes:
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
# 2. `torch.export()` requires mutations to be registered as buffers.
if not is_torchdynamo_compiling():
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
torch._dynamo.mark_static_address(new_layer_key_cache)
Expand Down Expand Up @@ -1130,9 +1138,9 @@ def update(
Return:
A tuple containing the updated key and value states.
"""

cache_position = cache_kwargs.get("cache_position")
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)

k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]

Expand Down Expand Up @@ -1201,6 +1209,9 @@ class SlidingWindowCache(StaticCache):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.

Example:

Expand Down Expand Up @@ -1231,6 +1242,7 @@ def __init__(
device: torch.device = None,
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
Expand All @@ -1247,6 +1259,7 @@ def __init__(
device=device,
dtype=dtype,
max_batch_size=max_batch_size,
layer_device_map=layer_device_map,
)

def update(
Expand Down Expand Up @@ -1280,7 +1293,6 @@ def update(
v_out = v_out[:, :, indices]

try:
cache_position.to(device=k_out.device)
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
Expand Down Expand Up @@ -1495,6 +1507,9 @@ class HybridCache(Cache):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (torch.dtype, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.

Example:

Expand Down Expand Up @@ -1525,6 +1540,7 @@ def __init__(
device: Union[torch.device, str] = "cpu",
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
if max_batch_size is not None:
Expand Down Expand Up @@ -1562,11 +1578,15 @@ def __init__(
self.head_dim,
)
for i in range(config.num_hidden_layers):
if layer_device_map is not None:
layer_device = layer_device_map[i]
else:
layer_device = device
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
Expand Down Expand Up @@ -1617,8 +1637,6 @@ def update(
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
sliding_window = cache_kwargs.get("sliding_window")
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
if sliding_window:
Expand Down
27 changes: 27 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,12 +1446,39 @@ def _get_cache(
# models. May cause trobles with non-text modalities.
cache_dtype = self.get_output_embeddings().weight.dtype

def get_layer_device_map(execution_device_map: Optional[dict] = None):
if execution_device_map is None or len(execution_device_map) <= 1:
return None
layer_device_map = {}
for layer in execution_device_map:
for idx in range(self.config.num_hidden_layers):
if f".{idx}." in f"{layer}.":
layer_device_map[idx] = execution_device_map[layer]
break
for idx in range(self.config.num_hidden_layers):
if idx not in layer_device_map:
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
return layer_device_map

execution_device_map = None
# Taken from dispatch_model from accelerate.
# This is needed here if we don't want to make changes in accelerate in order to save execution_device
# For offloaded case, we need to get the execution device, not just the device where it is offloaded
if hasattr(self, "hf_device_map"):
main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0]
execution_device_map = {
name: main_device if device in ["cpu", "disk"] else device
for name, device in self.hf_device_map.items()
}
layer_device_map = get_layer_device_map(execution_device_map)

cache_kwargs = {
"config": self.config if hasattr(self.config, "text_config") else self.config,
"max_batch_size": batch_size,
"max_cache_len": max_cache_len,
"device": device,
"dtype": cache_dtype,
"layer_device_map": layer_device_map,
}
self._cache = cache_cls(**cache_kwargs)
if requires_cross_attention_cache:
Expand Down
85 changes: 85 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3444,6 +3444,91 @@ def test_special_tokens_fall_back_to_model_default(self):
self.assertTrue(test_bos_id == gen_output[0, 0])
self.assertTrue(generation_config.bos_token_id is None)

@pytest.mark.generate
@require_torch_multi_gpu
def test_generate_with_static_cache_multi_gpu(self):
"""
Tests if the static cache has been set correctly and if generate works correctly when we are using multi-gpus.
"""
# need to split manually as auto doesn't work well with unbalanced model
device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")

text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)

generation_kwargs = {
"max_new_tokens": 20,
"cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
}

results = model.generate(input_ids, **generation_kwargs)
self.assertTrue(isinstance(results.past_key_values, StaticCache))

# check device of each layer
key_cache_0 = results.past_key_values.key_cache[0]
value_cache_0 = results.past_key_values.value_cache[0]
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))

key_cache_1 = results.past_key_values.key_cache[1]
value_cache_1 = results.past_key_values.value_cache[1]
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))

@pytest.mark.generate
@require_torch_multi_gpu
def test_init_static_cache_multi_gpu(self):
"""
Tests if the static cache has been set correctly when we initialize it manually in a multi-gpu setup.
"""
# need to split manually as auto doesn't work well with unbalanced model
device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")

text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)

generation_kwargs = {
"max_new_tokens": 20,
"return_dict_in_generate": True, # Required to return `past_key_values`
}

# TODO: We need to raise a warning in case the cache is not set correctly
# with self.assertRaisesRegex(ValueError, "If you are manually initializing the cache"):
# past_key_values = StaticCache(
# config=model.config, batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype
# )
# results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)

# deduced from the device_map : layer 0 on device 0 and layer 1 on device 1
layer_device_map = {0: 0, 1: 1}
past_key_values = StaticCache(
config=model.config,
batch_size=1,
max_cache_len=30,
device=torch_device,
dtype=model.dtype,
layer_device_map=layer_device_map,
)
results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)

# check device of each layer
key_cache_0 = results.past_key_values.key_cache[0]
value_cache_0 = results.past_key_values.value_cache[0]
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))

key_cache_1 = results.past_key_values.key_cache[1]
value_cache_1 = results.past_key_values.value_cache[1]
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))


@require_torch
class TokenHealingTestCase(unittest.TestCase):
Expand Down
Loading