Skip to content

Commit

Permalink
Integrate Shapelets
Browse files Browse the repository at this point in the history
  • Loading branch information
adrian-carrio committed Nov 6, 2024
1 parent 1ab46dd commit 5b10f25
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 0 deletions.
10 changes: 10 additions & 0 deletions vectordb_bench/backend/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class DB(Enum):
Chroma = "Chroma"
AWSOpenSearch = "OpenSearch"
Test = "test"
Shapelets = "Shapelets"


@property
Expand Down Expand Up @@ -97,6 +98,10 @@ def init_cls(self) -> Type[VectorDB]:
if self == DB.AWSOpenSearch:
from .aws_opensearch.aws_opensearch import AWSOpenSearch
return AWSOpenSearch

if self == DB.Shapelets:
from .shapelets.shapelets import ShapeletsClient
return ShapeletsClient

@property
def config_cls(self) -> Type[DBConfig]:
Expand Down Expand Up @@ -156,6 +161,11 @@ def config_cls(self) -> Type[DBConfig]:
if self == DB.AWSOpenSearch:
from .aws_opensearch.config import AWSOpenSearchConfig
return AWSOpenSearchConfig

if self == DB.Shapelets:
from .shapelets.config import ShapeletsConfig
return ShapeletsConfig


def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
if self == DB.Milvus:
Expand Down
6 changes: 6 additions & 0 deletions vectordb_bench/backend/clients/shapelets/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ..api import DBConfig

class ShapeletsConfig(DBConfig):

def to_dict(self) -> dict:
return {}
104 changes: 104 additions & 0 deletions vectordb_bench/backend/clients/shapelets/shapelets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import logging
from contextlib import contextmanager
from typing import Any
from ..api import VectorDB, DBCaseConfig
from shapelets.storage import RecordStore, KnnOptions, MetricType
from shapelets.data import DataType

log = logging.getLogger(__name__)
service = RecordStore.start()
knnopt = KnnOptions()
knnopt.include_record = True
knnopt.include_embedding = False

class ShapeletsClient(VectorDB):
"""Shapelets client for VectorDB.
"""

def __init__(
self,
dim: int,
db_config: dict,
db_case_config: DBCaseConfig,
drop_old: bool = False,
**kwargs
):
#self.db_config = db_config
#self.db_config["host"] = "127.0.0.1"
#self.db_config["port"] = 8500
#self.case_config = db_case_config
self.collection_name = 'example'
service.create_catalog(self.collection_name, {'embedding':DataType.embedding(dim, MetricType.Cosine)})

@contextmanager
def init(self) -> None:
""" create and destory connections to database.
"""
yield
self.client = None
self.collection = None

def ready_to_search(self) -> bool:
pass

def ready_to_load(self) -> bool:
pass

def optimize(self) -> None:
pass

def insert_embeddings(
self,
embeddings: list[list[float]],
metadata: list[int],
**kwargs: Any,
) -> (int, Exception):
"""Insert embeddings into the database.
Args:
embeddings(list[list[float]]): list of embeddings
metadata(list[int]): list of metadata
kwargs: other arguments
Returns:
(int, Exception): number of embeddings inserted and exception if any
"""
ids=[str(i) for i in metadata]
#metadata = [{"id": int(i)} for i in metadata]

index = service.open_catalog(self.collection_name)
loader = index.create_loader()
if len(embeddings) > 0:
vectors_per_request = 5000
for i in range(0,len(embeddings), vectors_per_request):
data = [{'embedding': emb, 'id':id} for emb, id in zip(embeddings[i:i+vectors_per_request],ids[i:i+vectors_per_request])]
for entry in data:
loader.append({'id':int(entry['id'])},{'embedding':entry['embedding']})
loader.finalize()
return len(embeddings), None

def search_embedding(
self,
query: list[float],
k: int = 100,
filters: dict | None = None,
timeout: int | None = None,
**kwargs: Any,
) -> dict:
"""Search embeddings from the database.
Args:
embedding(list[float]): embedding to search
k(int): number of results to return
kwargs: other arguments
Returns:
Dict {ids: list[list[int]],
embedding: list[list[float]]
distance: list[list[float]]}
"""
index = service.open_catalog(self.collection_name)
knnResult = index.knn(query, k, options = knnopt)
return [r.record['id'] for r in knnResult]



2 changes: 2 additions & 0 deletions vectordb_bench/frontend/config/styles.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def getPatternShape(i):
DB.Redis: "https://assets.zilliz.com/Redis_Cloud_74b8bfef39.png",
DB.Chroma: "https://assets.zilliz.com/chroma_ceb3f06ed7.png",
DB.AWSOpenSearch: "https://assets.zilliz.com/opensearch_1eee37584e.jpeg",
DB.Shapelets: "https://agile-data-science-project-data.s3.eu-west-3.amazonaws.com/logo_.png"
}

# RedisCloud color: #0D6EFD
Expand All @@ -61,4 +62,5 @@ def getPatternShape(i):
DB.PgVector.value: "#4C779A",
DB.Redis.value: "#0D6EFD",
DB.AWSOpenSearch.value: "#0DCAF0",
DB.Shapelets.value: "#2ECADA",
}

0 comments on commit 5b10f25

Please sign in to comment.