Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hotfix] fix url in dblp_new.py #76

Merged
merged 1 commit into from
May 9, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions federatedscope/gfl/dataset/dblp_new.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
import os
import os.path as osp

import numpy as np
import networkx as nx
import torch
from torch_geometric.data import InMemoryDataset, download_url
from torch_geometric.utils import from_networkx
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction._stop_words import ENGLISH_STOP_WORDS as sklearn_stopwords
from nltk import word_tokenize
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords as nltk_stopwords


class LemmaTokenizer(object):
def __init__(self):
from nltk.stem import WordNetLemmatizer
self.wnl = WordNetLemmatizer()

def __call__(self, doc):
from nltk import word_tokenize
return [self.wnl.lemmatize(t) for t in word_tokenize(doc)]


def build_feature(words, threshold):
from nltk.corpus import stopwords as nltk_stopwords
# use bag-of-words representation of paper titles as the features of papers
stopwords = sklearn_stopwords.union(set(nltk_stopwords.words('english')))
vectorizer = CountVectorizer(min_df=int(threshold),
Expand All @@ -32,10 +30,8 @@ def build_feature(words, threshold):
return features_paper


def build_graph(path, FL=0, threshold=15):

filename = 'gfl%2Fpaper_classification_dataset.tsv'
with open(os.path.join(path, filename), 'r') as f:
def build_graph(path, filename, FL=0, threshold=15):
with open(osp.join(path, filename), 'r') as f:
node_cnt = sum([1 for line in f])

G = nx.DiGraph()
Expand All @@ -47,7 +43,7 @@ def build_graph(path, FL=0, threshold=15):
org2paper = dict()

# Build node feature from title
with open(os.path.join(path, filename), 'r') as f:
with open(osp.join(path, filename), 'r') as f:
for line in f:
cols = line.strip().split('\t')
nid, title = int(cols[0]), cols[3]
Expand All @@ -57,7 +53,7 @@ def build_graph(path, FL=0, threshold=15):
dtype=np.float32)

# Build graph structure
with open(os.path.join(path, filename), 'r') as f:
with open(osp.join(path, filename), 'r') as f:
for line in f:
cols = line.strip().split('\t')
nid, conf, org, label = int(cols[0]), cols[1], cols[2], int(
Expand Down Expand Up @@ -132,9 +128,7 @@ def __init__(self,

@property
def raw_file_names(self):
names = [
'gfl%2Fpaper_classification_dataset.tsv',
]
names = ['dblp_new.tsv']
return names

@property
Expand All @@ -151,13 +145,13 @@ def processed_dir(self):

def download(self):
# Download to `self.raw_dir`.
url = 'xxx.com'
url = 'https://federatedscope.oss-cn-beijing.aliyuncs.com'
for name in self.raw_file_names:
download_url(osp.join(url, name), self.raw_dir)

def process(self):
# Read data into huge `Data` list.
data_list = build_graph(self.raw_dir, self.FL)
data_list = build_graph(self.raw_dir, self.raw_file_names[0], self.FL)

data_list_w_masks = []
for data in data_list:
Expand Down