diff --git a/test/utils.jl b/test/utils.jl index 1681a0df28..025743c924 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -146,6 +146,15 @@ end end end + @testset "truncated_normal" begin + for sz in [(100,), (100, 100), (2, 3, 32, 64), (2, 3, 4, 32, 64)] + v = truncated_normal(sz...) + @test -1.0 < minimum(v) < 0.0 + @test 0.0 < maximum(v) < 1.0 + @test eltype(v) == Float32 + end + end + @testset "partial_application" begin big = 1e9