Skip to content

Commit

Permalink
evm: Prevent registering peers on the same chain
Browse files Browse the repository at this point in the history
  • Loading branch information
djb15 authored and barnjamin committed Apr 5, 2024
1 parent 6cd7cd7 commit 4086052
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 40 deletions.
3 changes: 3 additions & 0 deletions evm/src/NttManager/NttManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ contract NttManager is INttManager, RateLimiter, ManagerBase {
if (decimals == 0) {
revert InvalidPeerDecimals();
}
if (peerChainId == chainId) {
revert InvalidPeerSameChainId();
}

NttManagerPeer memory oldPeer = _getPeersStorage()[peerChainId];

Expand Down
4 changes: 4 additions & 0 deletions evm/src/interfaces/INttManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ interface INttManager is IManagerBase {
/// @dev Selector 0xbd28e889.
error UnexpectedMsgValue();

/// @notice Peer cannot be on the same chain
/// @dev Selector 0x20371f2a.
error InvalidPeerSameChainId();

/// @notice Transfer a given amount to a recipient on a given chain. This function is called
/// by the user to send the token cross-chain. This function will either lock or burn the
/// sender's tokens. Finally, this function will call into registered `Endpoint` contracts
Expand Down
55 changes: 33 additions & 22 deletions evm/test/NttManager.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ contract TestNttManager is Test, IRateLimiterEvents {

// 0x99'E''T''T'
uint16 constant chainId = 7;
uint16 constant chainId2 = 8;
uint256 constant DEVNET_GUARDIAN_PK =
0xcfb12303a19cde580bb4dd771639b0d26bc68353645571a8cff516ab2ee113a0;
WormholeSimulator guardian;
Expand Down Expand Up @@ -109,7 +110,7 @@ contract TestNttManager is Test, IRateLimiterEvents {
uint8 decimals = t.decimals();

nttManagerZeroRateLimiter.setPeer(
chainId, toWormholeFormat(address(0x1)), 9, type(uint64).max
chainId2, toWormholeFormat(address(0x1)), 9, type(uint64).max
);

t.mintDummy(address(user_A), 5 * 10 ** decimals);
Expand All @@ -119,13 +120,13 @@ contract TestNttManager is Test, IRateLimiterEvents {
t.approve(address(nttManagerZeroRateLimiter), 3 * 10 ** decimals);

uint64 s1 = nttManagerZeroRateLimiter.transfer(
1 * 10 ** decimals, chainId, toWormholeFormat(user_B)
1 * 10 ** decimals, chainId2, toWormholeFormat(user_B)
);
uint64 s2 = nttManagerZeroRateLimiter.transfer(
1 * 10 ** decimals, chainId, toWormholeFormat(user_B)
1 * 10 ** decimals, chainId2, toWormholeFormat(user_B)
);
uint64 s3 = nttManagerZeroRateLimiter.transfer(
1 * 10 ** decimals, chainId, toWormholeFormat(user_B)
1 * 10 ** decimals, chainId2, toWormholeFormat(user_B)
);
vm.stopPrank();

Expand Down Expand Up @@ -336,7 +337,7 @@ contract TestNttManager is Test, IRateLimiterEvents {

uint8 decimals = token.decimals();

newNttManager.setPeer(chainId, toWormholeFormat(address(0x1)), 9, type(uint64).max);
newNttManager.setPeer(chainId2, toWormholeFormat(address(0x1)), 9, type(uint64).max);
newNttManager.setOutboundLimit(packTrimmedAmount(type(uint64).max, 8).untrim(decimals));

token.mintDummy(address(user_A), 5 * 10 ** decimals);
Expand All @@ -348,7 +349,7 @@ contract TestNttManager is Test, IRateLimiterEvents {
vm.expectRevert(abi.encodeWithSelector(IManagerBase.NoEnabledTransceivers.selector));
newNttManager.transfer(
1 * 10 ** decimals,
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
Expand Down Expand Up @@ -398,7 +399,7 @@ contract TestNttManager is Test, IRateLimiterEvents {

uint8 decimals = token.decimals();

nttManager.setPeer(chainId, toWormholeFormat(address(0x1)), 9, type(uint64).max);
nttManager.setPeer(chainId2, toWormholeFormat(address(0x1)), 9, type(uint64).max);
nttManager.setOutboundLimit(packTrimmedAmount(type(uint64).max, 8).untrim(decimals));

token.mintDummy(address(user_A), 5 * 10 ** decimals);
Expand All @@ -418,7 +419,7 @@ contract TestNttManager is Test, IRateLimiterEvents {

nttManager.transfer(
1 * 10 ** decimals,
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
Expand All @@ -434,7 +435,7 @@ contract TestNttManager is Test, IRateLimiterEvents {

uint8 decimals = token.decimals();

nttManager.setPeer(chainId, toWormholeFormat(address(0x1)), 9, type(uint64).max);
nttManager.setPeer(chainId2, toWormholeFormat(address(0x1)), 9, type(uint64).max);
nttManager.setOutboundLimit(0);

token.mintDummy(address(user_A), 5 * 10 ** decimals);
Expand All @@ -448,7 +449,7 @@ contract TestNttManager is Test, IRateLimiterEvents {

uint64 s1 = nttManager.transfer(
1 * 10 ** decimals,
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
true,
Expand Down Expand Up @@ -485,7 +486,7 @@ contract TestNttManager is Test, IRateLimiterEvents {
// The next transfer has previous sequence number + 1
uint64 s2 = nttManager.transfer(
1 * 10 ** decimals,
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
true,
Expand Down Expand Up @@ -683,7 +684,7 @@ contract TestNttManager is Test, IRateLimiterEvents {

uint8 decimals = token.decimals();

nttManager.setPeer(chainId, toWormholeFormat(address(0x1)), 9, type(uint64).max);
nttManager.setPeer(chainId2, toWormholeFormat(address(0x1)), 9, type(uint64).max);
nttManager.setOutboundLimit(packTrimmedAmount(type(uint64).max, 8).untrim(decimals));

token.mintDummy(address(user_A), 5 * 10 ** decimals);
Expand All @@ -694,23 +695,23 @@ contract TestNttManager is Test, IRateLimiterEvents {

uint64 s1 = nttManager.transfer(
1 * 10 ** decimals,
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
new bytes(1)
);
uint64 s2 = nttManager.transfer(
1 * 10 ** decimals,
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
new bytes(1)
);
uint64 s3 = nttManager.transfer(
1 * 10 ** decimals,
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
Expand All @@ -724,7 +725,7 @@ contract TestNttManager is Test, IRateLimiterEvents {

function test_transferWithAmountAndDecimalsThatCouldOverflow() public {
// The source chain has 18 decimals trimmed to 8, and the peer has 6 decimals trimmed to 6
nttManager.setPeer(chainId, toWormholeFormat(address(0x1)), 6, type(uint64).max);
nttManager.setPeer(chainId2, toWormholeFormat(address(0x1)), 6, type(uint64).max);

address user_A = address(0x123);
address user_B = address(0x456);
Expand All @@ -745,13 +746,23 @@ contract TestNttManager is Test, IRateLimiterEvents {

vm.expectRevert("SafeCast: value doesn't fit in 64 bits");
nttManager.transfer(
amount, chainId, toWormholeFormat(user_B), toWormholeFormat(user_A), false, new bytes(1)
amount,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
new bytes(1)
);

// A (slightly) more sensible amount should work normally
amount = (type(uint64).max * 10 ** (decimals - 6 - 2)) - 150000000000; // Subtract this to make sure we don't have dust
nttManager.transfer(
amount, chainId, toWormholeFormat(user_B), toWormholeFormat(user_A), false, new bytes(1)
amount,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
new bytes(1)
);
}

Expand Down Expand Up @@ -958,7 +969,7 @@ contract TestNttManager is Test, IRateLimiterEvents {

uint256 maxAmount = 5 * 10 ** decimals;
token.mintDummy(from, maxAmount);
nttManager.setPeer(chainId, toWormholeFormat(address(0x1)), 9, type(uint64).max);
nttManager.setPeer(chainId2, toWormholeFormat(address(0x1)), 9, type(uint64).max);
nttManager.setOutboundLimit(packTrimmedAmount(type(uint64).max, 8).untrim(decimals));
nttManager.setInboundLimit(
packTrimmedAmount(type(uint64).max, 8).untrim(decimals),
Expand All @@ -983,7 +994,7 @@ contract TestNttManager is Test, IRateLimiterEvents {
);
nttManager.transfer(
amountWithDust,
chainId,
chainId2,
toWormholeFormat(to),
toWormholeFormat(from),
false,
Expand Down Expand Up @@ -1203,7 +1214,7 @@ contract TestNttManager is Test, IRateLimiterEvents {

uint8 decimals = token.decimals();

nttManager.setPeer(chainId, toWormholeFormat(address(0x1)), 9, type(uint64).max);
nttManager.setPeer(chainId2, toWormholeFormat(address(0x1)), 9, type(uint64).max);
nttManager.setOutboundLimit(packTrimmedAmount(type(uint64).max, 8).untrim(decimals));

token.mintDummy(address(user_A), 5 * 10 ** decimals);
Expand All @@ -1217,7 +1228,7 @@ contract TestNttManager is Test, IRateLimiterEvents {
);
nttManager.transfer(
1 * 10 ** decimals,
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
Expand Down
37 changes: 19 additions & 18 deletions evm/test/RateLimit.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ contract TestRateLimit is Test, IRateLimiterEvents {
using BytesParsing for bytes;

uint16 constant chainId = 7;
uint16 constant chainId2 = 8;

uint256 constant DEVNET_GUARDIAN_PK =
0xcfb12303a19cde580bb4dd771639b0d26bc68353645571a8cff516ab2ee113a0;
Expand All @@ -45,7 +46,7 @@ contract TestRateLimit is Test, IRateLimiterEvents {
nttManager = MockNttManagerContract(address(new ERC1967Proxy(address(implementation), "")));
nttManager.initialize();

nttManager.setPeer(chainId, toWormholeFormat(address(0x1)), 9, type(uint64).max);
nttManager.setPeer(chainId2, toWormholeFormat(address(0x1)), 9, type(uint64).max);

DummyTransceiver e = new DummyTransceiver(address(nttManager));
nttManager.setTransceiver(address(e));
Expand Down Expand Up @@ -88,7 +89,7 @@ contract TestRateLimit is Test, IRateLimiterEvents {
token.approve(address(nttManager), transferAmount);
nttManager.transfer(
transferAmount,
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
Expand All @@ -109,7 +110,7 @@ contract TestRateLimit is Test, IRateLimiterEvents {
// assert inbound rate limit for destination chain is still at the max.
// the backflow should not override the limit.
IRateLimiter.RateLimitParams memory inboundLimitParams =
nttManager.getInboundLimitParams(chainId);
nttManager.getInboundLimitParams(chainId2);
assertEq(
inboundLimitParams.currentCapacity.getAmount(), inboundLimitParams.limit.getAmount()
);
Expand All @@ -135,7 +136,7 @@ contract TestRateLimit is Test, IRateLimiterEvents {
token.approve(address(nttManager), transferAmount);
nttManager.transfer(
transferAmount,
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
Expand Down Expand Up @@ -182,7 +183,7 @@ contract TestRateLimit is Test, IRateLimiterEvents {
token.approve(address(nttManager), transferAmount);
nttManager.transfer(
transferAmount,
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
Expand Down Expand Up @@ -224,7 +225,7 @@ contract TestRateLimit is Test, IRateLimiterEvents {
token.approve(address(nttManager), transferAmount);
nttManager.transfer(
transferAmount,
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
Expand Down Expand Up @@ -280,7 +281,7 @@ contract TestRateLimit is Test, IRateLimiterEvents {
token.approve(address(nttManager), transferAmount);
nttManager.transfer(
transferAmount,
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
Expand Down Expand Up @@ -330,7 +331,7 @@ contract TestRateLimit is Test, IRateLimiterEvents {
token.approve(address(nttManager), transferAmount);
nttManager.transfer(
transferAmount,
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
Expand Down Expand Up @@ -391,7 +392,7 @@ contract TestRateLimit is Test, IRateLimiterEvents {
);
nttManager.transfer(
transferAmount,
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
Expand All @@ -417,7 +418,7 @@ contract TestRateLimit is Test, IRateLimiterEvents {
token.approve(address(nttManager), transferAmount);
nttManager.transfer(
transferAmount,
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
Expand Down Expand Up @@ -445,7 +446,7 @@ contract TestRateLimit is Test, IRateLimiterEvents {
);
nttManager.transfer(
badTransferAmount,
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
Expand Down Expand Up @@ -478,7 +479,7 @@ contract TestRateLimit is Test, IRateLimiterEvents {
// transfer with shouldQueue == true
uint64 qSeq = nttManager.transfer(
transferAmount,
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
true,
Expand All @@ -489,7 +490,7 @@ contract TestRateLimit is Test, IRateLimiterEvents {
assertEq(qSeq, 0);
IRateLimiter.OutboundQueuedTransfer memory qt = nttManager.getOutboundQueuedTransfer(0);
assertEq(qt.amount.getAmount(), transferAmount.trim(decimals, decimals).getAmount());
assertEq(qt.recipientChain, chainId);
assertEq(qt.recipientChain, chainId2);
assertEq(qt.recipient, toWormholeFormat(user_B));
assertEq(qt.txTimestamp, initialBlockTimestamp);

Expand Down Expand Up @@ -686,7 +687,7 @@ contract TestRateLimit is Test, IRateLimiterEvents {
// transfer 10 tokens from user_A -> user_B via the nttManager
nttManager.transfer(
transferAmount.untrim(decimals),
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
Expand Down Expand Up @@ -875,7 +876,7 @@ contract TestRateLimit is Test, IRateLimiterEvents {
vm.expectRevert(abi.encodeWithSelector(INttManager.ZeroAmount.selector));
nttManager.transfer(
transferAmount.untrim(decimals),
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
Expand All @@ -888,7 +889,7 @@ contract TestRateLimit is Test, IRateLimiterEvents {
// transfer tokens from user_A -> user_B via the nttManager
nttManager.transfer(
transferAmount.untrim(decimals),
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
false,
Expand Down Expand Up @@ -1018,7 +1019,7 @@ contract TestRateLimit is Test, IRateLimiterEvents {
// shouldQueue == true
uint64 qSeq = nttManager.transfer(
transferAmount,
chainId,
chainId2,
toWormholeFormat(user_B),
toWormholeFormat(user_A),
true,
Expand All @@ -1029,7 +1030,7 @@ contract TestRateLimit is Test, IRateLimiterEvents {
assertEq(qSeq, 0);
IRateLimiter.OutboundQueuedTransfer memory qt = nttManager.getOutboundQueuedTransfer(0);
assertEq(qt.amount.getAmount(), transferAmount.trim(decimals, decimals).getAmount());
assertEq(qt.recipientChain, chainId);
assertEq(qt.recipientChain, chainId2);
assertEq(qt.recipient, toWormholeFormat(user_B));
assertEq(qt.txTimestamp, initialBlockTimestamp);

Expand Down

0 comments on commit 4086052

Please sign in to comment.