Skip to content

Commit

Permalink
feat(cheatcodes): count assertion for expectRevert (#9484)
Browse files Browse the repository at this point in the history
* expectRevert count overload boilerplate

* introduce `count` variable

* populate `ExpectedRevert` for count overloads

* intro `actual_count` and make ExpectedRevert mut

* increment `actual_account` on success and tests

* handle non-zero count reverts separately

* handle count for specific reverts

* nit

* more tests

* fix: handle count > 1 with reverter specified

* test: ExpectRevertCountWithReverter

* expectRevert with reverter and count 0

* nit

* reverter count with data

* nit

* cleanup

* nit

* nit

* clippy

* nit

* cargo cheats
  • Loading branch information
yash-atreya authored Dec 6, 2024
1 parent e520767 commit 63484d0
Show file tree
Hide file tree
Showing 6 changed files with 609 additions and 68 deletions.
120 changes: 120 additions & 0 deletions crates/cheatcodes/assets/cheatcodes.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions crates/cheatcodes/spec/src/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,30 @@ interface Vm {
#[cheatcode(group = Testing, safety = Unsafe)]
function expectRevert(bytes calldata revertData, address reverter) external;

/// Expects a `count` number of reverts from the upcoming calls with any revert data or reverter.
#[cheatcode(group = Testing, safety = Unsafe)]
function expectRevert(uint64 count) external;

/// Expects a `count` number of reverts from the upcoming calls that match the revert data.
#[cheatcode(group = Testing, safety = Unsafe)]
function expectRevert(bytes4 revertData, uint64 count) external;

/// Expects a `count` number of reverts from the upcoming calls that exactly match the revert data.
#[cheatcode(group = Testing, safety = Unsafe)]
function expectRevert(bytes calldata revertData, uint64 count) external;

/// Expects a `count` number of reverts from the upcoming calls from the reverter address.
#[cheatcode(group = Testing, safety = Unsafe)]
function expectRevert(address reverter, uint64 count) external;

/// Expects a `count` number of reverts from the upcoming calls from the reverter address that match the revert data.
#[cheatcode(group = Testing, safety = Unsafe)]
function expectRevert(bytes4 revertData, address reverter, uint64 count) external;

/// Expects a `count` number of reverts from the upcoming calls from the reverter address that exactly match the revert data.
#[cheatcode(group = Testing, safety = Unsafe)]
function expectRevert(bytes calldata revertData, address reverter, uint64 count) external;

/// Expects an error on next call that starts with the revert data.
#[cheatcode(group = Testing, safety = Unsafe)]
function expectPartialRevert(bytes4 revertData) external;
Expand Down
40 changes: 32 additions & 8 deletions crates/cheatcodes/src/inspector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -754,16 +754,23 @@ where {
if ecx.journaled_state.depth() <= expected_revert.depth &&
matches!(expected_revert.kind, ExpectedRevertKind::Default)
{
let expected_revert = std::mem::take(&mut self.expected_revert).unwrap();
return match expect::handle_expect_revert(
let mut expected_revert = std::mem::take(&mut self.expected_revert).unwrap();
let handler_result = expect::handle_expect_revert(
false,
true,
&expected_revert,
&mut expected_revert,
outcome.result.result,
outcome.result.output.clone(),
&self.config.available_artifacts,
) {
);

return match handler_result {
Ok((address, retdata)) => {
expected_revert.actual_count += 1;
if expected_revert.actual_count < expected_revert.count {
self.expected_revert = Some(expected_revert.clone());
}

outcome.result.result = InstructionResult::Return;
outcome.result.output = retdata;
outcome.address = address;
Expand Down Expand Up @@ -1302,6 +1309,14 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes {
expected_revert.reverted_by.is_none()
{
expected_revert.reverted_by = Some(call.target_address);
} else if outcome.result.is_revert() &&
expected_revert.reverter.is_some() &&
expected_revert.reverted_by.is_some() &&
expected_revert.count > 1
{
// If we're expecting more than one revert, we need to reset the reverted_by address
// to latest reverter.
expected_revert.reverted_by = Some(call.target_address);
}

if ecx.journaled_state.depth() <= expected_revert.depth {
Expand All @@ -1315,22 +1330,31 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes {
};

if needs_processing {
let expected_revert = std::mem::take(&mut self.expected_revert).unwrap();
return match expect::handle_expect_revert(
// Only `remove` the expected revert from state if `expected_revert.count` ==
// `expected_revert.actual_count`
let mut expected_revert = std::mem::take(&mut self.expected_revert).unwrap();

let handler_result = expect::handle_expect_revert(
cheatcode_call,
false,
&expected_revert,
&mut expected_revert,
outcome.result.result,
outcome.result.output.clone(),
&self.config.available_artifacts,
) {
);

return match handler_result {
Err(error) => {
trace!(expected=?expected_revert, ?error, status=?outcome.result.result, "Expected revert mismatch");
outcome.result.result = InstructionResult::Revert;
outcome.result.output = error.abi_encode().into();
outcome
}
Ok((_, retdata)) => {
expected_revert.actual_count += 1;
if expected_revert.actual_count < expected_revert.count {
self.expected_revert = Some(expected_revert.clone());
}
outcome.result.result = InstructionResult::Return;
outcome.result.output = retdata;
outcome
Expand Down
Loading

0 comments on commit 63484d0

Please sign in to comment.