Skip to content

Commit

Permalink
torch: add conjugate and update split
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Oct 26, 2022
1 parent 481809f commit 5a80323
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions autoray/autoray.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,9 @@ def import_lib_fn(backend, fn):
return import_lib_fn(backend_alt, fn)

raise ImportError(
f"autoray couldn't find function '{fn}' for backend '{backend}'.")
f"autoray couldn't find function '{fn}' for "
f"backend '{backend.replace('[alt]', '')}'."
)

return lib_fn

Expand Down Expand Up @@ -1242,6 +1244,10 @@ def numpy_like(ary, indices_or_sections, axis=0, **kwargs):
split_size = ary.shape[axis] // indices_or_sections
return fn(ary, split_size, dim=axis, **kwargs)
else:
# torch.split doesn't support empty splits
if len(indices_or_sections) == 0:
return (ary,)

diff = do(
"diff",
indices_or_sections,
Expand Down Expand Up @@ -1296,6 +1302,8 @@ def numpy_like(N, M=None, dtype=None, **kwargs):
_FUNC_ALIASES["torch", "random.uniform"] = "rand"
_FUNC_ALIASES["torch", "take"] = "index_select"
_FUNC_ALIASES["torch", "linalg.expm"] = "matrix_exp"
_FUNC_ALIASES["torch", "conjugate"] = "conj"
_FUNC_ALIASES["torch", "split"] = "tensor_split"

_SUBMODULE_ALIASES["torch", "linalg.expm"] = "torch"
_SUBMODULE_ALIASES["torch", "random.normal"] = "torch"
Expand All @@ -1304,7 +1312,6 @@ def numpy_like(N, M=None, dtype=None, **kwargs):
_CUSTOM_WRAPPERS["torch", "linalg.svd"] = svd_not_full_matrices_wrapper
_CUSTOM_WRAPPERS["torch", "random.normal"] = scale_random_normal_manually
_CUSTOM_WRAPPERS["torch", "random.uniform"] = scale_random_uniform_manually
_CUSTOM_WRAPPERS["torch", "split"] = torch_split_wrap
_CUSTOM_WRAPPERS["torch", "tensordot"] = torch_tensordot_wrap
_CUSTOM_WRAPPERS["torch", "stack"] = make_translator(
[("arrays", ("tensors",)), ("axis", ("dim", 0))]
Expand All @@ -1329,7 +1336,7 @@ def numpy_like(N, M=None, dtype=None, **kwargs):
[("a", ("input",)), ("indices", ("index",)), ("axis", ("dim",))]
)

# for torch < 1.9
# for older versions of torch, can provide some alternative implementations
_MODULE_ALIASES['torch[alt]'] = 'torch'

_FUNCS["torch[alt]", "linalg.eigh"] = torch_linalg_eigh
Expand All @@ -1340,6 +1347,7 @@ def numpy_like(N, M=None, dtype=None, **kwargs):
_SUBMODULE_ALIASES["torch[alt]", "linalg.norm"] = "torch"
_SUBMODULE_ALIASES["torch[alt]", "linalg.solve"] = "torch"

_CUSTOM_WRAPPERS["torch[alt]", "split"] = torch_split_wrap
_CUSTOM_WRAPPERS["torch[alt]", "linalg.svd"] = svd_UsV_to_UsVH_wrapper
_CUSTOM_WRAPPERS["torch[alt]", "linalg.qr"] = qr_allow_fat
_CUSTOM_WRAPPERS["torch[alt]", "linalg.solve"] = torch_linalg_solve_wrap
Expand Down

0 comments on commit 5a80323

Please sign in to comment.