-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added key value extraction evaluation and example with images (#1529)
* Added example for image key value extraction evaluation Signed-off-by: Yoav Katz <[email protected]> * Removed unneeded comments Signed-off-by: Yoav Katz <[email protected]> * Model key value extraction as a first class task in catalog Signed-off-by: Yoav Katz <[email protected]> * Added documentation link Signed-off-by: Yoav Katz <[email protected]> * Added default template to extraction task Signed-off-by: Yoav Katz <[email protected]> * Changed order of results printout for clarity Signed-off-by: Yoav Katz <[email protected]> * Update docs/docs/examples.rst Co-authored-by: Elron Bandel <[email protected]> * Ensure key values are strings as expected by the metric Signed-off-by: Yoav Katz <[email protected]> * Moved to more standard font across Linux and Mac Signed-off-by: Yoav Katz <[email protected]> * Moved to default font Signed-off-by: Yoav Katz <[email protected]> * Doc improvements Signed-off-by: Yoav Katz <[email protected]> --------- Signed-off-by: Yoav Katz <[email protected]> Co-authored-by: Elron Bandel <[email protected]>
- Loading branch information
1 parent
7b27d6e
commit 189c482
Showing
10 changed files
with
237 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import json | ||
|
||
from unitxt import get_logger | ||
from unitxt.api import create_dataset, evaluate | ||
from unitxt.inference import ( | ||
CrossProviderInferenceEngine, | ||
) | ||
|
||
logger = get_logger() | ||
keys = ["Worker", "LivesIn", "WorksAt"] | ||
|
||
|
||
def text_to_image(text: str): | ||
"""Return a image with the input text render in it.""" | ||
from PIL import Image, ImageDraw, ImageFont | ||
|
||
bg_color = (255, 255, 255) | ||
text_color = (0, 0, 0) | ||
font_size = 10 | ||
font = ImageFont.load_default(size=font_size) | ||
|
||
img = Image.new("RGB", (1, 1), bg_color) | ||
|
||
# Get dimensions of the text | ||
# text_width, text_height = font.getsize_multiline(value) | ||
|
||
# Create a new image with appropriate size | ||
img = Image.new("RGB", (1000, 1000), bg_color) | ||
draw = ImageDraw.Draw(img) | ||
|
||
# Draw the text on the image | ||
draw.multiline_text((0, 0), text, fill=text_color, font=font) | ||
return {"image": img, "format": "png"} | ||
|
||
|
||
test_set = [ | ||
{ | ||
"input": text_to_image("John lives in Texas."), | ||
"keys": keys, | ||
"key_value_pairs_answer": {"Worker": "John", "LivesIn": "Texas"}, | ||
}, | ||
{ | ||
"input": text_to_image("Phil works at Apple and eats an apple."), | ||
"keys": keys, | ||
"key_value_pairs_answer": {"Worker": "Phil", "WorksAt": "Apple"}, | ||
}, | ||
] | ||
|
||
|
||
dataset = create_dataset( | ||
task="tasks.key_value_extraction", | ||
template="templates.key_value_extraction.extract_in_json_format", | ||
test_set=test_set, | ||
split="test", | ||
format="formats.chat_api", | ||
) | ||
|
||
model = CrossProviderInferenceEngine( | ||
model="llama-3-2-11b-vision-instruct", provider="watsonx" | ||
) | ||
|
||
predictions = model(dataset) | ||
results = evaluate(predictions=predictions, data=dataset) | ||
|
||
print("Example prompt:") | ||
|
||
print(json.dumps(results.instance_scores[0]["source"], indent=4)) | ||
|
||
print("Instance Results:") | ||
print(results.instance_scores) | ||
|
||
print("Global Results:") | ||
print(results.global_scores.summary) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from typing import Any, Dict, List, Tuple | ||
|
||
from unitxt.blocks import Task | ||
from unitxt.catalog import add_to_catalog | ||
|
||
add_to_catalog( | ||
Task( | ||
__description__="This is a key value extraction task, where a specific list of possible 'keys' need to be extracted from the input. The ground truth is provided key-value pairs in the form of the dictionary. The results are evaluating using F1 score metric, that expects the predictions to be converted into a list of (key,value) pairs. ", | ||
input_fields={"input": Any, "keys": List[str]}, | ||
reference_fields={"key_value_pairs_answer": Dict[str, str]}, | ||
prediction_type=List[Tuple[str, str]], | ||
metrics=["metrics.key_value_extraction"], | ||
default_template="templates.key_value_extraction.extract_in_json_format", | ||
), | ||
"tasks.key_value_extraction", | ||
overwrite=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from unitxt import add_to_catalog | ||
from unitxt.processors import PostProcess | ||
from unitxt.struct_data_operators import JsonStrToListOfKeyValuePairs | ||
from unitxt.templates import ( | ||
InputOutputTemplate, | ||
) | ||
|
||
add_to_catalog( | ||
InputOutputTemplate( | ||
instruction="Extract the key value pairs from the input. Return a valid json object with the following keys: {keys}. Return only the json representation, no additional text or explanations.", | ||
input_format="{input}", | ||
output_format="{key_value_pairs_answer}", | ||
postprocessors=[PostProcess(JsonStrToListOfKeyValuePairs())], | ||
), | ||
"templates.key_value_extraction.extract_in_json_format", | ||
overwrite=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
{ | ||
"__type__": "key_value_extraction" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
{ | ||
"__type__": "task", | ||
"__description__": "This is a key value extraction task, where a specific list of possible 'keys' need to be extracted from the input. The ground truth is provided key-value pairs in the form of the dictionary. The results are evaluating using F1 score metric, that expects the predictions to be converted into a list of (key,value) pairs. ", | ||
"input_fields": { | ||
"input": "Any", | ||
"keys": "List[str]" | ||
}, | ||
"reference_fields": { | ||
"key_value_pairs_answer": "Dict[str, str]" | ||
}, | ||
"prediction_type": "List[Tuple[str, str]]", | ||
"metrics": [ | ||
"metrics.key_value_extraction" | ||
], | ||
"default_template": "templates.key_value_extraction.extract_in_json_format" | ||
} |
14 changes: 14 additions & 0 deletions
14
src/unitxt/catalog/templates/key_value_extraction/extract_in_json_format.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
{ | ||
"__type__": "input_output_template", | ||
"instruction": "Extract the key value pairs from the input. Return a valid json object with the following keys: {keys}. Return only the json representation, no additional text or explanations.", | ||
"input_format": "{input}", | ||
"output_format": "{key_value_pairs_answer}", | ||
"postprocessors": [ | ||
{ | ||
"__type__": "post_process", | ||
"operator": { | ||
"__type__": "json_str_to_list_of_key_value_pairs" | ||
} | ||
} | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters