Skip to content

Commit

Permalink
implemented sort for float and bool
Browse files Browse the repository at this point in the history
  • Loading branch information
benplotke committed Jul 11, 2024
1 parent 51d728a commit c168931
Showing 1 changed file with 91 additions and 10 deletions.
101 changes: 91 additions & 10 deletions crates/compiler/gen_llvm/src/llvm/sort.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::build::BuilderExt;
use crate::llvm::build::Env;
use inkwell::values::{BasicValueEnum, IntValue};
use inkwell::IntPredicate;
use roc_builtins::bitcode::IntWidth;
use inkwell::{IntPredicate, FloatPredicate};
use roc_builtins::bitcode::{IntWidth, FloatWidth};
use roc_mono::layout::{
Builtin, InLayout, LayoutIds, LayoutInterner, LayoutRepr, STLayoutInterner,
};
Expand All @@ -21,8 +21,12 @@ pub fn generic_compare<'a, 'ctx>(
LayoutRepr::Builtin(Builtin::Int(int_width)) => {
int_compare(env, lhs_val, rhs_val, int_width)
}
LayoutRepr::Builtin(Builtin::Float(_)) => todo!(),
LayoutRepr::Builtin(Builtin::Bool) => todo!(),
LayoutRepr::Builtin(Builtin::Float(float_width)) => {
float_cmp(env, lhs_val, rhs_val, float_width)
}
LayoutRepr::Builtin(Builtin::Bool) => {
bool_compare(env, lhs_val, rhs_val)
}
LayoutRepr::Builtin(Builtin::Decimal) => todo!(),
LayoutRepr::Builtin(Builtin::Str) => todo!(),
LayoutRepr::Builtin(Builtin::List(_)) => todo!(),
Expand All @@ -48,22 +52,24 @@ fn int_compare<'ctx>(
// (a > b) + 2 * (a < b);
let lhs_gt_rhs = int_gt(env, lhs_val, rhs_val, builtin);
let lhs_lt_rhs = int_lt(env, lhs_val, rhs_val, builtin);
let two = env.ptr_int().const_int(2, false);
let two = env.context.i8_type().const_int(2, false);
let lhs_lt_rhs_times_two =
env.builder
.new_build_int_mul(lhs_lt_rhs, two, "lhs_lt_rhs_times_two");
env.builder
.new_build_int_sub(lhs_gt_rhs, lhs_lt_rhs_times_two, "int_compare")
.new_build_int_add(lhs_gt_rhs, lhs_lt_rhs_times_two, "int_compare")
}



fn int_lt<'ctx>(
env: &Env<'_, 'ctx, '_>,
lhs_val: BasicValueEnum<'ctx>,
rhs_val: BasicValueEnum<'ctx>,
builtin: IntWidth,
) -> IntValue<'ctx> {
use IntWidth::*;
match builtin {
let lhs_lt_rhs = match builtin {
I128 => env.builder.new_build_int_compare(
IntPredicate::SLT,
lhs_val.into_int_value(),
Expand Down Expand Up @@ -124,7 +130,8 @@ fn int_lt<'ctx>(
rhs_val.into_int_value(),
"lhs_lt_rhs_u8",
),
}
};
env.builder.new_build_int_cast_sign_flag(lhs_lt_rhs, env.context.i8_type(), false, "lhs_lt_rhs_cast")
}

fn int_gt<'ctx>(
Expand All @@ -134,7 +141,7 @@ fn int_gt<'ctx>(
builtin: IntWidth,
) -> IntValue<'ctx> {
use IntWidth::*;
match builtin {
let lhs_gt_rhs = match builtin {
I128 => env.builder.new_build_int_compare(
IntPredicate::SGT,
lhs_val.into_int_value(),
Expand Down Expand Up @@ -195,5 +202,79 @@ fn int_gt<'ctx>(
rhs_val.into_int_value(),
"lhs_gt_rhs_u8",
),
}
};
env.builder.new_build_int_cast_sign_flag(lhs_gt_rhs, env.context.i8_type(), false, "lhs_gt_rhs_cast")
}


