Skip to content

Commit

Permalink
reverts modularization of gather.
Browse files Browse the repository at this point in the history
  • Loading branch information
azane committed Aug 9, 2024
1 parent 60515f0 commit 8dd39e5
Showing 1 changed file with 4 additions and 19 deletions.
23 changes: 4 additions & 19 deletions chirho/indexed/internals.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import functools
import numbers
from typing import Dict, Hashable, List, Optional, TypeVar, Union
from typing import Dict, Hashable, Optional, TypeVar, Union

import pyro
import pyro.infer.reparam
Expand Down Expand Up @@ -38,21 +37,6 @@ def _gather_number(
)


@functools.singledispatch
def index_select_from_array_like(arr, dim: int, indices: List[int]):
raise NotImplementedError(f"index_select not implemented for {type(arr)}")


@index_select_from_array_like.register
def _index_select_from_array_like_tensor(
arr: torch.Tensor, dim: int, indices: List[int]
) -> torch.Tensor:
indices_tensor = torch.tensor(indices, device=arr.device, dtype=torch.long)
return arr.index_select(dim=dim, index=indices_tensor)


# TODO _gather_tensor now works for any array like with an index_select_from_array_like implementation.
# Can we dispatch for array like objects generally?
@gather.register
def _gather_tensor(
value: torch.Tensor,
Expand All @@ -75,8 +59,9 @@ def _gather_tensor(
dim = name_to_dim[name] - event_dim
if len(result.shape) < -dim or result.shape[dim] == 1:
continue
result = index_select_from_array_like(
result, name_to_dim[name] - event_dim, list(sorted(indices))
result = result.index_select(
name_to_dim[name] - event_dim,
torch.tensor(list(sorted(indices)), device=value.device, dtype=torch.long),
)
return result

Expand Down

0 comments on commit 8dd39e5

Please sign in to comment.