Skip to content

Commit

Permalink
Merge pull request #2593 from JuliaReach/schillic/sample
Browse files Browse the repository at this point in the history
Add 'include_vertices' option to 'sample'
  • Loading branch information
schillic authored Feb 25, 2021
2 parents 1b108df + 10fbafb commit f8510e9
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/Utils/samples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ abstract type Sampler end
[sampler]=_default_sampler(X),
[rng]::AbstractRNG=GLOBAL_RNG,
[seed]::Union{Int, Nothing}=nothing,
[include_vertices]=false,
[VN]=Vector{N}) where {N}
Sampling of an arbitrary bounded set `X`.
Expand All @@ -31,6 +32,7 @@ Sampling of an arbitrary bounded set `X`.
on the type of `X`
- `rng` -- (optional, default: `GLOBAL_RNG`) random number generator
- `seed` -- (optional, default: `nothing`) seed for reseeding
- `include_vertices` -- (optional, default: `false`) option to include the vertices
- `VN` -- (optional, default: `Vector{N}`) vector type of the sampled points
### Output
Expand All @@ -42,16 +44,35 @@ vector).
### Algorithm
See the documentation of the respective `Sampler`.
### Notes
If `include_vertices == true`, we include all vertices computed with `vertices`.
Alternatively if a number ``k`` is passed, we plot the first ``k`` vertices
returned by `vertices`.
"""
function sample(X::LazySet{N}, num_samples::Int;
sampler=_default_sampler(X),
rng::AbstractRNG=GLOBAL_RNG,
seed::Union{Int, Nothing}=nothing,
include_vertices=false,
VN=Vector{N}) where {N}
@assert isbounded(X) "this function requires that the set `X` is bounded"

D = Vector{VN}(undef, num_samples) # preallocate output
_sample!(D, sampler(X); rng=rng, seed=seed)

if include_vertices != false
k = (include_vertices isa Bool) ? Inf : include_vertices
for v in vertices(X)
push!(D, v)
k -= 1
if k <= 0
break
end
end
end

return D
end

Expand Down
8 changes: 8 additions & 0 deletions test/unit_samples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,12 @@ for N in [Float64]
# specifying a distribution from Distributions.jl
@test LazySets.RejectionSampler(P2, Uniform).box_approx == [Uniform(-3.0, 1.0), Uniform(-4.0,2.0)]
@test LazySets.RejectionSampler(P2, Normal).box_approx == [Normal(-3.0, 1.0), Normal(-4.0, 2.0)]

# including vertices
for k in 0:4
p1 = sample(P1, 10; include_vertices=k)
@test length(p1) == 10 + k
end
@test length(sample(P1, 10; include_vertices=false)) == 10
@test length(sample(P1, 10; include_vertices=true)) == 42
end

0 comments on commit f8510e9

Please sign in to comment.