Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 501652739
Change-Id: I5232dd2c1dbc5a3b204d076be8f3d3fe51abf2b1
  • Loading branch information
Brax Team authored and btaba committed Jan 13, 2023
1 parent 29d751e commit 6211f46
Show file tree
Hide file tree
Showing 45 changed files with 1,581 additions and 559 deletions.
6 changes: 3 additions & 3 deletions brax/tools/urdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,11 @@ def expand_node(self,

else:
warnings.warn(f'No collider found on link {node}.')
if inertia.find("mass") is None:

if inertia.find('mass') is None:
body.mass += 1.
else:
body.mass = float(inertia.find("mass").get("value"))
body.mass = float(inertia.find('mass').get('value'))

# TODO: load inertia
body.inertia.x, body.inertia.y, body.inertia.z = 1., 1., 1.
Expand Down
17 changes: 15 additions & 2 deletions brax/v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,12 +349,23 @@ class Mesh(Geometry):
The mesh is expected to be in the counter-clockwise winding order.
Attributes:
face: (num_faces, num_face_vertices) vertices associated with each face
vert: (num_verts, 3) spatial coordinates associated with each vertex
face: (num_faces, num_face_vertices) vertices associated with each face
"""

face: jp.ndarray
vert: jp.ndarray
face: jp.ndarray


@struct.dataclass
class Convex(Mesh):
"""A convex mesh geometry.
Attributes:
unique_edge: (num_unique, 2) vert index associated with each unique edge
"""

unique_edge: jp.ndarray


@struct.dataclass
Expand Down Expand Up @@ -405,12 +416,14 @@ class State:
qd: (qd_size,) joint velocity vector
x: (num_links,) link position in world frame
xd: (num_links,) link velocity in world frame
contact: calculated contacts
"""

q: jp.ndarray
qd: jp.ndarray
x: jp.ndarray
xd: jp.ndarray
contact: Optional[Contact]


@struct.dataclass
Expand Down
87 changes: 87 additions & 0 deletions brax/v2/generalized/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2022 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint:disable=g-multiple-import
"""Base types for generalized pipeline."""

from brax.v2 import base
from brax.v2.base import Inertia, Motion, System, Transform
from flax import struct
from jax import numpy as jp


@struct.dataclass
class State(base.State):
"""Dynamic state that changes after every step.
Attributes:
com: center of mass position
cinr: inertia in com frame
cd: body velocities in com frame
cdof: dofs in com frame
cdofd: cdof velocity
mass_mx: (qd_size, qd_size) mass matrix
mass_mx_inv: (qd_size, qd_size) inverse mass matrix
contact: calculated contacts
con_jac: constraint jacobian
con_pos: constraint position
con_diag: constraint A diagonal
qf_smooth: smooth dynamics force
qf_constraint: (qd_size,) force from constraints (collision etc)
qdd: (qd_size,) joint acceleration vector
"""

# position/velocity based terms are updated at the end of each step:
com: jp.ndarray
cinr: Inertia
cd: Motion
cdof: Motion
cdofd: Motion
mass_mx: jp.ndarray
mass_mx_inv: jp.ndarray
con_jac: jp.ndarray
con_pos: jp.ndarray
con_diag: jp.ndarray
# acceleration based terms are calculated using terms from the previous step:
qf_smooth: jp.ndarray
qf_constraint: jp.ndarray
qdd: jp.ndarray

@classmethod
def zero(cls, sys: System) -> 'State':
"""Returns an initial State given a brax system."""
return State(
q=jp.zeros(sys.q_size()),
qd=jp.zeros(sys.qd_size()),
x=Transform.zero((sys.num_links(),)),
xd=Motion.zero((sys.num_links(),)),
contact=None,
com=jp.zeros((sys.num_links(), 3)),
cinr=Inertia(
Transform.zero((sys.num_links(),)),
jp.zeros((sys.num_links(), 3, 3)),
jp.zeros((sys.num_links(),)),
),
cd=Motion.zero((sys.num_links(),)),
cdof=Motion.zero((sys.num_links(),)),
cdofd=Motion.zero((sys.num_links(),)),
mass_mx=jp.zeros((sys.num_links(), sys.num_links())),
mass_mx_inv=jp.zeros((sys.num_links(), sys.num_links())),
con_jac=jp.zeros(()),
con_pos=jp.zeros(()),
con_diag=jp.zeros(()),
qf_smooth=jp.zeros((sys.qd_size(),)),
qf_constraint=jp.zeros((sys.qd_size(),)),
qdd=jp.zeros(sys.qd_size()),
)
97 changes: 41 additions & 56 deletions brax/v2/generalized/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@

# pylint:disable=g-multiple-import
"""Functions for constraint satisfaction."""
from typing import Optional, Tuple
from typing import Tuple

from brax.v2 import math
from brax.v2 import scan
from brax.v2.base import Contact, Motion, System, Transform
from brax.v2.base import Motion, System, Transform
from brax.v2.generalized.base import State
import jax
from jax import numpy as jp
import jaxopt
from jaxopt.projection import projection_non_negative


def _pt_jac(
Expand Down Expand Up @@ -92,13 +92,13 @@ def _imp_aref(pos: jp.ndarray, vel: jp.ndarray) -> jp.ndarray:


def jac_limit(
sys: System, q: jp.ndarray
sys: System, state: State
) -> Tuple[jp.ndarray, jp.ndarray, jp.ndarray]:
"""Calculates the jacobian for angle limits in dof frame.
Args:
sys: the brax system
q: joint angle vector
state: generalized state
Returns:
jac: the angle limit jacobian
Expand All @@ -111,8 +111,8 @@ def jac_limit(
# determine q and qd indices for non-free joints
q_idx, qd_idx = sys.q_idx('123'), sys.qd_idx('123')

pos_min = q[q_idx] - sys.dof.limit[0][qd_idx]
pos_max = sys.dof.limit[1][qd_idx] - q[q_idx]
pos_min = state.q[q_idx] - sys.dof.limit[0][qd_idx]
pos_max = sys.dof.limit[1][qd_idx] - state.q[q_idx]
pos = jp.minimum(jp.minimum(pos_min, pos_max), 0)

side = ((pos_min < pos_max) * 2 - 1) * (pos < 0)
Expand All @@ -123,28 +123,26 @@ def jac_limit(


def jac_contact(
sys: System, com: jp.ndarray, cdof: Motion, contact: Optional[Contact]
sys: System, state: State
) -> Tuple[jp.ndarray, jp.ndarray, jp.ndarray]:
"""Calculates the jacobian for contact constraints.
Args:
sys: the brax system
com: center of mass position
cdof: dofs in com frame
contact: contacts computed for link geometries
state: generalized state
Returns:
jac: the contact jacobian
pos: contact position in constraint frame
diag: approximate diagonal of A matrix
"""
if contact is None:
if state.contact is None:
return jp.zeros((0, sys.qd_size())), jp.zeros((0,)), jp.zeros((0,))

def row_fn(contact):
link_a, link_b = contact.link_idx
a = _pt_jac(sys, com, cdof, contact.pos, link_a)
b = _pt_jac(sys, com, cdof, contact.pos, link_b)
a = _pt_jac(sys, state.com, state.cdof, contact.pos, link_a)
b = _pt_jac(sys, state.com, state.cdof, contact.pos, link_b)
diff = b - a

# 4 pyramidal friction directions
Expand All @@ -163,82 +161,69 @@ def row_fn(contact):
lambda x: x * (contact.penetration > 0), (jac, pos, diag)
)

return jax.tree_map(jp.concatenate, jax.vmap(row_fn)(contact))
return jax.tree_map(jp.concatenate, jax.vmap(row_fn)(state.contact))


def jacobian(
sys: System,
q: jp.ndarray,
com: jp.ndarray,
cdof: Motion,
contact: Optional[Contact],
) -> Tuple[jp.ndarray, jp.ndarray, jp.ndarray]:
"""Calculates the full constraint jacobian and constraint position.
def jacobian(sys: System, state: State) -> State:
"""Calculates the constraint jacobian, position, and A matrix diagonal.
Args:
sys: a brax system
q: joint position vector
com: center of mass position
cdof: dofs in com frame
contact: contacts computed for link geometries
state: generalized state
Returns:
jac: the constraint jacobian
pos: position in constraint frame
diag: approximate diagonal of A matrix
state: generalized state with jac, pos, diag updated
"""
jpds = jac_contact(sys, com, cdof, contact), jac_limit(sys, q)
jpds = jac_contact(sys, state), jac_limit(sys, state)
jac, pos, diag = jax.tree_map(lambda *x: jp.concatenate(x), *jpds)
return state.replace(con_jac=jac, con_pos=pos, con_diag=diag)

return jax.tree_map(lambda *x: jp.concatenate(x), *jpds)


def force(
sys: System,
qd: jp.ndarray,
qf_smooth: jp.ndarray,
mass_mx_inv: jp.ndarray,
jac: jp.ndarray,
pos: jp.ndarray,
diag: jp.ndarray,
) -> jp.ndarray:
def force(sys: System, state: State) -> jp.ndarray:
"""Calculates forces that satisfy joint, collision constraints.
Args:
sys: a brax system
qd: joint velocity vector
qf_smooth: joint force vector for smooth dynamics
mass_mx_inv: inverse mass matrix
jac: the constraint jacobian
pos: position in constraint frame
diag: approximate diagonal of A matrix
state: generalized state
Returns:
qf_constraint: (qd_size,) constraint force
"""
if jac.shape[0] == 0:
return jp.zeros_like(qd)
if state.con_jac.shape[0] == 0:
return jp.zeros(sys.qd_size())

# calculate A matrix and b vector
imp, aref = _imp_aref(pos, jac @ qd)
a = jac @ mass_mx_inv @ jac.T
a = a + jp.diag(diag * (1 - imp) / imp)
b = jac @ mass_mx_inv @ qf_smooth - aref
imp, aref = _imp_aref(state.con_pos, state.con_jac @ state.qd)
a = state.con_jac @ state.mass_mx_inv @ state.con_jac.T
a = a + jp.diag(state.con_diag * (1 - imp) / imp)
b = state.con_jac @ state.mass_mx_inv @ state.qf_smooth - aref

# solve for forces in constraint frame, Ax + b = 0 s.t. x >= 0
def objective(x):
residual = a @ x + b
return jp.sum(0.5 * residual**2)

# profiling a physics step, most of the time is spent running this solver:
#
# there might still be some opportunities to speed this up. consider that
# the A matrix is positive definite. could possibly use conjugate gradient?
# we made a jax version of this: https://github.com/david-cortes/nonneg_cg
# but it was still not as fast as the projected gradient solver below.
# this is possibly due to the line search method, which is a big part
# of the cost. jaxopt uses FISTA which we did not implement
#
# another avenue worth pursuing is that these A matrices are often
# fairly sparse. perhaps worth trying some kind of random or
# learned projection to solve a smaller dense matrix at each step
pg = jaxopt.ProjectedGradient(
objective,
projection_non_negative,
jaxopt.projection.projection_non_negative,
maxiter=sys.solver_iterations,
implicit_diff=False,
maxls=5,
)

# solve and convert back to q coordinates
qf_constraint = jac.T @ pg.run(jp.zeros_like(b)).params
qf_constraint = state.con_jac.T @ pg.run(jp.zeros_like(b)).params

return qf_constraint
Loading

0 comments on commit 6211f46

Please sign in to comment.