Skip to content

Commit

Permalink
Optimize groupSum
Browse files Browse the repository at this point in the history
  • Loading branch information
Vectorized committed Jan 18, 2025
1 parent 70231d3 commit e33cd9e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
27 changes: 14 additions & 13 deletions src/utils/LibSort.sol
Original file line number Diff line number Diff line change
Expand Up @@ -597,29 +597,30 @@ library LibSort {
}
}
if (keys.length == uint256(0)) return;
uint256[] memory oriKeys = copy(keys);
uint256[] memory oriValues = copy(values);
insertionSort(keys);
(uint256[] memory oriKeys, uint256[] memory oriValues) = (copy(keys), copy(values));
insertionSort(keys); // Optimize for bytecode size.
uniquifySorted(keys);
/// @solidity memory-safe-assembly
assembly {
let d := sub(values, keys)
let w := not(0x1f)
let s := add(keys, 0x20)
mstore(values, mload(keys)) // Truncate.
calldatacopy(add(values, 0x20), calldatasize(), shl(5, mload(keys))) // Zeroize.
let end := add(0x20, shl(5, mload(oriKeys)))
for { let i := 0x20 } 1 {} {
calldatacopy(add(s, d), calldatasize(), shl(5, mload(keys))) // Zeroize.
for { let i := shl(5, mload(oriKeys)) } 1 {} {
let k := mload(add(oriKeys, i))
let v := mload(add(oriValues, i))
let j := 0x20
for {} iszero(eq(mload(add(keys, j)), k)) {} { j := add(j, 0x20) }
let s := add(mload(add(values, j)), v)
if lt(s, v) {
let j := s
for {} iszero(eq(mload(j), k)) {} { j := add(j, 0x20) }
j := add(j, d)
mstore(j, add(mload(j), v))
if lt(mload(j), v) {
mstore(0x00, 0x4e487b71)
mstore(0x20, 0x11) // Overflow panic.
revert(0x1c, 0x24)
}
mstore(add(values, j), s)
i := add(i, 0x20)
if eq(i, end) { break }
i := add(i, w)
if iszero(i) { break }
}
mstore(0x40, m) // Frees the temporary memory.
}
Expand Down
18 changes: 18 additions & 0 deletions test/LibSort.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,24 @@ contract LibSortTest is SoladyTest {
assertEq(orAll >> 160, 0);
}

function testGroupSum() public {
uint256 n = 32;
uint256[] memory keys = new uint256[](n);
uint256[] memory values = new uint256[](n);
uint256 total;
unchecked {
for (uint256 i; i < n; ++i) {
keys[i] = (i + 1) % 7;
values[i] = i;
total += i;
}
}
LibSort.groupSum(keys, values);
assertEq(keys.length, 7);
assertEq(values.length, 7);
assertEq(_sum(values), total);
}

function testGroupSum(bytes32) public {
if (_randomChance(2)) {
_misalignFreeMemoryPointer();
Expand Down

0 comments on commit e33cd9e

Please sign in to comment.