Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use IOArray in interpreter for linear complexity #378

Merged
merged 14 commits into from
Dec 28, 2023
6 changes: 3 additions & 3 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ jobs:
run: |
./install_bazel.sh
./build.sh cpu
- name: Build tests
run: pack --no-prompt build test.ipkg
- name: Run tests
run: |
pack --no-prompt build test.ipkg
pack run test.ipkg
run: pack run test.ipkg
env:
LD_LIBRARY_PATH: $LD_LIBRARY_PATH:backend/bazel-bin
readme:
Expand Down
350 changes: 176 additions & 174 deletions src/Compiler/Eval.idr
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@ limitations under the License.
--}
module Compiler.Eval

import Control.Monad.Error.Either
import Control.Monad.Maybe
import Control.Monad.State
import Control.Monad.Either
import Control.Monad.Reader
import Control.Monad.Trans
import Data.IOArray
import Data.List
import Data.List.Elem
import Data.SortedMap
import Decidable.Equality

import Compiler.Expr
import Compiler.LiteralRW
Expand All @@ -48,186 +47,189 @@ import Primitive
import Types
import Util

%hide Util.List.All.map

export
data Err = IndexErr String
data Err = OutOfBounds Nat Nat
| ValueNotFound Nat

export
Show Err where
show (IndexErr msg) = "IndexErr: \{msg}"

0 Computation : Type -> Type
Computation = StateT (SortedMap Nat XlaOp) (EitherT Err IO)

||| Look up the `XlaOp` at `position` in the graph.
lookup : (position : Nat) -> Computation XlaOp
lookup n = do
cache <- get
case lookup n cache of
Nothing =>
lift $ left (IndexErr "Tried to look up value at index \{show n} but found keys \{show $ toList (keys cache)}")
Just op => pure op

interpret : XlaBuilder -> Nat -> Env -> Computation XlaOp

||| Build a computation from an inner function
|||
||| @xlaBuilder The enclosing XLA builder in which this function is built.
||| This is not the XLA builder used to build the computation itself.
||| @computationName The name of the computation.
||| @arity The function arity.
||| @f The function to build.
buildSub : (xlaBuilder : XlaBuilder) ->
(computationName : String) ->
(f : Fn arity) ->
Computation XlaComputation
buildSub builder name (MkFn params result env) = do
subBuilder <- createSubBuilder builder name
traverse_ (interpretParameter subBuilder) (enumerate params)
root <- assert_total $ interpret subBuilder result env
build subBuilder root
show (OutOfBounds idx size) = "Index \{show idx} is out of bounds for array of size \{show size}"
show (ValueNotFound idx) = "Value requested but not found at index \{show idx}"

where
public export 0
ErrIO : Type -> Type
ErrIO = EitherT Err IO

interpretParameter : XlaBuilder -> (Nat, Nat, ShapeAndType) -> Computation ()
interpretParameter builder (positionInFnParams, positionInGraph, MkShapeAndType shape dtype) = do
xlaShape <- mkShape {dtype} shape
param <- parameter builder positionInFnParams xlaShape name
put $ insert positionInGraph param !get
covering
interpret : XlaBuilder -> Fn arity -> ErrIO XlaOp

