From 19248dd55eaa6af1dadbd406c4b2efb2c9044b67 Mon Sep 17 00:00:00 2001 From: Paul Abumov Date: Wed, 24 Apr 2024 11:40:35 -0400 Subject: [PATCH 1/3] Add Data Porter feature --- .../guides/how_to_contribute/db_migrations.md | 117 ++++ .../guides/how_to_contribute/documentation.md | 2 +- .../how_to_use/merge_dbs/_category_.yml | 7 + .../merge_dbs/custom_conflict_resolver.md | 38 ++ .../guides/how_to_use/merge_dbs/reference.md | 115 ++++ .../how_to_use/merge_dbs/simple_usage.md | 136 ++++ .../how_to_use/review_app/server_api.md | 4 +- mephisto/abstractions/database.py | 41 +- .../abstractions/databases/local_database.py | 634 ++++++++---------- .../databases/local_database_tables.py | 223 ++++++ .../databases/local_singleton_database.py | 35 +- ...0325_preparing_db_for_merge_dbs_command.py | 245 +++++++ .../databases/migrations/__init__.py | 12 + .../providers/mock/mock_datastore.py | 172 +++-- .../providers/mock/mock_datastore_export.py | 41 ++ .../providers/mock/mock_datastore_tables.py | 36 + .../providers/mturk/mturk_datastore.py | 155 ++--- .../providers/mturk/mturk_datastore_export.py | 56 ++ .../providers/mturk/mturk_datastore_tables.py | 63 ++ ...0325_preparing_db_for_merge_dbs_command.py | 164 +++++ .../providers/prolific/migrations/__init__.py | 12 + .../providers/prolific/prolific_datastore.py | 574 ++++++++++------ .../prolific/prolific_datastore_export.py | 68 ++ .../prolific/prolific_datastore_tables.py | 36 +- .../providers/prolific/prolific_provider.py | 6 +- .../providers/prolific/prolific_unit.py | 8 +- .../test/crowd_provider_tester.py | 19 +- .../test/data_model_database_tester.py | 51 +- mephisto/client/cli.py | 258 ++++++- .../form_composer/config_validation/utils.py | 11 +- mephisto/operations/operator.py | 7 +- mephisto/review_app/server/__init__.py | 2 +- .../api/views/qualification_workers_view.py | 10 +- .../server/api/views/qualifications_view.py | 2 +- .../review_app/server/api/views/stats_view.py | 8 +- .../review_app/server/api/views/tasks_view.py | 2 +- .../server/api/views/worker_block_view.py | 6 +- .../worker_granted_qualifications_view.py | 10 +- mephisto/review_app/server/db_queries.py | 2 +- mephisto/tools/db_data_porter/__init__.py | 7 + mephisto/tools/db_data_porter/backups.py | 183 +++++ .../conflict_resolvers/__init__.py | 28 + .../base_merge_conflict_resolver.py | 240 +++++++ .../default_merge_conflict_resolver.py | 45 ++ mephisto/tools/db_data_porter/constants.py | 215 ++++++ .../tools/db_data_porter/db_data_porter.py | 436 ++++++++++++ mephisto/tools/db_data_porter/dumps.py | 170 +++++ mephisto/tools/db_data_porter/import_dump.py | 334 +++++++++ .../tools/db_data_porter/randomize_ids.py | 189 ++++++ mephisto/tools/db_data_porter/validation.py | 91 +++ mephisto/utils/console_writer.py | 47 ++ mephisto/utils/db.py | 629 +++++++++++++++++ mephisto/utils/misc.py | 25 + mephisto/utils/testing.py | 8 +- 54 files changed, 5166 insertions(+), 869 deletions(-) create mode 100644 docs/web/docs/guides/how_to_contribute/db_migrations.md create mode 100644 docs/web/docs/guides/how_to_use/merge_dbs/_category_.yml create mode 100644 docs/web/docs/guides/how_to_use/merge_dbs/custom_conflict_resolver.md create mode 100644 docs/web/docs/guides/how_to_use/merge_dbs/reference.md create mode 100644 docs/web/docs/guides/how_to_use/merge_dbs/simple_usage.md create mode 100644 mephisto/abstractions/databases/local_database_tables.py create mode 100644 mephisto/abstractions/databases/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py create mode 100644 mephisto/abstractions/databases/migrations/__init__.py create mode 100644 mephisto/abstractions/providers/mock/mock_datastore_export.py create mode 100644 mephisto/abstractions/providers/mock/mock_datastore_tables.py create mode 100644 mephisto/abstractions/providers/mturk/mturk_datastore_export.py create mode 100644 mephisto/abstractions/providers/mturk/mturk_datastore_tables.py create mode 100644 mephisto/abstractions/providers/prolific/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py create mode 100644 mephisto/abstractions/providers/prolific/migrations/__init__.py create mode 100644 mephisto/abstractions/providers/prolific/prolific_datastore_export.py create mode 100644 mephisto/tools/db_data_porter/__init__.py create mode 100644 mephisto/tools/db_data_porter/backups.py create mode 100644 mephisto/tools/db_data_porter/conflict_resolvers/__init__.py create mode 100644 mephisto/tools/db_data_porter/conflict_resolvers/base_merge_conflict_resolver.py create mode 100644 mephisto/tools/db_data_porter/conflict_resolvers/default_merge_conflict_resolver.py create mode 100644 mephisto/tools/db_data_porter/constants.py create mode 100644 mephisto/tools/db_data_porter/db_data_porter.py create mode 100644 mephisto/tools/db_data_porter/dumps.py create mode 100644 mephisto/tools/db_data_porter/import_dump.py create mode 100644 mephisto/tools/db_data_porter/randomize_ids.py create mode 100644 mephisto/tools/db_data_porter/validation.py create mode 100644 mephisto/utils/console_writer.py create mode 100644 mephisto/utils/db.py create mode 100644 mephisto/utils/misc.py diff --git a/docs/web/docs/guides/how_to_contribute/db_migrations.md b/docs/web/docs/guides/how_to_contribute/db_migrations.md new file mode 100644 index 000000000..373fd3d86 --- /dev/null +++ b/docs/web/docs/guides/how_to_contribute/db_migrations.md @@ -0,0 +1,117 @@ +--- + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +sidebar_position: 4 +--- + +# Database migrations + +## Overview + +Currently we are not using any special framework for updating Mephisto database or provider-specific datastores. +This is how it's done: + +1. Each database should have table `migrations` where we store all applied or failed migrations +2. Every run of any Mephisto command will automatically attempt to apply unapplied migrations +3. Each migration is a Python module that contains one constant (a raw SQL query string) +4. After adding a migration, its constant must be imported and added to the migrations dict + under a readable name (dict key) that will be used in `migrations` table +5. Any database implementation, must call function `apply_migrations` in method `init_tables` (after creating all tables). + NOTE: Migrations must be applied before creating DB indices, as migrations may erase them without restoring. +6. When migrations fail, you will see a console log message in console. + The error will also be written to `migrations` table under `error_message` column with status `"errored"` + +## Details + +Let's see how exactly DB migrations should be created. + +We'll use Mephisto DB as example; the same set of steps is used for provider-specific databases +. + +### Add migration package + +To add a new migration package, follow these steps: + +1. Create Python-package `migrations` next to `mephisto/abstractions/databases/local_database.py`. +2. Create migration module in that package, e.g. `_001_20240101_add__column_name__in__table_name.py`. + Note leading underscore - Python does not allow importing modeuls that start with a number. +3. Populate module with a SQL query constant: + ```python + # + + """ + This migration introduces the following changes: + - ... + """ + + MY_SQL_MIGRATION_QUERY_NAME = """ + + """ + ``` +4. Include this SQL query constant in `__init__.py` module (located next to the migration module): + ```python + # + from ._001_20240101_add__column_name__in__table_name import * + + + migrations = { + "20240101_add__column_name__in__table_name": MY_SQL_MIGRATION_QUERY_NAME, + } + ``` + +5. Note that for now we support only forward migrations. +If you do need a backward migration, simply add it as a forward migration that would undo the undesired changes. + + +### Call `apply_migrations` function + +1. Import migrations in `mephisto/abstractions/databases/local_database.py`: + ```python + ... + from .migrations import migrations + ... + ``` +2. Apply migrations in `LocalMephistoDB`: + ```python + class LocalMephistoDB(MephistoDB): + ... + def init_tables(self) -> None: + with self.table_access_condition: + conn = self.get_connection() + conn.execute("PRAGMA foreign_keys = on;") + + with conn: + c = conn.cursor() + c.execute(tables.CREATE_IF_NOT_EXISTS_PROJECTS_TABLE) + ... + + apply_migrations(self, migrations) + ... + + with conn: + c.executescript(tables.CREATE_IF_NOT_EXISTS_CORE_INDICES) + ... + ``` + +## Maintenance of related code + +Making changes in databases must be carefully thought through and tested. + +This is a list of places that will most likely need to be synced with your DB change: + +1. All queries (involving tables that you have updated) in database class, e.g. `LocalMephistoDB` +2. Module with common database queries `mephisto/utils/db.py` +3. Queries in __Review App__ (`mephisto/review_app/server`) - it has its own set of specific queries +4. Names/relationships for tables and columns in __DBDataPorter__ (they're hardcoded in many places there), + within Mephisto DB and provider-specific databases. For instance: + - `mephisto/tools/db_data_porter/constants.py` + - `mephisto/tools/db_data_porter/import_dump.py` + - ... +5. Data processing within Mephisto itself (obviously) + +While we did our best to abstract away particular tables and fields structure, +they still have to be spelled out in some places. +Please run tests and check manually all Mephisto applications after performing database changes. diff --git a/docs/web/docs/guides/how_to_contribute/documentation.md b/docs/web/docs/guides/how_to_contribute/documentation.md index 1f435b337..328afccf7 100644 --- a/docs/web/docs/guides/how_to_contribute/documentation.md +++ b/docs/web/docs/guides/how_to_contribute/documentation.md @@ -4,7 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -sidebar_position: 4 +sidebar_position: 5 --- # Updating documentation diff --git a/docs/web/docs/guides/how_to_use/merge_dbs/_category_.yml b/docs/web/docs/guides/how_to_use/merge_dbs/_category_.yml new file mode 100644 index 000000000..919eb213e --- /dev/null +++ b/docs/web/docs/guides/how_to_use/merge_dbs/_category_.yml @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +label: "Merge databases" +collapsed: false +position: 9 diff --git a/docs/web/docs/guides/how_to_use/merge_dbs/custom_conflict_resolver.md b/docs/web/docs/guides/how_to_use/merge_dbs/custom_conflict_resolver.md new file mode 100644 index 000000000..77fd228d0 --- /dev/null +++ b/docs/web/docs/guides/how_to_use/merge_dbs/custom_conflict_resolver.md @@ -0,0 +1,38 @@ +--- + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +sidebar_position: 3 +--- + +# Custom conflict resolver + +When importing dump data into local DB, some rows may refer to the same object +(e.g. two Task rows with hte same value of "name" column). This class contains default logic +to resolve such merging conflicts (implemented for all currently present DBs). + +To change this default behavior, you can write your own coflict resolver class: +1. Add a new Python module next to this module (e.g. `my_conflict_resolver`) + +2. This module must contain a class (e.g. `MyMergeConflictResolver`) + that inherits from either `BaseMergeConflictResolver` + or default resolver `DefaultMergeConflictResolver` (also in this directory) + ```python + from .base_merge_conflict_resolver import BaseMergeConflictResolver + + class CustomMergeConflictResolver(BaseMergeConflictResolver): + default_strategy_name = "..." + strategies_config = {...} + ``` + +3. To use this newly created class, specify its name in import command: + `mephisto db import ... --conflict-resolver MyMergeConflictResolver` + +The easiest place to start customization is to modify `strategies_config` property, +and perhaps `default_strategy_name` value (see `DefaultMergeConflictResolver` as an example). + +NOTE: All available providers must be present in `strategies_config`. +Table names (under each provider key) are optional, and if missing, `default_strategy_name` +will be used for all conflicts related to this table. diff --git a/docs/web/docs/guides/how_to_use/merge_dbs/reference.md b/docs/web/docs/guides/how_to_use/merge_dbs/reference.md new file mode 100644 index 000000000..5cf58becb --- /dev/null +++ b/docs/web/docs/guides/how_to_use/merge_dbs/reference.md @@ -0,0 +1,115 @@ +--- + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +sidebar_position: 2 +--- + +# Reference + +This is a reference describing set of commands under the `mephisto db` command group. + +## Export + +This command exports data from Mephisto DB and provider-specific datastores +as a combination of (i) a JSON file, and (ii) an archived `data` catalog with related files. + +If no parameter passed, full data dump (i.e. backup) will be created. + +To pass a list of values for one command option, simply repeat that option name before each value. + +Examples: +``` +mephisto db export +mephisto db export --verbosity +mephisto db export --export-tasks-by-names "My first Task" +mephisto db export --export-tasks-by-ids 1 --export-tasks-by-ids 2 +mephisto db export --export-task-runs-by-ids 3 --export-task-runs-by-ids 4 +mephisto db export --export-task-runs-since-date 2024-01-01 +mephisto db export --export-task-runs-since-date 2023-01-01T00:00:00 +mephisto db export --export-labels first_dump --export-labels second_dump +mephisto db export --export-tasks-by-ids 1 --delete-exported-data --randomize-legacy-ids --export-indent 2 +``` + +Options (all optional): + +- `-tn/--export-tasks-by-names` - names of Tasks that will be exported +- `-ti/--export-tasks-by-ids` - ids of Tasks that will be exported +- `-tr/--export-task-runs-by-ids` - ids of TaskRuns that will be exported +- `-trs/--export-task-runs-since-date` - only objects created after this ISO8601 datetime will be exported +- `-tl/--export-labels` - only data imported under these labels will be exported +- `-de/--delete-exported-data` - after exporting data, delete it from local DB +- `-r/--randomize-legacy-ids` - replace legacy autoincremented ids with + new pseudo-random ids to avoid conflicts during data merging +- `-i/--export-indent` - make dump easy to read via formatting JSON with indentations +- `-v/--verbosity` - write more informative messages about progress (Default 0. Values: 0, 1) + +Note that the following options cannot be used together: +`--export-tasks-by-names`, `--export-tasks-by-ids`, `--export-task-runs-by-ids`, `--export-task-runs-since-date`, `--export-labels`. + + +## Import + +This command imports data from a dump file created by `mephisto db export` command. + +Examples: +``` +mephisto db import --dump-file + +mephisto db import --dump-file 2024_01_01_00_00_01_mephisto_dump.json --verbosity +mephisto db import --dump-file 2024_01_01_00_00_01_mephisto_dump.json --label-name my_first_dump +mephisto db import --dump-file 2024_01_01_00_00_01_mephisto_dump.json --conflict-resolver MyCustomMergeConflictResolver +mephisto db import --dump-file 2024_01_01_00_00_01_mephisto_dump.json --keep-import-metadata +``` + +Options: +- `-d/--dump-file` - location of the __***.json__ dump file (filename if created in + `/outputs/export` folder, or absolute filepath) +- `-cr/--conflict-resolver` (Optional) - name of Python class to be used for resolving merging conflicts + (when your local DB already has a row with same unique field value as a DB row in the dump data) +- `-l/--label-name` - a short string serving as a reference for the ported data (stored in `imported_data` table), + so later you can export the imported data with `--export-labels` export option +- `-k/--keep-import-metadata` - write data from `imported_data` table of the dump (by default it's not imported) +- `-v/--verbosity` - level of logging (default: 0; values: 0, 1) + +Note that before every import we create a full snapshot copy of your local data, by +archiving content of your `data` directory. If any data gets corrupte during the import, +you can always return to the original state by replacing your `data` folder with the snaphot. + +## Backup + +Creates full backup of all current data (Mephisto DB, provider-specific datastores, and related files) on local machine. + +``` +mephisto db backup +``` + + +## Restore + +Restores all data (Mephisto DB, provider-specific datastores, and related files) from a backup archive. + +Note that it will erase all current data, and you may want to run command `mephisto db backup` beforehand. + +Examples: +``` +mephisto db restore --backup-file + +mephisto db restore --backup-file 2024_01_01_00_10_01.zip +``` + +Options: +- `-b/--backup-file` - location of the __*.zip__ backup file (filename if created in + `/outputs/backup` folder, or absolute filepath) +- `-v/--verbosity` - level of logging (default: 0; values: 0, 1) + + +## Note on legacy PKs + +Prior to release `v1.4` of Mephisto, its DB schemas used auto-incremented integer primary keys. While convenient for debugging, it causes problems during data import/export. + +As of `v1.4` we have replaced these "legacy" PKs with quazi-random integers (for backward compatibility their values are designed to be above 1,000,000). + +If you do wish to use import/export commands with your "legacy" data, include the `--randomize-legacy-ids` option. It prevents data corruption when merging 2 sets of "legacy" data (because they will contain same integer PKs `1, 2, 3,...` for completely unrelated objects). diff --git a/docs/web/docs/guides/how_to_use/merge_dbs/simple_usage.md b/docs/web/docs/guides/how_to_use/merge_dbs/simple_usage.md new file mode 100644 index 000000000..3ebf9020a --- /dev/null +++ b/docs/web/docs/guides/how_to_use/merge_dbs/simple_usage.md @@ -0,0 +1,136 @@ +--- + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +sidebar_position: 1 +--- + +# Simple usage + + +## Introduction + +We realized that caompanies can be big, and they can run many tasks on different computers/servers, +or even one task in several departments or school classes. +But later it is much easier to review all tasks together. + +And here is the solution - merging tasks data into simple one. + + +## How it works + +1. You create full backup to save all your data to have the ability to roll all changes back if somthing went wrong +2. Export tasks into JSON dump file with related files in ZIP archive +3. Send or collect all dumps together +4. Use your main project or new Mehisto project to import all these dumps into it +5. Restore from backup if changed your mind or start from scratch + + +## Most common scenario of usage + +### Backup your main project + +If you already have some kind of main Mephisto project where all your tasks were processed, +you may want to merge a dump into this exact project. +We strongly recommend to make a backup of all your data manually and save it somewhere you can easily find. + +The command is: + +```shell +mephisto db backup +``` + +And you will see text like this + +``` +Started making backup +Finished successfully! File: '//outputs/backup/2024_01_01_00_00_01_mephisto_backup.zip +``` + +Find and copy this file. + +### Export data in dump + +To make a dump with all you data, use simple command: + +```shell +mephisto db export +``` + +if you want to export just 2 tasks from 10, you need to add an option: + +```shell +mephisto db export --export-tasks-by-names "My first Task" --export-tasks-by-names "My second Task" +``` + +If you run tasks before June 2024 you should use parameter `--randomize-legacy-ids`. +Why do you need this? We modified our Primary Keys in our databases. +They were autoincremented and in all you projects start from 1. +It will bring us into conflicts in all databases. +So, this parameter will regenerate randomly all Primary Keys and replace Foreign Keys with them as well. +Note that it will not affect databases, Primary Keys will be new only in dump. + +```shell +mephisto db export --randomize-legacy-ids +``` + +And you will see text like this + +``` +Started exporting +Run command for all TaskRuns. +Finished successfully! +Files created: + - Database dump - //outputs/export/2024_01_01_00_00_01_mephisto_dump.json + - Data files dump - //outputs/export/2024_01_01_00_00_01_mephisto_dump.zip +``` + +### Import just created dump into main project + +Put your dump into export directory `/mephisto/outputs/export/` and you can use just a dump name in the command, +or use a full path to the file. +Let's just imagine, you put file in export directory: + +```shell +mephisto db import --dump-file 2024_01_01_00_00_01_mephisto_dump.json +``` + +And you will see text like this + +``` +Started importing from dump '2024_01_01_00_00_01_mephisto_dump.json' +Are you sure? It will affect your databases and related files. Type 'yes' and press Enter if you want to proceed: yes +Just in case, we are making a backup of all your local data. If something went wrong during import, we will restore all your data from this backup +Backup was created successfully! File: '/mephisto/outputs/backup/2024_01_01_00_10_01_mephisto_backup.zip' +Finished successfully +``` + +Note that the progress will stop and will be waiting for your answer __yes__. +Also, we create a backup automatically just in case too, just before all changes. + +### Restoring from backup + +"OMG! I imported wrong dump! What have I done!" - you may cry. + +No worries, just restore everything from your or our backup: + +```shell +mephisto db restore --backup-file 2024_01_01_00_10_01.zip +``` + +And you will see text like this + +``` +Started restoring from backup '2024_01_01_00_10_01.zip' +Are you sure? It will affect your databases and related files. Type 'yes' and press Enter if you want to proceed: yes +Finished successfully +``` + +Note that the progress will stop and will be waiting for your answer __yes__. + +### Conclusion + +Now, after you merged your two projects, you can easily start +[reviewing your tasks](/docs/guides/how_to_use/review_app/overview/). diff --git a/docs/web/docs/guides/how_to_use/review_app/server_api.md b/docs/web/docs/guides/how_to_use/review_app/server_api.md index 2b5714d12..59b826082 100644 --- a/docs/web/docs/guides/how_to_use/review_app/server_api.md +++ b/docs/web/docs/guides/how_to_use/review_app/server_api.md @@ -149,7 +149,7 @@ Get list of all bearers of a qualification. "worker_id": , "value": , "unit_review_id": , // latest grant of this qualification - "granted_at": , // maps to `unit_review.created_at` column + "granted_at": , // maps to `unit_review.creation_date` column }, ... // more qualified workers ] @@ -301,7 +301,7 @@ Get list of all granted qualifications for a worker "worker_id": , "qualification_id": , "value": , - "granted_at": , // maps to `unit_review.created_at` column + "granted_at": , // maps to `unit_review.creation_date` column } ], ... // more granted qualifications diff --git a/mephisto/abstractions/database.py b/mephisto/abstractions/database.py index 7f7716dc8..2fa77f477 100644 --- a/mephisto/abstractions/database.py +++ b/mephisto/abstractions/database.py @@ -6,44 +6,37 @@ import os -import sqlite3 import warnings +from abc import ABC +from abc import abstractmethod +from typing import Any +from typing import Dict +from typing import List +from typing import Mapping +from typing import Optional +from typing import Union + from prometheus_client import Histogram # type: ignore -from abc import ABC, abstractmethod -from mephisto.utils.dirs import get_data_dir -from mephisto.operations.registry import ( - get_crowd_provider_from_type, - get_valid_provider_types, -) -from typing import Mapping, Optional, Any, List, Dict, Union -import enum -from mephisto.data_model.agent import Agent, OnboardingAgent -from mephisto.data_model.unit import Unit +from mephisto.data_model.agent import Agent +from mephisto.data_model.agent import OnboardingAgent from mephisto.data_model.assignment import Assignment from mephisto.data_model.project import Project +from mephisto.data_model.qualification import GrantedQualification +from mephisto.data_model.qualification import Qualification from mephisto.data_model.requester import Requester from mephisto.data_model.task import Task from mephisto.data_model.task_run import TaskRun +from mephisto.data_model.unit import Unit from mephisto.data_model.worker import Worker -from mephisto.data_model.qualification import Qualification, GrantedQualification +from mephisto.operations.registry import get_crowd_provider_from_type +from mephisto.operations.registry import get_valid_provider_types +from mephisto.utils.dirs import get_data_dir # TODO(#101) investigate cursors for DB queries as the project scales -class MephistoDBException(Exception): - pass - - -class EntryAlreadyExistsException(MephistoDBException): - pass - - -class EntryDoesNotExistException(MephistoDBException): - pass - - # Initialize histogram for database latency DATABASE_LATENCY = Histogram("database_latency_seconds", "Logging for db requests", ["method"]) # Need all the specific decorators b/c cascading is not allowed in decorators diff --git a/mephisto/abstractions/databases/local_database.py b/mephisto/abstractions/databases/local_database.py index c26b8c745..c6b0ef676 100644 --- a/mephisto/abstractions/databases/local_database.py +++ b/mephisto/abstractions/databases/local_database.py @@ -4,32 +4,44 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from mephisto.abstractions.database import ( - MephistoDB, - MephistoDBException, - EntryAlreadyExistsException, - EntryDoesNotExistException, -) -from typing import Mapping, Optional, Any, List, Dict, Tuple, Union -from mephisto.operations.registry import get_valid_provider_types -from mephisto.data_model.agent import Agent, AgentState, OnboardingAgent -from mephisto.data_model.unit import Unit -from mephisto.data_model.assignment import Assignment, AssignmentState +import json +import os +import sqlite3 +import threading +from sqlite3 import Connection +from typing import Any +from typing import Dict +from typing import List +from typing import Mapping +from typing import Optional +from typing import Tuple +from typing import Union + +from mephisto.abstractions.database import MephistoDB +from mephisto.data_model.agent import Agent +from mephisto.data_model.agent import AgentState +from mephisto.data_model.agent import OnboardingAgent +from mephisto.data_model.assignment import Assignment +from mephisto.data_model.assignment import AssignmentState from mephisto.data_model.constants import NO_PROJECT_NAME from mephisto.data_model.project import Project +from mephisto.data_model.qualification import GrantedQualification +from mephisto.data_model.qualification import Qualification from mephisto.data_model.requester import Requester from mephisto.data_model.task import Task from mephisto.data_model.task_run import TaskRun +from mephisto.data_model.unit import Unit from mephisto.data_model.worker import Worker -from mephisto.data_model.qualification import Qualification, GrantedQualification - -import sqlite3 -from sqlite3 import Connection -import threading -import os -import json - +from mephisto.operations.registry import get_valid_provider_types +from mephisto.utils.db import apply_migrations +from mephisto.utils.db import EntryAlreadyExistsException +from mephisto.utils.db import EntryDoesNotExistException +from mephisto.utils.db import make_randomized_int_id +from mephisto.utils.db import MephistoDBException +from mephisto.utils.db import retry_generate_id from mephisto.utils.logger_core import get_logger +from . import local_database_tables as tables +from .migrations import migrations logger = get_logger(name=__name__) @@ -46,209 +58,27 @@ def assert_valid_provider(provider_type: str) -> None: valid_types = get_valid_provider_types() if provider_type not in valid_types: raise MephistoDBException( - f"Supplied provider {provider_type} is not in supported list of providers {valid_types}." + f"Supplied provider {provider_type} is not in supported list of " + f"providers {valid_types}." ) def is_key_failure(e: sqlite3.IntegrityError) -> bool: """ - Return if the given error is representing a foreign key - failure, where an insertion was expecting something to - exist already in the DB but it didn't. + Return if the given error is representing a foreign key failure, + where an insertion was expecting something to exist already in the DB, but it didn't. """ return str(e) == "FOREIGN KEY constraint failed" def is_unique_failure(e: sqlite3.IntegrityError) -> bool: """ - Return if the given error is representing a foreign key - failure, where an insertion was expecting something to - exist already in the DB but it didn't. + Return if the given error is representing a foreign key failure, + where an insertion was expecting something to exist already in the DB, but it didn't. """ return str(e).startswith("UNIQUE constraint") -CREATE_PROJECTS_TABLE = """CREATE TABLE IF NOT EXISTS projects ( - project_id INTEGER PRIMARY KEY AUTOINCREMENT, - project_name TEXT NOT NULL UNIQUE, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP -); -""" - -CREATE_TASKS_TABLE = """CREATE TABLE IF NOT EXISTS tasks ( - task_id INTEGER PRIMARY KEY AUTOINCREMENT, - task_name TEXT NOT NULL UNIQUE, - task_type TEXT NOT NULL, - project_id INTEGER, - parent_task_id INTEGER, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (parent_task_id) REFERENCES tasks (task_id), - FOREIGN KEY (project_id) REFERENCES projects (project_id) -); -""" - -CREATE_REQUESTERS_TABLE = """CREATE TABLE IF NOT EXISTS requesters ( - requester_id INTEGER PRIMARY KEY AUTOINCREMENT, - requester_name TEXT NOT NULL UNIQUE, - provider_type TEXT NOT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP -); -""" - -CREATE_TASK_RUNS_TABLE = """ - CREATE TABLE IF NOT EXISTS task_runs ( - task_run_id INTEGER PRIMARY KEY AUTOINCREMENT, - task_id INTEGER NOT NULL, - requester_id INTEGER NOT NULL, - init_params TEXT NOT NULL, - is_completed BOOLEAN NOT NULL, - provider_type TEXT NOT NULL, - task_type TEXT NOT NULL, - sandbox BOOLEAN NOT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (task_id) REFERENCES tasks (task_id), - FOREIGN KEY (requester_id) REFERENCES requesters (requester_id) -); -""" - -CREATE_ASSIGNMENTS_TABLE = """CREATE TABLE IF NOT EXISTS assignments ( - assignment_id INTEGER PRIMARY KEY AUTOINCREMENT, - task_id INTEGER NOT NULL, - task_run_id INTEGER NOT NULL, - requester_id INTEGER NOT NULL, - task_type TEXT NOT NULL, - provider_type TEXT NOT NULL, - sandbox BOOLEAN NOT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (task_id) REFERENCES tasks (task_id), - FOREIGN KEY (task_run_id) REFERENCES task_runs (task_run_id), - FOREIGN KEY (requester_id) REFERENCES requesters (requester_id) -); -""" - -CREATE_UNITS_TABLE = """CREATE TABLE IF NOT EXISTS units ( - unit_id INTEGER PRIMARY KEY AUTOINCREMENT, - assignment_id INTEGER NOT NULL, - unit_index INTEGER NOT NULL, - pay_amount FLOAT NOT NULL, - provider_type TEXT NOT NULL, - status TEXT NOT NULL, - agent_id INTEGER, - worker_id INTEGER, - task_type TEXT NOT NULL, - task_id INTEGER NOT NULL, - task_run_id INTEGER NOT NULL, - sandbox BOOLEAN NOT NULL, - requester_id INTEGER NOT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (assignment_id) REFERENCES assignments (assignment_id), - FOREIGN KEY (agent_id) REFERENCES agents (agent_id), - FOREIGN KEY (task_run_id) REFERENCES task_runs (task_run_id), - FOREIGN KEY (task_id) REFERENCES tasks (task_id), - FOREIGN KEY (requester_id) REFERENCES requesters (requester_id), - FOREIGN KEY (worker_id) REFERENCES workers (worker_id), - UNIQUE (assignment_id, unit_index) -); -""" - -CREATE_WORKERS_TABLE = """CREATE TABLE IF NOT EXISTS workers ( - worker_id INTEGER PRIMARY KEY AUTOINCREMENT, - worker_name TEXT NOT NULL UNIQUE, - provider_type TEXT NOT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP -); -""" - -CREATE_AGENTS_TABLE = """CREATE TABLE IF NOT EXISTS agents ( - agent_id INTEGER PRIMARY KEY AUTOINCREMENT, - worker_id INTEGER NOT NULL, - unit_id INTEGER NOT NULL, - task_id INTEGER NOT NULL, - task_run_id INTEGER NOT NULL, - assignment_id INTEGER NOT NULL, - task_type TEXT NOT NULL, - provider_type TEXT NOT NULL, - status TEXT NOT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (worker_id) REFERENCES workers (worker_id), - FOREIGN KEY (unit_id) REFERENCES units (unit_id) -); -""" - -CREATE_ONBOARDING_AGENTS_TABLE = """CREATE TABLE IF NOT EXISTS onboarding_agents ( - onboarding_agent_id INTEGER PRIMARY KEY AUTOINCREMENT, - worker_id INTEGER NOT NULL, - task_id INTEGER NOT NULL, - task_run_id INTEGER NOT NULL, - task_type TEXT NOT NULL, - status TEXT NOT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (worker_id) REFERENCES workers (worker_id), - FOREIGN KEY (task_run_id) REFERENCES task_runs (task_run_id) -); -""" - -CREATE_QUALIFICATIONS_TABLE = """CREATE TABLE IF NOT EXISTS qualifications ( - qualification_id INTEGER PRIMARY KEY AUTOINCREMENT, - qualification_name TEXT NOT NULL UNIQUE, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP -); -""" - -CREATE_GRANTED_QUALIFICATIONS_TABLE = """ -CREATE TABLE IF NOT EXISTS granted_qualifications ( - granted_qualification_id INTEGER PRIMARY KEY AUTOINCREMENT, - worker_id INTEGER NOT NULL, - qualification_id INTEGER NOT NULL, - value INTEGER NOT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (worker_id) REFERENCES workers (worker_id), - FOREIGN KEY (qualification_id) REFERENCES qualifications (qualification_id), - UNIQUE (worker_id, qualification_id) -); -""" - -CREATE_UNIT_REVIEW_TABLE = """ - CREATE TABLE IF NOT EXISTS unit_review ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - unit_id INTEGER NOT NULL, - worker_id INTEGER NOT NULL, - task_id INTEGER NOT NULL, - status TEXT NOT NULL, - review_note TEXT, - bonus INTEGER, - blocked_worker BOOLEAN DEFAULT false, - /* ID of `db.qualifications` (not `db.granted_qualifications`) */ - updated_qualification_id INTEGER, - updated_qualification_value INTEGER, - /* ID of `db.qualifications` (not `db.granted_qualifications`) */ - revoked_qualification_id INTEGER, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - - FOREIGN KEY (unit_id) REFERENCES units (unit_id), - FOREIGN KEY (worker_id) REFERENCES workers (worker_id), - FOREIGN KEY (task_id) REFERENCES tasks (task_id) - ); -""" - -# Indices that are used by system-specific calls across Mephisto during live tasks -# that improve the runtime of the system as a whole -CREATE_CORE_INDEXES = """ -CREATE INDEX IF NOT EXISTS requesters_by_provider_index ON requesters(provider_type); -CREATE INDEX IF NOT EXISTS unit_by_status_index ON units(status); -CREATE INDEX IF NOT EXISTS unit_by_assignment_id_index ON units(assignment_id); -CREATE INDEX IF NOT EXISTS unit_by_task_run_index ON units(task_run_id); -CREATE INDEX IF NOT EXISTS unit_by_task_run_by_worker_by_status_index ON units(task_run_id, worker_id, status); -CREATE INDEX IF NOT EXISTS unit_by_task_by_worker_index ON units(task_id, worker_id); -CREATE INDEX IF NOT EXISTS agent_by_worker_by_status_index ON agents(worker_id, status); -CREATE INDEX IF NOT EXISTS agent_by_task_run_index ON agents(task_run_id); -CREATE INDEX IF NOT EXISTS assignment_by_task_run_index ON assignments(task_run_id); -CREATE INDEX IF NOT EXISTS task_run_by_requester_index ON task_runs(requester_id); -CREATE INDEX IF NOT EXISTS task_run_by_task_index ON task_runs(task_id); -CREATE INDEX IF NOT EXISTS unit_review_by_unit_index ON unit_review(unit_id); -""" - - class StringIDRow(sqlite3.Row): def __getitem__(self, key: str) -> Any: val = super().__getitem__(key) @@ -261,7 +91,7 @@ def __getitem__(self, key: str) -> Any: class LocalMephistoDB(MephistoDB): """ Local database for core Mephisto data storage, the LocalMephistoDatabase handles - grounding all of the python interactions with the Mephisto architecture to + grounding all the python interactions with the Mephisto architecture to local files and a database. """ @@ -271,7 +101,7 @@ def __init__(self, database_path=None): self.table_access_condition = threading.Condition() super().__init__(database_path) - def _get_connection(self) -> Connection: + def get_connection(self) -> Connection: """Returns a singular database connection to be shared amongst all calls for a given thread. """ @@ -297,23 +127,34 @@ def init_tables(self) -> None: Run all the table creation SQL queries to ensure the expected tables exist """ with self.table_access_condition: - conn = self._get_connection() - conn.execute("PRAGMA foreign_keys = 1") + conn = self.get_connection() + conn.execute("PRAGMA foreign_keys = on;") + with conn: c = conn.cursor() - c.execute(CREATE_PROJECTS_TABLE) - c.execute(CREATE_TASKS_TABLE) - c.execute(CREATE_REQUESTERS_TABLE) - c.execute(CREATE_TASK_RUNS_TABLE) - c.execute(CREATE_ASSIGNMENTS_TABLE) - c.execute(CREATE_UNITS_TABLE) - c.execute(CREATE_WORKERS_TABLE) - c.execute(CREATE_AGENTS_TABLE) - c.execute(CREATE_QUALIFICATIONS_TABLE) - c.execute(CREATE_GRANTED_QUALIFICATIONS_TABLE) - c.execute(CREATE_ONBOARDING_AGENTS_TABLE) - c.execute(CREATE_UNIT_REVIEW_TABLE) - c.executescript(CREATE_CORE_INDEXES) + c.execute(tables.CREATE_IF_NOT_EXISTS_PROJECTS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_TASKS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_REQUESTERS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_TASK_RUNS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_ASSIGNMENTS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_UNITS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_WORKERS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_AGENTS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_QUALIFICATIONS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_GRANTED_QUALIFICATIONS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_ONBOARDING_AGENTS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_UNIT_REVIEW_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_IMPORT_DATA_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_MIGRATIONS_TABLE) + + apply_migrations(self, migrations) + + # Creating indices must be after migrations. + # SQLite have a lack of features comparing to other databases, + # and, e.g., if we need to change a constraint, we need to recteate a table. + # We will lose indices in this case, or we need to repeat creating in the migration + with conn: + c.executescript(tables.CREATE_IF_NOT_EXISTS_CORE_INDICES) def __get_one_by_id(self, table_name: str, id_name: str, db_id: str) -> Mapping[str, Any]: """ @@ -321,7 +162,7 @@ def __get_one_by_id(self, table_name: str, id_name: str, db_id: str) -> Mapping[ raise EntryDoesNotExistException if it isn't present """ with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( f""" @@ -335,8 +176,8 @@ def __get_one_by_id(self, table_name: str, id_name: str, db_id: str) -> Mapping[ raise EntryDoesNotExistException(f"Table {table_name} has no {id_name} {db_id}") return results[0] + @staticmethod def __create_query_and_tuple( - self, arg_list: List[str], arg_vals: List[Optional[Union[str, int, bool]]], ) -> Tuple[str, tuple]: @@ -362,24 +203,41 @@ def __create_query_and_tuple( return "".join(query_lines), tuple(fin_vals) + @retry_generate_id(caught_excs=[EntryAlreadyExistsException]) def _new_project(self, project_name: str) -> str: """ - Create a new project with the given project name. Raise EntryAlreadyExistsException if a project - with this name has already been created. + Create a new project with the given project name. + Raise EntryAlreadyExistsException if a project with this name has already been created. """ if project_name in [NO_PROJECT_NAME, ""]: raise MephistoDBException(f'Invalid project name "{project_name}') - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() try: - c.execute("INSERT INTO projects(project_name) VALUES (?);", (project_name,)) + c.execute( + """ + INSERT INTO projects( + project_id, + project_name + ) VALUES (?, ?); + """, + ( + make_randomized_int_id(), + project_name, + ) + ) project_id = str(c.lastrowid) return project_id except sqlite3.IntegrityError as e: if is_key_failure(e): raise EntryDoesNotExistException() elif is_unique_failure(e): - raise EntryAlreadyExistsException(f"Project {project_name} already exists") + raise EntryAlreadyExistsException( + f"Project {project_name} already exists", + db=self, + table_name="projects", + original_exc=e, + ) raise MephistoDBException(e) def _get_project(self, project_id: str) -> Mapping[str, Any]: @@ -397,7 +255,7 @@ def _find_projects(self, project_name: Optional[str] = None) -> List[Project]: return all projects. """ with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() additional_query, arg_tuple = self.__create_query_and_tuple( ["project_name"], [project_name] @@ -412,6 +270,7 @@ def _find_projects(self, project_name: Optional[str] = None) -> List[Project]: rows = c.fetchall() return [Project(self, str(r["project_id"]), row=r, _used_new_call=True) for r in rows] + @retry_generate_id(caught_excs=[EntryAlreadyExistsException]) def _new_task( self, task_name: str, @@ -424,17 +283,21 @@ def _new_task( """ if task_name in [""]: raise MephistoDBException(f'Invalid task name "{task_name}') - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() try: c.execute( - """INSERT INTO tasks( + """ + INSERT INTO tasks( + task_id, task_name, task_type, project_id, parent_task_id - ) VALUES (?, ?, ?, ?);""", + ) VALUES (?, ?, ?, ?, ?); + """, ( + make_randomized_int_id(), task_name, task_type, nonesafe_int(project_id), @@ -447,7 +310,9 @@ def _new_task( if is_key_failure(e): raise EntryDoesNotExistException(e) elif is_unique_failure(e): - raise EntryAlreadyExistsException(e) + raise EntryAlreadyExistsException( + e, db=self, table_name="tasks", original_exc=e, + ) raise MephistoDBException(e) def _get_task(self, task_id: str) -> Mapping[str, Any]: @@ -469,7 +334,7 @@ def _find_tasks( return all tasks. """ with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() additional_query, arg_tuple = self.__create_query_and_tuple( ["task_name", "project_id", "parent_task_id"], @@ -492,10 +357,11 @@ def _update_task( project_id: Optional[str] = None, ) -> None: """ - Update the given task with the given parameters if possible, raise appropriate exception otherwise. + Update the given task with the given parameters if possible, + raise appropriate exception otherwise. - Tasks can only be updated if no runs exist for this task yet, otherwise there's too much state - and we shouldn't make changes. + Tasks can only be updated if no runs exist for this task yet, + otherwise there's too much state, and we shouldn't make changes. """ if len(self.find_task_runs(task_id=task_id)) != 0: raise MephistoDBException( @@ -503,7 +369,7 @@ def _update_task( ) if task_name in [""]: raise MephistoDBException(f'Invalid task name "{task_name}') - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() try: if task_name is not None: @@ -528,9 +394,15 @@ def _update_task( if is_key_failure(e): raise EntryDoesNotExistException(e) elif is_unique_failure(e): - raise EntryAlreadyExistsException(f"Task name {task_name} is already in use") + raise EntryAlreadyExistsException( + f"Task name {task_name} is already in use", + db=self, + table_name="units", + original_exc=e, + ) raise MephistoDBException(e) + @retry_generate_id(caught_excs=[EntryAlreadyExistsException]) def _new_task_run( self, task_id: str, @@ -541,13 +413,14 @@ def _new_task_run( sandbox: bool = True, ) -> str: """Create a new task_run for the given task.""" - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: # Ensure given ids are valid c = conn.cursor() try: c.execute( """ INSERT INTO task_runs( + task_run_id, task_id, requester_id, init_params, @@ -556,8 +429,10 @@ def _new_task_run( task_type, sandbox ) - VALUES (?, ?, ?, ?, ?, ?, ?);""", + VALUES (?, ?, ?, ?, ?, ?, ?, ?); + """, ( + make_randomized_int_id(), int(task_id), int(requester_id), init_params, @@ -572,12 +447,16 @@ def _new_task_run( except sqlite3.IntegrityError as e: if is_key_failure(e): raise EntryDoesNotExistException(e) + elif is_unique_failure(e): + raise EntryAlreadyExistsException( + e, db=self, table_name="task_runs", original_exc=e, + ) raise MephistoDBException(e) def _get_task_run(self, task_run_id: str) -> Mapping[str, Any]: """ - Return the given task_run's fields by task_run_id, raise EntryDoesNotExistException if no id exists - in task_runs. + Return the given task_run's fields by task_run_id, raise EntryDoesNotExistException + if no id exists in task_runs. Returns a SQLite Row object with the expected fields """ @@ -594,7 +473,7 @@ def _find_task_runs( return all task_runs. """ with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() additional_query, arg_tuple = self.__create_query_and_tuple( ["task_id", "requester_id", "is_completed"], @@ -614,7 +493,7 @@ def _update_task_run(self, task_run_id: str, is_completed: bool): """ Update a task run. At the moment, can only update completion status """ - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() try: c.execute( @@ -630,6 +509,7 @@ def _update_task_run(self, task_run_id: str, is_completed: bool): raise EntryDoesNotExistException(e) raise MephistoDBException(e) + @retry_generate_id(caught_excs=[EntryAlreadyExistsException]) def _new_assignment( self, task_id: str, @@ -642,29 +522,39 @@ def _new_assignment( """Create a new assignment for the given task""" # Ensure task run exists self.get_task_run(task_run_id) - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() - c.execute( - """ - INSERT INTO assignments( - task_id, - task_run_id, - requester_id, - task_type, - provider_type, - sandbox - ) VALUES (?, ?, ?, ?, ?, ?);""", - ( - int(task_id), - int(task_run_id), - int(requester_id), - task_type, - provider_type, - sandbox, - ), - ) - assignment_id = str(c.lastrowid) - return assignment_id + try: + c.execute( + """ + INSERT INTO assignments( + assignment_id, + task_id, + task_run_id, + requester_id, + task_type, + provider_type, + sandbox + ) VALUES (?, ?, ?, ?, ?, ?, ?); + """, + ( + make_randomized_int_id(), + int(task_id), + int(task_run_id), + int(requester_id), + task_type, + provider_type, + sandbox, + ), + ) + assignment_id = str(c.lastrowid) + return assignment_id + except sqlite3.IntegrityError as e: + if is_unique_failure(e): + raise EntryAlreadyExistsException( + e, db=self, table_name="assignments", original_exc=e, + ) + raise MephistoDBException(e) def _get_assignment(self, assignment_id: str) -> Mapping[str, Any]: """ @@ -689,7 +579,7 @@ def _find_assignments( return all tasks. """ with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() additional_query, arg_tuple = self.__create_query_and_tuple( [ @@ -721,6 +611,7 @@ def _find_assignments( Assignment(self, str(r["assignment_id"]), row=r, _used_new_call=True) for r in rows ] + @retry_generate_id(caught_excs=[EntryAlreadyExistsException]) def _new_unit( self, task_id: str, @@ -737,11 +628,13 @@ def _new_unit( Create a new unit with the given index. Raises EntryAlreadyExistsException if there is already a unit for the given assignment with the given index. """ - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() try: c.execute( - """INSERT INTO units( + """ + INSERT INTO units( + unit_id, task_id, task_run_id, requester_id, @@ -752,8 +645,10 @@ def _new_unit( task_type, sandbox, status - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);""", + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); + """, ( + make_randomized_int_id(), int(task_id), int(task_run_id), int(requester_id), @@ -772,7 +667,9 @@ def _new_unit( if is_key_failure(e): raise EntryDoesNotExistException(e) elif is_unique_failure(e): - raise EntryAlreadyExistsException(e) + raise EntryAlreadyExistsException( + e, db=self, table_name="units", original_exc=e, + ) raise MephistoDBException(e) def _get_unit(self, unit_id: str) -> Mapping[str, Any]: @@ -803,7 +700,7 @@ def _find_units( return all units. """ with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() additional_query, arg_tuple = self.__create_query_and_tuple( [ @@ -848,7 +745,7 @@ def _clear_unit_agent_assignment(self, unit_id: str) -> None: Update the given unit by removing the agent that is assigned to it, thus updating the status to assignable. """ - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() try: c.execute( @@ -870,11 +767,12 @@ def _update_unit( self, unit_id: str, agent_id: Optional[str] = None, status: Optional[str] = None ) -> None: """ - Update the given task with the given parameters if possible, raise appropriate exception otherwise. + Update the given task with the given parameters if possible, + raise appropriate exception otherwise. """ if status not in AssignmentState.valid_unit(): raise MephistoDBException(f"Invalid status {status} for a unit") - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() try: if agent_id is not None: @@ -902,6 +800,7 @@ def _update_unit( ) raise MephistoDBException(e) + @retry_generate_id(caught_excs=[EntryAlreadyExistsException]) def _new_requester(self, requester_name: str, provider_type: str) -> str: """ Create a new requester with the given name and provider type. @@ -911,18 +810,30 @@ def _new_requester(self, requester_name: str, provider_type: str) -> str: if requester_name == "": raise MephistoDBException("Empty string is not a valid requester name") assert_valid_provider(provider_type) - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() try: c.execute( - "INSERT INTO requesters(requester_name, provider_type) VALUES (?, ?);", - (requester_name, provider_type), + """ + INSERT INTO requesters( + requester_id, + requester_name, + provider_type + ) VALUES (?, ?, ?); + """, + ( + make_randomized_int_id(), + requester_name, + provider_type, + ), ) requester_id = str(c.lastrowid) return requester_id except sqlite3.IntegrityError as e: if is_unique_failure(e): - raise EntryAlreadyExistsException() + raise EntryAlreadyExistsException( + e, db=self, table_name="requesters", original_exc=e, + ) raise MephistoDBException(e) def _get_requester(self, requester_id: str) -> Mapping[str, Any]: @@ -942,7 +853,7 @@ def _find_requesters( return all requesters. """ with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() additional_query, arg_tuple = self.__create_query_and_tuple( ["requester_name", "provider_type"], [requester_name, provider_type] @@ -959,6 +870,7 @@ def _find_requesters( Requester(self, str(r["requester_id"]), row=r, _used_new_call=True) for r in rows ] + @retry_generate_id(caught_excs=[EntryAlreadyExistsException]) def _new_worker(self, worker_name: str, provider_type: str) -> str: """ Create a new worker with the given name and provider type. @@ -971,18 +883,30 @@ def _new_worker(self, worker_name: str, provider_type: str) -> str: if worker_name == "": raise MephistoDBException("Empty string is not a valid requester name") assert_valid_provider(provider_type) - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() try: c.execute( - "INSERT INTO workers(worker_name, provider_type) VALUES (?, ?);", - (worker_name, provider_type), + """ + INSERT INTO workers( + worker_id, + worker_name, + provider_type + ) VALUES (?, ?, ?); + """, + ( + make_randomized_int_id(), + worker_name, + provider_type, + ), ) worker_id = str(c.lastrowid) return worker_id except sqlite3.IntegrityError as e: if is_unique_failure(e): - raise EntryAlreadyExistsException() + raise EntryAlreadyExistsException( + e, db=self, table_name="workers", original_exc=e, + ) raise MephistoDBException(e) def _get_worker(self, worker_id: str) -> Mapping[str, Any]: @@ -1002,7 +926,7 @@ def _find_workers( return all workers. """ with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() additional_query, arg_tuple = self.__create_query_and_tuple( ["worker_name", "provider_type"], [worker_name, provider_type] @@ -1017,6 +941,7 @@ def _find_workers( rows = c.fetchall() return [Worker(self, str(r["worker_id"]), row=r, _used_new_call=True) for r in rows] + @retry_generate_id(caught_excs=[EntryAlreadyExistsException]) def _new_agent( self, worker_id: str, @@ -1029,15 +954,16 @@ def _new_agent( ) -> str: """ Create a new agent with the given name and provider type. - Raises EntryAlreadyExistsException - if there is already a agent with this name + Raises EntryAlreadyExistsException if there is already an agent with this name """ assert_valid_provider(provider_type) - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() try: c.execute( - """INSERT INTO agents( + """ + INSERT INTO agents( + agent_id, worker_id, unit_id, task_id, @@ -1046,8 +972,10 @@ def _new_agent( task_type, provider_type, status - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?);""", + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?); + """, ( + make_randomized_int_id(), int(worker_id), int(unit_id), int(task_id), @@ -1076,12 +1004,16 @@ def _new_agent( except sqlite3.IntegrityError as e: if is_key_failure(e): raise EntryDoesNotExistException(e) + elif is_unique_failure(e): + raise EntryAlreadyExistsException( + e, db=self, table_name="agents", original_exc=e, + ) raise MephistoDBException(e) def _get_agent(self, agent_id: str) -> Mapping[str, Any]: """ Return agent's fields by agent_id, raise EntryDoesNotExistException - if no id exists in agents + if no id exists in agents. Returns a SQLite Row object with the expected fields """ @@ -1089,12 +1021,13 @@ def _get_agent(self, agent_id: str) -> Mapping[str, Any]: def _update_agent(self, agent_id: str, status: Optional[str] = None) -> None: """ - Update the given task with the given parameters if possible, raise appropriate exception otherwise. + Update the given task with the given parameters if possible, + raise appropriate exception otherwise. """ if status not in AgentState.valid(): raise MephistoDBException(f"Invalid status {status} for an agent") - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() c.execute( """ @@ -1121,7 +1054,7 @@ def _find_agents( return all agents. """ with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() additional_query, arg_tuple = self.__create_query_and_tuple( [ @@ -1162,7 +1095,7 @@ def _make_qualification(self, qualification_name: str) -> str: """ if qualification_name == "": raise MephistoDBException("Empty string is not a valid qualification name") - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() try: c.execute( @@ -1173,7 +1106,9 @@ def _make_qualification(self, qualification_name: str) -> str: return qualification_id except sqlite3.IntegrityError as e: if is_unique_failure(e): - raise EntryAlreadyExistsException() + raise EntryAlreadyExistsException( + e, db=self, table_name="units", original_exc=e, + ) raise MephistoDBException(e) def _find_qualifications(self, qualification_name: Optional[str] = None) -> List[Qualification]: @@ -1181,7 +1116,7 @@ def _find_qualifications(self, qualification_name: Optional[str] = None) -> List Find a qualification. If no name is supplied, returns all qualifications. """ with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() additional_query, arg_tuple = self.__create_query_and_tuple( ["qualification_name"], [qualification_name] @@ -1216,7 +1151,7 @@ def _delete_qualification(self, qualification_name: str) -> None: if len(qualifications) == 0: raise EntryDoesNotExistException(f"No qualification found by name {qualification_name}") qualification = qualifications[0] - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() c.execute( "DELETE FROM granted_qualifications WHERE qualification_id = ?1;", @@ -1236,7 +1171,7 @@ def _grant_qualification(self, qualification_id: str, worker_id: str, value: int try: # Update existing entry qual_row = self.get_granted_qualification(qualification_id, worker_id) - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: if value != qual_row["value"]: c = conn.cursor() c.execute( @@ -1251,7 +1186,7 @@ def _grant_qualification(self, qualification_id: str, worker_id: str, value: int conn.commit() return None except EntryDoesNotExistException: - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() try: c.execute( @@ -1264,12 +1199,13 @@ def _grant_qualification(self, qualification_id: str, worker_id: str, value: int """, (int(qualification_id), int(worker_id), value), ) - qualification_id = str(c.lastrowid) conn.commit() return None except sqlite3.IntegrityError as e: if is_unique_failure(e): - raise EntryAlreadyExistsException() + raise EntryAlreadyExistsException( + e, db=self, table_name="units", original_exc=e, + ) raise MephistoDBException(e) def _check_granted_qualifications( @@ -1282,7 +1218,7 @@ def _check_granted_qualifications( Find granted qualifications that match the given specifications """ with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ @@ -1314,7 +1250,7 @@ def _get_granted_qualification( See GrantedQualification for the expected fields for the returned mapping """ with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( f""" @@ -1335,16 +1271,18 @@ def _revoke_qualification(self, qualification_id: str, worker_id: str) -> None: """ Remove the given qualification from the given worker """ - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() c.execute( - """DELETE FROM granted_qualifications + """ + DELETE FROM granted_qualifications WHERE (qualification_id = ?1) AND (worker_id = ?2); """, (int(qualification_id), int(worker_id)), ) + @retry_generate_id(caught_excs=[EntryAlreadyExistsException]) def _new_onboarding_agent( self, worker_id: str, task_id: str, task_run_id: str, task_type: str ) -> str: @@ -1352,18 +1290,22 @@ def _new_onboarding_agent( Create a new agent for the given worker id to assign to the given unit Raises EntryAlreadyExistsException """ - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() try: c.execute( - """INSERT INTO onboarding_agents( + """ + INSERT INTO onboarding_agents( + onboarding_agent_id, worker_id, task_id, task_run_id, task_type, status - ) VALUES (?, ?, ?, ?, ?);""", + ) VALUES (?, ?, ?, ?, ?, ?); + """, ( + make_randomized_int_id(), int(worker_id), int(task_id), int(task_run_id), @@ -1375,6 +1317,10 @@ def _new_onboarding_agent( except sqlite3.IntegrityError as e: if is_key_failure(e): raise EntryDoesNotExistException(e) + elif is_unique_failure(e): + raise EntryAlreadyExistsException( + e, db=self, table_name="onboarding_agents", original_exc=e, + ) raise MephistoDBException(e) def _get_onboarding_agent(self, onboarding_agent_id: str) -> Mapping[str, Any]: @@ -1395,7 +1341,7 @@ def _update_onboarding_agent( """ if status not in AgentState.valid(): raise MephistoDBException(f"Invalid status {status} for an agent") - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() if status is not None: c.execute( @@ -1420,7 +1366,7 @@ def _find_onboarding_agents( return all onboarding agents. """ with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() additional_query, arg_tuple = self.__create_query_and_tuple( [ @@ -1451,6 +1397,7 @@ def _find_onboarding_agents( for r in rows ] + @retry_generate_id(caught_excs=[EntryAlreadyExistsException]) def _new_unit_review( self, unit_id: Union[int, str], @@ -1463,29 +1410,38 @@ def _new_unit_review( """Create unit review""" with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() - c.execute( - """ - INSERT INTO unit_review ( - unit_id, - worker_id, - task_id, - status, - review_note, - bonus - ) VALUES (?, ?, ?, ?, ?, ?); - """, - ( - nonesafe_int(unit_id), - nonesafe_int(worker_id), - nonesafe_int(task_id), - status, - review_note, - bonus, - ), - ) - conn.commit() + try: + c.execute( + """ + INSERT INTO unit_review ( + id, + unit_id, + worker_id, + task_id, + status, + review_note, + bonus + ) VALUES (?, ?, ?, ?, ?, ?, ?); + """, + ( + make_randomized_int_id(), + nonesafe_int(unit_id), + nonesafe_int(worker_id), + nonesafe_int(task_id), + status, + review_note, + bonus, + ), + ) + conn.commit() + except sqlite3.IntegrityError as e: + if is_unique_failure(e): + raise EntryAlreadyExistsException( + e, db=self, table_name="unit_review", original_exc=e, + ) + raise MephistoDBException(e) def _update_unit_review( self, @@ -1500,14 +1456,14 @@ def _update_unit_review( raise appropriate exception otherwise. """ with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ SELECT * FROM unit_review WHERE (unit_id = ?) AND (worker_id = ?) - ORDER BY created_at ASC; + ORDER BY creation_date ASC; """, (unit_id, worker_id), ) diff --git a/mephisto/abstractions/databases/local_database_tables.py b/mephisto/abstractions/databases/local_database_tables.py new file mode 100644 index 000000000..ceeccb9ea --- /dev/null +++ b/mephisto/abstractions/databases/local_database_tables.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +WARNING: In this module you can find initial table structures, but not final. +There are can be changes in migrations. To see actual fields, constraints, etc., +see information in databases or look through all migrations for current database +""" + +CREATE_IF_NOT_EXISTS_PROJECTS_TABLE = """ + CREATE TABLE IF NOT EXISTS projects ( + project_id INTEGER PRIMARY KEY AUTOINCREMENT, + project_name TEXT NOT NULL UNIQUE, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + ); +""" + +CREATE_IF_NOT_EXISTS_TASKS_TABLE = """ + CREATE TABLE IF NOT EXISTS tasks ( + task_id INTEGER PRIMARY KEY AUTOINCREMENT, + task_name TEXT NOT NULL UNIQUE, + task_type TEXT NOT NULL, + project_id INTEGER, + parent_task_id INTEGER, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (parent_task_id) REFERENCES tasks (task_id), + FOREIGN KEY (project_id) REFERENCES projects (project_id) + ); +""" + +CREATE_IF_NOT_EXISTS_REQUESTERS_TABLE = """ + CREATE TABLE IF NOT EXISTS requesters ( + requester_id INTEGER PRIMARY KEY AUTOINCREMENT, + requester_name TEXT NOT NULL UNIQUE, + provider_type TEXT NOT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + ); +""" + +CREATE_IF_NOT_EXISTS_TASK_RUNS_TABLE = """ + CREATE TABLE IF NOT EXISTS task_runs ( + task_run_id INTEGER PRIMARY KEY AUTOINCREMENT, + task_id INTEGER NOT NULL, + requester_id INTEGER NOT NULL, + init_params TEXT NOT NULL, + is_completed BOOLEAN NOT NULL, + provider_type TEXT NOT NULL, + task_type TEXT NOT NULL, + sandbox BOOLEAN NOT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (task_id) REFERENCES tasks (task_id), + FOREIGN KEY (requester_id) REFERENCES requesters (requester_id) + ); +""" + +CREATE_IF_NOT_EXISTS_ASSIGNMENTS_TABLE = """ + CREATE TABLE IF NOT EXISTS assignments ( + assignment_id INTEGER PRIMARY KEY AUTOINCREMENT, + task_id INTEGER NOT NULL, + task_run_id INTEGER NOT NULL, + requester_id INTEGER NOT NULL, + task_type TEXT NOT NULL, + provider_type TEXT NOT NULL, + sandbox BOOLEAN NOT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (task_id) REFERENCES tasks (task_id), + FOREIGN KEY (task_run_id) REFERENCES task_runs (task_run_id), + FOREIGN KEY (requester_id) REFERENCES requesters (requester_id) + ); +""" + +CREATE_IF_NOT_EXISTS_UNITS_TABLE = """ + CREATE TABLE IF NOT EXISTS units ( + unit_id INTEGER PRIMARY KEY AUTOINCREMENT, + assignment_id INTEGER NOT NULL, + unit_index INTEGER NOT NULL, + pay_amount FLOAT NOT NULL, + provider_type TEXT NOT NULL, + status TEXT NOT NULL, + agent_id INTEGER, + worker_id INTEGER, + task_type TEXT NOT NULL, + task_id INTEGER NOT NULL, + task_run_id INTEGER NOT NULL, + sandbox BOOLEAN NOT NULL, + requester_id INTEGER NOT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (assignment_id) REFERENCES assignments (assignment_id), + FOREIGN KEY (agent_id) REFERENCES agents (agent_id), + FOREIGN KEY (task_run_id) REFERENCES task_runs (task_run_id), + FOREIGN KEY (task_id) REFERENCES tasks (task_id), + FOREIGN KEY (requester_id) REFERENCES requesters (requester_id), + FOREIGN KEY (worker_id) REFERENCES workers (worker_id), + UNIQUE (assignment_id, unit_index) + ); +""" + +CREATE_IF_NOT_EXISTS_WORKERS_TABLE = """ + CREATE TABLE IF NOT EXISTS workers ( + worker_id INTEGER PRIMARY KEY AUTOINCREMENT, + worker_name TEXT NOT NULL UNIQUE, + provider_type TEXT NOT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + ); +""" + +CREATE_IF_NOT_EXISTS_AGENTS_TABLE = """ + CREATE TABLE IF NOT EXISTS agents ( + agent_id INTEGER PRIMARY KEY AUTOINCREMENT, + worker_id INTEGER NOT NULL, + unit_id INTEGER NOT NULL, + task_id INTEGER NOT NULL, + task_run_id INTEGER NOT NULL, + assignment_id INTEGER NOT NULL, + task_type TEXT NOT NULL, + provider_type TEXT NOT NULL, + status TEXT NOT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (worker_id) REFERENCES workers (worker_id), + FOREIGN KEY (unit_id) REFERENCES units (unit_id) + ); +""" + +CREATE_IF_NOT_EXISTS_ONBOARDING_AGENTS_TABLE = """ + CREATE TABLE IF NOT EXISTS onboarding_agents ( + onboarding_agent_id INTEGER PRIMARY KEY AUTOINCREMENT, + worker_id INTEGER NOT NULL, + task_id INTEGER NOT NULL, + task_run_id INTEGER NOT NULL, + task_type TEXT NOT NULL, + status TEXT NOT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (worker_id) REFERENCES workers (worker_id), + FOREIGN KEY (task_run_id) REFERENCES task_runs (task_run_id) + ); +""" + +CREATE_IF_NOT_EXISTS_QUALIFICATIONS_TABLE = """ + CREATE TABLE IF NOT EXISTS qualifications ( + qualification_id INTEGER PRIMARY KEY AUTOINCREMENT, + qualification_name TEXT NOT NULL UNIQUE, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + ); +""" + +CREATE_IF_NOT_EXISTS_GRANTED_QUALIFICATIONS_TABLE = """ + CREATE TABLE IF NOT EXISTS granted_qualifications ( + granted_qualification_id INTEGER PRIMARY KEY AUTOINCREMENT, + worker_id INTEGER NOT NULL, + qualification_id INTEGER NOT NULL, + value INTEGER NOT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (worker_id) REFERENCES workers (worker_id), + FOREIGN KEY (qualification_id) REFERENCES qualifications (qualification_id), + UNIQUE (worker_id, qualification_id) + ); +""" + +CREATE_IF_NOT_EXISTS_UNIT_REVIEW_TABLE = """ + CREATE TABLE IF NOT EXISTS unit_review ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + unit_id INTEGER NOT NULL, + worker_id INTEGER NOT NULL, + task_id INTEGER NOT NULL, + status TEXT NOT NULL, + review_note TEXT, + bonus INTEGER, + blocked_worker BOOLEAN DEFAULT false, + /* ID of `db.qualifications` (not `db.granted_qualifications`) */ + updated_qualification_id INTEGER, + updated_qualification_value INTEGER, + /* ID of `db.qualifications` (not `db.granted_qualifications`) */ + revoked_qualification_id INTEGER, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + + FOREIGN KEY (unit_id) REFERENCES units (unit_id), + FOREIGN KEY (worker_id) REFERENCES workers (worker_id), + FOREIGN KEY (task_id) REFERENCES tasks (task_id) + ); +""" + +CREATE_IF_NOT_EXISTS_IMPORT_DATA_TABLE = """ + CREATE TABLE IF NOT EXISTS imported_data ( + id INTEGER PRIMARY KEY, + source_file_name TEXT NOT NULL, + data_labels TEXT NOT NULL, + table_name TEXT NOT NULL, + unique_field_names TEXT NOT NULL, /* JSON */ + unique_field_values TEXT NOT NULL, /* JSON */ + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + ); +""" + +# WARNING: Changing this table, be careful, it will affect all datastores too +CREATE_IF_NOT_EXISTS_MIGRATIONS_TABLE = """ + CREATE TABLE IF NOT EXISTS migrations ( + id INTEGER PRIMARY KEY AUTOINCREMENT , + name TEXT NOT NULL, + status TEXT NOT NULL, + error_message TEXT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + ); +""" + +# Indices that are used by system-specific calls across Mephisto during live tasks +# that improve the runtime of the system as a whole +CREATE_IF_NOT_EXISTS_CORE_INDICES = """ + CREATE INDEX IF NOT EXISTS requesters_by_provider_index ON requesters(provider_type); + CREATE INDEX IF NOT EXISTS unit_by_status_index ON units(status); + CREATE INDEX IF NOT EXISTS unit_by_assignment_id_index ON units(assignment_id); + CREATE INDEX IF NOT EXISTS unit_by_task_run_index ON units(task_run_id); + CREATE INDEX IF NOT EXISTS unit_by_task_run_by_worker_by_status_index ON units(task_run_id, worker_id, status); + CREATE INDEX IF NOT EXISTS unit_by_task_by_worker_index ON units(task_id, worker_id); + CREATE INDEX IF NOT EXISTS agent_by_worker_by_status_index ON agents(worker_id, status); + CREATE INDEX IF NOT EXISTS agent_by_task_run_index ON agents(task_run_id); + CREATE INDEX IF NOT EXISTS assignment_by_task_run_index ON assignments(task_run_id); + CREATE INDEX IF NOT EXISTS task_run_by_requester_index ON task_runs(requester_id); + CREATE INDEX IF NOT EXISTS task_run_by_task_index ON task_runs(task_id); + CREATE INDEX IF NOT EXISTS unit_review_by_unit_index ON unit_review(unit_id); +""" # noqa: E501 diff --git a/mephisto/abstractions/databases/local_singleton_database.py b/mephisto/abstractions/databases/local_singleton_database.py index 27728cacf..5f113b91e 100644 --- a/mephisto/abstractions/databases/local_singleton_database.py +++ b/mephisto/abstractions/databases/local_singleton_database.py @@ -4,39 +4,34 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from mephisto.abstractions.database import ( - MephistoDB, - MephistoDBException, - EntryAlreadyExistsException, - EntryDoesNotExistException, -) +import threading +from typing import Any +from typing import Dict +from typing import List +from typing import Mapping +from typing import Optional + from mephisto.abstractions.databases.local_database import LocalMephistoDB -from typing import Mapping, Optional, Any, List, Dict -from mephisto.utils.dirs import get_data_dir -from mephisto.operations.registry import get_valid_provider_types -from mephisto.data_model.agent import Agent, AgentState, OnboardingAgent -from mephisto.data_model.unit import Unit -from mephisto.data_model.assignment import Assignment, AssignmentState -from mephisto.data_model.constants import NO_PROJECT_NAME +from mephisto.data_model.agent import Agent +from mephisto.data_model.agent import OnboardingAgent +from mephisto.data_model.assignment import Assignment +from mephisto.data_model.assignment import AssignmentState from mephisto.data_model.project import Project +from mephisto.data_model.qualification import Qualification from mephisto.data_model.requester import Requester from mephisto.data_model.task import Task from mephisto.data_model.task_run import TaskRun +from mephisto.data_model.unit import Unit from mephisto.data_model.worker import Worker -from mephisto.data_model.qualification import Qualification, GrantedQualification - -import sqlite3 -from sqlite3 import Connection, Cursor -import threading +from mephisto.utils.logger_core import get_logger # We should be using WeakValueDictionary rather than a full dict once # we're better able to trade-off between memory and space. # from weakref import WeakValueDictionary -from mephisto.utils.logger_core import get_logger - logger = get_logger(name=__name__) + # Note: This class could be a generic factory around any MephistoDB, converting # the system to a singleton implementation. It requires all of the data being # updated locally though, so binding to LocalMephistoDB makes sense for now. diff --git a/mephisto/abstractions/databases/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py b/mephisto/abstractions/databases/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py new file mode 100644 index 000000000..0f8094fe2 --- /dev/null +++ b/mephisto/abstractions/databases/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +1. Rename `unit_review.created_at` -> `unit_review.creation_date` +2. Remove autoincrement parameter for all Primary Keys +3. Add missed Foreign Keys in `agents` table +4. Add `granted_qualifications.update_date` +""" + + +PREPARING_DB_FOR_MERGE_DBS_COMMAND = """ + ALTER TABLE unit_review RENAME COLUMN created_at TO creation_date; + + /* Disable FK constraints */ + PRAGMA foreign_keys = off; + + + /* Projects */ + CREATE TABLE IF NOT EXISTS _projects ( + project_id INTEGER PRIMARY KEY, + project_name TEXT NOT NULL UNIQUE, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + ); + INSERT INTO _projects SELECT * FROM projects; + DROP TABLE projects; + ALTER TABLE _projects RENAME TO projects; + + + /* Tasks */ + CREATE TABLE IF NOT EXISTS _tasks ( + task_id INTEGER PRIMARY KEY, + task_name TEXT NOT NULL UNIQUE, + task_type TEXT NOT NULL, + project_id INTEGER, + parent_task_id INTEGER, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (parent_task_id) REFERENCES tasks (task_id), + FOREIGN KEY (project_id) REFERENCES projects (project_id) + ); + INSERT INTO _tasks SELECT * FROM tasks; + DROP TABLE tasks; + ALTER TABLE _tasks RENAME TO tasks; + + + /* Requesters */ + CREATE TABLE IF NOT EXISTS _requesters ( + requester_id INTEGER PRIMARY KEY, + requester_name TEXT NOT NULL UNIQUE, + provider_type TEXT NOT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + ); + INSERT INTO _requesters SELECT * FROM requesters; + DROP TABLE requesters; + ALTER TABLE _requesters RENAME TO requesters; + + + /* Task Runs */ + CREATE TABLE IF NOT EXISTS _task_runs ( + task_run_id INTEGER PRIMARY KEY, + task_id INTEGER NOT NULL, + requester_id INTEGER NOT NULL, + init_params TEXT NOT NULL, + is_completed BOOLEAN NOT NULL, + provider_type TEXT NOT NULL, + task_type TEXT NOT NULL, + sandbox BOOLEAN NOT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (task_id) REFERENCES tasks (task_id), + FOREIGN KEY (requester_id) REFERENCES requesters (requester_id) + ); + INSERT INTO _task_runs SELECT * FROM task_runs; + DROP TABLE task_runs; + ALTER TABLE _task_runs RENAME TO task_runs; + + + /* Assignments */ + CREATE TABLE IF NOT EXISTS _assignments ( + assignment_id INTEGER PRIMARY KEY, + task_id INTEGER NOT NULL, + task_run_id INTEGER NOT NULL, + requester_id INTEGER NOT NULL, + task_type TEXT NOT NULL, + provider_type TEXT NOT NULL, + sandbox BOOLEAN NOT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (task_id) REFERENCES tasks (task_id), + FOREIGN KEY (task_run_id) REFERENCES task_runs (task_run_id), + FOREIGN KEY (requester_id) REFERENCES requesters (requester_id) + ); + INSERT INTO _assignments SELECT * FROM assignments; + DROP TABLE assignments; + ALTER TABLE _assignments RENAME TO assignments; + + + /* Units */ + CREATE TABLE IF NOT EXISTS _units ( + unit_id INTEGER PRIMARY KEY, + assignment_id INTEGER NOT NULL, + unit_index INTEGER NOT NULL, + pay_amount FLOAT NOT NULL, + provider_type TEXT NOT NULL, + status TEXT NOT NULL, + agent_id INTEGER, + worker_id INTEGER, + task_type TEXT NOT NULL, + task_id INTEGER NOT NULL, + task_run_id INTEGER NOT NULL, + sandbox BOOLEAN NOT NULL, + requester_id INTEGER NOT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (assignment_id) REFERENCES assignments (assignment_id), + FOREIGN KEY (agent_id) REFERENCES agents (agent_id), + FOREIGN KEY (task_run_id) REFERENCES task_runs (task_run_id), + FOREIGN KEY (task_id) REFERENCES tasks (task_id), + FOREIGN KEY (requester_id) REFERENCES requesters (requester_id), + FOREIGN KEY (worker_id) REFERENCES workers (worker_id), + UNIQUE (assignment_id, unit_index) + ); + INSERT INTO _units SELECT * FROM units; + DROP TABLE units; + ALTER TABLE _units RENAME TO units; + + + /* Workers */ + CREATE TABLE IF NOT EXISTS _workers ( + worker_id INTEGER PRIMARY KEY, + worker_name TEXT NOT NULL UNIQUE, + provider_type TEXT NOT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + ); + INSERT INTO _workers SELECT * FROM workers; + DROP TABLE workers; + ALTER TABLE _workers RENAME TO workers; + + + /* Agents */ + CREATE TABLE IF NOT EXISTS _agents ( + agent_id INTEGER PRIMARY KEY, + worker_id INTEGER NOT NULL, + unit_id INTEGER NOT NULL, + task_id INTEGER NOT NULL, + task_run_id INTEGER NOT NULL, + assignment_id INTEGER NOT NULL, + task_type TEXT NOT NULL, + provider_type TEXT NOT NULL, + status TEXT NOT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (worker_id) REFERENCES workers (worker_id), + FOREIGN KEY (unit_id) REFERENCES units (unit_id), + FOREIGN KEY (task_id) REFERENCES tasks (task_id) ON DELETE NO ACTION, + FOREIGN KEY (task_run_id) REFERENCES task_runs (task_run_id) ON DELETE NO ACTION, + FOREIGN KEY (assignment_id) REFERENCES assignments (assignment_id) ON DELETE NO ACTION + ); + INSERT INTO _agents SELECT * FROM agents; + DROP TABLE agents; + ALTER TABLE _agents RENAME TO agents; + + + /* Onboarding Agents */ + CREATE TABLE IF NOT EXISTS _onboarding_agents ( + onboarding_agent_id INTEGER PRIMARY KEY, + worker_id INTEGER NOT NULL, + task_id INTEGER NOT NULL, + task_run_id INTEGER NOT NULL, + task_type TEXT NOT NULL, + status TEXT NOT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (worker_id) REFERENCES workers (worker_id), + FOREIGN KEY (task_run_id) REFERENCES task_runs (task_run_id) + ); + INSERT INTO _onboarding_agents SELECT * FROM onboarding_agents; + DROP TABLE onboarding_agents; + ALTER TABLE _onboarding_agents RENAME TO onboarding_agents; + + + /* Qualifications */ + CREATE TABLE IF NOT EXISTS _qualifications ( + qualification_id INTEGER PRIMARY KEY, + qualification_name TEXT NOT NULL UNIQUE, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + ); + INSERT INTO _qualifications SELECT * FROM qualifications; + DROP TABLE qualifications; + ALTER TABLE _qualifications RENAME TO qualifications; + + + /* Granted Qualifications */ + CREATE TABLE IF NOT EXISTS _granted_qualifications ( + granted_qualification_id INTEGER PRIMARY KEY, + worker_id INTEGER NOT NULL, + qualification_id INTEGER NOT NULL, + value INTEGER NOT NULL, + update_date DATETIME DEFAULT CURRENT_TIMESTAMP, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (worker_id) REFERENCES workers (worker_id), + FOREIGN KEY (qualification_id) REFERENCES qualifications (qualification_id), + UNIQUE (worker_id, qualification_id) + ); + /* Copy data from backed up table and set value from `creation_date` to `update_date` */ + INSERT INTO _granted_qualifications + SELECT + granted_qualification_id, + worker_id, + qualification_id, + value, + creation_date, + creation_date + FROM granted_qualifications; + DROP TABLE granted_qualifications; + ALTER TABLE _granted_qualifications RENAME TO granted_qualifications; + + + /* Unit Review */ + CREATE TABLE IF NOT EXISTS _unit_review ( + id INTEGER PRIMARY KEY, + unit_id INTEGER NOT NULL, + worker_id INTEGER NOT NULL, + task_id INTEGER NOT NULL, + status TEXT NOT NULL, + review_note TEXT, + bonus INTEGER, + blocked_worker BOOLEAN DEFAULT false, + /* ID of `db.qualifications` (not `db.granted_qualifications`) */ + updated_qualification_id INTEGER, + updated_qualification_value INTEGER, + /* ID of `db.qualifications` (not `db.granted_qualifications`) */ + revoked_qualification_id INTEGER, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + + FOREIGN KEY (unit_id) REFERENCES units (unit_id), + FOREIGN KEY (worker_id) REFERENCES workers (worker_id), + FOREIGN KEY (task_id) REFERENCES tasks (task_id) + ); + INSERT INTO _unit_review SELECT * FROM unit_review; + DROP TABLE unit_review; + ALTER TABLE _unit_review RENAME TO unit_review; + + + /* Enable FK constraints back */ + PRAGMA foreign_keys = on; +""" diff --git a/mephisto/abstractions/databases/migrations/__init__.py b/mephisto/abstractions/databases/migrations/__init__.py new file mode 100644 index 000000000..092965e1b --- /dev/null +++ b/mephisto/abstractions/databases/migrations/__init__.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from ._001_20240325_preparing_db_for_merge_dbs_command import * + + +migrations = { + "20240418_preparing_db_for_merge_dbs_command": PREPARING_DB_FOR_MERGE_DBS_COMMAND, +} diff --git a/mephisto/abstractions/providers/mock/mock_datastore.py b/mephisto/abstractions/providers/mock/mock_datastore.py index 5737d8d22..da17a832d 100644 --- a/mephisto/abstractions/providers/mock/mock_datastore.py +++ b/mephisto/abstractions/providers/mock/mock_datastore.py @@ -4,34 +4,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import sqlite3 import os +import sqlite3 import threading +from typing import Any +from typing import Dict -from datetime import datetime - -from typing import Dict, Any, Optional +from mephisto.utils.db import check_if_row_with_params_exists +from . import mock_datastore_tables as tables +from .mock_datastore_export import export_datastore MTURK_REGION_NAME = "us-east-1" -CREATE_REQUESTERS_TABLE = """CREATE TABLE IF NOT EXISTS requesters ( - requester_id TEXT PRIMARY KEY UNIQUE, - is_registered BOOLEAN -); -""" - -CREATE_UNITS_TABLE = """CREATE TABLE IF NOT EXISTS units ( - unit_id TEXT PRIMARY KEY UNIQUE, - is_expired BOOLEAN -); -""" - -CREATE_WORKERS_TABLE = """CREATE TABLE IF NOT EXISTS workers ( - worker_id TEXT PRIMARY KEY UNIQUE, - is_blocked BOOLEAN -); -""" - class MockDatastore: """ @@ -48,8 +32,9 @@ def __init__(self, datastore_root: str): self.init_tables() self.datastore_root = datastore_root - def _get_connection(self) -> sqlite3.Connection: - """Returns a singular database connection to be shared amongst all + def get_connection(self) -> sqlite3.Connection: + """ + Returns a singular database connection to be shared amongst all calls for a given thread. """ curr_thread = threading.get_ident() @@ -64,37 +49,58 @@ def init_tables(self) -> None: Run all the table creation SQL queries to ensure the expected tables exist """ with self.table_access_condition: - conn = self._get_connection() - conn.execute("PRAGMA foreign_keys = 1") - c = conn.cursor() - c.execute(CREATE_REQUESTERS_TABLE) - c.execute(CREATE_UNITS_TABLE) - c.execute(CREATE_WORKERS_TABLE) - conn.commit() + conn = self.get_connection() + conn.execute("PRAGMA foreign_keys = on;") + + with conn: + c = conn.cursor() + c.execute(tables.CREATE_IF_NOT_EXISTS_REQUESTERS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_UNITS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_WORKERS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_MIGRATIONS_TABLE) + + def get_export_data(self, **kwargs) -> dict: + return export_datastore(self, **kwargs) def ensure_requester_exists(self, requester_id: str) -> None: """Create a record of this requester if it doesn't exist""" + already_exists = check_if_row_with_params_exists( + db=self, + table_name="requesters", + params={ + "requester_id": requester_id, + "is_registered": False, + }, + select_field="requester_id", + ) + with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() - c.execute( - """INSERT OR IGNORE INTO requesters( - requester_id, - is_registered - ) VALUES (?, ?);""", - (requester_id, False), - ) - conn.commit() + + if not already_exists: + c.execute( + """ + INSERT INTO requesters( + requester_id, + is_registered + ) VALUES (?, ?); + """, + (requester_id, False), + ) + conn.commit() + return None def set_requester_registered(self, requester_id: str, val: bool) -> None: """Set the requester registration status for the given id""" self.ensure_requester_exists(requester_id) with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( - """UPDATE requesters + """ + UPDATE requesters SET is_registered = ? WHERE requester_id = ? """, @@ -107,7 +113,7 @@ def get_requester_registered(self, requester_id: str) -> bool: """Get the registration status of a requester""" self.ensure_requester_exists(requester_id) with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ @@ -121,27 +127,43 @@ def get_requester_registered(self, requester_id: str) -> bool: def ensure_worker_exists(self, worker_id: str) -> None: """Create a record of this worker if it doesn't exist""" + already_exists = check_if_row_with_params_exists( + db=self, + table_name="workers", + params={ + "worker_id": worker_id, + "is_blocked": False, + }, + select_field="worker_id", + ) + with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() - c.execute( - """INSERT OR IGNORE INTO workers( - worker_id, - is_blocked - ) VALUES (?, ?);""", - (worker_id, False), - ) - conn.commit() + + if not already_exists: + c.execute( + """ + INSERT INTO workers( + worker_id, + is_blocked + ) VALUES (?, ?); + """, + (worker_id, False), + ) + conn.commit() + return None def set_worker_blocked(self, worker_id: str, val: bool) -> None: """Set the worker registration status for the given id""" self.ensure_worker_exists(worker_id) with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( - """UPDATE workers + """ + UPDATE workers SET is_blocked = ? WHERE worker_id = ? """, @@ -154,7 +176,7 @@ def get_worker_blocked(self, worker_id: str) -> bool: """Get the registration status of a worker""" self.ensure_worker_exists(worker_id) with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ @@ -168,27 +190,43 @@ def get_worker_blocked(self, worker_id: str) -> bool: def ensure_unit_exists(self, unit_id: str) -> None: """Create a record of this unit if it doesn't exist""" + already_exists = check_if_row_with_params_exists( + db=self, + table_name="units", + params={ + "unit_id": unit_id, + "is_expired": False, + }, + select_field="unit_id", + ) + with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() - c.execute( - """INSERT OR IGNORE INTO units( - unit_id, - is_expired - ) VALUES (?, ?);""", - (unit_id, False), - ) - conn.commit() + + if not already_exists: + c.execute( + """ + INSERT INTO units( + unit_id, + is_expired + ) VALUES (?, ?); + """, + (unit_id, False), + ) + conn.commit() + return None def set_unit_expired(self, unit_id: str, val: bool) -> None: """Set the unit registration status for the given id""" self.ensure_unit_exists(unit_id) with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( - """UPDATE units + """ + UPDATE units SET is_expired = ? WHERE unit_id = ? """, @@ -201,7 +239,7 @@ def get_unit_expired(self, unit_id: str) -> bool: """Get the registration status of a unit""" self.ensure_unit_exists(unit_id) with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ diff --git a/mephisto/abstractions/providers/mock/mock_datastore_export.py b/mephisto/abstractions/providers/mock/mock_datastore_export.py new file mode 100644 index 000000000..f3fc14ecd --- /dev/null +++ b/mephisto/abstractions/providers/mock/mock_datastore_export.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List +from typing import Optional + +from mephisto.utils import db as db_utils + + +def export_datastore( + datastore: "MockDatastore", + mephisto_db_data: dict, + task_run_ids: Optional[List[str]] = None, + **kwargs, +) -> dict: + """Logic of collecting export data from Mock datastore""" + + dump_data = db_utils.db_or_datastore_to_dict(datastore) + + if not task_run_ids: + # Exporting the entire DB + return dump_data + + # Find and serialize `units` + unit_ids = [i["unit_id"] for i in mephisto_db_data["units"]] + unit_rows = db_utils.select_rows_by_list_of_field_values( + datastore, "units", ["unit_id"], [unit_ids], + ) + dump_data["units"] = db_utils.serialize_data_for_table(unit_rows) + + # Find and serialize `workers` + worker_ids = [i["worker_id"] for i in mephisto_db_data["workers"]] + workers_rows = db_utils.select_rows_by_list_of_field_values( + datastore, "workers", ["worker_id"], [worker_ids], + ) + dump_data["workers"] = db_utils.serialize_data_for_table(workers_rows) + + return dump_data diff --git a/mephisto/abstractions/providers/mock/mock_datastore_tables.py b/mephisto/abstractions/providers/mock/mock_datastore_tables.py new file mode 100644 index 000000000..de4e2aa08 --- /dev/null +++ b/mephisto/abstractions/providers/mock/mock_datastore_tables.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +WARNING: In this module you can find initial table structures, but not final. +There are can be changes in migrations. To see actual fields, constraints, etc., +see information in databases or look through all migrations for current database +""" + +from mephisto.abstractions.databases import local_database_tables + +CREATE_IF_NOT_EXISTS_REQUESTERS_TABLE = """ + CREATE TABLE IF NOT EXISTS requesters ( + requester_id TEXT PRIMARY KEY UNIQUE, + is_registered BOOLEAN + ); +""" + +CREATE_IF_NOT_EXISTS_UNITS_TABLE = """ + CREATE TABLE IF NOT EXISTS units ( + unit_id TEXT PRIMARY KEY UNIQUE, + is_expired BOOLEAN + ); +""" + +CREATE_IF_NOT_EXISTS_WORKERS_TABLE = """ + CREATE TABLE IF NOT EXISTS workers ( + worker_id TEXT PRIMARY KEY UNIQUE, + is_blocked BOOLEAN + ); +""" + +CREATE_IF_NOT_EXISTS_MIGRATIONS_TABLE = local_database_tables.CREATE_IF_NOT_EXISTS_MIGRATIONS_TABLE diff --git a/mephisto/abstractions/providers/mturk/mturk_datastore.py b/mephisto/abstractions/providers/mturk/mturk_datastore.py index 06d0917eb..517cae0b0 100644 --- a/mephisto/abstractions/providers/mturk/mturk_datastore.py +++ b/mephisto/abstractions/providers/mturk/mturk_datastore.py @@ -4,66 +4,27 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import boto3 # type: ignore -import sqlite3 import os +import sqlite3 import threading import time - -from datetime import datetime from collections import defaultdict +from typing import Any +from typing import Dict +from typing import Optional - +import boto3 # type: ignore from botocore.exceptions import ClientError # type: ignore from botocore.exceptions import ProfileNotFound # type: ignore -from mephisto.abstractions.databases.local_database import is_unique_failure - -from typing import Dict, Any, Optional +from mephisto.abstractions.databases.local_database import is_unique_failure from mephisto.utils.logger_core import get_logger - -logger = get_logger(name=__name__) +from . import mturk_datastore_tables as tables +from .mturk_datastore_export import export_datastore MTURK_REGION_NAME = "us-east-1" -CREATE_HITS_TABLE = """CREATE TABLE IF NOT EXISTS hits ( - hit_id TEXT PRIMARY KEY UNIQUE, - unit_id TEXT, - assignment_id TEXT, - link TEXT, - assignment_time_in_seconds INTEGER NOT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP -); -""" - -CREATE_RUN_MAP_TABLE = """CREATE TABLE IF NOT EXISTS run_mappings ( - hit_id TEXT, - run_id TEXT -); -""" - -CREATE_RUNS_TABLE = """CREATE TABLE IF NOT EXISTS runs ( - run_id TEXT PRIMARY KEY UNIQUE, - arn_id TEXT, - hit_type_id TEXT NOT NULL, - hit_config_path TEXT NOT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, - frame_height INTEGER NOT NULL DEFAULT 650 -); -""" - -UPDATE_RUNS_TABLE_1 = """ALTER TABLE runs - ADD COLUMN frame_height INTEGER NOT NULL DEFAULT 650; -""" - -CREATE_QUALIFICATIONS_TABLE = """CREATE TABLE IF NOT EXISTS qualifications ( - qualification_name TEXT PRIMARY KEY UNIQUE, - requester_id TEXT, - mturk_qualification_name TEXT, - mturk_qualification_id TEXT, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP -); -""" +logger = get_logger(name=__name__) class MTurkDatastore: @@ -86,8 +47,9 @@ def __init__(self, datastore_root: str): lambda: time.monotonic() ) - def _get_connection(self) -> sqlite3.Connection: - """Returns a singular database connection to be shared amongst all + def get_connection(self) -> sqlite3.Connection: + """ + Returns a singular database connection to be shared amongst all calls for a given thread. """ curr_thread = threading.get_ident() @@ -109,21 +71,28 @@ def init_tables(self) -> None: Run all the table creation SQL queries to ensure the expected tables exist """ with self.table_access_condition: - conn = self._get_connection() - conn.execute("PRAGMA foreign_keys = 1") + conn = self.get_connection() + conn.execute("PRAGMA foreign_keys = on;") + with conn: c = conn.cursor() - c.execute(CREATE_HITS_TABLE) - c.execute(CREATE_RUNS_TABLE) - c.execute(CREATE_RUN_MAP_TABLE) - c.execute(CREATE_QUALIFICATIONS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_HITS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_RUNS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_RUN_MAP_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_QUALIFICATIONS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_MIGRATIONS_TABLE) + + # Migrations with conn: try: c = conn.cursor() - c.execute(UPDATE_RUNS_TABLE_1) - except Exception as _e: + c.execute(tables.UPDATE_RUNS_TABLE_1) + except Exception: pass # extra column already exists + def get_export_data(self, **kwargs) -> dict: + return export_datastore(self, **kwargs) + def is_hit_mapping_in_sync(self, unit_id: str, compare_time: float): """ Determine if a cached value from the given compare time is still valid @@ -132,21 +101,25 @@ def is_hit_mapping_in_sync(self, unit_id: str, compare_time: float): def new_hit(self, hit_id: str, hit_link: str, duration: int, run_id: str) -> None: """Register a new HIT mapping in the table""" - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() c.execute( - """INSERT INTO hits( + """ + INSERT INTO hits( hit_id, link, assignment_time_in_seconds - ) VALUES (?, ?, ?);""", + ) VALUES (?, ?, ?); + """, (hit_id, hit_link, duration), ) c.execute( - """INSERT INTO run_mappings( + """ + INSERT INTO run_mappings( hit_id, run_id - ) VALUES (?, ?);""", + ) VALUES (?, ?); + """, (hit_id, run_id), ) @@ -155,7 +128,7 @@ def get_unassigned_hit_ids(self, run_id: str): Return a list of all HIT ids that haven't been assigned """ with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ @@ -188,7 +161,7 @@ def register_assignment_to_hit( logger.debug( f"Attempting to assign HIT {hit_id}, Unit {unit_id}, Assignment {assignment_id}." ) - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() c.execute( """ @@ -204,7 +177,8 @@ def register_assignment_to_hit( logger.debug(f"Cleared HIT mapping cache for previous unit, {old_unit_id}") c.execute( - """UPDATE hits + """ + UPDATE hits SET assignment_id = ?, unit_id = ? WHERE hit_id = ? """, @@ -218,7 +192,7 @@ def clear_hit_from_unit(self, unit_id: str) -> None: Clear the hit mapping that maps the given unit, if such a unit-hit map exists """ - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() c.execute( """ @@ -238,7 +212,8 @@ def clear_hit_from_unit(self, unit_id: str) -> None: ) result_hit_id = results[0]["hit_id"] c.execute( - """UPDATE hits + """ + UPDATE hits SET assignment_id = ?, unit_id = ? WHERE hit_id = ? """, @@ -249,7 +224,7 @@ def clear_hit_from_unit(self, unit_id: str) -> None: def get_hit_mapping(self, unit_id: str) -> sqlite3.Row: """Get the mapping between Mephisto IDs and MTurk ids""" with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ @@ -269,23 +244,31 @@ def register_run( frame_height: int = 0, ) -> None: """Register a new task run in the mturk table""" - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() c.execute( - """INSERT INTO runs( + """ + INSERT INTO runs( run_id, arn_id, hit_type_id, hit_config_path, frame_height - ) VALUES (?, ?, ?, ?, ?);""", - (run_id, "unused", hit_type_id, hit_config_path, frame_height), + ) VALUES (?, ?, ?, ?, ?); + """, + ( + run_id, + "unused", + hit_type_id, + hit_config_path, + frame_height, + ), ) def get_run(self, run_id: str) -> sqlite3.Row: """Get the details for a run by task_run_id""" with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ @@ -311,15 +294,17 @@ def create_qualification_mapping( Repeat entries with the same `qualification_name` will be idempotent """ try: - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() c.execute( - """INSERT INTO qualifications( + """ + INSERT INTO qualifications( qualification_name, requester_id, mturk_qualification_name, mturk_qualification_id - ) VALUES (?, ?, ?, ?);""", + ) VALUES (?, ?, ?, ?); + """, ( qualification_name, requester_id, @@ -337,19 +322,23 @@ def create_qualification_mapping( f"Found existing one: {qual}. " ) assert qual is not None, "Cannot be none given is_unique_failure on insert" + cur_requester_id = qual["requester_id"] cur_mturk_qualification_name = qual["mturk_qualification_name"] - cur_mturk_qualification_id = qual["mturk_qualification_id"] if cur_requester_id != requester_id: logger.warning( - f"MTurk Qualification mapping create for {qualification_name} under requester " - f"{requester_id}, already exists under {cur_requester_id}." + f"MTurk Qualification mapping create for {qualification_name} " + f"under requester {requester_id}, already exists under {cur_requester_id}." ) + if cur_mturk_qualification_name != mturk_qualification_name: logger.warning( - f"MTurk Qualification mapping create for {qualification_name} with mturk name " - f"{mturk_qualification_name}, already exists under {cur_mturk_qualification_name}." + f"MTurk Qualification mapping create " + f"for {qualification_name} with mturk name " + f"{mturk_qualification_name}, already exists " + f"under {cur_mturk_qualification_name}." ) + return None else: raise e @@ -357,7 +346,7 @@ def create_qualification_mapping( def get_qualification_mapping(self, qualification_name: str) -> Optional[sqlite3.Row]: """Get the mapping between Mephisto qualifications and MTurk qualifications""" with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ diff --git a/mephisto/abstractions/providers/mturk/mturk_datastore_export.py b/mephisto/abstractions/providers/mturk/mturk_datastore_export.py new file mode 100644 index 000000000..f538801a5 --- /dev/null +++ b/mephisto/abstractions/providers/mturk/mturk_datastore_export.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List +from typing import Optional + +from mephisto.utils import db as db_utils + + +def export_datastore( + datastore: "MTurkDatastore", + mephisto_db_data: dict, + task_run_ids: Optional[List[str]] = None, + **kwargs, +) -> dict: + """Logic of collecting export data from MTurk datastore""" + + dump_data = db_utils.db_or_datastore_to_dict(datastore) + + if not task_run_ids: + # Exporting the entire DB + return dump_data + + tables_with_task_run_relations = [ + "run_mappings", + "runs", + ] + + for table_name in tables_with_task_run_relations: + table_rows = db_utils.select_rows_by_list_of_field_values( + datastore, table_name, ["run_id"], [task_run_ids], + ) + runs_table_data = db_utils.serialize_data_for_table(table_rows) + dump_data[table_name] = runs_table_data + + # Find and serialize `hits` + hit_ids = list(set(filter(bool, [i["hit_id"] for i in dump_data["run_mappings"]]))) + hit_rows = db_utils.select_rows_by_list_of_field_values( + datastore, "hits", ["hit_id"], [hit_ids], + ) + dump_data["hits"] = db_utils.serialize_data_for_table(hit_rows) + + # Find and serialize `qualifications` + qualification_names = [i["qualification_name"] for i in mephisto_db_data["qualifications"]] + if qualification_names: + qualification_rows = db_utils.select_rows_by_list_of_field_values( + datastore, "qualifications", ["qualification_name"], [qualification_names], + ) + else: + qualification_rows = db_utils.select_all_table_rows(datastore, "qualifications") + dump_data["qualifications"] = db_utils.serialize_data_for_table(qualification_rows) + + return dump_data diff --git a/mephisto/abstractions/providers/mturk/mturk_datastore_tables.py b/mephisto/abstractions/providers/mturk/mturk_datastore_tables.py new file mode 100644 index 000000000..49a658d32 --- /dev/null +++ b/mephisto/abstractions/providers/mturk/mturk_datastore_tables.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +WARNING: In this module you can find initial table structures, but not final. +There are can be changes in migrations. To see actual fields, constraints, etc., +see information in databases or look through all migrations for current database +""" + +from mephisto.abstractions.databases import local_database_tables + + +CREATE_IF_NOT_EXISTS_HITS_TABLE = """ + CREATE TABLE IF NOT EXISTS hits ( + hit_id TEXT PRIMARY KEY UNIQUE, + unit_id TEXT, + assignment_id TEXT, + link TEXT, + assignment_time_in_seconds INTEGER NOT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + ); +""" + +CREATE_IF_NOT_EXISTS_RUN_MAP_TABLE = """ + CREATE TABLE IF NOT EXISTS run_mappings ( + hit_id TEXT, + run_id TEXT + ); +""" + +CREATE_IF_NOT_EXISTS_RUNS_TABLE = """ + CREATE TABLE IF NOT EXISTS runs ( + run_id TEXT PRIMARY KEY UNIQUE, + arn_id TEXT, + hit_type_id TEXT NOT NULL, + hit_config_path TEXT NOT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + frame_height INTEGER NOT NULL DEFAULT 650 + ); +""" + +CREATE_IF_NOT_EXISTS_QUALIFICATIONS_TABLE = """ + CREATE TABLE IF NOT EXISTS qualifications ( + qualification_name TEXT PRIMARY KEY UNIQUE, + requester_id TEXT, + mturk_qualification_name TEXT, + mturk_qualification_id TEXT, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + ); +""" + +CREATE_IF_NOT_EXISTS_MIGRATIONS_TABLE = local_database_tables.CREATE_IF_NOT_EXISTS_MIGRATIONS_TABLE + + +# Migrations +# TODO: Refactor it as it was done for other databases + +UPDATE_RUNS_TABLE_1 = """ + ALTER TABLE runs ADD COLUMN frame_height INTEGER NOT NULL DEFAULT 650; +""" diff --git a/mephisto/abstractions/providers/prolific/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py b/mephisto/abstractions/providers/prolific/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py new file mode 100644 index 000000000..6dbe1c116 --- /dev/null +++ b/mephisto/abstractions/providers/prolific/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +1. Remove autoincrement parameter for all Primary Keys +2. Added `update_date` and `creation_date` in `workers` table +3. Added `creation_date` in `units` table +4. Rename field `run_id` -> `task_run_id` +5. Remove table `requesters` +""" + + +PREPARING_DB_FOR_MERGE_DBS_COMMAND = """ + /* Disable FK constraints */ + PRAGMA foreign_keys = off; + + + /* Studies */ + CREATE TABLE IF NOT EXISTS _studies ( + id INTEGER PRIMARY KEY, + prolific_study_id TEXT UNIQUE, + status TEXT, + link TEXT, + task_run_id TEXT UNIQUE, + assignment_time_in_seconds INTEGER NOT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + ); + INSERT INTO _studies SELECT * FROM studies; + DROP TABLE studies; + ALTER TABLE _studies RENAME TO studies; + + + /* Submissions */ + CREATE TABLE IF NOT EXISTS _submissions ( + id INTEGER PRIMARY KEY, + prolific_submission_id TEXT UNIQUE, + prolific_study_id TEXT, + status TEXT DEFAULT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + ); + INSERT INTO _submissions SELECT * FROM submissions; + DROP TABLE submissions; + ALTER TABLE _submissions RENAME TO submissions; + + + /* Run Mappings */ + CREATE TABLE IF NOT EXISTS _run_mappings ( + id INTEGER PRIMARY KEY, + prolific_study_id TEXT, + run_id TEXT + ); + INSERT INTO _run_mappings SELECT * FROM run_mappings; + DROP TABLE run_mappings; + ALTER TABLE _run_mappings RENAME TO run_mappings; + + + /* Units */ + CREATE TABLE IF NOT EXISTS _units ( + id INTEGER PRIMARY KEY, + unit_id TEXT UNIQUE, + run_id TEXT, + prolific_study_id TEXT, + prolific_submission_id TEXT, + is_expired BOOLEAN DEFAULT false, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + ); + /* Copy data from backed up table and set values for `creation_date` */ + INSERT INTO _units + SELECT + id, + unit_id, + run_id, + prolific_study_id, + prolific_submission_id, + is_expired, + datetime('now', 'localtime') + FROM units; + DROP TABLE units; + ALTER TABLE _units RENAME TO units; + + + /* Workers */ + CREATE TABLE IF NOT EXISTS _workers ( + id INTEGER PRIMARY KEY, + worker_id TEXT UNIQUE, + is_blocked BOOLEAN default false, + update_date DATETIME DEFAULT CURRENT_TIMESTAMP, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + ); + /* Copy data from backed up table and set values for `creation_date` and `update_date` */ + INSERT INTO _workers + SELECT + id, + worker_id, + is_blocked, + datetime('now', 'localtime'), + datetime('now', 'localtime') + FROM workers; + DROP TABLE workers; + ALTER TABLE _workers RENAME TO workers; + + + /* Runs */ + CREATE TABLE IF NOT EXISTS _runs ( + id INTEGER PRIMARY KEY, + run_id TEXT UNIQUE, + arn_id TEXT, + prolific_workspace_id TEXT NOT NULL, + prolific_project_id TEXT NOT NULL, + prolific_study_id TEXT, + prolific_study_config_path TEXT NOT NULL, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + frame_height INTEGER NOT NULL DEFAULT 650, + actual_available_places INTEGER DEFAULT NULL, + listed_available_places INTEGER DEFAULT NULL + ); + INSERT INTO _runs SELECT * FROM runs; + DROP TABLE runs; + ALTER TABLE _runs RENAME TO runs; + + + /* Participant Groups */ + CREATE TABLE IF NOT EXISTS _participant_groups ( + id INTEGER PRIMARY KEY, + qualification_name TEXT UNIQUE, + requester_id TEXT, + prolific_project_id TEXT, + prolific_participant_group_name TEXT, + prolific_participant_group_id TEXT, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + ); + INSERT INTO _participant_groups SELECT * FROM participant_groups; + DROP TABLE participant_groups; + ALTER TABLE _participant_groups RENAME TO participant_groups; + + + /* Runs */ + CREATE TABLE IF NOT EXISTS _qualifications ( + id INTEGER PRIMARY KEY, + prolific_participant_group_id TEXT, + task_run_id TEXT, + json_qual_logic TEXT, + qualification_ids TEXT, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + ); + INSERT INTO _qualifications SELECT * FROM qualifications; + DROP TABLE qualifications; + ALTER TABLE _qualifications RENAME TO qualifications; + + + /* Enable FK constraints back */ + PRAGMA foreign_keys = on; + + + ALTER TABLE run_mappings RENAME COLUMN run_id TO task_run_id; + ALTER TABLE units RENAME COLUMN run_id TO task_run_id; + ALTER TABLE runs RENAME COLUMN run_id TO task_run_id; + + + DROP TABLE IF EXISTS requesters; +""" diff --git a/mephisto/abstractions/providers/prolific/migrations/__init__.py b/mephisto/abstractions/providers/prolific/migrations/__init__.py new file mode 100644 index 000000000..092965e1b --- /dev/null +++ b/mephisto/abstractions/providers/prolific/migrations/__init__.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from ._001_20240325_preparing_db_for_merge_dbs_command import * + + +migrations = { + "20240418_preparing_db_for_merge_dbs_command": PREPARING_DB_FOR_MERGE_DBS_COMMAND, +} diff --git a/mephisto/abstractions/providers/prolific/prolific_datastore.py b/mephisto/abstractions/providers/prolific/prolific_datastore.py index ee273f8b1..784467d9e 100644 --- a/mephisto/abstractions/providers/prolific/prolific_datastore.py +++ b/mephisto/abstractions/providers/prolific/prolific_datastore.py @@ -18,10 +18,18 @@ from mephisto.abstractions.databases.local_database import is_unique_failure from mephisto.abstractions.providers.prolific.api.constants import StudyStatus from mephisto.abstractions.providers.prolific.provider_type import PROVIDER_TYPE +from mephisto.utils.db import apply_migrations +from mephisto.utils.db import check_if_row_with_params_exists +from mephisto.utils.db import EntryAlreadyExistsException +from mephisto.utils.db import make_randomized_int_id +from mephisto.utils.db import MephistoDBException +from mephisto.utils.db import retry_generate_id from mephisto.utils.logger_core import get_logger from mephisto.utils.qualifications import QualificationType from . import prolific_datastore_tables as tables from .api.client import ProlificClient +from .migrations import migrations +from .prolific_datastore_export import export_datastore from .prolific_utils import get_authenticated_client logger = get_logger(name=__name__) @@ -41,7 +49,7 @@ def __init__(self, datastore_root: str): lambda: time.monotonic() ) - def _get_connection(self) -> sqlite3.Connection: + def get_connection(self) -> sqlite3.Connection: """ Returns a singular database connection to be shared amongst all calls for a given thread. """ @@ -62,24 +70,31 @@ def _mark_study_mapping_update(self, unit_id: str) -> None: def init_tables(self) -> None: """Run all the table creation SQL queries to ensure the expected tables exist""" with self.table_access_condition: - conn = self._get_connection() - conn.execute("PRAGMA foreign_keys = 1") - c = conn.cursor() - c.execute(tables.CREATE_STUDIES_TABLE) - c.execute(tables.CREATE_SUBMISSIONS_TABLE) - c.execute(tables.CREATE_REQUESTERS_TABLE) - c.execute(tables.CREATE_UNITS_TABLE) - c.execute(tables.CREATE_WORKERS_TABLE) - c.execute(tables.CREATE_RUNS_TABLE) - c.execute(tables.CREATE_RUN_MAP_TABLE) - c.execute(tables.CREATE_PARTICIPANT_GROUPS_TABLE) - c.execute(tables.CREATE_PARTICIPANT_GROUP_QUALIFICATIONS_MAPPING_TABLE) - conn.commit() + conn = self.get_connection() + conn.execute("PRAGMA foreign_keys = on;") + + with conn: + c = conn.cursor() + c.execute(tables.CREATE_IF_NOT_EXISTS_STUDIES_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_SUBMISSIONS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_UNITS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_WORKERS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_RUNS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_RUN_MAP_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_PARTICIPANT_GROUPS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_QUALIFICATIONS_TABLE) + c.execute(tables.CREATE_IF_NOT_EXISTS_MIGRATIONS_TABLE) + + apply_migrations(self, migrations) + + def get_export_data(self, **kwargs) -> dict: + return export_datastore(self, **kwargs) def is_study_mapping_in_sync(self, unit_id: str, compare_time: float): """Determine if a cached value from the given compare time is still valid""" return compare_time > self._last_study_mapping_update_times[unit_id] + @retry_generate_id(caught_excs=[EntryAlreadyExistsException]) def new_study( self, prolific_study_id: str, @@ -88,35 +103,74 @@ def new_study( task_run_id: str, status: str = StudyStatus.UNPUBLISHED, ) -> None: - """Register a new Study mapping in the table""" - with self.table_access_condition, self._get_connection() as conn: + """Register a new Study in the table""" + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() - c.execute( - """ - INSERT INTO studies( - prolific_study_id, - task_run_id, - link, - assignment_time_in_seconds, - status - ) VALUES (?, ?, ?, ?, ?); - """, - (prolific_study_id, task_run_id, study_link, duration_in_seconds, status), - ) - c.execute( - """ - INSERT INTO run_mappings( - prolific_study_id, - run_id - ) VALUES (?, ?); - """, - (prolific_study_id, task_run_id), - ) + try: + c.execute( + """ + INSERT INTO studies( + id, + prolific_study_id, + task_run_id, + link, + assignment_time_in_seconds, + status + ) VALUES (?, ?, ?, ?, ?, ?); + """, + ( + make_randomized_int_id(), + prolific_study_id, + task_run_id, + study_link, + duration_in_seconds, + status, + ), + ) + except sqlite3.IntegrityError as e: + if is_unique_failure(e): + raise EntryAlreadyExistsException( + e, + db=self, + table_name="studies", + original_exc=e, + ) + raise MephistoDBException(e) + + @retry_generate_id(caught_excs=[EntryAlreadyExistsException]) + def new_run_mapping(self, prolific_study_id: str, task_run_id: str) -> None: + """Register a new Run mapping in the table""" + with self.table_access_condition, self.get_connection() as conn: + c = conn.cursor() + try: + c.execute( + """ + INSERT INTO run_mappings( + id, + prolific_study_id, + task_run_id + ) VALUES (?, ?, ?); + """, + ( + make_randomized_int_id(), + prolific_study_id, + task_run_id, + ), + ) + except sqlite3.IntegrityError as e: + if is_unique_failure(e): + raise EntryAlreadyExistsException( + e, + db=self, + table_name="run_mappings", + original_exc=e, + ) + raise MephistoDBException(e) def update_study_status(self, study_id: str, status: str) -> None: """Set the study status in datastore""" with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ @@ -129,10 +183,10 @@ def update_study_status(self, study_id: str, status: str) -> None: conn.commit() return None - def all_study_units_are_expired(self, run_id: str) -> bool: + def all_study_units_are_expired(self, task_run_id: str) -> bool: """Return a list of all Study ids that haven't been assigned""" with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( @@ -150,15 +204,16 @@ def all_study_units_are_expired(self, run_id: str) -> bool: FROM studies INNER JOIN run_mappings USING (prolific_study_id) WHERE - run_mappings.run_id = ? AND + run_mappings.task_run_id = ? AND unexpired_units_count == 0 GROUP BY prolific_study_id; """, - (run_id,), + (task_run_id,), ) results = c.fetchall() return bool(results) + @retry_generate_id(caught_excs=[EntryAlreadyExistsException]) def register_submission_to_study( self, prolific_study_id: str, @@ -174,18 +229,44 @@ def register_submission_to_study( f"Unit {unit_id}, " f"Submission {prolific_submission_id}." ) - with self.table_access_condition, self._get_connection() as conn: + already_exists = check_if_row_with_params_exists( + db=self, + table_name="submissions", + params={ + "prolific_study_id": prolific_study_id, + "prolific_submission_id": prolific_submission_id, + }, + select_field="id", + ) + + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() - c.execute( - """ - INSERT OR IGNORE INTO submissions( - prolific_study_id, - prolific_submission_id - ) VALUES (?, ?); - """, - (prolific_study_id, prolific_submission_id), - ) + if not already_exists: + try: + c.execute( + """ + INSERT INTO submissions( + id, + prolific_study_id, + prolific_submission_id + ) VALUES (?, ?, ?); + """, + ( + make_randomized_int_id(), + prolific_study_id, + prolific_submission_id, + ), + ) + except sqlite3.IntegrityError as e: + if is_unique_failure(e): + raise EntryAlreadyExistsException( + e, + db=self, + table_name="submissions", + original_exc=e, + ) + raise MephistoDBException(e) if unit_id is not None: self._mark_study_mapping_update(unit_id) @@ -193,7 +274,7 @@ def register_submission_to_study( def update_submission_status(self, prolific_submission_id: str, status: str) -> None: """Set prolific_submission_id to unit""" with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ @@ -206,78 +287,57 @@ def update_submission_status(self, prolific_submission_id: str, status: str) -> conn.commit() return None - def ensure_requester_exists(self, requester_id: str) -> None: - """Create a record of this requester if it doesn't exist""" - with self.table_access_condition: - conn = self._get_connection() - c = conn.cursor() - c.execute( - """ - INSERT OR IGNORE INTO requesters( - requester_id, - is_registered - ) VALUES (?, ?); - """, - (requester_id, False), - ) - conn.commit() - return None + @retry_generate_id(caught_excs=[EntryAlreadyExistsException]) + def ensure_worker_exists(self, worker_id: str) -> None: + """Create a record of this worker if it doesn't exist""" + already_exists = check_if_row_with_params_exists( + db=self, + table_name="workers", + params={ + "worker_id": worker_id, + "is_blocked": False, + }, + select_field="id", + ) - def set_requester_registered(self, requester_id: str, val: bool) -> None: - """Set the requester registration status for the given id""" - self.ensure_requester_exists(requester_id) with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() - c.execute( - """ - UPDATE requesters - SET is_registered = ? - WHERE requester_id = ? - """, - (val, requester_id), - ) - conn.commit() - return None - def get_requester_registered(self, requester_id: str) -> bool: - """Get the registration status of a requester""" - self.ensure_requester_exists(requester_id) - with self.table_access_condition: - conn = self._get_connection() - c = conn.cursor() - c.execute( - """ - SELECT is_registered FROM requesters - WHERE requester_id = ? - """, - (requester_id,), - ) - results = c.fetchall() - return bool(results[0]["is_registered"]) + if not already_exists: + try: + c.execute( + """ + INSERT INTO workers( + id, + worker_id, + is_blocked + ) VALUES (?, ?, ?); + """, + ( + make_randomized_int_id(), + worker_id, + False, + ), + ) + conn.commit() + except sqlite3.IntegrityError as e: + if is_unique_failure(e): + raise EntryAlreadyExistsException( + e, + db=self, + table_name="workers", + original_exc=e, + ) + raise MephistoDBException(e) - def ensure_worker_exists(self, worker_id: str) -> None: - """Create a record of this worker if it doesn't exist""" - with self.table_access_condition: - conn = self._get_connection() - c = conn.cursor() - c.execute( - """ - INSERT OR IGNORE INTO workers( - worker_id, - is_blocked - ) VALUES (?, ?); - """, - (worker_id, False), - ) - conn.commit() return None def set_worker_blocked(self, worker_id: str, is_blocked: bool) -> None: """Set the worker registration status for the given id""" self.ensure_worker_exists(worker_id) with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ @@ -294,7 +354,7 @@ def get_worker_blocked(self, worker_id: str) -> bool: """Get the blocked status of a worker""" self.ensure_worker_exists(worker_id) with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ @@ -309,7 +369,7 @@ def get_worker_blocked(self, worker_id: str) -> bool: def get_blocked_workers(self) -> List[dict]: """Get all workers with blocked status""" with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ @@ -324,46 +384,108 @@ def get_blocked_workers(self) -> List[dict]: def get_bloked_participant_ids(self) -> List[str]: return [w["worker_id"] for w in self.get_blocked_workers()] + @retry_generate_id(caught_excs=[EntryAlreadyExistsException]) def ensure_unit_exists(self, unit_id: str) -> None: """Create a record of this unit if it doesn't exist""" + already_exists = check_if_row_with_params_exists( + db=self, + table_name="units", + params={ + "unit_id": unit_id, + "is_expired": False, + }, + select_field="id", + ) + with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() - c.execute( - """ - INSERT OR IGNORE INTO units( - unit_id, - is_expired - ) VALUES (?, ?); - """, - (unit_id, False), - ) - conn.commit() + + if not already_exists: + try: + c.execute( + """ + INSERT INTO units( + id, + unit_id, + is_expired + ) VALUES (?, ?, ?); + """, + ( + make_randomized_int_id(), + unit_id, + False, + ), + ) + conn.commit() + except sqlite3.IntegrityError as e: + if is_unique_failure(e): + raise EntryAlreadyExistsException( + e, + db=self, + table_name="units", + original_exc=e, + ) + raise MephistoDBException(e) + return None - def create_unit(self, unit_id: str, run_id: str, prolific_study_id: str) -> None: + @retry_generate_id(caught_excs=[EntryAlreadyExistsException]) + def create_unit(self, unit_id: str, task_run_id: str, prolific_study_id: str) -> None: """Create the unit if not exists""" + already_exists = check_if_row_with_params_exists( + db=self, + table_name="units", + params={ + "unit_id": unit_id, + "task_run_id": task_run_id, + "prolific_study_id": prolific_study_id, + "is_expired": False, + }, + select_field="id", + ) + with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() - c.execute( - """ - INSERT OR IGNORE INTO units( - unit_id, - run_id, - prolific_study_id, - is_expired - ) VALUES (?, ?, ?, ?); - """, - (unit_id, run_id, prolific_study_id, False), - ) - conn.commit() + + if not already_exists: + try: + c.execute( + """ + INSERT INTO units( + id, + unit_id, + task_run_id, + prolific_study_id, + is_expired + ) VALUES (?, ?, ?, ?, ?); + """, + ( + make_randomized_int_id(), + unit_id, + task_run_id, + prolific_study_id, + False, + ), + ) + conn.commit() + except sqlite3.IntegrityError as e: + if is_unique_failure(e): + raise EntryAlreadyExistsException( + e, + db=self, + table_name="units", + original_exc=e, + ) + raise MephistoDBException(e) + return None def get_unit(self, unit_id: str) -> sqlite3.Row: """Get the details for a unit by unit_id""" with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ @@ -379,7 +501,7 @@ def set_unit_expired(self, unit_id: str, val: bool) -> None: """Set the unit registration status for the given id""" self.ensure_unit_exists(unit_id) with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ @@ -396,7 +518,7 @@ def get_unit_expired(self, unit_id: str) -> bool: """Get the registration status of a unit""" self.ensure_unit_exists(unit_id) with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ @@ -412,7 +534,7 @@ def set_submission_for_unit(self, unit_id: str, prolific_submission_id: str) -> """Set prolific_submission_id to unit""" self.ensure_unit_exists(unit_id) with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ @@ -446,7 +568,7 @@ def get_client_for_requester(self, requester_name: str) -> ProlificClient: def get_qualification_mapping(self, qualification_name: str) -> Optional[sqlite3.Row]: """Get the mapping between Mephisto qualifications and Prolific Participant Group""" with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ @@ -460,6 +582,7 @@ def get_qualification_mapping(self, qualification_name: str) -> Optional[sqlite3 return None return results[0] + @retry_generate_id(caught_excs=[EntryAlreadyExistsException]) def create_participant_group_mapping( self, qualification_name: str, @@ -475,19 +598,21 @@ def create_participant_group_mapping( Repeat entries with the same `qualification_name` will be idempotent """ try: - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() c.execute( """ INSERT INTO participant_groups( + id, qualification_name, requester_id, prolific_project_id, prolific_participant_group_name, prolific_participant_group_id - ) VALUES (?, ?, ?, ?, ?); + ) VALUES (?, ?, ?, ?, ?, ?); """, ( + make_randomized_int_id(), qualification_name, requester_id, prolific_project_id, @@ -528,6 +653,13 @@ def create_participant_group_mapping( ) return None + elif is_unique_failure(e): + raise EntryAlreadyExistsException( + e, + db=self, + table_name="participant_groups", + original_exc=e, + ) else: raise e @@ -539,7 +671,7 @@ def delete_participant_groups_by_participant_group_ids( if not participant_group_ids: return None - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() participant_group_ids_block = "" @@ -557,34 +689,47 @@ def delete_participant_groups_by_participant_group_ids( ) return None + @retry_generate_id(caught_excs=[EntryAlreadyExistsException]) def create_qualification_mapping( self, - run_id: str, + task_run_id: str, prolific_participant_group_id: str, qualifications: List[QualificationType], qualification_ids: List[int], ) -> None: """Register a new participant group mapping with qualifications""" - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() qualifications_json = json.dumps(qualifications) qualification_ids_json = json.dumps(qualification_ids) - c.execute( - """ - INSERT INTO qualifications( - prolific_participant_group_id, - task_run_id, - json_qual_logic, - qualification_ids - ) VALUES (?, ?, ?, ?); - """, - ( - prolific_participant_group_id, - run_id, - qualifications_json, - qualification_ids_json, - ), - ) + try: + c.execute( + """ + INSERT INTO qualifications( + id, + prolific_participant_group_id, + task_run_id, + json_qual_logic, + qualification_ids + ) VALUES (?, ?, ?, ?, ?); + """, + ( + make_randomized_int_id(), + prolific_participant_group_id, + task_run_id, + qualifications_json, + qualification_ids_json, + ), + ) + except sqlite3.IntegrityError as e: + if is_unique_failure(e): + raise EntryAlreadyExistsException( + e, + db=self, + table_name="qualifications", + original_exc=e, + ) + raise MephistoDBException(e) def find_studies_by_status(self, statuses: List[str], exclude: bool = False) -> List[dict]: """Find all studies having or excluding certain statuses""" @@ -594,7 +739,7 @@ def find_studies_by_status(self, statuses: List[str], exclude: bool = False) -> logic_str = "NOT" if exclude else "" statuses_str = ",".join([f'"{s}"' for s in statuses]) - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() c.execute( f""" @@ -632,7 +777,7 @@ def find_qualifications_by_ids( if not qualification_ids: return [] - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() qualification_ids_block = "" @@ -664,7 +809,7 @@ def delete_qualifications_by_participant_group_ids( if not participant_group_ids: return None - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() participant_group_ids_block = "" @@ -687,7 +832,7 @@ def clear_study_from_unit(self, unit_id: str) -> None: Clear the Study mapping that maps the given unit, if such a unit-study map exists """ - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() c.execute( """ @@ -720,7 +865,7 @@ def clear_study_from_unit(self, unit_id: str) -> None: def get_study_mapping(self, unit_id: str) -> sqlite3.Row: """Get the mapping between Mephisto IDs and Prolific IDs""" with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ @@ -733,9 +878,10 @@ def get_study_mapping(self, unit_id: str) -> sqlite3.Row: results = c.fetchall() return results[0] + @retry_generate_id(caught_excs=[EntryAlreadyExistsException]) def register_run( self, - run_id: str, + task_run_id: str, prolific_workspace_id: str, prolific_project_id: str, prolific_study_config_path: str, @@ -745,70 +891,82 @@ def register_run( prolific_study_id: Optional[str] = None, ) -> None: """Register a new task run in the Task Runs table""" - with self.table_access_condition, self._get_connection() as conn: + with self.table_access_condition, self.get_connection() as conn: c = conn.cursor() - c.execute( - """ - INSERT INTO runs( - run_id, - arn_id, - prolific_workspace_id, - prolific_project_id, - prolific_study_id, - prolific_study_config_path, - frame_height, - actual_available_places, - listed_available_places - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?); - """, - ( - run_id, - "unused", - prolific_workspace_id, - prolific_project_id, - prolific_study_id, - prolific_study_config_path, - frame_height, - actual_available_places, - listed_available_places, - ), - ) + try: + c.execute( + """ + INSERT INTO runs( + id, + task_run_id, + arn_id, + prolific_workspace_id, + prolific_project_id, + prolific_study_id, + prolific_study_config_path, + frame_height, + actual_available_places, + listed_available_places + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?); + """, + ( + make_randomized_int_id(), + task_run_id, + "unused", + prolific_workspace_id, + prolific_project_id, + prolific_study_id, + prolific_study_config_path, + frame_height, + actual_available_places, + listed_available_places, + ), + ) + except sqlite3.IntegrityError as e: + if is_unique_failure(e): + raise EntryAlreadyExistsException( + e, + db=self, + table_name="runs", + original_exc=e, + ) + raise MephistoDBException(e) - def get_run(self, run_id: str) -> sqlite3.Row: + def get_run(self, task_run_id: str) -> sqlite3.Row: """Get the details for a run by task_run_id""" with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ SELECT * from runs - WHERE run_id = ?; + WHERE task_run_id = ?; """, - (run_id,), + (task_run_id,), ) results = c.fetchall() return results[0] def set_available_places_for_run( self, - run_id: str, + task_run_id: str, actual_available_places: int, listed_available_places: int, ) -> None: """Set available places for a run by task_run_id""" with self.table_access_condition: - conn = self._get_connection() + conn = self.get_connection() c = conn.cursor() c.execute( """ UPDATE runs SET actual_available_places = ?, listed_available_places = ? - WHERE run_id = ? + WHERE task_run_id = ? """, ( actual_available_places, listed_available_places, - run_id, + task_run_id, ), ) conn.commit() diff --git a/mephisto/abstractions/providers/prolific/prolific_datastore_export.py b/mephisto/abstractions/providers/prolific/prolific_datastore_export.py new file mode 100644 index 000000000..cc1d85027 --- /dev/null +++ b/mephisto/abstractions/providers/prolific/prolific_datastore_export.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List +from typing import Optional + +from mephisto.utils import db as db_utils + + +def export_datastore( + datastore: "ProlificDatastore", + mephisto_db_data: dict, + task_run_ids: Optional[List[str]] = None, + **kwargs, +) -> dict: + """Logic of collecting export data from Prolific datastore""" + + dump_data = db_utils.db_or_datastore_to_dict(datastore) + + if not task_run_ids: + # Exporting the entire DB + return dump_data + + tables_with_task_run_relations = [ + "qualifications", + "run_mappings", + "runs", + "studies", + "units", + ] + + for table_name in tables_with_task_run_relations: + table_rows = db_utils.select_rows_from_table_related_to_task_run( + datastore, table_name, task_run_ids, + ) + runs_table_data = db_utils.serialize_data_for_table(table_rows) + dump_data[table_name] = runs_table_data + + # Find and serialize `submissions` + study_ids = list(set(filter(bool, [i["prolific_study_id"] for i in dump_data["studies"]]))) + submission_rows = db_utils.select_rows_by_list_of_field_values( + datastore, "submissions", ["prolific_study_id"], [study_ids], + ) + dump_data["submissions"] = db_utils.serialize_data_for_table(submission_rows) + + # Find and serialize `participant_groups` + participant_group_ids = list(set(filter(bool, [ + i["prolific_participant_group_id"] for i in dump_data["qualifications"] + ]))) + participant_group_rows = db_utils.select_rows_by_list_of_field_values( + datastore, "participant_groups", ["prolific_participant_group_id"], [participant_group_ids], + ) + dump_data["participant_groups"] = db_utils.serialize_data_for_table(participant_group_rows) + + # Find and serialize `workers` + worker_ids = [i["worker_name"] for i in mephisto_db_data["workers"]] + if worker_ids: + worker_rows = db_utils.select_rows_by_list_of_field_values( + datastore, "workers", ["worker_id"], [worker_ids], + ) + else: + worker_rows = db_utils.select_all_table_rows(datastore, "workers") + dump_data["workers"] = db_utils.serialize_data_for_table(worker_rows) + + return dump_data diff --git a/mephisto/abstractions/providers/prolific/prolific_datastore_tables.py b/mephisto/abstractions/providers/prolific/prolific_datastore_tables.py index 72817286c..fcdbcfc5c 100644 --- a/mephisto/abstractions/providers/prolific/prolific_datastore_tables.py +++ b/mephisto/abstractions/providers/prolific/prolific_datastore_tables.py @@ -4,7 +4,16 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -CREATE_STUDIES_TABLE = """ +""" +WARNING: In this module you can find initial table structures, but not final. +There are can be changes in migrations. To see actual fields, constraints, etc., +see information in databases or look through all migrations for current database +""" + +from mephisto.abstractions.databases import local_database_tables + + +CREATE_IF_NOT_EXISTS_STUDIES_TABLE = """ CREATE TABLE IF NOT EXISTS studies ( id INTEGER PRIMARY KEY AUTOINCREMENT, prolific_study_id TEXT UNIQUE, @@ -16,7 +25,7 @@ ); """ -CREATE_SUBMISSIONS_TABLE = """ +CREATE_IF_NOT_EXISTS_SUBMISSIONS_TABLE = """ CREATE TABLE IF NOT EXISTS submissions ( id INTEGER PRIMARY KEY AUTOINCREMENT, prolific_submission_id TEXT UNIQUE, @@ -26,7 +35,7 @@ ); """ -CREATE_RUN_MAP_TABLE = """ +CREATE_IF_NOT_EXISTS_RUN_MAP_TABLE = """ CREATE TABLE IF NOT EXISTS run_mappings ( id INTEGER PRIMARY KEY AUTOINCREMENT, prolific_study_id TEXT, @@ -34,15 +43,7 @@ ); """ -CREATE_REQUESTERS_TABLE = """ - CREATE TABLE IF NOT EXISTS requesters ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - requester_id TEXT UNIQUE, - is_registered BOOLEAN - ); -""" - -CREATE_UNITS_TABLE = """ +CREATE_IF_NOT_EXISTS_UNITS_TABLE = """ CREATE TABLE IF NOT EXISTS units ( id INTEGER PRIMARY KEY AUTOINCREMENT, unit_id TEXT UNIQUE, @@ -53,7 +54,7 @@ ); """ -CREATE_WORKERS_TABLE = """ +CREATE_IF_NOT_EXISTS_WORKERS_TABLE = """ CREATE TABLE IF NOT EXISTS workers ( id INTEGER PRIMARY KEY AUTOINCREMENT, worker_id TEXT UNIQUE, @@ -61,7 +62,7 @@ ); """ -CREATE_RUNS_TABLE = """ +CREATE_IF_NOT_EXISTS_RUNS_TABLE = """ CREATE TABLE IF NOT EXISTS runs ( id INTEGER PRIMARY KEY AUTOINCREMENT, run_id TEXT UNIQUE, @@ -77,7 +78,7 @@ ); """ -CREATE_PARTICIPANT_GROUPS_TABLE = """ +CREATE_IF_NOT_EXISTS_PARTICIPANT_GROUPS_TABLE = """ CREATE TABLE IF NOT EXISTS participant_groups ( id INTEGER PRIMARY KEY AUTOINCREMENT, qualification_name TEXT UNIQUE, @@ -89,7 +90,8 @@ ); """ -CREATE_PARTICIPANT_GROUP_QUALIFICATIONS_MAPPING_TABLE = """ +# Mapping between Mephisto qualifications and Prolific Participant Groups +CREATE_IF_NOT_EXISTS_QUALIFICATIONS_TABLE = """ CREATE TABLE IF NOT EXISTS qualifications ( id INTEGER PRIMARY KEY AUTOINCREMENT, prolific_participant_group_id TEXT, @@ -99,3 +101,5 @@ creation_date DATETIME DEFAULT CURRENT_TIMESTAMP ); """ + +CREATE_IF_NOT_EXISTS_MIGRATIONS_TABLE = local_database_tables.CREATE_IF_NOT_EXISTS_MIGRATIONS_TABLE diff --git a/mephisto/abstractions/providers/prolific/prolific_provider.py b/mephisto/abstractions/providers/prolific/prolific_provider.py index 09cab0d93..6efa30b45 100644 --- a/mephisto/abstractions/providers/prolific/prolific_provider.py +++ b/mephisto/abstractions/providers/prolific/prolific_provider.py @@ -332,7 +332,7 @@ def setup_resources_for_task_run( if q.qualification_name in qualification_names ] self.datastore.create_qualification_mapping( - run_id=task_run_id, + task_run_id=task_run_id, prolific_participant_group_id=prolific_participant_group.id, qualifications=qualifications, qualification_ids=qualifications_ids, @@ -361,7 +361,7 @@ def setup_resources_for_task_run( # Register TaskRun in Datastore self.datastore.register_run( - run_id=task_run_id, + task_run_id=task_run_id, prolific_workspace_id=prolific_workspace.id, prolific_project_id=prolific_project.id, prolific_study_config_path=config_dir, @@ -377,6 +377,8 @@ def setup_resources_for_task_run( task_run_id=task_run_id, status=StudyStatus.ACTIVE, ) + self.datastore.new_run_mapping(prolific_study.id, task_run_id) + logger.debug( f'{self.log_prefix}Prolific Study "{prolific_study.id}" has been saved into datastore' ) diff --git a/mephisto/abstractions/providers/prolific/prolific_unit.py b/mephisto/abstractions/providers/prolific/prolific_unit.py index 5b0b2d378..ac4814908 100644 --- a/mephisto/abstractions/providers/prolific/prolific_unit.py +++ b/mephisto/abstractions/providers/prolific/prolific_unit.py @@ -230,7 +230,7 @@ def set_db_status(self, status: str) -> None: task_run_id = self.get_task_run().db_id datastore_task_run = self.datastore.get_run(task_run_id) self.datastore.set_available_places_for_run( - run_id=task_run_id, + task_run_id=task_run_id, actual_available_places=datastore_task_run["actual_available_places"] - 1, listed_available_places=datastore_task_run["listed_available_places"] - 1, ) @@ -309,7 +309,7 @@ def launch(self, task_url: str) -> None: actual_available_places += 1 self.datastore.set_available_places_for_run( - run_id=task_run_id, + task_run_id=task_run_id, actual_available_places=actual_available_places, listed_available_places=listed_available_places, ) @@ -346,7 +346,7 @@ def expire(self) -> float: listed_places_decrement = 1 if task_run.get_is_completed() else 0 self.datastore.set_available_places_for_run( - run_id=task_run.db_id, + task_run_id=task_run.db_id, actual_available_places=actual_available_places - 1, listed_available_places=listed_available_places - listed_places_decrement, ) @@ -403,7 +403,7 @@ def new(db: "MephistoDB", assignment: "Assignment", index: int, pay_amount: floa ) datastore.create_unit( unit_id=unit.db_id, - run_id=assignment.task_run_id, + task_run_id=assignment.task_run_id, prolific_study_id=task_run_details["prolific_study_id"], ) logger.debug(f"{ProlificUnit.log_prefix}Unit was created in datastore successfully!") diff --git a/mephisto/abstractions/test/crowd_provider_tester.py b/mephisto/abstractions/test/crowd_provider_tester.py index 4cfc1ebf2..e6f8eb732 100644 --- a/mephisto/abstractions/test/crowd_provider_tester.py +++ b/mephisto/abstractions/test/crowd_provider_tester.py @@ -5,22 +5,17 @@ # LICENSE file in the root directory of this source tree. -import unittest -from typing import Optional, Tuple, Type -import tempfile -import mephisto import os import shutil +import tempfile +import unittest +from typing import Type + +from mephisto.abstractions.crowd_provider import CrowdProvider +from mephisto.abstractions.database import MephistoDB +from mephisto.abstractions.databases.local_database import LocalMephistoDB from mephisto.data_model.requester import Requester from mephisto.data_model.worker import Worker -from mephisto.abstractions.database import ( - MephistoDB, - MephistoDBException, - EntryAlreadyExistsException, - EntryDoesNotExistException, -) -from mephisto.abstractions.databases.local_database import LocalMephistoDB -from mephisto.abstractions.crowd_provider import CrowdProvider class CrowdProviderTests(unittest.TestCase): diff --git a/mephisto/abstractions/test/data_model_database_tester.py b/mephisto/abstractions/test/data_model_database_tester.py index 00032a81d..600567517 100644 --- a/mephisto/abstractions/test/data_model_database_tester.py +++ b/mephisto/abstractions/test/data_model_database_tester.py @@ -5,40 +5,39 @@ # LICENSE file in the root directory of this source tree. +import json import unittest -from typing import Optional, Tuple -from mephisto.utils.testing import ( - get_test_assignment, - get_test_project, - get_test_requester, - get_test_task, - get_test_task_run, - get_test_worker, - get_test_unit, - get_test_agent, -) -from mephisto.abstractions.providers.mock.provider_type import PROVIDER_TYPE -from mephisto.data_model.constants import NO_PROJECT_NAME -from mephisto.data_model.agent import Agent, OnboardingAgent +from typing import Optional + +from omegaconf import OmegaConf + from mephisto.abstractions.blueprint import AgentState -from mephisto.data_model.unit import Unit +from mephisto.abstractions.database import MephistoDB +from mephisto.abstractions.providers.mock.provider_type import PROVIDER_TYPE +from mephisto.data_model.agent import Agent +from mephisto.data_model.agent import OnboardingAgent from mephisto.data_model.assignment import Assignment +from mephisto.data_model.constants import NO_PROJECT_NAME from mephisto.data_model.constants.assignment_state import AssignmentState from mephisto.data_model.project import Project +from mephisto.data_model.qualification import Qualification from mephisto.data_model.requester import Requester from mephisto.data_model.task import Task -from mephisto.data_model.task_run import TaskRun, TaskRunArgs -from mephisto.data_model.qualification import Qualification +from mephisto.data_model.task_run import TaskRun +from mephisto.data_model.task_run import TaskRunArgs +from mephisto.data_model.unit import Unit from mephisto.data_model.worker import Worker -from mephisto.abstractions.database import ( - MephistoDB, - MephistoDBException, - EntryAlreadyExistsException, - EntryDoesNotExistException, -) - -from omegaconf import OmegaConf -import json +from mephisto.utils.db import EntryAlreadyExistsException +from mephisto.utils.db import EntryDoesNotExistException +from mephisto.utils.db import MephistoDBException +from mephisto.utils.testing import get_test_agent +from mephisto.utils.testing import get_test_assignment +from mephisto.utils.testing import get_test_project +from mephisto.utils.testing import get_test_requester +from mephisto.utils.testing import get_test_task +from mephisto.utils.testing import get_test_task_run +from mephisto.utils.testing import get_test_unit +from mephisto.utils.testing import get_test_worker class BaseDatabaseTests(unittest.TestCase): diff --git a/mephisto/client/cli.py b/mephisto/client/cli.py index 82212b995..ef42dca60 100644 --- a/mephisto/client/cli.py +++ b/mephisto/client/cli.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 + # Copyright (c) Meta Platforms and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. @@ -12,7 +13,6 @@ import rich_click as click # type: ignore from flask.cli import pass_script_info from flask.cli import ScriptInfo -from rich import print from rich.markdown import Markdown from rich_click import RichCommand from rich_click import RichGroup @@ -49,7 +49,10 @@ set_custom_validators_js_env_var, ) from mephisto.operations.registry import get_valid_provider_types +from mephisto.tools.db_data_porter import DBDataPorter +from mephisto.tools.db_data_porter.constants import DEFAULT_CONFLICT_RESOLVER from mephisto.tools.scripts import build_custom_bundle +from mephisto.utils.console_writer import ConsoleWriter from mephisto.utils.rich import console from mephisto.utils.rich import create_table @@ -59,6 +62,8 @@ FORM_COMPOSER__TOKEN_SETS_VALUES_CONFIG_NAME = "token_sets_values_config.json" FORM_COMPOSER__SEPARATE_TOKEN_VALUES_CONFIG_NAME = "separate_token_values_config.json" +logger = ConsoleWriter() + @click.group(cls=RichGroup) def cli(): @@ -104,7 +109,7 @@ def config(identifier, value): else: # Write mode: add_config_arg(section, key, value) - print(f"[green]{identifier} succesfully updated to: {value}[/green]") + logger.info(f"[green]{identifier} succesfully updated to: {value}[/green]") @cli.command("check", cls=RichCommand) @@ -117,10 +122,10 @@ def check(): db = LocalMephistoDB() get_mock_requester(db) except Exception as e: - print("\n[red]Something went wrong.[/red]") + logger.exception("\n[red]Something went wrong.[/red]") click.echo(e) return - print("\n[green]Mephisto seems to be set up correctly.[/green]\n") + logger.info("\n[green]Mephisto seems to be set up correctly.[/green]\n") @cli.command("requesters", cls=RichCommand) @@ -140,7 +145,7 @@ def list_requesters(): requester_table.add_row(*requester_vals) console.print(requester_table) else: - print("[red]No requesters found[/red]") + logger.error("[red]No requesters found[/red]") @cli.command("register", cls=RichCommand, context_settings={"ignore_unknown_options": True}) @@ -148,14 +153,14 @@ def list_requesters(): def register_provider(args): """Register a requester with a crowd provider""" if len(args) == 0: - print("\n[red]Usage: mephisto register arg1=value arg2=value[/red]") - print("\n[b]Valid Providers[/b]") + logger.error("\n[red]Usage: mephisto register arg1=value arg2=value[/red]") + logger.info("\n[b]Valid Providers[/b]") provider_text = """""" for provider in get_valid_provider_types(): provider_text += "\n* " + provider provider_text_markdown = Markdown(provider_text) console.print(provider_text_markdown) - print("") + logger.info("") return from mephisto.abstractions.databases.local_database import LocalMephistoDB @@ -186,7 +191,7 @@ def register_provider(args): requester_table.add_row(*arg_values) console.print(requester_table) else: - print("[red]Requester has no args[/red]") + logger.error("[red]Requester has no args[/red]") return try: @@ -195,7 +200,7 @@ def register_provider(args): click.echo(str(e)) if parsed_options.name is None: - print("[red]No name was specified for the requester.[/red]") + logger.error("[red]No name was specified for the requester.[/red]") db = LocalMephistoDB() requesters = db.find_requesters(requester_name=parsed_options.name) @@ -205,7 +210,7 @@ def register_provider(args): requester = requesters[0] try: requester.register(parsed_options) - print("[green]Registered successfully.[/green]") + logger.info("[green]Registered successfully.[/green]") except Exception as e: click.echo(str(e)) @@ -232,7 +237,7 @@ def print_non_markdown_list(items: List[str]): VALID_SCRIPT_TYPES = ["local_db", "heroku", "metrics", "mturk", "form_composer"] if script_type is None or script_type.strip() not in VALID_SCRIPT_TYPES: - print("") + logger.info("") raise click.UsageError( "You must specify a valid script_type from below. \n\nValid script types are:" + print_non_markdown_list(VALID_SCRIPT_TYPES) @@ -300,7 +305,7 @@ def print_non_markdown_list(items: List[str]): if script_name is None or ( script_name not in script_type_to_scripts_data[script_type]["valid_script_names"] ): - print("") + logger.info("") raise click.UsageError( "You must specify a valid script_name from below. \n\nValid script names are:" + print_non_markdown_list( @@ -324,7 +329,7 @@ def metrics_cli(args): ) if len(args) == 0 or args[0] not in ["install", "view", "cleanup"]: - print("\n[red]Usage: mephisto metrics [/red]") + logger.error("\n[red]Usage: mephisto metrics [/red]") metrics_table = create_table(["Property", "Value"], "Metrics Arguments") metrics_table.add_row("install", f"Installs Prometheus and Grafana to {METRICS_DIR}") metrics_table.add_row( @@ -388,7 +393,7 @@ def review_app( os.environ["HOST"] = host os.environ["PORT"] = str(port) - print(f'[green]Review APP will start on "{app_url}" address.[/green]') + logger.info(f'[green]Review APP will start on "{app_url}" address.[/green]') # Set up Review App Client if not skip_build: @@ -401,9 +406,9 @@ def review_app( # Install JS requirements if os.path.exists(os.path.join(client_path, "node_modules")): - print(f"[blue]JS requirements are already installed.[/blue]") + logger.info(f"[blue]JS requirements are already installed.[/blue]") else: - print(f"[blue]Installing JS requirements started.[/blue]") + logger.info(f"[blue]Installing JS requirements started.[/blue]") subprocess.call(["ls"], cwd=client_path) app_started = subprocess.call(["npm", "install"], cwd=client_path) if app_started != 0: @@ -411,19 +416,19 @@ def review_app( "Please make sure npm is installed, " "otherwise view the above error for more info." ) - print(f"[blue]Installing JS requirements finished.[/blue]") + logger.info(f"[blue]Installing JS requirements finished.[/blue]") if os.path.exists(os.path.join(client_path, "build", "index.html")) and not force_rebuild: - print(f"[blue]React bundle is already built.[/blue]") + logger.info(f"[blue]React bundle is already built.[/blue]") else: - print(f"[blue]Building React bundle started.[/blue]") + logger.info(f"[blue]Building React bundle started.[/blue]") build_custom_bundle( review_app_path, force_rebuild=force_rebuild, webapp_name=client_dir, build_command="build", ) - print(f"[blue]Building React bundle finished.[/blue]") + logger.info(f"[blue]Building React bundle finished.[/blue]") # Set debug debug = debug if debug is not None else get_debug_flag() @@ -533,15 +538,15 @@ def form_composer_config( else: app_path = _get_form_composer_app_path() app_data_path = os.path.join(app_path, FORM_COMPOSER__DATA_DIR_NAME) - print(f"[blue]Using config directory: {app_data_path}[/blue]") + logger.info(f"[blue]Using config directory: {app_data_path}[/blue]") # Validate param values if not os.path.exists(app_data_path): - print(f"[red]Directory '{app_data_path}' does not exist[/red]") + logger.error(f"[red]Directory '{app_data_path}' does not exist[/red]") return None if use_presigned_urls and not update_file_location_values: - print( + logger.error( f"[red]Parameter `--use-presigned-urls` can be used " f"only with `--update-file-location-values` option[/red]" ) @@ -556,7 +561,7 @@ def form_composer_config( # Run the command if verify: - print(f"Started configs verification") + logger.info(f"Started configs verification") verify_form_composer_configs( task_data_config_path=task_data_config_path, form_config_path=form_config_path, @@ -565,10 +570,10 @@ def form_composer_config( task_data_config_only=False, data_path=app_data_path, ) - print(f"Finished configs verification") + logger.info(f"Finished configs verification") elif update_file_location_values: - print( + logger.info( f"[green]Started updating '{FORM_COMPOSER__SEPARATE_TOKEN_VALUES_CONFIG_NAME}' " f"with file URLs from '{update_file_location_values}'[/green]" ) @@ -578,12 +583,12 @@ def form_composer_config( separate_token_values_config_path=separate_token_values_config_path, use_presigned_urls=use_presigned_urls, ) - print(f"[green]Finished successfully[/green]") + logger.info(f"[green]Finished successfully[/green]") else: - print("`--update-file-location-values` must be a valid S3 URL") + logger.info("`--update-file-location-values` must be a valid S3 URL") elif permutate_separate_tokens: - print( + logger.info( f"[green]Started updating '{FORM_COMPOSER__TOKEN_SETS_VALUES_CONFIG_NAME}' " f"with permutated separate-token values[/green]" ) @@ -591,10 +596,10 @@ def form_composer_config( separate_token_values_config_path=separate_token_values_config_path, token_sets_values_config_path=token_sets_values_config_path, ) - print(f"[green]Finished successfully[/green]") + logger.info(f"[green]Finished successfully[/green]") elif extrapolate_token_sets: - print( + logger.info( f"[green]Started extrapolating token sets values " f"from '{FORM_COMPOSER__TOKEN_SETS_VALUES_CONFIG_NAME}' [/green]" ) @@ -604,10 +609,10 @@ def form_composer_config( task_data_config_path=task_data_config_path, data_path=app_data_path, ) - print(f"[green]Finished successfully[/green]") + logger.info(f"[green]Finished successfully[/green]") else: - print( + logger.error( f"[red]" f"This command must have one of following parameters:" f"\n-v/--verify" @@ -618,5 +623,190 @@ def form_composer_config( ) +@cli.command("db", cls=RichCommand) +@click.argument("action_name", required=True, nargs=1) +@click.option("-d", "--dump-file", type=(str), default=None) +@click.option("-i", "--export-indent", type=(int), default=None) +@click.option("-tn", "--export-tasks-by-names", type=(str), multiple=True, default=None) +@click.option("-ti", "--export-tasks-by-ids", type=(str), multiple=True, default=None) +@click.option("-tr", "--export-task-runs-by-ids", type=(str), multiple=True, default=None) +@click.option("-trs", "--export-task-runs-since-date", type=(str), default=None) +@click.option("-tl", "--export-labels", type=(str), multiple=True, default=None) +@click.option("-de", "--delete-exported-data", type=(bool), default=False, is_flag=True) +@click.option("-r", "--randomize-legacy-ids", type=(bool), default=False, is_flag=True) +@click.option("-l", "--label-name", type=(str), default=None) +@click.option("-cr", "--conflict-resolver", type=(str), default=DEFAULT_CONFLICT_RESOLVER) +@click.option("-k", "--keep-import-metadata", type=(bool), default=False, is_flag=True) +@click.option("-b", "--backup-file", type=(str), default=None) +@click.option("-v", "--verbosity", type=(int), default=0) +def db( + action_name: str, + dump_file: Optional[str] = None, + export_indent: Optional[int] = None, + export_tasks_by_names: Optional[List[str]] = None, + export_tasks_by_ids: Optional[List[str]] = None, + export_task_runs_by_ids: Optional[List[str]] = None, + export_task_runs_since_date: Optional[str] = None, + export_labels: Optional[List[str]] = None, + delete_exported_data: bool = False, + randomize_legacy_ids: bool = False, + label_name: Optional[str] = None, + conflict_resolver: Optional[str] = DEFAULT_CONFLICT_RESOLVER, + keep_import_metadata: Optional[bool] = False, + backup_file: Optional[str] = None, + verbosity: int = 0, +): + """ + Operations with Mephisto DB and provider-specific datastores. + + Commands: + 1. mephisto db export + This command exports data from Mephisto DB and provider-specific datastores + as a combination of (i) a JSON file, and (ii) an archived `data` catalog with related files. + + If no parameter passed, full data dump (i.e. backup) will be created. + + To pass a list of values for one command option, + simply repeat that option name before each value. + + Options (all optional): + `-tn/--export-tasks-by-names` - names of Tasks that will be exported + `-ti/--export-tasks-by-ids` - ids of Tasks that will be exported + `-tr/--export-task-runs-by-ids` - ids of TaskRuns that will be exported + `-trs/--export-task-runs-since-date` - only objects created after this + ISO8601 datetime will be exported + `-tl/--export-labels` - only data imported under these labels will be exported + `-de/--delete-exported-data` - after exporting data, delete it from local DB + `-r/--randomize-legacy-ids` - replace legacy autoincremented ids with + new pseudo-random ids to avoid conflicts during data merging + `-i/--export-indent` - make dump easy to read via formatting JSON with indentations + `-v/--verbosity` - write more informative messages about progress + (Default 0. Values: 0, 1) + + + 2. mephisto db import --dump-file + + This command imports data from a dump file created by `mephisto db export` command. + + Options: + `-d/--dump-file` - location of the __***.json__ dump file (filename if created in + `/outputs/export` folder, or absolute filepath) + `-cr/--conflict-resolver` (Optional) - name of Python class + to be used for resolving merging conflicts (when your local DB already has a row + with same unique field value as a DB row in the dump data) + `-l/--label-name` - a short string serving as a reference for the ported data + (stored in `imported_data` table), so later you can export the imported data + with `--export-labels` export option + `-k/--keep-import-metadata` - write data from `imported_data` table of the dump + (by default it's not imported) + `-v/--verbosity` - level of logging (default: 0; values: 0, 1) + + 3. mephisto db backup + + Creates full backup of all current data (Mephisto DB, provider-specific datastores, + and related files) on local machine. + + 4. mephisto db restore --backup-file + + Restores all data (Mephisto DB, provider-specific datastores, and related files) + from a backup archive. + + Options: + `-b/--backup-file` - location of the __*.zip__ backup file (filename if created in + `/outputs/backup` folder, or absolute filepath) + `-v/--verbosity` - level of logging (default: 0; values: 0, 1) + """ + porter = DBDataPorter() + + # --- EXPORT --- + if action_name == "export": + has_conflicting_task_runs_options = len(list(filter(bool, [ + export_tasks_by_names, + export_tasks_by_ids, + export_task_runs_by_ids, + export_task_runs_since_date, + export_labels, + ]))) > 1 + + if has_conflicting_task_runs_options: + logger.warning( + "[yellow]" + "You cannot use following options together:" + "\n\t--export-tasks-by-names" + "\n\t--export-tasks-by-ids" + "\n\t--export-task-runs-by-ids" + "\n\t--export-task-runs-since-date" + "\n\t--export-labels" + "\nUse one of them or none of them to export all data." + "[/yellow]" + ) + exit() + + logger.info(f"Started exporting") + + export_results = porter.export_dump( + json_indent=export_indent, + task_names=export_tasks_by_names, + task_ids=export_tasks_by_ids, + task_run_ids=export_task_runs_by_ids, + task_runs_since_date=export_task_runs_since_date, + task_runs_labels=export_labels, + delete_exported_data=delete_exported_data, + randomize_legacy_ids=randomize_legacy_ids, + verbosity=verbosity, + ) + + data_files_line = "" + if export_results["data_path"]: + data_files_line = f"\n\t- Data files dump - {export_results['data_path']}" + + backup_line = "" + if export_results["backup_path"]: + backup_line = f"\n\t- Backup - {export_results['backup_path']}" + + logger.info( + f"[green]" + f"Finished successfully! " + f"\nFiles created:" + f"\n\t- Database dump - {export_results['db_path']}" + f"{data_files_line}" + f"{backup_line}" + f"[/green]" + ) + + # --- IMPORT --- + elif action_name == "import": + logger.info(f"Started importing from dump '{dump_file}'") + porter.import_dump( + dump_file_name_or_path=dump_file, + conflict_resolver_name=conflict_resolver, + label=label_name, + keep_import_metadata=keep_import_metadata, + verbosity=verbosity, + ) + logger.info(f"[green]Finished successfully[/green]") + + # --- BACKUP --- + elif action_name == "backup": + logger.info(f"Started making backup") + backup_path = porter.make_backup() + logger.info(f"[green]Finished successfully! File: '{backup_path}[/green]") + + # --- RESTORE --- + elif action_name == "restore": + logger.info(f"Started restoring from backup '{backup_file}'") + porter.restore_from_backup(backup_file_name_or_path=backup_file, verbosity=verbosity) + logger.info(f"[green]Finished successfully[/green]") + + # Otherwise, error + else: + logger.error( + f"[red]" + f"Unexpected action name '{action_name}'. Available: export, import, restore." + f"[/red]" + ) + exit() + + if __name__ == "__main__": cli() diff --git a/mephisto/generators/form_composer/config_validation/utils.py b/mephisto/generators/form_composer/config_validation/utils.py index 2e945d7df..5d9bf7333 100644 --- a/mephisto/generators/form_composer/config_validation/utils.py +++ b/mephisto/generators/form_composer/config_validation/utils.py @@ -9,6 +9,7 @@ from pathlib import Path from typing import Any from typing import List +from typing import Optional from typing import Tuple from typing import Union from urllib.parse import urljoin @@ -72,10 +73,16 @@ def read_config_file( return config_data -def make_error_message(main_message: str, error_list: List[str], indent: int = 2) -> str: +# TODO: Move this function and its tests into `mephisto.utils`, as it is too useful for one app +def make_error_message( + main_message: str, + error_list: List[str], + indent: int = 2, + list_title: Optional[str] = "Errors", +) -> str: prefix = "\n" + (" " * indent) + "- " errors_bullets = prefix + prefix.join(map(str, error_list)) - error_title = f"{main_message.rstrip('.')}. Errors:" if main_message else "" + error_title = f"{main_message.rstrip('.')}. {list_title}:" if main_message else "" return error_title + errors_bullets diff --git a/mephisto/operations/operator.py b/mephisto/operations/operator.py index 202072173..a5b63bd1e 100644 --- a/mephisto/operations/operator.py +++ b/mephisto/operations/operator.py @@ -11,7 +11,6 @@ import threading import signal import asyncio -import traceback from mephisto.operations.datatypes import LiveTaskRun, LoopWrapper @@ -27,8 +26,9 @@ from mephisto.abstractions.blueprints.mixins.onboarding_required import ( OnboardingRequired, ) -from mephisto.abstractions.database import MephistoDB, EntryDoesNotExistException +from mephisto.abstractions.database import MephistoDB from mephisto.data_model.qualification import QUAL_NOT_EXIST +from mephisto.utils.db import EntryDoesNotExistException from mephisto.utils.qualifications import make_qualification_dict from mephisto.operations.task_launcher import TaskLauncher from mephisto.operations.client_io_handler import ClientIOHandler @@ -52,13 +52,12 @@ ) from omegaconf import DictConfig, OmegaConf -logger = get_logger(name=__name__) - if TYPE_CHECKING: from mephisto.abstractions.blueprint import Blueprint from mephisto.abstractions.crowd_provider import CrowdProvider from mephisto.abstractions.architect import Architect +logger = get_logger(name=__name__) RUN_STATUS_POLL_TIME = 10 diff --git a/mephisto/review_app/server/__init__.py b/mephisto/review_app/server/__init__.py index 5b4a2a971..cc7ac0768 100644 --- a/mephisto/review_app/server/__init__.py +++ b/mephisto/review_app/server/__init__.py @@ -18,11 +18,11 @@ from werkzeug.exceptions import HTTPException as WerkzeugHTTPException from werkzeug.utils import import_string -from mephisto.abstractions.database import EntryDoesNotExistException from mephisto.abstractions.databases.local_database import LocalMephistoDB from mephisto.abstractions.providers.prolific.api import status from mephisto.abstractions.providers.prolific.api.exceptions import ProlificException from mephisto.tools.data_browser import DataBrowser +from mephisto.utils.db import EntryDoesNotExistException from mephisto.utils.logger_core import get_logger from .urls import init_urls diff --git a/mephisto/review_app/server/api/views/qualification_workers_view.py b/mephisto/review_app/server/api/views/qualification_workers_view.py index 7cb1a6cd4..21e23f1be 100644 --- a/mephisto/review_app/server/api/views/qualification_workers_view.py +++ b/mephisto/review_app/server/api/views/qualification_workers_view.py @@ -20,7 +20,7 @@ def _find_granted_qualifications(db: LocalMephistoDB, qualification_id: str) -> """Return the granted qualifications in the database by the given qualification id""" with db.table_access_condition: - conn = db._get_connection() + conn = db.get_connection() c = conn.cursor() c.execute( f""" @@ -50,13 +50,13 @@ def _find_unit_reviews( params.append(nonesafe_int(task_id)) with db.table_access_condition: - conn = db._get_connection() + conn = db.get_connection() c = conn.cursor() c.execute( f""" SELECT * FROM unit_review WHERE (updated_qualification_id = ?1) AND (worker_id = ?2) {task_query} - ORDER BY created_at ASC; + ORDER BY creation_date ASC; """, params, ) @@ -89,7 +89,7 @@ def get(self, qualification_id: int) -> dict: if unit_reviews: latest_unit_review = unit_reviews[-1] unit_review_id = latest_unit_review["id"] - granted_at = latest_unit_review["created_at"] + granted_at = latest_unit_review["creation_date"] else: continue @@ -98,7 +98,7 @@ def get(self, qualification_id: int) -> dict: "worker_id": gq["worker_id"], "value": gq["value"], "unit_review_id": unit_review_id, # latest grant of this qualification - "granted_at": granted_at, # maps to `unit_review.created_at` column + "granted_at": granted_at, # maps to `unit_review.creation_date` column } ) diff --git a/mephisto/review_app/server/api/views/qualifications_view.py b/mephisto/review_app/server/api/views/qualifications_view.py index d5f2596b6..601559fb7 100644 --- a/mephisto/review_app/server/api/views/qualifications_view.py +++ b/mephisto/review_app/server/api/views/qualifications_view.py @@ -22,7 +22,7 @@ def _find_qualifications_by_ids( debug: bool = False, ) -> List[Qualification]: with db.table_access_condition: - conn = db._get_connection() + conn = db.get_connection() c = conn.cursor() diff --git a/mephisto/review_app/server/api/views/stats_view.py b/mephisto/review_app/server/api/views/stats_view.py index 192f8bb41..1629c6701 100644 --- a/mephisto/review_app/server/api/views/stats_view.py +++ b/mephisto/review_app/server/api/views/stats_view.py @@ -42,7 +42,7 @@ def _find_unit_reviews( if status: params.append(status) - since_query = "created_at >= ?" if since else "" + since_query = "creation_date >= ?" if since else "" if since: params.append(since) @@ -67,13 +67,13 @@ def _find_unit_reviews( params.append(nonesafe_int(limit)) with db.table_access_condition: - conn = db._get_connection() + conn = db.get_connection() c = conn.cursor() c.execute( f""" SELECT * FROM unit_review {where_query} - ORDER BY created_at ASC {limit_query}; + ORDER BY creation_date ASC {limit_query}; """, params, ) @@ -128,7 +128,7 @@ def _find_units_for_worker( params.append(nonesafe_int(limit)) with db.table_access_condition: - conn = db._get_connection() + conn = db.get_connection() c = conn.cursor() c.execute( f""" diff --git a/mephisto/review_app/server/api/views/tasks_view.py b/mephisto/review_app/server/api/views/tasks_view.py index 5abaa86d8..99c3772cd 100644 --- a/mephisto/review_app/server/api/views/tasks_view.py +++ b/mephisto/review_app/server/api/views/tasks_view.py @@ -17,7 +17,7 @@ def _find_tasks(db, debug: bool = False) -> List[StringIDRow]: with db.table_access_condition: - conn = db._get_connection() + conn = db.get_connection() c = conn.cursor() c.execute( diff --git a/mephisto/review_app/server/api/views/worker_block_view.py b/mephisto/review_app/server/api/views/worker_block_view.py index 7b3643fb7..a1d49e4d4 100644 --- a/mephisto/review_app/server/api/views/worker_block_view.py +++ b/mephisto/review_app/server/api/views/worker_block_view.py @@ -11,9 +11,9 @@ from flask.views import MethodView from werkzeug.exceptions import BadRequest -from mephisto.abstractions.database import EntryDoesNotExistException from mephisto.data_model.unit import Unit from mephisto.data_model.worker import Worker +from mephisto.utils.db import EntryDoesNotExistException def _update_blocked_worker_in_unit_review( @@ -25,14 +25,14 @@ def _update_blocked_worker_in_unit_review( """Update unit review in the db with blocking Worker value""" with db.table_access_condition: - conn = db._get_connection() + conn = db.get_connection() c = conn.cursor() c.execute( """ SELECT * FROM unit_review WHERE (unit_id = ?) AND (worker_id = ?) - ORDER BY created_at ASC; + ORDER BY creation_date ASC; """, (unit_id, worker_id), ) diff --git a/mephisto/review_app/server/api/views/worker_granted_qualifications_view.py b/mephisto/review_app/server/api/views/worker_granted_qualifications_view.py index 7d894990a..d756f2a2b 100644 --- a/mephisto/review_app/server/api/views/worker_granted_qualifications_view.py +++ b/mephisto/review_app/server/api/views/worker_granted_qualifications_view.py @@ -18,7 +18,7 @@ def _find_granted_qualifications(db: LocalMephistoDB, worker_id: str) -> List[St """Return the granted qualifications in the database by the given worker id""" with db.table_access_condition: - conn = db._get_connection() + conn = db.get_connection() c = conn.cursor() c.execute( """ @@ -27,14 +27,14 @@ def _find_granted_qualifications(db: LocalMephistoDB, worker_id: str) -> List[St gq.worker_id, gq.qualification_id, gq.granted_qualification_id, - ur.created_at AS granted_at + ur.creation_date AS granted_at FROM granted_qualifications AS gq LEFT JOIN ( SELECT updated_qualification_id, - created_at + creation_date FROM unit_review - ORDER BY created_at DESC + ORDER BY creation_date DESC /* We’re retrieving unit_review data only for the latest update of the worker-qualification pair. @@ -70,7 +70,7 @@ def get(self, worker_id: int) -> dict: "worker_id": gq["worker_id"], "qualification_id": gq["qualification_id"], "value": int(gq["value"]), - "granted_at": gq["worker_id"], # maps to `unit_review.created_at` column + "granted_at": gq["granted_at"], # maps to `unit_review.creation_date` column }, ) diff --git a/mephisto/review_app/server/db_queries.py b/mephisto/review_app/server/db_queries.py index ebb84555b..d244d0c2d 100644 --- a/mephisto/review_app/server/db_queries.py +++ b/mephisto/review_app/server/db_queries.py @@ -18,7 +18,7 @@ def find_units( debug: bool = False, ) -> List[StringIDRow]: with db.table_access_condition: - conn = db._get_connection() + conn = db.get_connection() params = [] diff --git a/mephisto/tools/db_data_porter/__init__.py b/mephisto/tools/db_data_porter/__init__.py new file mode 100644 index 000000000..0dfc664a9 --- /dev/null +++ b/mephisto/tools/db_data_porter/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .db_data_porter import DBDataPorter diff --git a/mephisto/tools/db_data_porter/backups.py b/mephisto/tools/db_data_porter/backups.py new file mode 100644 index 000000000..9222917ad --- /dev/null +++ b/mephisto/tools/db_data_porter/backups.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import shutil +from distutils.dir_util import copy_tree +from pathlib import Path +from typing import List + +from mephisto.abstractions.database import MephistoDB +from mephisto.data_model.task_run import TaskRun +from mephisto.tools.db_data_porter.constants import AGENTS_TABLE_NAME +from mephisto.tools.db_data_porter.constants import ASSIGNMENTS_TABLE_NAME +from mephisto.tools.db_data_porter.constants import MEPHISTO_DUMP_KEY +from mephisto.tools.db_data_porter.constants import TASK_RUNS_TABLE_NAME +from mephisto.tools.db_data_porter.randomize_ids import get_old_pk_from_substitutions +from mephisto.utils import db as db_utils +from mephisto.utils.console_writer import ConsoleWriter +from mephisto.utils.dirs import get_data_dir +from mephisto.utils.dirs import get_mephisto_tmp_dir + +DEFAULT_ARCHIVE_FORMAT = "zip" + +logger = ConsoleWriter() + + +def _rename_dirs_with_new_pks(task_run_dirs: List[str], pk_substitutions: dict): + def rename_dir_with_new_pk(dir_path: str, substitutions: dict) -> str: + dump_id = substitutions.get(os.path.basename(dir_path)) + renamed_dir_path = os.path.join(os.path.dirname(dir_path), dump_id) + os.rename(dir_path, renamed_dir_path) + return renamed_dir_path + + task_runs_subs = pk_substitutions.get(MEPHISTO_DUMP_KEY, {}).get(TASK_RUNS_TABLE_NAME, {}) + if not task_runs_subs: + # Nothing to rename + return + + assignment_subs = pk_substitutions.get(MEPHISTO_DUMP_KEY, {}).get(ASSIGNMENTS_TABLE_NAME, {}) + agent_subs = pk_substitutions.get(MEPHISTO_DUMP_KEY, {}).get(AGENTS_TABLE_NAME, {}) + + task_run_dirs = [ + d for d in task_run_dirs if os.path.basename(d) in task_runs_subs.keys() + ] + for task_run_dir in task_run_dirs: + # Rename TaskRun dir + renamed_task_run_dir = rename_dir_with_new_pk(task_run_dir, task_runs_subs) + + # Rename Assignments dirs + assignments_dirs = [ + os.path.join(renamed_task_run_dir, d) for d in os.listdir(renamed_task_run_dir) + if d in assignment_subs.keys() + ] + for assignment_dir in assignments_dirs: + renamed_assignment_dir = rename_dir_with_new_pk(assignment_dir, assignment_subs) + + # Rename Agents dirs + agents_dirs = [ + os.path.join(renamed_assignment_dir, d) for d in os.listdir(renamed_assignment_dir) + if d in agent_subs.keys() + ] + for agent_dir in agents_dirs: + rename_dir_with_new_pk(agent_dir, agent_subs) + + +def _export_data_dir_for_task_runs( + input_dir_path: str, + archive_file_path_without_ext: str, + task_runs: List[TaskRun], + pk_substitutions: dict, + _format: str = DEFAULT_ARCHIVE_FORMAT, + verbosity: int = 0, +) -> bool: + tmp_dir = get_mephisto_tmp_dir() + tmp_export_dir = os.path.join(tmp_dir, "export") + + task_run_data_dirs = [i.get_run_dir() for i in task_runs] + if not task_run_data_dirs: + return False + + try: + tmp_task_run_dirs = [] + + # Copy all files for passed TaskRuns into tmp dir + for task_run_data_dir in task_run_data_dirs: + relative_dir = Path(task_run_data_dir).relative_to(input_dir_path) + tmp_task_run_dir = os.path.join(tmp_export_dir, relative_dir) + + tmp_task_run_dirs.append(tmp_task_run_dir) + + os.makedirs(tmp_task_run_dir, exist_ok=True) + copy_tree(task_run_data_dir, tmp_task_run_dir, verbose=verbosity) + + _rename_dirs_with_new_pks(tmp_task_run_dirs, pk_substitutions) + + # Create archive in export dir + shutil.make_archive( + base_name=archive_file_path_without_ext, + format="zip", + root_dir=tmp_export_dir, + ) + finally: + # Remove tmp dir + if os.path.exists(tmp_export_dir): + shutil.rmtree(tmp_export_dir) + + return True + + +def make_backup_file_path_by_timestamp( + backup_dir: str, timestamp: str, _format: str = DEFAULT_ARCHIVE_FORMAT, +) -> str: + return os.path.join(backup_dir, f"{timestamp}_mephisto_backup.{_format}") + + +def make_full_data_dir_backup( + backup_dir: str, timestamp: str, _format: str = DEFAULT_ARCHIVE_FORMAT, +) -> str: + mephisto_data_dir = get_data_dir() + file_name_without_ext = f"{timestamp}_mephisto_backup" + archive_file_path_without_ext = os.path.join(backup_dir, file_name_without_ext) + + shutil.make_archive( + base_name=archive_file_path_without_ext, + format=_format, + root_dir=mephisto_data_dir, + ) + + return make_backup_file_path_by_timestamp(backup_dir, file_name_without_ext, _format) + + +def archive_and_copy_data_files( + db: "MephistoDB", + export_dir: str, + dump_name: str, + dump_data: dict, + pk_substitutions: dict, + verbosity: int = 0, + _format: str = DEFAULT_ARCHIVE_FORMAT, +) -> bool: + mephisto_data_files_path = os.path.join(get_data_dir(), "data") + output_zip_file_base_name = os.path.join(export_dir, dump_name) # name without extension + + # Get TaskRuns for PKs in dump + task_runs: List[TaskRun] = [] + for dump_task_run in dump_data[MEPHISTO_DUMP_KEY][TASK_RUNS_TABLE_NAME]: + task_runs_pk_field_name = db_utils.get_table_pk_field_name(db, TASK_RUNS_TABLE_NAME) + dump_pk = dump_task_run[task_runs_pk_field_name] + db_pk = get_old_pk_from_substitutions(dump_pk, pk_substitutions, TASK_RUNS_TABLE_NAME) + db_pk = db_pk or dump_pk + task_run: TaskRun = TaskRun.get(db, db_pk) + task_runs.append(task_run) + + # Export archived related data files to TaskRuns from dump + exported = _export_data_dir_for_task_runs( + input_dir_path=mephisto_data_files_path, + archive_file_path_without_ext=output_zip_file_base_name, + task_runs=task_runs, + pk_substitutions=pk_substitutions, + _format=_format, + verbosity=verbosity, + ) + + return exported + + +def restore_from_backup( + backup_file_path: str, + extract_dir: str, + _format: str = DEFAULT_ARCHIVE_FORMAT, + remove_backup: bool = False, +): + try: + shutil.unpack_archive(filename=backup_file_path, extract_dir=extract_dir, format=_format) + + if remove_backup: + Path(backup_file_path).unlink(missing_ok=True) + except Exception as e: + logger.exception(f"[red]Could not restore backup '{backup_file_path}'. Error: {e}[/red]") + raise diff --git a/mephisto/tools/db_data_porter/conflict_resolvers/__init__.py b/mephisto/tools/db_data_porter/conflict_resolvers/__init__.py new file mode 100644 index 000000000..86a1a7d02 --- /dev/null +++ b/mephisto/tools/db_data_porter/conflict_resolvers/__init__.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +from importlib import import_module +from inspect import isclass +from pkgutil import iter_modules + +from .base_merge_conflict_resolver import BaseMergeConflictResolver + +# Import all conflict resolver classes except the base class. +# This is needed in case if user decides to write a custom class and +# this way its name will be available as a parameter for import command +current_dir = os.path.dirname(os.path.abspath(__file__)) +for (_, module_name, _) in iter_modules([current_dir]): + module = import_module(f"{__name__}.{module_name}") + for attribute_name in dir(module): + attribute = getattr(module, attribute_name) + + if ( + isclass(attribute) and + issubclass(attribute, BaseMergeConflictResolver) and + attribute is not BaseMergeConflictResolver + ): + globals().update({attribute.__name__: attribute}) diff --git a/mephisto/tools/db_data_porter/conflict_resolvers/base_merge_conflict_resolver.py b/mephisto/tools/db_data_porter/conflict_resolvers/base_merge_conflict_resolver.py new file mode 100644 index 000000000..da5dbb5e0 --- /dev/null +++ b/mephisto/tools/db_data_porter/conflict_resolvers/base_merge_conflict_resolver.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from copy import deepcopy +from types import MethodType +from typing import Optional +from typing import Tuple + +from mephisto.abstractions.database import MephistoDB +from mephisto.utils.console_writer import ConsoleWriter +from mephisto.utils.misc import serialize_date_to_python + +logger = ConsoleWriter() + + +class BaseMergeConflictResolver: + """ + When importing dump data into local DB, some rows may refer to the same object + (e.g. two Task rows with hte same value of "name" column). This class contains default logic + to resolve such merging conflicts (implemented for all currently present DBs). + + To change this default behavior, you can write your own coflict resolver class: + 1. Add a new Python module next to this module (e.g. `my_conflict_resolver`) + 2. This module must contain a class (e.g. `MyMergeConflictResolver`) + that inherits from either `BaseMergeConflictResolver` + or default resolver `DefaultMergeConflictResolver` (also in this directory) + 3. To use this newly created class, specify its name in import command: + `mephisto db import ... --conflict-resolver MyMergeConflictResolver` + + The easiest place to start customization is to modify `strategies_config` property, + and perhaps `default_strategy_name` value (see `DefaultMergeConflictResolver` as an example). + + NOTE: All available providers must be present in `strategies_config`. + Table names (under each provider key) are optional, and if missing, `default_strategy_name` + will be used for all conflicts related to this table. + """ + + default_strategy_name = "pick_row_from_db" + strategies_config = {} + + def __init__(self, db: "MephistoDB", provider_type: str): + self.db = db + self.provider_type = provider_type + + @staticmethod + def _merge_rows_after_resolving( + table_pk_field_name: str, db_row: dict, dump_row: dict, resolved_row: dict, + ) -> dict: + """ + After we've resolved merging conflicts with rows fields, + we also need to select resulting value for some standard fields: + 1. Primary Key (choose DB row) + 2. Creation date (choose the earliest) + 3. Update date (choose the latest) + """ + + # 1. Save original PK from current DB + merged_row = deepcopy(resolved_row) + merged_row[table_pk_field_name] = db_row[table_pk_field_name] + + # 2. Choose the earliest creation date if table has this field + if "creation_date" in resolved_row: + min_creation_date = min( + serialize_date_to_python(db_row["creation_date"]), + serialize_date_to_python(dump_row["creation_date"]), + ) + merged_row["creation_date"] = min_creation_date + + # 3. Choose the latest updating date if table has this field + if "update_date" in resolved_row: + min_update_date = max( + serialize_date_to_python(db_row["update_date"]), + serialize_date_to_python(dump_row["update_date"]), + ) + merged_row["update_date"] = min_update_date + + return merged_row + + @staticmethod + def _serialize_compared_fields_in_rows( + db_row: dict, dump_row: dict, compared_field_name: str, + ) -> Tuple[dict, dict]: + db_value = db_row[compared_field_name] + dump_value = dump_row[compared_field_name] + + # Date fields + if compared_field_name.endswith("_at") or compared_field_name.endswith("_date"): + db_row[compared_field_name] = serialize_date_to_python(db_value) + dump_row[compared_field_name] = serialize_date_to_python(dump_value) + + # Numeric fields (integer or float) + # Note: We cast both compared values to a numeric type + # ONLY when one value is numeric, and another one is a string + # (to avoid, for example, casting float to integer) + for _type in [int, float]: + if ( + (isinstance(db_value, _type) and isinstance(dump_value, str)) or + (isinstance(db_value, str) and isinstance(dump_value, _type)) + ): + db_row[compared_field_name] = _type(db_value) + dump_row[compared_field_name] = _type(dump_value) + + return db_row, dump_row + + def resolve( + self, table_name: str, table_pk_field_name: str, db_row: dict, dump_row: dict, + ) -> dict: + """ + Default logic of validating `strategies_config`, + and resolving conflicts between database/datastore and dump rows + """ + # Validate strategies + + # 1. Providers must be set + provider_strategies = self.strategies_config.get(self.provider_type) + if not provider_strategies: + error_message = f"Could not find strategies for provider '{self.provider_type}'" + logger.error(f"[red]{error_message}[/red]") + raise ValueError(error_message) + + # 2. If no tables, use default strategy - `pick_row_from_db` + table_strategy = provider_strategies.get(table_name) + strategy_method_name = self.default_strategy_name + # Custom strategy + if table_strategy: + strategy_method_name = table_strategy.get("method") + strategy_method: MethodType = getattr(self, strategy_method_name, None) + strategy_method_kwargs = table_strategy.get("kwargs", {}) + # Default strategy + else: + strategy_method: MethodType = getattr(self, strategy_method_name, None) + strategy_method_kwargs = {} + + if not strategy_method: + error_message = f"Could not find method for strategy with name '{strategy_method_name}'" + logger.error(f"[red]{error_message}[/red]") + raise ValueError(error_message) + + # 3. Resolve conflicts + resolved_row = strategy_method(db_row, dump_row, **strategy_method_kwargs) + + # 4. Merge data + merged_row = self._merge_rows_after_resolving( + table_pk_field_name, db_row, dump_row, resolved_row, + ) + + # 4. Return merged row + return merged_row + + # --- Prepared most cummon strategies --- + def pick_row_with_smaller_value( + self, db_row: dict, dump_row: dict, compared_field_name: str, + ) -> dict: + db_row, dump_row = self._serialize_compared_fields_in_rows( + db_row, dump_row, compared_field_name, + ) + db_value = db_row[compared_field_name] + dump_value = dump_row[compared_field_name] + + # None cannot be compared with anything + if db_value is None: + return dump_row + if dump_value is None: + return db_value + + min_value = min(db_value, dump_value) + if min_value == db_value: + return db_row + return dump_row + + def pick_row_with_larger_value( + self, db_row: dict, dump_row: dict, compared_field_name: str, + ) -> dict: + db_row, dump_row = self._serialize_compared_fields_in_rows( + db_row, dump_row, compared_field_name, + ) + db_value = db_row[compared_field_name] + dump_value = dump_row[compared_field_name] + + # None cannot be compared with anything + if db_value is None: + return dump_row + if dump_value is None: + return db_value + + max_value = max(db_value, dump_value) + if max_value == db_value: + return db_row + return dump_row + + def pick_row_from_db( + self, db_row: dict, dump_row: dict, compared_field_name: Optional[str] = None, + ) -> dict: + return db_row + + def pick_row_from_dump( + self, db_row: dict, dump_row: dict, compared_field_name: Optional[str] = None, + ) -> dict: + return dump_row + + def pick_row_with_earlier_value( + self, db_row: dict, dump_row: dict, compared_field_name: str = "creation_date", + ) -> dict: + db_row, dump_row = self._serialize_compared_fields_in_rows( + db_row, dump_row, compared_field_name, + ) + db_value = db_row[compared_field_name] + dump_value = dump_row[compared_field_name] + + # None cannot be compared with anything + if db_value is None: + return dump_row + if dump_value is None: + return db_value + + if dump_value > db_value: + return db_row + return dump_row + + def pick_row_with_later_value( + self, db_row: dict, dump_row: dict, compared_field_name: str = "creation_date", + ) -> dict: + db_row, dump_row = self._serialize_compared_fields_in_rows( + db_row, dump_row, compared_field_name, + ) + db_value = db_row[compared_field_name] + dump_value = dump_row[compared_field_name] + + # None cannot be compared with anything + if db_value is None: + return dump_row + if dump_value is None: + return db_value + + if dump_value < db_value: + return db_row + return dump_row diff --git a/mephisto/tools/db_data_porter/conflict_resolvers/default_merge_conflict_resolver.py b/mephisto/tools/db_data_porter/conflict_resolvers/default_merge_conflict_resolver.py new file mode 100644 index 000000000..85a2bce68 --- /dev/null +++ b/mephisto/tools/db_data_porter/conflict_resolvers/default_merge_conflict_resolver.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from mephisto.tools.db_data_porter.constants import MEPHISTO_DUMP_KEY +from mephisto.tools.db_data_porter.constants import MOCK_PROVIDER_TYPE +from mephisto.tools.db_data_porter.constants import MTURK_PROVIDER_TYPE +from mephisto.tools.db_data_porter.constants import PROLIFIC_PROVIDER_TYPE +from .base_merge_conflict_resolver import BaseMergeConflictResolver + + +class DefaultMergeConflictResolver(BaseMergeConflictResolver): + """ + Default conflict resolver for importing JSON DB dumps. + If table name is not specified, default resolver strategy will be used + on all of its conflicting fields. + + For more detailed information, see docstring in `BaseMergeConflictResolver`. + """ + + strategies_config = { + MEPHISTO_DUMP_KEY: { + "granted_qualifications": { + # Go with more restrictive value + "method": "pick_row_with_smaller_value", + "kwargs": { + "compared_field_name": "value", + }, + }, + }, + PROLIFIC_PROVIDER_TYPE: { + "workers": { + # Go with more restrictive value + # Note that `is_blocked` is SQLite-boolean, which is an integer in Python + "method": "pick_row_with_larger_value", + "kwargs": { + "compared_field_name": "is_blocked", + }, + }, + }, + MOCK_PROVIDER_TYPE: {}, + MTURK_PROVIDER_TYPE: {}, + } diff --git a/mephisto/tools/db_data_porter/constants.py b/mephisto/tools/db_data_porter/constants.py new file mode 100644 index 000000000..bea3047d4 --- /dev/null +++ b/mephisto/tools/db_data_porter/constants.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from mephisto.abstractions.providers.mock.provider_type import PROVIDER_TYPE as MOCK_PROVIDER_TYPE +from mephisto.abstractions.providers.mturk.provider_type import PROVIDER_TYPE as MTURK_PROVIDER_TYPE +from mephisto.abstractions.providers.prolific.provider_type import ( + PROVIDER_TYPE as PROLIFIC_PROVIDER_TYPE +) + + +BACKUP_OUTPUT_DIR = "outputs/backup" +EXPORT_OUTPUT_DIR = "outputs/export" +MEPHISTO_DUMP_KEY = "mephisto" +METADATA_DUMP_KEY = "dump_metadata" +AVAILABLE_PROVIDER_TYPES = [ + MEPHISTO_DUMP_KEY, + MOCK_PROVIDER_TYPE, + MTURK_PROVIDER_TYPE, + PROLIFIC_PROVIDER_TYPE, +] +DATASTORE_EXPORT_METHOD_NAME = "get_export_data" +DEFAULT_CONFLICT_RESOLVER = "DefaultMergeConflictResolver" +IMPORTED_DATA_TABLE_NAME = "imported_data" +MIGRATIONS_TABLE_NAME = "migrations" +TASK_RUNS_TABLE_NAME = "task_runs" +UNITS_TABLE_NAME = "units" +ASSIGNMENTS_TABLE_NAME = "assignments" +AGENTS_TABLE_NAME = "agents" + +# Format of mappings: +# { +# : { +# : { +# "from": , +# "to": , +# }, +# ... # Table can have several FKs +# }, +# ... +# } +# +# If FK is related to Mephisto DB table, use `FK_MEPHISTO_TABLE_PREFIX` before table name. +# If FK is related to the same DB, simply use table name without any prefixes. +FK_MEPHISTO_TABLE_PREFIX = "mephisto." +PROVIDER_DATASTORES__MEPHISTO_FK__MAPPINGS = { + PROLIFIC_PROVIDER_TYPE: { + "studies": { + f"{FK_MEPHISTO_TABLE_PREFIX}{TASK_RUNS_TABLE_NAME}": { + "from": "task_run_id", + "to": "task_run_id", + }, + }, + "run_mappings": { + f"{FK_MEPHISTO_TABLE_PREFIX}{TASK_RUNS_TABLE_NAME}": { + "from": "task_run_id", + "to": "task_run_id", + }, + }, + "units": { + f"{FK_MEPHISTO_TABLE_PREFIX}{TASK_RUNS_TABLE_NAME}": { + "from": "task_run_id", + "to": "task_run_id", + }, + f"{FK_MEPHISTO_TABLE_PREFIX}{UNITS_TABLE_NAME}": { + "from": "unit_id", + "to": "unit_id", + }, + }, + "runs": { + f"{FK_MEPHISTO_TABLE_PREFIX}{TASK_RUNS_TABLE_NAME}": { + "from": "task_run_id", + "to": "task_run_id", + }, + }, + "participant_groups": { + f"{FK_MEPHISTO_TABLE_PREFIX}requesters": { + "from": "requester_id", + "to": "requester_id", + }, + }, + "qualifications": { + f"{FK_MEPHISTO_TABLE_PREFIX}{TASK_RUNS_TABLE_NAME}": { + "from": "task_run_id", + "to": "task_run_id", + }, + }, + }, + MOCK_PROVIDER_TYPE: { + "requesters": { + f"{FK_MEPHISTO_TABLE_PREFIX}requesters": { + "from": "requester_id", + "to": "requester_id", + }, + }, + "units": { + f"{FK_MEPHISTO_TABLE_PREFIX}{UNITS_TABLE_NAME}": { + "from": "unit_id", + "to": "unit_id", + }, + }, + "workers": { + f"{FK_MEPHISTO_TABLE_PREFIX}workers": { + "from": "worker_id", + "to": "worker_id", + }, + }, + }, + MTURK_PROVIDER_TYPE: { + "hits": { + f"{FK_MEPHISTO_TABLE_PREFIX}{UNITS_TABLE_NAME}": { + "from": "unit_id", + "to": "unit_id", + }, + f"{FK_MEPHISTO_TABLE_PREFIX}{ASSIGNMENTS_TABLE_NAME}": { + "from": "assignment_id", + "to": "assignment_id", + }, + }, + "run_mappings": { + f"{FK_MEPHISTO_TABLE_PREFIX}{TASK_RUNS_TABLE_NAME}": { + "from": "run_id", + "to": "task_run_id", + }, + }, + "runs": { + f"{FK_MEPHISTO_TABLE_PREFIX}{TASK_RUNS_TABLE_NAME}": { + "from": "run_id", + "to": "task_run_id", + }, + }, + "qualifications": { + f"{FK_MEPHISTO_TABLE_PREFIX}requesters": { + "from": "requester_id", + "to": "requester_id", + }, + }, + }, +} + +# As Mock or MTurk do not have real PKs (they use Mephisto PKs as provider tables' PKs), +# we cannot rely on auto getting PKs from db. +# Map tables with their PKs manually +PROVIDER_DATASTORES__RANDOMIZABLE_PK__MAPPINGS = { + PROLIFIC_PROVIDER_TYPE: { + "participant_groups": "id", + "qualifications": "id", + "run_mappings": "id", + "runs": "id", + "studies": "id", + "submissions": "id", + "units": "id", + "workers": "id", + }, +} + +# Tables must be in specific order to satisfy constraints of Foreign Keys. +# NOTE: field names are lists, because fields can be UNIQUE TOGETHER +TABLES_UNIQUE_LOOKUP_FIELDS = { + MEPHISTO_DUMP_KEY: { + "projects": ["project_name"], + "requesters": ["requester_name"], + "tasks": ["task_name"], + "qualifications": ["qualification_name"], + "workers": ["worker_name"], + TASK_RUNS_TABLE_NAME: None, + ASSIGNMENTS_TABLE_NAME: None, + UNITS_TABLE_NAME: None, + AGENTS_TABLE_NAME: None, + "onboarding_agents": None, + "granted_qualifications": ["worker_id", "qualification_id"], + "unit_review": None, + }, + PROLIFIC_PROVIDER_TYPE: { + "workers": ["worker_id"], + "participant_groups": ["qualification_name"], + "qualifications": None, + "studies": None, + "runs": ["task_run_id"], + "run_mappings": None, + "submissions": None, + "units": ["unit_id"], + }, + MOCK_PROVIDER_TYPE: { + "requesters": None, + "workers": None, + "units": None, + }, + MTURK_PROVIDER_TYPE: { + "hits": None, + "run_mappings": None, + "runs": None, + "qualifications": None, + }, +} + +# Tables that we need to write into `imported_data` during importing from dump file +IMPORTED_DATA_TABLE_NAMES = [ + "projects", + "requesters", + "tasks", + "qualifications", + "granted_qualifications", + "workers", + # TaskRuns cannot conflict, + # but we write them into `imported_data` to know which TaskRuns were imported and + # have fast access to export by labels + TASK_RUNS_TABLE_NAME, +] + +# We mark rows in `imported_data` with labels and this label is used +# if conflicted row was already presented in local DB +LOCAL_DB_LABEL = "_" diff --git a/mephisto/tools/db_data_porter/db_data_porter.py b/mephisto/tools/db_data_porter/db_data_porter.py new file mode 100644 index 000000000..237cfb235 --- /dev/null +++ b/mephisto/tools/db_data_porter/db_data_porter.py @@ -0,0 +1,436 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +from datetime import datetime +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +from mephisto.abstractions.database import MephistoDB +from mephisto.abstractions.databases.local_database import LocalMephistoDB +from mephisto.generators.form_composer.config_validation.utils import make_error_message +from mephisto.tools.db_data_porter import backups +from mephisto.tools.db_data_porter import dumps +from mephisto.tools.db_data_porter import import_dump +from mephisto.tools.db_data_porter.constants import BACKUP_OUTPUT_DIR +from mephisto.tools.db_data_porter.constants import EXPORT_OUTPUT_DIR +from mephisto.tools.db_data_porter.constants import IMPORTED_DATA_TABLE_NAME +from mephisto.tools.db_data_porter.constants import MEPHISTO_DUMP_KEY +from mephisto.tools.db_data_porter.constants import METADATA_DUMP_KEY +from mephisto.tools.db_data_porter.constants import MIGRATIONS_TABLE_NAME +from mephisto.tools.db_data_porter.randomize_ids import randomize_ids +from mephisto.tools.db_data_porter.validation import validate_dump_data +from mephisto.utils import db as db_utils +from mephisto.utils.dirs import get_data_dir +from mephisto.utils.misc import serialize_date_to_python +from mephisto.utils.console_writer import ConsoleWriter + +logger = ConsoleWriter() + + +class DBDataPorter: + """ + Import, export, backup and restore DB data. + + This class contains the main logic of commands `mephisto db ...`. + """ + + def __init__(self, db=None): + # Load Mephisto DB and providers' datastores + if db is None: + db = LocalMephistoDB() + self.db = db + self.provider_datastores: Dict[str, "MephistoDB"] = db_utils.get_providers_datastores( + self.db, + ) + + # Cached Primary Keys + self._pk_substitutions = {} + + @staticmethod + def _get_root_mephisto_repo_dir() -> str: + return os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname( + os.path.abspath(__file__) + )))) + + def _get_export_dir(self) -> str: + root_dir = self._get_root_mephisto_repo_dir() + export_path = os.path.join(root_dir, EXPORT_OUTPUT_DIR) + # Create dirs if needed + os.makedirs(export_path, exist_ok=True) + return export_path + + def _get_backup_dir(self) -> str: + root_dir = self._get_root_mephisto_repo_dir() + backup_path = os.path.join(root_dir, BACKUP_OUTPUT_DIR) + # Create dirs if needed + os.makedirs(backup_path, exist_ok=True) + return backup_path + + @staticmethod + def _make_export_timestamp() -> str: + return datetime.now().strftime("%Y_%m_%d_%H_%M_%S") + + @staticmethod + def _make_dump_name(timestamp: str) -> str: + return f"{timestamp}_mephisto_dump" + + @staticmethod + def _make_export_dump_file_path(export_path: str, dump_name: str) -> str: + file_name = f"{dump_name}.json" + file_path = os.path.join(export_path, file_name) + return file_path + + def _prepare_dump_data( + self, + task_names: Optional[List[str]] = None, + task_ids: Optional[List[str]] = None, + task_run_ids: Optional[List[str]] = None, + task_runs_labels: Optional[List[str]] = None, + since_datetime: Optional[datetime] = None, + randomize_legacy_ids: Optional[bool] = False, + ) -> dict: + partial = bool(task_names or task_ids or task_run_ids or task_runs_labels or since_datetime) + if not partial: + dump_data = dumps.prepare_full_dump_data(self.db, self.provider_datastores) + else: + dump_data = dumps.prepare_partial_dump_data( + self.db, + task_names=task_names, + task_ids=task_ids, + task_run_ids=task_run_ids, + task_runs_labels=task_runs_labels, + since_datetime=since_datetime, + ) + + if randomize_legacy_ids: + randomize_ids_results = randomize_ids(self.db, dump_data, legacy_only=True) + dump_data = randomize_ids_results["updated_dump"] + self._pk_substitutions = randomize_ids_results["pk_substitutions"] + + return dump_data + + def _get_latest_migrations(self) -> Dict[str, Union[None, str]]: + db_and_datastores = { + MEPHISTO_DUMP_KEY: self.db, + **self.provider_datastores, + } + latest_migrations = {} + for db_name, db in db_and_datastores.items(): + try: + latest_migration = db_utils.get_latest_row_from_table(db, MIGRATIONS_TABLE_NAME) + except Exception as e: + # This is almost unreal scenario, but it should be covered anyway. + # If somebody runs this code, it must create this table in all DB and datastores, + # right in the beginning, far away from this part of the code + logger.warning(f"[yellow]No 'migrations' table found. {e}[/yellow]") + latest_migration = None + + migration_name = latest_migration["name"] if latest_migration else None + latest_migrations[db_name] = migration_name + + return latest_migrations + + @staticmethod + def _ask_user_if_they_are_sure() -> bool: + question = input( + "Are you sure? " + "It will affect your databases and related files. " + "Type 'yes' and press Enter if you want to proceed: " + ) + if question != "yes": + logger.info("Ok. Bye") + exit() + + return True + + @staticmethod + def _get_label_from_file_path(file_path: str) -> str: + base_name = os.path.basename(file_path) + name_without_ext = base_name.split(".")[0] + return name_without_ext + + def export_dump( + self, + json_indent: Optional[int] = None, + task_names: Optional[List[str]] = None, + task_ids: Optional[List[str]] = None, + task_run_ids: Optional[List[str]] = None, + task_runs_since_date: Optional[str] = None, + task_runs_labels: Optional[List[str]] = None, + delete_exported_data: bool = False, + randomize_legacy_ids: bool = False, + verbosity: int = 0, + ) -> dict: + # 1. Protect from accidental launches + if delete_exported_data: + self._ask_user_if_they_are_sure() + + # 2. Prepare dump data with Mephisto DB and provider datastores + since_datetime = None + if task_runs_since_date: + try: + since_datetime = serialize_date_to_python(task_runs_since_date) + except Exception: + error_message = f"Could not parse date '{task_runs_since_date}'." + logger.exception(f"[red]{error_message}[/red]") + exit() + + dump_data_to_export = self._prepare_dump_data( + task_names=task_names, + task_ids=task_ids, + task_run_ids=task_run_ids, + task_runs_labels=task_runs_labels, + since_datetime=since_datetime, + randomize_legacy_ids=randomize_legacy_ids, + ) + + # 3. Prepare export dirs and get dump file path + export_dir = self._get_export_dir() + dump_timestamp = self._make_export_timestamp() + dump_name = self._make_dump_name(dump_timestamp) + file_path = self._make_export_dump_file_path(export_dir, dump_name) + + # 4. Prepare metadata + metadata = { + "migrations": self._get_latest_migrations(), + "export_parameters": { + "--export-indent": json_indent, + "--export-tasks-by-names": task_names, + "--export-tasks-by-ids": task_ids, + "--export-task-runs-by-ids": task_run_ids, + "--export-task-runs-since-date": task_runs_since_date, + "--verbosity": verbosity, + }, + "timestamp": dump_timestamp, + "pk_substitutions": self._pk_substitutions, + } + dump_data_to_export[METADATA_DUMP_KEY] = metadata + + # 5. Save JSON file + try: + with open(file_path, "w") as f: + f.write(json.dumps(dump_data_to_export, indent=json_indent)) + except Exception as e: + # Remove file to not make a mess in export directory + error_message = f"Could not create dump file {dump_data_to_export}. Reason: {str(e)}." + + if verbosity: + logger.exception(f"[red]{error_message}[/red]") + os.remove(file_path) + exit() + + # 6. Archive files in file system + exported = backups.archive_and_copy_data_files( + self.db, + export_dir, + dump_name, + dump_data_to_export, + pk_substitutions=self._pk_substitutions, + verbosity=verbosity, + ) + + # 7. Delete exported data if needed after backing data up + backup_path = None + if delete_exported_data: + backup_dir = self._get_backup_dir() + backup_path = backups.make_full_data_dir_backup(backup_dir, dump_timestamp) + delete_tasks = bool(task_names or task_ids) + is_partial_dump = bool(task_names or task_ids or task_run_ids or task_runs_since_date) + dumps.delete_exported_data( + db=self.db, + dump_data_to_export=dump_data_to_export, + pk_substitutions=self._pk_substitutions, + partial=is_partial_dump, + delete_tasks=delete_tasks, + ) + + data_path = None + if exported: + data_path = os.path.join( + export_dir, f"{dump_name}.{backups.DEFAULT_ARCHIVE_FORMAT}", + ) + + return { + "db_path": file_path, + "data_path": data_path, + "backup_path": backup_path, + } + + def import_dump( + self, + dump_file_name_or_path: str, + conflict_resolver_name: str, + label: Optional[str] = None, + keep_import_metadata: Optional[bool] = None, + verbosity: int = 0, + ): + # 1. Check dump file path + is_dump_path_full = os.path.isabs(dump_file_name_or_path) + if not is_dump_path_full: + root_dir = self._get_root_mephisto_repo_dir() + dump_file_name_or_path = os.path.join( + root_dir, EXPORT_OUTPUT_DIR, dump_file_name_or_path, + ) + + if not os.path.exists(dump_file_name_or_path): + error_message = ( + f"Could not find dump file '{dump_file_name_or_path}'. " + f"Please, specify full path to existing file or " + f"just file name that is located in `/{EXPORT_OUTPUT_DIR}`." + ) + + if verbosity: + logger.exception(f"[red]{error_message}[/red]") + exit() + + # 2. Read dump file + with open(dump_file_name_or_path, "r") as f: + try: + dump_file_data: dict = json.loads(f.read()) + except Exception as e: + error_message = ( + f"Could not read JSON from dump file '{dump_file_name_or_path}'. " + f"Please, check if it has the correct format. Reason: {str(e)}" + ) + logger.exception(f"[red]{error_message}[/red]") + exit() + + # 3. Validate dump + dump_data_errors = validate_dump_data(self.db, dump_file_data) + if dump_data_errors: + error_message = make_error_message( + "Your dump file has incorrect format", dump_data_errors, indent=4, + ) + logger.error(f"[red]{error_message}[/red]") + exit() + + # 4. Protect from accidental launches + self._ask_user_if_they_are_sure() + + # 5. Extract metadata (we do not use it for now, but it is needed to be popped) + metadata = dump_file_data.pop(METADATA_DUMP_KEY, {}) + + # 6. Make a backup of full local `data` path with databases and files. + # This is for simulating transactional writing into several database and if sth went wrong, + # have the ability to rollback everything we've just done + logger.info( + "Just in case, we are making a backup of all your local data. " + "If something went wrong during import, we will restore all your data from this backup" + ) + backup_dir = self._get_backup_dir() + dump_timestamp = self._make_export_timestamp() + backup_path = backups.make_full_data_dir_backup(backup_dir, dump_timestamp) + logger.info(f"Backup was created successfully! File: '{backup_path}'") + + # 7. Write dump data into local DBs + for db_or_datastore_name, db_or_datastore_data in dump_file_data.items(): + imported_data_from_dump = [] + + if db_or_datastore_name == MEPHISTO_DUMP_KEY: + # Main Mephisto database + db = self.db + imported_data_from_dump = dump_file_data.get(MEPHISTO_DUMP_KEY, {}).pop( + IMPORTED_DATA_TABLE_NAME, [], + ) + else: + # Provider's datastore. + # NOTE: It is being created if it does not exist (yes, here, magically) + datastore = self.provider_datastores.get(db_or_datastore_name) + + if not datastore: + logger.error( + f"Current version of Mephisto does not support " + f"'{db_or_datastore_name}' providers." + ) + exit() + + db = datastore + + if verbosity: + logger.info(f"Start importing into `{db_or_datastore_name}` database") + + label = label or self._get_label_from_file_path(dump_file_name_or_path) + import_single_db_results = import_dump.import_single_db( + db=db, + provider_type=db_or_datastore_name, + dump_data=db_or_datastore_data, + conflict_resolver_name=conflict_resolver_name, + label=label, + verbosity=verbosity, + ) + + errors = import_single_db_results["errors"] + + if errors: + error_message = make_error_message("Import was not processed", errors, indent=4) + logger.error(f"[red]{error_message}[/red]") + + # Simulating rollback for all databases/datastores and related data files + mephisto_data_path = get_data_dir() + backup_path = backups.make_backup_file_path_by_timestamp(backup_dir, dump_timestamp) + + if verbosity: + logger.info(f"Rolling back all changed from backup '{backup_path}'") + + backups.restore_from_backup(backup_path, mephisto_data_path) + + exit() + + # Write imformation in `imported_data` + if db_or_datastore_name == MEPHISTO_DUMP_KEY: + # Fill `imported_data` table with imported dump + import_dump.fill_imported_data_with_imported_dump( + db=db, + imported_data=import_single_db_results["imported_data"], + source_file_name=os.path.basename(dump_file_name_or_path), + ) + + # Fill `imported_data` with information from `imported_data` from dump + if keep_import_metadata and imported_data_from_dump: + import_dump.import_table_imported_data_from_dump(db, imported_data_from_dump) + + if verbosity: + logger.info( + f"Finished importing into `{db_or_datastore_name}` database successfully!" + ) + + def make_backup(self) -> str: + backup_dir = self._get_backup_dir() + dump_timestamp = self._make_export_timestamp() + backup_path = backups.make_full_data_dir_backup(backup_dir, dump_timestamp) + return backup_path + + def restore_from_backup(self, backup_file_name_or_path: str, verbosity: int = 0): + # 1. Protect from accidental launches + self._ask_user_if_they_are_sure() + + # 2. Check backup file path + is_backup_path_full = os.path.isabs(backup_file_name_or_path) + if not is_backup_path_full: + root_dir = self._get_root_mephisto_repo_dir() + backup_file_name_or_path = os.path.join( + root_dir, BACKUP_OUTPUT_DIR, backup_file_name_or_path, + ) + + if not os.path.exists(backup_file_name_or_path): + error_message = ( + f"Could not find backup file '{backup_file_name_or_path}'. " + f"Please, specify full path to existing file or " + f"just file name that is located in `/{BACKUP_OUTPUT_DIR}`." + ) + logger.exception(f"[red]{error_message}[/red]") + exit() + + if verbosity and not is_backup_path_full: + logger.info(f"Found backup file '{backup_file_name_or_path}'") + + # 3. Restore + mephisto_data_path = get_data_dir() + backups.restore_from_backup(backup_file_name_or_path, mephisto_data_path) diff --git a/mephisto/tools/db_data_porter/dumps.py b/mephisto/tools/db_data_porter/dumps.py new file mode 100644 index 000000000..34618ace0 --- /dev/null +++ b/mephisto/tools/db_data_porter/dumps.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import shutil +from datetime import datetime +from pathlib import Path +from types import MethodType +from typing import Dict +from typing import List +from typing import Optional + +from mephisto.abstractions.database import MephistoDB +from mephisto.data_model.task_run import TaskRun +from mephisto.tools.db_data_porter.constants import DATASTORE_EXPORT_METHOD_NAME +from mephisto.tools.db_data_porter.constants import MEPHISTO_DUMP_KEY +from mephisto.tools.db_data_porter.constants import TASK_RUNS_TABLE_NAME +from mephisto.tools.db_data_porter.randomize_ids import get_old_pk_from_substitutions +from mephisto.utils import db as db_utils +from mephisto.utils.console_writer import ConsoleWriter +from mephisto.utils.dirs import get_data_dir + +logger = ConsoleWriter() + + +def prepare_partial_dump_data( + db: "MephistoDB", + task_names: Optional[List[str]] = None, + task_ids: Optional[List[str]] = None, + task_run_ids: Optional[List[str]] = None, + task_runs_labels: Optional[List[str]] = None, + since_datetime: Optional[datetime] = None, +) -> dict: + dump_data_to_export = {} + + # Mephisto DB + + # Convert all parameters to `task_run_ids` (TaskRun is the main object in DB) + if not task_run_ids: + if task_names or task_ids: + if task_names: + task_ids = db_utils.get_task_ids_by_task_names(db, task_names) + task_ids = task_ids or [] + task_run_ids = db_utils.get_task_run_ids_ids_by_task_ids(db, task_ids) + elif task_runs_labels: + task_run_ids = db_utils.get_task_run_ids_ids_by_labels(db, task_runs_labels) + elif since_datetime: + task_run_ids = db_utils.select_task_run_ids_since_date(db, since_datetime) + + logger.info(f"Run command for TaskRun IDs: {', '.join(task_run_ids)}.") + + mephisto_db_data = db_utils.mephisto_db_to_dict_for_task_runs(db, task_run_ids) + dump_data_to_export[MEPHISTO_DUMP_KEY] = mephisto_db_data + + # Providers' DBs + provider_types = [i["provider_type"] for i in mephisto_db_data["requesters"]] + + for provider_type in provider_types: + provider_datastore = db.get_datastore_for_provider(provider_type) + dump_data_to_export[provider_type] = db_utils.db_or_datastore_to_dict( + provider_datastore, + ) + + # Get a method-function from provider-datastore. + # There is a provider-specific logic of exporting DB data as it can have any scheme. + # It can be missed and not implemented at all + datastore_export_method: MethodType = getattr( + provider_datastore, DATASTORE_EXPORT_METHOD_NAME, None, + ) + if datastore_export_method: + datastore_export_data = datastore_export_method( + task_run_ids=task_run_ids, mephisto_db_data=mephisto_db_data, + ) + else: + # If method was not implemented in provider datastore, we export all tables fully. + error_message = ( + f"You did not implement " + f"{provider_datastore.__class__.__name__}.{DATASTORE_EXPORT_METHOD_NAME}." + f"Exporting full datastore. Specify logic of selecting related table rows " + f"in your provider datastore or leave it as it is." + ) + logger.error(f"[red]{error_message}[/red]") + datastore_export_data = db_utils.db_or_datastore_to_dict(provider_datastore) + + dump_data_to_export[provider_type] = datastore_export_data + + return dump_data_to_export + + +def prepare_full_dump_data(db: "MephistoDB", provider_datastores: Dict[str, "MephistoDB"]) -> dict: + dump_data_to_export = {} + + logger.info(f"Run command for all TaskRuns.") + + # Mephisto DB + dump_data_to_export[MEPHISTO_DUMP_KEY] = db_utils.db_or_datastore_to_dict(db) + + # Providers' DBs + for provider_type, provider_datastore in provider_datastores.items(): + dump_data_to_export[provider_type] = db_utils.db_or_datastore_to_dict(provider_datastore) + + return dump_data_to_export + + +def delete_exported_data( + db: "MephistoDB", + dump_data_to_export: dict, + pk_substitutions: dict, + partial: bool = False, + delete_tasks: bool = False, +): + # 1. Mephisto DB + if not partial: + # Clean DB + db_utils.delete_entire_exported_data(db) + + # Clean related files + mephisto_data_files_path = os.path.join(get_data_dir(), "data") + for filename in os.listdir(mephisto_data_files_path): + file_path = os.path.join(mephisto_data_files_path, filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + Path(file_path).unlink(missing_ok=True) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + logger.warning(f"Failed to delete '{file_path}'. Reason: {e}") + else: + # NOTE: We cannot remove all exported rows from DB that are presented in dump, + # because some objects can be common among other Projects, Tasks, etc. + # E.g. you cannot remove Task, Worker, Qualification + # just because it was related to the exported TaskRuns, + # they are used or can be used in the future for other Projects, Tasks or TaskRuns. + names_of_tables_to_cleanup = [ + "agents", + "assignments", + "task_runs", + "unit_review", + "units", + ] + if delete_tasks: + names_of_tables_to_cleanup.append("tasks") + + # Get directories related to dumped TaskRuns + task_run_rows = dump_data_to_export.get(MEPHISTO_DUMP_KEY, {}).get( + TASK_RUNS_TABLE_NAME, [], + ) + task_runs_pk_field_name = db_utils.get_table_pk_field_name(db, TASK_RUNS_TABLE_NAME) + task_run_ids = [r[task_runs_pk_field_name] for r in task_run_rows] + task_run_ids = [ + get_old_pk_from_substitutions(i, pk_substitutions, TASK_RUNS_TABLE_NAME) or i + for i in task_run_ids + ] + + task_run_data_dirs = [TaskRun.get(db, i).get_run_dir() for i in task_run_ids] + + # Clean DB + db_utils.delete_exported_data_without_fk_constraints( + db, dump_data_to_export[MEPHISTO_DUMP_KEY], names_of_tables_to_cleanup, + ) + + # Clean related files + for task_run_data_dir in task_run_data_dirs: + try: + shutil.rmtree(task_run_data_dir) + except Exception as e: + logger.warning(f"Failed to delete '{task_run_data_dir}'. Reason: {e}") diff --git a/mephisto/tools/db_data_porter/import_dump.py b/mephisto/tools/db_data_porter/import_dump.py new file mode 100644 index 000000000..7390d928c --- /dev/null +++ b/mephisto/tools/db_data_porter/import_dump.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +from typing import Dict +from typing import List +from typing import Optional +from typing import TypedDict + +from mephisto.abstractions.database import MephistoDB +from mephisto.tools.db_data_porter import conflict_resolvers +from mephisto.tools.db_data_porter.constants import IMPORTED_DATA_TABLE_NAME +from mephisto.tools.db_data_porter.constants import IMPORTED_DATA_TABLE_NAMES +from mephisto.tools.db_data_porter.constants import LOCAL_DB_LABEL +from mephisto.tools.db_data_porter.constants import MEPHISTO_DUMP_KEY +from mephisto.tools.db_data_porter.constants import TABLES_UNIQUE_LOOKUP_FIELDS +from mephisto.utils import db as db_utils +from mephisto.utils.console_writer import ConsoleWriter + +UNIQUE_FIELD_NAMES = "unique_field_names" +UNIQUE_FIELD_VALUES = "unique_field_values" + +logger = ConsoleWriter() + +TableNameType = str +PKSubstitutionsType = Dict[ + str, # Importing PK + str, # Existing PK in local DB +] +MappingResolvingsType = Dict[TableNameType, PKSubstitutionsType] + + +class ImportSingleDBsType(TypedDict): + errors: Optional[List[str]] + imported_data: Optional[dict] + + +def _update_row_with_pks_from_resolvings_mappings( + db: "MephistoDB", + table_name: str, + row: dict, + resolvings_mapping: MappingResolvingsType, +) -> dict: + table_fks = db_utils.select_fk_mappings_for_table(db, table_name) + + # Update FK fields from resolving mappings if needed + for fk_table, fk_table_fields in table_fks.items(): + relating_table_mapping = resolvings_mapping.get(fk_table) + if not relating_table_mapping: + continue + + relating_table_row_pk_mapping = relating_table_mapping.get(row[fk_table_fields["from"]]) + if not relating_table_row_pk_mapping: + continue + + row[fk_table_fields["from"]] = relating_table_row_pk_mapping + + return row + + +def import_single_db( + db: "MephistoDB", + provider_type: str, + dump_data: dict, + conflict_resolver_name: str, + label: str, + verbosity: int = 0, +) -> ImportSingleDBsType: + # Results of the function + imported_data = {} + errors = [] + + # Variables to save current (intermediate) working values in case of exceptions + # to make a comprehensive error message, because SQLite has a lack of it + in_progress_table_name = None + in_progress_dump_row = None + in_progress_table_pk_field_name = None + + # Mappings between conflicted and chosen after resolving a conflict PKs + resolvings_mapping: MappingResolvingsType = {} + + # --- HACK (#UNIT.AGENT_ID) START #1: + # In Mephisto DB we have a problem with inserting `units` and `agents` tables. + # Both tables have the relation (FK `units.agent_id` and `agents.unit_id`) to each other and + # if both fields are filled in, we catch an FK constraint error. + # Changing the order of importing tables will not help us in this case. + # The solution is to create all rows in `units` table `agent_id = NULL` and + # save these IDs in following dict. + # As soon as we complete all `units` and `agents`, + # we will update all rows in `units` with saved `agent_id` values. + units_agents = {} + # --- HACK (#UNIT.AGENT_ID) END #1: + + # Import conflict resolver class and initiate it + conflict_resolver_class = getattr(conflict_resolvers, conflict_resolver_name, None) + if not conflict_resolver_class: + error_message = f"Conflict resolver with name '{conflict_resolver_name}' has not found" + logger.error(f"[red]{error_message}[/red]") + raise ImportError(error_message) + conflict_resolver_name = conflict_resolver_class(db, provider_type) + + try: + # Independent tables with their not PK unigue field names where can be conflicts. + # They must be imported before other tables + tables_with_special_unique_field = TABLES_UNIQUE_LOOKUP_FIELDS.get(provider_type) + for table_name, unique_field_names in tables_with_special_unique_field.items(): + dump_table_rows = dump_data[table_name] + table_pk_field_name = db_utils.get_table_pk_field_name(db, table_name) + is_table_with_special_unique_field = unique_field_names is not None + + # Save data that in progress for better logging + in_progress_table_name = table_name + in_progress_table_pk_field_name = table_pk_field_name + + # Imported data vars + imported_data_needs_to_be_updated = ( + provider_type == MEPHISTO_DUMP_KEY and + table_name in IMPORTED_DATA_TABLE_NAMES + ) + + newly_imported_labels = json.dumps([label]) + conflicted_labels = json.dumps([LOCAL_DB_LABEL, label]) + imported_data_for_table = { + newly_imported_labels: [], + conflicted_labels: [], + } + + for dump_row in dump_table_rows: + # Save data that in progress for better logging + in_progress_dump_row = dump_row + + # --- HACK (#UNIT.AGENT_ID) START #2: + # We save pairs `unit_id: agent_id` in case if `agent_id is not None` and + # replace `agent_id` with `None` + if provider_type == MEPHISTO_DUMP_KEY: + if table_name == "units" and (unit_agent_id := dump_row.get("agent_id")): + unit_id = dump_row[table_pk_field_name] + units_agents[unit_id] = unit_agent_id + dump_row["agent_id"] = None + # --- HACK (#UNIT.AGENT_ID) END #2: + + imported_data_row_unique_field_values = [dump_row[table_pk_field_name]] + imported_data_conflicted_row = False + + _update_row_with_pks_from_resolvings_mappings( + db, table_name, dump_row, resolvings_mapping, + ) + + # Table with non-PK unique field + if is_table_with_special_unique_field: + imported_data_row_unique_field_values = [ + dump_row[fn] for fn in unique_field_names + ] + + unique_field_values: List[List[str]] = [ + [dump_row[fn]] for fn in unique_field_names + ] + existing_rows = db_utils.select_rows_by_list_of_field_values( + db=db, + table_name=table_name, + field_names=unique_field_names, + field_values=unique_field_values, + order_by="creation_date", + ) + + # If local DB does not have this row + if not existing_rows: + if verbosity: + logger.info(f"Inserting new row into table '{table_name}': {dump_row}") + + db_utils.insert_new_row_in_table(db, table_name, dump_row) + + # If local DB already has row with specified unique field name + else: + imported_data_conflicted_row = True + + existing_db_row = existing_rows[-1] + + if verbosity: + logger.info( + f"Conflicts during inserting row in table '{table_name}': " + f"{dump_row}. " + f"Existing row in your database: {existing_db_row}" + ) + + resolved_conflicting_row = conflict_resolver_name.resolve( + table_name, table_pk_field_name, existing_db_row, dump_row, + ) + db_utils.update_row_in_table( + db, table_name, resolved_conflicting_row, table_pk_field_name, + ) + + # Saving resolved a pair of PKs + existing_row_pk_value = resolved_conflicting_row[table_pk_field_name] + importing_row_pk_value = dump_row[table_pk_field_name] + + mappings_prev_value = resolvings_mapping.get(table_name, {}) + resolvings_mapping[table_name] = { + **mappings_prev_value, + **{importing_row_pk_value: existing_row_pk_value}, + } + + # Regular table. Create new row as is + else: + db_utils.insert_new_row_in_table(db, table_name, dump_row) + + # Update table lists of Imported data + if imported_data_needs_to_be_updated: + if imported_data_conflicted_row: + _label = conflicted_labels + else: + _label = newly_imported_labels + + imported_data_for_table[_label].append({ + UNIQUE_FIELD_NAMES: unique_field_names or [table_pk_field_name], + UNIQUE_FIELD_VALUES: imported_data_row_unique_field_values, + }) + + # Add table into Imported data + if imported_data_needs_to_be_updated: + imported_data[table_name] = imported_data_for_table + + # --- HACK (#UNIT.AGENT_ID) START #3: + # Update all created `units` rows in #2 with presaved `agent_id` values + if provider_type == MEPHISTO_DUMP_KEY: + for unit_id, agent_id in units_agents.items(): + db_utils.update_row_in_table( + db, "units", {"unit_id": unit_id, "agent_id": agent_id}, "unit_id", + ) + # --- HACK (#UNIT.AGENT_ID) END #3: + + except Exception as e: + # Custom error message in cases when we can guess what happens + # using small info SQLite gives us + possible_issue = "" + if in_progress_table_pk_field_name in str(e) and "UNIQUE constraint" in str(e): + pk_value = in_progress_dump_row[in_progress_table_pk_field_name] + possible_issue = ( + f"\nPossible issue: " + f"Local database already have Primary Key '{pk_value}' " + f"in table '{in_progress_table_name}'. " + f"Maybe you are trying to run already merged dump file. " + f"Or if you have old databases, you may bump into same Primary Keys. " + f"If you are sure that all data from this dump is unique and " + f"still have access to the dumped project, " + f"try to create dump with parameter `--randomize-legacy-ids` " + f"and start importing again." + ) + + default_error_message_beginning = "" + if not possible_issue: + default_error_message_beginning = "Unexpected error happened: " + + errors.append( + f"{default_error_message_beginning}{e}." + f"{possible_issue}" + f"\nProvider: {provider_type}." + f"\nTable: {in_progress_table_name}." + f"\nRow: {json.dumps(in_progress_dump_row, indent=2)}." + ) + + return { + "errors": errors, + "imported_data": imported_data, + } + + +def fill_imported_data_with_imported_dump( + db: "MephistoDB", imported_data: dict, source_file_name: str, +): + for table_name, table_info in imported_data.items(): + for labels, labels_rows in table_info.items(): + for row in labels_rows: + if not row["unique_field_values"]: + continue + + unique_field_names = json.dumps(row["unique_field_names"]) + unique_field_values = json.dumps(row["unique_field_values"]) + db_utils.insert_new_row_in_table( + db=db, + table_name=IMPORTED_DATA_TABLE_NAME, + row={ + "source_file_name": source_file_name, + "data_labels": labels, + "table_name": table_name, + "unique_field_names": unique_field_names, + "unique_field_values": unique_field_values, + }, + ) + + +def import_table_imported_data_from_dump(db: "MephistoDB", imported_data_rows: List[dict]): + for row in imported_data_rows: + table_name = row["table_name"] + unique_field_names = row["unique_field_names"] + unique_field_values = row["unique_field_values"] + + # Check if item from row was already imported before from other dumps + existing_rows = db_utils.select_rows_by_list_of_field_values( + db=db, + table_name="imported_data", + field_names=["table_name", "unique_field_names", "unique_field_values"], + field_values=[[table_name], [unique_field_names], [unique_field_values]], + order_by="-creation_date", + ) + existing_row = existing_rows[0] if existing_rows else None + + # As we add this imported data from dump, new lines cannot have label for local DB. + # Current local DB is not that one that was local for dumped data. + # We save only labels from imported dumps + importing_data_labels = json.loads(row["data_labels"]) + data_labels_without_local = [l for l in importing_data_labels if l != LOCAL_DB_LABEL] + + # Update existing row + if existing_row: + # Merge existing labels with from imported row + existing_data_labels = json.loads(existing_row["data_labels"]) + existing_data_labels += importing_data_labels + existing_row["data_labels"] = json.dumps(list(set(existing_data_labels))) + + db_utils.update_row_in_table( + db=db, table_name="imported_data", row=existing_row, pk_field_name="id", + ) + + # Create new row + else: + # Set labels and remove PK field from importing row + row.pop("id", None) + row["data_labels"] = json.dumps(data_labels_without_local) + + db_utils.insert_new_row_in_table(db=db, table_name="imported_data", row=row) diff --git a/mephisto/tools/db_data_porter/randomize_ids.py b/mephisto/tools/db_data_porter/randomize_ids.py new file mode 100644 index 000000000..ab81b75a9 --- /dev/null +++ b/mephisto/tools/db_data_porter/randomize_ids.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict +from typing import TypedDict +from typing import Union + +from mephisto.abstractions.database import MephistoDB +from mephisto.tools.db_data_porter.constants import FK_MEPHISTO_TABLE_PREFIX +from mephisto.tools.db_data_porter.constants import IMPORTED_DATA_TABLE_NAME +from mephisto.tools.db_data_porter.constants import MEPHISTO_DUMP_KEY +from mephisto.tools.db_data_porter.constants import PROVIDER_DATASTORES__MEPHISTO_FK__MAPPINGS +from mephisto.tools.db_data_porter.constants import PROVIDER_DATASTORES__RANDOMIZABLE_PK__MAPPINGS +from mephisto.utils import db as db_utils +from mephisto.utils.console_writer import ConsoleWriter + +logger = ConsoleWriter() + +TablePKSubstitutionsType = Dict[str, str] # One table +DBPKSubstitutionsType = Dict[str, TablePKSubstitutionsType] # One DB +PKSubstitutionsType = Dict[str, DBPKSubstitutionsType] # Multiple DBs + + +class RandomizedIDsType(TypedDict): + pk_substitutions: PKSubstitutionsType + updated_dump: dict + + +def _randomize_ids_for_mephisto( + db: "MephistoDB", mephisto_dump: dict, legacy_only: bool = False, +) -> DBPKSubstitutionsType: + table_names = [t for t in mephisto_dump.keys() if t not in [IMPORTED_DATA_TABLE_NAME]] + + # Find Foreign Keys' field names for all tables in Mephist DB + tables_fks = db_utils.select_fk_mappings_for_all_tables(db, table_names) + + # Make new Primary Keys for all or legacy values + mephisto_pk_substitutions = {} + for table_name in table_names: + pk_field_name = db_utils.get_table_pk_field_name(db, table_name) + table_rows_from_mephisto_dump = mephisto_dump[table_name] + + table_pk_substitutions = {} + for row in table_rows_from_mephisto_dump: + old_pk = row[pk_field_name] + + is_legacy_value = int(old_pk) < db_utils.SQLITE_ID_MIN + if not legacy_only or legacy_only and is_legacy_value: + new_pk = str(db_utils.make_randomized_int_id()) + table_pk_substitutions.update({old_pk: new_pk}) + row[pk_field_name] = new_pk + + prev_value = mephisto_pk_substitutions.get(table_name, {}) + mephisto_pk_substitutions[table_name] = {**prev_value, **table_pk_substitutions} + + # Update Foreign Keys in related tables + for table_name, fks in tables_fks.items(): + table_pk_substitutions = mephisto_pk_substitutions[table_name] + # If nothing to update in related table, just skip it + if not table_pk_substitutions: + continue + + table_fks = tables_fks[table_name] + # If table does not have any Foreign Keys, just skip it + if not table_fks: + continue + + # Change value in related tables rows + table_rows_from_mephisto_dump = mephisto_dump[table_name] + for row in table_rows_from_mephisto_dump: + for fk_table_name, relation_data in table_fks.items(): + row_fk_value = row[relation_data["from"]] + substitution = mephisto_pk_substitutions[fk_table_name].get(row_fk_value) + + if not substitution: + continue + + row[relation_data["from"]] = substitution + + return mephisto_pk_substitutions + + +def _randomize_ids_for_provider( + provider_type: str, + provider_dump: dict, + mephisto_pk_substitutions: DBPKSubstitutionsType, + legacy_only: bool = False, +) -> Union[DBPKSubstitutionsType, None]: + # Nothing to process + if not provider_dump: + logger.warning(f"Dump for provider '{provider_type}' is empty, nothing to process") + return + + provider_fks_mappings = PROVIDER_DATASTORES__MEPHISTO_FK__MAPPINGS.get(provider_type) + # If a new provider and developer forgot to set FKs for export + if not provider_fks_mappings: + logger.warning( + f"No configuration found in PROVIDER_DATASTORES__MEPHISTO_FK__MAPPINGS " + f"for provider '{provider_type}'" + ) + return + + # Make new Primary Keys for all or legacy values + provider_pk_substitutions = {} + provider_pks_mappings = PROVIDER_DATASTORES__RANDOMIZABLE_PK__MAPPINGS.get(provider_type, {}) + + for table_name, pk_field_name in provider_pks_mappings.items(): + table_rows_from_mephisto_dump = provider_dump[table_name] + table_pk_substitutions = {} + for row in table_rows_from_mephisto_dump: + old_pk = row[pk_field_name] + + is_legacy_value = int(old_pk) < db_utils.SQLITE_ID_MIN + if not legacy_only or legacy_only and is_legacy_value: + new_pk = str(db_utils.make_randomized_int_id()) + table_pk_substitutions.update({old_pk: new_pk}) + row[pk_field_name] = new_pk + + prev_value = provider_pk_substitutions.get(table_name, {}) + provider_pk_substitutions[table_name] = {**prev_value, **table_pk_substitutions} + + # Update Foreign Keys in related tables + for table_name, fks in provider_fks_mappings.items(): + table_fks = provider_fks_mappings[table_name] + # If table does not have any Foreign Keys, just skip it + if not table_fks: + continue + + table_rows_from_provider_dump = provider_dump.get(table_name, []) + for row in table_rows_from_provider_dump: + for fk_table_name, relation_data in table_fks.items(): + row_fk_value = row[relation_data["from"]] + + # FKs from Mephisto DB + is_fk_to_mephisto_db = fk_table_name.startswith(FK_MEPHISTO_TABLE_PREFIX) + if is_fk_to_mephisto_db: + fk_table_name = fk_table_name.split(FK_MEPHISTO_TABLE_PREFIX)[1] + substitution = mephisto_pk_substitutions[fk_table_name].get(row_fk_value) + # FKs from provider DB + else: + substitution = provider_pk_substitutions[fk_table_name].get(row_fk_value) + + if not substitution: + continue + + row[relation_data["from"]] = substitution + + return provider_pk_substitutions + + +def randomize_ids( + db: "MephistoDB", full_dump: dict, legacy_only: bool = False, +) -> RandomizedIDsType: + pk_substitutions: PKSubstitutionsType = {} + + # Mephisto DB + mephisto_dump = full_dump[MEPHISTO_DUMP_KEY] + mephisto_pk_substitutions = _randomize_ids_for_mephisto(db, mephisto_dump, legacy_only) + pk_substitutions["mephisto"] = mephisto_pk_substitutions + + # Providers' DBs + provider_types = [i["provider_type"] for i in mephisto_dump["requesters"]] + for provider_type in provider_types: + provider_dump = full_dump[provider_type] + randomized_ids_for_provider = _randomize_ids_for_provider( + provider_type, provider_dump, mephisto_pk_substitutions, + ) + + if randomized_ids_for_provider: + pk_substitutions[provider_type] = randomized_ids_for_provider + + return { + "pk_substitutions": pk_substitutions, + "updated_dump": full_dump, + } + + +def get_old_pk_from_substitutions( + pk: str, substitutions: dict, table_name: str, +) -> str: + # After we created a dump file, we already can have new randomized PKs. + # But we still have old ones in Mephisto DB. + # Find old PKs in reversed key-value pair + pk_subs = substitutions.get(MEPHISTO_DUMP_KEY, {}).get(table_name, {}) + pk_subs_reversed = dict((v, k) for k, v in pk_subs.items()) + return pk_subs_reversed.get(pk) diff --git a/mephisto/tools/db_data_porter/validation.py b/mephisto/tools/db_data_porter/validation.py new file mode 100644 index 000000000..853ad749f --- /dev/null +++ b/mephisto/tools/db_data_porter/validation.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List +from typing import Optional + +from mephisto.abstractions.database import MephistoDB +from mephisto.generators.form_composer.config_validation.utils import make_error_message +from mephisto.tools.db_data_porter.constants import AVAILABLE_PROVIDER_TYPES +from mephisto.tools.db_data_porter.constants import MEPHISTO_DUMP_KEY +from mephisto.tools.db_data_porter.constants import METADATA_DUMP_KEY +from mephisto.utils import db as db_utils + + +def validate_dump_data(db: "MephistoDB", dump_data: dict) -> Optional[List[str]]: + errors = [] + + db_dumps = {k: v for k, v in dump_data.items() if k != METADATA_DUMP_KEY} + + # 1. Check provider names + incorrect_db_names = list(filter(lambda i: i not in AVAILABLE_PROVIDER_TYPES, db_dumps.keys())) + if incorrect_db_names: + errors.append( + f"Dump file cannot contain these database names: {', '.join(incorrect_db_names)}." + ) + + # 2. Check if dump file contains JSON-object + db_values_are_not_dicts = list(filter(lambda i: not isinstance(i, dict), dump_data.values())) + if db_values_are_not_dicts: + errors.append( + f"We found {len(db_values_are_not_dicts)} values in the dump " + f"that are not JSON-objects." + ) + + # 3. Check dumps of DBs + for db_name, db_dump_data in db_dumps.items(): + # Get ot create DB/Datastore to request for available tables + if db_name == MEPHISTO_DUMP_KEY: + db_or_datastore = db + else: + # Use this method here as it creates an empty datastore if it does not exist + db_or_datastore = db.get_datastore_for_provider(db_name) + + available_table_names = db_utils.get_list_of_db_table_names(db_or_datastore) + + # Check tables + for table_name, table_data in db_dump_data.items(): + # Table name must be string + if not isinstance(table_name, str): + errors.append(f"Expecting table name to be a string, not `{table_name}`.") + + # Table data is a list or rows + if not isinstance(table_data, list): + errors.append(f"Expecting table data to be a JSON-array, not `{table_data}`.") + + # Local DB/Datastore has same tables as a dump + if table_name not in available_table_names: + error_message = make_error_message( + f"Your local `{db_name}` database does not have table '{table_name}'.", + [ + "local database has unapplied migrations", + "dump is too old and not compatible to your local database", + ], + indent=8, + list_title="Possible issues", + ) + errors.append(error_message) + + # Check table rows + for i, table_row in enumerate(table_data): + if not isinstance(table_row, dict): + errors.append( + f"Table `{table_name}`, row {i}: " + f"expecting it to be a JSON-object, not `{table_row}`." + ) + continue + + incorrect_field_names = list(filter( + lambda fn: not isinstance(fn, str), table_row.keys()) + ) + if incorrect_field_names: + errors.append( + f"Table `{table_name}`, row {i}: " + f"values of these fields must be strings: " + f"{', '.join(incorrect_field_names)}." + ) + + return errors diff --git a/mephisto/utils/console_writer.py b/mephisto/utils/console_writer.py new file mode 100644 index 000000000..9d32d68f0 --- /dev/null +++ b/mephisto/utils/console_writer.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import traceback +from types import FunctionType +from typing import Any +from typing import Optional + +from rich import print as rich_print + + +class ConsoleWriter: + """ + This class allows to consistently write logs to the console, and apply colored highlights. + To make it easily interchangeable with standard Python logger, it uses same method names. + + Usage: + logger = ConsoleWriter() + logger.info("Some message") + + # to interchange with Python logger, just change one line in the module + logger = get_logger(name=__name__) + """ + + _writer = None + + def __init__(self, printer: Optional[FunctionType] = None): + self._writer = printer or rich_print # by default, we use `rich.print` + + def info(self, value: Any): + self._writer(str(value)) + + def debug(self, value: Any): + self._writer(str(value)) + + def warning(self, value: Any): + self._writer(str(value)) + + def error(self, value: Any): + self._writer(str(value)) + + def exception(self, value: Any): + self.error(value) + rich_print(traceback.format_exc()) diff --git a/mephisto/utils/db.py b/mephisto/utils/db.py new file mode 100644 index 000000000..43eb7f374 --- /dev/null +++ b/mephisto/utils/db.py @@ -0,0 +1,629 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json + +import random +from copy import deepcopy +from datetime import datetime +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Type + +from dateutil.parser import ParserError + +from mephisto.abstractions.database import MephistoDB +from mephisto.utils.console_writer import ConsoleWriter +from mephisto.utils.misc import serialize_date_to_python + +SQLITE_ID_MIN = 1_000_000 +SQLITE_ID_MAX = 2**63 - 1 + +logger = ConsoleWriter() + + +# --- Exceptions --- + +class MephistoDBException(Exception): + pass + + +class EntryAlreadyExistsException(MephistoDBException): + db = None + original_exc = None + table_name = None + + def __init__(self, *args, **kwargs): + self.db = kwargs.pop("db", None) + self.table_name = kwargs.pop("table_name", None) + self.original_exc = kwargs.pop("original_exc", None) + super().__init__(*args, **kwargs) + + +class EntryDoesNotExistException(MephistoDBException): + pass + + +# --- Functions --- + +def _select_all_rows_from_table(db: "MephistoDB", table_name: str) -> List[dict]: + with db.table_access_condition, db.get_connection() as conn: + c = conn.cursor() + c.execute(f"SELECT * FROM {table_name};") + rows = c.fetchall() + return [dict(row) for row in rows] + + +def _select_rows_from_table_related_to_task( + db: "MephistoDB", table_name: str, task_ids: List[str], +) -> List[dict]: + return select_rows_by_list_of_field_values(db, table_name, ["task_id"], [task_ids]) + + +def select_rows_from_table_related_to_task_run( + db: "MephistoDB", table_name: str, task_run_ids: List[str], +) -> List[dict]: + return select_rows_by_list_of_field_values(db, table_name, ["task_run_id"], [task_run_ids]) + + +def serialize_data_for_table(rows: List[dict]) -> List[dict]: + serialized_data = [] + for row in rows: + _row = dict(row) + for field_name, field_value in _row.items(): + # SQLite dates + if field_name.endswith("_at") or field_name.endswith("_date"): + try: + python_datetime_value = serialize_date_to_python(field_value) + except (ParserError, OverflowError): + logger.exception( + f"[red]" + f"Cannot convert value `{field_value}` of field `field_name` " + f"to Python datetime. " + f"It seems your DB was corrupted." + f"[/red]" + ) + exit() + + _row[field_name] = python_datetime_value.isoformat() + + serialized_data.append(_row) + + return serialized_data + + +def make_randomized_int_id() -> int: + return random.randint(SQLITE_ID_MIN, SQLITE_ID_MAX) + + +def get_task_ids_by_task_names(db: "MephistoDB", task_names: List[str]) -> List[str]: + with db.table_access_condition, db.get_connection() as conn: + c = conn.cursor() + task_names_string = ",".join([f"'{s}'" for s in task_names]) + c.execute( + f""" + SELECT task_id FROM tasks + WHERE task_name IN ({task_names_string}); + """ + ) + rows = c.fetchall() + return [r["task_id"] for r in rows] + + +def get_task_run_ids_ids_by_task_ids(db: "MephistoDB", task_ids: List[str]) -> List[str]: + with db.table_access_condition, db.get_connection() as conn: + c = conn.cursor() + task_ids_string = ",".join([f"'{s}'" for s in task_ids]) + c.execute( + f""" + SELECT task_run_id FROM task_runs + WHERE task_id IN ({task_ids_string}); + """ + ) + rows = c.fetchall() + return [r["task_run_id"] for r in rows] + + +def get_task_run_ids_ids_by_labels(db: "MephistoDB", labels: List[str]) -> List[str]: + with db.table_access_condition, db.get_connection() as conn: + if not labels: + return [] + + c = conn.cursor() + + where_labels_string = " OR ".join([f"data_labels LIKE '%\"{l}\"%'" for l in labels]) + where_labels_string = f" AND ({where_labels_string})" + + c.execute( + f""" + SELECT unique_field_values FROM imported_data + WHERE table_name = 'task_runs' {where_labels_string}; + """ + ) + rows = c.fetchall() + + # Serialize data to plain Python list of IDs + task_run_ids = [] + for row in rows: + row_task_run_ids: List[List[str]] = json.loads(row["unique_field_values"]) + task_run_ids += row_task_run_ids + + return task_run_ids + + +def get_table_pk_field_name(db: "MephistoDB", table_name: str): + """ + Make a request to get the name of PK field of table and + store it in `self._tables_pk_fields` for next rows + """ + with db.table_access_condition, db.get_connection() as conn: + c = conn.cursor() + c.execute( + f"SELECT name FROM pragma_table_info('{table_name}') WHERE pk;" + ) + table_unique_field_name = c.fetchone()["name"] + return table_unique_field_name + + +def select_all_table_rows(db: "MephistoDB", table_name: str) -> List[dict]: + with db.table_access_condition, db.get_connection() as conn: + c = conn.cursor() + c.execute( + f"SELECT * FROM {table_name};" + ) + rows = c.fetchall() + return [dict(row) for row in rows] + + +def select_rows_by_list_of_field_values( + db: "MephistoDB", + table_name: str, + field_names: List[str], + field_values: List[List[str]], + order_by: Optional[str] = None, +) -> List[dict]: + """ + Select all entries by table name, field name and list of this field values. + `field_values` is a list of lists of values for each field name in same order as `field_names` + + For instance: + table_name - granted_qualifications + field_names - ["qualification_id", "worker_id"] + field_values - [[, ], []] + And in this case we will select all Granted Qualifications, + if we have rows in DB with two Qualification IDs _AND_ one Worder ID + """ + with db.table_access_condition, db.get_connection() as conn: + c = conn.cursor() + + # Combine WHERE statement + where_list = [] + for i, field_name in enumerate(field_names): + _field_values = field_values[i] + field_values_string = ",".join([f"'{s}'" for s in _field_values]) + where_list.append([field_name, field_values_string]) + where_string = " AND ".join([ + f"{field_name} IN ({field_values_string})" + for field_name, field_values_string in where_list + ]) + + # Combine ORDER BY statement + order_by_string = "" + if order_by: + order_by_direction = "DESC" if order_by.startswith("-") else "ASC" + order_by_field_name = order_by[1:] if order_by.startswith("-") else order_by + order_by_string = f" ORDER BY {order_by_field_name} {order_by_direction}" + + c.execute( + f""" + SELECT * FROM {table_name} + WHERE {where_string} + {order_by_string}; + """ + ) + + rows = c.fetchall() + return [dict(row) for row in rows] + + +def delete_exported_data_without_fk_constraints( + db: "MephistoDB", db_dump: dict, table_names_can_be_cleaned: Optional[List[str]] = None, +): + table_names_can_be_cleaned = table_names_can_be_cleaned or [] + + with db.table_access_condition, db.get_connection() as conn: + c = conn.cursor() + c.execute( + "PRAGMA foreign_keys = off;" + ) + + delete_queries = [] + for table_name, rows in db_dump.items(): + if table_name not in table_names_can_be_cleaned: + continue + + table_pk_name = get_table_pk_field_name(db, table_name) + table_pks = [r[table_pk_name] for r in rows] + table_pks_string = ",".join([f"'{s}'" for s in table_pks]) + delete_queries.append( + f"DELETE FROM {table_name} WHERE {table_pk_name} IN ({table_pks_string});" + ) + c.executescript("\n".join(delete_queries)) + + c.execute( + "PRAGMA foreign_keys = on;" + ) + + +def delete_entire_exported_data(db: "MephistoDB"): + """Delete all rows in tables without dropping tables""" + exclude_table_names = ["migrations"] + table_names = get_list_of_db_table_names(db) + table_names = [tn for tn in table_names if tn not in exclude_table_names] + + with db.table_access_condition, db.get_connection() as conn: + c = conn.cursor() + c.execute( + "PRAGMA foreign_keys = off;" + ) + + delete_queries = [] + for table_name in table_names: + delete_queries.append( + f"DELETE FROM {table_name};" + f"DELETE FROM sqlite_sequence WHERE name='{table_name}';" + ) + + c.executescript("\n".join(delete_queries)) + + c.execute( + "PRAGMA foreign_keys = on;" + ) + + +def get_list_of_provider_types(db: "MephistoDB") -> List[str]: + with db.table_access_condition, db.get_connection() as conn: + c = conn.cursor() + c.execute( + "SELECT provider_type FROM requesters;" + ) + rows = c.fetchall() + return [r["provider_type"] for r in rows] + + +def get_latest_row_from_table( + db: "MephistoDB", table_name: str, order_by: Optional[str] = "creation_date", +) -> Optional[dict]: + with db.table_access_condition, db.get_connection() as conn: + c = conn.cursor() + c.execute( + f""" + SELECT * + FROM {table_name} + ORDER BY {order_by} DESC + LIMIT 1; + """, + ) + latest_row = c.fetchone() + + return dict(latest_row) if latest_row else None + + +def apply_migrations(db: "MephistoDB", migrations: dict): + with db.table_access_condition, db.get_connection() as conn: + for migration_name, migration_sql in migrations.items(): + try: + c = conn.cursor() + + c.execute( + """SELECT id FROM migrations WHERE name = ?1;""", + (migration_name,), + ) + migration_has_been_applied = c.fetchone() + + if not migration_has_been_applied: + c.executescript(migration_sql) + c.execute( + """ + INSERT INTO migrations( + name, status + ) VALUES (?, ?); + """, + (migration_name, "completed"), + ) + except Exception as e: + c.execute( + """ + INSERT INTO migrations( + name, status, error_message + ) VALUES (?, ?, ?); + """, + (migration_name, "errored", str(e)), + ) + logger.exception( + f"Could not apply migration '{migration_name}' for database '{db.db_path}':\n" + f"{migration_sql}.\n" + f"Error: {e}. Marked it as errored. in 'migrations' table." + ) + + +def get_list_of_db_table_names(db: "MephistoDB") -> List[str]: + with db.table_access_condition, db.get_connection() as conn: + c = conn.cursor() + c.execute("SELECT name FROM sqlite_master WHERE type='table';") + rows = c.fetchall() + return [r["name"] for r in rows] + + +def get_list_of_tables_to_export(db: "MephistoDB") -> List[str]: + table_names = get_list_of_db_table_names(db) + + filtered_table_names = [] + for table_name in table_names: + if not table_name.startswith("sqlite_") and table_name not in ["migrations"]: + filtered_table_names.append(table_name) + + return filtered_table_names + + +def check_if_row_with_params_exists( + db: "MephistoDB", table_name: str, params: dict, select_field: Optional[str] = "*", +) -> bool: + """ + Check if row exists in `table_name` for passed dict of `params` + """ + with db.table_access_condition, db.get_connection() as conn: + c = conn.cursor() + + where_args = [] + execute_args = [] + + for i, (field_name, field_value) in enumerate(params.items(), start=1): + execute_args.append(field_value) + where_args.append(f"{field_name} = ?{i}") + + where_string = "WHERE " + " AND ".join(where_args) if where_args else "" + + c.execute( + f""" + SELECT {select_field} + FROM {table_name} {where_string} + LIMIT 1; + """, + execute_args, + ) + existing_row_in_current_db = c.fetchone() + + return bool(existing_row_in_current_db) + + +def get_providers_datastores(db: "MephistoDB") -> Dict[str, "MephistoDB"]: + provider_types = get_list_of_provider_types(db) + provider_datastores = {t: db.get_datastore_for_provider(t) for t in provider_types} + return provider_datastores + + +def db_or_datastore_to_dict(db: "MephistoDB") -> dict: + """Convert all kind of DBs to dict""" + dump_data = {} + table_names = get_list_of_tables_to_export(db) + for table_name in table_names: + table_rows = _select_all_rows_from_table(db, table_name) + table_data = serialize_data_for_table(table_rows) + dump_data[table_name] = table_data + + return dump_data + + +def mephisto_db_to_dict_for_task_runs( + db: "MephistoDB", + task_run_ids: Optional[List[str]] = None, +) -> dict: + """ + Partial converation Mephisto DB into dict by given TaskRun IDs + NOTE: does not work with provider datastores, only main database + """ + dump_data = {} + table_names = get_list_of_tables_to_export(db) + + tables_with_task_run_relations = [ + "agents", + "assignments", + "onboarding_agents", + "task_runs", + "units", + ] + + tables_with_task_relations = [ + "tasks", + "unit_review", + ] + + # Find and serialize tables with `task_run_id` field + for table_name in table_names: + if table_name in tables_with_task_run_relations: + table_rows = select_rows_from_table_related_to_task_run(db, table_name, task_run_ids) + table_data = serialize_data_for_table(table_rows) + dump_data[table_name] = table_data + + # Find and serialize tables with `task_id` field + task_ids = list(set(filter(bool, [i["task_id"] for i in dump_data["task_runs"]]))) + for table_name in table_names: + if table_name in tables_with_task_relations: + table_rows = _select_rows_from_table_related_to_task(db, table_name, task_ids) + table_data = serialize_data_for_table(table_rows) + dump_data[table_name] = table_data + + # Find and serialize `projects` + project_ids = list(set(filter(bool, [i["project_id"] for i in dump_data["tasks"]]))) + project_rows = select_rows_by_list_of_field_values( + db, "projects", ["project_id"], [project_ids], + ) + dump_data["projects"] = serialize_data_for_table(project_rows) + + # Find and serialize `requesters` + requester_ids = list(set(filter(bool, [i["requester_id"] for i in dump_data["task_runs"]]))) + requester_rows = select_rows_by_list_of_field_values( + db, "requesters", ["requester_id"], [requester_ids], + ) + dump_data["requesters"] = serialize_data_for_table(requester_rows) + + # Find and serialize `workers` + worker_ids = list(set(filter(bool, [i["worker_id"] for i in dump_data["units"]]))) + worker_rows = select_rows_by_list_of_field_values( + db, "workers", ["worker_id"], [worker_ids], + ) + dump_data["workers"] = serialize_data_for_table(worker_rows) + + # Find and serialize `granted_qualifications` + granted_qualification_rows = select_rows_by_list_of_field_values( + db, "granted_qualifications", ["worker_id"], [worker_ids], + ) + dump_data["granted_qualifications"] = serialize_data_for_table(granted_qualification_rows) + + # Find and serialize `qualifications` + qualification_ids = list(set(filter( + bool, [i["qualification_id"] for i in dump_data["granted_qualifications"]], + ))) + qualification_rows = select_rows_by_list_of_field_values( + db, "qualifications", ["qualification_id"], [qualification_ids], + ) + dump_data["qualifications"] = serialize_data_for_table(qualification_rows) + + return dump_data + + +def select_task_run_ids_since_date(db: "MephistoDB", since: datetime) -> List[str]: + # We are not doing this on the database level because SQLite can have different formats + # for datetime fields and this is more reliable to perform comparing Python datetimes + task_run_rows = select_all_table_rows(db, "task_runs") + task_run_ids_since = [] + for row in task_run_rows: + creation_datetime = serialize_date_to_python(row["creation_date"]) + + if creation_datetime >= since: + task_run_ids_since.append(row["task_run_id"]) + + return task_run_ids_since + + +def select_fk_mappings_for_table(db: "MephistoDB", table_name: str) -> dict: + with db.table_access_condition, db.get_connection() as conn: + c = conn.cursor() + c.execute(f"SELECT * FROM pragma_foreign_key_list('{table_name}');") + rows = c.fetchall() + table_fks = {} + + for row in rows: + fk_table_name = row["table"] + current_table_field_name = row["from"] + relating_table_field_name = row["to"] + + table_fks[fk_table_name] = { + "from": current_table_field_name, + "to": relating_table_field_name, + } + + return table_fks + + +def select_fk_mappings_for_all_tables(db: "MephistoDB", table_names: List[str]) -> dict: + tables_fks = {} + for table_name in table_names: + table_fks = select_fk_mappings_for_table(db, table_name) + tables_fks.update({table_name: table_fks}) + return tables_fks + + +def insert_new_row_in_table(db: "MephistoDB", table_name: str, row: dict): + with db.table_access_condition, db.get_connection() as conn: + c = conn.cursor() + + columns, values = zip(*row.items()) + + columns_string = ",".join(columns) + columns_questions_string = ",".join(["?"] * len(columns)) + + c.execute( + f""" + INSERT INTO {table_name}( + {columns_string} + ) VALUES ({columns_questions_string}); + """, + values, + ) + + +def update_row_in_table( + db: "MephistoDB", table_name: str, row: dict, pk_field_name: Optional[str] = None, +): + row = deepcopy(row) + + if not pk_field_name: + pk_field_name = get_table_pk_field_name(db, table_name=table_name) + + pk_field_value = row.pop(pk_field_name) + + with db.table_access_condition, db.get_connection() as conn: + c = conn.cursor() + + columns, values = zip(*row.items()) + + columns_set_string = ", ".join([f"{c} = ?" for c in columns]) + + c.execute( + f""" + UPDATE {table_name} + SET {columns_set_string} + WHERE {pk_field_name} = {pk_field_value}; + """, + values, + ) + + +# --- Decorators --- + +def retry_generate_id(caught_excs: Optional[List[Type[Exception]]] = None): + """ + A decorator that attempts to call create DB entry until ID will be unique. + + Exception object must have next attributes: + - original_exc + - db + - table_name + """ + def decorator(unreliable_fn: Callable): + def wrapped_fn(*args, **kwargs): + caught_excs_tuple = tuple(caught_excs or [Exception]) + + pk_exists = True + while pk_exists: + pk_exists = False + + try: + # happy path + result = unreliable_fn(*args, **kwargs) + return result + except caught_excs_tuple as e: + # We can check constraint only in case if excpetion was configured well. + # Othervise, we just leave error as is + exc_message = str(getattr(e, "original_exc", None) or "") + db = getattr(e, "db", None) + table_name = getattr(e, "table_name", None) + is_unique_constraint = exc_message.startswith("UNIQUE constraint") + + if db and table_name and is_unique_constraint: + pk_field_name = get_table_pk_field_name(db, table_name=table_name) + if pk_field_name in exc_message: + pk_exists = True + + # Set original function name to wrapped one. + wrapped_fn.__name__ = unreliable_fn.__name__ + + return wrapped_fn + return decorator diff --git a/mephisto/utils/misc.py b/mephisto/utils/misc.py new file mode 100644 index 000000000..ca5b35f7a --- /dev/null +++ b/mephisto/utils/misc.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from datetime import datetime +from typing import Any + +from dateutil.parser import parse as dateutil_parse + + +def serialize_date_to_python(value: Any) -> datetime: + """Convert string dates or integer timestamps into Python datetime format""" + # If integer timestamp + if isinstance(value, int): + timestamp_is_in_msec = len(str(value)) == 13 + datetime_value = datetime.fromtimestamp( + value / 1000 if timestamp_is_in_msec else value + ) + # If datetime string + else: + datetime_value = dateutil_parse(str(value)) + + return datetime_value diff --git a/mephisto/utils/testing.py b/mephisto/utils/testing.py index 3ef97eb53..aa811be42 100644 --- a/mephisto/utils/testing.py +++ b/mephisto/utils/testing.py @@ -215,17 +215,17 @@ def find_unit_reviews( params.append(nonesafe_int(task_id)) with db.table_access_condition: - conn = db._get_connection() + conn = db.get_connection() c = conn.cursor() c.execute( f""" SELECT * FROM unit_review WHERE (updated_qualification_id = ?1) OR - (revoked_qualification_id = ?1) - AND (worker_id = ?2) + (revoked_qualification_id = ?1) AND + (worker_id = ?2) {task_query} - ORDER BY created_at ASC; + ORDER BY creation_date ASC; """, params, ) From fe832376439bc90560a9c29d188a3e8917131a73 Mon Sep 17 00:00:00 2001 From: Paul Abumov Date: Mon, 29 Apr 2024 13:46:28 -0400 Subject: [PATCH 2/3] Fixes for Data Porter feature --- .../merge_dbs/custom_conflict_resolver.md | 4 + .../guides/how_to_use/merge_dbs/reference.md | 38 +-- .../how_to_use/merge_dbs/simple_usage.md | 23 +- .../abstractions/databases/local_database.py | 91 +++-- ...0325_preparing_db_for_merge_dbs_command.py | 97 +++--- .../providers/mock/mock_datastore.py | 1 - .../providers/mock/mock_datastore_export.py | 10 +- ...0325_preparing_db_for_merge_dbs_command.py | 70 ++++ .../providers/mturk/migrations/__init__.py | 12 + .../providers/mturk/mturk_datastore.py | 4 + .../providers/mturk/mturk_datastore_export.py | 15 +- ...0325_preparing_db_for_merge_dbs_command.py | 93 ++--- .../prolific/prolific_datastore_export.py | 27 +- mephisto/client/cli.py | 216 +----------- mephisto/client/cli_db_commands.py | 318 ++++++++++++++++++ mephisto/tools/db_data_porter/backups.py | 150 +-------- .../conflict_resolvers/__init__.py | 6 +- .../base_merge_conflict_resolver.py | 101 ++++-- .../default_merge_conflict_resolver.py | 4 +- .../example_merge_conflict_resolver.py | 72 ++++ mephisto/tools/db_data_porter/constants.py | 4 +- .../tools/db_data_porter/db_data_porter.py | 236 ++++++++----- mephisto/tools/db_data_porter/dumps.py | 118 ++++++- mephisto/tools/db_data_porter/export_dump.py | 226 +++++++++++++ mephisto/tools/db_data_porter/import_dump.py | 132 ++++++-- .../tools/db_data_porter/randomize_ids.py | 16 +- mephisto/tools/db_data_porter/validation.py | 4 +- mephisto/utils/db.py | 152 ++++++--- mephisto/utils/misc.py | 4 +- mephisto/utils/testing.py | 2 +- test/core/test_operator.py | 3 +- test/review_app/server/api/test_units_view.py | 32 +- 32 files changed, 1544 insertions(+), 737 deletions(-) create mode 100644 mephisto/abstractions/providers/mturk/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py create mode 100644 mephisto/abstractions/providers/mturk/migrations/__init__.py create mode 100644 mephisto/client/cli_db_commands.py create mode 100644 mephisto/tools/db_data_porter/conflict_resolvers/example_merge_conflict_resolver.py create mode 100644 mephisto/tools/db_data_porter/export_dump.py diff --git a/docs/web/docs/guides/how_to_use/merge_dbs/custom_conflict_resolver.md b/docs/web/docs/guides/how_to_use/merge_dbs/custom_conflict_resolver.md index 77fd228d0..01e077422 100644 --- a/docs/web/docs/guides/how_to_use/merge_dbs/custom_conflict_resolver.md +++ b/docs/web/docs/guides/how_to_use/merge_dbs/custom_conflict_resolver.md @@ -36,3 +36,7 @@ and perhaps `default_strategy_name` value (see `DefaultMergeConflictResolver` as NOTE: All available providers must be present in `strategies_config`. Table names (under each provider key) are optional, and if missing, `default_strategy_name` will be used for all conflicts related to this table. + +4. There is an example of a working custom conflict resolver in module `mephisto/tools/db_data_porter/conflict_resolvers/example_merge_conflict_resolver.py`. You can launch it like this: + +`mephisto db import ... --conflict-resolver ExampleMergeConflictResolver` diff --git a/docs/web/docs/guides/how_to_use/merge_dbs/reference.md b/docs/web/docs/guides/how_to_use/merge_dbs/reference.md index 5cf58becb..55677f590 100644 --- a/docs/web/docs/guides/how_to_use/merge_dbs/reference.md +++ b/docs/web/docs/guides/how_to_use/merge_dbs/reference.md @@ -14,7 +14,7 @@ This is a reference describing set of commands under the `mephisto db` command g ## Export This command exports data from Mephisto DB and provider-specific datastores -as a combination of (i) a JSON file, and (ii) an archived `data` catalog with related files. +as an archived combination of (i) a JSON file, and (ii) a `data` catalog with related files. If no parameter passed, full data dump (i.e. backup) will be created. @@ -29,25 +29,25 @@ mephisto db export --export-tasks-by-ids 1 --export-tasks-by-ids 2 mephisto db export --export-task-runs-by-ids 3 --export-task-runs-by-ids 4 mephisto db export --export-task-runs-since-date 2024-01-01 mephisto db export --export-task-runs-since-date 2023-01-01T00:00:00 -mephisto db export --export-labels first_dump --export-labels second_dump -mephisto db export --export-tasks-by-ids 1 --delete-exported-data --randomize-legacy-ids --export-indent 2 +mephisto db export --labels first_dump --labels second_dump +mephisto db export --export-tasks-by-ids 1 --delete-exported-data --randomize-legacy-ids --export-indent 4 ``` Options (all optional): - `-tn/--export-tasks-by-names` - names of Tasks that will be exported - `-ti/--export-tasks-by-ids` - ids of Tasks that will be exported -- `-tr/--export-task-runs-by-ids` - ids of TaskRuns that will be exported +- `-tri/--export-task-runs-by-ids` - ids of TaskRuns that will be exported - `-trs/--export-task-runs-since-date` - only objects created after this ISO8601 datetime will be exported -- `-tl/--export-labels` - only data imported under these labels will be exported -- `-de/--delete-exported-data` - after exporting data, delete it from local DB +- `-l/--labels` - only data imported under these labels will be exported +- `-del/--delete-exported-data` - after exporting data, delete it from local DB - `-r/--randomize-legacy-ids` - replace legacy autoincremented ids with new pseudo-random ids to avoid conflicts during data merging -- `-i/--export-indent` - make dump easy to read via formatting JSON with indentations +- `-i/--export-indent` - make dump easy to read via formatting JSON with indentations (Default 2) - `-v/--verbosity` - write more informative messages about progress (Default 0. Values: 0, 1) Note that the following options cannot be used together: -`--export-tasks-by-names`, `--export-tasks-by-ids`, `--export-task-runs-by-ids`, `--export-task-runs-since-date`, `--export-labels`. +`--export-tasks-by-names`, `--export-tasks-by-ids`, `--export-task-runs-by-ids`, `--export-task-runs-since-date`, `--labels`. ## Import @@ -56,21 +56,21 @@ This command imports data from a dump file created by `mephisto db export` comma Examples: ``` -mephisto db import --dump-file +mephisto db import --file -mephisto db import --dump-file 2024_01_01_00_00_01_mephisto_dump.json --verbosity -mephisto db import --dump-file 2024_01_01_00_00_01_mephisto_dump.json --label-name my_first_dump -mephisto db import --dump-file 2024_01_01_00_00_01_mephisto_dump.json --conflict-resolver MyCustomMergeConflictResolver -mephisto db import --dump-file 2024_01_01_00_00_01_mephisto_dump.json --keep-import-metadata +mephisto db import --file 2024_01_01_00_00_01_mephisto_dump.json --verbosity +mephisto db import --file 2024_01_01_00_00_01_mephisto_dump.json --labels my_first_dump +mephisto db import --file 2024_01_01_00_00_01_mephisto_dump.json --conflict-resolver MyCustomMergeConflictResolver +mephisto db import --file 2024_01_01_00_00_01_mephisto_dump.json --keep-import-metadata ``` Options: -- `-d/--dump-file` - location of the __***.json__ dump file (filename if created in +- `-f/--file` - location of the `***.zip` dump file (filename if created in `/outputs/export` folder, or absolute filepath) - `-cr/--conflict-resolver` (Optional) - name of Python class to be used for resolving merging conflicts (when your local DB already has a row with same unique field value as a DB row in the dump data) -- `-l/--label-name` - a short string serving as a reference for the ported data (stored in `imported_data` table), - so later you can export the imported data with `--export-labels` export option +- `-l/--labels` - one or more short strings serving as a reference for the ported data (stored in `imported_data` table), + so later you can export the imported data with `--labels` export option - `-k/--keep-import-metadata` - write data from `imported_data` table of the dump (by default it's not imported) - `-v/--verbosity` - level of logging (default: 0; values: 0, 1) @@ -95,13 +95,13 @@ Note that it will erase all current data, and you may want to run command `mephi Examples: ``` -mephisto db restore --backup-file +mephisto db restore --file -mephisto db restore --backup-file 2024_01_01_00_10_01.zip +mephisto db restore --file 2024_01_01_00_10_01.zip ``` Options: -- `-b/--backup-file` - location of the __*.zip__ backup file (filename if created in +- `-f/--file` - location of the `***.zip` backup file (filename if created in `/outputs/backup` folder, or absolute filepath) - `-v/--verbosity` - level of logging (default: 0; values: 0, 1) diff --git a/docs/web/docs/guides/how_to_use/merge_dbs/simple_usage.md b/docs/web/docs/guides/how_to_use/merge_dbs/simple_usage.md index 3ebf9020a..1c2089dc6 100644 --- a/docs/web/docs/guides/how_to_use/merge_dbs/simple_usage.md +++ b/docs/web/docs/guides/how_to_use/merge_dbs/simple_usage.md @@ -45,8 +45,8 @@ mephisto db backup And you will see text like this ``` -Started making backup -Finished successfully! File: '//outputs/backup/2024_01_01_00_00_01_mephisto_backup.zip +Started creating backup file ... +Finished successfully! File: //outputs/backup/2024_01_01_00_00_01_mephisto_backup.zip ``` Find and copy this file. @@ -79,31 +79,30 @@ mephisto db export --randomize-legacy-ids And you will see text like this ``` -Started exporting -Run command for all TaskRuns. +Started exporting data ... +No filter for TaskRun specified - exporting all TaskRuns. Finished successfully! Files created: - - Database dump - //outputs/export/2024_01_01_00_00_01_mephisto_dump.json - - Data files dump - //outputs/export/2024_01_01_00_00_01_mephisto_dump.zip + - Dump archive - //outputs/export/2024_01_01_00_00_01_mephisto_dump.zip ``` ### Import just created dump into main project -Put your dump into export directory `/mephisto/outputs/export/` and you can use just a dump name in the command, +Put your dump into export directory `//outputs/export/` and you can use just a dump name in the command, or use a full path to the file. Let's just imagine, you put file in export directory: ```shell -mephisto db import --dump-file 2024_01_01_00_00_01_mephisto_dump.json +mephisto db import --file 2024_01_01_00_00_01_mephisto_dump.zip ``` And you will see text like this ``` -Started importing from dump '2024_01_01_00_00_01_mephisto_dump.json' Are you sure? It will affect your databases and related files. Type 'yes' and press Enter if you want to proceed: yes Just in case, we are making a backup of all your local data. If something went wrong during import, we will restore all your data from this backup -Backup was created successfully! File: '/mephisto/outputs/backup/2024_01_01_00_10_01_mephisto_backup.zip' +Backup was created successfully! File: '//outputs/backup/2024_04_25_17_11_56_mephisto_backup.zip' +Started importing from dump file //outputs/export/2024_04_25_17_11_43_mephisto_dump.zip ... Finished successfully ``` @@ -117,14 +116,14 @@ Also, we create a backup automatically just in case too, just before all changes No worries, just restore everything from your or our backup: ```shell -mephisto db restore --backup-file 2024_01_01_00_10_01.zip +mephisto db restore --file 2024_01_01_00_10_01_mephisto_backup.zip ``` And you will see text like this ``` -Started restoring from backup '2024_01_01_00_10_01.zip' Are you sure? It will affect your databases and related files. Type 'yes' and press Enter if you want to proceed: yes +Started restoring from backup //outputs/backup/2024_01_01_00_10_01_mephisto_backup.zip ... Finished successfully ``` diff --git a/mephisto/abstractions/databases/local_database.py b/mephisto/abstractions/databases/local_database.py index c6b0ef676..bfca8f822 100644 --- a/mephisto/abstractions/databases/local_database.py +++ b/mephisto/abstractions/databases/local_database.py @@ -224,7 +224,7 @@ def _new_project(self, project_name: str) -> str: ( make_randomized_int_id(), project_name, - ) + ), ) project_id = str(c.lastrowid) return project_id @@ -264,7 +264,8 @@ def _find_projects(self, project_name: Optional[str] = None) -> List[Project]: """ SELECT * from projects """ - + additional_query, + + additional_query + + " ORDER BY creation_date ASC", arg_tuple, ) rows = c.fetchall() @@ -311,7 +312,10 @@ def _new_task( raise EntryDoesNotExistException(e) elif is_unique_failure(e): raise EntryAlreadyExistsException( - e, db=self, table_name="tasks", original_exc=e, + e, + db=self, + table_name="tasks", + original_exc=e, ) raise MephistoDBException(e) @@ -344,7 +348,8 @@ def _find_tasks( """ SELECT * from tasks """ - + additional_query, + + additional_query + + " ORDER BY creation_date ASC", arg_tuple, ) rows = c.fetchall() @@ -449,7 +454,10 @@ def _new_task_run( raise EntryDoesNotExistException(e) elif is_unique_failure(e): raise EntryAlreadyExistsException( - e, db=self, table_name="task_runs", original_exc=e, + e, + db=self, + table_name="task_runs", + original_exc=e, ) raise MephistoDBException(e) @@ -483,7 +491,8 @@ def _find_task_runs( """ SELECT * from task_runs """ - + additional_query, + + additional_query + + " ORDER BY creation_date ASC", arg_tuple, ) rows = c.fetchall() @@ -552,7 +561,10 @@ def _new_assignment( except sqlite3.IntegrityError as e: if is_unique_failure(e): raise EntryAlreadyExistsException( - e, db=self, table_name="assignments", original_exc=e, + e, + db=self, + table_name="assignments", + original_exc=e, ) raise MephistoDBException(e) @@ -603,7 +615,8 @@ def _find_assignments( """ SELECT * from assignments """ - + additional_query, + + additional_query + + " ORDER BY creation_date ASC", arg_tuple, ) rows = c.fetchall() @@ -668,7 +681,10 @@ def _new_unit( raise EntryDoesNotExistException(e) elif is_unique_failure(e): raise EntryAlreadyExistsException( - e, db=self, table_name="units", original_exc=e, + e, + db=self, + table_name="units", + original_exc=e, ) raise MephistoDBException(e) @@ -734,7 +750,8 @@ def _find_units( """ SELECT * from units """ - + additional_query, + + additional_query + + " ORDER BY creation_date ASC", arg_tuple, ) rows = c.fetchall() @@ -817,7 +834,7 @@ def _new_requester(self, requester_name: str, provider_type: str) -> str: """ INSERT INTO requesters( requester_id, - requester_name, + requester_name, provider_type ) VALUES (?, ?, ?); """, @@ -832,7 +849,10 @@ def _new_requester(self, requester_name: str, provider_type: str) -> str: except sqlite3.IntegrityError as e: if is_unique_failure(e): raise EntryAlreadyExistsException( - e, db=self, table_name="requesters", original_exc=e, + e, + db=self, + table_name="requesters", + original_exc=e, ) raise MephistoDBException(e) @@ -862,7 +882,8 @@ def _find_requesters( """ SELECT * from requesters """ - + additional_query, + + additional_query + + " ORDER BY creation_date ASC", arg_tuple, ) rows = c.fetchall() @@ -890,7 +911,7 @@ def _new_worker(self, worker_name: str, provider_type: str) -> str: """ INSERT INTO workers( worker_id, - worker_name, + worker_name, provider_type ) VALUES (?, ?, ?); """, @@ -905,7 +926,10 @@ def _new_worker(self, worker_name: str, provider_type: str) -> str: except sqlite3.IntegrityError as e: if is_unique_failure(e): raise EntryAlreadyExistsException( - e, db=self, table_name="workers", original_exc=e, + e, + db=self, + table_name="workers", + original_exc=e, ) raise MephistoDBException(e) @@ -935,7 +959,8 @@ def _find_workers( """ SELECT * from workers """ - + additional_query, + + additional_query + + " ORDER BY creation_date ASC", arg_tuple, ) rows = c.fetchall() @@ -1006,7 +1031,10 @@ def _new_agent( raise EntryDoesNotExistException(e) elif is_unique_failure(e): raise EntryAlreadyExistsException( - e, db=self, table_name="agents", original_exc=e, + e, + db=self, + table_name="agents", + original_exc=e, ) raise MephistoDBException(e) @@ -1082,7 +1110,8 @@ def _find_agents( """ SELECT * from agents """ - + additional_query, + + additional_query + + " ORDER BY creation_date ASC", arg_tuple, ) rows = c.fetchall() @@ -1107,7 +1136,10 @@ def _make_qualification(self, qualification_name: str) -> str: except sqlite3.IntegrityError as e: if is_unique_failure(e): raise EntryAlreadyExistsException( - e, db=self, table_name="units", original_exc=e, + e, + db=self, + table_name="units", + original_exc=e, ) raise MephistoDBException(e) @@ -1125,7 +1157,8 @@ def _find_qualifications(self, qualification_name: Optional[str] = None) -> List """ SELECT * from qualifications """ - + additional_query, + + additional_query + + " ORDER BY creation_date ASC", arg_tuple, ) rows = c.fetchall() @@ -1204,7 +1237,10 @@ def _grant_qualification(self, qualification_id: str, worker_id: str, value: int except sqlite3.IntegrityError as e: if is_unique_failure(e): raise EntryAlreadyExistsException( - e, db=self, table_name="units", original_exc=e, + e, + db=self, + table_name="units", + original_exc=e, ) raise MephistoDBException(e) @@ -1319,7 +1355,10 @@ def _new_onboarding_agent( raise EntryDoesNotExistException(e) elif is_unique_failure(e): raise EntryAlreadyExistsException( - e, db=self, table_name="onboarding_agents", original_exc=e, + e, + db=self, + table_name="onboarding_agents", + original_exc=e, ) raise MephistoDBException(e) @@ -1388,7 +1427,8 @@ def _find_onboarding_agents( """ SELECT * from onboarding_agents """ - + additional_query, + + additional_query + + " ORDER BY creation_date ASC", arg_tuple, ) rows = c.fetchall() @@ -1439,7 +1479,10 @@ def _new_unit_review( except sqlite3.IntegrityError as e: if is_unique_failure(e): raise EntryAlreadyExistsException( - e, db=self, table_name="unit_review", original_exc=e, + e, + db=self, + table_name="unit_review", + original_exc=e, ) raise MephistoDBException(e) diff --git a/mephisto/abstractions/databases/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py b/mephisto/abstractions/databases/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py index 0f8094fe2..20061a12b 100644 --- a/mephisto/abstractions/databases/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py +++ b/mephisto/abstractions/databases/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py @@ -9,27 +9,28 @@ 2. Remove autoincrement parameter for all Primary Keys 3. Add missed Foreign Keys in `agents` table 4. Add `granted_qualifications.update_date` +5. Modified default value for `creation_date` """ PREPARING_DB_FOR_MERGE_DBS_COMMAND = """ ALTER TABLE unit_review RENAME COLUMN created_at TO creation_date; - + /* Disable FK constraints */ PRAGMA foreign_keys = off; - - + + /* Projects */ CREATE TABLE IF NOT EXISTS _projects ( project_id INTEGER PRIMARY KEY, project_name TEXT NOT NULL UNIQUE, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) ); INSERT INTO _projects SELECT * FROM projects; DROP TABLE projects; ALTER TABLE _projects RENAME TO projects; - - + + /* Tasks */ CREATE TABLE IF NOT EXISTS _tasks ( task_id INTEGER PRIMARY KEY, @@ -37,27 +38,27 @@ task_type TEXT NOT NULL, project_id INTEGER, parent_task_id INTEGER, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), FOREIGN KEY (parent_task_id) REFERENCES tasks (task_id), FOREIGN KEY (project_id) REFERENCES projects (project_id) ); INSERT INTO _tasks SELECT * FROM tasks; DROP TABLE tasks; ALTER TABLE _tasks RENAME TO tasks; - - + + /* Requesters */ CREATE TABLE IF NOT EXISTS _requesters ( requester_id INTEGER PRIMARY KEY, requester_name TEXT NOT NULL UNIQUE, provider_type TEXT NOT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) ); INSERT INTO _requesters SELECT * FROM requesters; DROP TABLE requesters; ALTER TABLE _requesters RENAME TO requesters; - - + + /* Task Runs */ CREATE TABLE IF NOT EXISTS _task_runs ( task_run_id INTEGER PRIMARY KEY, @@ -68,15 +69,15 @@ provider_type TEXT NOT NULL, task_type TEXT NOT NULL, sandbox BOOLEAN NOT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), FOREIGN KEY (task_id) REFERENCES tasks (task_id), FOREIGN KEY (requester_id) REFERENCES requesters (requester_id) ); INSERT INTO _task_runs SELECT * FROM task_runs; DROP TABLE task_runs; ALTER TABLE _task_runs RENAME TO task_runs; - - + + /* Assignments */ CREATE TABLE IF NOT EXISTS _assignments ( assignment_id INTEGER PRIMARY KEY, @@ -86,7 +87,7 @@ task_type TEXT NOT NULL, provider_type TEXT NOT NULL, sandbox BOOLEAN NOT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), FOREIGN KEY (task_id) REFERENCES tasks (task_id), FOREIGN KEY (task_run_id) REFERENCES task_runs (task_run_id), FOREIGN KEY (requester_id) REFERENCES requesters (requester_id) @@ -94,8 +95,8 @@ INSERT INTO _assignments SELECT * FROM assignments; DROP TABLE assignments; ALTER TABLE _assignments RENAME TO assignments; - - + + /* Units */ CREATE TABLE IF NOT EXISTS _units ( unit_id INTEGER PRIMARY KEY, @@ -111,7 +112,7 @@ task_run_id INTEGER NOT NULL, sandbox BOOLEAN NOT NULL, requester_id INTEGER NOT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), FOREIGN KEY (assignment_id) REFERENCES assignments (assignment_id), FOREIGN KEY (agent_id) REFERENCES agents (agent_id), FOREIGN KEY (task_run_id) REFERENCES task_runs (task_run_id), @@ -123,20 +124,20 @@ INSERT INTO _units SELECT * FROM units; DROP TABLE units; ALTER TABLE _units RENAME TO units; - - + + /* Workers */ CREATE TABLE IF NOT EXISTS _workers ( worker_id INTEGER PRIMARY KEY, worker_name TEXT NOT NULL UNIQUE, provider_type TEXT NOT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) ); INSERT INTO _workers SELECT * FROM workers; DROP TABLE workers; ALTER TABLE _workers RENAME TO workers; - - + + /* Agents */ CREATE TABLE IF NOT EXISTS _agents ( agent_id INTEGER PRIMARY KEY, @@ -148,7 +149,7 @@ task_type TEXT NOT NULL, provider_type TEXT NOT NULL, status TEXT NOT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), FOREIGN KEY (worker_id) REFERENCES workers (worker_id), FOREIGN KEY (unit_id) REFERENCES units (unit_id), FOREIGN KEY (task_id) REFERENCES tasks (task_id) ON DELETE NO ACTION, @@ -158,8 +159,8 @@ INSERT INTO _agents SELECT * FROM agents; DROP TABLE agents; ALTER TABLE _agents RENAME TO agents; - - + + /* Onboarding Agents */ CREATE TABLE IF NOT EXISTS _onboarding_agents ( onboarding_agent_id INTEGER PRIMARY KEY, @@ -168,52 +169,52 @@ task_run_id INTEGER NOT NULL, task_type TEXT NOT NULL, status TEXT NOT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), FOREIGN KEY (worker_id) REFERENCES workers (worker_id), FOREIGN KEY (task_run_id) REFERENCES task_runs (task_run_id) ); INSERT INTO _onboarding_agents SELECT * FROM onboarding_agents; DROP TABLE onboarding_agents; ALTER TABLE _onboarding_agents RENAME TO onboarding_agents; - - + + /* Qualifications */ CREATE TABLE IF NOT EXISTS _qualifications ( qualification_id INTEGER PRIMARY KEY, qualification_name TEXT NOT NULL UNIQUE, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) ); INSERT INTO _qualifications SELECT * FROM qualifications; DROP TABLE qualifications; ALTER TABLE _qualifications RENAME TO qualifications; - - + + /* Granted Qualifications */ CREATE TABLE IF NOT EXISTS _granted_qualifications ( granted_qualification_id INTEGER PRIMARY KEY, worker_id INTEGER NOT NULL, qualification_id INTEGER NOT NULL, value INTEGER NOT NULL, - update_date DATETIME DEFAULT CURRENT_TIMESTAMP, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + update_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), FOREIGN KEY (worker_id) REFERENCES workers (worker_id), FOREIGN KEY (qualification_id) REFERENCES qualifications (qualification_id), UNIQUE (worker_id, qualification_id) ); /* Copy data from backed up table and set value from `creation_date` to `update_date` */ - INSERT INTO _granted_qualifications - SELECT - granted_qualification_id, - worker_id, - qualification_id, - value, - creation_date, - creation_date + INSERT INTO _granted_qualifications + SELECT + granted_qualification_id, + worker_id, + qualification_id, + value, + creation_date, + creation_date FROM granted_qualifications; DROP TABLE granted_qualifications; ALTER TABLE _granted_qualifications RENAME TO granted_qualifications; - - + + /* Unit Review */ CREATE TABLE IF NOT EXISTS _unit_review ( id INTEGER PRIMARY KEY, @@ -229,7 +230,7 @@ updated_qualification_value INTEGER, /* ID of `db.qualifications` (not `db.granted_qualifications`) */ revoked_qualification_id INTEGER, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), FOREIGN KEY (unit_id) REFERENCES units (unit_id), FOREIGN KEY (worker_id) REFERENCES workers (worker_id), @@ -238,8 +239,8 @@ INSERT INTO _unit_review SELECT * FROM unit_review; DROP TABLE unit_review; ALTER TABLE _unit_review RENAME TO unit_review; - - + + /* Enable FK constraints back */ PRAGMA foreign_keys = on; """ diff --git a/mephisto/abstractions/providers/mock/mock_datastore.py b/mephisto/abstractions/providers/mock/mock_datastore.py index da17a832d..1b4ac5912 100644 --- a/mephisto/abstractions/providers/mock/mock_datastore.py +++ b/mephisto/abstractions/providers/mock/mock_datastore.py @@ -132,7 +132,6 @@ def ensure_worker_exists(self, worker_id: str) -> None: table_name="workers", params={ "worker_id": worker_id, - "is_blocked": False, }, select_field="worker_id", ) diff --git a/mephisto/abstractions/providers/mock/mock_datastore_export.py b/mephisto/abstractions/providers/mock/mock_datastore_export.py index f3fc14ecd..a0fdb45dc 100644 --- a/mephisto/abstractions/providers/mock/mock_datastore_export.py +++ b/mephisto/abstractions/providers/mock/mock_datastore_export.py @@ -27,14 +27,20 @@ def export_datastore( # Find and serialize `units` unit_ids = [i["unit_id"] for i in mephisto_db_data["units"]] unit_rows = db_utils.select_rows_by_list_of_field_values( - datastore, "units", ["unit_id"], [unit_ids], + datastore, + "units", + ["unit_id"], + [unit_ids], ) dump_data["units"] = db_utils.serialize_data_for_table(unit_rows) # Find and serialize `workers` worker_ids = [i["worker_id"] for i in mephisto_db_data["workers"]] workers_rows = db_utils.select_rows_by_list_of_field_values( - datastore, "workers", ["worker_id"], [worker_ids], + datastore, + "workers", + ["worker_id"], + [worker_ids], ) dump_data["workers"] = db_utils.serialize_data_for_table(workers_rows) diff --git a/mephisto/abstractions/providers/mturk/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py b/mephisto/abstractions/providers/mturk/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py new file mode 100644 index 000000000..d38dd9d40 --- /dev/null +++ b/mephisto/abstractions/providers/mturk/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +1. Modified default value for `creation_date` +""" + + +PREPARING_DB_FOR_MERGE_DBS_COMMAND = """ + /* Disable FK constraints */ + PRAGMA foreign_keys = off; + + + /* Hits */ + CREATE TABLE IF NOT EXISTS _hits ( + hit_id TEXT PRIMARY KEY UNIQUE, + unit_id TEXT, + assignment_id TEXT, + link TEXT, + assignment_time_in_seconds INTEGER NOT NULL, + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) + ); + INSERT INTO _hits SELECT * FROM hits; + DROP TABLE hits; + ALTER TABLE _hits RENAME TO hits; + + + /* Run mappings */ + CREATE TABLE IF NOT EXISTS _run_mappings ( + hit_id TEXT, + run_id TEXT + ); + INSERT INTO _run_mappings SELECT * FROM run_mappings; + DROP TABLE run_mappings; + ALTER TABLE _run_mappings RENAME TO run_mappings; + + + /* Runs */ + CREATE TABLE IF NOT EXISTS _runs ( + run_id TEXT PRIMARY KEY UNIQUE, + arn_id TEXT, + hit_type_id TEXT NOT NULL, + hit_config_path TEXT NOT NULL, + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + frame_height INTEGER NOT NULL DEFAULT 650 + ); + INSERT INTO _runs SELECT * FROM runs; + DROP TABLE runs; + ALTER TABLE _runs RENAME TO runs; + + + /* Qualifications */ + CREATE TABLE IF NOT EXISTS _qualifications ( + qualification_name TEXT PRIMARY KEY UNIQUE, + requester_id TEXT, + mturk_qualification_name TEXT, + mturk_qualification_id TEXT, + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) + ); + INSERT INTO _qualifications SELECT * FROM qualifications; + DROP TABLE qualifications; + ALTER TABLE _qualifications RENAME TO qualifications; + + + /* Enable FK constraints back */ + PRAGMA foreign_keys = on; +""" diff --git a/mephisto/abstractions/providers/mturk/migrations/__init__.py b/mephisto/abstractions/providers/mturk/migrations/__init__.py new file mode 100644 index 000000000..092965e1b --- /dev/null +++ b/mephisto/abstractions/providers/mturk/migrations/__init__.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from ._001_20240325_preparing_db_for_merge_dbs_command import * + + +migrations = { + "20240418_preparing_db_for_merge_dbs_command": PREPARING_DB_FOR_MERGE_DBS_COMMAND, +} diff --git a/mephisto/abstractions/providers/mturk/mturk_datastore.py b/mephisto/abstractions/providers/mturk/mturk_datastore.py index 517cae0b0..714860e6b 100644 --- a/mephisto/abstractions/providers/mturk/mturk_datastore.py +++ b/mephisto/abstractions/providers/mturk/mturk_datastore.py @@ -18,8 +18,10 @@ from botocore.exceptions import ProfileNotFound # type: ignore from mephisto.abstractions.databases.local_database import is_unique_failure +from mephisto.utils.db import apply_migrations from mephisto.utils.logger_core import get_logger from . import mturk_datastore_tables as tables +from .migrations import migrations from .mturk_datastore_export import export_datastore MTURK_REGION_NAME = "us-east-1" @@ -90,6 +92,8 @@ def init_tables(self) -> None: except Exception: pass # extra column already exists + apply_migrations(self, migrations) + def get_export_data(self, **kwargs) -> dict: return export_datastore(self, **kwargs) diff --git a/mephisto/abstractions/providers/mturk/mturk_datastore_export.py b/mephisto/abstractions/providers/mturk/mturk_datastore_export.py index f538801a5..a558926cb 100644 --- a/mephisto/abstractions/providers/mturk/mturk_datastore_export.py +++ b/mephisto/abstractions/providers/mturk/mturk_datastore_export.py @@ -31,7 +31,10 @@ def export_datastore( for table_name in tables_with_task_run_relations: table_rows = db_utils.select_rows_by_list_of_field_values( - datastore, table_name, ["run_id"], [task_run_ids], + datastore, + table_name, + ["run_id"], + [task_run_ids], ) runs_table_data = db_utils.serialize_data_for_table(table_rows) dump_data[table_name] = runs_table_data @@ -39,7 +42,10 @@ def export_datastore( # Find and serialize `hits` hit_ids = list(set(filter(bool, [i["hit_id"] for i in dump_data["run_mappings"]]))) hit_rows = db_utils.select_rows_by_list_of_field_values( - datastore, "hits", ["hit_id"], [hit_ids], + datastore, + "hits", + ["hit_id"], + [hit_ids], ) dump_data["hits"] = db_utils.serialize_data_for_table(hit_rows) @@ -47,7 +53,10 @@ def export_datastore( qualification_names = [i["qualification_name"] for i in mephisto_db_data["qualifications"]] if qualification_names: qualification_rows = db_utils.select_rows_by_list_of_field_values( - datastore, "qualifications", ["qualification_name"], [qualification_names], + datastore, + "qualifications", + ["qualification_name"], + [qualification_names], ) else: qualification_rows = db_utils.select_all_table_rows(datastore, "qualifications") diff --git a/mephisto/abstractions/providers/prolific/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py b/mephisto/abstractions/providers/prolific/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py index 6dbe1c116..f792c223e 100644 --- a/mephisto/abstractions/providers/prolific/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py +++ b/mephisto/abstractions/providers/prolific/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py @@ -10,14 +10,15 @@ 3. Added `creation_date` in `units` table 4. Rename field `run_id` -> `task_run_id` 5. Remove table `requesters` +6. Modified default value for `creation_date` """ PREPARING_DB_FOR_MERGE_DBS_COMMAND = """ /* Disable FK constraints */ PRAGMA foreign_keys = off; - - + + /* Studies */ CREATE TABLE IF NOT EXISTS _studies ( id INTEGER PRIMARY KEY, @@ -26,26 +27,26 @@ link TEXT, task_run_id TEXT UNIQUE, assignment_time_in_seconds INTEGER NOT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) ); INSERT INTO _studies SELECT * FROM studies; DROP TABLE studies; ALTER TABLE _studies RENAME TO studies; - - + + /* Submissions */ CREATE TABLE IF NOT EXISTS _submissions ( id INTEGER PRIMARY KEY, prolific_submission_id TEXT UNIQUE, prolific_study_id TEXT, status TEXT DEFAULT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) ); INSERT INTO _submissions SELECT * FROM submissions; DROP TABLE submissions; ALTER TABLE _submissions RENAME TO submissions; - - + + /* Run Mappings */ CREATE TABLE IF NOT EXISTS _run_mappings ( id INTEGER PRIMARY KEY, @@ -55,8 +56,8 @@ INSERT INTO _run_mappings SELECT * FROM run_mappings; DROP TABLE run_mappings; ALTER TABLE _run_mappings RENAME TO run_mappings; - - + + /* Units */ CREATE TABLE IF NOT EXISTS _units ( id INTEGER PRIMARY KEY, @@ -65,44 +66,44 @@ prolific_study_id TEXT, prolific_submission_id TEXT, is_expired BOOLEAN DEFAULT false, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) ); /* Copy data from backed up table and set values for `creation_date` */ - INSERT INTO _units - SELECT - id, - unit_id, - run_id, - prolific_study_id, - prolific_submission_id, - is_expired, - datetime('now', 'localtime') + INSERT INTO _units + SELECT + id, + unit_id, + run_id, + prolific_study_id, + prolific_submission_id, + is_expired, + datetime('now', 'localtime') FROM units; DROP TABLE units; ALTER TABLE _units RENAME TO units; - - + + /* Workers */ CREATE TABLE IF NOT EXISTS _workers ( id INTEGER PRIMARY KEY, worker_id TEXT UNIQUE, is_blocked BOOLEAN default false, - update_date DATETIME DEFAULT CURRENT_TIMESTAMP, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + update_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) ); /* Copy data from backed up table and set values for `creation_date` and `update_date` */ - INSERT INTO _workers - SELECT - id, - worker_id, - is_blocked, - datetime('now', 'localtime'), - datetime('now', 'localtime') + INSERT INTO _workers + SELECT + id, + worker_id, + is_blocked, + datetime('now', 'localtime'), + datetime('now', 'localtime') FROM workers; DROP TABLE workers; ALTER TABLE _workers RENAME TO workers; - - + + /* Runs */ CREATE TABLE IF NOT EXISTS _runs ( id INTEGER PRIMARY KEY, @@ -112,7 +113,7 @@ prolific_project_id TEXT NOT NULL, prolific_study_id TEXT, prolific_study_config_path TEXT NOT NULL, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP, + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), frame_height INTEGER NOT NULL DEFAULT 650, actual_available_places INTEGER DEFAULT NULL, listed_available_places INTEGER DEFAULT NULL @@ -120,8 +121,8 @@ INSERT INTO _runs SELECT * FROM runs; DROP TABLE runs; ALTER TABLE _runs RENAME TO runs; - - + + /* Participant Groups */ CREATE TABLE IF NOT EXISTS _participant_groups ( id INTEGER PRIMARY KEY, @@ -130,13 +131,13 @@ prolific_project_id TEXT, prolific_participant_group_name TEXT, prolific_participant_group_id TEXT, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) ); INSERT INTO _participant_groups SELECT * FROM participant_groups; DROP TABLE participant_groups; ALTER TABLE _participant_groups RENAME TO participant_groups; - - + + /* Runs */ CREATE TABLE IF NOT EXISTS _qualifications ( id INTEGER PRIMARY KEY, @@ -144,21 +145,21 @@ task_run_id TEXT, json_qual_logic TEXT, qualification_ids TEXT, - creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + creation_date DATETIME DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) ); INSERT INTO _qualifications SELECT * FROM qualifications; DROP TABLE qualifications; ALTER TABLE _qualifications RENAME TO qualifications; - - + + /* Enable FK constraints back */ PRAGMA foreign_keys = on; - - + + ALTER TABLE run_mappings RENAME COLUMN run_id TO task_run_id; ALTER TABLE units RENAME COLUMN run_id TO task_run_id; ALTER TABLE runs RENAME COLUMN run_id TO task_run_id; - - + + DROP TABLE IF EXISTS requesters; """ diff --git a/mephisto/abstractions/providers/prolific/prolific_datastore_export.py b/mephisto/abstractions/providers/prolific/prolific_datastore_export.py index cc1d85027..502058041 100644 --- a/mephisto/abstractions/providers/prolific/prolific_datastore_export.py +++ b/mephisto/abstractions/providers/prolific/prolific_datastore_export.py @@ -34,7 +34,9 @@ def export_datastore( for table_name in tables_with_task_run_relations: table_rows = db_utils.select_rows_from_table_related_to_task_run( - datastore, table_name, task_run_ids, + datastore, + table_name, + task_run_ids, ) runs_table_data = db_utils.serialize_data_for_table(table_rows) dump_data[table_name] = runs_table_data @@ -42,16 +44,22 @@ def export_datastore( # Find and serialize `submissions` study_ids = list(set(filter(bool, [i["prolific_study_id"] for i in dump_data["studies"]]))) submission_rows = db_utils.select_rows_by_list_of_field_values( - datastore, "submissions", ["prolific_study_id"], [study_ids], + datastore, + "submissions", + ["prolific_study_id"], + [study_ids], ) dump_data["submissions"] = db_utils.serialize_data_for_table(submission_rows) # Find and serialize `participant_groups` - participant_group_ids = list(set(filter(bool, [ - i["prolific_participant_group_id"] for i in dump_data["qualifications"] - ]))) + participant_group_ids = list( + set(filter(bool, [i["prolific_participant_group_id"] for i in dump_data["qualifications"]])) + ) participant_group_rows = db_utils.select_rows_by_list_of_field_values( - datastore, "participant_groups", ["prolific_participant_group_id"], [participant_group_ids], + datastore, + "participant_groups", + ["prolific_participant_group_id"], + [participant_group_ids], ) dump_data["participant_groups"] = db_utils.serialize_data_for_table(participant_group_rows) @@ -59,10 +67,13 @@ def export_datastore( worker_ids = [i["worker_name"] for i in mephisto_db_data["workers"]] if worker_ids: worker_rows = db_utils.select_rows_by_list_of_field_values( - datastore, "workers", ["worker_id"], [worker_ids], + datastore, + "workers", + ["worker_id"], + [worker_ids], ) else: - worker_rows = db_utils.select_all_table_rows(datastore, "workers") + worker_rows = [] dump_data["workers"] = db_utils.serialize_data_for_table(worker_rows) return dump_data diff --git a/mephisto/client/cli.py b/mephisto/client/cli.py index ef42dca60..6f6c280ce 100644 --- a/mephisto/client/cli.py +++ b/mephisto/client/cli.py @@ -31,6 +31,7 @@ import mephisto.scripts.mturk.launch_makeup_hits as launch_makeup_hits_mturk import mephisto.scripts.mturk.print_outstanding_hit_status as soft_block_workers_by_mturk_id_mturk from mephisto.client.cli_commands import get_wut_arguments +from mephisto.client.cli_db_commands import db_cli from mephisto.generators.form_composer.config_validation.separate_token_values_config import ( update_separate_token_values_config_with_file_urls, ) @@ -49,8 +50,6 @@ set_custom_validators_js_env_var, ) from mephisto.operations.registry import get_valid_provider_types -from mephisto.tools.db_data_porter import DBDataPorter -from mephisto.tools.db_data_porter.constants import DEFAULT_CONFLICT_RESOLVER from mephisto.tools.scripts import build_custom_bundle from mephisto.utils.console_writer import ConsoleWriter from mephisto.utils.rich import console @@ -81,8 +80,8 @@ def cli(): @cli.command("config", cls=RichCommand) -@click.argument("identifier", type=(str), default=None, required=False) -@click.argument("value", type=(str), default=None, required=False) +@click.argument("identifier", type=str, default=None, required=False) +@click.argument("value", type=str, default=None, required=False) def config(identifier, value): from mephisto.operations.config_handler import ( get_config_arg, @@ -197,6 +196,7 @@ def register_provider(args): try: parsed_options = parse_arg_dict(RequesterClass, args_dict) except Exception as e: + parsed_options = None click.echo(str(e)) if parsed_options.name is None: @@ -364,11 +364,11 @@ def metrics_cli(args): @cli.command("review_app", cls=RichCommand) -@click.option("-h", "--host", type=(str), default="127.0.0.1") -@click.option("-p", "--port", type=(int), default=5000) -@click.option("-d", "--debug", type=(bool), default=False, is_flag=True) -@click.option("-f", "--force-rebuild", type=(bool), default=False, is_flag=True) -@click.option("-s", "--skip-build", type=(bool), default=False, is_flag=True) +@click.option("-h", "--host", type=str, default="127.0.0.1") +@click.option("-p", "--port", type=int, default=5000) +@click.option("-d", "--debug", type=bool, default=False, is_flag=True) +@click.option("-f", "--force-rebuild", type=bool, default=False, is_flag=True) +@click.option("-s", "--skip-build", type=bool, default=False, is_flag=True) @pass_script_info def review_app( info: ScriptInfo, @@ -462,7 +462,7 @@ def _get_form_composer_app_path() -> str: @cli.command("form_composer", cls=RichCommand) -@click.option("-o", "--task-data-config-only", type=(bool), default=True, is_flag=True) +@click.option("-o", "--task-data-config-only", type=bool, default=True, is_flag=True) def form_composer(task_data_config_only: bool = True): # Get app path to run Python script from there (instead of the current file's directory). # This is necessary, because the whole infrastructure is built relative to the location @@ -501,12 +501,12 @@ def form_composer(task_data_config_only: bool = True): @cli.command("form_composer_config", cls=RichCommand) -@click.option("-v", "--verify", type=(bool), default=False, is_flag=True) -@click.option("-f", "--update-file-location-values", type=(str), default=None) -@click.option("-e", "--extrapolate-token-sets", type=(bool), default=False, is_flag=True) -@click.option("-p", "--permutate-separate-tokens", type=(bool), default=False, is_flag=True) -@click.option("-d", "--directory", type=(str), default=None) -@click.option("-u", "--use-presigned-urls", type=(bool), default=False, is_flag=True) +@click.option("-v", "--verify", type=bool, default=False, is_flag=True) +@click.option("-f", "--update-file-location-values", type=str, default=None) +@click.option("-e", "--extrapolate-token-sets", type=bool, default=False, is_flag=True) +@click.option("-p", "--permutate-separate-tokens", type=bool, default=False, is_flag=True) +@click.option("-d", "--directory", type=str, default=None) +@click.option("-u", "--use-presigned-urls", type=bool, default=False, is_flag=True) def form_composer_config( verify: Optional[bool] = False, update_file_location_values: Optional[str] = None, @@ -623,189 +623,7 @@ def form_composer_config( ) -@cli.command("db", cls=RichCommand) -@click.argument("action_name", required=True, nargs=1) -@click.option("-d", "--dump-file", type=(str), default=None) -@click.option("-i", "--export-indent", type=(int), default=None) -@click.option("-tn", "--export-tasks-by-names", type=(str), multiple=True, default=None) -@click.option("-ti", "--export-tasks-by-ids", type=(str), multiple=True, default=None) -@click.option("-tr", "--export-task-runs-by-ids", type=(str), multiple=True, default=None) -@click.option("-trs", "--export-task-runs-since-date", type=(str), default=None) -@click.option("-tl", "--export-labels", type=(str), multiple=True, default=None) -@click.option("-de", "--delete-exported-data", type=(bool), default=False, is_flag=True) -@click.option("-r", "--randomize-legacy-ids", type=(bool), default=False, is_flag=True) -@click.option("-l", "--label-name", type=(str), default=None) -@click.option("-cr", "--conflict-resolver", type=(str), default=DEFAULT_CONFLICT_RESOLVER) -@click.option("-k", "--keep-import-metadata", type=(bool), default=False, is_flag=True) -@click.option("-b", "--backup-file", type=(str), default=None) -@click.option("-v", "--verbosity", type=(int), default=0) -def db( - action_name: str, - dump_file: Optional[str] = None, - export_indent: Optional[int] = None, - export_tasks_by_names: Optional[List[str]] = None, - export_tasks_by_ids: Optional[List[str]] = None, - export_task_runs_by_ids: Optional[List[str]] = None, - export_task_runs_since_date: Optional[str] = None, - export_labels: Optional[List[str]] = None, - delete_exported_data: bool = False, - randomize_legacy_ids: bool = False, - label_name: Optional[str] = None, - conflict_resolver: Optional[str] = DEFAULT_CONFLICT_RESOLVER, - keep_import_metadata: Optional[bool] = False, - backup_file: Optional[str] = None, - verbosity: int = 0, -): - """ - Operations with Mephisto DB and provider-specific datastores. - - Commands: - 1. mephisto db export - This command exports data from Mephisto DB and provider-specific datastores - as a combination of (i) a JSON file, and (ii) an archived `data` catalog with related files. - - If no parameter passed, full data dump (i.e. backup) will be created. - - To pass a list of values for one command option, - simply repeat that option name before each value. - - Options (all optional): - `-tn/--export-tasks-by-names` - names of Tasks that will be exported - `-ti/--export-tasks-by-ids` - ids of Tasks that will be exported - `-tr/--export-task-runs-by-ids` - ids of TaskRuns that will be exported - `-trs/--export-task-runs-since-date` - only objects created after this - ISO8601 datetime will be exported - `-tl/--export-labels` - only data imported under these labels will be exported - `-de/--delete-exported-data` - after exporting data, delete it from local DB - `-r/--randomize-legacy-ids` - replace legacy autoincremented ids with - new pseudo-random ids to avoid conflicts during data merging - `-i/--export-indent` - make dump easy to read via formatting JSON with indentations - `-v/--verbosity` - write more informative messages about progress - (Default 0. Values: 0, 1) - - - 2. mephisto db import --dump-file - - This command imports data from a dump file created by `mephisto db export` command. - - Options: - `-d/--dump-file` - location of the __***.json__ dump file (filename if created in - `/outputs/export` folder, or absolute filepath) - `-cr/--conflict-resolver` (Optional) - name of Python class - to be used for resolving merging conflicts (when your local DB already has a row - with same unique field value as a DB row in the dump data) - `-l/--label-name` - a short string serving as a reference for the ported data - (stored in `imported_data` table), so later you can export the imported data - with `--export-labels` export option - `-k/--keep-import-metadata` - write data from `imported_data` table of the dump - (by default it's not imported) - `-v/--verbosity` - level of logging (default: 0; values: 0, 1) - - 3. mephisto db backup - - Creates full backup of all current data (Mephisto DB, provider-specific datastores, - and related files) on local machine. - - 4. mephisto db restore --backup-file - - Restores all data (Mephisto DB, provider-specific datastores, and related files) - from a backup archive. - - Options: - `-b/--backup-file` - location of the __*.zip__ backup file (filename if created in - `/outputs/backup` folder, or absolute filepath) - `-v/--verbosity` - level of logging (default: 0; values: 0, 1) - """ - porter = DBDataPorter() - - # --- EXPORT --- - if action_name == "export": - has_conflicting_task_runs_options = len(list(filter(bool, [ - export_tasks_by_names, - export_tasks_by_ids, - export_task_runs_by_ids, - export_task_runs_since_date, - export_labels, - ]))) > 1 - - if has_conflicting_task_runs_options: - logger.warning( - "[yellow]" - "You cannot use following options together:" - "\n\t--export-tasks-by-names" - "\n\t--export-tasks-by-ids" - "\n\t--export-task-runs-by-ids" - "\n\t--export-task-runs-since-date" - "\n\t--export-labels" - "\nUse one of them or none of them to export all data." - "[/yellow]" - ) - exit() - - logger.info(f"Started exporting") - - export_results = porter.export_dump( - json_indent=export_indent, - task_names=export_tasks_by_names, - task_ids=export_tasks_by_ids, - task_run_ids=export_task_runs_by_ids, - task_runs_since_date=export_task_runs_since_date, - task_runs_labels=export_labels, - delete_exported_data=delete_exported_data, - randomize_legacy_ids=randomize_legacy_ids, - verbosity=verbosity, - ) - - data_files_line = "" - if export_results["data_path"]: - data_files_line = f"\n\t- Data files dump - {export_results['data_path']}" - - backup_line = "" - if export_results["backup_path"]: - backup_line = f"\n\t- Backup - {export_results['backup_path']}" - - logger.info( - f"[green]" - f"Finished successfully! " - f"\nFiles created:" - f"\n\t- Database dump - {export_results['db_path']}" - f"{data_files_line}" - f"{backup_line}" - f"[/green]" - ) - - # --- IMPORT --- - elif action_name == "import": - logger.info(f"Started importing from dump '{dump_file}'") - porter.import_dump( - dump_file_name_or_path=dump_file, - conflict_resolver_name=conflict_resolver, - label=label_name, - keep_import_metadata=keep_import_metadata, - verbosity=verbosity, - ) - logger.info(f"[green]Finished successfully[/green]") - - # --- BACKUP --- - elif action_name == "backup": - logger.info(f"Started making backup") - backup_path = porter.make_backup() - logger.info(f"[green]Finished successfully! File: '{backup_path}[/green]") - - # --- RESTORE --- - elif action_name == "restore": - logger.info(f"Started restoring from backup '{backup_file}'") - porter.restore_from_backup(backup_file_name_or_path=backup_file, verbosity=verbosity) - logger.info(f"[green]Finished successfully[/green]") - - # Otherwise, error - else: - logger.error( - f"[red]" - f"Unexpected action name '{action_name}'. Available: export, import, restore." - f"[/red]" - ) - exit() +cli.add_command(db_cli) if __name__ == "__main__": diff --git a/mephisto/client/cli_db_commands.py b/mephisto/client/cli_db_commands.py new file mode 100644 index 000000000..17790ec14 --- /dev/null +++ b/mephisto/client/cli_db_commands.py @@ -0,0 +1,318 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List +from typing import Optional + +import click +from rich_click import RichCommand +from rich_click import RichContext + +from mephisto.tools.db_data_porter import DBDataPorter +from mephisto.tools.db_data_porter.constants import DEFAULT_CONFLICT_RESOLVER +from mephisto.tools.db_data_porter.export_dump import get_export_options_for_metadata +from mephisto.utils.console_writer import ConsoleWriter + +VERBOSITY_HELP = "write more informative messages about progress (Default 0. Values: 0, 1)" +VERBOSITY_DEFAULT_VALUE = 0 + +logger = ConsoleWriter() + + +def _print_used_options_for_running_command_message(ctx: RichContext, options: dict): + message = "Running command with the following options:\n" + for p in ctx.command.params: + values = options[p.name] + + if isinstance(values, tuple): + values = list(values) + if not values: + values = None + + message += f"\t{'/'.join(p.opts)} = {values}\n" + + logger.debug(message) + + +@click.group(name="db", context_settings=dict(help_option_names=["-h", "--help"])) +def db_cli(): + """Operations with Mephisto DB and provider-specific datastores.""" + pass + + +# --- EXPORT --- +@db_cli.command("export", cls=RichCommand) +@click.pass_context +@click.option( + "-i", + "--export-indent", + type=int, + default=2, + help="make dump easy to read via formatting JSON with indentations (Default 2)", +) +@click.option( + "-tn", + "--export-tasks-by-names", + type=str, + multiple=True, + default=None, + help="names of Tasks that will be exported", +) +@click.option( + "-ti", + "--export-tasks-by-ids", + type=str, + multiple=True, + default=None, + help="ids of Tasks that will be exported", +) +@click.option( + "-tri", + "--export-task-runs-by-ids", + type=str, + multiple=True, + default=None, + help="ids of TaskRuns that will be exported", +) +@click.option( + "-trs", + "--export-task-runs-since-date", + type=str, + default=None, + help="only objects created after this ISO8601 datetime will be exported", +) +@click.option( + "-l", + "--labels", + type=str, + multiple=True, + default=None, + help="only data imported under these labels will be exported", +) +@click.option( + "-del", + "--delete-exported-data", + type=bool, + default=False, + is_flag=True, + help="after exporting data, delete it from local DB", +) +@click.option( + "-r", + "--randomize-legacy-ids", + type=bool, + default=False, + is_flag=True, + help=( + "replace legacy autoincremented ids with new pseudo-random ids " + "to avoid conflicts during data merging" + ), +) +@click.option("-v", "--verbosity", type=int, default=VERBOSITY_DEFAULT_VALUE, help=VERBOSITY_HELP) +def export(ctx: RichContext, **options: dict): + """ + This command exports data from Mephisto DB and provider-specific datastores + as an archived combination of (i) a JSON file, and (ii) a `data` catalog with related files. + If no parameter passed, full data dump (i.e. backup) will be created. + To pass a list of values for one command option, + simply repeat that option name before each value. + + mephisto db export + """ + _print_used_options_for_running_command_message(ctx, options) + + export_indent: Optional[int] = options.get("export_indent", 2) + export_tasks_by_names: Optional[List[str]] = options.get("export_tasks_by_names") + export_tasks_by_ids: Optional[List[str]] = options.get("export_tasks_by_ids") + export_task_runs_by_ids: Optional[List[str]] = options.get("export_task_runs_by_ids") + export_task_runs_since_date: Optional[str] = options.get("export_task_runs_since_date") + export_labels: Optional[List[str]] = options.get("export_labels") + delete_exported_data: bool = options.get("delete_exported_data", False) + randomize_legacy_ids: bool = options.get("randomize_legacy_ids", False) + verbosity: int = options.get("verbosity", VERBOSITY_DEFAULT_VALUE) + + porter = DBDataPorter() + + has_conflicting_task_runs_options = ( + len( + list( + filter( + bool, + [ + export_tasks_by_names, + export_tasks_by_ids, + export_task_runs_by_ids, + export_task_runs_since_date, + export_labels, + ], + ) + ) + ) + > 1 + ) + + if has_conflicting_task_runs_options: + logger.warning( + "[yellow]" + "You cannot use following options together:" + "\n\t--export-tasks-by-names" + "\n\t--export-tasks-by-ids" + "\n\t--export-task-runs-by-ids" + "\n\t--export-task-runs-since-date" + "\n\t--labels" + "\nUse one of them or none of them to export all data." + "[/yellow]" + ) + exit() + + export_results = porter.export_dump( + json_indent=export_indent, + task_names=export_tasks_by_names, + task_ids=export_tasks_by_ids, + task_run_ids=export_task_runs_by_ids, + task_runs_since_date=export_task_runs_since_date, + task_runs_labels=export_labels, + delete_exported_data=delete_exported_data, + randomize_legacy_ids=randomize_legacy_ids, + metadata_export_options=get_export_options_for_metadata(ctx, options), + verbosity=verbosity, + ) + + backup_line = "" + if export_results["backup_path"]: + backup_line = f"\nCreated backup file (just in case): {export_results['backup_path']}" + + logger.info( + f"[green]" + f"Finished successfully, saved to file: {export_results['dump_path']}" + f"{backup_line}" + f"[/green]" + ) + + +# --- IMPORT --- +@db_cli.command("import", cls=RichCommand) +@click.pass_context +@click.option( + "-f", + "--file", + type=str, + default=None, + help=( + "location of the `***.zip` dump file " + "(filename if created in `/outputs/export` folder, or absolute filepath)" + ), +) +@click.option( + "-l", + "--labels", + type=str, + multiple=True, + default=None, + help=( + "a short strings serving as a reference for the ported data " + "(stored in `imported_data` table), " + "so later you can export the imported data with `--labels` export option" + ), +) +@click.option( + "-cr", + "--conflict-resolver", + type=str, + default=DEFAULT_CONFLICT_RESOLVER, + help=( + "(Optional) name of Python class to be used for resolving merging conflicts " + "(when your local DB already has a row with same unique field value " + "as a DB row in the dump data)" + ), +) +@click.option( + "-k", + "--keep-import-metadata", + type=bool, + default=False, + is_flag=True, + help="write data from `imported_data` table of the dump (by default it's not imported)", +) +@click.option("-v", "--verbosity", type=int, default=VERBOSITY_DEFAULT_VALUE, help=VERBOSITY_HELP) +def _import(ctx: RichContext, **options: dict): + """ + This command imports data from a dump file created by `mephisto db export` command. + + mephisto db import --file + """ + _print_used_options_for_running_command_message(ctx, options) + + file: Optional[str] = options.get("file") + labels: Optional[str] = options.get("labels") + conflict_resolver: Optional[str] = options.get("conflict_resolver", DEFAULT_CONFLICT_RESOLVER) + keep_import_metadata: Optional[bool] = options.get("keep_import_metadata", False) + verbosity: int = options.get("verbosity", VERBOSITY_DEFAULT_VALUE) + + porter = DBDataPorter() + results = porter.import_dump( + dump_archive_file_name_or_path=file, + conflict_resolver_name=conflict_resolver, + labels=labels, + keep_import_metadata=keep_import_metadata, + verbosity=verbosity, + ) + logger.info( + f"[green]" + f"Finished successfully. Imported {results['imported_task_runs_number']} TaskRuns" + f"[/green]" + ) + + +# --- BACKUP --- +@db_cli.command("backup", cls=RichCommand) +@click.pass_context +@click.option("-v", "--verbosity", type=int, default=VERBOSITY_DEFAULT_VALUE, help=VERBOSITY_HELP) +def backup(ctx: RichContext, **options: dict): + """ + Creates full backup of all current data (Mephisto DB, provider-specific datastores, + and related files) on local machine. + + mephisto db backup + """ + _print_used_options_for_running_command_message(ctx, options) + + verbosity: int = options.get("verbosity", VERBOSITY_DEFAULT_VALUE) + + porter = DBDataPorter() + backup_path = porter.create_backup(verbosity=verbosity) + logger.info(f"[green]Finished successfully, saved to file: {backup_path}[/green]") + + +# --- RESTORE --- +@db_cli.command("restore", cls=RichCommand) +@click.pass_context +@click.option( + "-f", + "--file", + type=str, + default=None, + help=( + "location of the `***.zip` backup file (filename if created in " + "`/outputs/backup` folder, or absolute filepath)" + ), +) +@click.option("-v", "--verbosity", type=int, default=VERBOSITY_DEFAULT_VALUE, help=VERBOSITY_HELP) +def restore(ctx: RichContext, **options): + """ + Restores all data (Mephisto DB, provider-specific datastores, and related files) + from a backup archive. + + mephisto db restore --file + """ + _print_used_options_for_running_command_message(ctx, options) + + file: str = options.get("file") + verbosity: int = options.get("verbosity", VERBOSITY_DEFAULT_VALUE) + + porter = DBDataPorter() + porter.restore_from_backup(backup_file_name_or_path=file, verbosity=verbosity) + logger.info(f"[green]Finished successfully[/green]") diff --git a/mephisto/tools/db_data_porter/backups.py b/mephisto/tools/db_data_porter/backups.py index 9222917ad..a77e15a50 100644 --- a/mephisto/tools/db_data_porter/backups.py +++ b/mephisto/tools/db_data_porter/backups.py @@ -6,165 +6,31 @@ import os import shutil -from distutils.dir_util import copy_tree from pathlib import Path -from typing import List -from mephisto.abstractions.database import MephistoDB -from mephisto.data_model.task_run import TaskRun -from mephisto.tools.db_data_porter.constants import AGENTS_TABLE_NAME -from mephisto.tools.db_data_porter.constants import ASSIGNMENTS_TABLE_NAME -from mephisto.tools.db_data_porter.constants import MEPHISTO_DUMP_KEY -from mephisto.tools.db_data_porter.constants import TASK_RUNS_TABLE_NAME -from mephisto.tools.db_data_porter.randomize_ids import get_old_pk_from_substitutions -from mephisto.utils import db as db_utils +from mephisto.tools.db_data_porter.constants import DEFAULT_ARCHIVE_FORMAT from mephisto.utils.console_writer import ConsoleWriter from mephisto.utils.dirs import get_data_dir -from mephisto.utils.dirs import get_mephisto_tmp_dir - -DEFAULT_ARCHIVE_FORMAT = "zip" logger = ConsoleWriter() -def _rename_dirs_with_new_pks(task_run_dirs: List[str], pk_substitutions: dict): - def rename_dir_with_new_pk(dir_path: str, substitutions: dict) -> str: - dump_id = substitutions.get(os.path.basename(dir_path)) - renamed_dir_path = os.path.join(os.path.dirname(dir_path), dump_id) - os.rename(dir_path, renamed_dir_path) - return renamed_dir_path - - task_runs_subs = pk_substitutions.get(MEPHISTO_DUMP_KEY, {}).get(TASK_RUNS_TABLE_NAME, {}) - if not task_runs_subs: - # Nothing to rename - return - - assignment_subs = pk_substitutions.get(MEPHISTO_DUMP_KEY, {}).get(ASSIGNMENTS_TABLE_NAME, {}) - agent_subs = pk_substitutions.get(MEPHISTO_DUMP_KEY, {}).get(AGENTS_TABLE_NAME, {}) - - task_run_dirs = [ - d for d in task_run_dirs if os.path.basename(d) in task_runs_subs.keys() - ] - for task_run_dir in task_run_dirs: - # Rename TaskRun dir - renamed_task_run_dir = rename_dir_with_new_pk(task_run_dir, task_runs_subs) - - # Rename Assignments dirs - assignments_dirs = [ - os.path.join(renamed_task_run_dir, d) for d in os.listdir(renamed_task_run_dir) - if d in assignment_subs.keys() - ] - for assignment_dir in assignments_dirs: - renamed_assignment_dir = rename_dir_with_new_pk(assignment_dir, assignment_subs) - - # Rename Agents dirs - agents_dirs = [ - os.path.join(renamed_assignment_dir, d) for d in os.listdir(renamed_assignment_dir) - if d in agent_subs.keys() - ] - for agent_dir in agents_dirs: - rename_dir_with_new_pk(agent_dir, agent_subs) - - -def _export_data_dir_for_task_runs( - input_dir_path: str, - archive_file_path_without_ext: str, - task_runs: List[TaskRun], - pk_substitutions: dict, - _format: str = DEFAULT_ARCHIVE_FORMAT, - verbosity: int = 0, -) -> bool: - tmp_dir = get_mephisto_tmp_dir() - tmp_export_dir = os.path.join(tmp_dir, "export") - - task_run_data_dirs = [i.get_run_dir() for i in task_runs] - if not task_run_data_dirs: - return False - - try: - tmp_task_run_dirs = [] - - # Copy all files for passed TaskRuns into tmp dir - for task_run_data_dir in task_run_data_dirs: - relative_dir = Path(task_run_data_dir).relative_to(input_dir_path) - tmp_task_run_dir = os.path.join(tmp_export_dir, relative_dir) - - tmp_task_run_dirs.append(tmp_task_run_dir) - - os.makedirs(tmp_task_run_dir, exist_ok=True) - copy_tree(task_run_data_dir, tmp_task_run_dir, verbose=verbosity) - - _rename_dirs_with_new_pks(tmp_task_run_dirs, pk_substitutions) - - # Create archive in export dir - shutil.make_archive( - base_name=archive_file_path_without_ext, - format="zip", - root_dir=tmp_export_dir, - ) - finally: - # Remove tmp dir - if os.path.exists(tmp_export_dir): - shutil.rmtree(tmp_export_dir) - - return True - - def make_backup_file_path_by_timestamp( - backup_dir: str, timestamp: str, _format: str = DEFAULT_ARCHIVE_FORMAT, + backup_dir: str, + timestamp: str, + _format: str = DEFAULT_ARCHIVE_FORMAT, ) -> str: return os.path.join(backup_dir, f"{timestamp}_mephisto_backup.{_format}") -def make_full_data_dir_backup( - backup_dir: str, timestamp: str, _format: str = DEFAULT_ARCHIVE_FORMAT, -) -> str: +def make_full_data_dir_backup(backup_file_path: str, _format: str = DEFAULT_ARCHIVE_FORMAT) -> str: mephisto_data_dir = get_data_dir() - file_name_without_ext = f"{timestamp}_mephisto_backup" - archive_file_path_without_ext = os.path.join(backup_dir, file_name_without_ext) - shutil.make_archive( - base_name=archive_file_path_without_ext, + base_name=os.path.splitext(backup_file_path)[0], format=_format, root_dir=mephisto_data_dir, ) - - return make_backup_file_path_by_timestamp(backup_dir, file_name_without_ext, _format) - - -def archive_and_copy_data_files( - db: "MephistoDB", - export_dir: str, - dump_name: str, - dump_data: dict, - pk_substitutions: dict, - verbosity: int = 0, - _format: str = DEFAULT_ARCHIVE_FORMAT, -) -> bool: - mephisto_data_files_path = os.path.join(get_data_dir(), "data") - output_zip_file_base_name = os.path.join(export_dir, dump_name) # name without extension - - # Get TaskRuns for PKs in dump - task_runs: List[TaskRun] = [] - for dump_task_run in dump_data[MEPHISTO_DUMP_KEY][TASK_RUNS_TABLE_NAME]: - task_runs_pk_field_name = db_utils.get_table_pk_field_name(db, TASK_RUNS_TABLE_NAME) - dump_pk = dump_task_run[task_runs_pk_field_name] - db_pk = get_old_pk_from_substitutions(dump_pk, pk_substitutions, TASK_RUNS_TABLE_NAME) - db_pk = db_pk or dump_pk - task_run: TaskRun = TaskRun.get(db, db_pk) - task_runs.append(task_run) - - # Export archived related data files to TaskRuns from dump - exported = _export_data_dir_for_task_runs( - input_dir_path=mephisto_data_files_path, - archive_file_path_without_ext=output_zip_file_base_name, - task_runs=task_runs, - pk_substitutions=pk_substitutions, - _format=_format, - verbosity=verbosity, - ) - - return exported + return backup_file_path def restore_from_backup( @@ -180,4 +46,4 @@ def restore_from_backup( Path(backup_file_path).unlink(missing_ok=True) except Exception as e: logger.exception(f"[red]Could not restore backup '{backup_file_path}'. Error: {e}[/red]") - raise + exit() diff --git a/mephisto/tools/db_data_porter/conflict_resolvers/__init__.py b/mephisto/tools/db_data_porter/conflict_resolvers/__init__.py index 86a1a7d02..96190d3af 100644 --- a/mephisto/tools/db_data_porter/conflict_resolvers/__init__.py +++ b/mephisto/tools/db_data_porter/conflict_resolvers/__init__.py @@ -21,8 +21,8 @@ attribute = getattr(module, attribute_name) if ( - isclass(attribute) and - issubclass(attribute, BaseMergeConflictResolver) and - attribute is not BaseMergeConflictResolver + isclass(attribute) + and issubclass(attribute, BaseMergeConflictResolver) + and attribute is not BaseMergeConflictResolver ): globals().update({attribute.__name__: attribute}) diff --git a/mephisto/tools/db_data_porter/conflict_resolvers/base_merge_conflict_resolver.py b/mephisto/tools/db_data_porter/conflict_resolvers/base_merge_conflict_resolver.py index da5dbb5e0..46c12ecaa 100644 --- a/mephisto/tools/db_data_porter/conflict_resolvers/base_merge_conflict_resolver.py +++ b/mephisto/tools/db_data_porter/conflict_resolvers/base_merge_conflict_resolver.py @@ -47,7 +47,10 @@ def __init__(self, db: "MephistoDB", provider_type: str): @staticmethod def _merge_rows_after_resolving( - table_pk_field_name: str, db_row: dict, dump_row: dict, resolved_row: dict, + table_pk_field_name: str, + db_row: dict, + dump_row: dict, + resolved_row: dict, ) -> dict: """ After we've resolved merging conflicts with rows fields, @@ -81,32 +84,37 @@ def _merge_rows_after_resolving( @staticmethod def _serialize_compared_fields_in_rows( - db_row: dict, dump_row: dict, compared_field_name: str, + db_row: dict, + dump_row: dict, + row_field_name: str, ) -> Tuple[dict, dict]: - db_value = db_row[compared_field_name] - dump_value = dump_row[compared_field_name] + db_value = db_row[row_field_name] + dump_value = dump_row[row_field_name] # Date fields - if compared_field_name.endswith("_at") or compared_field_name.endswith("_date"): - db_row[compared_field_name] = serialize_date_to_python(db_value) - dump_row[compared_field_name] = serialize_date_to_python(dump_value) + if row_field_name.endswith("_at") or row_field_name.endswith("_date"): + db_row[row_field_name] = serialize_date_to_python(db_value) + dump_row[row_field_name] = serialize_date_to_python(dump_value) # Numeric fields (integer or float) # Note: We cast both compared values to a numeric type # ONLY when one value is numeric, and another one is a string # (to avoid, for example, casting float to integer) for _type in [int, float]: - if ( - (isinstance(db_value, _type) and isinstance(dump_value, str)) or - (isinstance(db_value, str) and isinstance(dump_value, _type)) + if (isinstance(db_value, _type) and isinstance(dump_value, str)) or ( + isinstance(db_value, str) and isinstance(dump_value, _type) ): - db_row[compared_field_name] = _type(db_value) - dump_row[compared_field_name] = _type(dump_value) + db_row[row_field_name] = _type(db_value) + dump_row[row_field_name] = _type(dump_value) return db_row, dump_row def resolve( - self, table_name: str, table_pk_field_name: str, db_row: dict, dump_row: dict, + self, + table_name: str, + table_pk_field_name: str, + db_row: dict, + dump_row: dict, ) -> dict: """ Default logic of validating `strategies_config`, @@ -144,7 +152,10 @@ def resolve( # 4. Merge data merged_row = self._merge_rows_after_resolving( - table_pk_field_name, db_row, dump_row, resolved_row, + table_pk_field_name, + db_row, + dump_row, + resolved_row, ) # 4. Return merged row @@ -152,13 +163,18 @@ def resolve( # --- Prepared most cummon strategies --- def pick_row_with_smaller_value( - self, db_row: dict, dump_row: dict, compared_field_name: str, + self, + db_row: dict, + dump_row: dict, + row_field_name: str, ) -> dict: db_row, dump_row = self._serialize_compared_fields_in_rows( - db_row, dump_row, compared_field_name, + db_row, + dump_row, + row_field_name, ) - db_value = db_row[compared_field_name] - dump_value = dump_row[compared_field_name] + db_value = db_row[row_field_name] + dump_value = dump_row[row_field_name] # None cannot be compared with anything if db_value is None: @@ -172,13 +188,18 @@ def pick_row_with_smaller_value( return dump_row def pick_row_with_larger_value( - self, db_row: dict, dump_row: dict, compared_field_name: str, + self, + db_row: dict, + dump_row: dict, + row_field_name: str, ) -> dict: db_row, dump_row = self._serialize_compared_fields_in_rows( - db_row, dump_row, compared_field_name, + db_row, + dump_row, + row_field_name, ) - db_value = db_row[compared_field_name] - dump_value = dump_row[compared_field_name] + db_value = db_row[row_field_name] + dump_value = dump_row[row_field_name] # None cannot be compared with anything if db_value is None: @@ -192,23 +213,34 @@ def pick_row_with_larger_value( return dump_row def pick_row_from_db( - self, db_row: dict, dump_row: dict, compared_field_name: Optional[str] = None, + self, + db_row: dict, + dump_row: dict, + row_field_name: Optional[str] = None, ) -> dict: return db_row def pick_row_from_dump( - self, db_row: dict, dump_row: dict, compared_field_name: Optional[str] = None, + self, + db_row: dict, + dump_row: dict, + row_field_name: Optional[str] = None, ) -> dict: return dump_row def pick_row_with_earlier_value( - self, db_row: dict, dump_row: dict, compared_field_name: str = "creation_date", + self, + db_row: dict, + dump_row: dict, + row_field_name: str = "creation_date", ) -> dict: db_row, dump_row = self._serialize_compared_fields_in_rows( - db_row, dump_row, compared_field_name, + db_row, + dump_row, + row_field_name, ) - db_value = db_row[compared_field_name] - dump_value = dump_row[compared_field_name] + db_value = db_row[row_field_name] + dump_value = dump_row[row_field_name] # None cannot be compared with anything if db_value is None: @@ -221,13 +253,18 @@ def pick_row_with_earlier_value( return dump_row def pick_row_with_later_value( - self, db_row: dict, dump_row: dict, compared_field_name: str = "creation_date", + self, + db_row: dict, + dump_row: dict, + row_field_name: str = "creation_date", ) -> dict: db_row, dump_row = self._serialize_compared_fields_in_rows( - db_row, dump_row, compared_field_name, + db_row, + dump_row, + row_field_name, ) - db_value = db_row[compared_field_name] - dump_value = dump_row[compared_field_name] + db_value = db_row[row_field_name] + dump_value = dump_row[row_field_name] # None cannot be compared with anything if db_value is None: diff --git a/mephisto/tools/db_data_porter/conflict_resolvers/default_merge_conflict_resolver.py b/mephisto/tools/db_data_porter/conflict_resolvers/default_merge_conflict_resolver.py index 85a2bce68..ab7743913 100644 --- a/mephisto/tools/db_data_porter/conflict_resolvers/default_merge_conflict_resolver.py +++ b/mephisto/tools/db_data_porter/conflict_resolvers/default_merge_conflict_resolver.py @@ -26,7 +26,7 @@ class DefaultMergeConflictResolver(BaseMergeConflictResolver): # Go with more restrictive value "method": "pick_row_with_smaller_value", "kwargs": { - "compared_field_name": "value", + "row_field_name": "value", }, }, }, @@ -36,7 +36,7 @@ class DefaultMergeConflictResolver(BaseMergeConflictResolver): # Note that `is_blocked` is SQLite-boolean, which is an integer in Python "method": "pick_row_with_larger_value", "kwargs": { - "compared_field_name": "is_blocked", + "row_field_name": "is_blocked", }, }, }, diff --git a/mephisto/tools/db_data_porter/conflict_resolvers/example_merge_conflict_resolver.py b/mephisto/tools/db_data_porter/conflict_resolvers/example_merge_conflict_resolver.py new file mode 100644 index 000000000..6cf5c3ebb --- /dev/null +++ b/mephisto/tools/db_data_porter/conflict_resolvers/example_merge_conflict_resolver.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from datetime import datetime +from typing import Optional + +from rich import print as rich_print + +from mephisto.tools.db_data_porter.constants import MEPHISTO_DUMP_KEY +from mephisto.tools.db_data_porter.constants import MOCK_PROVIDER_TYPE +from mephisto.tools.db_data_porter.constants import MTURK_PROVIDER_TYPE +from mephisto.tools.db_data_porter.constants import PROLIFIC_PROVIDER_TYPE +from .base_merge_conflict_resolver import BaseMergeConflictResolver + + +class ExampleMergeConflictResolver(BaseMergeConflictResolver): + """ + Example how to write your own conflict resolver. + + NOTE: do not accidentally use this example resolver on your real data. + """ + + default_strategy_name = "pick_row_from_db_and_set_creation_date_to_y2k" + + def pick_row_from_db_and_set_creation_date_to_y2k( + self, + db_row: dict, + dump_row: dict, + row_field_name: Optional[str] = None, + ) -> dict: + if "creation_date" in db_row: + db_row["creation_date"] = datetime(2000, 1, 1) + rich_print(f"\tSet `creation_date` to y2k for row {db_row}") + + return db_row + + def concatenate_values( + self, + db_row: dict, + dump_row: dict, + row_field_name: str, + separator: str, + ) -> dict: + resulting_row = self.pick_row_from_db_and_set_creation_date_to_y2k(db_row, dump_row) + + db_value = db_row[row_field_name] or "" + dump_value = dump_row[row_field_name] or "" + + if dump_value and db_value: + resulting_row[row_field_name] = db_value + separator + dump_value + rich_print(f"\tConcatenated `{row_field_name}` values for row {resulting_row}") + + return resulting_row + + strategies_config = { + MEPHISTO_DUMP_KEY: { + "tasks": { + # Concatenate names + "method": "concatenate_values", + "kwargs": { + "row_field_name": "task_name", + "separator": " + ", + }, + }, + }, + PROLIFIC_PROVIDER_TYPE: {}, + MOCK_PROVIDER_TYPE: {}, + MTURK_PROVIDER_TYPE: {}, + } diff --git a/mephisto/tools/db_data_porter/constants.py b/mephisto/tools/db_data_porter/constants.py index bea3047d4..ba40f7884 100644 --- a/mephisto/tools/db_data_porter/constants.py +++ b/mephisto/tools/db_data_porter/constants.py @@ -7,7 +7,7 @@ from mephisto.abstractions.providers.mock.provider_type import PROVIDER_TYPE as MOCK_PROVIDER_TYPE from mephisto.abstractions.providers.mturk.provider_type import PROVIDER_TYPE as MTURK_PROVIDER_TYPE from mephisto.abstractions.providers.prolific.provider_type import ( - PROVIDER_TYPE as PROLIFIC_PROVIDER_TYPE + PROVIDER_TYPE as PROLIFIC_PROVIDER_TYPE, ) @@ -213,3 +213,5 @@ # We mark rows in `imported_data` with labels and this label is used # if conflicted row was already presented in local DB LOCAL_DB_LABEL = "_" + +DEFAULT_ARCHIVE_FORMAT = "zip" diff --git a/mephisto/tools/db_data_porter/db_data_porter.py b/mephisto/tools/db_data_porter/db_data_porter.py index 237cfb235..bc916a820 100644 --- a/mephisto/tools/db_data_porter/db_data_porter.py +++ b/mephisto/tools/db_data_porter/db_data_porter.py @@ -6,19 +6,24 @@ import json import os +import zipfile from datetime import datetime from typing import Dict from typing import List from typing import Optional from typing import Union +from rich.console import Console + from mephisto.abstractions.database import MephistoDB from mephisto.abstractions.databases.local_database import LocalMephistoDB from mephisto.generators.form_composer.config_validation.utils import make_error_message from mephisto.tools.db_data_porter import backups +from mephisto.tools.db_data_porter import export_dump from mephisto.tools.db_data_porter import dumps from mephisto.tools.db_data_porter import import_dump from mephisto.tools.db_data_porter.constants import BACKUP_OUTPUT_DIR +from mephisto.tools.db_data_porter.constants import DEFAULT_ARCHIVE_FORMAT from mephisto.tools.db_data_porter.constants import EXPORT_OUTPUT_DIR from mephisto.tools.db_data_porter.constants import IMPORTED_DATA_TABLE_NAME from mephisto.tools.db_data_porter.constants import MEPHISTO_DUMP_KEY @@ -27,9 +32,9 @@ from mephisto.tools.db_data_porter.randomize_ids import randomize_ids from mephisto.tools.db_data_porter.validation import validate_dump_data from mephisto.utils import db as db_utils +from mephisto.utils.console_writer import ConsoleWriter from mephisto.utils.dirs import get_data_dir from mephisto.utils.misc import serialize_date_to_python -from mephisto.utils.console_writer import ConsoleWriter logger = ConsoleWriter() @@ -55,9 +60,9 @@ def __init__(self, db=None): @staticmethod def _get_root_mephisto_repo_dir() -> str: - return os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname( - os.path.abspath(__file__) - )))) + return os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + ) def _get_export_dir(self) -> str: root_dir = self._get_root_mephisto_repo_dir() @@ -114,6 +119,14 @@ def _prepare_dump_data( dump_data = randomize_ids_results["updated_dump"] self._pk_substitutions = randomize_ids_results["pk_substitutions"] + legacy_ids_found = any([v for v in self._pk_substitutions.values()]) + if not legacy_ids_found: + logger.info( + "Note that there was no need to randomize any ids, " + "because your Mephisto DB and provider-specific " + "datastores do not contain any legacy ids." + ) + return dump_data def _get_latest_migrations(self) -> Dict[str, Union[None, str]]: @@ -139,13 +152,14 @@ def _get_latest_migrations(self) -> Dict[str, Union[None, str]]: @staticmethod def _ask_user_if_they_are_sure() -> bool: - question = input( + console = Console() + question = console.input( "Are you sure? " "It will affect your databases and related files. " - "Type 'yes' and press Enter if you want to proceed: " + "Type '[green]yes[/green]' and press Enter if you want to proceed: " ) if question != "yes": - logger.info("Ok. Bye") + logger.info("Exiting now ...") exit() return True @@ -166,6 +180,7 @@ def export_dump( task_runs_labels: Optional[List[str]] = None, delete_exported_data: bool = False, randomize_legacy_ids: bool = False, + metadata_export_options: dict = None, verbosity: int = 0, ) -> dict: # 1. Protect from accidental launches @@ -173,15 +188,25 @@ def export_dump( self._ask_user_if_they_are_sure() # 2. Prepare dump data with Mephisto DB and provider datastores + logger.info(f"Started exporting data ...") + since_datetime = None if task_runs_since_date: try: since_datetime = serialize_date_to_python(task_runs_since_date) except Exception: - error_message = f"Could not parse date '{task_runs_since_date}'." - logger.exception(f"[red]{error_message}[/red]") + error_message = ( + f"Could not parse date '{task_runs_since_date}'. " + f"Expected ISO 8601 format in UTC timezone." + f"\n For example:" + f"\n\t - 2024-01-24" + f"\n\t - 2024-01-24T01:10:30" + ) + logger.error(f"[red]{error_message}[/red]") exit() + logger.info(f"Copying database records ...") + dump_data_to_export = self._prepare_dump_data( task_names=task_names, task_ids=task_ids, @@ -191,23 +216,27 @@ def export_dump( randomize_legacy_ids=randomize_legacy_ids, ) - # 3. Prepare export dirs and get dump file path + # 3. Prepare export dirs and get dump file path. + # JSON file is going to be located in tmp directory, + # where we add all related files and then archive them all together export_dir = self._get_export_dir() dump_timestamp = self._make_export_timestamp() dump_name = self._make_dump_name(dump_timestamp) - file_path = self._make_export_dump_file_path(export_dir, dump_name) + tmp_export_dir = export_dump.make_tmp_export_dir() + tmp_dump_json_file_path = self._make_export_dump_file_path(tmp_export_dir, dump_name) # 4. Prepare metadata metadata = { "migrations": self._get_latest_migrations(), - "export_parameters": { - "--export-indent": json_indent, - "--export-tasks-by-names": task_names, - "--export-tasks-by-ids": task_ids, - "--export-task-runs-by-ids": task_run_ids, - "--export-task-runs-since-date": task_runs_since_date, - "--verbosity": verbosity, - }, + "export_options": metadata_export_options, + # "export_options": { + # "--export-indent": json_indent, + # "--export-tasks-by-names": task_names, + # "--export-tasks-by-ids": task_ids, + # "--export-task-runs-by-ids": task_run_ids, + # "--export-task-runs-since-date": task_runs_since_date, + # "--verbosity": verbosity, + # }, "timestamp": dump_timestamp, "pk_substitutions": self._pk_substitutions, } @@ -215,19 +244,19 @@ def export_dump( # 5. Save JSON file try: - with open(file_path, "w") as f: + with open(tmp_dump_json_file_path, "w") as f: f.write(json.dumps(dump_data_to_export, indent=json_indent)) except Exception as e: # Remove file to not make a mess in export directory error_message = f"Could not create dump file {dump_data_to_export}. Reason: {str(e)}." - - if verbosity: - logger.exception(f"[red]{error_message}[/red]") - os.remove(file_path) + logger.exception(f"[red]{error_message}[/red]") + os.remove(tmp_dump_json_file_path) exit() + logger.info(f"Copying database records finished") + # 6. Archive files in file system - exported = backups.archive_and_copy_data_files( + exported = export_dump.archive_and_copy_data_files( self.db, export_dir, dump_name, @@ -239,8 +268,13 @@ def export_dump( # 7. Delete exported data if needed after backing data up backup_path = None if delete_exported_data: + logger.info( + f"Backing up your current data and removing exported data from local data ..." + ) + backup_dir = self._get_backup_dir() - backup_path = backups.make_full_data_dir_backup(backup_dir, dump_timestamp) + backup_path = backups.make_backup_file_path_by_timestamp(backup_dir, dump_timestamp) + backups.make_full_data_dir_backup(backup_path) delete_tasks = bool(task_names or task_ids) is_partial_dump = bool(task_names or task_ids or task_run_ids or task_runs_since_date) dumps.delete_exported_data( @@ -250,63 +284,77 @@ def export_dump( partial=is_partial_dump, delete_tasks=delete_tasks, ) + logger.info(f"Backing up of your current data and removing of exported data finished") data_path = None if exported: data_path = os.path.join( - export_dir, f"{dump_name}.{backups.DEFAULT_ARCHIVE_FORMAT}", + export_dir, + f"{dump_name}.{DEFAULT_ARCHIVE_FORMAT}", ) return { - "db_path": file_path, - "data_path": data_path, + "dump_path": data_path, "backup_path": backup_path, } def import_dump( self, - dump_file_name_or_path: str, + dump_archive_file_name_or_path: str, conflict_resolver_name: str, - label: Optional[str] = None, + labels: Optional[List[str]] = None, keep_import_metadata: Optional[bool] = None, verbosity: int = 0, ): # 1. Check dump file path - is_dump_path_full = os.path.isabs(dump_file_name_or_path) + if not dump_archive_file_name_or_path: + error_message = "Option `-f/--file` is required." + logger.error(f"[red]{error_message}[/red]") + exit() + + is_dump_path_full = os.path.isabs(dump_archive_file_name_or_path) if not is_dump_path_full: root_dir = self._get_root_mephisto_repo_dir() - dump_file_name_or_path = os.path.join( - root_dir, EXPORT_OUTPUT_DIR, dump_file_name_or_path, + dump_archive_file_name_or_path = os.path.join( + root_dir, + EXPORT_OUTPUT_DIR, + dump_archive_file_name_or_path, ) - if not os.path.exists(dump_file_name_or_path): + if not os.path.exists(dump_archive_file_name_or_path): error_message = ( - f"Could not find dump file '{dump_file_name_or_path}'. " - f"Please, specify full path to existing file or " - f"just file name that is located in `/{EXPORT_OUTPUT_DIR}`." + f"Could not find dump file '{dump_archive_file_name_or_path}'. " + f"Please specify full path to existing file or " + f"only filename if located in /{EXPORT_OUTPUT_DIR}." ) - if verbosity: - logger.exception(f"[red]{error_message}[/red]") + logger.error(f"[red]{error_message}[/red]") exit() - # 2. Read dump file - with open(dump_file_name_or_path, "r") as f: - try: - dump_file_data: dict = json.loads(f.read()) - except Exception as e: - error_message = ( - f"Could not read JSON from dump file '{dump_file_name_or_path}'. " - f"Please, check if it has the correct format. Reason: {str(e)}" - ) - logger.exception(f"[red]{error_message}[/red]") - exit() + # 2. Read JSON dump file from archive + with zipfile.ZipFile(dump_archive_file_name_or_path) as archive: + dump_name = os.path.basename(os.path.splitext(dump_archive_file_name_or_path)[0]) + json_dump_file_name = f"{dump_name}.json" + + with archive.open(json_dump_file_name) as f: + try: + dump_file_data: dict = json.loads(f.read()) + except Exception as e: + error_message = ( + f"Could not read JSON from dump file '{dump_archive_file_name_or_path}'. " + f"Please, check if file '{json_dump_file_name}' in it " + f"has the correct format. Reason: {str(e)}" + ) + logger.exception(f"[red]{error_message}[/red]") + exit() # 3. Validate dump dump_data_errors = validate_dump_data(self.db, dump_file_data) if dump_data_errors: error_message = make_error_message( - "Your dump file has incorrect format", dump_data_errors, indent=4, + "Your dump file has incorrect format", + dump_data_errors, + indent=4, ) logger.error(f"[red]{error_message}[/red]") exit() @@ -326,19 +374,28 @@ def import_dump( ) backup_dir = self._get_backup_dir() dump_timestamp = self._make_export_timestamp() - backup_path = backups.make_full_data_dir_backup(backup_dir, dump_timestamp) + backup_path = backups.make_backup_file_path_by_timestamp(backup_dir, dump_timestamp) + backups.make_full_data_dir_backup(backup_path) logger.info(f"Backup was created successfully! File: '{backup_path}'") # 7. Write dump data into local DBs + logger.info(f"Started importing from dump file {dump_archive_file_name_or_path} ...") + + imported_task_runs_number = 0 + for db_or_datastore_name, db_or_datastore_data in dump_file_data.items(): + # Pop `imported_data` from dump content, to merge it into local `imported_data` + # when option `--keep-import-metadata` is passed imported_data_from_dump = [] if db_or_datastore_name == MEPHISTO_DUMP_KEY: # Main Mephisto database db = self.db imported_data_from_dump = dump_file_data.get(MEPHISTO_DUMP_KEY, {}).pop( - IMPORTED_DATA_TABLE_NAME, [], + IMPORTED_DATA_TABLE_NAME, + [], ) + imported_task_runs_number = len(db_or_datastore_data.get("task_runs", [])) else: # Provider's datastore. # NOTE: It is being created if it does not exist (yes, here, magically) @@ -354,22 +411,22 @@ def import_dump( db = datastore if verbosity: - logger.info(f"Start importing into `{db_or_datastore_name}` database") + logger.debug(f"Start importing into `{db_or_datastore_name}` database") - label = label or self._get_label_from_file_path(dump_file_name_or_path) + labels = labels or [self._get_label_from_file_path(dump_archive_file_name_or_path)] import_single_db_results = import_dump.import_single_db( db=db, provider_type=db_or_datastore_name, dump_data=db_or_datastore_data, conflict_resolver_name=conflict_resolver_name, - label=label, + labels=labels, verbosity=verbosity, ) errors = import_single_db_results["errors"] if errors: - error_message = make_error_message("Import was not processed", errors, indent=4) + error_message = make_error_message("Nothing was imported", errors, indent=4) logger.error(f"[red]{error_message}[/red]") # Simulating rollback for all databases/datastores and related data files @@ -377,60 +434,89 @@ def import_dump( backup_path = backups.make_backup_file_path_by_timestamp(backup_dir, dump_timestamp) if verbosity: - logger.info(f"Rolling back all changed from backup '{backup_path}'") + logger.debug(f"Rolling back all changed from backup {backup_path} ...") backups.restore_from_backup(backup_path, mephisto_data_path) + if verbosity: + logger.debug(f"Rolling back finished") + exit() - # Write imformation in `imported_data` if db_or_datastore_name == MEPHISTO_DUMP_KEY: + # Unpack files related to the imported TaskRuns + dump_archive_file_path = ( + os.path.splitext(dump_archive_file_name_or_path)[0] + + f".{DEFAULT_ARCHIVE_FORMAT}" + ) + export_dump.unarchive_data_files(dump_archive_file_path, verbosity=verbosity) + + # Write imformation in `imported_data` # Fill `imported_data` table with imported dump import_dump.fill_imported_data_with_imported_dump( db=db, imported_data=import_single_db_results["imported_data"], - source_file_name=os.path.basename(dump_file_name_or_path), + source_file_name=os.path.basename(dump_archive_file_name_or_path), + verbosity=verbosity, ) # Fill `imported_data` with information from `imported_data` from dump if keep_import_metadata and imported_data_from_dump: - import_dump.import_table_imported_data_from_dump(db, imported_data_from_dump) + import_dump.import_table_imported_data_from_dump( + db, + imported_data_from_dump, + verbosity=verbosity, + ) if verbosity: - logger.info( + logger.debug( f"Finished importing into `{db_or_datastore_name}` database successfully!" ) - def make_backup(self) -> str: + return { + "imported_task_runs_number": imported_task_runs_number, + } + + def create_backup(self, verbosity: int = 0) -> str: backup_dir = self._get_backup_dir() dump_timestamp = self._make_export_timestamp() - backup_path = backups.make_full_data_dir_backup(backup_dir, dump_timestamp) + backup_path = backups.make_backup_file_path_by_timestamp(backup_dir, dump_timestamp) + + logger.info(f"Creating backup file ...") + + backups.make_full_data_dir_backup(backup_path) return backup_path def restore_from_backup(self, backup_file_name_or_path: str, verbosity: int = 0): - # 1. Protect from accidental launches - self._ask_user_if_they_are_sure() + # 1. Check backup file path + if not backup_file_name_or_path: + error_message = "Option `-f/--file` is required." + logger.error(f"[red]{error_message}[/red]") + exit() - # 2. Check backup file path is_backup_path_full = os.path.isabs(backup_file_name_or_path) if not is_backup_path_full: root_dir = self._get_root_mephisto_repo_dir() backup_file_name_or_path = os.path.join( - root_dir, BACKUP_OUTPUT_DIR, backup_file_name_or_path, + root_dir, + BACKUP_OUTPUT_DIR, + backup_file_name_or_path, ) if not os.path.exists(backup_file_name_or_path): error_message = ( - f"Could not find backup file '{backup_file_name_or_path}'. " - f"Please, specify full path to existing file or " - f"just file name that is located in `/{BACKUP_OUTPUT_DIR}`." + f"Could not find backup file {backup_file_name_or_path}. " + f"Please specify full path to existing file or " + f"only filename if located in /{BACKUP_OUTPUT_DIR}." ) - logger.exception(f"[red]{error_message}[/red]") + logger.error(f"[red]{error_message}[/red]") exit() - if verbosity and not is_backup_path_full: - logger.info(f"Found backup file '{backup_file_name_or_path}'") + # 2. Protect from accidental launches + self._ask_user_if_they_are_sure() # 3. Restore + logger.info(f"Started restoring from backup {backup_file_name_or_path} ...") + mephisto_data_path = get_data_dir() backups.restore_from_backup(backup_file_name_or_path, mephisto_data_path) diff --git a/mephisto/tools/db_data_porter/dumps.py b/mephisto/tools/db_data_porter/dumps.py index 34618ace0..dfb5952ba 100644 --- a/mephisto/tools/db_data_porter/dumps.py +++ b/mephisto/tools/db_data_porter/dumps.py @@ -26,6 +26,28 @@ logger = ConsoleWriter() +def _make_options_error_message( + title: str, + values: List[str], + not_found_values: List[str], + available_values: Optional[List[str]] = None, +) -> str: + available_values_string = "" + if available_values: + available_values_string = ( + f"\nThere are {len(available_values)} available values: {', '.join(available_values)}" + ) + + return ( + f"[red]" + f"You provided incorrect {title}. " + f"\nProvided {len(values)} values: {', '.join(values)}. " + f"\nNot found {len(not_found_values)} values: {', '.join(not_found_values)}." + f"{available_values_string}" + f"[/red]" + ) + + def prepare_partial_dump_data( db: "MephistoDB", task_names: Optional[List[str]] = None, @@ -42,19 +64,97 @@ def prepare_partial_dump_data( if not task_run_ids: if task_names or task_ids: if task_names: + # Validate on correct values of passed Task names + db_tasks = db_utils.select_rows_by_list_of_field_values( + db, + "tasks", + ["task_name"], + [task_names], + ) + if len(task_names) != len(db_tasks): + db_task_names = [t["task_name"] for t in db_tasks] + not_found_values = [t for t in task_names if t not in db_task_names] + logger.error( + _make_options_error_message("Task names", task_names, not_found_values) + ) + exit() + + # Get Task IDs by their names task_ids = db_utils.get_task_ids_by_task_names(db, task_names) + else: + # Validate on correct values of passed Task IDs + db_tasks = db_utils.select_rows_by_list_of_field_values( + db, + "tasks", + ["task_id"], + [task_ids], + ) + if len(task_ids) != len(db_tasks): + db_task_ids = [t["task_id"] for t in db_tasks] + not_found_values = [t for t in task_ids if t not in db_task_ids] + logger.error( + _make_options_error_message("Task IDs", task_ids, not_found_values) + ) + exit() + task_ids = task_ids or [] + + # Get TaskRun IDs by Task IDs task_run_ids = db_utils.get_task_run_ids_ids_by_task_ids(db, task_ids) elif task_runs_labels: + # Validate on correct values of passed TaskRun labels + db_labels = db_utils.get_list_of_available_labels(db) + not_found_values = [t for t in task_runs_labels if t not in db_labels] + if not_found_values: + logger.error( + _make_options_error_message( + "TaskRun labels", + task_runs_labels, + not_found_values, + db_labels, + ) + ) + exit() + + # Get TaskRun IDs task_run_ids = db_utils.get_task_run_ids_ids_by_labels(db, task_runs_labels) elif since_datetime: + # Get TaskRun IDs task_run_ids = db_utils.select_task_run_ids_since_date(db, since_datetime) - logger.info(f"Run command for TaskRun IDs: {', '.join(task_run_ids)}.") + if not task_run_ids: + logger.error( + f"Nothing to export - " + f"no TaskRuns found that were created after {since_datetime}" + ) + exit() + else: + # Validate on correct values of passed TaskRun IDs + db_task_runs = db_utils.select_rows_by_list_of_field_values( + db, + "task_runs", + ["task_run_id"], + [task_run_ids], + ) + if len(task_run_ids) != len(db_task_runs): + db_task_run_ids = [t["task_run_id"] for t in db_task_runs] + not_found_values = [t for t in task_run_ids if t not in db_task_run_ids] + logger.error(_make_options_error_message("TaskRun IDs", task_run_ids, not_found_values)) + exit() + + if task_run_ids: + logger.info(f"Run command for TaskRun IDs: {', '.join(task_run_ids)}.") + else: + logger.error("[yellow]Nothing to export - no TaskRuns found[/yellow]") + exit() mephisto_db_data = db_utils.mephisto_db_to_dict_for_task_runs(db, task_run_ids) dump_data_to_export[MEPHISTO_DUMP_KEY] = mephisto_db_data + if not mephisto_db_data.get("task_runs"): + logger.error("[yellow]Nothing to export - no TaskRuns found[/yellow]") + exit() + # Providers' DBs provider_types = [i["provider_type"] for i in mephisto_db_data["requesters"]] @@ -68,11 +168,14 @@ def prepare_partial_dump_data( # There is a provider-specific logic of exporting DB data as it can have any scheme. # It can be missed and not implemented at all datastore_export_method: MethodType = getattr( - provider_datastore, DATASTORE_EXPORT_METHOD_NAME, None, + provider_datastore, + DATASTORE_EXPORT_METHOD_NAME, + None, ) if datastore_export_method: datastore_export_data = datastore_export_method( - task_run_ids=task_run_ids, mephisto_db_data=mephisto_db_data, + task_run_ids=task_run_ids, + mephisto_db_data=mephisto_db_data, ) else: # If method was not implemented in provider datastore, we export all tables fully. @@ -93,7 +196,7 @@ def prepare_partial_dump_data( def prepare_full_dump_data(db: "MephistoDB", provider_datastores: Dict[str, "MephistoDB"]) -> dict: dump_data_to_export = {} - logger.info(f"Run command for all TaskRuns.") + logger.info(f"No filter for TaskRuns specified - exporting all TaskRuns.") # Mephisto DB dump_data_to_export[MEPHISTO_DUMP_KEY] = db_utils.db_or_datastore_to_dict(db) @@ -146,7 +249,8 @@ def delete_exported_data( # Get directories related to dumped TaskRuns task_run_rows = dump_data_to_export.get(MEPHISTO_DUMP_KEY, {}).get( - TASK_RUNS_TABLE_NAME, [], + TASK_RUNS_TABLE_NAME, + [], ) task_runs_pk_field_name = db_utils.get_table_pk_field_name(db, TASK_RUNS_TABLE_NAME) task_run_ids = [r[task_runs_pk_field_name] for r in task_run_rows] @@ -159,7 +263,9 @@ def delete_exported_data( # Clean DB db_utils.delete_exported_data_without_fk_constraints( - db, dump_data_to_export[MEPHISTO_DUMP_KEY], names_of_tables_to_cleanup, + db, + dump_data_to_export[MEPHISTO_DUMP_KEY], + names_of_tables_to_cleanup, ) # Clean related files diff --git a/mephisto/tools/db_data_porter/export_dump.py b/mephisto/tools/db_data_porter/export_dump.py new file mode 100644 index 000000000..98e030904 --- /dev/null +++ b/mephisto/tools/db_data_porter/export_dump.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import shutil +from distutils.dir_util import copy_tree +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List + +from rich_click import RichContext + +from mephisto.data_model.task_run import TaskRun +from mephisto.tools.db_data_porter.constants import AGENTS_TABLE_NAME +from mephisto.tools.db_data_porter.constants import ASSIGNMENTS_TABLE_NAME +from mephisto.tools.db_data_porter.constants import DEFAULT_ARCHIVE_FORMAT +from mephisto.tools.db_data_porter.constants import MEPHISTO_DUMP_KEY +from mephisto.tools.db_data_porter.constants import TASK_RUNS_TABLE_NAME +from mephisto.tools.db_data_porter.randomize_ids import get_old_pk_from_substitutions +from mephisto.utils import db as db_utils +from mephisto.utils.console_writer import ConsoleWriter +from mephisto.utils.dirs import get_data_dir +from mephisto.utils.dirs import get_mephisto_tmp_dir + +logger = ConsoleWriter() + + +def make_tmp_export_dir() -> str: + tmp_dir = get_mephisto_tmp_dir() + tmp_export_dir = os.path.join(tmp_dir, "export") + os.makedirs(tmp_export_dir, exist_ok=True) + return tmp_export_dir + + +def _rename_dirs_with_new_pks(task_run_dirs: List[str], pk_substitutions: dict): + def rename_dir_with_new_pk(dir_path: str, substitutions: dict) -> str: + dump_id = substitutions.get(os.path.basename(dir_path)) + renamed_dir_path = os.path.join(os.path.dirname(dir_path), dump_id) + os.rename(dir_path, renamed_dir_path) + return renamed_dir_path + + task_runs_subs = pk_substitutions.get(MEPHISTO_DUMP_KEY, {}).get(TASK_RUNS_TABLE_NAME, {}) + if not task_runs_subs: + # Nothing to rename + return + + assignment_subs = pk_substitutions.get(MEPHISTO_DUMP_KEY, {}).get(ASSIGNMENTS_TABLE_NAME, {}) + agent_subs = pk_substitutions.get(MEPHISTO_DUMP_KEY, {}).get(AGENTS_TABLE_NAME, {}) + + task_run_dirs = [d for d in task_run_dirs if os.path.basename(d) in task_runs_subs.keys()] + for task_run_dir in task_run_dirs: + # Rename TaskRun dir + renamed_task_run_dir = rename_dir_with_new_pk(task_run_dir, task_runs_subs) + + # Rename Assignments dirs + assignments_dirs = [ + os.path.join(renamed_task_run_dir, d) + for d in os.listdir(renamed_task_run_dir) + if d in assignment_subs.keys() + ] + for assignment_dir in assignments_dirs: + renamed_assignment_dir = rename_dir_with_new_pk(assignment_dir, assignment_subs) + + # Rename Agents dirs + agents_dirs = [ + os.path.join(renamed_assignment_dir, d) + for d in os.listdir(renamed_assignment_dir) + if d in agent_subs.keys() + ] + for agent_dir in agents_dirs: + rename_dir_with_new_pk(agent_dir, agent_subs) + + +def _export_data_dir_for_task_runs( + input_dir_path: str, + archive_file_path_without_ext: str, + task_runs: List[TaskRun], + pk_substitutions: dict, + _format: str = DEFAULT_ARCHIVE_FORMAT, + verbosity: int = 0, +) -> bool: + tmp_export_dir = make_tmp_export_dir() + + task_run_data_dirs = [i.get_run_dir() for i in task_runs] + if not task_run_data_dirs: + return False + + try: + tmp_task_run_dirs = [] + + # Copy all files for passed TaskRuns into tmp dir + for task_run_data_dir in task_run_data_dirs: + relative_dir = Path(task_run_data_dir).relative_to(input_dir_path) + tmp_task_run_dir = os.path.join(tmp_export_dir, relative_dir) + + tmp_task_run_dirs.append(tmp_task_run_dir) + + os.makedirs(tmp_task_run_dir, exist_ok=True) + copy_tree(task_run_data_dir, tmp_task_run_dir, verbose=0) + + _rename_dirs_with_new_pks(tmp_task_run_dirs, pk_substitutions) + + # Create archive in export dir + shutil.make_archive( + base_name=archive_file_path_without_ext, + format="zip", + root_dir=tmp_export_dir, + ) + finally: + # Remove tmp dir + if os.path.exists(tmp_export_dir): + shutil.rmtree(tmp_export_dir) + + return True + + +def archive_and_copy_data_files( + db: "MephistoDB", + export_dir: str, + dump_name: str, + dump_data: dict, + pk_substitutions: dict, + _format: str = DEFAULT_ARCHIVE_FORMAT, + verbosity: int = 0, +) -> bool: + mephisto_data_files_path = os.path.join(get_data_dir(), "data") + output_zip_file_base_name = os.path.join(export_dir, dump_name) # name without extension + + if verbosity: + logger.debug(f"Archiving data files started ...") + + # Get TaskRuns for PKs in dump + task_runs: List[TaskRun] = [] + task_runs_ids: List[str] = [] + for dump_task_run in dump_data[MEPHISTO_DUMP_KEY][TASK_RUNS_TABLE_NAME]: + task_runs_pk_field_name = db_utils.get_table_pk_field_name(db, TASK_RUNS_TABLE_NAME) + dump_pk = dump_task_run[task_runs_pk_field_name] + db_pk = get_old_pk_from_substitutions(dump_pk, pk_substitutions, TASK_RUNS_TABLE_NAME) + db_pk = db_pk or dump_pk + task_run: TaskRun = TaskRun.get(db, db_pk) + task_runs.append(task_run) + task_runs_ids.append(db_pk) + + if verbosity: + logger.debug(f"Archiving data files for TaskRuns: {', '.join(task_runs_ids)}") + + # Export archived related data files to TaskRuns from dump + exported = _export_data_dir_for_task_runs( + input_dir_path=mephisto_data_files_path, + archive_file_path_without_ext=output_zip_file_base_name, + task_runs=task_runs, + pk_substitutions=pk_substitutions, + _format=_format, + verbosity=verbosity, + ) + + if verbosity: + logger.debug(f"Archiving data files finished") + + return exported + + +def unarchive_data_files( + dump_file_path: str, + _format: str = DEFAULT_ARCHIVE_FORMAT, + verbosity: int = 0, +): + # Local directory with data files for TaskRuns + mephisto_data_files_path = os.path.join(get_data_dir(), "data") + mephisto_data_runs_path = os.path.join(mephisto_data_files_path, "runs") + + # Tmp directory where data files for TaskRuns will be unarchived from dump to + tmp_dir = get_mephisto_tmp_dir() + tmp_unarchive_dir = os.path.join(tmp_dir, "unarchive") + tmp_unarchive_task_runs_dir = os.path.join(tmp_unarchive_dir, "runs") + + try: + # Unarchive into tmp directory + if verbosity: + logger.debug("Unpacking TaskRuns files ...") + + shutil.unpack_archive( + filename=dump_file_path, + extract_dir=tmp_unarchive_dir, + format=_format, + ) + + if verbosity: + logger.debug("Unpacking TaskRuns files finished") + + # Copy files + if verbosity: + logger.debug("Copying TaskRuns files into {mephisto_data_runs_path} ...") + + copy_tree(tmp_unarchive_task_runs_dir, mephisto_data_runs_path, verbose=0) + + if verbosity: + logger.debug("Copying TaskRuns files finished") + except Exception as e: + logger.exception("Could not unpack TaskRuns files from dump") + exit() + finally: + # Remove tmp dir with dump data files + if verbosity: + logger.debug("Removing unpacked TaskRuns files ...") + + if os.path.exists(tmp_unarchive_dir): + shutil.rmtree(tmp_unarchive_dir) + + if verbosity: + logger.debug("Removing unpacked TaskRuns files finished") + + +def get_export_options_for_metadata(ctx: RichContext, options: dict) -> Dict[str, Any]: + export_options_for_metadata = {} + + for param in ctx.command.params: + option_name = "/".join(param.opts) # Concatenated option name variants (short/full) + values = options[p.name] + export_options_for_metadata[option_name] = values + + return export_options_for_metadata diff --git a/mephisto/tools/db_data_porter/import_dump.py b/mephisto/tools/db_data_porter/import_dump.py index 7390d928c..d2771f920 100644 --- a/mephisto/tools/db_data_porter/import_dump.py +++ b/mephisto/tools/db_data_porter/import_dump.py @@ -66,7 +66,7 @@ def import_single_db( provider_type: str, dump_data: dict, conflict_resolver_name: str, - label: str, + labels: List[str], verbosity: int = 0, ) -> ImportSingleDBsType: # Results of the function @@ -117,12 +117,11 @@ def import_single_db( # Imported data vars imported_data_needs_to_be_updated = ( - provider_type == MEPHISTO_DUMP_KEY and - table_name in IMPORTED_DATA_TABLE_NAMES + provider_type == MEPHISTO_DUMP_KEY and table_name in IMPORTED_DATA_TABLE_NAMES ) - newly_imported_labels = json.dumps([label]) - conflicted_labels = json.dumps([LOCAL_DB_LABEL, label]) + newly_imported_labels = json.dumps(sorted(labels)) + conflicted_labels = json.dumps(sorted([LOCAL_DB_LABEL, *labels])) imported_data_for_table = { newly_imported_labels: [], conflicted_labels: [], @@ -146,7 +145,10 @@ def import_single_db( imported_data_conflicted_row = False _update_row_with_pks_from_resolvings_mappings( - db, table_name, dump_row, resolvings_mapping, + db, + table_name, + dump_row, + resolvings_mapping, ) # Table with non-PK unique field @@ -169,7 +171,7 @@ def import_single_db( # If local DB does not have this row if not existing_rows: if verbosity: - logger.info(f"Inserting new row into table '{table_name}': {dump_row}") + logger.debug(f"Inserting new row into table '{table_name}': {dump_row}") db_utils.insert_new_row_in_table(db, table_name, dump_row) @@ -180,17 +182,30 @@ def import_single_db( existing_db_row = existing_rows[-1] if verbosity: - logger.info( + logger.debug( f"Conflicts during inserting row in table '{table_name}': " f"{dump_row}. " f"Existing row in your database: {existing_db_row}" ) resolved_conflicting_row = conflict_resolver_name.resolve( - table_name, table_pk_field_name, existing_db_row, dump_row, + table_name, + table_pk_field_name, + existing_db_row, + dump_row, ) + + if verbosity: + logger.debug( + f"Resolving finished successfully. " + f"Chosen row: {resolved_conflicting_row}" + ) + db_utils.update_row_in_table( - db, table_name, resolved_conflicting_row, table_pk_field_name, + db, + table_name, + resolved_conflicting_row, + table_pk_field_name, ) # Saving resolved a pair of PKs @@ -205,6 +220,9 @@ def import_single_db( # Regular table. Create new row as is else: + if verbosity: + logger.debug(f"Inserting new row into table '{table_name}': {dump_row}") + db_utils.insert_new_row_in_table(db, table_name, dump_row) # Update table lists of Imported data @@ -214,10 +232,12 @@ def import_single_db( else: _label = newly_imported_labels - imported_data_for_table[_label].append({ - UNIQUE_FIELD_NAMES: unique_field_names or [table_pk_field_name], - UNIQUE_FIELD_VALUES: imported_data_row_unique_field_values, - }) + imported_data_for_table[_label].append( + { + UNIQUE_FIELD_NAMES: unique_field_names or [table_pk_field_name], + UNIQUE_FIELD_VALUES: imported_data_row_unique_field_values, + } + ) # Add table into Imported data if imported_data_needs_to_be_updated: @@ -228,38 +248,45 @@ def import_single_db( if provider_type == MEPHISTO_DUMP_KEY: for unit_id, agent_id in units_agents.items(): db_utils.update_row_in_table( - db, "units", {"unit_id": unit_id, "agent_id": agent_id}, "unit_id", + db, + "units", + {"unit_id": unit_id, "agent_id": agent_id}, + "unit_id", ) # --- HACK (#UNIT.AGENT_ID) END #3: except Exception as e: + error_message_ending = "" + # Custom error message in cases when we can guess what happens # using small info SQLite gives us possible_issue = "" if in_progress_table_pk_field_name in str(e) and "UNIQUE constraint" in str(e): pk_value = in_progress_dump_row[in_progress_table_pk_field_name] + error_message_ending = ( + f". Local database already has Primary Key '{pk_value}' " + f"in table '{in_progress_table_name}'." + ) possible_issue = ( - f"\nPossible issue: " - f"Local database already have Primary Key '{pk_value}' " - f"in table '{in_progress_table_name}'. " - f"Maybe you are trying to run already merged dump file. " - f"Or if you have old databases, you may bump into same Primary Keys. " - f"If you are sure that all data from this dump is unique and " - f"still have access to the dumped project, " - f"try to create dump with parameter `--randomize-legacy-ids` " - f"and start importing again." + f"\n\n[bold]Possible issue:[/bold] " + f"You may be trying to import an already imported dump file. " + f"Or this could be related to legacy auto-increment database primary keys " + f"(in which case the dump needs to be re-created " + f"with -r/--randomize-legacy-ids option)." ) - default_error_message_beginning = "" + error_message_beginning = "" if not possible_issue: - default_error_message_beginning = "Unexpected error happened: " + error_message_beginning = "Unexpected error happened: " errors.append( - f"{default_error_message_beginning}{e}." + f"{error_message_beginning}{e}{error_message_ending}" f"{possible_issue}" - f"\nProvider: {provider_type}." - f"\nTable: {in_progress_table_name}." - f"\nRow: {json.dumps(in_progress_dump_row, indent=2)}." + f"\n" + f"\n[bold]Provider:[/bold] {provider_type}." + f"\n[bold]Table:[/bold] {in_progress_table_name}." + f"\n[bold]Row:[/bold]\n{json.dumps(in_progress_dump_row, indent=2)}." + f"\n" ) return { @@ -269,8 +296,15 @@ def import_single_db( def fill_imported_data_with_imported_dump( - db: "MephistoDB", imported_data: dict, source_file_name: str, + db: "MephistoDB", + imported_data: dict, + source_file_name: str, + verbosity: int = 0, ): + if verbosity: + if imported_data: + logger.debug("Saving information about imported data ...") + for table_name, table_info in imported_data.items(): for labels, labels_rows in table_info.items(): for row in labels_rows: @@ -291,8 +325,22 @@ def fill_imported_data_with_imported_dump( }, ) + if verbosity: + if imported_data: + logger.debug("Saving information about imported data finished") + + +def import_table_imported_data_from_dump( + db: "MephistoDB", + imported_data_rows: List[dict], + verbosity: int = 0, +): + if verbosity: + if imported_data_rows: + logger.debug( + "Updating local information about imported data with imported data from dump ..." + ) -def import_table_imported_data_from_dump(db: "MephistoDB", imported_data_rows: List[dict]): for row in imported_data_rows: table_name = row["table_name"] unique_field_names = row["unique_field_names"] @@ -316,13 +364,19 @@ def import_table_imported_data_from_dump(db: "MephistoDB", imported_data_rows: L # Update existing row if existing_row: + if verbosity: + logger.debug(f"Updating already existing row for `{table_name}`: {existing_row}") + # Merge existing labels with from imported row existing_data_labels = json.loads(existing_row["data_labels"]) existing_data_labels += importing_data_labels - existing_row["data_labels"] = json.dumps(list(set(existing_data_labels))) + existing_row["data_labels"] = json.dumps(sorted(list(set(existing_data_labels)))) db_utils.update_row_in_table( - db=db, table_name="imported_data", row=existing_row, pk_field_name="id", + db=db, + table_name="imported_data", + row=existing_row, + pk_field_name="id", ) # Create new row @@ -331,4 +385,14 @@ def import_table_imported_data_from_dump(db: "MephistoDB", imported_data_rows: L row.pop("id", None) row["data_labels"] = json.dumps(data_labels_without_local) + if verbosity: + logger.debug(f"Inserting new row for `{table_name}`: {row}") + db_utils.insert_new_row_in_table(db=db, table_name="imported_data", row=row) + + if verbosity: + if imported_data_rows: + logger.debug( + "Updating local information about imported data " + "with imported data from dump finished" + ) diff --git a/mephisto/tools/db_data_porter/randomize_ids.py b/mephisto/tools/db_data_porter/randomize_ids.py index ab81b75a9..fbb48fbca 100644 --- a/mephisto/tools/db_data_porter/randomize_ids.py +++ b/mephisto/tools/db_data_porter/randomize_ids.py @@ -30,7 +30,9 @@ class RandomizedIDsType(TypedDict): def _randomize_ids_for_mephisto( - db: "MephistoDB", mephisto_dump: dict, legacy_only: bool = False, + db: "MephistoDB", + mephisto_dump: dict, + legacy_only: bool = False, ) -> DBPKSubstitutionsType: table_names = [t for t in mephisto_dump.keys() if t not in [IMPORTED_DATA_TABLE_NAME]] @@ -152,7 +154,9 @@ def _randomize_ids_for_provider( def randomize_ids( - db: "MephistoDB", full_dump: dict, legacy_only: bool = False, + db: "MephistoDB", + full_dump: dict, + legacy_only: bool = False, ) -> RandomizedIDsType: pk_substitutions: PKSubstitutionsType = {} @@ -166,7 +170,9 @@ def randomize_ids( for provider_type in provider_types: provider_dump = full_dump[provider_type] randomized_ids_for_provider = _randomize_ids_for_provider( - provider_type, provider_dump, mephisto_pk_substitutions, + provider_type, + provider_dump, + mephisto_pk_substitutions, ) if randomized_ids_for_provider: @@ -179,7 +185,9 @@ def randomize_ids( def get_old_pk_from_substitutions( - pk: str, substitutions: dict, table_name: str, + pk: str, + substitutions: dict, + table_name: str, ) -> str: # After we created a dump file, we already can have new randomized PKs. # But we still have old ones in Mephisto DB. diff --git a/mephisto/tools/db_data_porter/validation.py b/mephisto/tools/db_data_porter/validation.py index 853ad749f..64f3c67ae 100644 --- a/mephisto/tools/db_data_porter/validation.py +++ b/mephisto/tools/db_data_porter/validation.py @@ -78,8 +78,8 @@ def validate_dump_data(db: "MephistoDB", dump_data: dict) -> Optional[List[str]] ) continue - incorrect_field_names = list(filter( - lambda fn: not isinstance(fn, str), table_row.keys()) + incorrect_field_names = list( + filter(lambda fn: not isinstance(fn, str), table_row.keys()) ) if incorrect_field_names: errors.append( diff --git a/mephisto/utils/db.py b/mephisto/utils/db.py index 43eb7f374..3828edbff 100644 --- a/mephisto/utils/db.py +++ b/mephisto/utils/db.py @@ -29,6 +29,7 @@ # --- Exceptions --- + class MephistoDBException(Exception): pass @@ -51,6 +52,7 @@ class EntryDoesNotExistException(MephistoDBException): # --- Functions --- + def _select_all_rows_from_table(db: "MephistoDB", table_name: str) -> List[dict]: with db.table_access_condition, db.get_connection() as conn: c = conn.cursor() @@ -60,13 +62,17 @@ def _select_all_rows_from_table(db: "MephistoDB", table_name: str) -> List[dict] def _select_rows_from_table_related_to_task( - db: "MephistoDB", table_name: str, task_ids: List[str], + db: "MephistoDB", + table_name: str, + task_ids: List[str], ) -> List[dict]: return select_rows_by_list_of_field_values(db, table_name, ["task_id"], [task_ids]) def select_rows_from_table_related_to_task_run( - db: "MephistoDB", table_name: str, task_run_ids: List[str], + db: "MephistoDB", + table_name: str, + task_run_ids: List[str], ) -> List[dict]: return select_rows_by_list_of_field_values(db, table_name, ["task_run_id"], [task_run_ids]) @@ -107,7 +113,7 @@ def get_task_ids_by_task_names(db: "MephistoDB", task_names: List[str]) -> List[ task_names_string = ",".join([f"'{s}'" for s in task_names]) c.execute( f""" - SELECT task_id FROM tasks + SELECT task_id FROM tasks WHERE task_name IN ({task_names_string}); """ ) @@ -121,7 +127,7 @@ def get_task_run_ids_ids_by_task_ids(db: "MephistoDB", task_ids: List[str]) -> L task_ids_string = ",".join([f"'{s}'" for s in task_ids]) c.execute( f""" - SELECT task_run_id FROM task_runs + SELECT task_run_id FROM task_runs WHERE task_id IN ({task_ids_string}); """ ) @@ -141,7 +147,7 @@ def get_task_run_ids_ids_by_labels(db: "MephistoDB", labels: List[str]) -> List[ c.execute( f""" - SELECT unique_field_values FROM imported_data + SELECT unique_field_values FROM imported_data WHERE table_name = 'task_runs' {where_labels_string}; """ ) @@ -163,9 +169,7 @@ def get_table_pk_field_name(db: "MephistoDB", table_name: str): """ with db.table_access_condition, db.get_connection() as conn: c = conn.cursor() - c.execute( - f"SELECT name FROM pragma_table_info('{table_name}') WHERE pk;" - ) + c.execute(f"SELECT name FROM pragma_table_info('{table_name}') WHERE pk;") table_unique_field_name = c.fetchone()["name"] return table_unique_field_name @@ -173,9 +177,7 @@ def get_table_pk_field_name(db: "MephistoDB", table_name: str): def select_all_table_rows(db: "MephistoDB", table_name: str) -> List[dict]: with db.table_access_condition, db.get_connection() as conn: c = conn.cursor() - c.execute( - f"SELECT * FROM {table_name};" - ) + c.execute(f"SELECT * FROM {table_name};") rows = c.fetchall() return [dict(row) for row in rows] @@ -207,10 +209,12 @@ def select_rows_by_list_of_field_values( _field_values = field_values[i] field_values_string = ",".join([f"'{s}'" for s in _field_values]) where_list.append([field_name, field_values_string]) - where_string = " AND ".join([ - f"{field_name} IN ({field_values_string})" - for field_name, field_values_string in where_list - ]) + where_string = " AND ".join( + [ + f"{field_name} IN ({field_values_string})" + for field_name, field_values_string in where_list + ] + ) # Combine ORDER BY statement order_by_string = "" @@ -221,7 +225,7 @@ def select_rows_by_list_of_field_values( c.execute( f""" - SELECT * FROM {table_name} + SELECT * FROM {table_name} WHERE {where_string} {order_by_string}; """ @@ -232,15 +236,15 @@ def select_rows_by_list_of_field_values( def delete_exported_data_without_fk_constraints( - db: "MephistoDB", db_dump: dict, table_names_can_be_cleaned: Optional[List[str]] = None, + db: "MephistoDB", + db_dump: dict, + table_names_can_be_cleaned: Optional[List[str]] = None, ): table_names_can_be_cleaned = table_names_can_be_cleaned or [] with db.table_access_condition, db.get_connection() as conn: c = conn.cursor() - c.execute( - "PRAGMA foreign_keys = off;" - ) + c.execute("PRAGMA foreign_keys = off;") delete_queries = [] for table_name, rows in db_dump.items(): @@ -255,9 +259,7 @@ def delete_exported_data_without_fk_constraints( ) c.executescript("\n".join(delete_queries)) - c.execute( - "PRAGMA foreign_keys = on;" - ) + c.execute("PRAGMA foreign_keys = on;") def delete_entire_exported_data(db: "MephistoDB"): @@ -268,9 +270,7 @@ def delete_entire_exported_data(db: "MephistoDB"): with db.table_access_condition, db.get_connection() as conn: c = conn.cursor() - c.execute( - "PRAGMA foreign_keys = off;" - ) + c.execute("PRAGMA foreign_keys = off;") delete_queries = [] for table_name in table_names: @@ -281,31 +281,29 @@ def delete_entire_exported_data(db: "MephistoDB"): c.executescript("\n".join(delete_queries)) - c.execute( - "PRAGMA foreign_keys = on;" - ) + c.execute("PRAGMA foreign_keys = on;") def get_list_of_provider_types(db: "MephistoDB") -> List[str]: with db.table_access_condition, db.get_connection() as conn: c = conn.cursor() - c.execute( - "SELECT provider_type FROM requesters;" - ) + c.execute("SELECT provider_type FROM requesters;") rows = c.fetchall() return [r["provider_type"] for r in rows] def get_latest_row_from_table( - db: "MephistoDB", table_name: str, order_by: Optional[str] = "creation_date", + db: "MephistoDB", + table_name: str, + order_by: Optional[str] = "creation_date", ) -> Optional[dict]: with db.table_access_condition, db.get_connection() as conn: c = conn.cursor() c.execute( f""" SELECT * - FROM {table_name} - ORDER BY {order_by} DESC + FROM {table_name} + ORDER BY {order_by} DESC LIMIT 1; """, ) @@ -371,8 +369,25 @@ def get_list_of_tables_to_export(db: "MephistoDB") -> List[str]: return filtered_table_names +def get_list_of_available_labels(db: "MephistoDB") -> List[str]: + with db.table_access_condition, db.get_connection() as conn: + c = conn.cursor() + c.execute("SELECT data_labels FROM imported_data;") + rows = c.fetchall() + + labels = [] + for row in rows: + row_labels: List[List[str]] = json.loads(row["data_labels"]) + labels += row_labels + + return list(set(labels)) + + def check_if_row_with_params_exists( - db: "MephistoDB", table_name: str, params: dict, select_field: Optional[str] = "*", + db: "MephistoDB", + table_name: str, + params: dict, + select_field: Optional[str] = "*", ) -> bool: """ Check if row exists in `table_name` for passed dict of `params` @@ -391,8 +406,8 @@ def check_if_row_with_params_exists( c.execute( f""" - SELECT {select_field} - FROM {table_name} {where_string} + SELECT {select_field} + FROM {table_name} {where_string} LIMIT 1; """, execute_args, @@ -462,36 +477,56 @@ def mephisto_db_to_dict_for_task_runs( # Find and serialize `projects` project_ids = list(set(filter(bool, [i["project_id"] for i in dump_data["tasks"]]))) project_rows = select_rows_by_list_of_field_values( - db, "projects", ["project_id"], [project_ids], + db, + "projects", + ["project_id"], + [project_ids], ) dump_data["projects"] = serialize_data_for_table(project_rows) # Find and serialize `requesters` requester_ids = list(set(filter(bool, [i["requester_id"] for i in dump_data["task_runs"]]))) requester_rows = select_rows_by_list_of_field_values( - db, "requesters", ["requester_id"], [requester_ids], + db, + "requesters", + ["requester_id"], + [requester_ids], ) dump_data["requesters"] = serialize_data_for_table(requester_rows) # Find and serialize `workers` worker_ids = list(set(filter(bool, [i["worker_id"] for i in dump_data["units"]]))) worker_rows = select_rows_by_list_of_field_values( - db, "workers", ["worker_id"], [worker_ids], + db, + "workers", + ["worker_id"], + [worker_ids], ) dump_data["workers"] = serialize_data_for_table(worker_rows) # Find and serialize `granted_qualifications` granted_qualification_rows = select_rows_by_list_of_field_values( - db, "granted_qualifications", ["worker_id"], [worker_ids], + db, + "granted_qualifications", + ["worker_id"], + [worker_ids], ) dump_data["granted_qualifications"] = serialize_data_for_table(granted_qualification_rows) # Find and serialize `qualifications` - qualification_ids = list(set(filter( - bool, [i["qualification_id"] for i in dump_data["granted_qualifications"]], - ))) + qualification_ids = list( + set( + filter( + bool, + [i["qualification_id"] for i in dump_data["granted_qualifications"]], + ) + ) + ) qualification_rows = select_rows_by_list_of_field_values( - db, "qualifications", ["qualification_id"], [qualification_ids], + db, + "qualifications", + ["qualification_id"], + [qualification_ids], ) dump_data["qualifications"] = serialize_data_for_table(qualification_rows) @@ -560,7 +595,10 @@ def insert_new_row_in_table(db: "MephistoDB", table_name: str, row: dict): def update_row_in_table( - db: "MephistoDB", table_name: str, row: dict, pk_field_name: Optional[str] = None, + db: "MephistoDB", + table_name: str, + row: dict, + pk_field_name: Optional[str] = None, ): row = deepcopy(row) @@ -588,6 +626,7 @@ def update_row_in_table( # --- Decorators --- + def retry_generate_id(caught_excs: Optional[List[Type[Exception]]] = None): """ A decorator that attempts to call create DB entry until ID will be unique. @@ -597,6 +636,7 @@ def retry_generate_id(caught_excs: Optional[List[Type[Exception]]] = None): - db - table_name """ + def decorator(unreliable_fn: Callable): def wrapped_fn(*args, **kwargs): caught_excs_tuple = tuple(caught_excs or [Exception]) @@ -614,16 +654,22 @@ def wrapped_fn(*args, **kwargs): # Othervise, we just leave error as is exc_message = str(getattr(e, "original_exc", None) or "") db = getattr(e, "db", None) - table_name = getattr(e, "table_name", None) - is_unique_constraint = exc_message.startswith("UNIQUE constraint") + table_name = getattr(e, "table_name", "") + pk_fieldname = get_table_pk_field_name(db, table_name=table_name) + is_pk_unique_constraint = ( + exc_message.startswith("UNIQUE constraint") + and f"{table_name}.{pk_fieldname}" in exc_message + ) - if db and table_name and is_unique_constraint: - pk_field_name = get_table_pk_field_name(db, table_name=table_name) - if pk_field_name in exc_message: - pk_exists = True + if db and table_name and is_pk_unique_constraint: + pk_exists = True + else: + # In case if we caught other unique constraint, reraise it + raise # Set original function name to wrapped one. wrapped_fn.__name__ = unreliable_fn.__name__ return wrapped_fn + return decorator diff --git a/mephisto/utils/misc.py b/mephisto/utils/misc.py index ca5b35f7a..df168fa8b 100644 --- a/mephisto/utils/misc.py +++ b/mephisto/utils/misc.py @@ -15,9 +15,7 @@ def serialize_date_to_python(value: Any) -> datetime: # If integer timestamp if isinstance(value, int): timestamp_is_in_msec = len(str(value)) == 13 - datetime_value = datetime.fromtimestamp( - value / 1000 if timestamp_is_in_msec else value - ) + datetime_value = datetime.fromtimestamp(value / 1000 if timestamp_is_in_msec else value) # If datetime string else: datetime_value = dateutil_parse(str(value)) diff --git a/mephisto/utils/testing.py b/mephisto/utils/testing.py index aa811be42..694ed3b20 100644 --- a/mephisto/utils/testing.py +++ b/mephisto/utils/testing.py @@ -222,7 +222,7 @@ def find_unit_reviews( SELECT * FROM unit_review WHERE (updated_qualification_id = ?1) OR - (revoked_qualification_id = ?1) AND + (revoked_qualification_id = ?1) AND (worker_id = ?2) {task_query} ORDER BY creation_date ASC; diff --git a/test/core/test_operator.py b/test/core/test_operator.py index 5a7078279..bf6a97da1 100644 --- a/test/core/test_operator.py +++ b/test/core/test_operator.py @@ -15,6 +15,7 @@ from unittest.mock import patch from tqdm import TMonitor # type: ignore +from mephisto.data_model.assignment import Assignment from mephisto.utils.testing import get_test_requester from mephisto.data_model.constants.assignment_state import AssignmentState from mephisto.abstractions.databases.local_database import LocalMephistoDB @@ -90,7 +91,7 @@ def tearDown(self): f"Expected only main thread at teardown after {SHUTDOWN_TIMEOUT} seconds, found {target_threads}", ) - def wait_for_complete_assignment(self, assignment, timeout: int): + def wait_for_complete_assignment(self, assignment: Assignment, timeout: int): start_time = time.time() while time.time() - start_time < timeout: if assignment.get_status() == AssignmentState.COMPLETED: diff --git a/test/review_app/server/api/test_units_view.py b/test/review_app/server/api/test_units_view.py index f9fc93d3b..3a4955388 100644 --- a/test/review_app/server/api/test_units_view.py +++ b/test/review_app/server/api/test_units_view.py @@ -76,14 +76,14 @@ def test_one_unit_with_unit_ids_success(self, *args, **kwargs): unit_1_id = get_test_unit(self.db) unit_1: Unit = Unit.get(self.db, unit_1_id) unit_2_id = self.db.new_unit( - unit_1.task_id, - unit_1.task_run_id, - unit_1.requester_id, - unit_1.db_id, - 2, - 1, - unit_1.provider_type, - unit_1.task_type, + task_id=unit_1.task_id, + task_run_id=unit_1.task_run_id, + requester_id=unit_1.requester_id, + assignment_id=unit_1.assignment_id, + unit_index=2, + pay_amount=1, + provider_type=unit_1.provider_type, + task_type=unit_1.task_type, ) unit_2: Unit = Unit.get(self.db, unit_2_id) unit_1.set_db_status(AssignmentState.COMPLETED) @@ -104,14 +104,14 @@ def test_two_units_with_unit_ids_success(self, *args, **kwargs): unit_1_id = get_test_unit(self.db) unit_1: Unit = Unit.get(self.db, unit_1_id) unit_2_id = self.db.new_unit( - unit_1.task_id, - unit_1.task_run_id, - unit_1.requester_id, - unit_1.db_id, - 2, - 1, - unit_1.provider_type, - unit_1.task_type, + task_id=unit_1.task_id, + task_run_id=unit_1.task_run_id, + requester_id=unit_1.requester_id, + assignment_id=unit_1.assignment_id, + unit_index=2, + pay_amount=1, + provider_type=unit_1.provider_type, + task_type=unit_1.task_type, ) unit_2: Unit = Unit.get(self.db, unit_2_id) From 09bcff1d36841ca0994fe32664bb2be0618c3922 Mon Sep 17 00:00:00 2001 From: Paul Abumov Date: Sun, 5 May 2024 21:51:53 -0400 Subject: [PATCH 3/3] Unittest coverage for Data Porter feature --- .../{merge_dbs => data_porter}/_category_.yml | 2 +- .../custom_conflict_resolver.md | 0 .../{merge_dbs => data_porter}/reference.md | 20 +- .../how_to_use/data_porter/simple_usage.md | 89 +++ .../how_to_use/merge_dbs/simple_usage.md | 135 ---- ...y => _001_20240325_data_porter_feature.py} | 3 +- .../databases/migrations/__init__.py | 4 +- ...y => _001_20240325_data_porter_feature.py} | 13 +- .../providers/mturk/migrations/__init__.py | 4 +- ...y => _001_20240325_data_porter_feature.py} | 9 +- .../providers/prolific/migrations/__init__.py | 4 +- mephisto/client/cli_db_commands.py | 11 +- mephisto/data_model/task_run.py | 3 + .../scripts/local_db/review_tips_for_task.py | 29 +- mephisto/tools/db_data_porter/backups.py | 2 +- .../tools/db_data_porter/db_data_porter.py | 5 +- mephisto/tools/db_data_porter/dumps.py | 4 +- mephisto/tools/db_data_porter/export_dump.py | 4 +- mephisto/tools/db_data_porter/import_dump.py | 2 +- .../tools/db_data_porter/randomize_ids.py | 2 +- mephisto/utils/db.py | 29 +- mephisto/utils/dirs.py | 16 +- mephisto/utils/testing.py | 38 +- pytest.ini | 2 + test/tools/db_data_porter/__init__.py | 0 .../conflict_resolvers/__init__.py | 0 .../test_default_merge_conflict_resolver.py | 132 ++++ test/tools/db_data_porter/test_backups.py | 106 +++ .../db_data_porter/test_db_data_porter.py | 258 +++++++ test/utils/test_db.py | 726 ++++++++++++++++++ test/utils/test_misc.py | 40 + 31 files changed, 1477 insertions(+), 215 deletions(-) rename docs/web/docs/guides/how_to_use/{merge_dbs => data_porter}/_category_.yml (88%) rename docs/web/docs/guides/how_to_use/{merge_dbs => data_porter}/custom_conflict_resolver.md (100%) rename docs/web/docs/guides/how_to_use/{merge_dbs => data_porter}/reference.md (84%) create mode 100644 docs/web/docs/guides/how_to_use/data_porter/simple_usage.md delete mode 100644 docs/web/docs/guides/how_to_use/merge_dbs/simple_usage.md rename mephisto/abstractions/databases/migrations/{_001_20240325_preparing_db_for_merge_dbs_command.py => _001_20240325_data_porter_feature.py} (99%) rename mephisto/abstractions/providers/mturk/migrations/{_001_20240325_preparing_db_for_merge_dbs_command.py => _001_20240325_data_porter_feature.py} (94%) rename mephisto/abstractions/providers/prolific/migrations/{_001_20240325_preparing_db_for_merge_dbs_command.py => _001_20240325_data_porter_feature.py} (96%) create mode 100644 test/tools/db_data_porter/__init__.py create mode 100644 test/tools/db_data_porter/conflict_resolvers/__init__.py create mode 100644 test/tools/db_data_porter/conflict_resolvers/test_default_merge_conflict_resolver.py create mode 100644 test/tools/db_data_porter/test_backups.py create mode 100644 test/tools/db_data_porter/test_db_data_porter.py create mode 100644 test/utils/test_db.py create mode 100644 test/utils/test_misc.py diff --git a/docs/web/docs/guides/how_to_use/merge_dbs/_category_.yml b/docs/web/docs/guides/how_to_use/data_porter/_category_.yml similarity index 88% rename from docs/web/docs/guides/how_to_use/merge_dbs/_category_.yml rename to docs/web/docs/guides/how_to_use/data_porter/_category_.yml index 919eb213e..2fbf9eb07 100644 --- a/docs/web/docs/guides/how_to_use/merge_dbs/_category_.yml +++ b/docs/web/docs/guides/how_to_use/data_porter/_category_.yml @@ -2,6 +2,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -label: "Merge databases" +label: "Move data around" collapsed: false position: 9 diff --git a/docs/web/docs/guides/how_to_use/merge_dbs/custom_conflict_resolver.md b/docs/web/docs/guides/how_to_use/data_porter/custom_conflict_resolver.md similarity index 100% rename from docs/web/docs/guides/how_to_use/merge_dbs/custom_conflict_resolver.md rename to docs/web/docs/guides/how_to_use/data_porter/custom_conflict_resolver.md diff --git a/docs/web/docs/guides/how_to_use/merge_dbs/reference.md b/docs/web/docs/guides/how_to_use/data_porter/reference.md similarity index 84% rename from docs/web/docs/guides/how_to_use/merge_dbs/reference.md rename to docs/web/docs/guides/how_to_use/data_porter/reference.md index 55677f590..e90ce9354 100644 --- a/docs/web/docs/guides/how_to_use/merge_dbs/reference.md +++ b/docs/web/docs/guides/how_to_use/data_porter/reference.md @@ -106,10 +106,28 @@ Options: - `-v/--verbosity` - level of logging (default: 0; values: 0, 1) -## Note on legacy PKs +## Important notes + +### Data dump vs backup + +Mephisto stores local data in `outputs` and `data` folders. The safest way to back Mephisto up is to create a copy of the `data` folder - and that's what a Mephisto backup contains. + +On the other hand, partial data export is written into a data dump that contains: + +- a JSON file representing relevant data entries from DB tables +- a folder with all files related to the exported data entries + +With the export command, you **can** create a dump of the entire data as well, and here's when it's useful: +- Use `mephisto db backup` as the safest option, and if you only intend to restore this data instead of previous one +- Use `mephisto db export` to dump complete data from a small Mephisto project, so it can be imported into a larger Mephisto project later. + + +### Legacy PKs Prior to release `v1.4` of Mephisto, its DB schemas used auto-incremented integer primary keys. While convenient for debugging, it causes problems during data import/export. As of `v1.4` we have replaced these "legacy" PKs with quazi-random integers (for backward compatibility their values are designed to be above 1,000,000). If you do wish to use import/export commands with your "legacy" data, include the `--randomize-legacy-ids` option. It prevents data corruption when merging 2 sets of "legacy" data (because they will contain same integer PKs `1, 2, 3,...` for completely unrelated objects). + +This handling of legacy PKs ensures that Data Porter feature is backward compatible, and wll work with your previous existing Mephisto data. diff --git a/docs/web/docs/guides/how_to_use/data_porter/simple_usage.md b/docs/web/docs/guides/how_to_use/data_porter/simple_usage.md new file mode 100644 index 000000000..03f63c0d4 --- /dev/null +++ b/docs/web/docs/guides/how_to_use/data_porter/simple_usage.md @@ -0,0 +1,89 @@ +--- + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +sidebar_position: 1 +--- + +# Simple usage + + +## Introduction + +Sometimes you may want to run Mephisto remotely on remote server(s) for data collection, since that stage takes a while. +The Data Porter feature allows to move around data collected by different Mephisto instances, for ease of review and record keeping. + +Data Porter can do the following for you: + +- Backing up your local data +- Restoring your local data +- Exporting part of your local data (into a data dump) +- Importing data from a data dump (into your local data) + +Before making any changes to data, we recommend creating a backup of your local data +(so you can roll back the changes if anything goes wrong). + +--- + +## Common use scenarios + +### Backing up data + +The below backup command will create an archived version of your local `data` directory +(that contains all data for the project), and place it in `outputs/backup/` directory: + +```shell +mephisto db backup +``` + +### Restoring a backup + +You can restore a snapshot of your local data from a backup data dump (created with `mephisto db backup` command): + +```shell +mephisto db restore --file +``` + +where `` can be either full path to a file, or just the filename (if it's located in the `outputs/backup/` directory) + +Important notes: + +- Your current local data will be erased (to make room for the restored data) +- If DB schema of the data dump is outdated, Mephisto when launched will automatically try to apply all existing migrations + + +### Exporting data + +To export all local data (and make it importable later), run + +```shell +mephisto db export +``` + +To export partial data only partially (i.e. from a few selected Task Runs), you have a few options of identifying the imported data. The simplest method is by using Task name(s): + +```shell +mephisto db export --export-tasks-by-names "My first Task" --export-tasks-by-names "My second Task" +``` + +This will generate an archive file in the `outputs/export/` directory. + +#### Legacy PKs note + +If you're exporting "legacy" data entries (i.e. created before May 2024), you should add parameter `--randomize-legacy-ids` to your export command. This will ensure lack of conflicts when importing this dump into a "legacy" dated database. +All this parameter does is change (within the dump) sequential integer PKs to random integer PKs, while preserving all object relations. + + +### Importing data + +You can restore data dump content (created with `mephisto db export` command) into your local data like so: + +```shell +mephisto db import --file +``` + +where `` can be either full path to a file, or just the filename (if it's located in the `outputs/export/` directory) + +Note that before the import starts, a full backup of your local data will be automatically created and saved to `outputs/backup/` directory. diff --git a/docs/web/docs/guides/how_to_use/merge_dbs/simple_usage.md b/docs/web/docs/guides/how_to_use/merge_dbs/simple_usage.md deleted file mode 100644 index 1c2089dc6..000000000 --- a/docs/web/docs/guides/how_to_use/merge_dbs/simple_usage.md +++ /dev/null @@ -1,135 +0,0 @@ ---- - -# Copyright (c) Meta Platforms and its affiliates. -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -sidebar_position: 1 ---- - -# Simple usage - - -## Introduction - -We realized that caompanies can be big, and they can run many tasks on different computers/servers, -or even one task in several departments or school classes. -But later it is much easier to review all tasks together. - -And here is the solution - merging tasks data into simple one. - - -## How it works - -1. You create full backup to save all your data to have the ability to roll all changes back if somthing went wrong -2. Export tasks into JSON dump file with related files in ZIP archive -3. Send or collect all dumps together -4. Use your main project or new Mehisto project to import all these dumps into it -5. Restore from backup if changed your mind or start from scratch - - -## Most common scenario of usage - -### Backup your main project - -If you already have some kind of main Mephisto project where all your tasks were processed, -you may want to merge a dump into this exact project. -We strongly recommend to make a backup of all your data manually and save it somewhere you can easily find. - -The command is: - -```shell -mephisto db backup -``` - -And you will see text like this - -``` -Started creating backup file ... -Finished successfully! File: //outputs/backup/2024_01_01_00_00_01_mephisto_backup.zip -``` - -Find and copy this file. - -### Export data in dump - -To make a dump with all you data, use simple command: - -```shell -mephisto db export -``` - -if you want to export just 2 tasks from 10, you need to add an option: - -```shell -mephisto db export --export-tasks-by-names "My first Task" --export-tasks-by-names "My second Task" -``` - -If you run tasks before June 2024 you should use parameter `--randomize-legacy-ids`. -Why do you need this? We modified our Primary Keys in our databases. -They were autoincremented and in all you projects start from 1. -It will bring us into conflicts in all databases. -So, this parameter will regenerate randomly all Primary Keys and replace Foreign Keys with them as well. -Note that it will not affect databases, Primary Keys will be new only in dump. - -```shell -mephisto db export --randomize-legacy-ids -``` - -And you will see text like this - -``` -Started exporting data ... -No filter for TaskRun specified - exporting all TaskRuns. -Finished successfully! -Files created: - - Dump archive - //outputs/export/2024_01_01_00_00_01_mephisto_dump.zip -``` - -### Import just created dump into main project - -Put your dump into export directory `//outputs/export/` and you can use just a dump name in the command, -or use a full path to the file. -Let's just imagine, you put file in export directory: - -```shell -mephisto db import --file 2024_01_01_00_00_01_mephisto_dump.zip -``` - -And you will see text like this - -``` -Are you sure? It will affect your databases and related files. Type 'yes' and press Enter if you want to proceed: yes -Just in case, we are making a backup of all your local data. If something went wrong during import, we will restore all your data from this backup -Backup was created successfully! File: '//outputs/backup/2024_04_25_17_11_56_mephisto_backup.zip' -Started importing from dump file //outputs/export/2024_04_25_17_11_43_mephisto_dump.zip ... -Finished successfully -``` - -Note that the progress will stop and will be waiting for your answer __yes__. -Also, we create a backup automatically just in case too, just before all changes. - -### Restoring from backup - -"OMG! I imported wrong dump! What have I done!" - you may cry. - -No worries, just restore everything from your or our backup: - -```shell -mephisto db restore --file 2024_01_01_00_10_01_mephisto_backup.zip -``` - -And you will see text like this - -``` -Are you sure? It will affect your databases and related files. Type 'yes' and press Enter if you want to proceed: yes -Started restoring from backup //outputs/backup/2024_01_01_00_10_01_mephisto_backup.zip ... -Finished successfully -``` - -Note that the progress will stop and will be waiting for your answer __yes__. - -### Conclusion - -Now, after you merged your two projects, you can easily start -[reviewing your tasks](/docs/guides/how_to_use/review_app/overview/). diff --git a/mephisto/abstractions/databases/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py b/mephisto/abstractions/databases/migrations/_001_20240325_data_porter_feature.py similarity index 99% rename from mephisto/abstractions/databases/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py rename to mephisto/abstractions/databases/migrations/_001_20240325_data_porter_feature.py index 20061a12b..3261f0de1 100644 --- a/mephisto/abstractions/databases/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py +++ b/mephisto/abstractions/databases/migrations/_001_20240325_data_porter_feature.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. """ +List of changes: 1. Rename `unit_review.created_at` -> `unit_review.creation_date` 2. Remove autoincrement parameter for all Primary Keys 3. Add missed Foreign Keys in `agents` table @@ -13,7 +14,7 @@ """ -PREPARING_DB_FOR_MERGE_DBS_COMMAND = """ +MODIFICATIONS_FOR_DATA_PORTER = """ ALTER TABLE unit_review RENAME COLUMN created_at TO creation_date; /* Disable FK constraints */ diff --git a/mephisto/abstractions/databases/migrations/__init__.py b/mephisto/abstractions/databases/migrations/__init__.py index 092965e1b..5f1d516d6 100644 --- a/mephisto/abstractions/databases/migrations/__init__.py +++ b/mephisto/abstractions/databases/migrations/__init__.py @@ -4,9 +4,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from ._001_20240325_preparing_db_for_merge_dbs_command import * +from ._001_20240325_data_porter_feature import * migrations = { - "20240418_preparing_db_for_merge_dbs_command": PREPARING_DB_FOR_MERGE_DBS_COMMAND, + "20240418_data_porter_feature": MODIFICATIONS_FOR_DATA_PORTER, } diff --git a/mephisto/abstractions/providers/mturk/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py b/mephisto/abstractions/providers/mturk/migrations/_001_20240325_data_porter_feature.py similarity index 94% rename from mephisto/abstractions/providers/mturk/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py rename to mephisto/abstractions/providers/mturk/migrations/_001_20240325_data_porter_feature.py index d38dd9d40..8c4517b47 100644 --- a/mephisto/abstractions/providers/mturk/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py +++ b/mephisto/abstractions/providers/mturk/migrations/_001_20240325_data_porter_feature.py @@ -5,11 +5,12 @@ # LICENSE file in the root directory of this source tree. """ -1. Modified default value for `creation_date` +List of changes: +1. Modify default value for `creation_date` """ -PREPARING_DB_FOR_MERGE_DBS_COMMAND = """ +MODIFICATIONS_FOR_DATA_PORTER = """ /* Disable FK constraints */ PRAGMA foreign_keys = off; @@ -36,8 +37,8 @@ INSERT INTO _run_mappings SELECT * FROM run_mappings; DROP TABLE run_mappings; ALTER TABLE _run_mappings RENAME TO run_mappings; - - + + /* Runs */ CREATE TABLE IF NOT EXISTS _runs ( run_id TEXT PRIMARY KEY UNIQUE, @@ -50,8 +51,8 @@ INSERT INTO _runs SELECT * FROM runs; DROP TABLE runs; ALTER TABLE _runs RENAME TO runs; - - + + /* Qualifications */ CREATE TABLE IF NOT EXISTS _qualifications ( qualification_name TEXT PRIMARY KEY UNIQUE, diff --git a/mephisto/abstractions/providers/mturk/migrations/__init__.py b/mephisto/abstractions/providers/mturk/migrations/__init__.py index 092965e1b..5f1d516d6 100644 --- a/mephisto/abstractions/providers/mturk/migrations/__init__.py +++ b/mephisto/abstractions/providers/mturk/migrations/__init__.py @@ -4,9 +4,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from ._001_20240325_preparing_db_for_merge_dbs_command import * +from ._001_20240325_data_porter_feature import * migrations = { - "20240418_preparing_db_for_merge_dbs_command": PREPARING_DB_FOR_MERGE_DBS_COMMAND, + "20240418_data_porter_feature": MODIFICATIONS_FOR_DATA_PORTER, } diff --git a/mephisto/abstractions/providers/prolific/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py b/mephisto/abstractions/providers/prolific/migrations/_001_20240325_data_porter_feature.py similarity index 96% rename from mephisto/abstractions/providers/prolific/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py rename to mephisto/abstractions/providers/prolific/migrations/_001_20240325_data_porter_feature.py index f792c223e..8c70e92e4 100644 --- a/mephisto/abstractions/providers/prolific/migrations/_001_20240325_preparing_db_for_merge_dbs_command.py +++ b/mephisto/abstractions/providers/prolific/migrations/_001_20240325_data_porter_feature.py @@ -5,16 +5,17 @@ # LICENSE file in the root directory of this source tree. """ +List of changes: 1. Remove autoincrement parameter for all Primary Keys -2. Added `update_date` and `creation_date` in `workers` table -3. Added `creation_date` in `units` table +2. Add `update_date` and `creation_date` in `workers` table +3. Add `creation_date` in `units` table 4. Rename field `run_id` -> `task_run_id` 5. Remove table `requesters` -6. Modified default value for `creation_date` +6. Modify default value for `creation_date` """ -PREPARING_DB_FOR_MERGE_DBS_COMMAND = """ +MODIFICATIONS_FOR_DATA_PORTER = """ /* Disable FK constraints */ PRAGMA foreign_keys = off; diff --git a/mephisto/abstractions/providers/prolific/migrations/__init__.py b/mephisto/abstractions/providers/prolific/migrations/__init__.py index 092965e1b..5f1d516d6 100644 --- a/mephisto/abstractions/providers/prolific/migrations/__init__.py +++ b/mephisto/abstractions/providers/prolific/migrations/__init__.py @@ -4,9 +4,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from ._001_20240325_preparing_db_for_merge_dbs_command import * +from ._001_20240325_data_porter_feature import * migrations = { - "20240418_preparing_db_for_merge_dbs_command": PREPARING_DB_FOR_MERGE_DBS_COMMAND, + "20240418_data_porter_feature": MODIFICATIONS_FOR_DATA_PORTER, } diff --git a/mephisto/client/cli_db_commands.py b/mephisto/client/cli_db_commands.py index 17790ec14..e782ab9b7 100644 --- a/mephisto/client/cli_db_commands.py +++ b/mephisto/client/cli_db_commands.py @@ -9,7 +9,6 @@ import click from rich_click import RichCommand -from rich_click import RichContext from mephisto.tools.db_data_porter import DBDataPorter from mephisto.tools.db_data_porter.constants import DEFAULT_CONFLICT_RESOLVER @@ -22,7 +21,7 @@ logger = ConsoleWriter() -def _print_used_options_for_running_command_message(ctx: RichContext, options: dict): +def _print_used_options_for_running_command_message(ctx: click.Context, options: dict): message = "Running command with the following options:\n" for p in ctx.command.params: values = options[p.name] @@ -112,7 +111,7 @@ def db_cli(): ), ) @click.option("-v", "--verbosity", type=int, default=VERBOSITY_DEFAULT_VALUE, help=VERBOSITY_HELP) -def export(ctx: RichContext, **options: dict): +def export(ctx: click.Context, **options: dict): """ This command exports data from Mephisto DB and provider-specific datastores as an archived combination of (i) a JSON file, and (ii) a `data` catalog with related files. @@ -238,7 +237,7 @@ def export(ctx: RichContext, **options: dict): help="write data from `imported_data` table of the dump (by default it's not imported)", ) @click.option("-v", "--verbosity", type=int, default=VERBOSITY_DEFAULT_VALUE, help=VERBOSITY_HELP) -def _import(ctx: RichContext, **options: dict): +def _import(ctx: click.Context, **options: dict): """ This command imports data from a dump file created by `mephisto db export` command. @@ -271,7 +270,7 @@ def _import(ctx: RichContext, **options: dict): @db_cli.command("backup", cls=RichCommand) @click.pass_context @click.option("-v", "--verbosity", type=int, default=VERBOSITY_DEFAULT_VALUE, help=VERBOSITY_HELP) -def backup(ctx: RichContext, **options: dict): +def backup(ctx: click.Context, **options: dict): """ Creates full backup of all current data (Mephisto DB, provider-specific datastores, and related files) on local machine. @@ -301,7 +300,7 @@ def backup(ctx: RichContext, **options: dict): ), ) @click.option("-v", "--verbosity", type=int, default=VERBOSITY_DEFAULT_VALUE, help=VERBOSITY_HELP) -def restore(ctx: RichContext, **options): +def restore(ctx: click.Context, **options): """ Restores all data (Mephisto DB, provider-specific datastores, and related files) from a backup archive. diff --git a/mephisto/data_model/task_run.py b/mephisto/data_model/task_run.py index 73a37d74f..d71974e7b 100644 --- a/mephisto/data_model/task_run.py +++ b/mephisto/data_model/task_run.py @@ -8,6 +8,8 @@ import os import json from dataclasses import dataclass, field +from datetime import datetime +from dateutil.parser import parse from mephisto.data_model.requester import Requester from mephisto.data_model.constants.assignment_state import AssignmentState @@ -202,6 +204,7 @@ def __init__( self.task_type: str = row["task_type"] self.sandbox: bool = row["sandbox"] self.assignments_generator_done: bool = False + self.creation_date: Optional[datetime] = parse(row["creation_date"]) # properties with deferred loading self.__is_completed = row["is_completed"] diff --git a/mephisto/scripts/local_db/review_tips_for_task.py b/mephisto/scripts/local_db/review_tips_for_task.py index 571269be3..8c4571ff0 100644 --- a/mephisto/scripts/local_db/review_tips_for_task.py +++ b/mephisto/scripts/local_db/review_tips_for_task.py @@ -117,6 +117,7 @@ def main(): if len(units) == 0: print("[red]No units were received[/red]") quit() + for unit in units: if unit.agent_id is not None: unit_data = mephisto_data_browser.get_data_from_unit(unit) @@ -125,15 +126,18 @@ def main(): if tips is not None and len(tips) > 0: tips_copy = tips.copy() for i in range(len(tips)): - if tips[i]["accepted"] == False: + if tips[i]["accepted"] is False: + title = ( + "\nTip {current_tip} of {total_number_of_tips} from Agent {agent_id}" + ).format( + current_tip=i + 1, + total_number_of_tips=len(tips), + agent_id=unit.agent_id, + ) current_tip_table = Table( "Property", "Value", - title="\nTip {current_tip} of {total_number_of_tips} From Agent {agent_id}".format( - current_tip=i + 1, - total_number_of_tips=len(tips), - agent_id=unit.agent_id, - ), + title=title, box=box.ROUNDED, expand=True, show_lines=True, @@ -157,17 +161,20 @@ def main(): print("[green]Tip Accepted[/green]") # given the option to pay a bonus to the worker who wrote the tip bonus = FloatPrompt.ask( - "\nHow much would you like to bonus the tip submitter? (Default: 0.0)", + "\nHow much would you like to bonus the tip submitter? " + "(Default: 0.0)", show_default=False, default=0.0, ) if bonus > 0: reason = Prompt.ask( - "\nWhat reason would you like to give the worker for this tip? NOTE: This will be shared with the worker.(Default: Thank you for submitting a tip!)", + "\nWhat reason would you like to give the worker for this tip? " + "NOTE: This will be shared with the worker. " + "(Default: Thank you for submitting a tip!)", default="Thank you for submitting a tip!", show_default=False, ) - worker_id = float(unit_data["worker_id"]) + worker_id = unit_data["worker_id"] worker = Worker.get(db, worker_id) if worker is not None: bonus_successfully_paid = worker.bonus_worker( @@ -177,7 +184,9 @@ def main(): print("\n[green]Bonus Successfully Paid![/green]\n") else: print( - "\n[red]There was an error when paying out your bonus[/red]\n" + "\n[red]" + "There was an error when paying out your bonus" + "[/red]\n" ) elif tip_response == TipsReviewType.REJECTED.value: diff --git a/mephisto/tools/db_data_porter/backups.py b/mephisto/tools/db_data_porter/backups.py index a77e15a50..f6ccd4a3b 100644 --- a/mephisto/tools/db_data_porter/backups.py +++ b/mephisto/tools/db_data_porter/backups.py @@ -45,5 +45,5 @@ def restore_from_backup( if remove_backup: Path(backup_file_path).unlink(missing_ok=True) except Exception as e: - logger.exception(f"[red]Could not restore backup '{backup_file_path}'. Error: {e}[/red]") + logger.exception(f"[red]Could not restore backup {backup_file_path}. Error: {e}[/red]") exit() diff --git a/mephisto/tools/db_data_porter/db_data_porter.py b/mephisto/tools/db_data_porter/db_data_porter.py index bc916a820..c195e90d8 100644 --- a/mephisto/tools/db_data_porter/db_data_porter.py +++ b/mephisto/tools/db_data_porter/db_data_porter.py @@ -19,11 +19,12 @@ from mephisto.abstractions.databases.local_database import LocalMephistoDB from mephisto.generators.form_composer.config_validation.utils import make_error_message from mephisto.tools.db_data_porter import backups -from mephisto.tools.db_data_porter import export_dump from mephisto.tools.db_data_porter import dumps +from mephisto.tools.db_data_porter import export_dump from mephisto.tools.db_data_porter import import_dump from mephisto.tools.db_data_porter.constants import BACKUP_OUTPUT_DIR from mephisto.tools.db_data_porter.constants import DEFAULT_ARCHIVE_FORMAT +from mephisto.tools.db_data_porter.constants import DEFAULT_CONFLICT_RESOLVER from mephisto.tools.db_data_porter.constants import EXPORT_OUTPUT_DIR from mephisto.tools.db_data_porter.constants import IMPORTED_DATA_TABLE_NAME from mephisto.tools.db_data_porter.constants import MEPHISTO_DUMP_KEY @@ -301,7 +302,7 @@ def export_dump( def import_dump( self, dump_archive_file_name_or_path: str, - conflict_resolver_name: str, + conflict_resolver_name: Optional[str] = DEFAULT_CONFLICT_RESOLVER, labels: Optional[List[str]] = None, keep_import_metadata: Optional[bool] = None, verbosity: int = 0, diff --git a/mephisto/tools/db_data_porter/dumps.py b/mephisto/tools/db_data_porter/dumps.py index dfb5952ba..735f82ad5 100644 --- a/mephisto/tools/db_data_porter/dumps.py +++ b/mephisto/tools/db_data_porter/dumps.py @@ -100,7 +100,7 @@ def prepare_partial_dump_data( task_ids = task_ids or [] # Get TaskRun IDs by Task IDs - task_run_ids = db_utils.get_task_run_ids_ids_by_task_ids(db, task_ids) + task_run_ids = db_utils.get_task_run_ids_by_task_ids(db, task_ids) elif task_runs_labels: # Validate on correct values of passed TaskRun labels db_labels = db_utils.get_list_of_available_labels(db) @@ -117,7 +117,7 @@ def prepare_partial_dump_data( exit() # Get TaskRun IDs - task_run_ids = db_utils.get_task_run_ids_ids_by_labels(db, task_runs_labels) + task_run_ids = db_utils.get_task_run_ids_by_labels(db, task_runs_labels) elif since_datetime: # Get TaskRun IDs task_run_ids = db_utils.select_task_run_ids_since_date(db, since_datetime) diff --git a/mephisto/tools/db_data_porter/export_dump.py b/mephisto/tools/db_data_porter/export_dump.py index 98e030904..cadc8c404 100644 --- a/mephisto/tools/db_data_porter/export_dump.py +++ b/mephisto/tools/db_data_porter/export_dump.py @@ -12,7 +12,7 @@ from typing import Dict from typing import List -from rich_click import RichContext +import click from mephisto.data_model.task_run import TaskRun from mephisto.tools.db_data_porter.constants import AGENTS_TABLE_NAME @@ -215,7 +215,7 @@ def unarchive_data_files( logger.debug("Removing unpacked TaskRuns files finished") -def get_export_options_for_metadata(ctx: RichContext, options: dict) -> Dict[str, Any]: +def get_export_options_for_metadata(ctx: click.Context, options: dict) -> Dict[str, Any]: export_options_for_metadata = {} for param in ctx.command.params: diff --git a/mephisto/tools/db_data_porter/import_dump.py b/mephisto/tools/db_data_porter/import_dump.py index d2771f920..33b001fab 100644 --- a/mephisto/tools/db_data_porter/import_dump.py +++ b/mephisto/tools/db_data_porter/import_dump.py @@ -44,7 +44,7 @@ def _update_row_with_pks_from_resolvings_mappings( row: dict, resolvings_mapping: MappingResolvingsType, ) -> dict: - table_fks = db_utils.select_fk_mappings_for_table(db, table_name) + table_fks = db_utils.select_fk_mappings_for_single_table(db, table_name) # Update FK fields from resolving mappings if needed for fk_table, fk_table_fields in table_fks.items(): diff --git a/mephisto/tools/db_data_porter/randomize_ids.py b/mephisto/tools/db_data_porter/randomize_ids.py index fbb48fbca..c8c942c0e 100644 --- a/mephisto/tools/db_data_porter/randomize_ids.py +++ b/mephisto/tools/db_data_porter/randomize_ids.py @@ -37,7 +37,7 @@ def _randomize_ids_for_mephisto( table_names = [t for t in mephisto_dump.keys() if t not in [IMPORTED_DATA_TABLE_NAME]] # Find Foreign Keys' field names for all tables in Mephist DB - tables_fks = db_utils.select_fk_mappings_for_all_tables(db, table_names) + tables_fks = db_utils.select_fk_mappings_for_tables(db, table_names) # Make new Primary Keys for all or legacy values mephisto_pk_substitutions = {} diff --git a/mephisto/utils/db.py b/mephisto/utils/db.py index 3828edbff..bd856bd36 100644 --- a/mephisto/utils/db.py +++ b/mephisto/utils/db.py @@ -53,14 +53,6 @@ class EntryDoesNotExistException(MephistoDBException): # --- Functions --- -def _select_all_rows_from_table(db: "MephistoDB", table_name: str) -> List[dict]: - with db.table_access_condition, db.get_connection() as conn: - c = conn.cursor() - c.execute(f"SELECT * FROM {table_name};") - rows = c.fetchall() - return [dict(row) for row in rows] - - def _select_rows_from_table_related_to_task( db: "MephistoDB", table_name: str, @@ -121,7 +113,7 @@ def get_task_ids_by_task_names(db: "MephistoDB", task_names: List[str]) -> List[ return [r["task_id"] for r in rows] -def get_task_run_ids_ids_by_task_ids(db: "MephistoDB", task_ids: List[str]) -> List[str]: +def get_task_run_ids_by_task_ids(db: "MephistoDB", task_ids: List[str]) -> List[str]: with db.table_access_condition, db.get_connection() as conn: c = conn.cursor() task_ids_string = ",".join([f"'{s}'" for s in task_ids]) @@ -135,7 +127,7 @@ def get_task_run_ids_ids_by_task_ids(db: "MephistoDB", task_ids: List[str]) -> L return [r["task_run_id"] for r in rows] -def get_task_run_ids_ids_by_labels(db: "MephistoDB", labels: List[str]) -> List[str]: +def get_task_run_ids_by_labels(db: "MephistoDB", labels: List[str]) -> List[str]: with db.table_access_condition, db.get_connection() as conn: if not labels: return [] @@ -174,10 +166,15 @@ def get_table_pk_field_name(db: "MephistoDB", table_name: str): return table_unique_field_name -def select_all_table_rows(db: "MephistoDB", table_name: str) -> List[dict]: +def select_all_table_rows( + db: "MephistoDB", + table_name: str, + order_by: Optional[str] = None, +) -> List[dict]: + order_by_string = f" ORDER BY {order_by}" if order_by else "" with db.table_access_condition, db.get_connection() as conn: c = conn.cursor() - c.execute(f"SELECT * FROM {table_name};") + c.execute(f"SELECT * FROM {table_name}{order_by_string};") rows = c.fetchall() return [dict(row) for row in rows] @@ -428,7 +425,7 @@ def db_or_datastore_to_dict(db: "MephistoDB") -> dict: dump_data = {} table_names = get_list_of_tables_to_export(db) for table_name in table_names: - table_rows = _select_all_rows_from_table(db, table_name) + table_rows = select_all_table_rows(db, table_name) table_data = serialize_data_for_table(table_rows) dump_data[table_name] = table_data @@ -547,7 +544,7 @@ def select_task_run_ids_since_date(db: "MephistoDB", since: datetime) -> List[st return task_run_ids_since -def select_fk_mappings_for_table(db: "MephistoDB", table_name: str) -> dict: +def select_fk_mappings_for_single_table(db: "MephistoDB", table_name: str) -> dict: with db.table_access_condition, db.get_connection() as conn: c = conn.cursor() c.execute(f"SELECT * FROM pragma_foreign_key_list('{table_name}');") @@ -567,10 +564,10 @@ def select_fk_mappings_for_table(db: "MephistoDB", table_name: str) -> dict: return table_fks -def select_fk_mappings_for_all_tables(db: "MephistoDB", table_names: List[str]) -> dict: +def select_fk_mappings_for_tables(db: "MephistoDB", table_names: List[str]) -> dict: tables_fks = {} for table_name in table_names: - table_fks = select_fk_mappings_for_table(db, table_name) + table_fks = select_fk_mappings_for_single_table(db, table_name) tables_fks.update({table_name: table_fks}) return tables_fks diff --git a/mephisto/utils/dirs.py b/mephisto/utils/dirs.py index 0b9f4a936..def39a46f 100644 --- a/mephisto/utils/dirs.py +++ b/mephisto/utils/dirs.py @@ -72,9 +72,10 @@ def get_root_data_dir() -> str: actual_data_dir = get_config_arg(CORE_SECTION, DATA_STORAGE_KEY) if actual_data_dir is None: data_dir_location = input( - "Please enter the full path to a location to store Mephisto run data. By default this " - f"would be at '{default_data_dir}'. This dir should NOT be on a distributed file " - "store. Press enter to use the default: " + "Please enter the full path to a location to store Mephisto run data. " + "By default this would be at '{default_data_dir}'. " + "This dir should NOT be on a distributed file store. " + "Press enter to use the default: " ).strip() if len(data_dir_location) == 0: data_dir_location = default_data_dir @@ -85,8 +86,9 @@ def get_root_data_dir() -> str: if os.path.exists(database_loc) and data_dir_location != default_data_dir: should_migrate = ( input( - "We have found an existing database in the default data directory, do you want to " - f"copy any existing data from the default location to {data_dir_location}? (y)es/no: " + f"We have found an existing database in the default data directory, " + f"do you want to copy any existing data from the default location to " + f"{data_dir_location}? (y)es/no: " ) .lower() .strip() @@ -94,8 +96,8 @@ def get_root_data_dir() -> str: if len(should_migrate) == 0 or should_migrate[0] == "y": copy_tree(default_data_dir, data_dir_location) print( - "Mephisto data successfully copied, once you've confirmed the migration worked, " - "feel free to remove all of the contents in " + "Mephisto data successfully copied, once you've confirmed " + "the migration worked, feel free to remove all of the contents in " f"{default_data_dir} EXCEPT for `README.md`." ) add_config_arg(CORE_SECTION, DATA_STORAGE_KEY, data_dir_location) diff --git a/mephisto/utils/testing.py b/mephisto/utils/testing.py index 694ed3b20..37ee2b73b 100644 --- a/mephisto/utils/testing.py +++ b/mephisto/utils/testing.py @@ -46,25 +46,29 @@ ) -def get_test_project(db: MephistoDB) -> Tuple[str, str]: +def get_test_project(db: MephistoDB, project_name: Optional[str] = None) -> Tuple[str, str]: """Helper to create a project for tests""" - project_name = "test_project" + project_name = project_name or "test_project" project_id = db.new_project(project_name) return project_name, project_id -def get_test_worker(db: MephistoDB) -> Tuple[str, str]: +def get_test_worker(db: MephistoDB, worker_name: Optional[str] = None) -> Tuple[str, str]: """Helper to create a worker for tests""" - worker_name = "test_worker" + worker_name = worker_name or "test_worker" provider_type = "mock" worker_id = db.new_worker(worker_name, provider_type) return worker_name, worker_id -def get_test_requester(db: MephistoDB) -> Tuple[str, str]: +def get_test_requester( + db: MephistoDB, + requester_name: Optional[str] = None, + provider_type: Optional[str] = None, +) -> Tuple[str, str]: """Helper to create a requester for tests""" - requester_name = "test_requester" - provider_type = "mock" + requester_name = requester_name or "test_requester" + provider_type = provider_type or "mock" requester_id = db.new_requester(requester_name, provider_type) return requester_name, requester_id @@ -78,18 +82,26 @@ def get_mock_requester(db) -> "Requester": return mock_requesters[0] -def get_test_task(db: MephistoDB) -> Tuple[str, str]: +def get_test_task(db: MephistoDB, task_name: Optional[str] = None) -> Tuple[str, str]: """Helper to create a task for tests""" - task_name = "test_task" + task_name = task_name or "test_task" task_type = "mock" task_id = db.new_task(task_name, task_type) return task_name, task_id -def get_test_task_run(db: MephistoDB) -> str: +def get_test_task_run( + db: MephistoDB, + task_id: Optional[str] = None, + requester_id: Optional[str] = None, +) -> str: """Helper to create a task run for tests""" - task_name, task_id = get_test_task(db) - requester_name, requester_id = get_test_requester(db) + if not task_id: + _, task_id = get_test_task(db) + + if not requester_id: + _, requester_id = get_test_requester(db) + init_params = OmegaConf.to_yaml(OmegaConf.structured(MOCK_CONFIG)) return db.new_task_run(task_id, requester_id, json.dumps(init_params), "mock", "mock") @@ -191,7 +203,7 @@ def make_completed_unit(db: MephistoDB) -> str: return unit.db_id -def get_test_qualification(db: MephistoDB, name="test_qualification") -> str: +def get_test_qualification(db: MephistoDB, name: str = "test_qualification") -> str: return db.make_qualification(name) diff --git a/pytest.ini b/pytest.ini index 112e1c6a8..63b774de0 100644 --- a/pytest.ini +++ b/pytest.ini @@ -9,5 +9,7 @@ addopts = -ra -q -s markers = req_creds: test which requires credentials prolific: Prolific tests + utils: Mephisto utils + db_data_porter: DB Data Porter tool testpaths = test diff --git a/test/tools/db_data_porter/__init__.py b/test/tools/db_data_porter/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/tools/db_data_porter/conflict_resolvers/__init__.py b/test/tools/db_data_porter/conflict_resolvers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/tools/db_data_porter/conflict_resolvers/test_default_merge_conflict_resolver.py b/test/tools/db_data_porter/conflict_resolvers/test_default_merge_conflict_resolver.py new file mode 100644 index 000000000..09aaff52a --- /dev/null +++ b/test/tools/db_data_porter/conflict_resolvers/test_default_merge_conflict_resolver.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import shutil +import tempfile +import unittest +from copy import deepcopy +from datetime import datetime +from typing import ClassVar +from typing import Type + +import pytest + +from mephisto.abstractions.database import MephistoDB +from mephisto.abstractions.databases.local_database import LocalMephistoDB +from mephisto.tools.db_data_porter.conflict_resolvers.default_merge_conflict_resolver import ( + DefaultMergeConflictResolver, +) + + +@pytest.mark.db_data_porter +class TestDefaultMergeConflictResolver(unittest.TestCase): + DB_CLASS: ClassVar[Type["MephistoDB"]] = LocalMephistoDB + + def setUp(self): + # Configure test database + self.data_dir = tempfile.mkdtemp() + database_path = os.path.join(self.data_dir, "test_mephisto.db") + + assert self.DB_CLASS is not None, "Did not specify db to use" + self.db = self.DB_CLASS(database_path) + + # Init conflict resolver instance + self.conflict_resolver = DefaultMergeConflictResolver(self.db, "mephisto") + + def tearDown(self): + # Clean test database + self.db.shutdown() + shutil.rmtree(self.data_dir, ignore_errors=True) + + def test_resolve_with_default(self, *args): + db_row = { + "project_id": 1, + "project_name": "test_project_name", + "creation_date": "2001-01-01 01:01:01.001", + } + dump_row = { + "project_id": 2, + "project_name": "test_project_name", + "creation_date": "1999-01-01 01:01:01.001", + } + expecting_result = deepcopy(db_row) + # Earlier data from two + expecting_result["creation_date"] = datetime(1999, 1, 1, 1, 1, 1, 1000) + + result = self.conflict_resolver.resolve( + table_name="project", + table_pk_field_name="project_id", + db_row=db_row, + dump_row=dump_row, + ) + + self.assertEqual(result, expecting_result) + + def test_resolve_with_granted_qualifications(self, *args): + db_row = { + "granted_qualification_id": 1, + "qualification_id": 1, + "worker_id": 1, + "value": 999, + "creation_date": "2001-01-01 01:01:01.001", + "update_date": "1999-01-01 01:01:01.001", + } + dump_row = { + "granted_qualification_id": 2, + "qualification_id": 1, + "worker_id": 1, + "value": 1, + "creation_date": "1999-01-01 01:01:01.001", + "update_date": "2001-01-01 01:01:01.001", + } + expecting_result = deepcopy(dump_row) + # Original id + expecting_result["granted_qualification_id"] = db_row["granted_qualification_id"] + # Earlier data from two + expecting_result["creation_date"] = datetime(1999, 1, 1, 1, 1, 1, 1000) + # Greater data from two + expecting_result["update_date"] = datetime(2001, 1, 1, 1, 1, 1, 1000) + + result = self.conflict_resolver.resolve( + table_name="granted_qualifications", + table_pk_field_name="granted_qualification_id", + db_row=db_row, + dump_row=dump_row, + ) + + self.assertEqual(result, expecting_result) + + def test_resolve_with_workers(self, *args): + db_row = { + "worker_id": 1, + "worker_name": "test_worker_name", + "is_blocked": 0, # False + "creation_date": "2001-01-01 01:01:01.001", + } + dump_row = { + "worker_id": 2, + "worker_name": "test_worker_name", + "is_blocked": 1, # True + "creation_date": "1999-01-01 01:01:01.001", + } + expecting_result = deepcopy(dump_row) + # Original id + expecting_result["worker_id"] = db_row["worker_id"] + # Blocked one + expecting_result["is_blocked"] = dump_row["is_blocked"] + # Earlier data from two + expecting_result["creation_date"] = datetime(1999, 1, 1, 1, 1, 1, 1000) + + # Simulate Prolific datastore + result = DefaultMergeConflictResolver(self.db, "prolific").resolve( + table_name="workers", + table_pk_field_name="worker_id", + db_row=db_row, + dump_row=dump_row, + ) + + self.assertEqual(result, expecting_result) diff --git a/test/tools/db_data_porter/test_backups.py b/test/tools/db_data_porter/test_backups.py new file mode 100644 index 000000000..f74ec70a2 --- /dev/null +++ b/test/tools/db_data_porter/test_backups.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import io +import os +import shutil +import sys +import tempfile +import unittest +from unittest.mock import patch + +import pytest + +from mephisto.tools.db_data_porter.backups import make_backup_file_path_by_timestamp +from mephisto.tools.db_data_porter.backups import make_full_data_dir_backup +from mephisto.tools.db_data_porter.backups import restore_from_backup + + +@pytest.mark.db_data_porter +class TestBackups(unittest.TestCase): + def setUp(self): + self.data_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.data_dir, ignore_errors=True) + + def test_make_backup_file_path_by_timestamp(self, *args): + timestamp = "2001_01_01_01_01_01" + path = make_backup_file_path_by_timestamp( + backup_dir=self.data_dir, + timestamp=timestamp, + ) + + self.assertEqual(path, os.path.join(self.data_dir, f"{timestamp}_mephisto_backup.zip")) + + @patch("mephisto.tools.db_data_porter.backups.get_data_dir") + def test_make_full_data_dir_backup(self, mock_get_data_dir, *args): + mock_get_data_dir.return_value = self.data_dir + + test_file_name = "test_backup.zip" + backup_file_path = os.path.join(self.data_dir, test_file_name) + + self.assertFalse(os.path.exists(backup_file_path)) + + make_full_data_dir_backup(backup_file_path) + + self.assertTrue(os.path.exists(backup_file_path)) + + @patch("mephisto.tools.db_data_porter.backups.get_data_dir") + def test_restore_from_backup_without_deleting_backup(self, mock_get_data_dir, *args): + mock_get_data_dir.return_value = self.data_dir + + test_file_name = "test_backup.zip" + restore_dir_name = "test_restore" + backup_file_path = os.path.join(self.data_dir, test_file_name) + extract_dir = os.path.join(self.data_dir, restore_dir_name) + + make_full_data_dir_backup(backup_file_path) + + self.assertFalse(os.path.exists(extract_dir)) + + restore_from_backup(backup_file_path, extract_dir) + + self.assertTrue(os.path.exists(extract_dir)) + self.assertTrue(os.path.exists(backup_file_path)) + + @patch("mephisto.tools.db_data_porter.backups.get_data_dir") + def test_restore_from_backup_with_deleting_backup(self, mock_get_data_dir, *args): + mock_get_data_dir.return_value = self.data_dir + + test_file_name = "test_backup.zip" + restore_dir_name = "test_restore" + backup_file_path = os.path.join(self.data_dir, test_file_name) + extract_dir = os.path.join(self.data_dir, restore_dir_name) + + make_full_data_dir_backup(backup_file_path) + + self.assertFalse(os.path.exists(extract_dir)) + + restore_from_backup(backup_file_path, extract_dir, remove_backup=True) + + self.assertTrue(os.path.exists(extract_dir)) + self.assertFalse(os.path.exists(backup_file_path)) + + def test_restore_from_backup_error(self, *args): + test_file_name = "test_backup.zip" + restore_dir_name = "test_restore" + backup_file_path = os.path.join(self.data_dir, test_file_name) + extract_dir = os.path.join(self.data_dir, restore_dir_name) + + self.assertFalse(os.path.exists(extract_dir)) + + with self.assertRaises(SystemExit) as cm: + captured_print_output = io.StringIO() + sys.stdout = captured_print_output + restore_from_backup(backup_file_path, extract_dir, remove_backup=True) + sys.stdout = sys.__stdout__ + + self.assertEqual(cm.exception.code, None) + self.assertFalse(os.path.exists(extract_dir)) + self.assertFalse(os.path.exists(backup_file_path)) + self.assertIn("Could not restore backup", captured_print_output.getvalue()) + self.assertIn(backup_file_path, captured_print_output.getvalue()) diff --git a/test/tools/db_data_porter/test_db_data_porter.py b/test/tools/db_data_porter/test_db_data_porter.py new file mode 100644 index 000000000..6ff605792 --- /dev/null +++ b/test/tools/db_data_porter/test_db_data_porter.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import shutil +import tempfile +import unittest +import zipfile +from typing import ClassVar +from typing import Type +from unittest.mock import patch + +import pytest + +from mephisto.abstractions.database import MephistoDB +from mephisto.abstractions.databases.local_database import LocalMephistoDB +from mephisto.tools.db_data_porter import DBDataPorter +from mephisto.utils import db as db_utils +from mephisto.utils.testing import get_test_qualification +from mephisto.utils.testing import get_test_requester +from mephisto.utils.testing import get_test_task_run +from mephisto.utils.testing import get_test_worker +from mephisto.utils.testing import grant_test_qualification +from mephisto.utils.testing import make_completed_unit + +FILE_TIMESTAMP = "2001_01_01_01_01_01" + + +@pytest.mark.db_data_porter +class TestDBDataPorter(unittest.TestCase): + DB_CLASS: ClassVar[Type["MephistoDB"]] = LocalMephistoDB + + def setUp(self): + # Configure test database + self.data_dir = tempfile.mkdtemp() + database_path = os.path.join(self.data_dir, "test_mephisto.db") + + assert self.DB_CLASS is not None, "Did not specify db to use" + self.db = self.DB_CLASS(database_path) + + # Backup dir + self.backup_dir = os.path.join(self.data_dir, "backup") + os.makedirs(self.backup_dir, exist_ok=True) + + # Restore dir + self.restore_dir = os.path.join(self.data_dir, "restore") + os.makedirs(self.restore_dir, exist_ok=True) + + # Export dir + self.export_dir = os.path.join(self.data_dir, "export") + os.makedirs(self.export_dir, exist_ok=True) + + # Init db data porter instance + self.porter = DBDataPorter(self.db) + + def tearDown(self): + # Clean test database + self.db.shutdown() + shutil.rmtree(self.data_dir, ignore_errors=True) + + @patch("mephisto.tools.db_data_porter.backups.get_data_dir") + @patch("mephisto.tools.db_data_porter.db_data_porter.DBDataPorter._make_export_timestamp") + def test_create_backup(self, mock__make_export_timestamp, mock_get_data_dir, *args): + mock__make_export_timestamp.return_value = FILE_TIMESTAMP + mock_get_data_dir.return_value = self.data_dir + + with patch( + "mephisto.tools.db_data_porter.db_data_porter.DBDataPorter._get_backup_dir" + ) as mock__get_backup_dir: + mock__get_backup_dir.return_value = self.backup_dir + + files_count_before = len([fn for fn in os.listdir(self.backup_dir)]) + + self.porter.create_backup() + + backup_filenames = os.listdir(self.backup_dir) + files_count_after = len([fn for fn in backup_filenames]) + + self.assertEqual(files_count_before, 0) + self.assertEqual(files_count_after, 1) + self.assertEqual(backup_filenames[0], f"{FILE_TIMESTAMP}_mephisto_backup.zip") + + @patch("mephisto.tools.db_data_porter.backups.get_data_dir") + @patch("mephisto.tools.db_data_porter.db_data_porter.get_data_dir") + @patch("mephisto.tools.db_data_porter.db_data_porter.DBDataPorter._ask_user_if_they_are_sure") + def test_restore_from_backup( + self, + mock__ask_user_if_they_are_sure, + mock_get_data_dir, + mock_backups_get_data_dir, + *args, + ): + mock__ask_user_if_they_are_sure.return_value = True + mock_get_data_dir.return_value = self.restore_dir + mock_backups_get_data_dir.return_value = self.data_dir + + files_count_before = len([fn for fn in os.listdir(self.restore_dir)]) + backup_file_path = self.porter.create_backup() + + self.porter.restore_from_backup(backup_file_name_or_path=backup_file_path) + + files_count_after = len([fn for fn in os.listdir(self.restore_dir)]) + + self.assertEqual(files_count_before, 0) + self.assertEqual(files_count_after, 4) + + @patch("mephisto.tools.db_data_porter.db_data_porter.DBDataPorter._make_export_timestamp") + @patch("mephisto.tools.db_data_porter.export_dump.get_data_dir") + @patch("mephisto.tools.db_data_porter.db_data_porter.DBDataPorter._get_export_dir") + @patch("mephisto.tools.db_data_porter.db_data_porter.DBDataPorter._ask_user_if_they_are_sure") + def test_export_dump_full( + self, + mock__ask_user_if_they_are_sure, + mock__get_export_dir, + mock_get_data_dir, + mock__make_export_timestamp, + *args, + ): + mock__ask_user_if_they_are_sure.return_value = True + mock__get_export_dir.return_value = self.export_dir + mock_get_data_dir.return_value = self.data_dir + mock__make_export_timestamp.return_value = FILE_TIMESTAMP + + # Create entries in Mephisto DB + _, requester_id = get_test_requester(self.db) + task_run_id_1 = get_test_task_run(self.db, requester_id=requester_id) + _, worker_id = get_test_worker(self.db) + make_completed_unit(self.db) + qualification_id = get_test_qualification(self.db, "qual_1") + grant_test_qualification(self.db, worker_id=worker_id, qualification_id=qualification_id) + + files_count_before = len([fn for fn in os.listdir(self.export_dir)]) + + # Create dump + export_results = self.porter.export_dump() + + files_count_after = len([fn for fn in os.listdir(self.export_dir)]) + + # Test fiels + self.assertEqual(files_count_before, 0) + self.assertEqual(files_count_after, 1) + self.assertIn(f"export/{FILE_TIMESTAMP}_mephisto_dump.zip", export_results["dump_path"]) + self.assertEqual(export_results["backup_path"], None) + + # Test dump archive + with zipfile.ZipFile(export_results["dump_path"]) as archive: + dump_name = os.path.basename(os.path.splitext(export_results["dump_path"])[0]) + json_dump_file_name = f"{dump_name}.json" + + with archive.open(json_dump_file_name) as f: + dump_file_data = json.loads(f.read()) + + # Test main keys + self.assertIn("dump_metadata", dump_file_data) + self.assertIn("mephisto", dump_file_data) + + # Test `dump_metadata` + self.assertEqual(dump_file_data["dump_metadata"]["export_options"], None) + self.assertEqual( + dump_file_data["dump_metadata"]["migrations"], + {"mephisto": "20240418_data_porter_feature"}, + ) + self.assertEqual(dump_file_data["dump_metadata"]["pk_substitutions"], {}) + self.assertEqual(dump_file_data["dump_metadata"]["timestamp"], FILE_TIMESTAMP) + + # Test `mephisto` + mephisto_dump = dump_file_data["mephisto"] + + tables_without_task_run_id = [ + "workers", + "tasks", + "requesters", + "qualifications", + "granted_qualifications", + ] + + for table_name in mephisto_dump.keys(): + if table_name == "imported_data": + continue + + table_data = mephisto_dump[table_name] + if table_name in ["onboarding_agents", "unit_review", "projects"]: + self.assertEqual(len(table_data), 0) + else: + if table_name not in tables_without_task_run_id: + self.assertEqual(table_data[0]["task_run_id"], task_run_id_1) + self.assertEqual(len(table_data), 1) + + @patch("mephisto.tools.db_data_porter.backups.get_data_dir") + @patch("mephisto.tools.db_data_porter.db_data_porter.DBDataPorter._make_export_timestamp") + @patch("mephisto.tools.db_data_porter.export_dump.get_data_dir") + @patch("mephisto.tools.db_data_porter.db_data_porter.DBDataPorter._get_export_dir") + @patch("mephisto.tools.db_data_porter.db_data_porter.DBDataPorter._ask_user_if_they_are_sure") + def test_import_dump_full( + self, + mock__ask_user_if_they_are_sure, + mock__get_export_dir, + mock_get_data_dir, + mock__make_export_timestamp, + mock_backups_get_data_dir, + *args, + ): + mock__ask_user_if_they_are_sure.return_value = True + mock__get_export_dir.return_value = self.export_dir + mock_get_data_dir.return_value = self.data_dir + mock__make_export_timestamp.return_value = FILE_TIMESTAMP + mock_backups_get_data_dir.return_value = self.data_dir + + # Create entries in Mephisto DB + _, requester_id = get_test_requester(self.db) + task_run_id_1 = get_test_task_run(self.db, requester_id=requester_id) + _, worker_id = get_test_worker(self.db) + unit_id = make_completed_unit(self.db) + qualification_id = get_test_qualification(self.db, "qual_1") + grant_test_qualification(self.db, worker_id=worker_id, qualification_id=qualification_id) + + # Create dump + export_results = self.porter.export_dump() + dump_archive_file_path = export_results["dump_path"] + + # Clear db + db_utils.delete_entire_exported_data(self.db) + + # Test clear database + table_names = db_utils.get_list_of_tables_to_export(self.db) + for table_name in table_names: + rows = db_utils.select_all_table_rows(self.db, table_name) + self.assertEqual(len(rows), 0) + + # Import dump + results = self.porter.import_dump(dump_archive_file_name_or_path=dump_archive_file_path) + + # Test imported data into clear database + self.assertEqual(results["imported_task_runs_number"], 1) + table_names = db_utils.get_list_of_tables_to_export(self.db) + for table_name in table_names: + rows = db_utils.select_all_table_rows(self.db, table_name) + if table_name == "imported_data": + self.assertEqual(len(rows), 6) + elif table_name in ["projects", "onboarding_agents", "unit_review"]: + self.assertEqual(len(rows), 0) + else: + self.assertEqual(len(rows), 1) + + if table_name == "task_runs": + self.assertEqual(rows[0]["task_run_id"], task_run_id_1) + elif table_name == "requesters": + self.assertEqual(rows[0]["requester_id"], requester_id) + elif table_name == "workers": + self.assertEqual(rows[0]["worker_id"], worker_id) + elif table_name == "qualifications": + self.assertEqual(rows[0]["qualification_id"], qualification_id) + elif table_name == "units": + self.assertEqual(rows[0]["unit_id"], unit_id) diff --git a/test/utils/test_db.py b/test/utils/test_db.py new file mode 100644 index 000000000..0630b2973 --- /dev/null +++ b/test/utils/test_db.py @@ -0,0 +1,726 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import shutil +import sqlite3 +import tempfile +import unittest +from datetime import timedelta +from typing import ClassVar +from typing import Type +from unittest.mock import patch + +import pytest + +from mephisto.abstractions.database import MephistoDB +from mephisto.abstractions.databases.local_database import is_unique_failure +from mephisto.abstractions.databases.local_database import LocalMephistoDB +from mephisto.abstractions.providers.mock.mock_datastore import MockDatastore +from mephisto.abstractions.providers.prolific.prolific_datastore import ProlificDatastore +from mephisto.data_model.requester import Requester +from mephisto.data_model.task import Task +from mephisto.data_model.task_run import TaskRun +from mephisto.utils import db as db_utils +from mephisto.utils.db import EntryAlreadyExistsException +from mephisto.utils.misc import serialize_date_to_python +from mephisto.utils.testing import get_test_assignment +from mephisto.utils.testing import get_test_project +from mephisto.utils.testing import get_test_qualification +from mephisto.utils.testing import get_test_requester +from mephisto.utils.testing import get_test_task +from mephisto.utils.testing import get_test_task_run +from mephisto.utils.testing import get_test_unit +from mephisto.utils.testing import get_test_worker +from mephisto.utils.testing import grant_test_qualification +from mephisto.utils.testing import make_completed_unit + + +@pytest.mark.utils +class TestUtilsDB(unittest.TestCase): + DB_CLASS: ClassVar[Type["MephistoDB"]] = LocalMephistoDB + + def setUp(self): + # Configure test database + self.data_dir = tempfile.mkdtemp() + database_path = os.path.join(self.data_dir, "test_mephisto.db") + assert self.DB_CLASS is not None, "Did not specify db to use" + self.db = self.DB_CLASS(database_path) + + def tearDown(self): + # Clean test database + self.db.shutdown() + shutil.rmtree(self.data_dir, ignore_errors=True) + + def test__select_rows_from_table_related_to_task(self, *args): + task_1_name = "task_1" + task_2_name = "task_2" + + _, requester_id = get_test_requester(self.db) + _, task_1_id = get_test_task(self.db, task_1_name) + _, task_2_id = get_test_task(self.db, task_2_name) + get_test_task_run(self.db, task_1_id, requester_id) + get_test_task_run(self.db, task_2_id, requester_id) + get_test_task_run(self.db, task_2_id, requester_id) + + rows_for_task_1 = db_utils._select_rows_from_table_related_to_task( + self.db, + "task_runs", + [task_1_id], + ) + rows_for_task_2 = db_utils._select_rows_from_table_related_to_task( + self.db, + "task_runs", + [task_2_id], + ) + + self.assertEqual(len(rows_for_task_1), 1) + self.assertEqual(len(rows_for_task_2), 2) + + def test_select_rows_from_table_related_to_task_run(self, *args): + task_1_name = "task_1" + task_2_name = "task_2" + + _, requester_id = get_test_requester(self.db) + _, task_1_id = get_test_task(self.db, task_1_name) + _, task_2_id = get_test_task(self.db, task_2_name) + task_run_1_id = get_test_task_run(self.db, task_1_id, requester_id) + task_run_2_id = get_test_task_run(self.db, task_2_id, requester_id) + + task_run_1 = TaskRun.get(self.db, task_run_1_id) + task_run_2 = TaskRun.get(self.db, task_run_2_id) + get_test_assignment(self.db, task_run_1) + get_test_assignment(self.db, task_run_1) + get_test_assignment(self.db, task_run_2) + + rows_for_task_run_1 = db_utils.select_rows_from_table_related_to_task_run( + self.db, + "assignments", + [task_run_1_id], + ) + rows_for_task_run_2 = db_utils.select_rows_from_table_related_to_task_run( + self.db, + "assignments", + [task_run_2_id], + ) + + self.assertEqual(len(rows_for_task_run_1), 2) + self.assertEqual(len(rows_for_task_run_2), 1) + + def test_serialize_data_for_table(self, *args): + task_id = "111111111111111111" + task_name = "task_1" + task_type = "mock" + rows = [ + { + "task_id": task_id, + "task_name": task_name, + "task_type": task_type, + "project_id": None, + "parent_task_id": None, + "creation_date": "2001-01-01 01:01:01.001", + }, + ] + + serialized_rows = db_utils.serialize_data_for_table(rows) + + self.assertEqual( + serialized_rows, + [ + { + "task_id": task_id, + "task_name": task_name, + "task_type": task_type, + "project_id": None, + "parent_task_id": None, + "creation_date": "2001-01-01T01:01:01.001000", + }, + ], + ) + + def test_make_randomized_int_id(self, *args): + value_1 = db_utils.make_randomized_int_id() + value_2 = db_utils.make_randomized_int_id() + + self.assertNotEqual(value_1, value_2) + self.assertGreater(value_1, db_utils.SQLITE_ID_MIN) + self.assertGreater(value_2, db_utils.SQLITE_ID_MIN) + self.assertLess(value_1, db_utils.SQLITE_ID_MAX) + self.assertLess(value_2, db_utils.SQLITE_ID_MAX) + + def test_get_task_ids_by_task_names(self, *args): + task_1_name = "task_1" + task_2_name = "task_2" + + _, task_1_id = get_test_task(self.db, task_1_name) + _, task_2_id = get_test_task(self.db, task_2_name) + + task_ids = db_utils.get_task_ids_by_task_names(self.db, [task_1_name, task_2_name]) + + self.assertEqual(task_ids, [task_1_id, task_2_id]) + + def test_get_task_run_ids_by_task_ids(self, *args): + task_1_name = "task_1" + task_2_name = "task_2" + + _, requester_id = get_test_requester(self.db) + _, task_1_id = get_test_task(self.db, task_1_name) + _, task_2_id = get_test_task(self.db, task_2_name) + task_run_1_id = get_test_task_run(self.db, task_1_id, requester_id) + task_run_2_id = get_test_task_run(self.db, task_2_id, requester_id) + task_run_3_id = get_test_task_run(self.db, task_2_id, requester_id) + + task_run_ids = db_utils.get_task_run_ids_by_task_ids(self.db, [task_1_id, task_2_id]) + + self.assertEqual( + sorted(task_run_ids), + sorted([task_run_1_id, task_run_2_id, task_run_3_id]), + ) + + def test_get_task_run_ids_ids_by_labels(self, *args): + test_label = "test_label" + + task_1_name = "task_1" + task_2_name = "task_2" + + _, requester_id = get_test_requester(self.db) + _, task_1_id = get_test_task(self.db, task_1_name) + _, task_2_id = get_test_task(self.db, task_2_name) + task_run_1_id = get_test_task_run(self.db, task_1_id, requester_id) + task_run_2_id = get_test_task_run(self.db, task_2_id, requester_id) + task_run_3_id = get_test_task_run(self.db, task_2_id, requester_id) + + db_utils.insert_new_row_in_table( + db=self.db, + table_name="imported_data", + row=dict( + source_file_name="test", + data_labels=f'["{test_label}"]', + table_name="task_runs", + unique_field_names='["task_run_id"]', + unique_field_values=f'["{task_run_1_id}", "{task_run_3_id}"]', + ), + ) + + task_run_ids = db_utils.get_task_run_ids_by_labels(self.db, [test_label]) + + self.assertEqual(sorted(task_run_ids), sorted([task_run_1_id, task_run_3_id])) + + def test_get_table_pk_field_name(self, *args): + table_names = ["tasks", "task_runs", "units", "unit_review"] + pk_fields = [db_utils.get_table_pk_field_name(self.db, t) for t in table_names] + + self.assertEqual(pk_fields, ["task_id", "task_run_id", "unit_id", "id"]) + + def test_select_all_table_rows(self, *args): + # Empty table + rows = db_utils.select_all_table_rows(self.db, "projects", order_by="creation_date") + self.assertEqual(len(rows), 0) + + # Table with 2 entries + get_test_project(self.db, "project_1") + get_test_project(self.db, "project_2") + rows = db_utils.select_all_table_rows(self.db, "projects", order_by="creation_date") + self.assertEqual(len(rows), 2) + self.assertEqual(rows[0]["project_name"], "project_1") + self.assertEqual(rows[1]["project_name"], "project_2") + + def test_select_rows_by_list_of_field_values(self, *args): + qualification_1_id = get_test_qualification(self.db, "qual_1") + qualification_2_id = get_test_qualification(self.db, "qual_2") + _, worker_1_id = get_test_worker(self.db, worker_name="worker_1") + _, worker_2_id = get_test_worker(self.db, worker_name="worker_2") + grant_test_qualification( + self.db, + worker_id=worker_1_id, + qualification_id=qualification_2_id, + value=1, + ) + grant_test_qualification( + self.db, + worker_id=worker_1_id, + qualification_id=qualification_1_id, + value=2, + ) + grant_test_qualification( + self.db, + worker_id=worker_2_id, + qualification_id=qualification_2_id, + value=3, + ) + + rows = db_utils.select_rows_by_list_of_field_values( + self.db, + "granted_qualifications", + field_names=["worker_id", "qualification_id"], + field_values=[ + [worker_1_id, worker_2_id], + [qualification_2_id], + ], + order_by="creation_date", + ) + + self.assertEqual(len(rows), 2) + self.assertEqual(rows[0]["value"], 1) + self.assertEqual(rows[1]["value"], 3) + + def test_delete_exported_data_without_fk_constraints(self, *args): + task_1_name = "task_1" + + _, requester_id = get_test_requester(self.db) + _, task_1_id = get_test_task(self.db, task_1_name) + get_test_task_run(self.db, task_1_id, requester_id) + + task_rows = db_utils.select_all_table_rows(self.db, "tasks") + task_run_rows = db_utils.select_all_table_rows(self.db, "task_runs") + task_rows_len_before = len(task_rows) + task_runs_rows_len_before = len(task_run_rows) + + dump = { + "tasks": db_utils.serialize_data_for_table(task_rows), + "task_runs": db_utils.serialize_data_for_table(task_run_rows), + } + + db_utils.delete_exported_data_without_fk_constraints( + self.db, + db_dump=dump, + table_names_can_be_cleaned=["tasks"], + ) + + task_rows_len_after = len(db_utils.select_all_table_rows(self.db, "tasks")) + task_runs_rows_len_after = len(db_utils.select_all_table_rows(self.db, "task_runs")) + + self.assertEqual(task_rows_len_before, 1) + self.assertEqual(task_runs_rows_len_before, 1) + self.assertEqual(task_rows_len_after, 0) + self.assertEqual(task_runs_rows_len_after, 1) + + def test_delete_entire_exported_data(self, *args): + get_test_unit(self.db) + + db_utils.delete_entire_exported_data(self.db) + + table_names = db_utils.get_list_of_db_table_names(self.db) + + for table_name in table_names: + rows = db_utils.select_all_table_rows(self.db, table_name) + if table_name == "migrations": + self.assertGreater(len(rows), 0) + else: + self.assertEqual(len(rows), 0) + + def test_get_list_of_provider_types(self, *args): + requester_name_1 = "requester_1" + requester_name_2 = "requester_2" + provider_type_1 = "mock" + provider_type_2 = "prolific" + + get_test_requester(self.db, requester_name_1, provider_type_1) + get_test_requester(self.db, requester_name_2, provider_type_2) + + provider_types = db_utils.get_list_of_provider_types(self.db) + + self.assertEqual(sorted(provider_types), sorted([provider_type_1, provider_type_2])) + + def test_get_latest_row_from_table(self, *args): + task_1_name = "task_1" + task_2_name = "task_2" + + _, task_1_id = get_test_task(self.db, task_1_name) + _, task_2_id = get_test_task(self.db, task_2_name) + + task_2 = Task.get(self.db, task_2_id) + + latest_task_row = db_utils.get_latest_row_from_table(self.db, "tasks", "creation_date") + + self.assertEqual(latest_task_row["task_id"], task_2_id) + self.assertEqual( + serialize_date_to_python(latest_task_row["creation_date"]), + task_2.creation_date, + ) + + def test_apply_migrations(self, *args): + new_table_name = "TEST_TABLE" + test_migration_name = "test_migration" + migrations = { + test_migration_name: f""" + CREATE TABLE {new_table_name} ( + id INTEGER PRIMARY KEY, + creation_date DATETIME DEFAULT CURRENT_TIMESTAMP + ); + """ + } + + db_utils.apply_migrations(self.db, migrations) + + table_names = db_utils.get_list_of_db_table_names(self.db) + migration_rows = db_utils.select_all_table_rows(self.db, "migrations") + migration_names = [m["name"] for m in migration_rows] + + self.assertIn(new_table_name, table_names) + self.assertIn(test_migration_name, migration_names) + + def test_get_list_of_db_table_names(self, *args): + table_names = db_utils.get_list_of_db_table_names(self.db) + + self.assertEqual( + sorted(table_names), + sorted( + [ + "agents", + "assignments", + "granted_qualifications", + "imported_data", + "migrations", + "onboarding_agents", + "projects", + "qualifications", + "requesters", + "sqlite_sequence", + "task_runs", + "tasks", + "unit_review", + "units", + "workers", + ] + ), + ) + + def test_get_list_of_tables_to_export(self, *args): + table_names = db_utils.get_list_of_tables_to_export(self.db) + + self.assertEqual( + sorted(table_names), + sorted( + [ + "agents", + "assignments", + "granted_qualifications", + "imported_data", + "onboarding_agents", + "projects", + "qualifications", + "requesters", + "task_runs", + "tasks", + "unit_review", + "units", + "workers", + ] + ), + ) + + def test_get_list_of_available_labels(self, *args): + label_1 = "test_label_1" + label_2 = "test_label_2" + label_3 = "test_label_3" + + db_utils.insert_new_row_in_table( + db=self.db, + table_name="imported_data", + row=dict( + source_file_name="test", + data_labels=f'["{label_1}"]', + table_name="task_runs", + unique_field_names='["task_run_id"]', + unique_field_values=f'["1","2"]', + ), + ) + db_utils.insert_new_row_in_table( + db=self.db, + table_name="imported_data", + row=dict( + source_file_name="test", + data_labels=f'["{label_1}","{label_2}","{label_3}"]', + table_name="tasks", + unique_field_names='["task_id"]', + unique_field_values=f'["1","2"]', + ), + ) + + available_labels = db_utils.get_list_of_available_labels(self.db) + + self.assertEqual(sorted(available_labels), sorted([label_1, label_2, label_3])) + + def test_check_if_row_with_params_exists(self, *args): + requester_name_1 = "requester_1" + provider_type_1 = "mock" + + _, requester_id = get_test_requester(self.db, requester_name_1, provider_type_1) + + already_exists = db_utils.check_if_row_with_params_exists( + db=self.db, + table_name="requesters", + params={ + "requester_id": requester_id, + "provider_type": provider_type_1, + }, + select_field="requester_id", + ) + + not_exists = db_utils.check_if_row_with_params_exists( + db=self.db, + table_name="requesters", + params={ + "requester_id": "wrong_id", + "provider_type": provider_type_1, + }, + select_field="requester_id", + ) + + self.assertTrue(already_exists) + self.assertFalse(not_exists) + + def test_get_providers_datastores(self, *args): + requester_name_1 = "requester_1" + requester_name_2 = "requester_2" + provider_type_1 = "mock" + provider_type_2 = "prolific" + + get_test_requester(self.db, requester_name_1, provider_type_1) + get_test_requester(self.db, requester_name_2, provider_type_2) + + datastores = db_utils.get_providers_datastores(self.db) + + self.assertEqual(len(datastores.keys()), 2) + self.assertIn(provider_type_1, datastores) + self.assertIn(provider_type_2, datastores) + self.assertTrue(isinstance(datastores[provider_type_1], MockDatastore)) + self.assertTrue(isinstance(datastores[provider_type_2], ProlificDatastore)) + + def test_db_or_datastore_to_dict(self, *args): + get_test_requester(self.db) + get_test_worker(self.db) + + db_dump = db_utils.db_or_datastore_to_dict(self.db) + + table_names = db_utils.get_list_of_tables_to_export(self.db) + + for table_name in table_names: + self.assertIn(table_name, db_dump) + table_data = db_dump[table_name] + if table_name in ["requesters", "workers"]: + self.assertEqual(len(table_data), 1) + self.assertGreater(len(table_data[0].keys()), 0) + else: + self.assertEqual(len(table_data), 0) + + def test_mephisto_db_to_dict_for_task_runs(self, *args): + tables_without_task_run_id = [ + "workers", + "tasks", + "requesters", + "qualifications", + "granted_qualifications", + ] + + _, requester_id = get_test_requester(self.db) + + table_names = db_utils.get_list_of_tables_to_export(self.db) + + # First TaskRun + task_run_id_1 = get_test_task_run(self.db, requester_id=requester_id) + _, worker_id = get_test_worker(self.db) + make_completed_unit(self.db) + qualification_id = get_test_qualification(self.db, "qual_1") + grant_test_qualification(self.db, worker_id=worker_id, qualification_id=qualification_id) + + db_dump_for_task_run_1 = db_utils.mephisto_db_to_dict_for_task_runs( + self.db, + task_run_ids=[task_run_id_1], + ) + for table_name in table_names: + if table_name == "imported_data": + continue + + table_data = db_dump_for_task_run_1[table_name] + if table_name in ["onboarding_agents", "unit_review", "projects"]: + self.assertEqual(len(table_data), 0) + else: + if table_name not in tables_without_task_run_id: + self.assertEqual(table_data[0]["task_run_id"], task_run_id_1) + self.assertEqual(len(table_data), 1) + + # Second TaskRun + _, task_2_id = get_test_task(self.db, "task_2") + task_run_id_2 = get_test_task_run(self.db, task_id=task_2_id, requester_id=requester_id) + + db_dump_for_task_run_2 = db_utils.mephisto_db_to_dict_for_task_runs( + self.db, + task_run_ids=[task_run_id_2], + ) + for table_name in table_names: + if table_name == "imported_data": + continue + + table_data = db_dump_for_task_run_2[table_name] + if table_name in ["task_runs", "tasks", "requesters"]: + if table_name not in tables_without_task_run_id: + self.assertEqual(table_data[0]["task_run_id"], task_run_id_2) + self.assertEqual(len(table_data), 1) + else: + self.assertEqual(len(table_data), 0) + + def test_select_task_run_ids_since_date(self, *args): + _, requester_id = get_test_requester(self.db) + _, task_1_id = get_test_task(self.db, "task_1") + task_run_1_id = get_test_task_run(self.db, task_1_id, requester_id) + task_run_2_id = get_test_task_run(self.db, task_1_id, requester_id) + + task_run_1 = TaskRun.get(self.db, task_run_1_id) + task_run_2 = TaskRun.get(self.db, task_run_2_id) + + since_task_run_1_created = task_run_1.creation_date - timedelta(milliseconds=1) + since_task_run_2_created = task_run_2.creation_date - timedelta(milliseconds=1) + + task_run_ids_since_task_run_1_created = db_utils.select_task_run_ids_since_date( + self.db, + since_task_run_1_created, + ) + task_run_ids_since_task_run_2_created = db_utils.select_task_run_ids_since_date( + self.db, + since_task_run_2_created, + ) + + self.assertEqual( + sorted(task_run_ids_since_task_run_1_created), + sorted([task_run_1_id, task_run_2_id]), + ) + self.assertEqual( + sorted(task_run_ids_since_task_run_2_created), + sorted([task_run_2_id]), + ) + + def test_select_fk_mappings_for_table(self, *args): + units_mappings = db_utils.select_fk_mappings_for_single_table(self.db, "units") + + self.assertEqual( + units_mappings, + { + "agents": {"from": "agent_id", "to": "agent_id"}, + "assignments": {"from": "assignment_id", "to": "assignment_id"}, + "requesters": {"from": "requester_id", "to": "requester_id"}, + "task_runs": {"from": "task_run_id", "to": "task_run_id"}, + "tasks": {"from": "task_id", "to": "task_id"}, + "workers": {"from": "worker_id", "to": "worker_id"}, + }, + ) + + def test_select_fk_mappings_for_tables(self, *args): + fk_mappings = db_utils.select_fk_mappings_for_tables(self.db, ["units", "tasks"]) + + self.assertEqual( + fk_mappings, + { + "tasks": { + "projects": {"from": "project_id", "to": "project_id"}, + "tasks": {"from": "parent_task_id", "to": "task_id"}, + }, + "units": { + "agents": {"from": "agent_id", "to": "agent_id"}, + "assignments": {"from": "assignment_id", "to": "assignment_id"}, + "requesters": {"from": "requester_id", "to": "requester_id"}, + "task_runs": {"from": "task_run_id", "to": "task_run_id"}, + "tasks": {"from": "task_id", "to": "task_id"}, + "workers": {"from": "worker_id", "to": "worker_id"}, + }, + }, + ) + + def test_insert_new_row_in_table(self, *args): + rows_count_before = len(db_utils.select_all_table_rows(self.db, "workers")) + + _, requester_id = get_test_requester(self.db) + requester = Requester.get(self.db, requester_id) + + db_utils.insert_new_row_in_table( + self.db, + "workers", + { + "worker_name": "test_worker", + "provider_type": requester.provider_type, + }, + ) + + rows_count_after = len(db_utils.select_all_table_rows(self.db, "workers")) + + self.assertEqual(rows_count_before, 0) + self.assertEqual(rows_count_after, 1) + + def test_update_row_in_table(self, *args): + updated_requester_name = "updated_requester_name" + + _, requester_id = get_test_requester(self.db) + row_before = self.db.get_requester(requester_id) + + db_utils.update_row_in_table( + self.db, "requesters", {**row_before, **{"requester_name": updated_requester_name}} + ) + + row_after = self.db.get_requester(requester_id) + + self.assertNotEqual(row_before["requester_name"], updated_requester_name) + self.assertEqual(row_after["requester_name"], updated_requester_name) + + def test_retry_generate_id(self, *args): + + # Function to simulate methods in Mephisto DB and provider-specific datastores + @db_utils.retry_generate_id(caught_excs=[EntryAlreadyExistsException]) + def _insert_new_row_in_projects(db: "MephistoDB", name: str): + with db.table_access_condition, db.get_connection() as conn: + c = conn.cursor() + + try: + c.execute( + f""" + INSERT INTO projects( + project_id, project_name + ) VALUES (?, ?); + """, + ( + db_utils.make_randomized_int_id(), + name, + ), + ) + except sqlite3.IntegrityError as e: + if is_unique_failure(e): + raise EntryAlreadyExistsException( + e, + db=db, + table_name="projects", + original_exc=e, + ) + + project_id_1 = 1 + project_id_2 = project_id_1 + project_id_3 = project_id_1 + project_id_4 = db_utils.make_randomized_int_id() + mock_randomized_ids = [ + project_id_1, # Correct first id + project_id_2, # Conflicting id + project_id_3, # Conflicting id + project_id_4, # Random id that must be called after conflict + ] + + project_names = ["project_name_1", "project_name_2"] + + with patch("mephisto.utils.db.make_randomized_int_id") as mock_make_randomized_int_id: + mock_make_randomized_int_id.side_effect = mock_randomized_ids + + # 1. We call function only TWICE (for each project_name). + # 2. First call creates project for first name + # 3. Second call raises exception and decorator retries to call function again twice + # where mocked `make_randomized_int_id` returns randomized id + # (3d value in `mock_randomized_ids` var) + for project_name in project_names: + # Call function wrapped with decorator `retry_generate_id` + _insert_new_row_in_projects(self.db, project_name) + + rows = db_utils.select_all_table_rows(self.db, "projects", order_by="creation_date") + + self.assertEqual(len(rows), 2) + self.assertEqual(rows[0]["project_id"], str(project_id_1)) + self.assertEqual(rows[1]["project_id"], str(project_id_4)) diff --git a/test/utils/test_misc.py b/test/utils/test_misc.py new file mode 100644 index 000000000..ab9cc36cd --- /dev/null +++ b/test/utils/test_misc.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from datetime import datetime + +import pytest +from dateutil.parser import ParserError +from dateutil.tz import tzlocal + +from mephisto.utils.misc import serialize_date_to_python + + +@pytest.mark.utils +class TestUtilsMisc(unittest.TestCase): + def test_serialize_date_to_python(self, *args): + common_datetime_string = "2001-01-01 01:01:01.000" + common_date_string = "2001-01-01" + iso8601_datetime_string = "2002-02-02T02:02:02.000Z" + unix_timestamp_string = 1046660583000 + wrong_date_string = "wrong_date" + + python_datetime_1 = serialize_date_to_python(common_datetime_string) + self.assertEqual(python_datetime_1, datetime(2001, 1, 1, 1, 1, 1)) + + python_datetime_2 = serialize_date_to_python(common_date_string) + self.assertEqual(python_datetime_2, datetime(2001, 1, 1, 0, 0, 0)) + + python_datetime_3 = serialize_date_to_python(iso8601_datetime_string) + self.assertEqual(python_datetime_3, datetime(2002, 2, 2, 2, 2, 2, tzinfo=tzlocal())) + + python_datetime_4 = serialize_date_to_python(unix_timestamp_string) + self.assertEqual(python_datetime_4, datetime(2003, 3, 3, 3, 3, 3)) + + with self.assertRaises(ParserError) as cm: + serialize_date_to_python(wrong_date_string) + self.assertEqual(cm.exception.__str__(), f"Unknown string format: {wrong_date_string}")