From 734feeb29f623ac15cc04023b01e4dcf7f19508e Mon Sep 17 00:00:00 2001 From: davidmezzetti <561939+davidmezzetti@users.noreply.github.com> Date: Sat, 21 Jan 2023 15:48:06 -0500 Subject: [PATCH] Modify default index configuration, closes #64 --- docker/Dockerfile | 11 +---------- scripts/getvectors.sh | 6 ------ src/python/paperai/index.py | 11 ++++++++--- test/python/testindex.py | 30 ++++++++++++++++++++++++++++++ test/python/testquery.py | 8 ++++++++ 5 files changed, 47 insertions(+), 19 deletions(-) delete mode 100755 scripts/getvectors.sh diff --git a/docker/Dockerfile b/docker/Dockerfile index db2cce5..7cab8ea 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -20,17 +20,8 @@ RUN ln -sf /usr/bin/python3.7 /usr/bin/python && \ # Cleanup build packages RUN apt-get -y purge gcc g++ python3-dev && apt-get -y autoremove -# Copy paperai scripts -RUN mkdir -p scripts -COPY scripts/ ./scripts/ - # Create paperetl directories -RUN mkdir -p cord19/data cord19/report && \ - mkdir -p paperetl/data paperetl/report - -# Install vector model -RUN scripts/getvectors.sh cord19/models && \ - scripts/getvectors.sh paperetl/models +RUN mkdir -p paperetl/data paperetl/report # Start script ENTRYPOINT /bin/bash diff --git a/scripts/getvectors.sh b/scripts/getvectors.sh deleted file mode 100755 index ca832d7..0000000 --- a/scripts/getvectors.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash - -# Download and unpack vector model -mkdir -p $1 -wget -N https://github.com/neuml/paperai/releases/download/v1.3.0/cord19-300d.magnitude.gz -P $1 -gunzip $1/cord19-300d.magnitude.gz diff --git a/src/python/paperai/index.py b/src/python/paperai/index.py index d04e2f4..d23a0ac 100644 --- a/src/python/paperai/index.py +++ b/src/python/paperai/index.py @@ -11,6 +11,7 @@ from txtai.embeddings import Embeddings from txtai.pipeline import Tokenizer +from txtai.vectors import WordVectors class Index: @@ -83,7 +84,7 @@ def config(vectors): Builds embeddings configuration. Args: - vectors: path to word vectors or configuration + vectors: vector model path or configuration Returns: configuration @@ -98,8 +99,12 @@ def config(vectors): with open(vectors, "r", encoding="utf-8") as f: return yaml.safe_load(f) - # Default configuration - return {"path": vectors, "scoring": "bm25", "pca": 3, "quantize": True} + # Configuration for word vectors model + if WordVectors.isdatabase(vectors): + return {"path": vectors, "scoring": "bm25", "pca": 3, "quantize": True} + + # Use vector path if provided, else use default txtai configuration + return {"path": vectors} if vectors else None @staticmethod def embeddings(dbfile, vectors, maxsize): diff --git a/test/python/testindex.py b/test/python/testindex.py index eb42274..cf5f75a 100644 --- a/test/python/testindex.py +++ b/test/python/testindex.py @@ -2,6 +2,8 @@ Index module tests """ +import os +import tempfile import unittest from paperai.index import Index @@ -15,6 +17,34 @@ class TestIndex(unittest.TestCase): Index tests """ + def testConfig(self): + """ + Test configuration + """ + + # Test YAML config + config = os.path.join(tempfile.gettempdir(), "testconfig.yml") + + with open(config, "w", encoding="utf-8") as output: + output.write("path: sentence-transformers/all-MiniLM-L6-v2") + + self.assertEqual( + Index.config(config), {"path": "sentence-transformers/all-MiniLM-L6-v2"} + ) + + # Test word vectors + self.assertEqual( + Index.config(Utils.VECTORFILE), + {"path": Utils.VECTORFILE, "scoring": "bm25", "pca": 3, "quantize": True}, + ) + + # Test default + self.assertEqual( + Index.config("sentence-transformers/all-MiniLM-L6-v2"), + {"path": "sentence-transformers/all-MiniLM-L6-v2"}, + ) + self.assertEqual(Index.config(None), None) + def testStream(self): """ Test row streaming diff --git a/test/python/testquery.py b/test/python/testquery.py index 1211664..d345756 100644 --- a/test/python/testquery.py +++ b/test/python/testquery.py @@ -17,6 +17,14 @@ class TestQuery(unittest.TestCase): Query tests """ + def testEmpty(self): + """ + Test empty fields. + """ + + self.assertEqual(Query.authors(None), None) + self.assertEqual(Query.date(None), None) + def testRun(self): """ Test query execution