diff --git a/rllib/models/torch/torch_distributions.py b/rllib/models/torch/torch_distributions.py index 178e92fb64bf0..e4a87b53ebc9e 100644 --- a/rllib/models/torch/torch_distributions.py +++ b/rllib/models/torch/torch_distributions.py @@ -497,8 +497,17 @@ def from_logits( return TorchMultiCategorical(categoricals=categoricals) - def to_deterministic(self) -> "TorchMultiDistribution": - return TorchMultiDistribution([cat.to_deterministic() for cat in self._cats]) + def to_deterministic(self) -> "TorchDeterministic": + if self._cats[0].probs is not None: + probs_or_logits = nn.utils.rnn.pad_sequence( + [cat.logits.t() for cat in self._cats], padding_value=-torch.inf + ) + else: + probs_or_logits = nn.utils.rnn.pad_sequence( + [cat.logits.t() for cat in self._cats], padding_value=-torch.inf + ) + + return TorchDeterministic(loc=torch.argmax(probs_or_logits, dim=0)) @DeveloperAPI