Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix grouping for geom_path, geom_line, geom_step, and geom_density #103

Merged
merged 9 commits into from
May 8, 2024
Merged
1 change: 1 addition & 0 deletions src/TidierPlots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ include("extract_aes.jl")
include("geom.jl")
include("ggplot.jl")
include("ggsave.jl")
include("grouping.jl")
include("labs.jl")
include("label_functions.jl")
include("legend.jl")
Expand Down
37 changes: 31 additions & 6 deletions src/draw.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,39 @@ function Makie.SpecApi.Axis(plot::GGPlot)
end
end

required_aes_data = [p.makie_function(p.raw) for p in [given_aes[a] for a in Symbol.(required_aes)]]
optional_aes_data = [a => p.makie_function(p.raw) for (a, p) in given_aes if !(String(a) in required_aes)]
if length(intersect(keys(given_aes), geom.grouping_aes)) == 0
# if there are no grouping_aes given, we only need one PlotSpec
required_aes_data = [p.makie_function(p.raw) for p in [given_aes[a] for a in Symbol.(required_aes)]]
optional_aes_data = [a => p.makie_function(p.raw) for (a, p) in given_aes if !(String(a) in required_aes)]

args = Tuple([geom.visual, required_aes_data...])
kwargs = merge(args_dict_makie, Dict(optional_aes_data))
args = Tuple([geom.visual, required_aes_data...])
kwargs = merge(args_dict_makie, Dict(optional_aes_data))

# push completed PlotSpec (type, args, and kwargs) to the list of plots
push!(plot_list, Makie.PlotSpec(args...; kwargs...))
# push completed PlotSpec (type, args, and kwargs) to the list of plots
push!(plot_list, Makie.PlotSpec(args...; kwargs...))
else
# if there is a aes in the grouping_aes list given, we will need multiple PlotSpecs
# make a list of modified given_aes objects which only include the points from their subsets
grouping_columns = [aes_dict_makie[a] for a in [intersect(keys(given_aes), geom.grouping_aes)...]]
subgroup_given_aes = subgroup_split(given_aes, plot_data[!, grouping_columns])

# push each one to the overall plot_list
for sub in subgroup_given_aes
required_aes_data = [p.makie_function(p.raw) for p in [sub[a] for a in Symbol.(required_aes)]]
optional_aes_data = [a => p.makie_function(p.raw) for (a, p) in sub if !(String(a) in required_aes)]

args = Tuple([geom.visual, required_aes_data...])
kwargs = merge(args_dict_makie, Dict(optional_aes_data))

# if we are grouping, we only need a single value rather than a vector
for aes in [intersect(keys(given_aes), geom.grouping_aes)...]
kwargs[aes] = first(kwargs[aes])
end

# push completed PlotSpec (type, args, and kwargs) to the list of plots
push!(plot_list, Makie.PlotSpec(args...; kwargs...))
end
end
end

# rename and correct types on all axis options
Expand Down
6 changes: 4 additions & 2 deletions src/geom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ function build_geom(
spec_api_function,
aes_function,
column_transformations;
special_aes = Dict())
special_aes = Dict(),
grouping_aes = Symbol[])

if haskey(args_dict, "data")
if args_dict["data"] isa DataFrame
Expand All @@ -30,6 +31,7 @@ function build_geom(
spec_api_function,
Dict(),
aes_function,
column_transformations
column_transformations,
grouping_aes
)
end
2 changes: 1 addition & 1 deletion src/geoms/geom_density.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ ggplot(penguins, @aes(x=bill_length_mm)) +
geom_density(color = (:red, 0.3), strokecolor = :red, stroke = 2)
```
"""
geom_density = geom_template("geom_density", ["x"], :Density)
geom_density = geom_template("geom_density", ["x"], :Density; grouping_aes = [:color, :colour])
27 changes: 17 additions & 10 deletions src/geoms/geom_path.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
function stat_sort_by_x(aes_dict::Dict{String, Symbol},
args_dict::Dict{Any, Any}, required_aes::Vector{String}, plot_data::DataFrame)

x_column = aes_dict["x"]

perm = sortperm(plot_data[!, x_column])

return (aes_dict, args_dict, required_aes, plot_data[perm, :])
end

"""
geom_line(aes(...), ...)
geom_line(plot::GGPlot, aes(...), ...)
Expand Down Expand Up @@ -36,11 +46,9 @@ df = DataFrame(x = xs, y = sin.(xs))
ggplot(df, @aes(x = x, y = y)) + geom_line()
```
"""
geom_line = geom_template("geom_line", ["x", "y"], :Lines;
column_transformations = Dict{Symbol, Pair{Vector{Symbol}, AesTransform}}(
:y => [:y, :x]=>sort_by,
:x => [:x, :x]=>sort_by
)
geom_line = geom_template("geom_line", ["x", "y"], :Lines;
aes_function = stat_sort_by_x,
grouping_aes = [:color, :colour]
)


Expand Down Expand Up @@ -83,10 +91,8 @@ ggplot(df, @aes(x = x, y = y)) + geom_step()
```
"""
geom_step = geom_template("geom_step", ["x", "y"], :Stairs;
column_transformations = Dict{Symbol, Pair{Vector{Symbol}, AesTransform}}(
:y => [:y, :x]=>sort_by,
:x => [:x, :x]=>sort_by
)
aes_function = stat_sort_by_x,
grouping_aes = [:color, :colour]
)


Expand Down Expand Up @@ -126,4 +132,5 @@ ggplot(penguins, @aes(x = bill_length_mm, y = bill_depth_mm)) +
geom_path()
```
"""
geom_path = geom_template("geom_path", ["x", "y"], :Lines)
geom_path = geom_template("geom_path", ["x", "y"], :Lines;
grouping_aes = [:color, :colour])
6 changes: 4 additions & 2 deletions src/geoms/geom_template.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ function geom_template(name::AbstractString,
spec_api_function::Symbol;
aes_function::Function = do_nothing,
column_transformations::Dict{Symbol, Pair{Vector{Symbol}, AesTransform}} = Dict{Symbol, Pair{Vector{Symbol}, AesTransform}}(),
extra_args::Dict = Dict())
extra_args::Dict = Dict(),
grouping_aes::Vector{Symbol} = Symbol[])

extract_geom_aes = make_aes_extractor(required_aes)

Expand All @@ -16,7 +17,8 @@ function geom_template(name::AbstractString,
required_aes,
spec_api_function,
aes_function,
merge(transforms, column_transformations))
merge(transforms, column_transformations);
grouping_aes = grouping_aes)
end

function geom_function(plot::GGPlot, args...; kwargs...)
Expand Down
19 changes: 19 additions & 0 deletions src/grouping.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
function subgroup_split(given_aes, grouping_columns)
grouping_columns.index = 1:nrow(grouping_columns)
group_index_list = Vector{Int}.([df.index for df in groupby(grouping_columns, Not(:index))])
return [subset_aes(given_aes, index) for index in group_index_list]
end

function subset_aes(given_aes, index)
subset_given_aes_dict = Dict{Symbol, PlottableData}()
for (key, value) in given_aes
subset_given_aes_item = PlottableData(
value.raw[index],
value.makie_function,
value.label_target,
value.label_function
)
push!(subset_given_aes_dict, key => subset_given_aes_item)
end
return subset_given_aes_dict
end
1 change: 1 addition & 0 deletions src/structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ struct Geom
axis_options::Dict
aes_function::Function
column_transformations::Dict
grouping_aes::Vector{Symbol}
end

struct GGPlot
Expand Down
Loading