Skip to content

Commit

Permalink
[BUGFIX] copy graph index to shared memory. (#634)
Browse files Browse the repository at this point in the history
* copy graph index to shared memory.

* fix.

* fix.

* fix.

* use a diff name for in-csr and out-csr.

* fix lint.

* remove print.

* add test.

* add comments.
  • Loading branch information
zheng-da authored Jun 12, 2019
1 parent 16ec2a8 commit 94ecb8e
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 35 deletions.
8 changes: 3 additions & 5 deletions examples/mxnet/sampling/run_store_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@ class GraphData:
def __init__(self, csr, num_feats, graph_name):
num_nodes = csr.shape[0]
num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0]
edge_ids = np.arange(0, num_edges, step=1, dtype=np.int64)
self.graph = dgl.graph_index.from_csr_matrix(
dgl.utils.toindex(csr.indptr), dgl.utils.toindex(csr.indices), False,
"in", dgl.contrib.graph_store._get_graph_path(graph_name))
self.graph = dgl.graph_index.from_csr(csr.indptr, csr.indices, False,
'in', dgl.contrib.graph_store._get_graph_path(graph_name))
self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats))
self.num_labels = 10
self.labels = mx.nd.floor(mx.nd.random.uniform(low=0, high=self.num_labels,
Expand Down Expand Up @@ -69,7 +67,7 @@ def main(args):
# create GCN model
print('graph name: ' + graph_name)
g = dgl.contrib.graph_store.create_graph_store_server(data.graph, graph_name, "shared_mem",
args.num_workers, False)
args.num_workers, False, edge_dir='in')
g.ndata['features'] = features
g.ndata['labels'] = labels
g.ndata['train_mask'] = train_mask
Expand Down
128 changes: 128 additions & 0 deletions include/dgl/immutable_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ class CSR : public GraphInterface {
return {indptr_, indices_, edge_ids_};
}

/*! \brief Indicate whether this uses shared memory. */
bool IsSharedMem() const {
return !shared_mem_name_.empty();
}

/*! \brief Return the reverse of this CSR graph (i.e, a CSC graph) */
CSRPtr Transpose() const;

Expand All @@ -230,6 +235,13 @@ class CSR : public GraphInterface {
*/
CSR CopyTo(const DLContext& ctx) const;

/*!
* \brief Copy data to shared memory.
* \param name The name of the shared memory.
* \return The graph in the shared memory
*/
CSR CopyToSharedMem(const std::string &name) const;

/*!
* \brief Convert the graph to use the given number of bits for storage.
* \param bits The new number of integer bits (32 or 64).
Expand Down Expand Up @@ -262,6 +274,10 @@ class CSR : public GraphInterface {

// whether the graph is a multi-graph
LazyObject<bool> is_multigraph_;

// The name of the shared memory to store data.
// If it's empty, data isn't stored in shared memory.
std::string shared_mem_name_;
};

class COO : public GraphInterface {
Expand Down Expand Up @@ -478,13 +494,25 @@ class COO : public GraphInterface {
*/
COO CopyTo(const DLContext& ctx) const;

/*!
* \brief Copy data to shared memory.
* \param name The name of the shared memory.
* \return The graph in the shared memory
*/
COO CopyToSharedMem(const std::string &name) const;

/*!
* \brief Convert the graph to use the given number of bits for storage.
* \param bits The new number of integer bits (32 or 64).
* \return The graph with new bit size storage.
*/
COO AsNumBits(uint8_t bits) const;

/*! \brief Indicate whether this uses shared memory. */
bool IsSharedMem() const {
return false;
}

// member getters

IdArray src() const { return src_; }
Expand Down Expand Up @@ -512,6 +540,7 @@ class ImmutableGraph: public GraphInterface {
public:
/*! \brief Construct an immutable graph from the COO format. */
explicit ImmutableGraph(COOPtr coo): coo_(coo) { }

/*!
* \brief Construct an immutable graph from the CSR format.
*
Expand Down Expand Up @@ -889,6 +918,9 @@ class ImmutableGraph: public GraphInterface {
if (!in_csr_) {
if (out_csr_) {
const_cast<ImmutableGraph*>(this)->in_csr_ = out_csr_->Transpose();
if (out_csr_->IsSharedMem())
LOG(WARNING) << "We just construct an in-CSR from a shared-memory out CSR. "
<< "It may dramatically increase memory consumption.";
} else {
CHECK(coo_) << "None of CSR, COO exist";
const_cast<ImmutableGraph*>(this)->in_csr_ = coo_->Transpose()->ToCSR();
Expand All @@ -902,6 +934,9 @@ class ImmutableGraph: public GraphInterface {
if (!out_csr_) {
if (in_csr_) {
const_cast<ImmutableGraph*>(this)->out_csr_ = in_csr_->Transpose();
if (in_csr_->IsSharedMem())
LOG(WARNING) << "We just construct an out-CSR from a shared-memory in CSR. "
<< "It may dramatically increase memory consumption.";
} else {
CHECK(coo_) << "None of CSR, COO exist";
const_cast<ImmutableGraph*>(this)->out_csr_ = coo_->ToCSR();
Expand Down Expand Up @@ -941,13 +976,92 @@ class ImmutableGraph: public GraphInterface {
*/
ImmutableGraph CopyTo(const DLContext& ctx) const;

/*!
* \brief Copy data to shared memory.
* \param edge_dir the graph of the specific edge direction to be copied.
* \param name The name of the shared memory.
* \return The graph in the shared memory
*/
ImmutableGraph CopyToSharedMem(const std::string &edge_dir, const std::string &name) const;

/*!
* \brief Convert the graph to use the given number of bits for storage.
* \param bits The new number of integer bits (32 or 64).
* \return The graph with new bit size storage.
*/
ImmutableGraph AsNumBits(uint8_t bits) const;

/*! \brief Create an immutable graph from CSR. */
static ImmutableGraph CreateFromCSR(IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir) {
CSRPtr csr(new CSR(indptr, indices, edge_ids));
if (edge_dir == "in") {
return ImmutableGraph(csr, nullptr);
} else if (edge_dir == "out") {
return ImmutableGraph(nullptr, csr);
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraph();
}
}

static ImmutableGraph CreateFromCSR(IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph));
if (edge_dir == "in") {
return ImmutableGraph(csr, nullptr);
} else if (edge_dir == "out") {
return ImmutableGraph(nullptr, csr);
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraph();
}
}

static ImmutableGraph CreateFromCSR(IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir,
const std::string &shared_mem_name) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, GetSharedMemName(shared_mem_name, edge_dir)));
if (edge_dir == "in") {
return ImmutableGraph(csr, nullptr, shared_mem_name);
} else if (edge_dir == "out") {
return ImmutableGraph(nullptr, csr, shared_mem_name);
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraph();
}
}

