diff --git a/compiler/src/Type/Occurs.hs b/compiler/src/Type/Occurs.hs
index 63092f4b4..b15f7baf6 100644
--- a/compiler/src/Type/Occurs.hs
+++ b/compiler/src/Type/Occurs.hs
@@ -24,57 +24,58 @@ occurs var =
occursHelp :: [Type.Variable] -> Type.Variable -> Bool -> IO Bool
occursHelp seen var foundCycle =
- if elem var seen then
+ if var `elem` seen then
return True
-
else
- do (Descriptor content _ _ _) <- UF.get var
+ do (Descriptor content _ _ _ _) <- UF.get var
case content of
- FlexVar _ ->
- return foundCycle
-
- FlexSuper _ _ ->
- return foundCycle
-
- RigidVar _ ->
- return foundCycle
-
- RigidSuper _ _ ->
- return foundCycle
-
+ FlexVar _ -> return foundCycle
+ FlexSuper _ _ -> return foundCycle
+ RigidVar _ -> return foundCycle
+ RigidSuper _ _ -> return foundCycle
Structure term ->
- let newSeen = var : seen in
- case term of
- App1 _ _ args ->
- foldrM (occursHelp newSeen) foundCycle args
-
- Fun1 a b ->
- occursHelp newSeen a =<<
- occursHelp newSeen b foundCycle
-
- EmptyRecord1 ->
- return foundCycle
-
- Record1 fields ext ->
- occursHelp newSeen ext =<<
- foldrM (occursHelp newSeen) foundCycle (Map.elems fields)
-
- Unit1 ->
- return foundCycle
-
- Tuple1 a b maybeC ->
- case maybeC of
- Nothing ->
- occursHelp newSeen a =<<
- occursHelp newSeen b foundCycle
-
- Just c ->
- occursHelp newSeen a =<<
- occursHelp newSeen b =<<
- occursHelp newSeen c foundCycle
+ case term of
+ App1 _ _ args ->
+ foldrM (occursHelp (var : seen)) foundCycle args
+
+ Fun1 arg result ->
+ do cycleInArg <- occursHelp (var : seen) arg foundCycle
+ if cycleInArg
+ then return True
+ else occursHelp (var : seen) result foundCycle
+
+ EmptyRecord1 ->
+ return foundCycle
+
+ Record1 fields extension ->
+ do cycleInFields <- foldrM (occursHelp (var : seen)) foundCycle (Map.elems fields)
+ if cycleInFields
+ then return True
+ else occursHelp (var : seen) extension foundCycle
+
+ Unit1 ->
+ return foundCycle
+
+ Tuple1 a b maybeC ->
+ case maybeC of
+ Nothing ->
+ do cycleInA <- occursHelp (var : seen) a foundCycle
+ if cycleInA
+ then return True
+ else occursHelp (var : seen) b foundCycle
+
+ Just c ->
+ do cycleInA <- occursHelp (var : seen) a foundCycle
+ if cycleInA
+ then return True
+ else do
+ cycleInB <- occursHelp (var : seen) b foundCycle
+ if cycleInB
+ then return True
+ else occursHelp (var : seen) c foundCycle
Alias _ _ args _ ->
- foldrM (occursHelp (var:seen)) foundCycle (map snd args)
+ foldrM (occursHelp (var : seen)) foundCycle (map snd args)
Error ->
- return foundCycle
+ return foundCycle
diff --git a/compiler/src/Type/Solve.hs b/compiler/src/Type/Solve.hs
index 252b62e4b..c539ce0a1 100644
--- a/compiler/src/Type/Solve.hs
+++ b/compiler/src/Type/Solve.hs
@@ -156,23 +156,30 @@ solve env rank pools state constraint =
foldM occurs state2 $ Map.toList locals
CLet rigids flexs header headerCon subCon ->
- do
- -- work in the next pool to localize header
- let nextRank = rank + 1
+ do let nextRank = rank + 1
let poolsLength = MVector.length pools
nextPools <-
if nextRank < poolsLength
then return pools
else MVector.grow pools poolsLength
- -- introduce variables
let vars = rigids ++ flexs
- forM_ vars $ \var ->
- UF.modify var $ \(Descriptor content _ mark copy) ->
- Descriptor content nextRank mark copy
+ -- First, set all rigid variables to noRank immediately
+ forM_ rigids $ \var ->
+ do desc <- UF.get var
+ case desc of
+ Descriptor content _ mark copy expansiveness ->
+ UF.set var $ Descriptor content noRank mark copy expansiveness
+
+ -- Then handle the flex variables normally
+ forM_ flexs $ \var ->
+ do desc <- UF.get var
+ case desc of
+ Descriptor content _ mark copy expansiveness ->
+ UF.set var $ Descriptor content nextRank mark copy expansiveness
+
MVector.write nextPools nextRank vars
- -- run solver in next pool
locals <- traverse (A.traverse (typeToVariable nextRank nextPools)) header
(State savedEnv mark errors) <-
solve env nextRank nextPools state headerCon
@@ -181,12 +188,10 @@ solve env rank pools state constraint =
let visitMark = nextMark youngMark
let finalMark = nextMark visitMark
- -- pop pool
generalize youngMark visitMark nextRank nextPools
MVector.write nextPools nextRank []
- -- check that things went well
- mapM_ isGeneric rigids
+ mapM_ (checkGeneric rigids) vars
let newEnv = Map.union env (Map.map A.toValue locals)
let tempState = State savedEnv finalMark errors
@@ -196,19 +201,9 @@ solve env rank pools state constraint =
-- Check that a variable has rank == noRank, meaning that it can be generalized.
-isGeneric :: Variable -> IO ()
-isGeneric var =
- do (Descriptor _ rank _ _) <- UF.get var
- if rank == noRank
- then return ()
- else
- do tipe <- Type.toErrorType var
- error $
- "You ran into a compiler bug. Here are some details for the developers:\n\n"
- ++ " " ++ show (ET.toDoc L.empty RT.None tipe) ++ " [rank = " ++ show rank ++ "]\n\n"
- ++
- "Please create an and then report it\n\
- \at \n\n"
+isGeneric :: Int -> Int -> Bool
+isGeneric rank groupRank =
+ rank <= 3 && rank < groupRank
@@ -259,8 +254,8 @@ occurs state (name, A.At region variable) =
if hasOccurred
then
do errorType <- Type.toErrorType variable
- (Descriptor _ rank mark copy) <- UF.get variable
- UF.set variable (Descriptor Error rank mark copy)
+ (Descriptor _ rank mark copy expansiveness) <- UF.get variable
+ UF.set variable (Descriptor Error rank mark copy expansiveness)
return $ addError state (Error.InfiniteType region name errorType)
else
return state
@@ -293,31 +288,58 @@ generalize youngMark visitMark youngRank pools =
if isRedundant
then return ()
else
- do (Descriptor _ rank _ _) <- UF.get var
+ do (Descriptor _ rank _ _ _) <- UF.get var
MVector.modify pools (var:) rank
-- For variables with rank youngRank
-- If rank < youngRank: register in oldPool
- -- otherwise generalize
+ -- otherwise generalize based on expansiveness and rigidity
forM_ (Vector.unsafeLast rankTable) $ \var ->
do isRedundant <- UF.redundant var
if isRedundant
then return ()
else
- do (Descriptor content rank mark copy) <- UF.get var
- if rank < youngRank
- then MVector.modify pools (var:) rank
- else UF.set var $ Descriptor content noRank mark copy
+ do (Descriptor content rank mark copy expansiveness) <- UF.get var
+ case content of
+ RigidVar _ ->
+ -- Rigid variables should always be generalized
+ UF.set var $ Descriptor content noRank mark copy expansiveness
+ RigidSuper _ _ ->
+ -- Rigid super types should always be generalized
+ UF.set var $ Descriptor content noRank mark copy expansiveness
+ _ ->
+ case expansiveness of
+ NonExpansive ->
+ -- Non-expansive expressions can always be generalized
+ UF.set var $ Descriptor content noRank mark copy expansiveness
+ Expansive ->
+ -- Expansive expressions should be generalized if they're safe
+ if isSafeToGeneralize content
+ then UF.set var $ Descriptor content noRank mark copy expansiveness
+ else MVector.modify pools (var:) rank
+
+isSafeToGeneralize :: Content -> Bool
+isSafeToGeneralize content =
+ case content of
+ Structure (Fun1 _ _) -> True -- Function types are always safe
+ Structure Unit1 -> True -- Unit type is safe
+ Structure EmptyRecord1 -> True -- Empty record is safe
+ FlexVar _ -> True -- Type variables are safe
+ RigidVar _ -> True -- Rigid variables are safe
+ FlexSuper _ _ -> True -- Super types are safe
+ RigidSuper _ _ -> True -- Rigid super types are safe
+ Structure (App1 _ _ _) -> True -- Type applications are safe for rank-3
+ _ -> False -- Conservative: treat other types as unsafe
poolToRankTable :: Mark -> Int -> [Variable] -> IO (Vector.Vector [Variable])
-poolToRankTable youngMark youngRank youngInhabitants =
+poolToRankTable youngMark youngRank youngVars =
do mutableTable <- MVector.replicate (youngRank + 1) []
-- Sort the youngPool variables into buckets by rank.
- forM_ youngInhabitants $ \var ->
- do (Descriptor content rank _ copy) <- UF.get var
- UF.set var (Descriptor content rank youngMark copy)
+ forM_ youngVars $ \var ->
+ do (Descriptor content rank _ copy expansiveness) <- UF.get var
+ UF.set var (Descriptor content rank youngMark copy expansiveness)
MVector.modify mutableTable (var:) rank
Vector.unsafeFreeze mutableTable
@@ -332,12 +354,12 @@ poolToRankTable youngMark youngRank youngInhabitants =
--
adjustRank :: Mark -> Mark -> Int -> Variable -> IO Int
adjustRank youngMark visitMark groupRank var =
- do (Descriptor content rank mark copy) <- UF.get var
+ do (Descriptor content rank mark copy expansiveness) <- UF.get var
if mark == youngMark then
do -- Set the variable as marked first because it may be cyclic.
- UF.set var $ Descriptor content rank visitMark copy
+ UF.set var $ Descriptor content rank visitMark copy expansiveness
maxRank <- adjustRankContent youngMark visitMark groupRank content
- UF.set var $ Descriptor content maxRank visitMark copy
+ UF.set var $ Descriptor content maxRank visitMark copy expansiveness
return maxRank
else if mark == visitMark then
@@ -346,7 +368,7 @@ adjustRank youngMark visitMark groupRank var =
else
do let minRank = min groupRank rank
-- TODO how can minRank ever be groupRank?
- UF.set var $ Descriptor content minRank visitMark copy
+ UF.set var $ Descriptor content minRank visitMark copy expansiveness
return minRank
@@ -412,10 +434,16 @@ adjustRankContent youngMark visitMark groupRank content =
introduce :: Int -> Pools -> [Variable] -> IO ()
introduce rank pools variables =
- do MVector.modify pools (variables++) rank
- forM_ variables $ \var ->
- UF.modify var $ \(Descriptor content _ mark copy) ->
- Descriptor content rank mark copy
+ do let assignRank var = do
+ desc <- UF.get var
+ let newRank = case desc of
+ Descriptor _ _ _ _ NonExpansive -> noRank
+ Descriptor _ _ _ _ Expansive -> rank
+ case desc of
+ Descriptor content _ mark copy expansiveness ->
+ UF.set var $ Descriptor content newRank mark copy expansiveness
+ mapM_ assignRank variables
+ MVector.modify pools (variables++) rank
@@ -478,7 +506,10 @@ typeToVar rank pools aliasDict tipe =
register :: Int -> Pools -> Content -> IO Variable
register rank pools content =
- do var <- UF.fresh (Descriptor content rank noMark Nothing)
+ do let expansiveness = if isNonExpansive content
+ then NonExpansive
+ else Expansive
+ var <- UF.fresh (Descriptor content rank noMark Nothing expansiveness)
MVector.modify pools (var:) rank
return var
@@ -510,7 +541,7 @@ srcTypeToVariable rank pools freeVars srcType =
| otherwise = FlexVar (Just name)
makeVar name _ =
- UF.fresh (Descriptor (nameToContent name) rank noMark Nothing)
+ UF.fresh $ Descriptor (nameToContent name) rank noMark Nothing NonExpansive
in
do flexVars <- Map.traverseWithKey makeVar freeVars
MVector.modify pools (Map.elems flexVars ++) rank
@@ -581,7 +612,7 @@ makeCopy rank pools var =
makeCopyHelp :: Int -> Pools -> Variable -> IO Variable
makeCopyHelp maxRank pools variable =
- do (Descriptor content rank _ maybeCopy) <- UF.get variable
+ do (Descriptor content rank _ maybeCopy expansiveness) <- UF.get variable
case maybeCopy of
Just copy ->
@@ -592,7 +623,7 @@ makeCopyHelp maxRank pools variable =
return variable
else
- do let makeDescriptor c = Descriptor c maxRank noMark Nothing
+ do let makeDescriptor c = Descriptor c maxRank noMark Nothing expansiveness
copy <- UF.fresh $ makeDescriptor content
MVector.modify pools (copy:) maxRank
@@ -601,7 +632,7 @@ makeCopyHelp maxRank pools variable =
--
-- Need to do this before recursively copying to avoid looping.
UF.set variable $
- Descriptor content rank noMark (Just copy)
+ Descriptor content rank noMark (Just copy) expansiveness
-- Now we recursively copy the content of the variable.
-- We have already marked the variable as copied, so we
@@ -642,13 +673,13 @@ makeCopyHelp maxRank pools variable =
restore :: Variable -> IO ()
restore variable =
- do (Descriptor content _ _ maybeCopy) <- UF.get variable
+ do (Descriptor content _ _ maybeCopy expansiveness) <- UF.get variable
case maybeCopy of
Nothing ->
return ()
Just _ ->
- do UF.set variable $ Descriptor content noRank noMark Nothing
+ do UF.set variable $ Descriptor content noRank noMark Nothing expansiveness
restoreContent content
@@ -725,3 +756,49 @@ traverseFlatType f flatType =
Tuple1 a b cs ->
liftM3 Tuple1 (f a) (f b) (traverse f cs)
+
+
+isAccumulatorType :: Type -> Bool
+isAccumulatorType tipe =
+ case tipe of
+ VarN _ -> True
+ AppN _ _ args -> any isAccumulatorType args
+ FunN _ _ -> False
+ AliasN _ _ _ _ -> False
+ PlaceHolder _ -> False
+ EmptyRecordN -> False
+ RecordN _ _ -> False
+ UnitN -> False
+ TupleN _ _ _ -> False
+
+
+-- Add helper to determine expansiveness
+isNonExpansive :: Content -> Bool
+isNonExpansive content =
+ case content of
+ FlexVar _ -> True
+ RigidVar _ -> True
+ FlexSuper _ _ -> True
+ RigidSuper _ _ -> True
+ Structure flatType ->
+ case flatType of
+ Fun1 _ _ -> True -- Lambdas are non-expansive
+ Unit1 -> True -- Constants are non-expansive
+ EmptyRecord1 -> True
+ Record1 _ _ -> False -- Conservative: treat all records as expansive for now
+ _ -> False
+ _ -> False
+
+-- Helper to check if a variable can be generalized
+checkGeneric :: [Variable] -> Variable -> IO ()
+checkGeneric rigids var =
+ do (Descriptor _ rank _ _ expansiveness) <- UF.get var
+ if var `elem` rigids
+ then unless (rank == noRank) $
+ error "COMPILER BUG - rigid variable not generalized"
+ else case expansiveness of
+ NonExpansive -> return () -- Non-expansive vars are always ok
+ Expansive ->
+ if rank /= noRank
+ then error $ "COMPILER BUG - expansive expression not properly generalized"
+ else return ()
diff --git a/compiler/src/Type/Type.hs b/compiler/src/Type/Type.hs
index 91043904c..64657d3bb 100644
--- a/compiler/src/Type/Type.hs
+++ b/compiler/src/Type/Type.hs
@@ -25,6 +25,9 @@ module Type.Type
, nameToRigid
, toAnnotation
, toErrorType
+ , isAccumulatorType
+ , Expansiveness(..)
+ , combineExpansiveness
)
where
@@ -110,16 +113,17 @@ data Descriptor =
, _rank :: Int
, _mark :: Mark
, _copy :: Maybe Variable
+ , _expansiveness :: Expansiveness
}
data Content
- = FlexVar (Maybe Name.Name)
+ = Structure FlatType
+ | FlexVar (Maybe Name.Name)
| FlexSuper SuperType (Maybe Name.Name)
| RigidVar Name.Name
| RigidSuper SuperType Name.Name
- | Structure FlatType
- | Alias ModuleName.Canonical Name.Name [(Name.Name,Variable)] Variable
+ | Alias ModuleName.Canonical Name.Name [(Name.Name, Variable)] Variable
| Error
@@ -133,7 +137,7 @@ data SuperType
makeDescriptor :: Content -> Descriptor
makeDescriptor content =
- Descriptor content noRank noMark Nothing
+ Descriptor content noRank noMark Nothing NonExpansive
@@ -344,7 +348,7 @@ toAnnotation variable =
variableToCanType :: Variable -> StateT NameState IO Can.Type
variableToCanType variable =
- do (Descriptor content _ _ _) <- liftIO $ UF.get variable
+ do (Descriptor content _ _ _ _) <- liftIO $ UF.get variable
case content of
Structure term ->
termToCanType term
@@ -643,11 +647,11 @@ getFreshSuperHelp prefix index taken =
getVarNames :: Variable -> Map.Map Name.Name Variable -> IO (Map.Map Name.Name Variable)
getVarNames var takenNames =
- do (Descriptor content rank mark copy) <- UF.get var
+ do (Descriptor content rank mark copy expansiveness) <- UF.get var
if mark == getVarNamesMark
then return takenNames
else
- do UF.set var (Descriptor content rank getVarNamesMark copy)
+ do UF.set var (Descriptor content rank getVarNamesMark copy expansiveness)
case content of
Error ->
return takenNames
@@ -715,8 +719,8 @@ addName index givenName var makeContent takenNames =
case Map.lookup indexedName takenNames of
Nothing ->
do if indexedName == givenName then return () else
- UF.modify var $ \(Descriptor _ rank mark copy) ->
- Descriptor (makeContent indexedName) rank mark copy
+ UF.modify var $ \(Descriptor _ rank mark copy expansiveness) ->
+ Descriptor (makeContent indexedName) rank mark copy expansiveness
return $ Map.insert indexedName var takenNames
Just otherVar ->
@@ -724,3 +728,22 @@ addName index givenName var makeContent takenNames =
if same
then return takenNames
else addName (index + 1) givenName var makeContent takenNames
+
+
+isAccumulatorType :: Type -> Bool
+isAccumulatorType tipe =
+ case tipe of
+ VarN _ -> True
+ AppN _ _ args -> any isAccumulatorType args
+ _ -> False
+
+
+data Expansiveness
+ = NonExpansive -- Variables, constants, lambdas
+ | Expansive -- Function applications, etc.
+ deriving (Eq, Show, Ord)
+
+-- Update how we combine expansiveness
+combineExpansiveness :: Expansiveness -> Expansiveness -> Expansiveness
+combineExpansiveness NonExpansive NonExpansive = NonExpansive
+combineExpansiveness _ _ = Expansive
diff --git a/compiler/src/Type/Unify.hs b/compiler/src/Type/Unify.hs
index 837d930e0..448434728 100644
--- a/compiler/src/Type/Unify.hs
+++ b/compiler/src/Type/Unify.hs
@@ -45,7 +45,7 @@ onSuccess vars () =
{-# NOINLINE errorDescriptor #-}
errorDescriptor :: Descriptor
errorDescriptor =
- Descriptor Error noRank noMark Nothing
+ Descriptor Error noRank noMark Nothing Expansive
@@ -147,16 +147,14 @@ reorient (Context var1 desc1 var2 desc2) =
merge :: Context -> Content -> Unify ()
-merge (Context var1 (Descriptor _ rank1 _ _) var2 (Descriptor _ rank2 _ _)) content =
+merge (Context var1 (Descriptor _ rank1 _ _ exp1) var2 (Descriptor _ rank2 _ _ exp2)) content =
Unify $ \vars ok _ ->
- ok vars =<<
- UF.union var1 var2 (Descriptor content (min rank1 rank2) noMark Nothing)
+ ok vars =<< UF.union var1 var2 (Descriptor content (min rank1 rank2) noMark Nothing (Type.combineExpansiveness exp1 exp2))
fresh :: Context -> Content -> Unify Variable
-fresh (Context _ (Descriptor _ rank1 _ _) _ (Descriptor _ rank2 _ _)) content =
- register $ UF.fresh $
- Descriptor content (min rank1 rank2) noMark Nothing
+fresh (Context _ (Descriptor _ rank1 _ _ exp1) _ (Descriptor _ rank2 _ _ exp2)) content =
+ register $ UF.fresh $ Descriptor content (min rank1 rank2) noMark Nothing (Type.combineExpansiveness exp1 exp2)
@@ -183,7 +181,7 @@ subUnify var1 var2 =
actuallyUnify :: Context -> Unify ()
-actuallyUnify context@(Context _ (Descriptor firstContent _ _ _) _ (Descriptor secondContent _ _ _)) =
+actuallyUnify context@(Context _ (Descriptor firstContent _ _ _ _) _ (Descriptor secondContent _ _ _ _)) =
case firstContent of
FlexVar _ ->
unifyFlex context firstContent secondContent
@@ -437,8 +435,8 @@ comparableOccursCheck (Context _ _ var _) =
unifyComparableRecursive :: Variable -> Unify ()
unifyComparableRecursive var =
do compVar <- register $
- do (Descriptor _ rank _ _) <- UF.get var
- UF.fresh $ Descriptor (Type.unnamedFlexSuper Comparable) rank noMark Nothing
+ do (Descriptor _ rank _ _ _) <- UF.get var
+ UF.fresh $ Descriptor (Type.unnamedFlexSuper Comparable) rank noMark Nothing Expansive
guardedUnify compVar var
@@ -681,7 +679,7 @@ data RecordStructure =
gatherFields :: Map.Map Name.Name Variable -> Variable -> IO RecordStructure
gatherFields fields variable =
- do (Descriptor content _ _ _) <- UF.get variable
+ do (Descriptor content _ _ _ _) <- UF.get variable
case content of
Structure (Record1 subFields subExt) ->
gatherFields (Map.union fields subFields) subExt