diff --git a/node/rustchain_sync.py b/node/rustchain_sync.py index 265f8cfd0..6df0518ef 100644 --- a/node/rustchain_sync.py +++ b/node/rustchain_sync.py @@ -2,17 +2,45 @@ # SPDX-License-Identifier: MIT # Author: @createkr (RayBot AI) # BCOS-Tier: L1 -import sqlite3 import hashlib import json -import time import logging -from typing import List, Dict, Any, Optional +import sqlite3 +import time +from copy import deepcopy +from typing import Any, Dict, List, Optional, Protocol -class RustChainSyncManager: +class StateProvider(Protocol): + """Swappable source of syncable RustChain state.""" + + def get_available_sync_tables(self) -> List[str]: + ... + + def calculate_table_hash(self, table_name: str) -> str: + ... + + def get_merkle_root(self) -> str: + ... + + def get_primary_key(self, table_name: str) -> Optional[str]: + ... + + def get_table_data( + self, table_name: str, limit: int = 200, offset: int = 0 + ) -> List[Dict[str, Any]]: + ... + + def apply_sync_payload(self, table_name: str, remote_data: List[Dict[str, Any]]): + ... + + def get_count(self, table_name: str) -> int: + ... + + +class SQLiteStateProvider: """ - Handles bidirectional SQLite synchronization between RustChain nodes. + SQLite-backed sync state provider. Security model: - Table names are allowlisted @@ -30,11 +58,9 @@ class RustChainSyncManager: "transaction_history", ] - def __init__(self, db_path: str, admin_key: str): + def __init__(self, db_path: str, logger: Optional[logging.Logger] = None): self.db_path = db_path - self.admin_key = admin_key - logging.basicConfig(level=logging.INFO) - self.logger = logging.getLogger("RustChainSync") + self.logger = logger or logging.getLogger("RustChainSync") self._schema_cache: Dict[str, Dict[str, Any]] = {} def _get_connection(self): @@ -128,7 +154,7 @@ def get_merkle_root(self) -> str: combined = "".join(table_hashes) return hashlib.sha256(combined.encode()).hexdigest() - def _get_primary_key(self, table_name: str) -> Optional[str]: + def get_primary_key(self, table_name: str) -> Optional[str]: schema = self._load_table_schema(table_name) if not schema: return None @@ -270,6 +296,222 @@ def apply_sync_payload(self, table_name: str, remote_data: List[Dict[str, Any]]) finally: conn.close() + def get_count(self, table_name: str) -> int: + if table_name not in self.SYNC_TABLES: + return 0 + conn = self._get_connection() + try: + cursor = conn.cursor() + cursor.execute(f"SELECT COUNT(*) FROM {table_name}") + count = cursor.fetchone()[0] + return int(count) + finally: + conn.close() + + +class InMemoryStateProvider: + """Small in-memory provider for tests and embedded callers.""" + + def __init__( + self, + tables: Optional[Dict[str, List[Dict[str, Any]]]] = None, + primary_keys: Optional[Dict[str, str]] = None, + ): + self.tables = deepcopy(tables or {}) + self.primary_keys = dict(primary_keys or {}) + + def get_available_sync_tables(self) -> List[str]: + return [name for name in self.tables if self.primary_keys.get(name)] + + def calculate_table_hash(self, table_name: str) -> str: + if table_name not in self.get_available_sync_tables(): + return "" + + pk = self.primary_keys[table_name] + rows = sorted(self.tables.get(table_name, []), key=lambda row: row.get(pk)) + hasher = hashlib.sha256() + for row in rows: + row_str = json.dumps(row, sort_keys=True, separators=(",", ":")) + hasher.update(row_str.encode()) + return hasher.hexdigest() + + def get_merkle_root(self) -> str: + combined = "".join( + self.calculate_table_hash(table) + for table in self.get_available_sync_tables() + ) + return hashlib.sha256(combined.encode()).hexdigest() + + def get_primary_key(self, table_name: str) -> Optional[str]: + return self.primary_keys.get(table_name) + + def get_table_data( + self, table_name: str, limit: int = 200, offset: int = 0 + ) -> List[Dict[str, Any]]: + if table_name not in self.get_available_sync_tables(): + return [] + rows = self.tables.get(table_name, []) + return deepcopy(rows[int(offset): int(offset) + int(limit)]) + + def apply_sync_payload(self, table_name: str, remote_data: List[Dict[str, Any]]): + pk = self.primary_keys.get(table_name) + if not pk or table_name not in self.tables: + return False + + rows_by_pk = {row.get(pk): dict(row) for row in self.tables[table_name]} + for row in remote_data: + if not isinstance(row, dict) or pk not in row: + continue + existing = rows_by_pk.get(row[pk], {}) + existing.update(row) + rows_by_pk[row[pk]] = existing + self.tables[table_name] = list(rows_by_pk.values()) + return True + + def get_count(self, table_name: str) -> int: + if table_name not in self.get_available_sync_tables(): + return 0 + return len(self.tables.get(table_name, [])) + + +class FallbackStateProvider: + """Try multiple providers in order so callers can swap state sources safely.""" + + def __init__(self, providers: List[StateProvider]): + if not providers: + raise ValueError("at least one state provider is required") + self.providers = providers + + def _first_table_provider(self, table_name: str) -> Optional[StateProvider]: + for provider in self.providers: + try: + if table_name in provider.get_available_sync_tables(): + return provider + except Exception: + continue + return None + + def _table_providers(self, table_name: str) -> List[StateProvider]: + providers: List[StateProvider] = [] + for provider in self.providers: + try: + if table_name in provider.get_available_sync_tables(): + providers.append(provider) + except Exception: + continue + return providers + + def get_available_sync_tables(self) -> List[str]: + tables: List[str] = [] + for provider in self.providers: + try: + for table in provider.get_available_sync_tables(): + if table not in tables: + tables.append(table) + except Exception: + continue + return tables + + def calculate_table_hash(self, table_name: str) -> str: + for provider in self._table_providers(table_name): + try: + return provider.calculate_table_hash(table_name) + except Exception: + continue + return "" + + def get_merkle_root(self) -> str: + combined = "".join( + self.calculate_table_hash(table) + for table in self.get_available_sync_tables() + ) + return hashlib.sha256(combined.encode()).hexdigest() + + def get_primary_key(self, table_name: str) -> Optional[str]: + for provider in self._table_providers(table_name): + try: + primary_key = provider.get_primary_key(table_name) + except Exception: + continue + if primary_key: + return primary_key + return None + + def get_table_data( + self, table_name: str, limit: int = 200, offset: int = 0 + ) -> List[Dict[str, Any]]: + for provider in self._table_providers(table_name): + try: + return provider.get_table_data(table_name, limit, offset) + except Exception: + continue + return [] + + def apply_sync_payload(self, table_name: str, remote_data: List[Dict[str, Any]]): + for provider in self._table_providers(table_name): + try: + if provider.apply_sync_payload(table_name, remote_data): + return True + except Exception: + continue + return False + + def get_count(self, table_name: str) -> int: + for provider in self._table_providers(table_name): + try: + return provider.get_count(table_name) + except Exception: + continue + return 0 + + +class RustChainSyncManager: + """Handles bidirectional synchronization through a swappable state provider.""" + + def __init__( + self, + db_path: str, + admin_key: str, + state_provider: Optional[StateProvider] = None, + ): + self.db_path = db_path + self.admin_key = admin_key + logging.basicConfig(level=logging.INFO) + self.logger = logging.getLogger("RustChainSync") + self.state_provider = state_provider or SQLiteStateProvider( + db_path, + logger=self.logger, + ) + # Backward-compatible access for older tests that clear the SQLite schema cache. + self._schema_cache = getattr(self.state_provider, "_schema_cache", {}) + + @property + def SYNC_TABLES(self) -> List[str]: + return self.get_available_sync_tables() + + def get_available_sync_tables(self) -> List[str]: + return self.state_provider.get_available_sync_tables() + + def calculate_table_hash(self, table_name: str) -> str: + return self.state_provider.calculate_table_hash(table_name) + + def get_merkle_root(self) -> str: + return self.state_provider.get_merkle_root() + + def _get_primary_key(self, table_name: str) -> Optional[str]: + return self.state_provider.get_primary_key(table_name) + + def get_table_data( + self, table_name: str, limit: int = 200, offset: int = 0 + ) -> List[Dict[str, Any]]: + return self.state_provider.get_table_data(table_name, limit, offset) + + def apply_sync_payload(self, table_name: str, remote_data: List[Dict[str, Any]]): + return self.state_provider.apply_sync_payload(table_name, remote_data) + + def _get_count(self, table_name: str) -> int: + return self.state_provider.get_count(table_name) + def get_sync_status(self) -> Dict[str, Any]: """Returns metadata about the current state of synced tables.""" tables = self.SYNC_TABLES @@ -286,15 +528,3 @@ def get_sync_status(self) -> Dict[str, Any]: "pk": self._get_primary_key(t), } return status - - def _get_count(self, table_name: str) -> int: - if table_name not in self.SYNC_TABLES: - return 0 - conn = self._get_connection() - try: - cursor = conn.cursor() - cursor.execute(f"SELECT COUNT(*) FROM {table_name}") - count = cursor.fetchone()[0] - return int(count) - finally: - conn.close() diff --git a/node/tests/test_state_provider_api.py b/node/tests/test_state_provider_api.py new file mode 100644 index 000000000..8a580e4ac --- /dev/null +++ b/node/tests/test_state_provider_api.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: MIT + +import os +import sqlite3 +import sys + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +from rustchain_sync import ( + FallbackStateProvider, + InMemoryStateProvider, + RustChainSyncManager, + SQLiteStateProvider, +) + + +def test_sync_manager_accepts_injected_state_provider(): + provider = InMemoryStateProvider( + tables={ + "miner_attest_recent": [ + {"miner_id": "miner-a", "last_attest": 10}, + ], + }, + primary_keys={"miner_attest_recent": "miner_id"}, + ) + sync = RustChainSyncManager(":memory:", "sync-secret", state_provider=provider) + + assert sync.SYNC_TABLES == ["miner_attest_recent"] + assert sync.get_table_data("miner_attest_recent") == [ + {"miner_id": "miner-a", "last_attest": 10}, + ] + assert sync.get_sync_status()["tables"]["miner_attest_recent"]["count"] == 1 + + +def test_fallback_state_provider_uses_secondary_when_primary_fails(): + class BrokenProvider: + def get_available_sync_tables(self): + raise RuntimeError("primary unavailable") + + def calculate_table_hash(self, table_name): + raise RuntimeError("primary unavailable") + + def get_merkle_root(self): + raise RuntimeError("primary unavailable") + + def get_primary_key(self, table_name): + raise RuntimeError("primary unavailable") + + def get_table_data(self, table_name, limit=200, offset=0): + raise RuntimeError("primary unavailable") + + def apply_sync_payload(self, table_name, remote_data): + raise RuntimeError("primary unavailable") + + def get_count(self, table_name): + raise RuntimeError("primary unavailable") + + secondary = InMemoryStateProvider( + tables={"epoch_rewards": [{"epoch": 7, "reward": 100}]}, + primary_keys={"epoch_rewards": "epoch"}, + ) + provider = FallbackStateProvider([BrokenProvider(), secondary]) + + assert provider.get_available_sync_tables() == ["epoch_rewards"] + assert provider.get_primary_key("epoch_rewards") == "epoch" + assert provider.get_table_data("epoch_rewards") == [ + {"epoch": 7, "reward": 100}, + ] + assert provider.get_count("epoch_rewards") == 1 + + +def test_fallback_state_provider_tries_secondary_when_advertised_table_operation_fails(): + class AdvertisesButFails: + def get_available_sync_tables(self): + return ["epoch_rewards"] + + def calculate_table_hash(self, table_name): + raise RuntimeError("hash failed") + + def get_merkle_root(self): + raise RuntimeError("root failed") + + def get_primary_key(self, table_name): + raise RuntimeError("pk failed") + + def get_table_data(self, table_name, limit=200, offset=0): + raise RuntimeError("data failed") + + def apply_sync_payload(self, table_name, remote_data): + raise RuntimeError("apply failed") + + def get_count(self, table_name): + raise RuntimeError("count failed") + + secondary = InMemoryStateProvider( + tables={"epoch_rewards": [{"epoch": 7, "reward": 100}]}, + primary_keys={"epoch_rewards": "epoch"}, + ) + provider = FallbackStateProvider([AdvertisesButFails(), secondary]) + + assert provider.get_primary_key("epoch_rewards") == "epoch" + assert provider.get_table_data("epoch_rewards") == [ + {"epoch": 7, "reward": 100}, + ] + assert provider.get_count("epoch_rewards") == 1 + assert provider.calculate_table_hash("epoch_rewards") == secondary.calculate_table_hash( + "epoch_rewards" + ) + + assert provider.apply_sync_payload( + "epoch_rewards", + [{"epoch": 7, "reward": 120}], + ) + assert provider.get_table_data("epoch_rewards") == [ + {"epoch": 7, "reward": 120}, + ] + + +def test_default_sqlite_provider_preserves_existing_sync_behavior(tmp_path): + db_path = tmp_path / "rustchain.db" + with sqlite3.connect(db_path) as conn: + conn.execute( + """ + CREATE TABLE miner_attest_recent ( + miner_id TEXT PRIMARY KEY, + last_attest INTEGER NOT NULL + ) + """ + ) + conn.execute( + "INSERT INTO miner_attest_recent (miner_id, last_attest) VALUES (?, ?)", + ("miner-a", 10), + ) + conn.commit() + + sync = RustChainSyncManager(str(db_path), "sync-secret") + + assert isinstance(sync.state_provider, SQLiteStateProvider) + assert sync.get_available_sync_tables() == ["miner_attest_recent"] + assert sync.get_table_data("miner_attest_recent") == [ + {"miner_id": "miner-a", "last_attest": 10}, + ] + + assert sync.apply_sync_payload( + "miner_attest_recent", + [{"miner_id": "miner-a", "last_attest": 5}], + ) + assert sync.get_table_data("miner_attest_recent") == [ + {"miner_id": "miner-a", "last_attest": 10}, + ] + + assert sync.apply_sync_payload( + "miner_attest_recent", + [{"miner_id": "miner-a", "last_attest": 12}], + ) + assert sync.get_table_data("miner_attest_recent") == [ + {"miner_id": "miner-a", "last_attest": 12}, + ]