diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc index 6c3e8d7c321aa..5d5183221eda4 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc @@ -32,7 +32,7 @@ BiasAdd::BiasAdd(const OpKernelInfo& op_info) : CudaKernel(op_info) { template Status BiasAdd::ComputeInternal(OpKernelContext* context) const { // Input: [batch_size, height*width, channels] - // Bias: [hidden_size] + // Bias: [channels] // Skip: [batch_size, height*width, channels] // Output: [batch_size, height*width, channels] diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu index 0bc635dd85cbb..2983cc99e30b1 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu @@ -23,6 +23,8 @@ #include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/diffusion/bias_add_impl.h" +using namespace onnxruntime::cuda; + namespace onnxruntime { namespace contrib { namespace cuda { @@ -34,11 +36,7 @@ __global__ void BiasAddKernel(T const* input, T const* bias, T const* residual, #pragma unroll for (int32_t i = 0; i < C / TPB; ++i) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) output[base_offset] = input[base_offset] + bias[bias_offset] + residual[base_offset]; -#else - output[base_offset] = static_cast(float(input[base_offset]) + float(bias[bias_offset]) + float(residual[base_offset])); -#endif base_offset += TPB; bias_offset += TPB; } diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu index 3cb95dad26b36..8069cbc0a1e0e 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu @@ -23,6 +23,8 @@ #include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/diffusion/bias_split_gelu_impl.h" +using namespace onnxruntime::cuda; + namespace onnxruntime { namespace contrib { namespace cuda { @@ -35,13 +37,9 @@ __global__ void biasSplitGeluKernel(T const* input, T const* bias, T* output) { #pragma unroll for (int32_t i = 0; i < HHS / TPB; ++i) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) auto value_left = (float)(input[index_input] + bias[index_bias]); auto value_right = (float)(input[index_input + HHS] + bias[index_bias + HHS]); -#else - auto value_left = (float)(input[index_input]) + (float)(bias[index_bias]); - auto value_right = (float)(input[index_input + HHS]) + (float)(bias[index_bias + HHS]); -#endif + // Gelu is applied to right side only: Gelu(x) = x * 0.5 * (erf(x / sqrt(2)) + 1.0) float gelu_right = value_right * 0.5f * (erff(value_right / 1.41421356237f) + 1.0f); float result = value_left * gelu_right; diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index 1b057c0d4deb5..a7904c39f8491 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -202,8 +202,10 @@ def convert_float_to_float16( value_info_list = [] node_list = [] - # List of Resize or GroupNorm nodes that is not in block list - resize_node_list = [] + # Some operators (Like Resize or GroupNorm) have data type fixed as float for some input. + # When it is converted to float16, there are mixed types: some inputs are float32 and some are float16. + # This list keeps track of such nodes that are not in block list. + mixed_float_type_node_list = [] # type inference on input model if func_infer_shape is not None: @@ -303,7 +305,7 @@ def convert_float_to_float16( for attr in n.attribute: next_level.append(attr) else: - resize_node_list.append(n) + mixed_float_type_node_list.append(n) # if q is model.graph.node.attribute, push q.g and q.graphs (GraphProto) # and process node.attribute.t and node.attribute.tensors (TensorProto) @@ -344,8 +346,8 @@ def convert_float_to_float16( ) ) - # Some operators has data type fixed as float for some input. Add a float16 to float cast for those inputs. - for node in resize_node_list: + # Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs. + for node in mixed_float_type_node_list: for i, input_name in enumerate(node.input): if i not in ALWAYS_FLOAT_INPUTS[node.op_type]: continue