Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the initialization of the cache when we have multi gpu #33303

Merged
merged 12 commits into from
Sep 13, 2024
76 changes: 65 additions & 11 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,9 @@ class StaticCache(Cache):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.

Example:

Expand Down Expand Up @@ -1020,6 +1023,7 @@ def __init__(
device: torch.device = None,
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
if max_batch_size is not None:
Expand Down Expand Up @@ -1047,16 +1051,20 @@ def __init__(
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for idx in range(config.num_hidden_layers):
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
if layer_device_map is not None:
layer_device = layer_device_map[idx]
else:
layer_device = device
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
# Notes:
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
# 2. `torch.export()` requires mutations to be registered as buffers.
if not is_torchdynamo_compiling():
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
torch._dynamo.mark_static_address(new_layer_key_cache)
Expand Down Expand Up @@ -1089,9 +1097,21 @@ def update(
Return:
A tuple containing the updated key and value states.
"""

for state_str, state_device, self_state_device in [
("key_states", key_states.device, self.key_cache[layer_idx].device),
("value_states", value_states.device, self.value_cache[layer_idx].device),
]:
if state_device != self_state_device:
raise ValueError(
f"Computed {state_str} from layer {layer_idx} is on device {state_device} "
f"whereas stored {state_str} is on device {self_state_device}. "
f"If you are manually initializing the cache, make sure to pass the argument `layer_device_map` if you are using multi-gpu. "
" Otherwise, you can just pass `cache_implementation` in `model.generate()` to correctly initialize the cache."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am really unsure this is worth it for us to run this at every forward pass. I know we want to help our users but would need to make sur it does not cost us anything

Copy link
Member

@gante gante Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a long discussion above, but tl;dr the options are:

  • don't warn at all
  • check devices in update (this implementation)

with torch.compile, these lines should get ignored anyway when called correctly (at tracing they have the same device). We should benchmark compile to confirm, though. Assuming they have no throughput cost, I think it's a win to have the error

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also wrap update with a try/except, rather than using an if/else


cache_position = cache_kwargs.get("cache_position")
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)

k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]

Expand Down Expand Up @@ -1160,6 +1180,9 @@ class SlidingWindowCache(StaticCache):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.

Example:

Expand Down Expand Up @@ -1190,6 +1213,7 @@ def __init__(
device: torch.device = None,
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
Expand All @@ -1206,6 +1230,7 @@ def __init__(
device=device,
dtype=dtype,
max_batch_size=max_batch_size,
layer_device_map=layer_device_map,
)

def update(
Expand All @@ -1215,6 +1240,18 @@ def update(
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor]:
for state_str, state_device, self_state_device in [
("key_states", key_states.device, self.key_cache[layer_idx].device),
("value_states", value_states.device, self.value_cache[layer_idx].device),
]:
if state_device != self_state_device:
raise ValueError(
f"Computed {state_str} from layer {layer_idx} is on device {state_device} "
f"whereas stored {state_str} is on device {self_state_device}. "
f"If you are manually initializing the cache, make sure to pass the argument `layer_device_map` if you are using multi-gpu. "
" Otherwise, you can just pass `cache_implementation` in `model.generate()` to correctly initialize the cache."
)

cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
Expand All @@ -1239,7 +1276,6 @@ def update(
v_out = v_out[:, :, indices]

try:
cache_position.to(device=k_out.device)
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
Expand Down Expand Up @@ -1454,6 +1490,9 @@ class HybridCache(Cache):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (torch.dtype, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.

Example:

Expand Down Expand Up @@ -1484,6 +1523,7 @@ def __init__(
device: Union[torch.device, str] = "cpu",
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
if max_batch_size is not None:
Expand Down Expand Up @@ -1521,11 +1561,15 @@ def __init__(
self.head_dim,
)
for i in range(config.num_hidden_layers):
if layer_device_map is not None:
layer_device = layer_device_map[i]
else:
layer_device = device
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
Expand Down Expand Up @@ -1574,10 +1618,20 @@ def update(
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor]:
for state_str, state_device, self_state_device in [
("key_states", key_states.device, self.key_cache[layer_idx].device),
("value_states", value_states.device, self.value_cache[layer_idx].device),
]:
if state_device != self_state_device:
raise ValueError(
f"Computed {state_str} from layer {layer_idx} is on device {state_device} "
f"whereas stored {state_str} is on device {self_state_device}. "
f"If you are manually initializing the cache, make sure to pass the argument `layer_device_map` if you are using multi-gpu. "
" Otherwise, you can just pass `cache_implementation` in `model.generate()` to correctly initialize the cache."
)

cache_position = cache_kwargs.get("cache_position")
sliding_window = cache_kwargs.get("sliding_window")
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
if sliding_window:
Expand Down
27 changes: 27 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,12 +1449,39 @@ def _get_cache(
# models. May cause trobles with non-text modalities.
cache_dtype = self.get_output_embeddings().weight.dtype

def get_layer_device_map(execution_device_map: Optional[dict] = None):
if execution_device_map is None or len(execution_device_map) <= 1:
return None
layer_device_map = {}
for layer in execution_device_map:
for idx in range(self.config.num_hidden_layers):
if f".{idx}." in f"{layer}.":
layer_device_map[idx] = execution_device_map[layer]
break
for idx in range(self.config.num_hidden_layers):
if idx not in layer_device_map:
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
return layer_device_map

execution_device_map = None
# Taken from dispatch_model from accelerate.
# This is needed here if we don't want to make changes in accelerate in order to save execution_device
# For offloaded case, we need to get the execution device, not just the device where it is offloaded
if hasattr(self, "hf_device_map"):
main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0]
execution_device_map = {
name: main_device if device in ["cpu", "disk"] else device
for name, device in self.hf_device_map.items()
}
layer_device_map = get_layer_device_map(execution_device_map)

cache_kwargs = {
"config": self.config,
"batch_size": batch_size,
"max_cache_len": max_cache_len,
"device": device,
"dtype": cache_dtype,
"layer_device_map": layer_device_map,
}
self._cache = cache_cls(**cache_kwargs)
if requires_cross_attention_cache:
Expand Down
84 changes: 84 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3381,6 +3381,90 @@ def test_special_tokens_fall_back_to_model_default(self):
self.assertTrue(test_bos_id == gen_output[0, 0])
self.assertTrue(generation_config.bos_token_id is None)

@pytest.mark.generate
@require_torch_multi_gpu
def test_generate_with_static_cache_multi_gpu(self):
"""
Tests if the static cache has been set correctly and if generate works correctly when we are using multi-gpus.
"""
# need to split manually as auto doesn't work well with unbalanced model
device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")

text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)

generation_kwargs = {
"max_new_tokens": 20,
"cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
}

results = model.generate(input_ids, **generation_kwargs)
self.assertTrue(isinstance(results.past_key_values, StaticCache))

# check device of each layer
key_cache_0 = results.past_key_values.key_cache[0]
value_cache_0 = results.past_key_values.value_cache[0]
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))

key_cache_1 = results.past_key_values.key_cache[1]
value_cache_1 = results.past_key_values.value_cache[1]
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))

@pytest.mark.generate
@require_torch_multi_gpu
def test_init_static_cache_multi_gpu(self):
"""
Tests if the static cache has been set correctly when we initialize it manually in a multi-gpu setup.
"""
# need to split manually as auto doesn't work well with unbalanced model
device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")

text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)

generation_kwargs = {
"max_new_tokens": 20,
"return_dict_in_generate": True, # Required to return `past_key_values`
}

with self.assertRaisesRegex(ValueError, "If you are manually initializing the cache"):
past_key_values = StaticCache(
config=model.config, batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype
)
results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)

# deduced from the device_map : layer 0 on device 0 and layer 1 on device 1
layer_device_map = {0: 0, 1: 1}
past_key_values = StaticCache(
config=model.config,
batch_size=1,
max_cache_len=30,
device=torch_device,
dtype=model.dtype,
layer_device_map=layer_device_map,
)
results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)

# check device of each layer
key_cache_0 = results.past_key_values.key_cache[0]
value_cache_0 = results.past_key_values.value_cache[0]
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))

key_cache_1 = results.past_key_values.key_cache[1]
value_cache_1 = results.past_key_values.value_cache[1]
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))


@require_torch
class TokenHealingTestCase(unittest.TestCase):
Expand Down
Loading