// Return 0 for equals, 1 for greater than, and 2 for less than.
// We consider NaNs to be smaller than non-NaNs
// We use the below expression to calculate this
// (a == a) + 2*(b == b) - (a < b) - 2*(a > b) - 3*(a == b)
fn float_cmp<'ctx>(
env: &Env<'_, 'ctx, '_>,
lhs_val: BasicValueEnum<'ctx>,
rhs_val: BasicValueEnum<'ctx>,
float_width: FloatWidth,
) -> IntValue<'ctx> {
use FloatWidth::*;
let type_str = match float_width {
F64 => "F64",
F32 => "F32",
};

let make_cmp = |operation, a: BasicValueEnum<'ctx>, b: BasicValueEnum<'ctx>, op_name: &str| {
let full_op_name = format!("{}_{}", op_name, type_str);
let bool_result = env.builder.new_build_float_compare(
operation,
a.into_float_value(),
b.into_float_value(),
&full_op_name,
);
env.builder.new_build_int_cast_sign_flag(bool_result, env.context.i8_type(), false, &format!("{}_cast", full_op_name))
};

let two = env.context.i8_type().const_int(2, false);
let three = env.context.i8_type().const_int(3, false);

let lt_test = make_cmp(FloatPredicate::OLT, lhs_val, rhs_val, "rhs_lt_lhs");
let gt_test = make_cmp(FloatPredicate::OGT, lhs_val, rhs_val, "lhs_gt_rhs");
let eq_test = make_cmp(FloatPredicate::OEQ, lhs_val, rhs_val, "lhs_eq_rhs");
let lhs_not_nan_test = make_cmp(FloatPredicate::OEQ, lhs_val, lhs_val, "lhs_not_NaN");
let rhs_not_nan_test = make_cmp(FloatPredicate::OEQ, rhs_val, rhs_val, "rhs_not_NaN");

let rhs_not_nan_scaled = env.builder.new_build_int_mul(two, rhs_not_nan_test, "2 * rhs_not_nan");
let gt_scaled = env.builder.new_build_int_mul(two, gt_test, "2 * lhs_gt_rhs");
let eq_scaled = env.builder.new_build_int_mul(three, eq_test, "3 * lhs_eq_rhs");

let non_nans = env.builder.new_build_int_add(lhs_not_nan_test, rhs_not_nan_scaled, "(a == a) + 2*(b == b))");
let minus_lt = env.builder.new_build_int_sub(non_nans, lt_test, "(a == a) + 2*(b == b) - (a < b");
let minus_gt = env.builder.new_build_int_sub(minus_lt, gt_scaled, "(a == a) + 2*(b == b) - (a < b) - 2*(a > b)");
env.builder.new_build_int_sub(minus_gt, eq_scaled, "float_compare")
}

// 1 1 0
// 0 0 0
// 0 1 1
// 1 0 2
fn bool_compare<'ctx>(
env: &Env<'_, 'ctx, '_>,
lhs_val: BasicValueEnum<'ctx>,
rhs_val: BasicValueEnum<'ctx>,
) -> IntValue<'ctx> {

// (a < b)
let lhs_lt_rhs = env.builder.new_build_int_compare(IntPredicate::SLT, lhs_val.into_int_value(), rhs_val.into_int_value(), "lhs_lt_rhs_bool");
let lhs_lt_rhs_byte = env.builder.new_build_int_cast_sign_flag(lhs_lt_rhs, env.context.i8_type(), false, "lhs_lt_rhs_byte");

// (a > b)
let lhs_gt_rhs = env.builder.new_build_int_compare(IntPredicate::SGT, lhs_val.into_int_value(), rhs_val.into_int_value(), "lhs_gt_rhs_bool");
let lhs_gt_rhs_byte = env.builder.new_build_int_cast_sign_flag(lhs_gt_rhs, env.context.i8_type(), false, "lhs_gt_rhs_byte");

// (a > b) * 2
let two = env.context.i8_type().const_int(2, false);
let lhs_gt_rhs_times_two = env.builder.new_build_int_mul(lhs_gt_rhs_byte, two, "lhs_gt_rhs_times_two");

// (a < b) + (a > b) * 2
env.builder.new_build_int_add(lhs_lt_rhs_byte, lhs_gt_rhs_times_two, "bool_compare")
}

0 comments on commit c168931

Please sign in to comment.