diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index bf462b4b6..507cab7d1 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -2348,6 +2348,24 @@ def masked_fill(self, mask: Tensor, value: float | bool) -> TensorDictBase: """ raise NotImplementedError + def masked_select_(self, mask: Tensor) -> TensorDictBase: + """Masks all tensors of the TensorDict. + + Args: + mask (torch.Tensor): boolean mask to be used for the tensors. + Shape must match the TensorDict batch_size. + + Examples: + >>> td = TensorDict(source={'a': torch.zeros(3, 4)}, + ... batch_size=[3]) + >>> mask = torch.tensor([True, False, False]) + >>> td.masked_select_(mask) + >>> td.get("a") + tensor([[0., 0., 0., 0.]]) + + """ + raise NotImplementedError + def masked_select(self, mask: Tensor) -> TensorDictBase: """Masks all tensors of the TensorDict and return a new TensorDict instance with similar keys pointing to masked values. @@ -4006,6 +4024,55 @@ def to(tensor): f"instance, {dest} not allowed" ) + def masked_select_(self, mask: Tensor) -> TensorDictBase: + """Masks all tensors of the TensorDict. + + Args: + mask (torch.Tensor): boolean mask to be used for the tensors. + Shape must match the TensorDict batch_size. + + Examples: + >>> td = TensorDict(source={'a': torch.zeros(3, 4)}, + ... batch_size=[3]) + >>> mask = torch.tensor([True, False, False]) + >>> td.masked_select_(mask) + >>> td_mask.get("a") + tensor([[0., 0., 0., 0.]]) + + """ + d={} + for key, val in self.items(): + if hasattr(val, "masked_select_"): # modify inplace supported, or nested TensorDict + val_sel = val.masked_select_(mask) # val_sel should be val + else: + val_sel = val[mask] + d[key] = val_sel + dim = int(mask.sum().item()) + other_dim = self.shape[mask.ndim :] + new_batch_size = torch.Size([dim, *other_dim]) + for key, val in d.items(): + self._set(key, val) + self.batch_size = new_batch_size + return self + + # def masked_select(self, mask: Tensor) -> TensorDictBase: + # """Masks all tensors of the TensorDict and return a new TensorDict instance with similar keys pointing to masked values. + + # Args: + # mask (torch.Tensor): boolean mask to be used for the tensors. + # Shape must match the TensorDict batch_size. + + # Examples: + # >>> td = TensorDict(source={'a': torch.zeros(3, 4)}, + # ... batch_size=[3]) + # >>> mask = torch.tensor([True, False, False]) + # >>> td_mask = td.masked_select(mask) + # >>> td_mask.get("a") + # tensor([[0., 0., 0., 0.]]) + + # """ + # return self.clone().masked_select_(mask) + def masked_fill_(self, mask: Tensor, value: float | int | bool) -> TensorDictBase: for item in self.values(): mask_expand = expand_as_right(mask, item)