diff --git a/src/utils/LibSort.sol b/src/utils/LibSort.sol index ac2df8336..a5af33f71 100644 --- a/src/utils/LibSort.sol +++ b/src/utils/LibSort.sol @@ -597,29 +597,30 @@ library LibSort { } } if (keys.length == uint256(0)) return; - uint256[] memory oriKeys = copy(keys); - uint256[] memory oriValues = copy(values); - insertionSort(keys); + (uint256[] memory oriKeys, uint256[] memory oriValues) = (copy(keys), copy(values)); + insertionSort(keys); // Optimize for bytecode size. uniquifySorted(keys); /// @solidity memory-safe-assembly assembly { + let d := sub(values, keys) + let w := not(0x1f) + let s := add(keys, 0x20) mstore(values, mload(keys)) // Truncate. - calldatacopy(add(values, 0x20), calldatasize(), shl(5, mload(keys))) // Zeroize. - let end := add(0x20, shl(5, mload(oriKeys))) - for { let i := 0x20 } 1 {} { + calldatacopy(add(s, d), calldatasize(), shl(5, mload(keys))) // Zeroize. + for { let i := shl(5, mload(oriKeys)) } 1 {} { let k := mload(add(oriKeys, i)) let v := mload(add(oriValues, i)) - let j := 0x20 - for {} iszero(eq(mload(add(keys, j)), k)) {} { j := add(j, 0x20) } - let s := add(mload(add(values, j)), v) - if lt(s, v) { + let j := s + for {} iszero(eq(mload(j), k)) {} { j := add(j, 0x20) } + j := add(j, d) + mstore(j, add(mload(j), v)) + if lt(mload(j), v) { mstore(0x00, 0x4e487b71) mstore(0x20, 0x11) // Overflow panic. revert(0x1c, 0x24) } - mstore(add(values, j), s) - i := add(i, 0x20) - if eq(i, end) { break } + i := add(i, w) + if iszero(i) { break } } mstore(0x40, m) // Frees the temporary memory. } diff --git a/test/LibSort.t.sol b/test/LibSort.t.sol index 5d2370176..3ca410492 100644 --- a/test/LibSort.t.sol +++ b/test/LibSort.t.sol @@ -1298,6 +1298,24 @@ contract LibSortTest is SoladyTest { assertEq(orAll >> 160, 0); } + function testGroupSum() public { + uint256 n = 32; + uint256[] memory keys = new uint256[](n); + uint256[] memory values = new uint256[](n); + uint256 total; + unchecked { + for (uint256 i; i < n; ++i) { + keys[i] = (i + 1) % 7; + values[i] = i; + total += i; + } + } + LibSort.groupSum(keys, values); + assertEq(keys.length, 7); + assertEq(values.length, 7); + assertEq(_sum(values), total); + } + function testGroupSum(bytes32) public { if (_randomChance(2)) { _misalignFreeMemoryPointer();