diff --git a/src/tensorflow.cc b/src/tensorflow.cc index 440d172..0f2144e 100644 --- a/src/tensorflow.cc +++ b/src/tensorflow.cc @@ -317,6 +317,24 @@ ValidateTRITONTFModel( // Verify that the model configuration input and outputs match what // is expected by the model. + // Check the name of each input first before checking the count to ensure that + // the error message returned includes the names of any unexpected extra + // inputs, instead of just a count mismatch error. + for (size_t i = 0; i < config_inputs.ArraySize(); i++) { + triton::common::TritonJson::Value io; + RETURN_IF_ERROR(config_inputs.IndexAsObject(i, &io)); + RETURN_IF_ERROR(CheckAllowedModelInput(io, expected_inputs)); + + std::string io_name; + RETURN_IF_ERROR(io.MemberAsString("name", &io_name)); + const TRITONTF_IO* input = FindIOByName(inputs, io_name); + if (input == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected inference input '" + io_name + "'").c_str()); + } + } + if (expected_inputs.size() != expected_input_cnt) { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, @@ -330,16 +348,10 @@ ValidateTRITONTFModel( for (size_t i = 0; i < config_inputs.ArraySize(); i++) { triton::common::TritonJson::Value io; RETURN_IF_ERROR(config_inputs.IndexAsObject(i, &io)); - RETURN_IF_ERROR(CheckAllowedModelInput(io, expected_inputs)); std::string io_name; RETURN_IF_ERROR(io.MemberAsString("name", &io_name)); const TRITONTF_IO* input = FindIOByName(inputs, io_name); - if (input == nullptr) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - std::string("unexpected inference input '" + io_name + "'").c_str()); - } // If a reshape is provided for the input then use that when // validating that the TF model matches what is expected.