diff --git a/crates/cheatcodes/assets/cheatcodes.json b/crates/cheatcodes/assets/cheatcodes.json index 41797106d6bb..347d738e01ec 100644 --- a/crates/cheatcodes/assets/cheatcodes.json +++ b/crates/cheatcodes/assets/cheatcodes.json @@ -7182,6 +7182,46 @@ "status": "stable", "safety": "unsafe" }, + { + "func": { + "id": "prank_2", + "description": "Sets the *next* delegate call's `msg.sender` to be the input address.", + "declaration": "function prank(address msgSender, bool delegateCall) external;", + "visibility": "external", + "mutability": "", + "signature": "prank(address,bool)", + "selector": "0xa7f8bf5c", + "selectorBytes": [ + 167, + 248, + 191, + 92 + ] + }, + "group": "evm", + "status": "stable", + "safety": "unsafe" + }, + { + "func": { + "id": "prank_3", + "description": "Sets the *next* delegate call's `msg.sender` to be the input address, and the `tx.origin` to be the second input.", + "declaration": "function prank(address msgSender, address txOrigin, bool delegateCall) external;", + "visibility": "external", + "mutability": "", + "signature": "prank(address,address,bool)", + "selector": "0x7d73d042", + "selectorBytes": [ + 125, + 115, + 208, + 66 + ] + }, + "group": "evm", + "status": "stable", + "safety": "unsafe" + }, { "func": { "id": "prevrandao_0", @@ -9288,6 +9328,46 @@ "status": "stable", "safety": "unsafe" }, + { + "func": { + "id": "startPrank_2", + "description": "Sets all subsequent delegate calls' `msg.sender` to be the input address until `stopPrank` is called.", + "declaration": "function startPrank(address msgSender, bool delegateCall) external;", + "visibility": "external", + "mutability": "", + "signature": "startPrank(address,bool)", + "selector": "0x1cc0b435", + "selectorBytes": [ + 28, + 192, + 180, + 53 + ] + }, + "group": "evm", + "status": "stable", + "safety": "unsafe" + }, + { + "func": { + "id": "startPrank_3", + "description": "Sets all subsequent delegate calls' `msg.sender` to be the input address until `stopPrank` is called, and the `tx.origin` to be the second input.", + "declaration": "function startPrank(address msgSender, address txOrigin, bool delegateCall) external;", + "visibility": "external", + "mutability": "", + "signature": "startPrank(address,address,bool)", + "selector": "0x4eb859b5", + "selectorBytes": [ + 78, + 184, + 89, + 181 + ] + }, + "group": "evm", + "status": "stable", + "safety": "unsafe" + }, { "func": { "id": "startSnapshotGas_0", diff --git a/crates/cheatcodes/spec/src/vm.rs b/crates/cheatcodes/spec/src/vm.rs index 3c11cd53aca3..d703f48e7ae5 100644 --- a/crates/cheatcodes/spec/src/vm.rs +++ b/crates/cheatcodes/spec/src/vm.rs @@ -600,6 +600,22 @@ interface Vm { #[cheatcode(group = Evm, safety = Unsafe)] function startPrank(address msgSender, address txOrigin) external; + /// Sets the *next* delegate call's `msg.sender` to be the input address. + #[cheatcode(group = Evm, safety = Unsafe)] + function prank(address msgSender, bool delegateCall) external; + + /// Sets all subsequent delegate calls' `msg.sender` to be the input address until `stopPrank` is called. + #[cheatcode(group = Evm, safety = Unsafe)] + function startPrank(address msgSender, bool delegateCall) external; + + /// Sets the *next* delegate call's `msg.sender` to be the input address, and the `tx.origin` to be the second input. + #[cheatcode(group = Evm, safety = Unsafe)] + function prank(address msgSender, address txOrigin, bool delegateCall) external; + + /// Sets all subsequent delegate calls' `msg.sender` to be the input address until `stopPrank` is called, and the `tx.origin` to be the second input. + #[cheatcode(group = Evm, safety = Unsafe)] + function startPrank(address msgSender, address txOrigin, bool delegateCall) external; + /// Resets subsequent calls' `msg.sender` to be `address(this)`. #[cheatcode(group = Evm, safety = Unsafe)] function stopPrank() external; diff --git a/crates/cheatcodes/src/evm/prank.rs b/crates/cheatcodes/src/evm/prank.rs index a310e28e515b..1d7ca5a07947 100644 --- a/crates/cheatcodes/src/evm/prank.rs +++ b/crates/cheatcodes/src/evm/prank.rs @@ -16,6 +16,8 @@ pub struct Prank { pub depth: u64, /// Whether the prank stops by itself after the next call pub single_call: bool, + /// Whether the prank should be be applied to delegate call + pub delegate_call: bool, /// Whether the prank has been used yet (false if unused) pub used: bool, } @@ -29,8 +31,18 @@ impl Prank { new_origin: Option
, depth: u64, single_call: bool, + delegate_call: bool, ) -> Self { - Self { prank_caller, prank_origin, new_caller, new_origin, depth, single_call, used: false } + Self { + prank_caller, + prank_origin, + new_caller, + new_origin, + depth, + single_call, + delegate_call, + used: false, + } } /// Apply the prank by setting `used` to true iff it is false @@ -47,28 +59,56 @@ impl Prank { impl Cheatcode for prank_0Call { fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { let Self { msgSender } = self; - prank(ccx, msgSender, None, true) + prank(ccx, msgSender, None, true, false) } } impl Cheatcode for startPrank_0Call { fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { let Self { msgSender } = self; - prank(ccx, msgSender, None, false) + prank(ccx, msgSender, None, false, false) } } impl Cheatcode for prank_1Call { fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { let Self { msgSender, txOrigin } = self; - prank(ccx, msgSender, Some(txOrigin), true) + prank(ccx, msgSender, Some(txOrigin), true, false) } } impl Cheatcode for startPrank_1Call { fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { let Self { msgSender, txOrigin } = self; - prank(ccx, msgSender, Some(txOrigin), false) + prank(ccx, msgSender, Some(txOrigin), false, false) + } +} + +impl Cheatcode for prank_2Call { + fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { + let Self { msgSender, delegateCall } = self; + prank(ccx, msgSender, None, true, *delegateCall) + } +} + +impl Cheatcode for startPrank_2Call { + fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { + let Self { msgSender, delegateCall } = self; + prank(ccx, msgSender, None, false, *delegateCall) + } +} + +impl Cheatcode for prank_3Call { + fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { + let Self { msgSender, txOrigin, delegateCall } = self; + prank(ccx, msgSender, Some(txOrigin), true, *delegateCall) + } +} + +impl Cheatcode for startPrank_3Call { + fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { + let Self { msgSender, txOrigin, delegateCall } = self; + prank(ccx, msgSender, Some(txOrigin), false, *delegateCall) } } @@ -85,6 +125,7 @@ fn prank( new_caller: &Address, new_origin: Option<&Address>, single_call: bool, + delegate_call: bool, ) -> Result { let prank = Prank::new( ccx.caller, @@ -93,8 +134,15 @@ fn prank( new_origin.copied(), ccx.ecx.journaled_state.depth(), single_call, + delegate_call, ); + // Ensure that code exists at `msg.sender` if delegate calling. + if delegate_call { + let code = ccx.code(*new_caller)?; + ensure!(!code.is_empty(), "cannot `prank` delegate call from an EOA"); + } + if let Some(Prank { used, single_call: current_single_call, .. }) = ccx.state.prank { ensure!(used, "cannot overwrite a prank until it is applied at least once"); // This case can only fail if the user calls `vm.startPrank` and then `vm.prank` later on. diff --git a/crates/cheatcodes/src/inspector.rs b/crates/cheatcodes/src/inspector.rs index 684f5eb110d2..046a495a10e8 100644 --- a/crates/cheatcodes/src/inspector.rs +++ b/crates/cheatcodes/src/inspector.rs @@ -42,7 +42,7 @@ use proptest::test_runner::{RngAlgorithm, TestRng, TestRunner}; use rand::Rng; use revm::{ interpreter::{ - opcode as op, CallInputs, CallOutcome, CallScheme, CreateInputs, CreateOutcome, + opcode as op, CallInputs, CallOutcome, CallScheme, CallValue, CreateInputs, CreateOutcome, EOFCreateInputs, EOFCreateKind, Gas, InstructionResult, Interpreter, InterpreterAction, InterpreterResult, }, @@ -941,6 +941,19 @@ where { // Apply our prank if let Some(prank) = &self.prank { + // Apply delegate call, `call.caller`` will not equal `prank.prank_caller` + if let CallScheme::DelegateCall | CallScheme::ExtDelegateCall = call.scheme { + if prank.delegate_call { + call.target_address = prank.new_caller; + call.caller = prank.new_caller; + let acc = ecx.journaled_state.account(prank.new_caller); + call.value = CallValue::Apparent(acc.info.balance); + if let Some(new_origin) = prank.new_origin { + ecx.env.tx.caller = new_origin; + } + } + } + if ecx.journaled_state.depth() >= prank.depth && call.caller == prank.prank_caller { let mut prank_applied = false; diff --git a/testdata/cheats/Vm.sol b/testdata/cheats/Vm.sol index d3a377f403ef..2004c44563d0 100644 --- a/testdata/cheats/Vm.sol +++ b/testdata/cheats/Vm.sol @@ -353,6 +353,8 @@ interface Vm { function pauseTracing() external view; function prank(address msgSender) external; function prank(address msgSender, address txOrigin) external; + function prank(address msgSender, bool delegateCall) external; + function prank(address msgSender, address txOrigin, bool delegateCall) external; function prevrandao(bytes32 newPrevrandao) external; function prevrandao(uint256 newPrevrandao) external; function projectRoot() external view returns (string memory path); @@ -458,6 +460,8 @@ interface Vm { function startMappingRecording() external; function startPrank(address msgSender) external; function startPrank(address msgSender, address txOrigin) external; + function startPrank(address msgSender, bool delegateCall) external; + function startPrank(address msgSender, address txOrigin, bool delegateCall) external; function startSnapshotGas(string calldata name) external; function startSnapshotGas(string calldata group, string calldata name) external; function startStateDiffRecording() external; diff --git a/testdata/default/cheats/Prank.t.sol b/testdata/default/cheats/Prank.t.sol index d833c0513d83..130e819606a2 100644 --- a/testdata/default/cheats/Prank.t.sol +++ b/testdata/default/cheats/Prank.t.sol @@ -85,9 +85,123 @@ contract NestedPranker { } } +contract ImplementationTest { + uint256 public num; + address public sender; + + function assertCorrectCaller(address expectedSender) public { + require(msg.sender == expectedSender); + } + + function assertCorrectOrigin(address expectedOrigin) public { + require(tx.origin == expectedOrigin); + } + + function setNum(uint256 _num) public { + num = _num; + } +} + +contract ProxyTest { + uint256 public num; + address public sender; +} + contract PrankTest is DSTest { Vm constant vm = Vm(HEVM_ADDRESS); + function testPrankDelegateCallPrank2() public { + ProxyTest proxy = new ProxyTest(); + ImplementationTest impl = new ImplementationTest(); + vm.prank(address(proxy), true); + + // Assert correct `msg.sender` + (bool success,) = + address(impl).delegatecall(abi.encodeWithSignature("assertCorrectCaller(address)", address(proxy))); + require(success, "prank2: delegate call failed assertCorrectCaller"); + + // Assert storage updates + uint256 num = 42; + vm.prank(address(proxy), true); + (bool successTwo,) = address(impl).delegatecall(abi.encodeWithSignature("setNum(uint256)", num)); + require(successTwo, "prank2: delegate call failed setNum"); + require(proxy.num() == num, "prank2: proxy's storage was not set correctly"); + vm.stopPrank(); + } + + function testPrankDelegateCallStartPrank2() public { + ProxyTest proxy = new ProxyTest(); + ImplementationTest impl = new ImplementationTest(); + vm.startPrank(address(proxy), true); + + // Assert correct `msg.sender` + (bool success,) = + address(impl).delegatecall(abi.encodeWithSignature("assertCorrectCaller(address)", address(proxy))); + require(success, "startPrank2: delegate call failed assertCorrectCaller"); + + // Assert storage updates + uint256 num = 42; + (bool successTwo,) = address(impl).delegatecall(abi.encodeWithSignature("setNum(uint256)", num)); + require(successTwo, "startPrank2: delegate call failed setNum"); + require(proxy.num() == num, "startPrank2: proxy's storage was not set correctly"); + vm.stopPrank(); + } + + function testPrankDelegateCallPrank3(address origin) public { + ProxyTest proxy = new ProxyTest(); + ImplementationTest impl = new ImplementationTest(); + vm.prank(address(proxy), origin, true); + + // Assert correct `msg.sender` + (bool success,) = + address(impl).delegatecall(abi.encodeWithSignature("assertCorrectCaller(address)", address(proxy))); + require(success, "prank3: delegate call failed assertCorrectCaller"); + + // Assert correct `tx.origin` + vm.prank(address(proxy), origin, true); + (bool successTwo,) = address(impl).delegatecall(abi.encodeWithSignature("assertCorrectOrigin(address)", origin)); + require(successTwo, "prank3: delegate call failed assertCorrectOrigin"); + + // Assert storage updates + uint256 num = 42; + vm.prank(address(proxy), address(origin), true); + (bool successThree,) = address(impl).delegatecall(abi.encodeWithSignature("setNum(uint256)", num)); + require(successThree, "prank3: delegate call failed setNum"); + require(proxy.num() == num, "prank3: proxy's storage was not set correctly"); + vm.stopPrank(); + } + + function testPrankDelegateCallStartPrank3(address origin) public { + ProxyTest proxy = new ProxyTest(); + ImplementationTest impl = new ImplementationTest(); + vm.startPrank(address(proxy), origin, true); + + // Assert correct `msg.sender` + (bool success,) = + address(impl).delegatecall(abi.encodeWithSignature("assertCorrectCaller(address)", address(proxy))); + require(success, "startPrank3: delegate call failed assertCorrectCaller"); + + // Assert correct `tx.origin` + (bool successTwo,) = address(impl).delegatecall(abi.encodeWithSignature("assertCorrectOrigin(address)", origin)); + require(successTwo, "startPrank3: delegate call failed assertCorrectOrigin"); + + // Assert storage updates + uint256 num = 42; + (bool successThree,) = address(impl).delegatecall(abi.encodeWithSignature("setNum(uint256)", num)); + require(successThree, "startPrank3: delegate call failed setNum"); + require(proxy.num() == num, "startPrank3: proxy's storage was not set correctly"); + vm.stopPrank(); + } + + function testFailPrankDelegateCallToEOA() public { + uint256 privateKey = uint256(keccak256(abi.encodePacked("alice"))); + address alice = vm.addr(privateKey); + ImplementationTest impl = new ImplementationTest(); + vm.prank(alice, true); + // Should fail when EOA pranked with delegatecall. + address(impl).delegatecall(abi.encodeWithSignature("assertCorrectCaller(address)", alice)); + } + function testPrankSender(address sender) public { // Perform the prank Victim victim = new Victim();