diff --git a/test/device/wmma.jl b/test/device/wmma.jl index a2fc0fcf..2a187f6a 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -230,18 +230,20 @@ using CUDAnative.WMMA return end - @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev, alpha, beta) - d = Array(d_dev) - - new_a = (a_layout == ColMajor) ? a : transpose(a) - new_b = (b_layout == ColMajor) ? b : transpose(b) - new_c = (c_layout == ColMajor) ? c : transpose(c) - new_d = (d_layout == ColMajor) ? d : transpose(d) - - if do_mac - @test all(isapprox.(alpha * new_a * new_b + beta * new_c, new_d; rtol=sqrt(eps(Float16)))) - else - @test all(isapprox.(alpha * new_a * new_b, new_d; rtol=sqrt(eps(Float16)))) + @test_broken_if VERSION >= v"1.5.0-DEV.393" begin + @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev, alpha, beta) + d = Array(d_dev) + + new_a = (a_layout == ColMajor) ? a : transpose(a) + new_b = (b_layout == ColMajor) ? b : transpose(b) + new_c = (c_layout == ColMajor) ? c : transpose(c) + new_d = (d_layout == ColMajor) ? d : transpose(d) + + if do_mac + all(isapprox.(alpha * new_a * new_b + beta * new_c, new_d; rtol=sqrt(eps(Float16)))) + else + all(isapprox.(alpha * new_a * new_b, new_d; rtol=sqrt(eps(Float16)))) + end end end diff --git a/test/util.jl b/test/util.jl index 9b799b8d..05ed6d41 100644 --- a/test/util.jl +++ b/test/util.jl @@ -102,3 +102,13 @@ function julia_script(code, args=``) proc.exitcode, read(out, String), read(err, String) end +# tests that are conditionall broken +macro test_broken_if(cond, ex...) + quote + if $(esc(cond)) + @test_broken $(map(esc, ex)...) + else + @test $(map(esc, ex)...) + end + end +end