Skip to content

Commit

Permalink
[CONTAINER] Struct Hash/Equal and JSON support for ShapeTuple (#13671)
Browse files Browse the repository at this point in the history
This PR add struct equal/hash and json serialization support
for shape tuple. Testcases added.
  • Loading branch information
tqchen authored Dec 29, 2022
1 parent 8551a5c commit d582b7e
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 2 deletions.
44 changes: 44 additions & 0 deletions src/node/structural_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,50 @@ TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait)
return ::tvm::runtime::make_object<ArrayNode>();
});

struct ShapeTupleObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;

static void SHashReduce(const ShapeTupleObj* self, SHashReducer hash_reduce) {
hash_reduce(self->size);
for (size_t i = 0; i < self->size; ++i) {
hash_reduce(self->data[i]);
}
}

static bool SEqualReduce(const ShapeTupleObj* lhs, const ShapeTupleObj* rhs,
SEqualReducer equal) {
if (lhs->size != rhs->size) return false;
for (size_t i = 0; i < lhs->size; ++i) {
if (!equal(lhs->data[i], rhs->data[i])) return false;
}
return true;
}
};

TVM_REGISTER_REFLECTION_VTABLE(ShapeTupleObj, ShapeTupleObjTrait)
.set_creator([](const std::string& blob) {
// Store shape tuple in blob to avoid large integer overflow in JSON.
dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob));
support::Base64InStream b64strm(&mstrm);
b64strm.InitPosition();
uint64_t size;
b64strm.Read<uint64_t>(&size);
std::vector<int64_t> data(size);
b64strm.ReadArray(data.data(), size);
ShapeTuple shape(data);
return RefToObjectPtr::Get(shape);
})
.set_repr_bytes([](const Object* n) -> std::string {
std::string blob;
dmlc::MemoryStringStream mstrm(&blob);
support::Base64OutStream b64strm(&mstrm);
const auto* shape = static_cast<const runtime::ShapeTupleObj*>(n);
b64strm.Write<uint64_t>(shape->size);
b64strm.WriteArray(shape->data, shape->size);
b64strm.Finish();
return blob;
});

struct MapNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;

Expand Down
9 changes: 7 additions & 2 deletions src/support/base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,10 @@ class Base64InStream : public dmlc::Stream {
}
/*! \brief whether current position is end of a base64 stream */
bool IsEOF(void) const { return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_)); }

using dmlc::Stream::Read;
// override read function.
virtual size_t Read(void* ptr, size_t size) {
size_t Read(void* ptr, size_t size) final {
using base64::DecodeTable;
if (size == 0) return 0;
// use tlen to record left size
Expand Down Expand Up @@ -224,7 +226,10 @@ class Base64InStream : public dmlc::Stream {
class Base64OutStream : public dmlc::Stream {
public:
explicit Base64OutStream(dmlc::Stream* fp) : fp_(fp) {}
virtual void Write(const void* ptr, size_t size) {

using dmlc::Stream::Write;

void Write(const void* ptr, size_t size) final {
using base64::EncodeTable;
size_t tlen = size;
const unsigned char* cptr = static_cast<const unsigned char*>(ptr);
Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_container_structural_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@ def test_array_structural_equal_to_self(contents):
assert get_first_mismatch_ensure_symmetry(a, b) is None


@pytest.mark.parametrize(
"contents",
[
[],
[1],
[1, 2, 3],
],
)
def test_shape_tuple_structural_equal_to_self(contents):
a = tvm.runtime.ShapeTuple(list(contents))
b = tvm.runtime.ShapeTuple(list(contents))
assert get_first_mismatch_ensure_symmetry(a, b) is None


@pytest.mark.parametrize(
"a, b, expected_a_path, expected_b_path",
[
Expand Down
5 changes: 5 additions & 0 deletions tests/python/unittest/test_runtime_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def test_shape_tuple():
# ShapleTuple vs. ShapeTuple
assert stuple == _container.ShapeTuple(shape)

# test pickle
z = pickle.loads(pickle.dumps(stuple))
assert isinstance(z, tvm.runtime.ShapeTuple)
assert stuple == z


if __name__ == "__main__":
test_string()
Expand Down

0 comments on commit d582b7e

Please sign in to comment.