From c3c53bbfcd7c32675fe7cc3cd1e4d0c4979ea3c2 Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Tue, 18 Feb 2020 18:15:33 -0500 Subject: [PATCH] WIP: add 2-argument `get` returning `Union{Nothing,Some}` --- base/abstractdict.jl | 10 +++++++ base/compiler/compiler.jl | 1 + base/dict.jl | 20 +++++++------- base/env.jl | 1 + base/iddict.jl | 55 +++++++++++++++++++++------------------ base/iterators.jl | 3 +-- base/weakkeydict.jl | 16 +++++++++--- src/iddict.c | 11 ++++++++ 8 files changed, 75 insertions(+), 42 deletions(-) diff --git a/base/abstractdict.jl b/base/abstractdict.jl index 845dcc54f8d66..c1135e03be4e6 100644 --- a/base/abstractdict.jl +++ b/base/abstractdict.jl @@ -500,6 +500,16 @@ function hash(a::AbstractDict, h::UInt) hash(hv, h) end +function get(default::Callable, d::AbstractDict, key) + val = get(d, key) + val === nothing ? default() : something(val) +end + +function get(d::AbstractDict, key, default) + val = get(d, key) + val === nothing ? default : something(val) +end + function getindex(t::AbstractDict, key) v = get(t, key, secret_table_token) if v === secret_table_token diff --git a/base/compiler/compiler.jl b/base/compiler/compiler.jl index 039f28de828be..945423a4bb01f 100644 --- a/base/compiler/compiler.jl +++ b/base/compiler/compiler.jl @@ -66,6 +66,7 @@ include("array.jl") include("abstractarray.jl") # core structures +include("some.jl") include("bitarray.jl") include("bitset.jl") include("abstractdict.jl") diff --git a/base/dict.jl b/base/dict.jl index 872c9ca0e3188..7233be6a9d62b 100644 --- a/base/dict.jl +++ b/base/dict.jl @@ -488,9 +488,9 @@ julia> get(d, "c", 3) """ get(collection, key, default) -function get(h::Dict{K,V}, key, default) where V where K +function get(h::Dict{K,V}, key) where V where K index = ht_keyindex(h, key) - @inbounds return (index < 0) ? default : h.vals[index]::V + @inbounds return (index < 0) ? nothing : Some{V}(h.vals[index]::V) end """ @@ -510,11 +510,6 @@ end """ get(::Function, collection, key) -function get(default::Callable, h::Dict{K,V}, key) where V where K - index = ht_keyindex(h, key) - @inbounds return (index < 0) ? default() : h.vals[index]::V -end - """ haskey(collection, key) -> Bool @@ -561,6 +556,11 @@ function getkey(h::Dict{K,V}, key, default) where V where K @inbounds return (index<0) ? default : h.keys[index]::K end +function getkey(h::Dict{K,V}, key) where V where K + index = ht_keyindex(h, key) + @inbounds return (index<0) ? nothing : Some{K}(h.keys[index]::K) +end + function _pop!(h::Dict, index) @inbounds val = h.vals[index] _delete!(h, index) @@ -769,12 +769,12 @@ function getindex(dict::ImmutableDict, key) end throw(KeyError(key)) end -function get(dict::ImmutableDict, key, default) +function get(dict::ImmutableDict{K,V}, key) where {K,V} while isdefined(dict, :parent) - dict.key == key && return dict.value + dict.key == key && return Some{V}(dict.value) dict = dict.parent end - return default + return nothing end # this actually defines reverse iteration (e.g. it should not be used for merge/copy/filter type operations) diff --git a/base/env.jl b/base/env.jl index 8f5256f25915e..175dd24dfc5db 100644 --- a/base/env.jl +++ b/base/env.jl @@ -78,6 +78,7 @@ const ENV = EnvDict() getindex(::EnvDict, k::AbstractString) = access_env(k->throw(KeyError(k)), k) get(::EnvDict, k::AbstractString, def) = access_env(k->def, k) +get(::EnvDict, k::AbstractString) = (v = get(ENV, k, nothing); v === nothing ? v : Some(v)) get(f::Callable, ::EnvDict, k::AbstractString) = access_env(k->f(), k) in(k::AbstractString, ::KeySet{String, EnvDict}) = _hasenv(k) pop!(::EnvDict, k::AbstractString) = (v = ENV[k]; _unsetenv(k); v) diff --git a/base/iddict.jl b/base/iddict.jl index 23ba65799a395..97195ab47e435 100644 --- a/base/iddict.jl +++ b/base/iddict.jl @@ -81,36 +81,46 @@ function setindex!(d::IdDict{K,V}, @nospecialize(val), @nospecialize(key)) where return d end -function get(d::IdDict{K,V}, @nospecialize(key), @nospecialize(default)) where {K, V} - val = ccall(:jl_eqtable_get, Any, (Any, Any, Any), d.ht, key, default) - val === default ? default : val::V +function get(d::IdDict{K,V}, @nospecialize(key)) where {K, V} + found = RefValue{Cint}(0) + val = ccall(:jl_eqtable_get1, Any, (Any, Any, Ptr{Cint}), d.ht, key, found) + if found[] == 1 + return Some{V}(val::V) + end + return nothing end function getindex(d::IdDict{K,V}, @nospecialize(key)) where {K, V} - val = get(d, key, secret_table_token) - val === secret_table_token && throw(KeyError(key)) - return val::V + val = get(d, key) + val === nothing && throw(KeyError(key)) + return something(val::Some{V}) end -function pop!(d::IdDict{K,V}, @nospecialize(key), @nospecialize(default)) where {K, V} +function _iddict_pop!(d::IdDict{K,V}, @nospecialize(key)) where {K, V} found = RefValue{Cint}(0) - val = ccall(:jl_eqtable_pop, Any, (Any, Any, Any, Ptr{Cint}), d.ht, key, default, found) + val = ccall(:jl_eqtable_pop, Any, (Any, Any, Any, Ptr{Cint}), d.ht, key, nothing, found) if found[] === Cint(0) - return default + return nothing else d.count -= 1 d.ndel += 1 - return val::V + return Some{V}(val::V) end end +function pop!(d::IdDict{K,V}, @nospecialize(key), @nospecialize(default)) where {K, V} + val = _iddict_pop!(d, key) + val === nothing && return default + return something(val::Some{V}) +end + function pop!(d::IdDict{K,V}, @nospecialize(key)) where {K, V} - val = pop!(d, key, secret_table_token) - val === secret_table_token && throw(KeyError(key)) - return val::V + val = _iddict_pop!(d, key) + val === nothing && throw(KeyError(key)) + return something(val::Some{V}) end function delete!(d::IdDict{K}, @nospecialize(key)) where K - pop!(d, key, secret_table_token) + _iddict_pop!(d, key) d end @@ -136,24 +146,17 @@ copy(d::IdDict) = typeof(d)(d) get!(d::IdDict{K,V}, @nospecialize(key), @nospecialize(default)) where {K, V} = (d[key] = get(d, key, default))::V -function get(default::Callable, d::IdDict{K,V}, @nospecialize(key)) where {K, V} - val = get(d, key, secret_table_token) - if val === secret_table_token - val = default() - end - return val -end - function get!(default::Callable, d::IdDict{K,V}, @nospecialize(key)) where {K, V} - val = get(d, key, secret_table_token) - if val === secret_table_token + val = get(d, key) + if val === nothing val = default() setindex!(d, val, key) + return val end - return val + return something(val::Some{V}) end -in(@nospecialize(k), v::KeySet{<:Any,<:IdDict}) = get(v.dict, k, secret_table_token) !== secret_table_token +in(@nospecialize(k), v::KeySet{<:Any,<:IdDict}) = get(v.dict, k) !== nothing # For some AbstractDict types, it is safe to implement filter! # by deleting keys during iteration. diff --git a/base/iterators.jl b/base/iterators.jl index 20baf2077e59d..93f02fe4e4597 100644 --- a/base/iterators.jl +++ b/base/iterators.jl @@ -249,8 +249,7 @@ keys(v::Pairs) = v.itr values(v::Pairs) = v.data getindex(v::Pairs, key) = v.data[key] setindex!(v::Pairs, value, key) = (v.data[key] = value; v) -get(v::Pairs, key, default) = get(v.data, key, default) -get(f::Base.Callable, v::Pairs, key) = get(f, v.data, key) +get(v::Pairs, key) = get(v.data, key) # zip diff --git a/base/weakkeydict.jl b/base/weakkeydict.jl index 079015ba8cd16..93df2ac4bb1aa 100644 --- a/base/weakkeydict.jl +++ b/base/weakkeydict.jl @@ -86,14 +86,22 @@ end function getkey(wkh::WeakKeyDict{K}, kk, default) where K return lock(wkh) do - k = getkey(wkh.ht, kk, secret_table_token) - k === secret_table_token && return default - return k.value::K + k = getkey(wkh.ht, kk) + k === nothing && return default + return something(k).value::K + end +end + +function getkey(wkh::WeakKeyDict{K}, kk) where K + return lock(wkh) do + k = getkey(wkh.ht, kk) + k === nothing && return nothing + return Some(something(k).value::K) end end map!(f,iter::ValueIterator{<:WeakKeyDict})= map!(f, values(iter.dict.ht)) -get(wkh::WeakKeyDict{K}, key, default) where {K} = lock(() -> get(wkh.ht, key, default), wkh) +get(wkh::WeakKeyDict{K}, key) where {K} = lock(() -> get(wkh.ht, key), wkh) get(default::Callable, wkh::WeakKeyDict{K}, key) where {K} = lock(() -> get(default, wkh.ht, key), wkh) function get!(wkh::WeakKeyDict{K}, key, default) where {K} !isa(key, K) && throw(ArgumentError("$(limitrepr(key)) is not a valid key for type $K")) diff --git a/src/iddict.c b/src/iddict.c index ca573247671e6..4ce8dd564a52a 100644 --- a/src/iddict.c +++ b/src/iddict.c @@ -158,6 +158,17 @@ jl_value_t *jl_eqtable_get(jl_array_t *h, jl_value_t *key, jl_value_t *deflt) return (bp == NULL) ? deflt : (jl_value_t *)*bp; } +JL_DLLEXPORT +jl_value_t *jl_eqtable_get1(jl_array_t *h, jl_value_t *key, int *found) +{ + void **bp = jl_table_peek_bp(h, key); + if (found) + *found = (bp != NULL); + if (bp == NULL) + return jl_nothing; + return (jl_value_t *)*bp; +} + JL_DLLEXPORT jl_value_t *jl_eqtable_pop(jl_array_t *h, jl_value_t *key, jl_value_t *deflt, int *found) {