diff --git a/Sources/kong.c b/Sources/kong.c index 1590aa3..5ee896d 100644 --- a/Sources/kong.c +++ b/Sources/kong.c @@ -102,6 +102,64 @@ void resolve_member_type(statement *parent_block, type_ref parent_type, expressi e->type = e->member.right->type; } +static bool types_compatible(type_id left, type_id right){ + if (left == right) { + return true; + } + + if ((left == int_id && right == float_id) || (left == float_id && right == int_id)) { + return true; + } + if ((left == int2_id && right == float2_id) || (left == float2_id && right == int2_id)) { + return true; + } + if ((left == int3_id && right == float3_id) || (left == float3_id && right == int3_id)) { + return true; + } + if ((left == int4_id && right == float4_id) || (left == float4_id && right == int4_id)) { + return true; + } + + return false; +} + +static type_ref upgrade_type(type_ref left_type, type_ref right_type) { + type_id left = left_type.type; + type_id right = right_type.type; + + if (left == right) { + return left_type; + } + + if (left == int_id && right == float_id) { + return right_type; + } + if (left == float_id && right == int_id) { + return left_type; + } + if (left == int2_id && right == float2_id) { + return right_type; + } + if (left == float2_id && right == int2_id) { + return left_type; + } + if (left == int3_id && right == float3_id) { + return right_type; + } + if (left == float3_id && right == int3_id) { + return left_type; + } + if (left == int4_id && right == float4_id) { + return right_type; + } + if (left == float4_id && right == int4_id) { + return left_type; + } + + kong_log(LOG_LEVEL_WARNING, "Suspicious type upgrade"); + return left_type; +} + void resolve_types_in_expression(statement *parent, expression *e) { switch (e->kind) { case EXPRESSION_BINARY: { @@ -123,12 +181,15 @@ void resolve_types_in_expression(statement *parent, expression *e) { case OPERATOR_MULTIPLY_ASSIGN: { type_id left_type = e->binary.left->type.type; type_id right_type = e->binary.right->type.type; - if (left_type == right_type || (left_type == float4x4_id && right_type == float4_id)) { + if (left_type == float4x4_id && right_type == float4_id) { e->type = e->binary.right->type; } else if (right_type == float_id && (left_type == float2_id || left_type == float3_id || left_type == float4_id)) { e->type = e->binary.left->type; } + else if (types_compatible(left_type, right_type)) { + e->type = e->binary.right->type; + } else { debug_context context = {0}; error(context, "Type mismatch %s vs %s", get_name(get_type(left_type)->name), get_name(get_type(right_type)->name)); @@ -138,14 +199,23 @@ void resolve_types_in_expression(statement *parent, expression *e) { case OPERATOR_MINUS: case OPERATOR_PLUS: case OPERATOR_DIVIDE: - case OPERATOR_MOD: + case OPERATOR_MOD: { + type_id left_type = e->binary.left->type.type; + type_id right_type = e->binary.right->type.type; + if (!types_compatible(left_type, right_type)) { + debug_context context = {0}; + error(context, "Type mismatch %s vs %s", get_name(get_type(left_type)->name), get_name(get_type(right_type)->name)); + } + e->type = upgrade_type(e->binary.left->type, e->binary.right->type); + break; + } case OPERATOR_ASSIGN: case OPERATOR_DIVIDE_ASSIGN: case OPERATOR_MINUS_ASSIGN: case OPERATOR_PLUS_ASSIGN: { type_id left_type = e->binary.left->type.type; type_id right_type = e->binary.right->type.type; - if (left_type != right_type) { + if (!types_compatible(left_type, right_type)) { debug_context context = {0}; error(context, "Type mismatch %s vs %s", get_name(get_type(left_type)->name), get_name(get_type(right_type)->name)); } diff --git a/Sources/types.c b/Sources/types.c index 520f107..2f46964 100644 --- a/Sources/types.c +++ b/Sources/types.c @@ -16,6 +16,10 @@ type_id float2_id; type_id float3_id; type_id float4_id; type_id float4x4_id; +type_id int_id; +type_id int2_id; +type_id int3_id; +type_id int4_id; type_id bool_id; type_id function_type_id; type_id tex2d_type_id; @@ -138,6 +142,96 @@ static void vec4_found_vec4(char *permutation) { ++t->members.size; } +static void int2_found_int(char *permutation) { + type *t = get_type(int2_id); + debug_context context = {0}; + check(t->members.size < MAX_MEMBERS, context, "Out of members"); + t->members.m[t->members.size].name = add_name(permutation); + t->members.m[t->members.size].type.type = int_id; + t->members.m[t->members.size].type.array_size = 0; + ++t->members.size; +} + +static void int2_found_int2(char *permutation) { + type *t = get_type(int2_id); + debug_context context = {0}; + check(t->members.size < MAX_MEMBERS, context, "Out of members"); + t->members.m[t->members.size].name = add_name(permutation); + t->members.m[t->members.size].type.type = int2_id; + t->members.m[t->members.size].type.array_size = 0; + ++t->members.size; +} + +static void int3_found_int(char *permutation) { + type *t = get_type(int3_id); + debug_context context = {0}; + check(t->members.size < MAX_MEMBERS, context, "Out of members"); + t->members.m[t->members.size].name = add_name(permutation); + t->members.m[t->members.size].type.type = int_id; + t->members.m[t->members.size].type.array_size = 0; + ++t->members.size; +} + +static void int3_found_int2(char *permutation) { + type *t = get_type(int3_id); + debug_context context = {0}; + check(t->members.size < MAX_MEMBERS, context, "Out of members"); + t->members.m[t->members.size].name = add_name(permutation); + t->members.m[t->members.size].type.type = int2_id; + t->members.m[t->members.size].type.array_size = 0; + ++t->members.size; +} + +static void int3_found_int3(char *permutation) { + type *t = get_type(int3_id); + debug_context context = {0}; + check(t->members.size < MAX_MEMBERS, context, "Out of members"); + t->members.m[t->members.size].name = add_name(permutation); + t->members.m[t->members.size].type.type = int3_id; + t->members.m[t->members.size].type.array_size = 0; + ++t->members.size; +} + +static void int4_found_int(char *permutation) { + type *t = get_type(int4_id); + debug_context context = {0}; + check(t->members.size < MAX_MEMBERS, context, "Out of members"); + t->members.m[t->members.size].name = add_name(permutation); + t->members.m[t->members.size].type.type = int_id; + t->members.m[t->members.size].type.array_size = 0; + ++t->members.size; +} + +static void int4_found_int2(char *permutation) { + type *t = get_type(int4_id); + debug_context context = {0}; + check(t->members.size < MAX_MEMBERS, context, "Out of members"); + t->members.m[t->members.size].name = add_name(permutation); + t->members.m[t->members.size].type.type = int2_id; + t->members.m[t->members.size].type.array_size = 0; + ++t->members.size; +} + +static void int4_found_int3(char *permutation) { + type *t = get_type(int4_id); + debug_context context = {0}; + check(t->members.size < MAX_MEMBERS, context, "Out of members"); + t->members.m[t->members.size].name = add_name(permutation); + t->members.m[t->members.size].type.type = int3_id; + t->members.m[t->members.size].type.array_size = 0; + ++t->members.size; +} + +static void int4_found_int4(char *permutation) { + type *t = get_type(int4_id); + debug_context context = {0}; + check(t->members.size < MAX_MEMBERS, context, "Out of members"); + t->members.m[t->members.size].name = add_name(permutation); + t->members.m[t->members.size].type.type = int4_id; + t->members.m[t->members.size].type.array_size = 0; + ++t->members.size; +} + void init_type_ref(type_ref *t, name_id name) { t->name = name; t->type = NO_TYPE; @@ -165,6 +259,8 @@ void types_init(void) { get_type(bool_id)->built_in = true; float_id = add_type(add_name("float")); get_type(float_id)->built_in = true; + int_id = add_type(add_name("int")); + get_type(int_id)->built_in = true; { float2_id = add_type(add_name("float2")); @@ -205,6 +301,45 @@ void types_init(void) { permute(letters, (int)strlen(letters), 3, vec4_found_vec4); } + { + int2_id = add_type(add_name("int2")); + get_type(int2_id)->built_in = true; + const char *letters = "xy"; + permute(letters, (int)strlen(letters), 1, int2_found_int); + permute(letters, (int)strlen(letters), 2, int2_found_int2); + letters = "rg"; + permute(letters, (int)strlen(letters), 1, int2_found_int); + permute(letters, (int)strlen(letters), 2, int2_found_int2); + } + + { + int3_id = add_type(add_name("int3")); + get_type(int3_id)->built_in = true; + const char *letters = "xyz"; + permute(letters, (int)strlen(letters), 1, int3_found_int); + permute(letters, (int)strlen(letters), 2, int3_found_int2); + permute(letters, (int)strlen(letters), 3, int3_found_int3); + letters = "rgb"; + permute(letters, (int)strlen(letters), 1, int3_found_int); + permute(letters, (int)strlen(letters), 2, int3_found_int2); + permute(letters, (int)strlen(letters), 3, int3_found_int3); + } + + { + int4_id = add_type(add_name("int4")); + get_type(int4_id)->built_in = true; + const char *letters = "xyzw"; + permute(letters, (int)strlen(letters), 1, int4_found_int); + permute(letters, (int)strlen(letters), 2, int4_found_int2); + permute(letters, (int)strlen(letters), 3, int4_found_int3); + permute(letters, (int)strlen(letters), 3, int4_found_int4); + letters = "rgba"; + permute(letters, (int)strlen(letters), 1, int4_found_int); + permute(letters, (int)strlen(letters), 2, int4_found_int2); + permute(letters, (int)strlen(letters), 3, int4_found_int3); + permute(letters, (int)strlen(letters), 3, int4_found_int4); + } + { float4x4_id = add_type(add_name("float4x4")); get_type(float4x4_id)->built_in = true; diff --git a/Sources/types.h b/Sources/types.h index a4260ca..b65a20d 100644 --- a/Sources/types.h +++ b/Sources/types.h @@ -83,6 +83,10 @@ extern type_id float2_id; extern type_id float3_id; extern type_id float4_id; extern type_id float4x4_id; +extern type_id int_id; +extern type_id int2_id; +extern type_id int3_id; +extern type_id int4_id; extern type_id bool_id; extern type_id tex2d_type_id; extern type_id texcube_type_id;