Skip to content

Commit

Permalink
Add more debug
Browse files Browse the repository at this point in the history
  • Loading branch information
Ted Themistokleous committed Jul 4, 2024
1 parent 868fc72 commit f186212
Showing 1 changed file with 20 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
}
}

LOGS_DEFAULT(WARNING) << "Compile Done" << std::endl;
LOGS_DEFAULT(WARNING) << "Initial Compile Done" << std::endl;

// compile the program
map_progs_[fused_node.Name()] = prog;
Expand Down Expand Up @@ -1282,7 +1282,9 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
// migraphx::onnx_options cmp_options;
if (param_shapes.size() > 0) {
for (auto&& name : param_shapes.names()) {
LOGS_DEFAULT(WARNING) << "Input Map:" << name << std::endl;
if (map_input_name_index.count(name) > 0) {
LOGS_DEFAULT(WARNING) << "Input Map Found:" << name << std::endl;
auto input_tensor = ctx.GetInput(map_input_name_index[name]);
auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
const auto tensor_shape = tensor_info.GetShape();
Expand All @@ -1293,12 +1295,24 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
auto mgx_strides = mgx_s.strides();
if (mgx_lens.size() == 1 and mgx_lens[0] == 1 and
mgx_strides.size() == 1 and mgx_strides[0] == 0) {
LOGS_DEFAULT(WARNING) << "mgx_lens clear!" << std::endl;
mgx_lens.clear();
}

if (mgx_lens != ort_lens) {
cmp_options.set_input_parameter_shape(name, ort_lens);
LOGS_DEFAULT(WARNING) << "mgx_lens != ort_lens" << std::endl;
std::string migx_out_lens;
std::string ort_out_lens;
for( auto &m_len: mgx_lens){
LOGS_DEFAULT(WARNING) << m_len << ",";
}
LOGS_DEFAULT(WARNING) << std::endl;
for( auto &o_len: ort_lens){
LOGS_DEFAULT(WARNING) << o_len << ",";
}
LOGS_DEFAULT(WARNING) << std::endl;

input_shape_match = false;
}
}
Expand Down Expand Up @@ -1403,6 +1417,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
}
// It is a output argument
else {
LOGS_DEFAULT(WARNING) << "output"<< name << std::endl;
auto compute_output_index = [](const std::string& name) -> int {
std::string out_name_prefix = "#output_";
auto pos = name.find(out_name_prefix);
Expand All @@ -1416,6 +1431,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&

int output_index = compute_output_index(name);
if (output_index != -1) {
LOGS_DEFAULT(WARNING) << "Set output"<< name << std::endl;
prog_output_indices.push_back(output_index);
auto mgx_output_shape = prog_output_shapes[output_index];
auto lens = mgx_output_shape.lengths();
Expand All @@ -1425,12 +1441,14 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&

// argument shape
auto mgx_arg_shape = param_shapes[name];
LOGS_DEFAULT(WARNING) << "add output arg"<< name << std::endl;
m.add(name, migraphx::argument(mgx_arg_shape, output_data));
}
}
}
}

LOGS_DEFAULT(WARNING) << "Before run" << std::endl;
{
// lock to avoid race condition
std::lock_guard<OrtMutex> lock(*(mgx_state->mgx_mu_ptr));
Expand All @@ -1456,6 +1474,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
}
};

LOGS_DEFAULT(WARNING) << "After Run" << std::endl;
return Status::OK();
};

Expand Down

0 comments on commit f186212

Please sign in to comment.