diff --git a/src/neo/IO/ByteArrayComparer.cs b/src/neo/IO/ByteArrayComparer.cs index 478c6e3c7d..8e9b2573c0 100644 --- a/src/neo/IO/ByteArrayComparer.cs +++ b/src/neo/IO/ByteArrayComparer.cs @@ -1,13 +1,30 @@ using System; using System.Collections.Generic; +using System.Runtime.CompilerServices; namespace Neo.IO { internal class ByteArrayComparer : IComparer { - public static readonly ByteArrayComparer Default = new ByteArrayComparer(); + public static readonly ByteArrayComparer Default = new ByteArrayComparer(1); + public static readonly ByteArrayComparer Reverse = new ByteArrayComparer(-1); + + private readonly int direction; + + private ByteArrayComparer(int direction) + { + this.direction = direction; + } public int Compare(byte[] x, byte[] y) + { + return direction > 0 + ? CompareInternal(x, y) + : -CompareInternal(x, y); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int CompareInternal(byte[] x, byte[] y) { int length = Math.Min(x.Length, y.Length); for (int i = 0; i < length; i++) diff --git a/src/neo/IO/Caching/CloneCache.cs b/src/neo/IO/Caching/CloneCache.cs index 559ee3d279..93e01e115a 100644 --- a/src/neo/IO/Caching/CloneCache.cs +++ b/src/neo/IO/Caching/CloneCache.cs @@ -24,15 +24,15 @@ protected override void DeleteInternal(TKey key) innerCache.Delete(key); } - protected override IEnumerable<(TKey, TValue)> FindInternal(byte[] key_prefix) + protected override TValue GetInternal(TKey key) { - foreach (var (key, value) in innerCache.Find(key_prefix)) - yield return (key, value.Clone()); + return innerCache[key].Clone(); } - protected override TValue GetInternal(TKey key) + protected override IEnumerable<(TKey, TValue)> SeekInternal(byte[] keyOrPreifx, SeekDirection direction) { - return innerCache[key].Clone(); + foreach (var (key, value) in innerCache.Seek(keyOrPreifx, direction)) + yield return (key, value.Clone()); } protected override TValue TryGetInternal(TKey key) diff --git a/src/neo/IO/Caching/DataCache.cs b/src/neo/IO/Caching/DataCache.cs index 4b39b5b7cc..21e4f95bb4 100644 --- a/src/neo/IO/Caching/DataCache.cs +++ b/src/neo/IO/Caching/DataCache.cs @@ -151,57 +151,18 @@ public void Delete(TKey key) /// Entries found with the desired prefix public IEnumerable<(TKey Key, TValue Value)> Find(byte[] key_prefix = null) { - IEnumerable<(byte[], TKey, TValue)> cached; - HashSet cachedKeySet; - lock (dictionary) - { - cached = dictionary - .Where(p => p.Value.State != TrackState.Deleted && (key_prefix == null || p.Key.ToArray().AsSpan().StartsWith(key_prefix))) - .Select(p => - ( - KeyBytes: p.Key.ToArray(), - p.Key, - p.Value.Item - )) - .OrderBy(p => p.KeyBytes, ByteArrayComparer.Default) - .ToArray(); - cachedKeySet = new HashSet(dictionary.Keys); - } - var uncached = FindInternal(key_prefix ?? Array.Empty()) - .Where(p => !cachedKeySet.Contains(p.Key)) - .Select(p => - ( - KeyBytes: p.Key.ToArray(), - p.Key, - p.Value - )); - using (var e1 = cached.GetEnumerator()) - using (var e2 = uncached.GetEnumerator()) - { - (byte[] KeyBytes, TKey Key, TValue Item) i1, i2; - bool c1 = e1.MoveNext(); - bool c2 = e2.MoveNext(); - i1 = c1 ? e1.Current : default; - i2 = c2 ? e2.Current : default; - while (c1 || c2) - { - if (!c2 || (c1 && ByteArrayComparer.Default.Compare(i1.KeyBytes, i2.KeyBytes) < 0)) - { - yield return (i1.Key, i1.Item); - c1 = e1.MoveNext(); - i1 = c1 ? e1.Current : default; - } - else - { - yield return (i2.Key, i2.Item); - c2 = e2.MoveNext(); - i2 = c2 ? e2.Current : default; - } - } - } + foreach (var (key, value) in Seek(key_prefix, SeekDirection.Forward)) + if (key.ToArray().AsSpan().StartsWith(key_prefix)) + yield return (key, value); } - protected abstract IEnumerable<(TKey Key, TValue Value)> FindInternal(byte[] key_prefix); + public IEnumerable<(TKey Key, TValue Value)> FindRange(TKey start, TKey end) + { + var endKey = end.ToArray(); + foreach (var (key, value) in Seek(start.ToArray(), SeekDirection.Forward)) + if (ByteArrayComparer.Default.Compare(key.ToArray(), endKey) < 0) + yield return (key, value); + } public IEnumerable GetChangeSet() { @@ -299,6 +260,67 @@ public TValue GetOrAdd(TKey key, Func factory) } } + /// + /// Seek to the entry with specific key + /// + /// The key to be sought + /// The direction of seek + /// An enumerator containing all the entries after seeking. + public IEnumerable<(TKey Key, TValue Value)> Seek(byte[] keyOrPrefix = null, SeekDirection direction = SeekDirection.Forward) + { + IEnumerable<(byte[], TKey, TValue)> cached; + HashSet cachedKeySet; + ByteArrayComparer comparer = direction == SeekDirection.Forward ? ByteArrayComparer.Default : ByteArrayComparer.Reverse; + lock (dictionary) + { + cached = dictionary + .Where(p => p.Value.State != TrackState.Deleted && (keyOrPrefix == null || comparer.Compare(p.Key.ToArray(), keyOrPrefix) >= 0)) + .Select(p => + ( + KeyBytes: p.Key.ToArray(), + p.Key, + p.Value.Item + )) + .OrderBy(p => p.KeyBytes, comparer) + .ToArray(); + cachedKeySet = new HashSet(dictionary.Keys); + } + var uncached = SeekInternal(keyOrPrefix ?? Array.Empty(), direction) + .Where(p => !cachedKeySet.Contains(p.Key)) + .Select(p => + ( + KeyBytes: p.Key.ToArray(), + p.Key, + p.Value + )); + using (var e1 = cached.GetEnumerator()) + using (var e2 = uncached.GetEnumerator()) + { + (byte[] KeyBytes, TKey Key, TValue Item) i1, i2; + bool c1 = e1.MoveNext(); + bool c2 = e2.MoveNext(); + i1 = c1 ? e1.Current : default; + i2 = c2 ? e2.Current : default; + while (c1 || c2) + { + if (!c2 || (c1 && comparer.Compare(i1.KeyBytes, i2.KeyBytes) < 0)) + { + yield return (i1.Key, i1.Item); + c1 = e1.MoveNext(); + i1 = c1 ? e1.Current : default; + } + else + { + yield return (i2.Key, i2.Item); + c2 = e2.MoveNext(); + i2 = c2 ? e2.Current : default; + } + } + } + } + + protected abstract IEnumerable<(TKey Key, TValue Value)> SeekInternal(byte[] keyOrPrefix, SeekDirection direction); + public TValue TryGet(TKey key) { lock (dictionary) diff --git a/src/neo/IO/Caching/SeekDirection.cs b/src/neo/IO/Caching/SeekDirection.cs new file mode 100644 index 0000000000..5387fd8311 --- /dev/null +++ b/src/neo/IO/Caching/SeekDirection.cs @@ -0,0 +1,8 @@ +namespace Neo.IO.Caching +{ + public enum SeekDirection : sbyte + { + Forward = 1, + Backward = -1 + } +} diff --git a/src/neo/Persistence/IReadOnlyStore.cs b/src/neo/Persistence/IReadOnlyStore.cs index 7a23bd4c80..234a36f534 100644 --- a/src/neo/Persistence/IReadOnlyStore.cs +++ b/src/neo/Persistence/IReadOnlyStore.cs @@ -1,3 +1,4 @@ +using Neo.IO.Caching; using System.Collections.Generic; namespace Neo.Persistence @@ -7,7 +8,7 @@ namespace Neo.Persistence /// public interface IReadOnlyStore { - IEnumerable<(byte[] Key, byte[] Value)> Find(byte table, byte[] prefix); + IEnumerable<(byte[] Key, byte[] Value)> Seek(byte table, byte[] key, SeekDirection direction); byte[] TryGet(byte table, byte[] key); } } diff --git a/src/neo/Persistence/MemorySnapshot.cs b/src/neo/Persistence/MemorySnapshot.cs index 5b1dc35742..1474b26df7 100644 --- a/src/neo/Persistence/MemorySnapshot.cs +++ b/src/neo/Persistence/MemorySnapshot.cs @@ -1,4 +1,5 @@ using Neo.IO; +using Neo.IO.Caching; using System; using System.Collections.Concurrent; using System.Collections.Generic; @@ -41,18 +42,19 @@ public void Dispose() { } - public IEnumerable<(byte[] Key, byte[] Value)> Find(byte table, byte[] prefix) + public void Put(byte table, byte[] key, byte[] value) { - IEnumerable> records = immutableData[table]; - if (prefix?.Length > 0) - records = records.Where(p => p.Key.AsSpan().StartsWith(prefix)); - records = records.OrderBy(p => p.Key, ByteArrayComparer.Default); - return records.Select(p => (p.Key, p.Value)); + writeBatch[table][key.EnsureNotNull()] = value; } - public void Put(byte table, byte[] key, byte[] value) + public IEnumerable<(byte[] Key, byte[] Value)> Seek(byte table, byte[] keyOrPrefix, SeekDirection direction = SeekDirection.Forward) { - writeBatch[table][key.EnsureNotNull()] = value; + ByteArrayComparer comparer = direction == SeekDirection.Forward ? ByteArrayComparer.Default : ByteArrayComparer.Reverse; + IEnumerable> records = immutableData[table]; + if (keyOrPrefix?.Length > 0) + records = records.Where(p => comparer.Compare(p.Key, keyOrPrefix) >= 0); + records = records.OrderBy(p => p.Key, comparer); + return records.Select(p => (p.Key, p.Value)); } public byte[] TryGet(byte table, byte[] key) diff --git a/src/neo/Persistence/MemoryStore.cs b/src/neo/Persistence/MemoryStore.cs index 5b6c09c58f..67bdd04152 100644 --- a/src/neo/Persistence/MemoryStore.cs +++ b/src/neo/Persistence/MemoryStore.cs @@ -1,4 +1,5 @@ using Neo.IO; +using Neo.IO.Caching; using System; using System.Collections.Concurrent; using System.Collections.Generic; @@ -26,16 +27,6 @@ public void Dispose() { } - public IEnumerable<(byte[] Key, byte[] Value)> Find(byte table, byte[] prefix) - { - IEnumerable> records = innerData[table]; - if (prefix?.Length > 0) - records = records.Where(p => p.Key.AsSpan().StartsWith(prefix)); - records = records.OrderBy(p => p.Key, ByteArrayComparer.Default); - foreach (var pair in records) - yield return (pair.Key, pair.Value); - } - public ISnapshot GetSnapshot() { return new MemorySnapshot(innerData); @@ -46,6 +37,17 @@ public void Put(byte table, byte[] key, byte[] value) innerData[table][key.EnsureNotNull()] = value; } + public IEnumerable<(byte[] Key, byte[] Value)> Seek(byte table, byte[] keyOrPrefix, SeekDirection direction = SeekDirection.Forward) + { + ByteArrayComparer comparer = direction == SeekDirection.Forward ? ByteArrayComparer.Default : ByteArrayComparer.Reverse; + IEnumerable> records = innerData[table]; + if (keyOrPrefix?.Length > 0) + records = records.Where(p => comparer.Compare(p.Key, keyOrPrefix) >= 0); + records = records.OrderBy(p => p.Key, comparer); + foreach (var pair in records) + yield return (pair.Key, pair.Value); + } + public byte[] TryGet(byte table, byte[] key) { innerData[table].TryGetValue(key.EnsureNotNull(), out byte[] value); diff --git a/src/neo/Persistence/StoreDataCache.cs b/src/neo/Persistence/StoreDataCache.cs index 995a34290f..831c88dce1 100644 --- a/src/neo/Persistence/StoreDataCache.cs +++ b/src/neo/Persistence/StoreDataCache.cs @@ -31,14 +31,14 @@ protected override void DeleteInternal(TKey key) snapshot?.Delete(prefix, key.ToArray()); } - protected override IEnumerable<(TKey, TValue)> FindInternal(byte[] key_prefix) + protected override TValue GetInternal(TKey key) { - return store.Find(prefix, key_prefix).Select(p => (p.Key.AsSerializable(), p.Value.AsSerializable())); + return store.TryGet(prefix, key.ToArray()).AsSerializable(); } - protected override TValue GetInternal(TKey key) + protected override IEnumerable<(TKey, TValue)> SeekInternal(byte[] keyOrPrefix, SeekDirection direction) { - return store.TryGet(prefix, key.ToArray()).AsSerializable(); + return store.Seek(prefix, keyOrPrefix, direction).Select(p => (p.Key.AsSerializable(), p.Value.AsSerializable())); } protected override TValue TryGetInternal(TKey key) diff --git a/tests/neo.UnitTests/IO/Caching/UT_DataCache.cs b/tests/neo.UnitTests/IO/Caching/UT_DataCache.cs index 2b07c8fa0c..74f3213f82 100644 --- a/tests/neo.UnitTests/IO/Caching/UT_DataCache.cs +++ b/tests/neo.UnitTests/IO/Caching/UT_DataCache.cs @@ -115,9 +115,10 @@ protected override void AddInternal(TKey key, TValue value) InnerDict.Add(key, value); } - protected override IEnumerable<(TKey, TValue)> FindInternal(byte[] key_prefix) + protected override IEnumerable<(TKey, TValue)> SeekInternal(byte[] keyOrPrefix, SeekDirection direction = SeekDirection.Forward) { - return InnerDict.Where(kvp => kvp.Key.ToArray().Take(key_prefix.Length).SequenceEqual(key_prefix)).Select(p => (p.Key, p.Value)); + ByteArrayComparer comparer = direction == SeekDirection.Forward ? ByteArrayComparer.Default : ByteArrayComparer.Reverse; + return InnerDict.Where(kvp => comparer.Compare(kvp.Key.ToArray(), keyOrPrefix) >= 0).Select(p => (p.Key, p.Value)); } protected override TValue GetInternal(TKey key) @@ -264,6 +265,43 @@ public void TestFind() items.Count().Should().Be(0); } + [TestMethod] + public void TestSeek() + { + myDataCache.Add(new MyKey("key1"), new MyValue("value1")); + myDataCache.Add(new MyKey("key2"), new MyValue("value2")); + + myDataCache.InnerDict.Add(new MyKey("key3"), new MyValue("value3")); + myDataCache.InnerDict.Add(new MyKey("key4"), new MyValue("value4")); + + var items = myDataCache.Seek(new MyKey("key3").ToArray(), SeekDirection.Backward).ToArray(); + items[0].Key.Should().Be(new MyKey("key3")); + items[0].Value.Should().Be(new MyValue("value3")); + items[1].Key.Should().Be(new MyKey("key2")); + items[1].Value.Should().Be(new MyValue("value2")); + items.Count().Should().Be(3); + + items = myDataCache.Seek(new MyKey("key5").ToArray(), SeekDirection.Forward).ToArray(); + items.Count().Should().Be(0); + } + + [TestMethod] + public void TestFindRange() + { + myDataCache.Add(new MyKey("key1"), new MyValue("value1")); + myDataCache.Add(new MyKey("key2"), new MyValue("value2")); + + myDataCache.InnerDict.Add(new MyKey("key3"), new MyValue("value3")); + myDataCache.InnerDict.Add(new MyKey("key4"), new MyValue("value4")); + + var items = myDataCache.FindRange(new MyKey("key3"), new MyKey("key5")).ToArray(); + items[0].Key.Should().Be(new MyKey("key3")); + items[0].Value.Should().Be(new MyValue("value3")); + items[1].Key.Should().Be(new MyKey("key4")); + items[1].Value.Should().Be(new MyValue("value4")); + items.Count().Should().Be(2); + } + [TestMethod] public void TestGetChangeSet() { diff --git a/tests/neo.UnitTests/IO/UT_ByteArrayComparer.cs b/tests/neo.UnitTests/IO/UT_ByteArrayComparer.cs index 562e2b6f6e..89c0c621d7 100644 --- a/tests/neo.UnitTests/IO/UT_ByteArrayComparer.cs +++ b/tests/neo.UnitTests/IO/UT_ByteArrayComparer.cs @@ -10,7 +10,7 @@ public class UT_ByteArrayComparer [TestMethod] public void TestCompare() { - ByteArrayComparer comparer = new ByteArrayComparer(); + ByteArrayComparer comparer = ByteArrayComparer.Default; byte[] x = new byte[0], y = new byte[0]; comparer.Compare(x, y).Should().Be(0); @@ -22,6 +22,16 @@ public void TestCompare() x = new byte[] { 1 }; y = new byte[] { 2 }; comparer.Compare(x, y).Should().Be(-1); + + comparer = ByteArrayComparer.Reverse; + x = new byte[] { 3 }; + comparer.Compare(x, y).Should().Be(-1); + y = x; + comparer.Compare(x, y).Should().Be(0); + + x = new byte[] { 1 }; + y = new byte[] { 2 }; + comparer.Compare(x, y).Should().Be(1); } } }