Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 250 additions & 20 deletions node/rustchain_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,40 @@
import json
import time
import logging
from typing import List, Dict, Any, Optional
from copy import deepcopy
from typing import List, Dict, Any, Optional, Protocol


class RustChainSyncManager:
class StateProvider(Protocol):
"""Swappable source of syncable RustChain state."""

def get_available_sync_tables(self) -> List[str]:
...

def calculate_table_hash(self, table_name: str) -> str:
...

def get_merkle_root(self) -> str:
...

def get_primary_key(self, table_name: str) -> Optional[str]:
...

def get_table_data(
self, table_name: str, limit: int = 200, offset: int = 0
) -> List[Dict[str, Any]]:
...

def apply_sync_payload(self, table_name: str, remote_data: List[Dict[str, Any]]):
...

def get_count(self, table_name: str) -> int:
...


class SQLiteStateProvider:
"""
Handles bidirectional SQLite synchronization between RustChain nodes.
SQLite-backed sync state provider.

Security model:
- Table names are allowlisted
Expand All @@ -30,11 +58,9 @@ class RustChainSyncManager:
"transaction_history",
]

def __init__(self, db_path: str, admin_key: str):
def __init__(self, db_path: str, logger: Optional[logging.Logger] = None):
self.db_path = db_path
self.admin_key = admin_key
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger("RustChainSync")
self.logger = logger or logging.getLogger("RustChainSync")
self._schema_cache: Dict[str, Dict[str, Any]] = {}

def _get_connection(self):
Expand Down Expand Up @@ -128,7 +154,7 @@ def get_merkle_root(self) -> str:
combined = "".join(table_hashes)
return hashlib.sha256(combined.encode()).hexdigest()

def _get_primary_key(self, table_name: str) -> Optional[str]:
def get_primary_key(self, table_name: str) -> Optional[str]:
schema = self._load_table_schema(table_name)
if not schema:
return None
Expand Down Expand Up @@ -256,6 +282,222 @@ def apply_sync_payload(self, table_name: str, remote_data: List[Dict[str, Any]])
finally:
conn.close()

def get_count(self, table_name: str) -> int:
if table_name not in self.SYNC_TABLES:
return 0
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
count = cursor.fetchone()[0]
return int(count)
finally:
conn.close()


class InMemoryStateProvider:
"""Small in-memory provider for tests and embedded callers."""

def __init__(
self,
tables: Optional[Dict[str, List[Dict[str, Any]]]] = None,
primary_keys: Optional[Dict[str, str]] = None,
):
self.tables = deepcopy(tables or {})
self.primary_keys = dict(primary_keys or {})

def get_available_sync_tables(self) -> List[str]:
return [name for name in self.tables if self.primary_keys.get(name)]

def calculate_table_hash(self, table_name: str) -> str:
if table_name not in self.get_available_sync_tables():
return ""

pk = self.primary_keys[table_name]
rows = sorted(self.tables.get(table_name, []), key=lambda row: row.get(pk))
hasher = hashlib.sha256()
for row in rows:
row_str = json.dumps(row, sort_keys=True, separators=(",", ":"))
hasher.update(row_str.encode())
return hasher.hexdigest()

def get_merkle_root(self) -> str:
combined = "".join(
self.calculate_table_hash(table)
for table in self.get_available_sync_tables()
)
return hashlib.sha256(combined.encode()).hexdigest()

def get_primary_key(self, table_name: str) -> Optional[str]:
return self.primary_keys.get(table_name)

def get_table_data(
self, table_name: str, limit: int = 200, offset: int = 0
) -> List[Dict[str, Any]]:
if table_name not in self.get_available_sync_tables():
return []
rows = self.tables.get(table_name, [])
return deepcopy(rows[int(offset): int(offset) + int(limit)])

def apply_sync_payload(self, table_name: str, remote_data: List[Dict[str, Any]]):
pk = self.primary_keys.get(table_name)
if not pk or table_name not in self.tables:
return False

rows_by_pk = {row.get(pk): dict(row) for row in self.tables[table_name]}
for row in remote_data:
if not isinstance(row, dict) or pk not in row:
continue
existing = rows_by_pk.get(row[pk], {})
existing.update(row)
rows_by_pk[row[pk]] = existing
self.tables[table_name] = list(rows_by_pk.values())
return True

def get_count(self, table_name: str) -> int:
if table_name not in self.get_available_sync_tables():
return 0
return len(self.tables.get(table_name, []))


class FallbackStateProvider:
"""Try multiple providers in order so callers can swap state sources safely."""

def __init__(self, providers: List[StateProvider]):
if not providers:
raise ValueError("at least one state provider is required")
self.providers = providers

def _first_table_provider(self, table_name: str) -> Optional[StateProvider]:
for provider in self.providers:
try:
if table_name in provider.get_available_sync_tables():
return provider
except Exception:
continue
return None

def _table_providers(self, table_name: str) -> List[StateProvider]:
providers: List[StateProvider] = []
for provider in self.providers:
try:
if table_name in provider.get_available_sync_tables():
providers.append(provider)
except Exception:
continue
return providers

def get_available_sync_tables(self) -> List[str]:
tables: List[str] = []
for provider in self.providers:
try:
for table in provider.get_available_sync_tables():
if table not in tables:
tables.append(table)
except Exception:
continue
return tables

def calculate_table_hash(self, table_name: str) -> str:
for provider in self._table_providers(table_name):
try:
return provider.calculate_table_hash(table_name)
except Exception:
continue
return ""

def get_merkle_root(self) -> str:
combined = "".join(
self.calculate_table_hash(table)
for table in self.get_available_sync_tables()
)
return hashlib.sha256(combined.encode()).hexdigest()

def get_primary_key(self, table_name: str) -> Optional[str]:
for provider in self._table_providers(table_name):
try:
primary_key = provider.get_primary_key(table_name)
except Exception:
continue
if primary_key:
return primary_key
return None

def get_table_data(
self, table_name: str, limit: int = 200, offset: int = 0
) -> List[Dict[str, Any]]:
for provider in self._table_providers(table_name):
try:
return provider.get_table_data(table_name, limit, offset)
except Exception:
continue
return []

def apply_sync_payload(self, table_name: str, remote_data: List[Dict[str, Any]]):
for provider in self._table_providers(table_name):
try:
if provider.apply_sync_payload(table_name, remote_data):
return True
except Exception:
continue
return False

def get_count(self, table_name: str) -> int:
for provider in self._table_providers(table_name):
try:
return provider.get_count(table_name)
except Exception:
continue
return 0


class RustChainSyncManager:
"""Handles bidirectional synchronization through a swappable state provider."""

def __init__(
self,
db_path: str,
admin_key: str,
state_provider: Optional[StateProvider] = None,
):
self.db_path = db_path
self.admin_key = admin_key
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger("RustChainSync")
self.state_provider = state_provider or SQLiteStateProvider(
db_path,
logger=self.logger,
)
# Backward-compatible access for older tests that clear the SQLite schema cache.
self._schema_cache = getattr(self.state_provider, "_schema_cache", {})

@property
def SYNC_TABLES(self) -> List[str]:
return self.get_available_sync_tables()

def get_available_sync_tables(self) -> List[str]:
return self.state_provider.get_available_sync_tables()

def calculate_table_hash(self, table_name: str) -> str:
return self.state_provider.calculate_table_hash(table_name)

def get_merkle_root(self) -> str:
return self.state_provider.get_merkle_root()

def _get_primary_key(self, table_name: str) -> Optional[str]:
return self.state_provider.get_primary_key(table_name)

def get_table_data(
self, table_name: str, limit: int = 200, offset: int = 0
) -> List[Dict[str, Any]]:
return self.state_provider.get_table_data(table_name, limit, offset)

def apply_sync_payload(self, table_name: str, remote_data: List[Dict[str, Any]]):
return self.state_provider.apply_sync_payload(table_name, remote_data)

def _get_count(self, table_name: str) -> int:
return self.state_provider.get_count(table_name)

def get_sync_status(self) -> Dict[str, Any]:
"""Returns metadata about the current state of synced tables."""
tables = self.SYNC_TABLES
Expand All @@ -272,15 +514,3 @@ def get_sync_status(self) -> Dict[str, Any]:
"pk": self._get_primary_key(t),
}
return status

def _get_count(self, table_name: str) -> int:
if table_name not in self.SYNC_TABLES:
return 0
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
count = cursor.fetchone()[0]
return int(count)
finally:
conn.close()
Loading
Loading