Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dectests #92

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 148 additions & 123 deletions scripts/dectest.jl
Original file line number Diff line number Diff line change
@@ -1,47 +1,109 @@
function _precision(line)
function (@main)(args=ARGS)
name, dectest_path, output_path = args

open(output_path, "w") do io
println(io, """
using Decimals
using Test
using Decimals: @with_context

@testset \"$name\" begin""")

translate(io, dectest_path)

println(io, "end")
end
end

function translate(io, dectest_path)
directives = Dict{Symbol, Any}()

for line in eachline(dectest_path)
line = strip(line)

isempty(line) && continue
startswith(line, "--") && continue

line = lowercase(line)

if startswith(line, "version:")
# ...
elseif startswith(line, "extended:")
# ...
elseif startswith(line, "clamp:")
# ...
elseif startswith(line, "precision:")
directives[:precision] = parse_precision(line)
elseif startswith(line, "rounding:")
directives[:rounding] = parse_rounding(line)
elseif startswith(line, "maxexponent:")
directives[:Emax] = parse_maxexponent(line)
elseif startswith(line, "minexponent:")
directives[:Emin] = parse_minexponent(line)
else
if directives[:rounding] == RoundingMode{:Unsupported}
continue
end

test = parse_test(line)
any(isspecial, test.operands) && continue
isspecial(test.result) && continue

dectest = decimal_test(test, directives)
println(io, dectest)
end
end
end

function isspecial(value)
value = lowercase(value)
return occursin(r"(inf|nan|#|\?)", value)
end

function parse_precision(line)
m = match(r"^precision:\s*(\d+)$", line)
isnothing(m) && throw(ArgumentError(line))
return parse(Int, m[1])
end

function _rounding(line)
function parse_rounding(line)
m = match(r"^rounding:\s*(\w+)$", line)
isnothing(m) && throw(ArgumentError(line))
r = m[1]
if r == "ceiling"
return "RoundUp"
return RoundUp
elseif r == "down"
return "RoundToZero"
return RoundToZero
elseif r == "floor"
return "RoundDown"
return RoundDown
elseif r == "half_even"
return "RoundNearest"
return RoundNearest
elseif r == "half_up"
return "RoundNearestTiesAway"
return RoundNearestTiesAway
elseif r == "up"
return "RoundFromZero"
return RoundFromZero
elseif r == "half_down"
return "RoundHalfDownUnsupported"
return RoundingMode{:Unsupported}
elseif r == "05up"
return "Round05UpUnsupported"
return RoundingMode{:Unsupported}
else
throw(ArgumentError(r))
end
end

function _maxexponent(line)
function parse_maxexponent(line)
m = match(r"^maxexponent:\s*\+?(\d+)$", line)
isnothing(m) && throw(ArgumentError(line))
return parse(Int, m[1])
end

function _minexponent(line)
function parse_minexponent(line)
m = match(r"^minexponent:\s*(-\d+)$", line)
isnothing(m) && throw(ArgumentError(line))
return parse(Int, m[1])
end

function _test(line)
function parse_test(line)
occursin("->", line) || throw(ArgumentError(line))
lhs, rhs = split(line, "->")
id, operation, operands... = split(lhs)
Expand All @@ -50,134 +112,97 @@ function _test(line)
return (;id, operation, operands, result, conditions)
end

function decimal(x)
function clean(@nospecialize ex)
if isa(ex, Expr)
if Meta.isexpr(ex, :macrocall)
return Expr(:macrocall, ex.args[1], nothing, map(clean, ex.args[3:end])...)
else
return Expr(ex.head, map(clean, ex.args)...)
end
elseif isa(ex, LineNumberNode)
return nothing
else
return ex
end
end

function decimal_test(test, directives)
ctxt = decimal_context(directives)
op = decimal_operation(test.operation, test.operands)
res = operation_result(test.operation, test.result)

if :overflow in test.conditions
ex = :(@with_context($ctxt, @test_throws OverflowError $op))
elseif :division_undefined in test.conditions
ex = :(@with_context($ctxt, @test_throws UndefinedDivisionError $op))
elseif :division_by_zero in test.conditions
ex = :(@with_context($ctxt, @test_throws DivisionByZeroError $op))
else
ex = :(@with_context($ctxt, @test $op == $(res)))
end
return clean(ex)
end

function dec(x)
x = strip(x, ['\'', '\"'])
return "dec\"$x\""
return :(@dec_str $("$x"))
end

function decimal_context(directives)
names = Tuple(sort!(collect(keys(directives))))
values = Tuple([directives[name] for name in names])
params = NamedTuple{names}(values)
return params
end

function operation_result(operation, result)
if operation == "compare"
return parse(Int, result)
else
return dec(result)
end
end

