From f6dfbc23a9e826c7964a48c357e597497063ba9d Mon Sep 17 00:00:00 2001 From: Mitchel Cabuloy Date: Sun, 31 Jul 2022 22:04:34 +0800 Subject: [PATCH 1/9] Add PGClient --- dslr/cli.py | 15 +++++++-------- dslr/operations.py | 19 +++++-------------- dslr/pg_client.py | 23 +++++++++++++++++++++++ dslr/runner.py | 26 ++++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 22 deletions(-) create mode 100644 dslr/pg_client.py diff --git a/dslr/cli.py b/dslr/cli.py index 1d9981c..cb29e7f 100644 --- a/dslr/cli.py +++ b/dslr/cli.py @@ -10,7 +10,6 @@ from .config import settings from .console import console, cprint, eprint from .operations import ( - DSLRException, SnapshotNotFound, create_snapshot, delete_snapshot, @@ -100,7 +99,7 @@ def snapshot(name: str): try: with console.status("Creating snapshot"): create_snapshot(name) - except DSLRException as e: + except Exception as e: eprint("Failed to create snapshot") eprint(e, style="white") sys.exit(1) @@ -126,7 +125,7 @@ def restore(name): with console.status("Restoring snapshot"): try: restore_snapshot(snapshot) - except DSLRException as e: + except Exception as e: eprint("Failed to restore snapshot") eprint(e, style="white") sys.exit(1) @@ -141,7 +140,7 @@ def list(): """ try: snapshots = get_snapshots() - except DSLRException as e: + except Exception as e: eprint("Failed to list snapshots") eprint(f"{e}", style="white") sys.exit(1) @@ -174,7 +173,7 @@ def delete(name): try: delete_snapshot(snapshot) - except DSLRException as e: + except Exception as e: eprint("Failed to delete snapshot") eprint(e, style="white") sys.exit(1) @@ -212,7 +211,7 @@ def rename(old_name, new_name): try: rename_snapshot(old_snapshot, new_name) - except DSLRException as e: + except Exception as e: eprint("Failed to rename snapshot") eprint(e, style="white") sys.exit(1) @@ -235,7 +234,7 @@ def export(name): try: with console.status("Exporting snapshot"): export_path = export_snapshot(snapshot) - except DSLRException as e: + except Exception as e: eprint("Failed to export snapshot") eprint(e, style="white") sys.exit(1) @@ -269,7 +268,7 @@ def import_(filename, name): try: with console.status("Importing snapshot"): import_snapshot(filename, name) - except DSLRException as e: + except Exception as e: eprint("Failed to import snapshot") eprint(e, style="white") sys.exit(1) diff --git a/dslr/operations.py b/dslr/operations.py index 2d97030..373e137 100644 --- a/dslr/operations.py +++ b/dslr/operations.py @@ -3,7 +3,7 @@ from typing import List, Optional from .config import settings -from .runner import exec +from .runner import exec, exec_sql Snapshot = namedtuple("Snapshot", ["dbname", "name", "created_at"]) @@ -52,28 +52,19 @@ def get_snapshots() -> List[Snapshot]: dslr__ """ # Find the snapshot databases - result = exec("psql", "-c", "SELECT datname FROM pg_database") + result = exec_sql("SELECT datname FROM pg_database WHERE datname LIKE 'dslr_%'") - if result.returncode != 0: - raise DSLRException(result.stderr) - - lines = sorted( - [ - line.strip() - for line in result.stdout.split("\n") - if line.strip().startswith("dslr_") - ] - ) + snapshot_dbnames = sorted([row[0] for row in result]) # Parse the name into a Snapshot - parts = [line.split("_") for line in lines] + parts = [dbname.split("_") for dbname in snapshot_dbnames] return [ Snapshot( dbname=line, name="_".join(part[2:]), created_at=datetime.fromtimestamp(int(part[1])), ) - for part, line in zip(parts, lines) + for part, line in zip(parts, snapshot_dbnames) ] diff --git a/dslr/pg_client.py b/dslr/pg_client.py new file mode 100644 index 0000000..930d20a --- /dev/null +++ b/dslr/pg_client.py @@ -0,0 +1,23 @@ +from typing import Any, List, Tuple + +import psycopg2 + + +class PGClient: + """ + Thin wrapper around psycopg2 + """ + + def __init__(self, host, port, user, password, dbname): + self.conn = psycopg2.connect( + host=host, + port=port, + user=user, + password=password, + dbname=dbname, + ) + self.cur = self.conn.cursor() + + def execute(self, sql, data) -> List[Tuple[Any, ...]]: + self.cur.execute(sql, data) + return self.cur.fetchall() diff --git a/dslr/runner.py b/dslr/runner.py index 6334686..5c4d0e4 100644 --- a/dslr/runner.py +++ b/dslr/runner.py @@ -1,6 +1,9 @@ import os import subprocess from collections import namedtuple +from typing import Any, Dict, List, Optional, Tuple + +from dslr.pg_client import PGClient from .config import settings from .console import console @@ -35,8 +38,31 @@ def exec(*cmd: str) -> Result: console.log("STDOUT:\n", stdout.decode("utf-8"), "\n") console.log("STDERR:\n", stderr.decode("utf-8"), "\n") + # TODO: Make this raise an exception instead return Result( returncode=p.returncode, stdout=stdout.decode("utf-8"), stderr=stderr.decode("utf-8"), ) + + +# Singleton instance of PGClient +pg_client: Optional[PGClient] = None + + +def exec_sql(sql: str, data: Optional[Dict[Any, Any]] = None) -> List[Tuple[Any, ...]]: + """ + Executes a SQL query. + """ + global pg_client + + if not pg_client: + pg_client = PGClient( + host=settings.db.host, + port=settings.db.port, + user=settings.db.username, + password=settings.db.password, + dbname="postgres", + ) + + return pg_client.execute(sql, data) From fdb695c264803a3d35c983109e146d0e33198073 Mon Sep 17 00:00:00 2001 From: Mitchel Cabuloy Date: Sun, 31 Jul 2022 22:27:43 +0800 Subject: [PATCH 2/9] Implement commands with exec_sql --- dslr/operations.py | 61 +++++++++++++--------------------------------- dslr/runner.py | 9 ++++--- tests/test_cli.py | 37 ++++++++++++++-------------- 3 files changed, 41 insertions(+), 66 deletions(-) diff --git a/dslr/operations.py b/dslr/operations.py index 373e137..606c685 100644 --- a/dslr/operations.py +++ b/dslr/operations.py @@ -3,7 +3,7 @@ from typing import List, Optional from .config import settings -from .runner import exec, exec_sql +from .runner import exec_shell, exec_sql Snapshot = namedtuple("Snapshot", ["dbname", "name", "created_at"]) @@ -30,18 +30,12 @@ def kill_connections(dbname: str): """ Kills all connections to the given database """ - result = exec( - "psql", - "-d", - "postgres", - "-c", - "SELECT pg_terminate_backend(pg_stat_activity.pid) " - f"FROM pg_stat_activity WHERE pg_stat_activity.datname = '{dbname}'", + exec_sql( + "SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity " + "WHERE pg_stat_activity.datname = %s", + [dbname], ) - if result.returncode != 0: - raise DSLRException(result.stderr) - def get_snapshots() -> List[Snapshot]: """ @@ -99,22 +93,17 @@ def create_snapshot(snapshot_name: str): """ kill_connections(settings.db.name) - result = exec( - "createdb", "-T", settings.db.name, generate_snapshot_db_name(snapshot_name) + exec_sql( + "CREATE DATABASE %s TEMPLATE %s", + [generate_snapshot_db_name(snapshot_name), settings.db.name], ) - if result.returncode != 0: - raise DSLRException(result.stderr) - def delete_snapshot(snapshot: Snapshot): """ Deletes the given snapshot """ - result = exec("dropdb", snapshot.dbname) - - if result.returncode != 0: - raise DSLRException(result.stderr) + exec_sql("DROP DATABASE %s", [snapshot.dbname]) def restore_snapshot(snapshot: Snapshot): @@ -123,30 +112,15 @@ def restore_snapshot(snapshot: Snapshot): """ kill_connections(settings.db.name) - result = exec("dropdb", settings.db.name) - - if result.returncode != 0: - raise DSLRException(result.stderr) - - result = exec("createdb", "-T", snapshot.dbname, settings.db.name) - - if result.returncode != 0: - raise DSLRException(result.stderr) + exec_sql("DROP DATABASE %s", [settings.db.name]) + exec_sql("CREATE DATABASE %s TEMPLATE %s", [settings.db.name, snapshot.dbname]) def rename_snapshot(snapshot: Snapshot, new_name: str): """ Renames the given snapshot """ - result = exec( - "psql", - "-c", - f'ALTER DATABASE "{snapshot.dbname}" RENAME TO ' - f'"{generate_snapshot_db_name(new_name, snapshot.created_at)}"', - ) - - if result.returncode != 0: - raise DSLRException(result.stderr) + exec_sql("ALTER DATABASE %s RENAME TO %s", [snapshot.dbname, new_name]) def export_snapshot(snapshot: Snapshot) -> str: @@ -154,7 +128,7 @@ def export_snapshot(snapshot: Snapshot) -> str: Exports the given snapshot to a file """ export_path = f"{snapshot.name}_{snapshot.created_at:%Y%m%d-%H%M%S}.dump" - result = exec("pg_dump", "-Fc", "-d", snapshot.dbname, "-f", export_path) + result = exec_shell("pg_dump", "-Fc", "-d", snapshot.dbname, "-f", export_path) if result.returncode != 0: raise DSLRException(result.stderr) @@ -167,12 +141,11 @@ def import_snapshot(import_path: str, snapshot_name: str): Imports the given snapshot from a file """ db_name = generate_snapshot_db_name(snapshot_name) - result = exec("createdb", db_name) - - if result.returncode != 0: - raise DSLRException(result.stderr) + exec_sql("CREATE DATABASE %s", [db_name]) - result = exec("pg_restore", "-d", db_name, "--no-acl", "--no-owner", import_path) + result = exec_shell( + "pg_restore", "-d", db_name, "--no-acl", "--no-owner", import_path + ) if result.returncode != 0: raise DSLRException(result.stderr) diff --git a/dslr/runner.py b/dslr/runner.py index 5c4d0e4..3c2c620 100644 --- a/dslr/runner.py +++ b/dslr/runner.py @@ -1,7 +1,7 @@ import os import subprocess from collections import namedtuple -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, List, Optional, Tuple from dslr.pg_client import PGClient @@ -11,7 +11,7 @@ Result = namedtuple("Result", ["returncode", "stdout", "stderr"]) -def exec(*cmd: str) -> Result: +def exec_shell(*cmd: str) -> Result: """ Executes a command. """ @@ -50,13 +50,16 @@ def exec(*cmd: str) -> Result: pg_client: Optional[PGClient] = None -def exec_sql(sql: str, data: Optional[Dict[Any, Any]] = None) -> List[Tuple[Any, ...]]: +def exec_sql(sql: str, data: Optional[List[Any]] = None) -> List[Tuple[Any, ...]]: """ Executes a SQL query. """ global pg_client if not pg_client: + # We always want to connect to the `postgres` and not the target + # database because all of our operations don't need to query the target + # database. pg_client = PGClient( host=settings.db.host, port=settings.db.port, diff --git a/tests/test_cli.py b/tests/test_cli.py index 277a13e..14625fc 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,5 +1,6 @@ import os from datetime import datetime +from typing import Any, List, Optional, Tuple from unittest import TestCase, mock from click.testing import CliRunner @@ -7,28 +8,25 @@ from dslr import cli, operations, runner -def stub_exec(*cmd: str): - # Set up fake snapshots - if "SELECT datname FROM pg_database" in cmd: - fake_snapshot_1 = operations.generate_snapshot_db_name( - "existing-snapshot-1", - created_at=datetime(2020, 1, 1, 0, 0, 0, 0), - ) - fake_snapshot_2 = operations.generate_snapshot_db_name( - "existing-snapshot-2", - created_at=datetime(2020, 1, 2, 0, 0, 0, 0), - ) - return runner.Result( - returncode=0, - stdout="\n".join([fake_snapshot_1, fake_snapshot_2]), - stderr="", - ) - +def stub_exec_shell(*cmd: str): return runner.Result(returncode=0, stdout="", stderr="") +def stub_exec_sql(sql: str, data: Optional[List[Any]] = None) -> List[Tuple[Any, ...]]: + fake_snapshot_1 = operations.generate_snapshot_db_name( + "existing-snapshot-1", + created_at=datetime(2020, 1, 1, 0, 0, 0, 0), + ) + fake_snapshot_2 = operations.generate_snapshot_db_name( + "existing-snapshot-2", + created_at=datetime(2020, 1, 2, 0, 0, 0, 0), + ) + return [(fake_snapshot_1,), (fake_snapshot_2,)] + + @mock.patch.dict(os.environ, {"DATABASE_URL": "postgres://user:pw@test:5432/my_db"}) -@mock.patch("dslr.operations.exec", new=stub_exec) +@mock.patch("dslr.operations.exec_shell", new=stub_exec_shell) +@mock.patch("dslr.operations.exec_sql", new=stub_exec_sql) class CliTest(TestCase): def test_executes(self): runner = CliRunner() @@ -168,7 +166,8 @@ def test_import_overwrite(self): ) -@mock.patch("dslr.operations.exec", new=stub_exec) +@mock.patch("dslr.operations.exec_shell", new=stub_exec_shell) +@mock.patch("dslr.operations.exec_sql", new=stub_exec_sql) class ConfigTest(TestCase): @mock.patch.dict( os.environ, {"DATABASE_URL": "postgres://envvar:pw@test:5432/my_db"} From 536c9a6ce88c598f164452ed1fa7a435e3b9c77a Mon Sep 17 00:00:00 2001 From: Mitchel Cabuloy Date: Sun, 31 Jul 2022 23:15:45 +0800 Subject: [PATCH 3/9] Refactor operations --- dslr/operations.py | 92 ++++++++++++++++++++++++++++++++-------------- dslr/pg_client.py | 28 ++++++++++++-- dslr/runner.py | 8 +++- tests/test_cli.py | 6 +-- 4 files changed, 99 insertions(+), 35 deletions(-) diff --git a/dslr/operations.py b/dslr/operations.py index 606c685..81a60d9 100644 --- a/dslr/operations.py +++ b/dslr/operations.py @@ -2,16 +2,65 @@ from datetime import datetime from typing import List, Optional +from psycopg2 import sql + from .config import settings from .runner import exec_shell, exec_sql -Snapshot = namedtuple("Snapshot", ["dbname", "name", "created_at"]) - class DSLRException(Exception): pass +################################################################################ +# Database operations +################################################################################ + + +def kill_connections(dbname: str): + """ + Kills all connections to the given database + """ + exec_sql( + "SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity " + "WHERE pg_stat_activity.datname = %s", + [dbname], + ) + + +def create_database(*, dbname: str, template: Optional[str] = None): + """ + Creates a new database with the given name, optionally using the given template + """ + if template: + exec_sql( + sql.SQL("CREATE DATABASE {} TEMPLATE {}").format( + sql.Identifier(dbname), + sql.Identifier(template), + ) + ) + else: + exec_sql( + sql.SQL("CREATE DATABASE {}").format( + sql.Identifier(dbname), + ) + ) + + +def drop_database(dbname: str): + """ + Drops the given database + """ + exec_sql(sql.SQL("DROP DATABASE {}").format(sql.Identifier(dbname))) + + +################################################################################ +# Snapshot operations +################################################################################ + +Snapshot = namedtuple("Snapshot", ["dbname", "name", "created_at"]) + + def generate_snapshot_db_name( snapshot_name: str, created_at: Optional[datetime] = None ) -> str: @@ -26,17 +75,6 @@ def generate_snapshot_db_name( return f"dslr_{timestamp}_{snapshot_name}" -def kill_connections(dbname: str): - """ - Kills all connections to the given database - """ - exec_sql( - "SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity " - "WHERE pg_stat_activity.datname = %s", - [dbname], - ) - - def get_snapshots() -> List[Snapshot]: """ Returns the list of database snapshots @@ -88,14 +126,10 @@ def create_snapshot(snapshot_name: str): Snapshotting works by creating a new database using the local database as a template. - - createdb -T source_db_name dslr__ """ kill_connections(settings.db.name) - - exec_sql( - "CREATE DATABASE %s TEMPLATE %s", - [generate_snapshot_db_name(snapshot_name), settings.db.name], + create_database( + dbname=generate_snapshot_db_name(snapshot_name), template=settings.db.name ) @@ -103,7 +137,7 @@ def delete_snapshot(snapshot: Snapshot): """ Deletes the given snapshot """ - exec_sql("DROP DATABASE %s", [snapshot.dbname]) + drop_database(snapshot.dbname) def restore_snapshot(snapshot: Snapshot): @@ -111,16 +145,20 @@ def restore_snapshot(snapshot: Snapshot): Restores the database from the given snapshot """ kill_connections(settings.db.name) - - exec_sql("DROP DATABASE %s", [settings.db.name]) - exec_sql("CREATE DATABASE %s TEMPLATE %s", [settings.db.name, snapshot.dbname]) + drop_database(settings.db.name) + create_database(dbname=settings.db.name, template=snapshot.dbname) def rename_snapshot(snapshot: Snapshot, new_name: str): """ Renames the given snapshot """ - exec_sql("ALTER DATABASE %s RENAME TO %s", [snapshot.dbname, new_name]) + exec_sql( + sql.SQL("ALTER DATABASE {} RENAME TO {}").format( + sql.Identifier(snapshot.dbname), + sql.Identifier(generate_snapshot_db_name(new_name, snapshot.created_at)), + ) + ) def export_snapshot(snapshot: Snapshot) -> str: @@ -140,11 +178,11 @@ def import_snapshot(import_path: str, snapshot_name: str): """ Imports the given snapshot from a file """ - db_name = generate_snapshot_db_name(snapshot_name) - exec_sql("CREATE DATABASE %s", [db_name]) + dbname = generate_snapshot_db_name(snapshot_name) + create_database(dbname=dbname) result = exec_shell( - "pg_restore", "-d", db_name, "--no-acl", "--no-owner", import_path + "pg_restore", "-d", dbname, "--no-acl", "--no-owner", import_path ) if result.returncode != 0: diff --git a/dslr/pg_client.py b/dslr/pg_client.py index 930d20a..dbc67c4 100644 --- a/dslr/pg_client.py +++ b/dslr/pg_client.py @@ -1,7 +1,11 @@ -from typing import Any, List, Tuple +from typing import Any, List, Optional, Tuple import psycopg2 +from dslr.console import console + +from .config import settings + class PGClient: """ @@ -9,6 +13,12 @@ class PGClient: """ def __init__(self, host, port, user, password, dbname): + self.host = host + self.port = port + self.user = user + self.password = password + self.dbname = dbname + self.conn = psycopg2.connect( host=host, port=port, @@ -16,8 +26,20 @@ def __init__(self, host, port, user, password, dbname): password=password, dbname=dbname, ) + self.conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) + self.cur = self.conn.cursor() - def execute(self, sql, data) -> List[Tuple[Any, ...]]: + def execute(self, sql, data) -> Optional[List[Tuple[Any, ...]]]: + if settings.debug: + console.log(f"SQL: {sql}") + console.log(f"DATA: {data}") + self.cur.execute(sql, data) - return self.cur.fetchall() + + try: + result = self.cur.fetchall() + except psycopg2.ProgrammingError: + result = None + + return result diff --git a/dslr/runner.py b/dslr/runner.py index 3c2c620..2193706 100644 --- a/dslr/runner.py +++ b/dslr/runner.py @@ -1,7 +1,9 @@ import os import subprocess from collections import namedtuple -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, Union + +from psycopg2 import sql from dslr.pg_client import PGClient @@ -50,7 +52,9 @@ def exec_shell(*cmd: str) -> Result: pg_client: Optional[PGClient] = None -def exec_sql(sql: str, data: Optional[List[Any]] = None) -> List[Tuple[Any, ...]]: +def exec_sql( + sql: Union[sql.SQL, str], data: Optional[List[Any]] = None +) -> Optional[List[Tuple[Any, ...]]]: """ Executes a SQL query. """ diff --git a/tests/test_cli.py b/tests/test_cli.py index 14625fc..12d97ee 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,6 +1,6 @@ import os from datetime import datetime -from typing import Any, List, Optional, Tuple +from typing import Any, List, Tuple from unittest import TestCase, mock from click.testing import CliRunner @@ -8,11 +8,11 @@ from dslr import cli, operations, runner -def stub_exec_shell(*cmd: str): +def stub_exec_shell(*args, **kwargs): return runner.Result(returncode=0, stdout="", stderr="") -def stub_exec_sql(sql: str, data: Optional[List[Any]] = None) -> List[Tuple[Any, ...]]: +def stub_exec_sql(*args, **kwargs) -> List[Tuple[Any, ...]]: fake_snapshot_1 = operations.generate_snapshot_db_name( "existing-snapshot-1", created_at=datetime(2020, 1, 1, 0, 0, 0, 0), From 0d1bacf9cb477e8855962bc8786d3d1107a8596b Mon Sep 17 00:00:00 2001 From: Mitchel Cabuloy Date: Sun, 31 Jul 2022 23:26:01 +0800 Subject: [PATCH 4/9] Fix tests --- tests/test_cli.py | 36 ++++++++++++------------------------ 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index 12d97ee..4bf6154 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -166,17 +166,15 @@ def test_import_overwrite(self): ) -@mock.patch("dslr.operations.exec_shell", new=stub_exec_shell) -@mock.patch("dslr.operations.exec_sql", new=stub_exec_sql) +@mock.patch("dslr.cli.get_snapshots") class ConfigTest(TestCase): @mock.patch.dict( os.environ, {"DATABASE_URL": "postgres://envvar:pw@test:5432/my_db"} ) @mock.patch("dslr.cli.settings") - @mock.patch("dslr.operations.settings") - def test_database_url(self, mock_operations_settings, mock_cli_settings): + def test_database_url(self, mock_cli_settings, mock_get_snapshots): runner = CliRunner() - result = runner.invoke(cli.cli, ["snapshot", "my-snapshot"]) + result = runner.invoke(cli.cli, ["list"]) self.assertEqual(result.exit_code, 0) @@ -186,14 +184,13 @@ def test_database_url(self, mock_operations_settings, mock_cli_settings): ) @mock.patch("dslr.cli.settings") - @mock.patch("dslr.operations.settings") - def test_toml(self, mock_operations_settings, mock_cli_settings): + def test_toml(self, mock_cli_settings, mock_get_snapshots): with mock.patch( "builtins.open", mock.mock_open(read_data=b"url = 'postgres://toml:pw@test:5432/my_db'"), ): runner = CliRunner() - result = runner.invoke(cli.cli, ["snapshot", "my-snapshot"]) + result = runner.invoke(cli.cli, ["list"]) self.assertEqual(result.exit_code, 0) @@ -203,12 +200,11 @@ def test_toml(self, mock_operations_settings, mock_cli_settings): ) @mock.patch("dslr.cli.settings") - @mock.patch("dslr.operations.settings") - def test_db_option(self, mock_operations_settings, mock_cli_settings): + def test_db_option(self, mock_cli_settings, mock_get_snapshots): runner = CliRunner() result = runner.invoke( cli.cli, - ["--url", "postgres://cli:pw@test:5432/my_db", "snapshot", "my-snapshot"], + ["--url", "postgres://cli:pw@test:5432/my_db", "list"], ) self.assertEqual(result.exit_code, 0) @@ -219,13 +215,10 @@ def test_db_option(self, mock_operations_settings, mock_cli_settings): ) @mock.patch("dslr.cli.settings") - @mock.patch("dslr.operations.settings") - def test_settings_preference_order( - self, mock_operations_settings, mock_cli_settings - ): + def test_settings_preference_order(self, mock_cli_settings, mock_get_snapshots): # No options passed (e.g. PG environment variables are used) runner = CliRunner() - result = runner.invoke(cli.cli, ["snapshot", "my-snapshot"]) + result = runner.invoke(cli.cli, ["list"]) self.assertEqual(result.exit_code, 0) # DATABASE_URL environment variable is used @@ -233,7 +226,7 @@ def test_settings_preference_order( os.environ, {"DATABASE_URL": "postgres://envvar:pw@test:5432/my_db"} ): runner = CliRunner() - result = runner.invoke(cli.cli, ["snapshot", "my-snapshot"]) + result = runner.invoke(cli.cli, ["list"]) self.assertEqual(result.exit_code, 0) # TOML file is used @@ -242,19 +235,14 @@ def test_settings_preference_order( mock.mock_open(read_data=b"url = 'postgres://toml:pw@test:5432/my_db'"), ): runner = CliRunner() - result = runner.invoke(cli.cli, ["snapshot", "my-snapshot"]) + result = runner.invoke(cli.cli, ["list"]) self.assertEqual(result.exit_code, 0) # --url option is used runner = CliRunner() result = runner.invoke( cli.cli, - [ - "--url", - "postgres://cli:pw@test:5432/my_db", - "snapshot", - "my-snapshot", - ], + ["--url", "postgres://cli:pw@test:5432/my_db", "list"], ) self.assertEqual(result.exit_code, 0) From 1323551dcafd8016c567f868e86a492bda2af254 Mon Sep 17 00:00:00 2001 From: Mitchel Cabuloy Date: Sun, 31 Jul 2022 23:33:36 +0800 Subject: [PATCH 5/9] Update README --- README.md | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index b9c8d88..9f16fc0 100644 --- a/README.md +++ b/README.md @@ -91,28 +91,21 @@ Here's the raw data: ``` -pip install DSLR +pip install DSLR psycopg2 # or psycopg2-binary ```` -DSLR requires that the Postgres client binaries (`psql`, `createdb`, `dropdb`) -are present in your `PATH`. DSLR uses them to interact with Postgres. +Additionally, the DSLR `export` and `import` snapshot commands require `pg_dump` +and `pg_restore` to be present in your `PATH`. ## Configuration You can tell DSLR which database to take snapshots of in a few ways: -**PG\* environment variables** - -If you have the [PG* environment -variables](https://www.postgresql.org/docs/current/libpq-envars.html) set, DSLR -will automatically try to use these in a similar way to `psql`. - **DATABASE_URL** If the `DATABASE_URL` environment variable is set, DSLR will use this to connect -to your target database. DSLR will prefer this over the PG* environment -variables. +to your target database. ```bash export DATABASE_URL=postgres://username:password@host:port/database_name @@ -120,9 +113,8 @@ export DATABASE_URL=postgres://username:password@host:port/database_name **dslr.toml** -If you have a `dslr.toml` file in the same directory where you're running -`dslr`, DSLR will read its settings from it. DSLR will prefer this over the -environment variables. +If a `dslr.toml` file exists in the current directory, DSLR will read its +settings from there. DSLR will prefer this over the environment variable. ```toml url: postgres://username:password@host:port/database_name @@ -135,8 +127,6 @@ This will override any of the above settings. ## Usage -You're ready to use DSLR! - ``` $ dslr snapshot my-first-snapshot Created new snapshot my-first-snapshot From 9d90acf6685400ff027f7e1596a1403a0ba579d1 Mon Sep 17 00:00:00 2001 From: Mitchel Cabuloy Date: Sun, 31 Jul 2022 23:38:17 +0800 Subject: [PATCH 6/9] Add psycopg2-binary requirement on tox --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index a77daf9..21066c6 100644 --- a/tox.ini +++ b/tox.ini @@ -17,6 +17,7 @@ python = 3.10: py310, flake8, black, isort [testenv] +deps = psycopg2-binary commands = python -m unittest From dd72964e53e4ee951be7c5a32e50362abf056142 Mon Sep 17 00:00:00 2001 From: Mitchel Cabuloy Date: Sun, 31 Jul 2022 23:44:35 +0800 Subject: [PATCH 7/9] Remove DSLRException --- dslr/operations.py | 17 ++--------------- dslr/runner.py | 7 ++++--- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/dslr/operations.py b/dslr/operations.py index 81a60d9..59f572d 100644 --- a/dslr/operations.py +++ b/dslr/operations.py @@ -7,11 +7,6 @@ from .config import settings from .runner import exec_shell, exec_sql - -class DSLRException(Exception): - pass - - ################################################################################ # Database operations ################################################################################ @@ -166,10 +161,7 @@ def export_snapshot(snapshot: Snapshot) -> str: Exports the given snapshot to a file """ export_path = f"{snapshot.name}_{snapshot.created_at:%Y%m%d-%H%M%S}.dump" - result = exec_shell("pg_dump", "-Fc", "-d", snapshot.dbname, "-f", export_path) - - if result.returncode != 0: - raise DSLRException(result.stderr) + exec_shell("pg_dump", "-Fc", "-d", snapshot.dbname, "-f", export_path) return export_path @@ -181,9 +173,4 @@ def import_snapshot(import_path: str, snapshot_name: str): dbname = generate_snapshot_db_name(snapshot_name) create_database(dbname=dbname) - result = exec_shell( - "pg_restore", "-d", dbname, "--no-acl", "--no-owner", import_path - ) - - if result.returncode != 0: - raise DSLRException(result.stderr) + exec_shell("pg_restore", "-d", dbname, "--no-acl", "--no-owner", import_path) diff --git a/dslr/runner.py b/dslr/runner.py index 2193706..d78a321 100644 --- a/dslr/runner.py +++ b/dslr/runner.py @@ -10,7 +10,7 @@ from .config import settings from .console import console -Result = namedtuple("Result", ["returncode", "stdout", "stderr"]) +Result = namedtuple("Result", ["stdout", "stderr"]) def exec_shell(*cmd: str) -> Result: @@ -40,9 +40,10 @@ def exec_shell(*cmd: str) -> Result: console.log("STDOUT:\n", stdout.decode("utf-8"), "\n") console.log("STDERR:\n", stderr.decode("utf-8"), "\n") - # TODO: Make this raise an exception instead + if p.returncode != 0: + raise RuntimeError(f"Command failed: {cmd}") + return Result( - returncode=p.returncode, stdout=stdout.decode("utf-8"), stderr=stderr.decode("utf-8"), ) From dd668fb7cd993dc0ada19fb3bd4bff2a275bb850 Mon Sep 17 00:00:00 2001 From: Mitchel Cabuloy Date: Sun, 31 Jul 2022 23:52:10 +0800 Subject: [PATCH 8/9] Fix tests --- tests/test_cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index 4bf6154..631f98e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -9,7 +9,7 @@ def stub_exec_shell(*args, **kwargs): - return runner.Result(returncode=0, stdout="", stderr="") + return runner.Result(stdout="", stderr="") def stub_exec_sql(*args, **kwargs) -> List[Tuple[Any, ...]]: From 68a26e2b5e0291f7443b17392955f377b2ac28a4 Mon Sep 17 00:00:00 2001 From: Mitchel Cabuloy Date: Mon, 1 Aug 2022 00:00:36 +0800 Subject: [PATCH 9/9] Update comment --- dslr/runner.py | 2 +- tests/test_cli.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dslr/runner.py b/dslr/runner.py index d78a321..cc48bb8 100644 --- a/dslr/runner.py +++ b/dslr/runner.py @@ -63,7 +63,7 @@ def exec_sql( if not pg_client: # We always want to connect to the `postgres` and not the target - # database because all of our operations don't need to query the target + # database because none of our operations need to query the target # database. pg_client = PGClient( host=settings.db.host, diff --git a/tests/test_cli.py b/tests/test_cli.py index 631f98e..35a9e1b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -8,7 +8,7 @@ from dslr import cli, operations, runner -def stub_exec_shell(*args, **kwargs): +def stub_exec_shell(*args, **kwargs) -> runner.Result: return runner.Result(stdout="", stderr="")