Skip to content

Commit

Permalink
Modify default index configuration, closes #64
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Jan 21, 2023
1 parent 6d6aa0c commit 734feeb
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 19 deletions.
11 changes: 1 addition & 10 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,8 @@ RUN ln -sf /usr/bin/python3.7 /usr/bin/python && \
# Cleanup build packages
RUN apt-get -y purge gcc g++ python3-dev && apt-get -y autoremove

# Copy paperai scripts
RUN mkdir -p scripts
COPY scripts/ ./scripts/

# Create paperetl directories
RUN mkdir -p cord19/data cord19/report && \
mkdir -p paperetl/data paperetl/report

# Install vector model
RUN scripts/getvectors.sh cord19/models && \
scripts/getvectors.sh paperetl/models
RUN mkdir -p paperetl/data paperetl/report

# Start script
ENTRYPOINT /bin/bash
6 changes: 0 additions & 6 deletions scripts/getvectors.sh

This file was deleted.

11 changes: 8 additions & 3 deletions src/python/paperai/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from txtai.embeddings import Embeddings
from txtai.pipeline import Tokenizer
from txtai.vectors import WordVectors


class Index:
Expand Down Expand Up @@ -83,7 +84,7 @@ def config(vectors):
Builds embeddings configuration.
Args:
vectors: path to word vectors or configuration
vectors: vector model path or configuration
Returns:
configuration
Expand All @@ -98,8 +99,12 @@ def config(vectors):
with open(vectors, "r", encoding="utf-8") as f:
return yaml.safe_load(f)

# Default configuration
return {"path": vectors, "scoring": "bm25", "pca": 3, "quantize": True}
# Configuration for word vectors model
if WordVectors.isdatabase(vectors):
return {"path": vectors, "scoring": "bm25", "pca": 3, "quantize": True}

# Use vector path if provided, else use default txtai configuration
return {"path": vectors} if vectors else None

@staticmethod
def embeddings(dbfile, vectors, maxsize):
Expand Down
30 changes: 30 additions & 0 deletions test/python/testindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Index module tests
"""

import os
import tempfile
import unittest

from paperai.index import Index
Expand All @@ -15,6 +17,34 @@ class TestIndex(unittest.TestCase):
Index tests
"""

def testConfig(self):
"""
Test configuration
"""

# Test YAML config
config = os.path.join(tempfile.gettempdir(), "testconfig.yml")

with open(config, "w", encoding="utf-8") as output:
output.write("path: sentence-transformers/all-MiniLM-L6-v2")

self.assertEqual(
Index.config(config), {"path": "sentence-transformers/all-MiniLM-L6-v2"}
)

# Test word vectors
self.assertEqual(
Index.config(Utils.VECTORFILE),
{"path": Utils.VECTORFILE, "scoring": "bm25", "pca": 3, "quantize": True},
)

# Test default
self.assertEqual(
Index.config("sentence-transformers/all-MiniLM-L6-v2"),
{"path": "sentence-transformers/all-MiniLM-L6-v2"},
)
self.assertEqual(Index.config(None), None)

def testStream(self):
"""
Test row streaming
Expand Down
8 changes: 8 additions & 0 deletions test/python/testquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ class TestQuery(unittest.TestCase):
Query tests
"""

def testEmpty(self):
"""
Test empty fields.
"""

self.assertEqual(Query.authors(None), None)
self.assertEqual(Query.date(None), None)

def testRun(self):
"""
Test query execution
Expand Down

0 comments on commit 734feeb

Please sign in to comment.