Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Use edge_ids directly in uniform sampling call to prevent cost of edge_id lookup #2520

Closed
VibhuJawa opened this issue Aug 9, 2022 · 0 comments · Fixed by #2550
Closed
Assignees
Milestone

Comments

@VibhuJawa
Copy link
Member

VibhuJawa commented Aug 9, 2022

Describe the solution you'd like and any additional context

Currently 74% of time of sample_neighbors in gnn/graph_store.py is spent in looking up edge_ids. If we provide edge_ids inside the uniform sampling call we should be able to optimize that by alot by preventing the time spent in edge_ids lookup.

Additional Context:
Fixing this will be critical for us to get acceptable performance for upstreaming DGL work.

Benchmarks

See below line for the profiler.

   259         1      73755.0  73755.0     74.0          sampled_df = edge_df.merge(sampled_df)
 181                                               def sample_neighbors(
   182                                                   self, nodes, fanout=-1, edge_dir="in", prob=None, replace=False
   183                                               ):
   184                                                   """
   185                                                   Sample neighboring edges of the given nodes and return the subgraph.
   186                                           
   187                                                   Parameters
   188                                                   ----------
   189                                                   nodes_cap : Dlpack of Node IDs (single dimension)
   190                                                       Node IDs to sample neighbors from.
   191                                                   fanout : int
   192                                                       The number of edges to be sampled for each node on each edge type.
   193                                                       If -1 is given all the neighboring edges for each node on
   194                                                       each edge type will be selected.
   195                                                   edge_dir : str {"in" or "out"}
   196                                                       Determines whether to sample inbound or outbound edges.
   197                                                       Can take either in for inbound edges or out for outbound edges.
   198                                                   prob : str
   199                                                       Feature name used as the (unnormalized) probabilities associated
   200                                                       with each neighboring edge of a node. Each feature must be a
   201                                                       scalar. The features must be non-negative floats, and the sum of
   202                                                       the features of inbound/outbound edges for every node must be
   203                                                       positive (though they don't have to sum up to one). Otherwise,
   204                                                       the result will be undefined. If not specified, sample uniformly.
   205                                                   replace : bool
   206                                                       If True, sample with replacement.
   207                                           
   208                                                   Returns
   209                                                   -------
   210                                                   DLPack capsule
   211                                                       The src nodes for the sampled bipartite graph.
   212                                                   DLPack capsule
   213                                                       The sampled dst nodes for the sampledbipartite graph.
   214                                                   DLPack capsule
   215                                                       The corresponding eids for the sampled bipartite graph
   216                                                   """
   217                                           
   218         1          2.0      2.0      0.0          if edge_dir not in ["in", "out"]:
   219                                                       raise ValueError(
   220                                                           f"edge_dir must be either 'in' or 'out' got {edge_dir} instead"
   221                                                       )
   222                                           
   223         1          1.0      1.0      0.0          if edge_dir == "in":
   224         1          2.0      2.0      0.0              sg = self.extracted_reverse_subgraph_without_renumbering
   225                                                   else:
   226                                                       sg = self.extracted_subgraph_without_renumbering
   227                                           
   228         1          1.0      1.0      0.0          if not hasattr(self, '_sg_node_dtype'):
   229                                                       self._sg_node_dtype = sg.edgelist.edgelist_df['src'].dtype
   230                                           
   231                                                   # Uniform sampling assumes fails when the dtype
   232                                                   # if the seed dtype is not same as the node dtype
   233         1        413.0    413.0      0.4          nodes = cudf.from_dlpack(nodes).astype(self._sg_node_dtype)
   234                                           
   235         2      21310.0  10655.0     21.4          sampled_df = uniform_neighbor_sample(
   236         1          1.0      1.0      0.0              sg, start_list=nodes, fanout_vals=[fanout],
   237         1          0.0      0.0      0.0              with_replacement=replace
   238                                                   )
   239                                           
   240         1        379.0    379.0      0.4          sampled_df.drop(columns=["indices"], inplace=True)
   241                                           
   242                                                   # handle empty graph case
   243         1         13.0     13.0      0.0          if len(sampled_df) == 0:
   244                                                       return None, None, None
   245                                           
   246                                                   # we reverse directions when directions=='in'
   247         1          1.0      1.0      0.0          if edge_dir == "in":
   248         2        177.0     88.5      0.2              sampled_df.rename(
   249         1          1.0      1.0      0.0                  columns={"destinations": src_n, "sources": dst_n}, inplace=True
   250                                                       )
   251                                                   else:
   252                                                       sampled_df.rename(
   253                                                           columns={"sources": src_n, "destinations": dst_n}, inplace=True
   254                                                       )
   255                                           
   256                                                   # FIXME: Remove once below lands
   257                                                   # https://github.com/rapidsai/cugraph/issues/2444
   258         1       1226.0   1226.0      1.2          edge_df = self.gdata._edge_prop_dataframe[[src_n, dst_n, eid_n]]
   259         1      73755.0  73755.0     74.0          sampled_df = edge_df.merge(sampled_df)
   260                                           
   261         1          2.0      2.0      0.0          return (
   262         1        929.0    929.0      0.9              sampled_df[src_n].to_dlpack(),
   263         1        714.0    714.0      0.7              sampled_df[dst_n].to_dlpack(),
   264         1        688.0    688.0      0.7              sampled_df[eid_n].to_dlpack(),
   265                                                   )
