From 2df75415d96b7c70a249bd6bac1e91acebb14805 Mon Sep 17 00:00:00 2001 From: davidmezzetti <561939+davidmezzetti@users.noreply.github.com> Date: Sat, 28 Dec 2024 12:11:21 -0500 Subject: [PATCH] Update project files for Python 3.9 and coding standards. Update example app, closes #69. Fix error with config.json files, closes #77. --- .coveragerc | 1 + .github/workflows/build.yml | 2 +- .pre-commit-config.yaml | 5 ++- .pylintrc | 3 -- README.md | 2 +- examples/search.py | 27 +++--------- pyproject.toml | 2 + setup.py | 4 +- src/python/paperai/api.py | 13 ++---- src/python/paperai/highlights.py | 12 ++---- src/python/paperai/index.py | 11 +---- src/python/paperai/models.py | 5 +-- src/python/paperai/query.py | 33 ++++---------- src/python/paperai/report/column.py | 6 +-- src/python/paperai/report/common.py | 62 +++++++-------------------- src/python/paperai/report/csvr.py | 8 +--- src/python/paperai/report/execute.py | 10 ++--- src/python/paperai/report/markdown.py | 6 +-- src/python/paperai/vectors.py | 4 +- test/python/testapi.py | 37 +++++++++++++++- test/python/testcolumn.py | 49 +++++++++++++-------- test/python/testexport.py | 5 +-- test/python/testindex.py | 4 +- test/python/testquery.py | 9 +--- test/python/testreport.py | 56 +++++++++++++----------- test/python/testshell.py | 2 +- test/python/testvectors.py | 4 +- test/python/utils.py | 13 +++--- 28 files changed, 170 insertions(+), 225 deletions(-) create mode 100644 pyproject.toml diff --git a/.coveragerc b/.coveragerc index 2483412..0cb653f 100644 --- a/.coveragerc +++ b/.coveragerc @@ -2,6 +2,7 @@ source = src/python concurrency = multiprocessing,thread disable_warnings = no-data-collected +omit = **/__main__.py [combine] disable_warnings = no-data-collected diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 362f720..38c3183 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -15,7 +15,7 @@ jobs: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: 3.9 - name: Install dependencies - Windows run: choco install wget diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index aaa9711..79f065d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,13 +1,14 @@ repos: - repo: https://github.com/pycqa/pylint - rev: v2.12.1 + rev: v3.3.1 hooks: - id: pylint args: - -d import-error - -d duplicate-code + - -d too-many-positional-arguments - repo: https://github.com/ambv/black - rev: 22.3.0 + rev: 24.10.0 hooks: - id: black language_version: python3 diff --git a/.pylintrc b/.pylintrc index dcb45dd..3bb9f57 100644 --- a/.pylintrc +++ b/.pylintrc @@ -15,6 +15,3 @@ min-public-methods=0 [FORMAT] max-line-length=150 - -[MESSAGES CONTROL] -disable=R0201 diff --git a/README.md b/README.md index 820fc36..aa7addb 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ The easiest way to install is via pip and PyPI pip install paperai ``` -Python 3.8+ is supported. Using a Python [virtual environment](https://docs.python.org/3/library/venv.html) is recommended. +Python 3.9+ is supported. Using a Python [virtual environment](https://docs.python.org/3/library/venv.html) is recommended. paperai can also be installed directly from GitHub to access the latest, unreleased features. diff --git a/examples/search.py b/examples/search.py index 618e876..20afd5b 100644 --- a/examples/search.py +++ b/examples/search.py @@ -2,7 +2,7 @@ Search a paperai index. Requires streamlit and lxml to be installed. - pip install streamlit lxml + pip install streamlit lxml[html_clean] """ import os @@ -64,12 +64,9 @@ def search(self, query, topn, threshold): articles = [] # Print each result, sorted by max score descending - for uid in sorted( - documents, key=lambda k: sum([x[0] for x in documents[k]]), reverse=True - ): + for uid in sorted(documents, key=lambda k: sum(x[0] for x in documents[k]), reverse=True): cur.execute( - "SELECT Title, Published, Publication, Entry, Id, Reference " - + "FROM articles WHERE id = ?", + "SELECT Title, Published, Publication, Entry, Id, Reference " + "FROM articles WHERE id = ?", [uid], ) article = cur.fetchone() @@ -96,9 +93,7 @@ def run(self): Runs Streamlit application. """ - st.sidebar.image( - "https://github.com/neuml/paperai/raw/master/logo.png", width=256 - ) + st.sidebar.image("https://github.com/neuml/paperai/raw/master/logo.png", width=256) st.sidebar.markdown("## Search parameters") # Search parameters @@ -110,19 +105,11 @@ def run(self): "", unsafe_allow_html=True, ) - st.sidebar.markdown( - "
Select columns
", unsafe_allow_html=True - ) - columns = [ - column - for column, enabled in self.columns - if st.sidebar.checkbox(column, enabled) - ] + st.sidebar.markdown("Select columns
", unsafe_allow_html=True) + columns = [column for column, enabled in self.columns if st.sidebar.checkbox(column, enabled)] if self.embeddings and query: df = self.search(query, topn, threshold) - st.markdown( - f"{len(df)} results
", unsafe_allow_html=True - ) + st.markdown(f"{len(df)} results
", unsafe_allow_html=True) if not df.empty: html = df[columns].to_html(escape=False, index=False) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..6b313bc --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 150 diff --git a/setup.py b/setup.py index 499a210..60f2403 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ packages=find_packages(where="src/python"), package_dir={"": "src/python"}, keywords="search embedding machine-learning nlp covid-19 medical scientific papers", - python_requires=">=3.8", + python_requires=">=3.9", entry_points={ "console_scripts": [ "paperai = paperai.shell:main", @@ -41,7 +41,7 @@ "regex>=2020.5.14", "rich>=12.0.1", "text2digits>=0.1.0", - "txtai[api,similarity]>=6.0.0", + "txtai[api,similarity]>=8.1.0", "txtmarker>=1.0.0", ], extras_require=extras, diff --git a/src/python/paperai/api.py b/src/python/paperai/api.py index 0b7bc1a..a320d93 100644 --- a/src/python/paperai/api.py +++ b/src/python/paperai/api.py @@ -30,11 +30,7 @@ def search(self, query, request=None): if self.embeddings: dbfile = os.path.join(self.config["path"], "articles.sqlite") limit = self.limit(request.query_params.get("limit")) if request else 10 - threshold = ( - float(request.query_params["threshold"]) - if request and "threshold" in request.query_params - else None - ) + threshold = float(request.query_params["threshold"]) if request and "threshold" in request.query_params else None with sqlite3.connect(dbfile) as db: cur = db.cursor() @@ -50,17 +46,16 @@ def search(self, query, request=None): # Print each result, sorted by max score descending for uid in sorted( documents, - key=lambda k: sum([x[0] for x in documents[k]]), + key=lambda k: sum(x[0] for x in documents[k]), reverse=True, ): cur.execute( - "SELECT Title, Published, Publication, Entry, Id, Reference " - + "FROM articles WHERE id = ?", + "SELECT Title, Published, Publication, Entry, Id, Reference " + "FROM articles WHERE id = ?", [uid], ) article = cur.fetchone() - score = max([score for score, text in documents[uid]]) + score = max(score for score, text in documents[uid]) matches = [text for _, text in documents[uid]] article = { diff --git a/src/python/paperai/highlights.py b/src/python/paperai/highlights.py index 2446f0f..739b3f6 100644 --- a/src/python/paperai/highlights.py +++ b/src/python/paperai/highlights.py @@ -133,7 +133,7 @@ def buildGraph(nodes): # Tokenize nodes, store uid and tokens vectors = [] - for (uid, text) in nodes: + for uid, text in nodes: # Custom tokenization that works best with textrank matching tokens = Highlights.tokenize(text) @@ -148,9 +148,7 @@ def buildGraph(nodes): node2, tokens2 = pair[1] # Add a graph edge and compute the cosine similarity for the weight - graph.add_edge( - node1, node2, weight=Highlights.jaccardIndex(tokens1, tokens2) - ) + graph.add_edge(node1, node2, weight=Highlights.jaccardIndex(tokens1, tokens2)) return graph @@ -183,8 +181,4 @@ def tokenize(text): """ # Remove additional stop words to improve highlighting results - return { - token - for token in Tokenizer.tokenize(text) - if token not in Highlights.STOP_WORDS - } + return {token for token in Tokenizer.tokenize(text) if token not in Highlights.STOP_WORDS} diff --git a/src/python/paperai/index.py b/src/python/paperai/index.py index be0266d..f08ebdb 100644 --- a/src/python/paperai/index.py +++ b/src/python/paperai/index.py @@ -39,10 +39,7 @@ def stream(dbfile, maxsize, scoring): cur = db.cursor() # Select sentences from tagged articles - query = ( - Index.SECTION_QUERY - + " WHERE article in (SELECT article FROM articles a WHERE a.id = article AND a.tags IS NOT NULL)" - ) + query = Index.SECTION_QUERY + " WHERE article in (SELECT article FROM articles a WHERE a.id = article AND a.tags IS NOT NULL)" if maxsize > 0: query += f" AND article in (SELECT id FROM articles ORDER BY entry DESC LIMIT {maxsize})" @@ -55,11 +52,7 @@ def stream(dbfile, maxsize, scoring): # Unpack row uid, name, text = row - if ( - not scoring - or not name - or not re.search(Index.SECTION_FILTER, name.lower()) - ): + if not scoring or not name or not re.search(Index.SECTION_FILTER, name.lower()): # Tokenize text text = Tokenizer.tokenize(text) if scoring else text diff --git a/src/python/paperai/models.py b/src/python/paperai/models.py index 0df4004..8c8133a 100644 --- a/src/python/paperai/models.py +++ b/src/python/paperai/models.py @@ -28,12 +28,11 @@ def load(path): dbfile = os.path.join(path, "articles.sqlite") - if os.path.isfile(os.path.join(path, "config")): + embeddings = None + if any(os.path.isfile(os.path.join(path, x)) for x in ["config", "config.json"]): print(f"Loading model from {path}") embeddings = Embeddings() embeddings.load(path) - else: - embeddings = None # Connect to database file db = sqlite3.connect(dbfile) diff --git a/src/python/paperai/query.py b/src/python/paperai/query.py index cb56668..1b2f566 100644 --- a/src/python/paperai/query.py +++ b/src/python/paperai/query.py @@ -40,30 +40,21 @@ def search(embeddings, cur, query, topn, threshold): return [] # Default threshold if None - threshold = threshold if threshold is not None else 0.6 + threshold = threshold if threshold is not None else 0.25 results = [] # Get list of required and prohibited tokens - must = [ - token.strip("+") - for token in query.split() - if token.startswith("+") and len(token) > 1 - ] - mnot = [ - token.strip("-") - for token in query.split() - if token.startswith("-") and len(token) > 1 - ] + must = [token.strip("+") for token in query.split() if token.startswith("+") and len(token) > 1] + mnot = [token.strip("-") for token in query.split() if token.startswith("-") and len(token) > 1] # Tokenize search query, if necessary query = Tokenizer.tokenize(query) if embeddings.isweighted() else query # Retrieve topn * 5 to account for duplicate matches for result in embeddings.search(query, topn * 5): - uid, score = ( - (result["id"], result["score"]) if isinstance(result, dict) else result - ) + uid, score = (result["id"], result["score"]) if isinstance(result, dict) else result + if score >= threshold: cur.execute("SELECT Article, Text FROM sections WHERE id = ?", [uid]) @@ -73,9 +64,7 @@ def search(embeddings, cur, query, topn, threshold): # Add result if: # - all required tokens are present or there are not required tokens AND # - all prohibited tokens are not present or there are not prohibited tokens - if ( - not must or all(token.lower() in text.lower() for token in must) - ) and ( + if (not must or all(token.lower() in text.lower() for token in must)) and ( not mnot or all(token.lower() not in text.lower() for token in mnot) ): # Save result @@ -100,7 +89,7 @@ def highlights(results, topn): sections = {} for uid, score, _, text in results: # Filter out lower scored results - if score >= 0.35: + if score >= 0.1: sections[text] = (uid, text) # Return up to 5 highlights @@ -133,9 +122,7 @@ def documents(results, topn): documents[uid] = sorted(list(documents[uid]), reverse=True) # Get documents with top n best sections - topn = sorted( - documents, key=lambda k: max([x[0] for x in documents[k]]), reverse=True - )[:topn] + topn = sorted(documents, key=lambda k: max(x[0] for x in documents[k]), reverse=True)[:topn] return {uid: documents[uid] for uid in topn} @staticmethod @@ -271,9 +258,7 @@ def query(embeddings, db, query, topn, threshold): console.print() # Print each result, sorted by max score descending - for uid in sorted( - documents, key=lambda k: sum(x[0] for x in documents[k]), reverse=True - ): + for uid in sorted(documents, key=lambda k: sum(x[0] for x in documents[k]), reverse=True): cur.execute( "SELECT Title, Published, Publication, Entry, Id, Reference FROM articles WHERE id = ?", [uid], diff --git a/src/python/paperai/report/column.py b/src/python/paperai/report/column.py index d2ba567..ae239cb 100644 --- a/src/python/paperai/report/column.py +++ b/src/python/paperai/report/column.py @@ -103,11 +103,7 @@ def duration(text, dtype): if len(data) > 1 and not data[1].endswith("s"): data[1] = data[1] + "s" - if len(data) == 2 and ( - data[0].replace(".", "", 1).isdigit() - and data[1] in (["days", "weeks", "months", "years"]) - ): - + if len(data) == 2 and (data[0].replace(".", "", 1).isdigit() and data[1] in (["days", "weeks", "months", "years"])): value, suffix = sorted(data) value = float(value) diff --git a/src/python/paperai/report/common.py b/src/python/paperai/report/common.py index b0eb30f..b98ad36 100644 --- a/src/python/paperai/report/common.py +++ b/src/python/paperai/report/common.py @@ -4,7 +4,7 @@ import regex as re -from txtai.pipeline import Extractor, Labels, Similarity, Tokenizer +from txtai.pipeline import Labels, RAG, Similarity, Tokenizer from ..index import Index from ..query import Query @@ -37,16 +37,14 @@ def __init__(self, embeddings, db, options): # Column names self.names = [] - self.similarity = ( - Similarity(options["similarity"]) if "similarity" in options else None - ) + self.similarity = Similarity(options["similarity"]) if "similarity" in options else None self.labels = Labels(model=self.similarity) if self.similarity else None # Question-answering model # Determine if embeddings or a custom similarity model should be used to build question context - self.extractor = Extractor( + self.rag = RAG( self.similarity if self.similarity else self.embeddings, - options["qa"] if options["qa"] else "NeuML/bert-small-cord19qa", + options["qa"] if options.get("qa") else "NeuML/bert-small-cord19qa", minscore=options.get("minscore"), mintokens=options.get("mintokens"), context=options.get("context"), @@ -76,9 +74,7 @@ def build(self, queries, options, output): self.separator(output) # Query for best matches - results = Query.search( - self.embeddings, self.cur, query, topn, options.get("threshold") - ) + results = Query.search(self.embeddings, self.cur, query, topn, options.get("threshold")) # Generate highlights section self.section(output, "Highlights") @@ -115,9 +111,7 @@ def highlights(self, output, results, topn): for highlight in Query.highlights(results, topn): # Get matching article uid = [article for _, _, article, text in results if text == highlight][0] - self.cur.execute( - "SELECT Authors, Reference FROM articles WHERE id = ?", [uid] - ) + self.cur.execute("SELECT Authors, Reference FROM articles WHERE id = ?", [uid]) article = self.cur.fetchone() # Write out highlight row @@ -138,9 +132,7 @@ def articles(self, output, topn, metadata, results): _, query, _ = metadata # Retrieve list of documents - documents = ( - Query.all(self.cur) if query == "*" else Query.documents(results, topn) - ) + documents = Query.all(self.cur) if query == "*" else Query.documents(results, topn) # Collect matching rows rows = [] @@ -207,7 +199,7 @@ def calculate(self, uid, metadata): extractions.append((name, query, question, snippet)) # Run all extractor queries against document text - results = self.extractor.query([query for _, query, _ in queries], texts) + results = self.rag.query([query for _, query, _ in queries], texts) # Only execute embeddings queries for columns with matches set for x, (name, query, matches) in enumerate(queries): @@ -216,36 +208,21 @@ def calculate(self, uid, metadata): topn = [text for _, text, _ in results[x]][:matches] # Join results into String and return - value = [ - self.resolve(params, sections, uid, name, value) for value in topn - ] + value = [self.resolve(params, sections, uid, name, value) for value in topn] fields[name] = "\n\n".join(value) if value else "" else: fields[name] = "" # Add extraction fields if extractions: - for name, value in self.extractor(extractions, texts): + for name, value in self.rag(extractions, texts): # Resolves the full value based on column parameters - fields[name] = ( - self.resolve(params, sections, uid, name, value) if value else "" - ) + fields[name] = self.resolve(params, sections, uid, name, value) if value else "" # Add question fields - names, qa, contexts, snippets = [], [], [], [] - for name, query, question, snippet in questions: - names.append(name) - qa.append(question) - contexts.append(fields[query]) - snippets.append(snippet) - - for name, value in self.extractor.answers( - names, qa, contexts, contexts, snippets - ): + for name, value in self.rag(questions, texts): # Resolves the full value based on column parameters - fields[name] = ( - self.resolve(params, sections, uid, name, value) if value else "" - ) + fields[name] = self.resolve(params, sections, uid, name, value) if value else "" return fields @@ -277,11 +254,7 @@ def params(self, metadata): elif "query" in column: # Query variable substitutions query = self.variables(column["query"], metadata) - question = ( - self.variables(column["question"], metadata) - if "question" in column - else query - ) + question = self.variables(column["question"], metadata) if "question" in column else query # Additional context parameters section = column.get("section", False) @@ -346,12 +319,7 @@ def sections(self, uid): # Get list of document text sections sections = [] for sid, name, text in self.cur.fetchall(): - if ( - not self.embeddings.isweighted() - or not name - or not re.search(Index.SECTION_FILTER, name.lower()) - or self.options.get("allsections") - ): + if not self.embeddings.isweighted() or not name or not re.search(Index.SECTION_FILTER, name.lower()) or self.options.get("allsections"): # Check that section has at least 1 token if Tokenizer.tokenize(text): sections.append((sid, text)) diff --git a/src/python/paperai/report/csvr.py b/src/python/paperai/report/csvr.py index 1addd5e..d9385ba 100644 --- a/src/python/paperai/report/csvr.py +++ b/src/python/paperai/report/csvr.py @@ -40,9 +40,7 @@ def query(self, output, task, query): encoding="utf-8", ) - self.writer = csv.writer( - self.csvout, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL - ) + self.writer = csv.writer(self.csvout, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL) def write(self, row): """ @@ -80,9 +78,7 @@ def buildRow(self, article, sections, calculated): row["Source"] = article[4] # Top Matches - row["Matches"] = ( - "\n\n".join([Query.text(text) for _, text in sections]) if sections else "" - ) + row["Matches"] = "\n\n".join([Query.text(text) for _, text in sections]) if sections else "" # Entry Date row["Entry"] = article[5] if article[5] else "" diff --git a/src/python/paperai/report/execute.py b/src/python/paperai/report/execute.py index 28bd14e..e964cc4 100644 --- a/src/python/paperai/report/execute.py +++ b/src/python/paperai/report/execute.py @@ -39,12 +39,10 @@ def create(render, embeddings, db, options): if render == "md": return Markdown(embeddings, db, options) - return None + raise ValueError(f"Invalid report format: {render}") @staticmethod - def run( - task, topn=None, render=None, path=None, qa=None, indir=None, threshold=None - ): + def run(task, topn=None, render=None, path=None, qa=None, indir=None, threshold=None): """ Reads a list of queries from a task file and builds a report. @@ -110,8 +108,6 @@ def options(options, topn, render, path, qa, indir, threshold): options["path"] = path if path else options.get("path") options["qa"] = qa if qa else options.get("qa") options["indir"] = indir if indir else options.get("indir") - options["threshold"] = ( - threshold if threshold is not None else options.get("threshold") - ) + options["threshold"] = threshold if threshold is not None else options.get("threshold") return options diff --git a/src/python/paperai/report/markdown.py b/src/python/paperai/report/markdown.py index 6017c5f..af48e55 100644 --- a/src/python/paperai/report/markdown.py +++ b/src/python/paperai/report/markdown.py @@ -99,11 +99,7 @@ def buildRow(self, article, sections, calculated): row["Study"] = title # Top Matches - row["Matches"] = ( - "