Skip to content

Commit

Permalink
MoE Gemm perf tuning (microsoft#20541)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

This PR supports profiling and tuning MoE Gemm kernels in the very first
run and store the best configuration to reuse in the following runs. The
Gemm id (the key to the config map, int64_t) is determined by num_rows,
gemm_n and gemm_k for each type.

First 32 bits are total_rows, next 16 bits are gemm_n, next 16 bits are
gemm_k
int64_t key = total_rows;
key = key << 16 | gemm_n;
key = key << 16 | gemm_k;

Mixtral-fp16 on 2 A100 with tp=2. batch size = 1, seq_len = 1k
|  | Prompt | Token |
| :---         |     :---:      |          ---: |
| before   | 138ms     | 16.4ms    |
| after      | 100ms       | 13.9ms      |


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
wangyems authored and poweiw committed Jun 25, 2024
1 parent da9ac9a commit fe88e84
Show file tree
Hide file tree
Showing 8 changed files with 1,039 additions and 828 deletions.
64 changes: 64 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#pragma once

#include <string>

namespace ort_fastertransformer {
// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape
// in the kernel layout details when doing weight only quantization.
Expand Down Expand Up @@ -120,6 +122,68 @@ struct CutlassGemmConfig {
mainloop_schedule(mainloop_schedule),
epilogue_schedule(epilogue_schedule),
cluster_shape(cluster_shape) {}

CutlassGemmConfig& operator=(const CutlassGemmConfig& other) {
tile_config = other.tile_config;
split_k_style = other.split_k_style;
split_k_factor = other.split_k_factor;
stages = other.stages;
return *this;
}

std::string to_string() {
std::string str = "tile_config: ";
switch (tile_config) {
case CutlassTileConfig::Undefined:
str += "Undefined";
break;
case CutlassTileConfig::ChooseWithHeuristic:
str += "ChooseWithHeuristic";
break;
case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8:
str += "CtaShape128x128x8_WarpShape64x64x8";
break;
case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
str += "CtaShape16x128x64_WarpShape16x32x64";
break;
case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
str += "CtaShape32x128x64_WarpShape32x32x64";
break;
case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
str += "CtaShape64x128x64_WarpShape32x64x64";
break;
case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64:
str += "CtaShape64x64x128_WarpShape32x64x64";
break;
case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
str += "CtaShape64x128x64_WarpShape64x32x64";
break;
case CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64:
str += "CtaShape128x64x64_WarpShape64x32x64";
break;
case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64:
str += "CtaShape128x128x64_WarpShape64x32x64";
break;
case CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64:
str += "CtaShape128x128x64_WarpShape64x64x64";
break;
case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
str += "CtaShape128x128x64_WarpShape128x32x64";
break;
case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64:
str += "CtaShape128x256x64_WarpShape64x64x64";
break;
case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64:
str += "CtaShape256x128x64_WarpShape64x64x64";
break;
case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
str += "CtaShape16x256x64_WarpShape16x64x64";
break;
}
str += ", stages: ";
str += std::to_string(stages);
return str;
}
};

} // namespace ort_fastertransformer
152 changes: 103 additions & 49 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,35 @@ struct TileShape {

TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) {
switch (tile_config) {
case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
return TileShape{16, 128};
case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64:
return TileShape{16, 256};
case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
return TileShape{32, 128};
case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64:
return TileShape{64, 64};
case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
return TileShape{64, 128};
case CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64:
return TileShape{128, 64};
case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8:
case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64:
case CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64:
case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
return TileShape{128, 128};
case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64:
return TileShape{128, 256};
case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64:
return TileShape{256, 128};
default:
ORT_THROW("[FT Error][get_grid_shape_for_config] Invalid config");
ORT_THROW("[get_grid_shape_for_config] Invalid config");
}
}

bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, const TileShape tile_shape,
const int split_k_factor, const size_t workspace_bytes, const bool is_weight_only) {
int const split_k_factor, const size_t workspace_bytes, bool const is_weight_only) {
// All tile sizes have a k_tile of 64.
static constexpr int k_tile = 64;

Expand All @@ -58,64 +71,105 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k,
return false;
}

