Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add various arithmetic operations #39

Merged
merged 7 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
- Support `Step` so that arbitrary-int can be used in a range expression, e.g. `for n in u3::MIN..=u3::MAX { println!("{n}") }`. Note this trait is currently unstable, and so is only usable in nightly. Enable this feature with `step_trait`.
- Support formatting via [defmt](https://crates.io/crates/defmt). Enable the option `defmt` feature
- Support serializing and deserializing via [serde](https://crates.io/crates/serde). Enable the option `serde` feature
- Implement `Mul`, `MulAssign`, `Div`, `DivAssign`
- Implement `wrapping_add`, `wrapping_sub`, `wrapping_mul`, `wrapping_div`, `wrapping_shl`, `wrapping_shr`
- Implement `saturating_add`, `saturating_sub`, `saturating_mul`, `saturating_div`, `saturating_pow`

### Changed
- In debug builds, `<<` (`Shl`, `ShlAssign`) and `>>` (`Shr`, `ShrAssign`) now bounds-check the shift amount using the same semantics as built-in shifts. For example, shifting a u5 by 5 or more bits will now panic as expected.

## arbitrary-int 1.2.6

Expand Down
244 changes: 242 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use core::iter::Step;
#[cfg(feature = "num-traits")]
use core::num::Wrapping;
use core::ops::{
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl,
ShlAssign, Shr, ShrAssign, Sub, SubAssign,
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign,
Mul, MulAssign, Not, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
};
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};
Expand Down Expand Up @@ -340,6 +340,121 @@ macro_rules! uint_impl {
UInt::<$type, BITS_RESULT> { value: self.value }
}

pub const fn wrapping_add(self, rhs: Self) -> Self {
let sum = self.value.wrapping_add(rhs.value);
Self {
value: sum & Self::MASK,
}
}

pub const fn wrapping_sub(self, rhs: Self) -> Self {
let sum = self.value.wrapping_sub(rhs.value);
Self {
value: sum & Self::MASK,
}
}

pub const fn wrapping_mul(self, rhs: Self) -> Self {
let sum = self.value.wrapping_mul(rhs.value);
Self {
value: sum & Self::MASK,
}
}

pub const fn wrapping_div(self, rhs: Self) -> Self {
let sum = self.value.wrapping_div(rhs.value);
Self {
// No need to mask here - divisions always produce a result that is <= self
value: sum,
}
}

pub const fn wrapping_shl(self, rhs: u32) -> Self {
// modulo is expensive on some platforms, so only do it when necessary
let shift_amount = if rhs >= (BITS as u32) {
rhs % (BITS as u32)
} else {
rhs
};

Self {
value: (self.value << shift_amount) & Self::MASK,
}
}

pub const fn wrapping_shr(self, rhs: u32) -> Self {
// modulo is expensive on some platforms, so only do it when necessary
let shift_amount = if rhs >= (BITS as u32) {
rhs % (BITS as u32)
} else {
rhs
};

Self {
value: (self.value >> shift_amount) & Self::MASK,
}
}

pub const fn saturating_add(self, rhs: Self) -> Self {
let saturated = if core::mem::size_of::<$type>() << 3 == BITS {
// We are something like a UInt::<u8; 8>. We can fallback to the base implementation
self.value.saturating_add(rhs.value)
} else {
// We're dealing with fewer bits than the underlying type (e.g. u7).
// That means the addition can never overflow the underlying type
let sum = self.value.wrapping_add(rhs.value);
let max = Self::MAX.value();
if sum > max { max } else { sum }
};
Self {
value: saturated,
}
}

pub const fn saturating_sub(self, rhs: Self) -> Self {
// For unsigned numbers, the only difference is when we reach 0 - which is the same
// no matter the data size
Self {
value: self.value.saturating_sub(rhs.value),
}
}

pub const fn saturating_mul(self, rhs: Self) -> Self {
let product = if BITS << 1 <= (core::mem::size_of::<$type>() << 3) {
// We have half the bits (e.g. u4 * u4) of the base type, so we can't overflow the base type
self.value * rhs.value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not have to go through the “saturated” logic below, does it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, don’t you want to use wrapping_mul to avoid redundant checks in debug?

(Same up in shl/shr)?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the first question: Yes it does. An u4 is backed by an u8. u4 * u4 can't overflow u8 (hence we skip the first check), but it can still overflow an u4, so we need the second check.

For the second part: I guess - I didn't care too much about debug foot print, but changing to wrapping_mul shouldn't hurt

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For shl/shr: I decided to use << and >> as I'd expect that they perform better on cpus that don't by default do wrapping_shl. Slight hit in Debug builds on "normal" cpus, but optimal code everywhere in Release build

} else {
// We have more than half the bits (e.g. u6 * u6)
self.value.saturating_mul(rhs.value)
};

let max = Self::MAX.value();
let saturated = if product > max { max } else { product };
Self {
value: saturated,
}
}

pub const fn saturating_div(self, rhs: Self) -> Self {
// When dividing unsigned numbers, we never need to saturate.
// Divison by zero in saturating_div throws an exception (in debug and release mode),
// so no need to do anything special there either
Self {
value: self.value.saturating_div(rhs.value),
}
}

pub const fn saturating_pow(self, exp: u32) -> Self {
// It might be possible to handwrite this to be slightly faster as both
// saturating_pow has to do a bounds-check and then we do second one
let powed = self.value.saturating_pow(exp);
let max = Self::MAX.value();
let saturated = if powed > max { max } else { powed };
Self {
value: saturated,
}
}

/// Reverses the order of bits in the integer. The least significant bit becomes the most significant bit, second least-significant bit becomes second most-significant bit, etc.
pub const fn reverse_bits(self) -> Self {
let shift_right = (core::mem::size_of::<$type>() << 3) - BITS;
Expand Down Expand Up @@ -501,6 +616,102 @@ where
}
}

