Skip to content

Commit

Permalink
Update project files for Python 3.9 and coding standards. Update example
Browse files Browse the repository at this point in the history
app, closes #69. Fix error with config.json files, closes #77.
  • Loading branch information
davidmezzetti committed Dec 28, 2024
1 parent fed0584 commit 2df7541
Show file tree
Hide file tree
Showing 28 changed files with 170 additions and 225 deletions.
1 change: 1 addition & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
source = src/python
concurrency = multiprocessing,thread
disable_warnings = no-data-collected
omit = **/__main__.py

[combine]
disable_warnings = no-data-collected
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: 3.8
python-version: 3.9

- name: Install dependencies - Windows
run: choco install wget
Expand Down
5 changes: 3 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
repos:
- repo: https://github.com/pycqa/pylint
rev: v2.12.1
rev: v3.3.1
hooks:
- id: pylint
args:
- -d import-error
- -d duplicate-code
- -d too-many-positional-arguments
- repo: https://github.com/ambv/black
rev: 22.3.0
rev: 24.10.0
hooks:
- id: black
language_version: python3
3 changes: 0 additions & 3 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,3 @@ min-public-methods=0

[FORMAT]
max-line-length=150

[MESSAGES CONTROL]
disable=R0201
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ The easiest way to install is via pip and PyPI
pip install paperai
```

Python 3.8+ is supported. Using a Python [virtual environment](https://docs.python.org/3/library/venv.html) is recommended.
Python 3.9+ is supported. Using a Python [virtual environment](https://docs.python.org/3/library/venv.html) is recommended.

paperai can also be installed directly from GitHub to access the latest, unreleased features.

Expand Down
27 changes: 7 additions & 20 deletions examples/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Search a paperai index.
Requires streamlit and lxml to be installed.
pip install streamlit lxml
pip install streamlit lxml[html_clean]
"""

import os
Expand Down Expand Up @@ -64,12 +64,9 @@ def search(self, query, topn, threshold):
articles = []

# Print each result, sorted by max score descending
for uid in sorted(
documents, key=lambda k: sum([x[0] for x in documents[k]]), reverse=True
):
for uid in sorted(documents, key=lambda k: sum(x[0] for x in documents[k]), reverse=True):
cur.execute(
"SELECT Title, Published, Publication, Entry, Id, Reference "
+ "FROM articles WHERE id = ?",
"SELECT Title, Published, Publication, Entry, Id, Reference " + "FROM articles WHERE id = ?",
[uid],
)
article = cur.fetchone()
Expand All @@ -96,9 +93,7 @@ def run(self):
Runs Streamlit application.
"""

st.sidebar.image(
"https://github.com/neuml/paperai/raw/master/logo.png", width=256
)
st.sidebar.image("https://github.com/neuml/paperai/raw/master/logo.png", width=256)
st.sidebar.markdown("## Search parameters")

# Search parameters
Expand All @@ -110,19 +105,11 @@ def run(self):
"<style>.small-font { font-size: 0.8rem !important;}</style>",
unsafe_allow_html=True,
)
st.sidebar.markdown(
"<p class='small-font'>Select columns</p>", unsafe_allow_html=True
)
columns = [
column
for column, enabled in self.columns
if st.sidebar.checkbox(column, enabled)
]
st.sidebar.markdown("<p class='small-font'>Select columns</p>", unsafe_allow_html=True)
columns = [column for column, enabled in self.columns if st.sidebar.checkbox(column, enabled)]
if self.embeddings and query:
df = self.search(query, topn, threshold)
st.markdown(
f"<p class='small-font'>{len(df)} results</p>", unsafe_allow_html=True
)
st.markdown(f"<p class='small-font'>{len(df)} results</p>", unsafe_allow_html=True)

if not df.empty:
html = df[columns].to_html(escape=False, index=False)
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[tool.black]
line-length = 150
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
packages=find_packages(where="src/python"),
package_dir={"": "src/python"},
keywords="search embedding machine-learning nlp covid-19 medical scientific papers",
python_requires=">=3.8",
python_requires=">=3.9",
entry_points={
"console_scripts": [
"paperai = paperai.shell:main",
Expand All @@ -41,7 +41,7 @@
"regex>=2020.5.14",
"rich>=12.0.1",
"text2digits>=0.1.0",
"txtai[api,similarity]>=6.0.0",
"txtai[api,similarity]>=8.1.0",
"txtmarker>=1.0.0",
],
extras_require=extras,
Expand Down
13 changes: 4 additions & 9 deletions src/python/paperai/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,7 @@ def search(self, query, request=None):
if self.embeddings:
dbfile = os.path.join(self.config["path"], "articles.sqlite")
limit = self.limit(request.query_params.get("limit")) if request else 10
threshold = (
float(request.query_params["threshold"])
if request and "threshold" in request.query_params
else None
)
threshold = float(request.query_params["threshold"]) if request and "threshold" in request.query_params else None

with sqlite3.connect(dbfile) as db:
cur = db.cursor()
Expand All @@ -50,17 +46,16 @@ def search(self, query, request=None):
# Print each result, sorted by max score descending
for uid in sorted(
documents,
key=lambda k: sum([x[0] for x in documents[k]]),
key=lambda k: sum(x[0] for x in documents[k]),
reverse=True,
):
cur.execute(
"SELECT Title, Published, Publication, Entry, Id, Reference "
+ "FROM articles WHERE id = ?",
"SELECT Title, Published, Publication, Entry, Id, Reference " + "FROM articles WHERE id = ?",
[uid],
)
article = cur.fetchone()

score = max([score for score, text in documents[uid]])
score = max(score for score, text in documents[uid])
matches = [text for _, text in documents[uid]]

article = {
Expand Down
12 changes: 3 additions & 9 deletions src/python/paperai/highlights.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def buildGraph(nodes):

# Tokenize nodes, store uid and tokens
vectors = []
for (uid, text) in nodes:
for uid, text in nodes:
# Custom tokenization that works best with textrank matching
tokens = Highlights.tokenize(text)

Expand All @@ -148,9 +148,7 @@ def buildGraph(nodes):
node2, tokens2 = pair[1]

# Add a graph edge and compute the cosine similarity for the weight
graph.add_edge(
node1, node2, weight=Highlights.jaccardIndex(tokens1, tokens2)
)
graph.add_edge(node1, node2, weight=Highlights.jaccardIndex(tokens1, tokens2))

return graph

Expand Down Expand Up @@ -183,8 +181,4 @@ def tokenize(text):
"""

# Remove additional stop words to improve highlighting results
return {
token
for token in Tokenizer.tokenize(text)
if token not in Highlights.STOP_WORDS
}
return {token for token in Tokenizer.tokenize(text) if token not in Highlights.STOP_WORDS}
11 changes: 2 additions & 9 deletions src/python/paperai/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ def stream(dbfile, maxsize, scoring):
cur = db.cursor()

# Select sentences from tagged articles
query = (
Index.SECTION_QUERY
+ " WHERE article in (SELECT article FROM articles a WHERE a.id = article AND a.tags IS NOT NULL)"
)
query = Index.SECTION_QUERY + " WHERE article in (SELECT article FROM articles a WHERE a.id = article AND a.tags IS NOT NULL)"

if maxsize > 0:
query += f" AND article in (SELECT id FROM articles ORDER BY entry DESC LIMIT {maxsize})"
Expand All @@ -55,11 +52,7 @@ def stream(dbfile, maxsize, scoring):
# Unpack row
uid, name, text = row

if (
not scoring
or not name
or not re.search(Index.SECTION_FILTER, name.lower())
):
if not scoring or not name or not re.search(Index.SECTION_FILTER, name.lower()):
# Tokenize text
text = Tokenizer.tokenize(text) if scoring else text

Expand Down
5 changes: 2 additions & 3 deletions src/python/paperai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@ def load(path):

dbfile = os.path.join(path, "articles.sqlite")

if os.path.isfile(os.path.join(path, "config")):
embeddings = None
if any(os.path.isfile(os.path.join(path, x)) for x in ["config", "config.json"]):
print(f"Loading model from {path}")
embeddings = Embeddings()
embeddings.load(path)
else:
embeddings = None

# Connect to database file
db = sqlite3.connect(dbfile)
Expand Down
33 changes: 9 additions & 24 deletions src/python/paperai/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,30 +40,21 @@ def search(embeddings, cur, query, topn, threshold):
return []

# Default threshold if None
threshold = threshold if threshold is not None else 0.6
threshold = threshold if threshold is not None else 0.25

results = []

# Get list of required and prohibited tokens
must = [
token.strip("+")
for token in query.split()
if token.startswith("+") and len(token) > 1
]
mnot = [
token.strip("-")
for token in query.split()
if token.startswith("-") and len(token) > 1
]
must = [token.strip("+") for token in query.split() if token.startswith("+") and len(token) > 1]
mnot = [token.strip("-") for token in query.split() if token.startswith("-") and len(token) > 1]

# Tokenize search query, if necessary
query = Tokenizer.tokenize(query) if embeddings.isweighted() else query

# Retrieve topn * 5 to account for duplicate matches
for result in embeddings.search(query, topn * 5):
uid, score = (
(result["id"], result["score"]) if isinstance(result, dict) else result
)
uid, score = (result["id"], result["score"]) if isinstance(result, dict) else result

if score >= threshold:
cur.execute("SELECT Article, Text FROM sections WHERE id = ?", [uid])

Expand All @@ -73,9 +64,7 @@ def search(embeddings, cur, query, topn, threshold):
# Add result if:
# - all required tokens are present or there are not required tokens AND
# - all prohibited tokens are not present or there are not prohibited tokens
if (
not must or all(token.lower() in text.lower() for token in must)
) and (
if (not must or all(token.lower() in text.lower() for token in must)) and (
not mnot or all(token.lower() not in text.lower() for token in mnot)
):
# Save result
Expand All @@ -100,7 +89,7 @@ def highlights(results, topn):
sections = {}
for uid, score, _, text in results:
# Filter out lower scored results
if score >= 0.35:
if score >= 0.1:
sections[text] = (uid, text)

# Return up to 5 highlights
Expand Down Expand Up @@ -133,9 +122,7 @@ def documents(results, topn):
documents[uid] = sorted(list(documents[uid]), reverse=True)

# Get documents with top n best sections
topn = sorted(
documents, key=lambda k: max([x[0] for x in documents[k]]), reverse=True
)[:topn]
topn = sorted(documents, key=lambda k: max(x[0] for x in documents[k]), reverse=True)[:topn]
return {uid: documents[uid] for uid in topn}

@staticmethod
Expand Down Expand Up @@ -271,9 +258,7 @@ def query(embeddings, db, query, topn, threshold):
console.print()

# Print each result, sorted by max score descending
for uid in sorted(
documents, key=lambda k: sum(x[0] for x in documents[k]), reverse=True
):
for uid in sorted(documents, key=lambda k: sum(x[0] for x in documents[k]), reverse=True):
cur.execute(
"SELECT Title, Published, Publication, Entry, Id, Reference FROM articles WHERE id = ?",
[uid],
Expand Down
6 changes: 1 addition & 5 deletions src/python/paperai/report/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,7 @@ def duration(text, dtype):
if len(data) > 1 and not data[1].endswith("s"):
data[1] = data[1] + "s"

if len(data) == 2 and (
data[0].replace(".", "", 1).isdigit()
and data[1] in (["days", "weeks", "months", "years"])
):

if len(data) == 2 and (data[0].replace(".", "", 1).isdigit() and data[1] in (["days", "weeks", "months", "years"])):
value, suffix = sorted(data)
value = float(value)

Expand Down
Loading

0 comments on commit 2df7541

Please sign in to comment.