Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add text2sql tasks #1414

Open
wants to merge 118 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
118 commits
Select commit Hold shift + click to select a range
8b26a9e
add text2sql templates
perlitz Dec 4, 2024
4752c56
add data managment utility for text2sql
perlitz Dec 4, 2024
0713ff3
add basic template
perlitz Dec 4, 2024
7909077
add sql execution accuracy metric
perlitz Dec 4, 2024
4fdab71
Merge branch 'main' into add-text2sql
perlitz Dec 13, 2024
61a9232
add text2sql execution accuracy metric
perlitz Dec 4, 2024
94f10c4
add text2sql task
perlitz Dec 6, 2024
9a90f90
condition download in presence of a cache dir
perlitz Dec 6, 2024
b37b467
add init fille
perlitz Dec 6, 2024
8d4894d
add processors
perlitz Dec 6, 2024
bc0d165
add processors
perlitz Dec 6, 2024
a185342
add basic template
perlitz Dec 6, 2024
f93eee9
change id to int
perlitz Dec 13, 2024
97c1bef
change notations in templates
perlitz Dec 13, 2024
3927e20
push to catalog
perlitz Dec 13, 2024
6a50032
add evidence, remove SL
perlitz Dec 13, 2024
cec65fd
remove unued function, fix
perlitz Dec 13, 2024
94c9c1e
fix imports from unitxt.text2sql
perlitz Dec 13, 2024
e5eb4a3
push to catalog
perlitz Dec 13, 2024
77eab83
fix cache location
perlitz Dec 13, 2024
2239ed6
add example
perlitz Dec 13, 2024
982d54d
fix imports
perlitz Dec 16, 2024
9a321e1
Merge branch 'main' into add-text2sql
perlitz Dec 16, 2024
2e337ad
add func_timeout to test reqs
perlitz Dec 16, 2024
c132d7d
fix typing
perlitz Dec 16, 2024
0cec726
change template name
perlitz Dec 16, 2024
dfa1af8
push to catalog
perlitz Dec 16, 2024
c857513
add req
perlitz Dec 18, 2024
9c566cc
add local model option
perlitz Dec 18, 2024
57f41f1
Merge branch 'main' into add-text2sql
perlitz Dec 18, 2024
67c9b4e
fix databases download
perlitz Dec 18, 2024
4a013aa
fix databases download
perlitz Dec 18, 2024
2c5fe5d
add loader limit ot make example faster
perlitz Dec 18, 2024
02f1b23
fix cache paths, avoid re-download
perlitz Dec 18, 2024
1854c25
add type schema
perlitz Dec 18, 2024
c83c319
remove inports from inits
perlitz Dec 18, 2024
d51a6d7
add text2sql to inits
perlitz Dec 18, 2024
82e1fe8
update card to use serializers
perlitz Dec 18, 2024
98bc231
add schema serializer
perlitz Dec 18, 2024
2bce256
add text2sql serializer to default template
perlitz Dec 18, 2024
3b4c23a
add schema to task
perlitz Dec 18, 2024
5d9112f
adjust templates to using serializer
perlitz Dec 18, 2024
3a9bccc
adjust templates to using serializer
perlitz Dec 18, 2024
9fda158
fix processor
perlitz Dec 18, 2024
ac3ebee
remove target prefix from template
perlitz Dec 19, 2024
f313a8b
add shuffle to bird
perlitz Dec 19, 2024
e333d27
add shuffle to bird
perlitz Dec 19, 2024
3e23e4c
edit template
perlitz Dec 19, 2024
0d18070
remove comment from init
perlitz Dec 19, 2024
9fce798
clear processors code
perlitz Dec 19, 2024
ce38e3a
add option with ticks
perlitz Dec 19, 2024
38639c1
add anls metric
perlitz Dec 19, 2024
2d7aa81
Merge branch 'main' into add-text2sql
perlitz Jan 6, 2025
40f3a56
add template
perlitz Dec 20, 2024
980556c
drop comment
perlitz Jan 6, 2025
84e4695
remove recursion limit
perlitz Jan 6, 2025
4793e7c
add loader_limit to example
perlitz Jan 6, 2025
a68ead5
fix recursion error
perlitz Jan 6, 2025
29f2505
move import to withing metric
perlitz Jan 6, 2025
fccbfd3
remove catalog files wo prepare
perlitz Jan 6, 2025
543f716
fix typing
perlitz Jan 6, 2025
5512c9e
change template im example
perlitz Jan 6, 2025
aa4cac5
moving text2sql implementaion to the main src dir
perlitz Jan 6, 2025
92aec0c
fix imports
perlitz Jan 6, 2025
a1a197a
fix imports
perlitz Jan 6, 2025
0aaac1d
fix imports
perlitz Jan 6, 2025
b0a4c7b
fix imports
perlitz Jan 6, 2025
fe9cd1e
import data_utils
perlitz Jan 6, 2025
342b7c5
Merge branch 'main' into add-text2sql
perlitz Jan 7, 2025
b6da498
Merge branch 'main' into add-text2sql
perlitz Jan 8, 2025
3a8de12
fix formatting
perlitz Jan 8, 2025
89b0ce0
refactor names
perlitz Jan 8, 2025
52982a6
add processors tests
perlitz Jan 8, 2025
cac3983
Merge branch 'main' into add-text2sql
perlitz Jan 8, 2025
0966eb7
add more tests
perlitz Jan 8, 2025
f5f4b50
add tests
perlitz Jan 8, 2025
2ed595e
refactor: allow more data sources
perlitz Jan 9, 2025
32c834b
allow db source input
perlitz Jan 9, 2025
b499715
organize imports
perlitz Jan 9, 2025
d75ecc6
update example
perlitz Jan 9, 2025
2fdab7b
add db_type to task
perlitz Jan 10, 2025
c02096b
format
perlitz Jan 10, 2025
5d689b5
add db_type to task
perlitz Jan 10, 2025
1317158
add local db definition ability
perlitz Jan 10, 2025
c3d5a2a
add EE tests
perlitz Jan 10, 2025
0fbbbb3
Merge branch 'main' into add-text2sql
perlitz Jan 10, 2025
b21c124
add tests
perlitz Jan 14, 2025
1f95e5a
rename file
perlitz Jan 14, 2025
4b8f029
rename file
perlitz Jan 14, 2025
52d8b84
update sql metric
perlitz Jan 14, 2025
e7222bf
rename file
perlitz Jan 14, 2025
9bd79cf
refactor types, serializers and metric
perlitz Jan 15, 2025
d651e9a
Merge branch 'main' into add-text2sql
perlitz Jan 15, 2025
9d41b4b
remove format_table
perlitz Jan 15, 2025
afe4121
add get schema for remove connector
perlitz Jan 15, 2025
a7f9f13
add tests for LocalConnector
perlitz Jan 15, 2025
fbb42fa
add tests for InMemoryDatabaseConnector
perlitz Jan 15, 2025
0ee81ea
add serializer tests
perlitz Jan 15, 2025
74525cb
remove fp test
perlitz Jan 15, 2025
14ddb22
fix serializer
perlitz Jan 15, 2025
ed2a422
make remote connector more robust
perlitz Jan 15, 2025
c6dae8d
Add schema serializer
perlitz Jan 15, 2025
69db33e
fix tests
perlitz Jan 15, 2025
09ba273
change error
perlitz Jan 15, 2025
59cf7ca
add data to bird card
perlitz Jan 16, 2025
24eb41d
fix tests
perlitz Jan 16, 2025
e0e82d6
add tests for db utils
perlitz Jan 16, 2025
540d3f8
fix metric test
perlitz Jan 16, 2025
cbd1fd7
pre-commit
perlitz Jan 16, 2025
baff237
Merge branch 'main' into add-text2sql
perlitz Jan 16, 2025
49f7548
Merge branch 'main' into add-text2sql
perlitz Jan 16, 2025
a434e9d
delete temp
perlitz Jan 16, 2025
e27e996
make id an str
perlitz Jan 17, 2025
7a2fcd9
fix acess to db
perlitz Jan 17, 2025
e274dbd
add empty template
perlitz Jan 17, 2025
22d9511
compare the results entry from the meric
perlitz Jan 17, 2025
daa28bb
reformat
perlitz Jan 17, 2025
f061bb8
reformat
perlitz Jan 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,4 @@ benchmark_output/*
.litellm_cache

docs/_static/data.js
cache
62 changes: 62 additions & 0 deletions examples/evaluate_text2sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from unitxt import evaluate, load_dataset, settings
from unitxt.inference import CrossProviderInferenceEngine
from unitxt.text_utils import print_dict

with settings.context(
disable_hf_datasets_cache=False,
allow_unverified_code=True,
):
test_dataset = load_dataset(
"card=cards.text2sql.bird"
",template=templates.text2sql.you_are_given_with_hint_with_sql_prefix,loader_limit=10",
# ",template=templates.text2sql.you_are_given_with_hint_with_sql_prefix",
split="validation",
)

# Infer
inference_model = CrossProviderInferenceEngine(
model="llama-3-70b-instruct",
max_tokens=256,
)

predictions = inference_model.infer(test_dataset)
evaluated_dataset = evaluate(predictions=predictions, data=test_dataset)

print_dict(
evaluated_dataset[0],
keys_to_print=[
"source",
"prediction",
"subset",
],
)
print_dict(
evaluated_dataset[0]["score"]["global"],
)

# with llama-3-70b-instruct
# num_of_instances (int):
# 1534
# execution_accuracy (float):
# 0.482

# like GPT4 (rank 40 in the benchmark https://bird-bench.github.io/)

# from transformers import AutoModelForCausalLM, AutoTokenizer

# DEBUG_NUM_EXAMPLES = 2
# model_name = "meta-llama/Llama-3.2-1B-Instruct"
# model = AutoModelForCausalLM.from_pretrained(model_name)
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# tokenizer.pad_token = tokenizer.eos_token
# test_dataset = test_dataset.select(range(DEBUG_NUM_EXAMPLES))
# predictions = tokenizer.batch_decode(
# model.generate(
# **tokenizer.batch_encode_plus(
# test_dataset["source"], return_tensors="pt", padding=True
# ),
# max_length=2048,
# ),
# skip_special_tokens=True,
# clean_up_tokenization_spaces=True,
# )
54 changes: 54 additions & 0 deletions prepare/cards/text2sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import sys

from unitxt import add_to_catalog
from unitxt.blocks import Copy, Rename, Set, TaskCard
from unitxt.loaders import LoadHF
from unitxt.operators import ExecuteExpression, Shuffle

card = TaskCard(
loader=LoadHF(path="premai-io/birdbench", split="validation"),
preprocess_steps=[
Shuffle(page_size=sys.maxsize),
Rename(
field_to_field={
"question_id": "id",
"question": "utterance",
"SQL": "query",
"db_id": "db_id",
"evidence": "hint",
}
),
Set(
fields={
"dbms": "sqlite",
"db_type": "local",
"use_oracle_knowledge": True,
"num_table_rows_to_add": 0,
"data": None,
}
),
ExecuteExpression(
expression="'bird/'+db_id",
to_field="db_id",
),
Copy(field="db_id", to_field="db/db_id"),
Copy(field="db_type", to_field="db/db_type"),
Copy(field="dbms", to_field="db/dbms"),
Copy(field="data", to_field="db/data"),
],
task="tasks.text2sql",
templates="templates.text2sql.all",
)

# test_card(card, num_demos=0, demos_pool_size=0, )

add_to_catalog(
card,
"cards.text2sql.bird",
overwrite=True,
)

# from unitxt import evaluate, load_dataset

# ds = load_dataset("card=cards.text2sql.bird,template_card_index=0")
# scores = evaluate(predictions=ds["validation"]["target"], data=ds["validation"])
65 changes: 65 additions & 0 deletions prepare/metrics/text2sql_execution_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from unitxt.catalog import add_to_catalog
from unitxt.metrics import ExecutionAccuracy
from unitxt.test_utils.metrics import test_metric

metric = ExecutionAccuracy()

predictions = [
"SELECT nme FROM employees WHERE department = 'Sales'",
"SELECT name FROM employees WHERE department = 'Sales'",
] # Incorrect column name 'nme'
references = [["SELECT name FROM employees WHERE department = 'Sales';"]] * 2
task_data = [
{
"db": {
"db_id": "mock_db",
"db_type": "in_memory",
"data": {
"employees": {
"columns": ["id", "name", "department", "salary"],
"rows": [
(1, "Alice", "Sales", 50000),
(2, "Bob", "Engineering", 60000),
(3, "Charlie", "Sales", 55000),
],
}
},
}
}
] * 2

instance_targets = [
{
"execution_accuracy": 0.0,
"score": 0.0,
"score_name": "execution_accuracy",
},
{
"execution_accuracy": 1.0,
"score": 1.0,
"score_name": "execution_accuracy",
},
]


global_target = {
"execution_accuracy": 0.5,
"execution_accuracy_ci_high": 1.0,
"execution_accuracy_ci_low": 0.0,
"num_of_instances": 2,
"score": 0.5,
"score_ci_high": 1.0,
"score_ci_low": 0.0,
"score_name": "execution_accuracy",
}

outputs = test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
task_data=task_data,
)

add_to_catalog(metric, "metrics.text2sql.execution_accuracy", overwrite=True)
13 changes: 13 additions & 0 deletions prepare/processors/text2sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from unitxt import add_to_catalog
from unitxt.operator import SequentialOperator
from unitxt.processors import GetSQL

add_to_catalog(
SequentialOperator(
steps=[
GetSQL(field="prediction"),
]
),
"processors.text2sql.get_sql",
overwrite=True,
)
6 changes: 6 additions & 0 deletions prepare/serializers/text2sql_serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from unitxt import add_to_catalog
from unitxt.serializers import SQLDatabaseAsSchemaSerializer

add_to_catalog(
SQLDatabaseAsSchemaSerializer(), "serializers.text2sql.schema", overwrite=True
)
19 changes: 19 additions & 0 deletions prepare/tasks/text2sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from unitxt.blocks import Task
from unitxt.catalog import add_to_catalog
from unitxt.types import SQLDatabase

add_to_catalog(
Task(
input_fields={
"id": str,
"utterance": str,
"hint": str,
"db": SQLDatabase,
},
reference_fields={"query": str},
prediction_type=str,
metrics=["metrics.text2sql.execution_accuracy", "metrics.anls"],
),
"tasks.text2sql",
overwrite=True,
)
64 changes: 64 additions & 0 deletions prepare/templates/text2sql/templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from unitxt import add_to_catalog
from unitxt.blocks import TemplatesList
from unitxt.templates import InputOutputTemplate

template_details = [
(
"templates.text2sql.you_are_given_with_sql_prefix",
"You are given the following question:\n\n{utterance}\n\nAn SQL schema\n\n```sql\n\n{db}\n```\n\nAnswer the following question:\n\n{utterance}\n\n",
"You are a Text2SQL generation model, in your answer, only have SQL code.\nStart your query with 'SELECT' and end it with ';'\n\n",
"```sql\nSELECT ",
),
(
"templates.text2sql.you_are_given_with_hint_with_sql_prefix",
"You are given the following question:\n\n{utterance}\n\nAn SQL schema\n\n```sql\n\n{db}\n```\n\nAnd hint:\n\n{hint}\n\nAnswer the following question:\n\n{utterance}\n\n",
"You are a Text2SQL generation model, in your answer, only have SQL code.\nMake sure you start your query with 'SELECT' and end it with ';'\n\n",
"```sql\nSELECT ",
),
(
"templates.text2sql.you_are_given",
"You are given the following question:\n\n{utterance}\n\nAn SQL schema\n\n```sql\n\n{db}\n```\n\nAnswer the following question:\n\n{utterance}\n\n",
"You are a Text2SQL generation model, in your answer, only have SQL code.\nStart your query with 'SELECT' and end it with ';'\n\n",
"",
),
(
"templates.text2sql.you_are_given_with_hint_with_sql_prefix",
"You are given the following question:\n\n{utterance}\n\nAn SQL schema\n\n```sql\n\n{db}\n```\n\nAnd hint:\n\n{hint}\n\nAnswer the following question:\n\n{utterance}\n\n",
"You are a Text2SQL generation model, in your answer, only have SQL code.\nMake sure you start your query with 'SELECT' and end it with ';'\n\n",
"",
),
(
"templates.text2sql.you_are_given_with_hint_answer_sql_prefix_no_inst",
"Question:\nYou are given the following SQL schema\n\n```sql\n{db}\n```\n\n{utterance}\n\n",
"",
"Answer:\n```sql\n",
),
(
"templates.text2sql.empty",
"{utterance}",
"",
"",
),
]

template_names = []
for name, input_format, instruction, target_prefix in template_details:
add_to_catalog(
InputOutputTemplate(
input_format=input_format,
instruction=instruction,
target_prefix=target_prefix,
output_format="{query}",
postprocessors=["processors.text2sql.get_sql"],
),
name,
overwrite=True,
)
template_names.append(name)


add_to_catalog(
TemplatesList(template_names),
"templates.text2sql.all",
overwrite=True,
)
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ tests = [
"bs4",
"tenacity==8.3.0",
"accelerate",
"spacy",
"spacy",
"func_timeout==4.3.5",
"Wikipedia-API"
]
ui = [
Expand Down
61 changes: 61 additions & 0 deletions src/unitxt/catalog/cards/text2sql/bird.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
{
"__type__": "task_card",
"loader": {
"__type__": "load_hf",
"path": "premai-io/birdbench",
"split": "validation"
},
"preprocess_steps": [
{
"__type__": "shuffle",
"page_size": 9223372036854775807
},
{
"__type__": "rename",
"field_to_field": {
"question_id": "id",
"question": "utterance",
"SQL": "query",
"db_id": "db_id",
"evidence": "hint"
}
},
{
"__type__": "set",
"fields": {
"dbms": "sqlite",
"db_type": "local",
"use_oracle_knowledge": true,
"num_table_rows_to_add": 0,
"data": null
}
},
{
"__type__": "execute_expression",
"expression": "'bird/'+db_id",
"to_field": "db_id"
},
{
"__type__": "copy",
"field": "db_id",
"to_field": "db/db_id"
},
{
"__type__": "copy",
"field": "db_type",
"to_field": "db/db_type"
},
{
"__type__": "copy",
"field": "dbms",
"to_field": "db/dbms"
},
{
"__type__": "copy",
"field": "data",
"to_field": "db/data"
}
],
"task": "tasks.text2sql",
"templates": "templates.text2sql.all"
}
3 changes: 3 additions & 0 deletions src/unitxt/catalog/metrics/text2sql/execution_accuracy.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"__type__": "execution_accuracy"
}
9 changes: 9 additions & 0 deletions src/unitxt/catalog/processors/text2sql/get_sql.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"__type__": "sequential_operator",
"steps": [
{
"__type__": "get_sql",
"field": "prediction"
}
]
}
3 changes: 3 additions & 0 deletions src/unitxt/catalog/serializers/text2sql/schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"__type__": "sql_database_as_schema_serializer"
}
Loading
Loading