diff --git a/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt b/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt index 46828f14ebba3..4371cd3db965e 100644 --- a/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt +++ b/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt @@ -27,6 +27,7 @@ iree_cc_library( iree::vm iree::vm::bytecode::module iree_pjrt_deps::headers + iree_pjrt_deps::protos PUBLIC ) diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc index a577cb0b82ad3..68b45ecbbb07f 100644 --- a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc +++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc @@ -6,6 +6,7 @@ #include "iree_pjrt/common/api_impl.h" +#include #include #include #include @@ -1002,7 +1003,7 @@ iree_status_t DeviceInstance::TransposeBroadcastDeviceBuffer( // Compile program and check for errors: LoadedExecutableInstance* executable; - auto* error = this->client().Compile(&program, &executable); + auto* error = this->client().Compile(&program, {}, &executable); if (error) { auto errinst = ErrorInstance::FromError(error); auto ret = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, @@ -1351,22 +1352,14 @@ void ClientInstance::BindApi(PJRT_Api* api) { LoadedExecutableInstance* executable; // Read compilation options. - // TODO: Port CompileOptionsProto into the project or leave ommitted. - // xla::CompileOptionsProto options_proto; - // if (!options_proto.ParseFromArray(args->compile_options, - // args->compile_options_size)) { - // return MakeError(iree_make_status(IREE_STATUS_INTERNAL, - // "could not parse compilation - // options")); - // } - // auto options = xla::CompileOptions::FromProto(options_proto); - // if (!options.ok()) { - // return MakeError( - // iree_make_status(IREE_STATUS_INTERNAL, - // std::string(options.status().message()).c_str())); - // } - - auto* error = client->Compile(args->program, /**options,*/ &executable); + xla::CompileOptionsProto options_proto; + if (!options_proto.ParseFromArray(args->compile_options, + args->compile_options_size)) { + return MakeError(iree_make_status(IREE_STATUS_INTERNAL, + "could not parse compilation options")); + } + + auto* error = client->Compile(args->program, options_proto, &executable); if (error) return error; args->executable = *executable; return nullptr; @@ -1451,7 +1444,7 @@ iree_status_t ClientInstance::PopulateDevices() { } PJRT_Error* ClientInstance::Compile(const PJRT_Program* program, - /*xla::CompileOptions options,*/ + xla::CompileOptionsProto options, LoadedExecutableInstance** out_executable) { std::unique_ptr artifact_tx; if (platform().artifact_dumper().enabled()) { @@ -1570,11 +1563,28 @@ PJRT_Error* ClientInstance::Compile(const PJRT_Program* program, output->GetDataSize())); } + // calculate devices for this computation from device assignment + std::vector devices; + + const auto& build_options = options.executable_build_options(); + if (build_options.has_device_assignment()) { + const auto& device_assignment = build_options.device_assignment(); + for (auto id : + device_assignment.computation_devices(0).replica_device_ids()) { + if (id < addressable_devices_.size()) + devices.push_back(addressable_devices_[id]); + } + } + + if (devices.empty()) { + devices = addressable_devices_; + } + auto executable = std::make_unique( *this, new ExecutableImage(std::move(output), std::string(program->code, program->code_size)), - addressable_devices_); + devices); status = executable->LoadAll(); if (!iree_status_is_ok(status)) { return MakeError(status); diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.h b/integrations/pjrt/src/iree_pjrt/common/api_impl.h index c6debcae8bf76..98912de73544a 100644 --- a/integrations/pjrt/src/iree_pjrt/common/api_impl.h +++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.h @@ -25,6 +25,7 @@ #include "iree_pjrt/common/layout_utils.h" #include "iree_pjrt/common/platform.h" #include "xla/pjrt/c/pjrt_c_api.h" +#include "compile_options.pb.h" namespace iree::pjrt { @@ -452,7 +453,7 @@ class ClientInstance { // Compiles. // See TODOs in PJRT_Client_Compile. PJRT_Error* Compile( - const PJRT_Program* program, /*xla::CompileOptions options, */ + const PJRT_Program* program, xla::CompileOptionsProto options, LoadedExecutableInstance** executable); // --------------------------------------------------------------------------- diff --git a/integrations/pjrt/src/iree_pjrt/vulkan/client.cc b/integrations/pjrt/src/iree_pjrt/vulkan/client.cc index 853ead814b247..228cbe9eca8cc 100644 --- a/integrations/pjrt/src/iree_pjrt/vulkan/client.cc +++ b/integrations/pjrt/src/iree_pjrt/vulkan/client.cc @@ -32,7 +32,7 @@ iree_status_t VulkanClientInstance::CreateDriver( } bool VulkanClientInstance::SetDefaultCompilerFlags(CompilerJob* compiler_job) { - return compiler_job->SetFlag("--iree-hal-target-backends=vulkan"); + return compiler_job->SetFlag("--iree-hal-target-backends=vulkan-spirv"); } } // namespace iree::pjrt::vulkan diff --git a/integrations/pjrt/third_party/pjrt_c_api/CMakeLists.txt b/integrations/pjrt/third_party/pjrt_c_api/CMakeLists.txt index 52e7fae256ec8..b5bd73baae28f 100644 --- a/integrations/pjrt/third_party/pjrt_c_api/CMakeLists.txt +++ b/integrations/pjrt/third_party/pjrt_c_api/CMakeLists.txt @@ -26,3 +26,26 @@ iree_cc_library( "xla/pjrt/c/pjrt_c_api.h" PUBLIC ) + +find_package(Protobuf REQUIRED) + +protobuf_generate_cpp( + PJRT_PROTO_SRCS + PJRT_PROTO_HDRS + "xla/pjrt/compile_options.proto" + "xla/xla_data.proto" +) + +iree_cc_library( + NAME + protos + INCLUDES + ${Protobuf_INCLUDE_DIRS} + HDRS + ${PJRT_PROTO_HDRS} + SRCS + ${PJRT_PROTO_SRCS} + DEPS + ${Protobuf_LIBRARIES} + PUBLIC +) diff --git a/integrations/pjrt/third_party/pjrt_c_api/README.md b/integrations/pjrt/third_party/pjrt_c_api/README.md index 404428d63634e..c55093a30aa22 100644 --- a/integrations/pjrt/third_party/pjrt_c_api/README.md +++ b/integrations/pjrt/third_party/pjrt_c_api/README.md @@ -1,6 +1,7 @@ # pjrt_c_api -This directory contains a fork of C headers needed to build a PJRT plugin. +This directory contains a fork of C headers and .proto files +needed to build a PJRT plugin. It is intended to be sync'd with upstream for major/breaking changes and releases. diff --git a/integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/compile_options.proto b/integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/compile_options.proto new file mode 100644 index 0000000000000..4cdd2bed25fa3 --- /dev/null +++ b/integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/compile_options.proto @@ -0,0 +1,165 @@ +syntax = "proto3"; + +package xla; + +// TODO: to avoid introducing too many source files in XLA to IREE PJRT, +// currently we remove some fields in this file which is not in use +// so that their message definitions are not required. +// If we want to uncomment these removed fields, we should also +// add the corresponding schema files, like below. + +// import "xla/stream_executor/device_description.proto"; +// import "xla/xla.proto"; +import "xla_data.proto"; + +// A serialization of xla::ExecutableBuildOptions. +// Next id: 24. +message ExecutableBuildOptionsProto { + // If set, this is the device to build the computation for. Valid + // device_ordinal values are: 0 to # of devices - 1. These values are + // identical to the device ordinal values used by StreamExecutor. The built + // executable will be executable on any device equivalent to the specified + // device as determined by Backend::devices_equivalent(). A value of -1 + // indicates this option has not been set. + int64 device_ordinal = 1; + + // If set, this specifies the layout of the result of the computation. If not + // set, the service will chose the layout of the result. A Shape is used to + // store the layout to accommodate tuple result shapes. A value of nullptr + // indicates the option has not been set. + xla.ShapeProto result_layout = 2; + + // Expose access to the XLA compilation environments, which will be passed to + // the compilation process. + // xla.CompilationEnvironmentsProto comp_envs = 13; + + // Expose access to the XLA debug options which will be passed to the + // compilation process. + // xla.DebugOptions debug_options = 3; + + // The number of replicas of this computation that are to be executed. + // Defaults to 1. + int64 num_replicas = 4; + + // The number of partitions in this computation. Defaults to 1. + int64 num_partitions = 5; + + // Indicates whether to use SPMD (true) or MPMD (false) partitioning when + // num_partitions > 1 and XLA is requested to partition the input program. + bool use_spmd_partitioning = 6; + + // Whether to automatically generate XLA shardings for SPMD partitioner. + bool use_auto_spmd_partitioning = 7; + + // The amount of effort to spend on optimizing for minimizing program + // execution time, as a value in [-1.0, +1.0]. The baseline is 0.0, which + // strongly prioritizes execution time at the cost of longer compile times, + // suitable for production workloads. A value of -0.5 would be appropriate for + // research use cases that prefer faster compilations to iterate more quickly. + // Positive values, on the other hand, might enable costly optimizations that + // are off by default. + float exec_time_optimization_effort = 20; + + // The amount of effort to spend on making the program fit in memory (where + // "fit in memory" here has a backend-dependent meaning), as a value in + // [-1.0,+1.0]. The baseline is 0.0, which expends significant effort on + // attempting to make the program fit. A value of -1.0 would be appropriate + // for use cases that wish to spend minimal effort here and fail as quickly as + // possible instead. Positive values, on the other hand, might enable costly + // algorithms to reduce memory usage that are off by default. + float memory_fitting_effort = 21; + + // Whether HLOs should be deduplicated. + bool deduplicate_hlo = 8; + + // If set, this specifies a static device assignment for the computation. + // Otherwise, the computation will be compiled generically and can be run with + // any device assignment compatible with the computation's replica and + // partition counts. + xla.DeviceAssignmentProto device_assignment = 9; + + // Whether input and output buffers are aliased if the associated parameter is + // passed-through XLA modules without being changed. + bool alias_passthrough_params = 10; + + // By default, XLA builds an executable by invoking standard compilation, i.e. + // running Compiler::Compile, or both Compiler::RunHloPasses and + // Compiler::RunBackend. When run_backend_only is set to true, XLA builds an + // executable by invoking only RunBackend and skip invoking RunHloPasses, + // which can be used to compile post-optimizations HLO modules. + bool run_backend_only = 11; + + // Allows sharding propagation to propagate to the parameters. This changes + // the input shape of the computation (which is undesirable), but it can be + // used to allow to run partial compilation to determine what would be the + // input sharding of a computation if XLA would be allowed to propagate the + // sharding which can be used by higher level framework as a way to query + // intermediate sharding of operations when multiple computation would be + // chained and merged together. + // This is a vector of bool, because the user can control which parameters can + // have the sharding substituted. If only one boolean value is passed in the + // vector that is interpreted as the value to be applied for every parameter. + repeated bool allow_spmd_sharding_propagation_to_parameters = 18; + + // Allows sharding propagation to propagate to the outputs. This changes the + // output shape of the computation (which is undesirable), but it can be used + // to allow to run partial compilation to determine what would be the output + // sharding of a computation if XLA would be allowed to propagate the sharding + // which can be used by higher level framework as a way to query intermediate + // sharding of operations when multiple computation would be chained and + // merged together. + // This is a vector of bool, because the user can control (if the output of + // the computation is a tuple) which elements of the tuple can have the + // sharding substituted and which don't. If only one boolean value is passed + // in the vector that's interpreted as the value to be applied for every + // single element of the output tuple. One value per element of the tuple + // means that each value is attached to one of the output elements. + repeated bool allow_spmd_sharding_propagation_to_output = 12; + + // Opaque profile data for any feedback directed optimizations. + bytes fdo_profile = 14; + + int64 device_memory_size = 15; + + // Mesh shape in auto sharding options. + repeated int64 auto_spmd_partitioning_mesh_shape = 16; + + // Mesh ids in auto sharding options. + repeated int64 auto_spmd_partitioning_mesh_ids = 17; + + // Use Shardy, a new partitioner, to replace the existing + // ShardingPropagation and SpmdPartitioner. See go/xla-sdy-pipeline for + // details. + bool use_shardy_partitioner = 19; + + int64 process_index = 22; + int64 process_count = 23; +} + +message OptionOverrideProto { + oneof value { + string string_field = 1; + bool bool_field = 2; + int64 int_field = 3; + double double_field = 4; + } +} + +message CompileOptionsProto { + // Refer CompileOptions for documentation of fields. + // repeated ShapeProto argument_layouts = 1; + bool parameter_is_tupled_arguments = 2; + ExecutableBuildOptionsProto executable_build_options = 3; + bool compile_portable_executable = 4; + int64 profile_version = 5; + bytes serialized_multi_slice_config = 6; + map env_option_overrides = 7; + + // stream_executor.GpuTargetConfigProto target_config = 8; +} + +// Helper for serializing opaque executables alongside CompileOptions. +message ExecutableAndOptionsProto { + bytes serialized_executable = 1; + CompileOptionsProto compile_options = 2; +} diff --git a/integrations/pjrt/third_party/pjrt_c_api/xla/xla_data.proto b/integrations/pjrt/third_party/pjrt_c_api/xla/xla_data.proto new file mode 100644 index 0000000000000..7d9563b11ab79 --- /dev/null +++ b/integrations/pjrt/third_party/pjrt_c_api/xla/xla_data.proto @@ -0,0 +1,1153 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +option cc_enable_arenas = true; + +// Primitive types are the individual values that can be held in rectangular +// multidimensional arrays. A description of the rectangular multidimensional +// array dimensions / primitive type is given by Shape, below. +// +// LINT.IfChange +enum PrimitiveType { + // Invalid primitive type to serve as default. + PRIMITIVE_TYPE_INVALID = 0; + + // Predicates are two-state booleans. + PRED = 1; + + // Signed integral values of fixed width. + S2 = 26; + S4 = 21; + S8 = 2; + S16 = 3; + S32 = 4; + S64 = 5; + + // Unsigned integral values of fixed width. + U2 = 27; + U4 = 22; + U8 = 6; + U16 = 7; + U32 = 8; + U64 = 9; + + // Floating-point values of fixed width. + // + // Note: if f16s are not natively supported on the device, they will be + // converted to f16 from f32 at arbirary points in the computation. + F16 = 10; + F32 = 11; + + // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit + // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent + // and 7 bits for the mantissa. + BF16 = 16; + + F64 = 12; + + // FP8 dtypes, as described in this paper: https://arxiv.org/abs/2209.05433 + // + // F8E5M2 has 5 exponent bits and 2 mantissa bits, and is similar to the + // existing IEEE types. + // + // F8E4M3 has 4 exponent bits and 3 mantissa bits, and is similar to the + // existing IEEE types. + // + // F8E4M3FN has 4 exponent bits and 3 mantissa bits. The "FN" means only + // Finite and NaN values are supported. Unlike IEEE types, infinities are not + // supported. NaN is represented when the exponent and mantissa bits are all + // 1s. All other values are finite. + // + // F8E4M3B11FNUZ has 4 exponent bits and 3 mantissa bits and a bias of 11. The + // "FNUZ" means only Finite and NaN values are supported; zero is unsigned. + // Unlike IEEE types, infinities are not supported. NaN is represented when + // the exponent and mantissa bits are all 0s with a sign bit of 1. All other + // values are finite. + // + // F8E3M4 has 3 exponent bits and 4 mantissa bits, and is similar to the + // existing IEEE types. + // + // Support for these dtypes is under development. They do not yet work + // properly in most cases. + // TODO(b/259609697): Fully support FP8. + F8E5M2 = 19; + F8E4M3 = 28; + F8E4M3FN = 20; + F8E4M3B11FNUZ = 23; + F8E3M4 = 29; + + // FP8 dtypes, as described in this paper: https://arxiv.org/abs/2206.02915 + // + // F8E5M2FNUZ has 5 exponent bits and 2 mantissa bits. + // F8E4M3FNUZ has 4 exponent bits and 3 mantissa bits. + // + // The "FNUZ" means only Finite and NaN values are supported; zero is + // unsigned. Unlike IEEE types, infinities are not supported. NaN is + // represented when the exponent and mantissa bits are all 0s with a sign bit + // of 1. All other values are finite. + // + // These differences mean there's an additional exponent value available. To + // keep the same dynamic range as an IEEE-like FP8 type, the exponent is + // biased one more than would be expected given the number of exponent bits + // (8 for Float8E4M3FNUZ and 16 for Float8E5M2FNUZ). + F8E5M2FNUZ = 24; + F8E4M3FNUZ = 25; + + // Complex values of fixed width. + C64 = 15; // Paired F32 (real, imag), as in std::complex. + C128 = 18; // Paired F64 (real, imag), as in std::complex. + + // A tuple is a polymorphic sequence; e.g. a shape that holds different + // sub-shapes. They are used for things like returning multiple values from a + // computation; e.g. a computation that returns weights and biases may have a + // signature that results in a tuple like (f32[784x2000], f32[2000]) + // + // If a shape proto has the tuple element type, it may not have any entries + // in the dimensions field. + TUPLE = 13; + + // An opaque type used for passing context-specific data to a custom + // operation. Shapes of this primitive type will have empty dimensions and + // tuple_shapes fields. + // + // (OPAQUE would be a better name for this identifier, but that conflicts with + // a macro defined in windows.h.) + OPAQUE_TYPE = 14; + + // A token type threaded between side-effecting operations. Shapes of this + // primitive type will have empty dimensions and tuple_shapes fields. + TOKEN = 17; + + // Next = 30 +} +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc +// ) + +// Describes the padding configuration for Pad operation. The padding amount on +// both edges as well as between the elements are specified for each dimension. +message PaddingConfig { + // Describes the padding configuration for a dimension. + message PaddingConfigDimension { + // Padding amount on the low-end (next to the index 0). May be negative. + int64 edge_padding_low = 1; + + // Padding amount on the high-end (next to the highest index). May be + // negative. + int64 edge_padding_high = 2; + + // Padding amount between the elements. May not be negative. + int64 interior_padding = 3; + } + + // The padding configuration for all dimensions. + repeated PaddingConfigDimension dimensions = 1; +} + +// A DimLevelType indicates the encoding method for a dimension in an array. +// The semantics of this field are identical to those of the MLIR SparseTensor +// dialect. +// This should be kept in sync with the SparseTensor DimLevelType enum: +// https://github.com/llvm/llvm-project/blob/5674a3c88088e668b684326c2194a6282e8270ff/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td#L86 +enum DimLevelType { + // The corresponding dimension is Dense, every entry is stored. + DIM_DENSE = 0; + // The corresponding dimension is Compressed, only nonzeros are stored. + DIM_COMPRESSED = 1; + // The corresponding dimension contains a single coordinate, no sibling + // elements for each parent. + DIM_SINGLETON = 2; + // The corresponding dimension is Compressed, but with potential trailing + // zeros, thus an extra upper bound (high) is used to exclude those zeros. + // E.g., indices = [1, 2, 0, 0, 3, 4, 0, 0], position = [(0, 2), (4, 6)]. + DIM_LOOSE_COMPRESSED = 3; +} + +// Describes a tile used in tiling-based layout. Refer to +// g3doc/third_party/xla/docs/tiled_layout.md for details about tiling-based +// layout. +message TileProto { + // Number of elements in each dimension of the tile. It's ordered from the + // most major dimension of the tile to the most minor dimension of the tile. + // The dimensions correspond to a suffix of the dimensions of the shape being + // tiled. + repeated int64 dimensions = 1; +} + +// Describes how data should be split between different memories. +message SplitConfigProto { + // The dimension that is split. + int64 dimension = 1; + // The indices where each split point occurs. For example, if the dimension + // size is 1024, a split_indices value of {512} indicates a two-way split of + // data through the middle. + repeated int64 split_indices = 2; +} + +// A layout describes how the array is placed in (1D) memory space. This +// includes the minor-to-major ordering of dimensions within a shape. +// +// Clients must specify the layouts of input Literals to the +// computation. Layouts specified in interior operations which take Shapes (for +// example, Convert) are ignored. +// +// See the XLA documentation for more information on shapes and layouts. +// +// LINT.IfChange +message LayoutProto { + // The dimension level type list for this array, specifying the way in which + // each array dimension is represented in memory. If this list is empty, the + // array is assumed to be dense. + repeated DimLevelType dim_level_types = 9; + + // Whether each dimension is unique or ordered. Each of the following lists + // must be empty, or have one entry for each entry of dim_level_types. If + // either list is empty, all dimensions are assumed to be unique and ordered, + // respectively. Entries in this list may not be false for some DimLevelType + // values (such as DIM_DENSE in particular). + repeated bool dim_unique = 13; + repeated bool dim_ordered = 14; + + // Sequence of dimension numbers, from minor (fastest varying index) to major + // (slowest varying index). This field is required. + repeated int64 minor_to_major = 1; + + // A sequence of tiles, starting from the tile that's applied first to the + // Shape. + // + // TODO(b/119839262): implement tiling in each backend or add Unimplemented + // error. + repeated TileProto tiles = 6; + + // The shape is padded at the end to multiple of, in terms of number of + // elements. This is useful when tiling does not bring the shape to certain + // desired granules. Tiling effectively pads/reshapes/transposes the shape + // to another shape. This field pads the total number of elements of that + // new shape to a multiple of certain number of elements. This is useful such + // as we want a layout which does not tile the data but still requires it to + // be padded to certain number of elements. + int64 tail_padding_alignment_in_elements = 16; + + // (Optional) Bit size of each element. When unspecified or being 0, default + // to ShapeUtil::ByteSizeOfPrimitiveType. + int64 element_size_in_bits = 7; + + // Memory space where this array resides. The integer field is interpreted in + // a backend-specific manner. + int64 memory_space = 8; + + // The integer types to be used for indices and pointers. These fields must + // not be used unless the layout represents a sparse array. The PrimitiveType + // must correspond to an unsigned integer (U8, U16, U32, or U64). + // If not provided, the compiler will use the largest unsigned integer + // that is naturally supported by the target device (U32 or U64 in currently + // supported devices). + PrimitiveType index_primitive_type = 11; + PrimitiveType pointer_primitive_type = 12; + + // The physical, on-device shape used to represent the shape this layout + // belongs to. Only used for sparse arrays. + // The layout(s) contained within the physical shape should not also contain + // a physical shape. + ShapeProto physical_shape = 10; + + // The dynamic shape metadata size in bytes in front of the shape data. The + // field may be non-zero for a static shape whose associated buffer is for a + // dynamic shape, e.g. a result of SliceToDynamic. + int64 dynamic_shape_metadata_prefix_bytes = 15; + + // The split configurations which describe if/how the data is split between + // different memories. + repeated SplitConfigProto split_configs = 17; + + // Important: if any field is added, be sure to modify ShapeUtil::Equal() and + // LayoutUtil::Hash appropriately to account for the new field. + + reserved 2; + reserved "padded_dimensions"; + reserved 3; + reserved "padding_value"; + reserved 4; + reserved "format"; + reserved 5; + reserved "max_sparse_elements"; +} +// LINT.ThenChange( \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc) + +// A shape describes the number of dimensions in the array, the size of each +// dimension, and the primitive component type. +// +// Tuples are a special case in that they have rank zero and have tuple_shapes +// defined. +// +// See the XLA documentation for more information on shapes and layouts. +// +// LINT.IfChange +message ShapeProto { + reserved 1; + reserved "rank"; + + // The element type for this shape. + PrimitiveType element_type = 2; + + // The size (number of elements) for each dimension, or an upper bound on the + // size if the dimension is dynamic. In XLA, dimensions are numbered from 0 + // to N-1 for an N-dimensional array. The first element of 'dimensions' is the + // size of dimension 0, the second element is the size of dimension 1, and so + // forth. Empty list indicates a scalar. + // + // If the respective element in 'is_dimension_dynamic' is true then the value + // in this field represents an upper bound on the size of the dimension. + repeated int64 dimensions = 3; + + // For tuples only, the shapes of constituent shapes in the tuple sequence. + repeated ShapeProto tuple_shapes = 4; + + // The layout used to back this shape. + LayoutProto layout = 5; + + // For arrays, this indicates whether or not each dimension is + // dynamically-sized. The number of elements in this repeated field should be + // zero (indicating that no dimensions are dynamic) or equal to the number of + // elements in the 'dimensions' field. + repeated bool is_dynamic_dimension = 6; + + // Important: if any field is added, be sure to modify ShapeUtil::Equal(), + // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for + // the new field. +} +// LINT.ThenChange( \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc) + +// Shape of the parameters and output of a computation (like a traditional +// function signature). +message ProgramShapeProto { + repeated ShapeProto parameters = 1; + ShapeProto result = 2; + repeated string parameter_names = 3; +} + +// Statistics of a computation. +message ComputationStats { + // The number of floating point operations in the computation. + double flop_count = 1; + + // The number of transcendental operations (e.g., exp) in the computation. + double transcendental_count = 2; +} + +// The type optimization profiles in use for Op-level optimizations. +enum ProfileType { + INVALID = 0; + WINDOW = 1; + FLAG = 2; + INTEGER = 3; +} + +// The source of the optimization profile. +enum ProfileSource { + PROFILE_SOURCE_UNKNOWN_SOURCE = 0; + PROFILE_SOURCE_EMBEDDED = 1; + PROFILE_SOURCE_REMOTE = 2; +} + +// The compilation event that triggered the use of the profile. +enum CompilationEvent { + COMPILATION_EVENT_UNKNOWN_EVENT = 0; + COMPILATION_EVENT_FIRST_COMPILATION = 1; + COMPILATION_EVENT_RECOMPILATION = 2; +} + +// Symbolization metadata for HLO Instructions. +// +// This metadata is used for debugging XLA code generation, as well as +// performance profiling of XLA-generated executables. +message OpMetadata { + // The framework op name that generated this XLA op. + // + // Frameworks that build on top of XLA should mirror the names of their ops + // back to users by specifying the op_type. In this way, even if the + // framework's "ops" are implemented as multiple XLA HLO Ops, they can be + // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as + // multiple ops, then each op should have the op_type be "SoftMax".) + string op_type = 1; + // The user-specified name of the op. + // + // This name is often unique within a computation. Note: some frameworks + // add auto-generated names if the user does not provide one. + string op_name = 2; + // Indicate a file and line that this op is associated to in a user's program. + // + // e.g. it could be the file and line of user code that generated the op. + string source_file = 3; + int32 source_line = 4; + + // Deprecated, use [ProfileInfo][profile_type] instead. + repeated ProfileType profile_type = 5 [deprecated = true]; + + reserved 6; + reserved "creation_pass_id"; + + reserved 7; + reserved "logical_creation_pass_id"; + + // The footprint of the generated code for the instruction. + int64 size_of_generated_code_in_bytes = 8; + // The size of the working set, i.e., the amount of memory, used by the + // instruction in a compiler-managed fast device memory. + int64 size_of_memory_working_set_in_bytes = 9; + + // Information about the optimization profile that this operation contains. + message ProfileInfo { + // The type of optimization profiles that this operation contains. + repeated ProfileType profile_type = 1; + // Speedup of tuned config compared to default config. + // TODO(b/203817882) Set the relative_speedup. + double relative_speedup = 2; + // The source of the optimization profiles that this operation contains. + ProfileSource profile_source = 3; + // The compilation event that triggered the use of the profiles. + CompilationEvent compilation_event = 4; + } + + // Profile information for the Op. + ProfileInfo profile_info = 10; + + // Deduplicated HLO name for this op. In some cases, we can have multiple + // instructions (e.g. fusions) that are considered duplicates. We want to + // group them together under the same name so that we can group them together + // during analysis (e.g. HLO Op Profile tool in Xprof). + // E.g. If we have fusion.1, fusion.2, and fusion.3 marked as duplicates, + // fusion.2 and fusion.3 will have deduplicated_name = fusion.1 + string deduplicated_name = 12; + + // Whether to preserve the layout of the HLO op. + bool preserve_layout = 13; + + // 1-based position of the frame in frames flat array. + // Ids are 1-based to keep 0 value as representation of non-set property. + int32 stack_frame_id = 15; + + // Instruction name available upon scheduling. + string scheduling_name = 16; + + reserved 14; +} + +// Profile data from the execution of a computation. +message ExecutionProfile { + // Whether the executable was read from the compilation cache. + bool compilation_cache_hit = 1; + + // The time in milliseconds spent to compile the computation. This only set if + // the executable was not read from the compilation cache + // (compilation_cache_hit == false). + int64 compile_time_ms = 2; + + // The number of cycles spent for the computation. This does not include the + // time taken for the data transfers between the host and the device. This is + // a target-dependent field and only used for debugging purposes. + int64 compute_cycle_count = 3; + + // The time in nanoseconds spent for the computation, without data transfer. + int64 compute_time_ns = 4; + + // The time in nanoseconds spent for the entire computation, including the + // result data transfer time. Current implementation does not spend any cycles + // for the input data transfer since the memory is initialized with the proper + // values before the execution. + int64 compute_and_transfer_time_ns = 5; + + // The size of the binary code in the executable. + int64 executable_size_in_bytes = 6; + + // Whether this profile was drawn from a cache of profiles instead of from + // execution on the hardware. + bool profile_cache_hit = 7; + + // Whether a warm-up run of the computation was executed before the + // measured execution. + bool warmup_run_executed = 8; +} + +// Handle given to a user that represents an execution that the user launched +// asynchronously on the device. +message ExecutionHandle { + int64 handle = 1; +} + +// Handle given to a user that represents a globally accessible allocation. +// Contrast this against a ComputationDataHandle, which is not globally +// accessible, since it only exists within a specific computation. +message GlobalDataHandle { + int64 handle = 1; +} + +// Handle given to a user that represents a replicated virtual device. Each +// replicated device represents N physical devices for execution where N is the +// number of replicas. +message DeviceHandle { + int64 handle = 1; + + // The number of model-parallel virtual devices that communicate via XLA + // Send/Recv instructions. + int64 device_count = 2; +} + +// Handle given to a user to represent a channel between two computations +// via a Send and Recv instruction pair. Channels are unbuffered, so Send +// Send instructions will be blocked until the data is transferred. +message ChannelHandle { + int64 handle = 1; + enum ChannelType { + // Invalid primitive type to serve as default. + CHANNEL_TYPE_INVALID = 0; + + // A channel for sending data between devices. + DEVICE_TO_DEVICE = 1; + + // A channel for sending data from the device to the host. Can only be used + // with a Send operation. + DEVICE_TO_HOST = 2; + + // A channel for sending data from the host to the device. Can only be used + // with a Recv operation. + HOST_TO_DEVICE = 3; + } + ChannelType type = 2; +} + +// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which +// represents the device ids assigned to a set of replicated computations. +// See xla::DeviceAssignment class comment for more details. +message DeviceAssignmentProto { + int32 replica_count = 1; + int32 computation_count = 2; + + // Each logical computation runs on replica_count physical devices. + // ComputationDevice represents the device ids assinged to the replicas. + message ComputationDevice { + repeated int64 replica_device_ids = 1; + } + repeated ComputationDevice computation_devices = 3; +} + +// Literals are used when the server and client need to exchange materialized +// data / results. Literals are also used to describe constants used in +// computations. +// +// Transfers to/from the client are encoded in literal form, and the structure +// of the repeated fields is implied by the shape. +message LiteralProto { + ShapeProto shape = 1; + repeated bool preds = 2; + bytes s2s = 26; + bytes s4s = 21; + bytes s8s = 15; + bytes u2s = 27; + bytes u4s = 22; + bytes u8s = 3; + repeated int32 s32s = 4; + repeated int64 s64s = 5; + repeated uint32 u32s = 6; + repeated uint64 u64s = 7; + repeated float f32s = 8; + repeated double f64s = 9; + repeated float c64s = 12; // Stored as interleaved real, imag floats. + repeated double c128s = 18; // Stored as interleaved real, imag doubles. + repeated LiteralProto tuple_literals = 10; + // The F16s, BF16s, U16s and S16s are encoded in little endian byte order + bytes f16s = 11; + bytes bf16s = 13; + bytes u16s = 16; + bytes s16s = 17; + bytes f8e5m2s = 19; + bytes f8e4m3s = 28; + bytes f8e4m3fns = 20; + bytes f8e4m3b11fnuzs = 23; + bytes f8e5m2fnuzs = 24; + bytes f8e4m3fnuzs = 25; + bytes f8e3m4s = 29; + repeated int64 sparse_indices = 14; + // Next = 30 +} + +message WindowDimension { + // The size of the window in this dimension. For a rectangle, this would be + // the width or height. + int64 size = 1; + + // The stride at which the window moves across the base area in this + // dimension. In other words, this is the spacing between different + // positions of the window in this dimension. + int64 stride = 2; + + // If positive, means the amount of padding to add to the base area at the low + // end of this dimension; if negative, its negative means the number of + // elements removed from the low end of this dimension. For example, in the + // horizontal dimension of a rectangle, this would be the number of padding + // values to pad on the left, given that indices increase when going right. + // The actual padding value depends upon the context. Convolution pads with + // zeros. ReduceWindow and SelectAndScatter pads with the reduce function's + // init value. + int64 padding_low = 3; + + // As padding_low, but on the high end of this dimension. For example, in the + // horizontal dimension of a rectangle, this would be the number of values to + // pad on the right, given that indices increase when going right. + int64 padding_high = 4; + + // Dilation factor of the sliding window in this dimension. A dilation factor + // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are + // implicitly placed between each kernel element. This value may not be less + // than 1. See documentation for convolution. + int64 window_dilation = 5; + + // Dilation factor of the base area in this dimension. A dilation factor of 1 + // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly + // placed between each base area element. This value may not be less than 1. + // See documentation for convolution. + int64 base_dilation = 6; + + // Window reversal means that this dimension was logically reversed before the + // operation. + bool window_reversal = 7; +} + +// Describes the windowing in an operation such as convolution. +// +// The window is moved across a base area and for each position of the +// window a computation is performed. The field below describes the +// window and the movement of the window across a base area. +message Window { + repeated WindowDimension dimensions = 1; +} + +// Describes the dimension numbers for a gather operation. +// +// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for +// more details. +message GatherDimensionNumbers { + // "Window indices" is a term for a set of indices that index into the + // interior of a dynamic-slice from the input tensor, the starting indices for + // which were computed from output_gather_dims (see the operation semantic for + // how this is defined) and the start_indices tensor. + // + // The window indices for a specific output index Out is computed as: + // + // i = 0 + // for (k : [0, input_tensor_shape.rank)) + // window_indices[k] = + // if k in collapsed_slice_dims + // then 0 + // else Out[offset_dims[i++]] + repeated int64 offset_dims = 1; + repeated int64 collapsed_slice_dims = 2; + + // This is interpreted as a map from i to start_index_map[i]. It + // transforms the gather index looked up from the start_indices tensor into + // the starting index in the input space. + repeated int64 start_index_map = 3; + + // The dimension in the start_indices input that contains the starting + // indices. + int64 index_vector_dim = 4; + + // This is the batch dimensions in the operand. + repeated int64 operand_batching_dims = 5; + + // This is the batch dimensions in the index, and it should be the same size + // as operand_batching_dims. + repeated int64 start_indices_batching_dims = 6; +} + +// Describes the dimension numbers for a scatter operation. +// +// All the fields are similar to the corresponding fields in +// GatherDimensionNumbers. Differences are noted below. +message ScatterDimensionNumbers { + // The set of dimensions in the updates shape that are window dimensions. + repeated int64 update_window_dims = 1; + // The set of window dimensions that must be inserted into the updates shape. + repeated int64 inserted_window_dims = 2; + + repeated int64 scatter_dims_to_operand_dims = 3; + int64 index_vector_dim = 4; + + // This is the batch dimension in the input. + repeated int64 input_batching_dims = 5; + + // This is the batch dimension in the index. + repeated int64 scatter_indices_batching_dims = 6; +} + +message ConvolutionDimensionNumbers { + // The number of the dimension that represents batch in the input. + int64 input_batch_dimension = 7; + + // The number of the dimension that represents features in the input. + int64 input_feature_dimension = 8; + + // The dimension numbers for the spatial dimensions that the window + // moves through in the input. + repeated int64 input_spatial_dimensions = 11; + + // The number of the dimension that represents input features in the + // convolutional kernel (rhs). + int64 kernel_input_feature_dimension = 3; + + // The number of the dimension that represents output features in + // the convolutional kernel (rhs). + int64 kernel_output_feature_dimension = 4; + + // The dimension numbers for the spatial dimensions that the window + // moves through in the kernel (rhs). window.strides(0) is the + // stride in the kernel_spatial_dimensions(0) dimension. + repeated int64 kernel_spatial_dimensions = 6; + + // The number of the dimension that represents batch in the output. + int64 output_batch_dimension = 9; + + // The number of the dimension that represents features in the output. + int64 output_feature_dimension = 10; + + // The dimension numbers for the spatial dimensions that the window + // moves through in the output. + repeated int64 output_spatial_dimensions = 12; + + // Next = 13 +} + +enum PaddingType { + PADDING_INVALID = 0; + PADDING_VALID = 1; // Only valid portion of the base are covered. + PADDING_SAME = 2; // Extra is added to produce same output size as the input. +} + +enum FftType { + FFT = 0; // Forward FFT; complex in, complex out. + IFFT = 1; // Inverse FFT; complex in, complex out. + RFFT = 2; // Forward real FFT; real in, fft_length / 2 + 1 complex out + IRFFT = 3; // Inverse real FFT; fft_length / 2 + 1 complex in, + // fft_length real out +} + +message DotDimensionNumbers { + // The dimension numbers that represent the 'lhs' contracting dimensions. + repeated int64 lhs_contracting_dimensions = 1; + // The dimension numbers that represent the 'rhs' contracting dimensions. + repeated int64 rhs_contracting_dimensions = 2; + // The dimension numbers that represent the 'lhs' batch dimensions. + repeated int64 lhs_batch_dimensions = 3; + // The dimension numbers that represent the 'rhs' batch dimensions. + repeated int64 rhs_batch_dimensions = 4; +} + +message RaggedDotDimensionNumbers { + // The contracting and batch dimensions of the 'lhs' and 'rhs'. + DotDimensionNumbers dot_dimension_numbers = 1; + // The dimension numbers that represent the 'lhs' ragged dimensions. + repeated int64 lhs_ragged_dimensions = 2; + // The dimension numbers that represent the 'rhs' group dimensions. + repeated int64 rhs_group_dimensions = 3; +} + +enum SparsityType { + SPARSITY_INVALID = 0; + + // Structured N:M sparsity. + SPARSITY_STRUCTURED_N_M = 1; + + // Next: 2 +} + +// Contains sparsity metadata for a sparse dot operation. +// The only supported type atm is structured 2:4 sparsity, which is natively +// supported on NVidia GPUs. +// Restrictions: +// - only one operand of the dot operation may be sparse; +// - only the contracting dimension may be sparse. +message SparsityDescriptor { + SparsityType type = 1; + + // Sparse operand index (0 or 1). + int32 index = 2; + // Sparse dimension number. + int32 dimension = 3; + + // Structured N:M sparsity (N < M). + int32 n = 4; + int32 m = 5; + + // Next: 6 +} + +enum RandomDistribution { + RNG_INVALID = 0; + + // Creates a uniform-distribution-generated random number on the semi-open + // interval [parameter[0], parameter[1]). + RNG_UNIFORM = 1; + + // Creates a normal-distribution-generated random number with mean + // parameter[0] and standard deviation parameter[1]. + RNG_NORMAL = 2; + + // Next: 4 +} + +enum RandomAlgorithm { + RNG_DEFAULT = 0; // Backend dependent default algorithm. + RNG_THREE_FRY = 1; + RNG_PHILOX = 2; + // Next: 2 +} + +message TriangularSolveOptions { + // If true, solves ax = b. If false, solves xa = b. + bool left_side = 1; + + // If true, 'a' is lower triangular. If false, 'a' is upper triangular. + bool lower = 2; + + // If true, the diagonal elements of 'a' are assumed to be 1 and not accessed. + bool unit_diagonal = 3; + + // Should we transpose or use the adjoint of 'a'? + enum Transpose { + TRANSPOSE_INVALID = 0; + NO_TRANSPOSE = 1; // Don't transpose 'a'. + TRANSPOSE = 2; // Transpose 'a'. + ADJOINT = 3; // Complex conjugate and transpose 'a'. + } + Transpose transpose_a = 4; +} + +message CholeskyOptions { + // If true, uses the lower triangle of `a`. If false, uses the upper triangle + // of `a`. + bool lower = 1; +} + +// Attributes of the sort custom call (cub::DeviceRadixSort). +message SortOptions { + bool descending = 1; +} + +// Generic map of attributes used to pass hints / configuration options from +// the Python frontend to the XLA backend. +message FrontendAttributes { + map map = 1; +} + +// Represents a single statistic to track. +message Statistic { + // Must be a single word consisting of any alphanumeric characters + string stat_name = 1; + // Must be within a range of [0, 100], in order for the graph dumper to + // properly render the statistic onto the graph. + double stat_val = 2; +} + +// Represents the information needed to visualize propagation statistics when +// rendering an HLO graph. This includes an array of statistics as well as the +// index of the statistic to render. +message StatisticsViz { + int64 stat_index_to_visualize = 1; + repeated Statistic statistics = 2; +} + +// LINT.IfChange +message OpSharding { + enum Type { + // This sharding is replicated across all devices (implies maximal, + // all other fields are unused). + REPLICATED = 0; + // This sharding is maximal - one device runs the entire operation. + MAXIMAL = 1; + // This sharding is a tuple - only the tuple_shardings field is valid. + TUPLE = 2; + // None of the above; tile_shape and tile_assignment are both used. + OTHER = 3; + // This op is manually sharded: the shapes are already partitioned and the + // partitioner should not change this op. + MANUAL = 4; + // This sharding is a placeholder sharding with lowest precedence, it can be + // overwriten by any other shardings. + UNKNOWN = 5; + } + Type type = 1; + // The shape of the sharded tile. + ShapeProto tile_shape = 2; + // The shape of the tile assignment tensor - this must be the same rank as + // tile_shape and the product of its dimensions must equal + // tile_assignment_devices.size(). + repeated int64 tile_assignment_dimensions = 3; + // Flattened list of device IDs. The order of flattening is the same as used + // by IndexUtil::MultiToLinearIndex(tile_assignment_shape). + // Only one of tile_assignment_devices and iota_dimensions shall be non-empty. + repeated int64 tile_assignment_devices = 4; + // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape, + // in pre-order. The tuple shape could be nested; here we store just a + // flattened list of all leaves in the tuple shape. Note that the tuple shape + // is not stored here; shardings do not store the shapes to which they are + // applied, this is inferred from the instruction this sharding gets attached + // to. + repeated OpSharding tuple_shardings = 5; + + // Only used for OTHER type. If true, data is sharded according to other + // dimensions of tile_assignment(), but replicated across devices along the + // last dimension. (Experimental) + bool replicate_on_last_tile_dim = 6; + // This field is used to track the source of this sharding, usually derived + // from instructions. Multple metadata may be populated if sharding is + // combined with other shardings. Metadata are to not be populated when + // type == TUPLE and instead metadata should be set on individual tuple + // elements. + repeated OpMetadata metadata = 7; + + // This field is used to represented the sharding type of each subgroup. + // For example, sharding={devices=[2,2,2,2]0,1,2,...,15 last_tile_dims={ + // replicate, manual, unreduced}} means that each of the last 3 dimensions + // in [2,2,2,2] represents a subgrouping in replicate, manual, + // unreduced sharding type respectively. + repeated Type last_tile_dims = 8; + + // Dimensions used to reshape the 1D iota array of device IDs. + // Only one of tile_assignment_devices and iota_reshape_dims shall be + // non-empty. + repeated int64 iota_reshape_dims = 9; + + // Dimension permutations to transposed the iota array reshaped to + // iota_reshape_dims. This must have the same size as iota_reshape_dims. + repeated int32 iota_transpose_perm = 10; + + // This field decides whether this op is in a shard group. + bool is_shard_group = 11; + + // This field is used to store the unique id of the shard group. + int64 shard_group_id = 12; + + // Used to decide whether this op is to be sharded like some other ops, or to + // which other ops will be sharded like. + enum ShardGroupType { + // This op will be sharded exactly the same as the other op. (hard + // restriction) + AS = 0; + // This op will try to allow sharding propagation within the same group even + // there is no data dependencies among them, but there is no guarantee that + // the final shardings within the same group will be exactly the same. (soft + // restriction) + LIKE = 1; + } + + ShardGroupType shard_group_type = 13; +} +// LINT.ThenChange() + +// Describes the replica groups in a cross replica op (e.g., all-reduce and +// all-to-all). +message ReplicaGroup { + // The ids of the replicas that belongs to the same group. The ordering of the + // ids matters in some ops (e.g., all-to-all). + repeated int64 replica_ids = 1; +} + +// Represents a list of replica groups (a list of list of devices) with +// reshaping and transposing an iota array (iota tile assignment). Can be used +// to represent certain common patterns of device lists in a compact, scalable +// format. +message IotaReplicaGroupListProto { + // Number of replica groups. + int64 num_replica_groups = 1; + + // Number of devices per group. + int64 num_devices_per_group = 2; + + // The dimensions used to reshape the 1D iota array of device IDs. + repeated int64 iota_reshape_dims = 3; + + // The dimension permutations to transposed the iota array reshaped to + // iota_reshape_dims. This must have the same size as iota_reshape_dims. + repeated int32 iota_transpose_perm = 4; +} + +// Represents a series of devices participating in a collective operation (e.g., +// all-reduce and all-to-all). While this directly translates to a list of +// replica groups, it may be used to represent these lists in a compact form. +message CollectiveDeviceListProto { + // ReplicaGroupV1: List of replica groups. Legacy way of representing device + // lists. + repeated ReplicaGroup replica_groups = 1; + + // ReplicaGroupV2: Represents a list of replica groups with reshaping and + // transposing an iota array. + IotaReplicaGroupListProto iota_replica_group_list = 2; +} + +// Describes the source target pair in the collective permute op. +message SourceTarget { + int64 source = 1; + int64 target = 2; +} + +// Describes the types of accuracy the user can request for unary ops with +// multiple implementations. +message ResultAccuracy { + enum Mode { + DEFAULT = 0; + HIGHEST = 1; + } + message Tolerance { + // Absolute error tolerance for unary instructions. + double atol = 1; + // Relative error tolerance for unary instructions. + double rtol = 2; + // The error in ulps (units in the last place) is relative to machine + // precision. + int64 ulps = 3; + } + oneof specs { + // Choose either DEFAULT or HIGHEST precision implementation. + Mode mode = 1; + Tolerance tolerance = 2; + } +} + +// Used to indicate the precision configuration. It has backend specific +// meaning. +message PrecisionConfig { + enum Precision { + DEFAULT = 0; + HIGH = 1; + HIGHEST = 2; + // Each U8/S8 value in a tensor actually represents 2 nibble values. + PACKED_NIBBLE = 3; + + // Next: 4 + } + + // The algorithm used to evaluate the instruction. + // + // The naming convention for the dot instruction is + // ALG_DOT_{A_TYPE}_{B_TYPE}_{ACCUM_TYPE}[_X{NUM_OPS}] where A_TYPE, B_TYPE + // and ACCUM_TYPE correspond to the types in the "primitive dot operations" + // (such as TensorCore operations) and NUM_OPS is the number of such + // operations used per "primitive tile". When the NUM_OPS + // field is skipped, it is assumed to be 1. The types mentioned in the name + // are independent of the storage types. + // + // In general ATYPE and BTYPE are the precisions that the LHS and RHS of the + // operation are rounded to and ACCUMTYPE is the accumulation type. If a + // backend does not support the given algorithm, an error is raised. The + // Algorithm enum is intended to eventually replace the Precision enum. + // + enum Algorithm { + // If the algorithm is `ALG_UNSET`, we will decide the algorithm based on + // the operand_precision values (for now). + ALG_UNSET = 0; + // The storage type can be any 8-bit floating point type. + ALG_DOT_ANY_F8_ANY_F8_F32 = 1; + // The storage type can be any 8-bit floating point type. Intermediate + // results will not periodically be promoted to a higher precision. This + // corresponds to CUBLASLT_MATMUL_DESC_FAST_ACCUM. Triton's + // maxNumImpreciseAcc=32 setting may be similar. + ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM = 2; + ALG_DOT_F16_F16_F16 = 3; + ALG_DOT_F16_F16_F32 = 4; + ALG_DOT_BF16_BF16_BF16 = 5; + ALG_DOT_BF16_BF16_F32 = 6; + // An algorithm which uses 3 BF16_BF16_F32 matmuls to achieve better + // precision. + ALG_DOT_BF16_BF16_F32_X3 = 7; + // An algorithm which uses 6 BF16_BF16_F32 matmuls to achieve better + // precision (similar to F32). + ALG_DOT_BF16_BF16_F32_X6 = 8; + ALG_DOT_TF32_TF32_F32 = 9; + // An algorithm which uses 3 TF32_TF32_F32 matmuls to achieve better + // precision (similar to F32). + ALG_DOT_TF32_TF32_F32_X3 = 10; + ALG_DOT_F32_F32_F32 = 11; + ALG_DOT_F64_F64_F64 = 12; + + // Next: 13 + } + + repeated Precision operand_precision = 1; + + // Currently doesn't do anything, but we plan to support it for dot and + // possibly more instructions. + // + // TODO(b/316147294): Support this on GPU and add this to StableHLO as well. + // + // If this is set, then `operand_precision` should be set to DEFAULT and it + // will be ignored. + Algorithm algorithm = 2; + + // Next: 8 +} + +// Describes whether all data-parallelism replicas will receive the same +// parameter data at each buffer. +message ParameterReplication { + // A list of boolean values for the flattened leaf buffers. Each value + // indicates whether the corresponding leaf buffer is replicated. + // + // If this field is empty, it means no buffer is replicated. Otherwise, the + // number of elements in this field must match the number of leaf buffers in + // the HLO instruction's shape. + repeated bool replicated_at_leaf_buffers = 1; +} + +// A backend-config for kWhile loops that stores the loop's trip count, if it is +// known. +// +// This is useful for backends that can implement a `for i in 0..N` loop more +// efficiently than a `while` loop. For example, on GPUs, we can implement a +// `for i in 0..N` loop by enqueueing the kernels for the loop body N times, +// whereas implementing a `while` loop requires a host-device sync on each +// iteration. +message WhileLoopBackendConfig { + message KnownTripCount { + int64 n = 1; + } + // This indirection lets us distinguish between known-trip-count == 0 and + // unknown-trip-count. + KnownTripCount known_trip_count = 1; +} + +// Specifies a pair of output/operand buffers that alias each other for +// kCustomCall and kFusion +message OutputOperandAliasing { + repeated int64 output_shape_index = 1; + int64 operand_index = 2; + repeated int64 operand_shape_index = 3; +} + +message OriginalArrayProto { + repeated int64 leaf_shape_index = 1; + string instruction_name = 2; + repeated int64 shape_index = 3; +} + +message OriginalValueProto { + repeated OriginalArrayProto leaves = 1; +}