From d2a2ef7189655fdd1326ca9d987b6588eed64251 Mon Sep 17 00:00:00 2001 From: Dan Lemon Date: Fri, 22 May 2026 08:52:50 +0200 Subject: [PATCH 1/5] chore: add restore database script --- scripts/restore_database.py | 680 ++++++++++++++++++++++++++++++++++++ 1 file changed, 680 insertions(+) create mode 100755 scripts/restore_database.py diff --git a/scripts/restore_database.py b/scripts/restore_database.py new file mode 100755 index 00000000..33f86e7a --- /dev/null +++ b/scripts/restore_database.py @@ -0,0 +1,680 @@ +#!/usr/bin/env python3 +""" +Restore a PostgreSQL database from a Lagoon backup (.tar.gz). + +Lagoon backups are .tar.gz files containing a pg_dump directory-format archive: + .tar.gz -> .tar -> (*.dat files, restore.sql, toc.dat) + +The restore.sql contains $$PATH$$ placeholders for .dat file paths. This script +extracts the archive, replaces $$PATH$$, and streams data via COPY FROM STDIN +(since server-side COPY FROM file requires superuser privileges). + +IMPORTANT: Only use this with backups from trusted sources (e.g. Lagoon). + +Usage: + python scripts/restore_database.py /path/to/backup.tar.gz +""" + +import argparse +import os +import pathlib +import re +import shutil +import sys +import tarfile +import tempfile +from urllib.parse import urlparse + +import psycopg2 +from psycopg2 import sql +from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT + + +def sanitize_error(msg, config): + """Remove sensitive data (password) from error messages.""" + if config.get("password"): + msg = msg.replace(config["password"], "****") + return msg + + +def get_db_config(): + """Parse DATABASE_URL into connection components.""" + database_url = os.getenv( + "DATABASE_URL", "postgres://postgres:postgres@postgres:5432/postgres_service" + ) + parsed = urlparse(database_url) + return { + "host": parsed.hostname, + "port": parsed.port or 5432, + "user": parsed.username, + "password": parsed.password, + "database": parsed.path.lstrip("/"), + } + + +def connect_to_maintenance_db(config): + """Connect to the 'postgres' maintenance database (not the target DB).""" + conn = psycopg2.connect( + host=config["host"], + port=config["port"], + user=config["user"], + password=config["password"], + dbname="postgres", + connect_timeout=10, + ) + conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) + return conn + + +def connect_to_target_db(config): + """Connect to the target database.""" + return psycopg2.connect( + host=config["host"], + port=config["port"], + user=config["user"], + password=config["password"], + dbname=config["database"], + connect_timeout=10, + ) + + +def backup_current_database(config, backup_path): + """Create a best-effort dump (schema + data) of the current database as a safety net. + + Dumps public-schema objects: tables (columns, defaults, NOT NULL), + constraints (PK, unique, FK, check), indexes, sequences, views, functions, + triggers, and data via COPY TO STDOUT. + + NOT included: extensions, custom types/enums/domains, roles, grants/ACLs, + RLS policies, non-public schemas, tablespaces, or publication/subscription + config. For a fully faithful backup, use pg_dump externally. + + The output is replayable with psql for the common case but may require + manual adjustment for databases using the above features. + """ + print(f" Backing up current database to file: {os.path.basename(backup_path)}") + + conn = connect_to_target_db(config) + try: + cur = conn.cursor() + + # Open with restrictive permissions (binary mode required for copy_expert) + fd = os.open(backup_path, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o600) + with os.fdopen(fd, "wb") as f: + f.write("-- Pre-restore safety backup\n") + f.write("-- Generated by restore_database.py\n") + f.write("-- Replay with: psql $DATABASE_URL < this_file.sql\n\n") + + # --- Sequences (must come before tables for DEFAULT references) --- + cur.execute( + "SELECT sequence_name FROM information_schema.sequences " + "WHERE sequence_schema = 'public' ORDER BY sequence_name;" + ) + sequences = [row[0] for row in cur.fetchall()] + for seq in sequences: + cur.execute( + "SELECT last_value, is_called FROM pg_sequences " + "WHERE schemaname = 'public' AND sequencename = %s;", + (seq,), + ) + row = cur.fetchone() + f.write(f'CREATE SEQUENCE IF NOT EXISTS "{seq}";\n') + if row and row[0] is not None: + is_called = "true" if row[1] else "false" + f.write(f"SELECT setval('\"{seq}\"', {row[0]}, {is_called});\n") + f.write("\n") + + # --- Tables (full DDL from pg_catalog) --- + cur.execute( + "SELECT tablename FROM pg_tables " + "WHERE schemaname = 'public' ORDER BY tablename;" + ) + tables = [row[0] for row in cur.fetchall()] + + for table in tables: + # Get column definitions including proper types and defaults + cur.execute(""" + SELECT + a.attname, + pg_catalog.format_type(a.atttypid, a.atttypmod), + a.attnotnull, + pg_catalog.pg_get_expr(d.adbin, d.adrelid) + FROM pg_catalog.pg_attribute a + LEFT JOIN pg_catalog.pg_attrdef d + ON d.adrelid = a.attrelid AND d.adnum = a.attnum + WHERE a.attrelid = %s::regclass + AND a.attnum > 0 + AND NOT a.attisdropped + ORDER BY a.attnum; + """, (f'public."{table}"',)) + columns = cur.fetchall() + + col_defs = [] + for col_name, col_type, not_null, default in columns: + parts = [f' "{col_name}" {col_type}'] + if default: + parts.append(f"DEFAULT {default}") + if not_null: + parts.append("NOT NULL") + col_defs.append(" ".join(parts)) + + f.write(f'CREATE TABLE IF NOT EXISTS "{table}" (\n') + f.write(",\n".join(col_defs)) + f.write("\n);\n\n") + + # --- Primary keys and unique constraints --- + cur.execute(""" + SELECT + conname, + conrelid::regclass::text, + pg_catalog.pg_get_constraintdef(oid) + FROM pg_catalog.pg_constraint + WHERE connamespace = 'public'::regnamespace + AND contype IN ('p', 'u') + ORDER BY conrelid::regclass::text, conname; + """) + for con_name, table_name, con_def in cur.fetchall(): + f.write(f'ALTER TABLE {table_name} ADD CONSTRAINT "{con_name}" {con_def};\n') + f.write("\n") + + # --- Foreign keys --- + cur.execute(""" + SELECT + conname, + conrelid::regclass::text, + pg_catalog.pg_get_constraintdef(oid) + FROM pg_catalog.pg_constraint + WHERE connamespace = 'public'::regnamespace + AND contype = 'f' + ORDER BY conrelid::regclass::text, conname; + """) + for con_name, table_name, con_def in cur.fetchall(): + f.write(f'ALTER TABLE {table_name} ADD CONSTRAINT "{con_name}" {con_def};\n') + f.write("\n") + + # --- Check constraints (excluding system-generated NOT NULL checks) --- + cur.execute(""" + SELECT + conname, + conrelid::regclass::text, + pg_catalog.pg_get_constraintdef(oid) + FROM pg_catalog.pg_constraint + WHERE connamespace = 'public'::regnamespace + AND contype = 'c' + ORDER BY conrelid::regclass::text, conname; + """) + for con_name, table_name, con_def in cur.fetchall(): + f.write(f'ALTER TABLE {table_name} ADD CONSTRAINT "{con_name}" {con_def};\n') + f.write("\n") + + # --- Indexes (excluding those backing constraints) --- + cur.execute(""" + SELECT pg_catalog.pg_get_indexdef(i.indexrelid) + FROM pg_catalog.pg_index i + JOIN pg_catalog.pg_class c ON c.oid = i.indexrelid + JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = 'public' + AND NOT i.indisprimary + AND NOT EXISTS ( + SELECT 1 FROM pg_catalog.pg_constraint con + WHERE con.conindid = i.indexrelid + ) + ORDER BY c.relname; + """) + for (index_def,) in cur.fetchall(): + f.write(f"{index_def};\n") + f.write("\n") + + # --- Views --- + cur.execute( + "SELECT viewname, definition FROM pg_views " + "WHERE schemaname = 'public' ORDER BY viewname;" + ) + for view_name, view_def in cur.fetchall(): + f.write(f'CREATE OR REPLACE VIEW "{view_name}" AS\n{view_def}\n\n') + + # --- Functions --- + cur.execute(""" + SELECT pg_catalog.pg_get_functiondef(p.oid) + FROM pg_catalog.pg_proc p + JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace + WHERE n.nspname = 'public' + ORDER BY p.proname; + """) + for (func_def,) in cur.fetchall(): + f.write(f"{func_def};\n\n") + + # --- Triggers --- + cur.execute(""" + SELECT pg_catalog.pg_get_triggerdef(t.oid) + FROM pg_catalog.pg_trigger t + JOIN pg_catalog.pg_class c ON c.oid = t.tgrelid + JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = 'public' + AND NOT t.tgisinternal + ORDER BY c.relname, t.tgname; + """) + for (trig_def,) in cur.fetchall(): + f.write(f"{trig_def};\n") + f.write("\n") + + # --- Table data via COPY --- + for table in tables: + copy_out = sql.SQL("COPY {} TO STDOUT").format( + sql.Identifier(table) + ) + f.write(f'COPY "{table}" FROM stdin;\n') + cur.copy_expert(copy_out.as_string(conn), f) + f.write("\\.\n\n") + + # --- Sequence ownership (link sequences to columns) --- + cur.execute(""" + SELECT + s.relname, + t.relname, + a.attname + FROM pg_catalog.pg_class s + JOIN pg_catalog.pg_namespace n ON n.oid = s.relnamespace + JOIN pg_catalog.pg_depend d ON d.objid = s.oid + JOIN pg_catalog.pg_class t ON t.oid = d.refobjid + JOIN pg_catalog.pg_attribute a + ON a.attrelid = d.refobjid AND a.attnum = d.refobjsubid + WHERE s.relkind = 'S' + AND n.nspname = 'public' + AND d.deptype = 'a' + ORDER BY s.relname; + """) + for seq_name, table_name, col_name in cur.fetchall(): + f.write( + f'ALTER SEQUENCE "{seq_name}" OWNED BY "{table_name}"."{col_name}";\n' + ) + f.write("\n") + + print(f" Backup complete ({len(tables)} tables, {len(sequences)} sequences)") + finally: + conn.close() + + return backup_path + + +def drop_and_recreate_database(config): + """Drop the target database and recreate it empty.""" + db_name = config["database"] + + if db_name == "postgres": + raise SystemExit( + "Error: Cannot drop the 'postgres' maintenance database. " + "Set DATABASE_URL to point to a different target database." + ) + + conn = connect_to_maintenance_db(config) + cur = conn.cursor() + + print(" Terminating existing connections to target database...") + cur.execute( + "SELECT pg_terminate_backend(pid) FROM pg_stat_activity " + "WHERE datname = %s AND pid <> pg_backend_pid();", + (db_name,), + ) + + print(" Dropping target database if it exists...") + cur.execute(sql.SQL("DROP DATABASE IF EXISTS {}").format(sql.Identifier(db_name))) + + print(" Creating target database...") + cur.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(db_name))) + + cur.close() + conn.close() + print(" Database recreated.") + + +def safe_extract_tar(tar, extract_dir): + """Extract tar members, validating paths stay within extract_dir.""" + extract_path = pathlib.Path(extract_dir).resolve() + safe_members = [] + for member in tar.getmembers(): + if member.issym() or member.islnk(): + raise ValueError(f"Link entries are not allowed in archive: {member.name}") + member_path = (extract_path / member.name).resolve() + if not member_path.is_relative_to(extract_path): + raise ValueError(f"Path traversal detected in archive: {member.name}") + safe_members.append(member) + tar.extractall(path=extract_dir, members=safe_members) + + +def extract_backup(backup_path, extract_dir): + """Extract .tar.gz backup to a directory. + + Returns the path to the directory containing restore.sql and .dat files. + Raises SystemExit if restore.sql cannot be found. + """ + print(f" Extracting {backup_path}...") + + with tarfile.open(backup_path, "r:gz") as tar: + safe_extract_tar(tar, extract_dir) + + # The .tar.gz may contain a nested .tar, or directly the dump files + contents = os.listdir(extract_dir) + for item in contents: + item_path = os.path.join(extract_dir, item) + if item.endswith(".tar") and tarfile.is_tarfile(item_path): + nested_dir = os.path.join(extract_dir, "dump") + os.makedirs(nested_dir, exist_ok=True) + with tarfile.open(item_path, "r:") as tar: + safe_extract_tar(tar, nested_dir) + contents = os.listdir(nested_dir) + # Check nested dir and its subdirectories + if "restore.sql" in contents: + return nested_dir + for sub in contents: + sub_path = os.path.join(nested_dir, sub) + if os.path.isdir(sub_path) and "restore.sql" in os.listdir(sub_path): + return sub_path + print(" Error: Could not locate restore.sql in nested archive.") + print(f" Nested contents: {contents}") + sys.exit(1) + + # Check if restore.sql is directly in the extract dir + if "restore.sql" in contents: + return extract_dir + + # Look one level deeper + for item in contents: + item_path = os.path.join(extract_dir, item) + if os.path.isdir(item_path) and "restore.sql" in os.listdir(item_path): + return item_path + + print(" Error: Could not locate restore.sql in archive.") + print(f" Top-level contents: {contents}") + sys.exit(1) + + +def apply_restore(config, dump_dir): + """Apply the database restore from the extracted dump directory. + + Strategy: + 1. Parse restore.sql to separate DDL/schema statements from COPY commands. + 2. Execute DDL statements directly. + 3. For COPY ... FROM '$$PATH$$/xxx.dat' statements, convert to + COPY ... FROM STDIN and stream the .dat file contents. This avoids + requiring superuser privileges on the database. + """ + restore_sql_path = os.path.join(dump_dir, "restore.sql") + + if not os.path.exists(restore_sql_path): + print(f" Error: restore.sql not found in {dump_dir}") + sys.exit(1) + + with open(restore_sql_path, "r") as f: + content = f.read() + + # Sanity check + if "PostgreSQL database dump" not in content[:500]: + print(" Warning: restore.sql does not look like a pg_dump file.") + print(f" Header: {content[:100]}") + + abs_dump_dir = os.path.abspath(dump_dir) + + # Remove empty COPY ... FROM stdin blocks (tables with no data): + # COPY public.table (...) FROM stdin; + # \. + content = re.sub( + r"^COPY\s+.+?\s+FROM\s+stdin;\s*\n\\\.\s*\n", + "", + content, + flags=re.MULTILINE, + ) + + # Pattern matching COPY ... FROM '$$PATH$$/filename.dat'; + copy_file_pattern = re.compile( + r"^(COPY\s+\S+\s*\([^)]*\))\s+FROM\s+'\$\$PATH\$\$/([^']+)';?\s*$", + re.MULTILINE, + ) + + # Split content into segments: (text_before, copy_stmt, dat_filename, text_after, ...) + # We process sequentially: run DDL, then when we hit a COPY-from-file, stream it. + segments = copy_file_pattern.split(content) + # split gives: [before, group1, group2, between, group1, group2, ..., after] + # groups come in pairs: (copy_prefix, dat_filename) + + conn = connect_to_target_db(config) + conn.autocommit = True + cur = conn.cursor() + + applied = 0 + data_loaded = 0 + warnings = 0 + + try: + i = 0 + while i < len(segments): + if i % 3 == 0: + # This is a DDL/non-COPY text segment + ddl_block = segments[i].strip() + if ddl_block: + # Execute line by line for statements, but we need to handle + # multi-line statements. Use psycopg2's execute which handles + # multiple statements in one call when separated by semicolons. + try: + cur.execute(ddl_block) + applied += 1 + except psycopg2.Error as e: + msg = e.pgerror or (e.diag.message_primary if e.diag else None) or str(e) + short_msg = sanitize_error(msg.strip().split("\n")[0][:150], config) + # Fail fast on critical DDL (CREATE/ALTER TABLE, CREATE TYPE, etc.) + # Allow non-critical statements (COMMENT, GRANT, SET, ALTER OWNER) to warn-and-continue + non_critical_prefixes = ("COMMENT ", "GRANT ", "REVOKE ", "SET ", "ALTER OWNER", "SELECT ") + block_upper = ddl_block.lstrip().upper() + is_non_critical = any(block_upper.startswith(p) for p in non_critical_prefixes) + if is_non_critical: + print(f" Warning: {short_msg}") + warnings += 1 + else: + print(f" Error (fatal): {short_msg}") + raise SystemExit( + "Critical DDL statement failed. Restore aborted. " + "The database has been dropped and is now empty. " + "Re-run the restore or recover from the safety backup." + ) + i += 1 + else: + # This is a COPY group: segments[i] = copy_prefix, segments[i+1] = dat_filename + copy_prefix = segments[i] + dat_filename = segments[i + 1] + dat_path = os.path.join(abs_dump_dir, dat_filename) + + # Validate dat file path stays within dump directory + resolved_dat = pathlib.Path(dat_path).resolve() + if not resolved_dat.is_relative_to(pathlib.Path(abs_dump_dir).resolve()): + print(f" Error: Path traversal in restore.sql: {dat_filename}") + warnings += 1 + i += 2 + continue + + if os.path.exists(dat_path): + # Convert to COPY ... FROM STDIN and stream the file + copy_stdin_sql = f"{copy_prefix} FROM STDIN" + try: + with open(dat_path, "rb") as dat_f: + cur.copy_expert(copy_stdin_sql, dat_f) + data_loaded += 1 + except psycopg2.Error as e: + msg = sanitize_error(str(e).strip().split("\n")[0][:150], config) + print(f" Warning ({dat_filename}): {msg}") + warnings += 1 + else: + print(f" Warning: Data file not found: {dat_filename}") + warnings += 1 + + i += 2 + + print(f" Restore complete: {applied} DDL sections, {data_loaded} tables loaded, {warnings} warnings.") + finally: + cur.close() + conn.close() + + return warnings + + +def check_disk_space(path, required_mb=500): + """Check if there's enough disk space at the given path.""" + stat = os.statvfs(path) + available_mb = (stat.f_bavail * stat.f_frsize) / (1024 * 1024) + if available_mb < required_mb: + print(f" Warning: Only {available_mb:.0f}MB available at {path} (recommend >= {required_mb}MB)") + return False + return True + + +def main(): + parser = argparse.ArgumentParser( + description="Restore a PostgreSQL database from a Lagoon .tar.gz backup." + ) + parser.add_argument( + "backup_file", + help="Path to the .tar.gz backup file", + ) + parser.add_argument( + "--no-backup", + action="store_true", + help="Skip creating a backup of the current database before restoring", + ) + parser.add_argument( + "--backup-dir", + default=None, + help="Directory to store the pre-restore safety backup (default: next to backup file)", + ) + parser.add_argument( + "--extract-dir", + default=None, + help="Directory to extract the backup into (default: system temp directory)", + ) + parser.add_argument( + "--yes", "-y", + action="store_true", + help="Skip confirmation prompt", + ) + args = parser.parse_args() + + # Validate backup file exists and is .tar.gz + if not os.path.exists(args.backup_file): + print(f"Error: Backup file not found: {args.backup_file}") + sys.exit(1) + + if not args.backup_file.endswith(".tar.gz"): + print(f"Error: Backup file must be a .tar.gz file: {args.backup_file}") + sys.exit(1) + + # Validate extract dir if provided, otherwise check tmp is writable + if args.extract_dir: + if not os.path.isdir(args.extract_dir): + print(f"Error: Extract directory does not exist: {args.extract_dir}") + sys.exit(1) + if not os.access(args.extract_dir, os.W_OK): + print(f"Error: Extract directory is not writable: {args.extract_dir}") + sys.exit(1) + extract_base = args.extract_dir + else: + extract_base = tempfile.gettempdir() + if not os.access(extract_base, os.W_OK): + print(f"Error: Temp directory is not writable: {extract_base}") + print(" Use --extract-dir to specify an alternative directory.") + sys.exit(1) + + # Check disk space (estimate: backup file * 3 for extraction headroom) + backup_size_mb = os.path.getsize(args.backup_file) / (1024 * 1024) + recommended_mb = max(500, int(backup_size_mb * 3)) + if not check_disk_space(extract_base, recommended_mb): + if not args.yes: + response = input(" Continue anyway? [yes/no]: ").strip().lower() + if response not in ("yes", "y"): + print("Aborted.") + sys.exit(0) + + config = get_db_config() + + print("\nDatabase restore configuration:") + print(f" Backup file : {args.backup_file}") + print(" Target host : [redacted]") + print(" Target DB : [redacted]") + print(" User : [redacted]") + print() + + if not args.yes: + print("WARNING: This will DROP the existing database and restore from backup.") + if not args.no_backup: + print("A safety backup of the current database will be created first.") + else: + print("No safety backup will be created (--no-backup specified).") + response = input("\nAre you sure you want to proceed? [yes/no]: ").strip().lower() + if response not in ("yes", "y"): + print("Aborted.") + sys.exit(0) + + tmpdir = tempfile.mkdtemp(prefix="db_restore_", dir=args.extract_dir) + os.chmod(tmpdir, 0o700) + + try: + # Step 1: Backup current database (unless skipped) + if not args.no_backup: + print("\n[1/4] Backing up current database...") + backup_dir = args.backup_dir or os.path.dirname(os.path.abspath(args.backup_file)) + os.makedirs(backup_dir, exist_ok=True) + safety_backup_path = os.path.join( + backup_dir, f"{config['database']}_pre_restore.sql" + ) + # Avoid overwriting an existing backup + if os.path.exists(safety_backup_path): + base, ext = os.path.splitext(safety_backup_path) + i = 1 + while os.path.exists(f"{base}_{i}{ext}"): + i += 1 + safety_backup_path = f"{base}_{i}{ext}" + + try: + backup_current_database(config, safety_backup_path) + print(" Safety backup created successfully.") + except Exception as e: + print(f" Could not backup current database: {sanitize_error(str(e), config)}") + if args.yes: + print(" Error: Cannot proceed without safety backup in non-interactive mode.") + print(" Use --no-backup to explicitly skip the safety backup.") + sys.exit(1) + response = input(" Continue without backup? [yes/no]: ").strip().lower() + if response not in ("yes", "y"): + print("Aborted.") + sys.exit(0) + else: + print("\n[1/4] Skipping backup (--no-backup)") + + # Step 2: Extract the backup archive + print("\n[2/4] Extracting backup archive...") + extract_dir = os.path.join(tmpdir, "extracted") + os.makedirs(extract_dir, exist_ok=True) + dump_dir = extract_backup(args.backup_file, extract_dir) + dump_contents = os.listdir(dump_dir) + dat_count = sum(1 for f in dump_contents if f.endswith(".dat")) + print(f" Dump directory: {dump_dir}") + print(f" Found: restore.sql + {dat_count} data files") + + # Step 3: Drop and recreate the database + print("\n[3/4] Dropping and recreating database...") + drop_and_recreate_database(config) + + # Step 4: Apply the restore + print("\n[4/4] Applying database restore...") + warning_count = apply_restore(config, dump_dir) + + if warning_count > 0: + print(f"\nRestore completed with {warning_count} warning(s).") + sys.exit(2) + else: + print("\nRestore complete!") + + finally: + # Clean up temp extraction directory + shutil.rmtree(tmpdir, ignore_errors=True) + + +if __name__ == "__main__": + main() From 375d2ec89e74b42d4ab59b8380c999541aa31da7 Mon Sep 17 00:00:00 2001 From: Dan Lemon Date: Fri, 22 May 2026 21:53:35 +0200 Subject: [PATCH 2/5] Update scripts/restore_database.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- scripts/restore_database.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/scripts/restore_database.py b/scripts/restore_database.py index 33f86e7a..631bb3e7 100755 --- a/scripts/restore_database.py +++ b/scripts/restore_database.py @@ -463,9 +463,12 @@ def apply_restore(config, dump_dir): short_msg = sanitize_error(msg.strip().split("\n")[0][:150], config) # Fail fast on critical DDL (CREATE/ALTER TABLE, CREATE TYPE, etc.) # Allow non-critical statements (COMMENT, GRANT, SET, ALTER OWNER) to warn-and-continue - non_critical_prefixes = ("COMMENT ", "GRANT ", "REVOKE ", "SET ", "ALTER OWNER", "SELECT ") + non_critical_prefixes = ("COMMENT ", "GRANT ", "REVOKE ", "SET ", "SELECT ") block_upper = ddl_block.lstrip().upper() - is_non_critical = any(block_upper.startswith(p) for p in non_critical_prefixes) + is_non_critical = ( + any(block_upper.startswith(p) for p in non_critical_prefixes) + or " OWNER TO " in block_upper + ) if is_non_critical: print(f" Warning: {short_msg}") warnings += 1 From 9eaf1b23e39d55075c0745f76aee2ac03650e1a4 Mon Sep 17 00:00:00 2001 From: Dan Lemon Date: Tue, 26 May 2026 15:59:36 +0200 Subject: [PATCH 3/5] chore: rewrite restore db script to use pg_* related tools --- scripts/restore_database.py | 635 +++++++++++------------------------- 1 file changed, 198 insertions(+), 437 deletions(-) diff --git a/scripts/restore_database.py b/scripts/restore_database.py index 631bb3e7..e5752637 100755 --- a/scripts/restore_database.py +++ b/scripts/restore_database.py @@ -3,11 +3,16 @@ Restore a PostgreSQL database from a Lagoon backup (.tar.gz). Lagoon backups are .tar.gz files containing a pg_dump directory-format archive: - .tar.gz -> .tar -> (*.dat files, restore.sql, toc.dat) + .tar.gz -> .tar -> directory with (*.dat files, toc.dat, restore.sql) -The restore.sql contains $$PATH$$ placeholders for .dat file paths. This script -extracts the archive, replaces $$PATH$$, and streams data via COPY FROM STDIN -(since server-side COPY FROM file requires superuser privileges). +This script: + 1. (Optionally) takes a pre-restore safety backup with `pg_dump -Fc`. + 2. Drops and recreates the target database via `dropdb` / `createdb`. + 3. Extracts the .tar.gz (and the nested .tar) to a temp directory. + 4. Applies the dump with `pg_restore -Fd --no-owner --no-privileges`. + +Requires the postgres client tools (pg_dump, pg_restore, dropdb, createdb, +psql) on PATH. In Lagoon these are provided by the `cli` container image. IMPORTANT: Only use this with backups from trusted sources (e.g. Lagoon). @@ -18,22 +23,24 @@ import argparse import os import pathlib -import re import shutil +import subprocess import sys import tarfile import tempfile -from urllib.parse import urlparse +from urllib.parse import unquote, urlparse -import psycopg2 -from psycopg2 import sql -from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT +REQUIRED_TOOLS = ("pg_dump", "pg_restore", "dropdb", "createdb", "psql") -def sanitize_error(msg, config): - """Remove sensitive data (password) from error messages.""" - if config.get("password"): - msg = msg.replace(config["password"], "****") + +def sanitize(msg, config): + """Remove sensitive data (password) from messages.""" + if not msg: + return msg + pw = config.get("password") + if pw: + msg = msg.replace(pw, "****") return msg @@ -43,262 +50,109 @@ def get_db_config(): "DATABASE_URL", "postgres://postgres:postgres@postgres:5432/postgres_service" ) parsed = urlparse(database_url) + db_name = unquote(parsed.path or "").removeprefix("/") return { "host": parsed.hostname, "port": parsed.port or 5432, - "user": parsed.username, - "password": parsed.password, - "database": parsed.path.lstrip("/"), + "user": unquote(parsed.username) if parsed.username is not None else None, + "password": unquote(parsed.password) if parsed.password is not None else None, + "database": db_name, } -def connect_to_maintenance_db(config): - """Connect to the 'postgres' maintenance database (not the target DB).""" - conn = psycopg2.connect( - host=config["host"], - port=config["port"], - user=config["user"], - password=config["password"], - dbname="postgres", - connect_timeout=10, - ) - conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) - return conn - - -def connect_to_target_db(config): - """Connect to the target database.""" - return psycopg2.connect( - host=config["host"], - port=config["port"], - user=config["user"], - password=config["password"], - dbname=config["database"], - connect_timeout=10, - ) +def pg_env(config): + """Build an environment dict with libpq connection variables set.""" + env = os.environ.copy() + if config.get("host"): + env["PGHOST"] = str(config["host"]) + if config.get("port"): + env["PGPORT"] = str(config["port"]) + if config.get("user"): + env["PGUSER"] = str(config["user"]) + if config.get("password"): + env["PGPASSWORD"] = str(config["password"]) + # libpq picks up these on its own; no need to pass them as CLI flags. + return env + + +def check_required_tools(): + """Verify all required postgres client tools are on PATH.""" + missing = [t for t in REQUIRED_TOOLS if shutil.which(t) is None] + if missing: + print(f"Error: Required postgres client tool(s) not found: {', '.join(missing)}") + print(" This script must run in an environment with postgres client tools") + print(" installed (e.g. the Lagoon `cli` container).") + sys.exit(1) -def backup_current_database(config, backup_path): - """Create a best-effort dump (schema + data) of the current database as a safety net. +def run_pg_tool(cmd, config, *, capture=True, check=True): + """Run a postgres client command, returning the CompletedProcess. - Dumps public-schema objects: tables (columns, defaults, NOT NULL), - constraints (PK, unique, FK, check), indexes, sequences, views, functions, - triggers, and data via COPY TO STDOUT. + Stderr is captured (and sanitized) so passwords don't leak into logs. + """ + try: + result = subprocess.run( + cmd, + env=pg_env(config), + stdout=subprocess.PIPE if capture else None, + stderr=subprocess.PIPE, + text=True, + check=False, + ) + except FileNotFoundError as e: + raise SystemExit(f"Error: {sanitize(str(e), config)}") - NOT included: extensions, custom types/enums/domains, roles, grants/ACLs, - RLS policies, non-public schemas, tablespaces, or publication/subscription - config. For a fully faithful backup, use pg_dump externally. + if not capture and result.stderr: + sys.stderr.write(sanitize(result.stderr, config)) - The output is replayable with psql for the common case but may require - manual adjustment for databases using the above features. - """ - print(f" Backing up current database to file: {os.path.basename(backup_path)}") + if check and result.returncode != 0: + stderr = sanitize((result.stderr or "").strip(), config) + raise subprocess.CalledProcessError( + result.returncode, cmd, output=result.stdout, stderr=stderr + ) + return result - conn = connect_to_target_db(config) - try: - cur = conn.cursor() - - # Open with restrictive permissions (binary mode required for copy_expert) - fd = os.open(backup_path, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o600) - with os.fdopen(fd, "wb") as f: - f.write("-- Pre-restore safety backup\n") - f.write("-- Generated by restore_database.py\n") - f.write("-- Replay with: psql $DATABASE_URL < this_file.sql\n\n") - - # --- Sequences (must come before tables for DEFAULT references) --- - cur.execute( - "SELECT sequence_name FROM information_schema.sequences " - "WHERE sequence_schema = 'public' ORDER BY sequence_name;" - ) - sequences = [row[0] for row in cur.fetchall()] - for seq in sequences: - cur.execute( - "SELECT last_value, is_called FROM pg_sequences " - "WHERE schemaname = 'public' AND sequencename = %s;", - (seq,), - ) - row = cur.fetchone() - f.write(f'CREATE SEQUENCE IF NOT EXISTS "{seq}";\n') - if row and row[0] is not None: - is_called = "true" if row[1] else "false" - f.write(f"SELECT setval('\"{seq}\"', {row[0]}, {is_called});\n") - f.write("\n") - - # --- Tables (full DDL from pg_catalog) --- - cur.execute( - "SELECT tablename FROM pg_tables " - "WHERE schemaname = 'public' ORDER BY tablename;" - ) - tables = [row[0] for row in cur.fetchall()] - - for table in tables: - # Get column definitions including proper types and defaults - cur.execute(""" - SELECT - a.attname, - pg_catalog.format_type(a.atttypid, a.atttypmod), - a.attnotnull, - pg_catalog.pg_get_expr(d.adbin, d.adrelid) - FROM pg_catalog.pg_attribute a - LEFT JOIN pg_catalog.pg_attrdef d - ON d.adrelid = a.attrelid AND d.adnum = a.attnum - WHERE a.attrelid = %s::regclass - AND a.attnum > 0 - AND NOT a.attisdropped - ORDER BY a.attnum; - """, (f'public."{table}"',)) - columns = cur.fetchall() - - col_defs = [] - for col_name, col_type, not_null, default in columns: - parts = [f' "{col_name}" {col_type}'] - if default: - parts.append(f"DEFAULT {default}") - if not_null: - parts.append("NOT NULL") - col_defs.append(" ".join(parts)) - - f.write(f'CREATE TABLE IF NOT EXISTS "{table}" (\n') - f.write(",\n".join(col_defs)) - f.write("\n);\n\n") - - # --- Primary keys and unique constraints --- - cur.execute(""" - SELECT - conname, - conrelid::regclass::text, - pg_catalog.pg_get_constraintdef(oid) - FROM pg_catalog.pg_constraint - WHERE connamespace = 'public'::regnamespace - AND contype IN ('p', 'u') - ORDER BY conrelid::regclass::text, conname; - """) - for con_name, table_name, con_def in cur.fetchall(): - f.write(f'ALTER TABLE {table_name} ADD CONSTRAINT "{con_name}" {con_def};\n') - f.write("\n") - - # --- Foreign keys --- - cur.execute(""" - SELECT - conname, - conrelid::regclass::text, - pg_catalog.pg_get_constraintdef(oid) - FROM pg_catalog.pg_constraint - WHERE connamespace = 'public'::regnamespace - AND contype = 'f' - ORDER BY conrelid::regclass::text, conname; - """) - for con_name, table_name, con_def in cur.fetchall(): - f.write(f'ALTER TABLE {table_name} ADD CONSTRAINT "{con_name}" {con_def};\n') - f.write("\n") - - # --- Check constraints (excluding system-generated NOT NULL checks) --- - cur.execute(""" - SELECT - conname, - conrelid::regclass::text, - pg_catalog.pg_get_constraintdef(oid) - FROM pg_catalog.pg_constraint - WHERE connamespace = 'public'::regnamespace - AND contype = 'c' - ORDER BY conrelid::regclass::text, conname; - """) - for con_name, table_name, con_def in cur.fetchall(): - f.write(f'ALTER TABLE {table_name} ADD CONSTRAINT "{con_name}" {con_def};\n') - f.write("\n") - - # --- Indexes (excluding those backing constraints) --- - cur.execute(""" - SELECT pg_catalog.pg_get_indexdef(i.indexrelid) - FROM pg_catalog.pg_index i - JOIN pg_catalog.pg_class c ON c.oid = i.indexrelid - JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE n.nspname = 'public' - AND NOT i.indisprimary - AND NOT EXISTS ( - SELECT 1 FROM pg_catalog.pg_constraint con - WHERE con.conindid = i.indexrelid - ) - ORDER BY c.relname; - """) - for (index_def,) in cur.fetchall(): - f.write(f"{index_def};\n") - f.write("\n") - - # --- Views --- - cur.execute( - "SELECT viewname, definition FROM pg_views " - "WHERE schemaname = 'public' ORDER BY viewname;" - ) - for view_name, view_def in cur.fetchall(): - f.write(f'CREATE OR REPLACE VIEW "{view_name}" AS\n{view_def}\n\n') - - # --- Functions --- - cur.execute(""" - SELECT pg_catalog.pg_get_functiondef(p.oid) - FROM pg_catalog.pg_proc p - JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace - WHERE n.nspname = 'public' - ORDER BY p.proname; - """) - for (func_def,) in cur.fetchall(): - f.write(f"{func_def};\n\n") - - # --- Triggers --- - cur.execute(""" - SELECT pg_catalog.pg_get_triggerdef(t.oid) - FROM pg_catalog.pg_trigger t - JOIN pg_catalog.pg_class c ON c.oid = t.tgrelid - JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE n.nspname = 'public' - AND NOT t.tgisinternal - ORDER BY c.relname, t.tgname; - """) - for (trig_def,) in cur.fetchall(): - f.write(f"{trig_def};\n") - f.write("\n") - - # --- Table data via COPY --- - for table in tables: - copy_out = sql.SQL("COPY {} TO STDOUT").format( - sql.Identifier(table) - ) - f.write(f'COPY "{table}" FROM stdin;\n') - cur.copy_expert(copy_out.as_string(conn), f) - f.write("\\.\n\n") - - # --- Sequence ownership (link sequences to columns) --- - cur.execute(""" - SELECT - s.relname, - t.relname, - a.attname - FROM pg_catalog.pg_class s - JOIN pg_catalog.pg_namespace n ON n.oid = s.relnamespace - JOIN pg_catalog.pg_depend d ON d.objid = s.oid - JOIN pg_catalog.pg_class t ON t.oid = d.refobjid - JOIN pg_catalog.pg_attribute a - ON a.attrelid = d.refobjid AND a.attnum = d.refobjsubid - WHERE s.relkind = 'S' - AND n.nspname = 'public' - AND d.deptype = 'a' - ORDER BY s.relname; - """) - for seq_name, table_name, col_name in cur.fetchall(): - f.write( - f'ALTER SEQUENCE "{seq_name}" OWNED BY "{table_name}"."{col_name}";\n' - ) - f.write("\n") - - print(f" Backup complete ({len(tables)} tables, {len(sequences)} sequences)") - finally: - conn.close() +def backup_current_database(config, backup_path): + """Run `pg_dump -Fc` against the current database as a safety net. + + Custom format (-Fc) is compact, supports selective restore via pg_restore, + and faithfully captures everything pg_dump can capture (schemas, tables, + indexes, views, functions, triggers, sequences, types, extensions, ACLs, + RLS policies, etc). + + Replay with: pg_restore -d "$DATABASE_URL" --clean .dump + """ + print(" Backing up current database to a local dump file") + + # Open with restrictive permissions before pg_dump writes to it. + fd = os.open(backup_path, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o600) + os.close(fd) + + cmd = [ + "pg_dump", + "--format=custom", + "--file", backup_path, + "--dbname", config["database"], + ] + try: + run_pg_tool(cmd, config) + except subprocess.CalledProcessError as e: + # Clean up the empty/partial file so we don't leave noise behind. + try: + os.unlink(backup_path) + except OSError as cleanup_err: + # Best-effort cleanup only: preserve original pg_dump failure path. + print(f" Warning: could not remove partial backup file: {cleanup_err}", file=sys.stderr) + raise SystemExit(f" pg_dump failed: {e.stderr or e}") + + size_mb = os.path.getsize(backup_path) / (1024 * 1024) + print(f" Backup complete ({size_mb:.1f} MB)") return backup_path def drop_and_recreate_database(config): - """Drop the target database and recreate it empty.""" + """Drop the target database (forcing connections closed) and recreate it.""" db_name = config["database"] if db_name == "postgres": @@ -307,24 +161,22 @@ def drop_and_recreate_database(config): "Set DATABASE_URL to point to a different target database." ) - conn = connect_to_maintenance_db(config) - cur = conn.cursor() - - print(" Terminating existing connections to target database...") - cur.execute( - "SELECT pg_terminate_backend(pid) FROM pg_stat_activity " - "WHERE datname = %s AND pid <> pg_backend_pid();", - (db_name,), - ) - - print(" Dropping target database if it exists...") - cur.execute(sql.SQL("DROP DATABASE IF EXISTS {}").format(sql.Identifier(db_name))) + print(" Dropping target database (forcing existing connections closed)...") + # --force terminates active connections (PostgreSQL 13+). + try: + run_pg_tool( + ["dropdb", "--if-exists", "--force", db_name], + config, + ) + except subprocess.CalledProcessError as e: + raise SystemExit(f" dropdb failed: {e.stderr or e}") print(" Creating target database...") - cur.execute(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(db_name))) + try: + run_pg_tool(["createdb", db_name], config) + except subprocess.CalledProcessError as e: + raise SystemExit(f" createdb failed: {e.stderr or e}") - cur.close() - conn.close() print(" Database recreated.") @@ -333,8 +185,10 @@ def safe_extract_tar(tar, extract_dir): extract_path = pathlib.Path(extract_dir).resolve() safe_members = [] for member in tar.getmembers(): - if member.issym() or member.islnk(): - raise ValueError(f"Link entries are not allowed in archive: {member.name}") + if not (member.isreg() or member.isdir()): + raise ValueError( + f"Only regular files/directories are allowed in archive: {member.name}" + ) member_path = (extract_path / member.name).resolve() if not member_path.is_relative_to(extract_path): raise ValueError(f"Path traversal detected in archive: {member.name}") @@ -342,181 +196,85 @@ def safe_extract_tar(tar, extract_dir): tar.extractall(path=extract_dir, members=safe_members) +def _find_dump_dir(root): + """Locate a pg_dump directory-format dump under `root`. + + A directory-format dump is a directory containing a `toc.dat` file (and + typically a number of `*.dat` files). We accept either `root` itself or + a single-level subdirectory. + """ + if os.path.isfile(os.path.join(root, "toc.dat")): + return root + for entry in os.listdir(root): + sub = os.path.join(root, entry) + if os.path.isdir(sub) and os.path.isfile(os.path.join(sub, "toc.dat")): + return sub + return None + + def extract_backup(backup_path, extract_dir): - """Extract .tar.gz backup to a directory. + """Extract .tar.gz (and any nested .tar) to `extract_dir`. - Returns the path to the directory containing restore.sql and .dat files. - Raises SystemExit if restore.sql cannot be found. + Returns the path to the directory containing `toc.dat` (the pg_dump + directory-format archive root). Exits if it cannot be located. """ print(f" Extracting {backup_path}...") with tarfile.open(backup_path, "r:gz") as tar: safe_extract_tar(tar, extract_dir) - # The .tar.gz may contain a nested .tar, or directly the dump files - contents = os.listdir(extract_dir) - for item in contents: - item_path = os.path.join(extract_dir, item) - if item.endswith(".tar") and tarfile.is_tarfile(item_path): + # The .tar.gz typically contains a nested .tar; unwrap it if present. + for entry in os.listdir(extract_dir): + entry_path = os.path.join(extract_dir, entry) + if entry.endswith(".tar") and tarfile.is_tarfile(entry_path): nested_dir = os.path.join(extract_dir, "dump") os.makedirs(nested_dir, exist_ok=True) - with tarfile.open(item_path, "r:") as tar: + with tarfile.open(entry_path, "r:") as tar: safe_extract_tar(tar, nested_dir) - contents = os.listdir(nested_dir) - # Check nested dir and its subdirectories - if "restore.sql" in contents: - return nested_dir - for sub in contents: - sub_path = os.path.join(nested_dir, sub) - if os.path.isdir(sub_path) and "restore.sql" in os.listdir(sub_path): - return sub_path - print(" Error: Could not locate restore.sql in nested archive.") - print(f" Nested contents: {contents}") + dump_dir = _find_dump_dir(nested_dir) + if dump_dir: + return dump_dir + print(" Error: Could not locate toc.dat in nested archive.") + print(f" Nested contents: {os.listdir(nested_dir)}") sys.exit(1) - # Check if restore.sql is directly in the extract dir - if "restore.sql" in contents: - return extract_dir - - # Look one level deeper - for item in contents: - item_path = os.path.join(extract_dir, item) - if os.path.isdir(item_path) and "restore.sql" in os.listdir(item_path): - return item_path + dump_dir = _find_dump_dir(extract_dir) + if dump_dir: + return dump_dir - print(" Error: Could not locate restore.sql in archive.") - print(f" Top-level contents: {contents}") + print(" Error: Could not locate toc.dat in archive.") + print(f" Top-level contents: {os.listdir(extract_dir)}") sys.exit(1) def apply_restore(config, dump_dir): - """Apply the database restore from the extracted dump directory. - - Strategy: - 1. Parse restore.sql to separate DDL/schema statements from COPY commands. - 2. Execute DDL statements directly. - 3. For COPY ... FROM '$$PATH$$/xxx.dat' statements, convert to - COPY ... FROM STDIN and stream the .dat file contents. This avoids - requiring superuser privileges on the database. - """ - restore_sql_path = os.path.join(dump_dir, "restore.sql") - - if not os.path.exists(restore_sql_path): - print(f" Error: restore.sql not found in {dump_dir}") - sys.exit(1) - - with open(restore_sql_path, "r") as f: - content = f.read() - - # Sanity check - if "PostgreSQL database dump" not in content[:500]: - print(" Warning: restore.sql does not look like a pg_dump file.") - print(f" Header: {content[:100]}") - - abs_dump_dir = os.path.abspath(dump_dir) - - # Remove empty COPY ... FROM stdin blocks (tables with no data): - # COPY public.table (...) FROM stdin; - # \. - content = re.sub( - r"^COPY\s+.+?\s+FROM\s+stdin;\s*\n\\\.\s*\n", - "", - content, - flags=re.MULTILINE, - ) - - # Pattern matching COPY ... FROM '$$PATH$$/filename.dat'; - copy_file_pattern = re.compile( - r"^(COPY\s+\S+\s*\([^)]*\))\s+FROM\s+'\$\$PATH\$\$/([^']+)';?\s*$", - re.MULTILINE, - ) - - # Split content into segments: (text_before, copy_stmt, dat_filename, text_after, ...) - # We process sequentially: run DDL, then when we hit a COPY-from-file, stream it. - segments = copy_file_pattern.split(content) - # split gives: [before, group1, group2, between, group1, group2, ..., after] - # groups come in pairs: (copy_prefix, dat_filename) - - conn = connect_to_target_db(config) - conn.autocommit = True - cur = conn.cursor() - - applied = 0 - data_loaded = 0 - warnings = 0 + """Apply the database restore from a pg_dump directory-format archive. + Uses `pg_restore -Fd --no-owner --no-privileges` so the restore runs as + the current connection user without requiring superuser privileges or + the original owner roles to exist. `--exit-on-error` makes pg_restore + fail fast on the first error rather than logging warnings and pressing on. + """ + cmd = [ + "pg_restore", + "--format=directory", + "--no-owner", + "--no-privileges", + "--exit-on-error", + "--dbname", config["database"], + dump_dir, + ] try: - i = 0 - while i < len(segments): - if i % 3 == 0: - # This is a DDL/non-COPY text segment - ddl_block = segments[i].strip() - if ddl_block: - # Execute line by line for statements, but we need to handle - # multi-line statements. Use psycopg2's execute which handles - # multiple statements in one call when separated by semicolons. - try: - cur.execute(ddl_block) - applied += 1 - except psycopg2.Error as e: - msg = e.pgerror or (e.diag.message_primary if e.diag else None) or str(e) - short_msg = sanitize_error(msg.strip().split("\n")[0][:150], config) - # Fail fast on critical DDL (CREATE/ALTER TABLE, CREATE TYPE, etc.) - # Allow non-critical statements (COMMENT, GRANT, SET, ALTER OWNER) to warn-and-continue - non_critical_prefixes = ("COMMENT ", "GRANT ", "REVOKE ", "SET ", "SELECT ") - block_upper = ddl_block.lstrip().upper() - is_non_critical = ( - any(block_upper.startswith(p) for p in non_critical_prefixes) - or " OWNER TO " in block_upper - ) - if is_non_critical: - print(f" Warning: {short_msg}") - warnings += 1 - else: - print(f" Error (fatal): {short_msg}") - raise SystemExit( - "Critical DDL statement failed. Restore aborted. " - "The database has been dropped and is now empty. " - "Re-run the restore or recover from the safety backup." - ) - i += 1 - else: - # This is a COPY group: segments[i] = copy_prefix, segments[i+1] = dat_filename - copy_prefix = segments[i] - dat_filename = segments[i + 1] - dat_path = os.path.join(abs_dump_dir, dat_filename) - - # Validate dat file path stays within dump directory - resolved_dat = pathlib.Path(dat_path).resolve() - if not resolved_dat.is_relative_to(pathlib.Path(abs_dump_dir).resolve()): - print(f" Error: Path traversal in restore.sql: {dat_filename}") - warnings += 1 - i += 2 - continue - - if os.path.exists(dat_path): - # Convert to COPY ... FROM STDIN and stream the file - copy_stdin_sql = f"{copy_prefix} FROM STDIN" - try: - with open(dat_path, "rb") as dat_f: - cur.copy_expert(copy_stdin_sql, dat_f) - data_loaded += 1 - except psycopg2.Error as e: - msg = sanitize_error(str(e).strip().split("\n")[0][:150], config) - print(f" Warning ({dat_filename}): {msg}") - warnings += 1 - else: - print(f" Warning: Data file not found: {dat_filename}") - warnings += 1 - - i += 2 - - print(f" Restore complete: {applied} DDL sections, {data_loaded} tables loaded, {warnings} warnings.") - finally: - cur.close() - conn.close() + run_pg_tool(cmd, config, capture=False) + except subprocess.CalledProcessError as e: + raise SystemExit( + f" pg_restore failed: {e.stderr or e}\n" + " The database has been dropped and is now empty. " + "Re-run the restore or recover from the safety backup with:\n" + " pg_restore -d \"$DATABASE_URL\" --clean .dump" + ) - return warnings + print(" Restore complete.") def check_disk_space(path, required_mb=500): @@ -529,6 +287,11 @@ def check_disk_space(path, required_mb=500): return True +def _presence(value): + """Display whether a connection field is set without logging its value.""" + return "(set)" if value else "(not set)" + + def main(): parser = argparse.ArgumentParser( description="Restore a PostgreSQL database from a Lagoon .tar.gz backup." @@ -559,6 +322,8 @@ def main(): ) args = parser.parse_args() + check_required_tools() + # Validate backup file exists and is .tar.gz if not os.path.exists(args.backup_file): print(f"Error: Backup file not found: {args.backup_file}") @@ -598,9 +363,9 @@ def main(): print("\nDatabase restore configuration:") print(f" Backup file : {args.backup_file}") - print(" Target host : [redacted]") - print(" Target DB : [redacted]") - print(" User : [redacted]") + print(f" Target host : {_presence(config.get('host'))}") + print(f" Target DB : {_presence(config.get('database'))}") + print(f" User : {_presence(config.get('user'))}") print() if not args.yes: @@ -624,7 +389,7 @@ def main(): backup_dir = args.backup_dir or os.path.dirname(os.path.abspath(args.backup_file)) os.makedirs(backup_dir, exist_ok=True) safety_backup_path = os.path.join( - backup_dir, f"{config['database']}_pre_restore.sql" + backup_dir, f"{config['database']}_pre_restore.dump" ) # Avoid overwriting an existing backup if os.path.exists(safety_backup_path): @@ -637,8 +402,8 @@ def main(): try: backup_current_database(config, safety_backup_path) print(" Safety backup created successfully.") - except Exception as e: - print(f" Could not backup current database: {sanitize_error(str(e), config)}") + except SystemExit as exc: + print(f" {exc}", file=sys.stderr) if args.yes: print(" Error: Cannot proceed without safety backup in non-interactive mode.") print(" Use --no-backup to explicitly skip the safety backup.") @@ -656,9 +421,9 @@ def main(): os.makedirs(extract_dir, exist_ok=True) dump_dir = extract_backup(args.backup_file, extract_dir) dump_contents = os.listdir(dump_dir) - dat_count = sum(1 for f in dump_contents if f.endswith(".dat")) + dat_count = sum(1 for f in dump_contents if f.endswith(".dat") and f != "toc.dat") print(f" Dump directory: {dump_dir}") - print(f" Found: restore.sql + {dat_count} data files") + print(f" Found: toc.dat + {dat_count} data files") # Step 3: Drop and recreate the database print("\n[3/4] Dropping and recreating database...") @@ -666,13 +431,9 @@ def main(): # Step 4: Apply the restore print("\n[4/4] Applying database restore...") - warning_count = apply_restore(config, dump_dir) + apply_restore(config, dump_dir) - if warning_count > 0: - print(f"\nRestore completed with {warning_count} warning(s).") - sys.exit(2) - else: - print("\nRestore complete!") + print("\nRestore complete!") finally: # Clean up temp extraction directory From c37d167de33c1d8d5d64965022e8eba260fcceb5 Mon Sep 17 00:00:00 2001 From: Dan Lemon Date: Tue, 2 Jun 2026 16:10:36 +0200 Subject: [PATCH 4/5] chore: address restore script PR feedback (libpq params + createdb attrs) - Pass full DATABASE_URL to libpq via --dbname / --maintenance-db so query params (sslmode, connect_timeout, socket host=..., etc.) flow through to every pg_* call instead of being silently dropped by get_db_config(). Trim pg_env() to only set PGPASSWORD; harden sanitize() to also strip the URL-encoded password form. - Replace the manual dropdb/createdb step with pg_restore --create --clean --if-exists --no-owner --no-privileges --no-tablespaces --exit-on-error --dbname=. This replays the source database's CREATE DATABASE from toc.dat, so the restored DB inherits the source's encoding and collation instead of cluster defaults. --no-owner / --no-tablespaces strip OWNER and TABLESPACE clauses so the restore does not depend on the source's roles or tablespaces existing on the target cluster. --- scripts/restore_database.py | 184 +++++++++++++++++++++++------------- 1 file changed, 117 insertions(+), 67 deletions(-) diff --git a/scripts/restore_database.py b/scripts/restore_database.py index e5752637..b809ceeb 100755 --- a/scripts/restore_database.py +++ b/scripts/restore_database.py @@ -7,12 +7,22 @@ This script: 1. (Optionally) takes a pre-restore safety backup with `pg_dump -Fc`. - 2. Drops and recreates the target database via `dropdb` / `createdb`. - 3. Extracts the .tar.gz (and the nested .tar) to a temp directory. - 4. Applies the dump with `pg_restore -Fd --no-owner --no-privileges`. - -Requires the postgres client tools (pg_dump, pg_restore, dropdb, createdb, -psql) on PATH. In Lagoon these are provided by the `cli` container image. + 2. Extracts the .tar.gz (and the nested .tar) to a temp directory. + 3. Applies the dump with + `pg_restore -Fd --create --clean --if-exists --no-owner --no-privileges + --no-tablespaces --exit-on-error` + connected to the `postgres` maintenance database. `--create --clean` + replays the source DB's own `CREATE DATABASE` from `toc.dat`, so the + target is recreated with the source's encoding and collation rather + than the cluster's defaults. `--no-owner` and `--no-tablespaces` strip + out OWNER/TABLESPACE clauses so the restore does not depend on the + source's roles or tablespaces existing on the target cluster. + +Requires the postgres client tools (pg_dump, pg_restore, psql) on PATH. In +Lagoon these are provided by the `cli` container image. + +The connecting role needs CREATEDB (to recreate the target database) and +must be able to connect to the `postgres` maintenance database. IMPORTANT: Only use this with backups from trusted sources (e.g. Lagoon). @@ -28,30 +38,59 @@ import sys import tarfile import tempfile -from urllib.parse import unquote, urlparse +from urllib.parse import quote, unquote, urlparse, urlunparse -REQUIRED_TOOLS = ("pg_dump", "pg_restore", "dropdb", "createdb", "psql") +REQUIRED_TOOLS = ("pg_dump", "pg_restore", "psql") def sanitize(msg, config): - """Remove sensitive data (password) from messages.""" + """Remove sensitive data (password) from messages. + + The password may appear verbatim (from PGPASSWORD-style logs) or + URL-encoded inside the DATABASE_URL that we now pass on argv via + `--dbname` / `--maintenance-db`. Strip both forms. + """ if not msg: return msg + msg = str(msg) pw = config.get("password") if pw: msg = msg.replace(pw, "****") + encoded_pw = quote(pw, safe="") + if encoded_pw and encoded_pw != pw: + msg = msg.replace(encoded_pw, "****") return msg def get_db_config(): - """Parse DATABASE_URL into connection components.""" + """Parse DATABASE_URL into connection components. + + The full URL is preserved (`url`) and handed to libpq via `--dbname` / + `--maintenance-db` so any libpq-supported parameters (sslmode, + connect_timeout, host=/var/run/..., etc.) flow through unchanged. Parsed + fields are kept only for local needs: + - `database` for the safety guard, the dropdb/createdb target, and the + backup filename; + - `password` for the log-sanitizer and PGPASSWORD (so it never lands in + argv); + - `host` / `user` for the diagnostic presence printout. + + Also derives `maintenance_url`, a URL with the path swapped to + `/postgres`, so `dropdb`/`createdb` can connect to a different database + than the one being dropped/created while inheriting all other params. + """ database_url = os.getenv( "DATABASE_URL", "postgres://postgres:postgres@postgres:5432/postgres_service" ) parsed = urlparse(database_url) db_name = unquote(parsed.path or "").removeprefix("/") + # Replace the path with `/postgres` while preserving netloc and query + # params (so sslmode, connect_timeout, etc. carry over). + maintenance_url = urlunparse(parsed._replace(path="/postgres")) return { + "url": database_url, + "maintenance_url": maintenance_url, "host": parsed.hostname, "port": parsed.port or 5432, "user": unquote(parsed.username) if parsed.username is not None else None, @@ -61,17 +100,17 @@ def get_db_config(): def pg_env(config): - """Build an environment dict with libpq connection variables set.""" + """Build an environment dict with PGPASSWORD set. + + All other connection parameters are passed via the libpq URL on the + command line (`--dbname` / `--maintenance-db`), so libpq parses any + extra params (sslmode, connect_timeout, socket host=..., etc.) verbatim. + PGPASSWORD is kept in the environment rather than the URL to avoid + leaking the password into argv / process listings. + """ env = os.environ.copy() - if config.get("host"): - env["PGHOST"] = str(config["host"]) - if config.get("port"): - env["PGPORT"] = str(config["port"]) - if config.get("user"): - env["PGUSER"] = str(config["user"]) if config.get("password"): env["PGPASSWORD"] = str(config["password"]) - # libpq picks up these on its own; no need to pass them as CLI flags. return env @@ -129,11 +168,14 @@ def backup_current_database(config, backup_path): fd = os.open(backup_path, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o600) os.close(fd) + # Pass the full DATABASE_URL via --dbname so libpq parses every param + # (sslmode, connect_timeout, socket host=..., etc.) instead of relying on + # the four fields we extracted in get_db_config(). cmd = [ "pg_dump", "--format=custom", "--file", backup_path, - "--dbname", config["database"], + "--dbname", config["url"], ] try: run_pg_tool(cmd, config) @@ -144,42 +186,13 @@ def backup_current_database(config, backup_path): except OSError as cleanup_err: # Best-effort cleanup only: preserve original pg_dump failure path. print(f" Warning: could not remove partial backup file: {cleanup_err}", file=sys.stderr) - raise SystemExit(f" pg_dump failed: {e.stderr or e}") + raise SystemExit(f" pg_dump failed: {sanitize(e.stderr or e, config)}") size_mb = os.path.getsize(backup_path) / (1024 * 1024) print(f" Backup complete ({size_mb:.1f} MB)") return backup_path -def drop_and_recreate_database(config): - """Drop the target database (forcing connections closed) and recreate it.""" - db_name = config["database"] - - if db_name == "postgres": - raise SystemExit( - "Error: Cannot drop the 'postgres' maintenance database. " - "Set DATABASE_URL to point to a different target database." - ) - - print(" Dropping target database (forcing existing connections closed)...") - # --force terminates active connections (PostgreSQL 13+). - try: - run_pg_tool( - ["dropdb", "--if-exists", "--force", db_name], - config, - ) - except subprocess.CalledProcessError as e: - raise SystemExit(f" dropdb failed: {e.stderr or e}") - - print(" Creating target database...") - try: - run_pg_tool(["createdb", db_name], config) - except subprocess.CalledProcessError as e: - raise SystemExit(f" createdb failed: {e.stderr or e}") - - print(" Database recreated.") - - def safe_extract_tar(tar, extract_dir): """Extract tar members, validating paths stay within extract_dir.""" extract_path = pathlib.Path(extract_dir).resolve() @@ -250,28 +263,67 @@ def extract_backup(backup_path, extract_dir): def apply_restore(config, dump_dir): """Apply the database restore from a pg_dump directory-format archive. - Uses `pg_restore -Fd --no-owner --no-privileges` so the restore runs as - the current connection user without requiring superuser privileges or - the original owner roles to exist. `--exit-on-error` makes pg_restore - fail fast on the first error rather than logging warnings and pressing on. + Connects to the `postgres` maintenance DB and lets pg_restore itself + drop and recreate the target with the source's recorded metadata + (encoding, collation, locale provider) replayed from `toc.dat`. + + Flag rationale: + --create / --clean / --if-exists + Replay the source's `CREATE DATABASE` (preserving encoding and + collation, fixing the cluster-defaults correctness gap that a + plain `createdb` would introduce). `--clean --if-exists` makes + the preceding `DROP DATABASE` idempotent so re-running the + script is safe even when the target does not exist. + --no-owner / --no-privileges + Skip OWNER and GRANT clauses (both for the database itself and + for objects inside it), so the restore does not require the + source's roles to exist on the target cluster. The restoring + role takes ownership. + --no-tablespaces + Strip TABLESPACE clauses, so the restore does not require the + source's tablespaces to exist on the target cluster (the DB + and its objects land in the default tablespace). + --exit-on-error + Fail fast on the first error rather than logging warnings and + continuing with a half-restored database. """ + db_name = config["database"] + + # Refuse to wipe the maintenance DB. (`pg_restore --clean --create` + # would happily DROP DATABASE postgres if asked.) + if db_name == "postgres": + raise SystemExit( + "Error: Cannot restore over the 'postgres' maintenance database. " + "Set DATABASE_URL to point to a different target database." + ) + cmd = [ "pg_restore", "--format=directory", + "--create", + "--clean", + "--if-exists", "--no-owner", "--no-privileges", + "--no-tablespaces", "--exit-on-error", - "--dbname", config["database"], + # Connect to the maintenance DB so pg_restore can drop/create the + # target. `maintenance_url` carries every libpq param from the + # configured DATABASE_URL (sslmode, connect_timeout, etc.). + "--dbname", config["maintenance_url"], dump_dir, ] try: run_pg_tool(cmd, config, capture=False) except subprocess.CalledProcessError as e: raise SystemExit( - f" pg_restore failed: {e.stderr or e}\n" - " The database has been dropped and is now empty. " - "Re-run the restore or recover from the safety backup with:\n" - " pg_restore -d \"$DATABASE_URL\" --clean .dump" + f" pg_restore failed: {sanitize(e.stderr or e, config)}\n" + " The target database may have been dropped or partially " + "restored. Recover from the safety backup with:\n" + " pg_restore --clean --create --if-exists --no-owner " + "--no-privileges \\\n" + " -d \"\" " + ".dump" ) print(" Restore complete.") @@ -385,7 +437,7 @@ def main(): try: # Step 1: Backup current database (unless skipped) if not args.no_backup: - print("\n[1/4] Backing up current database...") + print("\n[1/3] Backing up current database...") backup_dir = args.backup_dir or os.path.dirname(os.path.abspath(args.backup_file)) os.makedirs(backup_dir, exist_ok=True) safety_backup_path = os.path.join( @@ -413,10 +465,10 @@ def main(): print("Aborted.") sys.exit(0) else: - print("\n[1/4] Skipping backup (--no-backup)") + print("\n[1/3] Skipping backup (--no-backup)") # Step 2: Extract the backup archive - print("\n[2/4] Extracting backup archive...") + print("\n[2/3] Extracting backup archive...") extract_dir = os.path.join(tmpdir, "extracted") os.makedirs(extract_dir, exist_ok=True) dump_dir = extract_backup(args.backup_file, extract_dir) @@ -425,12 +477,10 @@ def main(): print(f" Dump directory: {dump_dir}") print(f" Found: toc.dat + {dat_count} data files") - # Step 3: Drop and recreate the database - print("\n[3/4] Dropping and recreating database...") - drop_and_recreate_database(config) - - # Step 4: Apply the restore - print("\n[4/4] Applying database restore...") + # Step 3: Apply the restore. pg_restore --create --clean handles + # the drop/recreate so the target DB is rebuilt with the source's + # encoding/collation rather than cluster defaults. + print("\n[3/3] Applying database restore...") apply_restore(config, dump_dir) print("\nRestore complete!") From 3822adec45feb7e48431d269675364afeeca2780 Mon Sep 17 00:00:00 2001 From: Dan Lemon Date: Tue, 2 Jun 2026 16:44:06 +0200 Subject: [PATCH 5/5] chore: safe url --- scripts/restore_database.py | 57 ++++++++++++++++++++++++++++++------- 1 file changed, 47 insertions(+), 10 deletions(-) diff --git a/scripts/restore_database.py b/scripts/restore_database.py index b809ceeb..84d78b57 100755 --- a/scripts/restore_database.py +++ b/scripts/restore_database.py @@ -63,13 +63,43 @@ def sanitize(msg, config): return msg +def _strip_password_from_netloc(parsed): + """Return a netloc string with the password component removed. + + Keeps the username (if any) and host[:port] intact, preserving the + original percent-encoding of the username so we don't double-encode + (`urlparse` returns raw, percent-encoded `username` / `password`). + + The password is stripped because `url` / `maintenance_url` are passed + on argv via `--dbname`, where they would otherwise be visible in `ps` + / `/proc//cmdline`. The password is supplied to libpq through + PGPASSWORD instead (see `pg_env`). + """ + netloc = parsed.netloc + # Split off any userinfo ("user[:password]@host[:port]"). Use rsplit so + # an `@` in the password (which is illegal unencoded but defensive) is + # handled by taking the *last* `@` as the userinfo/host boundary. + if "@" not in netloc: + return netloc + userinfo, _, hostport = netloc.rpartition("@") + # Drop the password portion of "user:password" (if present); keep the + # username's original percent-encoding verbatim. + raw_user = userinfo.split(":", 1)[0] + if not raw_user: + return hostport + return f"{raw_user}@{hostport}" + + def get_db_config(): """Parse DATABASE_URL into connection components. - The full URL is preserved (`url`) and handed to libpq via `--dbname` / + `url` and `maintenance_url` are handed to libpq via `--dbname` / `--maintenance-db` so any libpq-supported parameters (sslmode, - connect_timeout, host=/var/run/..., etc.) flow through unchanged. Parsed - fields are kept only for local needs: + connect_timeout, host=/var/run/..., etc.) flow through unchanged. The + password is stripped from these URLs before they hit argv to avoid + leaking the credential into `ps` / `/proc//cmdline`; libpq picks + it up from PGPASSWORD via `pg_env`. Parsed fields are kept only for + local needs: - `database` for the safety guard, the dropdb/createdb target, and the backup filename; - `password` for the log-sanitizer and PGPASSWORD (so it never lands in @@ -85,11 +115,16 @@ def get_db_config(): ) parsed = urlparse(database_url) db_name = unquote(parsed.path or "").removeprefix("/") - # Replace the path with `/postgres` while preserving netloc and query - # params (so sslmode, connect_timeout, etc. carry over). - maintenance_url = urlunparse(parsed._replace(path="/postgres")) + # Strip the password from the netloc so neither `url` nor + # `maintenance_url` carry the credential when passed on argv via + # `--dbname`. The password flows through PGPASSWORD instead. + safe_netloc = _strip_password_from_netloc(parsed) + safe_url = urlunparse(parsed._replace(netloc=safe_netloc)) + maintenance_url = urlunparse( + parsed._replace(netloc=safe_netloc, path="/postgres") + ) return { - "url": database_url, + "url": safe_url, "maintenance_url": maintenance_url, "host": parsed.hostname, "port": parsed.port or 5432, @@ -168,9 +203,11 @@ def backup_current_database(config, backup_path): fd = os.open(backup_path, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o600) os.close(fd) - # Pass the full DATABASE_URL via --dbname so libpq parses every param - # (sslmode, connect_timeout, socket host=..., etc.) instead of relying on - # the four fields we extracted in get_db_config(). + # Pass the password-stripped DATABASE_URL via --dbname so libpq parses + # every param (sslmode, connect_timeout, socket host=..., etc.) instead + # of relying on the four fields we extracted in get_db_config(). The + # password is supplied via PGPASSWORD in pg_env() so it never appears on + # argv / in `ps` output. cmd = [ "pg_dump", "--format=custom",