Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley committed Dec 13, 2024
1 parent 4f66569 commit cf0d80f
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 30 deletions.
15 changes: 3 additions & 12 deletions spidr/backend/src/stablehlo/dialect/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,17 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "stablehlo/dialect/Serialization.h"
#include "stablehlo/dialect/Version.h"

#include "../../mlir/IR/BuiltinOps.h"
#include "../../ffi.h"

extern "C" {
int serializePortableArtifact(ModuleOp& module, string& str) {
int serializePortableArtifact(ModuleOp& module, string& version, string& str) {
auto& module_ = reinterpret_cast<mlir::ModuleOp&>(module);
auto& version_ = reinterpret_cast<std::string&>(str);
auto& str_ = reinterpret_cast<std::string&>(str);
llvm::raw_string_ostream os(str_);
auto version = mlir::vhlo::Version::getMinimumVersion().toString();
auto result = mlir::stablehlo::serializePortableArtifact(module_, version, os);
auto result = mlir::stablehlo::serializePortableArtifact(module_, version_, os);
return (int) result.succeeded();
}

string* printModule(ModuleOp& module) {
auto& module_ = reinterpret_cast<mlir::ModuleOp&>(module);
auto str = new std::string();
llvm::raw_string_ostream os(*str);
module_.print(os);
return reinterpret_cast<string*>(str);
}
}
36 changes: 36 additions & 0 deletions spidr/backend/src/stablehlo/dialect/Version.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
Copyright 2024 Joel Berkeley
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "stablehlo/dialect/Version.h"

#include "../../ffi.h"

extern "C" {
struct Version;

void Version_delete(Version* s) {
delete reinterpret_cast<mlir::vhlo::Version*>(s);
}

Version* Version_getMinimumVersion() {
auto version = mlir::vhlo::Version::getMinimumVersion();
return reinterpret_cast<Version*>(new mlir::vhlo::Version(version));
}

string* Version_toString(Version& s) {
auto& s_ = reinterpret_cast<mlir::vhlo::Version&>(s);
return reinterpret_cast<string*>(new std::string(s_.toString()));
}
}
1 change: 1 addition & 0 deletions spidr/spidr.ipkg
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ modules =
Compiler.MLIR.IR.MLIRContext,
Compiler.StableHLO.Dialect.Register,
Compiler.StableHLO.Dialect.Serialization,
Compiler.StableHLO.Dialect.Version,
Compiler.Xla.Client.ExecutableBuildOptions,
Compiler.Xla.HLO.Builder.Lib.Arithmetic,
Compiler.Xla.HLO.Builder.Lib.Constants,
Expand Down
22 changes: 14 additions & 8 deletions spidr/src/Compiler/Eval.idr
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import Compiler.MLIR.IR.DialectRegistry
import Compiler.MLIR.IR.MLIRContext
import Compiler.StableHLO.Dialect.Register
import Compiler.StableHLO.Dialect.Serialization
import Compiler.StableHLO.Dialect.Version
import Compiler.Xla.Client.ExecutableBuildOptions
import Compiler.Xla.HLO.Builder.Lib.Arithmetic
import Compiler.Xla.HLO.Builder.Lib.Constants
Expand Down Expand Up @@ -229,20 +230,25 @@ interpret @{cache} xlaBuilder (MkFn params root env) = do
!(interpretE key) !(interpretE initialState) ThreeFry !(mkShape {dtype = F64} shape)
tuple xlaBuilder [value rngOutput, state rngOutput]

hloModuleProtoToStableHLO : HloModuleProto -> ErrIO CharArray
hloModuleProtoToStableHLO proto = do
dialectRegistry <- mkDialectRegistry
registerAllMhloDialects dialectRegistry
registerAllDialects dialectRegistry
mlirCtx <- mkMLIRContext
stablehlo <- convertHloToStablehlo mlirCtx proto
appendDialectRegistry mlirCtx dialectRegistry
Just code <- serializePortableArtifact stablehlo !(toString !getMinimumVersion)
| Nothing => throwE (SerializationError "Failed to serialize StableHLO")
pure code

||| It is up to the caller to free the `Literal`s.
export covering
execute : Device -> Fn 0 -> {outputs : _} -> Vect outputs Xla.Shape -> ErrIO $ Vect outputs Literal
execute (MkDevice api client) f@(MkFn _ _ env) shapes = do
xlaBuilder <- mkXlaBuilder "root"
computation <- compile @{!(newArray $ cast $ counter env)} xlaBuilder f
dialectRegistry <- mkDialectRegistry
registerAllMhloDialects dialectRegistry
registerAllDialects dialectRegistry
mlirCtx <- mkMLIRContext
stablehlo <- convertHloToStablehlo mlirCtx !(proto computation)
appendDialectRegistry mlirCtx dialectRegistry
Just code <- serializePortableArtifact stablehlo | Nothing => throwE (SerializationError "Failed to serialize StableHLO")
-- code <- printModule stablehlo
code <- hloModuleProtoToStableHLO !(proto computation)
executableBuildOptions <- mkExecutableBuildOptions
compileOptions <- serializeAsString !(mkCompileOptions executableBuildOptions)
program <- mkPjrtProgram code
Expand Down
3 changes: 3 additions & 0 deletions spidr/src/Compiler/FFI.idr
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ libxla fname = "C:" ++ fname ++ ",libc_xla"
public export
data CharArray = MkCharArray (Ptr Char) Bits64

public export
data CppString = MkCppString GCAnyPtr

namespace CharArray
export
free : HasIO io => CharArray -> io ()
Expand Down
13 changes: 3 additions & 10 deletions spidr/src/Compiler/StableHLO/Dialect/Serialization.idr
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,10 @@ import Compiler.FFI
prim__serializePortableArtifact : AnyPtr -> AnyPtr -> PrimIO Int

export
serializePortableArtifact : HasIO io => ModuleOp -> io (Maybe CharArray)
serializePortableArtifact (MkModuleOp moduleOp) = do
serializePortableArtifact : HasIO io => ModuleOp -> CppString -> io (Maybe CharArray)
serializePortableArtifact (MkModuleOp moduleOp) (MkCppString version) = do
str <- primIO prim__stringNew
ok <- primIO $ prim__serializePortableArtifact moduleOp str
ok <- primIO $ prim__serializePortableArtifact moduleOp version str
case cIntToBool ok of
True => Just <$> stringToCharArray str
False => free str >> pure Nothing

%foreign (libxla "printModule")
prim__printModule : AnyPtr -> PrimIO AnyPtr

export
printModule : HasIO io => ModuleOp -> io CharArray
printModule (MkModuleOp moduleOp) = primIO (prim__printModule moduleOp) >>= stringToCharArray
45 changes: 45 additions & 0 deletions spidr/src/Compiler/StableHLO/Dialect/Version.idr
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
{--
Copyright 2024 Joel Berkeley
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--}
||| For internal spidr use only.
module Compiler.StableHLO.Dialect.Version

import Compiler.FFI

export
data Version = MkVersion GCAnyPtr

%foreign (libxla "Version_delete")
prim__delete : AnyPtr -> PrimIO ()

%foreign (libxla "Version_getMinimumVersion")
prim__versionGetMinimumVersion : PrimIO AnyPtr

export
getMinimumVersion : HasIO io => io Version
getMinimumVersion = do
version <- primIO prim__versionGetMinimumVersion
version <- onCollectAny version (primIO . prim__delete)
pure (MkVersion version)

%foreign (libxla "Version_toString")
prim__versionToString : GCAnyPtr -> PrimIO AnyPtr

export
toString : HasIO io => Version -> io CppString
toString (MkVersion version) = do
str <- primIO $ prim__versionToString version
str <- onCollectAny str (primIO . prim__stringDelete)
pure (MkCppString str)

0 comments on commit cf0d80f

Please sign in to comment.