(このドキュメントの日本語バージョン: docs/internal.ja.md)
This document describes the internal implementation of Jikka.
The general structure of Jikka's internals is to perform the followings in order:
- Read a Python code
- Get a Python AST with parsing the Python code
- Convert the Python AST to an AST of our restricted Python
- Preprocess the AST of our restricted Python
- Convert the AST of a restricted Python to a AST of our core language
- Optimize the AST of our core language
- Convert the AST of our core languagee to a C++ AST
- Postprocess the C++ AST
- Convert the C++ AST to a C++ code
- Write the C++ code
Jikka converts (standard) Python, our restricted Python, our core language and C++ in this order. Here, our restricted Python is the language specified in docs/language.md. Our core language is a language which is similar to Haskell and is almost the same to GHC Core which is the intermediate language of GHC the Haskell compiler. This core language is described in docs/core.md.
- List of modules Jikka
- File: src/Jikka/Main/Subcommand/Convert.hs (Jikka.Main.Subcommand.Convert)
After reading a Python code string, Jikka parses it based on the grammar specification of Python. We use lex (Its Haskell version alex) and yacc (Its Haskell version happy) to generate an LALR(1) parser.
- File: lex src/Jikka/Python/Parse/Happy.y (Jikka.Python.Parse.Happy)
- File: yacc src/Jikka/Python/Parse/Alex.x (Jikka.Python.Parse.Alex)
- Reference: Modern Compiler Implement in ML
For example, consider the following Python code:
def f(a, b) -> int:
return a + b
From this code, we get the following syntax tree.
You can obtain this with running a command $ python3 -c 'import ast; print(ast.dump(ast.parse("def f(a, b) -> int: return a + b")))'
.
Module(
body=[
FunctionDef(
name='f',
args=arguments(
posonlyargs=[],
args=[
arg(arg='a', annotation=None, type_comment=None),
arg(arg='b', annotation=None, type_comment=None)
],
vararg=None,
kwonlyargs=[],
kw_defaults=[],
kwarg=None,
defaults=[]
),
body=[
Return(
value=BinOp(
left=Name(id='a', ctx=Load()),
op=Add(),
right=Name(id='b', ctx=Load()))
)
],
decorator_list=[],
returns=Name(id='int', ctx=Load()),
type_comment=None
)
],
type_ignores=[])
Jikka has the complete AST same to the one of ast
module (data Expr) after parsing Python.
Then, it removes unnecessary parts from this AST and converts it to a convenient AST for our restricted Python (data Expr).
Jikka performes the following preprocesses on the AST of our restricted Python:
- Checking and renaming variable names
- Type inference
- Other miscellaneous checking
It uses Hindley/Milner type inference algorithm. This algorithm reconstructs types with collecting equations about type variables and solving them.
- File: src/Jikka/RestrictedPython/Convert.hs (Jikka.RestrictedPython.Convert)
- File: Checking and renaming variable names src/Jikka/RestrictedPython/Convert/Alpha.hs (Jikka.RestrictedPython.Convert.Alpha)
- File: Type inference src/Jikka/RestrictedPython/Convert/TypeInfer.hs (Jikka.RestrictedPython.Convert.TypeInfer)
- Reference: Types and Programming Languages
It converts an AST of our restricted Python into an AST of the core language.
For example, Python has assignment statements and for
loops, whereas the core language (Haskell) does not.
Therefore, all assignment statements are converted to let
statements and for
loops are converted to foldl.
For example, consider an AST from the following Python code:
def solve(n: int) -> int:
a = 0
b = 1
for _ in range(n):
c = a + b
a = b
b = c
return a
This becomes an AST which corresponds to the following Haskell code:
solve :: Int -> Int
solve n =
let a0 = 0
in let b0 = 1
in let (a3, b3) =
foldl (\(a1, b1) _ ->
let c = a1 + b1
in let a2 = b1
in let b2 = c
in (a2, b2)
) (a0, b0) [0..n - 1]
in a3
This is the main process of optimizations that Jikka does. Jikka tries every optimizations we can think of. Most of them are implemented as rewrite rules.
At the moment, optimizations are done in a greedy way, which is looking for possible conversions using rewrite rules and always converts them. In other words, Jikka doesn't perform searching such as DFS or a beam search. Doing such complex optimizations are a future task.
- File: src/Jikka/Core/Convert.hs (Jikka.Core.Convert)
- Directory: src/Jikka/Core/Convert/
For example, consider the following O(N²) Python code:
def solve(n: int, a: List[int]) -> int:
b = 0
for i in range(n):
b += sum(a[:i])
return b
Before the optimization step, this Python code is already converted to the following Haskell code:
solve :: Int -> [Int] -> Int
solve n a =
foldl (\b i ->
b + sum (map (\j -> a !! j) [0..i - 1])
) 0 [0..n - 1]
At first, a rewrite rule about cumulative sum "replace a sub-expression like sum (map (\i -> xs !! i) [0..k - 1])
with an expresssion let ys = scanl (+) 0 xs in ys !! k
" works, and the above code becomes the following code with scanl:
solve :: Int -> [Int] -> Int
solve n a =
foldl (\b i ->
let c = scanl (+) 0 a
in b + c !! i
) 0 [0..n - 1]
Then a rewrite rule about foldl and let
expression "if variables y
and x
are not used in a expression c
, and if a variable a
is not used in expressions y0
and xs
, then replace a sub-expression foldl (\y x -> let a = c in e) y0 xs
with an expression let a = c in foldl (\y x -> e) y0 xs
" works. The code becomes the following:
solve :: Int -> [Int] -> Int
solve n a =
let c = scanl (+) 0 a
in foldl (\b i ->
b + c !! i
) 0 [0..n - 1]
This result Haskell code will the following C++ code with the following steps. This is O(N).
int solve(int n, vector<int> a) {
vector<int> c;
c.push_back(0);
for (int i = 0; i < a.size(); ++ i) {
c.push_back(c[i] + a[i]);
}
int b = 0;
for (int i = 0; i < n; ++ i) {
b += c[i];
}
return b;
}
- File: src/Jikka/Core/Convert/CumulativeSum.hs (Jikka.Core.Convert.CumulativeSum)
- File: src/Jikka/Core/Convert/BubbleLet.hs (Jikka.Core.Convert.BubbleLet)
Let's see the implementation of module Jikka.Core.Convert.ShortCutFusion for Short cut fusion.
For example, rewrite rules reduceFoldBuild
is defined as follows at v5.1.0.0
:
reduceFoldBuild :: MonadAlpha m => RewriteRule m
reduceFoldBuild =
let return' = return . Just
in RewriteRule $ \_ -> \case
-- reduce `Foldl`
Foldl' _ _ _ init (Nil' _) -> return' init
Foldl' t1 t2 g init (Cons' _ x xs) -> return' $ Foldl' t1 t2 g (App2 g init x) xs
-- reduce `Len`
Len' _ (Nil' _) -> return' Lit0
Len' t (Cons' _ _ xs) -> return' $ Plus' Lit1 (Len' t xs)
Len' _ (Range1' n) -> return' n
-- reduce `At`
At' t (Nil' _) i -> return' $ Bottom' t $ "cannot subscript empty list: index = " ++ formatExpr i
At' t (Cons' _ x xs) i -> return' $ If' t (Equal' IntTy i Lit0) x (At' t xs (Minus' i Lit1))
At' _ (Range1' _) i -> return' i
-- reduce `Elem`
Elem' _ _ (Nil' _) -> return' LitFalse
Elem' t y (Cons' _ x xs) -> return' $ And' (Equal' t x y) (Elem' t y xs)
Elem' _ x (Range1' n) -> return' $ And' (LessEqual' IntTy Lit0 x) (LessThan' IntTy x n)
-- others
Len' t (Build' _ _ base n) -> return' $ Plus' (Len' t base) n
_ -> return Nothing
For example, a line Len' _ (Nil' _) -> return' Lit0
represents a rewrite rule to replace a sub-expression length []
with an expression 0
.
A line Len' t (Cons' _ _ xs) -> return' $ Plus' Lit1 (Len' t xs)
represents a rewrite rule to replace a sub-expression length (cons x xs)
with an expression 1 + length xs
.
Also, this rewrite rule reduceFoldBuild
is rewritten at v5.2.0.0
with Template Haskell, which is a macro feature of Haskell (GHC).
The content remains the same and the code is:
reduceFoldMap :: MonadAlpha m => RewriteRule m
reduceFoldMap =
mconcat
[ -- reduce `Reversed`
[r| "len/reversed" forall xs. len (reversed xs) = len xs |],
[r| "elem/reversed" forall x xs. elem x (reversed xs) = elem x xs |],
[r| "at/reversed" forall xs i. (reversed xs)[i] = xs[len(xs) - i - 1] |],
-- reduce `Sorted`
[r| "len/sorted" forall xs. len (sorted xs) = len xs |],
[r| "elem/sorted" forall x xs. elem x (sorted xs) = elem x xs |],
-- reduce `Map`
[r| "len/map" forall f xs. len (map f xs) = len xs |],
[r| "at/map" forall f xs i. (map f xs)[i] = f xs[i] |],
[r| "foldl/map" forall g init f xs. foldl g init (map f xs) = foldl (fun y x -> g y (f x)) init xs|],
-- others
[r| "len/setat" forall xs i x. len xs[i <- x] = len xs |],
[r| "len/scanl" forall f init xs. len (scanl f init xs) = len xs + 1 |],
[r| "at/setat" forall xs i x j. xs[i <- x][j] = if i == j then x else xs[j] |]
]
For example which treats data structures, let's see the implementation about segment trees.
The module Jikka.Core.Convert.SegmentTree has a function reduceCumulativeSum
.
This function performs a conversion with segment trees, when cumulative sums are used in a foldl loop, but the target array of cumulative sums are updated in the loop and the cummulative sum cannot be moved out of the loop.
def solve(n: int, a: List[int], q: int, l: List[int], r: List[int]) -> List[int]:
for i in range(q):
# a[l[i]] = sum(a[:r[i])
b = [0]
for j in range(n):
b.append(b[j] + a[j])
a[l[i]] = b[r[i]]
return a
The function reduceCumulativeSum
is implemented as follows at v5.1.0.0
:
-- | `reduceCumulativeSum` converts combinations of cumulative sums and array assignments to segment trees.
reduceCumulativeSum :: (MonadAlpha m, MonadError Error m) => RewriteRule m
reduceCumulativeSum = RewriteRule $ \_ -> \case
-- foldl (fun a i -> setat a index(i) e(a, i)) base incides
Foldl' t1 t2 (Lam2 a _ i _ (SetAt' t (Var a') index e)) base indices | a' == a && a `isUnusedVar` index -> runMaybeT $ do
let sums = listCumulativeSum (Var a) e -- (A)
guard $ not (null sums)
let semigrps = nub (sort (map fst sums))
let ts = t2 : map SegmentTreeTy semigrps
c <- lift $ genVarName a
let proj i = Proj' ts i (Var c)
let e' = replaceWithSegtrees a (zip semigrps (map proj [1 ..])) e -- (B)
guard $ e' /= e
e' <- lift $ substitute a (proj 0) e'
b' <- lift $ genVarName a
let updateSegtrees i semigrp = SegmentTreeSetPoint' semigrp (proj i) index (At' t (Var b') index) -- (C)
let step = Lam2 c (TupleTy ts) i t1 (Let b' t2 (SetAt' t (proj 0) index e') (uncurryApp (Tuple' ts) (Var b' : zipWith updateSegtrees [1 ..] semigrps))) -- (D)
b <- lift $ genVarName a
let base' = Var b : map (\semigrp -> SegmentTreeInitList' semigrp (Var b)) semigrps -- (E)
return $ Let b t2 base (Proj' ts 0 (Foldl' t1 (TupleTy ts) step (uncurryApp (Tuple' ts) base') indices)) -- (F)
_ -> return Nothing
At first this function reduceCumulativeSum
finds expressions in the form of foldl (\a i -> setat a index(i) e(a, i)) base incides
, with the following entities:
- type
t
- expression
base
(with type[t]
) - expression
indices
(with type[Int]
) - variable
a
(with type[t]
) - variable
i
(with typeInt
) - builtin function
setat
(with type[t] -> Int -> t -> [t]
) - expression
index(i)
(may contain the variablei
but doesn't contain the variablea
. Its type isInt
.) - expression
e(a, i)
(may contain the variablesa
andi
. Its type ist
.)
At first, the function reduceCumulativeSum
calls listCumulativeSum
at (A) to list places where cumulative sums are used in e(a, i)
.
Then it lists corresponding semigroups from them, and calls replaceWithSegtrees
at (B) to replace cumulative sums in e(a, i)
with expressions with segment trees.
It makes an expression to update the segment trees at (C), and makes a function body to give to foldl
at (D).
Then it makes an initial state base'
of segment trees at (E) line, and finally returns the result expression at (F).
To use segment trees here, the core language has data-structure
types and builtin functions like SegmentTreeInitList
SegmentTreeGetRange
SegmentTreeSetPoint
.
For example, the builtin function SegmentTreeSetPoint
has the type segment−tree(S) → int → S → segment−tree(S)
for each S: semigroup
.
Similarly, C++, to which the core language has been translated, has types and builtin functions for segment trees.
- File: src/Jikka/Core/Convert/ShortCutFusion.hs (Jikka.Core.Convert.SegmentTree)
- File: src/Jikka/Core/Language/Expr.hs (Jikka.Core.Language.Expr)
- File: src/Jikka/CPlusPlus/Language/Expr.hs (Jikka.CPlusPlus.Language.Expr)
After optimizations, Jikka converts the AST of the core language to a C++ AST.
For example, consider the following code:
solve :: Int -> Int
solve n =
let a0 = 0
in let b0 = 1
in let (a3, b3) =
foldl (\(a1, b1) _ ->
let c = a1 + b1
in let a2 = b1
in let b2 = c
in (a2, b2)
) (a0, b0) [0..n - 1]
in a3
This is converted to the following C++ code:
int solve(int n) {
int a0 = 0;
int b0 = 1;
pair<int, int> x = make_pair(a0, b0);
for (int i = 0; i < n; ++ i) {
auto [a1, b1] = x;
int c = a1 + b1;
int a2 = b1;
int b2 = c;
x = make_pair(a2, b2);
}
auto [a3, b3] = x;
return a3;
}
Jikka performs conversions to eliminate inefficiencies that occur in the conversion from a AST of the core language.
Mainly, it converts unnecessary copies to moves.
It also inserts necessary #include
statements.
- File: src/Jikka/CPlusPlus/Convert.hs (Jikka.CPlusPlus.Convert)
- File: Conversion from copy to move src/Jikka/CPlusPlus/Convert/MoveSemantics.hs (Jikka.CPlusPlus.Convert.MoveSemantics)
A C++ AST just converted from the core language looks like the following C++ code:
int solve(int n) {
int a0 = 0;
int b0 = 1;
pair<int, int> x = make_pair(a0, b0);
for (int i = 0; i < n; ++ i) {
auto [a1, b1] = x;
int c = a1 + b1;
int a2 = b1;
int b2 = c;
x = make_pair(a2, b2);
}
auto [a3, b3] = x;
return a3;
}
This will be converted into an AST that corresponds to the following C++ code:
int solve(int n) {
int a = 0;
int b = 1;
for (int i = 0; i < n; ++ i) {
int c = a + b;
a = b;
b = c;
}
return a;
}
Finally Jikka converts C++ AST to C++ code.
We use the precedence value method for parentheses, as in Text.Show.