Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MoE Gemm perf tuning #20541

Merged
merged 8 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
static_cast<int>(k_), reinterpret_cast<char*>(work_space.get()), reinterpret_cast<CudaT*>(fc2_output.get()),
reinterpret_cast<CudaT*>(expert_scales.get()),
reinterpret_cast<int*>(expanded_source_row_to_expanded_dest_row.get()),
reinterpret_cast<int*>(expert_for_source_row.get()), Stream(context));
reinterpret_cast<int*>(expert_for_source_row.get()), Stream(context), best_config_map_ptr_->map);

Tensor* output = context->Output(0, input->Shape());

Expand Down
64 changes: 64 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#pragma once

#include <string>

namespace ort_fastertransformer {
// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape
// in the kernel layout details when doing weight only quantization.
Expand Down Expand Up @@ -120,6 +122,68 @@
mainloop_schedule(mainloop_schedule),
epilogue_schedule(epilogue_schedule),
cluster_shape(cluster_shape) {}

CutlassGemmConfig& operator=(const CutlassGemmConfig& other) {
tile_config = other.tile_config;
split_k_style = other.split_k_style;
split_k_factor = other.split_k_factor;
stages = other.stages;
return *this;
}

std::string to_string() {
std::string str = "tile_config: ";
switch (tile_config) {
case CutlassTileConfig::Undefined:
str += "Undefined";
break;
case CutlassTileConfig::ChooseWithHeuristic:
str += "ChooseWithHeuristic";
break;
case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8:
str += "CtaShape128x128x8_WarpShape64x64x8";
break;
case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
str += "CtaShape16x128x64_WarpShape16x32x64";
break;
case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
str += "CtaShape32x128x64_WarpShape32x32x64";
break;
case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
str += "CtaShape64x128x64_WarpShape32x64x64";
break;
case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64:
str += "CtaShape64x64x128_WarpShape32x64x64";
break;
case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
str += "CtaShape64x128x64_WarpShape64x32x64";
break;
case CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64:
str += "CtaShape128x64x64_WarpShape64x32x64";
break;
case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64:
str += "CtaShape128x128x64_WarpShape64x32x64";
break;
case CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64:
str += "CtaShape128x128x64_WarpShape64x64x64";
break;
case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
str += "CtaShape128x128x64_WarpShape128x32x64";
break;
case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64:
str += "CtaShape128x256x64_WarpShape64x64x64";
break;
case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64:
str += "CtaShape256x128x64_WarpShape64x64x64";
break;
case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
str += "CtaShape16x256x64_WarpShape16x64x64";
break;
};

Check warning on line 182 in onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h:182: You don't need a ; after a } [readability/braces] [4]
str += ", stages: ";
str += std::to_string(stages);
return str;
}
};

} // namespace ort_fastertransformer
156 changes: 105 additions & 51 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,35 @@

TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) {
switch (tile_config) {
case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
return TileShape{16, 128};
case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
return TileShape{16, 256};
case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
return TileShape{32, 128};
case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64:
return TileShape{64, 64};
case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
return TileShape{64, 128};
case CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64:
return TileShape{128, 64};
case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8:
case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64:
case CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64:
case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
return TileShape{128, 128};
case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64:
return TileShape{128, 256};
case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64:
return TileShape{256, 128};
default:
ORT_THROW("[FT Error][get_grid_shape_for_config] Invalid config");
throw std::runtime_error("[TensorRT-LLm Error][get_grid_shape_for_config] Invalid config");
}
}

bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, const TileShape tile_shape,
const int split_k_factor, const size_t workspace_bytes, const bool is_weight_only) {
int const split_k_factor, const size_t workspace_bytes, bool const is_weight_only) {
// All tile sizes have a k_tile of 64.
static constexpr int k_tile = 64;

Expand All @@ -58,64 +71,105 @@
return false;
}

const int k_elements_per_split = static_cast<int>(k / split_k_factor);
int const k_elements_per_split = k / split_k_factor;
if ((k_elements_per_split % k_tile) != 0) {
return false;
}
}

