Skip to content

Commit

Permalink
Merge pull request #64 from SherylHYX/K_convention
Browse files Browse the repository at this point in the history
K_convention to represent cheb polynomial order
  • Loading branch information
SherylHYX authored Feb 9, 2025
2 parents 10f54d8 + 86f9f33 commit 029ed9f
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 30 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
python-version: [3.8, 3.9]
os: [ubuntu-latest]
torch-version: [2.3.0]
include:
Expand Down Expand Up @@ -51,7 +51,7 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test coverage.
run: |
python setup.py test
pytest
- name: Run codecov
if: success()
env:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ pip install torch-geometric-signed-directed
**Running tests**

```
$ python setup.py test
$ pytest
```
--------------------------------------------------------------------------------

Expand Down
5 changes: 1 addition & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
[metadata]
description-file = README.md

[aliases]
test=pytest
description_file = README.md

[tool:pytest]
addopts = --capture=no --cov
9 changes: 4 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
"scipy"
]

setup_requires = ["pytest-runner"]

tests_require = ["pytest", "pytest-cov", "mock"]
extras_require = {
"test": ["pytest", "pytest-cov", "mock"]
}

keywords = [
"machine-learning",
Expand Down Expand Up @@ -53,8 +53,7 @@
download_url='{}/archive/{}.tar.gz'.format(url, __version__),
keywords=keywords,
install_requires=install_requires,
setup_requires=setup_requires,
tests_require=tests_require,
extras_require=extras_require,
python_requires=">=3.7",
classifiers=[
"Intended Audience :: Developers",
Expand Down
8 changes: 4 additions & 4 deletions test/directed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_MagNet():
num_nodes, num_classes
)

model = MagNet_node_classification(X.shape[1], K=3, label_dim=num_classes, layer=3, trainable_q=True,
model = MagNet_node_classification(X.shape[1], K=2, label_dim=num_classes, layer=3, trainable_q=True,
activation=True, hidden=2, dropout=0.5, cached=True).to(device)
preds = model(X, X, edge_index, edge_weight)

Expand All @@ -71,7 +71,7 @@ def test_MagNet():
num_nodes, num_classes
)
assert model.Chebs[0].__repr__(
) == 'MagNetConv(3, 2, K=3, normalization=sym)'
) == 'MagNetConv(3, 2, filter size=3, normalization=sym)'

model.reset_parameters()

Expand All @@ -98,7 +98,7 @@ def test_MagNet_Link():
len(link_data[0]['train']['edges']), num_classes
)

model = MagNet_link_prediction(data.x.shape[1], K=3, label_dim=num_classes, layer=3, trainable_q=True,
model = MagNet_link_prediction(data.x.shape[1], K=2, label_dim=num_classes, layer=3, trainable_q=True,
activation=True, hidden=2, dropout=0.5).to(device)
preds = model(data.x, data.x, link_data[0]['graph'], query_edges=link_data[0]['train']['edges'],
edge_weight=link_data[0]['weights'])
Expand All @@ -107,7 +107,7 @@ def test_MagNet_Link():
len(link_data[0]['train']['edges']), num_classes
)
assert model.Chebs[0].__repr__(
) == 'MagNetConv(3, 2, K=3, normalization=sym)'
) == 'MagNetConv(3, 2, filter size=3, normalization=sym)'

num_classes = 3
link_data = link_class_split(
Expand Down
8 changes: 4 additions & 4 deletions test/general_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_MSGNN():
num_nodes, num_classes
)

model = MSGNN_node_classification(q=0.25, K=3, num_features=X.shape[1], hidden=2, label_dim=num_classes,
model = MSGNN_node_classification(q=0.25, K=2, num_features=X.shape[1], hidden=2, label_dim=num_classes,
dropout=0.5, cached=True, normalization=None).to(device)
_, _, _, preds = model(X, X, edge_index=edge_index,
edge_weight=edge_weight)
Expand All @@ -150,7 +150,7 @@ def test_MSGNN():
num_nodes, num_classes
)
assert model.Chebs[0].__repr__(
) == 'MSConv(3, 2, K=3, normalization=None)'
) == 'MSConv(3, 2, filter size=3, normalization=None)'

model.reset_parameters()

Expand All @@ -167,7 +167,7 @@ def test_MSGNN_Link():
data = SignedData(x=X, edge_index=edge_index, edge_weight=edge_weight)
link_data = link_class_split(data, splits=2, task="four_class_signed_digraph", prob_val=0.15, prob_test=0.1, seed=10, device=device)

model = MSGNN_link_prediction(q=0.25, K=3, num_features=num_features, hidden=2, label_dim=num_classes, \
model = MSGNN_link_prediction(q=0.25, K=2, num_features=num_features, hidden=2, label_dim=num_classes, \
trainable_q = False, dropout=0.5, cached=True).to(device)
preds = model(data.x, data.x, edge_index=link_data[0]['graph'], query_edges=link_data[0]['train']['edges'],
edge_weight=link_data[0]['weights'])
Expand All @@ -183,7 +183,7 @@ def test_MSGNN_Link():
len(link_data[0]['train']['edges']), num_classes
)
assert model.Chebs[0].__repr__(
) == 'MSConv(3, 2, K=3, normalization=sym)'
) == 'MSConv(3, 2, filter size=3, normalization=sym)'

