Skip to content
This repository has been archived by the owner on Sep 1, 2024. It is now read-only.

PID code and Update Readme #165

Merged
merged 25 commits into from
Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
77e34b2
init PR for trajectory based model
natolambert Jul 26, 2022
9c6a241
clean model file, add replay buffer class
natolambert Jul 26, 2022
64e5245
setup data collection in notebook, will do precommits soon
natolambert Aug 9, 2022
18f7b34
Merge remote-tracking branch 'origin' into traj-model
natolambert Aug 9, 2022
bd9aa62
notebook loss goes down
natolambert Aug 10, 2022
4ead491
initial notebook added
natolambert Aug 10, 2022
2818f46
add ensemble support to notebook
natolambert Aug 15, 2022
85f3a16
substantially clean notebook, add text
natolambert Aug 15, 2022
28308ab
remove unused changes
natolambert Aug 16, 2022
196431c
clean PID implementation
natolambert Aug 16, 2022
2cb94ac
minor text changes
natolambert Aug 16, 2022
4b25138
make batch friendly, add tests
natolambert Aug 16, 2022
ebd5751
lint tests
natolambert Aug 16, 2022
55fdec4
reset precommits to main
natolambert Aug 16, 2022
8a28b01
make tests deterministic
natolambert Aug 19, 2022
96876e2
make tests deterministic
natolambert Aug 19, 2022
e708945
fix docstring
natolambert Aug 22, 2022
baa128c
fix notebook (some content wasn't saved on final save)
natolambert Aug 24, 2022
8f1682a
add colab to readme
natolambert Aug 24, 2022
f6d3745
clean PR for mbrl-lib
natolambert Aug 31, 2022
bc5fe1f
reset precommit file
natolambert Aug 31, 2022
1e971a8
more permament colab link via githubtocolab.com
natolambert Aug 31, 2022
9ac08b7
run precommits
natolambert Sep 8, 2022
fe672ba
fix line length
natolambert Sep 8, 2022
6a5bd9d
reset precommits
natolambert Sep 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ installation and are specific to models of type
We are planning to extend this in the future; if you have useful suggestions
don't hesitate to raise an issue or submit a pull request!

## Advanced Examples
MBRL-Lib can be used for many different research projects in the subject area.
We have support for the following projects:
* [Trajectory-based Dynamics Model](https://arxiv.org/abs/2012.09156) Training [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/natolambert/mbrl-lib-dev/blob/main/notebooks/traj_based_model.ipynb)

## Documentation
Please check out our **[documentation](https://facebookresearch.github.io/mbrl-lib/)**
and don't hesitate to raise issues or contribute if anything is unclear!
Expand Down
1 change: 1 addition & 0 deletions mbrl/planning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .core import Agent, RandomAgent, complete_agent_cfg, load_agent
from .linear_feedback import PIDAgent
from .trajectory_opt import (
CEMOptimizer,
ICEMOptimizer,
Expand Down
122 changes: 122 additions & 0 deletions mbrl/planning/linear_feedback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional

import numpy as np

from .core import Agent


class PIDAgent(Agent):
"""
Agent that reacts via an internal set of proportional–integral–derivative controllers.

A broad history of the PID controller can be found here:
https://en.wikipedia.org/wiki/PID_controller.

Args:
k_p (np.ndarry): proportional control coeff (Nx1)
k_i (np.ndarry): integral control coeff (Nx1)
k_d (np.ndarry): derivative control coeff (Nx1)
target (np.ndarry): setpoint (Nx1)
state_mapping (np.ndarry): indices of the state vector to apply the PID control to.
E.g. for a system with states [angle, angle_vel, position, position_vel], state_mapping
of [1, 3] and dim of 2 will apply the PID to angle_vel and position_vel variables.
batch_dim (int): number of samples to compute actions for simultaneously
"""

def __init__(
self,
k_p: np.ndarray,
k_i: np.ndarray,
k_d: np.ndarray,
target: np.ndarray,
state_mapping: Optional[np.ndarray] = None,
batch_dim: Optional[int] = 1,
):
super().__init__()
assert len(k_p) == len(k_i) == len(k_d) == len(target)
self.n_dof = len(k_p)

# State mapping defaults to first N states
if state_mapping is not None:
assert len(state_mapping) == len(target)
self.state_mapping = state_mapping
else:
self.state_mapping = np.arange(0, self.n_dof)

self.batch_dim = batch_dim

self._prev_error = np.zeros((self.n_dof, self.batch_dim))
self._cum_error = np.zeros((self.n_dof, self.batch_dim))

self.k_p = np.repeat(k_p[:, np.newaxis], self.batch_dim, axis=1)
self.k_i = np.repeat(k_i[:, np.newaxis], self.batch_dim, axis=1)
self.k_d = np.repeat(k_d[:, np.newaxis], self.batch_dim, axis=1)
self.target = np.repeat(target[:, np.newaxis], self.batch_dim, axis=1)

def act(self, obs: np.ndarray, **_kwargs) -> np.ndarray:
"""Issues an action given an observation.

This method optimizes a given observation or batch of observations for a
one-step action choice.


Args:
obs (np.ndarray): the observation for which the action is needed either N x 1 or N x B,
where N is the state dim and B is the batch size.

Returns:
(np.ndarray): the action outputted from the PID, either shape n_dof x 1 or n_dof x B.
"""
if obs.ndim == 1:
obs = np.expand_dims(obs, -1)
if len(obs) > self.n_dof:
pos = obs[self.state_mapping]
else:
pos = obs

error = self.target - pos
self._cum_error += error

P_value = np.multiply(self.k_p, error)
I_value = np.multiply(self.k_i, self._cum_error)
D_value = np.multiply(self.k_d, (error - self._prev_error))
self._prev_error = error
action = P_value + I_value + D_value
return action

def reset(self):
"""
Reset internal errors.
"""
self._prev_error = np.zeros((self.n_dof, self.batch_dim))
self._cum_error = np.zeros((self.n_dof, self.batch_dim))

def get_errors(self):
return self._prev_error, self._cum_error

def _get_P(self):
return self.k_p

def _get_I(self):
return self.k_i

def _get_D(self):
return self.k_d

def _get_targets(self):
return self.target

def get_parameters(self):
"""
Returns the parameters of the PID agent concatenated.

Returns:
(np.ndarray): the parameters.
"""
return np.stack(
(self._get_P(), self._get_I(), self._get_D(), self._get_targets())
).flatten()
81 changes: 81 additions & 0 deletions tests/core/test_planning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import pytest

import mbrl.planning as planning


def create_pid_agent(dim,
state_mapping=None,
batch_dim=1):
agent = planning.PIDAgent(k_p=np.ones(dim, ),
k_i=np.ones(dim, ),
k_d=np.ones(dim, ),
target=np.zeros(dim, ),
state_mapping=state_mapping,
batch_dim=batch_dim,
)
return agent


def test_pid_agent_one_dim():
"""
This test covers the creation of PID agents in the most basic form.
"""
pid = create_pid_agent(dim=1)
pid.reset()
init_obs = np.array([2.2408932])
act = pid.act(init_obs)

# check action computation
assert act == pytest.approx(-6.722, 0.1)

# check reset
pid.reset()
prev_error, cum_error = pid.get_errors()
assert np.sum(prev_error) == np.sum(cum_error) == 0


def test_pid_agent_multi_dim():
"""
This test covers regular updates for the multi-dim PID agent.
"""
pid = create_pid_agent(dim=2, state_mapping=np.array([1, 3]), )
init_obs = np.array([ 0.95008842, -0.15135721, -0.10321885, 0.4105985 ])
act1 = pid.act(init_obs)
next_obs = np.array([0.14404357, 1.45427351, 0.76103773, 0.12167502])
act2 = pid.act(next_obs)
assert act1 + act2 == pytest.approx([-3.908, -1.596], 0.1)

# check reset
pid.reset()
prev_error, cum_error = pid.get_errors()
assert np.sum(prev_error) == np.sum(cum_error) == 0


def test_pid_agent_batch(batch_dim=5):
"""
Tests the agent for batch-mode computation of actions.
"""
pid = create_pid_agent(dim=2, state_mapping=np.array([1, 3]), batch_dim=batch_dim)

init_obs = np.array([[ 0.95008842, -0.15135721, -0.10321885, 0.4105985 , 0.14404357],
[ 1.45427351, 0.76103773, 0.12167502, 0.44386323, 0.33367433],
[ 1.49407907, -0.20515826, 0.3130677 , -0.85409574, -2.55298982],
[ 0.6536186 , 0.8644362 , -0.74216502, 2.26975462, -1.45436567]])
act1 = pid.act(init_obs)
next_obs = np.array([[ 0.04575852, -0.18718385, 1.53277921, 1.46935877, 0.15494743],
[ 0.37816252, -0.88778575, -1.98079647, -0.34791215, 0.15634897],
[ 1.23029068, 1.20237985, -0.38732682, -0.30230275, -1.04855297],
[-1.42001794, -1.70627019, 1.9507754 , -0.50965218, -0.4380743 ]])
act2 = pid.act(next_obs)
assert (act1 + act2)[0] == pytest.approx([-5.497, 0.380, 5.577, -0.287, -1.470], 0.1)

# check reset
pid.reset()
prev_error, cum_error = pid.get_errors()
assert np.sum(prev_error) == np.sum(cum_error) == 0