Skip to content

Commit

Permalink
[mono] Fix vector class retrieval and type checks for binary operand …
Browse files Browse the repository at this point in the history
…APIs (dotnet#107388)

- change the function to be split by the OP code rather than the type of the operands
- add type checks to the callsite to ensure that the operands are of the correct type
  • Loading branch information
matouskozak authored Oct 8, 2024
1 parent 733ef6a commit 43295bb
Showing 1 changed file with 109 additions and 100 deletions.
209 changes: 109 additions & 100 deletions src/mono/mono/mini/simd-intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -339,111 +339,73 @@ emit_simd_ins_for_binary_op (MonoCompile *cfg, MonoClass *klass, MonoMethodSigna
int instc0 = -1;
int op = OP_XBINOP;

if (id == SN_BitwiseAnd || id == SN_BitwiseOr || id == SN_Xor ||
id == SN_op_BitwiseAnd || id == SN_op_BitwiseOr || id == SN_op_ExclusiveOr) {
op = OP_XBINOP_FORCEINT;

switch (id) {
switch (id) {
case SN_Add:
case SN_op_Addition: {
if (type_enum_is_float (arg_type)) {
instc0 = OP_FADD;
} else {
instc0 = OP_IADD;
}
break;
}
case SN_BitwiseAnd:
case SN_op_BitwiseAnd:
case SN_op_BitwiseAnd: {
op = OP_XBINOP_FORCEINT;
instc0 = XBINOP_FORCEINT_AND;
break;
}
case SN_BitwiseOr:
case SN_op_BitwiseOr:
case SN_op_BitwiseOr: {
op = OP_XBINOP_FORCEINT;
instc0 = XBINOP_FORCEINT_OR;
break;
case SN_op_ExclusiveOr:
case SN_Xor:
instc0 = XBINOP_FORCEINT_XOR;
break;
}
} else {
if (type_enum_is_float (arg_type)) {
switch (id) {
case SN_Add:
case SN_op_Addition:
instc0 = OP_FADD;
break;
case SN_Divide:
case SN_op_Division: {
const char *class_name = m_class_get_name (klass);
if (strcmp ("Quaternion", class_name) && strcmp ("Plane", class_name)) {
if (!type_is_simd_vector (fsig->params [1]))
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, OP_FDIV);
else if (type_is_simd_vector (fsig->params [0]) && type_is_simd_vector (fsig->params [1])) {
instc0 = OP_FDIV;
break;
} else {
return NULL;
}
}
case SN_Divide:
case SN_op_Division: {
if (type_enum_is_float (arg_type)) {
instc0 = OP_FDIV;
break;
}
#ifdef TARGET_ARM64
case SN_Max:
#endif
case SN_MaxNative:
instc0 = OP_FMAX;
break;
#ifdef TARGET_ARM64
case SN_Min:
#endif
case SN_MinNative:
instc0 = OP_FMIN;
break;
case SN_Multiply:
case SN_op_Multiply: {
const char *class_name = m_class_get_name (klass);
if (strcmp ("Quaternion", class_name) && strcmp ("Plane", class_name)) {
if (!type_is_simd_vector (fsig->params [1]))
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, OP_FMUL);
else if (!type_is_simd_vector (fsig->params [0]))
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [0]->dreg, args [1]->dreg, OP_FMUL);
else if (type_is_simd_vector (fsig->params [0]) && type_is_simd_vector (fsig->params [1])) {
instc0 = OP_FMUL;
break;
} else {
return NULL;
}
if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [1])) { // vector / scalar
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, instc0);
}
instc0 = OP_FMUL;
break;
}
case SN_Subtract:
case SN_op_Subtraction:
instc0 = OP_FSUB;
break;
default:
g_assert_not_reached ();
}
} else {
switch (id) {
case SN_Add:
case SN_op_Addition:
instc0 = OP_IADD;
break;
case SN_Divide:
case SN_op_Division:
} else {
return NULL;
case SN_Max:
case SN_MaxNative:
}
break;
}
case SN_Max:
case SN_MaxNative: {
if (type_enum_is_float (arg_type)) {
instc0 = OP_FMAX;
} else {
instc0 = type_enum_is_unsigned (arg_type) ? OP_IMAX_UN : OP_IMAX;

#ifdef TARGET_AMD64
if (!COMPILE_LLVM (cfg) && instc0 == OP_IMAX_UN)
return NULL;
#endif
break;
case SN_Min:
case SN_MinNative:
}
break;
}
case SN_Min:
case SN_MinNative: {
if (type_enum_is_float (arg_type)) {
instc0 = OP_FMIN;
} else {
instc0 = type_enum_is_unsigned (arg_type) ? OP_IMIN_UN : OP_IMIN;

#ifdef TARGET_AMD64
if (!COMPILE_LLVM (cfg) && instc0 == OP_IMIN_UN)
return NULL;
#endif
break;
case SN_Multiply:
case SN_op_Multiply: {
}
break;
}
case SN_Multiply:
case SN_op_Multiply: {
if (type_enum_is_float (arg_type)) {
instc0 = OP_FMUL;
} else {
#ifdef TARGET_ARM64
if (!COMPILE_LLVM (cfg) && (arg_type == MONO_TYPE_I8 || arg_type == MONO_TYPE_U8 || arg_type == MONO_TYPE_I || arg_type == MONO_TYPE_U))
return NULL;
Expand All @@ -452,22 +414,34 @@ emit_simd_ins_for_binary_op (MonoCompile *cfg, MonoClass *klass, MonoMethodSigna
if (!COMPILE_LLVM (cfg))
return NULL;
#endif
if (fsig->params [1]->type != MONO_TYPE_GENERICINST)
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, OP_IMUL);
else if (fsig->params [0]->type != MONO_TYPE_GENERICINST)
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [0]->dreg, args [1]->dreg, OP_IMUL);
instc0 = OP_IMUL;
break;
}
case SN_Subtract:
case SN_op_Subtraction:
if (MONO_TYPE_IS_VECTOR_PRIMITIVE(fsig->params [1])) { // vector * scalar
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [1]->dreg, args [0]->dreg, instc0);
} else if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [0])) { // scalar * vector
return handle_mul_div_by_scalar (cfg, klass, arg_type, args [0]->dreg, args [1]->dreg, instc0);
}
break;
}
case SN_Subtract:
case SN_op_Subtraction: {
if (type_enum_is_float (arg_type)) {
instc0 = OP_FSUB;
} else {
instc0 = OP_ISUB;
break;
default:
g_assert_not_reached ();
}
break;
}
case SN_Xor:
case SN_op_ExclusiveOr: {
op = OP_XBINOP_FORCEINT;
instc0 = XBINOP_FORCEINT_XOR;
break;
}
default:
g_assert_not_reached ();
}

