diff --git a/.gitignore b/.gitignore index bf2ad57..3e152bb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ build/ dist/ +htmlcov/ *egg-info/ __pycache__/ .coverage +.coverage.* *.pyc diff --git a/.pylintrc b/.pylintrc index 7d65d67..dcb45dd 100644 --- a/.pylintrc +++ b/.pylintrc @@ -17,4 +17,4 @@ min-public-methods=0 max-line-length=150 [MESSAGES CONTROL] -disable=I0011,R0201,W0105,W0108,W0110,W0141,W0621,W0640 +disable=R0201 diff --git a/Makefile b/Makefile index 711d53c..e0e41a0 100644 --- a/Makefile +++ b/Makefile @@ -14,7 +14,7 @@ PYTHON ?= python # Download test data data: mkdir -p /tmp/paperai - wget -N https://github.com/neuml/paperai/releases/download/v1.3.0/tests.tar.gz -P /tmp + wget -N https://github.com/neuml/paperai/releases/download/v1.10.0/tests.tar.gz -P /tmp tar -xvzf /tmp/tests.tar.gz -C /tmp # Unit tests diff --git a/README.md b/README.md index 6d22e8c..3dc633c 100644 --- a/README.md +++ b/README.md @@ -46,11 +46,11 @@ The easiest way to install is via pip and PyPI pip install paperai -You can also install paperai directly from GitHub. Using a Python Virtual Environment is recommended. +Python 3.6+ is supported. Using a Python [virtual environment](https://docs.python.org/3/library/venv.html) is recommended. - pip install git+https://github.com/neuml/paperai +paperai can also be installed directly from GitHub to access the latest, unreleased features. -Python 3.6+ is supported + pip install git+https://github.com/neuml/paperai See [this link](https://github.com/neuml/txtai#installation) to help resolve environment-specific install issues. @@ -125,7 +125,7 @@ no parameters are passed in. ## Building a report file Reports support generating output in multiple formats. An example report call: - python -m paperai.report tasks/risks.yml 50 md cord19/models + python -m paperai.report report.yml 50 md cord19/models The following report formats are supported: @@ -133,7 +133,7 @@ The following report formats are supported: - CSV - Renders a CSV report. Columns and answers are extracted from articles with the results stored in a CSV file. - Annotation - Columns and answers are extracted from articles with the results annotated over the original PDF files. Requires passing in a path with the original PDF files. -In the example above, a file named tasks/risk_factors.md will be created. Example report configuration files can be found [here](https://github.com/neuml/cord19q/tree/master/tasks). +In the example above, a file named report.md will be created. Example report configuration files can be found [here](https://github.com/neuml/cord19q/tree/master/tasks). ## Running queries The fastest way to run queries is to start a paperai shell diff --git a/demo.png b/demo.png index 88efd1b..6ec9969 100644 Binary files a/demo.png and b/demo.png differ diff --git a/examples/search.py b/examples/search.py index 0d7b550..3339563 100644 --- a/examples/search.py +++ b/examples/search.py @@ -28,8 +28,8 @@ def __init__(self, path): """ # Default list of columns - self.columns = [("Title", True), ("Published", False), ("Publication", False), ("Design", False), ("Sample", False), - ("Method", False), ("Entry", False), ("Id", False), ("Content", True)] + self.columns = [("Title", True), ("Published", False), ("Publication", False), ("Entry", False), + ("Id", False), ("Content", True)] # Load model self.path = path @@ -58,17 +58,16 @@ def search(self, query, topn, threshold): # Print each result, sorted by max score descending for uid in sorted(documents, key=lambda k: sum([x[0] for x in documents[k]]), reverse=True): - cur.execute("SELECT Title, Published, Publication, Design, Size, Sample, Method, Entry, Id, Reference " + + cur.execute("SELECT Title, Published, Publication, Entry, Id, Reference " + "FROM articles WHERE id = ?", [uid]) article = cur.fetchone() matches = "
".join([text for _, text in documents[uid]]) - title = "%s" % (article[9], article[0]) + title = f"{article[0]}" - article = {"Title": title, "Published": Query.date(article[1]), "Publication": article[2], "Design": Query.design(article[3]), - "Sample": Query.sample(article[4], article[5]), "Method": Query.text(article[6]), "Entry": article[7], - "Id": article[8], "Content": matches} + article = {"Title": title, "Published": Query.date(article[1]), "Publication": article[2], "Entry": article[3], + "Id": article[4], "Content": matches} articles.append(article) diff --git a/setup.py b/setup.py index 77f9a0b..f247072 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,11 @@ # pylint: disable = C0111 from setuptools import find_packages, setup -with open("README.md", "r") as f: +with open("README.md", "r", encoding="utf-8") as f: DESCRIPTION = f.read() setup(name="paperai", - version="1.11.0", + version="2.0.0", author="NeuML", description="AI-powered literature discovery and review engine for medical/scientific papers", long_description=DESCRIPTION, @@ -32,6 +32,7 @@ "networkx>=2.4", "PyYAML>=5.3", "regex>=2020.5.14", + "text2digits>=0.1.0", "txtai[api,similarity]>=3.4.0", "txtmarker>=1.0.0" ], diff --git a/src/python/paperai/api.py b/src/python/paperai/api.py index 3e8c909..6b65c62 100644 --- a/src/python/paperai/api.py +++ b/src/python/paperai/api.py @@ -14,7 +14,7 @@ class API(txtai.api.API): Extended API on top of txtai to return enriched query results. """ - def search(self, query, request): + def search(self, query, request=None): """ Extends txtai API to enrich results with content. @@ -28,8 +28,8 @@ def search(self, query, request): if self.embeddings: dbfile = os.path.join(self.config["path"], "articles.sqlite") - limit = self.limit(request.query_params.get("limit")) - threshold = float(request.query_params["threshold"]) if "threshold" in request.query_params else None + limit = self.limit(request.query_params.get("limit")) if request else 10 + threshold = float(request.query_params["threshold"]) if request and "threshold" in request.query_params else None with sqlite3.connect(dbfile) as db: cur = db.cursor() @@ -44,16 +44,15 @@ def search(self, query, request): # Print each result, sorted by max score descending for uid in sorted(documents, key=lambda k: sum([x[0] for x in documents[k]]), reverse=True): - cur.execute("SELECT Title, Published, Publication, Design, Size, Sample, Method, Entry, Id, Reference " + + cur.execute("SELECT Title, Published, Publication, Entry, Id, Reference " + "FROM articles WHERE id = ?", [uid]) article = cur.fetchone() score = max([score for score, text in documents[uid]]) matches = [text for _, text in documents[uid]] - article = {"id": article[8], "score": score, "title": article[0], "published": Query.date(article[1]), "publication": article[2], - "design": Query.design(article[3]), "sample": Query.sample(article[4], article[5]), "method": Query.text(article[6]), - "entry": article[7], "reference": article[9], "matches": matches} + article = {"id": article[4], "score": score, "title": article[0], "published": Query.date(article[1]), "publication": article[2], + "entry": article[3], "reference": article[5], "matches": matches} articles.append(article) diff --git a/src/python/paperai/export.py b/src/python/paperai/export.py index 1ab1582..d18670a 100644 --- a/src/python/paperai/export.py +++ b/src/python/paperai/export.py @@ -14,7 +14,7 @@ from .index import Index from .models import Models -class Export(object): +class Export: """ Exports database rows into a text file line-by-line. """ @@ -29,26 +29,26 @@ def stream(dbfile, output): output: output file to store text """ - with open(output, "w") as output: + with open(output, "w", encoding="utf-8") as output: # Connection to database file db = sqlite3.connect(dbfile) cur = db.cursor() - # Get all indexed text, with a detected study design, excluding modeling designs - cur.execute(Index.SECTION_QUERY + " AND design NOT IN (0, 9)") + # Get all indexed text + cur.execute(Index.SECTION_QUERY) count = 0 for _, name, text in cur: if not name or not re.search(Index.SECTION_FILTER, name.lower()): count += 1 if count % 1000 == 0: - print("Streamed %d documents" % (count), end="\r") + print(f"Streamed {count} documents", end="\r") # Write row if text: output.write(text + "\n") - print("Iterated over %d total rows" % (count)) + print(f"Iterated over {count} total rows") # Free database resources db.close() diff --git a/src/python/paperai/highlights.py b/src/python/paperai/highlights.py index 13805d8..461c70c 100644 --- a/src/python/paperai/highlights.py +++ b/src/python/paperai/highlights.py @@ -8,7 +8,7 @@ from txtai.pipeline import Tokenizer -class Highlights(object): +class Highlights: """ Methods to extract highlights from a list of text sections. """ diff --git a/src/python/paperai/index.py b/src/python/paperai/index.py index a0729af..1d713b0 100644 --- a/src/python/paperai/index.py +++ b/src/python/paperai/index.py @@ -14,14 +14,14 @@ from .models import Models -class Index(object): +class Index: """ Methods to build a new sentence embeddings index. """ # Section query and filtering logic constants SECTION_FILTER = r"background|(? 0: - query += " AND article in (SELECT id FROM articles ORDER BY entry DESC LIMIT %d)" % maxsize + query += f" AND article in (SELECT id FROM articles ORDER BY entry DESC LIMIT {maxsize})" # Run the query cur.execute(query) @@ -59,13 +59,13 @@ def stream(dbfile, maxsize): count += 1 if count % 1000 == 0: - print("Streamed %d documents" % (count), end="\r") + print(f"Streamed {count} documents", end="\r") # Skip documents with no tokens parsed if tokens: yield document - print("Iterated over %d total rows" % (count)) + print(f"Iterated over {count} total rows") # Free database resources db.close() @@ -88,7 +88,7 @@ def config(vectors): # Read YAML index configuration if vectors.endswith(".yml"): - with open(vectors, "r") as f: + with open(vectors, "r", encoding="utf-8") as f: return yaml.safe_load(f) return {"path": vectors, "scoring": "bm25", "pca": 3, "quantize": True} diff --git a/src/python/paperai/models.py b/src/python/paperai/models.py index 075308a..770d6bd 100644 --- a/src/python/paperai/models.py +++ b/src/python/paperai/models.py @@ -8,7 +8,7 @@ from txtai.embeddings import Embeddings -class Models(object): +class Models: """ Common methods for generating data paths. """ @@ -95,7 +95,7 @@ def load(path): dbfile = os.path.join(path, "articles.sqlite") if os.path.isfile(os.path.join(path, "config")): - print("Loading model from %s" % path) + print(f"Loading model from {path}") embeddings = Embeddings() embeddings.load(path) else: diff --git a/src/python/paperai/query.py b/src/python/paperai/query.py index 07e9cc3..7c97f75 100644 --- a/src/python/paperai/query.py +++ b/src/python/paperai/query.py @@ -14,7 +14,7 @@ from .highlights import Highlights from .models import Models -class Query(object): +class Query: """ Methods to query an embeddings index. """ @@ -238,7 +238,7 @@ def authors(authors): else: authors = authors.split()[-1] - return "%s et al" % authors + return f"{authors} et al" return None @@ -289,40 +289,6 @@ def text(text): return text - @staticmethod - def design(design): - """ - Formats a study design field. - - Args: - design: study design integer - - Returns: - Study Design string - """ - - # Study design type mapping - mapping = {1:"Systematic review", 2:"Randomized control trial", 3:"Non-randomized trial", - 4:"Prospective observational", 5:"Time-to-event analysis", 6:"Retrospective observational", - 7:"Cross-sectional", 8:"Case series", 9:"Modeling", 0:"Other"} - - return mapping[design] - - @staticmethod - def sample(size, text): - """ - Formats a sample string. - - Args: - size: Sample size - text: Sample text - - Returns: - Formatted sample text - """ - - return "[%s] %s" % (size, Query.text(text)) if size else Query.text(text) - @staticmethod def query(embeddings, db, query, topn, threshold): """ @@ -341,7 +307,7 @@ def query(embeddings, db, query, topn, threshold): cur = db.cursor() - print(Query.render("#Query: %s" % query, theme="729.8953") + "\n") + print(Query.render(f"#Query: {query}", theme="729.8953") + "\n") # Query for best matches results = Query.search(embeddings, cur, query, topn, threshold) @@ -349,7 +315,7 @@ def query(embeddings, db, query, topn, threshold): # Extract top sections as highlights print(Query.render("# Highlights")) for highlight in Query.highlights(results, int(topn / 5)): - print(Query.render("## - %s" % Query.text(highlight))) + print(Query.render(f"## - {Query.text(highlight)}")) print() @@ -360,22 +326,19 @@ def query(embeddings, db, query, topn, threshold): # Print each result, sorted by max score descending for uid in sorted(documents, key=lambda k: sum([x[0] for x in documents[k]]), reverse=True): - cur.execute("SELECT Title, Published, Publication, Design, Size, Sample, Method, Entry, Id, Reference FROM articles WHERE id = ?", [uid]) + cur.execute("SELECT Title, Published, Publication, Entry, Id, Reference FROM articles WHERE id = ?", [uid]) article = cur.fetchone() - print("Title: %s" % article[0]) - print("Published: %s" % Query.date(article[1])) - print("Publication: %s" % article[2]) - print("Design: %s" % Query.design(article[3])) - print("Sample: %s" % Query.sample(article[4], article[5])) - print("Method: %s" % Query.text(article[6])) - print("Entry: %s" % article[7]) - print("Id: %s" % article[8]) - print("Reference: %s" % article[9]) + print(f"Title: {article[0]}") + print(f"Published: {Query.date(article[1])}") + print(f"Publication: {article[2]}") + print(f"Entry: {article[3]}") + print(f"Id: {article[4]}") + print(f"Reference: {article[5]}") # Print top matches for score, text in documents[uid]: - print(Query.render("## - (%.4f): %s" % (score, Query.text(text)), html=False)) + print(Query.render(f"## - ({score:.4f}): {Query.text(text)}", html=False)) print() diff --git a/src/python/paperai/report/annotate.py b/src/python/paperai/report/annotate.py index de0c521..43faf44 100644 --- a/src/python/paperai/report/annotate.py +++ b/src/python/paperai/report/annotate.py @@ -45,7 +45,7 @@ def headers(self, columns, output): self.names = columns # Do not annotate following columns - for field in ["Date", "Study", "Study Link", "Journal", "Study Type", "Sample Size", "Matches", "Entry"]: + for field in ["Date", "Study", "Study Link", "Journal", "Matches", "Entry", "Id"]: if field in self.names: self.names.remove(field) @@ -62,12 +62,6 @@ def buildRow(self, article, sections, calculated): # Source row["Source"] = article[4] - # Sample Text - row["Sample Text"] = article[7] - - # Study Population - row["Study Population"] = Query.text(article[8] if article[8] else article[7]) - # Merge in calculated fields row.update(calculated) @@ -156,7 +150,7 @@ def formatter(self, text): patterns.append(r"(\(\d+\)\s){3,}") # Build regex pattern - pattern = re.compile("|".join(["(%s)" % p for p in patterns])) + pattern = re.compile("|".join([f"({p})" for p in patterns])) text = pattern.sub(" ", text) diff --git a/src/python/paperai/report/column.py b/src/python/paperai/report/column.py new file mode 100644 index 0000000..6fd74f1 --- /dev/null +++ b/src/python/paperai/report/column.py @@ -0,0 +1,168 @@ +""" +Column module +""" + +import regex as re + +from dateutil.parser import parse +from text2digits.text2digits import Text2Digits + +class Column: + """ + Column formatting functions for reports. + """ + + @staticmethod + def integer(text): + """ + Format text as a string. This method also converts text describing a number to a number. + For example, twenty three is converted to 23. + + Args: + text: input text + + Returns: + number if parsed, otherwise None + """ + + # Format text for numeric parsing + text = text.replace(",", "") + text = re.sub(r"(\d+)\s+(\d+)", r"\1\2", text) + + try: + # Convert numeric words to numbers + text = text if text.isnumeric() else Text2Digits().convert(text) + # pylint: disable=W0702 + except: + pass + + return text if text.isnumeric() else None + + @staticmethod + def categorical(model, text, labels): + """ + Applies a text classification model to text using labels. + + Args: + model: text classification model + text: input text + labels: labels to use + + Returns: + categorical label, if model not None, otherwise original text returned + """ + + if model: + index = model(text, labels)[0][0] if model else text + return labels[index] + + return text + + @staticmethod + def duration(text, dtype): + """ + Attempts to standardize a date duration string to a format specified by dtype. + If dtype is days and the duration string specifies a month range, the duration is converted + to days. If the duration is in years and the dtype is months, the duration is converted to months. + + Examples: + 2021-01-01 to 2021-01-31. In days = 30, in months = 1, in years = 0.083 + Jan 2021 to Mar 2021. In days = 60, in months = 2, in years = 0.167 + + Args: + text: date duration string + dtype: target duration type [supports days, weeks, months, years] + + Returns: + duration as number + """ + + try: + data = re.sub(r"(?i)\s*(between|over|up to)\s+", "", text) + data = re.split(r"\s+(?:and|to|through)\s+", data) + + d1, d2 = parse(data[0]), parse(data[1]) + value = (d2 - d1).days + + # Handle case where no year specified for first date + if value < 0: + d1 = d1.replace(year=d2.year) + value = (d2 - d1).days + + return Column.convert(value, "days", dtype) + # pylint: disable=W0702 + except: + pass + + data = re.sub(r"\(.*?\)", "", text) + data = re.split(r"\s+|\-", data) + data = sorted(data) + + data[0] = Text2Digits().convert(data[0]) + if len(data) > 1 and not data[1].endswith("s"): + data[1] = data[1] + "s" + + if len(data) == 2 and \ + (data[0].replace(".", "", 1).isnumeric() and data[1] in (["days", "weeks", "months", "years"])): + + value, suffix = sorted(data) + value = float(value) + + return Column.convert(value, suffix, dtype) + + return None + + @staticmethod + def convert(value, itype, otype): + """ + Attempts to convert a numeric duration from itype to otype. + + Args: + value: numeric duration + itype: input type [days, weeks, months, years] + otype: output type [days, weeks, months, years] + + Returns: + converted numeric duration + """ + + if itype == otype: + return value + + if itype == "days": + if otype == "weeks": + return value / 7 + if otype == "months": + return value / 30 + + # Years + return value / 365 + + if itype == "weeks": + if otype == "days": + return value * 7 + if otype == "months": + return value / 4 + + # Years + return value / 52 + + if itype == "months": + if otype == "days": + return value * 30 + if otype == "weeks": + return value * 4 + + # Years + return value / 12 + + if itype == "years": + if otype == "days": + return value * 365 + if otype == "weeks": + return value * 52 + + # Months + return value * 12 + + return value diff --git a/src/python/paperai/report/common.py b/src/python/paperai/report/common.py index e0d6837..b31f70c 100644 --- a/src/python/paperai/report/common.py +++ b/src/python/paperai/report/common.py @@ -4,12 +4,14 @@ import regex as re -from txtai.pipeline import Extractor, Similarity +from txtai.pipeline import Extractor, Labels, Similarity from ..index import Index from ..query import Query -class Report(object): +from .column import Column + +class Report: """ Methods to build reports from a series of queries """ @@ -34,12 +36,16 @@ def __init__(self, embeddings, db, options): # Column names self.names = [] + self.similarity = Similarity(options["similarity"]) if "similarity" in options else None + self.labels = Labels(model=self.similarity) if self.similarity else None + # Extractive question-answering model # Determine if embeddings or a custom similarity model should be used to build question context - self.extractor = Extractor(Similarity(options["similarity"]) if "similarity" in options else self.embeddings, + self.extractor = Extractor(self.similarity if self.similarity else self.embeddings, options["qa"] if options["qa"] else "NeuML/bert-small-cord19qa", minscore=options.get("minscore"), - mintokens=options.get("mintokens")) + mintokens=options.get("mintokens"), + context=options.get("context")) def build(self, queries, options, output): """ @@ -128,12 +134,14 @@ def articles(self, output, topn, metadata, results): # Collect matching rows rows = [] - for uid in documents: + for x, uid in enumerate(documents): # Get article metadata - self.cur.execute("SELECT Published, Title, Reference, Publication, Source, Design, Size, Sample, Method, Entry " + - "FROM articles WHERE id = ?", [uid]) + self.cur.execute("SELECT Published, Title, Reference, Publication, Source, Entry, Id FROM articles WHERE id = ?", [uid]) article = self.cur.fetchone() + if x and x % 100 == 0: + print(f"Processed {x} documents", end="\r") + # Calculate derived fields calculated = self.calculate(uid, metadata) @@ -166,18 +174,23 @@ def calculate(self, uid, metadata): # Parse column parameters fields, params = self.params(metadata) - # Stores embedding query only and full QA column definitions - queries, questions = [], [] + # Different type of calculations + # 1. Similarity query + # 2. Extractor query (similarity + question) + # 3. Question-answering on other field + queries, extractions, questions = [], [], [] # Retrieve indexed document text for article sections = self.sections(uid) texts = [text for _, text in sections] - for name, query, question, snippet, _, _, matches in params: - if matches: + for name, query, question, snippet, _, _, matches, _ in params: + if query.startswith("$"): + questions.append((name, query.replace("$", ""), question, snippet)) + elif matches: queries.append((name, query, matches)) else: - questions.append((name, query, question, snippet)) + extractions.append((name, query, question, snippet)) # Run all extractor queries against document text results = self.extractor.query([query for _, query, _ in queries], texts) @@ -189,12 +202,25 @@ def calculate(self, uid, metadata): topn = [text for _, text, _ in results[x]][:matches] # Join results into String and return - fields[name] = "\n\n".join([self.resolve(params, sections, uid, name, value) for value in topn]) + value = [self.resolve(params, sections, uid, name, value) for value in topn] + fields[name] = "\n\n".join(value) if value else None else: fields[name] = None # Add extraction fields - for name, value in self.extractor(questions, texts): + for name, value in self.extractor(extractions, texts): + # Resolves the full value based on column parameters + fields[name] = self.resolve(params, sections, uid, name, value) if value else "" + + # Add question fields + names, qa, contexts, snippets = [], [], [], [] + for name, query, question, snippet in questions: + names.append(name) + qa.append(question) + contexts.append(fields[query]) + snippets.append(snippet) + + for name, value in self.extractor.answers(names, qa, contexts, contexts, snippets): # Resolves the full value based on column parameters fields[name] = self.resolve(params, sections, uid, name, value) if value else "" @@ -231,13 +257,14 @@ def params(self, metadata): question = self.variables(column["question"], metadata) if "question" in column else query # Additional context parameters - section = column["section"] if "section" in column else False - surround = column["surround"] if "surround" in column else 0 - matches = column["matches"] if "matches" in column else 0 - snippet = column["snippet"] if "snippet" in column else False + section = column.get("section", False) + surround = column.get("surround", 0) + matches = column.get("matches", 0) + dtype = column.get("dtype") + snippet = column.get("snippet", False) snippet = True if section or surround else snippet - params.append((column["name"], query, question, snippet, section, surround, matches)) + params.append((column["name"], query, question, snippet, section, surround, matches, dtype)) return fields, params @@ -276,7 +303,7 @@ def sections(self, uid): """ # Retrieve indexed document text for article - self.cur.execute(Index.SECTION_QUERY + " AND article = ? ORDER BY id", [uid]) + self.cur.execute(Index.SECTION_QUERY + " WHERE article = ? ORDER BY id", [uid]) # Get list of document text sections sections = [] @@ -307,7 +334,7 @@ def resolve(self, params, sections, uid, name, value): # Get all column parameters index = [params.index(x) for x in params if x[0] == name][0] - _, _, _, _, section, surround, _ = params[index] + _, _, _, _, section, surround, _, dtype = params[index] if value: # Find matching section @@ -322,6 +349,14 @@ def resolve(self, params, sections, uid, name, value): elif surround: value = self.surround(uid, sid, surround) + # Column dtype formatting + if dtype == "int": + value = Column.integer(value) + elif isinstance(dtype, list): + value = Column.categorical(self.labels, value, dtype) + elif dtype in ["days", "weeks", "months", "years"]: + value = Column.duration(value, dtype) + return value def subsection(self, uid, sid): diff --git a/src/python/paperai/report/csvr.py b/src/python/paperai/report/csvr.py index d480c43..3f8aa06 100644 --- a/src/python/paperai/report/csvr.py +++ b/src/python/paperai/report/csvr.py @@ -31,7 +31,7 @@ def query(self, output, task, query): if self.csvout: self.csvout.close() - self.csvout = open(os.path.join(os.path.dirname(output.name), "%s.csv" % task), "w", newline="") + self.csvout = open(os.path.join(os.path.dirname(output.name), f"{task}.csv"), "w", newline="", encoding="utf-8") self.writer = csv.writer(self.csvout, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL) def write(self, row): @@ -69,23 +69,14 @@ def buildRow(self, article, sections, calculated): # Source row["Source"] = article[4] - # Study Type - row["Study Type"] = Query.design(article[5]) - - # Sample Size - row["Sample Size"] = article[6] - - # Study Population - row["Study Population"] = Query.text(article[8] if article[8] else article[7]) - - # Sample Text - row["Sample Text"] = article[7] - # Top Matches row["Matches"] = "\n\n".join([Query.text(text) for _, text in sections]) if sections else "" # Entry Date - row["Entry"] = article[9] if article[9] else "" + row["Entry"] = article[5] if article[5] else "" + + # Id + row["Id"] = article[6] # Merge in calculated fields row.update(calculated) diff --git a/src/python/paperai/report/execute.py b/src/python/paperai/report/execute.py index 94770ef..9e58588 100644 --- a/src/python/paperai/report/execute.py +++ b/src/python/paperai/report/execute.py @@ -11,7 +11,7 @@ from ..models import Models -class Execute(object): +class Execute: """ Creates a Report """ @@ -71,10 +71,10 @@ def run(task, topn=None, render=None, path=None, qa=None, indir=None, threshold= report = Execute.create(render, embeddings, db, options) # Generate output filename - outfile = os.path.join(outdir, "%s.%s" % (name, render)) + outfile = os.path.join(outdir, f"{name}.{render}") # Stream report to file - with open(outfile, "w") as output: + with open(outfile, "w", encoding="utf-8") as output: # Build the report report.build(queries, options, output) diff --git a/src/python/paperai/report/markdown.py b/src/python/paperai/report/markdown.py index 4815a9e..87a3914 100644 --- a/src/python/paperai/report/markdown.py +++ b/src/python/paperai/report/markdown.py @@ -48,20 +48,20 @@ def write(self, output, line): line: line to write """ - output.write("%s\n" % line) + output.write(f"{line}\n") def query(self, output, task, query): - self.write(output, "# %s" % query) + self.write(output, f"# {query}") def section(self, output, name): - self.write(output, "#### %s
" % name) + self.write(output, f"#### {name}
") def highlight(self, output, article, highlight): # Build citation link - link = "[%s](%s)" % (Query.authors(article[0]) if article[0] else "Source", self.encode(article[1])) + link = f"[{Query.authors(article[0]) if article[0] else 'Source'}]({self.encode(article[1])})" # Build highlight row with citation link - self.write(output, "- %s %s
" % (Query.text(highlight), link)) + self.write(output, f"- {Query.text(highlight)} {link}
") def headers(self, columns, output): self.names = columns @@ -73,11 +73,11 @@ def headers(self, columns, output): # Write table header headers = "|".join(self.names) - self.write(output, "|%s|" % headers) + self.write(output, f"|{headers}|") # Write markdown separator for headers headers = "|".join(["----"] * len(self.names)) - self.write(output, "|%s|" % headers) + self.write(output, f"|{headers}|") def buildRow(self, article, sections, calculated): row = {} @@ -86,10 +86,10 @@ def buildRow(self, article, sections, calculated): row["Date"] = Query.date(article[0]) if article[0] else "" # Title - title = "[%s](%s)" % (article[1], self.encode(article[2])) + title = f"[{article[1]}]({self.encode(article[2])})" # Append Publication if available. Assume preprint otherwise and show preprint source. - title += "
%s" % (article[3] if article[3] else article[4]) + title += f"
{article[3] if article[3] else article[4]}" # Source row["Source"] = article[4] @@ -97,30 +97,23 @@ def buildRow(self, article, sections, calculated): # Title + Publication if available row["Study"] = title - # Study Type - row["Study Type"] = Query.design(article[5]) - - # Sample Size - sample = Query.sample(article[6], article[7]) - row["Sample Size"] = sample if sample else "" - - # Study Population - row["Study Population"] = Query.text(article[8]) if article[8] else "" - # Top Matches row["Matches"] = "

