Skip to content

Commit

Permalink
RustExpr to separate pretty printing from traverse Agda internals (#18)
Browse files Browse the repository at this point in the history
* add RustExpr to separate pretty printing and creating expression from Agda internals

* rename after refactor to RustExpr

* rename after refactor to RustExpr - fix

* swap unless to when
  • Loading branch information
lemastero authored Dec 18, 2023
1 parent 4db314b commit 2f56f9a
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 108 deletions.
5 changes: 3 additions & 2 deletions agda2rust.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ common warnings
library
hs-source-dirs: src
exposed-modules: Agda.Compiler.Rust.Backend
Agda.Compiler.Rust.RustExpr
Agda.Compiler.Rust.CommonTypes
Agda.Compiler.Rust.PrettyPrintingUtils
Agda.Compiler.Rust.ToRustCompiler
Agda.Compiler.Rust.PrettyPrintRustExpr
Agda.Compiler.Rust.AgdaToRustExpr
Paths_agda2rust
autogen-modules: Paths_agda2rust
build-depends: base >= 4.10 && < 4.20,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
{-# LANGUAGE LambdaCase, RecordWildCards #-}

module Agda.Compiler.Rust.ToRustCompiler ( compile, compileModule, moduleHeader ) where
module Agda.Compiler.Rust.AgdaToRustExpr ( compile, compileModule ) where

import Control.Monad.IO.Class ( MonadIO(liftIO) )
import Data.List ( intersperse )
import qualified Data.List.NonEmpty as Nel

import Agda.Compiler.Backend ( IsMain )
Expand All @@ -19,142 +18,109 @@ import Agda.TypeChecking.Monad
import Agda.TypeChecking.CompiledClause ( CompiledClauses(..), CompiledClauses'(..) )

import Agda.Compiler.Rust.CommonTypes ( Options, CompiledDef, ModuleEnv )
import Agda.Compiler.Rust.PrettyPrintingUtils (
argList,
bracket,
combineLines,
defsSeparator,
exprSeparator,
funReturnTypeSeparator,
indent,
typeSeparator )
import Agda.Compiler.Rust.RustExpr ( RustExpr(..), RustName, RustType, RustElem(..), FunBody )

compile :: Options -> ModuleEnv -> IsMain -> Definition -> TCM CompiledDef
compile _ _ _ Defn{..}
= withCurrentModule (qnameModule defName)
$ getUniqueCompilerPragma "AGDA2RUST" defName >>= \case
Nothing -> return []
Just (CompilerPragma _ _) ->
Nothing -> return $ Unhandled "compile" ""
Just (CompilerPragma _ _) ->
return $ compileDefn defName theDef

compileDefn :: QName
-> Defn
-> CompiledDef
compileDefn :: QName -> Defn -> CompiledDef
compileDefn defName theDef =
case theDef of
Datatype{dataCons = fields} ->
compileDataType defName fields
Function{funCompiled = funDef, funClauses = fc} ->
compileFunction defName funDef fc
_ ->
"Unsupported compileDefn" <> showName defName <> " = " <> prettyShow theDef
Unhandled "compileDefn" (show defName ++ " = " ++ show theDef)

compileDataType :: QName -> [QName] -> CompiledDef
compileDataType defName fields = "enum" <> exprSeparator
<> showName defName
<> exprSeparator
<> bracket (
indent
<> concat (intersperse ", " (map showName fields)))
compileDataType defName fields = TeEnum (showName defName) (map showName fields)

compileFunction :: QName
-> Maybe CompiledClauses
-> [Clause]
-> CompiledDef
compileFunction defName funDef fc =
"pub fn" <> exprSeparator
<> showName defName
<> argList (
-- TODO handle multiple function clauses and arguments
compileFunctionArgument fc
<> typeSeparator <> exprSeparator
<> compileFunctionArgType fc )
<> exprSeparator <> funReturnTypeSeparator <> exprSeparator <> compileFunctionResultType fc
<> exprSeparator <> bracket (
-- TODO proper indentation for every line of function body
-- including nested expressions
-- build intermediate AST and pretty printer for it
indent
<> compileFunctionBody funDef)
<> defsSeparator
compileFunction defName funDef fc = TeFun
(showName defName)
(RustElem (compileFunctionArgument fc) (compileFunctionArgType fc))
(compileFunctionResultType fc)
(compileFunctionBody funDef)

-- TODO this is hacky way to reach find first argument name, assuming function has 1 argument
-- TODO proper way is to handle deBruijn indices
-- TODO read docs for `data Clause` section in https://hackage.haskell.org/package/Agda-2.6.4.1/docs/Agda-Syntax-Internal.html
-- TODO start from uncommenting line below and figure out the path to match indices with name and type
-- compileFunctionArgument fc = show fc
compileFunctionArgument :: [Clause] -> CompiledDef
-- TODO see `data Clause` in https://hackage.haskell.org/package/Agda-2.6.4.1/docs/Agda-Syntax-Internal.html
compileFunctionArgument :: [Clause] -> RustName
compileFunctionArgument [] = ""
compileFunctionArgument [fc] = fromDeBruijnPattern (namedThing (unArg (head (namedClausePats fc))))

Check warning on line 59 in src/Agda/Compiler/Rust/AgdaToRustExpr.hs

View workflow job for this annotation

GitHub Actions / agda2rust

In the use of ‘head’
compileFunctionArgument xs = error "unsupported compileFunctionArgument" ++ (show xs)
compileFunctionArgument xs = error "unsupported compileFunctionArgument" ++ (show xs) -- show xs

compileFunctionArgType :: [Clause] -> CompiledDef
compileFunctionArgType :: [Clause] -> RustType
compileFunctionArgType [ Clause{clauseTel = ct} ] = fromTelescope ct
compileFunctionArgType xs = error "unsupported compileFunctionArgType" ++ (show xs)

fromTelescope :: Telescope -> CompiledDef
fromTelescope :: Telescope -> RustName
fromTelescope = \case
ExtendTel a _ -> fromDom a
other -> error ("unhandled fromType" ++ show other)

fromDom :: Dom Type -> CompiledDef
fromDom :: Dom Type -> RustName
fromDom x = fromType (unDom x)

compileFunctionResultType :: [Clause] -> CompiledDef
compileFunctionResultType :: [Clause] -> RustType
compileFunctionResultType [Clause{clauseType = ct}] = fromMaybeType ct
compileFunctionResultType other = error ("unhandled compileFunctionResultType" ++ show other)

fromMaybeType :: Maybe (Arg Type) -> CompiledDef
fromMaybeType :: Maybe (Arg Type) -> RustName
fromMaybeType (Just argType) = fromArgType argType
fromMaybeType other = error ("unhandled fromMaybeType" ++ show other)

fromArgType :: Arg Type -> CompiledDef
fromArgType :: Arg Type -> RustName
fromArgType arg = fromType (unArg arg)

fromType :: Type -> CompiledDef
fromType :: Type -> RustName
fromType = \case
a@(El _ ue) -> fromTerm ue
other -> error ("unhandled fromType" ++ show other)

Check warning on line 88 in src/Agda/Compiler/Rust/AgdaToRustExpr.hs

View workflow job for this annotation

GitHub Actions / agda2rust

Pattern match is redundant

fromTerm :: Term -> CompiledDef
fromTerm :: Term -> RustName
fromTerm = \case
Def qname el -> fromQName qname
other -> error ("unhandled fromTerm" ++ show other)

fromQName :: QName -> CompiledDef
fromQName :: QName -> RustName
fromQName x = prettyShow (qnameName x)

fromDeBruijnPattern :: DeBruijnPattern -> CompiledDef
fromDeBruijnPattern :: DeBruijnPattern -> RustName
fromDeBruijnPattern = \case
VarP a b -> (dbPatVarName b)
a@(ConP x y z) -> show a
other -> error ("unhandled fromDeBruijnPattern" ++ show other)

-- TODO this is wrong for function other than identity
-- see asFriday in Hello.agda vs Hello.rs
compileFunctionBody :: Maybe CompiledClauses -> CompiledDef
compileFunctionBody (Just funDef) = "return" <> exprSeparator <> fromCompiledClauses funDef
compileFunctionBody :: Maybe CompiledClauses -> FunBody
compileFunctionBody (Just funDef) = fromCompiledClauses funDef
compileFunctionBody funDef = error ("unhandled compileFunctionBody " ++ show funDef)

fromCompiledClauses :: CompiledClauses -> CompiledDef
fromCompiledClauses :: CompiledClauses -> FunBody
fromCompiledClauses = \case
(Done (x:xs) term) -> fromArgName x
other -> error ("unhandled fromCompiledClauses " ++ show other)

fromArgName :: Arg ArgName -> CompiledDef
fromArgName :: Arg ArgName -> FunBody
fromArgName = unArg

showName :: QName -> CompiledDef
showName :: QName -> RustName
showName = prettyShow . qnameName

compileModule :: TopLevelModuleName -> [CompiledDef] -> String
compileModule :: TopLevelModuleName -> [CompiledDef] -> CompiledDef
compileModule mName cdefs =
moduleHeader (moduleName mName)
<> bracket (combineLines (map prettyShow cdefs))
<> defsSeparator
TeMod (moduleName mName) cdefs

moduleName :: TopLevelModuleName -> String
moduleName n = prettyShow (Nel.last (moduleNameParts n))

moduleHeader :: String -> String
moduleHeader mName = "mod" <> exprSeparator <> mName <> exprSeparator
9 changes: 5 additions & 4 deletions src/Agda/Compiler/Rust/Backend.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Agda.Compiler.Rust.Backend (
backend,
defaultOptions ) where

import Control.Monad ( unless )
import Control.Monad ( when )
import Control.Monad.IO.Class ( MonadIO(liftIO) )
import Control.DeepSeq ( NFData(..) )
import Data.Maybe ( fromMaybe )
Expand All @@ -23,7 +23,8 @@ import Agda.TypeChecking.Monad (
setScope )

import Agda.Compiler.Rust.CommonTypes ( Options(..), CompiledDef, ModuleEnv )
import Agda.Compiler.Rust.ToRustCompiler ( compile, compileModule )
import Agda.Compiler.Rust.AgdaToRustExpr ( compile, compileModule )
import Agda.Compiler.Rust.PrettyPrintRustExpr ( prettyPrintRustExpr )

runRustBackend :: IO ()
runRustBackend = runAgda [Backend backend]
Expand Down Expand Up @@ -79,9 +80,9 @@ writeModule :: Options
writeModule opts _ _ mName cdefs = do
outDir <- compileDir
compileLog $ "compiling " <> fileName
unless (all null cdefs) $ liftIO
when (null cdefs) $ liftIO
$ writeFile (outFile outDir)
$ compileModule mName cdefs
$ prettyPrintRustExpr (compileModule mName cdefs)
where
fileName = rustFileName mName
dirName outDir = fromMaybe outDir (optOutDir opts)
Expand Down
4 changes: 3 additions & 1 deletion src/Agda/Compiler/Rust/CommonTypes.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ module Agda.Compiler.Rust.CommonTypes (
CompiledDef,
ModuleEnv ) where

import Agda.Compiler.Rust.RustExpr ( RustExpr )

data Options = Options { optOutDir :: Maybe FilePath }

type CompiledDef = String
type CompiledDef = RustExpr

type ModuleEnv = ()
63 changes: 63 additions & 0 deletions src/Agda/Compiler/Rust/PrettyPrintRustExpr.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
module Agda.Compiler.Rust.PrettyPrintRustExpr ( prettyPrintRustExpr, moduleHeader ) where

import Data.List ( intersperse )
import Agda.Compiler.Rust.CommonTypes ( CompiledDef )
import Agda.Compiler.Rust.RustExpr ( RustExpr(..), RustElem(..), FunBody )

prettyPrintRustExpr :: CompiledDef -> String
prettyPrintRustExpr def = case def of
(TeEnum name fields) ->
"enum" <> exprSeparator
<> name
<> exprSeparator
<> bracket (
indent -- TODO this to siplistic indentation
<> concat (intersperse ", " fields))
(TeFun fName (RustElem aName aType) resType fBody) ->
"pub fn" <> exprSeparator
<> fName
<> argList (
aName
<> typeSeparator <> exprSeparator
<> aType )
<> exprSeparator <> funReturnTypeSeparator <> exprSeparator <> resType
<> exprSeparator <> bracket (
-- TODO proper indentation for every line of function body
-- including nested expressions
indent
<> (prettyPrintFunctionBody fBody))
<> defsSeparator
(TeMod mName defs) ->
moduleHeader mName
<> bracket (combineLines (map prettyPrintRustExpr defs))
<> defsSeparator

bracket :: String -> String
bracket str = "{\n" <> str <> "\n}"

argList :: String -> String
argList str = "(" <> str <> ")"

indent :: String
indent = " "

exprSeparator :: String
exprSeparator = " "

defsSeparator :: String
defsSeparator = "\n"

typeSeparator :: String
typeSeparator = ":"

funReturnTypeSeparator :: String
funReturnTypeSeparator = "->"

combineLines :: [String] -> String
combineLines xs = unlines (filter (not . null) xs)

prettyPrintFunctionBody :: FunBody -> String
prettyPrintFunctionBody fBody = "return" <> exprSeparator <> fBody <> ";"

moduleHeader :: String -> String
moduleHeader mName = "mod" <> exprSeparator <> mName <> exprSeparator
34 changes: 0 additions & 34 deletions src/Agda/Compiler/Rust/PrettyPrintingUtils.hs

This file was deleted.

21 changes: 21 additions & 0 deletions src/Agda/Compiler/Rust/RustExpr.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module Agda.Compiler.Rust.RustExpr (
RustName,
RustType,
RustExpr(..),
RustElem(..),
FunBody
) where

type RustName = String
type RustType = String
type FunBody = String

data RustElem = RustElem RustName RustType
deriving ( Show )

data RustExpr
= TeMod RustName [RustExpr]
| TeEnum RustName [RustName]
| TeFun RustName RustElem RustType FunBody
| Unhandled RustName String
deriving ( Show )
2 changes: 1 addition & 1 deletion test/RustBackendTest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import Test.HUnit (
, runTestTT)
import System.Exit ( exitFailure , exitSuccess )
import Agda.Compiler.Rust.Backend ( backend, defaultOptions )
import Agda.Compiler.Rust.ToRustCompiler ( moduleHeader )
import Agda.Compiler.Rust.PrettyPrintRustExpr ( moduleHeader )

import Agda.Compiler.Backend ( isEnabled )

Expand Down

0 comments on commit 2f56f9a

Please sign in to comment.