Skip to content

Commit

Permalink
fix dtype of config.head_dim to int for directml llm example (#1159)
Browse files Browse the repository at this point in the history
fix `config.head_dim` to `int` type. The changes in #1138 made the
head_dim param to type `float`.
  • Loading branch information
thevishalagarwal authored May 15, 2024
1 parent 5b1a79b commit 8467d56
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion examples/directml/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def set_config_parameters(tokenizer: transformers.AutoTokenizer, repo_id: str, n

config.hidden_size = llm_model.config.hidden_size
config.num_heads = llm_model.config.num_attention_heads
config.head_dim = getattr(llm_model.config, "head_dim", config.hidden_size / config.num_heads)
config.head_dim = getattr(llm_model.config, "head_dim", config.hidden_size // config.num_heads)
config.num_layers = num_layers or llm_model.config.num_hidden_layers
config.vocab_size = llm_model.config.vocab_size
config.model_type = main_model.config.model_type
Expand Down

0 comments on commit 8467d56

Please sign in to comment.