Replies: 3 comments
-
The staged model approach assumes that we have modules which are each composed of one or more stages and that we might want to intervene between these stages. However, this has some basic limitations. The most important limitation is that it is difficult to encode arbitrary DAG structures in The alternative is to define the model as a DAG of modules connected by wires, where:
Again, there is no free lunch: though we cannot intervene at a level lower than the granularity of the leaf modules, that this is unproblematic depends on good design decisions having been made, which need to be made anyway (and are probably more difficult to make) in the staged model approach. I think the DAG approach is obviously the more elegant of the two. It's also what I had originally intended when I set out to design Feedbax. However, in trying to balance other design considerations whose solutions were previously unclear to me -- such as how far I should go with structural typing re: state data, or how to manage dynamic/ODE vs. kinematic updates -- and also not having a clear idea of which patterns to use to implement the DAG approach, I ended up converging on the less elegant solution. A couple months ago I learned of Collimator. In some ways it is very similar to what I wished to achieve with Feedbax. Clearly Collimator is better designed and more mature. I'm learning from it. Importantly, it uses the DAG approach, and serves as a proof of concept in that respect, for me. Here are some important aspects of Collimator's approach, as far as I can tell after a brief review of the code:
This description is simplified; for example, output port callbacks can be set not to depend on input ports, and can also be made dependent on other values throughout the system. In the next few months, I intend to rework Feedbax to a DAG approach. This will take a lot of work, and complications will likely arise. For example, how do we intervene inside a neural network? Do we need to define individual units as leaf modules in a larger graph -- and what are the implications for compilation or runtime performance? I won't speculate any more, for now. I am just registering my intention to do this. |
Beta Was this translation helpful? Give feedback.
-
I've briefly commented on/tagged several issues to indicate how I think the switch to a DAG approach would alter, simplify, or obviate them. |
Beta Was this translation helpful? Give feedback.
-
One course is to switch completely to Collimator for the DAG-structuring and simulation of dynamical systems.
|
Beta Was this translation helpful? Give feedback.
-
A major design motivation for Feedbax is the common use case where a researcher wants to intervene on an existing optimal control experiment. In this issue, I describe the approach I've taken to this problem, and my uncertainties about it.
Currently, Feedbax implements the following solution: models are defined by Equinox modules of type
AbstractStagedModel
. Each type of model is treated as a series of operations performed on a shared PyTree of states. All state operations (AKA stages) are defined in a consistent way, each as a collection of three things: 1) a model component to be called, 2) a function that selects the subset of model inputs/states to pass to the component, and 3) a function that selects the subset of the state that the component returns/updates.To define a new staged model, we subclass
AbstractStagedModel
and implement the propertymodel_spec
, where those three things are defined for each of the model's stages.AbstractStagedModel
implements__call__
itself, to perform the state operations defined inmodel_spec
. For a more in depth description, see the documentation.What kind of PyTree is
model_spec
?Currently,
model_spec
is defined as property of typeOrderedDict[str, ModelStage]
. We use a mapping because it's nice for the stages to have names which can be referred to by the user. However, we cannot use adict
, because—while its entries maintain their insertion order since Python 3.7—its keys get sorted during PyTree flatten/uflatten operations.OrderedDict
doesn't have the same problem.[
ModelStage
] is an Equinox module whose fields describe the "three things" that define a stage. Using a module rather than (say) a tuple, makesmodel_spec
a little more readable. However, there have been some typing issues withModelState
: TypingModelStage
#23.Model state objects
AbstractStagedModel
is generic, and each of its final subclasses has a type argument that's some final subclass ofequinox.Module
. This is the type of state PyTree operated on by the model. Different staged models may operate on the same type of state object.A subclass of
AbstractStagedModel
may be composed of other types ofAbstractStagedModel
, in which case the state PyTrees associated with the higher-level model tend to be composites of the state PyTrees associated with the components.To subclass
AbstractStagedModel
we also have to implement aninit
method which takes a key, and returns a default instance of the model's state PyTree. I refer to this as "default state" and not "initial state" to distinguish it from the state that has been updated (e.g. placing the arm at its starting position) at the beginning of a trial, based on the specifications provided by a task. See the documentation for a description of how these initial states are specified.Having defined the model's computation as a series of state operations, the user can now insert interventions between the stages of an existing model, without needing to alter its source code. How?
All subclasses of
AbstractStagedModel
must include (#20) a fieldinterventions: Mapping[str, Sequence[AbstractIntervenor]]
, which maps from the names of model stages, to one or more instances ofAbstractIntervenor
. By performing surgery on this field, we can modify an existing model with interventions.AbstractStagedModel.__call__
automatically interleaves the state operations defined inintervenors
, with those inmodel_spec
.For more on what a subclass of
AbstractIntervenor
looks like, see the docs.What issues might there be with this approach?
model_spec
instead of an imperative__call__
is probably a little confusing, at first.AbstractIntervenor
, we have to (re)write as anAbstractStagedModel
with amodel_spec
. All the states we expect the user might want to intervene on, must be included as fields in the respective state PyTree.intervenors: dict[str, PyTree[Array]]
could be added to the state PyTree of the model it belongs to, into which intervenor "states" could be inserted... however this might lead to issues with inconsistent PyTree structure.Is there some other solution that would allow users to insert interventions into arbitrary points in a model, without needing to modify the model's source at intervention time? Perhaps there is a solution with hooks/callbacks that could work, especially if our models were stateful objects like they might be in PyTorch, and if we didn't need to pass around state PyTrees. But I'm not sure a solution like that is desirable in a JAX library, or what it would look like.
Returning now to the general design philosophy. Consider that in principle, there need only be a single, final class
StagedModel
that has a single, trivial model stage that on its own does nothing to the state. Any potential subclass ofAbstractStagedModel
we might want to build, could be replaced by a constructor that returns instances of this hypotheticalStagedModel
, but with an appropriate sequence of interventions inserted before each instance's single stage. That is, interventions and model stages both define operations on a model's state, and in principle they are interchangeable, though they are (currently) represented differently.So, which state operations do we include in a model to begin with, and which do we leave to potentially be defined as interventions? That's an important tradeoff our approach leaves us with. I suspect there's no avoiding that problem—no free lunch. The people designing models will always need to rely on their domain expertise not to presume too much, or too little.
Beta Was this translation helpful? Give feedback.
All reactions