Skip to content

Commit

Permalink
small streams refactoring, tests for accounts
Browse files Browse the repository at this point in the history
Signed-off-by: garyschulte <[email protected]>
  • Loading branch information
garyschulte committed Sep 22, 2023
1 parent 06ff1c4 commit 8179a31
Show file tree
Hide file tree
Showing 7 changed files with 276 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@
import java.util.function.Predicate;
import java.util.function.Supplier;

import kotlin.Pair;
import org.apache.tuweni.bytes.Bytes;
import org.apache.tuweni.bytes.Bytes32;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@SuppressWarnings("unused")
public class BonsaiWorldStateKeyValueStorage implements WorldStateStorage, FlatWorldStateStorage, AutoCloseable {
public class BonsaiWorldStateKeyValueStorage
implements WorldStateStorage, FlatWorldStateStorage, AutoCloseable {
private static final Logger LOG = LoggerFactory.getLogger(BonsaiWorldStateKeyValueStorage.class);

// 0x776f726c64526f6f74
Expand Down Expand Up @@ -256,6 +258,15 @@ public NavigableMap<Bytes32, Bytes> streamFlatAccounts(
.streamAccountFlatDatabase(composedWorldStateStorage, startKeyHash, endKeyHash, max);
}

@Override
public NavigableMap<Bytes32, Bytes> streamFlatAccounts(
final Bytes startKeyHash,
final Bytes32 endKeyHash,
final Predicate<Pair<Bytes32, Bytes>> takeWhile) {
return getFlatDbStrategy()
.streamAccountFlatDatabase(composedWorldStateStorage, startKeyHash, endKeyHash, takeWhile);
}

@Override
public NavigableMap<Bytes32, Bytes> streamFlatStorages(
final Hash accountHash, final Bytes startKeyHash, final Bytes32 endKeyHash, final long max) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.NavigableMap;
import java.util.Optional;
import java.util.TreeMap;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand Down Expand Up @@ -193,18 +194,18 @@ public NavigableMap<Bytes32, Bytes> streamAccountFlatDatabase(
final Bytes startKeyHash,
final Bytes32 endKeyHash,
final long max) {
final Stream<Pair<Bytes32, Bytes>> pairStream =
storage
.streamFromKey(
ACCOUNT_INFO_STATE, startKeyHash.toArrayUnsafe(), endKeyHash.toArrayUnsafe())
.limit(max)
.map(pair -> new Pair<>(Bytes32.wrap(pair.getKey()), Bytes.wrap(pair.getValue())));

final TreeMap<Bytes32, Bytes> collected =
pairStream.collect(
Collectors.toMap(Pair::getFirst, Pair::getSecond, (v1, v2) -> v1, TreeMap::new));
pairStream.close();
return collected;
return toNavigableMap(accountsToPairStream(storage, startKeyHash, endKeyHash).limit(max));
}

public NavigableMap<Bytes32, Bytes> streamAccountFlatDatabase(
final SegmentedKeyValueStorage storage,
final Bytes startKeyHash,
final Bytes32 endKeyHash,
final Predicate<Pair<Bytes32, Bytes>> takeWhile) {

return toNavigableMap(
accountsToPairStream(storage, startKeyHash, endKeyHash).takeWhile(takeWhile));
}

public NavigableMap<Bytes32, Bytes> streamStorageFlatDatabase(
Expand All @@ -213,19 +214,49 @@ public NavigableMap<Bytes32, Bytes> streamStorageFlatDatabase(
final Bytes startKeyHash,
final Bytes32 endKeyHash,
final long max) {
final Stream<Pair<Bytes32, Bytes>> pairStream =
storage
.streamFromKey(
ACCOUNT_STORAGE_STORAGE,
Bytes.concatenate(accountHash, startKeyHash).toArrayUnsafe(),
Bytes.concatenate(accountHash, endKeyHash).toArrayUnsafe())
.limit(max)
.map(
pair ->
new Pair<>(
Bytes32.wrap(Bytes.wrap(pair.getKey()).slice(Hash.SIZE)),
RLP.encodeValue(Bytes.wrap(pair.getValue()).trimLeadingZeros())));

return toNavigableMap(
storageToPairStream(storage, accountHash, startKeyHash, endKeyHash).limit(max));
}

public NavigableMap<Bytes32, Bytes> streamStorageFlatDatabase(
final SegmentedKeyValueStorage storage,
final Hash accountHash,
final Bytes startKeyHash,
final Bytes32 endKeyHash,
final Predicate<Pair<Bytes32, Bytes>> takeWhile) {

return toNavigableMap(
storageToPairStream(storage, accountHash, startKeyHash, endKeyHash).takeWhile(takeWhile));
}

private static Stream<Pair<Bytes32, Bytes>> storageToPairStream(
final SegmentedKeyValueStorage storage,
final Hash accountHash,
final Bytes startKeyHash,
final Bytes32 endKeyHash) {

return storage
.streamFromKey(
ACCOUNT_STORAGE_STORAGE,
Bytes.concatenate(accountHash, startKeyHash).toArrayUnsafe(),
Bytes.concatenate(accountHash, endKeyHash).toArrayUnsafe())
.map(
pair ->
new Pair<>(
Bytes32.wrap(Bytes.wrap(pair.getKey()).slice(Hash.SIZE)),
RLP.encodeValue(Bytes.wrap(pair.getValue()).trimLeadingZeros())));
}

private static Stream<Pair<Bytes32, Bytes>> accountsToPairStream(
final SegmentedKeyValueStorage storage, final Bytes startKeyHash, final Bytes32 endKeyHash) {
return storage
.streamFromKey(ACCOUNT_INFO_STATE, startKeyHash.toArrayUnsafe(), endKeyHash.toArrayUnsafe())
.map(pair -> new Pair<>(Bytes32.wrap(pair.getKey()), Bytes.wrap(pair.getValue())));
}

private static NavigableMap<Bytes32, Bytes> toNavigableMap(
final Stream<Pair<Bytes32, Bytes>> pairStream) {
final TreeMap<Bytes32, Bytes> collected =
pairStream.collect(
Collectors.toMap(Pair::getFirst, Pair::getSecond, (v1, v2) -> v1, TreeMap::new));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,4 @@ public interface FlatWorldStateStorage {
FlatDbMode getFlatDbMode();

FlatDbStrategy getFlatDbStrategy();

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.TreeMap;
import java.util.function.Predicate;

import kotlin.Pair;
import org.apache.tuweni.bytes.Bytes;
import org.apache.tuweni.bytes.Bytes32;

Expand Down Expand Up @@ -66,6 +67,21 @@ default NavigableMap<Bytes32, Bytes> streamFlatAccounts(
return new TreeMap<>();
}

/**
* Streams flat accounts within a specified range.
*
* @param startKeyHash The start key hash of the range.
* @param endKeyHash The end key hash of the range.
* @param takeWhile Function to limit the number of entries to stream.
* @return A map of flat accounts. (Empty map in this default implementation)
*/
default NavigableMap<Bytes32, Bytes> streamFlatAccounts(
final Bytes startKeyHash,
final Bytes32 endKeyHash,
final Predicate<Pair<Bytes32, Bytes>> takeWhile) {
return new TreeMap<>();
}

/**
* Streams flat storages within a specified range.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,48 @@
*/
package org.hyperledger.besu.ethereum.eth.manager.snap;

import kotlin.collections.ArrayDeque;
import org.apache.tuweni.bytes.Bytes;
import org.apache.tuweni.bytes.Bytes32;
import org.apache.tuweni.units.bigints.UInt256;
import org.hyperledger.besu.datatypes.Hash;
import org.hyperledger.besu.ethereum.eth.manager.EthMessages;
import org.hyperledger.besu.ethereum.eth.messages.snap.AccountRangeMessage;
import org.hyperledger.besu.ethereum.eth.messages.snap.ByteCodesMessage;
import org.hyperledger.besu.ethereum.eth.messages.snap.GetAccountRangeMessage;
import org.hyperledger.besu.ethereum.eth.messages.snap.GetStorageRangeMessage;
import org.hyperledger.besu.ethereum.eth.messages.snap.SnapV1;
import org.hyperledger.besu.ethereum.eth.messages.snap.StorageRangeMessage;
import org.hyperledger.besu.ethereum.eth.messages.snap.TrieNodesMessage;
import org.hyperledger.besu.ethereum.p2p.rlpx.wire.MessageData;
import org.hyperledger.besu.ethereum.proof.WorldStateProofProvider;
import org.hyperledger.besu.ethereum.rlp.BytesValueRLPOutput;
import org.hyperledger.besu.ethereum.worldstate.WorldStateArchive;
import org.hyperledger.besu.ethereum.worldstate.WorldStateStorage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collections;
import java.util.HashMap;
import java.util.NavigableMap;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import kotlin.Pair;
import kotlin.collections.ArrayDeque;
import org.apache.tuweni.bytes.Bytes;
import org.apache.tuweni.bytes.Bytes32;
import org.apache.tuweni.units.bigints.UInt256;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** See https://github.com/ethereum/devp2p/blob/master/caps/snap.md */
@SuppressWarnings("unused")
class SnapServer {
private static final Logger LOGGER = LoggerFactory.getLogger(SnapServer.class);
private static final int MAX_ENTRIES_PER_REQUEST = 100000;
private static final int MAX_RESPONSE_SIZE = 2 * 1024 * 1024;
private static final AccountRangeMessage EMPTY_ACCOUNT_RANGE = AccountRangeMessage.create(
new HashMap<>(), new ArrayDeque<>());
private static final AccountRangeMessage EMPTY_ACCOUNT_RANGE =
AccountRangeMessage.create(new HashMap<>(), new ArrayDeque<>());
private static final StorageRangeMessage EMPTY_STORAGE_RANGE =
StorageRangeMessage.create(new ArrayDeque<>(), Collections.emptyList());

private final EthMessages snapMessages;
private final Function<Hash, Optional<WorldStateStorage>> worldStateStorageProvider;
Expand Down Expand Up @@ -77,43 +87,101 @@ private void registerResponseConstructors() {
MessageData constructGetAccountRangeResponse(final MessageData message) {
final GetAccountRangeMessage getAccountRangeMessage = GetAccountRangeMessage.readFrom(message);
final GetAccountRangeMessage.Range range = getAccountRangeMessage.range(true);

final int maxResponseBytes = Math.min(range.responseBytes().intValue(), MAX_RESPONSE_SIZE);

LOGGER.info("Receive get account range message from {} to {}",
range.startKeyHash().toHexString(), range.endKeyHash().toHexString());

var worldStateHash = getAccountRangeMessage
.range(true).worldStateRootHash();

return worldStateStorageProvider.apply(worldStateHash)
.map(storage -> {
NavigableMap<Bytes32, Bytes>
accounts = storage.streamFlatAccounts(range.startKeyHash(), range.endKeyHash(),
MAX_ENTRIES_PER_REQUEST);

if (accounts.isEmpty()) {
// fetch next account after range, if it exists
accounts = storage.streamFlatAccounts(range.endKeyHash(), UInt256.MAX_VALUE, 1L);
}

final var worldStateProof = new WorldStateProofProvider(storage);
final ArrayDeque<Bytes> proof = new ArrayDeque<>(
worldStateProof.getAccountProofRelatedNodes(range.worldStateRootHash(),
Hash.wrap(range.startKeyHash())));
if (!accounts.isEmpty()) {
proof.addAll(worldStateProof.getAccountProofRelatedNodes(range.worldStateRootHash(),
Hash.wrap(accounts.lastKey())));
}
return AccountRangeMessage.create(accounts, proof);

})
// TODO: drop to TRACE
LOGGER.info(
"Receive get account range message from {} to {}",
range.startKeyHash().toHexString(),
range.endKeyHash().toHexString());

var worldStateHash = getAccountRangeMessage.range(true).worldStateRootHash();

return worldStateStorageProvider
.apply(worldStateHash)
.map(
storage -> {
NavigableMap<Bytes32, Bytes> accounts =
storage.streamFlatAccounts(
range.startKeyHash(),
range.endKeyHash(),
takeWhilePredicate(maxResponseBytes));

if (accounts.isEmpty()) {
// fetch next account after range, if it exists
accounts = storage.streamFlatAccounts(range.endKeyHash(), UInt256.MAX_VALUE, 1L);
}

final var worldStateProof = new WorldStateProofProvider(storage);
final ArrayDeque<Bytes> proof =
new ArrayDeque<>(
worldStateProof.getAccountProofRelatedNodes(
range.worldStateRootHash(), Hash.wrap(range.startKeyHash())));
if (!accounts.isEmpty()) {
proof.addAll(
worldStateProof.getAccountProofRelatedNodes(
range.worldStateRootHash(), Hash.wrap(accounts.lastKey())));
}
return AccountRangeMessage.create(accounts, proof);
})
.orElse(EMPTY_ACCOUNT_RANGE);
}

private MessageData constructGetStorageRangeResponse(final MessageData message) {
// TODO implement
return StorageRangeMessage.create(new ArrayDeque<>(), new ArrayDeque<>());
final GetStorageRangeMessage getStorageRangeMessage = GetStorageRangeMessage.readFrom(message);
final GetStorageRangeMessage.StorageRange range = getStorageRangeMessage.range(true);
final int maxResponseBytes = Math.min(range.responseBytes().intValue(), MAX_RESPONSE_SIZE);

Check notice

Code scanning / CodeQL

Unread local variable Note

Variable 'int maxResponseBytes' is never read.

// TODO: drop to TRACE
LOGGER
.atInfo()
.setMessage("Receive get storage range message from {} to {} for {}")
.addArgument(() -> range.startKeyHash().toHexString())
.addArgument(() -> range.endKeyHash())
.addArgument(
() ->
range.hashes().stream()
.map(Bytes32::toHexString)
.collect(Collectors.joining(",", "[", "]")))
.log();

return EMPTY_STORAGE_RANGE;
// return worldStateStorageProvider
// .apply(range.worldStateRootHash())
// .map(
// storage -> {
// NavigableMap<Bytes32, Bytes> accounts =
// storage.streamFlatAccounts(
// range.startKeyHash(), range.endKeyHash(), MAX_ENTRIES_PER_REQUEST);
//
// // for the first account, honor startHash
// Bytes32 startKeyBytes = range.startKeyHash();
// Bytes32 endKeyBytes = range.endKeyHash();
// NavigableMap<Bytes32, Bytes> collectedStorages = new TreeMap<>();
// ArrayList<Bytes> collectedProofs = new ArrayList<>();
// for (var forAccountHash : range.hashes()) {
// var accountStorages =
// storage.streamFlatStorages(
// Hash.wrap(forAccountHash),
// startKeyBytes,
// endKeyBytes,
// MAX_ENTRIES_PER_REQUEST);
// boolean shouldGetMore = false;
//// visitCollectedStorage(collectedStorages, collectedProofs,
// accountStorages, maxResponseBytes);
// // todo proofs for this accountHash
//
// if (shouldGetMore) {
// // reset startkeyBytes for subsequent accounts
// startKeyBytes = Bytes32.ZERO;
// } else {
// break;
// }
// }
//
// return StorageRangeMessage.create(new ArrayDeque<>(), Collections.emptyList());
// })
// .orElse(EMPTY_STORAGE_RANGE);
}

private MessageData constructGetBytecodesResponse(final MessageData message) {

Check notice

Code scanning / CodeQL

Useless parameter Note

The parameter 'message' is never used.
Expand All @@ -125,4 +193,22 @@ private MessageData constructGetTrieNodesResponse(final MessageData message) {
// TODO: what is this expecting? account state tries or storage tries?
return TrieNodesMessage.create(new ArrayDeque<>());
}

private static Predicate<Pair<Bytes32, Bytes>> takeWhilePredicate(final int maxResponseBytes) {
final AtomicInteger byteLimit = new AtomicInteger(0);
final AtomicInteger recordLimit = new AtomicInteger(0);
return pair ->
recordLimit.addAndGet(1) < MAX_ENTRIES_PER_REQUEST
&& byteLimit.accumulateAndGet(
0,
(cur, __) -> {
var rlpOutput = new BytesValueRLPOutput();
rlpOutput.startList();
rlpOutput.writeBytes(pair.getFirst());
rlpOutput.writeRLPBytes(pair.getSecond());
rlpOutput.endList();
return cur + rlpOutput.encoded().size();
})
< maxResponseBytes;
}
}
Loading

0 comments on commit 8179a31

Please sign in to comment.