diff --git a/hepcrawl/spiders/arxiv_spider.py b/hepcrawl/spiders/arxiv_spider.py index 1337b84f..28cd032d 100644 --- a/hepcrawl/spiders/arxiv_spider.py +++ b/hepcrawl/spiders/arxiv_spider.py @@ -24,6 +24,7 @@ get_licenses, split_fullname, ParsedItem, + strict_kwargs, ) RE_CONFERENCE = re.compile( @@ -47,15 +48,22 @@ class ArxivSpider(OAIPMHSpider): Using OAI-PMH XML files:: $ scrapy crawl arXiv \\ - -a "oai_set=physics:hep-th" -a "from_date=2017-12-13" + -a "sets=physics:hep-th" -a "from_date=2017-12-13" """ name = 'arXiv' - def __init__(self, *args, **kwargs): - kwargs.setdefault('url', 'http://export.arxiv.org/oai2') - kwargs.setdefault('format', 'arXiv') - super(ArxivSpider, self).__init__(*args, **kwargs) + @strict_kwargs + def __init__( + self, + url='http://export.arxiv.org/oai2', + format='arXiv', + sets=None, + from_date=None, + until_date=None, + **kwargs + ): + super(ArxivSpider, self).__init__(**self._all_kwargs) def parse_record(self, selector): """Parse an arXiv XML exported file into a HEP record.""" diff --git a/hepcrawl/spiders/common/oaipmh_spider.py b/hepcrawl/spiders/common/oaipmh_spider.py index fb69c741..282759ee 100644 --- a/hepcrawl/spiders/common/oaipmh_spider.py +++ b/hepcrawl/spiders/common/oaipmh_spider.py @@ -12,6 +12,7 @@ import abc import logging from datetime import datetime + from six import string_types from sickle import Sickle @@ -21,6 +22,7 @@ from scrapy.selector import Selector from .last_run_store import LastRunStoreSpider +from ...utils import strict_kwargs LOGGER = logging.getLogger(__name__) @@ -47,6 +49,7 @@ class OAIPMHSpider(LastRunStoreSpider): __metaclass__ = abc.ABCMeta name = 'OAI-PMH' + @strict_kwargs def __init__( self, url, @@ -55,9 +58,9 @@ def __init__( alias=None, from_date=None, until_date=None, - *args, **kwargs + **kwargs ): - super(OAIPMHSpider, self).__init__(*args, **kwargs) + super(OAIPMHSpider, self).__init__(**self._all_kwargs) self.url = url self.format = format if isinstance(sets, string_types): diff --git a/hepcrawl/utils.py b/hepcrawl/utils.py index e504c4a6..ae6b8b35 100644 --- a/hepcrawl/utils.py +++ b/hepcrawl/utils.py @@ -9,10 +9,12 @@ from __future__ import absolute_import, division, print_function +import inspect import fnmatch import os import pprint import re +from functools import wraps from operator import itemgetter from itertools import groupby from netrc import netrc @@ -382,6 +384,62 @@ def get_licenses( return license +def strict_kwargs(func): + """This decorator will disallow any keyword arguments that + do not begin with an underscore sign in the decorated method. + This is mainly to make errors while passing arguments to spiders + immediately visible. As we cannot remove kwargs from there altogether + (used by scrapyd), with this we can ensure that we are not passing unwanted + data by mistake. + + Additionaly this will add all of the 'public' kwargs to an `_init_kwargs` + field in the object for easier passing and all of the arguments (including + non-overloaded ones) to `_all_kwargs`. (To make passing them forward + easier.) + + Args: + func (function): a spider method + + Returns: + function: method which will disallow any keyword arguments that + do not begin with an underscore sign. + """ + argspec = inspect.getargspec(func) + defined_arguments = argspec.args[1:] + spider_fields = ['settings'] + + allowed_arguments = defined_arguments + spider_fields + + if argspec.defaults: + defaults = dict( + zip(argspec.args[-len(argspec.defaults):], argspec.defaults) + ) + else: + defaults = {} + + @wraps(func) + def wrapper(self, *args, **kwargs): + disallowed_kwargs = [ + key for key in kwargs + if not key.startswith('_') and key not in allowed_arguments + ] + + if disallowed_kwargs: + raise TypeError( + 'Only underscored kwargs allowed in {}. ' + 'Check {} for typos.'.format(func, ', '.join(disallowed_kwargs)) + ) + + defaults.update(kwargs) + self._init_kwargs = { + k: v for k, v in defaults.items() + if not k.startswith('_') and k not in spider_fields + } + self._all_kwargs = defaults + return func(self, *args, **kwargs) + return wrapper + + class RecordFile(object): """Metadata of a file needed for a record.