Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create cugraph-equivariant package #4036

Merged
merged 27 commits into from
Jan 29, 2024

Conversation

tingyu66
Copy link
Member

@tingyu66 tingyu66 commented Dec 4, 2023

Bring up cugraph-equivariant package and add TensorProduct conv layers.

@tingyu66 tingyu66 added this to the 24.02 milestone Dec 4, 2023
@tingyu66 tingyu66 self-assigned this Dec 4, 2023
@tingyu66 tingyu66 added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change labels Jan 12, 2024
@tingyu66 tingyu66 requested a review from stadlmax January 19, 2024 14:47
@tingyu66
Copy link
Member Author

@stadlmax I've updated the wrapper based on the new fused_tp kernels. Can you take a look? Tests won't pass until we have the new kernels reflected in pylibcugraphops nightly.

Copy link

@mariogeiger mariogeiger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a problem if mlp_fast_first_layer=False but src_scalars and dst_scalars are provided.

(optional) It would be nice if we can remove that mlp_fast_first_layer and simply do the thing when src_scalars and dst_scalars are provided.

batch_norm: bool = True,
mlp_channels: Optional[Sequence[int]] = None,
mlp_activation: Optional[Callable[..., nn.Module]] = nn.GELU,
mlp_fast_first_layer: bool = False,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this argument need to be here? Can't the special just happen when src_scalars and dst_scalars are provided?

Copy link

@DejunL DejunL Jan 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with Mario here. There are a few branches when it comes to computing the tp weights:

  1. user direct input the precomputed weights to forward
  2. feed only the edge_emb to the mlp to compute the weight. This could include the case that the node embeddings are already concatenated into the edge_emb, which we can't distinguish
  3. As in point 2 but additionally have one node embedding tensor, i.e., the graph is non-directed
  4. As in point 2 but additionally have separate src and dst node embedding tensors

Edit: we already handled cases 1, 2 and 4, despite some argument against the need for mlp_fast_first_layer in the init(). Do we want to handle case 3 explicitly in the API? i.e., if only src_scalars but not dst_scalars is given, do we want to assume the dst embedding are also index from src_scalars or do we always want the user to supply the same tensor to both src_scalars=node_attr, dst_scalars=node_attr? I guess it can't hurt to do the latter so it's up to you

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the code to support arbitrary number of scalar features in [src, dst]_scalars. They can also be None if needed.

Regarding use case 3, users should input src_scalars=node_attr, dst_scalars=node_attr.


from .tensor_product_conv import FullyConnectedTensorProductConv

DiffDockTensorProductConv = FullyConnectedTensorProductConv
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this alias?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is from one of our discussions on Monday, but it's totally optional. @mariogeiger Do you think we need the alias?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for me, the name FullyConnectedTensorProductConv is perfect by its own

\sum_{b \in \mathcal{N}_a} Y\left(\hat{r}_{a b}\right)
\otimes_{\psi_{a b}} \mathbf{h}_b

where the path weights :math:`\psi_{a b}` are from user input to the forward()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the Diffdock paper, 1) \psi is itself the MLP for computing the weights from the edge and node embeddings. 2) This implementation has the option to either directly input the weights to the forward() or compute the weights using the MLP (\psi) from the edge and/or the node embeddings

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

\psi denotes the weights while \Psi is the MLP. I have added a brief note to show the two options here

mlp = []
for i in range(len(dims) - 1):
mlp.append(nn.Linear(dims[i], dims[i + 1]))
if mlp_activation is not None and i != len(dims) - 2:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would mlp_activation be None?

@tingyu66
Copy link
Member Author

@mariogeiger @DejunL
I referred to the diffdock repo when implementing it and thought edge, src and dst sharing the same num_scalers would the only use case . If we do need support for various num_scalers for different component, I can make the change accordingly.

I will also remove the mlp_fast_first_layer flag, since it is only used to validate the dimensionality.

in_irreps: o3.Irreps,
sh_irreps: o3.Irreps,
out_irreps: o3.Irreps,
batch_norm: bool = True,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the Diffdock training code, we have customized batch_norm function. Instead of being a boolean option, can this take a callable that defaults to e3nn.nn.BatchNorm?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Will make the change.

Copy link
Member Author

@tingyu66 tingyu66 Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't come up with a neat way to support this because if we change the argument to be a callable (with the default value =Batchnorm), the customized batch_norm function must have the same function signature as e3nn.nn.Batchnorm. With that, I would suggest applying customized function outside of the conv layer.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both the e3nn.nn.BatchNorm and the modified version we used are nn.module so I should have suggested change the type of the batch_norm argument to nn.module. But then it becomes a moot point since we can apply the BatchNorm outside of the conv layer like you said. Let me double check with Guoqing, who authored the modified version of BatchNorm

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just confirmed with Guoqing that the current API is OK. So it's up to you if you want to keep it as it is

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it make more sense to just not have BN in the API then?

@DejunL
Copy link

DejunL commented Jan 22, 2024

@mariogeiger @DejunL I referred to the diffdock repo when implementing it and thought edge, src and dst sharing the same num_scalers would the only use case . If we do need support for various num_scalers for different component, I can make the change accordingly.

I will also remove the mlp_fast_first_layer flag, since it is only used to validate the dimensionality.

Yes, for now in the Diffdock model, edge_emb and src_scalars are really scalars. But in general, we don't want to constrain them to be scalar or of same dimensionality. Thank you!

@tingyu66 tingyu66 marked this pull request as ready for review January 23, 2024 19:32
@tingyu66 tingyu66 requested review from a team as code owners January 23, 2024 19:32
@tingyu66 tingyu66 changed the title [DRAFT] Create cugraph-equivariant package Create cugraph-equivariant package Jan 23, 2024
--channel pytorch \
--channel nvidia \
cugraph-equivariant
pip install e3nn==0.5.1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, but would this be better off as a dependency in the py_test_cugraph_equivariant section of the dependencies.yaml file?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally yes, but e3nn depends on PyTorch. Having that in pyproject.toml might pull wrong versions of pytorch for users. cugraph-dgl and -pyg's pyproject.toml does not have pytorch either (I guess for the same reason).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tingyu66 do you want me to remove pytorch from the dependencies of e3nn?

@tingyu66 tingyu66 requested a review from a team January 26, 2024 17:27
Copy link
Contributor

@stadlmax stadlmax left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

besides the minor comment w.r.t. having BN in the API, no further comments

Copy link

@mariogeiger mariogeiger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, looks good except that BN that I'm not super super happy about

@BradReesWork
Copy link
Member

/merge

@rapids-bot rapids-bot bot merged commit 3ff2abd into rapidsai:branch-24.02 Jan 29, 2024
109 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci conda improvement Improvement / enhancement to an existing function non-breaking Non-breaking change python
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants