-
Notifications
You must be signed in to change notification settings - Fork 236
/
Copy pathlocal_training_backing.h
53 lines (43 loc) · 2.02 KB
/
local_training_backing.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#ifndef _FLEXFLOW_LOCAL_EXECUTION_LOCAL_TRAINING_BACKING_H
#define _FLEXFLOW_LOCAL_EXECUTION_LOCAL_TRAINING_BACKING_H
#include "local-execution/local_slots_backing.h"
#include "local-execution/task_registry.h"
#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h"
#include "pcg/computation_graph.dtg.h"
#include "pcg/optimizer_attrs.dtg.h"
namespace FlexFlow {
using PerLayerElapsedTime =
std::unordered_map<layer_guid_t, std::optional<float>>;
struct LocalTrainingBacking {
LocalTrainingBacking(Allocator const &,
ComputationGraph const &,
LayerTensorBackingMap const &allocated_forward_tensors,
TensorBackingMap const &allocated_non_graph_tensors,
RuntimeArgConfig const &);
void register_and_allocate_layer(layer_guid_t const &);
void allocate_layer_optimizer_tensors(layer_guid_t const &,
OptimizerAttrs const &);
void execute_init(layer_guid_t const &);
std::optional<float> execute_forward(layer_guid_t const &);
void compute_loss(LossAttrs const &loss_attrs,
reduced_tensor_t const &logit_tensor,
reduced_tensor_t const &label_tensor);
std::optional<float> execute_backward(layer_guid_t const &);
void execute_update(layer_guid_t const &, OptimizerAttrs const &);
TaskArgumentAccessor
get_task_arg_accessor(TaskInvocation const &,
std::optional<layer_guid_t> const &) const;
TaskArgumentAccessor get_op_task_arg_accessor(OpTaskInvocation const &,
layer_guid_t const &) const;
private:
DeviceSpecificDeviceStates call_init_task_impl(task_id_t,
TaskArgumentAccessor const &);
std::optional<float> call_task_impl(task_id_t, TaskArgumentAccessor);
private:
Allocator allocator;
ComputationGraph computation_graph;
TaskRegistry task_registry;
LocalSlotsBacking local_slots_backing;
};
} // namespace FlexFlow
#endif