impl<T, const BITS: usize> Mul for UInt<T, BITS>
where
Self: Number,
T: PartialEq
+ Copy
+ BitAnd<T, Output = T>
+ Not<Output = T>
+ Mul<T, Output = T>
+ Sub<T, Output = T>
+ Shr<usize, Output = T>
+ Shl<usize, Output = T>
+ From<u8>,
{
type Output = UInt<T, BITS>;

fn mul(self, rhs: Self) -> Self::Output {
let product = self.value * rhs.value;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has another overflow check in debug, right? We need it, but should have add a comment that there is another check that we can’t avoid?

#[cfg(debug_assertions)]
if (product & !Self::MASK) != T::from(0) {
panic!("attempt to multiply with overflow");
}
Self {
value: product & Self::MASK,
}
}
}

impl<T, const BITS: usize> MulAssign for UInt<T, BITS>
where
Self: Number,
T: PartialEq
+ Eq
+ Not<Output = T>
+ Copy
+ MulAssign<T>
+ BitAnd<T, Output = T>
+ BitAndAssign<T>
+ Sub<T, Output = T>
+ Shr<usize, Output = T>
+ Shl<usize, Output = T>
+ From<u8>,
{
fn mul_assign(&mut self, rhs: Self) {
self.value *= rhs.value;
#[cfg(debug_assertions)]
if (self.value & !Self::MASK) != T::from(0) {
panic!("attempt to multiply with overflow");
}
self.value &= Self::MASK;
}
}

