From c9b6f79b33db3e4c90aa6fbdf7fb5c1b6512658e Mon Sep 17 00:00:00 2001 From: Quinton Miller Date: Thu, 2 Nov 2023 02:55:57 +0800 Subject: [PATCH] Add `Math.fma` --- spec/std/math_spec.cr | 10 +++++++++ .../crystal/interpreter/instructions.cr | 10 +++++++++ .../crystal/interpreter/primitives.cr | 6 ++++++ src/math/libm.cr | 6 ++++++ src/math/math.cr | 21 +++++++++++++++++++ 5 files changed, 53 insertions(+) diff --git a/spec/std/math_spec.cr b/spec/std/math_spec.cr index 9106636a377e..ede2149110eb 100644 --- a/spec/std/math_spec.cr +++ b/spec/std/math_spec.cr @@ -266,6 +266,16 @@ describe "Math" do # div rem + it "fma" do + x = Math.fma(0.1, 10.0, -1.0) + x.should be_close(5.551115123125783e-17, 1e-25) + x.should_not eq(0.0) + + x = Math.fma(0.1_f32, 10.0_f32, -1.0_f32) + x.should be_close(1.4901161e-8_f32, 1e-16_f32) + x.should_not eq(0.0_f32) + end + describe ".pw2ceil" do {% for int in %w(Int8 Int16 Int32 Int64 Int128) %} it {{ int }} do diff --git a/src/compiler/crystal/interpreter/instructions.cr b/src/compiler/crystal/interpreter/instructions.cr index 2c55cfe255df..67966277bf15 100644 --- a/src/compiler/crystal/interpreter/instructions.cr +++ b/src/compiler/crystal/interpreter/instructions.cr @@ -1837,6 +1837,16 @@ require "./repl" push: true, code: LibM.floor_f64(value), }, + libm_fma_f32: { + pop_values: [value1 : Float32, value2 : Float32, value3 : Float32], + push: true, + code: LibM.fma_f32(value1, value2, value3), + }, + libm_fma_f64: { + pop_values: [value1 : Float64, value2 : Float64, value3 : Float64], + push: true, + code: LibM.fma_f64(value1, value2, value3), + }, libm_log_f32: { pop_values: [value : Float32], push: true, diff --git a/src/compiler/crystal/interpreter/primitives.cr b/src/compiler/crystal/interpreter/primitives.cr index 2063830e0eed..1d50568c5d62 100644 --- a/src/compiler/crystal/interpreter/primitives.cr +++ b/src/compiler/crystal/interpreter/primitives.cr @@ -552,6 +552,12 @@ class Crystal::Repl::Compiler when "interpreter_libm_floor_f64" accept_call_args(node) libm_floor_f64 node: node + when "interpreter_libm_fma_f32" + accept_call_args(node) + libm_fma_f32 node: node + when "interpreter_libm_fma_f64" + accept_call_args(node) + libm_fma_f64 node: node when "interpreter_libm_log_f32" accept_call_args(node) libm_log_f32 node: node diff --git a/src/math/libm.cr b/src/math/libm.cr index 8130a8ac31a2..59e66f12737e 100644 --- a/src/math/libm.cr +++ b/src/math/libm.cr @@ -47,6 +47,12 @@ lib LibM {% if flag?(:interpreted) %} @[Primitive(:interpreter_libm_floor_f64)] {% end %} fun floor_f64 = "llvm.floor.f64"(value : Float64) : Float64 + {% if flag?(:interpreted) %} @[Primitive(:interpreter_libm_fma_f32)] {% end %} + fun fma_f32 = "llvm.fma.f32"(value1 : Float32, value2 : Float32, value3 : Float32) : Float32 + + {% if flag?(:interpreted) %} @[Primitive(:interpreter_libm_fma_f64)] {% end %} + fun fma_f64 = "llvm.fma.f64"(value1 : Float64, value2 : Float64, value3 : Float64) : Float64 + {% if flag?(:interpreted) %} @[Primitive(:interpreter_libm_log_f32)] {% end %} fun log_f32 = "llvm.log.f32"(value : Float32) : Float32 diff --git a/src/math/math.cr b/src/math/math.cr index f29eeab386e4..efeef9ad9e12 100644 --- a/src/math/math.cr +++ b/src/math/math.cr @@ -575,6 +575,27 @@ module Math hypot(value1.to_f, value2.to_f) end + # Fused multiply-add; returns `value1 * value2 + value3`, performing a single + # rounding instead of two. + # + # ``` + # Math.fma(0.1, 10.0, -1.0) # => 5.551115123125783e-17 + # 1.0 * 10.0 - 1.0 # => 0.0 + # ``` + def fma(value1 : Float32, value2 : Float32, value3 : Float32) : Float32 + LibM.fma_f32(value1, value2, value3) + end + + # :ditto: + def fma(value1 : Float64, value2 : Float64, value3 : Float64) : Float64 + LibM.fma_f64(value1, value2, value3) + end + + # :ditto: + def fma(value1, value2, value3) + fma(value1.to_f, value2.to_f, value3.to_f) + end + # Returns the unbiased base 2 exponent of the given floating-point *value*. def ilogb(value : Float32) : Int32 LibM.ilogb_f32(value)