diff --git a/scripts-dev/update_database b/scripts-dev/update_database index 22bee22bc771..1d62f0403afc 100755 --- a/scripts-dev/update_database +++ b/scripts-dev/update_database @@ -26,7 +26,6 @@ from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import run_as_background_process from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.prepare_database import prepare_database logger = logging.getLogger("update_database") @@ -77,13 +76,8 @@ if __name__ == "__main__": # Instantiate and initialise the homeserver object. hs = MockHomeserver(config) - # Update the database to the latest schema. - database = hs.config.get_single_database() - db_conn = database.make_conn() - prepare_database(db_conn, database.engine, config=config) - db_conn.commit() - - # setup instantiates the store within the homeserver object. + # Setup instantiates the store within the homeserver object and updates the + # DB. hs.setup() store = hs.get_datastore() diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 6592f7b9f04e..5b5368988c81 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -56,7 +56,7 @@ from synapse.storage.data_stores.main.stats import StatsStore from synapse.storage.data_stores.main.user_directory import ( UserDirectoryBackgroundUpdateStore, ) -from synapse.storage.database import Database +from synapse.storage.database import Database, make_conn from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database from synapse.util import Clock @@ -440,9 +440,9 @@ class Porter(object): else: return - def setup_db(self, db_config: DatabaseConnectionConfig): - db_conn = db_config.make_conn() - prepare_database(db_conn, db_config.engine, config=None) + def setup_db(self, db_config: DatabaseConnectionConfig, engine): + db_conn = make_conn(db_config, engine) + prepare_database(db_conn, engine, config=None) db_conn.commit() @@ -460,15 +460,16 @@ class Porter(object): """ self.progress.set_state("Preparing %s" % db_config.config["name"]) - conn = self.setup_db(db_config) + engine = create_engine(db_config.config) + conn = self.setup_db(db_config, engine) hs = MockHomeserver(self.hs_config) - store = Store(Database(hs, db_config), conn, hs) + store = Store(Database(hs, db_config, engine), conn, hs) yield store.db.runInteraction( "%s_engine.check_database" % db_config.config["name"], - db_config.engine.check_database, + engine.check_database, ) return store diff --git a/synapse/config/database.py b/synapse/config/database.py index 83a16176d4af..5f2f3c7cfdbb 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -19,10 +19,7 @@ import yaml -from twisted.enterprise import adbapi - from synapse.config._base import Config, ConfigError -from synapse.storage.engines import create_engine logger = logging.getLogger(__name__) @@ -53,37 +50,6 @@ def __init__(self, name: str, db_config: dict, data_stores: List[str]): self.config = db_config self.data_stores = data_stores - self.engine = create_engine(db_config) - self.config["args"]["cp_openfun"] = self.engine.on_new_connection - - self._pool = None - - def get_pool(self, reactor) -> adbapi.ConnectionPool: - """Get the connection pool for the database. - """ - - if self._pool is None: - self._pool = adbapi.ConnectionPool( - self.config["name"], cp_reactor=reactor, **self.config.get("args", {}) - ) - - return self._pool - - def make_conn(self): - """Make a new connection to the database and return it. - - Returns: - Connection - """ - - db_params = { - k: v - for k, v in self.config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.engine.module.connect(**db_params) - return db_conn - class DatabaseConfig(Config): section = "database" diff --git a/synapse/server.py b/synapse/server.py index 4b0970eada05..7926867b777e 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -273,6 +273,9 @@ def get_clock(self): def get_datastore(self): return self.datastores.main + def get_datastores(self): + return self.datastores + def get_config(self): return self.config diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index 8006f6497e51..0983e059c072 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -15,7 +15,8 @@ import logging -from synapse.storage.database import Database +from synapse.storage.database import Database, make_conn +from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database logger = logging.getLogger(__name__) @@ -34,20 +35,21 @@ def __init__(self, main_store_class, hs): # Note we pass in the main store class here as workers use a different main # store. + self.databases = [] + for database_config in hs.config.database.databases: db_name = database_config.name - with database_config.make_conn() as db_conn: + engine = create_engine(database_config.config) + + with make_conn(database_config, engine) as db_conn: logger.info("Preparing database %r...", db_name) - database_config.engine.check_database(db_conn.cursor()) + engine.check_database(db_conn.cursor()) prepare_database( - db_conn, - database_config.engine, - hs.config, - data_stores=database_config.data_stores, + db_conn, engine, hs.config, data_stores=database_config.data_stores, ) - database = Database(hs, database_config) + database = Database(hs, database_config, engine) if "main" in database_config.data_stores: logger.info("Starting 'main' data store") @@ -55,4 +57,6 @@ def __init__(self, main_store_class, hs): db_conn.commit() + self.databases.append(database) + logger.info("Database %r prepared", db_name) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index dc4e8ee188eb..1003dd84a541 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -24,9 +24,11 @@ from prometheus_client import Histogram +from twisted.enterprise import adbapi from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.config.database import DatabaseConnectionConfig from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater @@ -74,6 +76,37 @@ } +def make_pool( + reactor, db_config: DatabaseConnectionConfig, engine +) -> adbapi.ConnectionPool: + """Get the connection pool for the database. + """ + + return adbapi.ConnectionPool( + db_config.config["name"], + cp_reactor=reactor, + cp_openfun=engine.on_new_connection, + **db_config.config.get("args", {}) + ) + + +def make_conn(db_config: DatabaseConnectionConfig, engine): + """Make a new connection to the database and return it. + + Returns: + Connection + """ + + db_params = { + k: v + for k, v in db_config.config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = engine.module.connect(**db_params) + engine.on_new_connection(db_conn) + return db_conn + + class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() @@ -218,11 +251,11 @@ class Database(object): _TXN_ID = 0 - def __init__(self, hs, database_config): + def __init__(self, hs, database_config: DatabaseConnectionConfig, engine): self.hs = hs self._clock = hs.get_clock() self._database_config = database_config - self._db_pool = database_config.get_pool(hs.get_reactor()) + self._db_pool = make_pool(hs.get_reactor(), database_config, engine) self.updates = BackgroundUpdater(hs, self) @@ -235,7 +268,7 @@ def __init__(self, hs, database_config): # to watch it self._txn_perf_counters = PerformanceCounters() - self.engine = database_config.engine + self.engine = engine # A set of tables that are not safe to use native upserts in. self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index c4ed3f0e88ee..2a1e7c7166de 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -20,7 +20,7 @@ ReplicationClientHandler, ) from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory -from synapse.storage.database import Database +from synapse.storage.database import make_conn from tests import unittest from tests.server import FakeTransport @@ -44,8 +44,9 @@ def prepare(self, reactor, clock, hs): db_config = hs.config.database.get_single_database() self.master_store = self.hs.get_datastore() self.storage = hs.get_storage() + database = hs.get_datastores().databases[0] self.slaved_store = self.STORE_TYPE( - Database(hs, db_config), db_config.get_pool(reactor).connect(), self.hs + database, make_conn(db_config, database.engine), self.hs ) self.event_id = 0 diff --git a/tests/server.py b/tests/server.py index dafbb3cb2838..a554dfdd570a 100644 --- a/tests/server.py +++ b/tests/server.py @@ -308,33 +308,35 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): # Make the thread pool synchronous. clock = server.get_clock() - pool = database.get_pool(clock._reactor) - - def runWithConnection(func, *args, **kwargs): - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runWithConnection, - func, - *args, - **kwargs - ) - def runInteraction(interaction, *args, **kwargs): - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runInteraction, - interaction, - *args, - **kwargs - ) + for database in server.get_datastores().databases: + pool = database._db_pool + + def runWithConnection(func, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runWithConnection, + func, + *args, + **kwargs + ) + + def runInteraction(interaction, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runInteraction, + interaction, + *args, + **kwargs + ) - if pool: pool.runWithConnection = runWithConnection pool.runInteraction = runInteraction pool.threadpool = ThreadPool(clock._reactor) pool.running = True + return server diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 1657941d0372..fd5251269645 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -28,7 +28,7 @@ ApplicationServiceStore, ApplicationServiceTransactionStore, ) -from synapse.storage.database import Database +from synapse.storage.database import Database, make_conn from tests import unittest from tests.utils import setup_test_homeserver @@ -55,8 +55,10 @@ def setUp(self): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - database = hs.config.get_single_database() - self.store = ApplicationServiceStore(database, database.make_conn(), hs) + database = hs.get_datastores().databases[0] + self.store = ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) def tearDown(self): # TODO: suboptimal that we need to create files for tests! @@ -111,10 +113,6 @@ def setUp(self): hs.config.event_cache_size = 1 hs.config.password_providers = [] - database = hs.config.get_single_database() - self.db_pool = database.get_pool(hs.get_reactor()) - self.engine = database.engine - self.as_list = [ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"}, @@ -126,9 +124,14 @@ def setUp(self): self.as_yaml_files = [] + # We assume there is only one database in these tests + database = hs.get_datastores().databases[0] + self.db_pool = database._db_pool + self.engine = database.engine + db_config = hs.config.get_single_database() self.store = TestTransactionStore( - Database(hs, db_config), db_config.make_conn(), hs + database, make_conn(db_config, self.engine), hs ) def _add_service(self, url, as_token, id): @@ -422,8 +425,10 @@ def test_unique_works(self): hs.config.event_cache_size = 1 hs.config.password_providers = [] - database = hs.config.get_single_database() - ApplicationServiceStore(database, database.make_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) @defer.inlineCallbacks def test_duplicate_ids(self): @@ -439,8 +444,10 @@ def test_duplicate_ids(self): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - database = hs.config.get_single_database() - ApplicationServiceStore(database, database.make_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) e = cm.exception self.assertIn(f1, str(e)) @@ -461,8 +468,10 @@ def test_duplicate_as_tokens(self): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - database = hs.config.get_single_database() - ApplicationServiceStore(database, database.make_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) e = cm.exception self.assertIn(f1, str(e)) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index cf3c40273078..cdee0a9e60b4 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -52,17 +52,17 @@ def runWithConnection(func, *args, **kwargs): config = Mock() config._disable_native_upserts = True config.event_cache_size = 1 - config.database_config = {"name": "sqlite3"} - engine = create_engine(config.database_config) + hs = TestHomeServer("test", config=config) + + sqlite_config = {"name": "sqlite3"} + engine = create_engine(sqlite_config) fake_engine = Mock(wraps=engine) fake_engine.can_native_upsert = False - hs = TestHomeServer("test", config=config) - mock_db = Mock() - mock_db.engine = fake_engine - mock_db.get_pool.return_value = self.db_pool + db = Database(Mock(), Mock(config=sqlite_config), fake_engine) + db._db_pool = self.db_pool - self.datastore = SQLBaseStore(Database(Mock(), mock_db), None, hs) + self.datastore = SQLBaseStore(db, None, hs) @defer.inlineCallbacks def test_insert_1col(self): diff --git a/tests/utils.py b/tests/utils.py index 15646c8c9eab..9f5bf40b4bba 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -232,9 +232,10 @@ def setup_test_homeserver( } database = DatabaseConnectionConfig("master", database_config, ["main"]) - db_engine = database.engine config.database.databases = [database] + db_engine = create_engine(database.config) + # Create the database before we actually try and connect to it, based off # the template database we generate in setupdb() if datastore is None and isinstance(db_engine, PostgresEngine): @@ -264,13 +265,19 @@ def setup_test_homeserver( **kargs ) + hs.setup() + if homeserverToUse.__name__ == "TestHomeServer": + hs.setup_master() + if isinstance(db_engine, PostgresEngine): + database = hs.get_datastores().databases[0] + # We need to do cleanup on PostgreSQL def cleanup(): import psycopg2 # Close all the db pools - database.get_pool(reactor).close() + database._db_pool.close() dropped = False @@ -309,9 +316,6 @@ def cleanup(): # Register the cleanup hook cleanup_func(cleanup) - hs.setup() - if homeserverToUse.__name__ == "TestHomeServer": - hs.setup_master() else: hs = homeserverToUse( name,