Skip to content

Commit

Permalink
Safeguard default argument against mutation
Browse files Browse the repository at this point in the history
  • Loading branch information
kthui committed Oct 24, 2023
1 parent 4d1bcc6 commit f25ffdb
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions src/pb_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,27 @@ SignalHandler(int signum)
// Skip the SIGINT and SIGTERM
}

template <typename PYTYPE>
PYTYPE
PyDefaultArgumentToMutableType(const py::object& argument)
{
// The default argument on Python functions always reference the same copy,
// meaning if the default argument is changed by the function, then it is
// changed for all subsequent calls to the function. Thus, default arguments
// should be limited to basic types (i.e. None). This helper function returns
// an empty expected type, if the argument is None (i.e. default initialized).
// If the argument is neither None nor expected type, an exception is thrown.
if (py::isinstance<py::none>(argument)) {
return PYTYPE();
}
if (py::isinstance<PYTYPE>(argument)) {
return argument;
}
throw PythonBackendException(
std::string("Expect ") + typeid(PYTYPE).name() + ", got " +
std::string(py::str(argument.get_type())));
}

void
Stub::Instantiate(
int64_t shm_growth_size, int64_t shm_default_size,
Expand Down Expand Up @@ -1464,7 +1485,10 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module)
const int64_t model_version, const uint32_t flags,
const int32_t timeout,
const PreferredMemory& preferred_memory,
const InferenceTrace& trace, const py::dict& parameters) {
const InferenceTrace& trace,
const py::object& parameters_) {
py::dict parameters =
PyDefaultArgumentToMutableType<py::dict>(parameters_);
std::set<std::string> requested_outputs;
for (auto& requested_output_name : requested_output_names) {
requested_outputs.emplace(requested_output_name);
Expand Down Expand Up @@ -1503,7 +1527,7 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module)
py::arg("preferred_memory").none(false) =
PreferredMemory(PreferredMemory::DEFAULT, 0),
py::arg("trace").none(false) = InferenceTrace(),
py::arg("parameters").none(false) = py::dict())
py::arg("parameters").none(true) = py::none())
.def(
"inputs", &InferRequest::Inputs,
py::return_value_policy::reference_internal)
Expand Down

0 comments on commit f25ffdb

Please sign in to comment.