Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Sep 2, 2021
1 parent d47dad4 commit 7c08f7d
Showing 1 changed file with 7 additions and 19 deletions.
26 changes: 7 additions & 19 deletions torch_geometric/utils/convert.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional, Union, Tuple, List

from collections import defaultdict

import torch
import scipy.sparse
from torch import Tensor
Expand Down Expand Up @@ -132,7 +134,7 @@ def from_networkx(G, group_node_attrs: Optional[Union[List[str], all]] = None,
G = G.to_directed() if not nx.is_directed(G) else G
edge_index = torch.LongTensor(list(G.edges)).t().contiguous()

data = {}
data = defaultdict(list)

if G.number_of_nodes() > 0:
node_attrs = list(next(iter(G.nodes(data=True)))[-1].keys())
Expand All @@ -148,17 +150,17 @@ def from_networkx(G, group_node_attrs: Optional[Union[List[str], all]] = None,
if set(feat_dict.keys()) != set(node_attrs):
raise ValueError('Not all nodes contain the same attributes')
for key, value in feat_dict.items():
data[str(key)] = [value] if i == 0 else data[str(key)] + [value]
data[str(key)].append(value)

for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
if set(feat_dict.keys()) != set(edge_attrs):
raise ValueError('Not all edges contain the same attributes')
for key, value in feat_dict.items():
data[str(key)] = [value] if i == 0 else data[str(key)] + [value]
data[str(key)].append(value)

for key, item in data.items():
for key, value in data.items():
try:
data[key] = torch.tensor(item)
data[key] = torch.tensor(value)
except ValueError:
pass

Expand All @@ -181,20 +183,6 @@ def from_networkx(G, group_node_attrs: Optional[Union[List[str], all]] = None,
edge_attrs = [x.view(-1, 1) if x.dim() <= 1 else x for x in edge_attrs]
data.edge_attr = torch.cat(edge_attrs, dim=-1)

if group_node_attrs is all:
group_node_attrs = list(node_attrs)
if group_node_attrs is not None:
xs = [data[key] for key in group_node_attrs]
xs = [x.view(-1, 1) if x.dim() <= 1 else x for x in xs]
data.x = torch.cat(xs, dim=-1)

if group_edge_attrs is all:
group_edge_attrs = list(edge_attrs)
if group_edge_attrs is not None:
edge_attrs = [data[key] for key in group_edge_attrs]
edge_attrs = [x.view(-1, 1) if x.dim() <= 1 else x for x in edge_attrs]
data.edge_attr = torch.cat(edge_attrs, dim=-1)

return data


Expand Down

0 comments on commit 7c08f7d

Please sign in to comment.