Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 610906071
Change-Id: I6cacdc7f0df1a8a2cb3fe2330671ea8897e70c1b
  • Loading branch information
Brax Team authored and btaba committed Feb 28, 2024
1 parent 3c109cf commit 0825bcb
Show file tree
Hide file tree
Showing 45 changed files with 66 additions and 1,386 deletions.
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
include brax/envs/assets/*.xml
recursive-include brax/experimental/barkour_v0 *.csv *.stl *.xml
recursive-include brax/experimental/barkour_vb *.csv *.stl *.xml
recursive-include brax/test_data *.xml *.stl *.obj *.urdf
recursive-include brax/visualizer *
6 changes: 3 additions & 3 deletions brax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,15 +531,15 @@ def qd_idx(self, link_type: str) -> jax.Array:

def q_size(self) -> int:
"""Returns the size of the q vector (joint position) for this system."""
return sum([Q_WIDTHS[t] for t in self.link_types])
return self.nq

def qd_size(self) -> int:
"""Returns the size of the qd vector (joint velocity) for this system."""
return sum([QD_WIDTHS[t] for t in self.link_types])
return self.nv

def act_size(self) -> int:
"""Returns the act dimension for the system."""
return self.actuator.q_id.shape[0]
return self.nu


# below are some operation dispatch derivations
Expand Down
1 change: 1 addition & 0 deletions brax/envs/ant.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def reset(self, rng: jax.Array) -> State:
def step(self, state: State, action: jax.Array) -> State:
"""Run one timestep of the environment's dynamics."""
pipeline_state0 = state.pipeline_state
assert pipeline_state0 is not None
pipeline_state = self.pipeline_step(pipeline_state0, action)

velocity = (pipeline_state.x.pos[0] - pipeline_state0.x.pos[0]) / self.dt
Expand Down
1 change: 1 addition & 0 deletions brax/envs/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def reset(self, rng: jax.Array) -> State:
return State(pipeline_state, obs, reward, done)

def step(self, state: State, action: jax.Array) -> State:
assert state.pipeline_state is not None
self._step_count += 1
vel = state.pipeline_state.xd.vel + (action > 0) * self._dt
pos = state.pipeline_state.x.pos + vel * self._dt
Expand Down
1 change: 1 addition & 0 deletions brax/envs/half_cheetah.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def reset(self, rng: jax.Array) -> State:
def step(self, state: State, action: jax.Array) -> State:
"""Runs one timestep of the environment's dynamics."""
pipeline_state0 = state.pipeline_state
assert pipeline_state0 is not None
pipeline_state = self.pipeline_step(pipeline_state0, action)

x_velocity = (
Expand Down
1 change: 1 addition & 0 deletions brax/envs/hopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def reset(self, rng: jax.Array) -> State:
def step(self, state: State, action: jax.Array) -> State:
"""Runs one timestep of the environment's dynamics."""
pipeline_state0 = state.pipeline_state
assert pipeline_state0 is not None
pipeline_state = self.pipeline_step(pipeline_state0, action)

x_velocity = (
Expand Down
1 change: 1 addition & 0 deletions brax/envs/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def reset(self, rng: jax.Array) -> State:
return State(pipeline_state, obs, reward, done, metrics)

def step(self, state: State, action: jax.Array) -> State:
assert state.pipeline_state is not None
x_i = state.pipeline_state.x.vmap().do(
base.Transform.create(pos=self.sys.link.inertia.transform.pos)
)
Expand Down
1 change: 1 addition & 0 deletions brax/envs/walker2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def reset(self, rng: jax.Array) -> State:
def step(self, state: State, action: jax.Array) -> State:
"""Runs one timestep of the environment's dynamics."""
pipeline_state0 = state.pipeline_state
assert pipeline_state0 is not None
pipeline_state = self.pipeline_step(pipeline_state0, action)

x_velocity = (
Expand Down
66 changes: 0 additions & 66 deletions brax/experimental/barkour_v0/README.md

This file was deleted.

Binary file removed brax/experimental/barkour_v0/assets/abduction.stl
Binary file not shown.
Loading

0 comments on commit 0825bcb

Please sign in to comment.