diff --git a/.gitignore b/.gitignore
index bf2ad57..3e152bb 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,8 @@
build/
dist/
+htmlcov/
*egg-info/
__pycache__/
.coverage
+.coverage.*
*.pyc
diff --git a/.pylintrc b/.pylintrc
index 7d65d67..dcb45dd 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -17,4 +17,4 @@ min-public-methods=0
max-line-length=150
[MESSAGES CONTROL]
-disable=I0011,R0201,W0105,W0108,W0110,W0141,W0621,W0640
+disable=R0201
diff --git a/Makefile b/Makefile
index 711d53c..e0e41a0 100644
--- a/Makefile
+++ b/Makefile
@@ -14,7 +14,7 @@ PYTHON ?= python
# Download test data
data:
mkdir -p /tmp/paperai
- wget -N https://github.com/neuml/paperai/releases/download/v1.3.0/tests.tar.gz -P /tmp
+ wget -N https://github.com/neuml/paperai/releases/download/v1.10.0/tests.tar.gz -P /tmp
tar -xvzf /tmp/tests.tar.gz -C /tmp
# Unit tests
diff --git a/README.md b/README.md
index 6d22e8c..3dc633c 100644
--- a/README.md
+++ b/README.md
@@ -46,11 +46,11 @@ The easiest way to install is via pip and PyPI
pip install paperai
-You can also install paperai directly from GitHub. Using a Python Virtual Environment is recommended.
+Python 3.6+ is supported. Using a Python [virtual environment](https://docs.python.org/3/library/venv.html) is recommended.
- pip install git+https://github.com/neuml/paperai
+paperai can also be installed directly from GitHub to access the latest, unreleased features.
-Python 3.6+ is supported
+ pip install git+https://github.com/neuml/paperai
See [this link](https://github.com/neuml/txtai#installation) to help resolve environment-specific install issues.
@@ -125,7 +125,7 @@ no parameters are passed in.
## Building a report file
Reports support generating output in multiple formats. An example report call:
- python -m paperai.report tasks/risks.yml 50 md cord19/models
+ python -m paperai.report report.yml 50 md cord19/models
The following report formats are supported:
@@ -133,7 +133,7 @@ The following report formats are supported:
- CSV - Renders a CSV report. Columns and answers are extracted from articles with the results stored in a CSV file.
- Annotation - Columns and answers are extracted from articles with the results annotated over the original PDF files. Requires passing in a path with the original PDF files.
-In the example above, a file named tasks/risk_factors.md will be created. Example report configuration files can be found [here](https://github.com/neuml/cord19q/tree/master/tasks).
+In the example above, a file named report.md will be created. Example report configuration files can be found [here](https://github.com/neuml/cord19q/tree/master/tasks).
## Running queries
The fastest way to run queries is to start a paperai shell
diff --git a/demo.png b/demo.png
index 88efd1b..6ec9969 100644
Binary files a/demo.png and b/demo.png differ
diff --git a/examples/search.py b/examples/search.py
index 0d7b550..3339563 100644
--- a/examples/search.py
+++ b/examples/search.py
@@ -28,8 +28,8 @@ def __init__(self, path):
"""
# Default list of columns
- self.columns = [("Title", True), ("Published", False), ("Publication", False), ("Design", False), ("Sample", False),
- ("Method", False), ("Entry", False), ("Id", False), ("Content", True)]
+ self.columns = [("Title", True), ("Published", False), ("Publication", False), ("Entry", False),
+ ("Id", False), ("Content", True)]
# Load model
self.path = path
@@ -58,17 +58,16 @@ def search(self, query, topn, threshold):
# Print each result, sorted by max score descending
for uid in sorted(documents, key=lambda k: sum([x[0] for x in documents[k]]), reverse=True):
- cur.execute("SELECT Title, Published, Publication, Design, Size, Sample, Method, Entry, Id, Reference " +
+ cur.execute("SELECT Title, Published, Publication, Entry, Id, Reference " +
"FROM articles WHERE id = ?", [uid])
article = cur.fetchone()
matches = "
".join([text for _, text in documents[uid]])
- title = "%s" % (article[9], article[0])
+ title = f"{article[0]}"
- article = {"Title": title, "Published": Query.date(article[1]), "Publication": article[2], "Design": Query.design(article[3]),
- "Sample": Query.sample(article[4], article[5]), "Method": Query.text(article[6]), "Entry": article[7],
- "Id": article[8], "Content": matches}
+ article = {"Title": title, "Published": Query.date(article[1]), "Publication": article[2], "Entry": article[3],
+ "Id": article[4], "Content": matches}
articles.append(article)
diff --git a/setup.py b/setup.py
index 77f9a0b..f247072 100644
--- a/setup.py
+++ b/setup.py
@@ -1,11 +1,11 @@
# pylint: disable = C0111
from setuptools import find_packages, setup
-with open("README.md", "r") as f:
+with open("README.md", "r", encoding="utf-8") as f:
DESCRIPTION = f.read()
setup(name="paperai",
- version="1.11.0",
+ version="2.0.0",
author="NeuML",
description="AI-powered literature discovery and review engine for medical/scientific papers",
long_description=DESCRIPTION,
@@ -32,6 +32,7 @@
"networkx>=2.4",
"PyYAML>=5.3",
"regex>=2020.5.14",
+ "text2digits>=0.1.0",
"txtai[api,similarity]>=3.4.0",
"txtmarker>=1.0.0"
],
diff --git a/src/python/paperai/api.py b/src/python/paperai/api.py
index 3e8c909..6b65c62 100644
--- a/src/python/paperai/api.py
+++ b/src/python/paperai/api.py
@@ -14,7 +14,7 @@ class API(txtai.api.API):
Extended API on top of txtai to return enriched query results.
"""
- def search(self, query, request):
+ def search(self, query, request=None):
"""
Extends txtai API to enrich results with content.
@@ -28,8 +28,8 @@ def search(self, query, request):
if self.embeddings:
dbfile = os.path.join(self.config["path"], "articles.sqlite")
- limit = self.limit(request.query_params.get("limit"))
- threshold = float(request.query_params["threshold"]) if "threshold" in request.query_params else None
+ limit = self.limit(request.query_params.get("limit")) if request else 10
+ threshold = float(request.query_params["threshold"]) if request and "threshold" in request.query_params else None
with sqlite3.connect(dbfile) as db:
cur = db.cursor()
@@ -44,16 +44,15 @@ def search(self, query, request):
# Print each result, sorted by max score descending
for uid in sorted(documents, key=lambda k: sum([x[0] for x in documents[k]]), reverse=True):
- cur.execute("SELECT Title, Published, Publication, Design, Size, Sample, Method, Entry, Id, Reference " +
+ cur.execute("SELECT Title, Published, Publication, Entry, Id, Reference " +
"FROM articles WHERE id = ?", [uid])
article = cur.fetchone()
score = max([score for score, text in documents[uid]])
matches = [text for _, text in documents[uid]]
- article = {"id": article[8], "score": score, "title": article[0], "published": Query.date(article[1]), "publication": article[2],
- "design": Query.design(article[3]), "sample": Query.sample(article[4], article[5]), "method": Query.text(article[6]),
- "entry": article[7], "reference": article[9], "matches": matches}
+ article = {"id": article[4], "score": score, "title": article[0], "published": Query.date(article[1]), "publication": article[2],
+ "entry": article[3], "reference": article[5], "matches": matches}
articles.append(article)
diff --git a/src/python/paperai/export.py b/src/python/paperai/export.py
index 1ab1582..d18670a 100644
--- a/src/python/paperai/export.py
+++ b/src/python/paperai/export.py
@@ -14,7 +14,7 @@
from .index import Index
from .models import Models
-class Export(object):
+class Export:
"""
Exports database rows into a text file line-by-line.
"""
@@ -29,26 +29,26 @@ def stream(dbfile, output):
output: output file to store text
"""
- with open(output, "w") as output:
+ with open(output, "w", encoding="utf-8") as output:
# Connection to database file
db = sqlite3.connect(dbfile)
cur = db.cursor()
- # Get all indexed text, with a detected study design, excluding modeling designs
- cur.execute(Index.SECTION_QUERY + " AND design NOT IN (0, 9)")
+ # Get all indexed text
+ cur.execute(Index.SECTION_QUERY)
count = 0
for _, name, text in cur:
if not name or not re.search(Index.SECTION_FILTER, name.lower()):
count += 1
if count % 1000 == 0:
- print("Streamed %d documents" % (count), end="\r")
+ print(f"Streamed {count} documents", end="\r")
# Write row
if text:
output.write(text + "\n")
- print("Iterated over %d total rows" % (count))
+ print(f"Iterated over {count} total rows")
# Free database resources
db.close()
diff --git a/src/python/paperai/highlights.py b/src/python/paperai/highlights.py
index 13805d8..461c70c 100644
--- a/src/python/paperai/highlights.py
+++ b/src/python/paperai/highlights.py
@@ -8,7 +8,7 @@
from txtai.pipeline import Tokenizer
-class Highlights(object):
+class Highlights:
"""
Methods to extract highlights from a list of text sections.
"""
diff --git a/src/python/paperai/index.py b/src/python/paperai/index.py
index a0729af..1d713b0 100644
--- a/src/python/paperai/index.py
+++ b/src/python/paperai/index.py
@@ -14,14 +14,14 @@
from .models import Models
-class Index(object):
+class Index:
"""
Methods to build a new sentence embeddings index.
"""
# Section query and filtering logic constants
SECTION_FILTER = r"background|(? 0:
- query += " AND article in (SELECT id FROM articles ORDER BY entry DESC LIMIT %d)" % maxsize
+ query += f" AND article in (SELECT id FROM articles ORDER BY entry DESC LIMIT {maxsize})"
# Run the query
cur.execute(query)
@@ -59,13 +59,13 @@ def stream(dbfile, maxsize):
count += 1
if count % 1000 == 0:
- print("Streamed %d documents" % (count), end="\r")
+ print(f"Streamed {count} documents", end="\r")
# Skip documents with no tokens parsed
if tokens:
yield document
- print("Iterated over %d total rows" % (count))
+ print(f"Iterated over {count} total rows")
# Free database resources
db.close()
@@ -88,7 +88,7 @@ def config(vectors):
# Read YAML index configuration
if vectors.endswith(".yml"):
- with open(vectors, "r") as f:
+ with open(vectors, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
return {"path": vectors, "scoring": "bm25", "pca": 3, "quantize": True}
diff --git a/src/python/paperai/models.py b/src/python/paperai/models.py
index 075308a..770d6bd 100644
--- a/src/python/paperai/models.py
+++ b/src/python/paperai/models.py
@@ -8,7 +8,7 @@
from txtai.embeddings import Embeddings
-class Models(object):
+class Models:
"""
Common methods for generating data paths.
"""
@@ -95,7 +95,7 @@ def load(path):
dbfile = os.path.join(path, "articles.sqlite")
if os.path.isfile(os.path.join(path, "config")):
- print("Loading model from %s" % path)
+ print(f"Loading model from {path}")
embeddings = Embeddings()
embeddings.load(path)
else:
diff --git a/src/python/paperai/query.py b/src/python/paperai/query.py
index 07e9cc3..7c97f75 100644
--- a/src/python/paperai/query.py
+++ b/src/python/paperai/query.py
@@ -14,7 +14,7 @@
from .highlights import Highlights
from .models import Models
-class Query(object):
+class Query:
"""
Methods to query an embeddings index.
"""
@@ -238,7 +238,7 @@ def authors(authors):
else:
authors = authors.split()[-1]
- return "%s et al" % authors
+ return f"{authors} et al"
return None
@@ -289,40 +289,6 @@ def text(text):
return text
- @staticmethod
- def design(design):
- """
- Formats a study design field.
-
- Args:
- design: study design integer
-
- Returns:
- Study Design string
- """
-
- # Study design type mapping
- mapping = {1:"Systematic review", 2:"Randomized control trial", 3:"Non-randomized trial",
- 4:"Prospective observational", 5:"Time-to-event analysis", 6:"Retrospective observational",
- 7:"Cross-sectional", 8:"Case series", 9:"Modeling", 0:"Other"}
-
- return mapping[design]
-
- @staticmethod
- def sample(size, text):
- """
- Formats a sample string.
-
- Args:
- size: Sample size
- text: Sample text
-
- Returns:
- Formatted sample text
- """
-
- return "[%s] %s" % (size, Query.text(text)) if size else Query.text(text)
-
@staticmethod
def query(embeddings, db, query, topn, threshold):
"""
@@ -341,7 +307,7 @@ def query(embeddings, db, query, topn, threshold):
cur = db.cursor()
- print(Query.render("#Query: %s" % query, theme="729.8953") + "\n")
+ print(Query.render(f"#Query: {query}", theme="729.8953") + "\n")
# Query for best matches
results = Query.search(embeddings, cur, query, topn, threshold)
@@ -349,7 +315,7 @@ def query(embeddings, db, query, topn, threshold):
# Extract top sections as highlights
print(Query.render("# Highlights"))
for highlight in Query.highlights(results, int(topn / 5)):
- print(Query.render("## - %s" % Query.text(highlight)))
+ print(Query.render(f"## - {Query.text(highlight)}"))
print()
@@ -360,22 +326,19 @@ def query(embeddings, db, query, topn, threshold):
# Print each result, sorted by max score descending
for uid in sorted(documents, key=lambda k: sum([x[0] for x in documents[k]]), reverse=True):
- cur.execute("SELECT Title, Published, Publication, Design, Size, Sample, Method, Entry, Id, Reference FROM articles WHERE id = ?", [uid])
+ cur.execute("SELECT Title, Published, Publication, Entry, Id, Reference FROM articles WHERE id = ?", [uid])
article = cur.fetchone()
- print("Title: %s" % article[0])
- print("Published: %s" % Query.date(article[1]))
- print("Publication: %s" % article[2])
- print("Design: %s" % Query.design(article[3]))
- print("Sample: %s" % Query.sample(article[4], article[5]))
- print("Method: %s" % Query.text(article[6]))
- print("Entry: %s" % article[7])
- print("Id: %s" % article[8])
- print("Reference: %s" % article[9])
+ print(f"Title: {article[0]}")
+ print(f"Published: {Query.date(article[1])}")
+ print(f"Publication: {article[2]}")
+ print(f"Entry: {article[3]}")
+ print(f"Id: {article[4]}")
+ print(f"Reference: {article[5]}")
# Print top matches
for score, text in documents[uid]:
- print(Query.render("## - (%.4f): %s" % (score, Query.text(text)), html=False))
+ print(Query.render(f"## - ({score:.4f}): {Query.text(text)}", html=False))
print()
diff --git a/src/python/paperai/report/annotate.py b/src/python/paperai/report/annotate.py
index de0c521..43faf44 100644
--- a/src/python/paperai/report/annotate.py
+++ b/src/python/paperai/report/annotate.py
@@ -45,7 +45,7 @@ def headers(self, columns, output):
self.names = columns
# Do not annotate following columns
- for field in ["Date", "Study", "Study Link", "Journal", "Study Type", "Sample Size", "Matches", "Entry"]:
+ for field in ["Date", "Study", "Study Link", "Journal", "Matches", "Entry", "Id"]:
if field in self.names:
self.names.remove(field)
@@ -62,12 +62,6 @@ def buildRow(self, article, sections, calculated):
# Source
row["Source"] = article[4]
- # Sample Text
- row["Sample Text"] = article[7]
-
- # Study Population
- row["Study Population"] = Query.text(article[8] if article[8] else article[7])
-
# Merge in calculated fields
row.update(calculated)
@@ -156,7 +150,7 @@ def formatter(self, text):
patterns.append(r"(\(\d+\)\s){3,}")
# Build regex pattern
- pattern = re.compile("|".join(["(%s)" % p for p in patterns]))
+ pattern = re.compile("|".join([f"({p})" for p in patterns]))
text = pattern.sub(" ", text)
diff --git a/src/python/paperai/report/column.py b/src/python/paperai/report/column.py
new file mode 100644
index 0000000..6fd74f1
--- /dev/null
+++ b/src/python/paperai/report/column.py
@@ -0,0 +1,168 @@
+"""
+Column module
+"""
+
+import regex as re
+
+from dateutil.parser import parse
+from text2digits.text2digits import Text2Digits
+
+class Column:
+ """
+ Column formatting functions for reports.
+ """
+
+ @staticmethod
+ def integer(text):
+ """
+ Format text as a string. This method also converts text describing a number to a number.
+ For example, twenty three is converted to 23.
+
+ Args:
+ text: input text
+
+ Returns:
+ number if parsed, otherwise None
+ """
+
+ # Format text for numeric parsing
+ text = text.replace(",", "")
+ text = re.sub(r"(\d+)\s+(\d+)", r"\1\2", text)
+
+ try:
+ # Convert numeric words to numbers
+ text = text if text.isnumeric() else Text2Digits().convert(text)
+ # pylint: disable=W0702
+ except:
+ pass
+
+ return text if text.isnumeric() else None
+
+ @staticmethod
+ def categorical(model, text, labels):
+ """
+ Applies a text classification model to text using labels.
+
+ Args:
+ model: text classification model
+ text: input text
+ labels: labels to use
+
+ Returns:
+ categorical label, if model not None, otherwise original text returned
+ """
+
+ if model:
+ index = model(text, labels)[0][0] if model else text
+ return labels[index]
+
+ return text
+
+ @staticmethod
+ def duration(text, dtype):
+ """
+ Attempts to standardize a date duration string to a format specified by dtype.
+ If dtype is days and the duration string specifies a month range, the duration is converted
+ to days. If the duration is in years and the dtype is months, the duration is converted to months.
+
+ Examples:
+ 2021-01-01 to 2021-01-31. In days = 30, in months = 1, in years = 0.083
+ Jan 2021 to Mar 2021. In days = 60, in months = 2, in years = 0.167
+
+ Args:
+ text: date duration string
+ dtype: target duration type [supports days, weeks, months, years]
+
+ Returns:
+ duration as number
+ """
+
+ try:
+ data = re.sub(r"(?i)\s*(between|over|up to)\s+", "", text)
+ data = re.split(r"\s+(?:and|to|through)\s+", data)
+
+ d1, d2 = parse(data[0]), parse(data[1])
+ value = (d2 - d1).days
+
+ # Handle case where no year specified for first date
+ if value < 0:
+ d1 = d1.replace(year=d2.year)
+ value = (d2 - d1).days
+
+ return Column.convert(value, "days", dtype)
+ # pylint: disable=W0702
+ except:
+ pass
+
+ data = re.sub(r"\(.*?\)", "", text)
+ data = re.split(r"\s+|\-", data)
+ data = sorted(data)
+
+ data[0] = Text2Digits().convert(data[0])
+ if len(data) > 1 and not data[1].endswith("s"):
+ data[1] = data[1] + "s"
+
+ if len(data) == 2 and \
+ (data[0].replace(".", "", 1).isnumeric() and data[1] in (["days", "weeks", "months", "years"])):
+
+ value, suffix = sorted(data)
+ value = float(value)
+
+ return Column.convert(value, suffix, dtype)
+
+ return None
+
+ @staticmethod
+ def convert(value, itype, otype):
+ """
+ Attempts to convert a numeric duration from itype to otype.
+
+ Args:
+ value: numeric duration
+ itype: input type [days, weeks, months, years]
+ otype: output type [days, weeks, months, years]
+
+ Returns:
+ converted numeric duration
+ """
+
+ if itype == otype:
+ return value
+
+ if itype == "days":
+ if otype == "weeks":
+ return value / 7
+ if otype == "months":
+ return value / 30
+
+ # Years
+ return value / 365
+
+ if itype == "weeks":
+ if otype == "days":
+ return value * 7
+ if otype == "months":
+ return value / 4
+
+ # Years
+ return value / 52
+
+ if itype == "months":
+ if otype == "days":
+ return value * 30
+ if otype == "weeks":
+ return value * 4
+
+ # Years
+ return value / 12
+
+ if itype == "years":
+ if otype == "days":
+ return value * 365
+ if otype == "weeks":
+ return value * 52
+
+ # Months
+ return value * 12
+
+ return value
diff --git a/src/python/paperai/report/common.py b/src/python/paperai/report/common.py
index e0d6837..b31f70c 100644
--- a/src/python/paperai/report/common.py
+++ b/src/python/paperai/report/common.py
@@ -4,12 +4,14 @@
import regex as re
-from txtai.pipeline import Extractor, Similarity
+from txtai.pipeline import Extractor, Labels, Similarity
from ..index import Index
from ..query import Query
-class Report(object):
+from .column import Column
+
+class Report:
"""
Methods to build reports from a series of queries
"""
@@ -34,12 +36,16 @@ def __init__(self, embeddings, db, options):
# Column names
self.names = []
+ self.similarity = Similarity(options["similarity"]) if "similarity" in options else None
+ self.labels = Labels(model=self.similarity) if self.similarity else None
+
# Extractive question-answering model
# Determine if embeddings or a custom similarity model should be used to build question context
- self.extractor = Extractor(Similarity(options["similarity"]) if "similarity" in options else self.embeddings,
+ self.extractor = Extractor(self.similarity if self.similarity else self.embeddings,
options["qa"] if options["qa"] else "NeuML/bert-small-cord19qa",
minscore=options.get("minscore"),
- mintokens=options.get("mintokens"))
+ mintokens=options.get("mintokens"),
+ context=options.get("context"))
def build(self, queries, options, output):
"""
@@ -128,12 +134,14 @@ def articles(self, output, topn, metadata, results):
# Collect matching rows
rows = []
- for uid in documents:
+ for x, uid in enumerate(documents):
# Get article metadata
- self.cur.execute("SELECT Published, Title, Reference, Publication, Source, Design, Size, Sample, Method, Entry " +
- "FROM articles WHERE id = ?", [uid])
+ self.cur.execute("SELECT Published, Title, Reference, Publication, Source, Entry, Id FROM articles WHERE id = ?", [uid])
article = self.cur.fetchone()
+ if x and x % 100 == 0:
+ print(f"Processed {x} documents", end="\r")
+
# Calculate derived fields
calculated = self.calculate(uid, metadata)
@@ -166,18 +174,23 @@ def calculate(self, uid, metadata):
# Parse column parameters
fields, params = self.params(metadata)
- # Stores embedding query only and full QA column definitions
- queries, questions = [], []
+ # Different type of calculations
+ # 1. Similarity query
+ # 2. Extractor query (similarity + question)
+ # 3. Question-answering on other field
+ queries, extractions, questions = [], [], []
# Retrieve indexed document text for article
sections = self.sections(uid)
texts = [text for _, text in sections]
- for name, query, question, snippet, _, _, matches in params:
- if matches:
+ for name, query, question, snippet, _, _, matches, _ in params:
+ if query.startswith("$"):
+ questions.append((name, query.replace("$", ""), question, snippet))
+ elif matches:
queries.append((name, query, matches))
else:
- questions.append((name, query, question, snippet))
+ extractions.append((name, query, question, snippet))
# Run all extractor queries against document text
results = self.extractor.query([query for _, query, _ in queries], texts)
@@ -189,12 +202,25 @@ def calculate(self, uid, metadata):
topn = [text for _, text, _ in results[x]][:matches]
# Join results into String and return
- fields[name] = "\n\n".join([self.resolve(params, sections, uid, name, value) for value in topn])
+ value = [self.resolve(params, sections, uid, name, value) for value in topn]
+ fields[name] = "\n\n".join(value) if value else None
else:
fields[name] = None
# Add extraction fields
- for name, value in self.extractor(questions, texts):
+ for name, value in self.extractor(extractions, texts):
+ # Resolves the full value based on column parameters
+ fields[name] = self.resolve(params, sections, uid, name, value) if value else ""
+
+ # Add question fields
+ names, qa, contexts, snippets = [], [], [], []
+ for name, query, question, snippet in questions:
+ names.append(name)
+ qa.append(question)
+ contexts.append(fields[query])
+ snippets.append(snippet)
+
+ for name, value in self.extractor.answers(names, qa, contexts, contexts, snippets):
# Resolves the full value based on column parameters
fields[name] = self.resolve(params, sections, uid, name, value) if value else ""
@@ -231,13 +257,14 @@ def params(self, metadata):
question = self.variables(column["question"], metadata) if "question" in column else query
# Additional context parameters
- section = column["section"] if "section" in column else False
- surround = column["surround"] if "surround" in column else 0
- matches = column["matches"] if "matches" in column else 0
- snippet = column["snippet"] if "snippet" in column else False
+ section = column.get("section", False)
+ surround = column.get("surround", 0)
+ matches = column.get("matches", 0)
+ dtype = column.get("dtype")
+ snippet = column.get("snippet", False)
snippet = True if section or surround else snippet
- params.append((column["name"], query, question, snippet, section, surround, matches))
+ params.append((column["name"], query, question, snippet, section, surround, matches, dtype))
return fields, params
@@ -276,7 +303,7 @@ def sections(self, uid):
"""
# Retrieve indexed document text for article
- self.cur.execute(Index.SECTION_QUERY + " AND article = ? ORDER BY id", [uid])
+ self.cur.execute(Index.SECTION_QUERY + " WHERE article = ? ORDER BY id", [uid])
# Get list of document text sections
sections = []
@@ -307,7 +334,7 @@ def resolve(self, params, sections, uid, name, value):
# Get all column parameters
index = [params.index(x) for x in params if x[0] == name][0]
- _, _, _, _, section, surround, _ = params[index]
+ _, _, _, _, section, surround, _, dtype = params[index]
if value:
# Find matching section
@@ -322,6 +349,14 @@ def resolve(self, params, sections, uid, name, value):
elif surround:
value = self.surround(uid, sid, surround)
+ # Column dtype formatting
+ if dtype == "int":
+ value = Column.integer(value)
+ elif isinstance(dtype, list):
+ value = Column.categorical(self.labels, value, dtype)
+ elif dtype in ["days", "weeks", "months", "years"]:
+ value = Column.duration(value, dtype)
+
return value
def subsection(self, uid, sid):
diff --git a/src/python/paperai/report/csvr.py b/src/python/paperai/report/csvr.py
index d480c43..3f8aa06 100644
--- a/src/python/paperai/report/csvr.py
+++ b/src/python/paperai/report/csvr.py
@@ -31,7 +31,7 @@ def query(self, output, task, query):
if self.csvout:
self.csvout.close()
- self.csvout = open(os.path.join(os.path.dirname(output.name), "%s.csv" % task), "w", newline="")
+ self.csvout = open(os.path.join(os.path.dirname(output.name), f"{task}.csv"), "w", newline="", encoding="utf-8")
self.writer = csv.writer(self.csvout, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
def write(self, row):
@@ -69,23 +69,14 @@ def buildRow(self, article, sections, calculated):
# Source
row["Source"] = article[4]
- # Study Type
- row["Study Type"] = Query.design(article[5])
-
- # Sample Size
- row["Sample Size"] = article[6]
-
- # Study Population
- row["Study Population"] = Query.text(article[8] if article[8] else article[7])
-
- # Sample Text
- row["Sample Text"] = article[7]
-
# Top Matches
row["Matches"] = "\n\n".join([Query.text(text) for _, text in sections]) if sections else ""
# Entry Date
- row["Entry"] = article[9] if article[9] else ""
+ row["Entry"] = article[5] if article[5] else ""
+
+ # Id
+ row["Id"] = article[6]
# Merge in calculated fields
row.update(calculated)
diff --git a/src/python/paperai/report/execute.py b/src/python/paperai/report/execute.py
index 94770ef..9e58588 100644
--- a/src/python/paperai/report/execute.py
+++ b/src/python/paperai/report/execute.py
@@ -11,7 +11,7 @@
from ..models import Models
-class Execute(object):
+class Execute:
"""
Creates a Report
"""
@@ -71,10 +71,10 @@ def run(task, topn=None, render=None, path=None, qa=None, indir=None, threshold=
report = Execute.create(render, embeddings, db, options)
# Generate output filename
- outfile = os.path.join(outdir, "%s.%s" % (name, render))
+ outfile = os.path.join(outdir, f"{name}.{render}")
# Stream report to file
- with open(outfile, "w") as output:
+ with open(outfile, "w", encoding="utf-8") as output:
# Build the report
report.build(queries, options, output)
diff --git a/src/python/paperai/report/markdown.py b/src/python/paperai/report/markdown.py
index 4815a9e..87a3914 100644
--- a/src/python/paperai/report/markdown.py
+++ b/src/python/paperai/report/markdown.py
@@ -48,20 +48,20 @@ def write(self, output, line):
line: line to write
"""
- output.write("%s\n" % line)
+ output.write(f"{line}\n")
def query(self, output, task, query):
- self.write(output, "# %s" % query)
+ self.write(output, f"# {query}")
def section(self, output, name):
- self.write(output, "#### %s
" % name)
+ self.write(output, f"#### {name}
")
def highlight(self, output, article, highlight):
# Build citation link
- link = "[%s](%s)" % (Query.authors(article[0]) if article[0] else "Source", self.encode(article[1]))
+ link = f"[{Query.authors(article[0]) if article[0] else 'Source'}]({self.encode(article[1])})"
# Build highlight row with citation link
- self.write(output, "- %s %s
" % (Query.text(highlight), link))
+ self.write(output, f"- {Query.text(highlight)} {link}
")
def headers(self, columns, output):
self.names = columns
@@ -73,11 +73,11 @@ def headers(self, columns, output):
# Write table header
headers = "|".join(self.names)
- self.write(output, "|%s|" % headers)
+ self.write(output, f"|{headers}|")
# Write markdown separator for headers
headers = "|".join(["----"] * len(self.names))
- self.write(output, "|%s|" % headers)
+ self.write(output, f"|{headers}|")
def buildRow(self, article, sections, calculated):
row = {}
@@ -86,10 +86,10 @@ def buildRow(self, article, sections, calculated):
row["Date"] = Query.date(article[0]) if article[0] else ""
# Title
- title = "[%s](%s)" % (article[1], self.encode(article[2]))
+ title = f"[{article[1]}]({self.encode(article[2])})"
# Append Publication if available. Assume preprint otherwise and show preprint source.
- title += "
%s" % (article[3] if article[3] else article[4])
+ title += f"
{article[3] if article[3] else article[4]}"
# Source
row["Source"] = article[4]
@@ -97,30 +97,23 @@ def buildRow(self, article, sections, calculated):
# Title + Publication if available
row["Study"] = title
- # Study Type
- row["Study Type"] = Query.design(article[5])
-
- # Sample Size
- sample = Query.sample(article[6], article[7])
- row["Sample Size"] = sample if sample else ""
-
- # Study Population
- row["Study Population"] = Query.text(article[8]) if article[8] else ""
-
# Top Matches
row["Matches"] = "
".join([Query.text(text) for _, text in sections]) if sections else ""
# Entry Date
- row["Entry"] = article[9] if article[9] else ""
+ row["Entry"] = article[5] if article[5] else ""
+
+ # Id
+ row["Id"] = article[6]
# Merge in calculated fields
row.update(calculated)
# Escape | characters embedded within columns
- return {column: self.column(row[column]) for column in row}
+ return {name: self.column(value) for name, value in row.items()}
def writeRow(self, output, row):
- self.write(output, "|%s|" % "|".join(row))
+ self.write(output, f"|{'|'.join(row)}|")
def separator(self, output):
# Write section separator
diff --git a/src/python/paperai/report/task.py b/src/python/paperai/report/task.py
index 2637217..9a5a697 100644
--- a/src/python/paperai/report/task.py
+++ b/src/python/paperai/report/task.py
@@ -6,7 +6,7 @@
import yaml
-class Task(object):
+class Task:
"""
YAML task configuration loader
"""
@@ -25,7 +25,7 @@ def load(task):
if os.path.exists(task):
# Load tasks yml file
- with open(task, "r") as f:
+ with open(task, "r", encoding="utf-8") as f:
# Read configuration
config = yaml.safe_load(f)
diff --git a/src/python/paperai/vectors.py b/src/python/paperai/vectors.py
index 7d96a93..c51ebd5 100644
--- a/src/python/paperai/vectors.py
+++ b/src/python/paperai/vectors.py
@@ -13,7 +13,7 @@
from .models import Models
-class RowIterator(object):
+class RowIterator:
"""
Iterates over rows in a database query. Allows for multiple iterations.
"""
@@ -78,18 +78,18 @@ def stream(self, dbfile):
count += 1
if count % 1000 == 0:
- print("Streamed %d documents" % (count), end="\r")
+ print(f"Streamed {count} documents", end="\r")
# Skip documents with no tokens parsed
if tokens:
yield tokens
- print("Iterated over %d total rows" % (count))
+ print(f"Iterated over {count} total rows")
# Free database resources
db.close()
-class Vectors(object):
+class Vectors:
"""
Methods to build a FastText model.
"""
@@ -142,7 +142,7 @@ def run(path, size, mincount, output):
if not output:
# Output file path
- output = Models.vectorPath("cord19-%dd" % size, True)
+ output = Models.vectorPath(f"cord19-{size}d", True)
# Build word vectors model
WordVectors.build(tokens, size, mincount, output)
diff --git a/test/python/testapi.py b/test/python/testapi.py
index 52cdd1b..f735e7f 100644
--- a/test/python/testapi.py
+++ b/test/python/testapi.py
@@ -38,7 +38,7 @@ def start():
config = os.path.join(tempfile.gettempdir(), "testapi.yml")
- with open(config, "w") as output:
+ with open(config, "w", encoding="utf-8") as output:
output.write(INDEX % Utils.PATH)
client = TestClient(app)
@@ -59,7 +59,7 @@ def testSearch(self):
# Run search
params = urllib.parse.urlencode({"query": "+hypertension ci", "limit": 1})
- results= client.get("search?%s" % params).json()
+ results= client.get(f"search?{params}").json()
# Check number of results
self.assertEqual(len(results), 1)
diff --git a/test/python/testcolumn.py b/test/python/testcolumn.py
new file mode 100644
index 0000000..43a53f9
--- /dev/null
+++ b/test/python/testcolumn.py
@@ -0,0 +1,73 @@
+"""
+Column module tests
+"""
+
+import unittest
+
+from paperai.report.column import Column
+
+class TestColumn(unittest.TestCase):
+ """
+ Column tests
+ """
+
+ def testInteger(self):
+ """
+ Tests parsing integers from strings.
+ """
+
+ self.assertEqual(Column.integer("Twenty Three"), "23")
+ self.assertEqual(Column.integer("Two hundred and twelve"), "212")
+ self.assertEqual(Column.integer("4,000,234"), "4000234")
+ self.assertEqual(Column.integer("23"), "23")
+ self.assertEqual(Column.integer("30 days"), None)
+
+ def testCategorical(self):
+ """
+ Tests generating categorical strings.
+ """
+
+ def model(text, labels):
+ return [(0, 0.9), (1, 0.85), (text, labels)]
+
+ self.assertEqual(Column.categorical(model, "text", ["labels"]), "labels")
+ self.assertEqual(Column.categorical(None, "text", ["labels"]), "text")
+
+ def testDurationStartEnd(self):
+ """
+ Test duration ranges with start and end
+ """
+
+ self.assertEqual(Column.duration("2021-01-01 to 2021-01-31", "days"), 30)
+ self.assertEqual(Column.duration("2021-01-01 to 2021-01-31", "months"), 1)
+ self.assertEqual(round(Column.duration("2021-01-01 to 2021-01-31", "weeks"), 2), 4.29)
+ self.assertEqual(round(Column.duration("2021-01-01 to 2021-01-31", "years"), 2), 0.08)
+
+ def testDurationStartEndNoYear(self):
+ """
+ Test duration ranges with start and end but no year for first date
+ """
+
+ self.assertEqual(Column.duration("January to March 2020", "days"), 60)
+
+ def testDurationRelative(self):
+ """
+ Test relative duration ranges
+ """
+
+ self.assertEqual(Column.duration("30 day", "days"), 30)
+
+ self.assertEqual(Column.duration("1 week", "days"), 7)
+ self.assertEqual(Column.duration("1 week", "months"), 0.25)
+ self.assertEqual(round(Column.duration("1 week", "years"), 2), 0.02)
+
+ self.assertEqual(Column.duration("2 months", "days"), 60)
+ self.assertEqual(Column.duration("2 months", "weeks"), 8)
+ self.assertEqual(round(Column.duration("2 months", "years"), 2), 0.17)
+
+ self.assertEqual(Column.duration("1 year", "days"), 365)
+ self.assertEqual(Column.duration("1 year", "weeks"), 52)
+ self.assertEqual(Column.duration("1 year", "months"), 12)
+
+ self.assertEqual(Column.duration("30 moons", "days"), None)
+ self.assertEqual(Column.convert("30", "moons", "days"), "30")
diff --git a/test/python/testexport.py b/test/python/testexport.py
index 580ecd6..2fc90bb 100644
--- a/test/python/testexport.py
+++ b/test/python/testexport.py
@@ -20,4 +20,4 @@ def testRun(self):
"""
Export.run(Utils.PATH + "/export.txt", Utils.PATH)
- self.assertEqual(Utils.hashfile(Utils.PATH + "/export.txt"), "a6f85df295a19f2d3c1a10ec8edce6ae")
+ self.assertEqual(Utils.hashfile(Utils.PATH + "/export.txt"), "ac15a3ece486c3035ef861f6706c3e1b")
diff --git a/test/python/testindex.py b/test/python/testindex.py
index 0d4c3f5..60a9c2e 100644
--- a/test/python/testindex.py
+++ b/test/python/testindex.py
@@ -20,7 +20,7 @@ def testStream(self):
"""
# Full index stream
- self.assertEqual(len(list(Index.stream(Utils.DBFILE, 0))), 21478)
+ self.assertEqual(len(list(Index.stream(Utils.DBFILE, 0))), 29218)
# Partial index stream - top n documents by entry date
- self.assertEqual(len(list(Index.stream(Utils.DBFILE, 10))), 224)
+ self.assertEqual(len(list(Index.stream(Utils.DBFILE, 10))), 287)
diff --git a/test/python/testquery.py b/test/python/testquery.py
index 944fec3..231e5ab 100644
--- a/test/python/testquery.py
+++ b/test/python/testquery.py
@@ -2,7 +2,6 @@
Query module tests
"""
-import os
import unittest
from contextlib import redirect_stdout
@@ -17,15 +16,14 @@ class TestQuery(unittest.TestCase):
Query tests
"""
- @unittest.skipIf(os.name == "nt", "Faiss not installed on Windows")
def testRun(self):
"""
Test query execution
"""
# Execute query
- with open(Utils.PATH + "/query.txt", "w", newline="\n") as query:
+ with open(Utils.PATH + "/query.txt", "w", newline="\n", encoding="utf-8") as query:
with redirect_stdout(query):
Query.run("risk factors studied", 10, Utils.PATH)
- self.assertEqual(Utils.hashfile(Utils.PATH + "/query.txt"), "b1932b9ceb6c2ea2b626ebb44d89340b")
+ self.assertEqual(Utils.hashfile(Utils.PATH + "/query.txt"), "b7ba65adc0aacccf161d61da8616bfca")
diff --git a/test/python/testreport.py b/test/python/testreport.py
index b40eb48..e4cb818 100644
--- a/test/python/testreport.py
+++ b/test/python/testreport.py
@@ -2,7 +2,6 @@
Report module tests
"""
-import os
import unittest
# pylint: disable=E0401
@@ -15,7 +14,6 @@ class TestReport(unittest.TestCase):
Report tests
"""
- @unittest.skipIf(os.name == "nt", "Faiss not installed on Windows")
def testReport1(self):
"""
Runs test queries from report1.yml test file
@@ -25,16 +23,15 @@ def testReport1(self):
Execute.run(Utils.PATH + "/report1.yml", 10, "csv", Utils.PATH, None)
Execute.run(Utils.PATH + "/report1.yml", 10, "md", Utils.PATH, None)
- hashes = [("Age.csv", "ed2b9c761dc949708cd6254e6207ff83"),
- ("Heart Disease.csv", "90f2dede871c545dd1492aef8ed84645"),
- ("Heart Failure.csv", "2152a8187ff53e9c4224e3c9891b5b33"),
- ("Report1.md", "2da3dbcde55153ddaed1e709314eac07")]
+ hashes = [("Age.csv", "e5d589d131dce3293532e3fd19593637"),
+ ("Heart Disease.csv", "96b144fc1566e2c0aa774d098e203922"),
+ ("Heart Failure.csv", "afd812f7875c4fcb45bf800952327dba"),
+ ("Report1.md", "f1f3b70dd6242488f8d58e1e5d2faea6")]
# Check file hashes
for name, value in hashes:
self.assertEqual(Utils.hashfile(Utils.PATH + "/" + name), value)
- @unittest.skipIf(os.name == "nt", "Faiss not installed on Windows")
def testReport2(self):
"""
Runs test queries from report2.yml test file
@@ -44,17 +41,16 @@ def testReport2(self):
Execute.run(Utils.PATH + "/report2.yml", 10, "csv", Utils.PATH, None)
Execute.run(Utils.PATH + "/report2.yml", 10, "md", Utils.PATH, None)
- hashes = [("Match.csv", "2def38a008f33f25d7ab4a763d159e80"),
- ("MatchSurround.csv", "e9f581d19b8802822f47261bce0e91b1"),
- ("Section.csv", "7ae8b295f0d959ba12410807db7b7e48"),
- ("Surround.csv", "fed9fb4249bf2f73fa9822753d359207"),
- ("Report2.md", "9218fe80fe5e9fdd50c5719f54c52061")]
+ hashes = [("Match.csv", "9cb5ce8896355d049084d61fae13d97f"),
+ ("MatchSurround.csv", "47e4d2ec7ae8fda30a78d628d124f204"),
+ ("Section.csv", "7113d5af95542193fc3dc21dc785b014"),
+ ("Surround.csv", "14d124f85c140077d58ae3636ba8557f"),
+ ("Report2.md", "7813d253d7a792f93915c2dccfb78483")]
# Check file hashes
for name, value in hashes:
self.assertEqual(Utils.hashfile(Utils.PATH + "/" + name), value)
- @unittest.skipIf(os.name == "nt", "Faiss not installed on Windows")
def testReport3(self):
"""
Runs test queries from report3.yml test file
@@ -65,8 +61,9 @@ def testReport3(self):
Execute.run(Utils.PATH + "/report3.yml", 1, "md", Utils.PATH, None)
Execute.run(Utils.PATH + "/report3.yml", 1, "ant", Utils.PATH, None, Utils.PATH)
- hashes = [("AI.csv", "b47e96639a210d2089a5bd4e7e7bfc98"),
- ("Report3.md", "1a47340bc135fc086160c62f8731edee")]
+ hashes = [("AI.csv", "94f0bead413eb71835c3f27881b29c91"),
+ ("All.csv", "3bca7a39a541fa68b3ef457625fb0120"),
+ ("Report3.md", "0fc53703dace57e3403294fb8ea7e9d1")]
# Check file hashes
for name, value in hashes:
diff --git a/test/python/utils.py b/test/python/utils.py
index 214ff83..417dbd6 100644
--- a/test/python/utils.py
+++ b/test/python/utils.py
@@ -25,6 +25,6 @@ def hashfile(path):
MD5 hash
"""
- with open(path, "r") as data:
+ with open(path, "r", encoding="utf-8") as data:
# Read file into string and build MD5 hash
return hashlib.md5(data.read().encode()).hexdigest()