covering
enqueue : XlaBuilder -> Expr -> Computation XlaOp
enqueue builder (FromLiteral {dtype} lit) = constantLiteral builder !(write {dtype} lit)
enqueue _ (Arg x) = lookup x
enqueue builder (Tuple xs) = tuple builder !(traverse lookup xs)
enqueue builder (GetTupleElement idx x) = getTupleElement !(lookup x) idx
enqueue builder (MinValue {dtype}) = minValue {dtype} builder
enqueue builder (MaxValue {dtype}) = maxValue {dtype} builder
enqueue builder (MinFiniteValue {dtype}) = minFiniteValue {dtype} builder
enqueue builder (MaxFiniteValue {dtype}) = maxFiniteValue {dtype} builder
enqueue _ (ConvertElementType x) = convertElementType {dtype = F64} !(lookup x)
enqueue _ (Reshape from to x) = reshape !(lookup x) (range $ length from) to
enqueue _ (Slice starts stops strides x) = slice !(lookup x) starts stops strides
enqueue _ (DynamicSlice starts sizes x) =
dynamicSlice !(lookup x) !(traverse lookup starts) sizes
enqueue builder (Concat axis x y) = concatInDim builder [!(lookup x), !(lookup y)] (cast axis)
enqueue _ (Diag x) = getMatrixDiagonal !(lookup x)
enqueue _ (Triangle tri x) = triangle !(lookup x) tri
enqueue _ (Transpose ordering x) = transpose !(lookup x) ordering
enqueue builder (Identity {dtype} n) = let n = cast n in identityMatrix {dtype} builder n n
enqueue builder (Broadcast {dtype} from to x) =
if elem 0 to && from /= to
then do
literal <- allocLiteral {dtype} to
constantLiteral builder literal
else
let broadcastDims = map (+ length to `minus` length from) $ range $ length from
in broadcastInDim !(lookup x) to broadcastDims
enqueue builder (Map f xs dims) = do
computation <- buildSub builder "computation" f
map builder (toList !(traverse lookup xs)) computation dims
enqueue builder (Reduce f neutral axes x) = do
computation <- buildSub builder "computation" f
reduce !(lookup x) !(lookup neutral) computation axes
enqueue builder (Sort f axis isStable xs) = do
comparator <- buildSub builder "comparator" f
sort !(traverse lookup xs) comparator axis isStable
enqueue _ (Reverse axes x) = rev !(lookup x) axes
enqueue _ (BinaryElementwise f x y) = toXla f !(lookup x) !(lookup y)
where
toXla : BinaryOp -> HasIO io => XlaOp -> XlaOp -> io XlaOp
toXla = \case
Eq => eq
Ne => ne
Add => add
Sub => sub
Mul => mul
Div => div
Rem => rem
Pow => pow
Lt => lt
Gt => gt
Le => le
Ge => ge
And => and
Or => or
Min => min
Max => max
enqueue _ (UnaryElementwise f x) = toXla f !(lookup x)
where
toXla : UnaryOp -> HasIO io => XlaOp -> io XlaOp
toXla = \case
Not => not
Neg => neg
Reciprocal => reciprocal
Ceil => ceil
Floor => floor
Abs => abs
Log => log
Exp => exp
Logistic => logistic
Erf => erf
Square => square
Sqrt => sqrt
Sin => sin
Cos => cos
Tan => tan
Asin => asin
Acos => acos
Atan => atan
Sinh => sinh
Cosh => cosh
Tanh => tanh
Asinh => asinh
Acosh => acosh
Atanh => atanh
enqueue _ (Argmin {out} axis x) = argMin {outputType=out} !(lookup x) axis
enqueue _ (Argmax {out} axis x) = argMax {outputType=out} !(lookup x) axis
enqueue _ (Select pred true false) = select !(lookup pred) !(lookup true) !(lookup false)
enqueue builder (Cond pred fTrue true fFalse false) = do
trueComp <- buildSub builder "truthy computation" fTrue
falseComp <- buildSub builder "falsy computation" fFalse
conditional !(lookup pred) !(lookup true) trueComp !(lookup false) falseComp
enqueue _ (Dot l r) = dot !(lookup l) !(lookup r)
enqueue _ (Cholesky x) = cholesky !(lookup x) True
enqueue _ (TriangularSolve a b lower) =
triangularSolve !(lookup a) !(lookup b) True lower False NoTranspose
enqueue builder (UniformFloatingPoint key initialState minval maxval shape) = do
rngOutput <- uniformFloatingPointDistribution
!(lookup key)
!(lookup initialState)
ThreeFry
!(lookup minval)
!(lookup maxval)
!(mkShape {dtype=F64} shape)
tuple builder [value rngOutput, state rngOutput]
enqueue builder (NormalFloatingPoint key initialState shape) = do
rngOutput <- normalFloatingPointDistribution
!(lookup key) !(lookup initialState) ThreeFry !(mkShape {dtype=F64} shape)
tuple builder [value rngOutput, state rngOutput]

