Skip to content

Commit

Permalink
Removing FromApiStatus and unifying IREE_RETURN_IF_ERROR macros.
Browse files Browse the repository at this point in the history
  • Loading branch information
benvanik committed Aug 17, 2020
1 parent 04c39be commit 5527e48
Show file tree
Hide file tree
Showing 26 changed files with 416 additions and 637 deletions.
49 changes: 21 additions & 28 deletions bindings/java/com/google/iree/native/context_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,13 @@ std::vector<iree_vm_module_t*> GetModulesFromModuleWrappers(
} // namespace

Status ContextWrapper::Create(const InstanceWrapper& instance_wrapper) {
RETURN_IF_ERROR(
FromApiStatus(iree_vm_context_create(instance_wrapper.instance(),
iree_allocator_system(), &context_),
IREE_LOC));
RETURN_IF_ERROR(iree_vm_context_create(instance_wrapper.instance(),
iree_allocator_system(), &context_));
RETURN_IF_ERROR(CreateDefaultModules());
std::vector<iree_vm_module_t*> default_modules = {hal_module_};
return FromApiStatus(
iree_vm_context_register_modules(context_, default_modules.data(),
default_modules.size()),
IREE_LOC);
RETURN_IF_ERROR(iree_vm_context_register_modules(
context_, default_modules.data(), default_modules.size()));
return OkStatus();
}

