diff --git a/.gitignore b/.gitignore index 5eb727462..2a0cc07d6 100644 --- a/.gitignore +++ b/.gitignore @@ -159,3 +159,4 @@ benchmark_output/* .litellm_cache docs/_static/data.js +cache diff --git a/examples/evaluate_text2sql.py b/examples/evaluate_text2sql.py new file mode 100644 index 000000000..7b34c888b --- /dev/null +++ b/examples/evaluate_text2sql.py @@ -0,0 +1,62 @@ +from unitxt import evaluate, load_dataset, settings +from unitxt.inference import CrossProviderInferenceEngine +from unitxt.text_utils import print_dict + +with settings.context( + disable_hf_datasets_cache=False, + allow_unverified_code=True, +): + test_dataset = load_dataset( + "card=cards.text2sql.bird" + ",template=templates.text2sql.you_are_given_with_hint_with_sql_prefix,loader_limit=10", + # ",template=templates.text2sql.you_are_given_with_hint_with_sql_prefix", + split="validation", + ) + +# Infer +inference_model = CrossProviderInferenceEngine( + model="llama-3-70b-instruct", + max_tokens=256, +) + +predictions = inference_model.infer(test_dataset) +evaluated_dataset = evaluate(predictions=predictions, data=test_dataset) + +print_dict( + evaluated_dataset[0], + keys_to_print=[ + "source", + "prediction", + "subset", + ], +) +print_dict( + evaluated_dataset[0]["score"]["global"], +) + +# with llama-3-70b-instruct +# num_of_instances (int): +# 1534 +# execution_accuracy (float): +# 0.482 + +# like GPT4 (rank 40 in the benchmark https://bird-bench.github.io/) + +# from transformers import AutoModelForCausalLM, AutoTokenizer + +# DEBUG_NUM_EXAMPLES = 2 +# model_name = "meta-llama/Llama-3.2-1B-Instruct" +# model = AutoModelForCausalLM.from_pretrained(model_name) +# tokenizer = AutoTokenizer.from_pretrained(model_name) +# tokenizer.pad_token = tokenizer.eos_token +# test_dataset = test_dataset.select(range(DEBUG_NUM_EXAMPLES)) +# predictions = tokenizer.batch_decode( +# model.generate( +# **tokenizer.batch_encode_plus( +# test_dataset["source"], return_tensors="pt", padding=True +# ), +# max_length=2048, +# ), +# skip_special_tokens=True, +# clean_up_tokenization_spaces=True, +# ) diff --git a/prepare/cards/text2sql.py b/prepare/cards/text2sql.py new file mode 100644 index 000000000..725532bf6 --- /dev/null +++ b/prepare/cards/text2sql.py @@ -0,0 +1,54 @@ +import sys + +from unitxt import add_to_catalog +from unitxt.blocks import Copy, Rename, Set, TaskCard +from unitxt.loaders import LoadHF +from unitxt.operators import ExecuteExpression, Shuffle + +card = TaskCard( + loader=LoadHF(path="premai-io/birdbench", split="validation"), + preprocess_steps=[ + Shuffle(page_size=sys.maxsize), + Rename( + field_to_field={ + "question_id": "id", + "question": "utterance", + "SQL": "query", + "db_id": "db_id", + "evidence": "hint", + } + ), + Set( + fields={ + "dbms": "sqlite", + "db_type": "local", + "use_oracle_knowledge": True, + "num_table_rows_to_add": 0, + "data": None, + } + ), + ExecuteExpression( + expression="'bird/'+db_id", + to_field="db_id", + ), + Copy(field="db_id", to_field="db/db_id"), + Copy(field="db_type", to_field="db/db_type"), + Copy(field="dbms", to_field="db/dbms"), + Copy(field="data", to_field="db/data"), + ], + task="tasks.text2sql", + templates="templates.text2sql.all", +) + +# test_card(card, num_demos=0, demos_pool_size=0, ) + +add_to_catalog( + card, + "cards.text2sql.bird", + overwrite=True, +) + +# from unitxt import evaluate, load_dataset + +# ds = load_dataset("card=cards.text2sql.bird,template_card_index=0") +# scores = evaluate(predictions=ds["validation"]["target"], data=ds["validation"]) diff --git a/prepare/metrics/text2sql_execution_accuracy.py b/prepare/metrics/text2sql_execution_accuracy.py new file mode 100644 index 000000000..ca15409a5 --- /dev/null +++ b/prepare/metrics/text2sql_execution_accuracy.py @@ -0,0 +1,65 @@ +from unitxt.catalog import add_to_catalog +from unitxt.metrics import ExecutionAccuracy +from unitxt.test_utils.metrics import test_metric + +metric = ExecutionAccuracy() + +predictions = [ + "SELECT nme FROM employees WHERE department = 'Sales'", + "SELECT name FROM employees WHERE department = 'Sales'", +] # Incorrect column name 'nme' +references = [["SELECT name FROM employees WHERE department = 'Sales';"]] * 2 +task_data = [ + { + "db": { + "db_id": "mock_db", + "db_type": "in_memory", + "data": { + "employees": { + "columns": ["id", "name", "department", "salary"], + "rows": [ + (1, "Alice", "Sales", 50000), + (2, "Bob", "Engineering", 60000), + (3, "Charlie", "Sales", 55000), + ], + } + }, + } + } +] * 2 + +instance_targets = [ + { + "execution_accuracy": 0.0, + "score": 0.0, + "score_name": "execution_accuracy", + }, + { + "execution_accuracy": 1.0, + "score": 1.0, + "score_name": "execution_accuracy", + }, +] + + +global_target = { + "execution_accuracy": 0.5, + "execution_accuracy_ci_high": 1.0, + "execution_accuracy_ci_low": 0.0, + "num_of_instances": 2, + "score": 0.5, + "score_ci_high": 1.0, + "score_ci_low": 0.0, + "score_name": "execution_accuracy", +} + +outputs = test_metric( + metric=metric, + predictions=predictions, + references=references, + instance_targets=instance_targets, + global_target=global_target, + task_data=task_data, +) + +add_to_catalog(metric, "metrics.text2sql.execution_accuracy", overwrite=True) diff --git a/prepare/processors/text2sql.py b/prepare/processors/text2sql.py new file mode 100644 index 000000000..85fb4c663 --- /dev/null +++ b/prepare/processors/text2sql.py @@ -0,0 +1,13 @@ +from unitxt import add_to_catalog +from unitxt.operator import SequentialOperator +from unitxt.processors import GetSQL + +add_to_catalog( + SequentialOperator( + steps=[ + GetSQL(field="prediction"), + ] + ), + "processors.text2sql.get_sql", + overwrite=True, +) diff --git a/prepare/serializers/text2sql_serializers.py b/prepare/serializers/text2sql_serializers.py new file mode 100644 index 000000000..2db1923cf --- /dev/null +++ b/prepare/serializers/text2sql_serializers.py @@ -0,0 +1,6 @@ +from unitxt import add_to_catalog +from unitxt.serializers import SQLDatabaseAsSchemaSerializer + +add_to_catalog( + SQLDatabaseAsSchemaSerializer(), "serializers.text2sql.schema", overwrite=True +) diff --git a/prepare/tasks/text2sql.py b/prepare/tasks/text2sql.py new file mode 100644 index 000000000..b4be3f743 --- /dev/null +++ b/prepare/tasks/text2sql.py @@ -0,0 +1,19 @@ +from unitxt.blocks import Task +from unitxt.catalog import add_to_catalog +from unitxt.types import SQLDatabase + +add_to_catalog( + Task( + input_fields={ + "id": str, + "utterance": str, + "hint": str, + "db": SQLDatabase, + }, + reference_fields={"query": str}, + prediction_type=str, + metrics=["metrics.text2sql.execution_accuracy", "metrics.anls"], + ), + "tasks.text2sql", + overwrite=True, +) diff --git a/prepare/templates/text2sql/templates.py b/prepare/templates/text2sql/templates.py new file mode 100644 index 000000000..9895c4dd1 --- /dev/null +++ b/prepare/templates/text2sql/templates.py @@ -0,0 +1,64 @@ +from unitxt import add_to_catalog +from unitxt.blocks import TemplatesList +from unitxt.templates import InputOutputTemplate + +template_details = [ + ( + "templates.text2sql.you_are_given_with_sql_prefix", + "You are given the following question:\n\n{utterance}\n\nAn SQL schema\n\n```sql\n\n{db}\n```\n\nAnswer the following question:\n\n{utterance}\n\n", + "You are a Text2SQL generation model, in your answer, only have SQL code.\nStart your query with 'SELECT' and end it with ';'\n\n", + "```sql\nSELECT ", + ), + ( + "templates.text2sql.you_are_given_with_hint_with_sql_prefix", + "You are given the following question:\n\n{utterance}\n\nAn SQL schema\n\n```sql\n\n{db}\n```\n\nAnd hint:\n\n{hint}\n\nAnswer the following question:\n\n{utterance}\n\n", + "You are a Text2SQL generation model, in your answer, only have SQL code.\nMake sure you start your query with 'SELECT' and end it with ';'\n\n", + "```sql\nSELECT ", + ), + ( + "templates.text2sql.you_are_given", + "You are given the following question:\n\n{utterance}\n\nAn SQL schema\n\n```sql\n\n{db}\n```\n\nAnswer the following question:\n\n{utterance}\n\n", + "You are a Text2SQL generation model, in your answer, only have SQL code.\nStart your query with 'SELECT' and end it with ';'\n\n", + "", + ), + ( + "templates.text2sql.you_are_given_with_hint_with_sql_prefix", + "You are given the following question:\n\n{utterance}\n\nAn SQL schema\n\n```sql\n\n{db}\n```\n\nAnd hint:\n\n{hint}\n\nAnswer the following question:\n\n{utterance}\n\n", + "You are a Text2SQL generation model, in your answer, only have SQL code.\nMake sure you start your query with 'SELECT' and end it with ';'\n\n", + "", + ), + ( + "templates.text2sql.you_are_given_with_hint_answer_sql_prefix_no_inst", + "Question:\nYou are given the following SQL schema\n\n```sql\n{db}\n```\n\n{utterance}\n\n", + "", + "Answer:\n```sql\n", + ), + ( + "templates.text2sql.empty", + "{utterance}", + "", + "", + ), +] + +template_names = [] +for name, input_format, instruction, target_prefix in template_details: + add_to_catalog( + InputOutputTemplate( + input_format=input_format, + instruction=instruction, + target_prefix=target_prefix, + output_format="{query}", + postprocessors=["processors.text2sql.get_sql"], + ), + name, + overwrite=True, + ) + template_names.append(name) + + +add_to_catalog( + TemplatesList(template_names), + "templates.text2sql.all", + overwrite=True, +) diff --git a/pyproject.toml b/pyproject.toml index 16b47a89b..5ca59df22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,7 +104,8 @@ tests = [ "bs4", "tenacity==8.3.0", "accelerate", - "spacy", + "spacy", + "func_timeout==4.3.5", "Wikipedia-API" ] ui = [ diff --git a/src/unitxt/catalog/cards/text2sql/bird.json b/src/unitxt/catalog/cards/text2sql/bird.json new file mode 100644 index 000000000..6117a59a4 --- /dev/null +++ b/src/unitxt/catalog/cards/text2sql/bird.json @@ -0,0 +1,61 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_hf", + "path": "premai-io/birdbench", + "split": "validation" + }, + "preprocess_steps": [ + { + "__type__": "shuffle", + "page_size": 9223372036854775807 + }, + { + "__type__": "rename", + "field_to_field": { + "question_id": "id", + "question": "utterance", + "SQL": "query", + "db_id": "db_id", + "evidence": "hint" + } + }, + { + "__type__": "set", + "fields": { + "dbms": "sqlite", + "db_type": "local", + "use_oracle_knowledge": true, + "num_table_rows_to_add": 0, + "data": null + } + }, + { + "__type__": "execute_expression", + "expression": "'bird/'+db_id", + "to_field": "db_id" + }, + { + "__type__": "copy", + "field": "db_id", + "to_field": "db/db_id" + }, + { + "__type__": "copy", + "field": "db_type", + "to_field": "db/db_type" + }, + { + "__type__": "copy", + "field": "dbms", + "to_field": "db/dbms" + }, + { + "__type__": "copy", + "field": "data", + "to_field": "db/data" + } + ], + "task": "tasks.text2sql", + "templates": "templates.text2sql.all" +} diff --git a/src/unitxt/catalog/metrics/text2sql/execution_accuracy.json b/src/unitxt/catalog/metrics/text2sql/execution_accuracy.json new file mode 100644 index 000000000..cd503301f --- /dev/null +++ b/src/unitxt/catalog/metrics/text2sql/execution_accuracy.json @@ -0,0 +1,3 @@ +{ + "__type__": "execution_accuracy" +} diff --git a/src/unitxt/catalog/processors/text2sql/get_sql.json b/src/unitxt/catalog/processors/text2sql/get_sql.json new file mode 100644 index 000000000..d54969ebd --- /dev/null +++ b/src/unitxt/catalog/processors/text2sql/get_sql.json @@ -0,0 +1,9 @@ +{ + "__type__": "sequential_operator", + "steps": [ + { + "__type__": "get_sql", + "field": "prediction" + } + ] +} diff --git a/src/unitxt/catalog/serializers/text2sql/schema.json b/src/unitxt/catalog/serializers/text2sql/schema.json new file mode 100644 index 000000000..093b2efdd --- /dev/null +++ b/src/unitxt/catalog/serializers/text2sql/schema.json @@ -0,0 +1,3 @@ +{ + "__type__": "sql_database_as_schema_serializer" +} diff --git a/src/unitxt/catalog/tasks/text2sql.json b/src/unitxt/catalog/tasks/text2sql.json new file mode 100644 index 000000000..9cf2723db --- /dev/null +++ b/src/unitxt/catalog/tasks/text2sql.json @@ -0,0 +1,17 @@ +{ + "__type__": "task", + "input_fields": { + "id": "str", + "utterance": "str", + "hint": "str", + "db": "SQLDatabase" + }, + "reference_fields": { + "query": "str" + }, + "prediction_type": "str", + "metrics": [ + "metrics.text2sql.execution_accuracy", + "metrics.anls" + ] +} diff --git a/src/unitxt/catalog/templates/text2sql/all.json b/src/unitxt/catalog/templates/text2sql/all.json new file mode 100644 index 000000000..e9814e42c --- /dev/null +++ b/src/unitxt/catalog/templates/text2sql/all.json @@ -0,0 +1,11 @@ +{ + "__type__": "templates_list", + "items": [ + "templates.text2sql.you_are_given_with_sql_prefix", + "templates.text2sql.you_are_given_with_hint_with_sql_prefix", + "templates.text2sql.you_are_given", + "templates.text2sql.you_are_given_with_hint_with_sql_prefix", + "templates.text2sql.you_are_given_with_hint_answer_sql_prefix_no_inst", + "templates.text2sql.empty" + ] +} diff --git a/src/unitxt/catalog/templates/text2sql/empty.json b/src/unitxt/catalog/templates/text2sql/empty.json new file mode 100644 index 000000000..a405d9f6f --- /dev/null +++ b/src/unitxt/catalog/templates/text2sql/empty.json @@ -0,0 +1,10 @@ +{ + "__type__": "input_output_template", + "input_format": "{utterance}", + "instruction": "", + "target_prefix": "", + "output_format": "{query}", + "postprocessors": [ + "processors.text2sql.get_sql" + ] +} diff --git a/src/unitxt/catalog/templates/text2sql/you_are_given.json b/src/unitxt/catalog/templates/text2sql/you_are_given.json new file mode 100644 index 000000000..e92dc8995 --- /dev/null +++ b/src/unitxt/catalog/templates/text2sql/you_are_given.json @@ -0,0 +1,10 @@ +{ + "__type__": "input_output_template", + "input_format": "You are given the following question:\n\n{utterance}\n\nAn SQL schema\n\n```sql\n\n{db}\n```\n\nAnswer the following question:\n\n{utterance}\n\n", + "instruction": "You are a Text2SQL generation model, in your answer, only have SQL code.\nStart your query with 'SELECT' and end it with ';'\n\n", + "target_prefix": "", + "output_format": "{query}", + "postprocessors": [ + "processors.text2sql.get_sql" + ] +} diff --git a/src/unitxt/catalog/templates/text2sql/you_are_given_with_hint_answer_sql_prefix_no_inst.json b/src/unitxt/catalog/templates/text2sql/you_are_given_with_hint_answer_sql_prefix_no_inst.json new file mode 100644 index 000000000..d07ac9028 --- /dev/null +++ b/src/unitxt/catalog/templates/text2sql/you_are_given_with_hint_answer_sql_prefix_no_inst.json @@ -0,0 +1,10 @@ +{ + "__type__": "input_output_template", + "input_format": "Question:\nYou are given the following SQL schema\n\n```sql\n{db}\n```\n\n{utterance}\n\n", + "instruction": "", + "target_prefix": "Answer:\n```sql\n", + "output_format": "{query}", + "postprocessors": [ + "processors.text2sql.get_sql" + ] +} diff --git a/src/unitxt/catalog/templates/text2sql/you_are_given_with_hint_with_sql_prefix.json b/src/unitxt/catalog/templates/text2sql/you_are_given_with_hint_with_sql_prefix.json new file mode 100644 index 000000000..1ffa1e5ea --- /dev/null +++ b/src/unitxt/catalog/templates/text2sql/you_are_given_with_hint_with_sql_prefix.json @@ -0,0 +1,10 @@ +{ + "__type__": "input_output_template", + "input_format": "You are given the following question:\n\n{utterance}\n\nAn SQL schema\n\n```sql\n\n{db}\n```\n\nAnd hint:\n\n{hint}\n\nAnswer the following question:\n\n{utterance}\n\n", + "instruction": "You are a Text2SQL generation model, in your answer, only have SQL code.\nMake sure you start your query with 'SELECT' and end it with ';'\n\n", + "target_prefix": "", + "output_format": "{query}", + "postprocessors": [ + "processors.text2sql.get_sql" + ] +} diff --git a/src/unitxt/catalog/templates/text2sql/you_are_given_with_sql_prefix.json b/src/unitxt/catalog/templates/text2sql/you_are_given_with_sql_prefix.json new file mode 100644 index 000000000..2723700ed --- /dev/null +++ b/src/unitxt/catalog/templates/text2sql/you_are_given_with_sql_prefix.json @@ -0,0 +1,10 @@ +{ + "__type__": "input_output_template", + "input_format": "You are given the following question:\n\n{utterance}\n\nAn SQL schema\n\n```sql\n\n{db}\n```\n\nAnswer the following question:\n\n{utterance}\n\n", + "instruction": "You are a Text2SQL generation model, in your answer, only have SQL code.\nStart your query with 'SELECT' and end it with ';'\n\n", + "target_prefix": "```sql\nSELECT ", + "output_format": "{query}", + "postprocessors": [ + "processors.text2sql.get_sql" + ] +} diff --git a/src/unitxt/dataset.py b/src/unitxt/dataset.py index 9fc23f467..cf1b52219 100644 --- a/src/unitxt/dataset.py +++ b/src/unitxt/dataset.py @@ -14,6 +14,7 @@ from .collections_operators import __file__ as _ from .dataclass import __file__ as _ from .dataset_utils import get_dataset_artifact +from .db_utils import __file__ as _ from .deprecation_utils import __file__ as _ from .dialog_operators import __file__ as _ from .dict_utils import __file__ as _ diff --git a/src/unitxt/db_utils.py b/src/unitxt/db_utils.py new file mode 100644 index 000000000..63bc98833 --- /dev/null +++ b/src/unitxt/db_utils.py @@ -0,0 +1,307 @@ +import glob +import os +import sqlite3 +import time +from abc import ABC, abstractmethod +from typing import Any, List, Optional + +import evaluate +import requests +from huggingface_hub import snapshot_download +from requests.exceptions import ConnectionError, ReadTimeout + +from .types import SQLDatabase + +# Path to the user's databases cache directory. +# Logger instance. + +logger = evaluate.logging.get_logger(__name__) + + +class DatabaseConnector(ABC): + """Abstract base class for database connectors.""" + + def __init__(self, db_config: SQLDatabase): + self.db_config = db_config + self.databases_folder = os.path.join( + os.environ.get("UNITXT_TEXT2SQL_CACHE", "cache/text2sql"), "databases" + ) + os.makedirs(self.databases_folder, exist_ok=True) + + @abstractmethod + def get_table_schema( + self, + ) -> str: + """Abstract method to get database schema.""" + pass + + @abstractmethod + def execute_query(self, query: str) -> Any: + """Abstract method to execute a query against the database.""" + pass + + +class LocalSQLiteConnector(DatabaseConnector): + """Database connector for SQLite databases.""" + + def __init__(self, db_config: SQLDatabase): + super().__init__(db_config) + db_id = self.db_config.get("db_id") + if not db_id: + raise ValueError("db_id is required for SQLiteConnector.") + self.db_path = self.get_db_file_path(db_id) + self.conn: sqlite3.Connection = sqlite3.connect(self.db_path) + self.cursor: sqlite3.Cursor = self.conn.cursor() + + def download_database(self, db_id): + """Downloads the database from huggingface if needed.""" + done_file_path = os.path.join(self.databases_folder, "download_done") + if "bird/" in db_id: + if not os.path.exists(done_file_path): + snapshot_download( + repo_id="premai-io/birdbench", + repo_type="dataset", + local_dir=self.databases_folder, + force_download=False, + allow_patterns="*validation*", + ) + open(os.path.join(self.databases_folder, "download_done"), "w").close() + else: + raise NotImplementedError( + f"current local db: {db_id} is not supported, only bird" + ) + + def get_db_file_path(self, db_id): + """Gets the local path of a downloaded database file.""" + self.download_database(db_id) + db_id = db_id.split("/")[-1] + + db_file_pattern = os.path.join(self.databases_folder, "**", db_id + ".sqlite") + db_file_paths = glob.glob(db_file_pattern, recursive=True) + + if not db_file_paths: + raise FileNotFoundError(f"Database file {db_id} not found.") + if len(db_file_paths) > 1: + raise FileExistsError(f"More than one files matched for {db_id}") + return db_file_paths[0] + + def get_table_schema( + self, + ) -> str: + """Extracts schema from an SQLite database.""" + self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables: list[tuple[str]] = self.cursor.fetchall() + schemas: dict[str, str] = {} + + for table in tables: + if isinstance(table, tuple): + table = table[0] + if table == "sqlite_sequence": + continue + sql_query: str = ( + f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}';" + ) + self.cursor.execute(sql_query) + schema_prompt: str = self.cursor.fetchone()[0] + + schemas[table] = schema_prompt + + schema_prompt: str = "\n\n".join(list(schemas.values())) + return schema_prompt + + def execute_query(self, query: str) -> Any: + """Executes a query against the SQLite database.""" + conn = None # Initialize conn to None outside the try block + try: + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute(query) + return cursor.fetchall() + except sqlite3.Error as e: + logger.error(f"Error executing SQL: {e}") + return None + finally: + if conn: + conn.close() + + +class InMemoryDatabaseConnector(DatabaseConnector): + """Database connector for mocking databases with in-memory data structures.""" + + def __init__(self, db_config: SQLDatabase): + super().__init__(db_config) + self.tables = db_config.get("data", None) + + if not self.tables: + raise ValueError("data is required for InMemoryDatabaseConnector.") + + def get_table_schema( + self, + select_tables: Optional[List[str]] = None, + ) -> str: + """Generates a mock schema from the tables structure.""" + schemas = {} + for table_name, table_data in self.tables.items(): + if select_tables and table_name.lower() not in select_tables: + continue + columns = ", ".join([f"`{col}` TEXT" for col in table_data["columns"]]) + schema = f"CREATE TABLE `{table_name}` ({columns});" + + schemas[table_name] = schema + + return "\n\n".join(list(schemas.values())) + + def execute_query(self, query: str) -> Any: + """Simulates executing a query against the mock database.""" + # Initialize in-memory database from the 'tables' dictionary + conn = sqlite3.connect(":memory:") + cursor = conn.cursor() + logger.debug("Running SQL query over in-memory DB") + + # Create tables and insert data from the 'db' dictionary + for table_name, table_data in self.tables.items(): + columns = table_data["columns"] + rows = table_data["rows"] + + # Create table + cursor.execute(f"CREATE TABLE {table_name} ({', '.join(columns)})") + + # Insert data + placeholders = ", ".join(["?"] * len(columns)) + cursor.executemany( + f"INSERT INTO {table_name} VALUES ({placeholders})", rows + ) + + try: + cursor.execute(query) + return cursor.fetchall() + except sqlite3.Error as e: + logger.error(f"Error executing SQL: {e}") + return None + finally: + conn.close() + + +class RemoteDatabaseConnector(DatabaseConnector): + """Database connector for remote databases accessed via HTTP.""" + + RETRYABLE_EXCEPTIONS = (ConnectionError, ReadTimeout) + MAX_RETRIES = 3 + RETRY_DELAY = 5 # seconds + TIMEOUT = 30 # seconds + + def __init__(self, db_config: SQLDatabase): + super().__init__(db_config) + + assert db_config[ + "db_id" + ], "db_id must be in db_config for RemoteDatabaseConnector" + self.api_url, self.database_id = ( + db_config["db_id"].split(",")[0], + db_config["db_id"].split("db_id=")[-1].split(",")[0], + ) + + if not self.api_url or not self.database_id: + raise ValueError( + "Both 'api_url' and 'database_id' are required for RemoteDatabaseConnector." + ) + + self.api_key = os.getenv("SQL_API_KEY", None) + if not self.api_key: + raise ValueError( + "The environment variable 'SQL_API_KEY' must be set to use the RemoteDatabaseConnector." + ) + + self.base_headers = { + "Content-Type": "application/json", + "accept": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + + def get_table_schema( + self, + ) -> str: + """Retrieves the schema of a database.""" + cur_api_url = f"{self.api_url}/datasource/{self.database_id}" + response = requests.get( + cur_api_url, + headers=self.base_headers, + verify=True, + timeout=self.TIMEOUT, + ) + if response.status_code == 200: + schema = response.json()["schema"] + else: + raise OSError(f"Could not fetch schema from {cur_api_url}") + + schema_text = "" + for table in schema["tables"]: + schema_text += f"Table: {table['table_name']} has columns: {[col['column_name'] for col in table['columns']]}\n" + + return schema_text + + def execute_query(self, query: str) -> Any: + """Executes a query against the remote database, with retries for certain exceptions.""" + retries = 0 + while retries <= self.MAX_RETRIES: + try: + response = requests.post( + f"{self.api_url}/sql", + headers=self.base_headers, + json={"sql": query, "dataSourceId": self.database_id}, + verify=True, + timeout=self.TIMEOUT, + ) + response.raise_for_status() + return response.json() + + except self.RETRYABLE_EXCEPTIONS as e: + retries += 1 + logger.warning( + f"Attempt {retries} failed with error: {e}. Retrying in {self.RETRY_DELAY} seconds." + ) + if retries <= self.MAX_RETRIES: + time.sleep(self.RETRY_DELAY) + else: + logger.error( + f"Max retries ({self.MAX_RETRIES}) exceeded for query: {query}" + ) + return None + + except requests.exceptions.HTTPError as e: + if e.response.status_code >= 500: + retries += 1 + logger.warning( + f"Server error, attempt {retries} failed with error: {e}. Retrying in {self.RETRY_DELAY} seconds." + ) + if retries <= self.MAX_RETRIES: + time.sleep(self.RETRY_DELAY) + else: + logger.error( + f"Max retries ({self.MAX_RETRIES}) exceeded for query: {query}" + ) + return None + else: + logger.error(f"HTTP Error on attempt {retries}: {e}") + return None + + except Exception as e: + logger.error(f"Unexpected error on attempt {retries}: {e}") + return None + + return None + + +def get_db_connector(db_type: str): + """Creates and returns the appropriate DatabaseConnector instance based on db_type.""" + if db_type == "local": + connector = LocalSQLiteConnector + elif db_type == "in_memory": + connector = InMemoryDatabaseConnector + elif db_type == "remote": + connector = RemoteDatabaseConnector + + else: + raise ValueError(f"Unsupported database type: {db_type}") + + return connector diff --git a/src/unitxt/metric.py b/src/unitxt/metric.py index 06665a531..37d740bb1 100644 --- a/src/unitxt/metric.py +++ b/src/unitxt/metric.py @@ -13,6 +13,7 @@ from .collections_operators import __file__ as _ from .dataclass import __file__ as _ from .dataset_utils import __file__ as _ +from .db_utils import __file__ as _ from .deprecation_utils import __file__ as _ from .dialog_operators import __file__ as _ from .dict_utils import __file__ as _ diff --git a/src/unitxt/metrics.py b/src/unitxt/metrics.py index ec771c327..2fee2518e 100644 --- a/src/unitxt/metrics.py +++ b/src/unitxt/metrics.py @@ -7,14 +7,16 @@ import uuid import warnings from abc import ABC, abstractmethod -from collections import Counter, defaultdict, namedtuple +from collections import Counter, defaultdict from dataclasses import field from functools import lru_cache from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, Union +import evaluate import numpy import numpy as np import pandas as pd +import requests from scipy.stats import bootstrap from scipy.stats._warnings_errors import DegenerateDataWarning @@ -51,6 +53,8 @@ from .type_utils import Type, isoftype, parse_type_string, to_type_string from .utils import deep_copy, recursive_copy +logger = evaluate.logging.get_logger(__name__) + logger = get_logger() settings = get_settings() @@ -374,8 +378,7 @@ def bootstrap(self, data: List[Any], score_names: List[str]): return result -from typing import Generic, TypeVar, NamedTuple -from dataclasses import dataclass +from typing import Generic, TypeVar IntermediateType = TypeVar("IntermediateType") PredictionType = TypeVar("PredictionType") @@ -627,9 +630,10 @@ def prepare(self): from sklearn.metrics import f1_score self._metric = f1_score - import regex from functools import partial + import regex + self.remove_punc = partial(regex.compile(r"\p{P}+").sub, "") def get_str_id(self, str): @@ -1781,13 +1785,13 @@ def exact_match(pred, gt): try: if answer == predict[0]: return 1.0 - elif predict[0] == "(" and answer == predict[1]: + if predict[0] == "(" and answer == predict[1]: return 1.0 - elif predict[0:7] == "option " and answer == predict[7]: + if predict[0:7] == "option " and answer == predict[7]: return 1.0 - elif predict[0:14] == "the answer is " and answer == predict[14]: + if predict[0:14] == "the answer is " and answer == predict[14]: return 1.0 - except Exception as e: + except Exception: return 0.0 return 0.0 @@ -1904,8 +1908,7 @@ def _to_float(text: str): if text.endswith("%"): # Convert percentages to floats. return float(text.rstrip("%")) / 100.0 - else: - return float(text) + return float(text) except ValueError: return None @@ -1936,8 +1939,7 @@ def relaxed_correctness( if prediction_float is not None and target_float: relative_change = abs(prediction_float - target_float) / abs(target_float) return relative_change <= max_relative_change - else: - return prediction.lower() == target.lower() + return prediction.lower() == target.lower() class WebsrcSquadF1(GlobalMetric): @@ -2727,8 +2729,6 @@ def prepare(self): import importlib.util as iua import os - import requests - # download finqa evaluation script, load as a module and use it on the fly def download_finqa_eval_script_file(url, local_path, hash_of_script): if not os.path.exists(local_path): @@ -4612,8 +4612,6 @@ def create_metric_request(predictions, references, additional_inputs): return MetricRequest(instance_inputs=instance_inputs) def get_metric_response(self, metric_request: MetricRequest) -> MetricResponse: - import requests - response = requests.post( url=self.get_metric_url(), json=metric_request.to_dict(), @@ -5947,3 +5945,66 @@ def get_probabilities(self, top_tokens_list): torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0, ).numpy() + + +class ExecutionAccuracy(InstanceMetric): + reduction_map = {"mean": ["execution_accuracy"]} + main_score = "execution_accuracy" + ci_scores = ["execution_accuracy"] + + prediction_type = "Any" # string representation is compared + sql_timeout = 100.0 + + def run_sql_and_match(self, predicted_sql: str, gold_sql: str, connector) -> int: + """Runs SQL queries using the provided connector and checks if the results match.""" + if predicted_sql.strip() == gold_sql.strip(): + return 1 + + try: + pred_res = connector.execute_query(predicted_sql) + gold_res = connector.execute_query(gold_sql) + + if pred_res is None or gold_res is None: + return 0 # Treat execution error as mismatch + + return int(pred_res["results"] == gold_res["results"]) + except Exception as e: + logger.error(f"Error in run_sql_and_match: {e}") + return 0 + + def compute(self, references: List[Any], prediction: str, task_data: Dict) -> dict: + from .db_utils import get_db_connector + + try: + from func_timeout import FunctionTimedOut, func_timeout + except ImportError as err: + raise ImportError( + "func_timeout should be installed for this metric" + ) from err + + predicted_sql = prediction + execution_result: float = 0.0 + + if predicted_sql and predicted_sql.strip() != "": + if not predicted_sql.startswith("SELECT") and "SELECT" in predicted_sql: + predicted_sql = predicted_sql[predicted_sql.find("SELECT") :] + if ";" in predicted_sql: + predicted_sql = predicted_sql[: predicted_sql.find(";") + 1] + + db_connector = get_db_connector(task_data["db"]["db_type"])(task_data["db"]) + + try: + execution_result = func_timeout( + self.sql_timeout, + self.run_sql_and_match, + args=(predicted_sql, references[0], db_connector), + ) # type: ignore + except FunctionTimedOut: + logger.error("QUERY TIMEOUT, returning score=0 for this instance") + execution_result = 0.0 + + result = {self.main_score: float(execution_result)} + logger.debug(f"Result: {result}") + result["score"] = result[self.main_score] + result["score_name"] = self.main_score + return result diff --git a/src/unitxt/processors.py b/src/unitxt/processors.py index f6a23f6f3..f9858f0a6 100644 --- a/src/unitxt/processors.py +++ b/src/unitxt/processors.py @@ -412,6 +412,45 @@ def process_value(self, text: Any) -> Any: return " ".join(text.split()) +class AddPrefix(FieldOperator): + prefix: str + + def process_value(self, text: str) -> str: + text = text.strip() + if text.startswith(self.prefix): + return text + return self.prefix + text.strip() + + +class GetSQL(FieldOperator): + def process_value(self, text: str) -> str: + """Extracts the first SQL query from a given text. + + Args: + text: The input string containing the SQL query. + + Returns: + The first SQL query found in the text, or None if no query is found. + """ + match = re.search( + r"(?:```)?.*?(SELECT.*?(?:FROM|WITH|;|$).*?)(?:```|;|$)", + text, + re.IGNORECASE | re.DOTALL, + ) + + if match: + out = ( + text[match.start() : match.end()] + .replace("```", "") + .replace(";", "") + .strip() + ) + else: + out = "No query found in generation" + + return out + + class ScaleNumberToZeroOneReturnZeroIfFails(FieldOperator): max_val = 10 min_val = 0 diff --git a/src/unitxt/serializers.py b/src/unitxt/serializers.py index fa2a716d9..3ca9e767f 100644 --- a/src/unitxt/serializers.py +++ b/src/unitxt/serializers.py @@ -4,10 +4,20 @@ from typing import Any, Dict, List, Union from .dataclass import AbstractField, Field +from .db_utils import get_db_connector from .operators import InstanceFieldOperator from .settings_utils import get_constants from .type_utils import isoftype, to_type_string -from .types import Dialog, Document, Image, MultiDocument, Number, Table, Video +from .types import ( + Dialog, + Document, + Image, + MultiDocument, + Number, + SQLDatabase, + Table, + Video, +) constants = get_constants() @@ -176,3 +186,18 @@ def serialize(self, value: Any, instance: Dict[str, Any]) -> Any: return serializer.serialize(value, instance) return str(value) + + +class SQLDatabaseAsSchemaSerializer(SingleTypeSerializer): + """Serializes a database schema into a string representation.""" + + serialized_type = SQLDatabase + + def serialize(self, value: SQLDatabase, instance: Dict[str, Any]) -> str: + connector = get_db_connector(value["db_type"])(value) + try: + return connector.get_table_schema() + except Exception as e: + raise RuntimeError( + f"Failed to serialize SQL schema for database '{value.db_id}' using connector {connector.__class__.__name__}: {e}" + ) from e diff --git a/src/unitxt/templates.py b/src/unitxt/templates.py index a1322c0da..fa50d61d9 100644 --- a/src/unitxt/templates.py +++ b/src/unitxt/templates.py @@ -17,6 +17,7 @@ MultiTypeSerializer, NumberQuantizingSerializer, Serializer, + SQLDatabaseAsSchemaSerializer, TableSerializer, VideoSerializer, ) @@ -64,6 +65,7 @@ class Template(InstanceOperator): TableSerializer(), DialogSerializer(), ListSerializer(), + SQLDatabaseAsSchemaSerializer(), ] ) ) diff --git a/src/unitxt/types.py b/src/unitxt/types.py index e6ca47d1c..e0ef69212 100644 --- a/src/unitxt/types.py +++ b/src/unitxt/types.py @@ -1,4 +1,4 @@ -from typing import Any, List, Literal, NewType, TypedDict, Union +from typing import Any, Dict, List, Literal, NewType, Optional, TypedDict, Union from .type_utils import register_type @@ -45,6 +45,13 @@ class Table(TypedDict): rows: List[List[Any]] +class SQLDatabase(TypedDict): + db_id: Optional[str] + db_type: Literal["local", "in_memory", "remote"] + dbms: Optional[str] + data: Optional[Dict[str, Dict]] + + register_type(Text) register_type(Number) register_type(Turn) @@ -56,3 +63,4 @@ class Table(TypedDict): register_type(Document) register_type(MultiDocument) register_type(RagResponse) +register_type(SQLDatabase) diff --git a/tests/library/test_collections_operators.py b/tests/library/test_collections_operators.py index 8af98b716..a2964e2df 100644 --- a/tests/library/test_collections_operators.py +++ b/tests/library/test_collections_operators.py @@ -7,6 +7,13 @@ Slice, Wrap, ) +from unitxt.processors import ( + AddPrefix, + FixWhiteSpace, + GetSQL, + RemoveArticles, + RemovePunctuations, +) from unitxt.test_utils.operators import check_operator from tests.utils import UnitxtTestCase @@ -113,3 +120,157 @@ def test_chunk(self): inputs = [{"x": [0, 1, 2]}, {"x": [0, 1]}, {"x": [3]}] targets = [{"x": [[0, 1], [2]]}, {"x": [[0, 1]]}, {"x": [[3]]}] check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) + + def test_add_prefix(self): + operator = AddPrefix(field="text", prefix="Hello ") + inputs = [{"text": "World"}, {"text": "Hello there"}, {"text": " Hello again"}] + targets = [ + {"text": "Hello World"}, + {"text": "Hello there"}, + {"text": "Hello again"}, + ] + check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) + + def test_get_sql_with_simple_query(self): + operator = GetSQL(field="text") + inputs = [{"text": "SELECT * FROM table;"}] + targets = [{"text": "SELECT * FROM table"}] + check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) + + def test_get_sql_with_code_block(self): + operator = GetSQL(field="text") + inputs = [{"text": "```SELECT id FROM table```"}] + targets = [{"text": "SELECT id FROM table"}] + check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) + + def test_get_sql_with_no_query(self): + operator = GetSQL(field="text") + inputs = [{"text": "No SQL query here"}] + targets = [{"text": "No query found in generation"}] + check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) + + def test_get_sql_with_complex_query(self): + operator = GetSQL(field="text") + inputs = [ + { + "text": "```SELECT column1, column2, column3\n" + "FROM table_name\n" + "WHERE condition1 = 'value1'\n" + " OR condition2 BETWEEN 'value2' AND 'value3'\n" + " AND condition3 IN ('value4', 'value5', 'value6')\n" + "ORDER BY column1 ASC, column2 DESC\n" + "LIMIT 10 OFFSET 5;```" + } + ] + targets = [ + { + "text": "SELECT column1, column2, column3\n" + "FROM table_name\n" + "WHERE condition1 = 'value1'\n" + " OR condition2 BETWEEN 'value2' AND 'value3'\n" + " AND condition3 IN ('value4', 'value5', 'value6')\n" + "ORDER BY column1 ASC, column2 DESC\n" + "LIMIT 10 OFFSET 5" + } + ] + check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) + + def test_get_sql_with_multiple_queries(self): + operator = GetSQL(field="text") + inputs = [{"text": "SELECT * FROM table1; SELECT * FROM table2;"}] + targets = [{"text": "SELECT * FROM table1"}] + check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) + + def test_get_sql_with_no_semicolon(self): + operator = GetSQL(field="text") + inputs = [{"text": "SELECT * FROM table"}] + targets = [{"text": "SELECT * FROM table"}] + check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) + + def test_get_sql_with_multiple_selects(self): + operator = GetSQL(field="text") + inputs = [ + { + "text": "SELECT column1 FROM table1; \n Some text in the middle \n SELECT column2 FROM table2" + } + ] + targets = [{"text": "SELECT column1 FROM table1"}] + check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) + + def test_get_sql_with_with_clause(self): + operator = GetSQL(field="text") + inputs = [ + { + "text": "WITH regional_sales AS (SELECT region, SUM(amount) AS total_sales FROM sales_data GROUP BY region) SELECT region FROM regional_sales" + } + ] + targets = [ + { + "text": "WITH regional_sales AS (SELECT region, SUM(amount) AS total_sales FROM sales_data GROUP BY region) SELECT region FROM regional_sales" + } + ] + check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) + + def test_remove_articles_with_empty_input(self): + operator = RemoveArticles(field="text") + inputs = [{"text": ""}] + targets = [{"text": ""}] + check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) + + def test_remove_articles_with_no_articles(self): + operator = RemoveArticles(field="text") + inputs = [{"text": "Hello world!"}] + targets = [{"text": "Hello world!"}] + check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) + + def test_remove_punctuations(self): + operator = RemovePunctuations(field="text") + inputs = [ + {"text": "Hello, world!"}, + {"text": "This is a sentence with punctuation: .,;!?"}, + {"text": "No punctuation here"}, + ] + targets = [ + {"text": "Hello world"}, + {"text": "This is a sentence with punctuation "}, + {"text": "No punctuation here"}, + ] + check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) + + def test_remove_punctuations_with_empty_input(self): + operator = RemovePunctuations(field="text") + inputs = [{"text": ""}] + targets = [{"text": ""}] + check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) + + def test_remove_punctuations_with_only_punctuations(self): + operator = RemovePunctuations(field="text") + inputs = [{"text": ".,;!?"}] + targets = [{"text": ""}] + check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) + + def test_fix_white_space(self): + operator = FixWhiteSpace(field="text") + inputs = [ + {"text": " This is a test "}, + {"text": "NoExtraSpacesHere"}, + {"text": " "}, + ] + targets = [ + {"text": "This is a test"}, + {"text": "NoExtraSpacesHere"}, + {"text": ""}, + ] + check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) + + def test_fix_white_space_with_empty_input(self): + operator = FixWhiteSpace(field="text") + inputs = [{"text": ""}] + targets = [{"text": ""}] + check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) + + def test_fix_white_space_with_newline_and_tabs(self): + operator = FixWhiteSpace(field="text") + inputs = [{"text": " \tThis is a\n test with \t\nspaces."}] + targets = [{"text": "This is a test with spaces."}] + check_operator(operator=operator, inputs=inputs, targets=targets, tester=self) diff --git a/tests/library/test_db_utils.py b/tests/library/test_db_utils.py new file mode 100644 index 000000000..22b5970c6 --- /dev/null +++ b/tests/library/test_db_utils.py @@ -0,0 +1,718 @@ +# import os +# import sqlite3 +# import tempfile +# import unittest +# from unittest.mock import MagicMock, patch + +# import requests +# from unitxt.db_utils import ( +# InMemoryDatabaseConnector, +# LocalSQLiteConnector, +# RemoteDatabaseConnector, +# ) +# from unitxt.types import SQLDatabase + + +# class TestRemoteDatabaseConnector(unittest.TestCase): +# def setUp(self): +# # Set up any necessary environment variables or configurations +# self.env_patcher = patch.dict( +# os.environ, +# {"SQL_API_KEY": "test_api_key"}, # pragma: allowlist-secret +# clear=True, +# ) +# self.env_patcher.start() + +# self.db_config: SQLDatabase = { +# "db_type": "remote", +# "db_id": "https://testapi.com/api,db_id=test_db_id", +# "dbms": None, +# "data": None, +# } + +# def tearDown(self): +# # Clean up any resources or configurations +# self.env_patcher.stop() + +# def test_init_success(self): +# connector = RemoteDatabaseConnector(self.db_config) +# self.assertEqual(connector.api_url, "https://testapi.com/api") +# self.assertEqual(connector.database_id, "test_db_id") +# self.assertEqual(connector.api_key, "test_api_key") +# self.assertEqual(connector.base_headers["Authorization"], "Bearer test_api_key") + +# def test_init_missing_api_url(self): +# self.db_config["db_id"] = ",db_id=test_db_id" +# with self.assertRaises(ValueError): +# RemoteDatabaseConnector(self.db_config) + +# def test_init_missing_api_key(self): +# os.environ.pop("SQL_API_KEY") +# with self.assertRaises(ValueError): +# RemoteDatabaseConnector(self.db_config) + +# @patch("requests.post") +# def test_get_table_schema_success(self, mock_post): +# mock_response = MagicMock() +# mock_response.status_code = 200 +# mock_response.json.return_value = { +# "schema": { +# "tables": [ +# {"table_name": "table1", "columns": [{"column_name": "tab1col1"}]}, +# {"table_name": "table2", "columns": [{"column_name": "tab2col1"}]}, +# ] +# } +# } +# mock_post.return_value = mock_response + +# connector = RemoteDatabaseConnector(self.db_config) +# schema_text = connector.get_table_schema() + +# self.assertEqual( +# schema_text, +# "Table: table1 has columns: ['tab1col1']\nTable: table2 has columns: ['tab2col1']\n", +# ) + +# @patch("requests.post") +# def test_get_table_schema_failure(self, mock_post): +# mock_response = MagicMock() +# mock_response.status_code = 400 +# mock_post.return_value = mock_response + +# connector = RemoteDatabaseConnector(self.db_config) + +# with self.assertRaises(OSError): +# connector.get_table_schema() + +# @patch("requests.post") +# def test_execute_query_success(self, mock_post): +# mock_response = MagicMock() +# mock_response.status_code = 200 +# mock_response.json.return_value = {"result": "success"} +# mock_post.return_value = mock_response + +# connector = RemoteDatabaseConnector(self.db_config) +# result = connector.execute_query("SELECT * FROM table1") + +# self.assertEqual(result, {"result": "success"}) +# mock_post.assert_called_once_with( +# "https://testapi.com/api/sql", +# headers=connector.base_headers, +# json={"sql": "SELECT * FROM table1", "dataSourceId": "test_db_id"}, +# verify=True, +# timeout=RemoteDatabaseConnector.TIMEOUT, +# ) + +# @patch("requests.post") +# def test_execute_query_failure(self, mock_post): +# mock_post.side_effect = requests.exceptions.RequestException("API Error") + +# connector = RemoteDatabaseConnector(self.db_config) +# result = connector.execute_query("SELECT * FROM table1") + +# self.assertIsNone(result) + + +# class TestLocalSQLiteConnector(unittest.TestCase): +# def setUp(self): +# # Create a temporary directory for testing +# self.temp_dir = tempfile.TemporaryDirectory() + +# # Create a dummy SQLite database +# self.db_id = "test_db" +# self.db_path = os.path.join(self.temp_dir.name, self.db_id + ".sqlite") +# conn = sqlite3.connect(self.db_path) +# cursor = conn.cursor() +# cursor.execute("CREATE TABLE table1 (tab1col1 TEXT, tab1col2 INTEGER)") +# cursor.execute( +# "INSERT INTO table1 VALUES ('value1', 1), ('value2', 2)" +# ) # Insert data into table1 +# cursor.execute("CREATE TABLE table2 (tab2col1 REAL, tab2col2 TEXT)") +# cursor.execute( +# "INSERT INTO table2 VALUES (3.14, 'pi'), (2.71, 'e')" +# ) # Insert data into table2 +# cursor.execute("CREATE TABLE sequence (name,seq)") +# cursor.execute("INSERT INTO sequence VALUES ('table1', 2), ('table2', 2)") +# conn.commit() +# conn.close() + +# self.db_config: SQLDatabase = { +# "db_type": "local", +# "db_id": self.db_id, +# "dbms": "sqlite", +# "data": None, +# } + +# def tearDown(self): +# # Clean up the temporary directory +# self.temp_dir.cleanup() + +# @patch( +# "unitxt.db_utils.LocalSQLiteConnector.get_db_file_path", +# side_effect=FileNotFoundError("Database file not found."), +# ) +# def test_init_database_not_found(self, mock_get_db_file_path): +# with self.assertRaises(FileNotFoundError): +# LocalSQLiteConnector(self.db_config) + +# @patch( +# "unitxt.db_utils.LocalSQLiteConnector.get_db_file_path", +# side_effect=FileExistsError("More than one file matched for db_id"), +# ) +# def test_init_multiple_databases_found(self, mock_get_db_file_path): +# with self.assertRaises(FileExistsError): +# LocalSQLiteConnector(self.db_config) + + +# class TestInMemoryDatabaseConnector(unittest.TestCase): +# def setUp(self): +# self.db_config: SQLDatabase = { +# "db_type": "in_memory", +# "db_id": None, +# "dbms": None, +# "data": { +# "users": { +# "columns": ["user_id", "name", "email", "age", "city"], +# "rows": [ +# [1, "Alice", "alice@example.com", 30, "New York"], +# [2, "Bob", "bob@example.com", 25, "Los Angeles"], +# [3, "Charlie", "charlie@example.com", 40, "Chicago"], +# [4, "David", "david@example.com", 35, "New York"], +# [5, "Eva", "eva@example.com", 28, "Los Angeles"], +# ], +# }, +# "orders": { +# "columns": ["order_id", "user_id", "product", "quantity", "price"], +# "rows": [ +# [101, 1, "Laptop", 2, 1200.00], +# [102, 1, "Mouse", 5, 25.50], +# [103, 2, "Keyboard", 3, 75.00], +# [104, 3, "Monitor", 1, 300.00], +# [105, 3, "USB Drive", 10, 15.00], +# [106, 4, "Headphones", 2, 100.00], +# [107, 5, "Webcam", 1, 80.00], +# [108, 5, "Printer", 1, 250.00], +# [109, 5, "Laptop", 1, 1300.00], +# [110, 5, "Mouse", 2, 24.00], +# ], +# }, +# }, +# } + +# def test_init_success(self): +# connector = InMemoryDatabaseConnector(self.db_config) +# self.assertEqual(connector.tables, self.db_config["data"]) + +# def test_init_missing_tables(self): +# self.db_config["data"] = None +# with self.assertRaises(ValueError): +# InMemoryDatabaseConnector(self.db_config) + +# def test_get_table_schema(self): +# connector = InMemoryDatabaseConnector(self.db_config) +# schema_text = connector.get_table_schema() +# expected_schema = ( +# "CREATE TABLE `users` (`user_id` TEXT, `name` TEXT, `email` TEXT, `age` TEXT, `city` TEXT);\n\n" +# "CREATE TABLE `orders` (`order_id` TEXT, `user_id` TEXT, `product` TEXT, `quantity` TEXT, `price` TEXT);" +# ) +# self.assertEqual(schema_text, expected_schema) + +# def test_get_table_schema_with_selected_tables(self): +# connector = InMemoryDatabaseConnector(self.db_config) +# schema_text = connector.get_table_schema(select_tables=["orders"]) +# expected_schema = "CREATE TABLE `orders` (`order_id` TEXT, `user_id` TEXT, `product` TEXT, `quantity` TEXT, `price` TEXT);" + +# self.assertEqual(schema_text, expected_schema) + +# def test_execute_query_success(self): +# connector = InMemoryDatabaseConnector(self.db_config) +# result = connector.execute_query("SELECT * FROM users WHERE age > 30") +# expected_result = [ +# (3, "Charlie", "charlie@example.com", 40, "Chicago"), +# (4, "David", "david@example.com", 35, "New York"), +# ] + +# self.assertEqual(result, expected_result) + +# def test_execute_query_failure(self): +# connector = InMemoryDatabaseConnector(self.db_config) +# result = connector.execute_query("SELECT * FROM non_existent_table") + +# self.assertIsNone(result) + +# def test_execute_complex_query(self): +# connector = InMemoryDatabaseConnector(self.db_config) +# query = """ +# SELECT u.name, o.product, o.quantity +# FROM users u +# JOIN orders o ON u.user_id = o.user_id +# WHERE u.city = 'Los Angeles' +# ORDER BY o.quantity DESC +# """ +# result = connector.execute_query(query) +# expected_result = [ +# ("Bob", "Keyboard", 3), +# ("Eva", "Mouse", 2), +# ("Eva", "Laptop", 1), +# ("Eva", "Printer", 1), +# ("Eva", "Webcam", 1), +# ] +# self.assertEqual(result, expected_result) + +# def test_execute_query_with_aggregation(self): +# connector = InMemoryDatabaseConnector(self.db_config) +# query = """ +# SELECT u.city, AVG(u.age) +# FROM users u +# GROUP BY u.city +# ORDER BY u.city ASC +# """ +# result = connector.execute_query(query) +# expected_result = [ +# ("Chicago", 40.0), +# ("Los Angeles", 26.5), +# ("New York", 32.5), +# ] +# self.assertEqual(result, expected_result) + +# def test_execute_query_with_sum_and_having(self): +# connector = InMemoryDatabaseConnector(self.db_config) +# query = """ +# SELECT u.name, SUM(o.price) +# FROM users u +# JOIN orders o on u.user_id = o.user_id +# GROUP BY u.name +# HAVING SUM(o.price) > 300 +# ORDER BY u.name DESC +# """ +# result = connector.execute_query(query) +# expected_result = [("Eva", 1654.0), ("Charlie", 315.0), ("Alice", 1225.5)] + +# self.assertEqual(result, expected_result) + +import os +import sqlite3 +import tempfile +import unittest +from unittest.mock import MagicMock, patch + +import requests +from unitxt.db_utils import ( + InMemoryDatabaseConnector, + LocalSQLiteConnector, + RemoteDatabaseConnector, +) +from unitxt.types import SQLDatabase + + +class TestRemoteDatabaseConnector(unittest.TestCase): + def setUp(self): + # Set up any necessary environment variables or configurations + self.env_patcher = patch.dict( + os.environ, + {"SQL_API_KEY": "test_api_key"}, # pragma: allowlist-secret + clear=True, + ) + self.env_patcher.start() + + self.db_config: SQLDatabase = { + "db_type": "remote", + "db_id": "https://testapi.com/api,db_id=test_db_id", + "dbms": None, + "data": None, + } + + def tearDown(self): + # Clean up any resources or configurations + self.env_patcher.stop() + + def test_init_success(self): + connector = RemoteDatabaseConnector(self.db_config) + self.assertEqual(connector.api_url, "https://testapi.com/api") + self.assertEqual(connector.database_id, "test_db_id") + self.assertEqual(connector.api_key, "test_api_key") + self.assertEqual(connector.base_headers["Authorization"], "Bearer test_api_key") + + def test_init_missing_api_url(self): + self.db_config["db_id"] = ",db_id=test_db_id" + with self.assertRaises(ValueError): + RemoteDatabaseConnector(self.db_config) + + def test_init_missing_api_key(self): + os.environ.pop("SQL_API_KEY") + with self.assertRaises(ValueError): + RemoteDatabaseConnector(self.db_config) + + @patch("requests.post") + def test_get_table_schema_success(self, mock_post): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "schema": { + "tables": [ + {"table_name": "table1", "columns": [{"column_name": "tab1col1"}]}, + {"table_name": "table2", "columns": [{"column_name": "tab2col1"}]}, + ] + } + } + mock_post.return_value = mock_response + + connector = RemoteDatabaseConnector(self.db_config) + schema_text = connector.get_table_schema() + + self.assertEqual( + schema_text, + "Table: table1 has columns: ['tab1col1']\nTable: table2 has columns: ['tab2col1']\n", + ) + + @patch("requests.post") + def test_get_table_schema_failure(self, mock_post): + mock_response = MagicMock() + mock_response.status_code = 400 + mock_post.return_value = mock_response + + connector = RemoteDatabaseConnector(self.db_config) + + with self.assertRaises(OSError): + connector.get_table_schema() + + @patch("requests.post") + def test_execute_query_success(self, mock_post): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"result": "success"} + mock_post.return_value = mock_response + + connector = RemoteDatabaseConnector(self.db_config) + result = connector.execute_query("SELECT * FROM table1") + + self.assertEqual(result, {"result": "success"}) + mock_post.assert_called_once_with( + "https://testapi.com/api/sql", + headers=connector.base_headers, + json={"sql": "SELECT * FROM table1", "dataSourceId": "test_db_id"}, + verify=True, + timeout=RemoteDatabaseConnector.TIMEOUT, + ) + + @patch("requests.post") + def test_execute_query_failure(self, mock_post): + mock_post.side_effect = requests.exceptions.RequestException("API Error") + + connector = RemoteDatabaseConnector(self.db_config) + result = connector.execute_query("SELECT * FROM table1") + + self.assertIsNone(result) + + @patch("requests.post") + def test_execute_query_retries_on_connection_error(self, mock_post): + mock_post.side_effect = [ + requests.exceptions.ConnectionError("Connection Error"), + MagicMock(status_code=200, json=lambda: {"result": "success"}), + ] + + connector = RemoteDatabaseConnector(self.db_config) + result = connector.execute_query("SELECT * FROM table1") + + self.assertEqual(result, {"result": "success"}) + self.assertEqual(mock_post.call_count, 2) + + @patch("requests.post") + def test_execute_query_retries_on_timeout(self, mock_post): + mock_post.side_effect = [ + requests.exceptions.ReadTimeout("Read Timeout"), + MagicMock(status_code=200, json=lambda: {"result": "success"}), + ] + + connector = RemoteDatabaseConnector(self.db_config) + result = connector.execute_query("SELECT * FROM table1") + + self.assertEqual(result, {"result": "success"}) + self.assertEqual(mock_post.call_count, 2) + + @patch("requests.post") + def test_execute_query_retries_on_server_error(self, mock_post): + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError( + "Server Error", response=mock_response + ) + mock_post.side_effect = [ + mock_response, + MagicMock(status_code=200, json=lambda: {"result": "success"}), + ] + + connector = RemoteDatabaseConnector(self.db_config) + result = connector.execute_query("SELECT * FROM table1") + + self.assertEqual(result, {"result": "success"}) + self.assertEqual(mock_post.call_count, 2) + + @patch("requests.post") + def test_execute_query_max_retries_exceeded(self, mock_post): + mock_post.side_effect = requests.exceptions.ConnectionError("Connection Error") + + connector = RemoteDatabaseConnector(self.db_config) + result = connector.execute_query("SELECT * FROM table1") + + self.assertIsNone(result) + self.assertEqual(mock_post.call_count, 4) # 1 initial call + 3 retries + + +class TestLocalSQLiteConnector(unittest.TestCase): + def setUp(self): + # Create a temporary directory for testing + self.temp_dir = tempfile.TemporaryDirectory() + + # Create a dummy SQLite database + self.db_id = "test_db" + self.db_path = os.path.join(self.temp_dir.name, self.db_id + ".sqlite") + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute("CREATE TABLE table1 (tab1col1 TEXT, tab1col2 INTEGER)") + cursor.execute( + "INSERT INTO table1 VALUES ('value1', 1), ('value2', 2)" + ) # Insert data into table1 + cursor.execute("CREATE TABLE table2 (tab2col1 REAL, tab2col2 TEXT)") + cursor.execute( + "INSERT INTO table2 VALUES (3.14, 'pi'), (2.71, 'e')" + ) # Insert data into table2 + cursor.execute("CREATE TABLE sequence (name,seq)") + cursor.execute("INSERT INTO sequence VALUES ('table1', 2), ('table2', 2)") + conn.commit() + conn.close() + + self.db_config: SQLDatabase = { + "db_type": "local", + "db_id": self.db_id, + "dbms": "sqlite", + "data": None, + } + + def tearDown(self): + # Clean up the temporary directory + self.temp_dir.cleanup() + + @patch( + "unitxt.db_utils.LocalSQLiteConnector.get_db_file_path", + side_effect=FileNotFoundError("Database file not found."), + ) + def test_init_database_not_found(self, mock_get_db_file_path): + with self.assertRaises(FileNotFoundError): + LocalSQLiteConnector(self.db_config) + + @patch( + "unitxt.db_utils.LocalSQLiteConnector.get_db_file_path", + side_effect=FileExistsError("More than one file matched for db_id"), + ) + def test_init_multiple_databases_found(self, mock_get_db_file_path): + with self.assertRaises(FileExistsError): + LocalSQLiteConnector(self.db_config) + + def test_init_success(self): + with patch( + "unitxt.db_utils.LocalSQLiteConnector.get_db_file_path", + return_value=self.db_path, + ): + connector = LocalSQLiteConnector(self.db_config) + self.assertEqual(connector.db_path, self.db_path) + self.assertIsInstance(connector.conn, sqlite3.Connection) + self.assertIsInstance(connector.cursor, sqlite3.Cursor) + + def test_init_no_db_id(self): + self.db_config["db_id"] = None + with self.assertRaises(ValueError): + LocalSQLiteConnector(self.db_config) + + # def test_get_table_schema(self): + # with patch( + # "unitxt.db_utils.LocalSQLiteConnector.get_db_file_path", + # return_value=self.db_path, + # ): + # connector = LocalSQLiteConnector(self.db_config) + # schema_text = connector.get_table_schema() + + # self.assertIn("CREATE TABLE table1(tab1col1 TEXT, tab1col2 INTEGER)", schema_text) + # self.assertIn("CREATE TABLE table2(tab2col1 REAL, tab2col2 TEXT)", schema_text) + # self.assertNotIn("CREATE TABLE sequence", schema_text) + + def test_execute_query(self): + with patch( + "unitxt.db_utils.LocalSQLiteConnector.get_db_file_path", + return_value=self.db_path, + ): + connector = LocalSQLiteConnector(self.db_config) + result = connector.execute_query("SELECT * FROM table1") + self.assertEqual(len(result), 2) + self.assertEqual(result[0], ("value1", 1)) + self.assertEqual(result[1], ("value2", 2)) + + def test_execute_query_error(self): + with patch( + "unitxt.db_utils.LocalSQLiteConnector.get_db_file_path", + return_value=self.db_path, + ): + connector = LocalSQLiteConnector(self.db_config) + result = connector.execute_query("SELECT * FROM non_existent_table") + self.assertIsNone(result) + + # @patch( + # "unitxt.db_utils.LocalSQLiteConnector.download_database", + # side_effect=lambda x: None, + # ) + # def test_download_database(self, mock_download): + # with tempfile.TemporaryDirectory() as temp_dir: + # os.environ["UNITXT_ARTIFACTORIES"] = temp_dir + # connector = LocalSQLiteConnector(self.db_config) + # connector.download_database("bird/dev") + # mock_download.assert_called_with("bird/dev") + + # repo_dir = os.path.join( + # connector.databases_folder, "datasets", "premai-io--birdbench" + # ) + # hfapi = HfApi() + # hfapi.snapshot_download.assert_called_with( + # repo_id="premai-io/birdbench", + # repo_type="dataset", + # local_dir=repo_dir, + # force_download=False, + # allow_patterns="*validation*", + # ) + + def test_download_database_unsupported_db(self): + with self.assertRaises(NotImplementedError): + connector = LocalSQLiteConnector(self.db_config) + connector.download_database("unknown_db") + + +class TestInMemoryDatabaseConnector(unittest.TestCase): + def setUp(self): + self.db_config: SQLDatabase = { + "db_type": "in_memory", + "db_id": None, + "dbms": None, + "data": { + "users": { + "columns": ["user_id", "name", "email", "age", "city"], + "rows": [ + [1, "Alice", "alice@example.com", 30, "New York"], + [2, "Bob", "bob@example.com", 25, "Los Angeles"], + [3, "Charlie", "charlie@example.com", 40, "Chicago"], + [4, "David", "david@example.com", 35, "New York"], + [5, "Eva", "eva@example.com", 28, "Los Angeles"], + ], + }, + "orders": { + "columns": ["order_id", "user_id", "product", "quantity", "price"], + "rows": [ + [101, 1, "Laptop", 2, 1200.00], + [102, 1, "Mouse", 5, 25.50], + [103, 2, "Keyboard", 3, 75.00], + [104, 3, "Monitor", 1, 300.00], + [105, 3, "USB Drive", 10, 15.00], + [106, 4, "Headphones", 2, 100.00], + [107, 5, "Webcam", 1, 80.00], + [108, 5, "Printer", 1, 250.00], + [109, 5, "Laptop", 1, 1300.00], + [110, 5, "Mouse", 2, 24.00], + ], + }, + }, + } + + def test_init_success(self): + connector = InMemoryDatabaseConnector(self.db_config) + self.assertEqual(connector.tables, self.db_config["data"]) + + def test_init_missing_tables(self): + self.db_config["data"] = None + with self.assertRaises(ValueError): + InMemoryDatabaseConnector(self.db_config) + + def test_get_table_schema(self): + connector = InMemoryDatabaseConnector(self.db_config) + schema_text = connector.get_table_schema() + expected_schema = ( + "CREATE TABLE `users` (`user_id` TEXT, `name` TEXT, `email` TEXT, `age` TEXT, `city` TEXT);\n\n" + "CREATE TABLE `orders` (`order_id` TEXT, `user_id` TEXT, `product` TEXT, `quantity` TEXT, `price` TEXT);" + ) + self.assertEqual(schema_text, expected_schema) + + def test_get_table_schema_with_selected_tables(self): + connector = InMemoryDatabaseConnector(self.db_config) + schema_text = connector.get_table_schema(select_tables=["orders"]) + expected_schema = "CREATE TABLE `orders` (`order_id` TEXT, `user_id` TEXT, `product` TEXT, `quantity` TEXT, `price` TEXT);" + + self.assertEqual(schema_text, expected_schema) + + def test_execute_query_success(self): + connector = InMemoryDatabaseConnector(self.db_config) + result = connector.execute_query("SELECT * FROM users WHERE age > 30") + expected_result = [ + (3, "Charlie", "charlie@example.com", 40, "Chicago"), + (4, "David", "david@example.com", 35, "New York"), + ] + + self.assertEqual(result, expected_result) + + def test_execute_query_failure(self): + connector = InMemoryDatabaseConnector(self.db_config) + result = connector.execute_query("SELECT * FROM non_existent_table") + + self.assertIsNone(result) + + def test_execute_complex_query(self): + connector = InMemoryDatabaseConnector(self.db_config) + query = """ + SELECT u.name, o.product, o.quantity + FROM users u + JOIN orders o ON u.user_id = o.user_id + WHERE u.city = 'Los Angeles' + ORDER BY o.quantity DESC + """ + result = connector.execute_query(query) + expected_result = [ + ("Bob", "Keyboard", 3), + ("Eva", "Mouse", 2), + ("Eva", "Laptop", 1), + ("Eva", "Printer", 1), + ("Eva", "Webcam", 1), + ] + self.assertEqual(result, expected_result) + + def test_execute_query_with_aggregation(self): + connector = InMemoryDatabaseConnector(self.db_config) + query = """ + SELECT u.city, AVG(u.age) + FROM users u + GROUP BY u.city + ORDER BY u.city ASC + """ + result = connector.execute_query(query) + expected_result = [ + ("Chicago", 40.0), + ("Los Angeles", 26.5), + ("New York", 32.5), + ] + self.assertEqual(result, expected_result) + + def test_execute_query_with_sum_and_having(self): + connector = InMemoryDatabaseConnector(self.db_config) + query = """ + SELECT u.name, SUM(o.price) + FROM users u + JOIN orders o on u.user_id = o.user_id + GROUP BY u.name + HAVING SUM(o.price) > 300 + ORDER BY u.name DESC + """ + result = connector.execute_query(query) + expected_result = [("Eva", 1654.0), ("Charlie", 315.0), ("Alice", 1225.5)] + + self.assertEqual(result, expected_result) + + def test_execute_query_empty_table(self): + self.db_config["data"]["empty_table"] = {"columns": ["id"], "rows": []} + connector = InMemoryDatabaseConnector(self.db_config) + result = connector.execute_query("SELECT * FROM empty_table") + self.assertEqual(result, []) diff --git a/tests/library/test_metrics.py b/tests/library/test_metrics.py index 57bb8044f..dbca7fdad 100644 --- a/tests/library/test_metrics.py +++ b/tests/library/test_metrics.py @@ -12,6 +12,7 @@ BinaryMaxAccuracy, BinaryMaxF1, Detector, + ExecutionAccuracy, F1Binary, F1BinaryPosOnly, F1Fast, @@ -2113,3 +2114,183 @@ def test_metrics_ensemble(self): instance_targets=instance_targets, global_target=global_target, ) + + def test_execution_accuracy_correct_query_mock_db(self): + metric = ExecutionAccuracy() + predictions = ["SELECT name FROM employees WHERE department = 'Sales'"] + references = ["SELECT name FROM employees WHERE department = 'Sales';"] + task_data = [ + { + "db": { + "db_id": "mock_db", + "db_type": "in_memory", + "data": { + "employees": { + "columns": ["id", "name", "department", "salary"], + "rows": [ + (1, "Alice", "Sales", 50000), + (2, "Bob", "Engineering", 60000), + (3, "Charlie", "Sales", 55000), + ], + } + }, + } + } + ] + + outputs = metric.compute(references, predictions[0], task_data[0]) + self.assertEqual(1.0, outputs["score"]) + + def test_execution_accuracy_different_db_schema(self): + metric = ExecutionAccuracy() + predictions = [ + "SELECT product_name, price FROM products WHERE category = 'Electronics'" + ] + references = [ + "SELECT product_name, price FROM products WHERE category = 'Electronics';" + ] + task_data = [ + { + "db": { + "db_id": "products_db", + "db_type": "in_memory", + "data": { + "products": { + "columns": [ + "product_id", + "product_name", + "category", + "price", + ], + "rows": [ + (1, "Laptop", "Electronics", 1200), + (2, "Mouse", "Electronics", 25), + (3, "Shirt", "Clothing", 50), + (4, "Monitor", "Electronics", 300), + ], + } + }, + } + } + ] + + outputs = metric.compute(references, predictions[0], task_data[0]) + self.assertEqual(1.0, outputs["score"]) + + def test_execution_accuracy_multiple_tables(self): + metric = ExecutionAccuracy() + predictions = [ + "SELECT o.order_id, c.name FROM orders AS o JOIN customers AS c ON o.customer_id = c.customer_id WHERE o.status = 'Shipped'" + ] + references = [ + "SELECT o.order_id, c.name FROM orders AS o INNER JOIN customers AS c ON o.customer_id = c.customer_id WHERE o.status = 'Shipped';" + ] + task_data = [ + { + "db": { + "db_id": "sales_db", + "db_type": "in_memory", + "data": { + "customers": { + "columns": ["customer_id", "name", "city"], + "rows": [ + (1, "John Doe", "New York"), + (2, "Jane Smith", "Los Angeles"), + (3, "David Lee", "Chicago"), + ], + }, + "orders": { + "columns": ["order_id", "customer_id", "status"], + "rows": [ + (101, 1, "Shipped"), + (102, 2, "Pending"), + (103, 1, "Shipped"), + ], + }, + }, + } + } + ] + + outputs = metric.compute(references, predictions[0], task_data[0]) + self.assertEqual(1.0, outputs["score"]) + + def test_execution_accuracy_empty_result(self): + metric = ExecutionAccuracy() + predictions = ["SELECT name FROM employees WHERE department = 'HR'"] + references = ["SELECT name FROM employees WHERE department = 'HR';"] + task_data = [ + { + "db": { + "db_id": "mock_db", + "db_type": "in_memory", + "data": { + "employees": { + "columns": ["id", "name", "department", "salary"], + "rows": [ + (1, "Alice", "Sales", 50000), + (2, "Bob", "Engineering", 60000), + (3, "Charlie", "Sales", 55000), + ], + } + }, + } + } + ] + + outputs = metric.compute(references, predictions[0], task_data[0]) + self.assertEqual(1.0, outputs["score"]) + + def test_execution_accuracy_aggregation_query(self): + metric = ExecutionAccuracy() + predictions = ["SELECT AVG(salary) FROM employees"] + references = ["SELECT AVG(salary) FROM employees;"] + task_data = [ + { + "db": { + "db_id": "mock_db", + "db_type": "in_memory", + "data": { + "employees": { + "columns": ["id", "name", "department", "salary"], + "rows": [ + (1, "Alice", "Sales", 50000), + (2, "Bob", "Engineering", 60000), + (3, "Charlie", "Sales", 55000), + ], + } + }, + } + } + ] + + outputs = metric.compute(references, predictions[0], task_data[0]) + self.assertEqual(1.0, outputs["score"]) + + def test_execution_accuracy_incorrect_query(self): + metric = ExecutionAccuracy() + predictions = [ + "SELECT nme FROM employees WHERE department = 'Sales'" + ] # Incorrect column name 'nme' + references = ["SELECT name FROM employees WHERE department = 'Sales';"] + task_data = [ + { + "db": { + "db_id": "mock_db", + "db_type": "in_memory", + "data": { + "employees": { + "columns": ["id", "name", "department", "salary"], + "rows": [ + (1, "Alice", "Sales", 50000), + (2, "Bob", "Engineering", 60000), + (3, "Charlie", "Sales", 55000), + ], + } + }, + } + } + ] + + outputs = metric.compute(references, predictions[0], task_data[0]) + self.assertEqual(0.0, outputs["score"]) diff --git a/tests/library/test_serializers.py b/tests/library/test_serializers.py index 533c39992..c8a4ee8a2 100644 --- a/tests/library/test_serializers.py +++ b/tests/library/test_serializers.py @@ -1,13 +1,17 @@ +import unittest +from unittest.mock import MagicMock, patch + from unitxt.serializers import ( DefaultSerializer, DialogSerializer, MultiTypeSerializer, NumberQuantizingSerializer, NumberSerializer, + SQLDatabaseAsSchemaSerializer, TableSerializer, ) from unitxt.settings_utils import get_constants -from unitxt.types import Dialog, Number, Table, Text, Turn +from unitxt.types import Dialog, Number, SQLDatabase, Table, Text, Turn from tests.library.test_image_operators import create_random_jpeg_image from tests.utils import UnitxtTestCase @@ -139,3 +143,71 @@ def test_custom_serializer_with_number(self): number_data = Number(42) result = self.custom_serializer.serialize(number_data, {}) self.assertEqual(result, "42") # Should return the number as a string + + +class TestSQLDatabaseAsSchemaSerializer(unittest.TestCase): + def test_serialize_in_memory_success(self): + db_config: SQLDatabase = { + "db_type": "in_memory", + "db_id": None, + "dbms": None, + "data": { + "table1": {"columns": ["col1", "col2"], "rows": [[1, "a"], [2, "b"]]}, + "table2": {"columns": ["name", "age"], "rows": [["Alice", 30]]}, + }, + } + + serializer = SQLDatabaseAsSchemaSerializer() + result = serializer.serialize(db_config, {}) + expected_schema = ( + "CREATE TABLE `table1` (`col1` TEXT, `col2` TEXT);\n\n" + "CREATE TABLE `table2` (`name` TEXT, `age` TEXT);" + ) + self.assertEqual(result, expected_schema) + + @patch.dict( + "os.environ", + {"SQL_API_KEY": "test_api_key"}, # pragma: allowlist secret + clear=True, + ) # pragma: allowlist-secret + @patch("requests.post") + def test_serialize_remote_success(self, mock_post): + db_config: SQLDatabase = { + "db_type": "remote", + "db_id": "https://testapi.com/api,db_id=test_db_id", + "dbms": None, + "data": None, + } + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "schema": { + "tables": [ + {"table_name": "table1", "columns": [{"column_name": "col1"}]}, + {"table_name": "table2", "columns": [{"column_name": "name"}]}, + ] + } + } + mock_post.return_value = mock_response + + serializer = SQLDatabaseAsSchemaSerializer() + result = serializer.serialize(db_config, {}) + + expected_schema = ( + "Table: table1 has columns: ['col1']\n" + "Table: table2 has columns: ['name']\n" + ) + self.assertEqual(result, expected_schema) + + def test_serialize_unsupported_db_type(self): + db_config: SQLDatabase = { + "db_type": "unsupported", + "db_id": "test_db_id", + "dbms": None, + "data": None, + } + + serializer = SQLDatabaseAsSchemaSerializer() + with self.assertRaises(ValueError): + serializer.serialize(db_config, {}) diff --git a/utils/.secrets.baseline b/utils/.secrets.baseline index dbad38b82..d8906e302 100644 --- a/utils/.secrets.baseline +++ b/utils/.secrets.baseline @@ -184,5 +184,5 @@ } ] }, - "generated_at": "2025-01-15T12:35:17Z" + "generated_at": "2025-01-17T15:22:31Z" }