Skip to content

Commit

Permalink
o3.experimental.FullTensorProductv2 for `torch.compile(...., fullgr…
Browse files Browse the repository at this point in the history
…aph=True)` (e3nn#436)
  • Loading branch information
mitkotak authored Jun 13, 2024
1 parent ac3528f commit 19dca23
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 1 deletion.
1 change: 1 addition & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]
### Added
- `o3.experimental.FullTensorProductv2` for compatibility with `torch.compile(..., fulgraph=True)`
- enable `pip` caching in CI
- refactor to use `pyproject.toml` for packaging
- refactor `gh` community files
Expand Down
4 changes: 3 additions & 1 deletion e3nn/o3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
FullTensorProduct,
TensorSquare,
)
from .experimental import FullTensorProductv2

from ._spherical_harmonics import SphericalHarmonics, spherical_harmonics
from ._angular_spherical_harmonics import (
SphericalHarmonicsAlphaBeta,
Expand Down Expand Up @@ -101,7 +103,7 @@
"FullyConnectedTensorProduct",
"ElementwiseTensorProduct",
"FullTensorProduct",
"TensorSquare",
"FullTensorProductv2" "TensorSquare",
"SphericalHarmonics",
"spherical_harmonics",
"SphericalHarmonicsAlphaBeta",
Expand Down
17 changes: 17 additions & 0 deletions e3nn/o3/_irreps.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,23 @@ def sort(self):
irreps = Irreps([(mul, ir) for ir, _, mul in out])
return Ret(irreps, p, inv)

def regroup(self) -> "Irreps":
r"""Regroup the same irreps together.
Equivalent to :meth:`sort` followed by :meth:`simplify`.
Returns
-------
irreps: `e3nn.o3.Irreps`
Examples
--------
>>> Irreps("1e + 0e + 1e + 0x2e").regroup()
1x0e+2x1e
"""
return self.sort().irreps.simplify()

@property
def dim(self) -> int:
return sum(mul * ir.dim for mul, ir in self)
Expand Down
1 change: 1 addition & 0 deletions e3nn/o3/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._full_tp import FullTensorProduct as FullTensorProductv2
94 changes: 94 additions & 0 deletions e3nn/o3/experimental/_full_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from e3nn.util.datatypes import Path, Chunk
from e3nn import o3

import torch
from torch import nn
import numpy as np


def _prepare_inputs(input1, input2):
dtype = torch.promote_types(input1.dtype, input2.dtype)

input1 = input1.to(dtype=dtype)
input2 = input2.to(dtype=dtype)

leading_shape = torch.broadcast_shapes(input1.shape[:-1], input2.shape[:-1])
input1 = input1.broadcast_to(leading_shape + (-1,))
input2 = input2.broadcast_to(leading_shape + (-1,))
return input1, input2, leading_shape


class FullTensorProduct(nn.Module):
def __init__(
self,
irreps_in1: o3.Irreps,
irreps_in2: o3.Irreps,
*,
filter_ir_out: o3.Irreps = None,
irrep_normalization: str = "component",
regroup_output: bool = True,
):
"""Tensor Product adapted from https://github.com/e3nn/e3nn-jax/blob/cf37f3e95264b34587b3a202ea4c3eb82597307e/e3nn_jax/_src/tensor_products.py#L40-L135"""
super(FullTensorProduct, self).__init__()

if regroup_output:
irreps_in1 = o3.Irreps(irreps_in1).regroup()
irreps_in2 = o3.Irreps(irreps_in2).regroup()

paths = {}
irreps_out = []
for (mul_1, ir_1), slice_1 in zip(irreps_in1, irreps_in1.slices()):
for (mul_2, ir_2), slice_2 in zip(irreps_in2, irreps_in2.slices()):
for ir_out in ir_1 * ir_2:
if filter_ir_out is not None and ir_out not in filter_ir_out:
continue
cg = o3.wigner_3j(ir_1.l, ir_2.l, ir_out.l)
if irrep_normalization == "component":
cg *= np.sqrt(ir_out.dim)
elif irrep_normalization == "norm":
cg *= np.sqrt(ir_1.dim * ir_2.dim)
else:
raise ValueError(f"irrep_normalization={irrep_normalization} not supported")
self.register_buffer(f"cg_{ir_1.l}_{ir_2.l}_{ir_out.l}", cg)
paths[(ir_1.l, ir_2.l, ir_out.l)] = Path(
Chunk(mul_1, ir_1.dim, slice_1), Chunk(mul_2, ir_2.dim, slice_2), Chunk(mul_1 * mul_2, ir_out.dim)
)
irreps_out.append((mul_1 * mul_2, ir_out))
self.paths = paths
irreps_out = o3.Irreps(irreps_out)
self.irreps_out, _, self.inv = irreps_out.sort()
self.irreps_in1 = irreps_in1
self.irreps_in2 = irreps_in2

def forward(
self,
input1: torch.Tensor,
input2: torch.Tensor,
) -> torch.Tensor:
input1, input2, leading_shape = _prepare_inputs(input1, input2)
chunks = []
for (l1, l2, l3), (
(mul_1, input_dim1, slice_1),
(mul_2, input_dim2, slice_2),
(output_mul, output_dim, _),
) in self.paths.items():
x1 = input1[..., slice_1].reshape(
leading_shape
+ (
mul_1,
input_dim1,
)
)
x2 = input2[..., slice_2].reshape(
leading_shape
+ (
mul_2,
input_dim2,
)
)
cg = getattr(self, f"cg_{l1}_{l2}_{l3}")
chunk = torch.einsum("...ui, ...vj, ijk -> ...uvk", x1, x2, cg)
chunk = torch.reshape(chunk, chunk.shape[:-3] + (output_mul * output_dim,))
chunks.append(chunk)

return torch.cat([chunks[i] for i in self.inv], dim=-1)
13 changes: 13 additions & 0 deletions e3nn/util/datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import NamedTuple, Optional


class Chunk(NamedTuple):
mul: int
dim: int
slice: Optional[slice] = None


class Path(NamedTuple):
input_1_slice: Chunk
input_2_slice: Chunk
output_slice: Chunk
55 changes: 55 additions & 0 deletions tests/o3/experimental/benchmark_pt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
from torch._inductor.utils import print_performance

# Borrowed from https://github.com/pytorch-labs/gpt-fast/blob/db7b273ab86b75358bd3b014f1f022a19aba4797/generate.py#L16-L18
torch.set_float32_matmul_precision("high")
import torch._dynamo.config
import torch._inductor.config

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True

device = "cuda"
compile_mode = "max-autotune" # Bringing out all of the tricks that Torch 2.0 has but "reduce-overhead" should work as well

from e3nn import o3, util
import numpy as np
from torch import nn
import time

LMAX = 8
CHANNEL = 128
BATCH = 100


def main():
for lmax in range(1, LMAX + 1):
irreps = o3.Irreps.spherical_harmonics(lmax)
irreps_x = (CHANNEL * irreps).regroup()
x = irreps_x.randn(BATCH, -1).to(device=device)
irreps_y = irreps
y = irreps_y.randn(BATCH, -1).to(device=device)
print(f"{irreps_x} \otimes {irreps_y}")

tp = o3.FullTensorProduct(irreps_x, irreps_y) # Doesnt work with fullgraph=True

tp_jit_compile = util.jit.compile(tp).to(device=device)

tp_compile = torch.compile(tp, mode=compile_mode).to(device=device)
print(
f"TP JIT lmax {lmax} channel {CHANNEL} batch {BATCH}: {print_performance(lambda: tp_jit_compile(x, y), times=100, repeat=10)*1000:.3f}ms"
)

print(
f"TP Torch 2.0 lmax {lmax} channel {CHANNEL} batch {BATCH}: {print_performance(lambda: tp_compile(x, y), times=100, repeat=10)*1000:.3f}ms"
)

tp_experimental = o3.experimental.FullTensorProductv2(irreps_x, irreps_y)
tp_experimental_compile = torch.compile(tp_experimental, mode=compile_mode, fullgraph=True).to(device=device)
print(
f"TP Experimental Torch 2.0 lmax {lmax} channel {CHANNEL} batch {BATCH}: {print_performance(lambda: tp_experimental_compile(x, y), times=100, repeat=10)*1000:.3f}ms"
)


if __name__ == "__main__":
main()
15 changes: 15 additions & 0 deletions tests/o3/experimental/test_fulltp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch
from e3nn import o3
import pytest


@pytest.mark.parametrize("irreps_in1", ["0e", "0e + 1e"])
@pytest.mark.parametrize("irreps_in2", ["2x0e", "2x0e + 3x1e"])
def test_fulltp(irreps_in1, irreps_in2):
x = o3.Irreps(irreps_in1).randn(10, -1)
y = o3.Irreps(irreps_in2).randn(10, -1)

tp_pt2 = torch.compile(o3.experimental.FullTensorProductv2(irreps_in1, irreps_in2), fullgraph=True)
tp = o3.FullTensorProduct(irreps_in1, irreps_in2)

torch.testing.assert_close(tp_pt2(x, y), tp(x, y))

0 comments on commit 19dca23

Please sign in to comment.