Status ContextWrapper::CreateWithModules(
Expand All @@ -56,25 +53,24 @@ Status ContextWrapper::CreateWithModules(
// beginning of the vector.
modules.insert(modules.begin(), hal_module_);

return FromApiStatus(iree_vm_context_create_with_modules(
instance_wrapper.instance(), modules.data(),
modules.size(), iree_allocator_system(), &context_),
IREE_LOC);
RETURN_IF_ERROR(iree_vm_context_create_with_modules(
instance_wrapper.instance(), modules.data(), modules.size(),
iree_allocator_system(), &context_));
return OkStatus();
}

Status ContextWrapper::RegisterModules(
const std::vector<ModuleWrapper*>& module_wrappers) {
auto modules = GetModulesFromModuleWrappers(module_wrappers);
return FromApiStatus(iree_vm_context_register_modules(
context_, modules.data(), modules.size()),
IREE_LOC);
RETURN_IF_ERROR(iree_vm_context_register_modules(context_, modules.data(),
modules.size()));
return OkStatus();
}

Status ContextWrapper::ResolveFunction(const FunctionWrapper& function_wrapper,
iree_string_view_t name) {
return FromApiStatus(iree_vm_context_resolve_function(
context_, name, function_wrapper.function()),
IREE_LOC);
return iree_vm_context_resolve_function(context_, name,
function_wrapper.function());
}

int ContextWrapper::id() const { return iree_vm_context_id(context_); }
Expand All @@ -88,16 +84,13 @@ ContextWrapper::~ContextWrapper() {

// TODO(jennik): Also create default string and tensorlist modules.
Status ContextWrapper::CreateDefaultModules() {
RETURN_IF_ERROR(FromApiStatus(
iree_hal_driver_registry_create_driver(iree_make_cstring_view("vmla"),
iree_allocator_system(), &driver_),
IREE_LOC));
RETURN_IF_ERROR(FromApiStatus(iree_hal_driver_create_default_device(
driver_, iree_allocator_system(), &device_),
IREE_LOC));
return FromApiStatus(
iree_hal_module_create(device_, iree_allocator_system(), &hal_module_),
IREE_LOC);
RETURN_IF_ERROR(iree_hal_driver_registry_create_driver(
iree_make_cstring_view("vmla"), iree_allocator_system(), &driver_));
RETURN_IF_ERROR(iree_hal_driver_create_default_device(
driver_, iree_allocator_system(), &device_));
RETURN_IF_ERROR(
iree_hal_module_create(device_, iree_allocator_system(), &hal_module_));
return OkStatus();
}

} // namespace java
Expand Down
3 changes: 1 addition & 2 deletions bindings/java/com/google/iree/native/instance_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ Status InstanceWrapper::Create() {
static std::once_flag setup_vm_once;
std::call_once(setup_vm_once, [] { SetupVm(); });

return FromApiStatus(
iree_vm_instance_create(iree_allocator_system(), &instance_), IREE_LOC);
return iree_vm_instance_create(iree_allocator_system(), &instance_);
}

iree_vm_instance_t* InstanceWrapper::instance() const { return instance_; }
Expand Down
8 changes: 3 additions & 5 deletions bindings/java/com/google/iree/native/module_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@ namespace java {

Status ModuleWrapper::Create(const uint8_t* flatbuffer_data,
iree_host_size_t length) {
return FromApiStatus(
iree_vm_bytecode_module_create(
iree_const_byte_span_t{flatbuffer_data, length},
iree_allocator_null(), iree_allocator_system(), &module_),
IREE_LOC);
return iree_vm_bytecode_module_create(
iree_const_byte_span_t{flatbuffer_data, length}, iree_allocator_null(),
iree_allocator_system(), &module_);
}

iree_vm_module_t* ModuleWrapper::module() const { return module_; }
Expand Down
4 changes: 3 additions & 1 deletion iree/base/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,9 @@ cc_test(
cc_library(
name = "status",
hdrs = ["status.h"],
deps = platform_trampoline_deps("status"),
deps = [
"//iree/base/internal:status",
],
)

cc_test(
Expand Down
5 changes: 0 additions & 5 deletions iree/base/api_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ inline iree_status_t ToApiStatus(const Status& status) {
return iree_make_status(static_cast<iree_status_code_t>(status.code()));
}

inline StatusBuilder FromApiStatus(iree_status_t status_code,
SourceLocation loc) {
return StatusBuilder(static_cast<StatusCode>(status_code), loc);
}

// Internal helper for concatenating macro values.
#define IREE_API_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y
#define IREE_API_STATUS_MACROS_IMPL_CONCAT_(x, y) \
Expand Down
10 changes: 10 additions & 0 deletions iree/base/internal/status.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ std::ostream& operator<<(std::ostream& os, const StatusCode& x) {
return os;
}

Status::Status(iree_status_t status) {
// TODO(#265): just store status.
if (!iree_status_is_ok(status)) {
state_ = std::make_unique<State>();
state_->code = static_cast<StatusCode>(iree_status_code(status));
state_->message = std::string("TODO");
iree_status_ignore(status);
}
}

Status::Status(StatusCode code, absl::string_view message) {
if (code != StatusCode::kOk) {
state_ = std::make_unique<State>();
Expand Down
10 changes: 10 additions & 0 deletions iree/base/internal/status.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class ABSL_MUST_USE_RESULT Status final {
// Creates an OK status with no message.
Status() = default;

Status(iree_status_t status);

// Creates a status with the specified code and error message.
// If `code` is kOk, `message` is ignored.
Status(StatusCode code, absl::string_view message);
Expand Down Expand Up @@ -155,6 +157,14 @@ std::ostream& operator<<(std::ostream& os, const Status& x);
// has been augmented by adding `msg` to the end of the original message.
Status Annotate(const Status& s, absl::string_view msg);

ABSL_MUST_USE_RESULT static inline bool IsOk(const Status& status) {
return status.code() == StatusCode::kOk;
}

ABSL_MUST_USE_RESULT static inline bool IsOk(iree_status_t status) {
return iree_status_is_ok(status);
}

ABSL_MUST_USE_RESULT static inline bool IsAborted(const Status& status) {
return status.code() == StatusCode::kAborted;
}
Expand Down
16 changes: 15 additions & 1 deletion iree/base/internal/status_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
// current location.
explicit StatusBuilder(StatusCode code, SourceLocation location);

explicit StatusBuilder(iree_status_t status, SourceLocation location)
: status_(static_cast<StatusCode>(iree_status_code(status)), ""),
loc_(location) {}

StatusBuilder(const StatusBuilder& sb);
StatusBuilder& operator=(const StatusBuilder& sb);
StatusBuilder(StatusBuilder&&) = default;
Expand All @@ -59,6 +63,13 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
operator Status() const&;
operator Status() &&;

// TODO(#265): toll-free result.
operator iree_status_t() && {
return iree_status_allocate(static_cast<iree_status_code_t>(status_.code()),
loc_.file_name(), loc_.line(),
iree_string_view_empty());
}

friend bool operator==(const StatusBuilder& lhs, const StatusCode& rhs) {
return lhs.code() == rhs;
}
Expand Down Expand Up @@ -171,7 +182,10 @@ StatusBuilder Win32ErrorToCanonicalStatusBuilder(uint32_t error,

#define IREE_STATUS_MACROS_IMPL_RETURN_IF_ERROR_(var, expr) \
auto var = (expr); \
if (IREE_UNLIKELY(!var.ok())) \
if (IREE_UNLIKELY(!::iree::IsOk(var))) \
return ::iree::StatusBuilder(std::move(var), IREE_LOC)

#undef IREE_RETURN_IF_ERROR
#define IREE_RETURN_IF_ERROR(expr, ...) RETURN_IF_ERROR(expr)

#endif // IREE_BASE_INTERNAL_STATUS_BUILDER_H_
5 changes: 0 additions & 5 deletions iree/base/status.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,8 @@
#ifndef IREE_BASE_STATUS_H_
#define IREE_BASE_STATUS_H_

#if defined(IREE_CONFIG_GOOGLE_INTERNAL) && IREE_CONFIG_GOOGLE_INTERNAL
#include "iree/base/google_internal/source_location_google.h"
#include "iree/base/google_internal/status_google.h"
#else
#include "iree/base/internal/status.h"
#include "iree/base/internal/status_builder.h"
#include "iree/base/internal/statusor.h"
#endif // IREE_CONFIG_GOOGLE_INTERNAL

#endif // IREE_BASE_STATUS_H_
2 changes: 1 addition & 1 deletion iree/hal/api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2073,7 +2073,7 @@ iree_hal_semaphore_fail(iree_hal_semaphore_t* semaphore, iree_status_t status) {
IREE_TRACE_SCOPE0("iree_hal_semaphore_fail");
IREE_ASSERT_ARGUMENT(semaphore);
auto* handle = reinterpret_cast<Semaphore*>(semaphore);
handle->Fail(FromApiStatus(status, IREE_LOC));
handle->Fail(std::move(status));
}

IREE_API_EXPORT iree_status_t IREE_API_CALL
Expand Down
66 changes: 19 additions & 47 deletions iree/hal/api_string_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
#include "iree/base/api_util.h"
#include "iree/base/memory.h"
#include "iree/base/status.h"
#include "iree/testing/status_matchers.h"
#include "iree/hal/api.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"

namespace iree {
namespace hal {
Expand All @@ -45,10 +45,7 @@ StatusOr<Shape> ParseShape(absl::string_view value) {
shape.size(), shape.data(), &actual_rank);
shape.resize(actual_rank);
} while (iree_status_is_out_of_range(status));
if (!iree_status_is_ok(status)) {
return FromApiStatus(status, IREE_LOC)
<< "Failed to parse shape '" << value << "'";
}
RETURN_IF_ERROR(std::move(status));
return std::move(shape);
}

Expand All @@ -63,9 +60,7 @@ StatusOr<std::string> FormatShape(absl::Span<const iree_hal_dim_t> value) {
&buffer[0], &actual_length);
buffer.resize(actual_length);
} while (iree_status_is_out_of_range(status));
if (!iree_status_is_ok(status)) {
return FromApiStatus(status, IREE_LOC);
}
RETURN_IF_ERROR(std::move(status));
return std::move(buffer);
}

Expand All @@ -75,10 +70,8 @@ StatusOr<iree_hal_element_type_t> ParseElementType(absl::string_view value) {
iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE;
iree_status_t status = iree_hal_parse_element_type(
iree_string_view_t{value.data(), value.size()}, &element_type);
if (!iree_status_is_ok(status)) {
return FromApiStatus(status, IREE_LOC)
<< "Failed to parse element type '" << value << "'";
}
RETURN_IF_ERROR(std::move(status))
<< "Failed to parse element type '" << value << "'";
return element_type;
}

Expand All @@ -93,9 +86,7 @@ StatusOr<std::string> FormatElementType(iree_hal_element_type_t value) {
&actual_length);
buffer.resize(actual_length);
} while (iree_status_is_out_of_range(status));
if (!iree_status_is_ok(status)) {
return FromApiStatus(status, IREE_LOC);
}
RETURN_IF_ERROR(std::move(status));
return std::move(buffer);
}

Expand All @@ -111,10 +102,8 @@ Status ParseElement(absl::string_view value,
iree_string_view_t{value.data(), value.size()}, element_type,
iree_byte_span_t{reinterpret_cast<uint8_t*>(buffer.data()),
buffer.size() * sizeof(T)});
if (!iree_status_is_ok(status)) {
return FromApiStatus(status, IREE_LOC)
<< "Failed to parse element '" << value << "'";
}
RETURN_IF_ERROR(std::move(status))
<< "Failed to parse element '" << value << "'";
return OkStatus();
}

Expand All @@ -132,10 +121,8 @@ StatusOr<std::string> FormatElement(T value,
element_type, result.size() + 1, &result[0], &actual_length);
result.resize(actual_length);
} while (iree_status_is_out_of_range(status));
if (!iree_status_is_ok(status)) {
return FromApiStatus(status, IREE_LOC)
<< "Failed to format buffer element '" << value << "'";
}
RETURN_IF_ERROR(std::move(status))
<< "Failed to format buffer element '" << value << "'";
return std::move(result);
}

Expand All @@ -148,14 +135,11 @@ template <typename T>
Status ParseBufferElements(absl::string_view value,
iree_hal_element_type_t element_type,
absl::Span<T> buffer) {
iree_status_t status = iree_hal_parse_buffer_elements(
RETURN_IF_ERROR(iree_hal_parse_buffer_elements(
iree_string_view_t{value.data(), value.size()}, element_type,
iree_byte_span_t{reinterpret_cast<uint8_t*>(buffer.data()),
buffer.size() * sizeof(T)});
if (!iree_status_is_ok(status)) {
return FromApiStatus(status, IREE_LOC)
<< "Failed to parse buffer elements '" << value << "'";
}
buffer.size() * sizeof(T)}))
<< "Failed to parse buffer elements '" << value << "'";
return OkStatus();
}

