diff --git a/.changeset/friendly-owls-stare.md b/.changeset/friendly-owls-stare.md new file mode 100644 index 0000000000..1d75ff0b26 --- /dev/null +++ b/.changeset/friendly-owls-stare.md @@ -0,0 +1,5 @@ +--- +'@hyperlane-xyz/core': minor +--- + +Fixed misuse of aggregation hook funds for relaying messages by making sure msg.value is adequate and refunding if excess. diff --git a/solidity/contracts/hooks/aggregation/StaticAggregationHook.sol b/solidity/contracts/hooks/aggregation/StaticAggregationHook.sol index 5146ca2c02..a8e6fc7acd 100644 --- a/solidity/contracts/hooks/aggregation/StaticAggregationHook.sol +++ b/solidity/contracts/hooks/aggregation/StaticAggregationHook.sol @@ -13,11 +13,23 @@ pragma solidity >=0.8.0; @@@@@@@@@ @@@@@@@@@ @@@@@@@@@ @@@@@@@@*/ +// ============ Internal Imports ============ +import {StandardHookMetadata} from "../libs/StandardHookMetadata.sol"; +import {Message} from "../../libs/Message.sol"; +import {TypeCasts} from "../../libs/TypeCasts.sol"; import {AbstractPostDispatchHook} from "../libs/AbstractPostDispatchHook.sol"; import {IPostDispatchHook} from "../../interfaces/hooks/IPostDispatchHook.sol"; import {MetaProxy} from "../../libs/MetaProxy.sol"; +// ============ External Imports ============ +import {Address} from "@openzeppelin/contracts/utils/Address.sol"; + contract StaticAggregationHook is AbstractPostDispatchHook { + using Message for bytes; + using TypeCasts for bytes32; + using StandardHookMetadata for bytes; + using Address for address payable; + // ============ External functions ============ /// @inheritdoc IPostDispatchHook @@ -32,16 +44,29 @@ contract StaticAggregationHook is AbstractPostDispatchHook { ) internal override { address[] memory _hooks = hooks(message); uint256 count = _hooks.length; + uint256 valueRemaining = msg.value; for (uint256 i = 0; i < count; i++) { uint256 quote = IPostDispatchHook(_hooks[i]).quoteDispatch( metadata, message ); + require( + valueRemaining >= quote, + "StaticAggregationHook: insufficient value" + ); IPostDispatchHook(_hooks[i]).postDispatch{value: quote}( metadata, message ); + + valueRemaining -= quote; + } + + if (valueRemaining > 0) { + payable(metadata.refundAddress(message.senderAddress())).sendValue( + valueRemaining + ); } } diff --git a/solidity/test/hooks/AggregationHook.t.sol b/solidity/test/hooks/AggregationHook.t.sol index a37b20f1a1..bbd403ce33 100644 --- a/solidity/test/hooks/AggregationHook.t.sol +++ b/solidity/test/hooks/AggregationHook.t.sol @@ -3,12 +3,16 @@ pragma solidity ^0.8.13; import {Test} from "forge-std/Test.sol"; +import {Message} from "../../contracts/libs/Message.sol"; +import {TypeCasts} from "../../contracts/libs/TypeCasts.sol"; import {StaticAggregationHook} from "../../contracts/hooks/aggregation/StaticAggregationHook.sol"; import {StaticAggregationHookFactory} from "../../contracts/hooks/aggregation/StaticAggregationHookFactory.sol"; import {TestPostDispatchHook} from "../../contracts/test/TestPostDispatchHook.sol"; import {IPostDispatchHook} from "../../contracts/interfaces/hooks/IPostDispatchHook.sol"; contract AggregationHookTest is Test { + using TypeCasts for address; + StaticAggregationHookFactory internal factory; StaticAggregationHook internal hook; @@ -72,6 +76,55 @@ contract AggregationHookTest is Test { hook.postDispatch{value: _msgValue}("", message); } + function test_postDispatch_refundsExcess( + uint8 _hooks, + bytes calldata body + ) public { + uint256 fee = PER_HOOK_GAS_AMOUNT; + address[] memory hooksDeployed = deployHooks(_hooks, fee); + uint256 requiredValue = hooksDeployed.length * fee; + uint256 overpaidValue = requiredValue + 1000; + + vm.prank(address(this)); + + uint256 initialBalance = address(this).balance; + + bytes memory message = Message.formatMessage( + 1, + 0, + 1, + address(this).addressToBytes32(), + 2, + address(this).addressToBytes32(), + body + ); + hook.postDispatch{value: overpaidValue}("", message); + + assertEq(address(hook).balance, 0); + assertEq(address(this).balance, initialBalance - requiredValue); + } + + function testPostDispatch_preventsUsingContractFunds( + uint8 _hooks, + bytes calldata body + ) public { + uint256 fee = PER_HOOK_GAS_AMOUNT; + deployHooks(_hooks, fee); + vm.assume(_hooks > 0); + + vm.prank(address(this)); + + uint256 additionalFunds = 1 ether; + vm.deal(address(hook), additionalFunds); + + bytes memory message = abi.encodePacked("hello world"); + + vm.expectRevert("StaticAggregationHook: insufficient value"); + hook.postDispatch{value: 0}("", message); + + assertEq(address(hook).balance, additionalFunds); + } + function testQuoteDispatch(uint8 _hooks) public { uint256 fee = PER_HOOK_GAS_AMOUNT; address[] memory hooksDeployed = deployHooks(_hooks, fee); @@ -94,4 +147,6 @@ contract AggregationHookTest is Test { deployHooks(1, 0); assertEq(hook.hookType(), uint8(IPostDispatchHook.Types.AGGREGATION)); } + + receive() external payable {} }