Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix re-entrance #108

Merged
merged 9 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions contracts/StorageContract.sol
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,20 @@ abstract contract StorageContract is DecentralizedKV {
/// @notice Treasury address
address public treasury;

/// @notice
/// @notice Prepaid timestamp of last mined
uint256 public prepaidLastMineTime;

/// @notice Locker to prevent from reentrancy
bool private locked;

/// @notice Prevent from reentrancy
modifier noReentrant() {
require(!locked, "StorageContract: No reentrancy allowed!");
syntrust marked this conversation as resolved.
Show resolved Hide resolved
locked = true;
_;
locked = false;
}

// TODO: Reserve extra slots (to a total of 50?) in the storage layout for future upgrades

/// @notice Emitted when a block is mined.
Expand Down Expand Up @@ -239,7 +250,6 @@ abstract contract StorageContract is DecentralizedKV {
MiningLib.update(infos[_shardId], _minedTs, _diff);

require(treasuryReward + minerReward <= address(this).balance, "StorageContract: not enough balance");
// TODO: avoid reentrancy attack
syntrust marked this conversation as resolved.
Show resolved Hide resolved
payable(treasury).transfer(treasuryReward);
payable(_miner).transfer(minerReward);
emit MinedBlock(_shardId, _diff, infos[_shardId].blockMined, _minedTs, _miner, minerReward);
Expand Down Expand Up @@ -307,7 +317,7 @@ abstract contract StorageContract is DecentralizedKV {
bytes calldata _randaoProof,
bytes[] calldata _inclusiveProofs,
bytes[] calldata _decodeProof
) public virtual {
) public virtual noReentrant {
_mine(
_blockNum, _shardId, _miner, _nonce, _encodedSamples, _masks, _randaoProof, _inclusiveProofs, _decodeProof
);
Expand Down
102 changes: 99 additions & 3 deletions contracts/test/StorageContractTest.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@ import "forge-std/Test.sol";
import "forge-std/Vm.sol";

contract StorageContractTest is Test {
uint256 constant STORAGE_COST = 1000;
uint256 constant STORAGE_COST = 10000000;
uint256 constant SHARD_SIZE_BITS = 19;
uint256 constant MAX_KV_SIZE = 17;
uint256 constant PREPAID_AMOUNT = 2 * STORAGE_COST;
TestStorageContract storageContract;

function setUp() public {
storageContract = new TestStorageContract(
StorageContract.Config(MAX_KV_SIZE, SHARD_SIZE_BITS, 2, 0, 0, 0), 0, STORAGE_COST, 0
StorageContract.Config(MAX_KV_SIZE, SHARD_SIZE_BITS, 2, 0, 0, 0),
0,
STORAGE_COST,
340282366367469178095360967382638002176
);
storageContract.initialize(0, PREPAID_AMOUNT, 0, address(0x1), address(0x1));
storageContract.initialize(0, PREPAID_AMOUNT, 0, vm.addr(1), address(0x1));
}

function testMiningReward() public {
Expand All @@ -40,4 +43,97 @@ contract StorageContractTest is Test {
(,, reward) = storageContract.miningRewards(0, 1);
assertEq(reward, storageContract.paymentIn(PREPAID_AMOUNT + STORAGE_COST * 2, 0, 1));
}

function testRewardMiner() public {
address miner = vm.addr(2);
uint256 mineTs = 10000;
uint256 diff = 1;

vm.expectRevert("StorageContract: not enough balance");
storageContract.rewardMiner(0, miner, mineTs, 1);

vm.deal(address(storageContract), 1000);

(,, uint256 reward) = storageContract.miningRewards(0, mineTs);
storageContract.rewardMiner(0, miner, mineTs, diff);
(uint256 l, uint256 d, uint256 b) = storageContract.infos(0);
assertEq(l, mineTs);
assertEq(d, diff);
assertEq(b, 1);
assertEq(miner.balance, reward);
}

function testReentrancy() public {
vm.pauseGasMetering();
uint256 prefund = 1000;
// Without reentrancy protection, the fund could be drained by 29 times re-entrances given current params.
vm.deal(address(storageContract), prefund);
storageContract.setKvEntryCount(1);
Attacker attacker = new Attacker(storageContract);
vm.prank(address(attacker));

uint256 _blockNum = 1;
uint256 _shardId = 0;
uint256 _nonce = 0;
bytes32[] memory _encodedSamples = new bytes32[](0);
uint256[] memory _masks = new uint256[](0);
bytes memory _randaoProof = "0x01";
bytes[] memory _inclusiveProofs = new bytes[](0);
bytes[] memory _decodeProof = new bytes[](0);

vm.expectRevert("StorageContract: No reentrancy allowed!");
storageContract.mine(
_blockNum,
_shardId,
address(attacker),
_nonce,
_encodedSamples,
_masks,
_randaoProof,
_inclusiveProofs,
_decodeProof
);
}
}

contract Attacker {
// cannot access imported vm directly
address internal constant VM_ADDRESS = address(uint160(uint256(keccak256("hevm cheat code"))));
Vm vm = Vm(VM_ADDRESS);
TestStorageContract storageContract;
uint256 blockNumber = 1;
uint256 count = 0;

constructor(TestStorageContract _storageContract) {
storageContract = _storageContract;
}

fallback() external payable {
uint256 _shardId = 0;
uint256 _nonce = 0;
bytes32[] memory _encodedSamples = new bytes32[](0);
uint256[] memory _masks = new uint256[](0);
bytes memory _randaoProof = "0x01";
bytes[] memory _inclusiveProofs = new bytes[](0);
bytes[] memory _decodeProof = new bytes[](0);

blockNumber += 60;
vm.roll(blockNumber + 20);
vm.warp(block.number * 12);
uint256 reward = storageContract.miningReward(_shardId, blockNumber);
if (address(storageContract).balance >= reward) {
storageContract.mine(
blockNumber,
_shardId,
address(this),
_nonce,
_encodedSamples,
_masks,
_randaoProof,
_inclusiveProofs,
_decodeProof
);
count++;
}
}
}
19 changes: 19 additions & 0 deletions contracts/test/TestStorageContract.sol
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,23 @@ contract TestStorageContract is StorageContract {
function miningRewards(uint256 _shardId, uint256 _minedTs) public view returns (bool, uint256, uint256) {
return _miningReward(_shardId, _minedTs);
}

function rewardMiner(uint256 _shardId, address _miner, uint256 _minedTs, uint256 _diff) public {
return _rewardMiner(_shardId, _miner, _minedTs, _diff);
}

function _mine(
uint256 _blockNum,
uint256 _shardId,
address _miner,
uint256 _nonce,
bytes32[] memory _encodedSamples,
uint256[] memory _masks,
bytes calldata _randaoProof,
bytes[] calldata _inclusiveProofs,
bytes[] calldata _decodeProof
) internal override {
uint256 mineTs = _getMinedTs(_blockNum);
_rewardMiner(_shardId, _miner, mineTs, 1);
}
}
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"scripts": {
"install:all": "npm install && forge install",
"compile": "hardhat compile",
"test": "hardhat test && forge test",
"test": "hardhat test && forge test -vvvv",
"prettier:check": "prettier-check contracts/**/*.sol",
"prettier:fix": "prettier --write contracts/**/*.sol test/**/*.js scripts/**/*.js",
"deploy": "npx hardhat run scripts/deploy.js --network sepolia",
Expand Down
Loading