From 43295bb5378453d2ec4d9272cb44c6f50b4faa1f Mon Sep 17 00:00:00 2001 From: Matous Kozak <55735845+matouskozak@users.noreply.github.com> Date: Tue, 8 Oct 2024 07:24:44 +0200 Subject: [PATCH] [mono] Fix vector class retrieval and type checks for binary operand APIs (#107388) - change the function to be split by the OP code rather than the type of the operands - add type checks to the callsite to ensure that the operands are of the correct type --- src/mono/mono/mini/simd-intrinsics.c | 209 ++++++++++++++------------- 1 file changed, 109 insertions(+), 100 deletions(-) diff --git a/src/mono/mono/mini/simd-intrinsics.c b/src/mono/mono/mini/simd-intrinsics.c index 44e7072def713..7a706fa49d092 100644 --- a/src/mono/mono/mini/simd-intrinsics.c +++ b/src/mono/mono/mini/simd-intrinsics.c @@ -339,111 +339,73 @@ emit_simd_ins_for_binary_op (MonoCompile *cfg, MonoClass *klass, MonoMethodSigna int instc0 = -1; int op = OP_XBINOP; - if (id == SN_BitwiseAnd || id == SN_BitwiseOr || id == SN_Xor || - id == SN_op_BitwiseAnd || id == SN_op_BitwiseOr || id == SN_op_ExclusiveOr) { - op = OP_XBINOP_FORCEINT; - - switch (id) { + switch (id) { + case SN_Add: + case SN_op_Addition: { + if (type_enum_is_float (arg_type)) { + instc0 = OP_FADD; + } else { + instc0 = OP_IADD; + } + break; + } case SN_BitwiseAnd: - case SN_op_BitwiseAnd: + case SN_op_BitwiseAnd: { + op = OP_XBINOP_FORCEINT; instc0 = XBINOP_FORCEINT_AND; break; + } case SN_BitwiseOr: - case SN_op_BitwiseOr: + case SN_op_BitwiseOr: { + op = OP_XBINOP_FORCEINT; instc0 = XBINOP_FORCEINT_OR; break; - case SN_op_ExclusiveOr: - case SN_Xor: - instc0 = XBINOP_FORCEINT_XOR; - break; } - } else { - if (type_enum_is_float (arg_type)) { - switch (id) { - case SN_Add: - case SN_op_Addition: - instc0 = OP_FADD; - break; - case SN_Divide: - case SN_op_Division: { - const char *class_name = m_class_get_name (klass); - if (strcmp ("Quaternion", class_name) && strcmp ("Plane", class_name)) { - if (!type_is_simd_vector (fsig->params [1])) - return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, OP_FDIV); - else if (type_is_simd_vector (fsig->params [0]) && type_is_simd_vector (fsig->params [1])) { - instc0 = OP_FDIV; - break; - } else { - return NULL; - } - } + case SN_Divide: + case SN_op_Division: { + if (type_enum_is_float (arg_type)) { instc0 = OP_FDIV; - break; - } -#ifdef TARGET_ARM64 - case SN_Max: -#endif - case SN_MaxNative: - instc0 = OP_FMAX; - break; -#ifdef TARGET_ARM64 - case SN_Min: -#endif - case SN_MinNative: - instc0 = OP_FMIN; - break; - case SN_Multiply: - case SN_op_Multiply: { - const char *class_name = m_class_get_name (klass); - if (strcmp ("Quaternion", class_name) && strcmp ("Plane", class_name)) { - if (!type_is_simd_vector (fsig->params [1])) - return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, OP_FMUL); - else if (!type_is_simd_vector (fsig->params [0])) - return handle_mul_div_by_scalar (cfg, klass, arg_type, args [0]->dreg, args [1]->dreg, OP_FMUL); - else if (type_is_simd_vector (fsig->params [0]) && type_is_simd_vector (fsig->params [1])) { - instc0 = OP_FMUL; - break; - } else { - return NULL; - } + if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [1])) { // vector / scalar + return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, instc0); } - instc0 = OP_FMUL; - break; - } - case SN_Subtract: - case SN_op_Subtraction: - instc0 = OP_FSUB; - break; - default: - g_assert_not_reached (); - } - } else { - switch (id) { - case SN_Add: - case SN_op_Addition: - instc0 = OP_IADD; - break; - case SN_Divide: - case SN_op_Division: + } else { return NULL; - case SN_Max: - case SN_MaxNative: + } + break; + } + case SN_Max: + case SN_MaxNative: { + if (type_enum_is_float (arg_type)) { + instc0 = OP_FMAX; + } else { instc0 = type_enum_is_unsigned (arg_type) ? OP_IMAX_UN : OP_IMAX; + #ifdef TARGET_AMD64 if (!COMPILE_LLVM (cfg) && instc0 == OP_IMAX_UN) return NULL; #endif - break; - case SN_Min: - case SN_MinNative: + } + break; + } + case SN_Min: + case SN_MinNative: { + if (type_enum_is_float (arg_type)) { + instc0 = OP_FMIN; + } else { instc0 = type_enum_is_unsigned (arg_type) ? OP_IMIN_UN : OP_IMIN; + #ifdef TARGET_AMD64 if (!COMPILE_LLVM (cfg) && instc0 == OP_IMIN_UN) return NULL; #endif - break; - case SN_Multiply: - case SN_op_Multiply: { + } + break; + } + case SN_Multiply: + case SN_op_Multiply: { + if (type_enum_is_float (arg_type)) { + instc0 = OP_FMUL; + } else { #ifdef TARGET_ARM64 if (!COMPILE_LLVM (cfg) && (arg_type == MONO_TYPE_I8 || arg_type == MONO_TYPE_U8 || arg_type == MONO_TYPE_I || arg_type == MONO_TYPE_U)) return NULL; @@ -452,22 +414,34 @@ emit_simd_ins_for_binary_op (MonoCompile *cfg, MonoClass *klass, MonoMethodSigna if (!COMPILE_LLVM (cfg)) return NULL; #endif - if (fsig->params [1]->type != MONO_TYPE_GENERICINST) - return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, OP_IMUL); - else if (fsig->params [0]->type != MONO_TYPE_GENERICINST) - return handle_mul_div_by_scalar (cfg, klass, arg_type, args [0]->dreg, args [1]->dreg, OP_IMUL); instc0 = OP_IMUL; - break; } - case SN_Subtract: - case SN_op_Subtraction: + if (MONO_TYPE_IS_VECTOR_PRIMITIVE(fsig->params [1])) { // vector * scalar + return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, instc0); + } else if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [0])) { // scalar * vector + return handle_mul_div_by_scalar (cfg, klass, arg_type, args [0]->dreg, args [1]->dreg, instc0); + } + break; + } + case SN_Subtract: + case SN_op_Subtraction: { + if (type_enum_is_float (arg_type)) { + instc0 = OP_FSUB; + } else { instc0 = OP_ISUB; - break; - default: - g_assert_not_reached (); } + break; + } + case SN_Xor: + case SN_op_ExclusiveOr: { + op = OP_XBINOP_FORCEINT; + instc0 = XBINOP_FORCEINT_XOR; + break; } + default: + g_assert_not_reached (); } + return emit_simd_ins_for_sig (cfg, klass, op, instc0, arg_type, fsig, args); } @@ -1992,7 +1966,7 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi return NULL; #endif - MonoClass* klass = fsig->param_count > 0 ? args[0]->klass : cmethod->klass; + MonoClass *klass = fsig->param_count > 0 ? args [0]->klass : cmethod->klass; MonoTypeEnum arg0_type = fsig->param_count > 0 ? get_underlying_type (fsig->params [0]) : MONO_TYPE_VOID; if (cfg->verbose_level > 1) { @@ -2057,21 +2031,56 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi case SN_Add: case SN_BitwiseAnd: case SN_BitwiseOr: - case SN_Divide: case SN_Max: case SN_MaxNative: case SN_Min: case SN_MinNative: - case SN_Multiply: case SN_Subtract: - case SN_Xor: - if (!is_element_type_primitive (fsig->params [0])) + case SN_Xor: { + if (fsig->param_count != 2) return NULL; + + if (!is_element_type_primitive (fsig->params [0]) || !is_element_type_primitive (fsig->params [1])) + return NULL; + #ifndef TARGET_ARM64 if (((id == SN_Max) || (id == SN_Min)) && type_enum_is_float(arg0_type)) return NULL; #endif + return emit_simd_ins_for_binary_op (cfg, klass, fsig, args, arg0_type, id); + } + case SN_Divide: { + if (fsig->param_count != 2) + return NULL; + + if (!is_element_type_primitive (fsig->params [0]) || + !(MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [1]) || is_element_type_primitive (fsig->params [1]))) + return NULL; + + return emit_simd_ins_for_binary_op (cfg, klass, fsig, args, arg0_type, id); + } + case SN_Multiply: { + if (fsig->param_count != 2) + return NULL; + + MonoTypeEnum vector_inner_type = arg0_type; + if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [0])) { + if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [1]) || !is_element_type_primitive (fsig->params [1])) + return NULL; + // By default, we expect the first argument to be the vector type + // however, for Multiply, the first argument can be scalar. In this case, we need to + // get the vector type from the second argument. + klass = args [1]->klass; + vector_inner_type = get_underlying_type (fsig->params [1]); + } else if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [1])) { + if (!is_element_type_primitive (fsig->params [0])) + return NULL; + } else if (!(is_element_type_primitive (fsig->params [0]) && is_element_type_primitive (fsig->params [1]))) + return NULL; + + return emit_simd_ins_for_binary_op (cfg, klass, fsig, args, vector_inner_type, id); + } case SN_AndNot: { if (!is_element_type_primitive (fsig->params [0])) return NULL;