Skip to content

Commit

Permalink
Reorganize plotting commands (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Mar 30, 2021
1 parent 295c2ab commit 31d0ecb
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 57 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AbstractGPs"
uuid = "99985d1d-32ba-4be9-9821-2ec096f28918"
authors = ["JuliaGaussianProcesses Team"]
version = "0.2.22"
version = "0.2.23"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
4 changes: 2 additions & 2 deletions examples/regression_1d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ scatter(
title="posterior (default parameters)", label="Train Data",
)
scatter!(x_test, y_test; label="Test Data")
plot!(p_fx, 0:0.001:1; label=false)
plot!(0:0.001:1, p_fx; label=false)

# ## Markov Chain Monte Carlo
#
Expand Down Expand Up @@ -430,5 +430,5 @@ scatter(
title="posterior (VI with sparse grid)", label="Train Data",
)
scatter!(x_test, y_test; label="Test Data")
plot!(ap, 0:0.001:1; label=false)
plot!(0:0.001:1, ap; label=false)
vline!(logistic.(opt.minimizer[3:end]); label="Pseudo-points")
3 changes: 3 additions & 0 deletions src/AbstractGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,7 @@ module AbstractGPs

# Plotting utilities.
include(joinpath("util", "plotting.jl"))

# Deprecations.
include("deprecations.jl")
end # module
33 changes: 33 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
@recipe function f(gp::AbstractGP, x::AbstractArray)
Base.depwarn(
"`plot(gp::AbstractGP, x::AbstractArray)` is deprecated, " *
"use `plot(x, gp)` instead.",
:apply_recipe,
)
return x, gp
end
@recipe function f(gp::AbstractGP, xmin::Real, xmax::Real)
Base.depwarn(
"`plot(gp::AbstractGP, xmin::Real, xmax::Real)` is deprecated, use " *
"`plot(range(xmin, xmax; length=1_000), gp)` instead.",
:apply_recipe,
)
return range(xmin, xmax; length=1_000), gp
end

@recipe function f(z::AbstractVector, gp::AbstractGP, x::AbstractArray)
Base.depwarn(
"`plot(z::AbstractVector, gp::AbstractGP, x::AbstractArray)` is deprecated, " *
"use `plot(z, gp(x))` instead.",
:apply_recipe,
)
return z, gp(x)
end
@recipe function f(z::AbstractVector, gp::AbstractGP, xmin::Real, xmax::Real)
Base.depwarn(
"`plot(z::AbstractVector, gp::AbstractGP, xmin::Real, xmax::Real)` is deprecated, " *
"use `plot(z, gp(range(xmin, xmax; length=1_000))` instead.",
:apply_recipe,
)
return z, gp(range(xmin, xmax; length=1_000))
end
22 changes: 8 additions & 14 deletions src/util/plotting.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@
@recipe f(gp::AbstractGP, x::AbstractArray) = gp(x)
@recipe f(gp::AbstractGP, xmin::Real, xmax::Real) = gp(range(xmin, xmax; length=1_000))

@recipe f(z::AbstractVector, gp::AbstractGP, x::AbstractArray) = (z, gp(x))
@recipe function f(z::AbstractVector, gp::AbstractGP, xmin::Real, xmax::Real)
return (z, gp(range(xmin, xmax; length=1_000)))
end

@recipe f(x::AbstractVector, gp::AbstractGP) = gp(x)
@recipe f(gp::FiniteGP) = (gp.x, gp)
@recipe function f(z::AbstractVector, gp::FiniteGP)
length(z) == length(gp.x) ||
throw(DimensionMismatch("length of `z` and `gp.x` has to be equal"))
@recipe function f(x::AbstractVector, gp::FiniteGP)
length(x) == length(gp.x) ||
throw(DimensionMismatch("length of `x` and `gp.x` has to be equal"))
scale::Float64 = pop!(plotattributes, :ribbon_scale, 1.0)
scale > 0.0 || error("`bandwidth` keyword argument must be non-negative")

# compute marginals
μ, σ2 = mean_and_cov_diag(gp)
σ = map(sqrt, σ2)

ribbon := σ
ribbon := scale .* sqrt.(σ2)
fillalpha --> 0.3
linewidth --> 2
return z, μ
return x, μ
end

"""
Expand Down
41 changes: 41 additions & 0 deletions test/deprecations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
@testset "deprecations" begin
x = rand(10)
f = GP(SqExponentialKernel())
gp = f(x, 0.1)

# with `AbstractVector` and `AbstractRange`:
for x in (rand(10), 0:0.01:1)
rec = @test_deprecated RecipesBase.apply_recipe(Dict{Symbol,Any}(), f, x)
@test length(rec) == 1 && length(rec[1].args) == 2 # one series with two arguments
@test rec[1].args[1] == x
@test rec[1].args[2] == f
@test isempty(rec[1].plotattributes) # no default attributes

z = 1 .+ x
rec = @test_deprecated RecipesBase.apply_recipe(Dict{Symbol,Any}(), z, f, x)
@test length(rec) == 1 && length(rec[1].args) == 2 # one series with two arguments
@test rec[1].args[1] == z
@test rec[1].args[2] isa AbstractGPs.FiniteGP
@test rec[1].args[2].x == x
@test rec[1].args[2].f == f
@test isempty(rec[1].plotattributes) # no default attributes
end

# with minimum and maximum:
xmin = rand()
xmax = 4 + rand()
rec = @test_deprecated RecipesBase.apply_recipe(Dict{Symbol,Any}(), f, xmin, xmax)
@test length(rec) == 1 && length(rec[1].args) == 2 # one series with two arguments
@test rec[1].args[1] == range(xmin, xmax; length=1_000)
@test rec[1].args[2] == f
@test isempty(rec[1].plotattributes) # no default attributes

z = range(0, 1; length=1_000)
rec = @test_deprecated RecipesBase.apply_recipe(Dict{Symbol,Any}(), z, f, xmin, xmax)
@test length(rec) == 1 && length(rec[1].args) == 2 # one series with two arguments
@test rec[1].args[1] == z
@test rec[1].args[2] isa AbstractGPs.FiniteGP
@test rec[1].args[2].x == range(xmin, xmax; length=1_000)
@test rec[1].args[2].f == f
@test isempty(rec[1].plotattributes) # no default attributes
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ include("test_util.jl")
println(" ")
@info "Ran latent_gp tests"

include("deprecations.jl")
println(" ")
@info "Ran deprecation tests"

include("turing.jl")
println(" ")
@info "Ran Turing tests"
Expand Down
60 changes: 20 additions & 40 deletions test/util/plotting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,56 +14,36 @@
@test isempty(rec[1].plotattributes) # no default attributes

z = 1 .+ x
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), z, gp)
@test length(rec) == 1 && length(rec[1].args) == 2 # one series with two arguments
@test rec[1].args[1] == z
@test rec[1].args[2] == zero(x)
# 3 default attributes
attributes = rec[1].plotattributes
@test sort!(collect(keys(attributes))) == [:fillalpha, :linewidth, :ribbon]
@test attributes[:fillalpha] == 0.3
@test attributes[:linewidth] == 2
@test attributes[:ribbon] == sqrt.(cov_diag(gp))
for kwargs in (
Dict{Symbol, Any}(),
Dict{Symbol, Any}(:ribbon_scale => 3),
Dict{Symbol, Any}(:ribbon_scale => rand()),
)
scale = get(kwargs, :ribbon_scale, 1.0)
rec = RecipesBase.apply_recipe(kwargs, z, gp)
@test length(rec) == 1 && length(rec[1].args) == 2 # one series with two arguments
@test rec[1].args[1] == z
@test rec[1].args[2] == zero(x)
# 3 default attributes
attributes = rec[1].plotattributes
@test sort!(collect(keys(attributes))) == [:fillalpha, :linewidth, :ribbon]
@test attributes[:fillalpha] == 0.3
@test attributes[:linewidth] == 2
@test attributes[:ribbon] == scale .* sqrt.(cov_diag(gp))
end

# Check recipe dispatches for `AbstractGP`s
# with `AbstractVector` and `AbstractRange`:
for x in (rand(10), 0:0.01:1)
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), f, x)
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), x, f)
@test length(rec) == 1 && length(rec[1].args) == 1 # one series with one argument
@test rec[1].args[1] isa AbstractGPs.FiniteGP
@test rec[1].args[1].x == x
@test rec[1].args[1].f == f
@test isempty(rec[1].plotattributes) # no default attributes

z = 1 .+ x
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), z, f, x)
@test length(rec) == 1 && length(rec[1].args) == 2 # one series with two arguments
@test rec[1].args[1] == z
@test rec[1].args[2] isa AbstractGPs.FiniteGP
@test rec[1].args[2].x == x
@test rec[1].args[2].f == f
@test isempty(rec[1].plotattributes) # no default attributes
end

# with minimum and maximum:
xmin = rand()
xmax = 4 + rand()
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), f, xmin, xmax)
@test length(rec) == 1 && length(rec[1].args) == 1 # one series with one argument
@test rec[1].args[1] isa AbstractGPs.FiniteGP
@test rec[1].args[1].x == range(xmin, xmax; length=1_000)
@test rec[1].args[1].f == f
@test isempty(rec[1].plotattributes) # no default attributes

z = range(0, 1; length=1_000)
rec = RecipesBase.apply_recipe(Dict{Symbol, Any}(), z, f, xmin, xmax)
@test length(rec) == 1 && length(rec[1].args) == 2 # one series with two arguments
@test rec[1].args[1] == z
@test rec[1].args[2] isa AbstractGPs.FiniteGP
@test rec[1].args[2].x == range(xmin, xmax; length=1_000)
@test rec[1].args[2].f == f
@test isempty(rec[1].plotattributes) # no default attributes

# Check dimensions
# Checks
@test_throws DimensionMismatch plot(rand(5), gp)
@test_throws ErrorException plot(rand(10), gp; ribbon_scale=-0.5)
end

2 comments on commit 31d0ecb

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/33196

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.23 -m "<description of version>" 31d0ecb7c4b83253bb334bc466c2bc8ed2ad20c5
git push origin v0.2.23

Please sign in to comment.