diff --git a/onnxruntime/core/providers/js/operators/unsqueeze.h b/onnxruntime/core/providers/js/operators/unsqueeze.h index 7cbfdc38b742d..f15a3008895aa 100644 --- a/onnxruntime/core/providers/js/operators/unsqueeze.h +++ b/onnxruntime/core/providers/js/operators/unsqueeze.h @@ -26,8 +26,9 @@ class Unsqueeze final : public JsKernel, public UnsqueezeBase { if (num_inputs == 2) { // axes is an input const Tensor* axes_tensor = context->Input(1); ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); - ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, - "An axes tensor must be a vector tensor."); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 0 || + axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a scalar or a vector tensor."); auto nDims = static_cast(axes_tensor->Shape()[0]); const auto* data = axes_tensor->Data(); axes.assign(data, data + nDims);