-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathexample4_main.cc
66 lines (56 loc) · 1.8 KB
/
example4_main.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
// Copyright 2019 the deepx authors.
// Author: Yafei Zhang ([email protected])
//
#include <deepx_core/dx_log.h>
#include <deepx_core/graph/graph.h>
#include <deepx_core/graph/graph_node.h>
#include <deepx_core/graph/op_context.h>
#include <deepx_core/graph/tensor_map.h>
#include <deepx_core/tensor/data_type.h>
#include <iostream>
#include <random>
#include <vector>
namespace deepx_core {
namespace {
class Main : public DataType {
public:
static int main() {
std::default_random_engine engine;
Graph graph;
TensorMap param;
// Initialize graph: Z = X * W + B.
InstanceNode X("X", Shape(-1, 10), TENSOR_TYPE_TSR);
VariableNode W("W", Shape(10, 1), TENSOR_TYPE_TSR);
VariableNode B("B", Shape(1), TENSOR_TYPE_TSR);
MatmulNode XW("XW", &X, &W);
BroadcastAddNode Z("Z", &XW, &B);
DXCHECK_THROW(graph.Compile({&XW, &Z}, 0));
// Initialize param.
auto& _W = param.insert<tsr_t>(W.name());
_W.resize(W.shape());
_W.randn(engine);
auto& _B = param.insert<tsr_t>(B.name());
_B.resize(B.shape());
_B.randn(engine);
// Initialize op context.
OpContext op_context;
op_context.Init(&graph, ¶m);
DXCHECK_THROW(op_context.InitOp(std::vector<int>{0, 1}, -1));
// Input, forward, output.
for (int i = 0; i < 3; ++i) {
auto& _X = op_context.mutable_inst()->insert<tsr_t>(X.name());
_X.resize(2 + i, 10);
_X.randn(engine);
op_context.InitForward();
op_context.Forward();
const auto& _XW = op_context.hidden().get<tsr_t>(XW.name());
const auto& _Z = op_context.hidden().get<tsr_t>(Z.name());
std::cout << "XW=" << _XW << std::endl;
std::cout << "Z=" << _Z << std::endl;
}
return 0;
}
};
} // namespace
} // namespace deepx_core
int main() { return deepx_core::Main::main(); }