Skip to content

Commit

Permalink
feat(api): adds migration for tracking vector store indexing status (#…
Browse files Browse the repository at this point in the history
…830)

* Enables Supabase realtime for the vector_store_file table
* Adds new Python dependency to test supabase-realtime
* Status is currently being set during indexing so listening to this table should provide status updates on the file indexing process.
  • Loading branch information
CollectiveUnicorn authored Sep 5, 2024
1 parent 8a1d61e commit eee3ed7
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 5 deletions.
1 change: 1 addition & 0 deletions .github/actions/lfai-core/action.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ runs:
id: set-env-var
run: |
echo "ANON_KEY=$(uds zarf tools kubectl get secret supabase-bootstrap-jwt -n leapfrogai -o jsonpath='{.data.anon-key}' | base64 -d)" >> "$GITHUB_ENV"
echo "SERVICE_KEY=$(uds zarf tools kubectl get secret supabase-bootstrap-jwt -n leapfrogai -o jsonpath='{.data.service-key}' | base64 -d)" >> "$GITHUB_ENV"
- name: Deploy LFAI-API
shell: bash
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
-- Update the vector_store_file table to add an updated_at column
ALTER TABLE vector_store_file ADD COLUMN updated_at timestamp DEFAULT timezone('utc', now()) NOT NULL;

-- Add an index on user_id for faster queries
CREATE INDEX idx_vector_store_file_user_id ON vector_store_file(user_id);

-- Create a function to update the updated_at column
CREATE OR REPLACE FUNCTION update_modified_column()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = timezone('utc', now());
RETURN NEW;
END;
$$ language 'plpgsql';

-- Create a trigger to automatically update the updated_at column
CREATE TRIGGER update_vector_store_file_modtime
BEFORE UPDATE ON vector_store_file
FOR EACH ROW
EXECUTE FUNCTION update_modified_column();

-- Enable Supabase realtime for the vector_store_file table
alter publication supabase_realtime
add table vector_store_file;
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dev = [
"requests",
"requests-toolbelt",
"pytest",
"supabase == 2.6.0",
"huggingface_hub[cli,hf_transfer] == 0.24.5",
"fastapi == 0.109.1",
]
Expand Down
165 changes: 161 additions & 4 deletions tests/e2e/test_supabase.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,24 @@
import asyncio
import io
import threading
import uuid
from fastapi import UploadFile
import requests
from openai.types.beta.vector_stores import VectorStoreFile
from openai.types.beta import VectorStore
from openai.types.beta.vector_store import FileCounts
import _thread

from .utils import ANON_KEY
from supabase import AClient as AsyncClient, acreate_client
from realtime import Socket
from leapfrogai_api.data.crud_file_bucket import CRUDFileBucket
from leapfrogai_api.data.crud_file_object import CRUDFileObject
from leapfrogai_api.data.crud_vector_store import CRUDVectorStore

from leapfrogai_api.data.crud_vector_store_file import CRUDVectorStoreFile

from .utils import ANON_KEY, create_test_user, SERVICE_KEY
from openai.types import FileObject