interpret builder root env = do
traverse_ interpretExpr (toList env)
lookup root
compile : XlaBuilder -> Fn arity -> ErrIO XlaComputation
compile xlaBuilder f = do
root <- interpret xlaBuilder f
build xlaBuilder root

interpret xlaBuilder (MkFn params root env) = do
let (max, exprs) = toList env
cache <- newArray (cast max)
runReaderT cache $ do
traverse_ interpretParameter (enumerate params)
traverse_ (\(i, expr) => do set i !(interpretE expr)) exprs
get root

where
interpretExpr : (Nat, Expr) -> Computation ()
interpretExpr (n, expr) = put (insert n !(enqueue builder expr) !get)

export
toString : Nat -> Env -> EitherT Err IO String
toString root env = do
builder <- mkXlaBuilder "toString"
xlaOp <- evalStateT empty (interpret builder root env)
pure $ opToString builder xlaOp

export
run : PrimitiveRW dtype a => Nat -> Env -> {shape : _} -> EitherT Err IO (Literal shape a)
run root env = do
builder <- mkXlaBuilder "root"
root <- evalStateT empty (interpret builder root env)
computation <- XlaBuilder.build builder root
0 Builder : Type -> Type
Builder = ReaderT (IOArray XlaOp) ErrIO

set : Nat -> XlaOp -> Builder ()
set idx xlaOp = do
cache <- ask
True <- lift $ writeArray cache (cast idx) xlaOp
| False => lift $ left $ OutOfBounds idx (cast $ max cache)
pure ()

get : Nat -> Builder XlaOp
get idx = do
cache <- ask
Just xlaOp <- lift $ readArray cache (cast idx)
| _ => lift $ left $ let max = cast (max cache)
in if idx >= max then OutOfBounds idx max else ValueNotFound idx
pure xlaOp

interpretParameter : (Nat, Nat, ShapeAndType) -> Builder ()
interpretParameter (posInFnParams, posInGraph, MkShapeAndType shape dtype) = do
xlaShape <- mkShape {dtype} shape
param <- parameter xlaBuilder posInFnParams xlaShape (show posInFnParams)
set posInGraph param

