Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into gwang-msft/qdq_mul
Browse files Browse the repository at this point in the history
  • Loading branch information
guoyu-wang committed Feb 7, 2022
2 parents 7a32847 + 0f5d0a0 commit 9cadda7
Show file tree
Hide file tree
Showing 55 changed files with 1,280 additions and 172 deletions.
8 changes: 8 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,14 @@ if (onnxruntime_CROSS_COMPILING)
endif()
endif()

if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
check_cxx_compiler_flag(-Wno-error HAS_NOERROR)
if (HAS_NOERROR)
string(APPEND CMAKE_CXX_FLAGS " -Wno-error=attributes")
string(APPEND CMAKE_C_FLAGS " -Wno-error=attributes")
endif()
endif()

# Mark symbols to be invisible, for macOS/iOS target only
# Due to many dependencies have different symbol visibility settings, set global compile flags here.
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin|iOS")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

/// <summary>
/// Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT_V2.
/// Please note that this struct is identical to OrtTensorRTProviderOptions but only to be used internally.
/// Please note that this struct is *similar* to OrtTensorRTProviderOptions but only to be used internally.
/// Going forward, new trt provider options are to be supported via this struct and usage of the publicly defined
/// OrtTensorRTProviderOptions will be deprecated over time.
/// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions.
/// </summary>
struct OrtTensorRTProviderOptionsV2 {
Expand Down
3 changes: 3 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,9 @@ typedef struct OrtTensorRTProviderOptions {
int trt_engine_decryption_enable; // enable engine decryption. Default 0 = false, nonzero = true
const char* trt_engine_decryption_lib_path; // specify engine decryption library path
int trt_force_sequential_engine_build; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true
// This is the legacy struct and don't add new fields here.
// For new field that can be represented by string, please add it in include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
// For non-string field, need to create a new separate api to handle it.
} OrtTensorRTProviderOptions;

/** \brief MIGraphX Provider Options
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ struct SessionOptions : Base<OrtSessionOptions> {
SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
SessionOptions& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
SessionOptions& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX

SessionOptions& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
Expand Down
5 changes: 5 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,11 @@ inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT(const Or
return *this;
}

inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) {
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(p_, &provider_options));
return *this;
}

inline SessionOptions& SessionOptions::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) {
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(p_, &provider_options));
return *this;
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
Tensor* output = context->Output(0, input->Shape());

int64_t input_length = input->Shape().Size();
if (input_length == 0) {
return Status::OK();
}
int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
typedef typename ToCudaType<T>::MappedType CudaT;

Expand Down
14 changes: 5 additions & 9 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {

Tensor* output = ctx->Output(0, input->Shape());

if (input->SizeInBytes() == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'input' has no data from upstream nodes");
if (input->Shape() != skip->Shape()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"skip is expected to have same shape as input");
}

if (skip->SizeInBytes() == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'skip' has no data from upstream nodes");
if (input->Shape().Size() == 0) {
return Status::OK();
}

const auto& input_dims = input->Shape().GetDims();
Expand All @@ -55,11 +56,6 @@ Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
"input is expected to have 3 dimensions, got ", input_dims.size());
}

if (input->Shape() != skip->Shape()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"skip is expected to have same shape as input");
}

const auto& gamma_dims = gamma->Shape().GetDims();
if (gamma_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand Down
10 changes: 4 additions & 6 deletions onnxruntime/contrib_ops/cuda/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,6 @@ Status LayerNorm<T, U, simplified>::ComputeInternal(OpKernelContext* ctx) const
auto bias_data = (simplified || (nullptr == bias)) ? nullptr : reinterpret_cast<const CudaT*>(bias->template Data<T>());

const TensorShape& x_shape = X->Shape();
// Sometimes due to conversion issue, the input 'X' has no data which is a case that cuda kernel cannot handle.
// Provide more error infomation here instead of CUDA errors.
if (X->SizeInBytes() == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'X' has no data from upstream nodes");
}

const int64_t axis = HandleNegativeAxis(axis_, x_shape.NumDimensions());

int n1 = gsl::narrow<int>(x_shape.SizeToDimension(axis));
Expand Down Expand Up @@ -101,6 +95,10 @@ Status LayerNorm<T, U, simplified>::ComputeInternal(OpKernelContext* ctx) const
inv_var_data = reinterpret_cast<CudaU*>(var->template MutableData<U>());
}

if (x_shape.Size() == 0) {
return Status::OK();
}

HostApplyLayerNorm<CudaT, CudaU, simplified>(GetDeviceProp(), Stream(), Y_data, mean_data, inv_var_data, X_data, n1, n2, epsilon_, scale_data, bias_data);
return Status::OK();
}
Expand Down
20 changes: 20 additions & 0 deletions onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FusedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul);
Expand Down Expand Up @@ -76,6 +80,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int8_t, QAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention);

// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, BFloat16_float, LayerNormalization);

template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
KernelCreateInfo info;
Expand All @@ -93,6 +102,10 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul)>,
Expand Down Expand Up @@ -152,6 +165,13 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int8_t, QAttention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention)>

// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu)>,
// TransposedMatMul is still here for backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, BFloat16_float, LayerNormalization)>,

};

for (auto& function_table_entry : function_table) {
Expand Down
163 changes: 163 additions & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2820,6 +2820,169 @@ Example 4:
}
});

static const char* EfficientNMS_TRT_ver1_doc =
R"DOC(Efficient NMS TensorRT Plugin.)DOC";

ONNX_CONTRIB_OPERATOR_SCHEMA(EfficientNMS_TRT)
.SetDomain(kOnnxDomain)
.SinceVersion(1)
.SetDoc(EfficientNMS_TRT_ver1_doc)
.Input(0, "boxes", "The boxes input tensor.", "T")
.Input(1, "scores", "The scores input tensor.", "T")
.Input(2, "anchors", "The anchors input tensor.", "T", OpSchema::Optional)
.Output(0, "num_detections", "The num_detections output tensor.", "tensor(int32)")
.Output(1, "detection_boxes", "The detection_boxes output tensor.", "T")
.Output(2, "detection_scores", "The detection_scores output tensor.", "T")
.Output(3, "detection_classes", "The detection_classes output tensor.", "tensor(int32)")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
.Attr("background_class", "Background class ID.", AttributeProto::INT)
.Attr("box_coding", "Encoding type for the boxes or anchors inputs.", AttributeProto::INT)
.Attr("iou_threshold", "Box IOU threshold value.", AttributeProto::FLOAT)
.Attr("max_output_boxes", "Max detections to output.", AttributeProto::INT)
.Attr("plugin_version", "Version number of the TRT plugin.", AttributeProto::STRING)
.Attr("score_activation", "Activation function to apply to the scores input.", AttributeProto::INT)
.Attr("score_threshold", "Score threshold value.", AttributeProto::FLOAT)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
using namespace ONNX_NAMESPACE;
ONNX_NAMESPACE::updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT32);
propagateElemTypeFromInputToOutput(ctx, 0, 1);
propagateElemTypeFromInputToOutput(ctx, 0, 2);
ONNX_NAMESPACE::updateOutputElemType(ctx, 3, ONNX_NAMESPACE::TensorProto::INT32);

// Shape Inference
if (!hasInputShape(ctx, 0)) {
return;
}
int64_t max_output_boxes = 1;
auto max_output_boxes_proto = ctx.getAttribute("max_output_boxes");
if (max_output_boxes_proto) {
max_output_boxes = max_output_boxes_proto->i();
}
if (max_output_boxes < 1) {
fail_shape_inference("Attribute 'max_output_boxes' must be >= 1.")
}

Dim batch_size;
unifyInputDim(ctx, 0, 0, batch_size);

ONNX_NAMESPACE::TensorShapeProto num_detections_shape;
*num_detections_shape.add_dim() = batch_size;
num_detections_shape.add_dim()->set_dim_value(1);
updateOutputShape(ctx, 0, num_detections_shape);

ONNX_NAMESPACE::TensorShapeProto detection_boxes_shape;
*detection_boxes_shape.add_dim() = batch_size;
detection_boxes_shape.add_dim()->set_dim_value(max_output_boxes);
detection_boxes_shape.add_dim()->set_dim_value(4);
updateOutputShape(ctx, 1, detection_boxes_shape);

ONNX_NAMESPACE::TensorShapeProto detection_scores_shape;
*detection_scores_shape.add_dim() = batch_size;
detection_scores_shape.add_dim()->set_dim_value(max_output_boxes);
updateOutputShape(ctx, 2, detection_scores_shape);

ONNX_NAMESPACE::TensorShapeProto detection_classes_shape;
*detection_classes_shape.add_dim() = batch_size;
detection_classes_shape.add_dim()->set_dim_value(max_output_boxes);
updateOutputShape(ctx, 3, detection_classes_shape);
});

static const char* MultilevelCropAndResize_TRT_ver1_doc =
R"DOC(Multilevel Crop and Resize TensorRT Plugin.)DOC";

ONNX_CONTRIB_OPERATOR_SCHEMA(MultilevelCropAndResize_TRT)
.SetDomain(kOnnxDomain)
.SinceVersion(1)
.SetDoc(MultilevelCropAndResize_TRT_ver1_doc)
.Input(0, "boxes", "The boxes input tensor.", "T")
.Input(1, "feature_map_0", "The first feature map input tensor.", "T")
.Input(2, "feature_map_1", "The second feature map input tensor.", "T")
.Input(3, "feature_map_2", "The third feature map input tensor.", "T")
.Input(4, "feature_map_3", "The fourth feature map input tensor.", "T")
.Output(0, "patches", "The cropped patches output tensor.", "T")
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors.")
.Attr("image_size", "Image size.", AttributeProto::INTS)
.Attr("pooled_size", "Pooled size.", AttributeProto::INT)
.Attr("plugin_version", "Version number of the TRT plugin.", AttributeProto::STRING)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);

// Shape Inference
if (!hasInputShape(ctx, 0)) {
return;
}
int64_t pooled_size = 1;
auto pooled_size_proto = ctx.getAttribute("pooled_size");
if (pooled_size_proto) {
pooled_size = pooled_size_proto->i();
}
if (pooled_size < 1) {
fail_shape_inference("Attribute 'pooled_size' must be >= 1.")
}

Dim batch_size, number_boxes, channels;
unifyInputDim(ctx, 0, 0, batch_size);
unifyInputDim(ctx, 0, 1, number_boxes);
unifyInputDim(ctx, 1, 1, channels);

ONNX_NAMESPACE::TensorShapeProto output_shape;
*output_shape.add_dim() = batch_size;
*output_shape.add_dim() = number_boxes;
*output_shape.add_dim() = channels;
output_shape.add_dim()->set_dim_value(pooled_size);
output_shape.add_dim()->set_dim_value(pooled_size);
updateOutputShape(ctx, 0, output_shape);
});

static const char* PyramidROIAlign_TRT_ver1_doc =
R"DOC(Pyramid ROI Align TensorRT Plugin.)DOC";

ONNX_CONTRIB_OPERATOR_SCHEMA(PyramidROIAlign_TRT)
.SetDomain(kOnnxDomain)
.SinceVersion(1)
.SetDoc(PyramidROIAlign_TRT_ver1_doc)
.Input(0, "boxes", "The boxes input tensor.", "T")
.Input(1, "feature_map_0", "The first feature map input tensor.", "T")
.Input(2, "feature_map_1", "The second feature map input tensor.", "T")
.Input(3, "feature_map_2", "The third feature map input tensor.", "T")
.Input(4, "feature_map_3", "The fourth feature map input tensor.", "T")
.Output(0, "patches", "The cropped patches output tensor.", "T")
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors.")
.Attr("pooled_size", "Pooled size.", AttributeProto::INT)
.Attr("plugin_version", "Version number of the TRT plugin.", AttributeProto::STRING)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);

// Shape Inference
if (!hasInputShape(ctx, 0)) {
return;
}
int64_t pooled_size = 1;
auto pooled_size_proto = ctx.getAttribute("pooled_size");
if (pooled_size_proto) {
pooled_size = pooled_size_proto->i();
}
if (pooled_size < 1) {
fail_shape_inference("Attribute 'pooled_size' must be >= 1.")
}

Dim batch_size, number_boxes, channels;
unifyInputDim(ctx, 0, 0, batch_size);
unifyInputDim(ctx, 0, 1, number_boxes);
unifyInputDim(ctx, 1, 1, channels);

ONNX_NAMESPACE::TensorShapeProto output_shape;
*output_shape.add_dim() = batch_size;
*output_shape.add_dim() = number_boxes;
*output_shape.add_dim() = channels;
output_shape.add_dim()->set_dim_value(pooled_size);
output_shape.add_dim()->set_dim_value(pooled_size);
updateOutputShape(ctx, 0, output_shape);
});

static const char* Gelu_ver1_doc =
R"DOC(Gaussian Error Linear Unit.
A high-performing neural network activation function.The GELU nonlinearity is
Expand Down
Loading

0 comments on commit 9cadda7

Please sign in to comment.