From 63484d0a65c56e3378cc3f282ed962d5d499a490 Mon Sep 17 00:00:00 2001
From: Yash Atreya <44857776+yash-atreya@users.noreply.github.com>
Date: Fri, 6 Dec 2024 12:04:05 +0530
Subject: [PATCH] feat(`cheatcodes`): count assertion for `expectRevert`
(#9484)
* expectRevert count overload boilerplate
* introduce `count` variable
* populate `ExpectedRevert` for count overloads
* intro `actual_count` and make ExpectedRevert mut
* increment `actual_account` on success and tests
* handle non-zero count reverts separately
* handle count for specific reverts
* nit
* more tests
* fix: handle count > 1 with reverter specified
* test: ExpectRevertCountWithReverter
* expectRevert with reverter and count 0
* nit
* reverter count with data
* nit
* cleanup
* nit
* nit
* clippy
* nit
* cargo cheats
---
crates/cheatcodes/assets/cheatcodes.json | 120 +++++++++
crates/cheatcodes/spec/src/vm.rs | 24 ++
crates/cheatcodes/src/inspector.rs | 40 ++-
crates/cheatcodes/src/test/expect.rs | 273 ++++++++++++++++-----
testdata/cheats/Vm.sol | 6 +
testdata/default/cheats/ExpectRevert.t.sol | 214 +++++++++++++++-
6 files changed, 609 insertions(+), 68 deletions(-)
diff --git a/crates/cheatcodes/assets/cheatcodes.json b/crates/cheatcodes/assets/cheatcodes.json
index d8f8d21df67b..54556043ac9d 100644
--- a/crates/cheatcodes/assets/cheatcodes.json
+++ b/crates/cheatcodes/assets/cheatcodes.json
@@ -5051,6 +5051,46 @@
"status": "stable",
"safety": "unsafe"
},
+ {
+ "func": {
+ "id": "expectRevert_10",
+ "description": "Expects a `count` number of reverts from the upcoming calls from the reverter address that match the revert data.",
+ "declaration": "function expectRevert(bytes4 revertData, address reverter, uint64 count) external;",
+ "visibility": "external",
+ "mutability": "",
+ "signature": "expectRevert(bytes4,address,uint64)",
+ "selector": "0xb0762d73",
+ "selectorBytes": [
+ 176,
+ 118,
+ 45,
+ 115
+ ]
+ },
+ "group": "testing",
+ "status": "stable",
+ "safety": "unsafe"
+ },
+ {
+ "func": {
+ "id": "expectRevert_11",
+ "description": "Expects a `count` number of reverts from the upcoming calls from the reverter address that exactly match the revert data.",
+ "declaration": "function expectRevert(bytes calldata revertData, address reverter, uint64 count) external;",
+ "visibility": "external",
+ "mutability": "",
+ "signature": "expectRevert(bytes,address,uint64)",
+ "selector": "0xd345fb1f",
+ "selectorBytes": [
+ 211,
+ 69,
+ 251,
+ 31
+ ]
+ },
+ "group": "testing",
+ "status": "stable",
+ "safety": "unsafe"
+ },
{
"func": {
"id": "expectRevert_2",
@@ -5131,6 +5171,86 @@
"status": "stable",
"safety": "unsafe"
},
+ {
+ "func": {
+ "id": "expectRevert_6",
+ "description": "Expects a `count` number of reverts from the upcoming calls with any revert data or reverter.",
+ "declaration": "function expectRevert(uint64 count) external;",
+ "visibility": "external",
+ "mutability": "",
+ "signature": "expectRevert(uint64)",
+ "selector": "0x4ee38244",
+ "selectorBytes": [
+ 78,
+ 227,
+ 130,
+ 68
+ ]
+ },
+ "group": "testing",
+ "status": "stable",
+ "safety": "unsafe"
+ },
+ {
+ "func": {
+ "id": "expectRevert_7",
+ "description": "Expects a `count` number of reverts from the upcoming calls that match the revert data.",
+ "declaration": "function expectRevert(bytes4 revertData, uint64 count) external;",
+ "visibility": "external",
+ "mutability": "",
+ "signature": "expectRevert(bytes4,uint64)",
+ "selector": "0xe45ca72d",
+ "selectorBytes": [
+ 228,
+ 92,
+ 167,
+ 45
+ ]
+ },
+ "group": "testing",
+ "status": "stable",
+ "safety": "unsafe"
+ },
+ {
+ "func": {
+ "id": "expectRevert_8",
+ "description": "Expects a `count` number of reverts from the upcoming calls that exactly match the revert data.",
+ "declaration": "function expectRevert(bytes calldata revertData, uint64 count) external;",
+ "visibility": "external",
+ "mutability": "",
+ "signature": "expectRevert(bytes,uint64)",
+ "selector": "0x4994c273",
+ "selectorBytes": [
+ 73,
+ 148,
+ 194,
+ 115
+ ]
+ },
+ "group": "testing",
+ "status": "stable",
+ "safety": "unsafe"
+ },
+ {
+ "func": {
+ "id": "expectRevert_9",
+ "description": "Expects a `count` number of reverts from the upcoming calls from the reverter address.",
+ "declaration": "function expectRevert(address reverter, uint64 count) external;",
+ "visibility": "external",
+ "mutability": "",
+ "signature": "expectRevert(address,uint64)",
+ "selector": "0x1ff5f952",
+ "selectorBytes": [
+ 31,
+ 245,
+ 249,
+ 82
+ ]
+ },
+ "group": "testing",
+ "status": "stable",
+ "safety": "unsafe"
+ },
{
"func": {
"id": "expectSafeMemory",
diff --git a/crates/cheatcodes/spec/src/vm.rs b/crates/cheatcodes/spec/src/vm.rs
index 6b66d31ddb50..6954fd1e526d 100644
--- a/crates/cheatcodes/spec/src/vm.rs
+++ b/crates/cheatcodes/spec/src/vm.rs
@@ -1019,6 +1019,30 @@ interface Vm {
#[cheatcode(group = Testing, safety = Unsafe)]
function expectRevert(bytes calldata revertData, address reverter) external;
+ /// Expects a `count` number of reverts from the upcoming calls with any revert data or reverter.
+ #[cheatcode(group = Testing, safety = Unsafe)]
+ function expectRevert(uint64 count) external;
+
+ /// Expects a `count` number of reverts from the upcoming calls that match the revert data.
+ #[cheatcode(group = Testing, safety = Unsafe)]
+ function expectRevert(bytes4 revertData, uint64 count) external;
+
+ /// Expects a `count` number of reverts from the upcoming calls that exactly match the revert data.
+ #[cheatcode(group = Testing, safety = Unsafe)]
+ function expectRevert(bytes calldata revertData, uint64 count) external;
+
+ /// Expects a `count` number of reverts from the upcoming calls from the reverter address.
+ #[cheatcode(group = Testing, safety = Unsafe)]
+ function expectRevert(address reverter, uint64 count) external;
+
+ /// Expects a `count` number of reverts from the upcoming calls from the reverter address that match the revert data.
+ #[cheatcode(group = Testing, safety = Unsafe)]
+ function expectRevert(bytes4 revertData, address reverter, uint64 count) external;
+
+ /// Expects a `count` number of reverts from the upcoming calls from the reverter address that exactly match the revert data.
+ #[cheatcode(group = Testing, safety = Unsafe)]
+ function expectRevert(bytes calldata revertData, address reverter, uint64 count) external;
+
/// Expects an error on next call that starts with the revert data.
#[cheatcode(group = Testing, safety = Unsafe)]
function expectPartialRevert(bytes4 revertData) external;
diff --git a/crates/cheatcodes/src/inspector.rs b/crates/cheatcodes/src/inspector.rs
index 4e2b91b9dd2d..329b89d0ebc2 100644
--- a/crates/cheatcodes/src/inspector.rs
+++ b/crates/cheatcodes/src/inspector.rs
@@ -754,16 +754,23 @@ where {
if ecx.journaled_state.depth() <= expected_revert.depth &&
matches!(expected_revert.kind, ExpectedRevertKind::Default)
{
- let expected_revert = std::mem::take(&mut self.expected_revert).unwrap();
- return match expect::handle_expect_revert(
+ let mut expected_revert = std::mem::take(&mut self.expected_revert).unwrap();
+ let handler_result = expect::handle_expect_revert(
false,
true,
- &expected_revert,
+ &mut expected_revert,
outcome.result.result,
outcome.result.output.clone(),
&self.config.available_artifacts,
- ) {
+ );
+
+ return match handler_result {
Ok((address, retdata)) => {
+ expected_revert.actual_count += 1;
+ if expected_revert.actual_count < expected_revert.count {
+ self.expected_revert = Some(expected_revert.clone());
+ }
+
outcome.result.result = InstructionResult::Return;
outcome.result.output = retdata;
outcome.address = address;
@@ -1302,6 +1309,14 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes {
expected_revert.reverted_by.is_none()
{
expected_revert.reverted_by = Some(call.target_address);
+ } else if outcome.result.is_revert() &&
+ expected_revert.reverter.is_some() &&
+ expected_revert.reverted_by.is_some() &&
+ expected_revert.count > 1
+ {
+ // If we're expecting more than one revert, we need to reset the reverted_by address
+ // to latest reverter.
+ expected_revert.reverted_by = Some(call.target_address);
}
if ecx.journaled_state.depth() <= expected_revert.depth {
@@ -1315,15 +1330,20 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes {
};
if needs_processing {
- let expected_revert = std::mem::take(&mut self.expected_revert).unwrap();
- return match expect::handle_expect_revert(
+ // Only `remove` the expected revert from state if `expected_revert.count` ==
+ // `expected_revert.actual_count`
+ let mut expected_revert = std::mem::take(&mut self.expected_revert).unwrap();
+
+ let handler_result = expect::handle_expect_revert(
cheatcode_call,
false,
- &expected_revert,
+ &mut expected_revert,
outcome.result.result,
outcome.result.output.clone(),
&self.config.available_artifacts,
- ) {
+ );
+
+ return match handler_result {
Err(error) => {
trace!(expected=?expected_revert, ?error, status=?outcome.result.result, "Expected revert mismatch");
outcome.result.result = InstructionResult::Revert;
@@ -1331,6 +1351,10 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes {
outcome
}
Ok((_, retdata)) => {
+ expected_revert.actual_count += 1;
+ if expected_revert.actual_count < expected_revert.count {
+ self.expected_revert = Some(expected_revert.clone());
+ }
outcome.result.result = InstructionResult::Return;
outcome.result.output = retdata;
outcome
diff --git a/crates/cheatcodes/src/test/expect.rs b/crates/cheatcodes/src/test/expect.rs
index 3ee58407a7e6..a3ddef8c16c9 100644
--- a/crates/cheatcodes/src/test/expect.rs
+++ b/crates/cheatcodes/src/test/expect.rs
@@ -87,6 +87,10 @@ pub struct ExpectedRevert {
pub reverter: Option
,
/// Actual reverter of the call.
pub reverted_by: Option,
+ /// Number of times this revert is expected.
+ pub count: u64,
+ /// Actual number of times this revert has been seen.
+ pub actual_count: u64,
}
#[derive(Clone, Debug)]
@@ -295,7 +299,7 @@ impl Cheatcode for expectEmitAnonymous_3Call {
impl Cheatcode for expectRevert_0Call {
fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result {
let Self {} = self;
- expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), false, false, None)
+ expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), false, false, None, 1)
}
}
@@ -309,6 +313,7 @@ impl Cheatcode for expectRevert_1Call {
false,
false,
None,
+ 1,
)
}
}
@@ -323,6 +328,7 @@ impl Cheatcode for expectRevert_2Call {
false,
false,
None,
+ 1,
)
}
}
@@ -337,6 +343,7 @@ impl Cheatcode for expectRevert_3Call {
false,
false,
Some(*reverter),
+ 1,
)
}
}
@@ -351,6 +358,7 @@ impl Cheatcode for expectRevert_4Call {
false,
false,
Some(*reverter),
+ 1,
)
}
}
@@ -365,6 +373,89 @@ impl Cheatcode for expectRevert_5Call {
false,
false,
Some(*reverter),
+ 1,
+ )
+ }
+}
+
+impl Cheatcode for expectRevert_6Call {
+ fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result {
+ let Self { count } = self;
+ expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), false, false, None, *count)
+ }
+}
+
+impl Cheatcode for expectRevert_7Call {
+ fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result {
+ let Self { revertData, count } = self;
+ expect_revert(
+ ccx.state,
+ Some(revertData.as_ref()),
+ ccx.ecx.journaled_state.depth(),
+ false,
+ false,
+ None,
+ *count,
+ )
+ }
+}
+
+impl Cheatcode for expectRevert_8Call {
+ fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result {
+ let Self { revertData, count } = self;
+ expect_revert(
+ ccx.state,
+ Some(revertData),
+ ccx.ecx.journaled_state.depth(),
+ false,
+ false,
+ None,
+ *count,
+ )
+ }
+}
+
+impl Cheatcode for expectRevert_9Call {
+ fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result {
+ let Self { reverter, count } = self;
+ expect_revert(
+ ccx.state,
+ None,
+ ccx.ecx.journaled_state.depth(),
+ false,
+ false,
+ Some(*reverter),
+ *count,
+ )
+ }
+}
+
+impl Cheatcode for expectRevert_10Call {
+ fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result {
+ let Self { revertData, reverter, count } = self;
+ expect_revert(
+ ccx.state,
+ Some(revertData.as_ref()),
+ ccx.ecx.journaled_state.depth(),
+ false,
+ false,
+ Some(*reverter),
+ *count,
+ )
+ }
+}
+
+impl Cheatcode for expectRevert_11Call {
+ fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result {
+ let Self { revertData, reverter, count } = self;
+ expect_revert(
+ ccx.state,
+ Some(revertData),
+ ccx.ecx.journaled_state.depth(),
+ false,
+ false,
+ Some(*reverter),
+ *count,
)
}
}
@@ -379,6 +470,7 @@ impl Cheatcode for expectPartialRevert_0Call {
false,
true,
None,
+ 1,
)
}
}
@@ -393,13 +485,14 @@ impl Cheatcode for expectPartialRevert_1Call {
false,
true,
Some(*reverter),
+ 1,
)
}
}
impl Cheatcode for _expectCheatcodeRevert_0Call {
fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result {
- expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), true, false, None)
+ expect_revert(ccx.state, None, ccx.ecx.journaled_state.depth(), true, false, None, 1)
}
}
@@ -413,6 +506,7 @@ impl Cheatcode for _expectCheatcodeRevert_1Call {
true,
false,
None,
+ 1,
)
}
}
@@ -427,6 +521,7 @@ impl Cheatcode for _expectCheatcodeRevert_2Call {
true,
false,
None,
+ 1,
)
}
}
@@ -662,6 +757,7 @@ fn expect_revert(
cheatcode: bool,
partial_match: bool,
reverter: Option,
+ count: u64,
) -> Result {
ensure!(
state.expected_revert.is_none(),
@@ -678,6 +774,8 @@ fn expect_revert(
partial_match,
reverter,
reverted_by: None,
+ count,
+ actual_count: 0,
});
Ok(Default::default())
}
@@ -685,7 +783,7 @@ fn expect_revert(
pub(crate) fn handle_expect_revert(
is_cheatcode: bool,
is_create: bool,
- expected_revert: &ExpectedRevert,
+ expected_revert: &mut ExpectedRevert,
status: InstructionResult,
retdata: Bytes,
known_contracts: &Option,
@@ -698,72 +796,117 @@ pub(crate) fn handle_expect_revert(
}
};
- ensure!(!matches!(status, return_ok!()), "next call did not revert as expected");
-
- // If expected reverter address is set then check it matches the actual reverter.
- if let (Some(expected_reverter), Some(actual_reverter)) =
- (expected_revert.reverter, expected_revert.reverted_by)
- {
- if expected_reverter != actual_reverter {
- return Err(fmt_err!(
- "Reverter != expected reverter: {} != {}",
- actual_reverter,
- expected_reverter
- ));
+ let stringify = |data: &[u8]| {
+ if let Ok(s) = String::abi_decode(data, true) {
+ return s;
}
- }
-
- let expected_reason = expected_revert.reason.as_deref();
- // If None, accept any revert.
- let Some(expected_reason) = expected_reason else {
- return Ok(success_return());
+ if data.is_ascii() {
+ return std::str::from_utf8(data).unwrap().to_owned();
+ }
+ hex::encode_prefixed(data)
};
- if !expected_reason.is_empty() && retdata.is_empty() {
- bail!("call reverted as expected, but without data");
- }
-
- let mut actual_revert: Vec = retdata.into();
+ if expected_revert.count == 0 {
+ if expected_revert.reverter.is_none() && expected_revert.reason.is_none() {
+ ensure!(
+ matches!(status, return_ok!()),
+ "call reverted when it was expected not to revert"
+ );
+ return Ok(success_return());
+ }
- // Compare only the first 4 bytes if partial match.
- if expected_revert.partial_match && actual_revert.get(..4) == expected_reason.get(..4) {
- return Ok(success_return())
- }
+ // Flags to track if the reason and reverter match.
+ let mut reason_match = expected_revert.reason.as_ref().map(|_| false);
+ let mut reverter_match = expected_revert.reverter.as_ref().map(|_| false);
- // Try decoding as known errors.
- if matches!(
- actual_revert.get(..4).map(|s| s.try_into().unwrap()),
- Some(Vm::CheatcodeError::SELECTOR | alloy_sol_types::Revert::SELECTOR)
- ) {
- if let Ok(decoded) = Vec::::abi_decode(&actual_revert[4..], false) {
- actual_revert = decoded;
+ // Reverter check
+ if let (Some(expected_reverter), Some(actual_reverter)) =
+ (expected_revert.reverter, expected_revert.reverted_by)
+ {
+ if expected_reverter == actual_reverter {
+ reverter_match = Some(true);
+ }
}
- }
- if actual_revert == expected_reason ||
- (is_cheatcode && memchr::memmem::find(&actual_revert, expected_reason).is_some())
- {
- Ok(success_return())
+ // Reason check
+ let expected_reason = expected_revert.reason.as_deref();
+ if let Some(expected_reason) = expected_reason {
+ let mut actual_revert: Vec = retdata.into();
+ actual_revert = decode_revert(actual_revert);
+
+ if actual_revert == expected_reason {
+ reason_match = Some(true);
+ }
+ };
+
+ match (reason_match, reverter_match) {
+ (Some(true), Some(true)) => Err(fmt_err!(
+ "expected 0 reverts with reason: {}, from address: {}, but got one",
+ &stringify(expected_reason.unwrap_or_default()),
+ expected_revert.reverter.unwrap()
+ )),
+ (Some(true), None) => Err(fmt_err!(
+ "expected 0 reverts with reason: {}, but got one",
+ &stringify(expected_reason.unwrap_or_default())
+ )),
+ (None, Some(true)) => Err(fmt_err!(
+ "expected 0 reverts from address: {}, but got one",
+ expected_revert.reverter.unwrap()
+ )),
+ _ => Ok(success_return()),
+ }
} else {
- let (actual, expected) = if let Some(contracts) = known_contracts {
- let decoder = RevertDecoder::new().with_abis(contracts.iter().map(|(_, c)| &c.abi));
- (
- &decoder.decode(actual_revert.as_slice(), Some(status)),
- &decoder.decode(expected_reason, Some(status)),
- )
+ ensure!(!matches!(status, return_ok!()), "next call did not revert as expected");
+
+ // If expected reverter address is set then check it matches the actual reverter.
+ if let (Some(expected_reverter), Some(actual_reverter)) =
+ (expected_revert.reverter, expected_revert.reverted_by)
+ {
+ if expected_reverter != actual_reverter {
+ return Err(fmt_err!(
+ "Reverter != expected reverter: {} != {}",
+ actual_reverter,
+ expected_reverter
+ ));
+ }
+ }
+
+ let expected_reason = expected_revert.reason.as_deref();
+ // If None, accept any revert.
+ let Some(expected_reason) = expected_reason else {
+ return Ok(success_return());
+ };
+
+ if !expected_reason.is_empty() && retdata.is_empty() {
+ bail!("call reverted as expected, but without data");
+ }
+
+ let mut actual_revert: Vec = retdata.into();
+
+ // Compare only the first 4 bytes if partial match.
+ if expected_revert.partial_match && actual_revert.get(..4) == expected_reason.get(..4) {
+ return Ok(success_return())
+ }
+
+ // Try decoding as known errors.
+ actual_revert = decode_revert(actual_revert);
+
+ if actual_revert == expected_reason ||
+ (is_cheatcode && memchr::memmem::find(&actual_revert, expected_reason).is_some())
+ {
+ Ok(success_return())
} else {
- let stringify = |data: &[u8]| {
- if let Ok(s) = String::abi_decode(data, true) {
- return s;
- }
- if data.is_ascii() {
- return std::str::from_utf8(data).unwrap().to_owned();
- }
- hex::encode_prefixed(data)
+ let (actual, expected) = if let Some(contracts) = known_contracts {
+ let decoder = RevertDecoder::new().with_abis(contracts.iter().map(|(_, c)| &c.abi));
+ (
+ &decoder.decode(actual_revert.as_slice(), Some(status)),
+ &decoder.decode(expected_reason, Some(status)),
+ )
+ } else {
+ (&stringify(&actual_revert), &stringify(expected_reason))
};
- (&stringify(&actual_revert), &stringify(expected_reason))
- };
- Err(fmt_err!("Error != expected error: {} != {}", actual, expected,))
+ Err(fmt_err!("Error != expected error: {} != {}", actual, expected,))
+ }
}
}
@@ -774,3 +917,15 @@ fn expect_safe_memory(state: &mut Cheatcodes, start: u64, end: u64, depth: u64)
offsets.push(start..end);
Ok(Default::default())
}
+
+fn decode_revert(revert: Vec) -> Vec {
+ if matches!(
+ revert.get(..4).map(|s| s.try_into().unwrap()),
+ Some(Vm::CheatcodeError::SELECTOR | alloy_sol_types::Revert::SELECTOR)
+ ) {
+ if let Ok(decoded) = Vec::::abi_decode(&revert[4..], false) {
+ return decoded;
+ }
+ }
+ revert
+}
diff --git a/testdata/cheats/Vm.sol b/testdata/cheats/Vm.sol
index f6f66969f99f..b3746b6bc319 100644
--- a/testdata/cheats/Vm.sol
+++ b/testdata/cheats/Vm.sol
@@ -246,10 +246,16 @@ interface Vm {
function expectPartialRevert(bytes4 revertData, address reverter) external;
function expectRevert() external;
function expectRevert(bytes4 revertData) external;
+ function expectRevert(bytes4 revertData, address reverter, uint64 count) external;
+ function expectRevert(bytes calldata revertData, address reverter, uint64 count) external;
function expectRevert(bytes calldata revertData) external;
function expectRevert(address reverter) external;
function expectRevert(bytes4 revertData, address reverter) external;
function expectRevert(bytes calldata revertData, address reverter) external;
+ function expectRevert(uint64 count) external;
+ function expectRevert(bytes4 revertData, uint64 count) external;
+ function expectRevert(bytes calldata revertData, uint64 count) external;
+ function expectRevert(address reverter, uint64 count) external;
function expectSafeMemory(uint64 min, uint64 max) external;
function expectSafeMemoryCall(uint64 min, uint64 max) external;
function fee(uint256 newBasefee) external;
diff --git a/testdata/default/cheats/ExpectRevert.t.sol b/testdata/default/cheats/ExpectRevert.t.sol
index 18a90bac6e29..fef4ebaf5790 100644
--- a/testdata/default/cheats/ExpectRevert.t.sol
+++ b/testdata/default/cheats/ExpectRevert.t.sol
@@ -30,6 +30,10 @@ contract Reverter {
revert(message);
}
+ function callThenNoRevert(Dummy dummy) public pure {
+ dummy.callMe();
+ }
+
function revertWithoutReason() public pure {
revert();
}
@@ -188,7 +192,7 @@ contract ExpectRevertTest is DSTest {
}
function testexpectCheatcodeRevert() public {
- vm._expectCheatcodeRevert("JSON value at \".a\" is not an object");
+ vm._expectCheatcodeRevert('JSON value at ".a" is not an object');
vm.parseJsonKeys('{"a": "b"}', ".a");
}
@@ -351,3 +355,211 @@ contract ExpectRevertWithReverterTest is DSTest {
aContract.callAndRevert();
}
}
+
+contract ExpectRevertCount is DSTest {
+ Vm constant vm = Vm(HEVM_ADDRESS);
+
+ function testRevertCountAny() public {
+ uint64 count = 3;
+ Reverter reverter = new Reverter();
+ vm.expectRevert(count);
+ reverter.revertWithMessage("revert");
+ reverter.revertWithMessage("revert2");
+ reverter.revertWithMessage("revert3");
+
+ vm.expectRevert("revert");
+ reverter.revertWithMessage("revert");
+ }
+
+ function testFailRevertCountAny() public {
+ uint64 count = 3;
+ Reverter reverter = new Reverter();
+ vm.expectRevert(count);
+ reverter.revertWithMessage("revert");
+ reverter.revertWithMessage("revert2");
+ }
+
+ function testNoRevert() public {
+ uint64 count = 0;
+ Reverter reverter = new Reverter();
+ vm.expectRevert(count);
+ reverter.doNotRevert();
+ }
+
+ function testFailNoRevert() public {
+ uint64 count = 0;
+ Reverter reverter = new Reverter();
+ vm.expectRevert(count);
+ reverter.revertWithMessage("revert");
+ }
+
+ function testRevertCountSpecific() public {
+ uint64 count = 2;
+ Reverter reverter = new Reverter();
+ vm.expectRevert("revert", count);
+ reverter.revertWithMessage("revert");
+ reverter.revertWithMessage("revert");
+ }
+
+ function testFailReverCountSpecifc() public {
+ uint64 count = 2;
+ Reverter reverter = new Reverter();
+ vm.expectRevert("revert", count);
+ reverter.revertWithMessage("revert");
+ reverter.revertWithMessage("second-revert");
+ }
+
+ function testNoRevertSpecific() public {
+ uint64 count = 0;
+ Reverter reverter = new Reverter();
+ vm.expectRevert("revert", count);
+ reverter.doNotRevert();
+ }
+
+ function testFailNoRevertSpecific() public {
+ uint64 count = 0;
+ Reverter reverter = new Reverter();
+ vm.expectRevert("revert", count);
+ reverter.revertWithMessage("revert");
+ }
+
+ function testNoRevertSpecificButDiffRevert() public {
+ uint64 count = 0;
+ Reverter reverter = new Reverter();
+ vm.expectRevert("revert", count);
+ reverter.revertWithMessage("revert2");
+ }
+
+ function testRevertCountWithConstructor() public {
+ uint64 count = 1;
+ vm.expectRevert("constructor revert", count);
+ new ConstructorReverter("constructor revert");
+ }
+
+ function testNoRevertWithConstructor() public {
+ uint64 count = 0;
+ vm.expectRevert("constructor revert", count);
+ new CContract();
+ }
+
+ function testRevertCountNestedSpecific() public {
+ uint64 count = 2;
+ Reverter reverter = new Reverter();
+ Reverter inner = new Reverter();
+
+ vm.expectRevert("nested revert", count);
+ reverter.revertWithMessage("nested revert");
+ reverter.nestedRevert(inner, "nested revert");
+
+ vm.expectRevert("nested revert", count);
+ reverter.nestedRevert(inner, "nested revert");
+ reverter.nestedRevert(inner, "nested revert");
+ }
+
+ function testRevertCountCallsThenReverts() public {
+ uint64 count = 2;
+ Reverter reverter = new Reverter();
+ Dummy dummy = new Dummy();
+
+ vm.expectRevert("called a function and then reverted", count);
+ reverter.callThenRevert(dummy, "called a function and then reverted");
+ reverter.callThenRevert(dummy, "called a function and then reverted");
+ }
+
+ function testFailRevertCountCallsThenReverts() public {
+ uint64 count = 2;
+ Reverter reverter = new Reverter();
+ Dummy dummy = new Dummy();
+
+ vm.expectRevert("called a function and then reverted", count);
+ reverter.callThenRevert(dummy, "called a function and then reverted");
+ reverter.callThenRevert(dummy, "wrong revert");
+ }
+
+ function testNoRevertCall() public {
+ uint64 count = 0;
+ Reverter reverter = new Reverter();
+ Dummy dummy = new Dummy();
+
+ vm.expectRevert("called a function and then reverted", count);
+ reverter.callThenNoRevert(dummy);
+ }
+}
+
+contract ExpectRevertCountWithReverter is DSTest {
+ Vm constant vm = Vm(HEVM_ADDRESS);
+
+ function testRevertCountWithReverter() public {
+ uint64 count = 2;
+ Reverter reverter = new Reverter();
+ vm.expectRevert(address(reverter), count);
+ reverter.revertWithMessage("revert");
+ reverter.revertWithMessage("revert");
+ }
+
+ function testFailRevertCountWithReverter() public {
+ uint64 count = 2;
+ Reverter reverter = new Reverter();
+ Reverter reverter2 = new Reverter();
+ vm.expectRevert(address(reverter), count);
+ reverter.revertWithMessage("revert");
+ reverter2.revertWithMessage("revert");
+ }
+
+ function testNoRevertWithReverter() public {
+ uint64 count = 0;
+ Reverter reverter = new Reverter();
+ vm.expectRevert(address(reverter), count);
+ reverter.doNotRevert();
+ }
+
+ function testNoRevertWithWrongReverter() public {
+ uint64 count = 0;
+ Reverter reverter = new Reverter();
+ Reverter reverter2 = new Reverter();
+ vm.expectRevert(address(reverter), count);
+ reverter2.revertWithMessage("revert"); // revert from wrong reverter
+ }
+
+ function testFailNoRevertWithReverter() public {
+ uint64 count = 0;
+ Reverter reverter = new Reverter();
+ vm.expectRevert(address(reverter), count);
+ reverter.revertWithMessage("revert");
+ }
+
+ function testReverterCountWithData() public {
+ uint64 count = 2;
+ Reverter reverter = new Reverter();
+ vm.expectRevert("revert", address(reverter), count);
+ reverter.revertWithMessage("revert");
+ reverter.revertWithMessage("revert");
+ }
+
+ function testFailReverterCountWithWrongData() public {
+ uint64 count = 2;
+ Reverter reverter = new Reverter();
+ vm.expectRevert("revert", address(reverter), count);
+ reverter.revertWithMessage("revert");
+ reverter.revertWithMessage("wrong revert");
+ }
+
+ function testFailWrongReverterCountWithData() public {
+ uint64 count = 2;
+ Reverter reverter = new Reverter();
+ Reverter reverter2 = new Reverter();
+ vm.expectRevert("revert", address(reverter), count);
+ reverter.revertWithMessage("revert");
+ reverter2.revertWithMessage("revert");
+ }
+
+ function testNoReverterCountWithData() public {
+ uint64 count = 0;
+ Reverter reverter = new Reverter();
+ vm.expectRevert("revert", address(reverter), count);
+ reverter.doNotRevert();
+
+ vm.expectRevert("revert", address(reverter), count);
+ reverter.revertWithMessage("revert2");
+ }
+}