Skip to content

Commit

Permalink
Print correct number of key value heads on dimension assertion. (#414)
Browse files Browse the repository at this point in the history
  • Loading branch information
dstripelis authored Apr 14, 2024
1 parent 7e69eb9 commit f9cbbf7
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _load_gqa(config, prefix: str, weights):
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
], f"{list(weight.shape)} != {[(num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size]}"

if config.attention_bias:
w = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _load_gqa(config, prefix: str, weights):
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
], f"{list(weight.shape)} != {[(num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size]}"

return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _load_gqa(config, prefix: str, weights):
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
], f"{list(weight.shape)} != {[(num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size]}"

return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def _load_gqa(config, prefix: str, weights):
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
], f"{list(weight.shape)} != {[(num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size]}"

return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _load_gqa(config, prefix: str, weights):
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
], f"{list(weight.shape)} != {[(num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size]}"

return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _load_gqa(config, prefix: str, weights):
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
], f"{list(weight.shape)} != {[(num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size]}"

return TensorParallelColumnLinear(get_linear(weight, bias=True, quantize=config.quantize))

Expand Down

0 comments on commit f9cbbf7

Please sign in to comment.