Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check whether the bias tensor is nullptr before accessing the type. #2566

Merged
merged 2 commits into from
May 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions tensorflow/lite/micro/kernels/transpose_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ limitations under the License.

#include "tensorflow/lite/kernels/internal/reference/transpose_conv.h"

#include <cstddef>
#include <cstdint>

#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
Expand Down Expand Up @@ -48,8 +51,9 @@ struct OpData {
// A scratch buffer is required for quantized implementations.
int scratch_buffer_index;

// TODO(b/192090531): Remove this once all 8x16 transpose conv models use
// 64-bit biases.
// Index to the converted 64-bit bias buffer from 16-bit bias. This is
// required to handle 16x8 transpose convolutions where a 16-bit bias is
// provided, whereas the kernel expects 64-bit biases.
int bias_converted_buffer_index;

// Multiplier and shift arrays are required for the int8 implementation.
Expand Down Expand Up @@ -123,7 +127,9 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
if (input->type == kTfLiteInt16) {
TFLITE_DCHECK(filter->type == kTfLiteInt8);
TFLITE_DCHECK(output->type == kTfLiteInt16);
if (bias->type == kTfLiteInt16) {
// Handle the case where the bias is 16 bits for 16x8 transpose
// convolution where the kernel actually expects 64-bit biases.
if (bias != nullptr && bias->type == kTfLiteInt16) {
TFLITE_DCHECK(
context->RequestScratchBufferInArena(
context, GetTensorShape(bias).FlatSize() * sizeof(std::int64_t),
Expand Down Expand Up @@ -299,12 +305,10 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) {
break;
}
case kTfLiteInt16: {
std::int64_t* scratch_buffer = static_cast<int64_t*>(
auto* scratch_buffer = static_cast<int64_t*>(
context->GetScratchBuffer(context, data.scratch_buffer_index));
// TODO(b/192090531): Remove this once all 8x16 transpose conv models use
// 64-bit biases.
if (bias != nullptr && bias->type == kTfLiteInt16) {
std::int64_t* bias_converted_buffer =
auto* bias_converted_buffer =
static_cast<int64_t*>(context->GetScratchBuffer(
context, data.bias_converted_buffer_index));
for (int i = 0; i < tflite::micro::GetTensorShape(bias).FlatSize();
Expand Down
Loading