Skip to content

Commit

Permalink
#13 Release 0.0.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Bodigrim committed Jan 20, 2022
1 parent c22c048 commit 0c9bd46
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 436 deletions.
327 changes: 18 additions & 309 deletions Data/Mod.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
Expand All @@ -32,31 +31,15 @@ module Data.Mod

import Control.Exception
import Control.DeepSeq
import Control.Monad
import Data.Bits
import Data.Ratio
import Data.Word (Word8)
#ifdef MIN_VERSION_semirings
import Data.Euclidean (GcdDomain(..), Euclidean(..), Field)
import Data.Semiring (Semiring(..), Ring(..))
#endif
#ifdef MIN_VERSION_vector
import Control.Monad.Primitive
import Control.Monad.ST
import qualified Data.Primitive.Types as P
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as M
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Primitive as P
import Foreign (copyBytes)
import GHC.IO.Unsafe (unsafeDupablePerformIO)
#endif
import Foreign.Storable (Storable(..))
import GHC.Exts
import GHC.Generics
import GHC.Integer.GMP.Internals
import GHC.Natural (Natural(..), powModNatural)
import GHC.TypeNats (Nat, KnownNat, natVal, natVal')
import GHC.TypeNats (Nat, KnownNat, natVal)

-- | This data type represents
-- <https://en.wikipedia.org/wiki/Modular_arithmetic#Integers_modulo_n integers modulo m>,
Expand Down Expand Up @@ -106,82 +89,18 @@ instance KnownNat m => Bounded (Mod m) where
minBound = Mod 0
maxBound = let mx = Mod (natVal mx - 1) in mx

bigNatToNat :: BigNat -> Natural
bigNatToNat r# =
if isTrue# (sizeofBigNat# r# <=# 1#) then NatS# (bigNatToWord r#) else NatJ# r#

subIfGe :: BigNat -> BigNat -> Natural
subIfGe z# m# = case z# `compareBigNat` m# of
LT -> NatJ# z#
EQ -> NatS# 0##
GT -> bigNatToNat $ z# `minusBigNat` m#

#if !MIN_VERSION_base(4,12,0)
addWordC# :: Word# -> Word# -> (# Word#, Int# #)
addWordC# x# y# = (# z#, word2Int# c# #)
where
!(# c#, z# #) = x# `plusWord2#` y#
#endif

addMod :: Natural -> Natural -> Natural -> Natural
addMod (NatS# m#) (NatS# x#) (NatS# y#) =
if isTrue# c# || isTrue# (z# `geWord#` m#) then NatS# (z# `minusWord#` m#) else NatS# z#
where
!(# z#, c# #) = x# `addWordC#` y#
addMod NatS#{} _ _ = brokenInvariant
addMod (NatJ# m#) (NatS# x#) (NatS# y#) =
if isTrue# c# then subIfGe (wordToBigNat2 1## z#) m# else NatS# z#
where
!(# z#, c# #) = x# `addWordC#` y#
addMod (NatJ# m#) (NatS# x#) (NatJ# y#) = subIfGe (y# `plusBigNatWord` x#) m#
addMod (NatJ# m#) (NatJ# x#) (NatS# y#) = subIfGe (x# `plusBigNatWord` y#) m#
addMod (NatJ# m#) (NatJ# x#) (NatJ# y#) = subIfGe (x# `plusBigNat` y#) m#
addMod m x y = let z = x + y in if z >= m then z - m else z

subMod :: Natural -> Natural -> Natural -> Natural
subMod (NatS# m#) (NatS# x#) (NatS# y#) =
if isTrue# (x# `geWord#` y#) then NatS# z# else NatS# (z# `plusWord#` m#)
where
z# = x# `minusWord#` y#
subMod NatS#{} _ _ = brokenInvariant
subMod (NatJ# m#) (NatS# x#) (NatS# y#) =
if isTrue# (x# `geWord#` y#)
then NatS# (x# `minusWord#` y#)
else bigNatToNat $ m# `minusBigNatWord` (y# `minusWord#` x#)
subMod (NatJ# m#) (NatS# x#) (NatJ# y#) =
bigNatToNat $ (m# `minusBigNat` y#) `plusBigNatWord` x#
subMod NatJ#{} (NatJ# x#) (NatS# y#) =
bigNatToNat $ x# `minusBigNatWord` y#
subMod (NatJ# m#) (NatJ# x#) (NatJ# y#) = case x# `compareBigNat` y# of
LT -> bigNatToNat $ (m# `minusBigNat` y#) `plusBigNat` x#
EQ -> NatS# 0##
GT -> bigNatToNat $ x# `minusBigNat` y#
subMod m x y = if x >= y then x - y else m + x - y

negateMod :: Natural -> Natural -> Natural
negateMod _ (NatS# 0##) = NatS# 0##
negateMod (NatS# m#) (NatS# x#) = NatS# (m# `minusWord#` x#)
negateMod NatS#{} _ = brokenInvariant
negateMod (NatJ# m#) (NatS# x#) = bigNatToNat $ m# `minusBigNatWord` x#
negateMod (NatJ# m#) (NatJ# x#) = bigNatToNat $ m# `minusBigNat` x#
negateMod !_ 0 = 0
negateMod m x = m - x

mulMod :: Natural -> Natural -> Natural -> Natural
mulMod (NatS# m#) (NatS# x#) (NatS# y#) = NatS# r#
where
!(# z1#, z2# #) = timesWord2# x# y#
!(# _, r# #) = quotRemWord2# z1# z2# m#
mulMod NatS#{} _ _ = brokenInvariant
mulMod (NatJ# m#) (NatS# x#) (NatS# y#) =
bigNatToNat $ wordToBigNat2 z1# z2# `remBigNat` m#
where
!(# z1#, z2# #) = timesWord2# x# y#
mulMod (NatJ# m#) (NatS# x#) (NatJ# y#) =
bigNatToNat $ (y# `timesBigNatWord` x#) `remBigNat` m#
mulMod (NatJ# m#) (NatJ# x#) (NatS# y#) =
bigNatToNat $ (x# `timesBigNatWord` y#) `remBigNat` m#
mulMod (NatJ# m#) (NatJ# x#) (NatJ# y#) =
bigNatToNat $ (x# `timesBigNat` y#) `remBigNat` m#

brokenInvariant :: a
brokenInvariant = error "argument is larger than modulo"
mulMod m x y = (x * y) `Prelude.rem` m

instance KnownNat m => Num (Mod m) where
mx@(Mod !x) + (Mod !y) = Mod $ addMod (natVal mx) x y
Expand Down Expand Up @@ -276,6 +195,18 @@ invertMod mx
y = recipModInteger (toInteger (unMod mx)) (toInteger (natVal mx))
{-# INLINABLE invertMod #-}

recipModInteger :: Integer -> Integer -> Integer
recipModInteger x m = case gcdExt x m of
(1, s) -> s `mod` m
_ -> -1

gcdExt :: Integer -> Integer -> (Integer, Integer)
gcdExt = go 1 0
where
go s !_ r 0 = (r, s)
go s s' r r' = case Prelude.quotRem r r' of
(q, r'') -> go s' (s - q * s') r' r''

-- | Drop-in replacement for 'Prelude.^' with much better performance.
-- Negative powers are allowed, but may throw 'DivideByZero', if an argument
-- is not <https://en.wikipedia.org/wiki/Coprime_integers coprime> with the modulo.
Expand Down Expand Up @@ -321,225 +252,3 @@ mx ^% a
"powMod/3/Word" forall x. x ^% (3 :: Word) = let u = x in u*u*u #-}

infixr 8 ^%

wordSize :: Int
wordSize = finiteBitSize (0 :: Word)

lgWordSize :: Int
lgWordSize = case wordSize of
32 -> 2 -- 2^2 bytes in word
64 -> 3 -- 2^3 bytes in word
_ -> error "lgWordSize: unknown architecture"

instance KnownNat m => Storable (Mod m) where
sizeOf _ = case natVal' (proxy# :: Proxy# m) of
NatS#{} -> sizeOf (0 :: Word)
NatJ# m# -> I# (sizeofBigNat# m#) `shiftL` lgWordSize
{-# INLINE sizeOf #-}

alignment _ = alignment (0 :: Word)
{-# INLINE alignment #-}

peek (Ptr addr#) = case natVal' (proxy# :: Proxy# m) of
NatS#{} -> do
W# w# <- peek (Ptr addr#)
pure . Mod $! NatS# w#
NatJ# m# -> do
let !(I# lgWordSize#) = lgWordSize
sz# = sizeofBigNat# m# `iShiftL#` lgWordSize#
bn <- importBigNatFromAddr addr# (int2Word# sz#) 0#
pure . Mod $! bigNatToNat bn
{-# INLINE peek #-}

poke (Ptr addr#) (Mod x) = case natVal' (proxy# :: Proxy# m) of
NatS#{} -> case x of
NatS# x# -> poke (Ptr addr#) (W# x#)
_ -> brokenInvariant
NatJ# m# -> case x of
NatS# x# -> do
poke (Ptr addr#) (W# x#)
forM_ [1 .. sz - 1] $ \off ->
pokeElemOff (Ptr addr#) off (0 :: Word)
NatJ# bn -> do
l <- exportBigNatToAddr bn addr# 0#
forM_ [(fromIntegral :: Word -> Int) l .. (sz `shiftL` lgWordSize) - 1] $ \off ->
pokeElemOff (Ptr addr#) off (0 :: Word8)
where
sz = I# (sizeofBigNat# m#)
{-# INLINE poke #-}

#ifdef MIN_VERSION_vector

instance KnownNat m => P.Prim (Mod m) where
sizeOf# x = let !(I# sz#) = sizeOf x in sz#
{-# INLINE sizeOf# #-}

alignment# x = let !(I# a#) = alignment x in a#
{-# INLINE alignment# #-}

indexByteArray# arr# i' = case natVal' (proxy# :: Proxy# m) of
NatS#{} -> Mod (NatS# w#)
where
!(W# w#) = P.indexByteArray# arr# i'
NatJ# m# -> Mod $ bigNatToNat $ importBigNatFromByteArray arr# (int2Word# i#) (int2Word# sz#) 0#
where
!(I# lgWordSize#) = lgWordSize
sz# = sizeofBigNat# m# `iShiftL#` lgWordSize#
i# = i' *# sz#
{-# INLINE indexByteArray# #-}

indexOffAddr# arr# i' = case natVal' (proxy# :: Proxy# m) of
NatS#{} -> Mod (NatS# w#)
where
!(W# w#) = P.indexOffAddr# arr# i'
NatJ# m# -> Mod $ bigNatToNat $ unsafeDupablePerformIO $ importBigNatFromAddr (arr# `plusAddr#` i#) (int2Word# sz#) 0#
where
!(I# lgWordSize#) = lgWordSize
sz# = sizeofBigNat# m# `iShiftL#` lgWordSize#
i# = i' *# sz#
{-# INLINE indexOffAddr# #-}

readByteArray# marr !i' token = case natVal' (proxy# :: Proxy# m) of
NatS#{} -> case P.readByteArray# marr i' token of
(# newToken, W# w# #) -> (# newToken, Mod (NatS# w#) #)
NatJ# m# -> case unsafeFreezeByteArray# marr token of
(# newToken, arr #) -> (# newToken, Mod (bigNatToNat (importBigNatFromByteArray arr (int2Word# i#) (int2Word# sz#) 0#)) #)
where
!(I# lgWordSize#) = lgWordSize
sz# = sizeofBigNat# m# `iShiftL#` lgWordSize#
i# = i' *# sz#
{-# INLINE readByteArray# #-}

readOffAddr# marr !i' token = case natVal' (proxy# :: Proxy# m) of
NatS#{} -> case P.readOffAddr# marr i' token of
(# newToken, W# w# #) -> (# newToken, Mod (NatS# w#) #)
NatJ# m# -> case internal (unsafeIOToPrim (importBigNatFromAddr (marr `plusAddr#` i#) (int2Word# sz#) 0#) :: ST s BigNat) token of
(# newToken, bn #) -> (# newToken, Mod (bigNatToNat bn) #)
where
!(I# lgWordSize#) = lgWordSize
sz# = sizeofBigNat# m# `iShiftL#` lgWordSize#
i# = i' *# sz#
{-# INLINE readOffAddr# #-}

writeByteArray# marr !i' !(Mod x) token = case natVal' (proxy# :: Proxy# m) of
NatS#{} -> case x of
NatS# x# -> P.writeByteArray# marr i' (W# x#) token
_ -> error "argument is larger than modulo"
NatJ# m# -> case x of
NatS# x# -> case P.writeByteArray# marr i# (W# x#) token of
newToken -> P.setByteArray# marr (i# +# 1#) (sz# -# 1#) (0 :: Word) newToken
NatJ# bn -> case internal (unsafeIOToPrim (exportBigNatToMutableByteArray bn (unsafeCoerce# marr) (int2Word# (i# `iShiftL#` lgWordSize#)) 0#) :: ST s Word) token of
(# newToken, W# l# #) -> P.setByteArray# marr (i# `iShiftL#` lgWordSize# +# word2Int# l#) (sz# `iShiftL#` lgWordSize# -# word2Int# l#) (0 :: Word8) newToken
where
!(I# lgWordSize#) = lgWordSize
!sz@(I# sz#) = I# (sizeofBigNat# m#)
!(I# i#) = I# i' * sz
{-# INLINE writeByteArray# #-}

writeOffAddr# marr !i' !(Mod x) token = case natVal' (proxy# :: Proxy# m) of
NatS#{} -> case x of
NatS# x# -> P.writeOffAddr# marr i' (W# x#) token
_ -> error "argument is larger than modulo"
NatJ# m# -> case x of
NatS# x# -> case P.writeOffAddr# marr i# (W# x#) token of
newToken -> P.setOffAddr# marr (i# +# 1#) (sz# -# 1#) (0 :: Word) newToken
NatJ# bn -> case internal (unsafeIOToPrim (exportBigNatToAddr bn (marr `plusAddr#` (i# `iShiftL#` lgWordSize#)) 0#) :: ST s Word) token of
(# newToken, W# l# #) -> P.setOffAddr# marr (i# `iShiftL#` lgWordSize# +# word2Int# l#) (sz# `iShiftL#` lgWordSize# -# word2Int# l#) (0 :: Word8) newToken
where
!(I# lgWordSize#) = lgWordSize
!sz@(I# sz#) = I# (sizeofBigNat# m#)
!(I# i#) = I# i' * sz
{-# INLINE writeOffAddr# #-}

setByteArray# !_ !_ 0# !_ token = token
setByteArray# marr off len mx@(Mod x) token = case natVal' (proxy# :: Proxy# m) of
NatS#{} -> case x of
NatS# x# -> P.setByteArray# marr off len (W# x#) token
_ -> error "argument is larger than modulo"
NatJ# m# -> case P.writeByteArray# marr off mx token of
newToken -> doSet (sz `iShiftL#` lgWordSize#) newToken
where
!(I# lgWordSize#) = lgWordSize
sz = sizeofBigNat# m#
off' = (off *# sz) `iShiftL#` lgWordSize#
len' = (len *# sz) `iShiftL#` lgWordSize#
doSet i tkn
| isTrue# (2# *# i <# len') = case copyMutableByteArray# marr off' marr (off' +# i) i tkn of
tkn' -> doSet (2# *# i) tkn'
| otherwise = copyMutableByteArray# marr off' marr (off' +# i) (len' -# i) tkn
{-# INLINE setByteArray# #-}

setOffAddr# !_ !_ 0# !_ token = token
setOffAddr# marr off len mx@(Mod x) token = case natVal' (proxy# :: Proxy# m) of
NatS#{} -> case x of
NatS# x# -> P.setOffAddr# marr off len (W# x#) token
_ -> error "argument is larger than modulo"
NatJ# m# -> case P.writeOffAddr# marr off mx token of
newToken -> doSet (sz `iShiftL#` lgWordSize#) newToken
where
!(I# lgWordSize#) = lgWordSize
sz = sizeofBigNat# m#
off' = (off *# sz) `iShiftL#` lgWordSize#
len' = (len *# sz) `iShiftL#` lgWordSize#
doSet i tkn -- = tkn
| isTrue# (2# *# i <# len') = case internal (unsafeIOToPrim (copyBytes (Ptr (marr `plusAddr#` (off' +# i))) (Ptr (marr `plusAddr#` off')) (I# i)) :: ST s ()) tkn of
(# tkn', () #) -> doSet (2# *# i) tkn'
| otherwise = case internal (unsafeIOToPrim (copyBytes (Ptr (marr `plusAddr#` (off' +# i))) (Ptr (marr `plusAddr#` off')) (I# (len' -# i))) :: ST s ()) tkn of
(# tkn', () #) -> tkn'
{-# INLINE setOffAddr# #-}

-- | Unboxed vectors of 'Mod' cause more nursery allocations
-- than boxed ones, but reduce pressure on garbage collector,
-- especially for large vectors.
newtype instance U.MVector s (Mod m) = ModMVec (P.MVector s (Mod m))

-- | Unboxed vectors of 'Mod' cause more nursery allocations
-- than boxed ones, but reduce pressure on garbage collector,
-- especially for large vectors.
newtype instance U.Vector (Mod m) = ModVec (P.Vector (Mod m))

instance KnownNat m => U.Unbox (Mod m)

instance KnownNat m => M.MVector U.MVector (Mod m) where
{-# INLINE basicLength #-}
{-# INLINE basicUnsafeSlice #-}
{-# INLINE basicOverlaps #-}
{-# INLINE basicUnsafeNew #-}
{-# INLINE basicInitialize #-}
{-# INLINE basicUnsafeReplicate #-}
{-# INLINE basicUnsafeRead #-}
{-# INLINE basicUnsafeWrite #-}
{-# INLINE basicClear #-}
{-# INLINE basicSet #-}
{-# INLINE basicUnsafeCopy #-}
{-# INLINE basicUnsafeGrow #-}
basicLength (ModMVec v) = M.basicLength v
basicUnsafeSlice i n (ModMVec v) = ModMVec $ M.basicUnsafeSlice i n v
basicOverlaps (ModMVec v1) (ModMVec v2) = M.basicOverlaps v1 v2
basicUnsafeNew n = ModMVec `liftM` M.basicUnsafeNew n
basicInitialize (ModMVec v) = M.basicInitialize v
basicUnsafeReplicate n x = ModMVec `liftM` M.basicUnsafeReplicate n x
basicUnsafeRead (ModMVec v) i = M.basicUnsafeRead v i
basicUnsafeWrite (ModMVec v) i x = M.basicUnsafeWrite v i x
basicClear (ModMVec v) = M.basicClear v
basicSet (ModMVec v) x = M.basicSet v x
basicUnsafeCopy (ModMVec v1) (ModMVec v2) = M.basicUnsafeCopy v1 v2
basicUnsafeMove (ModMVec v1) (ModMVec v2) = M.basicUnsafeMove v1 v2
basicUnsafeGrow (ModMVec v) n = ModMVec `liftM` M.basicUnsafeGrow v n

instance KnownNat m => G.Vector U.Vector (Mod m) where
{-# INLINE basicUnsafeFreeze #-}
{-# INLINE basicUnsafeThaw #-}
{-# INLINE basicLength #-}
{-# INLINE basicUnsafeSlice #-}
{-# INLINE basicUnsafeIndexM #-}
{-# INLINE elemseq #-}
basicUnsafeFreeze (ModMVec v) = ModVec `liftM` G.basicUnsafeFreeze v
basicUnsafeThaw (ModVec v) = ModMVec `liftM` G.basicUnsafeThaw v
basicLength (ModVec v) = G.basicLength v
basicUnsafeSlice i n (ModVec v) = ModVec $ G.basicUnsafeSlice i n v
basicUnsafeIndexM (ModVec v) i = G.basicUnsafeIndexM v i
basicUnsafeCopy (ModMVec mv) (ModVec v) = G.basicUnsafeCopy mv v
elemseq _ = seq

#endif
Loading

0 comments on commit 0c9bd46

Please sign in to comment.