Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TriblerDatabaseMigrationChain #7622

Merged
merged 1 commit into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/tribler/core/components/database/database_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ async def run(self):
if self.session.config.gui_test_mode:
db_path = ":memory:"

self.db = TriblerDatabase(str(db_path), create_tables=True)
self.db = TriblerDatabase(str(db_path))

async def shutdown(self):
await super().shutdown()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@
# pylint: disable=protected-access
class TestKnowledgeAccessLayer(TestKnowledgeAccessLayerBase):
@patch.object(TrackedDatabase, 'generate_mapping')
@patch.object(TriblerDatabase, 'fill_default_data', Mock())
def test_constructor_create_tables_true(self, mocked_generate_mapping: Mock):
TriblerDatabase(':memory:')
""" Test that constructor of TriblerDatabase calls TrackedDatabase.generate_mapping with create_tables=True"""
TriblerDatabase()

mocked_generate_mapping.assert_called_with(create_tables=True)

@patch.object(TrackedDatabase, 'generate_mapping')
@patch.object(TriblerDatabase, 'fill_default_data', Mock())
def test_constructor_create_tables_false(self, mocked_generate_mapping: Mock):
TriblerDatabase(':memory:', create_tables=False)
""" Test that constructor of TriblerDatabase calls TrackedDatabase.generate_mapping with create_tables=False"""
TriblerDatabase(create_tables=False)

mocked_generate_mapping.assert_called_with(create_tables=False)

@db_session
Expand Down Expand Up @@ -245,7 +251,7 @@ def test_get_objects_removed(self):
)

self.add_operation(self.db, subject='infohash1', predicate=ResourceType.TAG, obj='tag2', peer=b'4',
operation=Operation.REMOVE)
operation=Operation.REMOVE)

assert self.db.knowledge.get_objects(subject='infohash1', predicate=ResourceType.TAG) == ['tag1']

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from ipv8.test.base import TestBase
from pony.orm import db_session

Expand Down Expand Up @@ -37,8 +38,10 @@ def dump_db(self):
@db_session
def test_set_misc(self):
"""Test that set_misc works as expected"""
self.db.set_misc(key='key', value='value')
assert self.db.get_misc(key='key') == 'value'
self.db.set_misc(key='string', value='value')
self.db.set_misc(key='integer', value=1)
assert self.db.get_misc(key='string') == 'value'
assert self.db.get_misc(key='integer') == '1'

@db_session
def test_non_existent_misc(self):
Expand All @@ -48,3 +51,20 @@ def test_non_existent_misc(self):

# A value if the key does exist
assert self.db.get_misc(key='non existent', default=42) == 42

@db_session
def test_default_version(self):
""" Test that the default version is equal to `CURRENT_VERSION`"""
assert self.db.version == TriblerDatabase.CURRENT_VERSION

@db_session
def test_version_getter_and_setter(self):
""" Test that the version getter and setter work as expected"""
self.db.version = 42
assert self.db.version == 42

@db_session
def test_version_getter_unsupported_type(self):
""" Test that the version getter raises a TypeError if the type is not supported"""
with pytest.raises(TypeError):
self.db.version = 'string'
37 changes: 33 additions & 4 deletions src/tribler/core/components/database/db/tribler_database.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import logging
import os
from typing import Any, Optional

from pony import orm

from tribler.core.components.database.db.layers.health_data_access_layer import HealthDataAccessLayer
from tribler.core.components.database.db.layers.knowledge_data_access_layer import KnowledgeDataAccessLayer
from tribler.core.utilities.pony_utils import TrackedDatabase, get_or_create
from tribler.core.utilities.pony_utils import TrackedDatabase, db_session, get_or_create

MEMORY = ':memory:'


class TriblerDatabase:
CURRENT_VERSION = 1
_SCHEME_VERSION_KEY = 'scheme_version'

def __init__(self, filename: Optional[str] = None, *, create_tables: bool = True, **generate_mapping_kwargs):
self.instance = TrackedDatabase()

