Skip to content

Commit

Permalink
Implementing union-find for types (WIP) (apache#38)
Browse files Browse the repository at this point in the history
* First pass at union-find

* Misc. style

* Simplify union-find interface; will keep the matter of unifying TYPES rather than type vars in a separate visitor

* Whitespace

* Whitespace for lint

* Use pointer equality

* Move union-find to separate files

* Whitespace for linter

* It was actually a mistake to compare the references at all
  • Loading branch information
slyubomirsky authored and jroesch committed Aug 16, 2018
1 parent 7babba8 commit 4d6315f
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 145 deletions.
45 changes: 45 additions & 0 deletions relay/include/relay/unifier.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*!
* Copyright (c) 2018 by Contributors
* \file nnvm/relay/unifier.h
* \brief Type unification data structures
*/
#ifndef NNVM_RELAY_UNIFIER_H_
#define NNVM_RELAY_UNIFIER_H_

#include <nnvm/relay/node.h>

namespace nnvm {
namespace relay {

/*! \brief a union-find data structure for the type-checker */
class UnionFind;
class UnionFindNode : public ExprNode {
public:
tvm::Map<TypeVar, TypeVar> uf_map;

UnionFindNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("uf_map", &uf_map);
}

TVM_DLL static UnionFind make(tvm::Map<TypeVar, TypeVar> uf_map);

// insert v into UF
void insert(const TypeVar& v);

// infers that v1 and v2 must be of the smae type
void unify(const TypeVar& v1, const TypeVar& v2);

// returns representative of v's UF-group
TypeVar find(const TypeVar& v);

static constexpr const char* _type_key = "nnvm.UnionFind";
TVM_DECLARE_NODE_TYPE_INFO(UnionFindNode, ExprNode);
};

TVM_DEFINE_NODE_REF(UnionFind, UnionFindNode);

} // namespace relay
} // namespace nnvm
#endif // NNVM_RELAY_UNIFIER_H_
Loading

0 comments on commit 4d6315f

Please sign in to comment.