Skip to content

Commit

Permalink
feat: more options for CUDA EP
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Sep 21, 2024
1 parent 7f71e6c commit e16fd5b
Showing 1 changed file with 73 additions and 2 deletions.
75 changes: 73 additions & 2 deletions src/execution_providers/cuda.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,50 @@
use std::ops::BitOr;

use crate::{
error::{Error, Result},
execution_providers::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderDispatch},
session::SessionBuilder
};

// https://github.com/microsoft/onnxruntime/blob/ffceed9d44f2f3efb9dd69fa75fea51163c91d91/onnxruntime/contrib_ops/cpu/bert/attention_common.h#L160-L171
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct CUDAExecutionProviderAttentionBackend(u32);

impl CUDAExecutionProviderAttentionBackend {
pub const FLASH_ATTENTION: Self = Self(1 << 0);
pub const EFFICIENT_ATTENTION: Self = Self(1 << 1);
pub const TRT_FUSED_ATTENTION: Self = Self(1 << 2);
pub const CUDNN_FLASH_ATTENTION: Self = Self(1 << 3);
pub const MATH: Self = Self(1 << 4);

pub const TRT_FLASH_ATTENTION: Self = Self(1 << 5);
pub const TRT_CROSS_ATTENTION: Self = Self(1 << 6);
pub const TRT_CAUSAL_ATTENTION: Self = Self(1 << 7);

pub fn none() -> Self {
CUDAExecutionProviderAttentionBackend(0)
}

pub fn all() -> Self {
Self::FLASH_ATTENTION
| Self::EFFICIENT_ATTENTION
| Self::TRT_FUSED_ATTENTION
| Self::CUDNN_FLASH_ATTENTION
| Self::MATH
| Self::TRT_FLASH_ATTENTION
| Self::TRT_CROSS_ATTENTION
| Self::TRT_CAUSAL_ATTENTION
}
}

impl BitOr for CUDAExecutionProviderAttentionBackend {
type Output = Self;
fn bitor(self, rhs: Self) -> Self::Output {
Self(rhs.0 | self.0)
}
}

/// The type of search done for cuDNN convolution algorithms.
#[derive(Debug, Clone)]
pub enum CUDAExecutionProviderCuDNNConvAlgoSearch {
Expand Down Expand Up @@ -44,6 +85,7 @@ impl Default for CUDAExecutionProviderCuDNNConvAlgoSearch {
pub struct CUDAExecutionProvider {
device_id: Option<i32>,
gpu_mem_limit: Option<usize>,
user_compute_stream: Option<*mut ()>,
arena_extend_strategy: Option<ArenaExtendStrategy>,
cudnn_conv_algo_search: Option<CUDAExecutionProviderCuDNNConvAlgoSearch>,
do_copy_in_default_stream: Option<bool>,
Expand All @@ -52,7 +94,9 @@ pub struct CUDAExecutionProvider {
enable_cuda_graph: Option<bool>,
enable_skip_layer_norm_strict_mode: Option<bool>,
use_tf32: Option<bool>,
prefer_nhwc: Option<bool>
prefer_nhwc: Option<bool>,
sdpa_kernel: Option<u32>,
fuse_conv_bias: Option<bool>
}

impl CUDAExecutionProvider {
Expand Down Expand Up @@ -173,6 +217,29 @@ impl CUDAExecutionProvider {
self
}

/// # Safety
/// The provided `stream` must outlive the environment/session created with the execution provider.
#[must_use]
pub unsafe fn with_compute_stream(mut self, stream: *mut ()) -> Self {
self.user_compute_stream = Some(stream);
self
}

#[must_use]
pub fn with_attention_backend(mut self, flags: CUDAExecutionProviderAttentionBackend) -> Self {
self.sdpa_kernel = Some(flags.0);
self
}

#[must_use]
pub fn with_fuse_conv_bias(mut self, enable: bool) -> Self {
self.fuse_conv_bias = Some(enable);
self
}

// https://github.com/microsoft/onnxruntime/blob/ffceed9d44f2f3efb9dd69fa75fea51163c91d91/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h#L48
// https://github.com/microsoft/onnxruntime/blob/fe8a10caa40f64a8fbd144e7049cf5b14c65542d/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc#L17

#[must_use]
pub fn build(self) -> ExecutionProviderDispatch {
self.into()
Expand Down Expand Up @@ -211,14 +278,18 @@ impl ExecutionProvider for CUDAExecutionProvider {
CUDAExecutionProviderCuDNNConvAlgoSearch::Heuristic => "HEURISTIC",
CUDAExecutionProviderCuDNNConvAlgoSearch::Default => "DEFAULT"
}),
// has_user_compute_stream = self.user_compute_stream.as_ref().map(|_| 1),
user_compute_stream = self.user_compute_stream.map(|x| x as usize),
gpu_mem_limit = self.gpu_mem_limit,
do_copy_in_default_stream = self.do_copy_in_default_stream.map(<bool as Into<i32>>::into),
cudnn_conv_use_max_workspace = self.cudnn_conv_use_max_workspace.map(<bool as Into<i32>>::into),
cudnn_conv1d_pad_to_nc1d = self.cudnn_conv1d_pad_to_nc1d.map(<bool as Into<i32>>::into),
enable_cuda_graph = self.enable_cuda_graph.map(<bool as Into<i32>>::into),
enable_skip_layer_norm_strict_mode = self.enable_skip_layer_norm_strict_mode.map(<bool as Into<i32>>::into),
use_tf32 = self.use_tf32.map(<bool as Into<i32>>::into),
prefer_nhwc = self.prefer_nhwc.map(<bool as Into<i32>>::into)
prefer_nhwc = self.prefer_nhwc.map(<bool as Into<i32>>::into),
sdpa_kernel = self.sdpa_kernel,
fuse_conv_bias = self.fuse_conv_bias.map(<bool as Into<i32>>::into)
};
if let Err(e) =
crate::error::status_to_result(crate::ortsys![unsafe UpdateCUDAProviderOptions(cuda_options, key_ptrs.as_ptr(), value_ptrs.as_ptr(), len as _)])
Expand Down

0 comments on commit e16fd5b

Please sign in to comment.