From 49c462b90521c67c9937fc1547ae6ebe7641038d Mon Sep 17 00:00:00 2001 From: Alice Schwarze Date: Tue, 20 Dec 2022 08:03:56 -0500 Subject: [PATCH] Update xgi_data.py (#254) * Update xgi_data.py * made helper function unavailable and changed name and docstring of _clean_file_path to _make_unix_file_path * Update xgi/readwrite/xgi_data.py Co-authored-by: Maxime Lucas * added new function to docs * Update xgi/readwrite/xgi_data.py Co-authored-by: Maxime Lucas * Update xgi/readwrite/xgi_data.py Co-authored-by: Maxime Lucas * Update xgi/readwrite/xgi_data.py Co-authored-by: Maxime Lucas * replaced _make_unix_file_path with os.path.join * added ".json" to file paths Co-authored-by: Maxime Lucas --- .../api/readwrite/xgi.readwrite.xgi_data.rst | 3 +- xgi/readwrite/xgi_data.py | 84 ++++++++++++++++--- 2 files changed, 75 insertions(+), 12 deletions(-) diff --git a/docs/source/api/readwrite/xgi.readwrite.xgi_data.rst b/docs/source/api/readwrite/xgi.readwrite.xgi_data.rst index 822dd51db..f754059ef 100644 --- a/docs/source/api/readwrite/xgi.readwrite.xgi_data.rst +++ b/docs/source/api/readwrite/xgi.readwrite.xgi_data.rst @@ -7,4 +7,5 @@ xgi.readwrite.xgi_data .. rubric:: Functions - .. autofunction:: load_xgi_data \ No newline at end of file + .. autofunction:: load_xgi_data + .. autofunction:: download_xgi_data diff --git a/xgi/readwrite/xgi_data.py b/xgi/readwrite/xgi_data.py index e17023e0c..e6f94ac11 100644 --- a/xgi/readwrite/xgi_data.py +++ b/xgi/readwrite/xgi_data.py @@ -1,13 +1,15 @@ import requests - +import json +import os +from warnings import warn from .. import convert from ..exception import XGIError -__all__ = ["load_xgi_data"] - +__all__ = ["load_xgi_data", "download_xgi_data"] -def load_xgi_data(dataset, nodetype=None, edgetype=None, max_order=None): - """_summary_ +def load_xgi_data(dataset, path='', read=True, nodetype=None, edgetype=None, + max_order=None): + """Load a data set from the xgi-data repository or a local file. Parameters ---------- @@ -15,6 +17,11 @@ def load_xgi_data(dataset, nodetype=None, edgetype=None, max_order=None): Dataset name. Valid options are the top-level tags of the index.json file in the xgi-data repository. + path : str, optional + Path to a local copy of the data set + read : bool, optional + If read==True, search for a local copy of the data set. Use the local + copy if it exists, otherwise use the xgi-data repository. nodetype : type, optional Type to cast the node ID to edgetype : type, optional @@ -32,15 +39,70 @@ def load_xgi_data(dataset, nodetype=None, edgetype=None, max_order=None): XGIError The specified dataset does not exist. """ - index_url = "https://raw.githubusercontent.com/ComplexGroupInteractions/xgi-data/main/index.json" + + if read: + cfp = os.path.join(path, dataset+'.json') + if os.path.exists(cfp): + data = json.load(open(cfp, 'r')) + else: + warn(f"No local copy was found at {cfp}. The data is requested from the xgi-data repository instead. To download a local copy, use `download_xgi_data`.") + data = _request_from_xgi_data(dataset) + else: + data = _request_from_xgi_data(dataset) + + return convert.dict_to_hypergraph( + data, nodetype=nodetype, edgetype=edgetype, max_order=max_order + ) + + +def download_xgi_data(dataset, path=''): + """Make a local copy of a dataset in the xgi-data repository. + + Parameters + ---------- + dataset : str + Dataset name. Valid options are the top-level tags of the + index.json file in the xgi-data repository. + + path : str, optional + Path to where the local copy should be saved. If none is given, save + file to local directory. + """ + + jsondata = _request_from_xgi_data(dataset) + jsonfile = open(os.path.join(path, dataset+'.json'), 'w') + json.dump(jsondata, jsonfile) + jsonfile.close() + + +def _request_from_xgi_data(dataset): + """Request a dataset from xgi-data. + + Parameters + ---------- + dataset : str + Dataset name. Valid options are the top-level tags of the + index.json file in the xgi-data repository. + + Returns + ------- + Data + The requested data loaded from a json file. + + See also + --------- + load_xgi_data + """ + + index_url = "https://gitlab.com/complexgroupinteractions/xgi-data/-/raw/main/index.json?inline=false" index = requests.get(index_url).json() - if dataset not in index: + + key = dataset.lower() + if key not in index: print("Valid dataset names:") print(*index, sep="\n") raise XGIError("Must choose a valid dataset name!") - r = requests.get(index[dataset]["url"]) + data = requests.get(index[key]["url"]).json() - return convert.dict_to_hypergraph( - r.json(), nodetype=nodetype, edgetype=edgetype, max_order=max_order - ) + return data