Skip to content

Commit

Permalink
Add simple extraction workflow (#124)
Browse files Browse the repository at this point in the history
This PR adds a simple extraction workflow. 

Makes it easier to apply concurrently poroces a set of documents
using a given chain and collect the results.
  • Loading branch information
eyurtsev authored Apr 7, 2023
1 parent b218678 commit 7ebd1df
Show file tree
Hide file tree
Showing 13 changed files with 684 additions and 369 deletions.
12 changes: 10 additions & 2 deletions kor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from .adapters import from_pydantic
from .encoders import CSVEncoder, JSONEncoder, XMLEncoder
from .extraction import create_extraction_chain
from .extraction import (
DocumentExtraction,
Extraction,
create_extraction_chain,
extract_from_documents,
)
from .nodes import Number, Object, Option, Selection, Text
from .type_descriptors import (
BulletPointDescriptor,
Expand All @@ -13,6 +18,8 @@
"BulletPointDescriptor",
"create_extraction_chain",
"CSVEncoder",
"DocumentExtraction",
"Extraction",
"from_pydantic",
"JSONEncoder",
"Number",
Expand All @@ -22,6 +29,7 @@
"Text",
"TypeDescriptor",
"TypeScriptDescriptor",
"XMLEncoder",
"extract_from_documents",
"__version__",
"XMLEncoder",
)
80 changes: 0 additions & 80 deletions kor/extraction.py

This file was deleted.

11 changes: 11 additions & 0 deletions kor/extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from kor.extraction.api import create_extraction_chain, extract_from_documents
from kor.extraction.parser import KorParser
from kor.extraction.typedefs import DocumentExtraction, Extraction

__all__ = [
"Extraction",
"KorParser",
"extract_from_documents",
"create_extraction_chain",
"DocumentExtraction",
]
172 changes: 172 additions & 0 deletions kor/extraction/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""Kor API for extraction related functionality."""
import asyncio
from typing import Any, Callable, List, Optional, Sequence, Type, Union, cast

from langchain import PromptTemplate
from langchain.chains import LLMChain
from langchain.docstore.document import Document
from langchain.schema import BaseLanguageModel

from kor.encoders import Encoder, InputFormatter, initialize_encoder
from kor.extraction.typedefs import DocumentExtraction, Extraction
from kor.nodes import Object
from kor.prompts import create_langchain_prompt
from kor.type_descriptors import TypeDescriptor, initialize_type_descriptors
from kor.validators import Validator


async def _extract_from_document_with_semaphore(
semaphore: asyncio.Semaphore,
chain: LLMChain,
document: Document,
uid: str,
source_uid: str,
) -> DocumentExtraction:
"""Extract from document with a semaphore to limit concurrency."""
async with semaphore:
extraction_result: Extraction = cast(
Extraction, await chain.apredict_and_parse(text=document.page_content)
)
return {
"uid": uid,
"source_uid": source_uid,
"data": extraction_result["data"],
"raw": extraction_result["raw"],
"validated_data": extraction_result["validated_data"],
"errors": extraction_result["errors"],
}


# PUBLIC API


def create_extraction_chain(
llm: BaseLanguageModel,
node: Object,
*,
encoder_or_encoder_class: Union[Type[Encoder], Encoder, str] = "csv",
type_descriptor: Union[TypeDescriptor, str] = "typescript",
validator: Optional[Validator] = None,
input_formatter: InputFormatter = None,
instruction_template: Optional[PromptTemplate] = None,
**encoder_kwargs: Any,
) -> LLMChain:
"""Create an extraction chain.
Args:
llm: the language model used for extraction
node: the schematic description of what to extract from text
encoder_or_encoder_class: Either an encoder instance, an encoder class
or a string representing the encoder class
type_descriptor: either a TypeDescriptor or a string representing the type \
descriptor name
validator: optional validator to use for validation
input_formatter: the formatter to use for encoding the input. Used for \
both input examples and the text to be analyzed.
* `None`: use for single sentences or single paragraph, no formatting
* `triple_quotes`: for long text, surround input with \"\"\"
* `text_prefix`: for long text, triple_quote with `TEXT: ` prefix
* `Callable`: user provided function
instruction_template: optional prompt template to use, use to over-ride prompt
used for generating the instruction section of the prompt.
It accepts 2 optional input variables:
* "type_description": type description of the node (from TypeDescriptor)
* "format_instructions": information on how to format the output
(from Encoder)
encoder_kwargs: Keyword arguments to pass to the encoder class
Returns:
A langchain chain
Examples:
.. code-block:: python
# For CSV encoding
chain = create_extraction_chain(llm, node, encoder_or_encoder_class="csv")
# For JSON encoding
chain = create_extraction_chain(llm, node, encoder_or_encoder_class="JSON",
input_formatter="triple_quotes")
"""
if not isinstance(node, Object):
raise ValueError(f"node must be an Object got {type(node)}")
encoder = initialize_encoder(encoder_or_encoder_class, node, **encoder_kwargs)
type_descriptor_to_use = initialize_type_descriptors(type_descriptor)
return LLMChain(
llm=llm,
prompt=create_langchain_prompt(
node,
encoder,
type_descriptor_to_use,
validator=validator,
instruction_template=instruction_template,
input_formatter=input_formatter,
),
)


async def extract_from_documents(
chain: LLMChain,
documents: Sequence[Document],
*,
max_concurrency: int = 1,
use_uid: bool = True,
extraction_uid_function: Optional[Callable[[Document], str]] = None,
return_exceptions: bool = False,
) -> List[DocumentExtraction]:
"""Run extraction through all the given documents.
Attention: When using this function with a large number of documents, mind the bill
since this can use a lot of tokens!
Concurrency is currently limited using a semaphore. This is a temporary
and can be changed to a queue implementation to support a non-materialized stream
of documents.
Args:
chain: the extraction chain to use for extraction
documents: the documents to run extraction on
max_concurrency: the maximum number of concurrent requests to make,
uses a semaphore to limit concurrency
use_uid: If True, will use a uid attribute in metadata if it exists
will raise error if attribute does not exist.
If False, will use the index of the document in the list as the uid
extraction_uid_function: Optional function to use to generate the uid for
a given DocumentExtraction. If not provided, will use the uid
of the document.
return_exceptions: named argument passed to asyncio.gather
Returns:
A list of extraction results
"""
semaphore = asyncio.Semaphore(value=max_concurrency)

tasks = []
for idx, doc in enumerate(documents):
if use_uid:
source_uid = doc.metadata.get("uid")
if source_uid is None:
raise ValueError(
f"uid not found in document metadata for document {idx}"
)
# C
source_uid = str(source_uid)
else:
source_uid = str(idx)

extraction_uid = (
extraction_uid_function(doc) if extraction_uid_function else source_uid
)

tasks.append(
asyncio.ensure_future(
_extract_from_document_with_semaphore(
semaphore, chain, doc, extraction_uid, source_uid
)
)
)

results = await asyncio.gather(*tasks, return_exceptions=return_exceptions)
return results
13 changes: 8 additions & 5 deletions kor/encoders/parser.py → kor/extraction/parser.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from typing import Any, Dict, Optional
from typing import List, Optional

from pydantic import Extra

from kor.encoders import Encoder
from kor.exceptions import ParseError
from kor.extraction.typedefs import Extraction
from kor.nodes import Object
from kor.validators import Validator

Expand All @@ -31,7 +32,7 @@ def _type(self) -> str:
"""Declare the type property."""
return "KorEncoder"

def parse(self, text: str) -> Dict[str, Any]:
def parse(self, text: str) -> Extraction:
"""Parse the text."""
try:
data = self.encoder.decode(text)
Expand All @@ -40,6 +41,8 @@ def parse(self, text: str) -> Dict[str, Any]:

key_id = self.schema_.id

errors: List[Exception]

if key_id not in data:
if data: # We got something parsed, but it doesn't match the schema.
errors = [
Expand All @@ -56,14 +59,14 @@ def parse(self, text: str) -> Dict[str, Any]:
obj_data = data[key_id]

if self.validator:
validated_data, exceptions = self.validator.clean_data(obj_data)
validated_data, errors = self.validator.clean_data(obj_data)
else:
validated_data, exceptions = {}, []
validated_data, errors = {}, []

return {
"data": data,
"raw": text,
"errors": exceptions,
"errors": errors,
"validated_data": validated_data,
}

Expand Down
31 changes: 31 additions & 0 deletions kor/extraction/typedefs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Type definitions for the extraction package."""
from typing import Any, Dict, List, TypedDict


class Extraction(TypedDict):
"""Type-definition for an extraction result."""

raw: str
"""The raw output from the LLM."""
data: Dict[str, Any]
"""The decoding of the raw output from the LLM without any further processing."""
validated_data: Dict[str, Any]
"""The validated data if a validator was provided."""
errors: List[Exception]
"""Any errors encountered during decoding or validation."""


class DocumentExtraction(Extraction):
"""Type-definition for a document extraction result.
The original extraction typedefs together with the unique identifiers for the result
itself as well as the source document.
Identifiers are included to make it easier to link the extraction result
to the source content.
"""

uid: str
"""The uid of the extraction result."""
source_uid: str
"""The source uid of the document from which data was extracted."""
2 changes: 1 addition & 1 deletion kor/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from kor.encoders import Encoder
from kor.encoders.encode import InputFormatter, encode_examples, format_text
from kor.encoders.parser import KorParser
from kor.examples import generate_examples
from kor.extraction.parser import KorParser
from kor.nodes import Object
from kor.type_descriptors import TypeDescriptor

Expand Down
Loading

0 comments on commit 7ebd1df

Please sign in to comment.