Expand All @@ -181,9 +165,7 @@ StatusOr<std::string> FormatBufferElements(absl::Span<const T> data,
result.size() + 1, &result[0], &actual_length);
result.resize(actual_length);
} while (iree_status_is_out_of_range(status));
if (!iree_status_is_ok(status)) {
return FromApiStatus(status, IREE_LOC);
}
RETURN_IF_ERROR(std::move(status));
return std::move(result);
}

Expand Down Expand Up @@ -411,9 +393,7 @@ struct Allocator final
Allocator allocator;
iree_status_t status = iree_hal_allocator_create_host_local(
iree_allocator_system(), &allocator);
if (!iree_status_is_ok(status)) {
return FromApiStatus(status, IREE_LOC);
}
RETURN_IF_ERROR(std::move(status));
return std::move(allocator);
}
};
Expand All @@ -436,9 +416,7 @@ struct Buffer final : public Handle<iree_hal_buffer_t, iree_hal_buffer_retain,
std::vector<T> result(total_byte_length / sizeof(T));
iree_status_t status =
iree_hal_buffer_read_data(get(), 0, result.data(), total_byte_length);
if (!iree_status_is_ok(status)) {
return FromApiStatus(status, IREE_LOC);
}
RETURN_IF_ERROR(std::move(status));
return std::move(result);
}
};
Expand All @@ -457,9 +435,7 @@ struct BufferView final
iree_status_t status = iree_hal_buffer_view_create(
buffer, shape.data(), shape.size(), element_type,
iree_allocator_system(), &buffer_view);
if (!iree_status_is_ok(status)) {
return FromApiStatus(status, IREE_LOC);
}
RETURN_IF_ERROR(std::move(status));
return std::move(buffer_view);
}

Expand Down Expand Up @@ -510,9 +486,7 @@ struct BufferView final
iree_status_t status = iree_hal_buffer_view_parse(
iree_string_view_t{value.data(), value.size()}, allocator,
iree_allocator_system(), &buffer_view);
if (!iree_status_is_ok(status)) {
return FromApiStatus(status, IREE_LOC);
}
RETURN_IF_ERROR(std::move(status));
return std::move(buffer_view);
}

Expand All @@ -532,9 +506,7 @@ struct BufferView final
&actual_length);
result.resize(actual_length);
} while (iree_status_is_out_of_range(status));
if (!iree_status_is_ok(status)) {
return FromApiStatus(status, IREE_LOC);
}
RETURN_IF_ERROR(std::move(status));
return std::move(result);
}
};
Expand Down
Loading

0 comments on commit 5527e48

Please sign in to comment.