From 4b2e7f45e6c88b84e2b3021775df39d3add2a399 Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Wed, 4 Dec 2024 20:14:26 +0100 Subject: [PATCH 1/2] Modified deterministic distribution for 'TorchMultiCategorical' such that we get a proper (i.e. stacked) tensor instead a list of single dimensional tensors. Used padding so, this works also with different action distribution lengths. Signed-off-by: simonsays1980 --- rllib/models/torch/torch_distributions.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/rllib/models/torch/torch_distributions.py b/rllib/models/torch/torch_distributions.py index 178e92fb64bf0..fe2b29b6de984 100644 --- a/rllib/models/torch/torch_distributions.py +++ b/rllib/models/torch/torch_distributions.py @@ -498,7 +498,16 @@ def from_logits( return TorchMultiCategorical(categoricals=categoricals) def to_deterministic(self) -> "TorchMultiDistribution": - return TorchMultiDistribution([cat.to_deterministic() for cat in self._cats]) + if self._cats[0].probs is not None: + probs_or_logits = nn.utils.rnn.pad_sequence( + [cat.logits.t() for cat in self._cats], padding_value=-torch.inf + ) + else: + probs_or_logits = nn.utils.rnn.pad_sequence( + [cat.logits.t() for cat in self._cats], padding_value=-torch.inf + ) + + return TorchDeterministic(loc=torch.argmax(probs_or_logits, dim=0)) @DeveloperAPI From 6897578cc2f52b1b625dab7395f7954aafd2ebc1 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Fri, 6 Dec 2024 12:27:59 +0100 Subject: [PATCH 2/2] Apply suggestions from code review Signed-off-by: Sven Mika --- rllib/models/torch/torch_distributions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/models/torch/torch_distributions.py b/rllib/models/torch/torch_distributions.py index fe2b29b6de984..e4a87b53ebc9e 100644 --- a/rllib/models/torch/torch_distributions.py +++ b/rllib/models/torch/torch_distributions.py @@ -497,7 +497,7 @@ def from_logits( return TorchMultiCategorical(categoricals=categoricals) - def to_deterministic(self) -> "TorchMultiDistribution": + def to_deterministic(self) -> "TorchDeterministic": if self._cats[0].probs is not None: probs_or_logits = nn.utils.rnn.pad_sequence( [cat.logits.t() for cat in self._cats], padding_value=-torch.inf