Skip to content

Commit

Permalink
Add sharing (#38)
Browse files Browse the repository at this point in the history
Add sharing of extractors
  • Loading branch information
eyurtsev authored Mar 18, 2024
1 parent bdfeea5 commit 1a12200
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 36 deletions.
110 changes: 78 additions & 32 deletions backend/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
from datetime import datetime
from typing import Generator

from sqlalchemy import Column, DateTime, ForeignKey, String, Text, create_engine
from sqlalchemy import (
Column,
DateTime,
ForeignKey,
String,
Text,
UniqueConstraint,
create_engine,
)
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, relationship, sessionmaker
Expand Down Expand Up @@ -56,37 +64,6 @@ class TimestampedModel(Base):
)


class Extractor(TimestampedModel):
__tablename__ = "extractors"

name = Column(
String(100),
nullable=False,
server_default="",
comment="The name of the extractor.",
)
schema = Column(
JSONB,
nullable=False,
comment="JSON Schema that describes what content will be "
"extracted from the document",
)
description = Column(
String(100),
nullable=False,
server_default="",
comment="Surfaced via UI to the users.",
)
instruction = Column(
Text, nullable=False, comment="The prompt to the language model."
) # TODO: This will need to evolve

examples = relationship("Example", backref="extractor")

def __repr__(self) -> str:
return f"<Extractor(id={self.uuid}, description={self.description})>"


class Example(TimestampedModel):
"""A representation of an example.
Expand Down Expand Up @@ -122,3 +99,72 @@ class Example(TimestampedModel):

def __repr__(self) -> str:
return f"<Example(uuid={self.uuid}, content={self.content[:20]}>"


class SharedExtractors(TimestampedModel):
"""A table for managing sharing of extractors."""

__tablename__ = "shared_extractors"

extractor_id = Column(
UUID(as_uuid=True),
ForeignKey("extractors.uuid", ondelete="CASCADE"),
index=True,
nullable=False,
comment="The extractor that is being shared.",
)

share_token = Column(
UUID(as_uuid=True),
index=True,
nullable=False,
unique=True,
comment="The token that is used to access the shared extractor.",
)

# Add unique constraint for (extractor_id, share_token)
__table_args__ = (
UniqueConstraint("extractor_id", "share_token", name="unique_share_token"),
)

def __repr__(self) -> str:
"""Return a string representation of the object."""
return f"<SharedExtractor(id={self.id}, run_id={self.run_id})>"


class Extractor(TimestampedModel):
__tablename__ = "extractors"

name = Column(
String(100),
nullable=False,
server_default="",
comment="The name of the extractor.",
)
schema = Column(
JSONB,
nullable=False,
comment="JSON Schema that describes what content will be "
"extracted from the document",
)
description = Column(
String(100),
nullable=False,
server_default="",
comment="Surfaced via UI to the users.",
)
instruction = Column(
Text, nullable=False, comment="The prompt to the language model."
) # TODO: This will need to evolve

examples = relationship("Example", backref="extractor")

# Used for sharing the extractor with others.
share_uuid = Column(
UUID(as_uuid=True),
nullable=True,
comment="The uuid of the shareable link.",
)

def __repr__(self) -> str:
return f"<Extractor(id={self.uuid}, description={self.description})>"
67 changes: 64 additions & 3 deletions backend/server/api/extractors.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Endpoints for managing definition of extractors."""
from typing import Any, Dict, List
from uuid import UUID
from uuid import UUID, uuid4

from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field, validator
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

from db.models import Extractor, get_session
from db.models import Extractor, SharedExtractors, get_session
from server.validators import validate_json_schema

