-
Notifications
You must be signed in to change notification settings - Fork 91
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add simple extraction workflow (#124)
This PR adds a simple extraction workflow. Makes it easier to apply concurrently poroces a set of documents using a given chain and collect the results.
- Loading branch information
Showing
13 changed files
with
684 additions
and
369 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from kor.extraction.api import create_extraction_chain, extract_from_documents | ||
from kor.extraction.parser import KorParser | ||
from kor.extraction.typedefs import DocumentExtraction, Extraction | ||
|
||
__all__ = [ | ||
"Extraction", | ||
"KorParser", | ||
"extract_from_documents", | ||
"create_extraction_chain", | ||
"DocumentExtraction", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
"""Kor API for extraction related functionality.""" | ||
import asyncio | ||
from typing import Any, Callable, List, Optional, Sequence, Type, Union, cast | ||
|
||
from langchain import PromptTemplate | ||
from langchain.chains import LLMChain | ||
from langchain.docstore.document import Document | ||
from langchain.schema import BaseLanguageModel | ||
|
||
from kor.encoders import Encoder, InputFormatter, initialize_encoder | ||
from kor.extraction.typedefs import DocumentExtraction, Extraction | ||
from kor.nodes import Object | ||
from kor.prompts import create_langchain_prompt | ||
from kor.type_descriptors import TypeDescriptor, initialize_type_descriptors | ||
from kor.validators import Validator | ||
|
||
|
||
async def _extract_from_document_with_semaphore( | ||
semaphore: asyncio.Semaphore, | ||
chain: LLMChain, | ||
document: Document, | ||
uid: str, | ||
source_uid: str, | ||
) -> DocumentExtraction: | ||
"""Extract from document with a semaphore to limit concurrency.""" | ||
async with semaphore: | ||
extraction_result: Extraction = cast( | ||
Extraction, await chain.apredict_and_parse(text=document.page_content) | ||
) | ||
return { | ||
"uid": uid, | ||
"source_uid": source_uid, | ||
"data": extraction_result["data"], | ||
"raw": extraction_result["raw"], | ||
"validated_data": extraction_result["validated_data"], | ||
"errors": extraction_result["errors"], | ||
} | ||
|
||
|
||
# PUBLIC API | ||
|
||
|
||
def create_extraction_chain( | ||
llm: BaseLanguageModel, | ||
node: Object, | ||
*, | ||
encoder_or_encoder_class: Union[Type[Encoder], Encoder, str] = "csv", | ||
type_descriptor: Union[TypeDescriptor, str] = "typescript", | ||
validator: Optional[Validator] = None, | ||
input_formatter: InputFormatter = None, | ||
instruction_template: Optional[PromptTemplate] = None, | ||
**encoder_kwargs: Any, | ||
) -> LLMChain: | ||
"""Create an extraction chain. | ||
Args: | ||
llm: the language model used for extraction | ||
node: the schematic description of what to extract from text | ||
encoder_or_encoder_class: Either an encoder instance, an encoder class | ||
or a string representing the encoder class | ||
type_descriptor: either a TypeDescriptor or a string representing the type \ | ||
descriptor name | ||
validator: optional validator to use for validation | ||
input_formatter: the formatter to use for encoding the input. Used for \ | ||
both input examples and the text to be analyzed. | ||
* `None`: use for single sentences or single paragraph, no formatting | ||
* `triple_quotes`: for long text, surround input with \"\"\" | ||
* `text_prefix`: for long text, triple_quote with `TEXT: ` prefix | ||
* `Callable`: user provided function | ||
instruction_template: optional prompt template to use, use to over-ride prompt | ||
used for generating the instruction section of the prompt. | ||
It accepts 2 optional input variables: | ||
* "type_description": type description of the node (from TypeDescriptor) | ||
* "format_instructions": information on how to format the output | ||
(from Encoder) | ||
encoder_kwargs: Keyword arguments to pass to the encoder class | ||
Returns: | ||
A langchain chain | ||
Examples: | ||
.. code-block:: python | ||
# For CSV encoding | ||
chain = create_extraction_chain(llm, node, encoder_or_encoder_class="csv") | ||
# For JSON encoding | ||
chain = create_extraction_chain(llm, node, encoder_or_encoder_class="JSON", | ||
input_formatter="triple_quotes") | ||
""" | ||
if not isinstance(node, Object): | ||
raise ValueError(f"node must be an Object got {type(node)}") | ||
encoder = initialize_encoder(encoder_or_encoder_class, node, **encoder_kwargs) | ||
type_descriptor_to_use = initialize_type_descriptors(type_descriptor) | ||
return LLMChain( | ||
llm=llm, | ||
prompt=create_langchain_prompt( | ||
node, | ||
encoder, | ||
type_descriptor_to_use, | ||
validator=validator, | ||
instruction_template=instruction_template, | ||
input_formatter=input_formatter, | ||
), | ||
) | ||
|
||
|
||
async def extract_from_documents( | ||
chain: LLMChain, | ||
documents: Sequence[Document], | ||
*, | ||
max_concurrency: int = 1, | ||
use_uid: bool = True, | ||
extraction_uid_function: Optional[Callable[[Document], str]] = None, | ||
return_exceptions: bool = False, | ||
) -> List[DocumentExtraction]: | ||
"""Run extraction through all the given documents. | ||
Attention: When using this function with a large number of documents, mind the bill | ||
since this can use a lot of tokens! | ||
Concurrency is currently limited using a semaphore. This is a temporary | ||
and can be changed to a queue implementation to support a non-materialized stream | ||
of documents. | ||
Args: | ||
chain: the extraction chain to use for extraction | ||
documents: the documents to run extraction on | ||
max_concurrency: the maximum number of concurrent requests to make, | ||
uses a semaphore to limit concurrency | ||
use_uid: If True, will use a uid attribute in metadata if it exists | ||
will raise error if attribute does not exist. | ||
If False, will use the index of the document in the list as the uid | ||
extraction_uid_function: Optional function to use to generate the uid for | ||
a given DocumentExtraction. If not provided, will use the uid | ||
of the document. | ||
return_exceptions: named argument passed to asyncio.gather | ||
Returns: | ||
A list of extraction results | ||
""" | ||
semaphore = asyncio.Semaphore(value=max_concurrency) | ||
|
||
tasks = [] | ||
for idx, doc in enumerate(documents): | ||
if use_uid: | ||
source_uid = doc.metadata.get("uid") | ||
if source_uid is None: | ||
raise ValueError( | ||
f"uid not found in document metadata for document {idx}" | ||
) | ||
# C | ||
source_uid = str(source_uid) | ||
else: | ||
source_uid = str(idx) | ||
|
||
extraction_uid = ( | ||
extraction_uid_function(doc) if extraction_uid_function else source_uid | ||
) | ||
|
||
tasks.append( | ||
asyncio.ensure_future( | ||
_extract_from_document_with_semaphore( | ||
semaphore, chain, doc, extraction_uid, source_uid | ||
) | ||
) | ||
) | ||
|
||
results = await asyncio.gather(*tasks, return_exceptions=return_exceptions) | ||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
"""Type definitions for the extraction package.""" | ||
from typing import Any, Dict, List, TypedDict | ||
|
||
|
||
class Extraction(TypedDict): | ||
"""Type-definition for an extraction result.""" | ||
|
||
raw: str | ||
"""The raw output from the LLM.""" | ||
data: Dict[str, Any] | ||
"""The decoding of the raw output from the LLM without any further processing.""" | ||
validated_data: Dict[str, Any] | ||
"""The validated data if a validator was provided.""" | ||
errors: List[Exception] | ||
"""Any errors encountered during decoding or validation.""" | ||
|
||
|
||
class DocumentExtraction(Extraction): | ||
"""Type-definition for a document extraction result. | ||
The original extraction typedefs together with the unique identifiers for the result | ||
itself as well as the source document. | ||
Identifiers are included to make it easier to link the extraction result | ||
to the source content. | ||
""" | ||
|
||
uid: str | ||
"""The uid of the extraction result.""" | ||
source_uid: str | ||
"""The source uid of the document from which data was extracted.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.