diff --git a/chirho/indexed/internals.py b/chirho/indexed/internals.py index b3038148..d684b8ad 100644 --- a/chirho/indexed/internals.py +++ b/chirho/indexed/internals.py @@ -1,6 +1,5 @@ -import functools import numbers -from typing import Dict, Hashable, List, Optional, TypeVar, Union +from typing import Dict, Hashable, Optional, TypeVar, Union import pyro import pyro.infer.reparam @@ -38,21 +37,6 @@ def _gather_number( ) -@functools.singledispatch -def index_select_from_array_like(arr, dim: int, indices: List[int]): - raise NotImplementedError(f"index_select not implemented for {type(arr)}") - - -@index_select_from_array_like.register -def _index_select_from_array_like_tensor( - arr: torch.Tensor, dim: int, indices: List[int] -) -> torch.Tensor: - indices_tensor = torch.tensor(indices, device=arr.device, dtype=torch.long) - return arr.index_select(dim=dim, index=indices_tensor) - - -# TODO _gather_tensor now works for any array like with an index_select_from_array_like implementation. -# Can we dispatch for array like objects generally? @gather.register def _gather_tensor( value: torch.Tensor, @@ -75,8 +59,9 @@ def _gather_tensor( dim = name_to_dim[name] - event_dim if len(result.shape) < -dim or result.shape[dim] == 1: continue - result = index_select_from_array_like( - result, name_to_dim[name] - event_dim, list(sorted(indices)) + result = result.index_select( + name_to_dim[name] - event_dim, + torch.tensor(list(sorted(indices)), device=value.device, dtype=torch.long), ) return result