Skip to content

Commit

Permalink
Implement and use a new RecordMap type.
Browse files Browse the repository at this point in the history
This type stores records as a finite map from field names to
values, while also remembering the original order of the fields
from when the record was generated (usually, from the program source).
For all "semantic" purposes, the fields are treated as appearing in
a canoical order (in sorted order of the field names).  However, for
user display purposes, records are presented in the order in which
the fields were originally stated.

In the course of implementing this, I discovered that we were not
previously checking for repeated fields in the parser or typechecker,
which would result in some rather strange situations and could probably
be used to break the type safety. This is now fixed and repeated fields
will result in either a parse error or a panic (for records generated
internally).

Fixes #706
  • Loading branch information
robdockins committed Jun 30, 2020
1 parent 87d5eda commit 6c6cb94
Show file tree
Hide file tree
Showing 43 changed files with 474 additions and 277 deletions.
1 change: 1 addition & 0 deletions cryptol.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ library

Cryptol.Utils.Fixity,
Cryptol.Utils.Ident,
Cryptol.Utils.RecordMap,
Cryptol.Utils.PP,
Cryptol.Utils.Panic,
Cryptol.Utils.Debug,
Expand Down
27 changes: 16 additions & 11 deletions src/Cryptol/Eval.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ import Cryptol.TypeCheck.Solver.InfNat(Nat'(..))
import Cryptol.Utils.Ident
import Cryptol.Utils.Panic (panic)
import Cryptol.Utils.PP

import Cryptol.Utils.RecordMap

import Control.Monad
import Data.Functor.Identity
import Data.List
import Data.Maybe
import qualified Data.Map.Strict as Map
Expand Down Expand Up @@ -131,7 +132,7 @@ evalExpr sym env expr = case expr of
return $ VTuple xs

ERec fields -> {-# SCC "evalExpr->ERec" #-} do
xs <- mapM (sDelay sym Nothing . eval) (Map.fromList fields)
xs <- traverse (sDelay sym Nothing . eval) fields
return $ VRecord xs

ESel e sel -> {-# SCC "evalExpr->ESel" #-} do
Expand Down Expand Up @@ -336,7 +337,7 @@ isValueType env Forall{ sVars = [], sProps = [], sType = t0 }
go TVBit = True
go (TVSeq _ x) = go x
go (TVTuple xs) = and (map go xs)
go (TVRec xs) = and (map (go . snd) xs)
go (TVRec xs) = and (fmap go xs)
go _ = False

isValueType _ _ = False
Expand Down Expand Up @@ -414,7 +415,11 @@ etaDelay sym msg env0 Forall{ sVars = vs0, sType = tp0 } = goTpVars env0 vs0

VRecord fs
| TVRec fts <- tp
-> return $ VRecord (Map.intersectionWith go (Map.fromList fts) fs)
-> do let res = zipRecords (\_ v t -> go t v) fs fts
case res of
Left (Left f) -> evalPanic "type mismatch during eta-expansion" ["missing field " ++ show f]
Left (Right f) -> evalPanic "type mismatch during eta-expansion" ["unexpected field " ++ show f]
Right fs' -> return (VRecord fs')

VFun f
| TVFun _t1 t2 <- tp
Expand Down Expand Up @@ -461,10 +466,9 @@ etaDelay sym msg env0 Forall{ sVars = vs0, sType = tp0 } = goTpVars env0 vs0
TVRec fs ->
do v' <- sDelay sym (Just msg) (fromVRecord <$> v)
let err f = evalPanic "expected record value with field" [show f]
return $ VRecord $ Map.fromList
[ (f, go t =<< (fromMaybe (err f) . Map.lookup f <$> v'))
| (f,t) <- fs
]
let eta f t = Identity (go t =<< (fromMaybe (err f) . lookupField f <$> v'))
let fs' = runIdentity (traverseRecordMap eta fs)
return $ VRecord fs'

TVAbstract {} -> v

Expand Down Expand Up @@ -601,9 +605,10 @@ evalSetSel sym e sel v =

setRecord n =
case e of
VRecord xs
| Map.member n xs -> pure (VRecord (Map.insert n v xs))
| otherwise -> bad "Missing field in record update."
VRecord xs ->
case adjustField n (\_ -> v) xs of
Just xs' -> pure (VRecord xs')
Nothing -> bad "Missing field in record update."
_ -> bad "Record update on a non-record."

setList n =
Expand Down
17 changes: 8 additions & 9 deletions src/Cryptol/Eval/Concrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ module Cryptol.Eval.Concrete
, toExpr
) where

import Control.Monad (join,guard,zipWithM)
import Data.List(sortBy)
import Data.Ord(comparing)
import Control.Monad (join,guard,zipWithM,mzero)
import Data.Bits (Bits(..))
import Data.Ratio(numerator,denominator)
import MonadLib( ChoiceT, findOne, lift )
Expand All @@ -50,7 +48,7 @@ import Cryptol.Utils.Panic (panic)
import Cryptol.Utils.Ident (PrimIdent,prelPrim,floatPrim)
import Cryptol.Utils.PP
import Cryptol.Utils.Logger(logPrint)

import Cryptol.Utils.RecordMap


-- Value to Expression conversion ----------------------------------------------
Expand All @@ -68,11 +66,12 @@ toExpr prims t0 v0 = findOne (go t0 v0)

go :: AST.Type -> Value -> ChoiceT Eval Expr
go ty val = case (tNoUser ty, val) of
(TRec (sortBy (comparing fst) -> tfs), VRecord vfs) -> do
let fns = Map.keys vfs
guard (map fst tfs == fns)
fes <- zipWithM go (map snd tfs) =<< lift (sequence (Map.elems vfs))
return $ ERec (zip fns fes)
(TRec tfs, VRecord vfs) -> do
-- NB, vfs first argument to keep their display order
res <- zipRecordsM (\_lbl v t -> go t =<< lift v) vfs tfs
case res of
Left _ -> mzero -- different fields
Right efs -> pure (ERec efs)
(TCon (TC (TCTuple tl)) ts, VTuple tvs) -> do
guard (tl == (length tvs))
ETuple `fmap` (zipWithM go ts =<< lift (sequence tvs))
Expand Down
75 changes: 32 additions & 43 deletions src/Cryptol/Eval/Generic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ import Cryptol.Eval.Monad
import Cryptol.Eval.Type
import Cryptol.Eval.Value
import Cryptol.Utils.Panic (panic)

import qualified Data.Map.Strict as Map

import Cryptol.Utils.RecordMap



Expand Down Expand Up @@ -220,11 +218,10 @@ ringBinary sym opw opi opz opq opfp = loop

-- records
TVRec fs ->
do fs' <- sequence
[ (f,) <$> sDelay sym Nothing (loop' fty (lookupRecord f l) (lookupRecord f r))
| (f,fty) <- fs
]
return $ VRecord (Map.fromList fs')
do VRecord <$>
traverseRecordMap
(\f fty -> sDelay sym Nothing (loop' fty (lookupRecord f l) (lookupRecord f r)))
fs

TVAbstract {} ->
evalPanic "ringBinary" ["Abstract type not in `Ring`"]
Expand Down Expand Up @@ -297,11 +294,10 @@ ringUnary sym opw opi opz opq opfp = loop

-- records
TVRec fs ->
do fs' <- sequence
[ (f,) <$> sDelay sym Nothing (loop' fty (lookupRecord f v))
| (f,fty) <- fs
]
return $ VRecord (Map.fromList fs')
VRecord <$>
traverseRecordMap
(\f fty -> sDelay sym Nothing (loop' fty (lookupRecord f v)))
fs

TVAbstract {} -> evalPanic "ringUnary" ["Abstract type not in `Ring`"]

Expand Down Expand Up @@ -363,11 +359,8 @@ ringNullary sym opw opi opz opq opfp = loop
pure $ VTuple xs

TVRec fs ->
do xs <- sequence [ do v <- sDelay sym Nothing (loop a)
return (f, v)
| (f,a) <- fs
]
pure $ VRecord $ Map.fromList xs
do xs <- traverse (sDelay sym Nothing . loop) fs
pure $ VRecord xs

TVAbstract {} ->
evalPanic "ringNullary" ["Abstract type not in `Ring`"]
Expand Down Expand Up @@ -721,11 +714,11 @@ cmpValue sym fb fw fi fz fq ff = cmp
TVFun _ _ -> panic "Cryptol.Prims.Value.cmpValue"
[ "Functions are not comparable" ]
TVTuple tys -> cmpValues tys (fromVTuple v1) (fromVTuple v2) k
TVRec fields -> do let tys = Map.elems (Map.fromList fields)
cmpValues tys
(Map.elems (fromVRecord v1))
(Map.elems (fromVRecord v2))
k
TVRec fields -> cmpValues
(map snd (canonicalFields fields))
(map snd (canonicalFields (fromVRecord v1)))
(map snd (canonicalFields (fromVRecord v2)))
k
TVAbstract {} -> evalPanic "cmpValue"
[ "Abstract type not in `Cmp`" ]

Expand Down Expand Up @@ -894,11 +887,8 @@ zeroV sym ty = case ty of

-- records
TVRec fields ->
do xs <- sequence [ do z <- sDelay sym Nothing (zeroV sym fty)
pure (f, z)
| (f,fty) <- fields
]
pure $ VRecord (Map.fromList xs)
do xs <- traverse (sDelay sym Nothing . zeroV sym) fields
pure $ VRecord xs

TVAbstract {} -> evalPanic "zeroV" [ "Abstract type not in `Zero`" ]

Expand Down Expand Up @@ -1306,13 +1296,10 @@ logicBinary sym opb opw = loop
return $ lam $ \ a -> loop' bty (fromVFun l a) (fromVFun r a)

TVRec fields ->
do fs <- sequence
[ (f,) <$> sDelay sym Nothing (loop' fty a b)
| (f,fty) <- fields
, let a = lookupRecord f l
b = lookupRecord f r
]
return $ VRecord $ Map.fromList fs
VRecord <$>
traverseRecordMap
(\f fty -> sDelay sym Nothing (loop' fty (lookupRecord f l) (lookupRecord f r)))
fields

TVAbstract {} -> evalPanic "logicBinary"
[ "Abstract type not in `Logic`" ]
Expand Down Expand Up @@ -1378,11 +1365,10 @@ logicUnary sym opb opw = loop
return $ lam $ \ a -> loop' bty (fromVFun val a)

TVRec fields ->
do fs <- sequence
[ (f,) <$> sDelay sym Nothing (loop' fty a)
| (f,fty) <- fields, let a = lookupRecord f val
]
return $ VRecord $ Map.fromList fs
VRecord <$>
traverseRecordMap
(\f fty -> sDelay sym Nothing (loop' fty (lookupRecord f val)))
fields

TVAbstract {} -> evalPanic "logicUnary" [ "Abstract type not in `Logic`" ]

Expand Down Expand Up @@ -1818,7 +1804,7 @@ errorV sym ty msg = case ty of

-- records
TVRec fields ->
return $ VRecord $ fmap (\t -> errorV sym t msg) $ Map.fromList fields
return $ VRecord $ fmap (\t -> errorV sym t msg) $ fields

TVAbstract {} -> cryUserError sym msg

Expand Down Expand Up @@ -1892,8 +1878,11 @@ mergeValue :: Backend sym =>
SEval sym (GenValue sym)
mergeValue sym c v1 v2 =
case (v1, v2) of
(VRecord fs1 , VRecord fs2 ) | Map.keys fs1 == Map.keys fs2 ->
pure $ VRecord $ Map.intersectionWith (mergeValue' sym c) fs1 fs2
(VRecord fs1 , VRecord fs2 ) ->
do let res = zipRecords (\_lbl -> mergeValue' sym c) fs1 fs2
case res of
Left f -> panic "Cryptol.Eval.Generic" [ "mergeValue: incompatible record values", show f ]
Right r -> pure (VRecord r)
(VTuple vs1 , VTuple vs2 ) | length vs1 == length vs2 ->
pure $ VTuple $ zipWith (mergeValue' sym c) vs1 vs2
(VBit b1 , VBit b2 ) -> VBit <$> iteBit sym c b1 b2
Expand Down
23 changes: 12 additions & 11 deletions src/Cryptol/Eval/Reference.lhs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
> import Cryptol.Utils.Ident (Ident,PrimIdent, prelPrim, floatPrim)
> import Cryptol.Utils.Panic (panic)
> import Cryptol.Utils.PP
> import Cryptol.Utils.RecordMap
>
> import qualified Cryptol.ModuleSystem as M
> import qualified Cryptol.ModuleSystem.Env as M (loadedModules)
Expand Down Expand Up @@ -190,7 +191,7 @@ cpo that represents any given schema.
> TVSeq w ety -> VList (Nat w) (map (go ety) (copyList w (fromVList val)))
> TVStream ety -> VList Inf (map (go ety) (copyStream (fromVList val)))
> TVTuple etys -> VTuple (zipWith go etys (copyList (genericLength etys) (fromVTuple val)))
> TVRec fields -> VRecord [ (f, go fty (lookupRecord f val)) | (f, fty) <- fields ]
> TVRec fields -> VRecord [ (f, go fty (lookupRecord f val)) | (f, fty) <- canonicalFields fields ]
> TVFun _ bty -> VFun (\v -> go bty (fromVFun val v))
> TVAbstract {} -> val
>
Expand Down Expand Up @@ -326,7 +327,7 @@ assigns values to those variables.
>
> EList es _ty -> VList (Nat (genericLength es)) [ evalExpr env e | e <- es ]
> ETuple es -> VTuple [ evalExpr env e | e <- es ]
> ERec fields -> VRecord [ (f, evalExpr env e) | (f, e) <- fields ]
> ERec fields -> VRecord [ (f, evalExpr env e) | (f, e) <- canonicalFields fields ]
> ESel e sel -> evalSel (evalExpr env e) sel
> ESet e sel v -> evalSet (evalExpr env e) sel (evalExpr env v)
>
Expand Down Expand Up @@ -890,7 +891,7 @@ where the given error is "pushed down" into the leaf types.
> cryError e (TVSeq n ety) = VList (Nat n) (genericReplicate n (cryError e ety))
> cryError e (TVStream ety) = VList Inf (repeat (cryError e ety))
> cryError e (TVTuple tys) = VTuple (map (cryError e) tys)
> cryError e (TVRec fields) = VRecord [ (f, cryError e fty) | (f, fty) <- fields ]
> cryError e (TVRec fields) = VRecord [ (f, cryError e fty) | (f, fty) <- canonicalFields fields ]
> cryError e (TVFun _ bty) = VFun (\_ -> cryError e bty)
> cryError _ (TVAbstract{}) = evalPanic "error" ["Abstract type encountered in `error`"]
Expand All @@ -916,7 +917,7 @@ For functions, `zero` returns the constant function that returns
> zero (TVSeq n ety) = VList (Nat n) (genericReplicate n (zero ety))
> zero (TVStream ety) = VList Inf (repeat (zero ety))
> zero (TVTuple tys) = VTuple (map zero tys)
> zero (TVRec fields) = VRecord [ (f, zero fty) | (f, fty) <- fields ]
> zero (TVRec fields) = VRecord [ (f, zero fty) | (f, fty) <- canonicalFields fields ]
> zero (TVFun _ bty) = VFun (\_ -> zero bty)
> zero (TVAbstract{}) = evalPanic "zero" ["Abstract type not in `Zero`"]
Expand Down Expand Up @@ -975,7 +976,7 @@ at the same positions.
> TVSeq w ety -> VList (Nat w) (map (go ety) (fromVList val))
> TVStream ety -> VList Inf (map (go ety) (fromVList val))
> TVTuple etys -> VTuple (zipWith go etys (fromVTuple val))
> TVRec fields -> VRecord [ (f, go fty (lookupRecord f val)) | (f, fty) <- fields ]
> TVRec fields -> VRecord [ (f, go fty (lookupRecord f val)) | (f, fty) <- canonicalFields fields ]
> TVFun _ bty -> VFun (\v -> go bty (fromVFun val v))
> TVInteger -> evalPanic "logicUnary" ["Integer not in class Logic"]
> TVIntMod _ -> evalPanic "logicUnary" ["Z not in class Logic"]
Expand All @@ -995,7 +996,7 @@ at the same positions.
> TVStream ety -> VList Inf (zipWith (go ety) (fromVList l) (fromVList r))
> TVTuple etys -> VTuple (zipWith3 go etys (fromVTuple l) (fromVTuple r))
> TVRec fields -> VRecord [ (f, go fty (lookupRecord f l) (lookupRecord f r))
> | (f, fty) <- fields ]
> | (f, fty) <- canonicalFields fields ]
> TVFun _ bty -> VFun (\v -> go bty (fromVFun l v) (fromVFun r v))
> TVInteger -> evalPanic "logicBinary" ["Integer not in class Logic"]
> TVIntMod _ -> evalPanic "logicBinary" ["Z not in class Logic"]
Expand Down Expand Up @@ -1047,7 +1048,7 @@ False]`, but to `[error "foo", error "foo"]`.
> TVTuple tys ->
> VTuple (map go tys)
> TVRec fs ->
> VRecord [ (f, go fty) | (f, fty) <- fs ]
> VRecord [ (f, go fty) | (f, fty) <- canonicalFields fs ]
> TVAbstract {} ->
> evalPanic "arithNullary" ["Abstract type not in `Ring`"]
Expand Down Expand Up @@ -1083,7 +1084,7 @@ False]`, but to `[error "foo", error "foo"]`.
> TVTuple tys ->
> VTuple (zipWith go tys (fromVTuple val))
> TVRec fs ->
> VRecord [ (f, go fty (lookupRecord f val)) | (f, fty) <- fs ]
> VRecord [ (f, go fty (lookupRecord f val)) | (f, fty) <- canonicalFields fs ]
> TVAbstract {} ->
> evalPanic "arithUnary" ["Abstract type not in `Ring`"]
Expand Down Expand Up @@ -1119,7 +1120,7 @@ False]`, but to `[error "foo", error "foo"]`.
> TVTuple tys ->
> VTuple (zipWith3 go tys (fromVTuple l) (fromVTuple r))
> TVRec fs ->
> VRecord [ (f, go fty (lookupRecord f l) (lookupRecord f r)) | (f, fty) <- fs ]
> VRecord [ (f, go fty (lookupRecord f l) (lookupRecord f r)) | (f, fty) <- canonicalFields fs ]
> TVAbstract {} ->
> evalPanic "arithBinary" ["Abstract type not in class `Ring`"]
Expand Down Expand Up @@ -1277,7 +1278,7 @@ bits to the *left* of that position are equal.
> TVTuple etys ->
> lexList (zipWith3 lexCompare etys (fromVTuple l) (fromVTuple r))
> TVRec fields ->
> let tys = map snd (sortBy (comparing fst) fields)
> let tys = map snd (canonicalFields fields)
> ls = map snd (sortBy (comparing fst) (fromVRecord l))
> rs = map snd (sortBy (comparing fst) (fromVRecord r))
> in lexList (zipWith3 lexCompare tys ls rs)
Expand Down Expand Up @@ -1328,7 +1329,7 @@ fields are compared in alphabetical order.
> TVTuple etys ->
> lexList (zipWith3 lexSignedCompare etys (fromVTuple l) (fromVTuple r))
> TVRec fields ->
> let tys = map snd (sortBy (comparing fst) fields)
> let tys = map snd (canonicalFields fields)
> ls = map snd (sortBy (comparing fst) (fromVRecord l))
> rs = map snd (sortBy (comparing fst) (fromVRecord r))
> in lexList (zipWith3 lexSignedCompare tys ls rs)
Expand Down
Loading

0 comments on commit 6c6cb94

Please sign in to comment.