diff --git a/contracts/StorageContract.sol b/contracts/StorageContract.sol index ffe940a..b6b5c55 100644 --- a/contracts/StorageContract.sol +++ b/contracts/StorageContract.sol @@ -1,6 +1,8 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.0; +// TODO: upgrade OpenZeppelin to next release and import "@openzeppelin/contracts/utils/ReentrancyGuardTransient.sol" +import "./libraries/ReentrancyGuardTransient.sol"; import "./DecentralizedKV.sol"; import "./libraries/MiningLib.sol"; import "./libraries/RandaoLib.sol"; @@ -8,7 +10,7 @@ import "./libraries/RandaoLib.sol"; /// @custom:upgradeable /// @title StorageContract /// @notice EthStorage L1 Contract with Decentralized KV Interface and Proof of Storage Verification -abstract contract StorageContract is DecentralizedKV { +abstract contract StorageContract is DecentralizedKV, ReentrancyGuardTransient { /// @notice Represents the configuration of the storage contract. /// @custom:field maxKvSizeBits Maximum size of a single key-value pair. /// @custom:field shardSizeBits Storage shard size. @@ -78,7 +80,7 @@ abstract contract StorageContract is DecentralizedKV { /// @notice Treasury address address public treasury; - /// @notice + /// @notice Prepaid timestamp of last mined uint256 public prepaidLastMineTime; // TODO: Reserve extra slots (to a total of 50?) in the storage layout for future upgrades @@ -239,7 +241,10 @@ abstract contract StorageContract is DecentralizedKV { MiningLib.update(infos[_shardId], _minedTs, _diff); require(treasuryReward + minerReward <= address(this).balance, "StorageContract: not enough balance"); - // TODO: avoid reentrancy attack + // Actually `transfer` is limited by the amount of gas allocated, which is not sufficient to enable reentrancy attacks. + // However, this behavior may restrict the extensibility of scenarios where the receiver is a contract that requires + // additional gas for its fallback functions of proper operations. + // Therefore, we use `ReentrancyGuard` in case `call` replaces `transfer` in the future. payable(treasury).transfer(treasuryReward); payable(_miner).transfer(minerReward); emit MinedBlock(_shardId, _diff, infos[_shardId].blockMined, _minedTs, _miner, minerReward); @@ -307,7 +312,7 @@ abstract contract StorageContract is DecentralizedKV { bytes calldata _randaoProof, bytes[] calldata _inclusiveProofs, bytes[] calldata _decodeProof - ) public virtual { + ) public virtual nonReentrant { _mine( _blockNum, _shardId, _miner, _nonce, _encodedSamples, _masks, _randaoProof, _inclusiveProofs, _decodeProof ); diff --git a/contracts/libraries/ReentrancyGuardTransient.sol b/contracts/libraries/ReentrancyGuardTransient.sol new file mode 100644 index 0000000..0389b86 --- /dev/null +++ b/contracts/libraries/ReentrancyGuardTransient.sol @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.24; + +import {StorageSlot} from "./StorageSlot.sol"; + +/** + * @dev Variant of {ReentrancyGuard} that uses transient storage. + * + * NOTE: This variant only works on networks where EIP-1153 is available. + */ +abstract contract ReentrancyGuardTransient { + using StorageSlot for *; + + // keccak256(abi.encode(uint256(keccak256("openzeppelin.storage.ReentrancyGuard")) - 1)) & ~bytes32(uint256(0xff)) + bytes32 private constant REENTRANCY_GUARD_STORAGE = + 0x9b779b17422d0df92223018b32b4d1fa46e071723d6817e2486d003becc55f00; + + /** + * @dev Unauthorized reentrant call. + */ + error ReentrancyGuardReentrantCall(); + + /** + * @dev Prevents a contract from calling itself, directly or indirectly. + * Calling a `nonReentrant` function from another `nonReentrant` + * function is not supported. It is possible to prevent this from happening + * by making the `nonReentrant` function external, and making it call a + * `private` function that does the actual work. + */ + modifier nonReentrant() { + _nonReentrantBefore(); + _; + _nonReentrantAfter(); + } + + function _nonReentrantBefore() private { + // On the first call to nonReentrant, _status will be NOT_ENTERED + if (_reentrancyGuardEntered()) { + revert ReentrancyGuardReentrantCall(); + } + + // Any calls to nonReentrant after this point will fail + REENTRANCY_GUARD_STORAGE.asBoolean().tstore(true); + } + + function _nonReentrantAfter() private { + REENTRANCY_GUARD_STORAGE.asBoolean().tstore(false); + } + + /** + * @dev Returns true if the reentrancy guard is currently set to "entered", which indicates there is a + * `nonReentrant` function in the call stack. + */ + function _reentrancyGuardEntered() internal view returns (bool) { + return REENTRANCY_GUARD_STORAGE.asBoolean().tload(); + } +} diff --git a/contracts/libraries/StorageSlot.sol b/contracts/libraries/StorageSlot.sol new file mode 100644 index 0000000..2e4f736 --- /dev/null +++ b/contracts/libraries/StorageSlot.sol @@ -0,0 +1,311 @@ +// SPDX-License-Identifier: MIT +// OpenZeppelin Contracts (last updated v5.0.0) (utils/StorageSlot.sol) +// This file was procedurally generated from scripts/generate/templates/StorageSlot.js. + +pragma solidity ^0.8.24; + +/** + * @dev Library for reading and writing primitive types to specific storage slots. + * + * Storage slots are often used to avoid storage conflict when dealing with upgradeable contracts. + * This library helps with reading and writing to such slots without the need for inline assembly. + * + * The functions in this library return Slot structs that contain a `value` member that can be used to read or write. + * + * Example usage to set ERC-1967 implementation slot: + * ```solidity + * contract ERC1967 { + * // Define the slot. Alternatively, use the SlotDerivation library to derive the slot. + * bytes32 internal constant _IMPLEMENTATION_SLOT = 0x360894a13ba1a3210667c828492db98dca3e2076cc3735a920a3ca505d382bbc; + * + * function _getImplementation() internal view returns (address) { + * return StorageSlot.getAddressSlot(_IMPLEMENTATION_SLOT).value; + * } + * + * function _setImplementation(address newImplementation) internal { + * require(newImplementation.code.length > 0); + * StorageSlot.getAddressSlot(_IMPLEMENTATION_SLOT).value = newImplementation; + * } + * } + * ``` + * + * Since version 5.1, this library also support writing and reading value types to and from transient storage. + * + * * Example using transient storage: + * ```solidity + * contract Lock { + * // Define the slot. Alternatively, use the SlotDerivation library to derive the slot. + * bytes32 internal constant _LOCK_SLOT = 0xf4678858b2b588224636b8522b729e7722d32fc491da849ed75b3fdf3c84f542; + * + * modifier locked() { + * require(!_LOCK_SLOT.asBoolean().tload()); + * + * _LOCK_SLOT.asBoolean().tstore(true); + * _; + * _LOCK_SLOT.asBoolean().tstore(false); + * } + * } + * ``` + * + * TIP: Consider using this library along with {SlotDerivation}. + */ +library StorageSlot { + struct AddressSlot { + address value; + } + + struct BooleanSlot { + bool value; + } + + struct Bytes32Slot { + bytes32 value; + } + + struct Uint256Slot { + uint256 value; + } + + struct Int256Slot { + int256 value; + } + + struct StringSlot { + string value; + } + + struct BytesSlot { + bytes value; + } + + /** + * @dev Returns an `AddressSlot` with member `value` located at `slot`. + */ + function getAddressSlot(bytes32 slot) internal pure returns (AddressSlot storage r) { + assembly ("memory-safe") { + r.slot := slot + } + } + + /** + * @dev Returns a `BooleanSlot` with member `value` located at `slot`. + */ + function getBooleanSlot(bytes32 slot) internal pure returns (BooleanSlot storage r) { + assembly ("memory-safe") { + r.slot := slot + } + } + + /** + * @dev Returns a `Bytes32Slot` with member `value` located at `slot`. + */ + function getBytes32Slot(bytes32 slot) internal pure returns (Bytes32Slot storage r) { + assembly ("memory-safe") { + r.slot := slot + } + } + + /** + * @dev Returns a `Uint256Slot` with member `value` located at `slot`. + */ + function getUint256Slot(bytes32 slot) internal pure returns (Uint256Slot storage r) { + assembly ("memory-safe") { + r.slot := slot + } + } + + /** + * @dev Returns a `Int256Slot` with member `value` located at `slot`. + */ + function getInt256Slot(bytes32 slot) internal pure returns (Int256Slot storage r) { + assembly ("memory-safe") { + r.slot := slot + } + } + + /** + * @dev Returns a `StringSlot` with member `value` located at `slot`. + */ + function getStringSlot(bytes32 slot) internal pure returns (StringSlot storage r) { + assembly ("memory-safe") { + r.slot := slot + } + } + + /** + * @dev Returns an `StringSlot` representation of the string storage pointer `store`. + */ + function getStringSlot(string storage store) internal pure returns (StringSlot storage r) { + assembly ("memory-safe") { + r.slot := store.slot + } + } + + /** + * @dev Returns a `BytesSlot` with member `value` located at `slot`. + */ + function getBytesSlot(bytes32 slot) internal pure returns (BytesSlot storage r) { + assembly ("memory-safe") { + r.slot := slot + } + } + + /** + * @dev Returns an `BytesSlot` representation of the bytes storage pointer `store`. + */ + function getBytesSlot(bytes storage store) internal pure returns (BytesSlot storage r) { + assembly ("memory-safe") { + r.slot := store.slot + } + } + + /** + * @dev UDVT that represent a slot holding a address. + */ + type AddressSlotType is bytes32; + + /** + * @dev Cast an arbitrary slot to a AddressSlotType. + */ + function asAddress(bytes32 slot) internal pure returns (AddressSlotType) { + return AddressSlotType.wrap(slot); + } + + /** + * @dev UDVT that represent a slot holding a bool. + */ + type BooleanSlotType is bytes32; + + /** + * @dev Cast an arbitrary slot to a BooleanSlotType. + */ + function asBoolean(bytes32 slot) internal pure returns (BooleanSlotType) { + return BooleanSlotType.wrap(slot); + } + + /** + * @dev UDVT that represent a slot holding a bytes32. + */ + type Bytes32SlotType is bytes32; + + /** + * @dev Cast an arbitrary slot to a Bytes32SlotType. + */ + function asBytes32(bytes32 slot) internal pure returns (Bytes32SlotType) { + return Bytes32SlotType.wrap(slot); + } + + /** + * @dev UDVT that represent a slot holding a uint256. + */ + type Uint256SlotType is bytes32; + + /** + * @dev Cast an arbitrary slot to a Uint256SlotType. + */ + function asUint256(bytes32 slot) internal pure returns (Uint256SlotType) { + return Uint256SlotType.wrap(slot); + } + + /** + * @dev UDVT that represent a slot holding a int256. + */ + type Int256SlotType is bytes32; + + /** + * @dev Cast an arbitrary slot to a Int256SlotType. + */ + function asInt256(bytes32 slot) internal pure returns (Int256SlotType) { + return Int256SlotType.wrap(slot); + } + + /** + * @dev Load the value held at location `slot` in transient storage. + */ + function tload(AddressSlotType slot) internal view returns (address value) { + assembly ("memory-safe") { + value := tload(slot) + } + } + + /** + * @dev Store `value` at location `slot` in transient storage. + */ + function tstore(AddressSlotType slot, address value) internal { + assembly ("memory-safe") { + tstore(slot, value) + } + } + + /** + * @dev Load the value held at location `slot` in transient storage. + */ + function tload(BooleanSlotType slot) internal view returns (bool value) { + assembly ("memory-safe") { + value := tload(slot) + } + } + + /** + * @dev Store `value` at location `slot` in transient storage. + */ + function tstore(BooleanSlotType slot, bool value) internal { + assembly ("memory-safe") { + tstore(slot, value) + } + } + + /** + * @dev Load the value held at location `slot` in transient storage. + */ + function tload(Bytes32SlotType slot) internal view returns (bytes32 value) { + assembly ("memory-safe") { + value := tload(slot) + } + } + + /** + * @dev Store `value` at location `slot` in transient storage. + */ + function tstore(Bytes32SlotType slot, bytes32 value) internal { + assembly ("memory-safe") { + tstore(slot, value) + } + } + + /** + * @dev Load the value held at location `slot` in transient storage. + */ + function tload(Uint256SlotType slot) internal view returns (uint256 value) { + assembly ("memory-safe") { + value := tload(slot) + } + } + + /** + * @dev Store `value` at location `slot` in transient storage. + */ + function tstore(Uint256SlotType slot, uint256 value) internal { + assembly ("memory-safe") { + tstore(slot, value) + } + } + + /** + * @dev Load the value held at location `slot` in transient storage. + */ + function tload(Int256SlotType slot) internal view returns (int256 value) { + assembly ("memory-safe") { + value := tload(slot) + } + } + + /** + * @dev Store `value` at location `slot` in transient storage. + */ + function tstore(Int256SlotType slot, int256 value) internal { + assembly ("memory-safe") { + tstore(slot, value) + } + } +} diff --git a/contracts/test/StorageContractTest.t.sol b/contracts/test/StorageContractTest.t.sol index 74197c0..85d0814 100644 --- a/contracts/test/StorageContractTest.t.sol +++ b/contracts/test/StorageContractTest.t.sol @@ -1,13 +1,14 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.19; +import "@openzeppelin/contracts/utils/ReentrancyGuard.sol"; import "./TestStorageContract.sol"; import "../StorageContract.sol"; import "forge-std/Test.sol"; import "forge-std/Vm.sol"; contract StorageContractTest is Test { - uint256 constant STORAGE_COST = 1000; + uint256 constant STORAGE_COST = 10000000; uint256 constant SHARD_SIZE_BITS = 19; uint256 constant MAX_KV_SIZE = 17; uint256 constant PREPAID_AMOUNT = 2 * STORAGE_COST; @@ -15,9 +16,12 @@ contract StorageContractTest is Test { function setUp() public { storageContract = new TestStorageContract( - StorageContract.Config(MAX_KV_SIZE, SHARD_SIZE_BITS, 2, 0, 0, 0), 0, STORAGE_COST, 0 + StorageContract.Config(MAX_KV_SIZE, SHARD_SIZE_BITS, 2, 0, 0, 0), + 0, + STORAGE_COST, + 340282366367469178095360967382638002176 ); - storageContract.initialize(0, PREPAID_AMOUNT, 0, address(0x1), address(0x1)); + storageContract.initialize(0, PREPAID_AMOUNT, 0, vm.addr(1), address(0x1)); } function testMiningReward() public { @@ -40,4 +44,94 @@ contract StorageContractTest is Test { (,, reward) = storageContract.miningRewards(0, 1); assertEq(reward, storageContract.paymentIn(PREPAID_AMOUNT + STORAGE_COST * 2, 0, 1)); } + + function testRewardMiner() public { + address miner = vm.addr(2); + uint256 mineTs = 10000; + uint256 diff = 1; + + vm.expectRevert("StorageContract: not enough balance"); + storageContract.rewardMiner(0, miner, mineTs, 1); + + vm.deal(address(storageContract), 1000); + + (,, uint256 reward) = storageContract.miningRewards(0, mineTs); + storageContract.rewardMiner(0, miner, mineTs, diff); + (uint256 l, uint256 d, uint256 b) = storageContract.infos(0); + assertEq(l, mineTs); + assertEq(d, diff); + assertEq(b, 1); + assertEq(miner.balance, reward); + } + + function testReentrancy() public noGasMetering { + uint256 prefund = 1000; + // Without reentrancy protection, the fund could be drained by 29 times re-entrances given current params. + vm.deal(address(storageContract), prefund); + storageContract.setKvEntryCount(1); + Attacker attacker = new Attacker(storageContract); + vm.prank(address(attacker)); + + uint256 _blockNum = 1; + uint256 _shardId = 0; + uint256 _nonce = 0; + bytes32[] memory _encodedSamples = new bytes32[](0); + uint256[] memory _masks = new uint256[](0); + bytes memory _randaoProof = "0x01"; + bytes[] memory _inclusiveProofs = new bytes[](0); + bytes[] memory _decodeProof = new bytes[](0); + // currently this error is not reachable on github server + // vm.expectRevert(ReentrancyGuard.ReentrancyGuardReentrantCall.selector); + vm.expectRevert(); + storageContract.mine( + _blockNum, + _shardId, + address(attacker), + _nonce, + _encodedSamples, + _masks, + _randaoProof, + _inclusiveProofs, + _decodeProof + ); + } +} + +contract Attacker is Test { + TestStorageContract storageContract; + uint256 blockNumber = 1; + uint256 count = 0; + + constructor(TestStorageContract _storageContract) { + storageContract = _storageContract; + } + + fallback() external payable { + uint256 _shardId = 0; + uint256 _nonce = 0; + bytes32[] memory _encodedSamples = new bytes32[](0); + uint256[] memory _masks = new uint256[](0); + bytes memory _randaoProof = "0x01"; + bytes[] memory _inclusiveProofs = new bytes[](0); + bytes[] memory _decodeProof = new bytes[](0); + + blockNumber += 60; + vm.roll(blockNumber + 20); + vm.warp(block.number * 12); + uint256 reward = storageContract.miningReward(_shardId, blockNumber); + if (address(storageContract).balance >= reward) { + storageContract.mine( + blockNumber, + _shardId, + address(this), + _nonce, + _encodedSamples, + _masks, + _randaoProof, + _inclusiveProofs, + _decodeProof + ); + count++; + } + } } diff --git a/contracts/test/TestStorageContract.sol b/contracts/test/TestStorageContract.sol index f6ff4f9..79ca43e 100644 --- a/contracts/test/TestStorageContract.sol +++ b/contracts/test/TestStorageContract.sol @@ -41,4 +41,23 @@ contract TestStorageContract is StorageContract { function miningRewards(uint256 _shardId, uint256 _minedTs) public view returns (bool, uint256, uint256) { return _miningReward(_shardId, _minedTs); } + + function rewardMiner(uint256 _shardId, address _miner, uint256 _minedTs, uint256 _diff) public { + return _rewardMiner(_shardId, _miner, _minedTs, _diff); + } + + function _mine( + uint256 _blockNum, + uint256 _shardId, + address _miner, + uint256 _nonce, + bytes32[] memory _encodedSamples, + uint256[] memory _masks, + bytes calldata _randaoProof, + bytes[] calldata _inclusiveProofs, + bytes[] calldata _decodeProof + ) internal override { + uint256 mineTs = _getMinedTs(_blockNum); + _rewardMiner(_shardId, _miner, mineTs, 1); + } }