// Check that the workspace has sufficient space for this split-k factor
const size_t ctas_in_m_dim = static_cast<int>((m + tile_shape.m - 1) / tile_shape.m);
const size_t ctas_in_n_dim = static_cast<int>((n + tile_shape.n - 1) / tile_shape.n);
const size_t required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;
int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
int const required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;

if (required_ws_bytes > workspace_bytes) {
if (static_cast<size_t>(required_ws_bytes) > workspace_bytes) {
return false;
}

return true;
}

std::vector<CutlassTileConfig> get_candidate_tiles(const bool is_weight_only, const bool simt_configs_only) {
std::vector<CutlassTileConfig> simt_configs{CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};

std::vector<CutlassTileConfig> square_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64};
std::vector<CutlassTileConfig> get_candidate_tiles(
int const sm, bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only) {
enum class CutlassGemmType : char {
Default,
WeightOnly,
Simt,
Int8
};

CutlassGemmType gemm_type = CutlassGemmType::Default;
if (simt_configs_only) {
gemm_type = CutlassGemmType::Simt;
} else if (is_weight_only) {
gemm_type = CutlassGemmType::WeightOnly;
} else if (int8_configs_only) {
gemm_type = CutlassGemmType::Int8;
}

std::vector<CutlassTileConfig> quant_B_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};
std::vector<CutlassTileConfig> base_configs{
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64};
if (sm >= 75) {
base_configs.push_back(CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64);
}

const std::vector<CutlassTileConfig> allowed_configs = is_weight_only ? quant_B_configs : square_configs;
return simt_configs_only ? simt_configs : allowed_configs;
switch (gemm_type) {
case CutlassGemmType::Simt:
return {CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};
case CutlassGemmType::WeightOnly:
if (sm >= 75) {
return {CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64,
CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64,
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};
} else {
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64};
}
case CutlassGemmType::Int8:
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64,
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64};
default:
return base_configs;
}
}

std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only) {
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(is_weight_only, simt_configs_only);
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, bool const is_weight_only, bool const simt_configs_only,
bool const int8_configs_only, int const max_split_k) {
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, is_weight_only, simt_configs_only, int8_configs_only);

std::vector<CutlassGemmConfig> candidate_configs;
const int min_stages = 2;
const int max_stages = sm >= 80 ? 4 : 2;

for (const auto& tile_config : tiles) {
int const min_stages = int8_configs_only ? 3 : 2;
int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);
for (auto const& tile_config : tiles) {
for (int stages = min_stages; stages <= max_stages; ++stages) {
CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages};
CutlassGemmConfig config(tile_config, SplitKStyle::NO_SPLIT_K, 1, stages);
candidate_configs.push_back(config);
if (sm >= 75) {
for (int split_k_factor = 2; split_k_factor <= max_split_k; ++split_k_factor) {
auto config = CutlassGemmConfig{tile_config, SplitKStyle::SPLIT_K_SERIAL, split_k_factor, stages};
candidate_configs.push_back(config);
}
}
}
}

return candidate_configs;
}

CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs,
const std::vector<int>& occupancies, const int64_t m,
CutlassGemmConfig estimate_best_config_from_occupancies(std::vector<CutlassGemmConfig> const& candidate_configs,
std::vector<int> const& occupancies, const int64_t m,
const int64_t n, const int64_t k, const int64_t,
const int split_k_limit, const size_t workspace_bytes,
const int multi_processor_count, const int is_weight_only) {
int const split_k_limit, const size_t workspace_bytes,
int const multi_processor_count, int const is_weight_only) {
if (occupancies.size() != candidate_configs.size()) {
ORT_THROW(
"[FT Error][estimate_best_config_from_occupancies] occpancies and "
throw std::runtime_error(
wangyems marked this conversation as resolved.
Show resolved Hide resolved
"[TensorRT-LLm Error][estimate_best_config_from_occupancies] occpancies and "
"candidate configs vectors must have equal length.");
}

Expand All @@ -126,8 +180,8 @@
int config_waves = INT_MAX;
int current_m_tile = 0;

const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit;
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
int const max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit;
for (int ii = 0; ii < static_cast<int>(candidate_configs.size()); ++ii) {
wangyems marked this conversation as resolved.
Show resolved Hide resolved
CutlassGemmConfig candidate_config = candidate_configs[ii];
TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config);
int occupancy = occupancies[ii];
Expand All @@ -142,34 +196,34 @@
continue;
}

const int ctas_in_m_dim = static_cast<int>((m + tile_shape.m - 1) / tile_shape.m);
const int ctas_in_n_dim = static_cast<int>((n + tile_shape.n - 1) / tile_shape.n);
int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;

for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) {
if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) {
const int ctas_per_wave = occupancy * multi_processor_count;
const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor;
int const ctas_per_wave = occupancy * multi_processor_count;
int const ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor;

const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave;
const float num_waves_fractional = ctas_for_problem / static_cast<float>(ctas_per_wave);
const float current_score = static_cast<float>(num_waves_total) - num_waves_fractional;
int const num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave;
float const num_waves_fractional = ctas_for_problem / float(ctas_per_wave);

Check warning on line 208 in onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<float>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc:208: Using deprecated casting style. Use static_cast<float>(...) instead [readability/casting] [4]
wangyems marked this conversation as resolved.
Show resolved Hide resolved
float const current_score = float(num_waves_total) - num_waves_fractional;

Check warning on line 209 in onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<float>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc:209: Using deprecated casting style. Use static_cast<float>(...) instead [readability/casting] [4]

const float score_slack = 0.1f;
if (current_score < config_score ||
((config_waves > num_waves_total) && (current_score < config_score + score_slack))) {
float const score_slack = 0.1f;
wangyems marked this conversation as resolved.
Show resolved Hide resolved
if (current_score < config_score || ((config_waves > num_waves_total) &&
(current_score < config_score + score_slack))) {
config_score = current_score;
config_waves = num_waves_total;
SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
best_config =
CutlassGemmConfig{candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
best_config = CutlassGemmConfig(
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages);
current_m_tile = tile_shape.m;
} else if (current_score == config_score &&
(best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor ||
current_m_tile < tile_shape.m)) {
} else if (current_score == config_score && (best_config.stages < candidate_config.stages ||
split_k_factor < best_config.split_k_factor ||
current_m_tile < tile_shape.m)) {
// Prefer deeper pipeline or smaller split-k
SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
best_config =
CutlassGemmConfig{candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
best_config = CutlassGemmConfig(
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages);
current_m_tile = tile_shape.m;
config_waves = num_waves_total;
}
Expand All @@ -178,7 +232,7 @@
}

if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) {
ORT_THROW("[FT Error] Heurisitc failed to find a valid config.");
throw std::runtime_error("[TensorRT-LLm Error] Heurisitc failed to find a valid config.");
wangyems marked this conversation as resolved.
Show resolved Hide resolved
}

return best_config;
Expand Down
11 changes: 6 additions & 5 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ using namespace onnxruntime;

namespace ort_fastertransformer {

std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only);
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, bool const is_weight_only, bool const simt_configs_only,
bool const int8_configs_only = false, int const max_split_k = 1);

CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs,
const std::vector<int>& occupancies, const int64_t m,
CutlassGemmConfig estimate_best_config_from_occupancies(std::vector<CutlassGemmConfig> const& candidate_configs,
std::vector<int> const& occupancies, const int64_t m,
const int64_t n, const int64_t k, const int64_t num_experts,
const int split_k_limit, const size_t workspace_bytes,
const int multi_processor_count, const int is_weight_only);
int const split_k_limit, const size_t workspace_bytes,
int const multi_processor_count, int const is_weight_only);

} // namespace ort_fastertransformer
Loading
Loading