health_urls = {
"auth_health_url": "http://supabase-kong.uds.dev/auth/v1/health",
Expand All @@ -12,9 +30,148 @@
def test_studio():
try:
for url_name in health_urls:
response = requests.get(health_urls[url_name], headers={"apikey": ANON_KEY})
response.raise_for_status()
resp = requests.get(health_urls[url_name], headers={"apikey": ANON_KEY})
resp.raise_for_status()
except requests.exceptions.RequestException as e:
print(f"Error: Request failed with status code {response.status_code}")
print(f"Error: Request failed with status code {resp.status_code}")
print(e)
exit(1)


def test_supabase_realtime_vector_store_indexing():
class TestCompleteException(Exception):
pass

def timeout_handler():
print("Test timed out after 10 seconds")
# This is necessary to stop the thread from hanging forever
_thread.interrupt_main()

async def postgres_db_changes():
"""
This function is responsible for creating a vector store and uploading a file to it.
"""
client: AsyncClient = await acreate_client(
supabase_key=ANON_KEY,
supabase_url="https://supabase-kong.uds.dev",
)
await client.auth.set_session(access_token=access_token, refresh_token="dummy")

upload_file_id = await upload_file(client)
assert upload_file_id is not None, "Failed to upload file"

vector_store = VectorStore(
id="",
created_at=0,
file_counts=FileCounts(
cancelled=0,
completed=0,
failed=0,
in_progress=0,
total=0,
),
name="test_vector_store",
object="vector_store",
status="completed",
usage_bytes=0,
)

new_vector_store = await CRUDVectorStore(client).create(vector_store)
assert new_vector_store is not None, "Failed to create vector store"

vector_store_file = VectorStoreFile(
id=upload_file_id,
vector_store_id=new_vector_store.id,
created_at=0,
object="vector_store.file",
status="completed",
usage_bytes=0,
)

await CRUDVectorStoreFile(client).create(vector_store_file)

def postgres_changes_callback(payload):
"""
This function is responsible for listening for changes to the vector store file and signaling success if the file triggers realtime successfully.
"""
expected_record = {
"object": "vector_store.file",
"status": "completed",
"usage_bytes": 0,
}

all_records_match = all(
payload.get("record", {}).get(key) == value
for key, value in expected_record.items()
)
event_information_match = (
payload.get("table") == "vector_store_file"
and payload.get("type") == "INSERT"
)

if event_information_match and all_records_match:
raise TestCompleteException("Test completed successfully")

async def upload_file(client: AsyncClient) -> str:
"""
This function is responsible for uploading a file to the file bucket.
"""
empty_file_object = FileObject(
id="",
bytes=0,
created_at=0,
filename="",
object="file",
purpose="assistants",
status="uploaded",
status_details=None,
)

file_object = await CRUDFileObject(client).create(object_=empty_file_object)
assert file_object is not None, "Failed to create file object"

crud_file_bucket = CRUDFileBucket(db=client, model=UploadFile)
await crud_file_bucket.upload(
file=UploadFile(filename="", file=io.BytesIO(b"")), id_=file_object.id
)
return file_object.id

def run_postgres_db_changes():
"""
This function is responsible for running the postgres_db_changes function.
"""
asyncio.run(postgres_db_changes())

timeout_timer = None
try:
random_name = str(uuid.uuid4())
access_token = create_test_user(email=f"{random_name}@fake.com")

# Schedule postgres_db_changes to run after 5 seconds
threading.Timer(5.0, run_postgres_db_changes).start()

# Set a timeout of 10 seconds
timeout_timer = threading.Timer(10.0, timeout_handler)
timeout_timer.start()

# Listening socket
# The service key is needed for proper permission to listen to realtime events
# At the time of writing this, the Supabase realtime library does not support RLS
URL = f"wss://supabase-kong.uds.dev/realtime/v1/websocket?apikey={SERVICE_KEY}&vsn=1.0.0"
s = Socket(URL)
s.connect()

# Set channel to listen for changes to the vector_store_file table
channel_1 = s.set_channel("realtime:public:vector_store_file")
# Listen for all events on the channel ex: INSERT, UPDATE, DELETE
channel_1.join().on("*", postgres_changes_callback)

# Start listening
s.listen()
except TestCompleteException:
if timeout_timer is not None:
timeout_timer.cancel() # Cancel the timeout timer if test completes successfully

assert True
except Exception:
assert False
3 changes: 2 additions & 1 deletion tests/e2e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import requests

# This is the anon_key for supabase, it provides access to the endpoints that would otherwise be inaccessible
ANON_KEY = os.getenv("ANON_KEY")
ANON_KEY = os.environ["ANON_KEY"]
SERVICE_KEY = os.environ["SERVICE_KEY"]
DEFAULT_TEST_EMAIL = "[email protected]"
DEFAULT_TEST_PASSWORD = "password"

Expand Down

0 comments on commit eee3ed7

Please sign in to comment.