diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index d3831057b..dcd78e395 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -145,6 +145,9 @@ if(BUILD_PROTOBUF) add_executable(onnx_gradients "onnx/4_onnx_test_gradients.cpp") target_link_libraries(onnx_gradients eddl) + add_executable(onnx_gradients_recurrent "onnx/7_onnx_test_gradients_recurrent.cpp") + target_link_libraries(onnx_gradients_recurrent eddl) + add_executable(onnx_import_reshape "onnx/5_onnx_import_net_and_reshape.cpp") target_link_libraries(onnx_import_reshape eddl) @@ -303,4 +306,4 @@ add_executable(test3 "test_internals/test3.cpp") target_link_libraries(test3 eddl) add_executable(test4 "test_internals/test4.cpp") -target_link_libraries(test4 eddl) \ No newline at end of file +target_link_libraries(test4 eddl) diff --git a/examples/applications/app_test_model.cpp b/examples/applications/app_test_model.cpp index 8d9b4aabc..a9a68ccc2 100644 --- a/examples/applications/app_test_model.cpp +++ b/examples/applications/app_test_model.cpp @@ -6,10 +6,10 @@ using namespace eddl; -Tensor* preprocess_input(Tensor* input, const vector &target_size, bool normalize=true, bool standarize=true){ +Tensor* preprocess_input(Tensor* input, const vector &target_size, const vector& mean, const vector& std, bool normalize=true, bool standarize=true, const string& channels_order="rgb"){ // Define preprocessing constants - auto* mean_vec = new Tensor( {0.485, 0.456, 0.406}, {3}, input->device); - auto* std_vec = new Tensor( {0.229, 0.224, 0.225}, {3}, input->device); + auto* mean_vec = new Tensor(mean, {3}, input->device); + auto* std_vec = new Tensor( std, {3}, input->device); // ========================================================================== // ====== SANITY CHECKS ===================================================== @@ -51,6 +51,29 @@ Tensor* preprocess_input(Tensor* input, const vector &target_size, bool nor } // ========================================================================== + // ========================================================================== + // ====== RE-ODER CHANNELS ================================================== + // ========================================================================== + + if (channels_order == "bgr"){ + // Take each channel of the image + Tensor *r_channel = new_input->select({":", "0", ":", ":"}); + Tensor *g_channel = new_input->select({":", "1", ":", ":"}); + Tensor *b_channel = new_input->select({":", "2", ":", ":"}); + + // Concat the channels reordering them + Tensor *bgr_input = Tensor::concat({b_channel, g_channel, r_channel}, 1); + Tensor::copy(bgr_input, new_input); + + delete bgr_input; + delete r_channel; + delete g_channel; + delete b_channel; + }else{ + + } + + // ========================================================================== // Free memory delete mean_vec; delete std_vec; @@ -69,12 +92,12 @@ int main(int argc, char **argv) { string class_names_file = "../../examples/data/imagenet_class_names.txt"; // Image Classification - string model_path = "models/resnet34-v1-7.onnx"; // 3x224x224 // okay -// string model_path = "models/mobilenetv2-7.onnx"; // 3x224x224 // Signal: SIGSEGV (Segmentation fault) +// string model_path = "models/resnet34-v1-7.onnx"; // 3x224x224 // okay + string model_path = "models/mobilenetv2-7.onnx"; // 3x224x224 // Signal: SIGSEGV (Segmentation fault) // string model_path = "models/vgg16-7.onnx"; // 3xHxW // okay // string model_path = "models/bvlcalexnet-3.onnx"; // 3x224x224 // The onnx node 'LRN' is not supported yet -// string model_path = "models/bvlcalexnet-12.onnx"; // 3x224x224 // The onnx node 'LRN' is not supported yet -// string model_path = "models/googlenet-3_simp.onnx"; // 3x224x224 // The onnx node 'LRN' is not supported yet +// string model_path = "models/bvlcalexnet-12.onnx"; // 3x224x224 // okay +// string model_path = "models/googlenet-3.onnx"; // 3x224x224 // **okay**. bad predictions // string model_path = "models/densenet-3.onnx"; // 3x224x224 // okay // string model_path = "models/inception-v1-3.onnx"; // 3x224x224 // The onnx node 'LRN' is not supported yet // string model_path = "models/efficientnet-lite4-11.onnx"; // 224x224x3 // The onnx node 'LRN' is not supported yet @@ -98,8 +121,14 @@ int main(int argc, char **argv) { int in_channels = 3; int in_height = 224; int in_width = 224; + string channels_order = "rgb"; vector input_shape = {in_channels, in_height, in_width}; + vector mean = {0.485, 0.456, 0.406}; + vector std = {0.229, 0.224, 0.225}; + bool normalize = true; // Between [0..1]? + bool standarize = true; // X = (X-mean)/std // vector dimensions_order = {0, 3, 1, 2}; + // // ========================================================================== // input_shape = {input_shape[dimensions_order[1]], input_shape[dimensions_order[2]], input_shape[dimensions_order[3]]}; @@ -109,7 +138,7 @@ int main(int argc, char **argv) { // Import ONNX model std::cout << "Importing ONNX..." << std::endl; - Net *net = import_net_from_onnx_file(model_path, input_shape); + Net *net = import_net_from_onnx_file(model_path, input_shape, 0, LOG_LEVEL::DEBUG); // ========================================================================== // Print and plot our model @@ -152,7 +181,7 @@ int main(int argc, char **argv) { Tensor *image = Tensor::load(image_fname); // Step 3: Preprocess input. (Look up the preprocessing required at the model's page) - Tensor* image_preprocessed = preprocess_input(image, {in_height, in_width}); + Tensor* image_preprocessed = preprocess_input(image, {in_height, in_width}, mean, std, normalize, standarize, channels_order); // image_preprocessed->permute_(dimensions_order); // Predict image. Returns a vector of tensors (here one). diff --git a/examples/onnx/4_onnx_test_gradients.cpp b/examples/onnx/4_onnx_test_gradients.cpp index 1e0e66351..4236fd282 100644 --- a/examples/onnx/4_onnx_test_gradients.cpp +++ b/examples/onnx/4_onnx_test_gradients.cpp @@ -28,8 +28,10 @@ int main(int argc, char **argv) { bool export_cpu = false; bool import_cpu = false; for (int i = 1; i < argc; ++i) { - if (strcmp(argv[i], "--export-cpu") == 0) export_cpu = true; - else if (strcmp(argv[i], "--import-cpu") == 0) import_cpu = true; + if (strcmp(argv[i], "--export-cpu") == 0) + export_cpu = true; + else if (strcmp(argv[i], "--import-cpu") == 0) + import_cpu = true; } // Download mnist @@ -52,9 +54,9 @@ int main(int argc, char **argv) { l = ReLu(Conv(l, 32, {3, 3}, {1, 1})); l = MaxPool(l, {2, 2}); - l = Reshape(l, {-1}); + l = Flatten(l); l = Dense(l, 128, false); - layer out = Activation(Dense(l, num_classes), "softmax"); + layer out = Softmax(Dense(l, num_classes)); cout << "Creating model" << endl; model net = Model({in}, {out}); @@ -63,7 +65,7 @@ int main(int argc, char **argv) { // Build model cout << "Building the model" << endl; build(net, - rmsprop(0.01), // Optimizer + adam(0.001), // Optimizer {"soft_cross_entropy"}, // Losses {"categorical_accuracy"}, // Metrics export_CS, // Computing service @@ -112,13 +114,9 @@ int main(int argc, char **argv) { // Export trained model void *serialized_net_once_trained; cout << "Exporting trained weights" << endl; - size_t snot_size = serialize_net_to_onnx_pointer(net, serialized_net_once_trained, false); + size_t snet_size = serialize_net_to_onnx_pointer(net, serialized_net_once_trained, false); cout << "Trained weights exported" << endl; - // Reset the counter of the layers index - LConv::reset_name_counter(); - LDense::reset_name_counter(); - // Import net topology without trained weights cout << "Importing original net topology (without training)" << endl; Net *imported_net = import_net_from_onnx_pointer(serialized_net, model_size); @@ -127,7 +125,7 @@ int main(int argc, char **argv) { // Build model cout << "Building the loaded topology" << endl; build(imported_net, - rmsprop(0.01), // Optimizer + adam(0.001), // Optimizer {"soft_cross_entropy"}, // Losses {"categorical_accuracy"}, // Metrics import_CS, // Computing service @@ -153,7 +151,7 @@ int main(int argc, char **argv) { // Set trained weights cout << "Putting the trained weights" << endl; - set_weights_from_onnx_pointer(imported_net, serialized_net_once_trained, snot_size); + set_weights_from_onnx_pointer(imported_net, serialized_net_once_trained, snet_size); cout << "Trained weights set" << endl; // Evaluate with trained weights diff --git a/examples/onnx/7_onnx_test_gradients_recurrent.cpp b/examples/onnx/7_onnx_test_gradients_recurrent.cpp new file mode 100644 index 000000000..8dbdb0ac9 --- /dev/null +++ b/examples/onnx/7_onnx_test_gradients_recurrent.cpp @@ -0,0 +1,164 @@ +/* + * EDDL Library - European Distributed Deep Learning Library. + * Version: 1.0 + * copyright (c) 2021, Universitat Politècnica de València (UPV), PRHLT Research + * Centre Date: November 2021 Author: PRHLT Research Centre, UPV, + * (rparedes@prhlt.upv.es), (jon@prhlt.upv.es) All rights reserved + */ + +#include +#include +#include + +#include "eddl/apis/eddl.h" + +#include "eddl/serialization/onnx/eddl_onnx.h" // Not allowed + +using namespace eddl; + +/////////////////////////////////////////// +// 7_onnx_test_gradients_recurrent.cpp: +// An example on how to use the functions +// for exporting weights and gradients +// using the ONNX format, with a recurrent +// network +/////////////////////////////////////////// + +int main(int argc, char **argv) { + // Read arguments + bool export_cpu = false; + bool import_cpu = false; + for (int i = 1; i < argc; ++i) { + if (strcmp(argv[i], "--export-cpu") == 0) + export_cpu = true; + else if (strcmp(argv[i], "--import-cpu") == 0) + import_cpu = true; + } + + // Download dataset + download_imdb_2000(); + + // Settings + int epochs = 2; + int batch_size = 64; + CompServ *export_CS = export_cpu ? CS_CPU() : CS_GPU({1}); + CompServ *import_CS = import_cpu ? CS_CPU() : CS_GPU({1}); + + int length = 250; + int embed_dim = 33; + int vocsize = 2000; + + // Define network + layer in = Input({1}); // 1 word + layer l = in; + + layer l_embed = RandomUniform(Embedding(l, vocsize, 1, embed_dim), -0.05, 0.05); + + l = LSTM(l_embed, 37); + l = ReLu(Dense(l, 256)); + layer out = Sigmoid(Dense(l, 1)); + + cout << "Creating model" << endl; + model net = Model({in}, {out}); + cout << "Model created" << endl; + + // Build model + cout << "Building the model" << endl; + build(net, + adam(0.001), // Optimizer + {"binary_cross_entropy"}, // Losses + {"binary_accuracy"}, // Metrics + export_CS, // Computing service + true // Enable parameters initialization + ); + cout << "Model is correctly built" << endl; + + cout << "Enabling distributed training" << endl; + net->enable_distributed(); + cout << "Distributed training enabled" << endl; + + // Export the net before training + void *serialized_net; + cout << "Serializing net (without training) to pointer" << endl; + size_t model_size = serialize_net_to_onnx_pointer(net, serialized_net, false); + cout << "Net serialized to pointer" << endl; + + // View model + summary(net); + + // Load dataset + Tensor *x_train = Tensor::load("imdb_2000_trX.bin"); + Tensor *y_train = Tensor::load("imdb_2000_trY.bin"); + Tensor *x_test = Tensor::load("imdb_2000_tsX.bin"); + Tensor *y_test = Tensor::load("imdb_2000_tsY.bin"); + + x_train->reshape_({x_train->shape[0], length, 1}); // batch x timesteps x input_dim + x_test->reshape_({x_test->shape[0], length, 1}); // batch x timesteps x input_dim + + y_train->reshape_({y_train->shape[0], 1, 1}); // batch x timesteps x input_dim + y_test->reshape_({y_test->shape[0], 1, 1}); // batch x timesteps x input_dim + + // Train model + cout << "Training the first model" << endl; + fit(net, {x_train}, {y_train}, batch_size, epochs); + + // Evaluate + cout << "Evaluating the first model" << endl; + evaluate(net, {x_test}, {y_test}, batch_size); + + // Export gradients + void *serialized_gradients; + string path("mnist.onnx"); + cout << "Exporting gradients" << endl; + size_t gradients_size = serialize_net_to_onnx_pointer(net, serialized_gradients, true); + cout << "Gradients exported" << endl; + + // Export trained model + void *serialized_net_once_trained; + cout << "Exporting trained weights" << endl; + size_t snet_size = serialize_net_to_onnx_pointer(net, serialized_net_once_trained, false); + cout << "Trained weights exported" << endl; + + // Import net topology without trained weights + cout << "Importing original net topology (without training)" << endl; + Net *imported_net = import_net_from_onnx_pointer(serialized_net, model_size); + cout << "Untrained net imported" << endl; + + // Build model + cout << "Building the loaded topology" << endl; + build(imported_net, + adam(0.001), // Optimizer + {"binary_cross_entropy"}, // Losses + {"binary_accuracy"}, // Metrics + import_CS, // Computing service + false // Disable parameters initialization + ); + cout << "Model is correctly built" << endl; + + // View loaded model + summary(imported_net); + + // Evaluate with untrained model + cout << "Evaluating test with the untrained weights" << endl; + evaluate(imported_net, {x_test}, {y_test}, batch_size); + + // Apply grads + cout << "Applying grads from training" << endl; + apply_grads_from_onnx_pointer(imported_net, serialized_gradients, gradients_size); + cout << "Grads applied" << endl; + + // Evaluate net with accumulated gradients applied + cout << "Evaluating test after applying gradients" << endl; + evaluate(imported_net, {x_test}, {y_test}, batch_size); + + // Set trained weights + cout << "Putting the trained weights" << endl; + set_weights_from_onnx_pointer(imported_net, serialized_net_once_trained, snet_size); + cout << "Trained weights set" << endl; + + // Evaluate with trained weights + cout << "Evaluating test after putting the trained weights" << endl; + evaluate(imported_net, {x_test}, {y_test}, batch_size); + + return 0; +} diff --git a/include/eddl/layers/core/layer_core.h b/include/eddl/layers/core/layer_core.h index 037b8efa0..323365480 100644 --- a/include/eddl/layers/core/layer_core.h +++ b/include/eddl/layers/core/layer_core.h @@ -88,6 +88,7 @@ class LEmbedding : public LinLayer { bool mask_zeros; Tensor *E; Tensor *gE; + Tensor *acc_gE; vector sind; static int total_layers; @@ -103,6 +104,16 @@ class LEmbedding : public LinLayer { void backward() override; + void update_weights(vector weights) override; + + void accumulate_accumulated_gradients(vector grads) override; + + void reset_accumulated_gradients() override; + + void apply_accumulated_gradients() override; + + void enable_distributed() override; + string plot(int c) override; }; diff --git a/include/eddl/layers/layer.h b/include/eddl/layers/layer.h index 7f421424d..168c00375 100644 --- a/include/eddl/layers/layer.h +++ b/include/eddl/layers/layer.h @@ -197,10 +197,19 @@ class MLayer : public Layer { void backward() override {} + void update_weights(vector weights) override {} + + void accumulate_accumulated_gradients(vector grads) override {} + + void reset_accumulated_gradients() override {} + + void apply_accumulated_gradients() override {} + Layer *share(int c, int bs, vector p) override { return nullptr; } Layer *clone(int c, int bs, vector p, int todev) override { return nullptr; } + void enable_distributed() override {}; }; #endif //EDDL_LAYER_H diff --git a/include/eddl/layers/recurrent/layer_recurrent.h b/include/eddl/layers/recurrent/layer_recurrent.h index c0cf53b7a..3fc8d1830 100644 --- a/include/eddl/layers/recurrent/layer_recurrent.h +++ b/include/eddl/layers/recurrent/layer_recurrent.h @@ -77,11 +77,14 @@ class LRNN : public MLayer { Tensor *Wx; Tensor *gWx; + Tensor *acc_gWx; Tensor *bias; Tensor *gbias; + Tensor *acc_gbias; Tensor *Wy; Tensor *gWy; + Tensor *acc_gWy; Tensor *biasy; @@ -97,7 +100,17 @@ class LRNN : public MLayer { void backward() override; + void update_weights(vector weights) override; + + void accumulate_accumulated_gradients(vector grads) override; + + void reset_accumulated_gradients() override; + + void apply_accumulated_gradients() override; + string plot(int c) override; + + void enable_distributed() override; }; @@ -127,9 +140,15 @@ class LLSTM : public MLayer { Tensor *gWoh,*gWox; Tensor *gWch,*gWcx; + Tensor *acc_gWih,*acc_gWix; + Tensor *acc_gWfh,*acc_gWfx; + Tensor *acc_gWoh,*acc_gWox; + Tensor *acc_gWch,*acc_gWcx; + Tensor *in,*fn,*on,*cn; Tensor *inbias,*fnbias,*onbias,*cnbias; Tensor *ginbias,*gfnbias,*gonbias,*gcnbias; + Tensor *acc_ginbias,*acc_gfnbias,*acc_gonbias,*acc_gcnbias; Tensor *incn,*cn1fn; Tensor *sh; @@ -155,7 +174,17 @@ class LLSTM : public MLayer { void backward() override; + void update_weights(vector weights) override; + + void accumulate_accumulated_gradients(vector grads) override; + + void reset_accumulated_gradients() override; + + void apply_accumulated_gradients() override; + string plot(int c) override; + + void enable_distributed() override; }; @@ -186,6 +215,12 @@ class LGRU : public MLayer { Tensor *gUn_h, *gWn_x; Tensor *g_bias_z_t, *g_bias_r_t, *g_bias_n_t, *g_bias_n_t_hidden; + // Accumulated gradient tensors for distributed training + Tensor *acc_gUz_h, *acc_gWz_x; + Tensor *acc_gUr_h, *acc_gWr_x; + Tensor *acc_gUn_h, *acc_gWn_x; + Tensor *acc_g_bias_z_t, *acc_g_bias_r_t, *acc_g_bias_n_t, *acc_g_bias_n_t_hidden; + // Intermediate outputs of the cell Tensor *z_t, *r_t, *n_t; // Gates outputs Tensor *n_t_hidden, *one_minus_z_t; // Gates interoperations @@ -212,7 +247,17 @@ class LGRU : public MLayer { void backward() override; + void update_weights(vector weights) override; + + void accumulate_accumulated_gradients(vector grads) override; + + void reset_accumulated_gradients() override; + + void apply_accumulated_gradients() override; + string plot(int c) override; + + void enable_distributed() override; }; void reduced_abs_sum(Tensor * input, Tensor *output); diff --git a/include/eddl/serialization/onnx/layers/conv/convT_onnx.h b/include/eddl/serialization/onnx/layers/conv/convT_onnx.h index 8124c9f31..f1f579723 100644 --- a/include/eddl/serialization/onnx/layers/conv/convT_onnx.h +++ b/include/eddl/serialization/onnx/layers/conv/convT_onnx.h @@ -25,5 +25,13 @@ Layer* build_convT_layer(onnx::NodeProto *node, // OPSET: 11, 1 void build_convT_node(LConvT2D *layer, onnx::GraphProto *graph, bool gradients); +/* + * DISTRIBUTED TRAINING + */ + +vector get_convT_tensors(onnx::NodeProto &node, + map> &map_init_values, + map> &map_init_dims); + #endif // EDDL_CONVT_ONNX_H #endif // cPROTO diff --git a/include/eddl/serialization/onnx/layers/conv/conv_onnx.h b/include/eddl/serialization/onnx/layers/conv/conv_onnx.h index 5460b2d60..d490444ad 100644 --- a/include/eddl/serialization/onnx/layers/conv/conv_onnx.h +++ b/include/eddl/serialization/onnx/layers/conv/conv_onnx.h @@ -31,10 +31,6 @@ void build_conv_node(LConv *layer, onnx::GraphProto *graph, bool gradients); * DISTRIBUTED TRAINING */ -void update_conv_weights(LConv *layer, vector weights); - -void apply_grads_to_conv(LConv *layer, vector grads); - vector get_conv_tensors(onnx::NodeProto &node, map> &map_init_values, map> &map_init_dims); diff --git a/include/eddl/serialization/onnx/layers/core/bypass_onnx.h b/include/eddl/serialization/onnx/layers/core/bypass_onnx.h new file mode 100644 index 000000000..fa19f0ff0 --- /dev/null +++ b/include/eddl/serialization/onnx/layers/core/bypass_onnx.h @@ -0,0 +1,25 @@ +#if defined(cPROTO) +#ifndef EDDL_BYPASS_ONNX_H +#define EDDL_BYPASS_ONNX_H +#include "eddl/serialization/onnx/onnx.pb.h" +#include "eddl/serialization/onnx/utils_onnx.h" +#include "eddl/layers/core/layer_core.h" + +/* + * ONNX EXPORT + */ + +Layer* build_lrn_layer(onnx::NodeProto *node, + map &output_node_map, + LOG_LEVEL log_level, + int dev, + int mem); +/* + * ONNX EXPORT + */ + +// OPSET: 16, 14, 13, 1 +void build_identity_node(LBypass *layer, onnx::GraphProto *graph); + +#endif // EDDL_BYPASS_ONNX_H +#endif // cPROTO diff --git a/include/eddl/serialization/onnx/layers/core/dense_onnx.h b/include/eddl/serialization/onnx/layers/core/dense_onnx.h index f941d6312..f94e479fc 100644 --- a/include/eddl/serialization/onnx/layers/core/dense_onnx.h +++ b/include/eddl/serialization/onnx/layers/core/dense_onnx.h @@ -34,10 +34,6 @@ void build_dense_with_matmul_node(LDense *layer, onnx::GraphProto *graph, bool g * DISTRIBUTED TRAINING */ -void update_dense_weights(LDense *layer, vector weights); - -void apply_grads_to_dense(LDense *layer, vector grads); - vector get_dense_tensors(onnx::NodeProto &node, map> &map_init_values, map> &map_init_dims); diff --git a/include/eddl/serialization/onnx/layers/core/embedding_onnx.h b/include/eddl/serialization/onnx/layers/core/embedding_onnx.h index d22e30827..679394896 100644 --- a/include/eddl/serialization/onnx/layers/core/embedding_onnx.h +++ b/include/eddl/serialization/onnx/layers/core/embedding_onnx.h @@ -9,7 +9,15 @@ */ // Implemented with Gather Op for OPSET: 13, 11, 1 -void build_embedding_node(LEmbedding *layer, onnx::GraphProto *graph); +void build_embedding_node(LEmbedding *layer, onnx::GraphProto *graph, bool gradients = false); + +/* + * DISTRIBUTED TRAINING + */ + +vector get_embedding_tensors(onnx::NodeProto &node, + map> &map_init_values, + map> &map_init_dims); #endif // EDDL_EMBEDDING_ONNX_H #endif // cPROTO diff --git a/include/eddl/serialization/onnx/layers/layers_onnx.h b/include/eddl/serialization/onnx/layers/layers_onnx.h index 793be9c0e..592c46d0b 100644 --- a/include/eddl/serialization/onnx/layers/layers_onnx.h +++ b/include/eddl/serialization/onnx/layers/layers_onnx.h @@ -69,7 +69,8 @@ enum ONNX_LAYERS { SPLIT, // OPSET: 13, 11, 2 EXPAND, // OPSET: 13, 8 CONSTANT, // OPSET: 13, 12, 11, 9, 1 - REPEAT // OPSET: 13, 6 + REPEAT, // OPSET: 13, 6 + LRN // Skiped with LBypass }; map create_enum_map(); @@ -99,10 +100,6 @@ void build_node_from_layer(Layer *layer, onnx::GraphProto *graph, bool gradients * DISTRIBUTED TRAINING */ -void update_layer_weights(Layer *layer, vector weights); - -void apply_grads_to_layer(Layer *layer, vector grads); - map> get_tensors_from_onnx_nodes(vector &nodes, map> &map_init_values, map> &map_init_dims); diff --git a/include/eddl/serialization/onnx/layers/merge/add_onnx.h b/include/eddl/serialization/onnx/layers/merge/add_onnx.h index 1d0911d70..1dc97cd8c 100644 --- a/include/eddl/serialization/onnx/layers/merge/add_onnx.h +++ b/include/eddl/serialization/onnx/layers/merge/add_onnx.h @@ -24,5 +24,13 @@ Layer* build_add_layer(onnx::NodeProto *node, // OPSET: 13, 7 void build_add_node(LAdd *layer, onnx::GraphProto *graph); +/* + * DISTRIBUTED TRAINING + */ + +vector get_add_tensors(onnx::NodeProto &node, + map> &map_init_values, + map> &map_init_dims); + #endif // EDDL_ADD_ONNX_H #endif // cPROTO diff --git a/include/eddl/serialization/onnx/layers/merge/matmul_onnx.h b/include/eddl/serialization/onnx/layers/merge/matmul_onnx.h index a33a2232a..b4b3a7d1d 100644 --- a/include/eddl/serialization/onnx/layers/merge/matmul_onnx.h +++ b/include/eddl/serialization/onnx/layers/merge/matmul_onnx.h @@ -18,10 +18,11 @@ Layer* build_matmul_layer(onnx::NodeProto *node, int mem); /* - * ONNX EXPORT + * DISTRIBUTED TRAINING */ -// TODO export of layer LMatMul - +vector get_matmul_tensors(onnx::NodeProto &node, + map> &map_init_values, + map> &map_init_dims); #endif // EDDL_MATMUL_ONNX_H #endif // cPROTO diff --git a/include/eddl/serialization/onnx/layers/onnx_nodes/onnx_node_conversion.h b/include/eddl/serialization/onnx/layers/onnx_nodes/onnx_node_conversion.h index 3af18c535..4b6c82eb3 100644 --- a/include/eddl/serialization/onnx/layers/onnx_nodes/onnx_node_conversion.h +++ b/include/eddl/serialization/onnx/layers/onnx_nodes/onnx_node_conversion.h @@ -43,7 +43,7 @@ void build_identity_node(string node_name, string input, string output, onnx::Gr void build_cast_node(string node_name, string input, string output, int cast_type, onnx::GraphProto *graph); // OPSET: 13, 11, 1 -void build_gather_node(string node_name, string input, string output, LEmbedding *layer, onnx::GraphProto *graph); +void build_gather_node(string node_name, string input, string output, LEmbedding *layer, onnx::GraphProto *graph, bool gradients = false); #endif // EDDL_EXPORT_NODES_H #endif // cPROTO diff --git a/include/eddl/serialization/onnx/layers/pool/avgpool_onnx.h b/include/eddl/serialization/onnx/layers/pool/avgpool_onnx.h index 229dc9514..91e9929ec 100644 --- a/include/eddl/serialization/onnx/layers/pool/avgpool_onnx.h +++ b/include/eddl/serialization/onnx/layers/pool/avgpool_onnx.h @@ -3,6 +3,7 @@ #define EDDL_AVGPOOL_ONNX_H #include "eddl/serialization/onnx/onnx.pb.h" #include "eddl/layers/pool/layer_pool.h" +#include "eddl/serialization/onnx/utils_onnx.h" /* * ONNX EXPORT @@ -11,6 +12,7 @@ // OPSET: 11, 10, 7, 1 Layer* build_averagepool_layer(onnx::NodeProto *node, map &output_node_map, + LOG_LEVEL log_level, int dev, int mem); diff --git a/include/eddl/serialization/onnx/layers/pool/maxpool_onnx.h b/include/eddl/serialization/onnx/layers/pool/maxpool_onnx.h index ae43dd262..f14a2fccb 100644 --- a/include/eddl/serialization/onnx/layers/pool/maxpool_onnx.h +++ b/include/eddl/serialization/onnx/layers/pool/maxpool_onnx.h @@ -3,6 +3,7 @@ #define EDDL_MAXPOOL_ONNX_H #include "eddl/serialization/onnx/onnx.pb.h" #include "eddl/layers/pool/layer_pool.h" +#include "eddl/serialization/onnx/utils_onnx.h" /* * ONNX IMPORT @@ -11,6 +12,7 @@ // OPSET: 12, 11, 10, 8, 1 Layer* build_maxpool_layer(onnx::NodeProto *node, map &output_node_map, + LOG_LEVEL log_level, int dev, int mem); diff --git a/include/eddl/serialization/onnx/layers/recurrent/gru_onnx.h b/include/eddl/serialization/onnx/layers/recurrent/gru_onnx.h index 5b07f7636..02cc9af3f 100644 --- a/include/eddl/serialization/onnx/layers/recurrent/gru_onnx.h +++ b/include/eddl/serialization/onnx/layers/recurrent/gru_onnx.h @@ -25,7 +25,15 @@ Layer* build_gru_layer(onnx::NodeProto *node, */ // OPSET: 7, 3, 1 -void build_gru_node(LGRU *layer, onnx::GraphProto *graph); +void build_gru_node(LGRU *layer, onnx::GraphProto *graph, bool gradients = false); + +/* + * DISTRIBUTED TRAINING + */ + +vector get_gru_tensors(onnx::NodeProto &node, + map> &map_init_values, + map> &map_init_dims); #endif // EDDL_GRU_ONNX_H #endif // cPROTO diff --git a/include/eddl/serialization/onnx/layers/recurrent/lstm_onnx.h b/include/eddl/serialization/onnx/layers/recurrent/lstm_onnx.h index 44c5dc651..9e32c7e17 100644 --- a/include/eddl/serialization/onnx/layers/recurrent/lstm_onnx.h +++ b/include/eddl/serialization/onnx/layers/recurrent/lstm_onnx.h @@ -25,7 +25,15 @@ Layer* build_lstm_layer(onnx::NodeProto *node, */ // OPSET: 7, 1 -void build_lstm_node(LLSTM *layer, onnx::GraphProto *graph); +void build_lstm_node(LLSTM *layer, onnx::GraphProto *graph, bool gradients = false); + +/* + * DISTRIBUTED TRAINING + */ + +vector get_lstm_tensors(onnx::NodeProto &node, + map> &map_init_values, + map> &map_init_dims); #endif // EDDL_LSTM_ONNX_H #endif // cPROTO diff --git a/include/eddl/serialization/onnx/layers/recurrent/rnn_onnx.h b/include/eddl/serialization/onnx/layers/recurrent/rnn_onnx.h index 3fb57da2a..04e1e23b8 100644 --- a/include/eddl/serialization/onnx/layers/recurrent/rnn_onnx.h +++ b/include/eddl/serialization/onnx/layers/recurrent/rnn_onnx.h @@ -25,7 +25,15 @@ Layer* build_rnn_layer(onnx::NodeProto *node, */ // OPSET: 7, 1 -void build_rnn_node(LRNN *layer, onnx::GraphProto *graph); +void build_rnn_node(LRNN *layer, onnx::GraphProto *graph, bool gradients = false); + +/* + * DISTRIBUTED TRAINING + */ + +vector get_rnn_tensors(onnx::NodeProto &node, + map> &map_init_values, + map> &map_init_dims); #endif // EDDL_RNN_ONNX_H #endif // cPROTO diff --git a/src/layers/core/layer_embedding.cpp b/src/layers/core/layer_embedding.cpp index ce61b7ca2..f54000d6b 100644 --- a/src/layers/core/layer_embedding.cpp +++ b/src/layers/core/layer_embedding.cpp @@ -61,6 +61,8 @@ LEmbedding::LEmbedding(Layer *parent, int vocsize, int length, int dim, bool mas gE=new Tensor({vocsize,dim},dev); gradients.push_back(gE); + distributed_training = false; + acc_gE = nullptr; parent->addchild(this); addparent(parent); @@ -123,8 +125,39 @@ void LEmbedding::backward() } } +void LEmbedding::update_weights(vector weights) { + if (weights.size() == 1) { + Tensor::copy(weights[0], E); + } else { + cerr << "[WARNING - LEmbedding::update_weights] " + << "Unexpected number of weights tensors recieved " + << "(weights.size()=" << weights.size() << ")" << endl; + } +} + +void LEmbedding::accumulate_accumulated_gradients(vector grads) { + if (grads.size() == 1) { + E->add_(grads[0]); + } else { + cerr << "[WARNING - LEmbedding::accumulate_accumulated_gradients] " + << "Unexpected number of gradient tensors recieved " + << "(grads.size()=" << grads.size() << ")" << endl; + } +} + +void LEmbedding::reset_accumulated_gradients() { + acc_gE->fill_(0.0); +} +void LEmbedding::apply_accumulated_gradients() { + E->add_(acc_gE); +} +void LEmbedding::enable_distributed() { + distributed_training = true; + acc_gE = new Tensor({vocsize, dim}, dev); + acc_gradients.push_back(acc_gE); +} Layer *LEmbedding::share(int c, int bs, vector p) { LEmbedding *n = new LEmbedding(p[0],vocsize, length, dim, mask_zeros, "share_"+to_string(c)+this->name, this->dev, this->mem_level); @@ -148,6 +181,12 @@ Layer *LEmbedding::share(int c, int bs, vector p) { n->params.push_back(this->E); n->gradients.push_back(this->gE); + if (distributed_training) { + n->acc_gradients.clear(); + n->acc_gE = acc_gE; + n->acc_gradients.push_back(acc_gE); + } + if (n->reg != nullptr) delete n->reg; n->reg = reg; if (n->init != nullptr) delete n->init; @@ -168,6 +207,9 @@ Layer *LEmbedding::clone(int c, int bs, vector p, int todev) { if (n->init != nullptr) delete n->init; n->init = this->init; + if (distributed_training) + n->enable_distributed(); + return n; } diff --git a/src/layers/pool/layer_avgpool2D.cpp b/src/layers/pool/layer_avgpool2D.cpp index b3304d46e..47bbffbeb 100644 --- a/src/layers/pool/layer_avgpool2D.cpp +++ b/src/layers/pool/layer_avgpool2D.cpp @@ -42,7 +42,9 @@ if(!D->I->isCPU()){ // Check padding asymmetries if(D->pad[0] != D->pad[1] || D->pad[2] != D->pad[3]){ - msg("Padding asymmetry detected. (top=" + to_string(D->pad[0]) + ", bottom=" + to_string(D->pad[1]) + ", left=" + to_string(D->pad[2]) + ", right=" + to_string(D->pad[3]) + ").\nLayer name: " + this->name, "LAveragePool::LAveragePool"); + string err_msg = "In layer " + this->name + ": Padding asymmetry detected (top=" + to_string(D->pad[0]) + ", bottom=" + to_string(D->pad[1]) + ", left=" + to_string(D->pad[2]) + ", right=" + to_string(D->pad[3]) + "). " + + "The padding asymmetry is not allowed in a AveragePool layer, we suggest you to use an explicit padding layer before this layer to fix the asymmetry."; + throw AsymmetricPaddingException(err_msg, D->pad); } } diff --git a/src/layers/pool/layer_maxpool.cpp b/src/layers/pool/layer_maxpool.cpp index a3df903f0..7db8c845e 100644 --- a/src/layers/pool/layer_maxpool.cpp +++ b/src/layers/pool/layer_maxpool.cpp @@ -45,7 +45,9 @@ LMaxPool::LMaxPool(Layer *parent, PoolDescriptor *D, const string& name, int dev // Check padding asymmetries if(D->pad[0] != D->pad[1] || D->pad[2] != D->pad[3]){ - msg("Padding asymmetry detected. (top=" + to_string(D->pad[0]) + ", bottom=" + to_string(D->pad[1]) + ", left=" + to_string(D->pad[2]) + ", right=" + to_string(D->pad[3]) + ").\nLayer name: " + this->name, "LMaxPool::LMaxPool"); + string err_msg = "In layer " + this->name + ": Padding asymmetry detected (top=" + to_string(D->pad[0]) + ", bottom=" + to_string(D->pad[1]) + ", left=" + to_string(D->pad[2]) + ", right=" + to_string(D->pad[3]) + "). " + + "The padding asymmetry is not allowed in a MaxPool layer, we suggest you to use an explicit padding layer before this layer to fix the asymmetry."; + throw AsymmetricPaddingException(err_msg, D->pad); } } diff --git a/src/layers/recurrent/layer_gru.cpp b/src/layers/recurrent/layer_gru.cpp index b03e93f58..fe9bd1af2 100644 --- a/src/layers/recurrent/layer_gru.cpp +++ b/src/layers/recurrent/layer_gru.cpp @@ -91,6 +91,12 @@ LGRU::LGRU(vector parent, int units, bool mask_zeros, bool bidirectiona parent[i]->addchild(this); addparent(parent[i]); } + + distributed_training = false; + acc_gUz_h = acc_gWz_x = nullptr; + acc_gUr_h = acc_gWr_x = nullptr; + acc_gUn_h = acc_gWn_x = nullptr; + acc_g_bias_z_t = acc_g_bias_r_t = acc_g_bias_n_t = acc_g_bias_n_t_hidden = nullptr; } LGRU::~LGRU() { @@ -400,6 +406,99 @@ void LGRU::backward() { } } +void LGRU::update_weights(vector weights) { + if (weights.size() == 10) { + Tensor::copy(weights[0], Wz_x); + Tensor::copy(weights[1], Wr_x); + Tensor::copy(weights[2], Wn_x); + Tensor::copy(weights[3], Uz_h); + Tensor::copy(weights[4], Ur_h); + Tensor::copy(weights[5], Un_h); + Tensor::copy(weights[6], bias_z_t); + Tensor::copy(weights[7], bias_r_t); + Tensor::copy(weights[8], bias_n_t); + Tensor::copy(weights[9], bias_n_t_hidden); + } else { + cerr << "[WARNING - LGRU::update_weights] " + << "Unexpected number of weights tensors recieved " + << "(weights.size()=" << weights.size() << ")" << endl; + } +} + +void LGRU::accumulate_accumulated_gradients(vector grads) { + if (grads.size() == 10) { + Wz_x->add_(grads[0]); + Wr_x->add_(grads[1]); + Wn_x->add_(grads[2]); + Uz_h->add_(grads[3]); + Ur_h->add_(grads[4]); + Un_h->add_(grads[5]); + bias_z_t->add_(grads[6]); + bias_r_t->add_(grads[7]); + bias_n_t->add_(grads[8]); + bias_n_t_hidden->add_(grads[9]); + } else { + cerr << "[WARNING - LGRU::accumulate_accumulated_gradients] " + << "Unexpected number of gradient tensors recieved " + << "(grads.size()=" << grads.size() << ")" << endl; + } +} + +void LGRU::reset_accumulated_gradients() { + acc_gWz_x->fill_(0.0); + acc_gWr_x->fill_(0.0); + acc_gWn_x->fill_(0.0); + acc_gUz_h->fill_(0.0); + acc_gUr_h->fill_(0.0); + acc_gUn_h->fill_(0.0); + acc_g_bias_z_t->fill_(0.0); + acc_g_bias_r_t->fill_(0.0); + acc_g_bias_n_t->fill_(0.0); + acc_g_bias_n_t_hidden->fill_(0.0); +} + +void LGRU::apply_accumulated_gradients() { + Wz_x->add_(acc_gWz_x); + Wr_x->add_(acc_gWr_x); + Wn_x->add_(acc_gWn_x); + Uz_h->add_(acc_gUz_h); + Ur_h->add_(acc_gUr_h); + Un_h->add_(acc_gUn_h); + bias_z_t->add_(acc_g_bias_z_t); + bias_r_t->add_(acc_g_bias_r_t); + bias_n_t->add_(acc_g_bias_n_t); + bias_n_t_hidden->add_(acc_g_bias_n_t_hidden); +} + +void LGRU::enable_distributed() { + distributed_training = true; + + // Initialize the accumlated gradients tensors + acc_gWz_x = new Tensor(vector{input->shape[1], units}, output->device); + acc_gWr_x = new Tensor(vector{input->shape[1], units}, output->device); + acc_gWn_x = new Tensor(vector{input->shape[1], units}, output->device); + acc_gUz_h = new Tensor(vector{units, units}, output->device); + acc_gUr_h = new Tensor(vector{units, units}, output->device); + acc_gUn_h = new Tensor(vector{units, units}, output->device); + acc_g_bias_z_t = new Tensor(vector{units}, output->device); + acc_g_bias_r_t = new Tensor(vector{units}, output->device); + acc_g_bias_n_t = new Tensor(vector{units}, output->device); + acc_g_bias_n_t_hidden = new Tensor(vector{units}, output->device); + + // Set accumlated gradients to zero + reset_accumulated_gradients(); + + acc_gradients.push_back(acc_gWz_x); + acc_gradients.push_back(acc_gWr_x); + acc_gradients.push_back(acc_gWn_x); + acc_gradients.push_back(acc_gUz_h); + acc_gradients.push_back(acc_gUr_h); + acc_gradients.push_back(acc_gUn_h); + acc_gradients.push_back(acc_g_bias_z_t); + acc_gradients.push_back(acc_g_bias_r_t); + acc_gradients.push_back(acc_g_bias_n_t); + acc_gradients.push_back(acc_g_bias_n_t_hidden); +} Layer *LGRU::share(int c, int bs, vector p) { LGRU *n = new LGRU(p, units, mask_zeros, bidirectional, "share_"+to_string(c)+this->name, this->dev, this->mem_level); @@ -461,6 +560,32 @@ Layer *LGRU::share(int c, int bs, vector p) { n->gradients.push_back(n->gUn_h); } + if (distributed_training) { + n->acc_gradients.clear(); + + n->acc_gWz_x = acc_gWz_x; + n->acc_gWr_x = acc_gWr_x; + n->acc_gWn_x = acc_gWn_x; + n->acc_gUz_h = acc_gUz_h; + n->acc_gUr_h = acc_gUr_h; + n->acc_gUn_h = acc_gUn_h; + n->acc_g_bias_z_t = acc_g_bias_z_t; + n->acc_g_bias_r_t = acc_g_bias_r_t; + n->acc_g_bias_n_t = acc_g_bias_n_t; + n->acc_g_bias_n_t_hidden = acc_g_bias_n_t_hidden; + + n->acc_gradients.push_back(acc_gWz_x); + n->acc_gradients.push_back(acc_gWr_x); + n->acc_gradients.push_back(acc_gWn_x); + n->acc_gradients.push_back(acc_gUz_h); + n->acc_gradients.push_back(acc_gUr_h); + n->acc_gradients.push_back(acc_gUn_h); + n->acc_gradients.push_back(acc_g_bias_z_t); + n->acc_gradients.push_back(acc_g_bias_r_t); + n->acc_gradients.push_back(acc_g_bias_n_t); + n->acc_gradients.push_back(acc_g_bias_n_t_hidden); + } + n->do_deletes = false; if (n->reg != nullptr) delete n->reg; n->reg = reg; @@ -474,6 +599,9 @@ Layer *LGRU::clone(int c, int bs, vector p, int todev) { LGRU *n = new LGRU(p, units, mask_zeros, bidirectional, name, todev, this->mem_level); n->orig = this; + if (distributed_training) + n->enable_distributed(); + // TODO: Implement return n; diff --git a/src/layers/recurrent/layer_lstm.cpp b/src/layers/recurrent/layer_lstm.cpp index 921bd486c..ef09612cd 100644 --- a/src/layers/recurrent/layer_lstm.cpp +++ b/src/layers/recurrent/layer_lstm.cpp @@ -103,8 +103,12 @@ LLSTM::LLSTM(vector parent, int units, bool mask_zeros, bool bidirectio addparent(parent[i]); } - - + distributed_training = false; + acc_gWih = acc_gWix = nullptr; + acc_gWfh = acc_gWfx = nullptr; + acc_gWoh = acc_gWox = nullptr; + acc_gWch = acc_gWcx = nullptr; + acc_ginbias = acc_gfnbias = acc_gonbias = acc_gcnbias = nullptr; } LLSTM::~LLSTM(){ @@ -437,6 +441,103 @@ void LLSTM::backward() { } +void LLSTM::update_weights(vector weights) { + if (weights.size() == 12) { + Tensor::copy(weights[0], Wix); + Tensor::copy(weights[1], Wfx); + Tensor::copy(weights[2], Wox); + Tensor::copy(weights[3], Wcx); + Tensor::copy(weights[4], Wih); + Tensor::copy(weights[5], Wfh); + Tensor::copy(weights[6], Woh); + Tensor::copy(weights[7], Wch); + Tensor::copy(weights[8], inbias); + Tensor::copy(weights[9], fnbias); + Tensor::copy(weights[10], onbias); + Tensor::copy(weights[11], cnbias); + } else { + cerr << "[WARNING - LLSTM::update_weights] " + << "Unexpected number of weights tensors recieved " + << "(weights.size()=" << weights.size() << ")" << endl; + } +} + +void LLSTM::accumulate_accumulated_gradients(vector grads) { + if (grads.size() == 12) { + Wix->add_(grads[0]); + Wfx->add_(grads[1]); + Wox->add_(grads[2]); + Wcx->add_(grads[3]); + Wih->add_(grads[4]); + Wfh->add_(grads[5]); + Woh->add_(grads[6]); + Wch->add_(grads[7]); + inbias->add_(grads[8]); + fnbias->add_(grads[9]); + onbias->add_(grads[10]); + cnbias->add_(grads[11]); + } else { + cerr << "[WARNING - LLSTM::accumulate_accumulated_gradients] " + << "Unexpected number of gradient tensors recieved " + << "(grads.size()=" << grads.size() << ")" << endl; + } +} + +void LLSTM::reset_accumulated_gradients() { + acc_gWih->fill_(0.0); acc_gWix->fill_(0.0); + acc_gWfh->fill_(0.0); acc_gWfx->fill_(0.0); + acc_gWoh->fill_(0.0); acc_gWox->fill_(0.0); + acc_gWch->fill_(0.0); acc_gWcx->fill_(0.0); + acc_ginbias->fill_(0.0); + acc_gfnbias->fill_(0.0); + acc_gonbias->fill_(0.0); + acc_gcnbias->fill_(0.0); +} + +void LLSTM::apply_accumulated_gradients() { + Wih->add_(acc_gWih); Wix->add_(acc_gWix); + Wfh->add_(acc_gWfh); Wfx->add_(acc_gWfx); + Woh->add_(acc_gWoh); Wox->add_(acc_gWox); + Wch->add_(acc_gWch); Wcx->add_(acc_gWcx); + inbias->add_(acc_ginbias); + fnbias->add_(acc_gfnbias); + onbias->add_(acc_gonbias); + cnbias->add_(acc_gcnbias); +} + +void LLSTM::enable_distributed() { + distributed_training = true; + + // Initialize the accumlated gradients tensors + acc_gWix = new Tensor(vector{input->shape[1], units}, output->device); + acc_gWfx = new Tensor(vector{input->shape[1], units}, output->device); + acc_gWox = new Tensor(vector{input->shape[1], units}, output->device); + acc_gWcx = new Tensor(vector{input->shape[1], units}, output->device); + acc_gWih = new Tensor(vector{units, units}, output->device); + acc_gWfh = new Tensor(vector{units, units}, output->device); + acc_gWoh = new Tensor(vector{units, units}, output->device); + acc_gWch = new Tensor(vector{units, units}, output->device); + acc_ginbias = new Tensor(vector{units}, output->device); + acc_gfnbias = new Tensor(vector{units}, output->device); + acc_gonbias = new Tensor(vector{units}, output->device); + acc_gcnbias = new Tensor(vector{units}, output->device); + + // Set accumlated gradients to zero + reset_accumulated_gradients(); + + acc_gradients.push_back(acc_gWix); + acc_gradients.push_back(acc_gWfx); + acc_gradients.push_back(acc_gWox); + acc_gradients.push_back(acc_gWcx); + acc_gradients.push_back(acc_gWih); + acc_gradients.push_back(acc_gWfh); + acc_gradients.push_back(acc_gWoh); + acc_gradients.push_back(acc_gWch); + acc_gradients.push_back(acc_ginbias); + acc_gradients.push_back(acc_gfnbias); + acc_gradients.push_back(acc_gonbias); + acc_gradients.push_back(acc_gcnbias); +} Layer *LLSTM::share(int c, int bs, vector p) { LLSTM *n = new LLSTM(p, units, mask_zeros, bidirectional, "share_"+to_string(c)+this->name, this->dev, this->mem_level); @@ -508,6 +609,35 @@ Layer *LLSTM::share(int c, int bs, vector p) { n->gradients.push_back(gWch); } + if ( distributed_training ) { + n->acc_gradients.clear(); + + n->acc_gWix = acc_gWix; + n->acc_gWfx = acc_gWfx; + n->acc_gWox = acc_gWox; + n->acc_gWcx = acc_gWcx; + n->acc_gWih = acc_gWih; + n->acc_gWfh = acc_gWfh; + n->acc_gWoh = acc_gWoh; + n->acc_gWch = acc_gWch; + n->acc_ginbias = acc_ginbias; + n->acc_gfnbias = acc_gfnbias; + n->acc_gonbias = acc_gonbias; + n->acc_gcnbias = acc_gcnbias; + + n->acc_gradients.push_back(acc_gWix); + n->acc_gradients.push_back(acc_gWfx); + n->acc_gradients.push_back(acc_gWox); + n->acc_gradients.push_back(acc_gWcx); + n->acc_gradients.push_back(acc_gWih); + n->acc_gradients.push_back(acc_gWfh); + n->acc_gradients.push_back(acc_gWoh); + n->acc_gradients.push_back(acc_gWch); + n->acc_gradients.push_back(acc_ginbias); + n->acc_gradients.push_back(acc_gfnbias); + n->acc_gradients.push_back(acc_gonbias); + n->acc_gradients.push_back(acc_gcnbias); + } if (n->reg != nullptr) delete n->reg; n->reg = reg; @@ -521,6 +651,9 @@ Layer *LLSTM::clone(int c, int bs, vector p, int todev) { LLSTM *n = new LLSTM(p, units, mask_zeros, bidirectional, name, todev, this->mem_level); n->orig = this; + if (distributed_training) + n->enable_distributed(); + // TODO: Implement return n; diff --git a/src/layers/recurrent/layer_rnn.cpp b/src/layers/recurrent/layer_rnn.cpp index 5a7f5bc75..ca73e86ad 100644 --- a/src/layers/recurrent/layer_rnn.cpp +++ b/src/layers/recurrent/layer_rnn.cpp @@ -50,7 +50,6 @@ LRNN::LRNN(vector parent, int units, string activation, bool use_bias, gWy = new Tensor(vector{units, units}, dev); gradients.push_back(gWy); - if (use_bias) { bias = new Tensor(vector{units}, dev); params.push_back(bias); @@ -58,12 +57,15 @@ LRNN::LRNN(vector parent, int units, string activation, bool use_bias, gradients.push_back(gbias); } - for (int i = 0; i < parent.size(); ++i) { parent[i]->addchild(this); addparent(parent[i]); } + distributed_training = false; + acc_gWx = nullptr; + acc_gWy = nullptr; + acc_gbias = nullptr; } LRNN::~LRNN(){ @@ -135,6 +137,63 @@ void LRNN::backward() { } +void LRNN::update_weights(vector weights) { + if (weights.size() == 3) { + Tensor::copy(weights[0], Wx); + Tensor::copy(weights[1], Wy); + Tensor::copy(weights[2], bias); + } else if (weights.size() == 2) { + Tensor::copy(weights[0], Wx); + Tensor::copy(weights[1], Wy); + } else { + cerr << "[WARNING - LRNN::update_weights] " + << "Unexpected number of weights tensors recieved " + << "(weights.size()=" << weights.size() << ")" << endl; + } +} + +void LRNN::accumulate_accumulated_gradients(vector grads) { + if (grads.size() == 3) { + Wx->add_(grads[0]); + Wy->add_(grads[1]); + bias->add_(grads[2]); + } else if (grads.size() == 2) { + Wx->add_(grads[0]); + Wy->add_(grads[1]); + } else { + cerr << "[WARNING - LRNN::accumulate_accumulated_gradients] " + << "Unexpected number of gradient tensors recieved " + << "(grads.size()=" << grads.size() << ")" << endl; + } +} + +void LRNN::reset_accumulated_gradients() { + acc_gWx->fill_(0.0); + acc_gWy->fill_(0.0); + if (use_bias) acc_gbias->fill_(0.0); +} + +void LRNN::apply_accumulated_gradients() { + Wx->add_(acc_gWx); + Wy->add_(acc_gWy); + if (use_bias) bias->add_(acc_gbias); +} + +void LRNN::enable_distributed() { + distributed_training = true; + + // Initialize the accumlated gradients tensors + acc_gWx = new Tensor(vector{input->shape[1], units}, output->device); + acc_gWy = new Tensor(vector{units, units}, output->device); + if (use_bias) acc_gbias = new Tensor(vector{units}, output->device); + + // Set accumlated gradients to zero + reset_accumulated_gradients(); + + acc_gradients.push_back(acc_gWx); + acc_gradients.push_back(acc_gWy); + if (use_bias) acc_gradients.push_back(acc_gbias); +} Layer *LRNN::share(int c, int bs, vector p) { LRNN *n = new LRNN(p, units, activation, use_bias, bidirectional, "share_"+to_string(c)+this->name, this->dev, this->mem_level); @@ -166,6 +225,18 @@ Layer *LRNN::share(int c, int bs, vector p) { n->gradients.push_back(n->gWy); if (use_bias) n->gradients.push_back(n->gbias); + if (distributed_training) { + n->acc_gradients.clear(); + + n->acc_gWx = acc_gWx; + n->acc_gWy = acc_gWy; + if (use_bias) n->acc_gbias = acc_gbias; + + n->acc_gradients.push_back(acc_gWx); + n->acc_gradients.push_back(acc_gWy); + if (use_bias) n->acc_gradients.push_back(acc_gbias); + } + if (n->reg != nullptr) delete n->reg; n->reg = reg; if (n->init != nullptr) delete n->init; @@ -178,6 +249,9 @@ Layer *LRNN::clone(int c, int bs, vector p, int todev) { LRNN *n = new LRNN(p, units, activation, use_bias, bidirectional, "clone_" + name, todev, this->mem_level); n->orig = this; + if (distributed_training) + n->enable_distributed(); + // TODO: Implement return n; diff --git a/src/serialization/onnx/net/import_helpers.cpp b/src/serialization/onnx/net/import_helpers.cpp index a95b90aff..552a1c810 100644 --- a/src/serialization/onnx/net/import_helpers.cpp +++ b/src/serialization/onnx/net/import_helpers.cpp @@ -614,10 +614,16 @@ void set_weights_from_model_proto(Net *net, onnx::ModelProto model_proto) continue; // Get the layer weights - vector layer_tensors = tensors[l->name]; + vector new_weights = tensors[l->name]; + if (new_weights.size() == 0) + { + cerr << "[ONNX::WARNING] Trying to update the weights of the layer \"" + << l->name << "\" with an empty list of tensors." << endl; + continue; + } // Apply the new weights - update_layer_weights(l, layer_tensors); + l->update_weights(new_weights); } // Copy the new weights to devices @@ -644,10 +650,16 @@ void apply_grads_from_model_proto(Net *net, onnx::ModelProto model_proto) continue; // Get the layer gradients - vector layer_tensors = tensors[l->name]; + vector acc_grads = tensors[l->name]; + if (acc_grads.size() == 0) + { + cerr << "[ONNX::WARNING] Trying to apply gradients to the layer \"" + << l->name << "\" with an empty list of tensors." << endl; + continue; + } // Apply the gradients - apply_grads_to_layer(l, layer_tensors); + l->accumulate_accumulated_gradients(acc_grads); } // Erase the map we used to free the memory diff --git a/src/serialization/onnx/net/layers/conv/convT_onnx.cpp b/src/serialization/onnx/net/layers/conv/convT_onnx.cpp index d2286f6c2..c7199bd34 100644 --- a/src/serialization/onnx/net/layers/conv/convT_onnx.cpp +++ b/src/serialization/onnx/net/layers/conv/convT_onnx.cpp @@ -279,4 +279,39 @@ void build_convT_node(LConvT2D *layer, onnx::GraphProto *graph, bool gradients) } } +/* + * DISTRIBUTED TRAINING + */ + +vector get_convT_tensors(onnx::NodeProto &node, + map> &map_init_values, + map> &map_init_dims) +{ + vector conv_tensors; + + string weights_name = node.input(1); // Get weights and dims + vector *weights = &(map_init_values[weights_name]); + vector dims = map_init_dims[weights_name]; + + if (dims.size() == 3) + msg("Error: ConvT1D is not supported.", "[ONNX::get_convT_tensors]"); + + Tensor * temp = new Tensor(dims, nullptr, DEV_CPU); + COPY_FROM_VECTOR_PTR_TO_TENSOR(weights, temp); + conv_tensors.push_back(temp); + + if (node.input_size() > 2) + { // This means we also have a bias + string bias_name = node.input(2); + vector *bias = &(map_init_values[bias_name]); + vector bias_shape; + bias_shape.push_back(bias->size()); + temp = new Tensor(bias_shape, nullptr, DEV_CPU); + COPY_FROM_VECTOR_PTR_TO_TENSOR(bias, temp); + conv_tensors.push_back(temp); + } + + return conv_tensors; +} + #endif // defined(cPROTO) diff --git a/src/serialization/onnx/net/layers/conv/conv_onnx.cpp b/src/serialization/onnx/net/layers/conv/conv_onnx.cpp index 1dc7e2a38..78558a9c1 100644 --- a/src/serialization/onnx/net/layers/conv/conv_onnx.cpp +++ b/src/serialization/onnx/net/layers/conv/conv_onnx.cpp @@ -325,16 +325,6 @@ void build_conv_node(LConv *layer, onnx::GraphProto *graph, bool gradients) * DISTRIBUTED TRAINING */ -void update_conv_weights(LConv *layer, vector weights) -{ - layer->update_weights(weights); -} - -void apply_grads_to_conv(LConv *layer, vector grads) -{ - layer->accumulate_accumulated_gradients(grads); -} - vector get_conv_tensors(onnx::NodeProto &node, map> &map_init_values, map> &map_init_dims) @@ -345,6 +335,11 @@ vector get_conv_tensors(onnx::NodeProto &node, vector *weights = &(map_init_values[weights_name]); vector dims = map_init_dims[weights_name]; + // Our Conv1D layers are computed using the backend of the Conv2D, so we + // need to add one extra dimension to have the shape of the kernels of a Conv2D + if (dims.size() == 3) + dims.push_back(1); + Tensor * temp = new Tensor(dims, nullptr, DEV_CPU); COPY_FROM_VECTOR_PTR_TO_TENSOR(weights, temp); conv_tensors.push_back(temp); diff --git a/src/serialization/onnx/net/layers/core/bypass_onnx.cpp b/src/serialization/onnx/net/layers/core/bypass_onnx.cpp new file mode 100644 index 000000000..8d95bb48b --- /dev/null +++ b/src/serialization/onnx/net/layers/core/bypass_onnx.cpp @@ -0,0 +1,35 @@ +#if defined(cPROTO) +#include "eddl/serialization/onnx/layers/core/bypass_onnx.h" + +// ONNX import +Layer* build_lrn_layer(onnx::NodeProto *node, + map &output_node_map, + LOG_LEVEL log_level, + int dev, + int mem) +{ + string name = node->name(); + log_string("Going to use a Bypass layer to skip the LRN node \"" + name + "\"", log_level, LOG_LEVEL::WARN); + string parent_name = node->input(0); // Get parent + Layer *parent = output_node_map[parent_name]; + + return new LBypass(parent, name, name, dev, mem); +} + +// ONNX export +void build_identity_node(LBypass *layer, onnx::GraphProto *graph) +{ + // Add an empty node to the graph + onnx::NodeProto *node = graph->add_node(); + node->set_op_type("Identity"); + node->set_name(layer->name); + // Set the inputs of the node from the parents of the layer + for (Layer *parentl : layer->parent) + { + node->add_input(parentl->name); + } + // Set the name of the output of the node to link with other nodes + node->add_output(layer->name); +} + +#endif // defined(cPROTO) diff --git a/src/serialization/onnx/net/layers/core/dense_onnx.cpp b/src/serialization/onnx/net/layers/core/dense_onnx.cpp index 4c6e73b5f..b43c8b4ff 100644 --- a/src/serialization/onnx/net/layers/core/dense_onnx.cpp +++ b/src/serialization/onnx/net/layers/core/dense_onnx.cpp @@ -191,7 +191,7 @@ void build_dense_with_matmul_node(LDense *layer, onnx::GraphProto *graph, bool g // Add an empty node to the graph onnx::NodeProto *node = graph->add_node(); node->set_op_type("MatMul"); - node->set_name(layer->name + "_MatMul"); + node->set_name(layer->name); // Set the inputs of the node from the parents of the layer for (Layer *parentl : layer->parent) { @@ -211,8 +211,13 @@ void build_dense_with_matmul_node(LDense *layer, onnx::GraphProto *graph, bool g onnx::TensorProto *weight = graph->add_initializer(); weight->set_name(layer->name + "_W"); weight->set_data_type(onnx::TensorProto::FLOAT); - weight->mutable_dims()->Add(layer->W->shape.begin(), layer->W->shape.end()); // Set the shape of the weights - weight->mutable_float_data()->Add(layer->W->ptr, layer->W->ptr + layer->W->size); // Set the weights values + if (!gradients) { + weight->mutable_dims()->Add(layer->W->shape.begin(), layer->W->shape.end()); // Set the shape of the weights + weight->mutable_float_data()->Add(layer->W->ptr, layer->W->ptr + layer->W->size); // Set the weights values + } else { + weight->mutable_dims()->Add(layer->acc_gW->shape.begin(), layer->acc_gW->shape.end()); // Set the shape of the weights + weight->mutable_float_data()->Add(layer->acc_gW->ptr, layer->acc_gW->ptr + layer->acc_gW->size); // Set the weights values + } // Create the Add node in case of using bias in the Dense layer if (layer->use_bias) @@ -220,7 +225,7 @@ void build_dense_with_matmul_node(LDense *layer, onnx::GraphProto *graph, bool g // Add an empty node to the graph onnx::NodeProto *node_bias = graph->add_node(); node_bias->set_op_type("Add"); - node_bias->set_name(layer->name + "_Add"); + node_bias->set_name(layer->name); // Take the input from the previous MatMul node_bias->add_input(layer->name + "_MatMul"); // Set the input param name of the Bias matrix @@ -231,8 +236,13 @@ void build_dense_with_matmul_node(LDense *layer, onnx::GraphProto *graph, bool g onnx::TensorProto *bias = graph->add_initializer(); bias->set_name(layer->name + "_b"); bias->set_data_type(onnx::TensorProto::FLOAT); - bias->mutable_dims()->Add(layer->bias->shape.begin(), layer->bias->shape.end()); // Set the bias shape - bias->mutable_float_data()->Add(layer->bias->ptr, layer->bias->ptr + layer->bias->size); // Set the bias values + if (!gradients) { + bias->mutable_dims()->Add(layer->bias->shape.begin(), layer->bias->shape.end()); // Set the bias shape + bias->mutable_float_data()->Add(layer->bias->ptr, layer->bias->ptr + layer->bias->size); // Set the bias values + } else { + bias->mutable_dims()->Add(layer->acc_gbias->shape.begin(), layer->acc_gbias->shape.end()); // Set the bias shape + bias->mutable_float_data()->Add(layer->acc_gbias->ptr, layer->acc_gbias->ptr + layer->acc_gbias->size); // Set the bias values + } } } @@ -240,16 +250,6 @@ void build_dense_with_matmul_node(LDense *layer, onnx::GraphProto *graph, bool g * DISTRIBUTED TRAINING */ -void update_dense_weights(LDense *layer, vector weights) -{ - layer->update_weights(weights); -} - -void apply_grads_to_dense(LDense *layer, vector grads) -{ - layer->accumulate_accumulated_gradients(grads); -} - vector get_dense_tensors(onnx::NodeProto &node, map> &map_init_values, map> &map_init_dims) diff --git a/src/serialization/onnx/net/layers/core/embedding_onnx.cpp b/src/serialization/onnx/net/layers/core/embedding_onnx.cpp index e78e64fd0..3a103483d 100644 --- a/src/serialization/onnx/net/layers/core/embedding_onnx.cpp +++ b/src/serialization/onnx/net/layers/core/embedding_onnx.cpp @@ -3,7 +3,7 @@ #include "eddl/serialization/onnx/layers/core/squeeze_onnx.h" #include "eddl/serialization/onnx/layers/onnx_nodes/onnx_node_conversion.h" -void build_embedding_node(LEmbedding *layer, onnx::GraphProto *graph) +void build_embedding_node(LEmbedding *layer, onnx::GraphProto *graph, bool gradients) { /* * To create the embedding operation in ONNX we have to use the following steps: @@ -48,7 +48,27 @@ void build_embedding_node(LEmbedding *layer, onnx::GraphProto *graph) cast_node_output, // node input from cast node layer->name, // node output name layer, - graph); + graph, + gradients); +} + +/* + * DISTRIBUTED TRAINING + */ + +vector get_embedding_tensors(onnx::NodeProto &node, + map> &map_init_values, + map> &map_init_dims) +{ + // Get weights and dims + string weights_name = node.input(0); + vector *weights = &(map_init_values[weights_name]); + vector dims = map_init_dims[weights_name]; + + Tensor *weights_tensor = new Tensor(dims, nullptr, DEV_CPU); + COPY_FROM_VECTOR_PTR_TO_TENSOR(weights, weights_tensor); + + return {weights_tensor}; } #endif // defined(cPROTO) diff --git a/src/serialization/onnx/net/layers/layers_onnx.cpp b/src/serialization/onnx/net/layers/layers_onnx.cpp index 2949cd15b..196b3773d 100644 --- a/src/serialization/onnx/net/layers/layers_onnx.cpp +++ b/src/serialization/onnx/net/layers/layers_onnx.cpp @@ -12,6 +12,7 @@ #include "eddl/serialization/onnx/layers/core/split_onnx.h" #include "eddl/serialization/onnx/layers/core/resize_onnx.h" #include "eddl/serialization/onnx/layers/core/repeat_onnx.h" +#include "eddl/serialization/onnx/layers/core/bypass_onnx.h" #include "eddl/serialization/onnx/layers/conv/conv_onnx.h" #include "eddl/serialization/onnx/layers/conv/conv1D_onnx.h" #include "eddl/serialization/onnx/layers/conv/conv3D_onnx.h" @@ -125,6 +126,7 @@ map create_enum_map() map_layers["Expand"] = ONNX_LAYERS::EXPAND; map_layers["Constant"] = ONNX_LAYERS::CONSTANT; map_layers["Tile"] = ONNX_LAYERS::REPEAT; + map_layers["LRN"] = ONNX_LAYERS::LRN; return map_layers; } @@ -177,13 +179,13 @@ Layer* build_layer_from_node(onnx::NodeProto *node, new_layer = build_dropout_layer(node, output_node_map, dev, mem); break; case ONNX_LAYERS::MAXPOOL: - new_layer = build_maxpool_layer(node, output_node_map, dev, mem); + new_layer = build_maxpool_layer(node, output_node_map, log_level, dev, mem); break; case ONNX_LAYERS::GLOBMAXPOOL: new_layer = build_globalmaxpool_layer(node, output_node_map, dev, mem); break; case ONNX_LAYERS::AVGPOOL: - new_layer = build_averagepool_layer(node, output_node_map, dev, mem); + new_layer = build_averagepool_layer(node, output_node_map, log_level, dev, mem); break; case ONNX_LAYERS::GLOBAVGPOOL: new_layer = build_globalaveragegpool_layer(node, output_node_map, dev, mem); @@ -335,6 +337,9 @@ Layer* build_layer_from_node(onnx::NodeProto *node, case ONNX_LAYERS::REPEAT: new_layer = build_repeat_layer(node, constant_node_map, map_init_values, output_node_map, log_level, dev, mem); break; + case ONNX_LAYERS::LRN: + new_layer = build_lrn_layer(node, output_node_map, log_level, dev, mem); + break; default: { std::cerr << "==================================================================" << std::endl; std::cerr << "[ONNX IMPORTING ERROR]: " << "The onnx node '" << layer_type_name << "' is not supported yet" << std::endl; @@ -471,15 +476,15 @@ void build_node_from_layer(Layer *layer, onnx::GraphProto *graph, bool gradients else if (LDropout *l = dynamic_cast(layer)) build_dropout_node(l, graph); else if (LLSTM *l = dynamic_cast(layer)) - build_lstm_node(l, graph); + build_lstm_node(l, graph, gradients); else if (LGRU *l = dynamic_cast(layer)) - build_gru_node(l, graph); + build_gru_node(l, graph, gradients); else if (LRNN *l = dynamic_cast(layer)) - build_rnn_node(l, graph); + build_rnn_node(l, graph, gradients); else if (LCopyStates *l = dynamic_cast(layer)) handle_copy_states(l, graph); else if (LEmbedding *l = dynamic_cast(layer)) - build_embedding_node(l, graph); + build_embedding_node(l, graph, gradients); else if (LResize *l = dynamic_cast(layer)) build_resize_node(l, graph); else if (LScale *l = dynamic_cast(layer)) @@ -494,6 +499,8 @@ void build_node_from_layer(Layer *layer, onnx::GraphProto *graph, bool gradients build_constant_node(l, graph); else if (LRepeat *l = dynamic_cast(layer)) build_tile_node(l, graph); + else if (LBypass *l = dynamic_cast(layer)) + build_identity_node(l, graph); else { cerr << "[ONNX EXPORTING ERROR]: The layer " << layer->name << " has no OpType in Onnx." << endl; @@ -505,40 +512,6 @@ void build_node_from_layer(Layer *layer, onnx::GraphProto *graph, bool gradients * DISTRIBUTED TRAINING */ -void update_layer_weights(Layer *layer, vector weights) -{ - if (weights.size() == 0) - { - cerr << "[ONNX::WARNING] Trying to update the weights of the layer \"" - << layer->name << "\" with an empty list of tensors." << endl; - return; - } - - if (LConv *l = dynamic_cast(layer)) - update_conv_weights(l, weights); - else if (LDense *l = dynamic_cast(layer)) - update_dense_weights(l, weights); - else - cerr << "The layer " << layer->name << " has no support for setting weights" << endl; -} - -void apply_grads_to_layer(Layer *layer, vector grads) -{ - if (grads.size() == 0) - { - cerr << "[ONNX::WARNING] Trying to apply gradients to the layer \"" - << layer->name << "\" with an empty list of tensors." << endl; - return; - } - - if (LConv *l = dynamic_cast(layer)) - apply_grads_to_conv(l, grads); - else if (LDense *l = dynamic_cast(layer)) - apply_grads_to_dense(l, grads); - else - cerr << "The layer " << layer->name << " has no support for applying gradients" << endl; -} - map> get_tensors_from_onnx_nodes(vector &nodes, map> &map_init_values, map> &map_init_dims) @@ -560,11 +533,52 @@ map> get_tensors_from_onnx_nodes(vector matmul_tensors = get_matmul_tensors(node, map_init_values, map_init_dims); + if (matmul_tensors.size()) + tensors[name] = matmul_tensors; + break; + } + case ONNX_LAYERS::ADD: + { + // The Add operator can be used to simulate a Dense layer bias, in that case we take the weights + vector add_tensors = get_add_tensors(node, map_init_values, map_init_dims); + if (add_tensors.size() && tensors.count(name)) + tensors[name].push_back(add_tensors[0]); + break; + } default: // This layer has no trainable parameters continue; diff --git a/src/serialization/onnx/net/layers/merge/add_onnx.cpp b/src/serialization/onnx/net/layers/merge/add_onnx.cpp index a19d28af5..31e2f81b4 100644 --- a/src/serialization/onnx/net/layers/merge/add_onnx.cpp +++ b/src/serialization/onnx/net/layers/merge/add_onnx.cpp @@ -157,4 +157,40 @@ void build_add_node(LAdd *layer, onnx::GraphProto *graph) node->add_output(layer->name); } +/* + * DISTRIBUTED TRAINING + */ + +vector get_add_tensors(onnx::NodeProto &node, + map> &map_init_values, + map> &map_init_dims) +{ + vector add_tensors; + + bool parameter_input = false; + int index_parameter = -1; // Possible values 0 and 1, we won't expect parameters in an add with more than two parents + for (int j = 0; j < node.input_size(); j++) + { + string parent_name = node.input(j); + if (map_init_values.count(parent_name)) + { + parameter_input = true; + index_parameter = j; + break; + } + } + + if (parameter_input) + { + string weights_name = node.input(index_parameter); + vector *weights = &(map_init_values[weights_name]); + vector dims = map_init_dims[weights_name]; + Tensor *weights_tensor = new Tensor(dims, nullptr, DEV_CPU); + COPY_FROM_VECTOR_PTR_TO_TENSOR(weights, weights_tensor); + add_tensors.push_back(weights_tensor); + } + + return add_tensors; +} + #endif // defined(cPROTO) diff --git a/src/serialization/onnx/net/layers/merge/matmul_onnx.cpp b/src/serialization/onnx/net/layers/merge/matmul_onnx.cpp index f356313ca..dfd0e0bbb 100644 --- a/src/serialization/onnx/net/layers/merge/matmul_onnx.cpp +++ b/src/serialization/onnx/net/layers/merge/matmul_onnx.cpp @@ -49,4 +49,40 @@ Layer* build_matmul_layer(onnx::NodeProto *node, return new LMatMul(parents, node->name(), dev, mem); } +/* + * DISTRIBUTED TRAINING + */ + +vector get_matmul_tensors(onnx::NodeProto &node, + map> &map_init_values, + map> &map_init_dims) +{ + vector dense_tensors; + + bool dense_detected = false; + int index_parameter = -1; + for (int j = 0; j < node.input_size(); j++) + { + string parent_name = node.input(j); + if (map_init_values.count(parent_name)) + { + // Dense detected + dense_detected = true; + index_parameter = j; + break; + } + } + if (dense_detected) + { + string weights_name = node.input(index_parameter); + vector *weights = &(map_init_values[weights_name]); + vector dims = map_init_dims[weights_name]; + Tensor *weights_tensor = new Tensor(dims, nullptr, DEV_CPU); + COPY_FROM_VECTOR_PTR_TO_TENSOR(weights, weights_tensor); + dense_tensors.push_back(weights_tensor); + } + + return dense_tensors; +} + #endif // defined(cPROTO) diff --git a/src/serialization/onnx/net/layers/onnx_nodes/onnx_node_conversion.cpp b/src/serialization/onnx/net/layers/onnx_nodes/onnx_node_conversion.cpp index a90bea3ba..78528edba 100644 --- a/src/serialization/onnx/net/layers/onnx_nodes/onnx_node_conversion.cpp +++ b/src/serialization/onnx/net/layers/onnx_nodes/onnx_node_conversion.cpp @@ -106,7 +106,7 @@ void build_cast_node(string node_name, string input, string output, int cast_typ to_attr->set_i(cast_type); } -void build_gather_node(string node_name, string input, string output, LEmbedding *layer, onnx::GraphProto *graph) +void build_gather_node(string node_name, string input, string output, LEmbedding *layer, onnx::GraphProto *graph, bool gradients) { onnx::NodeProto *node = graph->add_node(); node->set_op_type("Gather"); @@ -122,7 +122,10 @@ void build_gather_node(string node_name, string input, string output, LEmbedding embed_data->set_data_type(onnx::TensorProto::FLOAT); vector embed_data_dims{layer->vocsize, layer->dim}; embed_data->mutable_dims()->Add(embed_data_dims.begin(), embed_data_dims.end()); // Set the shape of the weights - embed_data->mutable_float_data()->Add(layer->E->ptr, layer->E->ptr + layer->E->size); // Set the data values + if (!gradients) + embed_data->mutable_float_data()->Add(layer->E->ptr, layer->E->ptr + layer->E->size); // Set the data values + else + embed_data->mutable_float_data()->Add(layer->acc_gE->ptr, layer->acc_gE->ptr + layer->acc_gE->size); // Set the data values } #endif // defined(cPROTO) diff --git a/src/serialization/onnx/net/layers/pool/avgpool_onnx.cpp b/src/serialization/onnx/net/layers/pool/avgpool_onnx.cpp index 2f1a920e1..57e7fded7 100644 --- a/src/serialization/onnx/net/layers/pool/avgpool_onnx.cpp +++ b/src/serialization/onnx/net/layers/pool/avgpool_onnx.cpp @@ -1,9 +1,11 @@ #if defined(cPROTO) #include "eddl/serialization/onnx/layers/pool/avgpool_onnx.h" +#include "eddl/layers/da/layer_da.h" // ONNX import Layer* build_averagepool_layer(onnx::NodeProto *node, map &output_node_map, + LOG_LEVEL log_level, int dev, int mem) { @@ -89,7 +91,33 @@ Layer* build_averagepool_layer(onnx::NodeProto *node, return new LAveragePool1D(parent, new PoolDescriptor(kernel_shape, strides, pads), name, dev, mem); } else - return new LAveragePool(parent, new PoolDescriptor(kernel_shape, strides, pads), name, dev, mem); + { + LAveragePool *pool_layer; + PoolDescriptor *pd; + try + { + pd = new PoolDescriptor(kernel_shape, strides, pads); + pool_layer = new LAveragePool(parent, pd, name, dev, mem); + } + catch (AsymmetricPaddingException& e) + { + log_string("Detected a padding asymmetry in the AveragePool layer \"" + name + "\". Going to add an explicit Pad layer before to fix it.", log_level, LOG_LEVEL::INFO); + // Remove the invalid pool layer from the parent child vector + parent->child.pop_back(); + parent->lout--; + + vector asym_pads = e.get_asymmetric_pads(); // Asymmetric paddings to fix + string pad_layer_name = name + "__asymmetric_padding"; + // Create a parent layer to fix the padding asymmetry + parent = new LPad(parent, {asym_pads[0], asym_pads[3], asym_pads[1], asym_pads[2]}, 0.0, pad_layer_name, dev, mem); + // Create again the full AveragePool layer + vector new_pads = {0, 0, 0, 0}; + pd = new PoolDescriptor(kernel_shape, strides, new_pads, mem); + pool_layer = new LAveragePool(parent, pd, name, dev, mem); + } + + return pool_layer; + } } // ONNX import diff --git a/src/serialization/onnx/net/layers/pool/maxpool_onnx.cpp b/src/serialization/onnx/net/layers/pool/maxpool_onnx.cpp index 5e43b3f98..6a996d4e6 100644 --- a/src/serialization/onnx/net/layers/pool/maxpool_onnx.cpp +++ b/src/serialization/onnx/net/layers/pool/maxpool_onnx.cpp @@ -1,9 +1,11 @@ #if defined(cPROTO) #include "eddl/serialization/onnx/layers/pool/maxpool_onnx.h" +#include "eddl/layers/da/layer_da.h" // ONNX import Layer* build_maxpool_layer(onnx::NodeProto *node, map &output_node_map, + LOG_LEVEL log_level, int dev, int mem) { @@ -97,7 +99,33 @@ Layer* build_maxpool_layer(onnx::NodeProto *node, return new LMaxPool1D(parent, new PoolDescriptor(kernel_shape, strides, pads), name, dev, mem); } else - return new LMaxPool(parent, new PoolDescriptor(kernel_shape, strides, pads), name, dev, mem); + { + LMaxPool *pool_layer; + PoolDescriptor *pd; + try + { + pd = new PoolDescriptor(kernel_shape, strides, pads); + pool_layer = new LMaxPool(parent, pd, name, dev, mem); + } + catch (AsymmetricPaddingException& e) + { + log_string("Detected a padding asymmetry in the MaxPool layer \"" + name + "\". Going to add an explicit Pad layer before to fix it.", log_level, LOG_LEVEL::INFO); + // Remove the invalid pool layer from the parent child vector + parent->child.pop_back(); + parent->lout--; + + vector asym_pads = e.get_asymmetric_pads(); // Asymmetric paddings to fix + string pad_layer_name = name + "__asymmetric_padding"; + // Create a parent layer to fix the padding asymmetry + parent = new LPad(parent, {asym_pads[0], asym_pads[3], asym_pads[1], asym_pads[2]}, 0.0, pad_layer_name, dev, mem); + // Create again the full MaxPool layer + vector new_pads = {0, 0, 0, 0}; + pd = new PoolDescriptor(kernel_shape, strides, new_pads, mem); + pool_layer = new LMaxPool(parent, pd, name, dev, mem); + } + + return pool_layer; + } } // ONNX import diff --git a/src/serialization/onnx/net/layers/recurrent/gru_onnx.cpp b/src/serialization/onnx/net/layers/recurrent/gru_onnx.cpp index f0e9e60b8..5c9873788 100644 --- a/src/serialization/onnx/net/layers/recurrent/gru_onnx.cpp +++ b/src/serialization/onnx/net/layers/recurrent/gru_onnx.cpp @@ -68,7 +68,7 @@ Layer* build_gru_layer(onnx::NodeProto *node, } if (hidden_size < 0) - msg("GRU layer " + name + " doesn't have the number of neurons.", "ONNX::ImportNet"); + msg("The layer " + name + " (GRU) does not provide the hidden_size attribute.", "[ONNX::ImportNet]"); string parent_name = node->input(0); // Get parent Layer *parent = output_node_map[parent_name]; @@ -262,7 +262,7 @@ Layer* build_gru_layer(onnx::NodeProto *node, } // ONNX export -void build_gru_node(LGRU *layer, onnx::GraphProto *graph) +void build_gru_node(LGRU *layer, onnx::GraphProto *graph, bool gradients) { // Add an empty node to the graph onnx::NodeProto *node = graph->add_node(); @@ -343,15 +343,27 @@ void build_gru_node(LGRU *layer, onnx::GraphProto *graph) /* * The Weights are permuted before saving them (required by ONNX standad) */ - Tensor *Wz_x = layer->Wz_x->permute({1, 0}); - w->mutable_float_data()->Add(Wz_x->ptr, Wz_x->ptr + Wz_x->size); // z weights - delete Wz_x; - Tensor *Wr_x = layer->Wr_x->permute({1, 0}); - w->mutable_float_data()->Add(Wr_x->ptr, Wr_x->ptr + Wr_x->size); // r weights - delete Wr_x; - Tensor *Wn_x = layer->Wn_x->permute({1, 0}); - w->mutable_float_data()->Add(Wn_x->ptr, Wn_x->ptr + Wn_x->size); // n weights - delete Wn_x; + if (!gradients) { + Tensor *Wz_x = layer->Wz_x->permute({1, 0}); + w->mutable_float_data()->Add(Wz_x->ptr, Wz_x->ptr + Wz_x->size); // z weights + delete Wz_x; + Tensor *Wr_x = layer->Wr_x->permute({1, 0}); + w->mutable_float_data()->Add(Wr_x->ptr, Wr_x->ptr + Wr_x->size); // r weights + delete Wr_x; + Tensor *Wn_x = layer->Wn_x->permute({1, 0}); + w->mutable_float_data()->Add(Wn_x->ptr, Wn_x->ptr + Wn_x->size); // n weights + delete Wn_x; + } else { + Tensor *Wz_x = layer->acc_gWz_x->permute({1, 0}); + w->mutable_float_data()->Add(Wz_x->ptr, Wz_x->ptr + Wz_x->size); // z weights + delete Wz_x; + Tensor *Wr_x = layer->acc_gWr_x->permute({1, 0}); + w->mutable_float_data()->Add(Wr_x->ptr, Wr_x->ptr + Wr_x->size); // r weights + delete Wr_x; + Tensor *Wn_x = layer->acc_gWn_x->permute({1, 0}); + w->mutable_float_data()->Add(Wn_x->ptr, Wn_x->ptr + Wn_x->size); // n weights + delete Wn_x; + } // R input (recurrent weights for all the layers W[zrh]) onnx::TensorProto *r = graph->add_initializer(); @@ -362,15 +374,27 @@ void build_gru_node(LGRU *layer, onnx::GraphProto *graph) /* * The Weights are permuted before saving them (required by ONNX standad) */ - Tensor *Wz_hidden = layer->Uz_h->permute({1, 0}); - r->mutable_float_data()->Add(Wz_hidden->ptr, Wz_hidden->ptr + Wz_hidden->size); // z recurrent weights - delete Wz_hidden; - Tensor *Wr_hidden = layer->Ur_h->permute({1, 0}); - r->mutable_float_data()->Add(Wr_hidden->ptr, Wr_hidden->ptr + Wr_hidden->size); // r recurrent weights - delete Wr_hidden; - Tensor *Wn_hidden = layer->Un_h->permute({1, 0}); - r->mutable_float_data()->Add(Wn_hidden->ptr, Wn_hidden->ptr + Wn_hidden->size); // n recurrent weights - delete Wn_hidden; + if (!gradients) { + Tensor *Wz_hidden = layer->Uz_h->permute({1, 0}); + r->mutable_float_data()->Add(Wz_hidden->ptr, Wz_hidden->ptr + Wz_hidden->size); // z recurrent weights + delete Wz_hidden; + Tensor *Wr_hidden = layer->Ur_h->permute({1, 0}); + r->mutable_float_data()->Add(Wr_hidden->ptr, Wr_hidden->ptr + Wr_hidden->size); // r recurrent weights + delete Wr_hidden; + Tensor *Wn_hidden = layer->Un_h->permute({1, 0}); + r->mutable_float_data()->Add(Wn_hidden->ptr, Wn_hidden->ptr + Wn_hidden->size); // n recurrent weights + delete Wn_hidden; + } else { + Tensor *Wz_hidden = layer->acc_gUz_h->permute({1, 0}); + r->mutable_float_data()->Add(Wz_hidden->ptr, Wz_hidden->ptr + Wz_hidden->size); // z recurrent weights + delete Wz_hidden; + Tensor *Wr_hidden = layer->acc_gUr_h->permute({1, 0}); + r->mutable_float_data()->Add(Wr_hidden->ptr, Wr_hidden->ptr + Wr_hidden->size); // r recurrent weights + delete Wr_hidden; + Tensor *Wn_hidden = layer->acc_gUn_h->permute({1, 0}); + r->mutable_float_data()->Add(Wn_hidden->ptr, Wn_hidden->ptr + Wn_hidden->size); // n recurrent weights + delete Wn_hidden; + } // B input (biases for all the layers) onnx::TensorProto *b = graph->add_initializer(); @@ -379,9 +403,15 @@ void build_gru_node(LGRU *layer, onnx::GraphProto *graph) vector b_dims{1, 6 * layer->units}; // b_dims shape[0] = 1 for weights in one directions b->mutable_dims()->Add(b_dims.begin(), b_dims.end()); // Set the shape of the weights - b->mutable_float_data()->Add(layer->bias_z_t->ptr, layer->bias_z_t->ptr + layer->bias_z_t->size); // z bias - b->mutable_float_data()->Add(layer->bias_r_t->ptr, layer->bias_r_t->ptr + layer->bias_r_t->size); // r bias - b->mutable_float_data()->Add(layer->bias_n_t->ptr, layer->bias_n_t->ptr + layer->bias_n_t->size); // n bias + if (!gradients) { + b->mutable_float_data()->Add(layer->bias_z_t->ptr, layer->bias_z_t->ptr + layer->bias_z_t->size); // z bias + b->mutable_float_data()->Add(layer->bias_r_t->ptr, layer->bias_r_t->ptr + layer->bias_r_t->size); // r bias + b->mutable_float_data()->Add(layer->bias_n_t->ptr, layer->bias_n_t->ptr + layer->bias_n_t->size); // n bias + } else { + b->mutable_float_data()->Add(layer->acc_g_bias_z_t->ptr, layer->acc_g_bias_z_t->ptr + layer->acc_g_bias_z_t->size); // z bias + b->mutable_float_data()->Add(layer->acc_g_bias_r_t->ptr, layer->acc_g_bias_r_t->ptr + layer->acc_g_bias_r_t->size); // r bias + b->mutable_float_data()->Add(layer->acc_g_bias_n_t->ptr, layer->acc_g_bias_n_t->ptr + layer->acc_g_bias_n_t->size); // n bias + } // Set recurrent forward biases to 0 for gates z and r for (int i = 0; i < 2 * layer->units; ++i) @@ -389,7 +419,10 @@ void build_gru_node(LGRU *layer, onnx::GraphProto *graph) // The recurrent bias for n is set. Because we need it for applying the linear transformation before the // r gate. See "linear_before_reset" attribute in https://github.com/onnx/onnx/blob/master/docs/Operators.md#GRU - b->mutable_float_data()->Add(layer->bias_n_t_hidden->ptr, layer->bias_n_t_hidden->ptr + layer->bias_n_t_hidden->size); // n recurrent bias + if (!gradients) + b->mutable_float_data()->Add(layer->bias_n_t_hidden->ptr, layer->bias_n_t_hidden->ptr + layer->bias_n_t_hidden->size); // n recurrent bias + else + b->mutable_float_data()->Add(layer->acc_g_bias_n_t_hidden->ptr, layer->acc_g_bias_n_t_hidden->ptr + layer->acc_g_bias_n_t_hidden->size); // n recurrent bias /* Set the outputs of the node to link with the other nodes * - In ONNX the GRU operator can have up to 2 outputs: @@ -428,4 +461,177 @@ void build_gru_node(LGRU *layer, onnx::GraphProto *graph) } } +/* + * DISTRIBUTED TRAINING + */ + +vector get_gru_tensors(onnx::NodeProto &node, + map> &map_init_values, + map> &map_init_dims) +{ + vector gru_tensors; + int hidden_size = -1; // Number of neurons in the hidden layer + + for (int j = 0; j < node.attribute_size(); j++) + { // Set the attributes + onnx::AttributeProto attribute = node.attribute(j); + string attr_name = attribute.name(); + if (!attr_name.compare("hidden_size")) { + hidden_size = attribute.i(); + break; + } + } + + if (hidden_size < 0) + msg("The layer " + node.name() + " (GRU) does not provide the hidden_size attribute.", "[ONNX::ImportNet]"); + + string weights_gates = node.input(1); // Get weights and dims + vector *weights_g = &(map_init_values[weights_gates]); + vector dims_g = map_init_dims[weights_gates]; + int input_size = dims_g[2]; + + // Load input weights with shape [hidden_size, input_size]. After load we transpose + // Note: EDDL input weights are of shape [input_size, hidden_size] + vector dims_input_gru = {dims_g[1] / 3, input_size}; + + vector *weights_z_g = new vector; + vector *weights_r_g = new vector; + vector *weights_n_g = new vector; + int w_size = input_size * hidden_size; + weights_z_g->assign(weights_g->begin() + w_size * 0, weights_g->begin() + w_size * 1); + weights_r_g->assign(weights_g->begin() + w_size * 1, weights_g->begin() + w_size * 2); + weights_n_g->assign(weights_g->begin() + w_size * 2, weights_g->begin() + w_size * 3); + + string recurrence_weights_gates = node.input(2); // Get weights and dims + vector *recurrence_weights_g = &(map_init_values[recurrence_weights_gates]); + vector recurrence_dims_g = map_init_dims[recurrence_weights_gates]; + + vector dims_recurrent_gru = {recurrence_dims_g[2], recurrence_dims_g[2]}; + + vector *recurrence_weights_z_g = new vector; + vector *recurrence_weights_r_g = new vector; + vector *recurrence_weights_n_g = new vector; + w_size = hidden_size * hidden_size; + recurrence_weights_z_g->assign(recurrence_weights_g->begin() + w_size * 0, recurrence_weights_g->begin() + w_size * 1); + recurrence_weights_r_g->assign(recurrence_weights_g->begin() + w_size * 1, recurrence_weights_g->begin() + w_size * 2); + recurrence_weights_n_g->assign(recurrence_weights_g->begin() + w_size * 2, recurrence_weights_g->begin() + w_size * 3); + + /* + * The Weights are permuted before copying them to the GRU layer (mismatch between ONNX standad and EDDL implementation) + */ + int dev = DEV_CPU; + Tensor *weights_z_tensor = new Tensor(dims_input_gru, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(weights_z_g, weights_z_tensor); + weights_z_tensor->permute_({1, 0}); + delete weights_z_g; + + Tensor *weights_r_tensor = new Tensor(dims_input_gru, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(weights_r_g, weights_r_tensor); + weights_r_tensor->permute_({1, 0}); + delete weights_r_g; + + Tensor *weights_n_tensor = new Tensor(dims_input_gru, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(weights_n_g, weights_n_tensor); + weights_n_tensor->permute_({1, 0}); + delete weights_n_g; + + Tensor *recurrence_weights_z_tensor = new Tensor(dims_recurrent_gru, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(recurrence_weights_z_g, recurrence_weights_z_tensor); + recurrence_weights_z_tensor->permute_({1, 0}); + delete recurrence_weights_z_g; + + Tensor *recurrence_weights_r_tensor = new Tensor(dims_recurrent_gru, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(recurrence_weights_r_g, recurrence_weights_r_tensor); + recurrence_weights_r_tensor->permute_({1, 0}); + delete recurrence_weights_r_g; + + Tensor *recurrence_weights_n_tensor = new Tensor(dims_recurrent_gru, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(recurrence_weights_n_g, recurrence_weights_n_tensor); + recurrence_weights_n_tensor->permute_({1, 0}); + delete recurrence_weights_n_g; + + gru_tensors.push_back(weights_z_tensor); + gru_tensors.push_back(weights_r_tensor); + gru_tensors.push_back(weights_n_tensor); + gru_tensors.push_back(recurrence_weights_z_tensor); + gru_tensors.push_back(recurrence_weights_r_tensor); + gru_tensors.push_back(recurrence_weights_n_tensor); + + /* + * Set bias values + */ + vector bias_dims = {hidden_size}; + // Vectors to store the imported weights + vector *bias_z = new vector; + vector *bias_r = new vector; + vector *bias_n = new vector; + vector *bias_recurrence_z = new vector; + vector *bias_recurrence_r = new vector; + vector *bias_recurrence_n = new vector; + + if (node.input_size() > 3) { // Check that we have bias + string biases_name = node.input(3); + vector *biases = &(map_init_values[biases_name]); + // Forward bias (zrh) + bias_z->assign(biases->begin() + hidden_size * 0, biases->begin() + hidden_size * 1); + bias_r->assign(biases->begin() + hidden_size * 1, biases->begin() + hidden_size * 2); + bias_n->assign(biases->begin() + hidden_size * 2, biases->begin() + hidden_size * 3); + // Recurrent bias (zrh) + bias_recurrence_z->assign(biases->begin() + hidden_size * 3, biases->begin() + hidden_size * 4); + bias_recurrence_r->assign(biases->begin() + hidden_size * 4, biases->begin() + hidden_size * 5); + bias_recurrence_n->assign(biases->begin() + hidden_size * 5, biases->begin() + hidden_size * 6); + } else { + // Set bias values to 0.0 + // Note: In EDDL we don't have use_bias option for GRU so to achieve the same + // result we set the bias values to 0.0 + vector zero_bias(hidden_size, 0.0); + bias_z->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_r->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_n->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_recurrence_z->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_recurrence_r->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_recurrence_n->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + } + + Tensor *bias_z_tensor = new Tensor(bias_dims, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(bias_z, bias_z_tensor); + delete bias_z; + + Tensor *bias_r_tensor = new Tensor(bias_dims, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(bias_r, bias_r_tensor); + delete bias_r; + + Tensor *bias_n_tensor = new Tensor(bias_dims, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(bias_n, bias_n_tensor); + delete bias_n; + + gru_tensors.push_back(bias_z_tensor); + gru_tensors.push_back(bias_r_tensor); + gru_tensors.push_back(bias_n_tensor); + + // Add the recurrent bias values for gates z and r + /* Not needed for importing gradients + Tensor *bias_recurrence_z_tensor = new Tensor(bias_dims, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(bias_recurrence_z, bias_recurrence_z_tensor); + Tensor::add(bias_recurrence_z_tensor, gru->bias_z_t, gru->bias_z_t); + delete bias_recurrence_z_tensor; + delete bias_recurrence_z; + + Tensor *bias_recurrence_r_tensor = new Tensor(bias_dims, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(bias_recurrence_r, bias_recurrence_r_tensor); + Tensor::add(bias_recurrence_r_tensor, gru->bias_r_t, gru->bias_r_t); + delete bias_recurrence_r_tensor; + delete bias_recurrence_r; + */ + + // The recurrent bias for h goes to its own tensor beacuse we need it for applying the linear transformation + // before the r gate. See "linear_before_reset" attribute in https://github.com/onnx/onnx/blob/master/docs/Operators.md#GRU + Tensor *bias_recurrence_n_tensor = new Tensor(bias_dims, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(bias_recurrence_n, bias_recurrence_n_tensor); + delete bias_recurrence_n; + gru_tensors.push_back(bias_recurrence_n_tensor); + + return gru_tensors; +} + #endif // defined(cPROTO) diff --git a/src/serialization/onnx/net/layers/recurrent/lstm_onnx.cpp b/src/serialization/onnx/net/layers/recurrent/lstm_onnx.cpp index 5dcd974ec..b850a0ddb 100644 --- a/src/serialization/onnx/net/layers/recurrent/lstm_onnx.cpp +++ b/src/serialization/onnx/net/layers/recurrent/lstm_onnx.cpp @@ -93,9 +93,7 @@ Layer* build_lstm_layer(onnx::NodeProto *node, } if (hidden_size < 0) - { - cerr << "Model contains a LSTM without the number of neurons" << endl; - } + msg("The layer " + name + " (LSTM) does not provide the hidden_size attribute.", "[ONNX::ImportNet]"); string weights_gates = node->input(1); // Get weights and dims vector *weights_g = &(map_init_values[weights_gates]); @@ -303,7 +301,7 @@ Layer* build_lstm_layer(onnx::NodeProto *node, } // ONNX export -void build_lstm_node(LLSTM *layer, onnx::GraphProto *graph) +void build_lstm_node(LLSTM *layer, onnx::GraphProto *graph, bool gradients) { // Add an empty node to the graph onnx::NodeProto *node = graph->add_node(); @@ -384,18 +382,33 @@ void build_lstm_node(LLSTM *layer, onnx::GraphProto *graph) /* * The Weights are permuted before saving them (required by ONNX standad) */ - Tensor *Wix = layer->Wix->permute({1, 0}); - w->mutable_float_data()->Add(Wix->ptr, Wix->ptr + Wix->size); // i weights - delete Wix; - Tensor *Wox = layer->Wox->permute({1, 0}); - w->mutable_float_data()->Add(Wox->ptr, Wox->ptr + Wox->size); // o weights - delete Wox; - Tensor *Wfx = layer->Wfx->permute({1, 0}); - w->mutable_float_data()->Add(Wfx->ptr, Wfx->ptr + Wfx->size); // f weights - delete Wfx; - Tensor *Wcx = layer->Wcx->permute({1, 0}); - w->mutable_float_data()->Add(Wcx->ptr, Wcx->ptr + Wcx->size); // c weights - delete Wcx; + if (!gradients) { + Tensor *Wix = layer->Wix->permute({1, 0}); + w->mutable_float_data()->Add(Wix->ptr, Wix->ptr + Wix->size); // i weights + delete Wix; + Tensor *Wox = layer->Wox->permute({1, 0}); + w->mutable_float_data()->Add(Wox->ptr, Wox->ptr + Wox->size); // o weights + delete Wox; + Tensor *Wfx = layer->Wfx->permute({1, 0}); + w->mutable_float_data()->Add(Wfx->ptr, Wfx->ptr + Wfx->size); // f weights + delete Wfx; + Tensor *Wcx = layer->Wcx->permute({1, 0}); + w->mutable_float_data()->Add(Wcx->ptr, Wcx->ptr + Wcx->size); // c weights + delete Wcx; + } else { + Tensor *Wix = layer->acc_gWix->permute({1, 0}); + w->mutable_float_data()->Add(Wix->ptr, Wix->ptr + Wix->size); // i weights + delete Wix; + Tensor *Wox = layer->acc_gWox->permute({1, 0}); + w->mutable_float_data()->Add(Wox->ptr, Wox->ptr + Wox->size); // o weights + delete Wox; + Tensor *Wfx = layer->acc_gWfx->permute({1, 0}); + w->mutable_float_data()->Add(Wfx->ptr, Wfx->ptr + Wfx->size); // f weights + delete Wfx; + Tensor *Wcx = layer->acc_gWcx->permute({1, 0}); + w->mutable_float_data()->Add(Wcx->ptr, Wcx->ptr + Wcx->size); // c weights + delete Wcx; + } // R input (recurrent weights for all the layers W[iofc]) onnx::TensorProto *r = graph->add_initializer(); @@ -406,18 +419,33 @@ void build_lstm_node(LLSTM *layer, onnx::GraphProto *graph) /* * The Weights are permuted before saving them (required by ONNX standad) */ - Tensor *Wih = layer->Wih->permute({1, 0}); - r->mutable_float_data()->Add(Wih->ptr, Wih->ptr + Wih->size); // i recurrent weights - delete Wih; - Tensor *Woh = layer->Woh->permute({1, 0}); - r->mutable_float_data()->Add(Woh->ptr, Woh->ptr + Woh->size); // o recurrent weights - delete Woh; - Tensor *Wfh = layer->Wfh->permute({1, 0}); - r->mutable_float_data()->Add(Wfh->ptr, Wfh->ptr + Wfh->size); // f recurrent weights - delete Wfh; - Tensor *Wch = layer->Wch->permute({1, 0}); - r->mutable_float_data()->Add(Wch->ptr, Wch->ptr + Wch->size); // c recurrent weights - delete Wch; + if (!gradients) { + Tensor *Wih = layer->Wih->permute({1, 0}); + r->mutable_float_data()->Add(Wih->ptr, Wih->ptr + Wih->size); // i recurrent weights + delete Wih; + Tensor *Woh = layer->Woh->permute({1, 0}); + r->mutable_float_data()->Add(Woh->ptr, Woh->ptr + Woh->size); // o recurrent weights + delete Woh; + Tensor *Wfh = layer->Wfh->permute({1, 0}); + r->mutable_float_data()->Add(Wfh->ptr, Wfh->ptr + Wfh->size); // f recurrent weights + delete Wfh; + Tensor *Wch = layer->Wch->permute({1, 0}); + r->mutable_float_data()->Add(Wch->ptr, Wch->ptr + Wch->size); // c recurrent weights + delete Wch; + } else { + Tensor *Wih = layer->acc_gWih->permute({1, 0}); + r->mutable_float_data()->Add(Wih->ptr, Wih->ptr + Wih->size); // i recurrent weights + delete Wih; + Tensor *Woh = layer->acc_gWoh->permute({1, 0}); + r->mutable_float_data()->Add(Woh->ptr, Woh->ptr + Woh->size); // o recurrent weights + delete Woh; + Tensor *Wfh = layer->acc_gWfh->permute({1, 0}); + r->mutable_float_data()->Add(Wfh->ptr, Wfh->ptr + Wfh->size); // f recurrent weights + delete Wfh; + Tensor *Wch = layer->acc_gWch->permute({1, 0}); + r->mutable_float_data()->Add(Wch->ptr, Wch->ptr + Wch->size); // c recurrent weights + delete Wch; + } // B input (biases for all the layers) onnx::TensorProto *b = graph->add_initializer(); @@ -426,10 +454,17 @@ void build_lstm_node(LLSTM *layer, onnx::GraphProto *graph) vector b_dims{1, 8 * layer->units}; // b_dims shape[0] = 1 for weights in one directions b->mutable_dims()->Add(b_dims.begin(), b_dims.end()); // Set the shape of the weights - b->mutable_float_data()->Add(layer->inbias->ptr, layer->inbias->ptr + layer->inbias->size); // i bias - b->mutable_float_data()->Add(layer->onbias->ptr, layer->onbias->ptr + layer->onbias->size); // o bias - b->mutable_float_data()->Add(layer->fnbias->ptr, layer->fnbias->ptr + layer->fnbias->size); // f bias - b->mutable_float_data()->Add(layer->cnbias->ptr, layer->cnbias->ptr + layer->cnbias->size); // c bias + if (!gradients) { + b->mutable_float_data()->Add(layer->inbias->ptr, layer->inbias->ptr + layer->inbias->size); // i bias + b->mutable_float_data()->Add(layer->onbias->ptr, layer->onbias->ptr + layer->onbias->size); // o bias + b->mutable_float_data()->Add(layer->fnbias->ptr, layer->fnbias->ptr + layer->fnbias->size); // f bias + b->mutable_float_data()->Add(layer->cnbias->ptr, layer->cnbias->ptr + layer->cnbias->size); // c bias + } else { + b->mutable_float_data()->Add(layer->acc_ginbias->ptr, layer->acc_ginbias->ptr + layer->acc_ginbias->size); // i bias + b->mutable_float_data()->Add(layer->acc_gonbias->ptr, layer->acc_gonbias->ptr + layer->acc_gonbias->size); // o bias + b->mutable_float_data()->Add(layer->acc_gfnbias->ptr, layer->acc_gfnbias->ptr + layer->acc_gfnbias->size); // f bias + b->mutable_float_data()->Add(layer->acc_gcnbias->ptr, layer->acc_gcnbias->ptr + layer->acc_gcnbias->size); // c bias + } // Set recurrent forward biases to 0 (only one bias used, not one for x and another for h) for (int i = 0; i < 4 * layer->units; ++i) @@ -478,4 +513,210 @@ void build_lstm_node(LLSTM *layer, onnx::GraphProto *graph) } } +/* + * DISTRIBUTED TRAINING + */ + +vector get_lstm_tensors(onnx::NodeProto &node, + map> &map_init_values, + map> &map_init_dims) +{ + vector lstm_tensors; + + int hidden_size = -1; + for (int j = 0; j < node.attribute_size(); j++) + { // Set the attributes + onnx::AttributeProto attribute = node.attribute(j); + string attr_name = attribute.name(); + if (!attr_name.compare("hidden_size")) + { + hidden_size = attribute.i(); + } + } + + if (hidden_size < 0) + msg("The layer " + node.name() + " (LSTM) does not provide the hidden_size attribute.", "[ONNX::ImportNet]"); + + string weights_gates = node.input(1); // Get weights and dims + vector *weights_g = &(map_init_values[weights_gates]); + vector dims_g = map_init_dims[weights_gates]; + int input_size = dims_g[2]; + + // Load input weights with shape [hidden_size, input_size]. After load we transpose + // Note: EDDL input weights are of shape [input_size, hidden_size] + vector dims_input_lstm = {dims_g[1] / 4, dims_g[2]}; + + vector *weights_input_g = new vector; + vector *weights_output_g = new vector; + vector *weights_forget_g = new vector; + vector *weights_cell_g = new vector; + int w_size = input_size * hidden_size; + weights_input_g->assign(weights_g->begin() + w_size * 0, weights_g->begin() + w_size * 1); + weights_output_g->assign(weights_g->begin() + w_size * 1, weights_g->begin() + w_size * 2); + weights_forget_g->assign(weights_g->begin() + w_size * 2, weights_g->begin() + w_size * 3); + weights_cell_g->assign(weights_g->begin() + w_size * 3, weights_g->begin() + w_size * 4); + + string recurrence_weights_gates = node.input(2); // Get weights and dims + vector *recurrence_weights_g = &(map_init_values[recurrence_weights_gates]); + vector recurrence_dims_g = map_init_dims[recurrence_weights_gates]; + + vector dims_recurrent_lstm = {recurrence_dims_g[2], recurrence_dims_g[2]}; + + vector *recurrence_weights_input_g = new vector; + vector *recurrence_weights_output_g = new vector; + vector *recurrence_weights_forget_g = new vector; + vector *recurrence_weights_cell_g = new vector; + w_size = hidden_size * hidden_size; + recurrence_weights_input_g->assign(recurrence_weights_g->begin() + w_size * 0, recurrence_weights_g->begin() + w_size * 1); + recurrence_weights_output_g->assign(recurrence_weights_g->begin() + w_size * 1, recurrence_weights_g->begin() + w_size * 2); + recurrence_weights_forget_g->assign(recurrence_weights_g->begin() + w_size * 2, recurrence_weights_g->begin() + w_size * 3); + recurrence_weights_cell_g->assign(recurrence_weights_g->begin() + w_size * 3, recurrence_weights_g->begin() + w_size * 4); + + /* + * The Weights are permuted before copying them to the LSTM layer (mismatch between ONNX standad and EDDL implementation) + */ + int dev = DEV_CPU; + Tensor *weights_input_tensor = new Tensor(dims_input_lstm, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(weights_input_g, weights_input_tensor); + delete weights_input_g; + weights_input_tensor->permute_({1, 0}); + + Tensor *weights_output_tensor = new Tensor(dims_input_lstm, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(weights_output_g, weights_output_tensor); + delete weights_output_g; + weights_output_tensor->permute_({1, 0}); + + Tensor *weights_forget_tensor = new Tensor(dims_input_lstm, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(weights_forget_g, weights_forget_tensor); + delete weights_forget_g; + weights_forget_tensor->permute_({1, 0}); + + Tensor *weights_cell_tensor = new Tensor(dims_input_lstm, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(weights_cell_g, weights_cell_tensor); + delete weights_cell_g; + weights_cell_tensor->permute_({1, 0}); + + Tensor *recurrence_weights_input_tensor = new Tensor(dims_recurrent_lstm, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(recurrence_weights_input_g, recurrence_weights_input_tensor); + delete recurrence_weights_input_g; + recurrence_weights_input_tensor->permute_({1, 0}); + + Tensor *recurrence_weights_output_tensor = new Tensor(dims_recurrent_lstm, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(recurrence_weights_output_g, recurrence_weights_output_tensor); + delete recurrence_weights_output_g; + recurrence_weights_output_tensor->permute_({1, 0}); + + Tensor *recurrence_weights_forget_tensor = new Tensor(dims_recurrent_lstm, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(recurrence_weights_forget_g, recurrence_weights_forget_tensor); + delete recurrence_weights_forget_g; + recurrence_weights_forget_tensor->permute_({1, 0}); + + Tensor *recurrence_weights_cell_tensor = new Tensor(dims_recurrent_lstm, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(recurrence_weights_cell_g, recurrence_weights_cell_tensor); + delete recurrence_weights_cell_g; + recurrence_weights_cell_tensor->permute_({1, 0}); + + // EDDL LSTM uses the order ifoc not iofc + lstm_tensors.push_back(weights_input_tensor); + lstm_tensors.push_back(weights_forget_tensor); + lstm_tensors.push_back(weights_output_tensor); + lstm_tensors.push_back(weights_cell_tensor); + lstm_tensors.push_back(recurrence_weights_input_tensor); + lstm_tensors.push_back(recurrence_weights_forget_tensor); + lstm_tensors.push_back(recurrence_weights_output_tensor); + lstm_tensors.push_back(recurrence_weights_cell_tensor); + + /* + * Set bias values + */ + vector bias_dims = {hidden_size}; + // Vectors to store the imported weights + vector *bias_input = new vector; + vector *bias_output = new vector; + vector *bias_forget = new vector; + vector *bias_cell = new vector; + vector *bias_recurrence_input = new vector; + vector *bias_recurrence_output = new vector; + vector *bias_recurrence_forget = new vector; + vector *bias_recurrence_cell = new vector; + + if (node.input_size() > 3) { + string biases_name = node.input(3); //Get weights and dims + vector *biases = &(map_init_values[biases_name]); + + bias_input->assign(biases->begin() + hidden_size * 0, biases->begin() + hidden_size * 1); + bias_output->assign(biases->begin() + hidden_size * 1, biases->begin() + hidden_size * 2); + bias_forget->assign(biases->begin() + hidden_size * 2, biases->begin() + hidden_size * 3); + bias_cell->assign(biases->begin() + hidden_size * 3, biases->begin() + hidden_size * 4); + bias_recurrence_input->assign(biases->begin() + hidden_size * 4, biases->begin() + hidden_size * 5); + bias_recurrence_output->assign(biases->begin() + hidden_size * 5, biases->begin() + hidden_size * 6); + bias_recurrence_forget->assign(biases->begin() + hidden_size * 6, biases->begin() + hidden_size * 7); + bias_recurrence_cell->assign(biases->begin() + hidden_size * 7, biases->begin() + hidden_size * 8); + } else { + // Set bias values to 0.0 + // Note: In EDDL we don't have use_bias option for LSTM so to achieve the same + // result we set the bias values to 0.0 + vector zero_bias(hidden_size, 0.0); + bias_input->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_output->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_forget->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_cell->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_recurrence_input->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_recurrence_output->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_recurrence_forget->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + bias_recurrence_cell->assign(zero_bias.begin(), zero_bias.begin() + hidden_size); + } + + Tensor *bias_input_tensor = new Tensor(bias_dims, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(bias_input, bias_input_tensor); + delete bias_input; + + Tensor *bias_output_tensor = new Tensor(bias_dims, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(bias_output, bias_output_tensor); + delete bias_output; + + Tensor *bias_forget_tensor = new Tensor(bias_dims, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(bias_forget, bias_forget_tensor); + delete bias_forget; + + Tensor *bias_cell_tensor = new Tensor(bias_dims, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(bias_cell, bias_cell_tensor); + delete bias_cell; + + // EDDL LSTM uses the order ifoc not iofc + lstm_tensors.push_back(bias_input_tensor); + lstm_tensors.push_back(bias_forget_tensor); + lstm_tensors.push_back(bias_output_tensor); + lstm_tensors.push_back(bias_cell_tensor); + + // Add the recurrent bias values + /* Not needed for importing gradients + Tensor *bias_recurrence_input_tensor = new Tensor(bias_dims, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(bias_recurrence_input, bias_recurrence_input_tensor); + Tensor::add(bias_recurrence_input_tensor, lstm->inbias, lstm->inbias); + delete bias_recurrence_input_tensor; + delete bias_recurrence_input; + + Tensor *bias_recurrence_output_tensor = new Tensor(bias_dims, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(bias_recurrence_output, bias_recurrence_output_tensor); + Tensor::add(bias_recurrence_output_tensor, lstm->onbias, lstm->onbias); + delete bias_recurrence_output_tensor; + delete bias_recurrence_output; + + Tensor *bias_recurrence_forget_tensor = new Tensor(bias_dims, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(bias_recurrence_forget, bias_recurrence_forget_tensor); + Tensor::add(bias_recurrence_forget_tensor, lstm->fnbias, lstm->fnbias); + delete bias_recurrence_forget_tensor; + delete bias_recurrence_forget; + + Tensor *bias_recurrence_cell_tensor = new Tensor(bias_dims, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(bias_recurrence_cell, bias_recurrence_cell_tensor); + Tensor::add(bias_recurrence_cell_tensor, lstm->cnbias, lstm->cnbias); + delete bias_recurrence_cell_tensor; + delete bias_recurrence_cell; + */ + + return lstm_tensors; +} + #endif // defined(cPROTO) diff --git a/src/serialization/onnx/net/layers/recurrent/rnn_onnx.cpp b/src/serialization/onnx/net/layers/recurrent/rnn_onnx.cpp index e70980bb0..6bc339a1d 100644 --- a/src/serialization/onnx/net/layers/recurrent/rnn_onnx.cpp +++ b/src/serialization/onnx/net/layers/recurrent/rnn_onnx.cpp @@ -114,7 +114,7 @@ Layer* build_rnn_layer(onnx::NodeProto *node, } if (hidden_size < 0) - msg("RNN layer " + name + " doesn't have the number of neurons.", "ONNX::ImportNet"); + msg("The layer " + name + " (RNN) does not provide the hidden_size attribute.", "[ONNX::ImportNet]"); string parent_name = node->input(0); // Get parent Layer *parent = output_node_map[parent_name]; @@ -222,7 +222,7 @@ Layer* build_rnn_layer(onnx::NodeProto *node, } // ONNX export -void build_rnn_node(LRNN *layer, onnx::GraphProto *graph) +void build_rnn_node(LRNN *layer, onnx::GraphProto *graph, bool gradients) { // Add an empty node to the graph onnx::NodeProto *node = graph->add_node(); @@ -233,11 +233,12 @@ void build_rnn_node(LRNN *layer, onnx::GraphProto *graph) node->add_input(layer->name + "_W"); node->add_input(layer->name + "_R"); if (layer->use_bias) node->add_input(layer->name + "_B"); - else node->add_input(""); - node->add_input(""); // Empty str to skip the sequence_lens input + // Check if we have to copy states for a decoder RNN if (layer->parent.size() > 1 && layer->isdecoder) { + if (!layer->use_bias) node->add_input(""); // Empty str to skip the bias input + node->add_input(""); // Empty str to skip the sequence_lens input string l_copyStates_name = layer->parent[1]->name; node->add_input(l_copyStates_name + "_h"); } @@ -310,9 +311,15 @@ void build_rnn_node(LRNN *layer, onnx::GraphProto *graph) /* * The Weights are permuted before saving them (required by ONNX standad) */ - Tensor *Wx = layer->Wx->permute({1, 0}); - w->mutable_float_data()->Add(Wx->ptr, Wx->ptr + Wx->size); - delete Wx; + if (!gradients) { + Tensor *Wx = layer->Wx->permute({1, 0}); + w->mutable_float_data()->Add(Wx->ptr, Wx->ptr + Wx->size); + delete Wx; + } else { + Tensor *Wx = layer->acc_gWx->permute({1, 0}); + w->mutable_float_data()->Add(Wx->ptr, Wx->ptr + Wx->size); + delete Wx; + } // Recurrent weights onnx::TensorProto *r = graph->add_initializer(); @@ -323,9 +330,15 @@ void build_rnn_node(LRNN *layer, onnx::GraphProto *graph) /* * The Weights are permuted before saving them (required by ONNX standad) */ - Tensor *Wy = layer->Wy->permute({1, 0}); - r->mutable_float_data()->Add(Wy->ptr, Wy->ptr + Wy->size); - delete Wy; + if (!gradients) { + Tensor *Wy = layer->Wy->permute({1, 0}); + r->mutable_float_data()->Add(Wy->ptr, Wy->ptr + Wy->size); + delete Wy; + } else { + Tensor *Wy = layer->acc_gWy->permute({1, 0}); + r->mutable_float_data()->Add(Wy->ptr, Wy->ptr + Wy->size); + delete Wy; + } // Bias if (layer->use_bias) { @@ -334,7 +347,10 @@ void build_rnn_node(LRNN *layer, onnx::GraphProto *graph) b->set_data_type(onnx::TensorProto::FLOAT); vector b_dims{1, 2 * layer->units}; // b_dims shape[0] = 1 for weights in one directions b->mutable_dims()->Add(b_dims.begin(), b_dims.end()); // Set the shape of the weights - b->mutable_float_data()->Add(layer->bias->ptr, layer->bias->ptr + layer->bias->size); + if (!gradients) + b->mutable_float_data()->Add(layer->bias->ptr, layer->bias->ptr + layer->bias->size); + else + b->mutable_float_data()->Add(layer->acc_gbias->ptr, layer->acc_gbias->ptr + layer->acc_gbias->size); // Set recurrent biases to 0 for (int i = 0; i < layer->units; ++i) b->add_float_data(0.0); @@ -377,4 +393,93 @@ void build_rnn_node(LRNN *layer, onnx::GraphProto *graph) } } +/* + * DISTRIBUTED TRAINING + */ + +vector get_rnn_tensors(onnx::NodeProto &node, + map> &map_init_values, + map> &map_init_dims) +{ + vector rnn_tensors; + int hidden_size = -1; + + for (int j = 0; j < node.attribute_size(); j++) + { // Set the attributes + onnx::AttributeProto attribute = node.attribute(j); + string attr_name = attribute.name(); + if (!attr_name.compare("hidden_size")) + hidden_size = attribute.i(); + } + + if (hidden_size < 0) + msg("The layer " + node.name() + " (RNN) does not provide the hidden_size attribute.", "[ONNX::ImportNet]"); + + string weights_gates = node.input(1); // Get weights and dims + vector *weights_g = &(map_init_values[weights_gates]); + vector dims_g = map_init_dims[weights_gates]; + int input_size = dims_g[2]; + + // Load input weights with shape [hidden_size, input_size]. After load we transpose + // Note: EDDL input weights are of shape [input_size, hidden_size] + vector dims_input_gru = {dims_g[1], input_size}; + + vector *weights_x = new vector; + int w_size = input_size * hidden_size; + weights_x->assign(weights_g->begin() , weights_g->begin() + w_size); + + string recurrence_weights_gates = node.input(2); // Get weights and dims + vector *recurrence_weights_g = &(map_init_values[recurrence_weights_gates]); + vector recurrence_dims_g = map_init_dims[recurrence_weights_gates]; + + vector dims_recurrent_gru = {recurrence_dims_g[2], recurrence_dims_g[2]}; + + vector *weights_h = new vector; + w_size = hidden_size * hidden_size; + weights_h->assign(recurrence_weights_g->begin(), recurrence_weights_g->begin() + w_size); + + /* + * The Weights are permuted before copying them to the RNN layer (mismatch between ONNX standad and EDDL implementation) + */ + int dev = DEV_CPU; + Tensor *weights_x_tensor = new Tensor(dims_input_gru, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(weights_x, weights_x_tensor); + weights_x_tensor->permute_({1, 0}); + delete weights_x; + + Tensor *weights_h_tensor = new Tensor(dims_recurrent_gru, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(weights_h, weights_h_tensor); + weights_h_tensor->permute_({1, 0}); + delete weights_h; + + rnn_tensors.push_back(weights_x_tensor); + rnn_tensors.push_back(weights_h_tensor); + + if (node.input_size() > 3) { // If use_bias + string biases_name = node.input(3); + vector *biases = &(map_init_values[biases_name]); + vector bias_dims = {hidden_size}; + + vector *bias_x = new vector; + bias_x->assign(biases->begin() + hidden_size * 0, biases->begin() + hidden_size * 1); + Tensor *bias_x_tensor = new Tensor(bias_dims, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(bias_x, bias_x_tensor); + delete bias_x; + + rnn_tensors.push_back(bias_x_tensor); + + /* Not needed for importing gradients + vector *bias_h = new vector; + bias_h->assign(biases->begin() + hidden_size * 1, biases->begin() + hidden_size * 2); + Tensor *bias_h_tensor = new Tensor(bias_dims, nullptr, dev); + COPY_FROM_VECTOR_PTR_TO_TENSOR(bias_h, bias_h_tensor); + Tensor::add(bias_h_tensor, bias_x_tensor, bias_x_tensor); + delete bias_h_tensor; + delete bias_h; + */ + } + + return rnn_tensors; +} + #endif // defined(cPROTO)