static ImmutableGraph CreateFromCSR(IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir,
const std::string &shared_mem_name) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph,
GetSharedMemName(shared_mem_name, edge_dir)));
if (edge_dir == "in") {
return ImmutableGraph(csr, nullptr, shared_mem_name);
} else if (edge_dir == "out") {
return ImmutableGraph(nullptr, csr, shared_mem_name);
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraph();
}
}

static ImmutableGraph CreateFromCSR(const std::string &shared_mem_name, size_t num_vertices,
size_t num_edges, bool multigraph,
const std::string &edge_dir) {
CSRPtr csr(new CSR(GetSharedMemName(shared_mem_name, edge_dir), num_vertices, num_edges,
multigraph));
if (edge_dir == "in") {
return ImmutableGraph(csr, nullptr, shared_mem_name);
} else if (edge_dir == "out") {
return ImmutableGraph(nullptr, csr, shared_mem_name);
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraph();
}
}

protected:
/* !\brief internal default constructor */
ImmutableGraph() {}
Expand All @@ -958,6 +1072,16 @@ class ImmutableGraph: public GraphInterface {
CHECK(AnyGraph()) << "At least one graph structure should exist.";
}

ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr, const std::string shared_mem_name)
: in_csr_(in_csr), out_csr_(out_csr) {
CHECK(in_csr_ || out_csr_) << "Both CSR are missing.";
this->shared_mem_name_ = shared_mem_name;
}

static std::string GetSharedMemName(const std::string &name, const std::string &edge_dir) {
return name + "_" + edge_dir;
}

