Skip to content

Commit

Permalink
Merge branch 'master' into padarn/linkx-datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored May 6, 2022
2 parents 7046efd + 926b5dc commit 631fdd7
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 11 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
### Changed
- The `bias` argument in `TAGConv` is now actually apllied ([#4597](https://github.com/pyg-team/pytorch_geometric/pull/4597))
- Fixed subclass behaviour of `process` and `download` in `Datsaet` ([#4586](https://github.com/pyg-team/pytorch_geometric/pull/4586))
### Removed
61 changes: 60 additions & 1 deletion test/data/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from torch_geometric.data import Data, HeteroData, InMemoryDataset
from torch_geometric.data import Data, Dataset, HeteroData, InMemoryDataset


class MyTestDataset(InMemoryDataset):
Expand Down Expand Up @@ -99,3 +99,62 @@ def test_hetero_in_memory_dataset():
assert dataset[1]['paper'].x.tolist() == data2['paper'].x.tolist()
assert (dataset[1]['paper', 'paper'].edge_index.tolist() == data2[
'paper', 'paper'].edge_index.tolist())


def test_override_behaviour():
class DS(Dataset):
def __init__(self):
self.enter_download = False
self.enter_process = False
super().__init__()

def _download(self):
self.enter_download = True

def _process(self):
self.enter_process = True

def download(self):
pass

def process(self):
pass

class DS2(Dataset):
def __init__(self):
self.enter_download = False
self.enter_process = False
super().__init__()

def _download(self):
self.enter_download = True

def _process(self):
self.enter_process = True

def process(self):
pass

class DS3(Dataset):
def __init__(self):
self.enter_download = False
self.enter_process = False
super().__init__()

def _download(self):
self.enter_download = True

def _process(self):
self.enter_process = True

ds = DS()
assert ds.enter_download
assert ds.enter_process

ds = DS2()
assert not ds.enter_download
assert ds.enter_process

ds = DS3()
assert not ds.enter_download
assert not ds.enter_process
4 changes: 2 additions & 2 deletions torch_geometric/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ def __init__(self, root: Optional[str] = None,
self.pre_filter = pre_filter
self._indices: Optional[Sequence] = None

if 'download' in self.__class__.__dict__:
if self.download.__qualname__.split('.')[0] != 'Dataset':
self._download()

if 'process' in self.__class__.__dict__:
if self.process.__qualname__.split('.')[0] != 'Dataset':
self._process()

def indices(self) -> Sequence:
Expand Down
6 changes: 0 additions & 6 deletions torch_geometric/data/in_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ def raw_file_names(self) -> Union[str, List[str], Tuple]:
def processed_file_names(self) -> Union[str, List[str], Tuple]:
raise NotImplementedError

def download(self):
raise NotImplementedError

def process(self):
raise NotImplementedError

def __init__(self, root: Optional[str] = None,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
Expand Down
16 changes: 14 additions & 2 deletions torch_geometric/nn/conv/tag_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import Adj, OptTensor


Expand Down Expand Up @@ -51,14 +52,21 @@ def __init__(self, in_channels: int, out_channels: int, K: int = 3,
self.K = K
self.normalize = normalize

self.lins = torch.nn.ModuleList(
[Linear(in_channels, out_channels) for _ in range(K + 1)])
self.lins = torch.nn.ModuleList([
Linear(in_channels, out_channels, bias=False) for _ in range(K + 1)
])

if bias:
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)

self.reset_parameters()

def reset_parameters(self):
for lin in self.lins:
lin.reset_parameters()
zeros(self.bias)

def forward(self, x: Tensor, edge_index: Adj,
edge_weight: OptTensor = None) -> Tensor:
Expand All @@ -80,6 +88,10 @@ def forward(self, x: Tensor, edge_index: Adj,
x = self.propagate(edge_index, x=x, edge_weight=edge_weight,
size=None)
out += lin.forward(x)

if self.bias is not None:
out += self.bias

return out

def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
Expand Down

0 comments on commit 631fdd7

Please sign in to comment.