Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[py framework] Setting context values now clones instead of owns #22455

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bindings/pydrake/systems/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ drake_py_unittest(
":analysis_py",
":framework_py",
":primitives_py",
":test_util_py",
"//bindings/pydrake/common/test_utilities",
"//bindings/pydrake/examples",
],
Expand Down
134 changes: 93 additions & 41 deletions bindings/pydrake/systems/framework_py_semantics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,29 @@
#include "drake/systems/framework/system_output.h"

using std::string;
using std::unique_ptr;
using std::vector;

namespace drake {
namespace pydrake {

namespace {

using AbstractValuePtrList = vector<unique_ptr<AbstractValue>>;

// NOLINTNEXTLINE(build/namespaces): Emulate placement in namespace.
using namespace drake::systems;
constexpr auto& doc = pydrake_doc.drake.systems;

// Given a vector of (possibly null) pointers, returns a vector formed by
// calling Clone() elementwise.
template <typename T>
std::vector<std::unique_ptr<T>> CloneVectorOfPointers(
const std::vector<const T*>& input) {
std::vector<std::unique_ptr<T>> result;
result.reserve(input.size());
for (const T* item : input) {
result.push_back(item != nullptr ? item->Clone() : nullptr);
}
return result;
}

// Given an InputPort or OutputPort as self, return self.Eval(context). In
// python, always returns either a numpy.ndarray (when vector-valued) or the
// unwrapped T in a Value<T> (when abstract-valued).
Expand Down Expand Up @@ -104,7 +113,10 @@ void DoScalarIndependentDefinitions(py::module m) {
DefClone(&abstract_values);
abstract_values // BR
.def(py::init<>(), doc.AbstractValues.ctor.doc_0args)
.def(py::init<AbstractValuePtrList>(), doc.AbstractValues.ctor.doc_1args)
.def(py::init([](const std::vector<const AbstractValue*>& data) {
return std::make_unique<AbstractValues>(CloneVectorOfPointers(data));
}),
py::arg("data"), doc.AbstractValues.ctor.doc_1args)
.def("size", &AbstractValues::size, doc.AbstractValues.size.doc)
.def("get_value", &AbstractValues::get_value, py::arg("index"),
py_rvp::reference_internal, doc.AbstractValues.get_value.doc)
Expand Down Expand Up @@ -939,24 +951,31 @@ void DefineParameters(py::module m) {
auto parameters = DefineTemplateClassWithDefault<Parameters<T>>(
m, "Parameters", GetPyParam<T>(), doc.Parameters.doc);
DefClone(&parameters);
using BasicVectorPtrList = vector<unique_ptr<BasicVector<T>>>;
parameters
parameters // BR
.def(py::init<>(), doc.Parameters.ctor.doc_0args)
// TODO(eric.cousineau): Ensure that we can respect keep alive behavior
// with lists of pointers.
.def(py::init<BasicVectorPtrList, AbstractValuePtrList>(),
.def(py::init([](const std::vector<const BasicVector<T>*>& numeric,
const std::vector<const AbstractValue*>& abstract) {
return std::make_unique<Parameters<T>>(
CloneVectorOfPointers(numeric), CloneVectorOfPointers(abstract));
}),
py::arg("numeric"), py::arg("abstract"),
doc.Parameters.ctor.doc_2args_numeric_abstract)
.def(py::init<BasicVectorPtrList>(), py::arg("numeric"),
doc.Parameters.ctor.doc_1args_numeric)
.def(py::init<AbstractValuePtrList>(), py::arg("abstract"),
doc.Parameters.ctor.doc_1args_abstract)
.def(py::init<unique_ptr<BasicVector<T>>>(), py::arg("vec"),
// Keep alive, ownership: `vec` keeps `self` alive.
py::keep_alive<2, 1>(), doc.Parameters.ctor.doc_1args_vec)
.def(py::init<unique_ptr<AbstractValue>>(), py::arg("value"),
// Keep alive, ownership: `value` keeps `self` alive.
py::keep_alive<2, 1>(), doc.Parameters.ctor.doc_1args_value)
.def(py::init([](const std::vector<const BasicVector<T>*>& numeric) {
return std::make_unique<Parameters<T>>(CloneVectorOfPointers(numeric));
}),
py::arg("numeric"), doc.Parameters.ctor.doc_1args_numeric)
.def(py::init([](const std::vector<const AbstractValue*>& abstract) {
return std::make_unique<Parameters<T>>(CloneVectorOfPointers(abstract));
}),
py::arg("abstract"), doc.Parameters.ctor.doc_1args_abstract)
.def(py::init([](const BasicVector<T>& vec) {
return std::make_unique<Parameters<T>>(vec.Clone());
}),
py::arg("vec"), doc.Parameters.ctor.doc_1args_vec)
.def(py::init([](const AbstractValue& value) {
return std::make_unique<Parameters<T>>(value.Clone());
}),
py::arg("value"), doc.Parameters.ctor.doc_1args_value)
.def("num_numeric_parameter_groups",
&Parameters<T>::num_numeric_parameter_groups,
doc.Parameters.num_numeric_parameter_groups.doc)
Expand All @@ -971,14 +990,17 @@ void DefineParameters(py::module m) {
doc.Parameters.get_mutable_numeric_parameter.doc)
.def("get_numeric_parameters", &Parameters<T>::get_numeric_parameters,
py_rvp::reference_internal, doc.Parameters.get_numeric_parameters.doc)
// TODO(eric.cousineau): Should this C++ code constrain the number of
// parameters???
.def("set_numeric_parameters", &Parameters<T>::set_numeric_parameters,
// WARNING: This will DELETE the existing parameters. See C++
// `AddValueInstantiation` for more information.
// Keep alive, ownership: `value` keeps `self` alive.
py::keep_alive<2, 1>(), py::arg("numeric_params"),
doc.Parameters.set_numeric_parameters.doc)
.def(
"set_numeric_parameters",
[](Parameters<T>& self, const DiscreteValues<T>& numeric_params) {
// TODO(eric.cousineau): Should this C++ code constrain the number
// of parameters???
//
// WARNING: This will DELETE the existing parameters. See C++
// `AddValueInstantiation` for more information.
self.set_numeric_parameters(numeric_params.Clone());
},
py::arg("numeric_params"), doc.Parameters.set_numeric_parameters.doc)
.def(
"get_abstract_parameter",
[](const Parameters<T>* self, int index) -> auto& {
Expand All @@ -996,11 +1018,14 @@ void DefineParameters(py::module m) {
.def("get_abstract_parameters", &Parameters<T>::get_abstract_parameters,
py_rvp::reference_internal,
doc.Parameters.get_abstract_parameters.doc)
.def("set_abstract_parameters", &Parameters<T>::set_abstract_parameters,
// WARNING: This will DELETE the existing parameters. See C++
// `AddValueInstantiation` for more information.
// Keep alive, ownership: `value` keeps `self` alive.
py::keep_alive<2, 1>(), py::arg("abstract_params"),
.def(
"set_abstract_parameters",
[](Parameters<T>& self, const AbstractValues& abstract_params) {
// WARNING: This will DELETE the existing parameters. See C++
// `AddValueInstantiation` for more information.
self.set_abstract_parameters(abstract_params.Clone());
},
py::arg("abstract_params"),
doc.Parameters.set_abstract_parameters.doc)
.def(
"SetFrom",
Expand Down Expand Up @@ -1065,14 +1090,37 @@ void DefineContinuousState(py::module m) {
auto continuous_state = DefineTemplateClassWithDefault<ContinuousState<T>>(
m, "ContinuousState", GetPyParam<T>(), doc.ContinuousState.doc);
DefClone(&continuous_state);
continuous_state
.def(py::init<unique_ptr<VectorBase<T>>>(), py::arg("state"),
doc.ContinuousState.ctor.doc_1args_state)
.def(py::init<unique_ptr<VectorBase<T>>, int, int, int>(),
continuous_state // BR
.def(py::init<>(), doc.ContinuousState.ctor.doc_0args)
// In the next pair of overloads, we'll try matching on BasicVector in
// order to preserve its subtype across cloning. In the subsequent pair
// of overloads, we'll also allow VectorBase.
.def(py::init([](const BasicVector<T>& state) {
return std::make_unique<ContinuousState<T>>(state.Clone());
}),
py::arg("state"), doc.ContinuousState.ctor.doc_1args_state)
.def(py::init([](const BasicVector<T>& state, int num_q, int num_v,
int num_z) {
return std::make_unique<ContinuousState<T>>(
state.Clone(), num_q, num_v, num_z);
}),
py::arg("state"), py::arg("num_q"), py::arg("num_v"),
py::arg("num_z"),
doc.ContinuousState.ctor.doc_4args_state_num_q_num_v_num_z)
.def(py::init([](const VectorBase<T>& state) {
return std::make_unique<ContinuousState<T>>(
std::make_unique<BasicVector<T>>(state.CopyToVector()));
}),
py::arg("state"), doc.ContinuousState.ctor.doc_1args_state)
.def(py::init(
[](const VectorBase<T>& state, int num_q, int num_v, int num_z) {
return std::make_unique<ContinuousState<T>>(
std::make_unique<BasicVector<T>>(state.CopyToVector()),
num_q, num_v, num_z);
}),
py::arg("state"), py::arg("num_q"), py::arg("num_v"),
py::arg("num_z"),
doc.ContinuousState.ctor.doc_4args_state_num_q_num_v_num_z)
.def(py::init<>(), doc.ContinuousState.ctor.doc_0args)
.def("size", &ContinuousState<T>::size, doc.ContinuousState.size.doc)
.def("num_q", &ContinuousState<T>::num_q, doc.ContinuousState.num_q.doc)
.def("num_v", &ContinuousState<T>::num_v, doc.ContinuousState.num_v.doc)
Expand Down Expand Up @@ -1134,9 +1182,13 @@ void DefineDiscreteValues(py::module m) {
m, "DiscreteValues", GetPyParam<T>(), doc.DiscreteValues.doc);
DefClone(&discrete_values);
discrete_values
.def(py::init<unique_ptr<BasicVector<T>>>(), py::arg("datum"),
doc.DiscreteValues.ctor.doc_1args_datum)
.def(py::init<std::vector<std::unique_ptr<BasicVector<T>>>&&>(),
.def(py::init([](const BasicVector<T>& datum) {
return std::make_unique<DiscreteValues<T>>(datum.Clone());
}),
py::arg("datum"), doc.DiscreteValues.ctor.doc_1args_datum)
.def(py::init([](const std::vector<const BasicVector<T>*>& data) {
return std::make_unique<DiscreteValues<T>>(CloneVectorOfPointers(data));
}),
py::arg("data"), doc.DiscreteValues.ctor.doc_1args_data)
.def(py::init<>(), doc.DiscreteValues.ctor.doc_0args)
.def("num_groups", &DiscreteValues<T>::num_groups,
Expand Down
25 changes: 25 additions & 0 deletions bindings/pydrake/systems/test/custom_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
CacheEntryValue,
CacheIndex,
Context,
ContinuousState_,
ContinuousStateIndex,
DependencyTicket,
Diagram,
DiagramBuilder,
DiagramBuilder_,
DiscreteStateIndex,
DiscreteValues,
EventStatus,
Expand Down Expand Up @@ -909,6 +911,29 @@ def DoCalcTimeDerivatives(self, context, derivatives):
self.assertEqual(
system.EvalTimeDerivatives(context=context).size(), 6)

# The constructors for ContinuousState(state: VectorBase, ...)
# used when diagrams are in play receives special treatment in
# the bindings for ContinuousState. We'll exercise it here.
builder = DiagramBuilder_[T]()
n = 2
for _ in range(n):
builder.AddSystem(TrivialSystem(index))
diagram = builder.Build()
diagram_context = diagram.CreateDefaultContext()
diagram_state_copy = ContinuousState_[T](
state=diagram_context.get_continuous_state().get_vector())
self.assertEqual(diagram_state_copy.size(), 6*n)
diagram_state_copy = ContinuousState_[T](
state=diagram_context.get_continuous_state().get_vector(),
num_q=2*n,
num_v=1*n,
num_z=3*n,
)
self.assertEqual(diagram_state_copy.num_q(), 2*n)
self.assertEqual(diagram_state_copy.num_v(), 1*n)
self.assertEqual(diagram_state_copy.num_z(), 3*n)
self.assertEqual(diagram_state_copy.size(), 6*n)

def test_discrete_state_api(self):
# N.B. Since this has trivial operations, we can test all scalar types.
for T in [float, AutoDiffXd, Expression]:
Expand Down
17 changes: 16 additions & 1 deletion bindings/pydrake/systems/test/general_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
PassThrough, PassThrough_,
ZeroOrderHold,
)
from pydrake.systems.test.test_util import MyVector2

# TODO(eric.cousineau): The scope of this test file and `custom_test.py`
# is poor. Move these tests into `framework_test` and `analysis_test`, and
Expand Down Expand Up @@ -318,7 +319,21 @@ def system_callback(system, context, event): pass

def test_continuous_state_api(self):
self.assertEqual(ContinuousState().size(), 0)
self.assertEqual(ContinuousState(state=BasicVector(2)).size(), 2)
custom_vector = MyVector2(np.ones(2))
state = ContinuousState(state=custom_vector)
self.assertIsInstance(state.get_vector(), MyVector2)
self.assertEqual(state.size(), 2)
self.assertEqual(state.num_q(), 0)
self.assertEqual(state.num_v(), 0)
self.assertEqual(state.num_z(), 2)
state = ContinuousState(state=custom_vector, num_q=1, num_v=1, num_z=0)
self.assertIsInstance(state.get_vector(), MyVector2)
self.assertEqual(state.size(), 2)
self.assertEqual(state.num_q(), 1)
self.assertEqual(state.num_v(), 1)
self.assertEqual(state.num_z(), 0)
state = ContinuousState(state=BasicVector(2))
self.assertEqual(state.size(), 2)
state = ContinuousState(state=BasicVector(np.arange(6)), num_q=3,
num_v=2, num_z=1)
state_clone = state.Clone()
Expand Down