diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index f320bf61f0ddf..bccbdc8b2dbd0 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -237,10 +237,12 @@ static bool IsTypeSupported(const NodeArg* node_arg) { case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2FNUZ: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32: @@ -277,6 +279,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ: mgx_type = migraphx_shape_fp8e5m2fnuz_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: + mgx_type = migraphx_shape_int8_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: mgx_type = migraphx_shape_int8_type; break; @@ -289,6 +294,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: mgx_type = migraphx_shape_int64_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4: + mgx_type = migraphx_shape_uint8_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: mgx_type = migraphx_shape_uint8_type; break; @@ -633,7 +641,7 @@ static bool IsNodeSupported(const std::set& op_set, } // whether an operator implemented in migraphx - if (op_set.count(optype) == 0) { + if (op_set.count(optype) == 0 or op_set.count(optype+"fusion")) { return false; }