diff --git a/test/triples.jl b/test/triples.jl index 79b8ac28..a0116f62 100644 --- a/test/triples.jl +++ b/test/triples.jl @@ -624,6 +624,59 @@ end end test_propagate(f7, (rand(2, 2), 4.0), (rand(2, 2), 1.0); test_deltas = true) + + function stoch_trip(val::Real, inf_pert::Real, fin_pert::Real, prob::Real) + Δs = StochasticAD.similar_new(StochasticAD.create_Δs(PrunedFIsBackend(), Int), + fin_pert, prob) + StochasticAD.StochasticTriple{0}(val, inf_pert, Δs) + end + + f8(x) = x + stoch_trip(1., 0.1, 10., 100.) + f8(x::StochasticAD.StochasticTriple) = StochasticAD.propagate(f8, x) + f8() = (x = stoch_trip(2., 0., 20., 100.); f8(x)) + samples = [f8() for _ in 1:10] + for s in samples + @test StochasticAD.value(s) == 3. + @test length(perturbations(s)) == 1 + @test perturbations(s)[1].weight == 200. + end + # check that the Δ is sometimes 10 and sometimes 20, + # which requires the Δs of both the added triples to be taken into account + Δs = [perturbations(s)[1].Δ for s in samples] + @test 10. in Δs + @test 20. in Δs + + function f9(value_1, value_2, rand_var) + if value_1 < value_2 + return (value_1 + rand(rand_var), value_2) + else + return (value_1, value_2 + rand(rand_var)) + end + end + + propagate_f9(value_1, value_2, rand_var) = StochasticAD.propagate((v1, v2) -> f(v1, v2, rand_var), value_1, value_2) + + f9(value_1::StochasticTriple, value_2, rand_var) = propagate_f9(value_1, value_2, rand_var) + f9(value_1, value_2::StochasticTriple, rand_var) = propagate_f9(value_1, value_2, rand_var) + f9(value_1::StochasticTriple, value_2::StochasticTriple, rand_var) = propagate_f9(value_1, value_2, rand_var) + + function g(p) + rand_var = Bernoulli(p) + value_1 = 0 + value_2 = 2 + for _ in 1:10 + value_1, value_2 = f9(value_1, value_2, rand_var) + end + return value_1, value_2 + end + + N = 100 + derivs = [derivative_estimate(p -> sum(g(p)), 0.5)] + standard_error = std(derivs) / sqrt(N) + estimate = mean(derivs) + expected_value = 10.01 # obtained by running for N = 1E6 + @test estimate - 5standard_error ≤ expected_value ≤ estimate + 5standard_error + end end