impl<T, const BITS: usize> Div for UInt<T, BITS>
where
Self: Number,
T: PartialEq
+ Copy
+ BitAnd<T, Output = T>
+ Not<Output = T>
+ Div<T, Output = T>
+ Sub<T, Output = T>
+ Shr<usize, Output = T>
+ Shl<usize, Output = T>
+ From<u8>,
{
type Output = UInt<T, BITS>;

fn div(self, rhs: Self) -> Self::Output {
// Integer division can only make the value smaller. And as the result is same type as
// Self, there's no need to range-check or mask
Self {
value: self.value / rhs.value,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the compiler know this though? Should we use an unchecked version to guarantee we avoid checks?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's unfortunately not a Trait for wrapping_div in regular Rust, so we can't express the generic in this way (num_traits has that, but that's an optional feature in this crate)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it matters too much though as I don't think the compiler checks anything here (what would it even test?)

}
}
}

impl<T, const BITS: usize> DivAssign for UInt<T, BITS>
where
Self: Number,
T: PartialEq
+ Eq
+ Not<Output = T>
+ Copy
+ DivAssign<T>
+ BitAnd<T, Output = T>
+ BitAndAssign<T>
+ Sub<T, Output = T>
+ Shr<usize, Output = T>
+ Shl<usize, Output = T>
+ From<u8>,
{
fn div_assign(&mut self, rhs: Self) {
self.value /= rhs.value;
}
}

impl<T, const BITS: usize> BitAnd for UInt<T, BITS>
where
Self: Number,
Expand Down Expand Up @@ -603,10 +814,18 @@ where
+ Shl<usize, Output = T>
+ Shr<usize, Output = T>
+ From<u8>,
TSHIFTBITS: TryInto<usize> + Copy,
{
type Output = UInt<T, BITS>;

fn shl(self, rhs: TSHIFTBITS) -> Self::Output {
// With debug assertions, the << and >> operators throw an exception if the shift amount
// is larger than the number of bits (in which case the result would always be 0)
#[cfg(debug_assertions)]
if rhs.try_into().unwrap_or(usize::MAX) >= BITS {
panic!("attempt to shift left with overflow")
}

Self {
value: (self.value << rhs) & Self::MASK,
}
Expand All @@ -624,8 +843,15 @@ where
+ Shr<usize, Output = T>
+ Shl<usize, Output = T>
+ From<u8>,
TSHIFTBITS: TryInto<usize> + Copy,
{
fn shl_assign(&mut self, rhs: TSHIFTBITS) {
// With debug assertions, the << and >> operators throw an exception if the shift amount
// is larger than the number of bits (in which case the result would always be 0)
#[cfg(debug_assertions)]
if rhs.try_into().unwrap_or(usize::MAX) >= BITS {
panic!("attempt to shift left with overflow")
}
self.value <<= rhs;
self.value &= Self::MASK;
}
Expand All @@ -634,10 +860,17 @@ where
impl<T, TSHIFTBITS, const BITS: usize> Shr<TSHIFTBITS> for UInt<T, BITS>
where
T: Copy + Shr<TSHIFTBITS, Output = T> + Sub<T, Output = T> + Shl<usize, Output = T> + From<u8>,
TSHIFTBITS: TryInto<usize> + Copy,
{
type Output = UInt<T, BITS>;

fn shr(self, rhs: TSHIFTBITS) -> Self::Output {
// With debug assertions, the << and >> operators throw an exception if the shift amount
// is larger than the number of bits (in which case the result would always be 0)
#[cfg(debug_assertions)]
if rhs.try_into().unwrap_or(usize::MAX) >= BITS {
panic!("attempt to shift left with overflow")
}
Self {
value: self.value >> rhs,
}
Expand All @@ -647,8 +880,15 @@ where
impl<T, TSHIFTBITS, const BITS: usize> ShrAssign<TSHIFTBITS> for UInt<T, BITS>
where
T: Copy + ShrAssign<TSHIFTBITS> + Sub<T, Output = T> + Shl<usize, Output = T> + From<u8>,
TSHIFTBITS: TryInto<usize> + Copy,
{
fn shr_assign(&mut self, rhs: TSHIFTBITS) {
// With debug assertions, the << and >> operators throw an exception if the shift amount
// is larger than the number of bits (in which case the result would always be 0)
#[cfg(debug_assertions)]
if rhs.try_into().unwrap_or(usize::MAX) >= BITS {
panic!("attempt to shift left with overflow")
}
self.value >>= rhs;
}
}
Expand Down
Loading
Loading