From 481aac5654e13fa20ef4c7e58405d135c8fafb65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20Paradzi=C5=84ski?= Date: Mon, 18 Dec 2023 16:08:27 +0100 Subject: [PATCH] Add RustExpr to split prettyprint and parse (#20) * add RustExpr to separate pretty printing and creating expression from Agda internals * rename after refactor to RustExpr * rename after refactor to RustExpr - fix * swap unless to when * rename back to PrettyPrintingUtils for easier review --- agda2rust.cabal | 3 +- .../{ToRustCompiler.hs => AgdaToRustExpr.hs} | 98 ++++++------------- src/Agda/Compiler/Rust/Backend.hs | 9 +- src/Agda/Compiler/Rust/CommonTypes.hs | 4 +- src/Agda/Compiler/Rust/PrettyPrintingUtils.hs | 49 ++++++++-- src/Agda/Compiler/Rust/RustExpr.hs | 21 ++++ test/RustBackendTest.hs | 2 +- 7 files changed, 103 insertions(+), 83 deletions(-) rename src/Agda/Compiler/Rust/{ToRustCompiler.hs => AgdaToRustExpr.hs} (56%) create mode 100644 src/Agda/Compiler/Rust/RustExpr.hs diff --git a/agda2rust.cabal b/agda2rust.cabal index 5b3255f..9c3e241 100644 --- a/agda2rust.cabal +++ b/agda2rust.cabal @@ -28,9 +28,10 @@ common warnings library hs-source-dirs: src exposed-modules: Agda.Compiler.Rust.Backend + Agda.Compiler.Rust.RustExpr Agda.Compiler.Rust.CommonTypes Agda.Compiler.Rust.PrettyPrintingUtils - Agda.Compiler.Rust.ToRustCompiler + Agda.Compiler.Rust.AgdaToRustExpr Paths_agda2rust autogen-modules: Paths_agda2rust build-depends: base >= 4.10 && < 4.20, diff --git a/src/Agda/Compiler/Rust/ToRustCompiler.hs b/src/Agda/Compiler/Rust/AgdaToRustExpr.hs similarity index 56% rename from src/Agda/Compiler/Rust/ToRustCompiler.hs rename to src/Agda/Compiler/Rust/AgdaToRustExpr.hs index d56bc78..3e63c95 100644 --- a/src/Agda/Compiler/Rust/ToRustCompiler.hs +++ b/src/Agda/Compiler/Rust/AgdaToRustExpr.hs @@ -1,9 +1,8 @@ {-# LANGUAGE LambdaCase, RecordWildCards #-} -module Agda.Compiler.Rust.ToRustCompiler ( compile, compileModule, moduleHeader ) where +module Agda.Compiler.Rust.AgdaToRustExpr ( compile, compileModule ) where import Control.Monad.IO.Class ( MonadIO(liftIO) ) -import Data.List ( intersperse ) import qualified Data.List.NonEmpty as Nel import Agda.Compiler.Backend ( IsMain ) @@ -19,27 +18,17 @@ import Agda.TypeChecking.Monad import Agda.TypeChecking.CompiledClause ( CompiledClauses(..), CompiledClauses'(..) ) import Agda.Compiler.Rust.CommonTypes ( Options, CompiledDef, ModuleEnv ) -import Agda.Compiler.Rust.PrettyPrintingUtils ( - argList, - bracket, - combineLines, - defsSeparator, - exprSeparator, - funReturnTypeSeparator, - indent, - typeSeparator ) +import Agda.Compiler.Rust.RustExpr ( RustExpr(..), RustName, RustType, RustElem(..), FunBody ) compile :: Options -> ModuleEnv -> IsMain -> Definition -> TCM CompiledDef compile _ _ _ Defn{..} = withCurrentModule (qnameModule defName) $ getUniqueCompilerPragma "AGDA2RUST" defName >>= \case - Nothing -> return [] - Just (CompilerPragma _ _) -> + Nothing -> return $ Unhandled "compile" "" + Just (CompilerPragma _ _) -> return $ compileDefn defName theDef -compileDefn :: QName - -> Defn - -> CompiledDef +compileDefn :: QName -> Defn -> CompiledDef compileDefn defName theDef = case theDef of Datatype{dataCons = fields} -> @@ -47,84 +36,66 @@ compileDefn defName theDef = Function{funCompiled = funDef, funClauses = fc} -> compileFunction defName funDef fc _ -> - "Unsupported compileDefn" <> showName defName <> " = " <> prettyShow theDef + Unhandled "compileDefn" (show defName ++ " = " ++ show theDef) compileDataType :: QName -> [QName] -> CompiledDef -compileDataType defName fields = "enum" <> exprSeparator - <> showName defName - <> exprSeparator - <> bracket ( - indent - <> concat (intersperse ", " (map showName fields))) +compileDataType defName fields = TeEnum (showName defName) (map showName fields) compileFunction :: QName -> Maybe CompiledClauses -> [Clause] -> CompiledDef -compileFunction defName funDef fc = - "pub fn" <> exprSeparator - <> showName defName - <> argList ( - -- TODO handle multiple function clauses and arguments - compileFunctionArgument fc - <> typeSeparator <> exprSeparator - <> compileFunctionArgType fc ) - <> exprSeparator <> funReturnTypeSeparator <> exprSeparator <> compileFunctionResultType fc - <> exprSeparator <> bracket ( - -- TODO proper indentation for every line of function body - -- including nested expressions - -- build intermediate AST and pretty printer for it - indent - <> compileFunctionBody funDef) - <> defsSeparator +compileFunction defName funDef fc = TeFun + (showName defName) + (RustElem (compileFunctionArgument fc) (compileFunctionArgType fc)) + (compileFunctionResultType fc) + (compileFunctionBody funDef) -- TODO this is hacky way to reach find first argument name, assuming function has 1 argument -- TODO proper way is to handle deBruijn indices --- TODO read docs for `data Clause` section in https://hackage.haskell.org/package/Agda-2.6.4.1/docs/Agda-Syntax-Internal.html --- TODO start from uncommenting line below and figure out the path to match indices with name and type --- compileFunctionArgument fc = show fc -compileFunctionArgument :: [Clause] -> CompiledDef +-- TODO see `data Clause` in https://hackage.haskell.org/package/Agda-2.6.4.1/docs/Agda-Syntax-Internal.html +compileFunctionArgument :: [Clause] -> RustName compileFunctionArgument [] = "" compileFunctionArgument [fc] = fromDeBruijnPattern (namedThing (unArg (head (namedClausePats fc)))) -compileFunctionArgument xs = error "unsupported compileFunctionArgument" ++ (show xs) +compileFunctionArgument xs = error "unsupported compileFunctionArgument" ++ (show xs) -- show xs -compileFunctionArgType :: [Clause] -> CompiledDef +compileFunctionArgType :: [Clause] -> RustType compileFunctionArgType [ Clause{clauseTel = ct} ] = fromTelescope ct compileFunctionArgType xs = error "unsupported compileFunctionArgType" ++ (show xs) -fromTelescope :: Telescope -> CompiledDef +fromTelescope :: Telescope -> RustName fromTelescope = \case ExtendTel a _ -> fromDom a other -> error ("unhandled fromType" ++ show other) -fromDom :: Dom Type -> CompiledDef +fromDom :: Dom Type -> RustName fromDom x = fromType (unDom x) -compileFunctionResultType :: [Clause] -> CompiledDef +compileFunctionResultType :: [Clause] -> RustType compileFunctionResultType [Clause{clauseType = ct}] = fromMaybeType ct compileFunctionResultType other = error ("unhandled compileFunctionResultType" ++ show other) -fromMaybeType :: Maybe (Arg Type) -> CompiledDef +fromMaybeType :: Maybe (Arg Type) -> RustName fromMaybeType (Just argType) = fromArgType argType fromMaybeType other = error ("unhandled fromMaybeType" ++ show other) -fromArgType :: Arg Type -> CompiledDef +fromArgType :: Arg Type -> RustName fromArgType arg = fromType (unArg arg) -fromType :: Type -> CompiledDef +fromType :: Type -> RustName fromType = \case a@(El _ ue) -> fromTerm ue other -> error ("unhandled fromType" ++ show other) -fromTerm :: Term -> CompiledDef +fromTerm :: Term -> RustName fromTerm = \case Def qname el -> fromQName qname other -> error ("unhandled fromTerm" ++ show other) -fromQName :: QName -> CompiledDef +fromQName :: QName -> RustName fromQName x = prettyShow (qnameName x) -fromDeBruijnPattern :: DeBruijnPattern -> CompiledDef +fromDeBruijnPattern :: DeBruijnPattern -> RustName fromDeBruijnPattern = \case VarP a b -> (dbPatVarName b) a@(ConP x y z) -> show a @@ -132,29 +103,24 @@ fromDeBruijnPattern = \case -- TODO this is wrong for function other than identity -- see asFriday in Hello.agda vs Hello.rs -compileFunctionBody :: Maybe CompiledClauses -> CompiledDef -compileFunctionBody (Just funDef) = "return" <> exprSeparator <> fromCompiledClauses funDef +compileFunctionBody :: Maybe CompiledClauses -> FunBody +compileFunctionBody (Just funDef) = fromCompiledClauses funDef compileFunctionBody funDef = error ("unhandled compileFunctionBody " ++ show funDef) -fromCompiledClauses :: CompiledClauses -> CompiledDef +fromCompiledClauses :: CompiledClauses -> FunBody fromCompiledClauses = \case (Done (x:xs) term) -> fromArgName x other -> error ("unhandled fromCompiledClauses " ++ show other) -fromArgName :: Arg ArgName -> CompiledDef +fromArgName :: Arg ArgName -> FunBody fromArgName = unArg -showName :: QName -> CompiledDef +showName :: QName -> RustName showName = prettyShow . qnameName -compileModule :: TopLevelModuleName -> [CompiledDef] -> String +compileModule :: TopLevelModuleName -> [CompiledDef] -> CompiledDef compileModule mName cdefs = - moduleHeader (moduleName mName) - <> bracket (combineLines (map prettyShow cdefs)) - <> defsSeparator + TeMod (moduleName mName) cdefs moduleName :: TopLevelModuleName -> String moduleName n = prettyShow (Nel.last (moduleNameParts n)) - -moduleHeader :: String -> String -moduleHeader mName = "mod" <> exprSeparator <> mName <> exprSeparator diff --git a/src/Agda/Compiler/Rust/Backend.hs b/src/Agda/Compiler/Rust/Backend.hs index 9fbe07e..054d680 100644 --- a/src/Agda/Compiler/Rust/Backend.hs +++ b/src/Agda/Compiler/Rust/Backend.hs @@ -3,7 +3,7 @@ module Agda.Compiler.Rust.Backend ( backend, defaultOptions ) where -import Control.Monad ( unless ) +import Control.Monad ( when ) import Control.Monad.IO.Class ( MonadIO(liftIO) ) import Control.DeepSeq ( NFData(..) ) import Data.Maybe ( fromMaybe ) @@ -23,7 +23,8 @@ import Agda.TypeChecking.Monad ( setScope ) import Agda.Compiler.Rust.CommonTypes ( Options(..), CompiledDef, ModuleEnv ) -import Agda.Compiler.Rust.ToRustCompiler ( compile, compileModule ) +import Agda.Compiler.Rust.AgdaToRustExpr ( compile, compileModule ) +import Agda.Compiler.Rust.PrettyPrintingUtils ( prettyPrintRustExpr ) runRustBackend :: IO () runRustBackend = runAgda [Backend backend] @@ -79,9 +80,9 @@ writeModule :: Options writeModule opts _ _ mName cdefs = do outDir <- compileDir compileLog $ "compiling " <> fileName - unless (all null cdefs) $ liftIO + when (null cdefs) $ liftIO $ writeFile (outFile outDir) - $ compileModule mName cdefs + $ prettyPrintRustExpr (compileModule mName cdefs) where fileName = rustFileName mName dirName outDir = fromMaybe outDir (optOutDir opts) diff --git a/src/Agda/Compiler/Rust/CommonTypes.hs b/src/Agda/Compiler/Rust/CommonTypes.hs index 87e7637..c0c7488 100644 --- a/src/Agda/Compiler/Rust/CommonTypes.hs +++ b/src/Agda/Compiler/Rust/CommonTypes.hs @@ -3,8 +3,10 @@ module Agda.Compiler.Rust.CommonTypes ( CompiledDef, ModuleEnv ) where +import Agda.Compiler.Rust.RustExpr ( RustExpr ) + data Options = Options { optOutDir :: Maybe FilePath } -type CompiledDef = String +type CompiledDef = RustExpr type ModuleEnv = () diff --git a/src/Agda/Compiler/Rust/PrettyPrintingUtils.hs b/src/Agda/Compiler/Rust/PrettyPrintingUtils.hs index f9d7a3b..b2036ef 100644 --- a/src/Agda/Compiler/Rust/PrettyPrintingUtils.hs +++ b/src/Agda/Compiler/Rust/PrettyPrintingUtils.hs @@ -1,13 +1,36 @@ -module Agda.Compiler.Rust.PrettyPrintingUtils ( - argList, - bracket, - combineLines, - defsSeparator, - exprSeparator, - funReturnTypeSeparator, - indent, - typeSeparator -) where +module Agda.Compiler.Rust.PrettyPrintingUtils ( prettyPrintRustExpr, moduleHeader ) where + +import Data.List ( intersperse ) +import Agda.Compiler.Rust.CommonTypes ( CompiledDef ) +import Agda.Compiler.Rust.RustExpr ( RustExpr(..), RustElem(..), FunBody ) + +prettyPrintRustExpr :: CompiledDef -> String +prettyPrintRustExpr def = case def of + (TeEnum name fields) -> + "enum" <> exprSeparator + <> name + <> exprSeparator + <> bracket ( + indent -- TODO this to siplistic indentation + <> concat (intersperse ", " fields)) + (TeFun fName (RustElem aName aType) resType fBody) -> + "pub fn" <> exprSeparator + <> fName + <> argList ( + aName + <> typeSeparator <> exprSeparator + <> aType ) + <> exprSeparator <> funReturnTypeSeparator <> exprSeparator <> resType + <> exprSeparator <> bracket ( + -- TODO proper indentation for every line of function body + -- including nested expressions + indent + <> (prettyPrintFunctionBody fBody)) + <> defsSeparator + (TeMod mName defs) -> + moduleHeader mName + <> bracket (combineLines (map prettyPrintRustExpr defs)) + <> defsSeparator bracket :: String -> String bracket str = "{\n" <> str <> "\n}" @@ -32,3 +55,9 @@ funReturnTypeSeparator = "->" combineLines :: [String] -> String combineLines xs = unlines (filter (not . null) xs) + +prettyPrintFunctionBody :: FunBody -> String +prettyPrintFunctionBody fBody = "return" <> exprSeparator <> fBody <> ";" + +moduleHeader :: String -> String +moduleHeader mName = "mod" <> exprSeparator <> mName <> exprSeparator diff --git a/src/Agda/Compiler/Rust/RustExpr.hs b/src/Agda/Compiler/Rust/RustExpr.hs new file mode 100644 index 0000000..3ee8ccc --- /dev/null +++ b/src/Agda/Compiler/Rust/RustExpr.hs @@ -0,0 +1,21 @@ +module Agda.Compiler.Rust.RustExpr ( + RustName, + RustType, + RustExpr(..), + RustElem(..), + FunBody + ) where + +type RustName = String +type RustType = String +type FunBody = String + +data RustElem = RustElem RustName RustType + deriving ( Show ) + +data RustExpr + = TeMod RustName [RustExpr] + | TeEnum RustName [RustName] + | TeFun RustName RustElem RustType FunBody + | Unhandled RustName String + deriving ( Show ) diff --git a/test/RustBackendTest.hs b/test/RustBackendTest.hs index 0765d23..ca9c9c5 100644 --- a/test/RustBackendTest.hs +++ b/test/RustBackendTest.hs @@ -7,7 +7,7 @@ import Test.HUnit ( , runTestTT) import System.Exit ( exitFailure , exitSuccess ) import Agda.Compiler.Rust.Backend ( backend, defaultOptions ) -import Agda.Compiler.Rust.ToRustCompiler ( moduleHeader ) +import Agda.Compiler.Rust.PrettyPrintingUtils ( moduleHeader ) import Agda.Compiler.Backend ( isEnabled )