From fe88e846204bfde8a41169ce53a7ad1e0f4dfe71 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Mon, 6 May 2024 14:40:44 -0700 Subject: [PATCH] MoE Gemm perf tuning (#20541) ### Description This PR supports profiling and tuning MoE Gemm kernels in the very first run and store the best configuration to reuse in the following runs. The Gemm id (the key to the config map, int64_t) is determined by num_rows, gemm_n and gemm_k for each type. First 32 bits are total_rows, next 16 bits are gemm_n, next 16 bits are gemm_k int64_t key = total_rows; key = key << 16 | gemm_n; key = key << 16 | gemm_k; Mixtral-fp16 on 2 A100 with tp=2. batch size = 1, seq_len = 1k | | Prompt | Token | | :--- | :---: | ---: | | before | 138ms | 16.4ms | | after | 100ms | 13.9ms | ### Motivation and Context --- .../moe/cutlass_extensions/gemm_configs.h | 64 + .../cuda/moe/ft_moe/cutlass_heuristic.cc | 152 +- .../cuda/moe/ft_moe/cutlass_heuristic.h | 11 +- .../cuda/moe/ft_moe/moe_gemm_kernels.h | 46 +- .../moe/ft_moe/moe_gemm_kernels_template.h | 175 ++- .../contrib_ops/cuda/moe/ft_moe/moe_kernel.cu | 1393 ++++++++--------- .../contrib_ops/cuda/moe/ft_moe/moe_kernel.h | 22 +- onnxruntime/contrib_ops/cuda/moe/moe.cc | 4 +- 8 files changed, 1039 insertions(+), 828 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h index 0841218a480ba..12ad9d717766e 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h @@ -16,6 +16,8 @@ #pragma once +#include + namespace ort_fastertransformer { // Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape // in the kernel layout details when doing weight only quantization. @@ -120,6 +122,68 @@ struct CutlassGemmConfig { mainloop_schedule(mainloop_schedule), epilogue_schedule(epilogue_schedule), cluster_shape(cluster_shape) {} + + CutlassGemmConfig& operator=(const CutlassGemmConfig& other) { + tile_config = other.tile_config; + split_k_style = other.split_k_style; + split_k_factor = other.split_k_factor; + stages = other.stages; + return *this; + } + + std::string to_string() { + std::string str = "tile_config: "; + switch (tile_config) { + case CutlassTileConfig::Undefined: + str += "Undefined"; + break; + case CutlassTileConfig::ChooseWithHeuristic: + str += "ChooseWithHeuristic"; + break; + case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + str += "CtaShape128x128x8_WarpShape64x64x8"; + break; + case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + str += "CtaShape16x128x64_WarpShape16x32x64"; + break; + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + str += "CtaShape32x128x64_WarpShape32x32x64"; + break; + case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + str += "CtaShape64x128x64_WarpShape32x64x64"; + break; + case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: + str += "CtaShape64x64x128_WarpShape32x64x64"; + break; + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + str += "CtaShape64x128x64_WarpShape64x32x64"; + break; + case CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: + str += "CtaShape128x64x64_WarpShape64x32x64"; + break; + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + str += "CtaShape128x128x64_WarpShape64x32x64"; + break; + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64: + str += "CtaShape128x128x64_WarpShape64x64x64"; + break; + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + str += "CtaShape128x128x64_WarpShape128x32x64"; + break; + case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: + str += "CtaShape128x256x64_WarpShape64x64x64"; + break; + case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + str += "CtaShape256x128x64_WarpShape64x64x64"; + break; + case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + str += "CtaShape16x256x64_WarpShape16x64x64"; + break; + } + str += ", stages: "; + str += std::to_string(stages); + return str; + } }; } // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc index cd59e904ad9eb..9d84880654766 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc @@ -29,22 +29,35 @@ struct TileShape { TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) { switch (tile_config) { + case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + return TileShape{16, 128}; + case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + return TileShape{16, 256}; case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: return TileShape{32, 128}; + case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: + return TileShape{64, 64}; case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: return TileShape{64, 128}; + case CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: + return TileShape{128, 64}; case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64: case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: return TileShape{128, 128}; + case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: + return TileShape{128, 256}; + case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + return TileShape{256, 128}; default: - ORT_THROW("[FT Error][get_grid_shape_for_config] Invalid config"); + ORT_THROW("[get_grid_shape_for_config] Invalid config"); } } bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, const TileShape tile_shape, - const int split_k_factor, const size_t workspace_bytes, const bool is_weight_only) { + int const split_k_factor, const size_t workspace_bytes, bool const is_weight_only) { // All tile sizes have a k_tile of 64. static constexpr int k_tile = 64; @@ -58,64 +71,105 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, return false; } - const int k_elements_per_split = static_cast(k / split_k_factor); + int const k_elements_per_split = static_cast(k / split_k_factor); if ((k_elements_per_split % k_tile) != 0) { return false; } } // Check that the workspace has sufficient space for this split-k factor - const size_t ctas_in_m_dim = static_cast((m + tile_shape.m - 1) / tile_shape.m); - const size_t ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); - const size_t required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + int const ctas_in_m_dim = static_cast((m + tile_shape.m - 1) / tile_shape.m); + int const ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); + int const required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; - if (required_ws_bytes > workspace_bytes) { + if (static_cast(required_ws_bytes) > workspace_bytes) { return false; } return true; } -std::vector get_candidate_tiles(const bool is_weight_only, const bool simt_configs_only) { - std::vector simt_configs{CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; - - std::vector square_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, - CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64, - CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64}; +std::vector get_candidate_tiles( + int const sm, bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only) { + enum class CutlassGemmType : char { + Default, + WeightOnly, + Simt, + Int8 + }; + + CutlassGemmType gemm_type = CutlassGemmType::Default; + if (simt_configs_only) { + gemm_type = CutlassGemmType::Simt; + } else if (is_weight_only) { + gemm_type = CutlassGemmType::WeightOnly; + } else if (int8_configs_only) { + gemm_type = CutlassGemmType::Int8; + } - std::vector quant_B_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, - CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, - CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; + std::vector base_configs{ + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64}; + if (sm >= 75) { + base_configs.push_back(CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64); + } - const std::vector allowed_configs = is_weight_only ? quant_B_configs : square_configs; - return simt_configs_only ? simt_configs : allowed_configs; + switch (gemm_type) { + case CutlassGemmType::Simt: + return {CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; + case CutlassGemmType::WeightOnly: + if (sm >= 75) { + return {CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64, + CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64, + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; + } else { + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64}; + } + case CutlassGemmType::Int8: + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + default: + return base_configs; + } } -std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only) { - std::vector tiles = get_candidate_tiles(is_weight_only, simt_configs_only); +std::vector get_candidate_configs(int sm, bool const is_weight_only, bool const simt_configs_only, + bool const int8_configs_only, int const max_split_k) { + std::vector tiles = get_candidate_tiles(sm, is_weight_only, simt_configs_only, int8_configs_only); std::vector candidate_configs; - const int min_stages = 2; - const int max_stages = sm >= 80 ? 4 : 2; - - for (const auto& tile_config : tiles) { + int const min_stages = int8_configs_only ? 3 : 2; + int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2); + for (auto const& tile_config : tiles) { for (int stages = min_stages; stages <= max_stages; ++stages) { - CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages}; + CutlassGemmConfig config(tile_config, SplitKStyle::NO_SPLIT_K, 1, stages); candidate_configs.push_back(config); + if (sm >= 75) { + for (int split_k_factor = 2; split_k_factor <= max_split_k; ++split_k_factor) { + candidate_configs.push_back( + CutlassGemmConfig{tile_config, SplitKStyle::SPLIT_K_SERIAL, split_k_factor, stages}); + } + } } } return candidate_configs; } -CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector& candidate_configs, - const std::vector& occupancies, const int64_t m, +CutlassGemmConfig estimate_best_config_from_occupancies(std::vector const& candidate_configs, + std::vector const& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t, - const int split_k_limit, const size_t workspace_bytes, - const int multi_processor_count, const int is_weight_only) { + int const split_k_limit, const size_t workspace_bytes, + int const multi_processor_count, int const is_weight_only) { if (occupancies.size() != candidate_configs.size()) { ORT_THROW( - "[FT Error][estimate_best_config_from_occupancies] occpancies and " + "[estimate_best_config_from_occupancies] occpancies and " "candidate configs vectors must have equal length."); } @@ -126,7 +180,7 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector= multi_processor_count * 256 ? 1 : split_k_limit; + int const max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { CutlassGemmConfig candidate_config = candidate_configs[ii]; TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); @@ -142,34 +196,34 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector((m + tile_shape.m - 1) / tile_shape.m); - const int ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); + int const ctas_in_m_dim = static_cast((m + tile_shape.m - 1) / tile_shape.m); + int const ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) { if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) { - const int ctas_per_wave = occupancy * multi_processor_count; - const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; + int const ctas_per_wave = occupancy * multi_processor_count; + int const ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; - const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; - const float num_waves_fractional = ctas_for_problem / static_cast(ctas_per_wave); - const float current_score = static_cast(num_waves_total) - num_waves_fractional; + int const num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; + float const num_waves_fractional = ctas_for_problem / static_cast(ctas_per_wave); + float const current_score = static_cast(num_waves_total) - num_waves_fractional; - const float score_slack = 0.1f; - if (current_score < config_score || - ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) { + constexpr float score_slack = 0.1f; + if (current_score < config_score || ((config_waves > num_waves_total) && + (current_score < config_score + score_slack))) { config_score = current_score; config_waves = num_waves_total; SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; - best_config = - CutlassGemmConfig{candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; + best_config = CutlassGemmConfig( + candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages); current_m_tile = tile_shape.m; - } else if (current_score == config_score && - (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor || - current_m_tile < tile_shape.m)) { + } else if (current_score == config_score && (best_config.stages < candidate_config.stages || + split_k_factor < best_config.split_k_factor || + current_m_tile < tile_shape.m)) { // Prefer deeper pipeline or smaller split-k SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; - best_config = - CutlassGemmConfig{candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; + best_config = CutlassGemmConfig( + candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages); current_m_tile = tile_shape.m; config_waves = num_waves_total; } @@ -178,7 +232,7 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only); +std::vector get_candidate_configs(int sm, bool const is_weight_only, bool const simt_configs_only, + bool const int8_configs_only = false, int const max_split_k = 1); -CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector& candidate_configs, - const std::vector& occupancies, const int64_t m, +CutlassGemmConfig estimate_best_config_from_occupancies(std::vector const& candidate_configs, + std::vector const& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts, - const int split_k_limit, const size_t workspace_bytes, - const int multi_processor_count, const int is_weight_only); + int const split_k_limit, const size_t workspace_bytes, + int const multi_processor_count, int const is_weight_only); } // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h index 7e29dde8f897b..36127054cfd5e 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h @@ -18,9 +18,34 @@ #include "contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h" #include +#include +#include +#include namespace ort_fastertransformer { +struct MoEGemmConfigMap { + using MoEGemmConfigMapT = std::unordered_map; + + MoEGemmConfigMapT map; + std::mutex mutex; + + void Insert(int64_t key, CutlassGemmConfig config) { + std::lock_guard lock(mutex); + map[key] = config; + } + + bool Contains(int64_t key) { + std::lock_guard lock(mutex); + return map.find(key) != map.end(); + } + + CutlassGemmConfig Get(int64_t key) { + std::lock_guard lock(mutex); + return map[key]; + } +}; + enum class ActivationType { Gelu, Relu, Silu, @@ -43,19 +68,30 @@ class MoeGemmRunner { int num_experts, ActivationType activation_type, cudaStream_t stream); void moe_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cudaStream_t stream); + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, cudaStream_t stream); + + static MoEGemmConfigMap& GetGemmConfigMap() { + static MoEGemmConfigMap gFactory; + return gFactory; + } private: template void dispatch_to_arch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy = nullptr); + int num_experts, CutlassGemmConfig gemm_config, cudaStream_t stream, + int* occupancy = nullptr); + + template + void profile_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, cudaStream_t stream, int64_t key); template void run_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cudaStream_t stream); + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, cudaStream_t stream); private: int sm_; diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index d81808e217fbc..ef1f97b9e57a2 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -35,7 +35,6 @@ #include "cutlass/gemm/kernel/default_gemm_grouped.h" #include "cutlass/layout/matrix.h" #include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" #include "contrib_ops/cuda/moe/cutlass_extensions/compute_occupancy.h" #include "contrib_ops/cuda/moe/cutlass_extensions/epilogue_helpers.h" @@ -54,8 +53,7 @@ #include "cutlass_heuristic.h" #include "moe_gemm_kernels.h" -#include -#include +#include #include #include @@ -136,21 +134,21 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w if (can_implement != cutlass::Status::kSuccess) { std::string err_msg = "MoEFC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); - ORT_THROW("[FT Error][MoE Runner] " + err_msg); + ORT_THROW("[MoE Runner] " + err_msg); } auto init_status = gemm.initialize(args); if (init_status != cutlass::Status::kSuccess) { std::string err_msg = "Failed to initialize cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(init_status)); - ORT_THROW("[FT Error][MoE Runner] " + err_msg); + ORT_THROW("[MoE Runner] " + err_msg); } auto run_status = gemm.run(stream); if (run_status != cutlass::Status::kSuccess) { std::string err_msg = "Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status)); - ORT_THROW("[FT Error][MoE Runner] " + err_msg); + ORT_THROW("[MoE Runner] " + err_msg); } } @@ -163,7 +161,7 @@ struct dispatch_stages { cudaStream_t /*stream*/, [[maybe_unused]] int* occupancy = nullptr) { std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); - ORT_THROW("[FT Error][dispatch_stages::dispatch] " + err_msg); + ORT_THROW("[dispatch_stages::dispatch] " + err_msg); } }; @@ -218,7 +216,7 @@ void dispatch_gemm_config(const T* A, const WeightType* B, const T* weight_scale break; default: std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages); - ORT_THROW("[FT Error][MoE][dispatch_gemm_config] " + err_msg); + ORT_THROW("[MoE][dispatch_gemm_config] " + err_msg); break; } } @@ -238,8 +236,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weig if constexpr (arch::kMinComputeCapability >= 75) { dispatch_gemm_config, cutlass::gemm::GemmShape<16, 32, 64>>( - A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, stream, occupancy); + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); } break; case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: @@ -247,27 +245,27 @@ void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weig if constexpr (arch::kMinComputeCapability >= 75) { dispatch_gemm_config, cutlass::gemm::GemmShape<16, 64, 64>>( - A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, stream, occupancy); + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); } break; case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: dispatch_gemm_config, - cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, - total_rows_before_expert, gemm_n, gemm_k, num_experts, - gemm_config, multi_processor_count, stream, occupancy); + cutlass::gemm::GemmShape<32, 32, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); break; case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: dispatch_gemm_config, - cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, C, - total_rows_before_expert, gemm_n, gemm_k, num_experts, - gemm_config, multi_processor_count, stream, occupancy); + cutlass::gemm::GemmShape<32, 64, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); break; case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: dispatch_gemm_config, - cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, - total_rows_before_expert, gemm_n, gemm_k, num_experts, - gemm_config, multi_processor_count, stream, occupancy); + cutlass::gemm::GemmShape<64, 32, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); break; case CutlassTileConfig::Undefined: ORT_THROW("GEMM config undefined."); @@ -288,36 +286,54 @@ template < typename T, typename WeightType, typename arch, typename EpilogueTag, typename std::enable_if::value && !std::is_same::value>::type* = nullptr> void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, CutlassGemmConfig gemm_config, int sm_version, + int64_t* total_rows_before_expert, int64_t /*total_rows*/, int64_t gemm_n, + int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + ORT_ENFORCE(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) { + dispatch_gemm_config, + cutlass::gemm::GemmShape<16, 32, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + } + break; + case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + ORT_ENFORCE(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) { + dispatch_gemm_config, + cutlass::gemm::GemmShape<16, 64, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); + } + break; case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: dispatch_gemm_config, - cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, - total_rows_before_expert, gemm_n, gemm_k, num_experts, - gemm_config, multi_processor_count, stream, occupancy); + cutlass::gemm::GemmShape<32, 32, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); break; case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: dispatch_gemm_config, - cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, - total_rows_before_expert, gemm_n, gemm_k, num_experts, - gemm_config, multi_processor_count, stream, occupancy); + cutlass::gemm::GemmShape<64, 32, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); break; case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: dispatch_gemm_config, cutlass::gemm::GemmShape<128, 32, 64>>( - A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, stream, occupancy); + A, B, weight_scales, biases, C, total_rows_before_expert, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); break; case CutlassTileConfig::Undefined: - ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config undefined."); + ORT_THROW("GEMM config undefined."); break; case CutlassTileConfig::ChooseWithHeuristic: - ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config should have already been set by heuristic."); + ORT_THROW("GEMM config should have already been set by heuristic."); break; default: - ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] Config is invalid for mixed type tensorop GEMM."); + ORT_THROW("Config is invalid for mixed type tensorop GEMM."); break; } } @@ -332,19 +348,18 @@ void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weig switch (gemm_config.tile_config) { case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: dispatch_gemm_config, - cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, C, - total_rows_before_expert, gemm_n, gemm_k, num_experts, - gemm_config, multi_processor_count, stream, occupancy); + cutlass::gemm::GemmShape<64, 64, 8>>( + A, B, weight_scales, biases, C, total_rows_before_expert, + gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); break; case CutlassTileConfig::Undefined: - ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] gemm config undefined."); + ORT_THROW("GEMM config undefined."); break; case CutlassTileConfig::ChooseWithHeuristic: - ORT_THROW( - "[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] gemm config should have already been set by heuristic."); + ORT_THROW("GEMM config should have already been set by heuristic."); break; default: - ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] Unsupported config for float MoE gemm."); + ORT_THROW("Unsupported config for float MoE gemm."); break; } } @@ -381,34 +396,80 @@ void MoeGemmRunner::dispatch_to_arch(const T* A, con A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, stream, occupancy); } else { - ORT_THROW("[FT Error][MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); + ORT_THROW("[MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); } } template template -void MoeGemmRunner::run_gemm(const T* A, const WeightType* B, const T* weight_scales, - const T* biases, T* C, int64_t* total_rows_before_expert, - int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, cudaStream_t stream) { +void MoeGemmRunner::profile_gemm(const T* A, const WeightType* B, const T* weight_scales, + const T* biases, T* C, int64_t* total_rows_before_expert, + int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, cudaStream_t stream, int64_t key) { static constexpr bool is_weight_only = !std::is_same::value; static constexpr bool only_simt_configs = std::is_same::value; + std::vector candidate_configs = get_candidate_configs(sm_, is_weight_only, only_simt_configs); std::vector occupancies(candidate_configs.size()); + constexpr int warmup = 5; + constexpr int runs = 10; + float min_elapsed = std::numeric_limits::max(); + size_t chosen_config_id = 0; for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { - dispatch_to_arch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, - num_experts, candidate_configs[ii], stream, &occupancies[ii]); - } + for (int jj = 0; jj < warmup; ++jj) { + dispatch_to_arch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, candidate_configs[ii], stream); + } - static constexpr int workspace_bytes = 0; // No workspace for MoE GEMMs. - static constexpr int split_k_limit = 1; // MoE GEMM does not support split-k. - CutlassGemmConfig chosen_config = - estimate_best_config_from_occupancies(candidate_configs, occupancies, total_rows, gemm_n, gemm_k, num_experts, - split_k_limit, workspace_bytes, multi_processor_count_, is_weight_only); + cudaEvent_t start; + cudaEvent_t stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + cudaStreamSynchronize(stream); + cudaEventRecord(start, stream); + for (int jj = 0; jj < runs; ++jj) { + dispatch_to_arch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, candidate_configs[ii], stream); + } + + cudaEventRecord(stop, stream); + cudaEventSynchronize(stop); + + float elapsed; + cudaEventElapsedTime(&elapsed, start, stop); + + cudaEventDestroy(start); + cudaEventDestroy(stop); + + if (elapsed < min_elapsed) { + min_elapsed = elapsed; + chosen_config_id = ii; + } + } + CutlassGemmConfig config = candidate_configs[chosen_config_id]; + GetGemmConfigMap().Insert(key, config); +} + +template +template +void MoeGemmRunner::run_gemm(const T* A, const WeightType* B, const T* weight_scales, + const T* biases, T* C, int64_t* total_rows_before_expert, + int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, cudaStream_t stream) { + // Generate Key to the GemmConfigMap + // First 32 bits are total_rows, next 16 bits are gemm_n, next 16 bits are gemm_k + int64_t key = total_rows; + key = key << 16 | gemm_n; + key = key << 16 | gemm_k; + + if (!GetGemmConfigMap().Contains(key)) { + profile_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream, key); + } dispatch_to_arch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, - num_experts, chosen_config, stream); + num_experts, GetGemmConfigMap().Get(key), stream); } template @@ -435,10 +496,10 @@ void MoeGemmRunner::moe_gemm_bias_act(const T* A, const WeightTyp num_experts, stream); break; case ActivationType::InvalidType: - ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); + ORT_THROW("[MoE Runner] Invalid activation type for MoE GEMM"); break; default: { - ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); + ORT_THROW("[MoE Runner] Invalid activation type for MoE GEMM"); } } } diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index 39ce6aec90e1a..5f26de4810c42 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -56,116 +56,116 @@ static constexpr int WARP_SIZE = 32; // in the softmax kernel when we extend this module to support expert-choice routing. template __launch_bounds__(TPB) __global__ - void moe_softmax(const T* input, const bool* finished, T* output, const int num_cols) { - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmpStorage; + void moe_softmax(const T *input, const bool *finished, T *output, const int num_cols) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; - __shared__ float normalizing_factor; - __shared__ float float_max; + __shared__ float normalizing_factor; + __shared__ float float_max; - const int thread_row_offset = blockIdx.x * num_cols; + const int thread_row_offset = blockIdx.x * num_cols; - cub::Sum sum; - float threadData(-FLT_MAX); + cub::Sum sum; + float threadData(-FLT_MAX); - // Don't touch finished rows. - if ((finished != nullptr) && finished[blockIdx.x]) { - return; - } + // Don't touch finished rows. + if ((finished != nullptr) && finished[blockIdx.x]) { + return; + } - for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { - const int idx = thread_row_offset + ii; - threadData = max(static_cast(input[idx]), threadData); - } + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData = max(static_cast(input[idx]), threadData); + } - const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); - if (threadIdx.x == 0) { - float_max = maxElem; - } - __syncthreads(); + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) { + float_max = maxElem; + } + __syncthreads(); - threadData = 0; + threadData = 0; - for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { - const int idx = thread_row_offset + ii; - threadData += exp((static_cast(input[idx]) - float_max)); - } + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData += exp((static_cast(input[idx]) - float_max)); + } - const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); - if (threadIdx.x == 0) { - normalizing_factor = 1.f / Z; - } - __syncthreads(); + if (threadIdx.x == 0) { + normalizing_factor = 1.f / Z; + } + __syncthreads(); - for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { - const int idx = thread_row_offset + ii; - const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; - output[idx] = T(val); - } + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; + output[idx] = T(val); + } } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 template -__launch_bounds__(TPB) __global__ void moe_top_k(const T*, const bool*, T*, int*, int*, int, int, bool) { - // Does not support pre-Kepler architectures - ; +__launch_bounds__(TPB) __global__ void moe_top_k(const T *, const bool *, T *, int *, int *, int, int, bool) { + // Does not support pre-Kepler architectures + ; } #else template __launch_bounds__(TPB) __global__ - void moe_top_k(const T* inputs_after_softmax, const bool* finished, T* output, int* indices, int* source_rows, + void moe_top_k(const T *inputs_after_softmax, const bool *finished, T *output, int *indices, int *source_rows, int num_experts, int k, bool normalize_routing_weights) { - using cub_kvp = cub::KeyValuePair; - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmpStorage; - - cub_kvp thread_kvp; - cub::ArgMax arg_max; - - int num_rows = gridDim.x; - const int block_row = blockIdx.x; - - const bool should_process_row = finished ? !finished[block_row] : true; - const int thread_read_offset = blockIdx.x * num_experts; - float output_row_sum = 0.f; - for (int k_idx = 0; k_idx < k; ++k_idx) { - thread_kvp.key = 0; - thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities - - cub_kvp inp_kvp; - for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { - const int idx = thread_read_offset + expert; - inp_kvp.key = expert; - inp_kvp.value = inputs_after_softmax[idx]; - - for (int prior_k = 0; prior_k < k_idx; ++prior_k) { - const int prior_winning_expert = indices[k * block_row + prior_k]; - - if (prior_winning_expert == expert) { - inp_kvp = thread_kvp; - } - } + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; - thread_kvp = arg_max(inp_kvp, thread_kvp); - } + cub_kvp thread_kvp; + cub::ArgMax arg_max; - const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); - if (threadIdx.x == 0) { - const int idx = k * block_row + k_idx; - output[idx] = result_kvp.value; - indices[idx] = should_process_row ? result_kvp.key : num_experts; - source_rows[idx] = k_idx * num_rows + block_row; + int num_rows = gridDim.x; + const int block_row = blockIdx.x; + + const bool should_process_row = finished ? !finished[block_row] : true; + const int thread_read_offset = blockIdx.x * num_experts; + float output_row_sum = 0.f; + for (int k_idx = 0; k_idx < k; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = T(-1.f); + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_softmax[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + const int idx = k * block_row + k_idx; + output[idx] = result_kvp.value; + indices[idx] = should_process_row ? result_kvp.key : num_experts; + source_rows[idx] = k_idx * num_rows + block_row; - if (normalize_routing_weights && k_idx == k - 1) { + if (normalize_routing_weights && k_idx == k - 1) { #pragma unroll - for (int ki = 0; ki < k; ++ki) { - output[idx - ki] = T(static_cast(output[idx - ki]) / output_row_sum); + for (int ki = 0; ki < k; ++ki) { + output[idx - ki] = T(static_cast(output[idx - ki]) / output_row_sum); + } + } } - } + __syncthreads(); } - __syncthreads(); - } } #endif @@ -184,279 +184,279 @@ __launch_bounds__(TPB) __global__ */ template -__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ - void topk_gating_softmax(const T* input, const bool* finished, T* output, int num_rows, int* indices, - int* source_rows, int k, bool normalize_routing_weights) { - // We begin by enforcing compile time assertions and setting up compile time constants. - static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); - static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); - static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); - static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); - - // Number of bytes each thread pulls in per load - static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); - static constexpr int ELTS_PER_ROW = NUM_EXPERTS; - static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; - static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; - - // Restrictions based on previous section. - static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); - static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); - static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); - static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); - - // We have NUM_EXPERTS elements per row. We specialize for small #experts - static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; - static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; - static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; - - // Restrictions for previous section. - static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp"); - - // ===================== From this point, we finally start computing run-time variables. ======================== - - // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps. - // This, each block processes a chunk of rows. We start by computing the start row for each block. - const int cta_base_row = blockIdx.x * ROWS_PER_CTA; - - // Now, using the base row per thread block, we compute the base row per warp. - const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; - - // The threads in a warp are split into sub-groups that will work on a row. - // We compute row offset for each thread sub-group - const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; - const int thread_row = warp_base_row + thread_row_in_warp; - - // Threads with indices out of bounds should early exit here. - if (thread_row >= num_rows) return; - const bool should_process_row = finished ? !finished[thread_row] : true; - - // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the - // row it will read. - const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW; - - // Now, we compute the group each thread belong to in order to determine the first column to start loads. - const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; - const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; - const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; - - // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, - // this can support all powers of 2 up to 16. - using AccessType = cutlass::AlignedArray; - - // Finally, we pull in the data from global mem - cutlass::Array row_chunk_input; - AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk_input); - const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +__launch_bounds__(WARPS_PER_CTA *WARP_SIZE) __global__ + void topk_gating_softmax(const T *input, const bool *finished, T *output, int num_rows, int *indices, + int *source_rows, int k, bool normalize_routing_weights) { + // We begin by enforcing compile time assertions and setting up compile time constants. + static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); + static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); + static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); + static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); + + // Number of bytes each thread pulls in per load + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); + static constexpr int ELTS_PER_ROW = NUM_EXPERTS; + static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; + static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + + // Restrictions based on previous section. + static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); + static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); + static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); + static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); + + // We have NUM_EXPERTS elements per row. We specialize for small #experts + static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; + static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; + static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; + + // Restrictions for previous section. + static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp"); + + // ===================== From this point, we finally start computing run-time variables. ======================== + + // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps. + // This, each block processes a chunk of rows. We start by computing the start row for each block. + const int cta_base_row = blockIdx.x * ROWS_PER_CTA; + + // Now, using the base row per thread block, we compute the base row per warp. + const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; + + // The threads in a warp are split into sub-groups that will work on a row. + // We compute row offset for each thread sub-group + const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; + const int thread_row = warp_base_row + thread_row_in_warp; + + // Threads with indices out of bounds should early exit here. + if (thread_row >= num_rows) + return; + const bool should_process_row = finished ? !finished[thread_row] : true; + + // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the + // row it will read. + const T *thread_row_ptr = input + thread_row * ELTS_PER_ROW; + + // Now, we compute the group each thread belong to in order to determine the first column to start loads. + const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; + const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; + const T *thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + + // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, + // this can support all powers of 2 up to 16. + using AccessType = cutlass::AlignedArray; + + // Finally, we pull in the data from global mem + cutlass::Array row_chunk_input; + AccessType *row_chunk_vec_ptr = reinterpret_cast(&row_chunk_input); + const AccessType *vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); #pragma unroll - for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { - row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; - } - - using ComputeType = float; - using Converter = cutlass::NumericArrayConverter; - Converter compute_type_converter; - cutlass::Array row_chunk = compute_type_converter(row_chunk_input); - - // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just - // convert to float afterwards for the exp + sum reduction. - ComputeType thread_max = row_chunk[0]; + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + + using ComputeType = float; + using Converter = cutlass::NumericArrayConverter; + Converter compute_type_converter; + cutlass::Array row_chunk = compute_type_converter(row_chunk_input); + + // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just + // convert to float afterwards for the exp + sum reduction. + ComputeType thread_max = row_chunk[0]; #pragma unroll - for (int ii = 1; ii < VPT; ++ii) { - thread_max = max(thread_max, row_chunk[ii]); - } + for (int ii = 1; ii < VPT; ++ii) { + thread_max = max(thread_max, row_chunk[ii]); + } // Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce. #pragma unroll - for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); - } + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); + } - // From this point, thread max in all the threads have the max within the row. - // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum. - float row_sum = 0; + // From this point, thread max in all the threads have the max within the row. + // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum. + float row_sum = 0; #pragma unroll - for (int ii = 0; ii < VPT; ++ii) { - row_chunk[ii] = expf(row_chunk[ii] - thread_max); - row_sum += row_chunk[ii]; - } + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = expf(row_chunk[ii] - thread_max); + row_sum += row_chunk[ii]; + } // Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern. #pragma unroll - for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); - } + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); + } - // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables - // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to - // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. - // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the - // argmax after computing the softmax. - const float reciprocal_row_sum = 1.f / row_sum; + // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables + // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to + // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. + // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the + // argmax after computing the softmax. + const float reciprocal_row_sum = 1.f / row_sum; #pragma unroll - for (int ii = 0; ii < VPT; ++ii) { - row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; - } - - // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along - // with the max index.​ - int start_col = first_elt_read_by_thread; - static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; - - float output_row_sum = 0.f; - for (int k_idx = 0; k_idx < k; ++k_idx) { - // First, each thread does the local argmax - float max_val = row_chunk[0]; - int expert = start_col; + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; + } + + // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along + // with the max index.​ + int start_col = first_elt_read_by_thread; + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + + float output_row_sum = 0.f; + for (int k_idx = 0; k_idx < k; ++k_idx) { + // First, each thread does the local argmax + float max_val = row_chunk[0]; + int expert = start_col; #pragma unroll - for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) { + for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) { #pragma unroll - for (int ii = 0; ii < ELTS_PER_LDG; ++ii) { - float val = row_chunk[ldg * ELTS_PER_LDG + ii]; - - // No check on the experts here since columns with the smallest index are processed first and only - // updated if > (not >=) - if (val > max_val) { - max_val = val; - expert = col + ii; + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) { + float val = row_chunk[ldg * ELTS_PER_LDG + ii]; + + // No check on the experts here since columns with the smallest index are processed first and only + // updated if > (not >=) + if (val > max_val) { + max_val = val; + expert = col + ii; + } + } } - } - } // Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max. // This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can // then blank out their max with -inf and the warp can run more iterations... #pragma unroll - for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { - float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); - int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); - - // We want lower indices to "win" in every thread so we break ties this way - if (other_max > max_val || (other_max == max_val && other_expert < expert)) { - max_val = other_max; - expert = other_expert; - } - } + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); + int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); + + // We want lower indices to "win" in every thread so we break ties this way + if (other_max > max_val || (other_max == max_val && other_expert < expert)) { + max_val = other_max; + expert = other_expert; + } + } - // Write the max for this k iteration to global memory. - if (thread_group_idx == 0) { - // The lead thread from each sub-group will write out the final results to global memory. (This will be a - // single) thread per row of the input/output matrices. - const int idx = k * thread_row + k_idx; - output[idx] = T(max_val); - output_row_sum = output_row_sum + static_cast(max_val); - indices[idx] = should_process_row ? expert : NUM_EXPERTS; - source_rows[idx] = k_idx * num_rows + thread_row; - - if (normalize_routing_weights && k_idx == k - 1) { + // Write the max for this k iteration to global memory. + if (thread_group_idx == 0) { + // The lead thread from each sub-group will write out the final results to global memory. (This will be a + // single) thread per row of the input/output matrices. + const int idx = k * thread_row + k_idx; + output[idx] = T(max_val); + output_row_sum = output_row_sum + static_cast(max_val); + indices[idx] = should_process_row ? expert : NUM_EXPERTS; + source_rows[idx] = k_idx * num_rows + thread_row; + + if (normalize_routing_weights && k_idx == k - 1) { #pragma unroll - for (int ki = 0; ki < k; ++ki) { - output[idx - ki] = T(static_cast(output[idx - ki]) / output_row_sum); + for (int ki = 0; ki < k; ++ki) { + output[idx - ki] = T(static_cast(output[idx - ki]) / output_row_sum); + } + } } - } - } - // Finally, we clear the value in the thread with the current max if there is another iteration to run. - if (k_idx + 1 < k) { - const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; - const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; - - // Only the thread in the group which produced the max will reset the "winning" value to -inf. - if (thread_group_idx == thread_to_clear_in_group) { - const int offset_for_expert = expert % ELTS_PER_LDG; - // Safe to set to any negative value since row_chunk values must be between 0 and 1. - row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = ComputeType(-10000.f); - } + // Finally, we clear the value in the thread with the current max if there is another iteration to run. + if (k_idx + 1 < k) { + const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + + // Only the thread in the group which produced the max will reset the "winning" value to -inf. + if (thread_group_idx == thread_to_clear_in_group) { + const int offset_for_expert = expert % ELTS_PER_LDG; + // Safe to set to any negative value since row_chunk values must be between 0 and 1. + row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = ComputeType(-10000.f); + } + } } - } } namespace detail { // Constructs some constants needed to partition the work across threads at compile time. -template -struct TopkConstants { - static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); - static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); - static constexpr int VECs_PER_THREAD = std::max(1, (int)EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); - static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; - static constexpr int THREADS_PER_ROW = EXPERTS / VPT; - static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; +template struct TopkConstants { + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); + static constexpr int VECs_PER_THREAD = std::max(1, (int)EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; }; -} // namespace detail +} // namespace detail template -void topk_gating_softmax_launcher_helper(const T* input, const bool* finished, T* output, int* indices, int* source_row, +void topk_gating_softmax_launcher_helper(const T *input, const bool *finished, T *output, int *indices, int *source_row, int num_rows, int /*num_experts*/, int k, bool normalize_routing_weights, cudaStream_t stream) { - static constexpr unsigned long MAX_BYTES_PER_LDG = 16; - - static constexpr int BYTES_PER_LDG = std::min((int)MAX_BYTES_PER_LDG, (int)sizeof(T) * EXPERTS); - using Constants = detail::TopkConstants; - static constexpr int VPT = Constants::VPT; - static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; - const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; - const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; - - dim3 block_dim(WARP_SIZE, WARPS_PER_TB); - topk_gating_softmax<<>>( - input, finished, output, num_rows, indices, source_row, k, normalize_routing_weights); + static constexpr unsigned long MAX_BYTES_PER_LDG = 16; + + static constexpr int BYTES_PER_LDG = std::min((int)MAX_BYTES_PER_LDG, (int)sizeof(T) * EXPERTS); + using Constants = detail::TopkConstants; + static constexpr int VPT = Constants::VPT; + static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; + const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; + + dim3 block_dim(WARP_SIZE, WARPS_PER_TB); + topk_gating_softmax<<>>( + input, finished, output, num_rows, indices, source_row, k, normalize_routing_weights); } template -void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_output, - int* indices, int* source_row, int num_rows, int num_experts, int k, +void topk_gating_softmax_kernelLauncher(const T *input, const bool *finished, T *output, T *softmax_temp_output, + int *indices, int *source_row, int num_rows, int num_experts, int k, bool normalize_routing_weights, cudaStream_t stream) { - static constexpr int WARPS_PER_TB = 4; + static constexpr int WARPS_PER_TB = 4; - switch (num_experts) { + switch (num_experts) { case 2: { - topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, normalize_routing_weights, stream); - break; + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, normalize_routing_weights, stream); + break; } case 4: { - topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, normalize_routing_weights, stream); - break; + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, normalize_routing_weights, stream); + break; } case 8: { - topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, normalize_routing_weights, stream); - break; + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, normalize_routing_weights, stream); + break; } case 16: { - topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, normalize_routing_weights, stream); - break; + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, normalize_routing_weights, stream); + break; } case 32: { - topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, normalize_routing_weights, stream); - break; + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, normalize_routing_weights, stream); + break; } case 64: { - topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, normalize_routing_weights, stream); - break; + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, normalize_routing_weights, stream); + break; } case 128: { - topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, normalize_routing_weights, stream); - break; + topk_gating_softmax_launcher_helper( + input, finished, output, indices, source_row, num_rows, num_experts, k, normalize_routing_weights, stream); + break; } case 256: { - topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, normalize_routing_weights, stream); - break; + topk_gating_softmax_launcher_helper( + input, finished, output, indices, source_row, num_rows, num_experts, k, normalize_routing_weights, stream); + break; } default: { - static constexpr int TPB = 256; - moe_softmax<<>>(input, finished, softmax_temp_output, num_experts); - moe_top_k<<>>(softmax_temp_output, finished, output, indices, source_row, - num_experts, k, normalize_routing_weights); + static constexpr int TPB = 256; + moe_softmax<<>>(input, finished, softmax_temp_output, num_experts); + moe_top_k<<>>(softmax_temp_output, finished, output, indices, source_row, + num_experts, k, normalize_routing_weights); + } } - } } // ========================== CUB Sorting things ==================================== @@ -466,404 +466,397 @@ CubKeyValueSorter::CubKeyValueSorter(int num_experts) : num_experts_(num_experts), num_bits_((int)log2(num_experts) + 1) {} void CubKeyValueSorter::update_num_experts(int num_experts) { - num_experts_ = num_experts; - num_bits_ = (int)log2(num_experts) + 1; + num_experts_ = num_experts; + num_bits_ = (int)log2(num_experts) + 1; } size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs) { - num_key_value_pairs_ = num_key_value_pairs; - size_t required_storage = 0; - int* null_int = nullptr; - cub::DeviceRadixSort::SortPairs(NULL, required_storage, null_int, null_int, null_int, null_int, - (int)num_key_value_pairs, 0, num_bits_); - return required_storage; + num_key_value_pairs_ = num_key_value_pairs; + size_t required_storage = 0; + int *null_int = nullptr; + cub::DeviceRadixSort::SortPairs(NULL, required_storage, null_int, null_int, null_int, null_int, + (int)num_key_value_pairs, 0, num_bits_); + return required_storage; } -void CubKeyValueSorter::run(void* workspace, const size_t workspace_size, const int* keys_in, int* keys_out, - const int* values_in, int* values_out, const size_t num_key_value_pairs, +void CubKeyValueSorter::run(void *workspace, const size_t workspace_size, const int *keys_in, int *keys_out, + const int *values_in, int *values_out, const size_t num_key_value_pairs, cudaStream_t stream) { - size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs); - size_t actual_ws_size = workspace_size; - - if (expected_ws_size > workspace_size) { - ORT_THROW("Error. The allocated workspace is too small to run this problem. Expected workspace size of at least ", - expected_ws_size, " but got problem size ", workspace_size, "\n"); - } - cub::DeviceRadixSort::SortPairs(workspace, actual_ws_size, keys_in, keys_out, values_in, values_out, - (int)num_key_value_pairs, 0, num_bits_, stream); + size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs); + size_t actual_ws_size = workspace_size; + + if (expected_ws_size > workspace_size) { + ORT_THROW( + "Error. The allocated workspace is too small to run this problem. Expected workspace size of at least ", + expected_ws_size, " but got problem size ", workspace_size, "\n"); + } + cub::DeviceRadixSort::SortPairs(workspace, actual_ws_size, keys_in, keys_out, values_in, values_out, + (int)num_key_value_pairs, 0, num_bits_, stream); } // ============================== Infer GEMM sizes ================================= -__device__ inline int find_total_elts_leq_target(const int* sorted_indices, const int arr_length, const int target) { - int64_t low = 0, high = arr_length - 1, target_location = -1; - while (low <= high) { - int64_t mid = (low + high) / 2; - - if (sorted_indices[mid] > target) { - high = mid - 1; - } else { - low = mid + 1; - target_location = mid; +__device__ inline int find_total_elts_leq_target(const int *sorted_indices, const int arr_length, const int target) { + int64_t low = 0, high = arr_length - 1, target_location = -1; + while (low <= high) { + int64_t mid = (low + high) / 2; + + if (sorted_indices[mid] > target) { + high = mid - 1; + } else { + low = mid + 1; + target_location = mid; + } } - } - return target_location + 1; + return target_location + 1; } // Sets up the gemm assuming the inputs, experts and outputs are stored in row major order. // Assumes we want to perform output = matmul(inputs, experts) + bias -__global__ void compute_total_rows_before_expert_kernel(const int* sorted_experts, const int sorted_experts_len, - const int64_t num_experts, int64_t* total_rows_before_expert) { - // First, compute the global tid. We only need 1 thread per expert. - const int expert = blockIdx.x * blockDim.x + threadIdx.x; - if (expert >= num_experts) return; - - // This should construct the last index where each expert occurs. - total_rows_before_expert[expert] = find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); +__global__ void compute_total_rows_before_expert_kernel(const int *sorted_experts, const int sorted_experts_len, + const int64_t num_experts, int64_t *total_rows_before_expert) { + // First, compute the global tid. We only need 1 thread per expert. + const int expert = blockIdx.x * blockDim.x + threadIdx.x; + if (expert >= num_experts) + return; + + // This should construct the last index where each expert occurs. + total_rows_before_expert[expert] = find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); } -__global__ void dispatch_activations_kernel(int64_t* total_rows_before_expert, int num_experts, int local_num_experts, +__global__ void dispatch_activations_kernel(int64_t *total_rows_before_expert, int num_experts, int local_num_experts, int local_experts_start_index) { - const int expert = blockIdx.x * blockDim.x + threadIdx.x; - const int local_experts_end_index = local_experts_start_index + local_num_experts - 1; + const int expert = blockIdx.x * blockDim.x + threadIdx.x; + const int local_experts_end_index = local_experts_start_index + local_num_experts - 1; - int total_past_rows = 0; - if (local_experts_start_index > 0) { - total_past_rows = total_rows_before_expert[local_experts_start_index - 1]; - } + int total_past_rows = 0; + if (local_experts_start_index > 0) { + total_past_rows = total_rows_before_expert[local_experts_start_index - 1]; + } - if (expert < local_experts_start_index || expert > local_experts_end_index) { - return; - } + if (expert < local_experts_start_index || expert > local_experts_end_index) { + return; + } - total_rows_before_expert[expert] -= total_past_rows; + total_rows_before_expert[expert] -= total_past_rows; } template CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights) - : has_fc3_(has_fc3), - total_past_rows_(0), - total_covered_rows_(0), + : has_fc3_(has_fc3), total_past_rows_(0), total_covered_rows_(0), normalize_routing_weights_(normalize_routing_weights) { - moe_gemm_runner_.initialize(sm_version); + moe_gemm_runner_.initialize(sm_version); } template size_t CutlassMoeFCRunner::getWorkspaceSize(size_t num_rows, const size_t hidden_size, const size_t inter_size, size_t num_experts, size_t k) { - total_covered_rows_ = k * num_rows; - - const size_t buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size); - const size_t interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size); - const size_t padded_experts = pad_to_multiple_of_16(num_experts); - const size_t num_moe_inputs = pad_to_multiple_of_16(k * num_rows); - size_t num_softmax_outs = 0; - - const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); - if (!is_pow_2 || num_experts > 256) { - num_softmax_outs = pad_to_multiple_of_16(num_rows * num_experts); - } - - // softmax output, permuted_rows and permuted_experts have moved to outside of moe kernel, allocate them - // in Encoder or Decoder before invoking FfnLayer forward. - size_t total_ws_bytes = 3 * num_moe_inputs * sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ - total_ws_bytes += buf_size * sizeof(T); // permuted_data - total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ - total_ws_bytes += num_softmax_outs * sizeof(T); - const size_t bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T); - const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows)); - sorter_.update_num_experts(static_cast(num_experts)); - - size_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result; - if (sorter_ws_size_bytes > bytes_for_fc1_result) { - size_t remaining_bytes = pad_to_multiple_of_16(sorter_ws_size_bytes - bytes_for_fc1_result); - bytes_for_intermediate_and_sorting += remaining_bytes; - } - - total_ws_bytes += bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub sorting workspace - return total_ws_bytes; + total_covered_rows_ = k * num_rows; + + const size_t buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size); + const size_t interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size); + const size_t padded_experts = pad_to_multiple_of_16(num_experts); + const size_t num_moe_inputs = pad_to_multiple_of_16(k * num_rows); + size_t num_softmax_outs = 0; + + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if (!is_pow_2 || num_experts > 256) { + num_softmax_outs = pad_to_multiple_of_16(num_rows * num_experts); + } + + // softmax output, permuted_rows and permuted_experts have moved to outside of moe kernel, allocate them + // in Encoder or Decoder before invoking FfnLayer forward. + size_t total_ws_bytes = 3 * num_moe_inputs * sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ + total_ws_bytes += buf_size * sizeof(T); // permuted_data + total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ + total_ws_bytes += num_softmax_outs * sizeof(T); + const size_t bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T); + const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows)); + sorter_.update_num_experts(static_cast(num_experts)); + + size_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result; + if (sorter_ws_size_bytes > bytes_for_fc1_result) { + size_t remaining_bytes = pad_to_multiple_of_16(sorter_ws_size_bytes - bytes_for_fc1_result); + bytes_for_intermediate_and_sorting += remaining_bytes; + } + + total_ws_bytes += bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub sorting workspace + return total_ws_bytes; } template -void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, size_t num_rows, +void CutlassMoeFCRunner::configure_ws_ptrs(char *ws_ptr, size_t num_rows, const size_t hidden_size, const size_t inter_size, size_t num_experts, size_t k) { - const size_t buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size); - const size_t interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size); - const size_t padded_experts = pad_to_multiple_of_16(num_experts); - const size_t num_moe_inputs = pad_to_multiple_of_16(k * num_rows); - - source_rows_ = reinterpret_cast(ws_ptr); - permuted_rows_ = source_rows_ + num_moe_inputs; - permuted_experts_ = permuted_rows_ + num_moe_inputs; - permuted_data_ = reinterpret_cast(permuted_experts_ + num_moe_inputs); - - total_rows_before_expert_ = reinterpret_cast(permuted_data_ + buf_size); - - if (has_fc3_) { - fc3_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); - fc1_result_ = reinterpret_cast(fc3_result_ + interbuf_size); - } else { - fc1_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); - } - - const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); - if (!is_pow_2 || num_experts > 256) { - softmax_out_ = reinterpret_cast(fc1_result_ + interbuf_size); - } else { - softmax_out_ = nullptr; - } + const size_t buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size); + const size_t interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size); + const size_t padded_experts = pad_to_multiple_of_16(num_experts); + const size_t num_moe_inputs = pad_to_multiple_of_16(k * num_rows); + + source_rows_ = reinterpret_cast(ws_ptr); + permuted_rows_ = source_rows_ + num_moe_inputs; + permuted_experts_ = permuted_rows_ + num_moe_inputs; + permuted_data_ = reinterpret_cast(permuted_experts_ + num_moe_inputs); + + total_rows_before_expert_ = reinterpret_cast(permuted_data_ + buf_size); + + if (has_fc3_) { + fc3_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); + fc1_result_ = reinterpret_cast(fc3_result_ + interbuf_size); + } else { + fc1_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); + } + + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if (!is_pow_2 || num_experts > 256) { + softmax_out_ = reinterpret_cast(fc1_result_ + interbuf_size); + } else { + softmax_out_ = nullptr; + } } namespace { struct __align__(8) Half4 { - half2 x; - half2 y; + half2 x; + half2 y; }; // TODO(wy): move to common header -template -struct T4; -template <> -struct T4 { - using Type = float4; +template struct T4; +template <> struct T4 { + using Type = float4; }; -template <> -struct T4 { - using Type = Half4; +template <> struct T4 { + using Type = Half4; }; -template -struct T2; -template <> -struct T2 { - using Type = float2; +template struct T2; +template <> struct T2 { + using Type = float2; }; -template <> -struct T2 { - using Type = half2; +template <> struct T2 { + using Type = half2; }; inline __device__ float2 operator*(const float2 a, const float2 b) { return make_float2(a.x * b.x, a.y * b.y); } inline __device__ float4 operator*(const float4 a, const float4 b) { - return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); + return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); } // TODO(wy): use cuda common header and investigate pipeline build issue. -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \ ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) -inline __device__ half operator*(const half a, const half b) { - return __float2half(__half2float(a) * __half2float(b)); -} +inline __device__ half operator*(const half a, const half b) { return __float2half(__half2float(a) * __half2float(b)); } -inline __device__ half2 operator*(const half2 a, const half2 b) { - return make_half2(a.x * b.x, a.y * b.y); -} +inline __device__ half2 operator*(const half2 a, const half2 b) { return make_half2(a.x * b.x, a.y * b.y); } #endif // TODO(wy): use cuda common header and investigate pipeline build issue. inline __device__ Half4 operator*(const Half4 a, const Half4 b) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \ ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) - Half4 result; - result.x = a.x * b.x; - result.y = a.y * b.y; - return result; + Half4 result; + result.x = a.x * b.x; + result.y = a.y * b.y; + return result; #else - return Half4{__hmul2(a.x, b.x), __hmul2(a.y, b.y)}; + return Half4{__hmul2(a.x, b.x), __hmul2(a.y, b.y)}; #endif } -} // anonymous namespace +} // anonymous namespace -template -__global__ void elementWiseMulKernel(T* output, T const* input, size_t inter_size) { - int const tid = threadIdx.x; - int const token = blockIdx.x; - - output = output + token * inter_size; - input = input + token * inter_size; - for (int i = tid; i < inter_size; i += blockDim.x) { - T fc1_value = input[i]; - output[i] = fc1_value * output[i]; - } +template __global__ void elementWiseMulKernel(T *output, T const *input, size_t inter_size) { + int const tid = threadIdx.x; + int const token = blockIdx.x; + + output = output + token * inter_size; + input = input + token * inter_size; + for (int i = tid; i < inter_size; i += blockDim.x) { + T fc1_value = input[i]; + output[i] = fc1_value * output[i]; + } } template -void elementWiseMul(T* output, T const* input, int inter_size, int num_tokens, cudaStream_t stream) { - int const blocks = num_tokens; - - if (inter_size & 3 == 0) { - using vec_type = typename T4::Type; - int const threads = std::min(inter_size / 4, 1024); - elementWiseMulKernel<<>>( - reinterpret_cast(output), reinterpret_cast(input), inter_size / 4); - } else if (inter_size & 1 == 0) { - using vec_type = typename T2::Type; - int const threads = std::min(inter_size / 2, 1024); - elementWiseMulKernel<<>>( - reinterpret_cast(output), reinterpret_cast(input), inter_size / 2); - } else { - int const threads = std::min(inter_size, 1024); - elementWiseMulKernel<<>>(output, input, inter_size); - } +void elementWiseMul(T *output, T const *input, int inter_size, int num_tokens, cudaStream_t stream) { + int const blocks = num_tokens; + + if (inter_size & 3 == 0) { + using vec_type = typename T4::Type; + int const threads = std::min(inter_size / 4, 1024); + elementWiseMulKernel<<>>( + reinterpret_cast(output), reinterpret_cast(input), inter_size / 4); + } else if (inter_size & 1 == 0) { + using vec_type = typename T2::Type; + int const threads = std::min(inter_size / 2, 1024); + elementWiseMulKernel<<>>( + reinterpret_cast(output), reinterpret_cast(input), inter_size / 2); + } else { + int const threads = std::min(inter_size, 1024); + elementWiseMulKernel<<>>(output, input, inter_size); + } } template void CutlassMoeFCRunner::run_moe_fc( - const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, - const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc3_expert_weights, - const T* fc3_scales, const T* fc3_expert_biases, const WeightType* fc2_expert_weights, const T* fc2_scales, + const T *input_activations, const T *gating_output, const WeightType *fc1_expert_weights, const T *fc1_scales, + const T *fc1_expert_biases, ActivationType fc1_activation_type, const WeightType *fc3_expert_weights, + const T *fc3_scales, const T *fc3_expert_biases, const WeightType *fc2_expert_weights, const T *fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts, int local_num_experts, - int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, const bool* finished, int active_rows, - T* expert_scales, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream) { - static constexpr bool scales_required = - std::is_same::value || std::is_same::value; - - if (scales_required) { - if (fc1_scales == nullptr) { - ORT_THROW("[FT Error][Run MoE FC] Scales expected but scale for first matmul is a null pointer"); - } else if (fc2_scales == nullptr) { - ORT_THROW("[FT Error][Run MoE FC] Scales expected but scale for second matmul is a null pointer"); - } - } else { - if (fc1_scales != nullptr) { - ORT_THROW("[FT Error][Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC1"); - } else if (fc2_scales != nullptr) { - ORT_THROW("[FT Error][Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC2"); + int local_experts_start_index, int k, char *workspace_ptr, T *fc2_result, const bool *finished, int active_rows, + T *expert_scales, int *expanded_source_row_to_expanded_dest_row, int *expert_for_source_row, cudaStream_t stream) { + static constexpr bool scales_required = + std::is_same::value || std::is_same::value; + + if (scales_required) { + if (fc1_scales == nullptr) { + ORT_THROW("[Run MoE FC] Scales expected but scale for first matmul is a null pointer"); + } else if (fc2_scales == nullptr) { + ORT_THROW("[Run MoE FC] Scales expected but scale for second matmul is a null pointer"); + } + } else { + if (fc1_scales != nullptr) { + ORT_THROW("[Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC1"); + } else if (fc2_scales != nullptr) { + ORT_THROW("[Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC2"); + } } - } - configure_ws_ptrs(workspace_ptr, static_cast(num_rows), static_cast(hidden_size), - static_cast(inter_size), static_cast(num_experts), static_cast(k)); - topk_gating_softmax_kernelLauncher(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row, - source_rows_, num_rows, num_experts, k, normalize_routing_weights_, stream); + configure_ws_ptrs(workspace_ptr, static_cast(num_rows), static_cast(hidden_size), + static_cast(inter_size), static_cast(num_experts), static_cast(k)); + topk_gating_softmax_kernelLauncher(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row, + source_rows_, num_rows, num_experts, k, normalize_routing_weights_, stream); - const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows))); - sorter_.run(reinterpret_cast(fc1_result_), sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, - source_rows_, permuted_rows_, k * num_rows, stream); + const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows))); + sorter_.run(reinterpret_cast(fc1_result_), sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, + source_rows_, permuted_rows_, k * num_rows, stream); - initialize_moe_routing_kernelLauncher(input_activations, permuted_data_, permuted_rows_, - expanded_source_row_to_expanded_dest_row, num_rows, active_rows, hidden_size, k, - stream); + initialize_moe_routing_kernelLauncher(input_activations, permuted_data_, permuted_rows_, + expanded_source_row_to_expanded_dest_row, num_rows, active_rows, hidden_size, + k, stream); - const int expanded_active_expert_rows = k * active_rows; - compute_total_rows_before_expert(permuted_experts_, expanded_active_expert_rows, num_experts, - total_rows_before_expert_, stream); + const int expanded_active_expert_rows = k * active_rows; + compute_total_rows_before_expert(permuted_experts_, expanded_active_expert_rows, num_experts, + total_rows_before_expert_, stream); - if (local_num_experts < num_experts) { - dispatch_activations(total_rows_before_expert_, num_experts, local_num_experts, local_experts_start_index, stream); - } + if (local_num_experts < num_experts) { + dispatch_activations(total_rows_before_expert_, num_experts, local_num_experts, local_experts_start_index, + stream); + } - moe_gemm_runner_.moe_gemm_bias_act(permuted_data_ + total_past_rows_ * hidden_size, fc1_expert_weights, fc1_scales, - fc1_expert_biases, fc1_result_ + total_past_rows_ * inter_size, - total_rows_before_expert_ + local_experts_start_index, expanded_active_expert_rows, - inter_size, hidden_size, local_num_experts, fc1_activation_type, stream); + // moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size, expanded_active_expert_rows); + moe_gemm_runner_.moe_gemm_bias_act( + permuted_data_ + total_past_rows_ * hidden_size, fc1_expert_weights, fc1_scales, fc1_expert_biases, + fc1_result_ + total_past_rows_ * inter_size, total_rows_before_expert_ + local_experts_start_index, + expanded_active_expert_rows, inter_size, hidden_size, local_num_experts, fc1_activation_type, stream); + + if (has_fc3_) { + if (scales_required) { + if (fc3_scales == nullptr) { + ORT_THROW("[Run MoE FC] Scales expected but scale for third matmul is a null pointer"); + } + } else { + if (fc3_scales != nullptr) { + ORT_THROW("[Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC3"); + } + } + if (fc3_expert_weights == nullptr) { + ORT_THROW("[Run MoE FC] FC3 weights are null"); + } + moe_gemm_runner_.moe_gemm(permuted_data_ + total_past_rows_ * hidden_size, fc3_expert_weights, fc3_scales, + fc3_expert_biases, fc3_result_ + total_past_rows_ * inter_size, + total_rows_before_expert_ + local_experts_start_index, expanded_active_expert_rows, + inter_size, hidden_size, local_num_experts, stream); - if (has_fc3_) { - if (scales_required) { - if (fc3_scales == nullptr) { - ORT_THROW("[FT Error][Run MoE FC] Scales expected but scale for third matmul is a null pointer"); - } - } else { - if (fc3_scales != nullptr) { - ORT_THROW("[FT Error][Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC3"); - } - } - if (fc3_expert_weights == nullptr) { - ORT_THROW("[FT Error][Run MoE FC] FC3 weights are null"); + elementWiseMul(fc1_result_ + total_past_rows_ * inter_size, fc3_result_ + total_past_rows_ * inter_size, + static_cast(inter_size), static_cast(total_covered_rows_), stream); } - moe_gemm_runner_.moe_gemm(permuted_data_ + total_past_rows_ * hidden_size, fc3_expert_weights, fc3_scales, - fc3_expert_biases, fc3_result_ + total_past_rows_ * inter_size, - total_rows_before_expert_ + local_experts_start_index, expanded_active_expert_rows, - inter_size, hidden_size, local_num_experts, stream); - elementWiseMul(fc1_result_ + total_past_rows_ * inter_size, fc3_result_ + total_past_rows_ * inter_size, - static_cast(inter_size), static_cast(total_covered_rows_), stream); - } - - moe_gemm_runner_.moe_gemm(fc1_result_ + total_past_rows_ * inter_size, fc2_expert_weights, fc2_scales, nullptr, - fc2_result + total_past_rows_ * hidden_size, - total_rows_before_expert_ + local_experts_start_index, expanded_active_expert_rows, - hidden_size, inter_size, local_num_experts, stream); + moe_gemm_runner_.moe_gemm(fc1_result_ + total_past_rows_ * inter_size, fc2_expert_weights, fc2_scales, nullptr, + fc2_result + total_past_rows_ * hidden_size, + total_rows_before_expert_ + local_experts_start_index, expanded_active_expert_rows, + hidden_size, inter_size, local_num_experts, stream); } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700 template -void CutlassMoeFCRunner::run_moe_fc(const T*, const T*, const WeightType*, const T*, const T*, - ActivationType, const WeightType*, const T*, const T*, - const WeightType*, const T*, int, const int, const int, int, - int, int, int k, char*, T*, T*, int*, int*, cudaStream_t) { - // MoE gemm only supports Volta+ architectures - ORT_THROW("[FT Error][Run MoE FC] MoE gemm only supports Volta+ architectures"); +void CutlassMoeFCRunner::run_moe_fc(const T *, const T *, const WeightType *, const T *, + const T *, ActivationType, const WeightType *, const T *, + const T *, const WeightType *, const T *, int, const int, + const int, int, int, int, int k, char *, T *, T *, int *, + int *, cudaStream_t) { + // MoE gemm only supports Volta+ architectures + ORT_THROW("[Run MoE FC] MoE gemm only supports Volta+ architectures"); } #else template void CutlassMoeFCRunner::run_moe_fc( - const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, - const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc3_expert_weights, - const T* fc3_scales, const T* fc3_expert_biases, const WeightType* fc2_expert_weights, const T* fc2_scales, + const T *input_activations, const T *gating_output, const WeightType *fc1_expert_weights, const T *fc1_scales, + const T *fc1_expert_biases, ActivationType fc1_activation_type, const WeightType *fc3_expert_weights, + const T *fc3_scales, const T *fc3_expert_biases, const WeightType *fc2_expert_weights, const T *fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts, int local_num_experts, - int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, T* expert_scales, - int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream) { - run_moe_fc(input_activations, gating_output, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_activation_type, - fc3_expert_weights, fc3_scales, fc3_expert_biases, fc2_expert_weights, fc2_scales, num_rows, hidden_size, - inter_size, num_experts, local_num_experts, local_experts_start_index, k, workspace_ptr, fc2_result, - nullptr, num_rows, expert_scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, stream); + int local_experts_start_index, int k, char *workspace_ptr, T *fc2_result, T *expert_scales, + int *expanded_source_row_to_expanded_dest_row, int *expert_for_source_row, cudaStream_t stream) { + run_moe_fc(input_activations, gating_output, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_activation_type, + fc3_expert_weights, fc3_scales, fc3_expert_biases, fc2_expert_weights, fc2_scales, num_rows, hidden_size, + inter_size, num_experts, local_num_experts, local_experts_start_index, k, workspace_ptr, fc2_result, + nullptr, num_rows, expert_scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, + stream); } #endif template -void CutlassMoeFCRunner::compute_total_rows_before_expert(const int* sorted_indices, +void CutlassMoeFCRunner::compute_total_rows_before_expert(const int *sorted_indices, const int total_indices, int num_experts, - int64_t* total_rows_before_expert, + int64_t *total_rows_before_expert, cudaStream_t stream) { - const int threads = std::min(1024, num_experts); - const int blocks = (num_experts + threads - 1) / threads; + const int threads = std::min(1024, num_experts); + const int blocks = (num_experts + threads - 1) / threads; - compute_total_rows_before_expert_kernel<<>>(sorted_indices, total_indices, num_experts, - total_rows_before_expert); + compute_total_rows_before_expert_kernel<<>>(sorted_indices, total_indices, num_experts, + total_rows_before_expert); } template -void CutlassMoeFCRunner::dispatch_activations(int64_t* total_rows_before_expert, int num_experts, +void CutlassMoeFCRunner::dispatch_activations(int64_t *total_rows_before_expert, int num_experts, int local_num_experts, int local_experts_start_index, cudaStream_t stream) { - total_rows_before_expert_host_.resize(num_experts); - cudaMemcpyAsync(total_rows_before_expert_host_.data(), total_rows_before_expert, num_experts * sizeof(int64_t), - cudaMemcpyDeviceToHost, stream); + total_rows_before_expert_host_.resize(num_experts); + cudaMemcpyAsync(total_rows_before_expert_host_.data(), total_rows_before_expert, num_experts * sizeof(int64_t), + cudaMemcpyDeviceToHost, stream); - const int threads = std::min(1024, num_experts); - const int blocks = (num_experts + threads - 1) / threads; + const int threads = std::min(1024, num_experts); + const int blocks = (num_experts + threads - 1) / threads; - cudaEvent_t& copy_event = cuda_event_.Get(); - cudaEventCreateWithFlags(©_event, cudaEventDisableTiming); - cudaEventRecord(copy_event, stream); + cudaEvent_t ©_event = cuda_event_.Get(); + cudaEventCreateWithFlags(©_event, cudaEventDisableTiming); + cudaEventRecord(copy_event, stream); - dispatch_activations_kernel<<>>(total_rows_before_expert, num_experts, local_num_experts, - local_experts_start_index); + dispatch_activations_kernel<<>>(total_rows_before_expert, num_experts, + local_num_experts, local_experts_start_index); - get_total_rows_info(local_experts_start_index, local_num_experts, total_past_rows_, total_covered_rows_); + get_total_rows_info(local_experts_start_index, local_num_experts, total_past_rows_, total_covered_rows_); } template void CutlassMoeFCRunner::get_total_rows_info(int64_t experts_start_index, - int64_t local_num_experts, int64_t& total_past_rows, - int64_t& total_covered_rows) { - int64_t experts_end_index = experts_start_index + local_num_experts - 1; - total_past_rows = 0; + int64_t local_num_experts, int64_t &total_past_rows, + int64_t &total_covered_rows) { + int64_t experts_end_index = experts_start_index + local_num_experts - 1; + total_past_rows = 0; - cudaEventSynchronize(cuda_event_.Get()); + cudaEventSynchronize(cuda_event_.Get()); - if (experts_start_index > 0) { - total_past_rows = total_rows_before_expert_host_[experts_start_index - 1]; - } - total_covered_rows = total_rows_before_expert_host_[experts_end_index] - total_past_rows; + if (experts_start_index > 0) { + total_past_rows = total_rows_before_expert_host_[experts_start_index - 1]; + } + total_covered_rows = total_rows_before_expert_host_[experts_end_index] - total_past_rows; } // ========================== Permutation things ======================================= @@ -880,150 +873,150 @@ void CutlassMoeFCRunner::get_total_rows_info(int64_t expe // of the expanded index. template -__global__ void initialize_moe_routing_kernel(const T* unpermuted_input, T* permuted_output, - const int* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, int num_rows, +__global__ void initialize_moe_routing_kernel(const T *unpermuted_input, T *permuted_output, + const int *expanded_dest_row_to_expanded_source_row, + int *expanded_source_row_to_expanded_dest_row, int num_rows, int active_rows, int cols) { - // Reverse permutation map. - // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the - // reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 - // thread block will be responsible for all k summations. - const int expanded_dest_row = blockIdx.x; - const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; - if (threadIdx.x == 0) { - expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_dest_row; - } - - if (blockIdx.x < active_rows) { - // Duplicate and permute rows - const int source_row = expanded_source_row % num_rows; - - const T* source_row_ptr = unpermuted_input + source_row * cols; - T* dest_row_ptr = permuted_output + expanded_dest_row * cols; + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the + // reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. + const int expanded_dest_row = blockIdx.x; + const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + if (threadIdx.x == 0) { + expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_dest_row; + } - for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { - dest_row_ptr[tid] = source_row_ptr[tid]; + if (blockIdx.x < active_rows) { + // Duplicate and permute rows + const int source_row = expanded_source_row % num_rows; + + const T *source_row_ptr = unpermuted_input + source_row * cols; + T *dest_row_ptr = permuted_output + expanded_dest_row * cols; + + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { + dest_row_ptr[tid] = source_row_ptr[tid]; + } } - } } template -void initialize_moe_routing_kernelLauncher(const T* unpermuted_input, T* permuted_output, - const int* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, int num_rows, int active_rows, +void initialize_moe_routing_kernelLauncher(const T *unpermuted_input, T *permuted_output, + const int *expanded_dest_row_to_expanded_source_row, + int *expanded_source_row_to_expanded_dest_row, int num_rows, int active_rows, int cols, int k, cudaStream_t stream) { - const int blocks = num_rows * k; - const int threads = std::min(cols, 1024); - initialize_moe_routing_kernel - <<>>(unpermuted_input, permuted_output, expanded_dest_row_to_expanded_source_row, - expanded_source_row_to_expanded_dest_row, num_rows, k * active_rows, cols); + const int blocks = num_rows * k; + const int threads = std::min(cols, 1024); + initialize_moe_routing_kernel + <<>>(unpermuted_input, permuted_output, expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, num_rows, k * active_rows, cols); } // Final kernel to unpermute and scale // This kernel unpermutes the original data, does the k-way reduction and performs the final skip connection. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 template -__global__ void finalize_moe_routing_kernel(const T*, T*, const T*, const T*, const T*, const T*, const int*, - const int*, int, const int) { - // Does not support pre-Kepler architectures - ; +__global__ void finalize_moe_routing_kernel(const T *, T *, const T *, const T *, const T *, const T *, const int *, + const int *, int, const int) { + // Does not support pre-Kepler architectures + ; } #else template -__global__ void finalize_moe_routing_kernel(const T* expanded_permuted_rows, T* reduced_unpermuted_output, - const T* skip_1, const T* skip_2, const T* bias, const T* scales, - const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, int cols, int k) { - const int original_row = blockIdx.x; - int num_rows = gridDim.x; - T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols; - - const T* skip_1_row_ptr = nullptr; - if (RESIDUAL_NUM == 1) { - skip_1_row_ptr = skip_1 + original_row * cols; - } - const T* skip_2_row_ptr = nullptr; - if (RESIDUAL_NUM == 2) { - skip_2_row_ptr = skip_2 + original_row * cols; - } - - for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { - T thread_output; - if (RESIDUAL_NUM == 0) { - thread_output = T(0); - } else if (RESIDUAL_NUM == 1) { - thread_output = skip_1_row_ptr[tid]; - } else if (RESIDUAL_NUM == 2) { - thread_output = skip_1_row_ptr[tid] + skip_2_row_ptr[tid]; +__global__ void finalize_moe_routing_kernel(const T *expanded_permuted_rows, T *reduced_unpermuted_output, + const T *skip_1, const T *skip_2, const T *bias, const T *scales, + const int *expanded_source_row_to_expanded_dest_row, + const int *expert_for_source_row, int cols, int k) { + const int original_row = blockIdx.x; + int num_rows = gridDim.x; + T *reduced_row_ptr = reduced_unpermuted_output + original_row * cols; + + const T *skip_1_row_ptr = nullptr; + if (RESIDUAL_NUM == 1) { + skip_1_row_ptr = skip_1 + original_row * cols; } - for (int k_idx = 0; k_idx < k; ++k_idx) { - const int expanded_original_row = original_row + k_idx * num_rows; - const int expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row]; + const T *skip_2_row_ptr = nullptr; + if (RESIDUAL_NUM == 2) { + skip_2_row_ptr = skip_2 + original_row * cols; + } + + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { + T thread_output; + if (RESIDUAL_NUM == 0) { + thread_output = T(0); + } else if (RESIDUAL_NUM == 1) { + thread_output = skip_1_row_ptr[tid]; + } else if (RESIDUAL_NUM == 2) { + thread_output = skip_1_row_ptr[tid] + skip_2_row_ptr[tid]; + } + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int expanded_original_row = original_row + k_idx * num_rows; + const int expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row]; - const int64_t k_offset = original_row * k + k_idx; - const T row_scale = scales[k_offset]; - const T* expanded_permuted_rows_row_ptr = expanded_permuted_rows + expanded_permuted_row * cols; + const int64_t k_offset = original_row * k + k_idx; + const T row_scale = scales[k_offset]; + const T *expanded_permuted_rows_row_ptr = expanded_permuted_rows + expanded_permuted_row * cols; - const int expert_idx = expert_for_source_row[k_offset]; - const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; + const int expert_idx = expert_for_source_row[k_offset]; + const T *bias_ptr = bias ? bias + expert_idx * cols : nullptr; - thread_output = - thread_output + row_scale * (expanded_permuted_rows_row_ptr[tid] + (bias_ptr ? bias_ptr[tid] : T(0))); + thread_output = + thread_output + row_scale * (expanded_permuted_rows_row_ptr[tid] + (bias_ptr ? bias_ptr[tid] : T(0))); + } + reduced_row_ptr[tid] = thread_output; } - reduced_row_ptr[tid] = thread_output; - } } #endif template -void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* bias, - const T* scales, const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, int num_rows, int cols, int k, +void finalize_moe_routing_kernelLauncher(const T *expanded_permuted_rows, T *reduced_unpermuted_output, const T *bias, + const T *scales, const int *expanded_source_row_to_expanded_dest_row, + const int *expert_for_source_row, int num_rows, int cols, int k, cudaStream_t stream) { - const int blocks = num_rows; - const int threads = std::min(cols, 1024); - finalize_moe_routing_kernel<<>>( - expanded_permuted_rows, reduced_unpermuted_output, nullptr, nullptr, bias, scales, - expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); + const int blocks = num_rows; + const int threads = std::min(cols, 1024); + finalize_moe_routing_kernel<<>>( + expanded_permuted_rows, reduced_unpermuted_output, nullptr, nullptr, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); } template -void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip, - const T* bias, const T* scales, - const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, int num_rows, int cols, int k, +void finalize_moe_routing_kernelLauncher(const T *expanded_permuted_rows, T *reduced_unpermuted_output, const T *skip, + const T *bias, const T *scales, + const int *expanded_source_row_to_expanded_dest_row, + const int *expert_for_source_row, int num_rows, int cols, int k, cudaStream_t stream) { - const int blocks = num_rows; - const int threads = std::min(cols, 1024); - finalize_moe_routing_kernel - <<>>(expanded_permuted_rows, reduced_unpermuted_output, skip, nullptr, bias, scales, - expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); + const int blocks = num_rows; + const int threads = std::min(cols, 1024); + finalize_moe_routing_kernel + <<>>(expanded_permuted_rows, reduced_unpermuted_output, skip, nullptr, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); } template -void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1, - const T* skip_2, const T* bias, const T* scales, - const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, int num_rows, int cols, int k, +void finalize_moe_routing_kernelLauncher(const T *expanded_permuted_rows, T *reduced_unpermuted_output, const T *skip_1, + const T *skip_2, const T *bias, const T *scales, + const int *expanded_source_row_to_expanded_dest_row, + const int *expert_for_source_row, int num_rows, int cols, int k, cudaStream_t stream) { - const int blocks = num_rows; - const int threads = std::min(cols, 1024); - if (skip_2 == nullptr) { - finalize_moe_routing_kernel<<>>( - expanded_permuted_rows, reduced_unpermuted_output, skip_1, skip_2, bias, scales, - expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); - } else { - finalize_moe_routing_kernel<<>>( - expanded_permuted_rows, reduced_unpermuted_output, skip_1, skip_2, bias, scales, - expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); - } + const int blocks = num_rows; + const int threads = std::min(cols, 1024); + if (skip_2 == nullptr) { + finalize_moe_routing_kernel<<>>( + expanded_permuted_rows, reduced_unpermuted_output, skip_1, skip_2, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); + } else { + finalize_moe_routing_kernel<<>>( + expanded_permuted_rows, reduced_unpermuted_output, skip_1, skip_2, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); + } } // ========================= TopK Softmax specializations =========================== -template void topk_gating_softmax_kernelLauncher(const float*, const bool*, float*, float*, int*, int*, int, int, int, - bool, cudaStream_t); -template void topk_gating_softmax_kernelLauncher(const half*, const bool*, half*, half*, int*, int*, int, int, int, - bool, cudaStream_t); +template void topk_gating_softmax_kernelLauncher(const float *, const bool *, float *, float *, int *, int *, int, int, + int, bool, cudaStream_t); +template void topk_gating_softmax_kernelLauncher(const half *, const bool *, half *, half *, int *, int *, int, int, + int, bool, cudaStream_t); // ==================== Variable batched GEMM specializations ================================== template class CutlassMoeFCRunner; @@ -1031,23 +1024,23 @@ template class CutlassMoeFCRunner; template class CutlassMoeFCRunner; // ===================== Specializations for init routing ========================= -template void initialize_moe_routing_kernelLauncher(const float*, float*, const int*, int*, int, int, int, int, +template void initialize_moe_routing_kernelLauncher(const float *, float *, const int *, int *, int, int, int, int, cudaStream_t); -template void initialize_moe_routing_kernelLauncher(const half*, half*, const int*, int*, int, int, int, int, +template void initialize_moe_routing_kernelLauncher(const half *, half *, const int *, int *, int, int, int, int, cudaStream_t); // ==================== Specializations for final routing =================================== -template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const int*, - const int*, int, int, int, cudaStream_t); -template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const int*, const int*, - int, int, int, cudaStream_t); -template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const float*, - const int*, const int*, int, int, int, cudaStream_t); -template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, const int*, - const int*, int, int, int, cudaStream_t); -template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const float*, - const float*, const int*, const int*, int, int, int, cudaStream_t); -template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, - const half*, const int*, const int*, int, int, int, cudaStream_t); - -} // namespace ort_fastertransformer +template void finalize_moe_routing_kernelLauncher(const float *, float *, const float *, const float *, const int *, + const int *, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const half *, half *, const half *, const half *, const int *, + const int *, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const float *, float *, const float *, const float *, const float *, + const int *, const int *, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const half *, half *, const half *, const half *, const half *, + const int *, const int *, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const float *, float *, const float *, const float *, const float *, + const float *, const int *, const int *, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const half *, half *, const half *, const half *, const half *, + const half *, const int *, const int *, int, int, int, cudaStream_t); + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index 5eef6f95f4820..18a26e6a43382 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -21,8 +21,8 @@ #include "moe_gemm_kernels.h" #include -#include "core/common/common.h" #include "contrib_ops/cuda/bert/transformer_cuda_common.h" +#include "core/common/common.h" #include "cutlass/numeric_types.h" @@ -53,8 +53,8 @@ static inline size_t pad_to_multiple_of_16(size_t input) { */ template void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_out, - int* indices, int* source_row, int num_rows, int num_experts, - int k, cudaStream_t stream); + int* indices, int* source_row, int num_rows, int num_experts, int k, + cudaStream_t stream); class CubKeyValueSorter { public: @@ -78,28 +78,28 @@ class CubKeyValueSorter { template void initialize_moe_routing_kernelLauncher(const T* unpermuted_input, T* permuted_output, const int* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, int num_rows, - int active_rows, int cols, int k, cudaStream_t stream); + int* expanded_source_row_to_expanded_dest_row, int num_rows, int active_rows, + int cols, int k, cudaStream_t stream); template void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* bias, const T* scales, const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, int num_rows, int cols, - int k, cudaStream_t stream); + const int* expert_for_source_row, int num_rows, int cols, int k, + cudaStream_t stream); template void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip, const T* bias, const T* scales, const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, int num_rows, int cols, - int k, cudaStream_t stream); + const int* expert_for_source_row, int num_rows, int cols, int k, + cudaStream_t stream); template void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1, const T* skip_2, const T* bias, const T* scales, const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, int num_rows, int cols, - int k, cudaStream_t stream); + const int* expert_for_source_row, int num_rows, int cols, int k, + cudaStream_t stream); // Assumes inputs activations are row major. Weights need to be preprocessed by th_op/weight_quantize.cc . // Nested in a class to avoid multiple calls to cudaGetDeviceProperties as this call can be expensive. diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index dbd783c0cb11c..6aa75840e6dc0 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -22,7 +22,8 @@ REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) template -MoE::MoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) {} +MoE::MoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) { +} template Status MoE::ComputeInternal(OpKernelContext* context) const { @@ -72,6 +73,7 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream); const CudaT* fc_scales_ptr = nullptr; + moe_runner.run_moe_fc( reinterpret_cast(input->template Data()), reinterpret_cast(router_probs->template Data()),