Skip to content

Commit

Permalink
add ==, isequal <, and isless for DataFrameRow and GroupKey (#2669)
Browse files Browse the repository at this point in the history
  • Loading branch information
bkamins authored Mar 25, 2021
1 parent 34307bc commit 1d3f31b
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 155 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
additional column to be added in the last position in the resulting data frame
that will identify the source data frame.
([#2649](https://github.com/JuliaData/DataFrames.jl/pull/2649))
* `GroupKey` and `DataFrameRow` are consistently behaving like `NamedTuple`
in comparisons and they now implement: `hash`, `==`, `isequal`, `<`, `isless`
([#2669](https://github.com/JuliaData/DataFrames.jl/pull/2669)])
* since Julia 1.7 using broadcasting assignment on a `DataFrame` column
selected as a property (e.g. `df.col .= 1`) is allowed when column does not
exist and it allocates a fresh column
Expand Down
2 changes: 1 addition & 1 deletion src/DataFrames.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ include("dataframe/dataframe.jl")
include("subdataframe/subdataframe.jl")
include("dataframerow/dataframerow.jl")
include("groupeddataframe/groupeddataframe.jl")
include("dataframerow/utils.jl")
include("groupeddataframe/utils.jl")

include("other/broadcasting.jl")

Expand Down
86 changes: 44 additions & 42 deletions src/dataframerow/dataframerow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -450,55 +450,57 @@ Base.merge(a::DataFrameRow, b::DataFrameRow) = merge(NamedTuple(a), NamedTuple(b
Base.merge(a::DataFrameRow, b::Base.Iterators.Pairs) = merge(NamedTuple(a), b)
Base.merge(a::DataFrameRow, itr) = merge(NamedTuple(a), itr)

# hash of DataFrame rows based on its values
# so that duplicate rows would have the same hash
# table columns are passed as a tuple of vectors to ensure type specialization
rowhash(cols::Tuple{AbstractVector}, r::Int, h::UInt = zero(UInt))::UInt =
hash(cols[1][r], h)
function rowhash(cols::Tuple{Vararg{AbstractVector}}, r::Int, h::UInt = zero(UInt))::UInt
h = hash(cols[1][r], h)
rowhash(Base.tail(cols), r, h)
end

Base.hash(r::DataFrameRow, h::UInt = zero(UInt)) =
rowhash(ntuple(col -> parent(r)[!, parentcols(index(r), col)], length(r)), row(r), h)

function Base.:(==)(r1::DataFrameRow, r2::DataFrameRow)
if parent(r1) === parent(r2)
parentcols(index(r1)) == parentcols(index(r2)) || return false
row(r1) == row(r2) && return true
else
_names(r1) == _names(r2) || return false
Base.hash(r::DataFrameRow, h::UInt) = _nt_like_hash(r, h)

_getnames(x::DataFrameRow) = _names(x)
_getnames(x::NamedTuple) = propertynames(x)

# this is required as == does not allow for comparison between tuples and vectors
function _equal_names(r1, r2)
n1 = _getnames(r1)
n2 = _getnames(r2)
length(n1) == length(n2) || return false
for (a, b) in zip(n1, n2)
a == b || return false
end
all(((a, b),) -> a == b, zip(r1, r2))
return true
end

function Base.isequal(r1::DataFrameRow, r2::DataFrameRow)
if parent(r1) === parent(r2)
parentcols(index(r1)) == parentcols(index(r2)) || return false
row(r1) == row(r2) && return true
else
_names(r1) == _names(r2) || return false
for eqfun in (:isequal, :(==)),
(leftarg, rightarg) in ((:DataFrameRow, :DataFrameRow),
(:DataFrameRow, :NamedTuple),
(:NamedTuple, :DataFrameRow))
@eval function Base.$eqfun(r1::$leftarg, r2::$rightarg)
_equal_names(r1, r2) || return false
return all(((a, b),) -> $eqfun(a, b), zip(r1, r2))
end
all(((a, b),) -> isequal(a, b), zip(r1, r2))
end

# lexicographic ordering on DataFrame rows, missing > !missing
function Base.isless(r1::DataFrameRow, r2::DataFrameRow)
length(r1) == length(r2) ||
throw(ArgumentError("compared DataFrameRows must have the same number " *
"of columns (got $(length(r1)) and $(length(r2)))"))
if _names(r1) != _names(r2)
mismatch = findfirst(i -> _names(r1)[i] != _names(r2)[i], 1:length(r1))
throw(ArgumentError("compared DataFrameRows must have the same colum " *
"names but they differ in column number $mismatch " *
"where the names are :$(names(r1)[mismatch]) and " *
":$(_names(r2)[mismatch]) respectively"))
end
for (a, b) in zip(r1, r2)
isequal(a, b) || return isless(a, b)
for (eqfun, cmpfun) in ((:isequal, :isless), (:(==), :(<))),
(leftarg, rightarg) in ((:DataFrameRow, :DataFrameRow),
(:DataFrameRow, :NamedTuple),
(:NamedTuple, :DataFrameRow))
@eval function Base.$cmpfun(r1::$leftarg, r2::$rightarg)
if !_equal_names(r1, r2)
length(r1) == length(r2) ||
throw(ArgumentError("compared objects must have the same number " *
"of columns (got $(length(r1)) and $(length(r2)))"))
mismatch = findfirst(i -> _getnames(r1)[i] != _getnames(r2)[i], 1:length(r1))
throw(ArgumentError("compared objects must have the same property " *
"names but they differ in column number $mismatch " *
"where the names are :$(_getnames(r1)[mismatch]) and " *
":$(_getnames(r2)[mismatch]) respectively"))
end
for (a, b) in zip(r1, r2)
eq = $eqfun(a, b)
if ismissing(eq)
return missing
elseif !eq
return $cmpfun(a, b)
end
end
return false # here we know that r1 and r2 have equal lengths and all values were equal
end
return false
end

function DataFrame(dfr::DataFrameRow)
Expand Down
45 changes: 45 additions & 0 deletions src/groupeddataframe/groupeddataframe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,51 @@ end

Base.getproperty(key::GroupKey, p::AbstractString) = getproperty(key, Symbol(p))

Base.hash(key::GroupKey, h::UInt) = _nt_like_hash(key, h)

_getnames(x::GroupKey) = parent(x).cols

for eqfun in (:isequal, :(==)),
(leftarg, rightarg) in ((:GroupKey, :GroupKey),
(:DataFrameRow, :GroupKey),
(:GroupKey, :DataFrameRow),
(:NamedTuple, :GroupKey),
(:GroupKey, :NamedTuple))
@eval function Base.$eqfun(k1::$leftarg, k2::$rightarg)
_equal_names(k1, k2) || return false
return all(((a, b),) -> $eqfun(a, b), zip(k1, k2))
end
end

for (eqfun, cmpfun) in ((:isequal, :isless), (:(==), :(<))),
(leftarg, rightarg) in ((:GroupKey, :GroupKey),
(:DataFrameRow, :GroupKey),
(:GroupKey, :DataFrameRow),
(:NamedTuple, :GroupKey),
(:GroupKey, :NamedTuple))
@eval function Base.$cmpfun(k1::$leftarg, k2::$rightarg)
if !_equal_names(k1, k2)
length(k1) == length(k2) ||
throw(ArgumentError("compared objects must have the same number " *
"of columns (got $(length(k1)) and $(length(k2)))"))
mismatch = findfirst(i -> _getnames(k1)[i] != _getnames(k2)[i], 1:length(k1))
throw(ArgumentError("compared objects must have the same column " *
"names but they differ in column number $mismatch " *
"where the names are :$(_getnames(k1)[mismatch]) and " *
":$(_getnames(k2)[mismatch]) respectively"))
end
for (a, b) in zip(k1, k2)
eq = $eqfun(a, b)
if ismissing(eq)
return missing
elseif !eq
return $cmpfun(a, b)
end
end
return false # here we know that r1 and r2 have equal lengths and all values were equal
end
end

function Base.NamedTuple(key::GroupKey)
N = NamedTuple{Tuple(parent(key).cols)}
N(_groupvalues(parent(key), getfield(key, :idx)))
Expand Down
75 changes: 0 additions & 75 deletions src/dataframerow/utils.jl → src/groupeddataframe/utils.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,3 @@
# Rows grouping.
# Maps row contents to the indices of all the equal rows.
# Used by groupby(), join(), nonunique()
struct RowGroupDict{T<:AbstractDataFrame}
"source data table"
df::T
"row hashes (optional, can be empty)"
rhashes::Vector{UInt}
"hashindex -> index of group-representative row (optional, can be empty)"
gslots::Vector{Int}
"group index for each row"
groups::Vector{Int}
"permutation of row indices that sorts them by groups"
rperm::Vector{Int}
"starts of ranges in rperm for each group"
starts::Vector{Int}
"stops of ranges in rperm for each group"
stops::Vector{Int}
end

# "kernel" functions for hashrows()
# adjust row hashes by the hashes of column elements
function hashrows_col!(h::Vector{UInt},
Expand Down Expand Up @@ -173,7 +153,6 @@ function refpool_and_array(x::AbstractArray)
return nothing, nothing
end

# Helper function for RowGroupDict.
# Returns a tuple:
# 1) the highest group index in the `groups` vector
# 2) vector of row hashes (may be empty if hash=Val(false))
Expand Down Expand Up @@ -440,57 +419,3 @@ function compute_indices(groups::AbstractVector{<:Integer}, ngroups::Integer)

return rperm, starts, stops
end

# Build RowGroupDict for a given DataFrame, using all of its columns as grouping keys
function group_rows(df::AbstractDataFrame)
groups = Vector{Int}(undef, nrow(df))
ngroups, rhashes, gslots, sorted =
row_group_slots(ntuple(i -> df[!, i], ncol(df)), Val(true), groups, false, false)
rperm, starts, stops = compute_indices(groups, ngroups)
return RowGroupDict(df, rhashes, gslots, groups, rperm, starts, stops)
end

# Find index of a row in gd that matches given row by content, 0 if not found
function findrow(gd::RowGroupDict,
df::AbstractDataFrame,
gd_cols::Tuple{Vararg{AbstractVector}},
df_cols::Tuple{Vararg{AbstractVector}},
row::Int)
(gd.df === df) && return row # same table, return itself
# different tables, content matching required
rhash = rowhash(df_cols, row)
szm1 = length(gd.gslots)-1
slotix = ini_slotix = rhash & szm1 + 1
while true
g_row = gd.gslots[slotix]
if g_row == 0 || # not found
(rhash == gd.rhashes[g_row] &&
isequal_row(gd_cols, g_row, df_cols, row)) # found
return g_row
end
slotix = (slotix & szm1) + 1 # miss, try the next slot
(slotix == ini_slotix) && break
end
return 0 # not found
end

# Find indices of rows in 'gd' that match given row by content.
# return empty set if no row matches
function findrows(gd::RowGroupDict,
df::AbstractDataFrame,
gd_cols::Tuple{Vararg{AbstractVector}},
df_cols::Tuple{Vararg{AbstractVector}},
row::Int)
g_row = findrow(gd, df, gd_cols, df_cols, row)
(g_row == 0) && return view(gd.rperm, 0:-1)
gix = gd.groups[g_row]
return view(gd.rperm, gd.starts[gix]:gd.stops[gix])
end

function Base.getindex(gd::RowGroupDict, dfr::DataFrameRow)
g_row = findrow(gd, parent(dfr), ntuple(i -> gd.df[!, i], ncol(gd.df)),
ntuple(i -> parent(dfr)[!, i], ncol(parent(dfr))), row(dfr))
(g_row == 0) && throw(KeyError(dfr))
gix = gd.groups[g_row]
return view(gd.rperm, gd.starts[gix]:gd.stops[gix])
end
3 changes: 0 additions & 3 deletions src/other/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1604,7 +1604,6 @@ function precompile(all=false)
Base.precompile(Tuple{typeof(DataFrames._unstack),DataFrame,Array{Int,1},Int,GroupedDataFrame{DataFrame},Array{Any,1},GroupedDataFrame{DataFrame},Function,Bool,Bool})
Base.precompile(Tuple{typeof(DataFrames._combine_multicol),String,Function,GroupedDataFrame{DataFrame},Nothing})
Base.precompile(Tuple{DataFrames.Reduce{typeof(min),Nothing,Nothing},Array{Union{Missing, BigFloat},1},GroupedDataFrame{DataFrame}})
Base.precompile(Tuple{typeof(DataFrames.rowhash),Tuple{Array{Symbol,1}},Int,UInt})
Base.precompile(Tuple{typeof(combine),GroupedDataFrame{DataFrame},Pair{InvertedIndex{Symbol},ByRow{typeof(/)}}})
Base.precompile(Tuple{typeof(transform),DataFrame,Any,Any})
Base.precompile(Tuple{typeof(DataFrames._combine_process_pair_symbol),Bool,GroupedDataFrame{DataFrame},Dict{Symbol,Tuple{Bool,Int}},Array{DataFrames.TransformationResult,1},Nothing,Symbol,Bool,Base.RefValue{SubArray{Int,1,Array{Int,1},Tuple{Array{Int,1}},false}},Union{Function, Type},Tuple{Array{Int,1}}})
Expand Down Expand Up @@ -1988,7 +1987,6 @@ function precompile(all=false)
Base.precompile(Tuple{Core.kwftype(typeof(DataFrames.Type)),NamedTuple{(:id1, :id2_left, :x_left, :ID2_right, :x_right),Tuple{Array{Int,1},Array{Union{Missing, Int},1},Array{Union{Missing, Int},1},Array{Union{Missing, Int},1},Array{Union{Missing, Int},1}}},Type{DataFrame}})
Base.precompile(Tuple{typeof(collect),Base.Generator{DataFrames.DataFrameColumns{DataFrame},typeof(typeof)}})
Base.precompile(Tuple{typeof(dropmissing),DataFrame,Regex})
Base.precompile(Tuple{typeof(getindex),DataFrames.RowGroupDict{DataFrame},DataFrameRow{DataFrame,DataFrames.Index}})
Base.precompile(Tuple{typeof(DataFrames.genkeymap),GroupedDataFrame{DataFrame},Tuple{Array{Int,1}}})
Base.precompile(Tuple{typeof(DataFrames._combine_tables_with_first!),NamedTuple{(:x1,),Tuple{SubArray{Int,1,Array{Int,1},Tuple{Array{Int,1}},false}}},Tuple{Array{Int,1}},Array{Int,1},Int,Int,Function,GroupedDataFrame{DataFrame},Tuple{Array{Int,1}},Tuple{Symbol},Val{false}})
Base.precompile(Tuple{Core.kwftype(typeof(DataFrames.Type)),NamedTuple{(:b, :x),Tuple{Array{Int,1},Array{Float64,1}}},Type{DataFrame}})
Expand Down Expand Up @@ -2627,7 +2625,6 @@ function precompile(all=false)
Base.precompile(Tuple{Core.kwftype(typeof(DataFrames.Type)),NamedTuple{(:a,),Tuple{Int}},Type{DataFrame}})
Base.precompile(Tuple{Core.kwftype(typeof(DataFrames.Type)),NamedTuple{(:a, :b, :v1),Tuple{Array{Union{Missing, Symbol},1},Array{Union{Missing, Symbol},1},UnitRange{Int}}},Type{DataFrame}})
Base.precompile(Tuple{typeof(getindex),DataFrame,Int,All{Tuple{}}})
Base.precompile(Tuple{typeof(DataFrames.group_rows),DataFrame})
Base.precompile(Tuple{typeof(DataFrames._combine_process_pair_symbol),Bool,GroupedDataFrame{DataFrame},Dict{Symbol,Tuple{Bool,Int}},Array{DataFrames.TransformationResult,1},Nothing,Symbol,Bool,Complex{Float64},Union{Function, Type},Tuple{Array{Complex{Float64},1}}})
Base.precompile(Tuple{typeof(push!),DataFrame,Tuple{Int,Char}})
Base.precompile(Tuple{Core.kwftype(typeof(DataFrames.leftjoin)),NamedTuple{(:on,),Tuple{Array{Pair{Symbol,String},1}}},typeof(leftjoin),DataFrame,DataFrame})
Expand Down
11 changes: 11 additions & 0 deletions src/other/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,14 @@ function tforeach(f, x::AbstractArray; basesize::Integer)
end
return
end

function _nt_like_hash(v, h::UInt)
length(v) == 0 && return hash(NamedTuple(), h)

h = hash((), h)
for i in length(v):-1:1
h = hash(v[i], h)
end

return xor(objectid(Tuple(propertynames(v))), h)
end
Loading

0 comments on commit 1d3f31b

Please sign in to comment.