router = APIRouter(
Expand Down Expand Up @@ -39,14 +40,74 @@ def validate_schema(cls, v: Any) -> Dict[str, Any]:
class CreateExtractorResponse(BaseModel):
"""Response for creating an extractor."""

uuid: UUID
uuid: UUID = Field(..., description="The UUID of the created extractor.")


class ShareExtractorRequest(BaseModel):
"""Response for sharing an extractor."""

uuid: UUID = Field(..., description="The UUID of the extractor to share.")


class ShareExtractorResponse(BaseModel):
"""Response for sharing an extractor."""

share_uuid: UUID = Field(..., description="The UUID for the shared extractor.")


@router.post("/{uuid}/share", response_model=ShareExtractorResponse)
def share(
uuid: UUID,
*,
session: Session = Depends(get_session),
) -> ShareExtractorResponse:
"""Endpoint to share an extractor.
Look up a shared extractor by UUID and return the share UUID if it exists.
If not shared, create a new shared extractor entry and return the new share UUID.
Args:
uuid: The UUID of the extractor to share.
session: The database session.
Returns:
The UUID for the shared extractor.
"""
# Check if the extractor is already shared
shared_extractor = (
session.query(SharedExtractors)
.filter(SharedExtractors.extractor_id == uuid)
.scalar()
)

if shared_extractor:
# The extractor is already shared, return the existing share_uuid
return ShareExtractorResponse(share_uuid=shared_extractor.share_token)

# If not shared, create a new shared extractor entry
new_shared_extractor = SharedExtractors(
extractor_id=uuid,
# This will automatically generate a new UUID for share_token
share_token=uuid4(),
)

session.add(new_shared_extractor)
try:
session.commit()
except IntegrityError:
session.rollback()
raise HTTPException(status_code=400, detail="Failed to share the extractor.")

# Return the new share_uuid
return ShareExtractorResponse(share_uuid=new_shared_extractor.share_token)


@router.post("")
def create(
create_request: CreateExtractor, *, session: Session = Depends(get_session)
) -> CreateExtractorResponse:
"""Endpoint to create an extractor."""

instance = Extractor(
name=create_request.name,
schema=create_request.json_schema,
Expand Down
51 changes: 51 additions & 0 deletions backend/server/api/shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Endpoints for working with shared resources."""
from typing import Any, Dict
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session

from db.models import Extractor, SharedExtractors, get_session

router = APIRouter(
prefix="/s",
tags=["extractor definitions"],
responses={404: {"description": "Not found"}},
)


class SharedExtractorResponse(BaseModel):
"""Response for sharing an extractor."""

# UUID should not be included in the response since it is not a public identifier!
name: str
description: str
# schema is a reserved keyword by pydantic
schema_: Dict[str, Any] = Field(..., alias="schema")
instruction: str


@router.get("/{uuid}")
def get(
uuid: UUID,
*,
session: Session = Depends(get_session),
) -> SharedExtractorResponse:
"""Get a shared extractor."""
extractor = (
session.query(Extractor)
.join(SharedExtractors, Extractor.uuid == SharedExtractors.extractor_id)
.filter(SharedExtractors.share_token == uuid)
.first()
)

if not extractor:
raise HTTPException(status_code=404, detail="Extractor not found.")

return SharedExtractorResponse(
name=extractor.name,
description=extractor.description,
schema=extractor.schema,
instruction=extractor.instruction,
)
3 changes: 2 additions & 1 deletion backend/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi.middleware.cors import CORSMiddleware
from langserve import add_routes

from server.api import examples, extract, extractors, suggest
from server.api import examples, extract, extractors, shared, suggest
from server.extraction_runnable import (
ExtractRequest,
ExtractResponse,
Expand Down Expand Up @@ -46,6 +46,7 @@ def ready() -> str:
app.include_router(examples.router)
app.include_router(extract.router)
app.include_router(suggest.router)
app.include_router(shared.router)

add_routes(
app,
Expand Down
44 changes: 44 additions & 0 deletions backend/tests/unit_tests/api/test_api_defining_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,47 @@ async def test_extractors_api() -> None:
}
response = await client.post("/extractors", json=create_request)
assert response.status_code == 200


async def test_sharing_extractor() -> None:
"""Test sharing an extractor."""
async with get_async_client() as client:
response = await client.get("/extractors")
assert response.status_code == 200
assert response.json() == []
# Verify that we can create an extractor
create_request = {
"name": "Test Name",
"description": "Test Description",
"schema": {"type": "object"},
"instruction": "Test Instruction",
}
response = await client.post("/extractors", json=create_request)
assert response.status_code == 200

uuid = response.json()["uuid"]

# Verify that the extractor was created
response = await client.post(f"/extractors/{uuid}/share")
assert response.status_code == 200
assert "share_uuid" in response.json()
share_uuid = response.json()["share_uuid"]

# Test idempotency
response = await client.post(f"/extractors/{uuid}/share")
assert response.status_code == 200
assert "share_uuid" in response.json()
assert response.json()["share_uuid"] == share_uuid

# Check that we can retrieve the shared extractor
response = await client.get(f"/s/{share_uuid}")
assert response.status_code == 200
keys = sorted(response.json())
assert keys == ["description", "instruction", "name", "schema"]

assert response.json() == {
"description": "Test Description",
"instruction": "Test Instruction",
"name": "Test Name",
"schema": {"type": "object"},
}

0 comments on commit 1a12200

Please sign in to comment.