From 08001d18ac41ee2fe95ce9d4d064c2fb725e583f Mon Sep 17 00:00:00 2001 From: pengwa Date: Thu, 25 Jul 2024 08:25:22 +0800 Subject: [PATCH] Fix security issue #22016 #22017 #22018 (#21333) ### Description ### Motivation and Context --- .../memory_optimizer/recompute_analysis.cc | 814 +++++++++--------- .../training_api/core/training_api_tests.cc | 3 +- .../orttraining/training_api/checkpoint.cc | 2 +- 3 files changed, 416 insertions(+), 403 deletions(-) diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 8d110c692751e..1135ef41cfc47 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -67,410 +67,422 @@ using OpsetToIgnorableIndicesMap = InlinedHashMap; * or not. * 3. Some ops are not supported in older opsets, we need to check whether it is applicable to recompute or not. */ -const InlinedHashMap& GetAllowedRecomputeOps(int probe_op_level) { - static InlinedHashMap> recomputable_op_table_map; - if (recomputable_op_table_map.find(probe_op_level) != recomputable_op_table_map.end()) { - return recomputable_op_table_map.at(probe_op_level); - } +InlinedHashMap> InitializeRecomputableOpTable() { + InlinedHashMap> recomputable_op_table_map; + + constexpr const int basic_op_level = static_cast(ProbeLevel::Basic); + recomputable_op_table_map.insert({basic_op_level, InlinedHashMap()}); + auto& basic_recomputable_op_table = recomputable_op_table_map.at(basic_op_level); + + basic_recomputable_op_table.insert({ + { + utils::GetFullQualifiedOpName("Add", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BatchNormalization", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {9, {}}, + {14, {}}, + {15, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BiasGelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BiasDropout", kMSDomain), + { + {1, {3, 4}}, // ignore ratio (optional) and training mode (optional) + }, + }, + { + utils::GetFullQualifiedOpName("BitmaskBiasDropout", kMSDomain), + { + {1, {3, 4}}, // ignore ratio (optional) and training mode (optional) + }, + }, + { + utils::GetFullQualifiedOpName("BitmaskDropout", kMSDomain), + { + {1, {1, 2}}, // ignore ratio (optional) and training mode (optional) + }, + }, + { + utils::GetFullQualifiedOpName("Cast", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {9, {}}, + {13, {}}, + {19, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("ConcatTraining", kMSDomain), + { + {1, {}}, + + }, + }, + { + utils::GetFullQualifiedOpName("ConstantOfShape", kOnnxDomain), + { + {9, {0}}, // ignore the `input`, e.g. the shape of the expected output tensor + {20, {0}}, + }, + }, + { + utils::GetFullQualifiedOpName("Cos", kOnnxDomain), + { + {7, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("CumSum", kOnnxDomain), + { + // The axis input is trivial + {11, {1}}, + {14, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Dropout", kOnnxDomain), + { + // ONNX Dropout 1, 6, 7, 10 do not have seed attribute, so we remove them from the recompute support. + {12, {1, 2}}, // ignore ratio and training_mode + {13, {1, 2}}, + }, + }, + { + utils::GetFullQualifiedOpName("Div", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Einsum", kOnnxDomain), + { + {12, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Equal", kOnnxDomain), + { + {1, {}}, + {7, {}}, + {11, {}}, + {13, {}}, + {19, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Expand", kOnnxDomain), + { + {8, {1}}, // Ignore the shape. + {13, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("FastGelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("FlattenAndUnpad", kMSDomain), + { + {1, {1}}, // ignore the indices + }, + }, + { + utils::GetFullQualifiedOpName("Gather", kOnnxDomain), + { + {1, {1}}, // ignore the indices + {11, {1}}, + {13, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Gelu", kOnnxDomain), + { + {20, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Gelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Gemm", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {9, {}}, + {11, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Less", kOnnxDomain), + { + {1, {}}, + {7, {}}, + {9, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("MemcpyFromHost", kOnnxDomain), + { + {1, {0}}, // Ignore CPU input. + }, + }, + { + utils::GetFullQualifiedOpName("Mul", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Neg", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("NonZero", kOnnxDomain), + { + {9, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("PadAndUnflatten", kMSDomain), + { + {1, {1, 2}}, // ignore the indices and unflatten_dims + }, + }, + { + // Be noted, NOT all PythonOp will be allowed to recompute, there will be further check. + utils::GetFullQualifiedOpName("PythonOp", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Range", kOnnxDomain), + { + {11, {0, 1, 2}}, // ignore start, end, delta, because they are scalars. + }, + }, + { + utils::GetFullQualifiedOpName("Reshape", kOnnxDomain), + { + {1, {}}, + {5, {}}, // ignore the shape. + {13, {}}, + {14, {}}, + {19, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Sin", kOnnxDomain), + { + {7, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Slice", kOnnxDomain), + { + {1, {}}, + {10, {1, 2, 3, 4}}, // ignore starts, ends, axes (optional) and steps (optional) + {11, {1, 2, 3, 4}}, + {13, {1, 2, 3, 4}}, + }, + }, + { + utils::GetFullQualifiedOpName("Split", kOnnxDomain), + { + {1, {1}}, // ignore split (optional) + {2, {}}, + {11, {}}, + {13, {1}}, // ignore the split (optional) + {18, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Squeeze", kOnnxDomain), + { + {1, {}}, + {11, {}}, + {13, {1}}, // ignore the axes (optional) + }, + }, + { + utils::GetFullQualifiedOpName("Sub", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Tile", kOnnxDomain), + { + {1, {1, 2}}, + {6, {1}}, + {13, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Transpose", kOnnxDomain), + { + {1, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Trilu", kOnnxDomain), + { + {14, {1}}, // ignore k (optional) + }, + }, + { + utils::GetFullQualifiedOpName("QuickGelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Unsqueeze", kOnnxDomain), + { + {1, {}}, + {11, {}}, + {13, {1}}, // ignore the axes (optional) + }, + }, + { + utils::GetFullQualifiedOpName("Where", kOnnxDomain), + { + {9, {}}, + {16, {}}, + }, + }, + + }); + + constexpr const int advanced_op_level = static_cast(ProbeLevel::Advanced); + recomputable_op_table_map.insert({advanced_op_level, InlinedHashMap()}); + auto& advanced_recomputable_op_table = recomputable_op_table_map.at(advanced_op_level); + // Append basic_recomputable_op_table to advanced_recomputable_op_table. + advanced_recomputable_op_table.insert(recomputable_op_table_map.at(basic_op_level).begin(), + recomputable_op_table_map.at(basic_op_level).end()); + + advanced_recomputable_op_table.insert({ + { + utils::GetFullQualifiedOpName("BiasSoftmax", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BiasSoftmaxDropout", kMSDomain), + { + {1, {2}}, // ignore ratio (optional) + }, + }, + { + utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), + { + // Opset 1 in ONNX official does not have LayerNormalization, + // while our contrib op defined LayerNormalization in opset 1 in ONNX domain. + {1, {}}, + {17, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("MatMul", kOnnxDomain), + { + {1, {}}, + {9, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("FusedMatMul", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("SimplifiedLayerNormalization", kOnnxDomain), + { + // Opset 1 in ONNX official does not have SimplifiedLayerNormalization, + // while our contrib op defined SimplifiedLayerNormalization in opset 1 in ONNX domain. + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("SkipLayerNormalization", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("SkipSimplifiedLayerNormalization", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Softmax", kOnnxDomain), + { + {1, {}}, + {11, {}}, + {13, {}}, + }, + }, + }); + + return recomputable_op_table_map; +} - recomputable_op_table_map.insert({probe_op_level, InlinedHashMap()}); - auto& recomputable_op_table = recomputable_op_table_map.at(probe_op_level); - if (probe_op_level >= static_cast(ProbeLevel::Basic)) { - recomputable_op_table.insert({ - { - utils::GetFullQualifiedOpName("Add", kOnnxDomain), - { - {1, {}}, - {6, {}}, - {7, {}}, - {13, {}}, - {14, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("BatchNormalization", kOnnxDomain), - { - {1, {}}, - {6, {}}, - {7, {}}, - {9, {}}, - {14, {}}, - {15, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("BiasGelu", kMSDomain), - { - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("BiasDropout", kMSDomain), - { - {1, {3, 4}}, // ignore ratio (optional) and training mode (optional) - }, - }, - { - utils::GetFullQualifiedOpName("BitmaskBiasDropout", kMSDomain), - { - {1, {3, 4}}, // ignore ratio (optional) and training mode (optional) - }, - }, - { - utils::GetFullQualifiedOpName("BitmaskDropout", kMSDomain), - { - {1, {1, 2}}, // ignore ratio (optional) and training mode (optional) - }, - }, - { - utils::GetFullQualifiedOpName("Cast", kOnnxDomain), - { - {1, {}}, - {6, {}}, - {9, {}}, - {13, {}}, - {19, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("ConcatTraining", kMSDomain), - { - {1, {}}, - - }, - }, - { - utils::GetFullQualifiedOpName("ConstantOfShape", kOnnxDomain), - { - {9, {0}}, // ignore the `input`, e.g. the shape of the expected output tensor - {20, {0}}, - }, - }, - { - utils::GetFullQualifiedOpName("Cos", kOnnxDomain), - { - {7, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("CumSum", kOnnxDomain), - { - // The axis input is trivial - {11, {1}}, - {14, {1}}, - }, - }, - { - utils::GetFullQualifiedOpName("Dropout", kOnnxDomain), - { - // ONNX Dropout 1, 6, 7, 10 do not have seed attribute, so we remove them from the recompute support. - {12, {1, 2}}, // ignore ratio and training_mode - {13, {1, 2}}, - }, - }, - { - utils::GetFullQualifiedOpName("Div", kOnnxDomain), - { - {1, {}}, - {6, {}}, - {7, {}}, - {13, {}}, - {14, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Einsum", kOnnxDomain), - { - {12, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Equal", kOnnxDomain), - { - {1, {}}, - {7, {}}, - {11, {}}, - {13, {}}, - {19, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Expand", kOnnxDomain), - { - {8, {1}}, // Ignore the shape. - {13, {1}}, - }, - }, - { - utils::GetFullQualifiedOpName("FastGelu", kMSDomain), - { - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("FlattenAndUnpad", kMSDomain), - { - {1, {1}}, // ignore the indices - }, - }, - { - utils::GetFullQualifiedOpName("Gather", kOnnxDomain), - { - {1, {1}}, // ignore the indices - {11, {1}}, - {13, {1}}, - }, - }, - { - utils::GetFullQualifiedOpName("Gelu", kOnnxDomain), - { - {20, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Gelu", kMSDomain), - { - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Gemm", kOnnxDomain), - { - {1, {}}, - {6, {}}, - {7, {}}, - {9, {}}, - {11, {}}, - {13, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Less", kOnnxDomain), - { - {1, {}}, - {7, {}}, - {9, {}}, - {13, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("MemcpyFromHost", kOnnxDomain), - { - {1, {0}}, // Ignore CPU input. - }, - }, - { - utils::GetFullQualifiedOpName("Mul", kOnnxDomain), - { - {1, {}}, - {6, {}}, - {7, {}}, - {13, {}}, - {14, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Neg", kOnnxDomain), - { - {1, {}}, - {6, {}}, - {13, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("NonZero", kOnnxDomain), - { - {9, {}}, - {13, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("PadAndUnflatten", kMSDomain), - { - {1, {1, 2}}, // ignore the indices and unflatten_dims - }, - }, - { - // Be noted, NOT all PythonOp will be allowed to recompute, there will be further check. - utils::GetFullQualifiedOpName("PythonOp", kMSDomain), - { - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Range", kOnnxDomain), - { - {11, {0, 1, 2}}, // ignore start, end, delta, because they are scalars. - }, - }, - { - utils::GetFullQualifiedOpName("Reshape", kOnnxDomain), - { - {1, {}}, - {5, {}}, // ignore the shape. - {13, {}}, - {14, {}}, - {19, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Sin", kOnnxDomain), - { - {7, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Slice", kOnnxDomain), - { - {1, {}}, - {10, {1, 2, 3, 4}}, // ignore starts, ends, axes (optional) and steps (optional) - {11, {1, 2, 3, 4}}, - {13, {1, 2, 3, 4}}, - }, - }, - { - utils::GetFullQualifiedOpName("Split", kOnnxDomain), - { - {1, {1}}, // ignore split (optional) - {2, {}}, - {11, {}}, - {13, {1}}, // ignore the split (optional) - {18, {1}}, - }, - }, - { - utils::GetFullQualifiedOpName("Squeeze", kOnnxDomain), - { - {1, {}}, - {11, {}}, - {13, {1}}, // ignore the axes (optional) - }, - }, - { - utils::GetFullQualifiedOpName("Sub", kOnnxDomain), - { - {1, {}}, - {6, {}}, - {7, {}}, - {13, {}}, - {14, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Tile", kOnnxDomain), - { - {1, {1, 2}}, - {6, {1}}, - {13, {1}}, - }, - }, - { - utils::GetFullQualifiedOpName("Transpose", kOnnxDomain), - { - {1, {}}, - {13, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Trilu", kOnnxDomain), - { - {14, {1}}, // ignore k (optional) - }, - }, - { - utils::GetFullQualifiedOpName("QuickGelu", kMSDomain), - { - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Unsqueeze", kOnnxDomain), - { - {1, {}}, - {11, {}}, - {13, {1}}, // ignore the axes (optional) - }, - }, - { - utils::GetFullQualifiedOpName("Where", kOnnxDomain), - { - {9, {}}, - {16, {}}, - }, - }, - - }); - } +const InlinedHashMap& GetAllowedRecomputeOps(int probe_op_level) { + static InlinedHashMap> + recomputable_op_table_map = InitializeRecomputableOpTable(); - if (probe_op_level >= static_cast(ProbeLevel::Advanced)) { - recomputable_op_table.insert({ - { - utils::GetFullQualifiedOpName("BiasSoftmax", kMSDomain), - { - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("BiasSoftmaxDropout", kMSDomain), - { - {1, {2}}, // ignore ratio (optional) - }, - }, - { - utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), - { - // Opset 1 in ONNX official does not have LayerNormalization, - // while our contrib op defined LayerNormalization in opset 1 in ONNX domain. - {1, {}}, - {17, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("MatMul", kOnnxDomain), - { - {1, {}}, - {9, {}}, - {13, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("FusedMatMul", kMSDomain), - { - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("SimplifiedLayerNormalization", kOnnxDomain), - { - // Opset 1 in ONNX official does not have SimplifiedLayerNormalization, - // while our contrib op defined SimplifiedLayerNormalization in opset 1 in ONNX domain. - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("SkipLayerNormalization", kMSDomain), - { - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("SkipSimplifiedLayerNormalization", kMSDomain), - { - {1, {}}, - }, - }, - { - utils::GetFullQualifiedOpName("Softmax", kOnnxDomain), - { - {1, {}}, - {11, {}}, - {13, {}}, - }, - }, - }); - } + ORT_ENFORCE(recomputable_op_table_map.find(probe_op_level) != recomputable_op_table_map.end(), + "Cannot get recomputable op table, probe level: ", probe_op_level); - return recomputable_op_table; + return recomputable_op_table_map.at(probe_op_level); } /** diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc index 90c97eed0c6d3..be25eefb201da 100644 --- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc @@ -542,8 +542,9 @@ TEST(TrainingApiTest, OptimStep) { std::string param_name = "fc2.weight"; // before training, check if optim state is initialized to 0 onnxruntime::training::api::OptimizerCheckpointState& optimizer_states = state.optimizer_checkpoint_state; + std::shared_ptr group0_states = optimizer_states.group_named_optimizer_states["group0"]; onnxruntime::training::api::ParameterOptimizerState& param_state = - optimizer_states.group_named_optimizer_states["group0"]->param_named_optimizer_states.at(param_name); + group0_states->param_named_optimizer_states.at(param_name); OrtValue& moment_1 = param_state.at("momentum0"); std::vector param_vec_before_optimizer_step; diff --git a/orttraining/orttraining/training_api/checkpoint.cc b/orttraining/orttraining/training_api/checkpoint.cc index 56029b34c24d7..cbff1891b8c84 100644 --- a/orttraining/orttraining/training_api/checkpoint.cc +++ b/orttraining/orttraining/training_api/checkpoint.cc @@ -449,7 +449,7 @@ Status FromOptimizerState(const OptimizerCheckpointState& optimizer_state, fbs_optimizer_groups.reserve(optimizer_state.group_named_optimizer_states.size()); for (const auto& group_name : SortedKeys(optimizer_state.group_named_optimizer_states)) { - const std::shared_ptr& group_optimizer_state_ptr = + std::shared_ptr group_optimizer_state_ptr = optimizer_state.group_named_optimizer_states.at(group_name); std::vector> optimizer_states;