Skip to content

Commit

Permalink
Sync from upstream TF.
Browse files Browse the repository at this point in the history
  • Loading branch information
TFLM-bot committed Nov 2, 2023
1 parent 6d337dc commit 91e4d7f
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 162 deletions.
98 changes: 97 additions & 1 deletion tensorflow/lite/core/api/flatbuffer_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "tensorflow/lite/core/api/flatbuffer_conversions.h"

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <memory>
Expand Down Expand Up @@ -881,6 +882,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_STABLEHLO_GATHER: {
return ParseStablehloGather(op, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_STABLEHLO_REDUCE_WINDOW: {
return ParseStablehloReduceWindow(op, error_reporter, allocator,
builtin_data);
}
case BuiltinOperator_REDUCE_WINDOW: {
auto params = safe_allocator.Allocate<TfLiteReduceWindowParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
Expand Down Expand Up @@ -949,7 +954,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_STABLEHLO_CONVERT:
case BuiltinOperator_STABLEHLO_PAD:
case BuiltinOperator_STABLEHLO_DOT_GENERAL:
case BuiltinOperator_STABLEHLO_REDUCE_WINDOW:
case BuiltinOperator_STABLEHLO_SORT:
case BuiltinOperator_STABLEHLO_WHILE:
case BuiltinOperator_STABLEHLO_TRANSPOSE:
Expand Down Expand Up @@ -2096,6 +2100,98 @@ TfLiteStatus ParseResizeNearestNeighbor(const Operator* op,
return kTfLiteOk;
}

TfLiteStatus ParseStablehloReduceWindow(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data) {
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);

SafeBuiltinDataAllocator safe_allocator(allocator);
auto params = safe_allocator.Allocate<TfLiteStablehloReduceWindowParams>();

const StablehloReduceWindowOptions* schema_params =
op->builtin_options_2_as_StablehloReduceWindowOptions();
if (schema_params) {
if (!schema_params->window_dimensions() ||
schema_params->window_dimensions()->size() == 0) {
TF_LITE_REPORT_ERROR(error_reporter,
"'window_dimensions' attribute is not optional for "
"'stablehlo.reduce_window' and cannot be empty.");
return kTfLiteError;
}

const size_t rank = schema_params->window_dimensions()->size();

auto LoadAttr = [&error_reporter](
auto& params_array, auto* const flatbuffer_vector,
const char* attr_name, const size_t expected_size,
const int64_t fill_value) -> TfLiteStatus {
if (flatbuffer_vector && flatbuffer_vector->size()) {
if (expected_size != 0 && flatbuffer_vector->size() != expected_size) {
TF_LITE_REPORT_ERROR(
error_reporter,
"'%s' attribute of 'stablehlo.reduce_window' does not have the "
"expected size (%llu != %llu).",
attr_name, flatbuffer_vector->size(), expected_size);
return kTfLiteError;
}
TfLiteStatus status = FlatBufferIntVectorToArray(
sizeof(params_array), flatbuffer_vector, params_array,
error_reporter, "stablehlo.reduce_window");
if (status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Check the '%s' attribute.",
attr_name);
return status;
}
} else {
std::fill_n(params_array,
TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT,
fill_value);
}
return kTfLiteOk;
};

if (TfLiteStatus status = LoadAttr(
params->window_dimensions, schema_params->window_dimensions(),
"window_dimensions", /*expected_size=*/rank, /*fill_value=*/1);
status != kTfLiteOk) {
return status;
}
if (TfLiteStatus status = LoadAttr(
params->window_strides, schema_params->window_strides(),
"window_strides", /*expected_size=*/rank, /*fill_value=*/1);
status != kTfLiteOk) {
return status;
}
if (TfLiteStatus status = LoadAttr(
params->base_dilations, schema_params->base_dilations(),
"base_dilations", /*expected_size=*/rank, /*fill_value=*/1);
status != kTfLiteOk) {
return status;
}
if (TfLiteStatus status = LoadAttr(
params->window_dilations, schema_params->window_dilations(),
"window_dilations", /*expected_size=*/rank, /*fill_value=*/1);
status != kTfLiteOk) {
return status;
}
if (TfLiteStatus status =
LoadAttr(params->padding, schema_params->padding(), "padding",
/*expected_size=*/2 * rank, /*fill_value=*/0);
status != kTfLiteOk) {
return status;
}

params->body_subgraph_index = schema_params->body_subgraph_index();
*builtin_data = params.release();
return kTfLiteOk;
}
TF_LITE_REPORT_ERROR(
error_reporter,
"Could not get 'stablehlo.reduce_window' operation parameters.");
return kTfLiteError;
}

TfLiteStatus ParseStablehloScatter(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/lite/core/api/flatbuffer_conversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,11 @@ TfLiteStatus ParseStablehloGather(const Operator* op,
BuiltinDataAllocator* allocator,
void** builtin_data);

TfLiteStatus ParseStablehloReduceWindow(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);

} // namespace tflite

#endif // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
17 changes: 17 additions & 0 deletions tensorflow/lite/core/c/builtin_op_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ extern "C" {
#define TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT 8
#define TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT 8
#define TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT 8
#define TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT 8

// TODO(aselle): Consider using "if this then that" for testing.

Expand Down Expand Up @@ -605,6 +606,22 @@ typedef struct {
bool indices_are_sorted;
} TfLiteStablehloGatherParams;

typedef struct {
// See the stablehlo spec for the explanation of the attributes:
// https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_window
int64_t window_dimensions
[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT];
int64_t
window_strides[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT];
int64_t
base_dilations[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT];
int64_t window_dilations
[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT];
int64_t
padding[2 * TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT];
int body_subgraph_index;
} TfLiteStablehloReduceWindowParams;

enum TfLiteReduceWindowFunction {
TfLiteReduceWindowFunctionUnsupported,
TfLiteReduceWindowFunctionAdd,
Expand Down
15 changes: 10 additions & 5 deletions tensorflow/lite/core/c/c_api_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,13 @@ limitations under the License.
extern "C" {
#endif

/** \addtogroup c_api_types tensorflow/lite/c/c_api_types.h
// clang-format off
// NOLINTBEGIN(whitespace/line_length)
/** \defgroup c_api_types tensorflow/lite/c/c_api_types.h
* @{
*/
// NOLINTEND(whitespace/line_length)
// clang-format on

// Define TFL_CAPI_EXPORT macro to export a function properly with a shared
// library.
Expand Down Expand Up @@ -123,12 +127,11 @@ typedef enum {
kTfLiteInt4 = 18,
} TfLiteType;

/// Legacy. Will be deprecated in favor of TfLiteAffineQuantization.
/// Legacy. Will be deprecated in favor of `TfLiteAffineQuantization`.
/// If per-layer quantization is specified this field will still be populated in
/// addition to TfLiteAffineQuantization.
/// addition to `TfLiteAffineQuantization`.
/// Parameters for asymmetric quantization. Quantized values can be converted
/// back to float using:
/// real_value = scale * (quantized_value - zero_point)
/// back to float using: `real_value = scale * (quantized_value - zero_point)`
typedef struct TfLiteQuantizationParams {
float scale;
int32_t zero_point;
Expand Down Expand Up @@ -156,13 +159,15 @@ typedef struct TfLiteDelegate TfLiteDelegate;
/// This is an abstract type that is intended to have the same
/// role as TfLiteDelegate, but without exposing the implementation
/// details of how delegates are implemented.
///
/// WARNING: This is an experimental type and subject to change.
typedef struct TfLiteOpaqueDelegateStruct TfLiteOpaqueDelegateStruct;

/// TfLiteOpaqueDelegate: conditionally opaque version of
/// TfLiteDelegate; allows delegation of nodes to alternative backends.
/// For TF Lite in Play Services, this is an opaque type,
/// but for regular TF Lite, this is just a typedef for TfLiteDelegate.
///
/// WARNING: This is an experimental type and subject to change.
#if TFLITE_WITH_STABLE_ABI || TFLITE_USE_OPAQUE_DELEGATE
typedef TfLiteOpaqueDelegateStruct TfLiteOpaqueDelegate;
Expand Down
Loading

0 comments on commit 91e4d7f

Please sign in to comment.