Expand All @@ -25,21 +31,31 @@ def __init__(self, filename: Optional[str] = None, *, create_tables: bool = True
self.TorrentHealth = self.health.TorrentHealth
self.Tracker = self.health.Tracker

self.instance.bind('sqlite', filename or ':memory:', create_db=True)
filename = filename or MEMORY
db_does_not_exist = filename == MEMORY or not os.path.isfile(filename)

self.instance.bind('sqlite', filename, create_db=db_does_not_exist)
generate_mapping_kwargs['create_tables'] = create_tables
self.instance.generate_mapping(**generate_mapping_kwargs)
self.logger = logging.getLogger(self.__class__.__name__)

if db_does_not_exist:
self.fill_default_data()

@staticmethod
def define_binding(db):
""" Define common bindings"""

class Misc(db.Entity): # pylint: disable=unused-variable
class Misc(db.Entity):
name = orm.PrimaryKey(str)
value = orm.Optional(str)

return Misc

@db_session
def fill_default_data(self):
self.logger.info('Filling the DB with the default data')
self.set_misc(self._SCHEME_VERSION_KEY, self.CURRENT_VERSION)

def get_misc(self, key: str, default: Optional[str] = None) -> Optional[str]:
data = self.Misc.get(name=key)
return data.value if data else default
Expand All @@ -48,5 +64,18 @@ def set_misc(self, key: str, value: Any):
key_value = get_or_create(self.Misc, name=key)
key_value.value = str(value)

@property
def version(self) -> int:
""" Get the database version"""
return int(self.get_misc(key=self._SCHEME_VERSION_KEY, default=0))

@version.setter
def version(self, value: int):
""" Set the database version"""
if not isinstance(value, int):
raise TypeError('DB version should be integer')

self.set_misc(key=self._SCHEME_VERSION_KEY, value=value)

def shutdown(self) -> None:
self.instance.disconnect()
18 changes: 18 additions & 0 deletions src/tribler/core/upgrade/tribler_db/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest

from tribler.core.components.database.db.tribler_database import TriblerDatabase
from tribler.core.upgrade.tribler_db.migration_chain import TriblerDatabaseMigrationChain
from tribler.core.utilities.path_util import Path
from tribler.core.utilities.simpledefs import STATEDIR_DB_DIR


# pylint: disable=redefined-outer-name


@pytest.fixture
def migration_chain(tmpdir):
""" Create an empty migration chain with an empty database."""
db_file_name = Path(tmpdir) / STATEDIR_DB_DIR / 'tribler.db'
db_file_name.parent.mkdir()
TriblerDatabase(filename=str(db_file_name))
return TriblerDatabaseMigrationChain(state_dir=Path(tmpdir), chain=[])
53 changes: 53 additions & 0 deletions src/tribler/core/upgrade/tribler_db/decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import functools
import logging
from typing import Callable, Optional

from tribler.core.components.database.db.tribler_database import TriblerDatabase
from tribler.core.utilities.pony_utils import db_session

MIGRATION_METADATA = "_tribler_db_migration"

logger = logging.getLogger('Migration (TriblerDB)')


def migration(execute_only_if_version: int, set_after_successful_execution_version: Optional[int] = None):
""" Decorator for migration functions.
The migration executes in the single transaction. If the migration fails, the transaction is rolled back.
The decorator also sets the metadata attribute to the decorated function. It could be checked by
calling the `has_migration_metadata` function.
Args:
execute_only_if_version: Execute the migration only if the current db version is equal to this value.
set_after_successful_execution_version: Set the db version to this value after the migration is executed.
If it is not specified, then `set_after_successful_execution_version = execute_only_if_version + 1`
"""

def decorator(func):
@functools.wraps(func)
@db_session
def wrapper(db: TriblerDatabase, **kwargs):
target_version = execute_only_if_version
if target_version != db.version:
logger.info(
f"Function {func.__name__} is not executed because DB version is not equal to {target_version}. "
f"The current db version is {db.version}"
)
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks dangerous to silently skip the migration in case of a version mismatch and can lead to incorrect migrations.

As a possible example:

  • We have migration chain [a, b, c], and after the migration c, the database has version 3
  • Two developers make two independent PRs with migrations d and e, both require db version 3
  • After both PRs were successfully merged, the merge conflict in the migration chain was resolved to [a, b, c, d, e], but both d and e still mistakenly require version 3 of the database.
  • The execution of the migration chain silently skips the migration e

For that reason, it is better to raise an error instead of skipping the migration so the database will not be left in a surprising state.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we were to raise an exception as you suggested, then almost all runs would result in unsuccessful execution.

Example:

  1. We have a migration chain [a, b, c] with the corresponding versions [1, 2, 3]. Let's represent them as [a(1), b(2), c(3)].
  2. Our DB currently has a version set to 2.
  3. We begin executing the migration chain starting with migration a(1).
  4. This results in an exception because the DB version is not equal to 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can first select suitable migrations, as described in the previous comment, and execute only them. This way, the migration decorator can throw an exception if the migration was performed at an improper moment. It looks safer than silently skipping migrations with incorrect db version numbers.

Also, consider the migrations [a, b, c, d, e] that mistakenly expect the schema version [1, 3, 2, 4, 5]. With silent skipping of the incorrect migrations, only migrations 1 and 2 will be executed. It is better to prevent such errors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I'm not sure that it is a good idea to automatically fix potential merge conflict errors.

Also, I'm not sure that the Migration Chain should have a monopoly on changing the version value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I'm not sure that it is a good idea to automatically fix potential merge conflict errors.

I agree, and what I like in the current approach is that the merge conflicts for migrations always require explicit conflict resolution, as migrations are explicitly listed in the same list.

What I'm concerned about is that the current PR approach allows silent skipping of migration if the schema number specified in the decorator does not agree with the order of migrations in the explicit list. Raising an error looks like a safer approach.


result = func(db, **kwargs)

next_version = set_after_successful_execution_version
if next_version is None:
next_version = target_version + 1
db.version = next_version

return result

setattr(wrapper, MIGRATION_METADATA, {})
return wrapper

return decorator


def has_migration_metadata(f: Callable):
""" Check if the function has migration metadata."""
return hasattr(f, MIGRATION_METADATA)
48 changes: 48 additions & 0 deletions src/tribler/core/upgrade/tribler_db/migration_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import logging
from typing import Callable, List, Optional

from tribler.core.components.database.db.tribler_database import TriblerDatabase
from tribler.core.upgrade.tribler_db.decorator import has_migration_metadata
from tribler.core.upgrade.tribler_db.scheme_migrations.scheme_migration_0 import scheme_migration_0
from tribler.core.utilities.path_util import Path
from tribler.core.utilities.simpledefs import STATEDIR_DB_DIR


class TriblerDatabaseMigrationChain:
""" A chain of migrations that can be executed on a TriblerDatabase.

To create a new migration, create a new function and decorate it with the `migration` decorator. Then add it to
the `DEFAULT_CHAIN` list.
"""

DEFAULT_CHAIN = [
scheme_migration_0,
# add your migration here
]

def __init__(self, state_dir: Path, chain: Optional[List[Callable]] = None):
self.logger = logging.getLogger(self.__class__.__name__)
self.state_dir = state_dir

db_path = self.state_dir / STATEDIR_DB_DIR / 'tribler.db'
self.logger.info(f'Tribler DB path: {db_path}')
self.db = TriblerDatabase(str(db_path), check_tables=False) if db_path.is_file() else None

self.migrations = chain or self.DEFAULT_CHAIN

def execute(self) -> bool:
""" Execute all migrations in the chain.

Returns: True if all migrations were executed successfully, False otherwise.
An exception in any of the migrations will halt the execution chain and be re-raised.
"""

if not self.db:
return False

for m in self.migrations:
if not has_migration_metadata(m):
raise NotImplementedError(f'The migration {m} should have `migration` decorator')
m(self.db, state_dir=self.state_dir)

return True
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from tribler.core.components.database.db.tribler_database import TriblerDatabase
from tribler.core.upgrade.tribler_db.decorator import migration


@migration(execute_only_if_version=0)
def scheme_migration_0(db: TriblerDatabase, **kwargs): # pylint: disable=unused-argument
""" "This is initial migration, placed here primarily for demonstration purposes.
It doesn't do anything except set the database version to `1`.

For upcoming migrations, there are some guidelines:
1. functions should contain a single parameter, `db: TriblerDatabase`,
2. they should apply the `@migration` decorator.


Utilizing plain SQL (as seen in the example below) is considered good practice since it helps prevent potential
inconsistencies in DB schemes in the future (model versions preceding the current one may differ from it).
For more information see: https://github.com/Tribler/tribler/issues/7382

The example of a migration:

db.execute('ALTER TABLE "TorrentState" ADD "has_data" BOOLEAN DEFAULT 0')
db.execute('UPDATE "TorrentState" SET "has_data" = 1 WHERE last_check > 0')
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from tribler.core.upgrade.tribler_db.migration_chain import TriblerDatabaseMigrationChain
from tribler.core.upgrade.tribler_db.scheme_migrations.scheme_migration_0 import scheme_migration_0
from tribler.core.utilities.pony_utils import db_session


@db_session
def test_scheme_migration_0(migration_chain: TriblerDatabaseMigrationChain):
""" Test that the scheme_migration_0 changes the database version to 1. """
migration_chain.db.version = 0
migration_chain.migrations = [scheme_migration_0]

assert migration_chain.execute()
assert migration_chain.db.version == 1
74 changes: 74 additions & 0 deletions src/tribler/core/upgrade/tribler_db/tests/test_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from unittest.mock import Mock

import pytest

from tribler.core.upgrade.tribler_db.decorator import has_migration_metadata, migration


def test_migration_execute_only_if_version():
""" Test that migration is executed only if the version of the database is equal to the specified one."""

@migration(execute_only_if_version=1)
def test(_: Mock):
return True

assert test(Mock(version=1))
assert not test(Mock(version=2))


def test_set_after_successful_execution_version():
""" Test that the version of the database is set to the specified one after the migration is successfully
executed.
"""

@migration(execute_only_if_version=1, set_after_successful_execution_version=33)
def test(_: Mock):
...

db = Mock(version=1)
test(db)

assert db.version == 33


def test_set_after_successful_execution_version_not_specified():
""" Test that if the version is not specified, the version of the database will be set to
execute_only_if_version + 1
"""

@migration(execute_only_if_version=1)
def test(_: Mock):
...

db = Mock(version=1)
test(db)

assert db.version == 2


def test_set_after_successful_execution_raise_an_exception():
""" Test that if an exception is raised during the migration, the version of the database is not changed."""

@migration(execute_only_if_version=1, set_after_successful_execution_version=33)
def test(_: Mock):
raise TypeError

db = Mock(version=1)
with pytest.raises(TypeError):
test(db)

assert db.version == 1


def test_set_metadata():
""" Test that the metadata flag is set."""

@migration(execute_only_if_version=1)
def simple_migration(_: Mock):
...

def no_migration(_: Mock):
...

assert has_migration_metadata(simple_migration)
assert not has_migration_metadata(no_migration)
Loading
Loading