Skip to content

Commit

Permalink
wip on lt
Browse files Browse the repository at this point in the history
  • Loading branch information
thedavidmeister committed Jul 23, 2024
1 parent cb141f0 commit 902a15b
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 22 deletions.
110 changes: 101 additions & 9 deletions src/lib/LibDecimalFloat.sol
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,105 @@ library LibDecimalFloat {
return (signedCoefficient, exponent);
}

function compareRescale(int256 signedCoefficientA, int256 exponentA, int256 signedCoefficientB, int256 exponentB) internal pure returns (bool, bool, int256, int256, int256, int256) {
unchecked {
bool didSwap = false;
if (exponentB > exponentA) {
int256 tmp = signedCoefficientA;
signedCoefficientA = signedCoefficientB;
signedCoefficientB = tmp;

tmp = exponentA;
exponentA = exponentB;
exponentB = tmp;

didSwap = true;
}

int256 exponentDiff = exponentA - exponentB;
bool didOverflow;
assembly ("memory-safe") {
didOverflow := or(slt(exponentDiff, 0), sgt(exponentDiff, 76))
}
if (didOverflow) {
return (didOverflow, didSwap, 0, 0, 0, 0);
}
int256 scale = int256(10 ** uint256(exponentDiff));
int256 rescaled = signedCoefficientA * scale;

if (rescaled / scale != signedCoefficientA) {
return (true, didSwap, 0, 0, 0, 0);
}
else {
return (false, didSwap, rescaled, exponentA - exponentDiff, signedCoefficientB, exponentB);
}
}
}

function equal(int256 signedCoefficientA, int256 exponentA, int256 signedCoefficientB, int256 exponentB) internal pure returns (bool) {
// Trivially true if both coefficient and exponent are the same.
bool trivialTrue;
assembly ("memory-safe") {
trivialTrue := and(eq(signedCoefficientA, signedCoefficientB), eq(exponentA, exponentB))
}
if (trivialTrue) {
return true;
}

// Handle zeros that don't rescale well.
if (signedCoefficientA == 0 || signedCoefficientB == 0) {
return signedCoefficientA == signedCoefficientB;
}

bool didOverflow;
bool didSwap;
(didOverflow, didSwap, signedCoefficientA, exponentA, signedCoefficientB, exponentB) = compareRescale(signedCoefficientA, exponentA, signedCoefficientB, exponentB);

if (didOverflow) {
return false;
}

return signedCoefficientA == signedCoefficientB;
}

function lt(int256 signedCoefficientA, int256 exponentA, int256 signedCoefficientB, int256 exponentB) internal pure returns (bool) {
bool straightLt;
// If either coefficients are 0 then do straight lt.
assembly ("memory-safe") {
straightLt := or(iszero(signedCoefficientA), iszero(signedCoefficientB))
}
if (straightLt) {
return signedCoefficientA < signedCoefficientB;
}

// If the exponents are the same or the signs of the coefficients differ then lt is just a normal lt on the coefficients.
assembly ("memory-safe") {
straightLt := or(
slt(sdiv(signedCoefficientA, signedCoefficientB), 0),
slt(sdiv(signedCoefficientB, signedCoefficientA), 0)
)
}

if (straightLt) {
return signedCoefficientA < signedCoefficientB;
}

bool didOverflow;
bool didSwap;
(didOverflow, didSwap, signedCoefficientA, exponentA, signedCoefficientB, exponentB) = compareRescale(signedCoefficientA, exponentA, signedCoefficientB, exponentB);

if (didOverflow) {
return !didSwap;
}

if (didSwap) {
return signedCoefficientB < signedCoefficientA;
}
else {
return signedCoefficientA < signedCoefficientB;
}
}

