Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
jslhcl committed Feb 21, 2024
1 parent 3e50a00 commit 45747fd
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
4 changes: 2 additions & 2 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ class PlannerImpl {

// Initialize allocation plan:
plan_.allocation_plan.resize(num_ml_values);
for (size_t i = 0; i < num_ml_values; i++) plan_.allocation_plan[i].reused_buffer = static_cast<OrtValueIndex>(i);
for (int i = 0; static_cast<size_t>(i) < num_ml_values; i++) AllocPlan(i).reused_buffer = i;
}

bool HasExternalOutputs(const Node& node) const {
Expand Down Expand Up @@ -1345,7 +1345,7 @@ class PlannerImpl {
for (size_t i = 0; i < stream_nodes_.size(); ++i) {
// compute use count first. TODO(leca): call ComputeReuseCount() only once is enough
ORT_RETURN_IF_ERROR(ComputeReuseCount());
for (size_t j = 0; j < ort_value_info_.size(); j++) Buffer(static_cast<OrtValueIndex>(j)) = static_cast<OrtValueIndex>(j);
for (int j = 0; static_cast<size_t>(j) < ort_value_info_.size(); j++) Buffer(j) = j;
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
if (i == 0) {
for (auto ort_value_info : ort_value_info_) {
Expand Down
25 changes: 19 additions & 6 deletions onnxruntime/test/framework/allocation_planner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1982,8 +1982,8 @@ TEST_F(PlannerTest, TestCpuIf) {
// |-----------> Gather
// Shape ----------------> Reshape --> Shape ------------------> Reshape
// ^ ^
// InstanceNormalization ----| InstanceNormalization ------|
//
// InstanceNormalization ----| InstanceNormalization ------|
//
// Python script to create this model:
// def CreateModelFor19480():
// #shape->reshape->shape->reshape, 4 gather
Expand All @@ -1998,10 +1998,23 @@ TEST_F(PlannerTest, TestCpuIf) {
// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices2'], outputs=['output2']))
// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices3'], outputs=['output3']))
// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices4'], outputs=['output4']))
// g = h.make_graph(graphNodes, 'issue_19480', [h.make_tensor_value_info('shape_input', tp.FLOAT, ['batch', 128, None, None]), h.make_tensor_value_info('in0_input', tp.FLOAT, ['batch', 32, None]), h.make_tensor_value_info('scale0', tp.FLOAT, [32]), h.make_tensor_value_info('B0', tp.FLOAT, [32]),
// h.make_tensor_value_info('in1_input', tp.FLOAT, ['batch', 32, None]), h.make_tensor_value_info('scale1', tp.FLOAT, [32]), h.make_tensor_value_info('B1', tp.FLOAT, [32]),
// h.make_tensor_value_info('indices1', tp.INT32, []), h.make_tensor_value_info('indices2', tp.INT32, []), h.make_tensor_value_info('indices3', tp.INT32, []), h.make_tensor_value_info('indices4', tp.INT32, [])],
// [h.make_tensor_value_info('output0', tp.FLOAT, None), h.make_tensor_value_info('output1', tp.INT64, None), h.make_tensor_value_info('output2', tp.INT64, None), h.make_tensor_value_info('output3', tp.INT64, None), h.make_tensor_value_info('output4', tp.INT64, None)])
// g = h.make_graph(graphNodes, 'issue_19480',
// [h.make_tensor_value_info('shape_input', tp.FLOAT, ['batch', 128, None, None]),
// h.make_tensor_value_info('in0_input', tp.FLOAT, ['batch', 32, None]),
// h.make_tensor_value_info('scale0', tp.FLOAT, [32]),
// h.make_tensor_value_info('B0', tp.FLOAT, [32]),
// h.make_tensor_value_info('in1_input', tp.FLOAT, ['batch', 32, None]),
// h.make_tensor_value_info('scale1', tp.FLOAT, [32]),
// h.make_tensor_value_info('B1', tp.FLOAT, [32]),
// h.make_tensor_value_info('indices1', tp.INT32, []),
// h.make_tensor_value_info('indices2', tp.INT32, []),
// h.make_tensor_value_info('indices3', tp.INT32, []),
// h.make_tensor_value_info('indices4', tp.INT32, [])],
// [h.make_tensor_value_info('output0', tp.FLOAT, None),
// h.make_tensor_value_info('output1', tp.INT64, None),
// h.make_tensor_value_info('output2', tp.INT64, None),
// h.make_tensor_value_info('output3', tp.INT64, None),
// h.make_tensor_value_info('output4', tp.INT64, None)])
// model = h.make_model(g, opset_imports=[h.make_operatorsetid("", 17)], producer_name='producer_name')
// onnx.save(model, 'issue_19480.onnx')
//
Expand Down

0 comments on commit 45747fd

Please sign in to comment.