/* !\brief return pointer to any available graph structure */
GraphPtr AnyGraph() const {
if (in_csr_) {
Expand All @@ -975,6 +1099,10 @@ class ImmutableGraph: public GraphInterface {
CSRPtr out_csr_;
// Store the edge list indexed by edge id (COO)
COOPtr coo_;

// The name of shared memory for this graph.
// If it's empty, the graph isn't stored in shared memory.
std::string shared_mem_name_;
};

// inline implementations
Expand Down
6 changes: 5 additions & 1 deletion python/dgl/contrib/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,11 @@ class SharedMemoryStoreServer(object):
"""
def __init__(self, graph_data, edge_dir, graph_name, multigraph, num_workers, port):
self.server = None
if isinstance(graph_data, (GraphIndex, DGLGraph)):
if isinstance(graph_data, GraphIndex):
graph_data = graph_data.copyto_shared_mem(edge_dir, _get_graph_path(graph_name))
self._graph = DGLGraph(graph_data, multigraph=multigraph, readonly=True)
elif isinstance(graph_data, DGLGraph):
graph_data = graph_data._graph.copyto_shared_mem(edge_dir, _get_graph_path(graph_name))
self._graph = DGLGraph(graph_data, multigraph=multigraph, readonly=True)
else:
indptr, indices = _to_csr(graph_data, edge_dir, multigraph)
Expand Down
20 changes: 20 additions & 0 deletions python/dgl/graph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,26 @@ def copy_to(self, ctx):
handle = _CAPI_DGLImmutableGraphCopyTo(self._handle, ctx.device_type, ctx.device_id)
return GraphIndex(handle)

def copyto_shared_mem(self, edge_dir, shared_mem_name):
"""Copy this immutable graph index to shared memory.
NOTE: this method only works for immutable graph index
Parameters
----------
edge_dir : string
Indicate which CSR should copy ("in", "out", "both").
shared_mem_name : string
The name of the shared memory.
Returns
-------
GraphIndex
The graph index on the given device context.
"""
handle = _CAPI_DGLImmutableGraphCopyToSharedMem(self._handle, edge_dir, shared_mem_name)
return GraphIndex(handle)

def nbits(self):
"""Return the number of integer bits used in the storage (32 or 64).
Expand Down
43 changes: 25 additions & 18 deletions src/graph/graph_apis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,33 +158,32 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate")
const std::string shared_mem_name = args[2];
const int multigraph = args[3];
const std::string edge_dir = args[4];
CSRPtr csr;

IdArray edge_ids = IdArray::Empty({indices->shape[0]},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t *edge_data = static_cast<int64_t *>(edge_ids->data);
for (size_t i = 0; i < edge_ids->shape[0]; i++)
edge_data[i] = i;
ImmutableGraph *g = nullptr;
if (shared_mem_name.empty()) {
if (multigraph == kBoolUnknown) {
csr.reset(new CSR(indptr, indices, edge_ids));
g = new ImmutableGraph(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids,
edge_dir));
} else {
csr.reset(new CSR(indptr, indices, edge_ids, multigraph));
g = new ImmutableGraph(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids,
multigraph, edge_dir));
}
} else {
if (multigraph == kBoolUnknown) {
csr.reset(new CSR(indptr, indices, edge_ids, shared_mem_name));
g = new ImmutableGraph(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids,
edge_dir, shared_mem_name));
} else {
csr.reset(new CSR(indptr, indices, edge_ids, multigraph, shared_mem_name));
g = new ImmutableGraph(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids,
multigraph, edge_dir,
shared_mem_name));
}
}

GraphHandle ghandle;
if (edge_dir == "in")
ghandle = new ImmutableGraph(csr, nullptr);
else
ghandle = new ImmutableGraph(nullptr, csr);
*rv = ghandle;
*rv = g;
});

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreateMMap")
Expand All @@ -195,12 +194,8 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreateMMap")
const bool multigraph = args[3];
const std::string edge_dir = args[4];
// TODO(minjie): how to know multigraph
CSRPtr csr(new CSR(shared_mem_name, num_vertices, num_edges, multigraph));
GraphHandle ghandle;
if (edge_dir == "in")
ghandle = new ImmutableGraph(csr, nullptr);
else
ghandle = new ImmutableGraph(nullptr, csr);
GraphHandle ghandle = new ImmutableGraph(ImmutableGraph::CreateFromCSR(
shared_mem_name, num_vertices, num_edges, multigraph, edge_dir));
*rv = ghandle;
});

Expand Down Expand Up @@ -546,6 +541,18 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
*rv = newhandle;
});

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyToSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
std::string edge_dir = args[1];
std::string name = args[2];
const GraphInterface *ptr = static_cast<GraphInterface *>(ghandle);
const ImmutableGraph *ig = dynamic_cast<const ImmutableGraph*>(ptr);
CHECK(ig) << "Invalid argument: must be an immutable graph object.";
GraphHandle newhandle = new ImmutableGraph(ig->CopyToSharedMem(edge_dir, name));
*rv = newhandle;
});

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
Expand Down
Loading

0 comments on commit 94ecb8e

Please sign in to comment.