const int k_elements_per_split = static_cast<int>(k / split_k_factor);
int const k_elements_per_split = static_cast<int>(k / split_k_factor);
if ((k_elements_per_split % k_tile) != 0) {
return false;
}
}

// Check that the workspace has sufficient space for this split-k factor
const size_t ctas_in_m_dim = static_cast<int>((m + tile_shape.m - 1) / tile_shape.m);
const size_t ctas_in_n_dim = static_cast<int>((n + tile_shape.n - 1) / tile_shape.n);
const size_t required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;
int const ctas_in_m_dim = static_cast<int>((m + tile_shape.m - 1) / tile_shape.m);
int const ctas_in_n_dim = static_cast<int>((n + tile_shape.n - 1) / tile_shape.n);
int const required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;

if (required_ws_bytes > workspace_bytes) {
if (static_cast<size_t>(required_ws_bytes) > workspace_bytes) {
return false;
}

return true;
}

std::vector<CutlassTileConfig> get_candidate_tiles(const bool is_weight_only, const bool simt_configs_only) {
std::vector<CutlassTileConfig> simt_configs{CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};

std::vector<CutlassTileConfig> square_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64};
std::vector<CutlassTileConfig> get_candidate_tiles(
int const sm, bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only) {
enum class CutlassGemmType : char {
Default,
WeightOnly,
Simt,
Int8
};

CutlassGemmType gemm_type = CutlassGemmType::Default;
if (simt_configs_only) {
gemm_type = CutlassGemmType::Simt;
} else if (is_weight_only) {
gemm_type = CutlassGemmType::WeightOnly;
} else if (int8_configs_only) {
gemm_type = CutlassGemmType::Int8;
}

std::vector<CutlassTileConfig> quant_B_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};
std::vector<CutlassTileConfig> base_configs{
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64};
if (sm >= 75) {
base_configs.push_back(CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64);
}

const std::vector<CutlassTileConfig> allowed_configs = is_weight_only ? quant_B_configs : square_configs;
return simt_configs_only ? simt_configs : allowed_configs;
switch (gemm_type) {
case CutlassGemmType::Simt:
return {CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};
case CutlassGemmType::WeightOnly:
if (sm >= 75) {
return {CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64,
CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64,
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};
} else {
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64};
}
case CutlassGemmType::Int8:
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64,
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64};
default:
return base_configs;
}
}

std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only) {
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(is_weight_only, simt_configs_only);
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, bool const is_weight_only, bool const simt_configs_only,
bool const int8_configs_only, int const max_split_k) {
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(sm, is_weight_only, simt_configs_only, int8_configs_only);

std::vector<CutlassGemmConfig> candidate_configs;
const int min_stages = 2;
const int max_stages = sm >= 80 ? 4 : 2;

for (const auto& tile_config : tiles) {
int const min_stages = int8_configs_only ? 3 : 2;
int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);
for (auto const& tile_config : tiles) {
for (int stages = min_stages; stages <= max_stages; ++stages) {
CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages};
CutlassGemmConfig config(tile_config, SplitKStyle::NO_SPLIT_K, 1, stages);
candidate_configs.push_back(config);
if (sm >= 75) {
for (int split_k_factor = 2; split_k_factor <= max_split_k; ++split_k_factor) {
candidate_configs.push_back(
CutlassGemmConfig{tile_config, SplitKStyle::SPLIT_K_SERIAL, split_k_factor, stages});
}
}
}
}

return candidate_configs;
}

CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs,
const std::vector<int>& occupancies, const int64_t m,
CutlassGemmConfig estimate_best_config_from_occupancies(std::vector<CutlassGemmConfig> const& candidate_configs,
std::vector<int> const& occupancies, const int64_t m,
const int64_t n, const int64_t k, const int64_t,
const int split_k_limit, const size_t workspace_bytes,
const int multi_processor_count, const int is_weight_only) {
int const split_k_limit, const size_t workspace_bytes,
int const multi_processor_count, int const is_weight_only) {
if (occupancies.size() != candidate_configs.size()) {
ORT_THROW(
"[FT Error][estimate_best_config_from_occupancies] occpancies and "
"[estimate_best_config_from_occupancies] occpancies and "
"candidate configs vectors must have equal length.");
}

Expand All @@ -126,7 +180,7 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
int config_waves = INT_MAX;
int current_m_tile = 0;

const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit;
int const max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit;
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
CutlassGemmConfig candidate_config = candidate_configs[ii];
TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config);
Expand All @@ -142,34 +196,34 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
continue;
}

const int ctas_in_m_dim = static_cast<int>((m + tile_shape.m - 1) / tile_shape.m);
const int ctas_in_n_dim = static_cast<int>((n + tile_shape.n - 1) / tile_shape.n);
int const ctas_in_m_dim = static_cast<int>((m + tile_shape.m - 1) / tile_shape.m);
int const ctas_in_n_dim = static_cast<int>((n + tile_shape.n - 1) / tile_shape.n);

for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) {
if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) {
const int ctas_per_wave = occupancy * multi_processor_count;
const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor;
int const ctas_per_wave = occupancy * multi_processor_count;
int const ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor;

const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave;
const float num_waves_fractional = ctas_for_problem / static_cast<float>(ctas_per_wave);
const float current_score = static_cast<float>(num_waves_total) - num_waves_fractional;
int const num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave;
float const num_waves_fractional = ctas_for_problem / static_cast<float>(ctas_per_wave);
float const current_score = static_cast<float>(num_waves_total) - num_waves_fractional;

const float score_slack = 0.1f;
if (current_score < config_score ||
((config_waves > num_waves_total) && (current_score < config_score + score_slack))) {
constexpr float score_slack = 0.1f;
if (current_score < config_score || ((config_waves > num_waves_total) &&
(current_score < config_score + score_slack))) {
config_score = current_score;
config_waves = num_waves_total;
SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
best_config =
CutlassGemmConfig{candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
best_config = CutlassGemmConfig(
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages);
current_m_tile = tile_shape.m;
} else if (current_score == config_score &&
(best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor ||
current_m_tile < tile_shape.m)) {
} else if (current_score == config_score && (best_config.stages < candidate_config.stages ||
split_k_factor < best_config.split_k_factor ||
current_m_tile < tile_shape.m)) {
// Prefer deeper pipeline or smaller split-k
SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
best_config =
CutlassGemmConfig{candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
best_config = CutlassGemmConfig(
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages);
current_m_tile = tile_shape.m;
config_waves = num_waves_total;
}
Expand All @@ -178,7 +232,7 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
}

if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) {
ORT_THROW("[FT Error] Heurisitc failed to find a valid config.");
ORT_THROW("Heurisitc failed to find a valid config.");
}

return best_config;
Expand Down
11 changes: 6 additions & 5 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ using namespace onnxruntime;

namespace ort_fastertransformer {

std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only);
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, bool const is_weight_only, bool const simt_configs_only,
bool const int8_configs_only = false, int const max_split_k = 1);

CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs,
const std::vector<int>& occupancies, const int64_t m,
CutlassGemmConfig estimate_best_config_from_occupancies(std::vector<CutlassGemmConfig> const& candidate_configs,
std::vector<int> const& occupancies, const int64_t m,
const int64_t n, const int64_t k, const int64_t num_experts,
const int split_k_limit, const size_t workspace_bytes,
const int multi_processor_count, const int is_weight_only);
int const split_k_limit, const size_t workspace_bytes,
int const multi_processor_count, int const is_weight_only);

} // namespace ort_fastertransformer
Loading

0 comments on commit fe88e84

Please sign in to comment.