Skip to content

Commit

Permalink
move dotdimnums
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley committed Dec 30, 2023
1 parent a34f95b commit 8723e5b
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 123 deletions.
1 change: 0 additions & 1 deletion backend/src/tensorflow/compiler/xla/client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ cc_library(
"@xla_extension//:xla_extension",
"//src:src",
"//src/tensorflow/compiler/xla:xla",
"//src/tensorflow/compiler/xla/service/gpu/runtime:runtime",
"//src/tensorflow/stream_executor:stream_executor",
],
visibility = ["//visibility:public"],
Expand Down
1 change: 0 additions & 1 deletion backend/src/tensorflow/compiler/xla/client/xla_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"

#include "../literal.h"
#include "../service/gpu/runtime/support.h"
#include "xla_builder.h"
#include "xla_computation.h"

Expand Down
2 changes: 1 addition & 1 deletion backend/src/tensorflow/compiler/xla/client/xla_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h"

#include "../literal.h"
#include "../service/gpu/runtime/support.h"
#include "../xla_data.pb.h"
#include "xla_computation.h"

extern "C" {
Expand Down
25 changes: 0 additions & 25 deletions backend/src/tensorflow/compiler/xla/service/gpu/runtime/BUILD

This file was deleted.

37 changes: 0 additions & 37 deletions backend/src/tensorflow/compiler/xla/service/gpu/runtime/support.h

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ limitations under the License.
*/
#include "tensorflow/compiler/xla/xla_data.pb.h"

#include "support.h"

extern "C" {
struct DotDimensionNumbers;

DotDimensionNumbers* DotDimensionNumbers_new() {
return reinterpret_cast<DotDimensionNumbers*>(new xla::DotDimensionNumbers());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--}
module Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Service.GPU.Runtime.Support
module Compiler.Xla.Prim.TensorFlow.Compiler.Xla.XlaData

import System.FFI

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Client.XlaBuilder
import Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaComputation
import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData
import Compiler.Xla.TensorFlow.Compiler.Xla.Literal
import Compiler.Xla.TensorFlow.Compiler.Xla.Service.GPU.Runtime.Support
import Compiler.Xla.TensorFlow.Compiler.Xla.Shape
import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData
import Compiler.Xla.Util
import Types
import Util
Expand Down

This file was deleted.

38 changes: 38 additions & 0 deletions src/Compiler/Xla/TensorFlow/Compiler/Xla/XlaData.idr
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.
--}
module Compiler.Xla.TensorFlow.Compiler.Xla.XlaData

import Compiler.Xla.Prim.TensorFlow.Compiler.Xla.XlaData

export
interface Primitive dtype where
xlaIdentifier : Int
Expand Down Expand Up @@ -60,3 +62,39 @@ export data F64 : Type where
export
Primitive F64 where
xlaIdentifier = 12

namespace Xla
public export
data DotDimensionNumbers : Type where
MkDotDimensionNumbers : GCAnyPtr -> DotDimensionNumbers

export
delete : HasIO io => AnyPtr -> io ()
delete = primIO . prim__dotDimensionNumbersDelete

export
allocDotDimensionNumbers : HasIO io => io DotDimensionNumbers
allocDotDimensionNumbers = do
ptr <- primIO prim__dotDimensionNumbersNew
ptr <- onCollectAny ptr delete
pure (MkDotDimensionNumbers ptr)

export
addLhsContractingDimensions : HasIO io => DotDimensionNumbers -> Nat -> io ()
addLhsContractingDimensions (MkDotDimensionNumbers dimension_numbers) n =
primIO $ prim__addLhsContractingDimensions dimension_numbers (cast n)

export
addRhsContractingDimensions : HasIO io => DotDimensionNumbers -> Nat -> io ()
addRhsContractingDimensions (MkDotDimensionNumbers dimension_numbers) n =
primIO $ prim__addRhsContractingDimensions dimension_numbers (cast n)

export
addLhsBatchDimensions : HasIO io => DotDimensionNumbers -> Nat -> io ()
addLhsBatchDimensions (MkDotDimensionNumbers dimension_numbers) n =
primIO $ prim__addLhsBatchDimensions dimension_numbers (cast n)

export
addRhsBatchDimensions : HasIO io => DotDimensionNumbers -> Nat -> io ()
addRhsBatchDimensions (MkDotDimensionNumbers dimension_numbers) n =
primIO $ prim__addRhsBatchDimensions dimension_numbers (cast n)

0 comments on commit 8723e5b

Please sign in to comment.