num_classes = 5
link_data = link_class_split(data, splits=2, task="five_class_signed_digraph", prob_val=0.15, prob_test=0.1, seed=10, device=device)
Expand Down
6 changes: 3 additions & 3 deletions torch_geometric_signed_directed/nn/directed/MagNetConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class MagNetConv(MessagePassing):
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
K (int): Order of the Chebyshev polynomial plus 1, i.e., Chebyshev filter size :math:`K`.
K (int): Order of the Chebyshev polynomial, i.e., Chebyshev filter size minus 1 :math:`K`.
q (float, optional): Initial value of the phase parameter, 0 <= q <= 0.25. Default: 0.25.
trainable_q (bool, optional): whether to set q to be trainable or not. (default: :obj:`False`)
normalization (str, optional): The normalization scheme for the magnetic
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(self, in_channels: int, out_channels: int, K: int, q: float, traina
self.q = Parameter(torch.Tensor(1).fill_(q))
else:
self.q = q
self.weight = Parameter(torch.Tensor(K, in_channels, out_channels))
self.weight = Parameter(torch.Tensor(K+1, in_channels, out_channels))

if bias:
self.bias = Parameter(torch.Tensor(out_channels))
Expand Down Expand Up @@ -252,6 +252,6 @@ def message(self, x_j, norm):
return norm.view(-1, 1) * x_j

def __repr__(self):
return '{}({}, {}, K={}, normalization={})'.format(
return '{}({}, {}, filter size={}, normalization={})'.format(
self.__class__.__name__, self.in_channels, self.out_channels,
self.weight.size(0), self.normalization)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class MagNet_link_prediction(nn.Module):
Args:
num_features (int): Size of each input sample.
hidden (int, optional): Number of hidden channels. Default: 2.
K (int, optional): Order of the Chebyshev polynomial plus 1, i.e., Chebyshev filter size :math:`K`. Default: 2.
K (int, optional): Order of the Chebyshev polynomial, i.e., Chebyshev filter size minus 1 :math:`K`. Default: 1.
q (float, optional): Initial value of the phase parameter, 0 <= q <= 0.25. Default: 0.25.
label_dim (int, optional): Number of output classes. Default: 2.
activation (bool, optional): whether to use activation function or not. (default: :obj:`True`)
Expand All @@ -36,7 +36,7 @@ class MagNet_link_prediction(nn.Module):
learning scenarios. (default: :obj:`False`)
"""

def __init__(self, num_features: int, hidden: int = 2, q: float = 0.25, K: int = 2, label_dim: int = 2,
def __init__(self, num_features: int, hidden: int = 2, q: float = 0.25, K: int = 1, label_dim: int = 2,
activation: bool = True, trainable_q: bool = False, layer: int = 2, dropout: float = 0.5, normalization: str = 'sym', cached: bool = False):
super(MagNet_link_prediction, self).__init__()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class MagNet_node_classification(nn.Module):
Args:
num_features (int): Size of each input sample.
hidden (int, optional): Number of hidden channels. Default: 2.
K (int, optional): Order of the Chebyshev polynomial plus 1, i.e., Chebyshev filter size :math:`K`. Default: 2.
K (int, optional): Order of the Chebyshev polynomial, i.e., Chebyshev filter size minus 1 :math:`K`. Default: 1.
q (float, optional): Initial value of the phase parameter, 0 <= q <= 0.25. Default: 0.25.
label_dim (int, optional): Number of output classes. Default: 2.
activation (bool, optional): whether to use activation function or not. (default: :obj:`False`)
Expand All @@ -37,7 +37,7 @@ class MagNet_node_classification(nn.Module):
learning scenarios. (default: :obj:`False`)
"""

def __init__(self, num_features: int, hidden: int = 2, q: float = 0.25, K: int = 2, label_dim: int = 2,
def __init__(self, num_features: int, hidden: int = 2, q: float = 0.25, K: int = 1, label_dim: int = 2,
activation: bool = False, trainable_q: bool = False, layer: int = 2, dropout: float = False, normalization: str = 'sym', cached: bool = False):
super(MagNet_node_classification, self).__init__()

Expand Down
6 changes: 3 additions & 3 deletions torch_geometric_signed_directed/nn/general/MSConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class MSConv(MessagePassing):
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
K (int): Order of the Chebyshev polynomial plus 1, i.e., Chebyshev filter size :math:`K`.
K (int): Order of the Chebyshev polynomial, i.e., Chebyshev filter size minus 1 :math:`K`.
q (float, optional): Initial value of the phase parameter, 0 <= q <= 0.25. Default: 0.25.
trainable_q (bool, optional): whether to set q to be trainable or not. (default: :obj:`False`)
normalization (str, optional): The normalization scheme for the magnetic
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(self, in_channels:int, out_channels:int, K:int, q:float, trainable_
self.q = Parameter(torch.Tensor(1).fill_(q))
else:
self.q = q
self.weight = Parameter(torch.Tensor(K, in_channels, out_channels))
self.weight = Parameter(torch.Tensor(K+1, in_channels, out_channels))

if bias:
self.bias = Parameter(torch.Tensor(out_channels))
Expand Down Expand Up @@ -234,6 +234,6 @@ def message(self, x_j, norm):
return norm.view(-1, 1) * x_j

def __repr__(self):
return '{}({}, {}, K={}, normalization={})'.format(
return '{}({}, {}, filter size={}, normalization={})'.format(
self.__class__.__name__, self.in_channels, self.out_channels,
self.weight.size(0), self.normalization)

0 comments on commit 029ed9f

Please sign in to comment.