Skip to content

Commit

Permalink
Merge branch 'decimal-amount'
Browse files Browse the repository at this point in the history
  • Loading branch information
iljakuklic committed Jan 9, 2024
2 parents e3dc5b4 + 173dd8c commit 5695a1a
Show file tree
Hide file tree
Showing 11 changed files with 401 additions and 117 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

70 changes: 3 additions & 67 deletions common/src/primitives/amount.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
use serialization::{Decode, Encode};
use std::iter::Sum;

use super::signed_amount::SignedAmount;
use super::{signed_amount::SignedAmount, DecimalAmount};

pub type UnsignedIntType = u128;

Expand All @@ -46,16 +46,6 @@ pub struct Amount {
val: UnsignedIntType,
}

fn remove_right_most_zeros_and_decimal_point(s: String) -> String {
let point_pos = s.chars().position(|c| c == '.');
if point_pos.is_none() {
return s;
}
let s = s.trim_end_matches('0');
let s = s.trim_end_matches('.');
s.to_owned()
}

impl Amount {
pub const MAX: Self = Self::from_atoms(UnsignedIntType::MAX);
pub const ZERO: Self = Self::from_atoms(0);
Expand All @@ -81,65 +71,11 @@ impl Amount {
}

pub fn into_fixedpoint_str(self, decimals: u8) -> String {
let amount_str = self.val.to_string();
let decimals = decimals as usize;
if amount_str.len() <= decimals {
let zeros = "0".repeat(decimals - amount_str.len());
let result = "0.".to_owned() + &zeros + &amount_str;

remove_right_most_zeros_and_decimal_point(result)
} else {
let ten: UnsignedIntType = 10;
let unit = ten.pow(decimals as u32);
let whole = self.val / unit;
let fraction = self.val % unit;
let result = format!("{whole}.{fraction:0decimals$}");

remove_right_most_zeros_and_decimal_point(result)
}
DecimalAmount::from_amount_minimal(self, decimals).to_string()
}

pub fn from_fixedpoint_str(amount_str: &str, decimals: u8) -> Option<Self> {
let decimals = decimals as usize;
let amount_str = amount_str.trim_matches(' '); // trim spaces
let amount_str = amount_str.replace('_', "");

// empty not allowed
if amount_str.is_empty() {
return None;
}
// too long
if amount_str.len() > 100 {
return None;
}
// must be only numbers or decimal point
if !amount_str.chars().all(|c| char::is_numeric(c) || c == '.') {
return None;
}

if amount_str.matches('.').count() > 1 {
// only 1 decimal point allowed
None
} else if amount_str.matches('.').count() == 0 {
// if there is no decimal point, then just add N zeros to the right and we're done
let zeros = "0".repeat(decimals);
let amount_str = amount_str + &zeros;

amount_str.parse::<UnsignedIntType>().ok().map(|v| Amount { val: v })
} else {
// if there's 1 decimal point, split, join the numbers, then add zeros to the right
let amount_split = amount_str.split('.').collect::<Vec<&str>>();
debug_assert!(amount_split.len() == 2); // we already checked we have 1 decimal exactly
if amount_split[1].len() > decimals {
// there cannot be more decimals than the assumed amount
return None;
}
let zeros = "0".repeat(decimals - amount_split[1].len());
let atoms_str = amount_split[0].to_owned() + amount_split[1] + &zeros;
let atoms_str = atoms_str.trim_start_matches('0');

atoms_str.parse::<UnsignedIntType>().ok().map(|v| Amount { val: v })
}
amount_str.parse::<DecimalAmount>().ok()?.to_amount(decimals)
}

pub fn abs_diff(self, other: Amount) -> Amount {
Expand Down
261 changes: 261 additions & 0 deletions common/src/primitives/decimal_amount.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
// Copyright (c) 2024 RBB S.r.l
// [email protected]
// SPDX-License-Identifier: MIT
// Licensed under the MIT License;
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://github.com/mintlayer/mintlayer-core/blob/master/LICENSE
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::fmt::Write;

use utils::ensure;

pub use super::amount::{Amount, UnsignedIntType};

const TEN: UnsignedIntType = 10;

/// Amount in fractional representation
///
/// Keeps track of the number of decimal digits that should be presented to the user. This is
/// mostly for presentation purposes so does not define any arithmetic operations. Convert to
/// `Amount` if arithmetic is needed.
///
/// Comparison operators are deliberately left out too. The reason is that there are two sensible
/// ways to compare `DecimalAmount`s:
/// 1. Compare the numerical values that they signify
/// 2. Compare the implied textual representation, e.g. "1.0" and "1.000" are considered different
/// The user is expected to convert to a number or a string before comparing to explicitly state
/// which for of comparison is desired in any given situation.
#[derive(Clone, Copy, Debug)]
pub struct DecimalAmount {
mantissa: UnsignedIntType,
decimals: u8,
}

impl DecimalAmount {
pub const ZERO: Self = Self::from_uint_integral(0);

/// Convert from integer with no decimals
pub const fn from_uint_integral(number: u128) -> Self {
Self::from_uint_decimal(number, 0)
}

/// Convert from integer, interpreting the last N digits as the fractional part
pub const fn from_uint_decimal(mantissa: UnsignedIntType, decimals: u8) -> Self {
Self { mantissa, decimals }
}

/// Convert from amount, keeping all decimal digits
pub const fn from_amount_full(amount: Amount, decimals: u8) -> Self {
Self::from_uint_decimal(amount.into_atoms(), decimals)
}

/// Convert from amount, keeping as few decimal digits as possible (without losing precision)
pub const fn from_amount_minimal(amount: Amount, decimals: u8) -> Self {
Self::from_amount_full(amount, decimals).minimize()
}

/// Convert to amount using given number of decimals
pub fn to_amount(self, decimals: u8) -> Option<Amount> {
Some(Amount::from_atoms(self.with_decimals(decimals)?.mantissa))
}

/// Change the number of decimals. Can only increase decimals, otherwise we risk losing digits.
pub fn with_decimals(self, decimals: u8) -> Option<Self> {
let extra_decimals = decimals.checked_sub(self.decimals)?;
let mantissa = self.mantissa.checked_mul(TEN.checked_pow(extra_decimals as u32)?)?;
Some(Self::from_uint_decimal(mantissa, decimals))
}

/// Trim trailing zeroes in the fractional part
pub const fn minimize(mut self) -> Self {
while self.decimals > 0 && self.mantissa % TEN == 0 {
self.mantissa /= TEN;
self.decimals -= 1;
}
self
}

/// Check this is the same number presented with the same precision
pub fn is_same(&self, other: &Self) -> bool {
(self.mantissa, self.decimals) == (other.mantissa, other.decimals)
}
}

fn empty_to_zero(s: &str) -> &str {
match s {
"" => "0",
s => s,
}
}

impl std::str::FromStr for DecimalAmount {
type Err = ParseError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
ensure!(s.len() <= 100, ParseError::StringTooLong);

let s = s.trim_matches(' ');
let s = s.replace('_', "");
ensure!(!s.is_empty(), ParseError::EmptyString);

let (int_str, frac_str) = s.split_once('.').unwrap_or((&s, ""));

let mut chars = int_str.chars().chain(frac_str.chars());
ensure!(chars.all(|c| c.is_ascii_digit()), ParseError::IllegalChar);
ensure!(int_str.len() + frac_str.len() > 0, ParseError::NoDigits);

let int: UnsignedIntType = empty_to_zero(int_str).parse()?;
let frac: UnsignedIntType = empty_to_zero(frac_str).parse()?;

let decimals: u8 = frac_str.len().try_into().expect("Checked string length above");

let mantissa = TEN
.checked_pow(decimals as u32)
.and_then(|mul| int.checked_mul(mul))
.and_then(|shifted| shifted.checked_add(frac))
.ok_or(ParseError::OutOfRange)?;

Ok(Self::from_uint_decimal(mantissa, decimals))
}
}

impl std::fmt::Display for DecimalAmount {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mantissa = self.mantissa;
let decimals = self.decimals as usize;

if decimals > 0 {
// Max string length: ceil(log10(u128::MAX)) + 1 for decimal point = 40
let mut buffer = String::with_capacity(40);
write!(&mut buffer, "{mantissa:0>width$}", width = decimals + 1)?;
assert!(buffer.len() > decimals);
buffer.insert(buffer.len() - decimals, '.');
f.pad(&buffer)
} else {
mantissa.fmt(f)
}
}
}

