Skip to content

Commit

Permalink
PushErrorScope and PopErrorScope
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 11, 2024
1 parent d4a963d commit 8b61532
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 10 deletions.
20 changes: 20 additions & 0 deletions onnxruntime/core/providers/webgpu/compute_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,25 @@ ComputeContext::ComputeContext(OpKernelContext& kernel_context)
kernel_context_{kernel_context} {
}

void ComputeContext::PushErrorScope() {
webgpu_context_.Device().PushErrorScope(wgpu::ErrorFilter::Validation);
}

Status ComputeContext::PopErrorScope() {
Status status{};

ORT_RETURN_IF_ERROR(webgpu_context_.Wait(
webgpu_context_.Device().PopErrorScope(
wgpu::CallbackMode::WaitAnyOnly, [](wgpu::PopErrorScopeStatus pop_status, wgpu::ErrorType error_type, char const* message, Status* status) {
ORT_ENFORCE(pop_status == wgpu::PopErrorScopeStatus::Success, "Instance dropped.");
if (error_type == wgpu::ErrorType::NoError) {
return;
}
*status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "WebGPU validation failed. ", message);
},
&status)));
return status;
}

} // namespace webgpu
} // namespace onnxruntime
14 changes: 14 additions & 0 deletions onnxruntime/core/providers/webgpu/compute_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,20 @@ class ComputeContext {
return webgpu_context_.Run(*this, program);
}

//
// Push error scope.
//
// This is useful only when "skip_validation" is not set.
//
void PushErrorScope();

//
// Pop error scope.
//
// This is useful only when "skip_validation" is not set.
//
Status PopErrorScope();

protected:
WebGpuContext& webgpu_context_;
OpKernelContext& kernel_context_;
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ std::vector<const char*> GetEnabledDeviceToggles() {
// Enable / disable other toggles that may affect the performance.
// Other toggles that may be useful: "dump_shaders", "disable_symbol_renaming"
constexpr const char* toggles[] = {
#ifdef NDEBUG
// todo: when skip validation, the process may crash.
// need careful decision to enable this toggle.
// revisit this flag before release.
"skip_validation",
#endif
"disable_robustness",
"disable_workgroup_init",
"d3d_disable_ieee_strictness",
Expand Down
18 changes: 8 additions & 10 deletions onnxruntime/core/providers/webgpu/webgpu_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,14 @@ class WebGpuKernel : public OpKernel {

Status Compute(OpKernelContext* p_op_kernel_context) const override {
ComputeContext context{*p_op_kernel_context};
auto s = ComputeInternal(context);
// use this to precisely locate the node where CUDA failure comes from
// if (cudaSuccess != cudaDeviceSynchronize())
// __debugbreak();
// if (s.IsOK()) {
// auto err = cudaGetLastError();
// if (err != cudaSuccess) {
// return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA error ", cudaGetErrorName(err), ":", cudaGetErrorString(err));
// }
// }
#ifndef NDEBUG
context.PushErrorScope();
#endif
Status s = ComputeInternal(context);
#ifndef NDEBUG
ORT_RETURN_IF_ERROR(context.PopErrorScope());
#endif

return s;
}

Expand Down

0 comments on commit 8b61532

Please sign in to comment.