Skip to content

Commit

Permalink
Add support for int4 inputs
Browse files Browse the repository at this point in the history
Map things to int8 right now as we don't explicitly set an int4 input type and pack/unpack int4 operands
  • Loading branch information
TedThemistokleous committed Dec 4, 2024
1 parent 531f35c commit 9b44fcc
Showing 1 changed file with 9 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,12 @@ static bool IsTypeSupported(const NodeArg* node_arg) {
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2FNUZ:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32:
Expand Down Expand Up @@ -277,6 +279,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type,
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ:
mgx_type = migraphx_shape_fp8e5m2fnuz_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4:
mgx_type = migraphx_shape_int8_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
mgx_type = migraphx_shape_int8_type;
break;
Expand All @@ -289,6 +294,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type,
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
mgx_type = migraphx_shape_int64_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4:
mgx_type = migraphx_shape_uint8_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
mgx_type = migraphx_shape_uint8_type;
break;
Expand Down Expand Up @@ -633,7 +641,7 @@ static bool IsNodeSupported(const std::set<std::string>& op_set,
}

// whether an operator implemented in migraphx
if (op_set.count(optype) == 0) {
if (op_set.count(optype) == 0 or op_set.count(optype+"fusion")) {
return false;
}

Expand Down

0 comments on commit 9b44fcc

Please sign in to comment.