Skip to content

Commit

Permalink
feat: shape ops (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
0xThemis authored Jan 4, 2024
1 parent 526bcf7 commit 5399202
Show file tree
Hide file tree
Showing 19 changed files with 107 additions and 19 deletions.
48 changes: 29 additions & 19 deletions mlir-assigner/include/mlir-assigner/parser/evaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,29 +543,39 @@ namespace zk_ml_toolchain {
// Create the global at the entry of the module.
assert(operation.getValue().has_value() && "Krnl Global must always have a value");
auto value = operation.getValue().value();
//TODO check other bit sizes. Also no range constraint is this necessary????
if (DenseElementsAttr attr = llvm::dyn_cast<DenseElementsAttr>(value)) {

// TODO handle other types
auto floats = attr.tryGetValues<APFloat>();
if (mlir::failed(floats)) {
UNREACHABLE("Unsupported attribute type");
}
size_t idx = 0;
for (auto a : floats.value()) {
double d;
if (&a.getSemantics() == &llvm::APFloat::IEEEdouble()) {
d = a.convertToDouble();
} else if (&a.getSemantics() == &llvm::APFloat::IEEEsingle()) {
d = a.convertToFloat();
} else {
UNREACHABLE("unsupported float semantics");
mlir::Type attrType = attr.getElementType();
if (attrType.isa<mlir::IntegerType>()) {
auto ints = attr.tryGetValues<APInt>();
assert(!mlir::failed(ints) && "must work as we checked above");
size_t idx = 0;
for (auto a : ints.value()) {
auto var = put_into_assignment(a.getSExtValue());
m.put_flat(idx++, var);
}
nil::blueprint::components::FixedPoint<BlueprintFieldType, 1, 1> fixed(d);
auto var = put_into_assignment(fixed.get_value());
m.put_flat(idx++, var);
} else if (attrType.isa<mlir::FloatType>()) {
auto floats = attr.tryGetValues<APFloat>();
assert(!mlir::failed(floats) && "must work as we checked above");
size_t idx = 0;
for (auto a : floats.value()) {
double d;
if (&a.getSemantics() == &llvm::APFloat::IEEEdouble()) {
d = a.convertToDouble();
} else if (&a.getSemantics() == &llvm::APFloat::IEEEsingle()) {
d = a.convertToFloat();
} else {
UNREACHABLE("unsupported float semantics");
}
nil::blueprint::components::FixedPoint<BlueprintFieldType, 1, 1> fixed(d);
auto var = put_into_assignment(fixed.get_value());
m.put_flat(idx++, var);
}
} else {
UNREACHABLE("Unsupported attribute type");
}
} else {
UNREACHABLE("Unsupported attribute type");
UNREACHABLE("Expected a DenseElementsAttr");
}
frames.back().memrefs.insert({mlir::hash_value(operation.getOutput()), m});
return;
Expand Down
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Flatten/FlattenSimple.json

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Flatten/FlattenSimple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
 :n
#
in_aout_a"Flatten*
axis�FlattenSimpleZ
in_a




b
out_a


�1B
Expand Down
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Flatten/FlattenSimple.res

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Flatten/FlattenSimple2.json

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Flatten/FlattenSimple2.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
 :o
#
in_aout_a"Flatten*
axis�FlattenSimple2Z
in_a




b
out_a


�B
Expand Down
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Flatten/FlattenSimple2.res

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Shape/ShapeLarge.json

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Shape/ShapeLarge.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
 :W

in_aout_a"Shape
ShapeLargeZ
in_a




b
out_a


B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Shape/ShapeLarge.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<4xi64>[1, 8, 28, 28]
6278
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Shape/ShapeSimple.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.6541748046875, 0.20391845703125, 0.4686126708984375, 0.5786590576171875, 0.13482666015625, 0.66265869140625, 0.332122802734375, 0.9350128173828125, 0.280731201171875, 0.10223388671875], "dims": [1, 10], "type": "f32"}}]
12 changes: 12 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Shape/ShapeSimple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
 :P

in_aout_a"Shape ShapeSimpleZ
in_a



b
out_a


B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Shape/ShapeSimple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<2xi64>[1, 10]
14
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Size/SizeLarge.json

Large diffs are not rendered by default.

Binary file added mlir-assigner/tests/Ops/Onnx/Size/SizeLarge.onnx
Binary file not shown.
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Size/SizeLarge.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<i64>[6272]
6275
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Size/SizeSimple.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.7107086181640625, 0.2982940673828125, 0.28741455078125, 0.364959716796875, 0.1244964599609375, 0.36456298828125, 0.82781982421875, 0.031829833984375, 0.291259765625, 0.97467041015625], "dims": [1, 10], "type": "f32"}}]
Binary file added mlir-assigner/tests/Ops/Onnx/Size/SizeSimple.onnx
Binary file not shown.
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Size/SizeSimple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<i64>[10]
13

0 comments on commit 5399202

Please sign in to comment.