Skip to content

Commit

Permalink
Merge pull request #5495 from google/benvanik-outline-all-the-things
Browse files Browse the repository at this point in the history
Force all constants to be outlined and stored in the constant pool.
  • Loading branch information
benvanik authored Apr 16, 2021
2 parents 3c8880e + 3525c5b commit eb35573
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ static PassRegistration<OutlineLargeConstantsPass> pass(
"iree-flow-outline-large-constants",
"Outlines large tensor constants into flow.variables at the module level.",
[] {
return std::make_unique<OutlineLargeConstantsPass>(kMinLargeConstantSize);
// TODO(#5493): add a flag for this.
return std::make_unique<OutlineLargeConstantsPass>(256);
});

} // namespace Flow
Expand Down
6 changes: 3 additions & 3 deletions iree/compiler/Dialect/Flow/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ std::unique_ptr<OperationPass<ModuleOp>> createExportBenchmarkFuncsPass();

// Outlines large tensor constants into flow.variables at the module level.
//
// NOTE: a total guess :) this feels like about the most per-dispatch-buffer
// data we'd want to embed in the command buffer.
static constexpr size_t kMinLargeConstantSize = 256;
// TODO(#5493): implement the support for inlining constants into the command
// buffer and raise this value to one that is measured to be good.
static constexpr size_t kMinLargeConstantSize = 1;
std::unique_ptr<OperationPass<ModuleOp>> createOutlineLargeConstantsPass(
size_t minLargeConstantSize = kMinLargeConstantSize);

Expand Down
5 changes: 5 additions & 0 deletions iree/hal/buffer_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ iree_hal_buffer_view_release(iree_hal_buffer_view_t* buffer_view);

// Returns the buffer underlying the buffer view.
// The caller must retain the returned buffer if they want to continue using it.
//
// NOTE: the returned buffer length will almost always be larger than the valid
// bytes representing this buffer view due to padding. Always query the actual
// valid length with iree_hal_buffer_view_byte_length instead of assuming the
// buffer is already clamped.
IREE_API_EXPORT iree_hal_buffer_t* IREE_API_CALL
iree_hal_buffer_view_buffer(const iree_hal_buffer_view_t* buffer_view);

Expand Down
21 changes: 14 additions & 7 deletions iree/modules/check/native_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,11 @@ class CheckModuleState final {
iree_hal_element_type_t element_type =
iree_hal_buffer_view_element_type(view);
iree_hal_buffer_t* buf = iree_hal_buffer_view_buffer(view);
iree_device_size_t size = iree_hal_buffer_view_byte_length(view);
iree_hal_buffer_mapping_t mapped_memory;
IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
buf, IREE_HAL_MEMORY_ACCESS_READ,
/*byte_offset=*/0, IREE_WHOLE_BUFFER, &mapped_memory));
IREE_RETURN_IF_ERROR(
iree_hal_buffer_map_range(buf, IREE_HAL_MEMORY_ACCESS_READ,
/*byte_offset=*/0, size, &mapped_memory));
IREE_RETURN_IF_ERROR(
::iree::ExpectAllTrue(mapped_memory.contents, element_type));
iree_hal_buffer_unmap_range(&mapped_memory);
Expand All @@ -215,13 +216,16 @@ class CheckModuleState final {
vm::ref<iree_hal_buffer_view_t> rhs_ref) {
auto* lhs = lhs_ref.get();
auto* rhs = rhs_ref.get();

iree_device_size_t lhs_size = iree_hal_buffer_view_byte_length(lhs);
size_t lhs_rank = iree_hal_buffer_view_shape_rank(lhs);
std::vector<iree_hal_dim_t> lhs_shape(lhs_rank);
if (lhs_rank > 0) {
IREE_RETURN_IF_ERROR(
iree_hal_buffer_view_shape(lhs, lhs_rank, lhs_shape.data(), nullptr));
}

iree_device_size_t rhs_size = iree_hal_buffer_view_byte_length(rhs);
size_t rhs_rank = iree_hal_buffer_view_shape_rank(rhs);
std::vector<iree_hal_dim_t> rhs_shape(rhs_rank);
if (rhs_rank > 0) {
Expand All @@ -238,12 +242,12 @@ class CheckModuleState final {
iree_hal_buffer_mapping_t lhs_mapped_memory;
IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
lhs_buf, IREE_HAL_MEMORY_ACCESS_READ,
/*byte_offset=*/0, IREE_WHOLE_BUFFER, &lhs_mapped_memory));
/*byte_offset=*/0, lhs_size, &lhs_mapped_memory));
iree_hal_buffer_t* rhs_buf = iree_hal_buffer_view_buffer(rhs);
iree_hal_buffer_mapping_t rhs_mapped_memory;
IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
rhs_buf, IREE_HAL_MEMORY_ACCESS_READ,
/*byte_offset=*/0, IREE_WHOLE_BUFFER, &rhs_mapped_memory));
/*byte_offset=*/0, rhs_size, &rhs_mapped_memory));

bool element_types_eq = lhs_element_type == rhs_element_type;
bool shape_eq = lhs_shape == rhs_shape;
Expand Down Expand Up @@ -288,13 +292,16 @@ class CheckModuleState final {
vm::ref<iree_hal_buffer_view_t> rhs_ref) {
auto* lhs = lhs_ref.get();
auto* rhs = rhs_ref.get();

iree_device_size_t lhs_size = iree_hal_buffer_view_byte_length(lhs);
size_t lhs_rank = iree_hal_buffer_view_shape_rank(lhs);
std::vector<iree_hal_dim_t> lhs_shape(lhs_rank);
if (lhs_rank > 0) {
IREE_RETURN_IF_ERROR(
iree_hal_buffer_view_shape(lhs, lhs_rank, lhs_shape.data(), nullptr));
}

iree_device_size_t rhs_size = iree_hal_buffer_view_byte_length(rhs);
size_t rhs_rank = iree_hal_buffer_view_shape_rank(rhs);
std::vector<iree_hal_dim_t> rhs_shape(rhs_rank);
if (rhs_rank > 0) {
Expand All @@ -311,12 +318,12 @@ class CheckModuleState final {
iree_hal_buffer_mapping_t lhs_mapped_memory;
IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
lhs_buf, IREE_HAL_MEMORY_ACCESS_READ,
/*byte_offset=*/0, IREE_WHOLE_BUFFER, &lhs_mapped_memory));
/*byte_offset=*/0, lhs_size, &lhs_mapped_memory));
iree_hal_buffer_t* rhs_buf = iree_hal_buffer_view_buffer(rhs);
iree_hal_buffer_mapping_t rhs_mapped_memory;
IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
rhs_buf, IREE_HAL_MEMORY_ACCESS_READ,
/*byte_offset=*/0, IREE_WHOLE_BUFFER, &rhs_mapped_memory));
/*byte_offset=*/0, rhs_size, &rhs_mapped_memory));

bool element_types_eq = lhs_element_type == rhs_element_type;
bool shape_eq = lhs_shape == rhs_shape;
Expand Down

0 comments on commit eb35573

Please sign in to comment.