/// https://speleotrove.com/decimal/daops.html#refnumco
/// > compare takes two operands and compares their values numerically. If
/// > either operand is a special value then the general rules apply. No
Expand Down Expand Up @@ -646,15 +745,8 @@ library LibDecimalFloat {

// Almost always we won't overflow by normalizing the exponents in one step.
int256 exponentDiff = exponentA - exponentB;
if (exponentDiff + exponentB != exponentA) {
// We overflowed the diff of the exponents in signed 256 bit space, so we're gazillions of OOMs apart.
// The coefficient with the larger exponent is the larger number.
// I.e. the current A is larger than B.
return didSwap ? COMPARE_LESS_THAN : COMPARE_GREATER_THAN;
}

if (exponentDiff > 76) {
// We didn't overflow but we're still too far apart to compare.
if (exponentDiff < 0 || exponentDiff > 76) {
// Our exponentDiff is overflowing what we can handle.
// The coefficient with the larger exponent is the larger number.
// I.e. the current A is larger than B.
return didSwap ? COMPARE_LESS_THAN : COMPARE_GREATER_THAN;
Expand Down
31 changes: 18 additions & 13 deletions test/src/LibDecimalFloat.compare.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ contract LibDecimalFloatCompareTest is Test {
) external pure {
signedCoefficientNeg = bound(signedCoefficientNeg, type(int256).min, -1);
signedCoefficientPos = bound(signedCoefficientPos, 0, type(int256).max);
exponentNeg = bound(exponentNeg, EXPONENT_MIN, EXPONENT_MAX);
exponentPos = bound(exponentPos, EXPONENT_MIN, EXPONENT_MAX);

int256 compare = LibDecimalFloat.compare(signedCoefficientNeg, exponentNeg, signedCoefficientPos, exponentPos);
assertEq(compare, COMPARE_LESS_THAN);
Expand All @@ -54,9 +52,6 @@ contract LibDecimalFloatCompareTest is Test {
int256 signedCoefficientB,
int256 exponentB
) external pure {
exponentA = bound(exponentA, EXPONENT_MIN, EXPONENT_MAX);
exponentB = bound(exponentB, EXPONENT_MIN, EXPONENT_MAX);

int256 compare0 = LibDecimalFloat.compare(signedCoefficientA, exponentA, signedCoefficientB, exponentB);
int256 compare1 = LibDecimalFloat.compare(signedCoefficientB, exponentB, signedCoefficientA, exponentA);

Expand All @@ -74,17 +69,12 @@ contract LibDecimalFloatCompareTest is Test {
external
pure
{
exponentA = bound(exponentA, EXPONENT_MIN, EXPONENT_MAX);
exponentB = bound(exponentB, EXPONENT_MIN, EXPONENT_MAX);

int256 compare = LibDecimalFloat.compare(signedCoefficientA, exponentA, signedCoefficientB, exponentB);
assert(compare == COMPARE_LESS_THAN || compare == COMPARE_EQUAL || compare == COMPARE_GREATER_THAN);
}

/// Comparing something to itself is always equal.
function testCompareSelf(int256 signedCoefficient, int256 exponent) external pure {
exponent = bound(exponent, EXPONENT_MIN, EXPONENT_MAX);

int256 compare = LibDecimalFloat.compare(signedCoefficient, exponent, signedCoefficient, exponent);
assertEq(compare, COMPARE_EQUAL);
}
Expand All @@ -94,7 +84,6 @@ contract LibDecimalFloatCompareTest is Test {
external
pure
{
exponent = bound(exponent, EXPONENT_MIN, EXPONENT_MAX);
vm.assume(signedCoefficientA != signedCoefficientB);

int256 compare = LibDecimalFloat.compare(signedCoefficientA, exponent, signedCoefficientB, exponent);
Expand All @@ -110,8 +99,6 @@ contract LibDecimalFloatCompareTest is Test {
/// Anything 0 is always less than anything positive.
function testCompareZero(int256 signedCoefficient, int256 exponent, int256 exponentZero) external pure {
signedCoefficient = bound(signedCoefficient, 1, type(int256).max);
exponent = bound(exponent, EXPONENT_MIN, EXPONENT_MAX);
exponentZero = bound(exponentZero, EXPONENT_MIN, EXPONENT_MAX);

int256 compare = LibDecimalFloat.compare(0, exponentZero, signedCoefficient, exponent);
assertEq(compare, COMPARE_LESS_THAN);
Expand All @@ -136,4 +123,22 @@ contract LibDecimalFloatCompareTest is Test {
function testCompareGasExponentDiffOverflow() external pure {
LibDecimalFloat.compare(1, type(int256).max, 1, type(int256).min);
}

function testCompareExponentDiffOverflowMax(int256 signedCoefficientA, int256 signedCoefficientB) external pure {
if (signedCoefficientA < 0) {
vm.assume(signedCoefficientB < 0);
}
if (signedCoefficientA > 0) {
vm.assume(signedCoefficientB > 0);
}
vm.assume(signedCoefficientA != 0);
vm.assume(signedCoefficientB != 0);

int256 compare = LibDecimalFloat.compare(signedCoefficientA, type(int256).max, signedCoefficientB, type(int256).min);
assertEq(compare, COMPARE_GREATER_THAN);
}

function testNeverRevert(int256 signedCoeffficientA, int256 exponentA, int256 signedCoefficientB, int256 exponentB) external pure {
LibDecimalFloat.compare(signedCoeffficientA, exponentA, signedCoefficientB, exponentB);
}
}

0 comments on commit 902a15b

Please sign in to comment.