@BradReesWork BradReesWork added this to the 22.10 milestone Aug 17, 2022
rapids-bot bot pushed a commit that referenced this issue Aug 17, 2022
…e_id lookup (#2550)

This PR fixes #2520

**Speedup Details**
We see a 2.6x speedup , ranging from 0.8x to 10x. 

**Benchmarking Gist:**
Benchmark Link: https://gist.github.com/VibhuJawa/38da2f151141c0582a0532a364458602                                                                                                                                                                                        


**Benchmarking Table:**




| dataset     | fanout | seednodes | PR cugraph\_t (ms) | Main cugraph\_t (ms) | Speedup      |
| ----------- | ------ | --------- | ------------------ | -------------------- | ------------ |
| livejournal | 5      | 6400      | 9.77469367         | 36.14                | 3.697743722  |
| livejournal | 5      | 12800     | 10.24105188        | 37.04                | 3.617198402  |
| livejournal | 5      | 25600     | 11.25398077        | 39.31                | 3.492790318  |
| livejournal | 5      | 51200     | 19.90233963        | 48.31                | 2.427492542  |
| livejournal | 20     | 6400      | 11.08045933        | 37.40                | 3.375111171  |
| livejournal | 20     | 12800     | 12.41813744        | 39.78                | 3.203001674  |
| livejournal | 20     | 25600     | 20.01964133        | 48.59                | 2.426926934  |
| livejournal | 20     | 51200     | 20.479394          | 51.75                | 2.526783655  |
| livejournal | 40     | 6400      | 18.02444187        | 38.42                | 2.13166189   |
| livejournal | 40     | 12800     | 15.95887286        | 41.13                | 2.577490516  |
| livejournal | 40     | 25600     | 30.42667777        | 49.21                | 1.617178892  |
| livejournal | 40     | 51200     | 31.27987486        | 56.83                | 1.816870032  |
| ogbn-arxiv  | 5      | 6400      | 7.269433069        | 6.81                 | 0.9363815769 |
| ogbn-arxiv  | 5      | 12800     | 3.700939559        | 6.48                 | 1.750559107  |
| ogbn-arxiv  | 5      | 25600     | 7.43439748         | 6.74                 | 0.9070057901 |
| ogbn-arxiv  | 5      | 51200     | 8.364707041        | 8.92                 | 1.06631151   |
| ogbn-arxiv  | 20     | 6400      | 3.526507211        | 6.01                 | 1.704996136  |
| ogbn-arxiv  | 20     | 12800     | 7.11795785         | 6.35                 | 0.8917298112 |
| ogbn-arxiv  | 20     | 25600     | 9.83814247         | 8.87                 | 0.9015745857 |
| ogbn-arxiv  | 20     | 51200     | 19.16898326        | 15.28                | 0.797070347  |
| ogbn-arxiv  | 40     | 6400      | 7.47879348         | 6.11                 | 0.8169812813 |
| ogbn-arxiv  | 40     | 12800     | 8.980390432        | 7.44                 | 0.828701598  |
| ogbn-arxiv  | 40     | 25600     | 9.939847551        | 9.78                 | 0.9838518889 |
| ogbn-arxiv  | 40     | 51200     | 21.65015471        | 17.39                | 0.8032186603 |
| reddit      | 5      | 6400      | 4.485681872        | 47.60                | 10.61118206  |
| reddit      | 5      | 12800     | 8.203881669        | 48.36                | 5.894866842  |
| reddit      | 5      | 25600     | 10.19984847        | 51.61                | 5.05981494   |
| reddit      | 5      | 51200     | 25.52061113        | 61.15                | 2.39617171   |
| reddit      | 20     | 6400      | 9.60336474         | 51.21                | 5.333003796  |
| reddit      | 20     | 12800     | 22.43147231        | 60.14                | 2.681092588  |
| reddit      | 20     | 25600     | 23.204309          | 70.10                | 3.021163687  |
| reddit      | 20     | 51200     | 27.07365799        | 76.18                | 2.813953476  |
| reddit      | 40     | 6400      | 24.64297758        | 60.25                | 2.445081387  |
| reddit      | 40     | 12800     | 23.05950785        | 68.38                | 2.965428975  |
| reddit      | 40     | 25600     | 24.84033842        | 74.12                | 2.983957307  |
| reddit      | 40     | 51200     | 30.75342988        | 87.18                | 2.834787134  |

**Bottleneck after the PR**

```python
Timer unit: 1e-06 s

Total time: 0.022579 s
File: /datasets/vjawa/miniconda3/envs/cugraph_dev_aug_10/lib/python3.9/site-packages/cugraph-22.10.0a0+45.g3ff5b53ff.dirty-py3.9-linux-x86_64.egg/cugraph/gnn/graph_store.py
Function: sample_neighbors at line 181

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   181                                               def sample_neighbors(
   182                                                   self, nodes, fanout=-1, edge_dir="in", prob=None, replace=False
   183                                               ):
  ................
   216                                                   """
   217                                           
   218         1          2.0      2.0      0.0          if edge_dir not in ["in", "out"]:
   219                                                       raise ValueError(
   220                                                           f"edge_dir must be either 'in' or 'out' got {edge_dir} instead"
   221                                                       )
   222                                           
   223         1          1.0      1.0      0.0          if edge_dir == "in":
   224         1          1.0      1.0      0.0              sg = self.extracted_reverse_subgraph_without_renumbering
   225                                                   else:
   226                                                       sg = self.extracted_subgraph_without_renumbering
   227                                           
   228         1          1.0      1.0      0.0          if not hasattr(self, '_sg_node_dtype'):
   229                                                       self._sg_node_dtype = sg.edgelist.edgelist_df['src'].dtype
   230                                           
   231                                                   # Uniform sampling assumes fails when the dtype
   232                                                   # if the seed dtype is not same as the node dtype
   233         1        774.0    774.0      3.4          nodes = cudf.from_dlpack(nodes).astype(self._sg_node_dtype)
   234                                           
   235         2      19303.0   9651.5     85.5          sampled_df = uniform_neighbor_sample(
   236         1          1.0      1.0      0.0              sg, start_list=nodes, fanout_vals=[fanout],
   237         1          0.0      0.0      0.0              with_replacement=replace,
   238         1          1.0      1.0      0.0              is_edge_ids=True  # FIXME: Does not seem to do anything
   239                                                   )
   240                                           
   241                                                   # handle empty graph case
   242         1         17.0     17.0      0.1          if len(sampled_df) == 0:
   243                                                       return None, None, None
   244                                           
   245                                                   # we reverse directions when directions=='in'
   246         1          1.0      1.0      0.0          if edge_dir == "in":
   247         2        136.0     68.0      0.6              sampled_df.rename(
   248         1          1.0      1.0      0.0                  columns={"destinations": src_n, "sources": dst_n}, inplace=True
   249                                                       )
   250                                                   else:
   251                                                       sampled_df.rename(
   252                                                           columns={"sources": src_n, "destinations": dst_n}, inplace=True
   253                                                       )
   254                                           
   255         1          2.0      2.0      0.0          return (
   256         1        786.0    786.0      3.5              sampled_df[src_n].to_dlpack(),
   257         1        776.0    776.0      3.4              sampled_df[dst_n].to_dlpack(),
   258         1        776.0    776.0      3.4              sampled_df['indices'].to_dlpack(),
   259                                                   )
```

Authors:
  - Vibhu Jawa (https://github.com/VibhuJawa)

Approvers:
  - Brad Rees (https://github.com/BradReesWork)
  - Rick Ratzel (https://github.com/rlratzel)
  - Alex Barghi (https://github.com/alexbarghi-nv)

URL: #2550
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants