Skip to content

Commit

Permalink
address review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Feb 14, 2023
1 parent 80a96b5 commit 18aa7fd
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 15 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ BiasAdd<T>::BiasAdd(const OpKernelInfo& op_info) : CudaKernel(op_info) {
template <typename T>
Status BiasAdd<T>::ComputeInternal(OpKernelContext* context) const {
// Input: [batch_size, height*width, channels]
// Bias: [hidden_size]
// Bias: [channels]
// Skip: [batch_size, height*width, channels]
// Output: [batch_size, height*width, channels]

Expand Down
6 changes: 2 additions & 4 deletions onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "core/providers/cuda/cu_inc/common.cuh"
#include "contrib_ops/cuda/diffusion/bias_add_impl.h"

using namespace onnxruntime::cuda;

namespace onnxruntime {
namespace contrib {
namespace cuda {
Expand All @@ -34,11 +36,7 @@ __global__ void BiasAddKernel(T const* input, T const* bias, T const* residual,

#pragma unroll
for (int32_t i = 0; i < C / TPB; ++i) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
output[base_offset] = input[base_offset] + bias[bias_offset] + residual[base_offset];
#else
output[base_offset] = static_cast<T>(float(input[base_offset]) + float(bias[bias_offset]) + float(residual[base_offset]));
#endif
base_offset += TPB;
bias_offset += TPB;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "core/providers/cuda/cu_inc/common.cuh"
#include "contrib_ops/cuda/diffusion/bias_split_gelu_impl.h"

using namespace onnxruntime::cuda;

namespace onnxruntime {
namespace contrib {
namespace cuda {
Expand All @@ -35,13 +37,9 @@ __global__ void biasSplitGeluKernel(T const* input, T const* bias, T* output) {

#pragma unroll
for (int32_t i = 0; i < HHS / TPB; ++i) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
auto value_left = (float)(input[index_input] + bias[index_bias]);
auto value_right = (float)(input[index_input + HHS] + bias[index_bias + HHS]);
#else
auto value_left = (float)(input[index_input]) + (float)(bias[index_bias]);
auto value_right = (float)(input[index_input + HHS]) + (float)(bias[index_bias + HHS]);
#endif

// Gelu is applied to right side only: Gelu(x) = x * 0.5 * (erf(x / sqrt(2)) + 1.0)
float gelu_right = value_right * 0.5f * (erff(value_right / 1.41421356237f) + 1.0f);
float result = value_left * gelu_right;
Expand Down
12 changes: 7 additions & 5 deletions onnxruntime/python/tools/transformers/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,10 @@ def convert_float_to_float16(
value_info_list = []
node_list = []

# List of Resize or GroupNorm nodes that is not in block list
resize_node_list = []
# Some operators (Like Resize or GroupNorm) have data type fixed as float for some input.
# When it is converted to float16, there are mixed types: some inputs are float32 and some are float16.
# This list keeps track of such nodes that are not in block list.
mixed_float_type_node_list = []

# type inference on input model
if func_infer_shape is not None:
Expand Down Expand Up @@ -303,7 +305,7 @@ def convert_float_to_float16(
for attr in n.attribute:
next_level.append(attr)
else:
resize_node_list.append(n)
mixed_float_type_node_list.append(n)

# if q is model.graph.node.attribute, push q.g and q.graphs (GraphProto)
# and process node.attribute.t and node.attribute.tensors (TensorProto)
Expand Down Expand Up @@ -344,8 +346,8 @@ def convert_float_to_float16(
)
)

# Some operators has data type fixed as float for some input. Add a float16 to float cast for those inputs.
for node in resize_node_list:
# Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs.
for node in mixed_float_type_node_list:
for i, input_name in enumerate(node.input):
if i not in ALWAYS_FLOAT_INPUTS[node.op_type]:
continue
Expand Down

0 comments on commit 18aa7fd

Please sign in to comment.