".join([Query.text(text) for _, text in sections]) if sections else "" # Entry Date - row["Entry"] = article[9] if article[9] else "" + row["Entry"] = article[5] if article[5] else "" + + # Id + row["Id"] = article[6] # Merge in calculated fields row.update(calculated) # Escape | characters embedded within columns - return {column: self.column(row[column]) for column in row} + return {name: self.column(value) for name, value in row.items()} def writeRow(self, output, row): - self.write(output, "|%s|" % "|".join(row)) + self.write(output, f"|{'|'.join(row)}|") def separator(self, output): # Write section separator diff --git a/src/python/paperai/report/task.py b/src/python/paperai/report/task.py index 2637217..9a5a697 100644 --- a/src/python/paperai/report/task.py +++ b/src/python/paperai/report/task.py @@ -6,7 +6,7 @@ import yaml -class Task(object): +class Task: """ YAML task configuration loader """ @@ -25,7 +25,7 @@ def load(task): if os.path.exists(task): # Load tasks yml file - with open(task, "r") as f: + with open(task, "r", encoding="utf-8") as f: # Read configuration config = yaml.safe_load(f) diff --git a/src/python/paperai/vectors.py b/src/python/paperai/vectors.py index 7d96a93..c51ebd5 100644 --- a/src/python/paperai/vectors.py +++ b/src/python/paperai/vectors.py @@ -13,7 +13,7 @@ from .models import Models -class RowIterator(object): +class RowIterator: """ Iterates over rows in a database query. Allows for multiple iterations. """ @@ -78,18 +78,18 @@ def stream(self, dbfile): count += 1 if count % 1000 == 0: - print("Streamed %d documents" % (count), end="\r") + print(f"Streamed {count} documents", end="\r") # Skip documents with no tokens parsed if tokens: yield tokens - print("Iterated over %d total rows" % (count)) + print(f"Iterated over {count} total rows") # Free database resources db.close() -class Vectors(object): +class Vectors: """ Methods to build a FastText model. """ @@ -142,7 +142,7 @@ def run(path, size, mincount, output): if not output: # Output file path - output = Models.vectorPath("cord19-%dd" % size, True) + output = Models.vectorPath(f"cord19-{size}d", True) # Build word vectors model WordVectors.build(tokens, size, mincount, output) diff --git a/test/python/testapi.py b/test/python/testapi.py index 52cdd1b..f735e7f 100644 --- a/test/python/testapi.py +++ b/test/python/testapi.py @@ -38,7 +38,7 @@ def start(): config = os.path.join(tempfile.gettempdir(), "testapi.yml") - with open(config, "w") as output: + with open(config, "w", encoding="utf-8") as output: output.write(INDEX % Utils.PATH) client = TestClient(app) @@ -59,7 +59,7 @@ def testSearch(self): # Run search params = urllib.parse.urlencode({"query": "+hypertension ci", "limit": 1}) - results= client.get("search?%s" % params).json() + results= client.get(f"search?{params}").json() # Check number of results self.assertEqual(len(results), 1) diff --git a/test/python/testcolumn.py b/test/python/testcolumn.py new file mode 100644 index 0000000..43a53f9 --- /dev/null +++ b/test/python/testcolumn.py @@ -0,0 +1,73 @@ +""" +Column module tests +""" + +import unittest + +from paperai.report.column import Column + +class TestColumn(unittest.TestCase): + """ + Column tests + """ + + def testInteger(self): + """ + Tests parsing integers from strings. + """ + + self.assertEqual(Column.integer("Twenty Three"), "23") + self.assertEqual(Column.integer("Two hundred and twelve"), "212") + self.assertEqual(Column.integer("4,000,234"), "4000234") + self.assertEqual(Column.integer("23"), "23") + self.assertEqual(Column.integer("30 days"), None) + + def testCategorical(self): + """ + Tests generating categorical strings. + """ + + def model(text, labels): + return [(0, 0.9), (1, 0.85), (text, labels)] + + self.assertEqual(Column.categorical(model, "text", ["labels"]), "labels") + self.assertEqual(Column.categorical(None, "text", ["labels"]), "text") + + def testDurationStartEnd(self): + """ + Test duration ranges with start and end + """ + + self.assertEqual(Column.duration("2021-01-01 to 2021-01-31", "days"), 30) + self.assertEqual(Column.duration("2021-01-01 to 2021-01-31", "months"), 1) + self.assertEqual(round(Column.duration("2021-01-01 to 2021-01-31", "weeks"), 2), 4.29) + self.assertEqual(round(Column.duration("2021-01-01 to 2021-01-31", "years"), 2), 0.08) + + def testDurationStartEndNoYear(self): + """ + Test duration ranges with start and end but no year for first date + """ + + self.assertEqual(Column.duration("January to March 2020", "days"), 60) + + def testDurationRelative(self): + """ + Test relative duration ranges + """ + + self.assertEqual(Column.duration("30 day", "days"), 30) + + self.assertEqual(Column.duration("1 week", "days"), 7) + self.assertEqual(Column.duration("1 week", "months"), 0.25) + self.assertEqual(round(Column.duration("1 week", "years"), 2), 0.02) + + self.assertEqual(Column.duration("2 months", "days"), 60) + self.assertEqual(Column.duration("2 months", "weeks"), 8) + self.assertEqual(round(Column.duration("2 months", "years"), 2), 0.17) + + self.assertEqual(Column.duration("1 year", "days"), 365) + self.assertEqual(Column.duration("1 year", "weeks"), 52) + self.assertEqual(Column.duration("1 year", "months"), 12) + + self.assertEqual(Column.duration("30 moons", "days"), None) + self.assertEqual(Column.convert("30", "moons", "days"), "30") diff --git a/test/python/testexport.py b/test/python/testexport.py index 580ecd6..2fc90bb 100644 --- a/test/python/testexport.py +++ b/test/python/testexport.py @@ -20,4 +20,4 @@ def testRun(self): """ Export.run(Utils.PATH + "/export.txt", Utils.PATH) - self.assertEqual(Utils.hashfile(Utils.PATH + "/export.txt"), "a6f85df295a19f2d3c1a10ec8edce6ae") + self.assertEqual(Utils.hashfile(Utils.PATH + "/export.txt"), "ac15a3ece486c3035ef861f6706c3e1b") diff --git a/test/python/testindex.py b/test/python/testindex.py index 0d4c3f5..60a9c2e 100644 --- a/test/python/testindex.py +++ b/test/python/testindex.py @@ -20,7 +20,7 @@ def testStream(self): """ # Full index stream - self.assertEqual(len(list(Index.stream(Utils.DBFILE, 0))), 21478) + self.assertEqual(len(list(Index.stream(Utils.DBFILE, 0))), 29218) # Partial index stream - top n documents by entry date - self.assertEqual(len(list(Index.stream(Utils.DBFILE, 10))), 224) + self.assertEqual(len(list(Index.stream(Utils.DBFILE, 10))), 287) diff --git a/test/python/testquery.py b/test/python/testquery.py index 944fec3..231e5ab 100644 --- a/test/python/testquery.py +++ b/test/python/testquery.py @@ -2,7 +2,6 @@ Query module tests """ -import os import unittest from contextlib import redirect_stdout @@ -17,15 +16,14 @@ class TestQuery(unittest.TestCase): Query tests """ - @unittest.skipIf(os.name == "nt", "Faiss not installed on Windows") def testRun(self): """ Test query execution """ # Execute query - with open(Utils.PATH + "/query.txt", "w", newline="\n") as query: + with open(Utils.PATH + "/query.txt", "w", newline="\n", encoding="utf-8") as query: with redirect_stdout(query): Query.run("risk factors studied", 10, Utils.PATH) - self.assertEqual(Utils.hashfile(Utils.PATH + "/query.txt"), "b1932b9ceb6c2ea2b626ebb44d89340b") + self.assertEqual(Utils.hashfile(Utils.PATH + "/query.txt"), "b7ba65adc0aacccf161d61da8616bfca") diff --git a/test/python/testreport.py b/test/python/testreport.py index b40eb48..e4cb818 100644 --- a/test/python/testreport.py +++ b/test/python/testreport.py @@ -2,7 +2,6 @@ Report module tests """ -import os import unittest # pylint: disable=E0401 @@ -15,7 +14,6 @@ class TestReport(unittest.TestCase): Report tests """ - @unittest.skipIf(os.name == "nt", "Faiss not installed on Windows") def testReport1(self): """ Runs test queries from report1.yml test file @@ -25,16 +23,15 @@ def testReport1(self): Execute.run(Utils.PATH + "/report1.yml", 10, "csv", Utils.PATH, None) Execute.run(Utils.PATH + "/report1.yml", 10, "md", Utils.PATH, None) - hashes = [("Age.csv", "ed2b9c761dc949708cd6254e6207ff83"), - ("Heart Disease.csv", "90f2dede871c545dd1492aef8ed84645"), - ("Heart Failure.csv", "2152a8187ff53e9c4224e3c9891b5b33"), - ("Report1.md", "2da3dbcde55153ddaed1e709314eac07")] + hashes = [("Age.csv", "e5d589d131dce3293532e3fd19593637"), + ("Heart Disease.csv", "96b144fc1566e2c0aa774d098e203922"), + ("Heart Failure.csv", "afd812f7875c4fcb45bf800952327dba"), + ("Report1.md", "f1f3b70dd6242488f8d58e1e5d2faea6")] # Check file hashes for name, value in hashes: self.assertEqual(Utils.hashfile(Utils.PATH + "/" + name), value) - @unittest.skipIf(os.name == "nt", "Faiss not installed on Windows") def testReport2(self): """ Runs test queries from report2.yml test file @@ -44,17 +41,16 @@ def testReport2(self): Execute.run(Utils.PATH + "/report2.yml", 10, "csv", Utils.PATH, None) Execute.run(Utils.PATH + "/report2.yml", 10, "md", Utils.PATH, None) - hashes = [("Match.csv", "2def38a008f33f25d7ab4a763d159e80"), - ("MatchSurround.csv", "e9f581d19b8802822f47261bce0e91b1"), - ("Section.csv", "7ae8b295f0d959ba12410807db7b7e48"), - ("Surround.csv", "fed9fb4249bf2f73fa9822753d359207"), - ("Report2.md", "9218fe80fe5e9fdd50c5719f54c52061")] + hashes = [("Match.csv", "9cb5ce8896355d049084d61fae13d97f"), + ("MatchSurround.csv", "47e4d2ec7ae8fda30a78d628d124f204"), + ("Section.csv", "7113d5af95542193fc3dc21dc785b014"), + ("Surround.csv", "14d124f85c140077d58ae3636ba8557f"), + ("Report2.md", "7813d253d7a792f93915c2dccfb78483")] # Check file hashes for name, value in hashes: self.assertEqual(Utils.hashfile(Utils.PATH + "/" + name), value) - @unittest.skipIf(os.name == "nt", "Faiss not installed on Windows") def testReport3(self): """ Runs test queries from report3.yml test file @@ -65,8 +61,9 @@ def testReport3(self): Execute.run(Utils.PATH + "/report3.yml", 1, "md", Utils.PATH, None) Execute.run(Utils.PATH + "/report3.yml", 1, "ant", Utils.PATH, None, Utils.PATH) - hashes = [("AI.csv", "b47e96639a210d2089a5bd4e7e7bfc98"), - ("Report3.md", "1a47340bc135fc086160c62f8731edee")] + hashes = [("AI.csv", "94f0bead413eb71835c3f27881b29c91"), + ("All.csv", "3bca7a39a541fa68b3ef457625fb0120"), + ("Report3.md", "0fc53703dace57e3403294fb8ea7e9d1")] # Check file hashes for name, value in hashes: diff --git a/test/python/utils.py b/test/python/utils.py index 214ff83..417dbd6 100644 --- a/test/python/utils.py +++ b/test/python/utils.py @@ -25,6 +25,6 @@ def hashfile(path): MD5 hash """ - with open(path, "r") as data: + with open(path, "r", encoding="utf-8") as data: # Read file into string and build MD5 hash return hashlib.md5(data.read().encode()).hexdigest()