Skip to content

Commit

Permalink
HeteroData.is_undirected() (#4604)
Browse files Browse the repository at this point in the history
* initial commit

* changelog

* update
  • Loading branch information
rusty1s authored May 10, 2022
1 parent cd6c1f7 commit 8fdf895
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 54 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added `HeteroData.is_undirected()` support ([#4604](https://github.com/pyg-team/pytorch_geometric/pull/4604))
- Added the `Genius` and `Wiki` datasets to `nn.datasets.LINKXDataset` ([#4570](https://github.com/pyg-team/pytorch_geometric/pull/4570), [#4600](https://github.com/pyg-team/pytorch_geometric/pull/4600))
- Added `nn.glob.GlobalPooling` module with support for multiple aggregations ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
### Changed
- The `bias` argument in `TAGConv` is now actually apllied ([#4597](https://github.com/pyg-team/pytorch_geometric/pull/4597))
- The `bias` argument in `TAGConv` is now actually applied ([#4597](https://github.com/pyg-team/pytorch_geometric/pull/4597))
- Fixed subclass behaviour of `process` and `download` in `Datsaet` ([#4586](https://github.com/pyg-team/pytorch_geometric/pull/4586))
### Removed
4 changes: 4 additions & 0 deletions test/transforms/test_to_undirected.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ def test_hetero_to_undirected():
data['v', 'w'].edge_attr = edge_attr

from torch_geometric.transforms import ToUndirected

assert not data.is_undirected()
data = ToUndirected()(data)
assert data.is_undirected()

assert data['v', 'v'].edge_index.tolist() == [[0, 1, 2, 3], [1, 0, 3, 2]]
assert data['v', 'v'].edge_weight.tolist() == edge_weight[perm].tolist()
assert data['v', 'v'].edge_attr.tolist() == edge_attr[perm].tolist()
Expand Down
116 changes: 64 additions & 52 deletions torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch_geometric.data.data import BaseData, Data, size_repr
from torch_geometric.data.storage import BaseStorage, EdgeStorage, NodeStorage
from torch_geometric.typing import EdgeType, NodeType, QueryType
from torch_geometric.utils import is_undirected

NodeOrEdgeType = Union[NodeType, EdgeType]
NodeOrEdgeStorage = Union[NodeStorage, EdgeStorage]
Expand Down Expand Up @@ -295,6 +296,11 @@ def num_edge_features(self) -> Dict[EdgeType, int]:
for key, store in self._edge_store_dict.items()
}

def is_undirected(self) -> bool:
r"""Returns :obj:`True` if graph edges are undirected."""
edge_index, _, _ = to_homogeneous_edge_index(self)
return is_undirected(edge_index, num_nodes=self.num_nodes)

def debug(self):
pass # TODO

Expand Down Expand Up @@ -481,26 +487,14 @@ def _consistent_size(stores: List[BaseStorage]) -> List[str]:
if len(sizes) == len(stores) and len(set(sizes)) == 1
]

data = Data(**self._global_store.to_dict())
edge_index, node_slices, edge_slices = to_homogeneous_edge_index(self)
device = edge_index.device if edge_index is not None else None

# Iterate over all node stores and record the slice information:
node_slices, cumsum = {}, 0
node_type_names, node_types = [], []
for i, (node_type, store) in enumerate(self._node_store_dict.items()):
num_nodes = store.num_nodes
node_slices[node_type] = (cumsum, cumsum + num_nodes)
node_type_names.append(node_type)
cumsum += num_nodes

if add_node_type:
kwargs = {'dtype': torch.long}
node_types.append(torch.full((num_nodes, ), i, **kwargs))
data._node_type_names = node_type_names

if len(node_types) > 1:
data.node_type = torch.cat(node_types, dim=0)
elif len(node_types) == 1:
data.node_type = node_types[0]
data = Data(**self._global_store.to_dict())
if edge_index is not None:
data.edge_index = edge_index
data._node_type_names = list(node_slices.keys())
data._edge_type_names = list(edge_slices.keys())

# Combine node attributes into a single tensor:
if node_attrs is None:
Expand All @@ -511,39 +505,8 @@ def _consistent_size(stores: List[BaseStorage]) -> List[str]:
value = torch.cat(values, dim) if len(values) > 1 else values[0]
data[key] = value

if len([
key for key in node_attrs
if (key in {'x', 'pos', 'batch'} or 'node' in key)
]) == 0 and not add_node_type:
data.num_nodes = cumsum

# Iterate over all edge stores and record the slice information:
edge_slices, cumsum = {}, 0
edge_indices, edge_type_names, edge_types = [], [], []
for i, (edge_type, store) in enumerate(self._edge_store_dict.items()):
src, _, dst = edge_type
num_edges = store.num_edges
edge_slices[edge_type] = (cumsum, cumsum + num_edges)
edge_type_names.append(edge_type)
cumsum += num_edges

kwargs = {'dtype': torch.long, 'device': store.edge_index.device}
offset = [[node_slices[src][0]], [node_slices[dst][0]]]
offset = torch.tensor(offset, **kwargs)
edge_indices.append(store.edge_index + offset)
if add_edge_type:
edge_types.append(torch.full((num_edges, ), i, **kwargs))
data._edge_type_names = edge_type_names

if len(edge_indices) > 1:
data.edge_index = torch.cat(edge_indices, dim=-1)
elif len(edge_indices) == 1:
data.edge_index = edge_indices[0]

if len(edge_types) > 1:
data.edge_type = torch.cat(edge_types, dim=0)
elif len(edge_types) == 1:
data.edge_type = edge_types[0]
if not data.can_infer_num_nodes:
data.num_nodes = list(node_slices.values())[-1][1]

# Combine edge attributes into a single tensor:
if edge_attrs is None:
Expand All @@ -554,4 +517,53 @@ def _consistent_size(stores: List[BaseStorage]) -> List[str]:
value = torch.cat(values, dim) if len(values) > 1 else values[0]
data[key] = value

if add_node_type:
sizes = [offset[1] - offset[0] for offset in node_slices.values()]
sizes = torch.tensor(sizes, dtype=torch.long, device=device)
node_type = torch.arange(len(sizes), device=device)
data.node_type = node_type.repeat_interleave(sizes)

if add_edge_type and edge_index is not None:
sizes = [offset[1] - offset[0] for offset in edge_slices.values()]
sizes = torch.tensor(sizes, dtype=torch.long, device=device)
edge_type = torch.arange(len(sizes), device=device)
data.edge_type = edge_type.repeat_interleave(sizes)

return data


# Helper functions ############################################################


def to_homogeneous_edge_index(
data: HeteroData,
) -> Tuple[Optional[Tensor], Dict[NodeType, Any], Dict[EdgeType, Any]]:
# Record slice information per node type:
cumsum = 0
node_slices: Dict[NodeType, Tuple[int, int]] = {}
for node_type, store in data._node_store_dict.items():
num_nodes = store.num_nodes
node_slices[node_type] = (cumsum, cumsum + num_nodes)
cumsum += num_nodes

# Record edge indices and slice information per edge type:
cumsum = 0
edge_indices: List[Tensor] = []
edge_slices: Dict[EdgeType, Tuple[int, int]] = {}
for edge_type, store in data._edge_store_dict.items():
src, _, dst = edge_type
offset = [[node_slices[src][0]], [node_slices[dst][0]]]
offset = torch.tensor(offset, device=store.edge_index.device)
edge_indices.append(store.edge_index + offset)

num_edges = store.num_edges
edge_slices[edge_type] = (cumsum, cumsum + num_edges)
cumsum += num_edges

edge_index = None
if len(edge_indices) == 1: # Memory-efficient `torch.cat`:
edge_index = edge_indices[0]
elif len(edge_indices) > 0:
edge_index = torch.cat(edge_indices, dim=-1)

return edge_index, node_slices, edge_slices
2 changes: 1 addition & 1 deletion torch_geometric/data/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def has_self_loops(self) -> bool:
return int((edge_index[0] == edge_index[1]).sum()) > 0

def is_undirected(self) -> bool:
if self.is_bipartite(): # TODO check for inverse storage.
if self.is_bipartite():
return False

for value in self.values('adj', 'adj_t'):
Expand Down

0 comments on commit 8fdf895

Please sign in to comment.