diff --git a/src/Cryptol/TypeCheck/SimpType.hs b/src/Cryptol/TypeCheck/SimpType.hs index af3317f62..1892d8c3c 100644 --- a/src/Cryptol/TypeCheck/SimpType.hs +++ b/src/Cryptol/TypeCheck/SimpType.hs @@ -142,7 +142,8 @@ tMul x y | Just n <- tIsNum x = mulK n y | Just n <- tIsNum y = mulK n x | Just v <- matchMaybe swapVars = v - | otherwise = tf2 TCMul x y + | otherwise = checkExpMul x y + where mulK 0 _ = tNum (0 :: Int) mulK 1 t = t @@ -154,8 +155,11 @@ tMul x y , Just b' <- tIsNum b -- XXX: similar for a = b * k? , n == b' = tSub a (tMod a b) - - + -- c * c ^ x = c ^ (1 + x) + | TCon (TF TCExp) [a,b] <- t' + , Just n' <- tIsNum a + , n == n' = tf2 TCExp a (tAdd (tNum (1::Int)) b) + -- c^x * c^y = c ^ (y + x) | otherwise = tf2 TCMul (tNum n) t where t' = tNoUser t @@ -163,6 +167,14 @@ tMul x y b <- aTVar y guard (b < a) return (tf2 TCMul y x) + + -- Check if (K^a * K^b) => K^(a + b) otherwise default to standard mul + checkExpMul s t | TCon (TF TCExp) [a,aExp] <- s + , Just a' <- tIsNum a + , TCon (TF TCExp) [b,bExp] <- t + , Just b' <- tIsNum b + , (a' >= 2 && a' == b') = tf2 TCExp a (tAdd aExp bExp) + | otherwise = tf2 TCMul x y diff --git a/src/Cryptol/TypeCheck/Solver/Numeric.hs b/src/Cryptol/TypeCheck/Solver/Numeric.hs index 60123183a..1ea8dea23 100644 --- a/src/Cryptol/TypeCheck/Solver/Numeric.hs +++ b/src/Cryptol/TypeCheck/Solver/Numeric.hs @@ -48,6 +48,7 @@ cryIsEqual ctxt t1 t2 = <|> tryCancelVar ctxt (=#=) t1 t2 <|> tryLinearSolution t1 t2 <|> tryLinearSolution t2 t1 + <|> tryEqExp t1 t2 -- | Try to solve @t1 /= t2@ cryIsNotEqual :: Ctxt -> Type -> Type -> Solved @@ -67,6 +68,7 @@ cryIsGeq i t1 t2 = <|> tryAddConst (>==) t1 t2 <|> tryCancelVar i (>==) t1 t2 <|> tryMinIsGeq t1 t2 + <|> tryGeqExp i t1 t2 -- XXX: k >= width e -- XXX: width e >= k @@ -137,6 +139,17 @@ tryGeqThanK _ t (Nat k) = -- XXX: K1 ^^ n >= K2 +-- (K >= 2 && K^a >= K^b) => a >= b +tryGeqExp :: Ctxt -> Type -> Type -> Match Solved +tryGeqExp _ x y = + do (k_1, a) <- (|^|) x + n <- aNat k_1 + guard (n >= 2) + (k_2, b) <- (|^|) y + guard (k_1 == k_2) + return $ SolvedIf [ a >== b ] + + tryGeqThanSub :: Ctxt -> Type -> Type -> Match Solved tryGeqThanSub _ x y = @@ -223,6 +236,19 @@ tryCancelVar ctxt p t1 t2 = +-- if (K >= 2) && K^a = K^b => a = b +tryEqExp :: Type -> Type -> Match Solved +tryEqExp x y = check x y <|> check y x + where + check i j = + do + (k_1, a) <- (|^|) i + n <- aNat k_1 + guard (n >= 2) + (k_2, b) <- (|^|) j + guard (k_1 == k_2) + return $ SolvedIf [ a =#= b ] + -- min t1 t2 = t1 ~> t1 <= t2 tryEqMin :: Type -> Type -> Match Solved tryEqMin x y = diff --git a/tests/issues/issue1489/issue1489.cry b/tests/issues/issue1489/issue1489.cry new file mode 100644 index 000000000..bac33e1a9 --- /dev/null +++ b/tests/issues/issue1489/issue1489.cry @@ -0,0 +1,54 @@ +module ID where + +id : {k} (fin k, k > 0) => [2^^k] -> [2^^k] +id x = join(split`{2,2^^(k-1)}x) + +type q = 3329 + +ct_butterfly : + {m, hm} + (m >= 2, m <= 8, hm >= 1, hm <= 7, hm == m - 1) => + [2^^m](Z q) -> (Z q) -> [2^^m](Z q) +ct_butterfly v z = new_v + where + halflen = 2^^`hm + lower, upper : [2^^hm](Z q) + lower@x = v@x + z * v@(x + halflen) + upper@x = v@x - z * v@(x + halflen) + new_v = lower # upper + +zeta_expc : [128](Z q) +zeta_expc = [ 1, 1729, 2580, 3289, 2642, 630, 1897, 848, + 1062, 1919, 193, 797, 2786, 3260, 569, 1746, + 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, + 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, + 289, 331, 3253, 1756, 1197, 2304, 2277, 2055, + 650, 1977, 2513, 632, 2865, 33, 1320, 1915, + 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, + 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, + 17, 2761, 583, 2649, 1637, 723, 2288, 1100, + 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, + 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, + 939, 2308, 2437, 2388, 733, 2337, 268, 641, + 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, + 1063, 319, 2773, 757, 2099, 561, 2466, 2594, + 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, + 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154 ] + +fast_nttl : + {lv} // Length of v is a member of {256,128,64,32,16,8,4} + (lv >= 2, lv <= 8) => + [2^^lv](Z q) -> [8] -> [2^^lv](Z q) +fast_nttl v k + // Base case. lv==2 so just compute the butterfly and return + | lv == 2 => ct_butterfly`{lv,lv-1} v (zeta_expc@k) + + // Recursive case. Butterfly what we have, then recurse on each half, + // concatenate the results and return. As above, we need coerceSize + // here (twice) to satisfy the type checker. + | lv > 2 => (fast_nttl`{lv-1} s0 (k * 2)) # + (fast_nttl`{lv-1} s1 (k * 2 + 1)) + where + t = ct_butterfly`{lv,lv-1} v (zeta_expc@k) + // Split t into two halves s0 and s1 + [s0, s1] = split t \ No newline at end of file diff --git a/tests/issues/issue1489/issue1489.icry b/tests/issues/issue1489/issue1489.icry new file mode 100644 index 000000000..972ee633e --- /dev/null +++ b/tests/issues/issue1489/issue1489.icry @@ -0,0 +1 @@ +:load ./issue1489.cry \ No newline at end of file diff --git a/tests/issues/issue1489/issue1489.icry.stdout b/tests/issues/issue1489/issue1489.icry.stdout new file mode 100644 index 000000000..0f044dbc8 --- /dev/null +++ b/tests/issues/issue1489/issue1489.icry.stdout @@ -0,0 +1,3 @@ +Loading module Cryptol +Loading module Cryptol +Loading module ID