diff --git a/CHANGELOG.md b/CHANGELOG.md index aa11b32c5ffe..877a367b28b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,11 +5,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added `HeteroData.is_undirected()` support ([#4604](https://github.com/pyg-team/pytorch_geometric/pull/4604)) - Added the `Genius` and `Wiki` datasets to `nn.datasets.LINKXDataset` ([#4570](https://github.com/pyg-team/pytorch_geometric/pull/4570), [#4600](https://github.com/pyg-team/pytorch_geometric/pull/4600)) - Added `nn.glob.GlobalPooling` module with support for multiple aggregations ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581)) ### Changed -- The `bias` argument in `TAGConv` is now actually apllied ([#4597](https://github.com/pyg-team/pytorch_geometric/pull/4597)) +- The `bias` argument in `TAGConv` is now actually applied ([#4597](https://github.com/pyg-team/pytorch_geometric/pull/4597)) - Fixed subclass behaviour of `process` and `download` in `Datsaet` ([#4586](https://github.com/pyg-team/pytorch_geometric/pull/4586)) ### Removed diff --git a/test/transforms/test_to_undirected.py b/test/transforms/test_to_undirected.py index 167f6535bf8c..3e7070b5c3c1 100644 --- a/test/transforms/test_to_undirected.py +++ b/test/transforms/test_to_undirected.py @@ -53,7 +53,11 @@ def test_hetero_to_undirected(): data['v', 'w'].edge_attr = edge_attr from torch_geometric.transforms import ToUndirected + + assert not data.is_undirected() data = ToUndirected()(data) + assert data.is_undirected() + assert data['v', 'v'].edge_index.tolist() == [[0, 1, 2, 3], [1, 0, 3, 2]] assert data['v', 'v'].edge_weight.tolist() == edge_weight[perm].tolist() assert data['v', 'v'].edge_attr.tolist() == edge_attr[perm].tolist() diff --git a/torch_geometric/data/hetero_data.py b/torch_geometric/data/hetero_data.py index 41d011248840..bf52ad2b2c1b 100644 --- a/torch_geometric/data/hetero_data.py +++ b/torch_geometric/data/hetero_data.py @@ -12,6 +12,7 @@ from torch_geometric.data.data import BaseData, Data, size_repr from torch_geometric.data.storage import BaseStorage, EdgeStorage, NodeStorage from torch_geometric.typing import EdgeType, NodeType, QueryType +from torch_geometric.utils import is_undirected NodeOrEdgeType = Union[NodeType, EdgeType] NodeOrEdgeStorage = Union[NodeStorage, EdgeStorage] @@ -295,6 +296,11 @@ def num_edge_features(self) -> Dict[EdgeType, int]: for key, store in self._edge_store_dict.items() } + def is_undirected(self) -> bool: + r"""Returns :obj:`True` if graph edges are undirected.""" + edge_index, _, _ = to_homogeneous_edge_index(self) + return is_undirected(edge_index, num_nodes=self.num_nodes) + def debug(self): pass # TODO @@ -481,26 +487,14 @@ def _consistent_size(stores: List[BaseStorage]) -> List[str]: if len(sizes) == len(stores) and len(set(sizes)) == 1 ] - data = Data(**self._global_store.to_dict()) + edge_index, node_slices, edge_slices = to_homogeneous_edge_index(self) + device = edge_index.device if edge_index is not None else None - # Iterate over all node stores and record the slice information: - node_slices, cumsum = {}, 0 - node_type_names, node_types = [], [] - for i, (node_type, store) in enumerate(self._node_store_dict.items()): - num_nodes = store.num_nodes - node_slices[node_type] = (cumsum, cumsum + num_nodes) - node_type_names.append(node_type) - cumsum += num_nodes - - if add_node_type: - kwargs = {'dtype': torch.long} - node_types.append(torch.full((num_nodes, ), i, **kwargs)) - data._node_type_names = node_type_names - - if len(node_types) > 1: - data.node_type = torch.cat(node_types, dim=0) - elif len(node_types) == 1: - data.node_type = node_types[0] + data = Data(**self._global_store.to_dict()) + if edge_index is not None: + data.edge_index = edge_index + data._node_type_names = list(node_slices.keys()) + data._edge_type_names = list(edge_slices.keys()) # Combine node attributes into a single tensor: if node_attrs is None: @@ -511,39 +505,8 @@ def _consistent_size(stores: List[BaseStorage]) -> List[str]: value = torch.cat(values, dim) if len(values) > 1 else values[0] data[key] = value - if len([ - key for key in node_attrs - if (key in {'x', 'pos', 'batch'} or 'node' in key) - ]) == 0 and not add_node_type: - data.num_nodes = cumsum - - # Iterate over all edge stores and record the slice information: - edge_slices, cumsum = {}, 0 - edge_indices, edge_type_names, edge_types = [], [], [] - for i, (edge_type, store) in enumerate(self._edge_store_dict.items()): - src, _, dst = edge_type - num_edges = store.num_edges - edge_slices[edge_type] = (cumsum, cumsum + num_edges) - edge_type_names.append(edge_type) - cumsum += num_edges - - kwargs = {'dtype': torch.long, 'device': store.edge_index.device} - offset = [[node_slices[src][0]], [node_slices[dst][0]]] - offset = torch.tensor(offset, **kwargs) - edge_indices.append(store.edge_index + offset) - if add_edge_type: - edge_types.append(torch.full((num_edges, ), i, **kwargs)) - data._edge_type_names = edge_type_names - - if len(edge_indices) > 1: - data.edge_index = torch.cat(edge_indices, dim=-1) - elif len(edge_indices) == 1: - data.edge_index = edge_indices[0] - - if len(edge_types) > 1: - data.edge_type = torch.cat(edge_types, dim=0) - elif len(edge_types) == 1: - data.edge_type = edge_types[0] + if not data.can_infer_num_nodes: + data.num_nodes = list(node_slices.values())[-1][1] # Combine edge attributes into a single tensor: if edge_attrs is None: @@ -554,4 +517,53 @@ def _consistent_size(stores: List[BaseStorage]) -> List[str]: value = torch.cat(values, dim) if len(values) > 1 else values[0] data[key] = value + if add_node_type: + sizes = [offset[1] - offset[0] for offset in node_slices.values()] + sizes = torch.tensor(sizes, dtype=torch.long, device=device) + node_type = torch.arange(len(sizes), device=device) + data.node_type = node_type.repeat_interleave(sizes) + + if add_edge_type and edge_index is not None: + sizes = [offset[1] - offset[0] for offset in edge_slices.values()] + sizes = torch.tensor(sizes, dtype=torch.long, device=device) + edge_type = torch.arange(len(sizes), device=device) + data.edge_type = edge_type.repeat_interleave(sizes) + return data + + +# Helper functions ############################################################ + + +def to_homogeneous_edge_index( + data: HeteroData, +) -> Tuple[Optional[Tensor], Dict[NodeType, Any], Dict[EdgeType, Any]]: + # Record slice information per node type: + cumsum = 0 + node_slices: Dict[NodeType, Tuple[int, int]] = {} + for node_type, store in data._node_store_dict.items(): + num_nodes = store.num_nodes + node_slices[node_type] = (cumsum, cumsum + num_nodes) + cumsum += num_nodes + + # Record edge indices and slice information per edge type: + cumsum = 0 + edge_indices: List[Tensor] = [] + edge_slices: Dict[EdgeType, Tuple[int, int]] = {} + for edge_type, store in data._edge_store_dict.items(): + src, _, dst = edge_type + offset = [[node_slices[src][0]], [node_slices[dst][0]]] + offset = torch.tensor(offset, device=store.edge_index.device) + edge_indices.append(store.edge_index + offset) + + num_edges = store.num_edges + edge_slices[edge_type] = (cumsum, cumsum + num_edges) + cumsum += num_edges + + edge_index = None + if len(edge_indices) == 1: # Memory-efficient `torch.cat`: + edge_index = edge_indices[0] + elif len(edge_indices) > 0: + edge_index = torch.cat(edge_indices, dim=-1) + + return edge_index, node_slices, edge_slices diff --git a/torch_geometric/data/storage.py b/torch_geometric/data/storage.py index b3469eeab3a3..1b51571f277b 100644 --- a/torch_geometric/data/storage.py +++ b/torch_geometric/data/storage.py @@ -427,7 +427,7 @@ def has_self_loops(self) -> bool: return int((edge_index[0] == edge_index[1]).sum()) > 0 def is_undirected(self) -> bool: - if self.is_bipartite(): # TODO check for inverse storage. + if self.is_bipartite(): return False for value in self.values('adj', 'adj_t'):