impl serde::Serialize for DecimalAmount {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&self.to_string())
}
}

#[derive(serde::Serialize, serde::Deserialize)]
#[serde(untagged)]
enum StringOrUint {
String(String),
UInt(u128),
}

impl<'de> serde::Deserialize<'de> for DecimalAmount {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
match StringOrUint::deserialize(deserializer)? {
StringOrUint::String(s) => s.parse().map_err(<D::Error as serde::de::Error>::custom),
StringOrUint::UInt(i) => Ok(Self::from_uint_integral(i)),
}
}
}

#[derive(thiserror::Error, Debug, PartialEq, Eq)]
pub enum ParseError {
#[error("Resulting number is too big")]
OutOfRange,

#[error("The number string is too long")]
StringTooLong,

#[error("Empty input")]
EmptyString,

#[error("Invalid character used in number")]
IllegalChar,

#[error("Number does not contain any digits")]
NoDigits,

#[error(transparent)]
IntParse(#[from] std::num::ParseIntError),
}

#[cfg(test)]
mod test {
use super::*;

#[rstest::rstest]
// Zero decimals
#[case("0", DecimalAmount::from_uint_integral(0))]
#[case("00", DecimalAmount::from_uint_integral(0))]
#[case("5", DecimalAmount::from_uint_integral(5))]
#[case("123", DecimalAmount::from_uint_integral(123))]
#[case("0123", DecimalAmount::from_uint_integral(123))]
#[case("55555", DecimalAmount::from_uint_integral(55555))]
#[case("9999", DecimalAmount::from_uint_integral(9999))]
#[case(
"340282366920938463463374607431768211455",
DecimalAmount::from_uint_integral(u128::MAX)
)]
// Trailing dot
#[case("0123.", DecimalAmount::from_uint_integral(123))]
#[case("55555.", DecimalAmount::from_uint_integral(55555))]
#[case("9999.", DecimalAmount::from_uint_integral(9999))]
#[case(
"340282366920938463463374607431768211455.",
DecimalAmount::from_uint_integral(u128::MAX)
)]
// One decimal
#[case("0.0", DecimalAmount::from_uint_decimal(0, 1))]
#[case("00.0", DecimalAmount::from_uint_decimal(0, 1))]
#[case("5.3", DecimalAmount::from_uint_decimal(53, 1))]
#[case("123.0", DecimalAmount::from_uint_decimal(1230, 1))]
#[case("0123.4", DecimalAmount::from_uint_decimal(1234, 1))]
#[case("55555.0", DecimalAmount::from_uint_decimal(555550, 1))]
#[case("9999.0", DecimalAmount::from_uint_decimal(99990, 1))]
#[case("0123.7", DecimalAmount::from_uint_decimal(1237, 1))]
#[case("55555.6", DecimalAmount::from_uint_decimal(555556, 1))]
#[case("9999.9", DecimalAmount::from_uint_decimal(99999, 1))]
#[case(
"34028236692093846346337460743176821.1455",
DecimalAmount::from_uint_decimal(u128::MAX, 4)
)]
fn parse_ok(#[case] s: &str, #[case] amt: DecimalAmount) {
assert!(amt.is_same(&s.parse().expect("parsing failed")));

let roundtrip = amt.to_string().parse().expect("parsing failed");
assert!(amt.is_same(&roundtrip));
}

