forked from e3nn/e3nn
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
o3.experimental.FullTensorProductv2
for `torch.compile(...., fullgr…
…aph=True)` (e3nn#436)
- Loading branch information
Showing
8 changed files
with
199 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from ._full_tp import FullTensorProduct as FullTensorProductv2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from e3nn.util.datatypes import Path, Chunk | ||
from e3nn import o3 | ||
|
||
import torch | ||
from torch import nn | ||
import numpy as np | ||
|
||
|
||
def _prepare_inputs(input1, input2): | ||
dtype = torch.promote_types(input1.dtype, input2.dtype) | ||
|
||
input1 = input1.to(dtype=dtype) | ||
input2 = input2.to(dtype=dtype) | ||
|
||
leading_shape = torch.broadcast_shapes(input1.shape[:-1], input2.shape[:-1]) | ||
input1 = input1.broadcast_to(leading_shape + (-1,)) | ||
input2 = input2.broadcast_to(leading_shape + (-1,)) | ||
return input1, input2, leading_shape | ||
|
||
|
||
class FullTensorProduct(nn.Module): | ||
def __init__( | ||
self, | ||
irreps_in1: o3.Irreps, | ||
irreps_in2: o3.Irreps, | ||
*, | ||
filter_ir_out: o3.Irreps = None, | ||
irrep_normalization: str = "component", | ||
regroup_output: bool = True, | ||
): | ||
"""Tensor Product adapted from https://github.com/e3nn/e3nn-jax/blob/cf37f3e95264b34587b3a202ea4c3eb82597307e/e3nn_jax/_src/tensor_products.py#L40-L135""" | ||
super(FullTensorProduct, self).__init__() | ||
|
||
if regroup_output: | ||
irreps_in1 = o3.Irreps(irreps_in1).regroup() | ||
irreps_in2 = o3.Irreps(irreps_in2).regroup() | ||
|
||
paths = {} | ||
irreps_out = [] | ||
for (mul_1, ir_1), slice_1 in zip(irreps_in1, irreps_in1.slices()): | ||
for (mul_2, ir_2), slice_2 in zip(irreps_in2, irreps_in2.slices()): | ||
for ir_out in ir_1 * ir_2: | ||
if filter_ir_out is not None and ir_out not in filter_ir_out: | ||
continue | ||
cg = o3.wigner_3j(ir_1.l, ir_2.l, ir_out.l) | ||
if irrep_normalization == "component": | ||
cg *= np.sqrt(ir_out.dim) | ||
elif irrep_normalization == "norm": | ||
cg *= np.sqrt(ir_1.dim * ir_2.dim) | ||
else: | ||
raise ValueError(f"irrep_normalization={irrep_normalization} not supported") | ||
self.register_buffer(f"cg_{ir_1.l}_{ir_2.l}_{ir_out.l}", cg) | ||
paths[(ir_1.l, ir_2.l, ir_out.l)] = Path( | ||
Chunk(mul_1, ir_1.dim, slice_1), Chunk(mul_2, ir_2.dim, slice_2), Chunk(mul_1 * mul_2, ir_out.dim) | ||
) | ||
irreps_out.append((mul_1 * mul_2, ir_out)) | ||
self.paths = paths | ||
irreps_out = o3.Irreps(irreps_out) | ||
self.irreps_out, _, self.inv = irreps_out.sort() | ||
self.irreps_in1 = irreps_in1 | ||
self.irreps_in2 = irreps_in2 | ||
|
||
def forward( | ||
self, | ||
input1: torch.Tensor, | ||
input2: torch.Tensor, | ||
) -> torch.Tensor: | ||
input1, input2, leading_shape = _prepare_inputs(input1, input2) | ||
chunks = [] | ||
for (l1, l2, l3), ( | ||
(mul_1, input_dim1, slice_1), | ||
(mul_2, input_dim2, slice_2), | ||
(output_mul, output_dim, _), | ||
) in self.paths.items(): | ||
x1 = input1[..., slice_1].reshape( | ||
leading_shape | ||
+ ( | ||
mul_1, | ||
input_dim1, | ||
) | ||
) | ||
x2 = input2[..., slice_2].reshape( | ||
leading_shape | ||
+ ( | ||
mul_2, | ||
input_dim2, | ||
) | ||
) | ||
cg = getattr(self, f"cg_{l1}_{l2}_{l3}") | ||
chunk = torch.einsum("...ui, ...vj, ijk -> ...uvk", x1, x2, cg) | ||
chunk = torch.reshape(chunk, chunk.shape[:-3] + (output_mul * output_dim,)) | ||
chunks.append(chunk) | ||
|
||
return torch.cat([chunks[i] for i in self.inv], dim=-1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from typing import NamedTuple, Optional | ||
|
||
|
||
class Chunk(NamedTuple): | ||
mul: int | ||
dim: int | ||
slice: Optional[slice] = None | ||
|
||
|
||
class Path(NamedTuple): | ||
input_1_slice: Chunk | ||
input_2_slice: Chunk | ||
output_slice: Chunk |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import torch | ||
from torch._inductor.utils import print_performance | ||
|
||
# Borrowed from https://github.com/pytorch-labs/gpt-fast/blob/db7b273ab86b75358bd3b014f1f022a19aba4797/generate.py#L16-L18 | ||
torch.set_float32_matmul_precision("high") | ||
import torch._dynamo.config | ||
import torch._inductor.config | ||
|
||
torch._inductor.config.coordinate_descent_tuning = True | ||
torch._inductor.config.triton.unique_kernel_names = True | ||
|
||
device = "cuda" | ||
compile_mode = "max-autotune" # Bringing out all of the tricks that Torch 2.0 has but "reduce-overhead" should work as well | ||
|
||
from e3nn import o3, util | ||
import numpy as np | ||
from torch import nn | ||
import time | ||
|
||
LMAX = 8 | ||
CHANNEL = 128 | ||
BATCH = 100 | ||
|
||
|
||
def main(): | ||
for lmax in range(1, LMAX + 1): | ||
irreps = o3.Irreps.spherical_harmonics(lmax) | ||
irreps_x = (CHANNEL * irreps).regroup() | ||
x = irreps_x.randn(BATCH, -1).to(device=device) | ||
irreps_y = irreps | ||
y = irreps_y.randn(BATCH, -1).to(device=device) | ||
print(f"{irreps_x} \otimes {irreps_y}") | ||
|
||
tp = o3.FullTensorProduct(irreps_x, irreps_y) # Doesnt work with fullgraph=True | ||
|
||
tp_jit_compile = util.jit.compile(tp).to(device=device) | ||
|
||
tp_compile = torch.compile(tp, mode=compile_mode).to(device=device) | ||
print( | ||
f"TP JIT lmax {lmax} channel {CHANNEL} batch {BATCH}: {print_performance(lambda: tp_jit_compile(x, y), times=100, repeat=10)*1000:.3f}ms" | ||
) | ||
|
||
print( | ||
f"TP Torch 2.0 lmax {lmax} channel {CHANNEL} batch {BATCH}: {print_performance(lambda: tp_compile(x, y), times=100, repeat=10)*1000:.3f}ms" | ||
) | ||
|
||
tp_experimental = o3.experimental.FullTensorProductv2(irreps_x, irreps_y) | ||
tp_experimental_compile = torch.compile(tp_experimental, mode=compile_mode, fullgraph=True).to(device=device) | ||
print( | ||
f"TP Experimental Torch 2.0 lmax {lmax} channel {CHANNEL} batch {BATCH}: {print_performance(lambda: tp_experimental_compile(x, y), times=100, repeat=10)*1000:.3f}ms" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import torch | ||
from e3nn import o3 | ||
import pytest | ||
|
||
|
||
@pytest.mark.parametrize("irreps_in1", ["0e", "0e + 1e"]) | ||
@pytest.mark.parametrize("irreps_in2", ["2x0e", "2x0e + 3x1e"]) | ||
def test_fulltp(irreps_in1, irreps_in2): | ||
x = o3.Irreps(irreps_in1).randn(10, -1) | ||
y = o3.Irreps(irreps_in2).randn(10, -1) | ||
|
||
tp_pt2 = torch.compile(o3.experimental.FullTensorProductv2(irreps_in1, irreps_in2), fullgraph=True) | ||
tp = o3.FullTensorProduct(irreps_in1, irreps_in2) | ||
|
||
torch.testing.assert_close(tp_pt2(x, y), tp(x, y)) |