function print_operation(io, operation, operands)
function decimal_operation(operation, operands)
if operation == "abs"
print_abs(io, operands...)
return decimal_abs(operands...)
elseif operation == "add"
print_add(io, operands...)
return decimal_add(operands...)
elseif operation == "apply"
print_apply(io, operands...)
return decimal_apply(operands...)
elseif operation == "compare"
print_compare(io, operands...)
return decimal_compare(operands...)
elseif operation == "divide"
print_divide(io, operands...)
return decimal_divide(operands...)
elseif operation == "max"
print_max(io, operands...)
return decimal_max(operands...)
elseif operation == "min"
print_min(io, operands...)
return decimal_min(operands...)
elseif operation == "minus"
print_minus(io, operands...)
return decimal_minus(operands...)
elseif operation == "multiply"
print_multiply(io, operands...)
return decimal_multiply(operands...)
elseif operation == "plus"
print_plus(io, operands...)
return decimal_plus(operands...)
elseif operation == "reduce"
print_reduce(io, operands...)
return decimal_reduce(operands...)
elseif operation == "subtract"
print_subtract(io, operands...)
return decimal_subtract(operands...)
else
throw(ArgumentError(operation))
end
end
print_abs(io, x) = print(io, "abs(", decimal(x), ")")
print_add(io, x, y) = print(io, decimal(x), " + ", decimal(y))
print_apply(io, x) = print(io, decimal(x))
print_compare(io, x, y) = print(io, "cmp(", decimal(x), ", ", decimal(y), ")")
print_divide(io, x, y) = print(io, decimal(x), " / ", decimal(y))
print_max(io, x, y) = print(io, "max(", decimal(x), ", ", decimal(y), ")")
print_min(io, x, y) = print(io, "min(", decimal(x), ", ", decimal(y), ")")
print_minus(io, x) = print(io, "-(", decimal(x), ")")
print_multiply(io, x, y) = print(io, decimal(x), " * ", decimal(y))
print_plus(io, x) = print(io, "+(", decimal(x), ")")
print_reduce(io, x) = print(io, "reduce(", decimal(x), ")")
print_subtract(io, x, y) = print(io, decimal(x), " - ", decimal(y))

function print_test(io, test, directives)
println(io, " # $(test.id)")

names = sort!(collect(keys(directives)))
params = join(("$k=$(directives[k])" for k in names), ", ")
print(io, " @with_context ($params) ")

if :overflow ∈ test.conditions
print(io, "@test_throws OverflowError ")
print_operation(io, test.operation, test.operands)
println(io)
elseif :division_undefined ∈ test.conditions
print(io, "@test_throws UndefinedDivisionError ")
print_operation(io, test.operation, test.operands)
println(io)
elseif :division_by_zero ∈ test.conditions
print(io, "@test_throws DivisionByZeroError ")
print_operation(io, test.operation, test.operands)
println(io)
else
print(io, "@test ")
print_operation(io, test.operation, test.operands)
print(io, " == ")
println(io, decimal(test.result))
end
end

function isspecial(value)
value = lowercase(value)
return occursin(r"(inf|nan|#)", value)
end

function translate(io, dectest_path)
directives = Dict{String, Any}()

for line in eachline(dectest_path)
line = strip(line)

isempty(line) && continue
startswith(line, "--") && continue

line = lowercase(line)

if startswith(line, "version:")
# ...
elseif startswith(line, "extended:")
# ...
elseif startswith(line, "clamp:")
# ...
elseif startswith(line, "precision:")
directives["precision"] = _precision(line)
elseif startswith(line, "rounding:")
directives["rounding"] = _rounding(line)
elseif startswith(line, "maxexponent:")
directives["Emax"] = _maxexponent(line)
elseif startswith(line, "minexponent:")
directives["Emin"] = _minexponent(line)
else
test = _test(line)
any(isspecial, test.operands) && continue
occursin("Unsupported", directives["rounding"]) && continue
print_test(io, test, directives)
end
end
end

function (@main)(args=ARGS)
name, dectest_path, output_path = args

open(output_path, "w") do io
println(io, """
using Decimals
using ScopedValues
using Test
using Decimals: @with_context

@testset \"$name\" begin""")

translate(io, dectest_path)

println(io, "end")
end
end
decimal_abs(x) = :(abs($(dec(x))))
decimal_add(x, y) = :($(dec(x)) + $(dec(y)))
decimal_apply(x) = dec(x)
decimal_compare(x, y) = :(cmp($(dec(x)), $(dec(y))))
decimal_divide(x, y) = :($(dec(x)) / $(dec(y)))
decimal_max(x, y) = :(max($(dec(x)), $(dec(y))))
decimal_min(x, y) = :(min($(dec(x)), $(dec(y))))
decimal_minus(x) = :(-($(dec(x))))
decimal_multiply(x, y) = :($(dec(x)) * $(dec(y)))
decimal_plus(x) = :(+($(dec(x))))
decimal_reduce(x) = :(normalize($(dec(x))))
decimal_subtract(x, y) = :($(dec(x)) - $(dec(y)))

Loading