Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

a #4

Open
wants to merge 10 commits into
base: corvo
Choose a base branch
from
Open

a #4

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,30 @@
"unordered_set": "cpp",
"future": "cpp",
"cfenv": "cpp",
"typeindex": "cpp"
"typeindex": "cpp",
"__bit_reference": "cpp",
"__bits": "cpp",
"__config": "cpp",
"__debug": "cpp",
"__errc": "cpp",
"__hash_table": "cpp",
"__locale": "cpp",
"__mutex_base": "cpp",
"__node_handle": "cpp",
"__split_buffer": "cpp",
"__threading_support": "cpp",
"__tree": "cpp",
"__tuple": "cpp",
"__verbose_abort": "cpp",
"bit": "cpp",
"ios": "cpp",
"locale": "cpp",
"queue": "cpp",
"stack": "cpp",
"variant": "cpp",
"__nullptr": "cpp",
"__string": "cpp",
"compare": "cpp",
"concepts": "cpp"
}
}
}
11 changes: 11 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ add_library(transformer-shared SHARED
$<TARGET_OBJECTS:FfnLayer>
$<TARGET_OBJECTS:FusedAttentionLayer>
$<TARGET_OBJECTS:GptContextAttentionLayer>
$<TARGET_OBJECTS:LlamaContextAttentionLayer>
$<TARGET_OBJECTS:LlamaDecoderSelfAttentionLayer>
$<TARGET_OBJECTS:GptJ>
$<TARGET_OBJECTS:GptJContextDecoder>
$<TARGET_OBJECTS:GptJDecoder>
Expand All @@ -353,6 +355,12 @@ add_library(transformer-shared SHARED
$<TARGET_OBJECTS:ParallelGptDecoderLayerWeight>
$<TARGET_OBJECTS:ParallelGptTritonBackend>
$<TARGET_OBJECTS:ParallelGptWeight>
$<TARGET_OBJECTS:Llama>
$<TARGET_OBJECTS:LlamaContextDecoder>
$<TARGET_OBJECTS:LlamaDecoder>
$<TARGET_OBJECTS:LlamaDecoderLayerWeight>
$<TARGET_OBJECTS:LlamaTritonBackend>
$<TARGET_OBJECTS:LlamaWeight>
$<TARGET_OBJECTS:T5Common>
$<TARGET_OBJECTS:T5Decoder>
$<TARGET_OBJECTS:T5Decoding>
Expand All @@ -361,6 +369,8 @@ add_library(transformer-shared SHARED
$<TARGET_OBJECTS:T5EncoderTritonBackend>
$<TARGET_OBJECTS:TensorParallelDecoderCrossAttentionLayer>
$<TARGET_OBJECTS:TensorParallelDecoderSelfAttentionLayer>
$<TARGET_OBJECTS:TensorParallelLlamaDecoderSelfAttentionLayer>
$<TARGET_OBJECTS:TensorParallelLlamaContextAttentionLayer>
$<TARGET_OBJECTS:TensorParallelDisentangledAttentionLayer>
$<TARGET_OBJECTS:TensorParallelGeluFfnLayer>
$<TARGET_OBJECTS:TensorParallelSiluFfnLayer>
Expand Down Expand Up @@ -394,6 +404,7 @@ add_library(transformer-shared SHARED
$<TARGET_OBJECTS:fpA_intB_gemm>
$<TARGET_OBJECTS:gen_relative_pos_bias>
$<TARGET_OBJECTS:gpt_kernels>
$<TARGET_OBJECTS:repeat_kv_kernels>
$<TARGET_OBJECTS:int8_gemm>
$<TARGET_OBJECTS:layernorm_int8_kernels>
$<TARGET_OBJECTS:layernorm_kernels>
Expand Down
1 change: 1 addition & 0 deletions examples/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ add_subdirectory(vit_int8)
add_subdirectory(wenet)

add_subdirectory(gptj)
add_subdirectory(llama)
add_subdirectory(gptneox)
add_subdirectory(multi_gpu_gpt)

Expand Down
22 changes: 22 additions & 0 deletions examples/cpp/llama/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

add_executable(llama_example llama_example.cc)
target_link_libraries(llama_example PUBLIC -lcublas -lcublasLt -lcudart
Llama nvtx_utils gpt_example_utils word_list mpi_utils nccl_utils)

add_executable(llama_triton_example llama_triton_example.cc)
target_link_libraries(llama_triton_example PUBLIC -lcublas -lcublasLt -lcudart -lpthread
LlamaTritonBackend TransformerTritonBackend custom_ar_comm
gpt_example_utils word_list mpi_utils nccl_utils nvtx_utils)
2 changes: 2 additions & 0 deletions examples/cpp/llama/bad_words.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
7768,3908
1,2
16 changes: 16 additions & 0 deletions examples/cpp/llama/check_with_huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import transformers

from transformers import LlamaForCausalLM, LlamaTokenizer

tokenizer = LlamaTokenizer.from_pretrained('/data/llama-7b-hf')

prompt = "Hey, are you consciours? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors='pt')
model = LlamaForCausalLM.from_pretrained("/data/llama-7b-hf")
hf_config = vars(model.config)
print(hf_config)
generated_ids = model.forward(inputs.input_ids, output_hidden_states=True)
print(generated_ids)

tokens = [0,18637,29892,526,366,1136,455,2470,29973,1815,366,5193,304,592,29973,18637,29892,526,366,1136,455,2470,29973,1815,366,5193,304,592,29973,18637,29892,526,366,1136,455,2470,29973,1815,366,5193,304,592,29973,18637,29892,526,366]
print(tokenizer.decode(tokens))
233 changes: 233 additions & 0 deletions examples/cpp/llama/huggingface_llama_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import configparser
import numpy as np
from pathlib import Path

import torch
import os
from transformers import LlamaForCausalLM, AutoConfig

def get_weight_data_type(data_type):
if data_type == "fp32":
return np.float32
elif data_type == "fp16":
return np.float16
else:
assert False, f"Invalid weight data type {data_type}"


def split_and_convert_process(saved_dir, factor, key, val):
if key.find("input_layernorm.weight") != -1 or key.find("post_attention_layernorm.weight") != -1:
# shared weights, only need to convert the weights of rank 0
saved_path = saved_dir + "/" + key + ".bin"
val.tofile(saved_path)
elif key.find("attention.dense.weight") != -1 or key.find("mlp.down_proj.weight") != -1:
split_vals = np.split(val, factor, axis=0)
for j in range(factor):
saved_path = saved_dir + "/" + key + ".%d.bin" % j
split_vals[j].tofile(saved_path)
elif key.find("mlp.gate_proj.weight") != -1 or key.find("mlp.up_proj.weight") != -1:
split_vals = np.split(val, factor, axis=-1)
for j in range(factor):
saved_path = saved_dir + "/" + key + ".%d.bin" % j
split_vals[j].tofile(saved_path)
elif key.find("attention.query_key_value.weight") != -1:
split_vals = np.split(val, factor, axis=-1)
for j in range(factor):
saved_path = saved_dir + "/" + key + ".%d.bin" % j
split_vals[j].tofile(saved_path)
else:
print("[ERROR] cannot find key '{}'".format(key))

def split_and_convert(args):
saved_dir = args.saved_dir + "/%d-gpu/" % args.infer_gpu_num

if(os.path.exists(saved_dir) == False):
os.makedirs(saved_dir)

t_gpu_num = args.trained_gpu_num
i_gpu_num = args.infer_gpu_num
assert(i_gpu_num % t_gpu_num == 0)

factor = (int)(i_gpu_num / t_gpu_num)
# load position_embedding from rank 0
# model = torch.load(ckpt_name)
print(f'load model from {args.in_file}')
# model = LlamaForCausalLM.from_pretrained(args.in_file, device_map='auto')
config = AutoConfig.from_pretrained(args.in_file)
# num_layers = 3
# config.num_hidden_layers = num_layers
print(config)
state_dict = {}
for f in os.listdir(args.in_file):
if not f.endswith('.bin'):
continue
w = torch.load(os.path.join(args.in_file, f), map_location='cpu')
keys = list(w.keys())
for k in keys:
if 'model.layers.' not in k:
continue
l = int(k.split('.')[2])
if l < config.num_hidden_layers:
continue
del w[k]
state_dict.update(w)

model = LlamaForCausalLM.from_pretrained(None, config=config, state_dict=state_dict)
hf_config = vars(model.config)
print(f"hf_config: {hf_config}")

print("named parameters:")
for name, param in model.named_parameters():
print(f"- {name}")

hidden_size = hf_config["hidden_size"]
head_num = hf_config["num_attention_heads"]
kv_head_num = hf_config["num_key_value_heads"]
head_size = hidden_size // head_num
# num_layers = hf_config["num_hidden_layers"]


np_weight_data_type = get_weight_data_type(args.weight_data_type)

try:
model_name = args.model_name
config = configparser.ConfigParser()
config['llama'] = {}
config['llama']['model_name'] = model_name
config['llama']["head_num"] = str(head_num)
config['llama']["kv_head_num"] = str(kv_head_num)
config['llama']["size_per_head"] = str(head_size)
config['llama']["inter_size"] = str(hf_config["intermediate_size"])
config['llama']["num_layer"] = str(num_layers)
config['llama']["rotary_embedding"] = str(head_size)
config['llama']['layernorm_eps'] = str(hf_config["rms_norm_eps"])
config['llama']["vocab_size"] = str(hf_config["vocab_size"])
config['llama']["start_id"] = str(hf_config["bos_token_id"])
config['llama']["end_id"] = str(hf_config["eos_token_id"])
config['llama']["weight_data_type"] = args.weight_data_type

with open((Path(saved_dir) / f"config.ini").as_posix(), 'w') as configfile:
config.write(configfile)
except Exception as e:
print(f"Fail to save the config in config.ini.")
print(e)

param_to_weights = lambda param: param.detach().cpu().numpy().astype(np_weight_data_type)

# layer-wise weights, example:
# - model.layers.0.self_attn.q_proj.weight
# - model.layers.0.self_attn.k_proj.weight
# - model.layers.0.self_attn.v_proj.weight
# - model.layers.0.self_attn.o_proj.weight
# - model.layers.0.mlp.gate_proj.weight
# - model.layers.0.mlp.down_proj.weight
# - model.layers.0.mlp.up_proj.weight
# - model.layers.0.input_layernorm.weight
# - model.layers.0.post_attention_layernorm.weight
for l in range(num_layers):
print(f"converting layer {l}")
# first merge QKV into a single weight
# concat direct to FT shape: [hidden_size, 3, head_num, head_size]
# copied from huggingface_gptj_ckpt_convert.py
# qkv_weights = np.stack([
# param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.q_proj.weight']),
# param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.k_proj.weight']),
# param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.v_proj.weight']),
# ])
# qkv_weights = np.transpose(qkv_weights, (2, 0, 1))
q_proj = param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.q_proj.weight'])
k_proj = param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.k_proj.weight'])
v_proj = param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.v_proj.weight'])
q_proj = np.split(q_proj, factor, axis=0)
k_proj = np.split(k_proj, factor, axis=0)
v_proj = np.split(v_proj, factor, axis=0)
for j in range(factor):
qkv_weights = np.concatenate((q_proj[j], k_proj[j], v_proj[j]), axis=0)
print(qkv_weights.shape)
# qkv_weights = np.transpose(qkv_weights, (2, 0, 1))
qkv_weights = np.transpose(qkv_weights)
qkv_weights_base_name = f'model.layers.{l}.attention.query_key_value.weight'
saved_path = saved_dir + "/" + qkv_weights_base_name + ".%d.bin" % j
qkv_weights.tofile(saved_path)
# qkv_weights = np.concatenate((
# param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.q_proj.weight']),
# param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.k_proj.weight']),
# param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.v_proj.weight']),
# ), axis=0)
# print(qkv_weights.shape)
# # qkv_weights = np.transpose(qkv_weights, (2, 0, 1))
# qkv_weights = np.transpose(qkv_weights)
# qkv_weights_base_name = f'model.layers.{l}.attention.query_key_value.weight'
# split_and_convert_process(saved_dir, factor, qkv_weights_base_name, qkv_weights)

# attention dense
o_weight = param_to_weights(model.state_dict()[f'model.layers.{l}.self_attn.o_proj.weight']).T
o_weight_base_name = f'model.layers.{l}.attention.dense.weight'
split_and_convert_process(saved_dir, factor, o_weight_base_name, o_weight)

# MLP
mlp_down_weight = param_to_weights(model.state_dict()[f'model.layers.{l}.mlp.down_proj.weight']).T
mlp_down_base_name = f'model.layers.{l}.mlp.down_proj.weight'
split_and_convert_process(saved_dir, factor, mlp_down_base_name, mlp_down_weight)

mlp_gate_weight = param_to_weights(model.state_dict()[f'model.layers.{l}.mlp.gate_proj.weight']).T
mlp_gate_base_name = f'model.layers.{l}.mlp.gate_proj.weight'
split_and_convert_process(saved_dir, factor, mlp_gate_base_name, mlp_gate_weight)

mlp_up_weight = param_to_weights(model.state_dict()[f'model.layers.{l}.mlp.up_proj.weight']).T
mlp_up_base_name = f'model.layers.{l}.mlp.up_proj.weight'
split_and_convert_process(saved_dir, factor, mlp_up_base_name, mlp_up_weight)

# LayerNorm
input_ln_weight = param_to_weights(model.state_dict()[f'model.layers.{l}.input_layernorm.weight'])
input_ln_base_name = f'model.layers.{l}.input_layernorm.weight'
split_and_convert_process(saved_dir, factor, input_ln_base_name, input_ln_weight)

post_attn_ln_weight = param_to_weights(model.state_dict()[f'model.layers.{l}.post_attention_layernorm.weight'])
post_attn_ln_base_name = f'model.layers.{l}.post_attention_layernorm.weight'
split_and_convert_process(saved_dir, factor, post_attn_ln_base_name, post_attn_ln_weight)

print(f"done layer {l}")


# final common weights
for name, param in model.named_parameters():
if name == 'model.embed_tokens.weight':
param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.wte.weight.bin")
elif name == 'model.norm.weight':
param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.final_layernorm.weight.bin")
elif name == 'lm_head.weight':
param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.lm_head.weight.bin")


if __name__ == "__main__":
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('-saved_dir', '-o', type=str, help='file name of output file', required=True)
parser.add_argument('-in_file', '-i', type=str, help='file name of input checkpoint file', required=True)
parser.add_argument('-trained_gpu_num', '-t_g', type=int, help='How many gpus for inference', default=1)
parser.add_argument('-infer_gpu_num', '-i_g', type=int, help='How many gpus for inference', required=True)
parser.add_argument("-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16", "bf16"])
parser.add_argument('-model_name', '-m_n', type=str, help='model name', required=True)

args = parser.parse_args()
print("\n=============== Argument ===============")
for key in vars(args):
print("{}: {}".format(key, vars(args)[key]))
print("========================================")

split_and_convert(args)
Loading