Skip to content

Commit

Permalink
place disabling tf32 (and reenabling) into contextmanager
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 20, 2021
1 parent de1b902 commit 63ec484
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
38 changes: 19 additions & 19 deletions invariant_point_attention/invariant_point_attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn.functional as F
from contextlib import contextmanager
from torch import nn, einsum

from einops.layers.torch import Rearrange
Expand All @@ -16,8 +17,12 @@ def default(val, d):
def max_neg_value(t):
return -torch.finfo(t.dtype).max

def switch_tf32(target):
torch.backends.cuda.matmul.allow_tf32 = target
@contextmanager
def disable_tf32():
orig_value = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = False
yield
torch.backends.cuda.matmul.allow_tf32 = orig_value

# classes

Expand Down Expand Up @@ -146,31 +151,26 @@ def forward(

attn = attn_logits.softmax(dim = - 1)

# ENTER SENSITIVE PART: disable TF32 for precision
with disable_tf32():
# ENTER SENSITIVE PART: disable TF32 for precision

switch_tf32(False)
# aggregate values

# aggregate values
results_scalar = einsum('b i j, b j d -> b i d', attn, v_scalar)

results_scalar = einsum('b i j, b j d -> b i d', attn, v_scalar)
attn_with_heads = rearrange(attn, '(b h) i j -> b h i j', h = h)

attn_with_heads = rearrange(attn, '(b h) i j -> b h i j', h = h)
if require_pairwise_repr:
results_pairwise = einsum('b h i j, b i j d -> b h i d', attn_with_heads, pairwise_repr)

if require_pairwise_repr:
results_pairwise = einsum('b h i j, b i j d -> b h i d', attn_with_heads, pairwise_repr)

# aggregate point values

results_points = einsum('b i j, b j d c -> b i d c', attn, v_point)

# rotate aggregated point values back into local frame
# aggregate point values

results_points = einsum('b n d c, b n c r -> b n d r', results_points - translations, rotations.transpose(-1, -2))
results_points_norm = torch.sqrt( torch.square(results_points).sum(dim=-1) + eps )
results_points = einsum('b i j, b j d c -> b i d c', attn, v_point)

# EXIT SENSITIVE PART: enable TF32 for speed
# rotate aggregated point values back into local frame

switch_tf32(True)
results_points = einsum('b n d c, b n c r -> b n d r', results_points - translations, rotations.transpose(-1, -2))
results_points_norm = torch.sqrt( torch.square(results_points).sum(dim=-1) + eps )

# merge back heads

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'invariant-point-attention',
packages = find_packages(),
version = '0.1.0',
version = '0.1.1',
license='MIT',
description = 'Invariant Point Attention',
author = 'Phil Wang',
Expand Down

0 comments on commit 63ec484

Please sign in to comment.