forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_train.h
122 lines (101 loc) · 3.21 KB
/
test_train.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#pragma once
/*
Nothing particularly complex here, just a way to construct training graphs for
NNC.
Skips all layers above NNC on the stack which is useful for performance ablation
studies.
*/
#include <torch/csrc/jit/tensorexpr/tensor.h>
#include <functional>
#include <list>
#include <vector>
// Virtual "graph" for testing/benchmarking full training in NNC
struct VTensor; // Virtual tensors of symbolic shapes
struct VOp; // Virtual operators
struct VGraph; // Owner of bipartite VTensor/VOp graph
// VOps reference VMethods, or "virtual" methods that store
// 1) TensorExpr construction function (for lowering)
// 2) Grad construction function (for differentiating)
// 3) Shape functions (TODO this actually comes for free from TE)
struct VMethod;
// Utility for graph construction by op
std::vector<VTensor*> call(
const std::string& name,
const std::vector<VTensor*>& vs);
// Utility for graph construction by differentiation
VTensor* grad(VTensor* y, VTensor* x, VTensor* j);
std::string dot(const VGraph& g);
std::tuple<
torch::jit::tensorexpr::Stmt*,
std::map<const VTensor*, torch::jit::tensorexpr::Placeholder>,
std::map<const VTensor*, torch::jit::tensorexpr::Tensor*>,
std::map<std::string, torch::jit::tensorexpr::VarHandle>>
to_tensorexpr(const VGraph& graph, std::vector<VTensor*> outputs = {});
/* IMPL */
struct VMethod {
using LowerFn = std::function<std::vector<torch::jit::tensorexpr::Tensor*>(
const std::vector<torch::jit::tensorexpr::Tensor*>&,
const std::vector<VTensor*>&,
const std::map<std::string, torch::jit::tensorexpr::VarHandle>&)>;
using GradFn = std::function<std::vector<VTensor*>(
const std::vector<VTensor*>&,
const std::vector<VTensor*>&)>;
using ShapeFn = std::function<std::vector<std::vector<std::string>>(
const std::vector<VTensor*>&)>;
// Lookup from name
static const VMethod& get(const std::string& name);
LowerFn lower;
GradFn grad;
ShapeFn shape;
std::string name;
size_t num_outputs;
};
struct VTensor {
VTensor(std::vector<std::string> shape_) : shape(shape_) {}
std::vector<std::string> shape;
VOp* op = nullptr;
std::vector<VOp*> consumers;
VGraph* graph;
};
struct VOp {
VOp(const std::string& method_name,
const std::vector<VTensor*>& inputs_,
size_t num_outputs,
VGraph* graph_);
std::vector<VTensor*> inputs = {};
std::vector<VTensor*> outputs = {};
const VMethod* method;
VGraph* graph;
};
struct VGraph {
inline VTensor* create_tensor(std::vector<std::string> dims) {
vtensors.emplace_back(dims);
for (auto d : dims) {
}
auto* v = &vtensors.back();
v->graph = this;
return v;
}
inline VOp* create_op(
std::string method,
const std::vector<VTensor*>& inputs,
size_t num_outputs) {
vops.emplace_back(method, inputs, num_outputs, this);
auto* o = &vops.back();
o->graph = this;
return o;
}
std::list<VTensor> vtensors;
std::list<VOp> vops;
};
class RegMethod {
public:
RegMethod(
std::string name,
VMethod::LowerFn lower,
VMethod::GradFn grad,
VMethod::ShapeFn shape,
size_t num_out = 1);
};
#define REGISTER_METHOD(name, ...) \
static RegMethod _reg_method_##name(#name, __VA_ARGS__);