return emit_simd_ins_for_sig (cfg, klass, op, instc0, arg_type, fsig, args);
}

Expand Down Expand Up @@ -1992,7 +1966,7 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
return NULL;
#endif

MonoClass* klass = fsig->param_count > 0 ? args[0]->klass : cmethod->klass;
MonoClass *klass = fsig->param_count > 0 ? args [0]->klass : cmethod->klass;
MonoTypeEnum arg0_type = fsig->param_count > 0 ? get_underlying_type (fsig->params [0]) : MONO_TYPE_VOID;

if (cfg->verbose_level > 1) {
Expand Down Expand Up @@ -2057,21 +2031,56 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
case SN_Add:
case SN_BitwiseAnd:
case SN_BitwiseOr:
case SN_Divide:
case SN_Max:
case SN_MaxNative:
case SN_Min:
case SN_MinNative:
case SN_Multiply:
case SN_Subtract:
case SN_Xor:
if (!is_element_type_primitive (fsig->params [0]))
case SN_Xor: {
if (fsig->param_count != 2)
return NULL;

if (!is_element_type_primitive (fsig->params [0]) || !is_element_type_primitive (fsig->params [1]))
return NULL;

#ifndef TARGET_ARM64
if (((id == SN_Max) || (id == SN_Min)) && type_enum_is_float(arg0_type))
return NULL;
#endif

return emit_simd_ins_for_binary_op (cfg, klass, fsig, args, arg0_type, id);
}
case SN_Divide: {
if (fsig->param_count != 2)
return NULL;

if (!is_element_type_primitive (fsig->params [0]) ||
!(MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [1]) || is_element_type_primitive (fsig->params [1])))
return NULL;

return emit_simd_ins_for_binary_op (cfg, klass, fsig, args, arg0_type, id);
}
case SN_Multiply: {
if (fsig->param_count != 2)
return NULL;

MonoTypeEnum vector_inner_type = arg0_type;
if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [0])) {
if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [1]) || !is_element_type_primitive (fsig->params [1]))
return NULL;
// By default, we expect the first argument to be the vector type
// however, for Multiply, the first argument can be scalar. In this case, we need to
// get the vector type from the second argument.
klass = args [1]->klass;
vector_inner_type = get_underlying_type (fsig->params [1]);
} else if (MONO_TYPE_IS_VECTOR_PRIMITIVE (fsig->params [1])) {
if (!is_element_type_primitive (fsig->params [0]))
return NULL;
} else if (!(is_element_type_primitive (fsig->params [0]) && is_element_type_primitive (fsig->params [1])))
return NULL;

return emit_simd_ins_for_binary_op (cfg, klass, fsig, args, vector_inner_type, id);
}
case SN_AndNot: {
if (!is_element_type_primitive (fsig->params [0]))
return NULL;
Expand Down

0 comments on commit 43295bb

Please sign in to comment.