Skip to content

Commit

Permalink
✨ LibSort groupSum (#1308)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vectorized authored Jan 16, 2025
1 parent 4169bef commit 00ce639
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 0 deletions.
59 changes: 59 additions & 0 deletions src/utils/LibSort.sol
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,65 @@ library LibSort {
}
}

/// @dev Performs a sum on `values`, grouped by `keys`.
/// `keys` will be insertion-sorted and uniquified,
/// `values` will be re-populated with the group sums.
/// The arrays must have the same length.
function groupSum(uint256[] memory keys, uint256[] memory values) internal pure {
uint256 m;
/// @solidity memory-safe-assembly
assembly {
m := mload(0x40) // Cache the free memory pointer, for freeing the memory.
if iszero(eq(mload(keys), mload(values))) {
mstore(0x00, 0x4e487b71)
mstore(0x20, 0x32) // Array out of bounds panic.
revert(0x1c, 0x24)
}
}
if (keys.length == uint256(0)) return;
uint256[] memory oriKeys = copy(keys);
uint256[] memory oriValues = copy(values);
insertionSort(keys);
uniquifySorted(keys);
/// @solidity memory-safe-assembly
assembly {
mstore(values, mload(keys)) // Truncate.
calldatacopy(add(values, 0x20), calldatasize(), shl(5, mload(keys))) // Zeroize.
let end := add(0x20, shl(5, mload(oriKeys)))
for { let i := 0x20 } 1 {} {
let k := mload(add(oriKeys, i))
let v := mload(add(oriValues, i))
let j := 0x20
for {} iszero(eq(mload(add(keys, j)), k)) {} { j := add(j, 0x20) }
let s := add(mload(add(values, j)), v)
if lt(s, v) {
mstore(0x00, 0x4e487b71)
mstore(0x20, 0x11) // Overflow panic.
revert(0x1c, 0x24)
}
mstore(add(values, j), s)
i := add(i, 0x20)
if eq(i, end) { break }
}
mstore(0x40, m) // Frees the temporary memory.
}
}

/// @dev Performs a sum on `values`, grouped by `keys`.
function groupSum(address[] memory keys, uint256[] memory values) internal pure {
groupSum(_toUints(keys), values);
}

/// @dev Performs a sum on `values`, grouped by `keys`.
function groupSum(bytes32[] memory keys, uint256[] memory values) internal pure {
groupSum(_toUints(keys), values);
}

/// @dev Performs a sum on `values`, grouped by `keys`.
function groupSum(int256[] memory keys, uint256[] memory values) internal pure {
groupSum(_toUints(keys), values);
}

/*´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:*/
/* PRIVATE HELPERS */
/*.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•*/
Expand Down
42 changes: 42 additions & 0 deletions test/LibSort.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -1297,4 +1297,46 @@ contract LibSortTest is SoladyTest {
}
assertEq(orAll >> 160, 0);
}

function testGroupSum(bytes32) public {
if (_randomChance(2)) {
_misalignFreeMemoryPointer();
_brutalizeMemory();
}
uint256 n = _random() & 0x1f;
uint256[] memory keys = new uint256[](n);
uint256[] memory values = new uint256[](n);
unchecked {
for (uint256 i; i < n; ++i) {
uint256 k = _randomUniform() & 0xf;
uint256 v = _randomUniform() & 0xff;
keys[i] = k;
values[i] = v;
}
}
uint256 oriSum = _sum(values);
uint256[] memory uniqueKeys = LibSort.copy(keys);
LibSort.insertionSort(uniqueKeys);
LibSort.uniquifySorted(uniqueKeys);
uint256[] memory sums = new uint256[](uniqueKeys.length);
unchecked {
for (uint256 i; i < n; ++i) {
(, uint256 j) = LibSort.searchSorted(uniqueKeys, keys[i]);
sums[j] += values[i];
}
}
LibSort.groupSum(keys, values);
_checkMemory(sums);
assertEq(keys, uniqueKeys);
assertEq(values, sums);
assertEq(_sum(sums), oriSum);
}

function _sum(uint256[] memory a) internal pure returns (uint256 result) {
unchecked {
for (uint256 i; i < a.length; ++i) {
result += a[i];
}
}
}
}

0 comments on commit 00ce639

Please sign in to comment.