From fd477bc29335e0b288d0dfce4dc56933018b08d2 Mon Sep 17 00:00:00 2001 From: Seunghun Lee Date: Mon, 14 Aug 2017 17:57:41 +0900 Subject: [PATCH] Defaulting None for optional fields --- src/Nirum/Targets/Python.hs | 90 +++++++++++++++++++++++------- test/nirum_fixture/fixture/foo.nrm | 13 ++++- test/python/primitive_test.py | 10 +++- 3 files changed, 88 insertions(+), 25 deletions(-) diff --git a/src/Nirum/Targets/Python.hs b/src/Nirum/Targets/Python.hs index 5a3b845..e624a46 100644 --- a/src/Nirum/Targets/Python.hs +++ b/src/Nirum/Targets/Python.hs @@ -320,16 +320,20 @@ toIndentedCodes f traversable concatenator = T.intercalate concatenator $ map f traversable compileParameters :: (ParameterName -> ParameterType -> Code) - -> [(T.Text, Code)] + -> [(T.Text, Code, Bool)] -> Code -compileParameters gen nameTypePairs = - toIndentedCodes (uncurry gen) nameTypePairs ", " +compileParameters gen nameTypeTriples = + toIndentedCodes + (\(n, t, o) -> gen n t `T.append` if o then "=None" else "") + nameTypeTriples ", " -compileFieldInitializers :: DS.DeclarationSet Field -> CodeGen Code -compileFieldInitializers fields = do +compileFieldInitializers :: DS.DeclarationSet Field -> Int -> CodeGen Code +compileFieldInitializers fields depth = do initializers <- forM (toList fields) compileFieldInitializer - return $ T.intercalate "\n " initializers + return $ T.intercalate indentSpaces initializers where + indentSpaces :: T.Text + indentSpaces = "\n" `T.append` T.replicate depth " " compileFieldInitializer :: Field -> CodeGen Code compileFieldInitializer (Field fieldName' fieldType' _) = case fieldType' of @@ -432,11 +436,16 @@ compileUnionTag :: Source -> Name -> Tag -> CodeGen Code compileUnionTag source parentname d@(Tag typename' fields _) = do typeExprCodes <- mapM (compileTypeExpression source) [typeExpr | (Field _ typeExpr _) <- toList fields] - let className = toClassName' typename' + let optionFlags = [ case typeExpr of + OptionModifier _ -> True + _ -> False + | (Field _ typeExpr _) <- toList fields + ] + className = toClassName' typename' tagNames = map (toAttributeName' . fieldName) (toList fields) - nameNTypes = zip tagNames typeExprCodes + nameTypeTriples = zip3 tagNames typeExprCodes optionFlags slotTypes = toIndentedCodes - (\ (n, t) -> [qq|('{n}', {t})|]) nameNTypes ",\n " + (\ (n, t, _) -> [qq|('{n}', {t})|]) nameTypeTriples ",\n " slots = if length tagNames == 1 then [qq|'{head tagNames}'|] `T.snoc` ',' else toIndentedCodes (\ n -> [qq|'{n}'|]) tagNames ",\n " @@ -458,7 +467,27 @@ compileUnionTag source parentname d@(Tag typename' fields _) = do typeRepr <- typeReprCompiler arg <- parameterCompiler ret <- returnCompiler - initializers <- compileFieldInitializers fields + pyVer <- getPythonVersion + initializers <- compileFieldInitializers fields $ case pyVer of + Python3 -> 2 + Python2 -> 3 + let initParams = compileParameters arg nameTypeTriples + inits = case pyVer of + Python2 -> [qq| + def __init__(self, **kwargs): + def __init__($initParams): + $initializers + pass + __init__(**kwargs) + validate_record_type(self) + |] + Python3 -> [qq| + def __init__(self{ if null nameTypeTriples + then T.empty + else ", *, " `T.append` initParams }) -> None: + $initializers + validate_union_type(self) + |] return [qq| class $className($parentClass): {compileDocstringWithFields " " d fields} @@ -474,9 +503,7 @@ class $className($parentClass): def __nirum_tag_types__(): return [$slotTypes] - def __init__(self, {compileParameters arg nameNTypes}){ ret "None" }: - $initializers - validate_union_type(self) + { inits :: T.Text } def __repr__(self){ ret "str" }: return '\{0\}(\{1\})'.format( @@ -669,12 +696,17 @@ compileTypeDeclaration src d@TypeDeclaration { typename = typename' fieldList = toList fields typeExprCodes <- mapM (compileTypeExpression src) [typeExpr | (Field _ typeExpr _) <- fieldList] - let fieldNames = map toAttributeName' [ name' + let optionFlags = [ case typeExpr of + OptionModifier _ -> True + _ -> False + | (Field _ typeExpr _) <- fieldList + ] + fieldNames = map toAttributeName' [ name' | (Field name' _ _) <- fieldList ] - nameTypePairs = zip fieldNames typeExprCodes + nameTypeTriples = zip3 fieldNames typeExprCodes optionFlags slotTypes = toIndentedCodes - (\ (n, t) -> [qq|'{n}': {t}|]) nameTypePairs ",\n " + (\ (n, t, _) -> [qq|'{n}': {t}|]) nameTypeTriples",\n " slots = toIndentedCodes (\ n -> [qq|'{n}'|]) fieldNames ",\n " nameMaps = toIndentedCodes toNamePair @@ -693,7 +725,27 @@ compileTypeDeclaration src d@TypeDeclaration { typename = typename' arg <- parameterCompiler ret <- returnCompiler typeRepr <- typeReprCompiler - initializers <- compileFieldInitializers fields + pyVer <- getPythonVersion + initializers <- compileFieldInitializers fields $ case pyVer of + Python3 -> 2 + Python2 -> 3 + let initParams = compileParameters arg nameTypeTriples + inits = case pyVer of + Python2 -> [qq| + def __init__(self, **kwargs): + def __init__($initParams): + $initializers + pass + __init__(**kwargs) + validate_record_type(self) + |] + Python3 -> [qq| + def __init__(self{ if null nameTypeTriples + then T.empty + else ", *, " `T.append` initParams }) -> None: + $initializers + validate_record_type(self) + |] let clsType = arg "cls" "type" return [qq| class $className(object): @@ -710,9 +762,7 @@ class $className(object): def __nirum_field_types__(): return \{$slotTypes\} - def __init__(self, {compileParameters arg nameTypePairs}){ret "None"}: - $initializers - validate_record_type(self) + {inits :: T.Text} def __repr__(self){ret "bool"}: return '\{0\}(\{1\})'.format( diff --git a/test/nirum_fixture/fixture/foo.nrm b/test/nirum_fixture/fixture/foo.nrm index e04cad4..2a8fb98 100644 --- a/test/nirum_fixture/fixture/foo.nrm +++ b/test/nirum_fixture/fixture/foo.nrm @@ -93,10 +93,17 @@ service ping-service ( ); record person ( - irum first-name, - irum last-name, + irum first-name, + irum last-name, ); record people ( - {person} people + {person} people +); + +record product ( + text name, + int64? price, + bool sale, + uri? url, ); diff --git a/test/python/primitive_test.py b/test/python/primitive_test.py index 87b5198..8bb2f8a 100644 --- a/test/python/primitive_test.py +++ b/test/python/primitive_test.py @@ -8,8 +8,8 @@ from fixture.foo import (CultureAgnosticName, EastAsianName, EvaChar, FloatUnbox, Gender, ImportedTypeUnbox, Irum, Line, MixedName, NullService, - Point1, Point2, Point3d, Pop, PingService, Rnb, - Run, Stop, Way, WesternName) + Point1, Point2, Point3d, Pop, PingService, Product, + Rnb, Run, Stop, Way, WesternName) from fixture.foo.bar import PathUnbox, IntUnbox, Point from fixture.qux import Path, Name @@ -249,3 +249,9 @@ def test_service(): PingService().ping(nonce=u'nonce') with raises(TypeError): PingService().ping(wrongkwd=u'a') + + +def test_optional_initializer_test(): + product = Product(name='coffee', sale=False) + assert product.price is None + assert product.url is None