Skip to content

Commit

Permalink
add dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
thomashirtz committed Nov 5, 2024
1 parent 5ab7166 commit 2b92b55
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions jumanji/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def restart(
observation: Observation,
extras: Optional[Dict] = None,
shape: Union[int, Sequence[int]] = (),
dtype: Union[jnp.dtype, type] = float,
) -> TimeStep:
"""Returns a `TimeStep` with `step_type` set to `StepType.FIRST`.
Expand All @@ -114,8 +115,8 @@ def restart(
extras = extras or {}
return TimeStep(
step_type=StepType.FIRST,
reward=jnp.zeros(shape, dtype=float),
discount=jnp.ones(shape, dtype=float),
reward=jnp.zeros(shape, dtype=dtype),
discount=jnp.ones(shape, dtype=dtype),
observation=observation,
extras=extras,
)
Expand All @@ -127,6 +128,7 @@ def transition(
discount: Optional[Array] = None,
extras: Optional[Dict] = None,
shape: Union[int, Sequence[int]] = (),
dtype: Union[jnp.dtype, type] = float,
) -> TimeStep:
"""Returns a `TimeStep` with `step_type` set to `StepType.MID`.
Expand All @@ -145,7 +147,7 @@ def transition(
Returns:
TimeStep identified as a transition.
"""
discount = discount if discount is not None else jnp.ones(shape, dtype=float)
discount = discount if discount is not None else jnp.ones(shape, dtype=dtype)
extras = extras or {}
return TimeStep(
step_type=StepType.MID,
Expand All @@ -161,6 +163,7 @@ def termination(
observation: Observation,
extras: Optional[Dict] = None,
shape: Union[int, Sequence[int]] = (),
dtype: Union[jnp.dtype, type] = float,
) -> TimeStep:
"""Returns a `TimeStep` with `step_type` set to `StepType.LAST`.
Expand All @@ -182,7 +185,7 @@ def termination(
return TimeStep(
step_type=StepType.LAST,
reward=reward,
discount=jnp.zeros(shape, dtype=float),
discount=jnp.zeros(shape, dtype=dtype),
observation=observation,
extras=extras,
)
Expand All @@ -194,6 +197,7 @@ def truncation(
discount: Optional[Array] = None,
extras: Optional[Dict] = None,
shape: Union[int, Sequence[int]] = (),
dtype: Union[jnp.dtype, type] = float,
) -> TimeStep:
"""Returns a `TimeStep` with `step_type` set to `StepType.LAST`.
Expand All @@ -211,7 +215,7 @@ def truncation(
Returns:
TimeStep identified as the truncation of an episode.
"""
discount = discount if discount is not None else jnp.ones(shape, dtype=float)
discount = discount if discount is not None else jnp.ones(shape, dtype=dtype)
extras = extras or {}
return TimeStep(
step_type=StepType.LAST,
Expand Down

0 comments on commit 2b92b55

Please sign in to comment.