Skip to content

Commit

Permalink
Add range function.
Browse files Browse the repository at this point in the history
  • Loading branch information
werbenhu committed Dec 31, 2024
1 parent b94a1d6 commit b0e74e0
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 18 deletions.
94 changes: 76 additions & 18 deletions ranklist.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ranklist

import (
"fmt"
"math/rand"
"sync"
)
Expand Down Expand Up @@ -31,16 +32,18 @@ func ZeroValue[K Ordered]() K {
return zero
}

// Entry represents a key-value pair
type Entry[K Ordered, V Ordered] struct {
Key K
Value V
}

// Node 定义跳表节点的结构
// Node defines the structure of a skip list node
type Node[K Ordered, V Ordered] struct {
// 节点的键
// Key of the node
key K

// 节点的值
// Value of the node
value V
// 节点的键值对
// Key-value pair of the node
data Entry[K, V]

// 每一层对应的前向指针数组
// Array of forward pointers for each level
Expand Down Expand Up @@ -83,8 +86,7 @@ type RankList[K Ordered, V Ordered] struct {
// NewNode creates a new skip list node
func NewNode[K Ordered, V Ordered](key K, value V, level int) *Node[K, V] {
return &Node[K, V]{
key: key,
value: value,
data: Entry[K, V]{Key: key, Value: value},
forward: make([]*Node[K, V], level),
span: make([]int, level),
level: level,
Expand Down Expand Up @@ -124,7 +126,7 @@ func (sl *RankList[K, V]) Set(key K, value V) {
// 如果节点已存在,先删除旧节点
// If node exists, remove old node first
if node, exists := sl.dict[key]; exists {
sl.del(node.key)
sl.del(node.data.Key)
}

// 用于记录每层的前驱节点
Expand Down Expand Up @@ -153,8 +155,8 @@ func (sl *RankList[K, V]) Set(key K, value V) {
for i := sl.level - 1; i >= 0; i-- {
for curr.forward[i] != nil {

if curr.forward[i].value > value ||
(curr.forward[i].value == value && curr.forward[i].key > key) {
if curr.forward[i].data.Value > value ||
(curr.forward[i].data.Value == value && curr.forward[i].data.Key > key) {
break
}
sum += curr.forward[i].span[i]
Expand Down Expand Up @@ -192,6 +194,14 @@ func (sl *RankList[K, V]) Set(key K, value V) {
sl.length++
}

// Length 返回跳表中当前元素的数量。
// Length returns the current number of elements in the skip list.
func (sl *RankList[K, V]) Length() int {
sl.RLock()
defer sl.RUnlock()
return sl.length
}

// Del 从跳表中删除指定键的节点。
// 如果键存在并且节点被删除,返回true;如果键不存在,返回false。
// Del removes the node with the specified key from the skip list.
Expand Down Expand Up @@ -221,8 +231,8 @@ func (sl *RankList[K, V]) del(key K) bool {
// Find the node to be deleted
for i := sl.level - 1; i >= 0; i-- {
for curr.forward[i] != nil &&
(curr.forward[i].value < node.value ||
(curr.forward[i].value == node.value && curr.forward[i].key < key)) {
(curr.forward[i].data.Value < node.data.Value ||
(curr.forward[i].data.Value == node.data.Value && curr.forward[i].data.Key < key)) {
curr = curr.forward[i]
}
prev[i] = curr
Expand Down Expand Up @@ -259,7 +269,7 @@ func (sl *RankList[K, V]) Get(key K) (V, bool) {
defer sl.RUnlock()

if node, exists := sl.dict[key]; exists {
return node.value, true
return node.data.Value, true
}
return ZeroValue[V](), false
}
Expand All @@ -285,13 +295,13 @@ func (sl *RankList[K, V]) Rank(key K) (int, bool) {
for i := sl.level - 1; i >= 0; i-- {
for curr.forward[i] != nil {

if curr.forward[i].value == node.value && curr.forward[i].key == key {
if curr.forward[i].data.Value == node.data.Value && curr.forward[i].data.Key == key {
rank += curr.forward[i].span[i]
return rank, true
}

if curr.forward[i].value > node.value ||
(curr.forward[i].value == node.value && curr.forward[i].key > key) {
if curr.forward[i].data.Value > node.data.Value ||
(curr.forward[i].data.Value == node.data.Value && curr.forward[i].data.Key > key) {
break
}

Expand All @@ -301,3 +311,51 @@ func (sl *RankList[K, V]) Rank(key K) (int, bool) {
}
return 0, false
}

// Range 获取指定排名区间内的榜单项(不包含END)
// 返回指定范围内的条目列表。
// Range retrieves the entries within the specified rank range (excluding END)
// Returns a list of entries within the specified range.
func (sl *RankList[K, V]) Range(start int, end int) []Entry[K, V] {
sl.RLock()
defer sl.RUnlock()

rank := 0
curr := sl.header
entries := make([]Entry[K, V], 0)

for i := sl.level - 1; i >= 0; i-- {
for curr.forward[i] != nil {
rank += curr.forward[i].span[i]
if rank >= start {
break
}
curr = curr.forward[i]
}
}

total := 0
for curr.forward[0] != nil && start+total < end {
entries = append(entries, curr.forward[0].data)
curr = curr.forward[0]
total++
}
return entries
}

// Print 打印跳表结构
func (sl *RankList[K, V]) Print() {
fmt.Printf("SkipList Level: %d, Length: %d\n", sl.level, sl.length)
for i := sl.level - 1; i >= 0; i-- {
curr := sl.header
fmt.Printf("L%d -> ", i+1)
for curr != nil {
if curr != sl.header {
fmt.Printf("[%v:%v:%v] -> ", curr.data.Key, curr.data.Value, curr.span[i])
}
curr = curr.forward[i]
}
fmt.Println("NIL")
}
fmt.Println("===================================")
}
86 changes: 86 additions & 0 deletions ranklist_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,89 @@ func TestEdgeCases(t *testing.T) {
t.Errorf("Length should be 0 after all deletions, got %d", sl.length)
}
}

func TestLength(t *testing.T) {
sl := New[int, int]()

for i := 0; i < 1000; i++ {
sl.Set(i, i)
}

if sl.Length() != 1000 {
t.Errorf("Length should be 1000 after insertions, got %d", sl.length)
}
}

func TestRankList_Range(t *testing.T) {
rankList := New[int, int]()

rankList.Set(1, 100)
rankList.Set(2, 200)
rankList.Set(3, 150)
rankList.Set(4, 120)
rankList.Set(5, 180)

tests := []struct {
start, end int
expected []Entry[int, int]
}{
{
start: 1,
end: 3,
expected: []Entry[int, int]{
{Key: 1, Value: 100},
{Key: 4, Value: 120},
},
},
// Test case: Range that includes all elements
{
start: 0,
end: 5,
expected: []Entry[int, int]{
{Key: 1, Value: 100},
{Key: 4, Value: 120},
{Key: 3, Value: 150},
{Key: 5, Value: 180},
{Key: 2, Value: 200},
},
},
// Test case: Empty range
{
start: 10,
end: 10,
expected: []Entry[int, int]{},
},
// Test case: Invalid range (start > end)
{
start: 3,
end: 1,
expected: []Entry[int, int]{},
},
// Test case: Range with an out-of-bounds end
{
start: 3,
end: 10,
expected: []Entry[int, int]{
{Key: 3, Value: 150},
{Key: 5, Value: 180},
{Key: 2, Value: 200},
},
},
}

for _, tt := range tests {
t.Run("Range Test", func(t *testing.T) {
result := rankList.Range(tt.start, tt.end)
if len(result) != len(tt.expected) {
t.Errorf("expected %d entries, got %d", len(tt.expected), len(result))
}

for i, entry := range result {
if entry.Key != tt.expected[i].Key || entry.Value != tt.expected[i].Value {
t.Errorf("at index %d: expected (%d, %d), got (%d, %d)",
i, tt.expected[i].Key, tt.expected[i].Value, entry.Key, entry.Value)
}
}
})
}
}

0 comments on commit b0e74e0

Please sign in to comment.