#[rstest::rstest]
#[case("", ParseError::EmptyString)]
#[case(" ", ParseError::EmptyString)]
#[case(" _ _ ", ParseError::IllegalChar)]
#[case(".", ParseError::NoDigits)]
#[case("._", ParseError::NoDigits)]
#[case("_.", ParseError::NoDigits)]
#[case("_._", ParseError::NoDigits)]
#[case("_.___", ParseError::NoDigits)]
#[case("x", ParseError::IllegalChar)]
#[case("-", ParseError::IllegalChar)]
#[case("%", ParseError::IllegalChar)]
#[case("13.5e2", ParseError::IllegalChar)]
#[case("34028236692093846346337460743176821145.6", ParseError::OutOfRange)]
#[case("3.40282366920938463463374607431768211456", ParseError::OutOfRange)]
#[case(
"99999_99999_99999_99999_99999.99999_99999_99999_99999_99999",
ParseError::OutOfRange
)]
fn parse_err(#[case] s: &str, #[case] expected_err: ParseError) {
let err = s.parse::<DecimalAmount>().expect_err("parsing succeeded");
assert_eq!(err, expected_err);
}
}
2 changes: 2 additions & 0 deletions common/src/primitives/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

pub mod amount;
pub mod compact;
pub mod decimal_amount;
pub mod encoding;
pub mod error;
pub mod height;
Expand All @@ -31,6 +32,7 @@ mod hash_encoded;

pub use amount::Amount;
pub use compact::Compact;
pub use decimal_amount::DecimalAmount;
pub use encoding::{Bech32Error, DecodedArbitraryDataFromBech32};
pub use height::{BlockCount, BlockDistance, BlockHeight};
pub use id::{Id, Idable, H256};
Expand Down
4 changes: 0 additions & 4 deletions wallet/wallet-cli-lib/src/commands/helper_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,6 @@ pub fn print_coin_amount(chain_config: &ChainConfig, value: Amount) -> String {
value.into_fixedpoint_str(chain_config.coin_decimals())
}

pub fn print_token_amount(token_number_of_decimals: u8, value: Amount) -> String {
value.into_fixedpoint_str(token_number_of_decimals)
}

#[cfg(test)]
mod tests {
use rstest::rstest;
Expand Down
Loading

0 comments on commit 5695a1a

Please sign in to comment.