Skip to content

Commit

Permalink
Merge pull request #2 from mixxorz/feature/psycopg2
Browse files Browse the repository at this point in the history
Use psycopg2 instead of shell commands
  • Loading branch information
mixxorz authored Jul 31, 2022
2 parents 3a0ab45 + 68a26e2 commit c8af0a8
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 145 deletions.
22 changes: 6 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,38 +91,30 @@ Here's the raw data:

```
pip install DSLR
pip install DSLR psycopg2 # or psycopg2-binary
````
DSLR requires that the Postgres client binaries (`psql`, `createdb`, `dropdb`)
are present in your `PATH`. DSLR uses them to interact with Postgres.
Additionally, the DSLR `export` and `import` snapshot commands require `pg_dump`
and `pg_restore` to be present in your `PATH`.
## Configuration
You can tell DSLR which database to take snapshots of in a few ways:
**PG\* environment variables**
If you have the [PG* environment
variables](https://www.postgresql.org/docs/current/libpq-envars.html) set, DSLR
will automatically try to use these in a similar way to `psql`.
**DATABASE_URL**
If the `DATABASE_URL` environment variable is set, DSLR will use this to connect
to your target database. DSLR will prefer this over the PG* environment
variables.
to your target database.
```bash
export DATABASE_URL=postgres://username:password@host:port/database_name
````
**dslr.toml**
If you have a `dslr.toml` file in the same directory where you're running
`dslr`, DSLR will read its settings from it. DSLR will prefer this over the
environment variables.
If a `dslr.toml` file exists in the current directory, DSLR will read its
settings from there. DSLR will prefer this over the environment variable.
```toml
url: postgres://username:password@host:port/database_name
Expand All @@ -135,8 +127,6 @@ This will override any of the above settings.

## Usage

You're ready to use DSLR!

```
$ dslr snapshot my-first-snapshot
Created new snapshot my-first-snapshot
Expand Down
15 changes: 7 additions & 8 deletions dslr/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .config import settings
from .console import console, cprint, eprint
from .operations import (
DSLRException,
SnapshotNotFound,
create_snapshot,
delete_snapshot,
Expand Down Expand Up @@ -100,7 +99,7 @@ def snapshot(name: str):
try:
with console.status("Creating snapshot"):
create_snapshot(name)
except DSLRException as e:
except Exception as e:
eprint("Failed to create snapshot")
eprint(e, style="white")
sys.exit(1)
Expand All @@ -126,7 +125,7 @@ def restore(name):
with console.status("Restoring snapshot"):
try:
restore_snapshot(snapshot)
except DSLRException as e:
except Exception as e:
eprint("Failed to restore snapshot")
eprint(e, style="white")
sys.exit(1)
Expand All @@ -141,7 +140,7 @@ def list():
"""
try:
snapshots = get_snapshots()
except DSLRException as e:
except Exception as e:
eprint("Failed to list snapshots")
eprint(f"{e}", style="white")
sys.exit(1)
Expand Down Expand Up @@ -174,7 +173,7 @@ def delete(name):

try:
delete_snapshot(snapshot)
except DSLRException as e:
except Exception as e:
eprint("Failed to delete snapshot")
eprint(e, style="white")
sys.exit(1)
Expand Down Expand Up @@ -212,7 +211,7 @@ def rename(old_name, new_name):

try:
rename_snapshot(old_snapshot, new_name)
except DSLRException as e:
except Exception as e:
eprint("Failed to rename snapshot")
eprint(e, style="white")
sys.exit(1)
Expand All @@ -235,7 +234,7 @@ def export(name):
try:
with console.status("Exporting snapshot"):
export_path = export_snapshot(snapshot)
except DSLRException as e:
except Exception as e:
eprint("Failed to export snapshot")
eprint(e, style="white")
sys.exit(1)
Expand Down Expand Up @@ -269,7 +268,7 @@ def import_(filename, name):
try:
with console.status("Importing snapshot"):
import_snapshot(filename, name)
except DSLRException as e:
except Exception as e:
eprint("Failed to import snapshot")
eprint(e, style="white")
sys.exit(1)
Expand Down
143 changes: 66 additions & 77 deletions dslr/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,58 @@
from datetime import datetime
from typing import List, Optional

from psycopg2 import sql

from .config import settings
from .runner import exec
from .runner import exec_shell, exec_sql

Snapshot = namedtuple("Snapshot", ["dbname", "name", "created_at"])
################################################################################
# Database operations
################################################################################


class DSLRException(Exception):
pass
def kill_connections(dbname: str):
"""
Kills all connections to the given database
"""
exec_sql(
"SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity "
"WHERE pg_stat_activity.datname = %s",
[dbname],
)


def create_database(*, dbname: str, template: Optional[str] = None):
"""
Creates a new database with the given name, optionally using the given template
"""
if template:
exec_sql(
sql.SQL("CREATE DATABASE {} TEMPLATE {}").format(
sql.Identifier(dbname),
sql.Identifier(template),
)
)
else:
exec_sql(
sql.SQL("CREATE DATABASE {}").format(
sql.Identifier(dbname),
)
)


def drop_database(dbname: str):
"""
Drops the given database
"""
exec_sql(sql.SQL("DROP DATABASE {}").format(sql.Identifier(dbname)))


################################################################################
# Snapshot operations
################################################################################

Snapshot = namedtuple("Snapshot", ["dbname", "name", "created_at"])


def generate_snapshot_db_name(
Expand All @@ -26,23 +70,6 @@ def generate_snapshot_db_name(
return f"dslr_{timestamp}_{snapshot_name}"


def kill_connections(dbname: str):
"""
Kills all connections to the given database
"""
result = exec(
"psql",
"-d",
"postgres",
"-c",
"SELECT pg_terminate_backend(pg_stat_activity.pid) "
f"FROM pg_stat_activity WHERE pg_stat_activity.datname = '{dbname}'",
)

if result.returncode != 0:
raise DSLRException(result.stderr)


def get_snapshots() -> List[Snapshot]:
"""
Returns the list of database snapshots
Expand All @@ -52,28 +79,19 @@ def get_snapshots() -> List[Snapshot]:
dslr_<timestamp>_<snapshot_name>
"""
# Find the snapshot databases
result = exec("psql", "-c", "SELECT datname FROM pg_database")

if result.returncode != 0:
raise DSLRException(result.stderr)
result = exec_sql("SELECT datname FROM pg_database WHERE datname LIKE 'dslr_%'")

lines = sorted(
[
line.strip()
for line in result.stdout.split("\n")
if line.strip().startswith("dslr_")
]
)
snapshot_dbnames = sorted([row[0] for row in result])

# Parse the name into a Snapshot
parts = [line.split("_") for line in lines]
parts = [dbname.split("_") for dbname in snapshot_dbnames]
return [
Snapshot(
dbname=line,
name="_".join(part[2:]),
created_at=datetime.fromtimestamp(int(part[1])),
)
for part, line in zip(parts, lines)
for part, line in zip(parts, snapshot_dbnames)
]


Expand Down Expand Up @@ -103,70 +121,47 @@ def create_snapshot(snapshot_name: str):
Snapshotting works by creating a new database using the local database as a
template.
createdb -T source_db_name dslr_<timestamp>_<name>
"""
kill_connections(settings.db.name)

result = exec(
"createdb", "-T", settings.db.name, generate_snapshot_db_name(snapshot_name)
create_database(
dbname=generate_snapshot_db_name(snapshot_name), template=settings.db.name
)

if result.returncode != 0:
raise DSLRException(result.stderr)


def delete_snapshot(snapshot: Snapshot):
"""
Deletes the given snapshot
"""
result = exec("dropdb", snapshot.dbname)

if result.returncode != 0:
raise DSLRException(result.stderr)
drop_database(snapshot.dbname)


def restore_snapshot(snapshot: Snapshot):
"""
Restores the database from the given snapshot
"""
kill_connections(settings.db.name)

result = exec("dropdb", settings.db.name)

if result.returncode != 0:
raise DSLRException(result.stderr)

result = exec("createdb", "-T", snapshot.dbname, settings.db.name)

if result.returncode != 0:
raise DSLRException(result.stderr)
drop_database(settings.db.name)
create_database(dbname=settings.db.name, template=snapshot.dbname)


def rename_snapshot(snapshot: Snapshot, new_name: str):
"""
Renames the given snapshot
"""
result = exec(
"psql",
"-c",
f'ALTER DATABASE "{snapshot.dbname}" RENAME TO '
f'"{generate_snapshot_db_name(new_name, snapshot.created_at)}"',
exec_sql(
sql.SQL("ALTER DATABASE {} RENAME TO {}").format(
sql.Identifier(snapshot.dbname),
sql.Identifier(generate_snapshot_db_name(new_name, snapshot.created_at)),
)
)

if result.returncode != 0:
raise DSLRException(result.stderr)


def export_snapshot(snapshot: Snapshot) -> str:
"""
Exports the given snapshot to a file
"""
export_path = f"{snapshot.name}_{snapshot.created_at:%Y%m%d-%H%M%S}.dump"
result = exec("pg_dump", "-Fc", "-d", snapshot.dbname, "-f", export_path)

if result.returncode != 0:
raise DSLRException(result.stderr)
exec_shell("pg_dump", "-Fc", "-d", snapshot.dbname, "-f", export_path)

return export_path

Expand All @@ -175,13 +170,7 @@ def import_snapshot(import_path: str, snapshot_name: str):
"""
Imports the given snapshot from a file
"""
db_name = generate_snapshot_db_name(snapshot_name)
result = exec("createdb", db_name)

if result.returncode != 0:
raise DSLRException(result.stderr)

result = exec("pg_restore", "-d", db_name, "--no-acl", "--no-owner", import_path)
dbname = generate_snapshot_db_name(snapshot_name)
create_database(dbname=dbname)

if result.returncode != 0:
raise DSLRException(result.stderr)
exec_shell("pg_restore", "-d", dbname, "--no-acl", "--no-owner", import_path)
45 changes: 45 additions & 0 deletions dslr/pg_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Any, List, Optional, Tuple

import psycopg2

from dslr.console import console

from .config import settings


class PGClient:
"""
Thin wrapper around psycopg2
"""

def __init__(self, host, port, user, password, dbname):
self.host = host
self.port = port
self.user = user
self.password = password
self.dbname = dbname

self.conn = psycopg2.connect(
host=host,
port=port,
user=user,
password=password,
dbname=dbname,
)
self.conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)

self.cur = self.conn.cursor()

def execute(self, sql, data) -> Optional[List[Tuple[Any, ...]]]:
if settings.debug:
console.log(f"SQL: {sql}")
console.log(f"DATA: {data}")

self.cur.execute(sql, data)

try:
result = self.cur.fetchall()
except psycopg2.ProgrammingError:
result = None

return result
Loading

0 comments on commit c8af0a8

Please sign in to comment.