Skip to content

Commit

Permalink
Optimise epoch transition (#5346)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajsutton authored Apr 14, 2022
1 parent e89d37c commit 3efa75f
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 42 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ For information on changes in released versions of Teku, see the [releases page]

### Additions and Improvements
- Improved performance when regenerating non-finalized states that had to be dropped from memory.
- Performance optimizations for Gnosis beacon chain
- Performance optimizations for Gnosis beacon chain.
- Improved performance when processing epoch transitions.

### Bug Fixes
- Added stricter limits on attestation pool size.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ public class EpochTransitionBenchmark {
public void init() throws Exception {
AbstractBlockProcessor.blsVerifyDeposit = false;

spec = TestSpecFactory.createMainnetAltair();
String blocksFile =
"/blocks/blocks_epoch_"
+ spec.getSlotsPerEpoch(UInt64.ZERO)
Expand All @@ -102,7 +103,6 @@ public void init() throws Exception {
BlsKeyPairIO.createReaderForResource(keysFile).readAll(validatorsCount);

final BlockImportNotifications blockImportNotifications = mock(BlockImportNotifications.class);
spec = TestSpecFactory.createMainnetPhase0();
epochProcessor = spec.getGenesisSpec().getEpochProcessor();
wsValidator = WeakSubjectivityFactory.lenientValidator();

Expand Down Expand Up @@ -142,7 +142,7 @@ public void init() throws Exception {
spec.getGenesisSpec()
.getValidatorStatusFactory()
.createValidatorStatuses(preEpochTransitionState);
preEpochTransitionState.updated(mbs -> preEpochTransitionMutableState = mbs);
preEpochTransitionState.hashTreeRoot();
attestationDeltas =
epochProcessor.getRewardAndPenaltyDeltas(preEpochTransitionState, validatorStatuses);

Expand All @@ -153,6 +153,7 @@ public void init() throws Exception {
public void epochTransition(Blackhole bh) {
try {
preEpochTransitionState = epochProcessor.processEpoch(preEpochTransitionState);
bh.consume(preEpochTransitionState.hashTreeRoot());
} catch (EpochProcessingException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import tech.pegasys.teku.spec.datastructures.state.ForkInfo;
import tech.pegasys.teku.spec.datastructures.state.Validator;
import tech.pegasys.teku.spec.datastructures.state.beaconstate.common.BeaconStateFields;
import tech.pegasys.teku.spec.datastructures.state.beaconstate.common.BeaconStateInvariants;
import tech.pegasys.teku.spec.datastructures.state.beaconstate.common.analysis.ValidatorStats;
import tech.pegasys.teku.spec.datastructures.state.beaconstate.versions.altair.BeaconStateAltair;
import tech.pegasys.teku.spec.datastructures.state.beaconstate.versions.bellatrix.BeaconStateBellatrix;
Expand All @@ -43,17 +44,17 @@ public interface BeaconState extends SszContainer, ValidatorStats {
BeaconStateSchema<? extends BeaconState, ? extends MutableBeaconState> getBeaconStateSchema();

default UInt64 getGenesisTime() {
final int fieldIndex = getSchema().getFieldIndex(BeaconStateFields.GENESIS_TIME);
final int fieldIndex = BeaconStateInvariants.GENESIS_TIME_FIELD.getIndex();
return ((SszUInt64) get(fieldIndex)).get();
}

default Bytes32 getGenesisValidatorsRoot() {
final int fieldIndex = getSchema().getFieldIndex(BeaconStateFields.GENESIS_VALIDATORS_ROOT);
final int fieldIndex = BeaconStateInvariants.GENESIS_VALIDATORS_ROOT_FIELD.getIndex();
return ((SszBytes32) get(fieldIndex)).get();
}

default UInt64 getSlot() {
final int fieldIndex = getSchema().getFieldIndex(BeaconStateFields.SLOT);
final int fieldIndex = BeaconStateInvariants.SLOT_FIELD.getIndex();
return ((SszUInt64) get(fieldIndex)).get();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ public class BeaconStateInvariants {
static final SszSchema<SszUInt64> SLOT_SCHEMA = SszPrimitiveSchemas.UINT64_SCHEMA;

// Fields
static final SszField GENESIS_TIME_FIELD =
public static final SszField GENESIS_TIME_FIELD =
new SszField(0, BeaconStateFields.GENESIS_TIME, GENESIS_TIME_SCHEMA);
static final SszField GENESIS_VALIDATORS_ROOT_FIELD =
public static final SszField GENESIS_VALIDATORS_ROOT_FIELD =
new SszField(1, BeaconStateFields.GENESIS_VALIDATORS_ROOT, GENESIS_VALIDATORS_ROOT_SCHEMA);
static final SszField SLOT_FIELD = new SszField(2, BeaconStateFields.SLOT, SLOT_SCHEMA);
public static final SszField SLOT_FIELD = new SszField(2, BeaconStateFields.SLOT, SLOT_SCHEMA);

// Return list of invariant fields
static List<SszField> getInvariantFields() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
import static tech.pegasys.teku.spec.constants.ParticipationFlags.TIMELY_HEAD_FLAG_INDEX;
import static tech.pegasys.teku.spec.logic.versions.altair.helpers.MiscHelpersAltair.PARTICIPATION_FLAG_WEIGHTS;

import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import java.util.List;
import tech.pegasys.teku.infrastructure.unsigned.UInt64;
import tech.pegasys.teku.spec.config.SpecConfigAltair;
import tech.pegasys.teku.spec.constants.ParticipationFlags;
import tech.pegasys.teku.spec.datastructures.state.beaconstate.BeaconState;
import tech.pegasys.teku.spec.datastructures.state.beaconstate.versions.altair.BeaconStateAltair;
import tech.pegasys.teku.spec.logic.common.statetransition.epoch.RewardAndPenaltyDeltas;
import tech.pegasys.teku.spec.logic.common.statetransition.epoch.RewardAndPenaltyDeltas.RewardAndPenalty;
Expand All @@ -39,7 +38,6 @@ public class RewardsAndPenaltiesCalculatorAltair extends RewardsAndPenaltiesCalc
private final BeaconStateAccessorsAltair beaconStateAccessorsAltair;

private final BeaconStateAltair stateAltair;
private final Int2ObjectMap<UInt64> baseRewardCache = new Int2ObjectOpenHashMap<>();

public RewardsAndPenaltiesCalculatorAltair(
final SpecConfigAltair specConfig,
Expand All @@ -53,12 +51,7 @@ public RewardsAndPenaltiesCalculatorAltair(
this.beaconStateAccessorsAltair = beaconStateAccessors;
}

/**
* Return attestation reward/penalty deltas for each validator
*
* @return
* @throws IllegalArgumentException
*/
/** Return attestation reward/penalty deltas for each validator */
@Override
public RewardAndPenaltyDeltas getDeltas() throws IllegalArgumentException {
final RewardAndPenaltyDeltas deltas =
Expand Down Expand Up @@ -92,14 +85,19 @@ public void processFlagIndexDeltas(final RewardAndPenaltyDeltas deltas, final in
final UInt64 activeIncrements =
totalBalances.getCurrentEpochActiveValidators().dividedBy(effectiveBalanceIncrement);

// Cache baseRewardPerIncrement - while it is also cached in transition caches,
// looking it up from there for every single validator is quite expensive.
final UInt64 baseRewardPerIncrement =
beaconStateAccessorsAltair.getBaseRewardPerIncrement(state);
for (int i = 0; i < statusList.size(); i++) {
final ValidatorStatus validator = statusList.get(i);
if (!validator.isEligibleValidator()) {
continue;
}
final RewardAndPenalty validatorDeltas = deltas.getDelta(i);

final UInt64 baseReward = getBaseReward(i);
final UInt64 baseReward =
getBaseReward(effectiveBalanceIncrement, baseRewardPerIncrement, validator);
if (isUnslashedPrevEpochParticipatingIndex(validator, flagIndex)) {
if (!isInactivityLeak()) {
final UInt64 rewardNumerator =
Expand All @@ -113,6 +111,24 @@ public void processFlagIndexDeltas(final RewardAndPenaltyDeltas deltas, final in
}
}

/**
* Calculate the base reward for the validator.
*
* <p>This is equivalent to {@link BeaconStateAccessorsAltair#getBaseReward(BeaconState, int)} but
* uses the ValidatorStatus to get the effective balance and uses the precalculated
* baseRewardPerIncrement. This is significantly faster than having to go back to the state for
* the data.
*/
private UInt64 getBaseReward(
final UInt64 effectiveBalanceIncrement,
final UInt64 baseRewardPerIncrement,
final ValidatorStatus validator) {
return validator
.getCurrentEpochEffectiveBalance()
.dividedBy(effectiveBalanceIncrement)
.times(baseRewardPerIncrement);
}

/**
* Corresponds to altair beacon chain accessor get_inactivity_penalty_deltas
*
Expand Down Expand Up @@ -180,9 +196,4 @@ private boolean isUnslashedPrevEpochParticipatingIndex(
return validatorStatus.isNotSlashed()
&& validatorHasPrevEpochParticipationFlag(validatorStatus, flagIndex);
}

private UInt64 getBaseReward(final int validatorIndex) {
return baseRewardCache.computeIfAbsent(
validatorIndex, index -> beaconStateAccessorsAltair.getBaseReward(state, validatorIndex));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.function.Consumer;
import java.util.stream.Stream;
import tech.pegasys.teku.infrastructure.ssz.SszPrimitive;
import tech.pegasys.teku.infrastructure.ssz.collections.SszMutablePrimitiveCollection;
Expand Down Expand Up @@ -67,23 +67,26 @@ protected TreeUpdates changesToNewNodes(
SszCollectionSchema<?, ?> type = getSchema();
int elementsPerChunk = type.getElementsPerChunk();

List<Map.Entry<Integer, SszElementT>> newChildren = newChildValues.collect(Collectors.toList());
int prevChildNodeIndex = 0;
List<NodeUpdate<ElementT, SszElementT>> nodeUpdates = new ArrayList<>();
NodeUpdate<ElementT, SszElementT> curNodeUpdate = null;

for (Map.Entry<Integer, SszElementT> entry : newChildren) {
int childIndex = entry.getKey();
int childNodeIndex = childIndex / elementsPerChunk;

if (curNodeUpdate == null || childNodeIndex != prevChildNodeIndex) {
long gIndex = type.getChildGeneralizedIndex(childNodeIndex);
curNodeUpdate = new NodeUpdate<>(gIndex, elementsPerChunk);
nodeUpdates.add(curNodeUpdate);
prevChildNodeIndex = childNodeIndex;
}
curNodeUpdate.addUpdate(childIndex % elementsPerChunk, entry.getValue());
}
final List<NodeUpdate<ElementT, SszElementT>> nodeUpdates = new ArrayList<>();
newChildValues.forEach(
new Consumer<>() {
private int prevChildNodeIndex = 0;
private NodeUpdate<ElementT, SszElementT> curNodeUpdate = null;

@Override
public void accept(final Map.Entry<Integer, SszElementT> entry) {
int childIndex = entry.getKey();
int childNodeIndex = childIndex / elementsPerChunk;

if (curNodeUpdate == null || childNodeIndex != prevChildNodeIndex) {
long gIndex = type.getChildGeneralizedIndex(childNodeIndex);
curNodeUpdate = new NodeUpdate<>(gIndex, elementsPerChunk);
nodeUpdates.add(curNodeUpdate);
prevChildNodeIndex = childNodeIndex;
}
curNodeUpdate.addUpdate(childIndex % elementsPerChunk, entry.getValue());
}
});

LongList gIndices = new LongArrayList();
List<TreeNode> newValues = new ArrayList<>();
Expand Down

0 comments on commit 3efa75f

Please sign in to comment.