interpretE : Expr -> Builder XlaOp
interpretE (FromLiteral {dtype} lit) = constantLiteral xlaBuilder !(write {dtype} lit)
interpretE (Arg x) = get x
interpretE (Tuple xs) = tuple xlaBuilder !(traverse get xs)
interpretE (GetTupleElement idx x) = getTupleElement !(get x) idx
interpretE (MinValue {dtype}) = minValue {dtype} xlaBuilder
interpretE (MaxValue {dtype}) = maxValue {dtype} xlaBuilder
interpretE (MinFiniteValue {dtype}) = minFiniteValue {dtype} xlaBuilder
interpretE (MaxFiniteValue {dtype}) = maxFiniteValue {dtype} xlaBuilder
interpretE (ConvertElementType x) = convertElementType {dtype = F64} !(get x)
interpretE (Reshape from to x) = reshape !(get x) (range $ length from) to
interpretE (Slice starts stops strides x) = slice !(get x) starts stops strides
interpretE (DynamicSlice starts sizes x) =
dynamicSlice !(get x) !(traverse get starts) sizes
interpretE (Concat axis x y) = concatInDim xlaBuilder [!(get x), !(get y)] (cast axis)
interpretE (Diag x) = getMatrixDiagonal !(get x)
interpretE (Triangle tri x) = triangle !(get x) tri
interpretE (Transpose ordering x) = transpose !(get x) ordering
interpretE (Identity {dtype} n) = let n = cast n in identityMatrix {dtype} xlaBuilder n n
interpretE (Broadcast {dtype} from to x) =
if elem 0 to && from /= to
then do
literal <- allocLiteral {dtype} to
constantLiteral xlaBuilder literal
else
let broadcastDims = Prelude.map (+ length to `minus` length from) $ range $ length from
in broadcastInDim !(get x) to broadcastDims
interpretE (Map f xs dims) = do
subBuilder <- createSubBuilder xlaBuilder "computation"
computation <- lift $ compile subBuilder f
map xlaBuilder (toList !(traverse get xs)) computation dims
interpretE (Reduce f neutral axes x) = do
subBuilder <- createSubBuilder xlaBuilder "monoid binary op"
computation <- lift $ compile subBuilder f
reduce !(get x) !(get neutral) computation axes
interpretE (Sort f axis isStable xs) = do
subBuilder <- createSubBuilder xlaBuilder "comparator"
computation <- lift $ compile subBuilder f
sort !(traverse get xs) computation axis isStable
interpretE (Reverse axes x) = rev !(get x) axes
interpretE (BinaryElementwise f x y) = toXla f !(get x) !(get y)
where
toXla : BinaryOp -> HasIO io => XlaOp -> XlaOp -> io XlaOp
toXla = \case
Eq => eq
Ne => ne
Add => add
Sub => sub
Mul => mul
Div => div
Rem => rem
Pow => pow
Lt => lt
Gt => gt
Le => le
Ge => ge
And => and
Or => or
Min => min
Max => max
interpretE (UnaryElementwise f x) = toXla f !(get x)
where
toXla : UnaryOp -> HasIO io => XlaOp -> io XlaOp
toXla = \case
Not => not
Neg => neg
Reciprocal => reciprocal
Ceil => ceil
Floor => floor
Abs => abs
Log => log
Exp => exp
Logistic => logistic
Erf => erf
Square => square
Sqrt => sqrt
Sin => sin
Cos => cos
Tan => tan
Asin => asin
Acos => acos
Atan => atan
Sinh => sinh
Cosh => cosh
Tanh => tanh
Asinh => asinh
Acosh => acosh
Atanh => atanh
interpretE (Argmin {out} axis x) = argMin {outputType=out} !(get x) axis
interpretE (Argmax {out} axis x) = argMax {outputType=out} !(get x) axis
interpretE (Select pred true false) = select !(get pred) !(get true) !(get false)
interpretE (Cond pred fTrue true fFalse false) = do
subBuilderT <- createSubBuilder xlaBuilder "truthy computation"
subBuilderF <- createSubBuilder xlaBuilder "falsy computation"
compTrue <- lift $ compile subBuilderT fTrue
compFalse <- lift $ compile subBuilderF fFalse
conditional !(get pred) !(get true) compTrue !(get false) compFalse
interpretE (Dot l r) = dot !(get l) !(get r)
interpretE (Cholesky x) = cholesky !(get x) True
interpretE (TriangularSolve a b lower) =
triangularSolve !(get a) !(get b) True lower False NoTranspose
interpretE (UniformFloatingPoint key initialState minval maxval shape) = do
rngOutput <- uniformFloatingPointDistribution
!(get key)
!(get initialState)
ThreeFry
!(get minval)
!(get maxval)
!(mkShape {dtype=F64} shape)
tuple xlaBuilder [value rngOutput, state rngOutput]
interpretE (NormalFloatingPoint key initialState shape) = do
rngOutput <- normalFloatingPointDistribution
!(get key) !(get initialState) ThreeFry !(mkShape {dtype=F64} shape)
tuple xlaBuilder [value rngOutput, state rngOutput]

export covering
toString : Fn 0 -> ErrIO String
toString f = do
xlaBuilder <- mkXlaBuilder "toString"
root <- interpret xlaBuilder f
pure $ opToString xlaBuilder root

export covering
execute : PrimitiveRW dtype a => Fn 0 -> {shape : _} -> ErrIO $ Literal shape a
execute f = do
xlaBuilder <- mkXlaBuilder "root"
computation <- compile xlaBuilder f
gpuStatus <- validateGPUMachineManager
platform <- if ok gpuStatus then gpuMachineManager else getPlatform "Host"
client <- getOrCreateLocalClient platform
Expand Down
Loading