diff --git a/src/NeuralVerification.jl b/src/NeuralVerification.jl index ecb4c181..dc077e4d 100644 --- a/src/NeuralVerification.jl +++ b/src/NeuralVerification.jl @@ -74,9 +74,7 @@ include("reachability/utils/reachability.jl") include("reachability/exactReach.jl") include("reachability/maxSens.jl") include("reachability/ai2.jl") -include("reachability/ai2z.jl") -include("reachability/box.jl") -export ExactReach, MaxSens, Ai2, Ai2z, Box +export ExactReach, MaxSens, Ai2, Ai2h, Ai2z, Box include("satisfiability/bab.jl") include("satisfiability/sherlock.jl") diff --git a/src/reachability/ai2.jl b/src/reachability/ai2.jl index 7bbe252e..098d2b0a 100644 --- a/src/reachability/ai2.jl +++ b/src/reachability/ai2.jl @@ -1,11 +1,25 @@ """ - Ai2 + Ai2{T} -Ai2 performs over-approximated reachability analysis to compute the over-approximated output reachable set for a network. +`Ai2` performs over-approximated reachability analysis to compute the over-approximated +output reachable set for a network. `T` can be `Hyperrectangle`, `Zonotope`, or +`HPolytope`, and determines the amount of over-approximation (and hence also performance +tradeoff). The original implementation (from [1]) uses Zonotopes, so we consider this +the "benchmark" case. The `HPolytope` case is more precise, but slower, and the opposite +is true of the `Hyperrectangle` case. + +Note that initializing `Ai2()` defaults to `Ai2{Zonotope}`. +The following aliases also exist for convenience: + +```julia +const Ai2h = Ai2{HPolytope} +const Ai2z = Ai2{Zonotope} +const Box = Ai2{Hyperrectangle} +``` # Problem requirement 1. Network: any depth, ReLU activation (more activations to be supported in the future) -2. Input: HPolytope +2. Input: AbstractPolytope 3. Output: AbstractPolytope # Return @@ -18,30 +32,58 @@ Reachability analysis using split and join. Sound but not complete. # Reference -T. Gehr, M. Mirman, D. Drashsler-Cohen, P. Tsankov, S. Chaudhuri, and M. Vechev, +[1] T. Gehr, M. Mirman, D. Drashsler-Cohen, P. Tsankov, S. Chaudhuri, and M. Vechev, "Ai2: Safety and Robustness Certification of Neural Networks with Abstract Interpretation," in *2018 IEEE Symposium on Security and Privacy (SP)*, 2018. + +## Note +Efficient over-approximation of intersections and unions involving zonotopes relies on Theorem 3.1 of + +[2] Singh, G., Gehr, T., Mirman, M., Püschel, M., & Vechev, M. (2018). Fast +and effective robustness certification. In Advances in Neural Information +Processing Systems (pp. 10802-10813). """ -struct Ai2 <: Solver end +struct Ai2{T<:Union{Hyperrectangle, Zonotope, HPolytope}} <: Solver end + +Ai2() = Ai2{Zonotope}() +const Ai2h = Ai2{HPolytope} +const Ai2z = Ai2{Zonotope} +const Box = Ai2{Hyperrectangle} function solve(solver::Ai2, problem::Problem) reach = forward_network(solver, problem.network, problem.input) return check_inclusion(reach, problem.output) end -forward_layer(solver::Ai2, layer::Layer, inputs::Vector{<:AbstractPolytope}) = forward_layer.(solver, layer, inputs) +forward_layer(solver::Ai2, L::Layer, inputs::Vector) = forward_layer.(solver, L, inputs) + +function forward_layer(solver::Ai2h, L::Layer{ReLU}, input::AbstractPolytope) + Ẑ = affine_map(L, input) + relued_subsets = forward_partition(L.activation, Ẑ) # defined in reachability.jl + return convex_hull(UnionSetArray(relued_subsets)) +end + +# method for Zonotope and Hyperrectangle, if the input set isn't a Zonotope +function forward_layer(solver::Union{Ai2z, Box}, L::Layer{ReLU}, input::AbstractPolytope) + return forward_layer(solver, L, overapproximate(input, Hyperrectangle)) +end + +function forward_layer(solver::Ai2z, L::Layer{ReLU}, input::AbstractZonotope) + Ẑ = affine_map(L, input) + return overapproximate(Rectification(Ẑ), Zonotope) +end + -function forward_layer(solver::Ai2, layer::Layer, input::AbstractPolytope) - outlinear = affine_map(layer, input) - relued_subsets = forward_partition(layer.activation, outlinear) # defined in ExactReach - return convex_hull(relued_subsets) +function forward_layer(solver::Box, L::Layer{ReLU}, input::AbstractZonotope) + Ẑ = approximate_affine_map(L, input) + return rectify(Ẑ) end -# extend lazysets convex_hull to a vector of polytopes -function LazySets.convex_hull(sets::Vector{<:AbstractPolytope}; backend = CDDLib.Library()) - hull = first(sets) - for P in sets - hull = convex_hull(hull, P, backend = backend) - end - return hull -end \ No newline at end of file +function forward_layer(solver::Ai2, L::Layer{Id}, input) + return affine_map(L, input) +end + + +function convex_hull(U::UnionSetArray{<:Any, <:HPolytope}) + tohrep(VPolytope(LazySets.convex_hull(U))) +end diff --git a/src/reachability/ai2z.jl b/src/reachability/ai2z.jl deleted file mode 100644 index 2561126f..00000000 --- a/src/reachability/ai2z.jl +++ /dev/null @@ -1,48 +0,0 @@ -""" - Ai2z <: AbstractSolver - -Ai2 performs over-approximated reachability analysis to compute the over-approximated output reachable set for a network. - -# Problem requirement -1. Network: any depth, ReLU activation (more activations to be supported in the future) -2. Input: Zonotope -3. Output: Zonotope - -# Return -`ReachabilityResult` - -# Method -Reachability analysis using split and join using Zonotopes as proposed on [1]. - -# Property -Sound but not complete. - -# Reference -T. Gehr, M. Mirman, D. Drashsler-Cohen, P. Tsankov, S. Chaudhuri, and M. Vechev, -"Ai2: Safety and Robustness Certification of Neural Networks with Abstract Interpretation," -in *2018 IEEE Symposium on Security and Privacy (SP)*, 2018. -""" -struct Ai2z <: Solver end - -function solve(solver::Ai2z, problem::Problem) - if isa(problem.input, LazySet) - input = [problem.input] - else - input = problem.input - end - f_n(x) = forward_network(solver, problem.network, x) - reach = map(f_n, input) - return check_inclusion(reach, problem.output) -end - -forward_layer(solver::Ai2z, layer::Layer, inputs::Vector{<:LazySet}) = forward_layer.(solver, layer, inputs) - -function forward_layer(solver::Ai2z, layer::Layer, input::AbstractPolytope) - return forward_layer(solver, layer, overapproximate(input, Hyperrectangle)) -end - -function forward_layer(solver::Ai2z, layer::Layer, input::AbstractZonotope) - outlinear = affine_map(layer, input) - relued_subsets = forward_partition(layer.activation, outlinear) - return relued_subsets -end diff --git a/src/reachability/box.jl b/src/reachability/box.jl deleted file mode 100644 index 76be4d3c..00000000 --- a/src/reachability/box.jl +++ /dev/null @@ -1,37 +0,0 @@ -""" - Box <: Solver - -Box performs over-approximated reachability analysis to compute the over-approximated output reachable set for a network. - -# Problem requirement -1. Network: any depth, ReLU activation (more activations to be supported in the future) -2. Input: Hyperrectangle -3. Output: Hyperrectangle - -# Return -`ReachabilityResult` - -# Method -Reachability analysis using using boxes. - -# Property -Sound but not complete. -""" -struct Box <: Solver end - -function solve(solver::Box, problem::Problem) - reach = forward_network(solver, problem.network, problem.input) - return check_inclusion(reach, problem.output) -end - -forward_layer(solver::Box, layer::Layer, inputs::Vector{<:LazySet}) = forward_layer.(solver, layer, inputs) - -function forward_layer(solver::Box, layer::Layer, input::AbstractPolytope) - return forward_layer(solver, layer, overapproximate(input, Hyperrectangle)) -end - -function forward_layer(solver::Box, layer::Layer, input::Hyperrectangle) - outlinear = overapproximate(AffineMap(layer.weights, input, layer.bias), Hyperrectangle) - relued_subsets = forward_partition(layer.activation, outlinear) - return relued_subsets -end diff --git a/src/reachability/exactReach.jl b/src/reachability/exactReach.jl index d839e3c7..ff8a2e2d 100644 --- a/src/reachability/exactReach.jl +++ b/src/reachability/exactReach.jl @@ -31,7 +31,7 @@ end forward_layer(solver::ExactReach, layer::Layer, input) = forward_layer(solver, layer, convert(HPolytope, input)) -function forward_layer(solver::ExactReach, layer::Layer, input::Vector{HPolytope}) +function forward_layer(solver::ExactReach, layer::Layer, input::Vector{<:HPolytope}) output = Vector{HPolytope}(undef, 0) for i in 1:length(input) input[i] = affine_map(layer, input[i]) diff --git a/src/reachability/utils/reachability.jl b/src/reachability/utils/reachability.jl index 3b157e57..6dfa9fd3 100644 --- a/src/reachability/utils/reachability.jl +++ b/src/reachability/utils/reachability.jl @@ -18,24 +18,19 @@ function check_inclusion(reach::Vector{<:LazySet}, output) for poly in reach issubset(poly, output) || return ReachabilityResult(:violated, reach) end - return ReachabilityResult(:holds, similar(reach, 0)) + return ReachabilityResult(:holds, reach) end function check_inclusion(reach::P, output) where P<:LazySet - if issubset(reach, output) - return ReachabilityResult(:holds, P[]) - end - return ReachabilityResult(:violated, [reach]) + return ReachabilityResult(issubset(reach, output) ? :holds : :violated, [reach]) end # return a vector so that append! is consistent with the relu forward_partition -forward_partition(act::Id, input::AbstractPolytope) = [input] - -forward_partition(act::Id, input::Zonotope) = input +forward_partition(act::Id, input) = [input] function forward_partition(act::ReLU, input::HPolytope) n = dim(input) - output = Vector{HPolytope}(undef, 0) + output = Vector{HPolytope{Float64}}(undef, 0) C, d = tosimplehrep(input) dh = [d; zeros(n)] for h in 0:(2^n)-1 @@ -56,14 +51,4 @@ function getP(h::Int64, n::Int64) vec[i] = ifelse(str[i] == '1', 1, 0) end return Diagonal(vec) -end - -# forward_partition for Zonotopes -function forward_partition(act::ReLU, input::Zonotope) - return overapproximate(Rectification(input), Zonotope) -end - -# for Hyperrectangles -function forward_partition(act::ReLU, input::Hyperrectangle) - return rectify(input) -end +end \ No newline at end of file diff --git a/test/identity_network.jl b/test/identity_network.jl index 13afdea8..c9cf4bf9 100644 --- a/test/identity_network.jl +++ b/test/identity_network.jl @@ -18,7 +18,7 @@ problem_holds = Problem(small_nnet, in_hpoly, convert(HPolytope, out_superset)) problem_violated = Problem(small_nnet, in_hpoly, convert(HPolytope, out_overlapping)) - for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2z(), Box()] + for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2h(), Box()] holds = solve(solver, problem_holds) violated = solve(solver, problem_violated) diff --git a/test/inactive_relus.jl b/test/inactive_relus.jl index 6f0bc40e..f454794a 100644 --- a/test/inactive_relus.jl +++ b/test/inactive_relus.jl @@ -18,7 +18,7 @@ problem_holds = Problem(small_nnet, in_hpoly, convert(HPolytope, out_superset)) problem_violated = Problem(small_nnet, in_hpoly, convert(HPolytope, out_overlapping)) - for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2z(), Box()] + for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2h(), Box()] holds = solve(solver, problem_holds) violated = solve(solver, problem_violated) diff --git a/test/relu_network.jl b/test/relu_network.jl index ddd9d7d7..669dbe72 100644 --- a/test/relu_network.jl +++ b/test/relu_network.jl @@ -18,7 +18,7 @@ problem_holds = Problem(small_nnet, in_hpoly, convert(HPolytope, out_superset)) problem_violated = Problem(small_nnet, in_hpoly, convert(HPolytope, out_overlapping)) - for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2z(), Box()] + for solver in [MaxSens(resolution = 0.6), ExactReach(), Ai2(), Ai2h(), Box()] holds = solve(solver, problem_holds) violated = solve(solver, problem_violated)