From 7f3624d09d64b3d0ddf589f61e535a015582c8c4 Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Thu, 9 May 2024 00:08:33 +0000 Subject: [PATCH] Check whether the bias tensor is `nullptr` before accessing the type. Prevents program crash due to null pointer dereference. --- .../lite/micro/kernels/transpose_conv.cc | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tensorflow/lite/micro/kernels/transpose_conv.cc b/tensorflow/lite/micro/kernels/transpose_conv.cc index cd616602c22..ea0efae0607 100644 --- a/tensorflow/lite/micro/kernels/transpose_conv.cc +++ b/tensorflow/lite/micro/kernels/transpose_conv.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/transpose_conv.h" +#include +#include + #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" @@ -48,8 +51,9 @@ struct OpData { // A scratch buffer is required for quantized implementations. int scratch_buffer_index; - // TODO(b/192090531): Remove this once all 8x16 transpose conv models use - // 64-bit biases. + // Index to the converted 64-bit bias buffer from 16-bit bias. This is + // required to handle 16x8 transpose convolutions where a 16-bit bias is + // provided, whereas the kernel expects 64-bit biases. int bias_converted_buffer_index; // Multiplier and shift arrays are required for the int8 implementation. @@ -123,7 +127,9 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, if (input->type == kTfLiteInt16) { TFLITE_DCHECK(filter->type == kTfLiteInt8); TFLITE_DCHECK(output->type == kTfLiteInt16); - if (bias->type == kTfLiteInt16) { + // Handle the case where the bias is 16 bits for 16x8 transpose + // convolution where the kernel actually expects 64-bit biases. + if (bias != nullptr && bias->type == kTfLiteInt16) { TFLITE_DCHECK( context->RequestScratchBufferInArena( context, GetTensorShape(bias).FlatSize() * sizeof(std::int64_t), @@ -299,12 +305,10 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) { break; } case kTfLiteInt16: { - std::int64_t* scratch_buffer = static_cast( + auto* scratch_buffer = static_cast( context->GetScratchBuffer(context, data.scratch_buffer_index)); - // TODO(b/192090531): Remove this once all 8x16 transpose conv models use - // 64-bit biases. if (bias != nullptr && bias->type == kTfLiteInt16) { - std::int64_t* bias_converted_buffer = + auto* bias_converted_buffer = static_cast(context->GetScratchBuffer( context, data.bias_converted_buffer_index)); for (int i = 0; i < tflite::micro::GetTensorShape(bias).FlatSize();