diff --git a/interval/search.go b/interval/search.go index 41496c0..839e728 100644 --- a/interval/search.go +++ b/interval/search.go @@ -1,5 +1,9 @@ package interval +import ( + "errors" +) + // Find returns the value which interval key exactly matches with the given start and end interval. // It returns true as the second return value if an exaclty matching interval key is found in the tree; // otherwise, false. @@ -484,13 +488,13 @@ func maxEnd[V, T any](n *node[V, T], searchEnd T, cmp CmpFunc[T], visit func(*no var StopTraversal = errors.New("stop tree traversal") // VisitFunc is called on all values. Returning non-nil error will stop iteration. -// If the returned error is [StopTraversal], the iteration is interrupted, but no error is returned to the caller. -type VisitFunc[V, T any] func(V, T) error +// If the returned error is [StopTraversal], the iteration is interrupted, but no error is returned to the caller. +type VisitFunc[V, T any] func(T, T, V) error // InOrderTraverse traverses the tree in order and applies VisitFunc to each node. It's safe for concurrent use. To prevent deadlock, avoid calling other tree methods within visitFunc. func (st *SearchTree[V, T]) InOrderTraverse(visitFunc VisitFunc[V, T]) error { - tree.mu.RLock() - defer tree.mu.RUnlock() + st.mu.RLock() + defer st.mu.RUnlock() var inOrder func(n *node[V, T]) error inOrder = func(n *node[V, T]) error { @@ -504,7 +508,7 @@ func (st *SearchTree[V, T]) InOrderTraverse(visitFunc VisitFunc[V, T]) error { } // Visit current node - err := visitFunc(n.Interval.Val, n.Interval.Start) + err := visitFunc(n.Interval.Start, n.Interval.End, n.Interval.Val) if err != nil { return err } @@ -513,18 +517,18 @@ func (st *SearchTree[V, T]) InOrderTraverse(visitFunc VisitFunc[V, T]) error { return inOrder(n.Right) } - err := inOrder(tree.root) + err := inOrder(st.root) // Do not percolate StopTraversal error to the caller. if errors.Is(err, StopTraversal) { - return nil + return nil } return err } // InOrderTraverse traverses the tree in order and applies VisitFunc to each node. It's safe for concurrent use. To prevent deadlock, avoid calling other tree methods within visitFunc. -func (st *MultiValueSearchTree[V, T]) InOrderTraverse(visitFunc VisitFunc[V, T]) error { - tree.mu.RLock() - defer tree.mu.RUnlock() +func (st *MultiValueSearchTree[V, T]) InOrderTraverse(visitFunc VisitFunc[[]V, T]) error { + st.mu.RLock() + defer st.mu.RUnlock() var inOrder func(n *node[V, T]) error inOrder = func(n *node[V, T]) error { @@ -538,7 +542,7 @@ func (st *MultiValueSearchTree[V, T]) InOrderTraverse(visitFunc VisitFunc[V, T]) } // Visit current node - err := visitFunc(n.Interval.Vals, n.Interval.Start) + err := visitFunc(n.Interval.Start, n.Interval.End, n.Interval.Vals) if err != nil { return err } @@ -547,10 +551,10 @@ func (st *MultiValueSearchTree[V, T]) InOrderTraverse(visitFunc VisitFunc[V, T]) return inOrder(n.Right) } - err := inOrder(tree.root) + err := inOrder(st.root) // Do not percolate StopTraversal error to the caller. if errors.Is(err, StopTraversal) { - return nil + return nil } return err } diff --git a/interval/search_test.go b/interval/search_test.go index c737613..3cf87fe 100644 --- a/interval/search_test.go +++ b/interval/search_test.go @@ -587,6 +587,124 @@ func TestSearchTree_Select(t *testing.T) { } } +func TestSearchTree_InOrderTraverse(t *testing.T) { + type insert struct { + start int + end int + val string + } + tests := []struct { + name string + inserts []insert + expectedVisits []string + }{ + { + name: "empty interval", + inserts: []insert{}, + expectedVisits: []string{}, + }, + { + name: "single interval", + inserts: []insert{ + {start: 1, end: 10, val: "node1"}, + }, + expectedVisits: []string{"node1"}, + }, + { + name: "multiple intervals", + inserts: []insert{ + {start: 1, end: 10, val: "node1"}, + {start: 5, end: 15, val: "node2"}, + {start: 10, end: 20, val: "node3"}, + {start: 15, end: 25, val: "node4"}, + {start: 20, end: 30, val: "node5"}, + }, + expectedVisits: []string{"node1", "node2", "node3", "node4", "node5"}, + }, + { + name: "multiple intervals with same end", + inserts: []insert{ + {start: 1, end: 10, val: "node1"}, + {start: 5, end: 15, val: "node2"}, + {start: 10, end: 20, val: "node3"}, + {start: 15, end: 25, val: "node4"}, + {start: 20, end: 30, val: "node5"}, + {start: 25, end: 30, val: "node6"}, + }, + expectedVisits: []string{"node1", "node2", "node3", "node4", "node5", "node6"}, + }, + { + name: "multiple intervals with same end and same start", + inserts: []insert{ + {start: 20, end: 30, val: "node5"}, + {start: 25, end: 30, val: "node6"}, + {start: 15, end: 30, val: "node7"}, + }, + expectedVisits: []string{"node7", "node5", "node6"}, + }, + { + name: "interval spanning entire range", + inserts: []insert{ + {start: 1, end: 5, val: "node1"}, + {start: 5, end: 10, val: "node2"}, + {start: 10, end: 20, val: "node3"}, + {start: 0, end: 30, val: "node4"}, + }, + expectedVisits: []string{"node4", "node1", "node2", "node3"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + st := NewSearchTree[string](func(x, y int) int { return x - y }) + + for _, insert := range tc.inserts { + st.Insert(insert.start, insert.end, insert.val) + } + + got := []string{} + + err := st.InOrderTraverse(func(start, end int, node string) error { + got = append(got, node) + return nil + }) + if err != nil { + t.Fatalf("st.InOrderTraverse(): error %v", err) + } + + if !reflect.DeepEqual(got, tc.expectedVisits) { + t.Errorf("st.InOderTraverse(): got unexpected value %v; want %v", got, tc.expectedVisits) + } + }) + } + + t.Run("stop traversal", func(t *testing.T) { + st := NewSearchTree[string](func(x, y int) int { return x - y }) + st.Insert(17, 19, "node1") + st.Insert(5, 8, "node2") + st.Insert(21, 24, "node3") + st.Insert(4, 8, "node4") + + want := []string{"node4", "node2", "node1"} + got := []string{} + err := st.InOrderTraverse(func(start, end int, node string) error { + got = append(got, node) + if node == "node1" { + return StopTraversal + } + return nil + }) + + if err != nil { + t.Fatalf("st.InOrderTraverse(): error %v", err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("st.InOderTraverse(): got unexpected value %v; want %v", got, want) + } + }) +} + func TestMultiValueSearchTree_AnyIntersection(t *testing.T) { st := NewMultiValueSearchTree[string](func(x, y int) int { return x - y }) defer mustBeValidTree(t, st.root) @@ -1115,3 +1233,122 @@ func TestMultiValueSearchTree_MaxEnd(t *testing.T) { } } + +func TestMultiValueSearchTree_InOrderTraverse(t *testing.T) { + type insert struct { + start int + end int + val string + } + tests := []struct { + name string + inserts []insert + expectedVisits [][]string + }{ + { + name: "empty interval", + inserts: []insert{}, + expectedVisits: [][]string{}, + }, + { + name: "single interval", + inserts: []insert{ + {start: 1, end: 10, val: "node1"}, + {start: 1, end: 10, val: "node2"}, + }, + expectedVisits: [][]string{{"node1", "node2"}}, + }, + { + name: "multiple intervals", + inserts: []insert{ + {start: 1, end: 10, val: "node1"}, + {start: 5, end: 15, val: "node2"}, + {start: 10, end: 20, val: "node3"}, + {start: 15, end: 25, val: "node4"}, + {start: 20, end: 30, val: "node5"}, + }, + expectedVisits: [][]string{{"node1"}, {"node2"}, {"node3"}, {"node4"}, {"node5"}}, + }, + { + name: "multiple intervals with same end", + inserts: []insert{ + {start: 1, end: 10, val: "node1"}, + {start: 5, end: 15, val: "node2"}, + {start: 10, end: 20, val: "node3"}, + {start: 15, end: 25, val: "node4"}, + {start: 20, end: 30, val: "node5"}, + {start: 25, end: 30, val: "node6"}, + }, + expectedVisits: [][]string{{"node1"}, {"node2"}, {"node3"}, {"node4"}, {"node5"}, {"node6"}}, + }, + { + name: "multiple intervals with same end and same start", + inserts: []insert{ + {start: 20, end: 30, val: "node5"}, + {start: 25, end: 30, val: "node6"}, + {start: 15, end: 30, val: "node7"}, + }, + expectedVisits: [][]string{{"node7"}, {"node5"}, {"node6"}}, + }, + { + name: "interval spanning entire range", + inserts: []insert{ + {start: 1, end: 5, val: "node1"}, + {start: 5, end: 10, val: "node2"}, + {start: 10, end: 20, val: "node3"}, + {start: 0, end: 30, val: "node4"}, + }, + expectedVisits: [][]string{{"node4"}, {"node1"}, {"node2"}, {"node3"}}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + st := NewMultiValueSearchTree[string](func(x, y int) int { return x - y }) + + for _, insert := range tc.inserts { + st.Insert(insert.start, insert.end, insert.val) + } + + got := [][]string{} + + err := st.InOrderTraverse(func(start, end int, node []string) error { + got = append(got, node) + return nil + }) + if err != nil { + t.Fatalf("st.InOrderTraverse(): error %v", err) + } + + if !reflect.DeepEqual(got, tc.expectedVisits) { + t.Errorf("st.InOderTraverse(): got unexpected value %v; want %v", got, tc.expectedVisits) + } + }) + } + + t.Run("stop traversal", func(t *testing.T) { + st := NewMultiValueSearchTree[string](func(x, y int) int { return x - y }) + st.Insert(17, 19, "node1") + st.Insert(5, 8, "node2") + st.Insert(21, 24, "node3") + st.Insert(4, 8, "node4") + + want := [][]string{{"node4"}, {"node2"}, {"node1"}} + got := [][]string{} + err := st.InOrderTraverse(func(start, end int, node []string) error { + got = append(got, node) + if node[0] == "node1" { + return StopTraversal + } + return nil + }) + + if err != nil { + t.Fatalf("st.InOrderTraverse(): error %v", err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("st.InOderTraverse(): got unexpected value %v; want %v", got, want) + } + }) +}