Skip to content

Commit

Permalink
[INFER] update tune_cublaslt_gemm op and fix some bugs (PaddlePaddle#…
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome authored Oct 11, 2024
1 parent f9eb62e commit 156182e
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 65 deletions.
104 changes: 54 additions & 50 deletions csrc/gpu/tune_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ limitations under the License. */

#include <algorithm>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <limits>
#include <list>
#include <vector>
#include <iomanip>

#include "helper.h"

Expand Down Expand Up @@ -105,6 +105,13 @@ static inline bool time_compare_algo_para(const algoSelect_t& algo_para_a,
return (algo_para_a.time < algo_para_b.time);
}

// 获取当前 GPU 的剩余显存大小(以字节为单位)
size_t get_remaining_memory() {
size_t free, total;
CUDA_CHECK(cudaMemGetInfo(&free, &total));
return free;
}

template <typename InT, typename OutT, typename ScaleT = OutT>
static void TestMatmulRun(cublasLtHandle_t ltHandle,
cublasLtMatmulDesc_t matmulDesc,
Expand All @@ -122,7 +129,10 @@ static void TestMatmulRun(cublasLtHandle_t ltHandle,
cublasLtMatmulHeuristicResult_t heurResult;
cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck(
ltHandle, matmulDesc, A_desc, B_desc, C_desc, C_desc, &algo, &heurResult);
if (algoStatus == CUBLAS_STATUS_SUCCESS) {

auto remainingMemorySize = 0.95 * get_remaining_memory();
if (algoStatus == CUBLAS_STATUS_SUCCESS &&
remainingMemorySize > heurResult.workspaceSize) {
ScaleT alpha = static_cast<ScaleT>(1), beta = static_cast<ScaleT>(0);
void* workSpace;
CUDA_CHECK(cudaMalloc(&workSpace, heurResult.workspaceSize));
Expand Down Expand Up @@ -166,8 +176,13 @@ static void TestMatmulRun(cublasLtHandle_t ltHandle,
}
CUDA_CHECK(cudaFree(workSpace));
} else {
std::cerr << "not enough workspace! current workspace is "
<< heurResult.workspaceSize;
std::cerr << "Not enough workspace! Required "
<< static_cast<double>(heurResult.workspaceSize) / 1024.0 /
1024.0 / 1024.0
<< " GiB" << ", But remaining "
<< static_cast<double>(remainingMemorySize) / 1024.0 / 1024.0 /
1024.0
<< " GiB" << std::endl;
perfResults.status = CUBLAS_STATUS_NOT_SUPPORTED; // Not enough workspace
}
}
Expand Down Expand Up @@ -442,7 +457,7 @@ void FindAlgo(const cublasLtHandle_t& ltHandle,
if (perfResults[i].status != CUBLAS_STATUS_SUCCESS) {
std::clog << "algo " << algos[i].algoId << " tile " << algos[i].tile
<< " stages " << algos[i].stages << " splitK_val "
<< algos[i].splitK_val;
<< algos[i].splitK_val << std::endl;
algos[i].time = std::numeric_limits<float>::max();
std::cerr << " TestMatmulRun with status " << perfResults[i].status
<< std::endl;
Expand All @@ -467,7 +482,7 @@ class DevContext {};
class CPUContext : public DevContext {};

class CUBLASLTContext : public DevContext {
public:
public:
CUBLASLTContext() { CUDA_CHECK(cublasLtCreate(&handle)); }

cublasLtHandle_t handle;
Expand Down Expand Up @@ -709,64 +724,51 @@ void GEMMInt8<int8_t, int32_t, CUBLASLTContext>(const CUBLASLTContext& dev_ctx,
CUDA_CHECK(cudaFree(workSpace));
}

void TuneCublasltGemm(const paddle::Tensor& M,
const paddle::Tensor& K,
void TuneCublasltGemm(const paddle::Tensor& K,
const paddle::Tensor& N,
const int M_start,
const int M_end,
const std::string& dtype,
bool is_test,
bool is_read_from_file,
const bool is_test,
const bool is_read_from_file,
const std::string& path) {
// Ensure that M, K, and N are all one-dimensional Tensors. is_test !=
// is_read_from_file
assert(M.dims().size() == 1 && K.dims().size() == 1 && N.dims().size() == 1);
assert(M_end >= M_start);
assert(M_start >= 1);
assert(K.dims().size() == 1 && N.dims().size() == 1);
assert(is_test != is_read_from_file);

auto M_cpu = M.copy_to(paddle::CPUPlace(), false);
auto K_cpu = K.copy_to(paddle::CPUPlace(), false);
auto N_cpu = N.copy_to(paddle::CPUPlace(), false);
int64_t* M_data = M_cpu.data<int64_t>();
int64_t* K_data = K_cpu.data<int64_t>();
int64_t* N_data = N_cpu.data<int64_t>();

int M_size = M.numel();
int K_size = K.numel();
int N_size = N.numel();
assert(K_size == N_size);

int m_data = (int)M_data[0];
assert(m_data > 0);

std::vector<int> mm;

int m = 1, step = 1;
while (m <= m_data) {
mm.push_back(m);
m += step;

int m = M_start, step = 1;
while (m <= M_end) {
// update step
switch (m) {
case 4:
step = 4;
break;
case 16:
step = 16;
break;
case 64:
step = 32;
break;
case 256:
step = 64;
break;
case 512:
step = 128;
break;
case 1024:
step = 1024;
break;
case 8192:
step = 4096;
break;
if (m >= 8192) {
step = 4096;
} else if (m >= 1024) {
step = 1024;
} else if (m >= 512) {
step = 128;
} else if (m >= 256) {
step = 64;
} else if (m >= 64) {
step = 32;
} else if (m >= 16) {
step = 16;
} else if (m >= 4) {
step = 4;
} else {
step = 1;
}
mm.push_back(m);
m += step;
}

for (int j = 0; j < mm.size(); j++) {
Expand All @@ -792,16 +794,18 @@ void TuneCublasltGemm(const paddle::Tensor& M,
path);
} else {
// other dtype
std::cout << "Not currently supported" << std::endl;
throw std::runtime_error(dtype + "not currently supported");
}
}
}
}

PD_BUILD_OP(tune_cublaslt_gemm)
.Inputs({"M", "K", "N"})
.Inputs({"K", "N"})
.Outputs({})
.Attrs({"dtype: std::string",
.Attrs({"M_start: int",
"M_end: int",
"dtype: std::string",
"is_test: bool",
"is_read_from_file: bool",
"path: std::string"})
Expand Down
11 changes: 8 additions & 3 deletions csrc/utils/tune_cublaslt_int8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import paddle
from paddlenlp_ops import tune_cublaslt_gemm

M_tensor = paddle.to_tensor([32768])
M_start = 1
M_end = 32768

# llama3.1-8b
k1 = [4096, 4096, 4096, 14336]
Expand All @@ -36,7 +37,11 @@
K_tensor = paddle.to_tensor(k1 + k2 + k3 + k4)
N_tensor = paddle.to_tensor(n1 + n2 + n3 + n4)

Dtype = "int8"
Path = "./cublaslt_gemm_search.csv"

tune_cublaslt_gemm(M_tensor, K_tensor, N_tensor, Dtype, True, False, Path)
tune_cublaslt_gemm(K_tensor, N_tensor, M_start, M_end, "int8", True, False, Path)

# shape 计算公式
# [qkv, out_linear, ffn1, ffn2]
# k = [hidden_size, hidden_size, hidden_size, intermediate_size//mp_size]
# n = [((num_attention_heads//mp_size)+2*(num_key_value_heads//mp_size))*(hidden_size//num_attention_heads), hidden_size, 2*(intermediate_size//mp_size), hidden_size]
18 changes: 14 additions & 4 deletions paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,8 @@ def set_state_dict(self, state_dict):
ffn_hidden_size=self.intermediate_size,
num_key_value_heads=self.num_key_value_heads,
mp_size=self.config.tensor_parallel_degree,
concat_qkv=True,
concat_ffn1=True,
)
self.transformer_block.weight_scales = weight_scales_loader.scale
self.transformer_block.act_scales = act_scale_loader.scale
Expand Down Expand Up @@ -1097,16 +1099,24 @@ def set_state_dict(self, state_dict):
dtype=paddle.get_default_dtype(),
)
self.transformer_block.linear_shifts[idx].set_value(
paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.shift_bias".format(idx)])
paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.shift_bias".format(idx)]).astype(
paddle.get_default_dtype()
)
)
self.transformer_block.linear_smooths[idx].set_value(
paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.smooth_weight".format(idx)])
paddle.to_tensor(
state_dict["llama.layers.{}.self_attn.o_proj.smooth_weight".format(idx)]
).astype(paddle.get_default_dtype())
)
self.transformer_block.ffn2_shifts[idx].set_value(
paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.shift_bias".format(idx)])
paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.shift_bias".format(idx)]).astype(
paddle.get_default_dtype()
)
)
self.transformer_block.ffn2_smooths[idx].set_value(
paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.smooth_weight".format(idx)])
paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.smooth_weight".format(idx)]).astype(
paddle.get_default_dtype()
)
)

if self.shift:
Expand Down
16 changes: 12 additions & 4 deletions paddlenlp/experimental/transformers/mixtral/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,16 +716,24 @@ def set_state_dict(self, state_dict):
if "a8w8" in self.quant_type:
if self.shift_smooth_all_linears:
self.transformer_block.linear_shifts[idx].set_value(
paddle.to_tensor(state_dict["mixtral.layers.{}.self_attn.o_proj.shift_bias".format(idx)])
paddle.to_tensor(
state_dict["mixtral.layers.{}.self_attn.o_proj.shift_bias".format(idx)]
).astype(paddle.get_default_dtype())
)
self.transformer_block.linear_smooths[idx].set_value(
paddle.to_tensor(state_dict["mixtral.layers.{}.self_attn.o_proj.smooth_weight".format(idx)])
paddle.to_tensor(
state_dict["mixtral.layers.{}.self_attn.o_proj.smooth_weight".format(idx)]
).astype(paddle.get_default_dtype())
)
self.transformer_block.ffn2_shifts[idx].set_value(
paddle.to_tensor(state_dict["mixtral.layers.{}.mlp.down_proj.shift_bias".format(idx)])
paddle.to_tensor(state_dict["mixtral.layers.{}.mlp.down_proj.shift_bias".format(idx)]).astype(
paddle.get_default_dtype()
)
)
self.transformer_block.ffn2_smooths[idx].set_value(
paddle.to_tensor(state_dict["mixtral.layers.{}.mlp.down_proj.smooth_weight".format(idx)])
paddle.to_tensor(
state_dict["mixtral.layers.{}.mlp.down_proj.smooth_weight".format(idx)]
).astype(paddle.get_default_dtype())
)

if self.shift:
Expand Down
18 changes: 14 additions & 4 deletions paddlenlp/experimental/transformers/qwen2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ def set_state_dict(self, state_dict):
ffn_hidden_size=self.intermediate_size,
num_key_value_heads=self.num_key_value_heads,
mp_size=self.config.tensor_parallel_degree,
concat_qkv=True,
concat_ffn1=True,
)
self.transformer_block.weight_scales = weight_scales_loader.scale
self.transformer_block.act_scales = act_scale_loader.scale
Expand Down Expand Up @@ -704,16 +706,24 @@ def set_state_dict(self, state_dict):
dtype=paddle.get_default_dtype(),
)
self.transformer_block.linear_shifts[idx].set_value(
paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.shift_bias".format(idx)])
paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.shift_bias".format(idx)]).astype(
paddle.get_default_dtype()
)
)
self.transformer_block.linear_smooths[idx].set_value(
paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.smooth_weight".format(idx)])
paddle.to_tensor(
state_dict["qwen2.layers.{}.self_attn.o_proj.smooth_weight".format(idx)]
).astype(paddle.get_default_dtype())
)
self.transformer_block.ffn2_shifts[idx].set_value(
paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.shift_bias".format(idx)])
paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.shift_bias".format(idx)]).astype(
paddle.get_default_dtype()
)
)
self.transformer_block.ffn2_smooths[idx].set_value(
paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.smooth_weight".format(idx)])
paddle.to_tensor(state_dict["qwen2.layers.{}.mlp.down_proj.smooth_weight".format(idx)]).astype(
paddle.get_default_dtype()
)
)

if self.shift:
Expand Down
13 changes: 13 additions & 0 deletions paddlenlp/experimental/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def __init__(
ffn_hidden_size,
num_key_value_heads=-1,
mp_size=1,
concat_qkv=False,
concat_ffn1=False,
):
self.key_map = key_map_dict
self.scale = {}
Expand All @@ -126,6 +128,17 @@ def __init__(
n = num_head * dim_head
self.scale[scale_type] = np.full([num_of_layers, n], fill_value=0.1, dtype="float32")

# concat qkv and ffn1
if concat_qkv:
self.scale["qkv_weight_scale"] = np.full(
[num_of_layers, qkv_out_size // mp_size], fill_value=0.1, dtype="float32"
)

if concat_ffn1:
self.scale["ffn1_weight_scale"] = np.full(
[num_of_layers, ffn_hidden_size * 2 // mp_size], fill_value=0.1, dtype="float32"
)


class EmptyCacheScale:
"""
Expand Down

0 comments on commit 156182e

Please sign in to comment.