-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathexample2_main.cc
66 lines (57 loc) · 1.7 KB
/
example2_main.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
// Copyright 2019 the deepx authors.
// Author: Yafei Zhang ([email protected])
//
#include <deepx_core/dx_log.h>
#include <deepx_core/graph/graph.h>
#include <deepx_core/graph/graph_node.h>
#include <deepx_core/graph/op_context.h>
#include <deepx_core/graph/tensor_map.h>
#include <deepx_core/tensor/data_type.h>
#include <iostream>
#include <vector>
namespace deepx_core {
namespace {
class Main : public DataType {
public:
static int main() {
Graph graph;
TensorMap param;
// Initialize graph: Z = X * W + B.
InstanceNode X("X", Shape(1), TENSOR_TYPE_TSR);
VariableNode W("W", Shape(1), TENSOR_TYPE_TSR);
VariableNode B("B", Shape(1), TENSOR_TYPE_TSR);
MulNode XW("XW", &X, &W);
AddNode Z("Z", &XW, &B);
DXCHECK_THROW(graph.Compile({&Z}, 0));
// Initialize param.
auto& _W = param.insert<tsr_t>(W.name());
_W.resize(W.shape());
_W.data(0) = 2;
auto& _B = param.insert<tsr_t>(B.name());
_B.resize(B.shape());
_B.data(0) = 3;
// Initialize op context.
OpContext op_context;
op_context.Init(&graph, ¶m);
DXCHECK_THROW(op_context.InitOp(std::vector<int>{0}, -1));
auto& _X = op_context.mutable_inst()->insert<tsr_t>(X.name());
_X.resize(X.shape());
op_context.InitForward();
// Input, forward, output.
auto compute = [&op_context, &_X, &Z](float_t x) {
_X.data(0) = x;
op_context.Forward();
const auto& _Z = op_context.hidden().get<tsr_t>(Z.name());
float_t z = _Z.data(0);
std::cout << "Z=" << z << std::endl;
};
compute(1);
compute(2);
compute(3);
compute(10);
return 0;
}
};
} // namespace
} // namespace deepx_core
int main() { return deepx_core::Main::main(); }