Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 217 additions & 20 deletions node/rustchain_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,40 @@
import json
import time
import logging
from typing import List, Dict, Any, Optional
from copy import deepcopy
from typing import List, Dict, Any, Optional, Protocol


class RustChainSyncManager:
class StateProvider(Protocol):
"""Swappable source of syncable RustChain state."""

def get_available_sync_tables(self) -> List[str]:
...

def calculate_table_hash(self, table_name: str) -> str:
...

def get_merkle_root(self) -> str:
...

def get_primary_key(self, table_name: str) -> Optional[str]:
...

def get_table_data(
self, table_name: str, limit: int = 200, offset: int = 0
) -> List[Dict[str, Any]]:
...

def apply_sync_payload(self, table_name: str, remote_data: List[Dict[str, Any]]):
...

def get_count(self, table_name: str) -> int:
...


class SQLiteStateProvider:
"""
Handles bidirectional SQLite synchronization between RustChain nodes.
SQLite-backed sync state provider.

Security model:
- Table names are allowlisted
Expand All @@ -30,11 +58,9 @@ class RustChainSyncManager:
"transaction_history",
]

def __init__(self, db_path: str, admin_key: str):
def __init__(self, db_path: str, logger: Optional[logging.Logger] = None):
self.db_path = db_path
self.admin_key = admin_key
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger("RustChainSync")
self.logger = logger or logging.getLogger("RustChainSync")
self._schema_cache: Dict[str, Dict[str, Any]] = {}

def _get_connection(self):
Expand Down Expand Up @@ -128,7 +154,7 @@ def get_merkle_root(self) -> str:
combined = "".join(table_hashes)
return hashlib.sha256(combined.encode()).hexdigest()

def _get_primary_key(self, table_name: str) -> Optional[str]:
def get_primary_key(self, table_name: str) -> Optional[str]:
schema = self._load_table_schema(table_name)
if not schema:
return None
Expand Down Expand Up @@ -256,6 +282,189 @@ def apply_sync_payload(self, table_name: str, remote_data: List[Dict[str, Any]])
finally:
conn.close()

def get_count(self, table_name: str) -> int:
if table_name not in self.SYNC_TABLES:
return 0
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
count = cursor.fetchone()[0]
return int(count)
finally:
conn.close()


class InMemoryStateProvider:
"""Small in-memory provider for tests and embedded callers."""

def __init__(
self,
tables: Optional[Dict[str, List[Dict[str, Any]]]] = None,
primary_keys: Optional[Dict[str, str]] = None,
):
self.tables = deepcopy(tables or {})
self.primary_keys = dict(primary_keys or {})

def get_available_sync_tables(self) -> List[str]:
return [name for name in self.tables if self.primary_keys.get(name)]

def calculate_table_hash(self, table_name: str) -> str:
if table_name not in self.get_available_sync_tables():
return ""

pk = self.primary_keys[table_name]
rows = sorted(self.tables.get(table_name, []), key=lambda row: row.get(pk))
hasher = hashlib.sha256()
for row in rows:
row_str = json.dumps(row, sort_keys=True, separators=(",", ":"))
hasher.update(row_str.encode())
return hasher.hexdigest()

def get_merkle_root(self) -> str:
combined = "".join(
self.calculate_table_hash(table)
for table in self.get_available_sync_tables()
)
return hashlib.sha256(combined.encode()).hexdigest()

def get_primary_key(self, table_name: str) -> Optional[str]:
return self.primary_keys.get(table_name)

def get_table_data(
self, table_name: str, limit: int = 200, offset: int = 0
) -> List[Dict[str, Any]]:
if table_name not in self.get_available_sync_tables():
return []
rows = self.tables.get(table_name, [])
return deepcopy(rows[int(offset): int(offset) + int(limit)])

def apply_sync_payload(self, table_name: str, remote_data: List[Dict[str, Any]]):
pk = self.primary_keys.get(table_name)
if not pk or table_name not in self.tables:
return False

rows_by_pk = {row.get(pk): dict(row) for row in self.tables[table_name]}
for row in remote_data:
if not isinstance(row, dict) or pk not in row:
continue
existing = rows_by_pk.get(row[pk], {})
existing.update(row)
rows_by_pk[row[pk]] = existing
self.tables[table_name] = list(rows_by_pk.values())
return True

def get_count(self, table_name: str) -> int:
if table_name not in self.get_available_sync_tables():
return 0
return len(self.tables.get(table_name, []))


class FallbackStateProvider:
"""Try multiple providers in order so callers can swap state sources safely."""

def __init__(self, providers: List[StateProvider]):
if not providers:
raise ValueError("at least one state provider is required")
self.providers = providers

def _first_table_provider(self, table_name: str) -> Optional[StateProvider]:
for provider in self.providers:
try:
if table_name in provider.get_available_sync_tables():
return provider
except Exception:
continue
return None

def get_available_sync_tables(self) -> List[str]:
tables: List[str] = []
for provider in self.providers:
try:
for table in provider.get_available_sync_tables():
if table not in tables:
tables.append(table)
except Exception:
continue
return tables

def calculate_table_hash(self, table_name: str) -> str:
provider = self._first_table_provider(table_name)
return provider.calculate_table_hash(table_name) if provider else ""

def get_merkle_root(self) -> str:
combined = "".join(
self.calculate_table_hash(table)
for table in self.get_available_sync_tables()
)
return hashlib.sha256(combined.encode()).hexdigest()

def get_primary_key(self, table_name: str) -> Optional[str]:
provider = self._first_table_provider(table_name)
return provider.get_primary_key(table_name) if provider else None

def get_table_data(
self, table_name: str, limit: int = 200, offset: int = 0
) -> List[Dict[str, Any]]:
provider = self._first_table_provider(table_name)
return provider.get_table_data(table_name, limit, offset) if provider else []

def apply_sync_payload(self, table_name: str, remote_data: List[Dict[str, Any]]):
provider = self._first_table_provider(table_name)
return provider.apply_sync_payload(table_name, remote_data) if provider else False

def get_count(self, table_name: str) -> int:
provider = self._first_table_provider(table_name)
return provider.get_count(table_name) if provider else 0


class RustChainSyncManager:
"""Handles bidirectional synchronization through a swappable state provider."""

def __init__(
self,
db_path: str,
admin_key: str,
state_provider: Optional[StateProvider] = None,
):
self.db_path = db_path
self.admin_key = admin_key
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger("RustChainSync")
self.state_provider = state_provider or SQLiteStateProvider(
db_path,
logger=self.logger,
)
# Backward-compatible access for older tests that clear the SQLite schema cache.
self._schema_cache = getattr(self.state_provider, "_schema_cache", {})

@property
def SYNC_TABLES(self) -> List[str]:
return self.get_available_sync_tables()

def get_available_sync_tables(self) -> List[str]:
return self.state_provider.get_available_sync_tables()

def calculate_table_hash(self, table_name: str) -> str:
return self.state_provider.calculate_table_hash(table_name)

def get_merkle_root(self) -> str:
return self.state_provider.get_merkle_root()

def _get_primary_key(self, table_name: str) -> Optional[str]:
return self.state_provider.get_primary_key(table_name)

def get_table_data(
self, table_name: str, limit: int = 200, offset: int = 0
) -> List[Dict[str, Any]]:
return self.state_provider.get_table_data(table_name, limit, offset)

def apply_sync_payload(self, table_name: str, remote_data: List[Dict[str, Any]]):
return self.state_provider.apply_sync_payload(table_name, remote_data)

def _get_count(self, table_name: str) -> int:
return self.state_provider.get_count(table_name)

def get_sync_status(self) -> Dict[str, Any]:
"""Returns metadata about the current state of synced tables."""
tables = self.SYNC_TABLES
Expand All @@ -272,15 +481,3 @@ def get_sync_status(self) -> Dict[str, Any]:
"pk": self._get_primary_key(t),
}
return status

def _get_count(self, table_name: str) -> int:
if table_name not in self.SYNC_TABLES:
return 0
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
count = cursor.fetchone()[0]
return int(count)
finally:
conn.close()
111 changes: 111 additions & 0 deletions node/tests/test_state_provider_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# SPDX-License-Identifier: MIT

import os
import sqlite3
import sys

sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))

from rustchain_sync import (
FallbackStateProvider,
InMemoryStateProvider,
RustChainSyncManager,
SQLiteStateProvider,
)


def test_sync_manager_accepts_injected_state_provider():
provider = InMemoryStateProvider(
tables={
"miner_attest_recent": [
{"miner_id": "miner-a", "last_attest": 10},
],
},
primary_keys={"miner_attest_recent": "miner_id"},
)
sync = RustChainSyncManager(":memory:", "sync-secret", state_provider=provider)

assert sync.SYNC_TABLES == ["miner_attest_recent"]
assert sync.get_table_data("miner_attest_recent") == [
{"miner_id": "miner-a", "last_attest": 10},
]
assert sync.get_sync_status()["tables"]["miner_attest_recent"]["count"] == 1


def test_fallback_state_provider_uses_secondary_when_primary_fails():
class BrokenProvider:
def get_available_sync_tables(self):
raise RuntimeError("primary unavailable")

def calculate_table_hash(self, table_name):
raise RuntimeError("primary unavailable")

def get_merkle_root(self):
raise RuntimeError("primary unavailable")

def get_primary_key(self, table_name):
raise RuntimeError("primary unavailable")

def get_table_data(self, table_name, limit=200, offset=0):
raise RuntimeError("primary unavailable")

def apply_sync_payload(self, table_name, remote_data):
raise RuntimeError("primary unavailable")

def get_count(self, table_name):
raise RuntimeError("primary unavailable")

secondary = InMemoryStateProvider(
tables={"epoch_rewards": [{"epoch": 7, "reward": 100}]},
primary_keys={"epoch_rewards": "epoch"},
)
provider = FallbackStateProvider([BrokenProvider(), secondary])

assert provider.get_available_sync_tables() == ["epoch_rewards"]
assert provider.get_primary_key("epoch_rewards") == "epoch"
assert provider.get_table_data("epoch_rewards") == [
{"epoch": 7, "reward": 100},
]
assert provider.get_count("epoch_rewards") == 1


def test_default_sqlite_provider_preserves_existing_sync_behavior(tmp_path):
db_path = tmp_path / "rustchain.db"
with sqlite3.connect(db_path) as conn:
conn.execute(
"""
CREATE TABLE miner_attest_recent (
miner_id TEXT PRIMARY KEY,
last_attest INTEGER NOT NULL
)
"""
)
conn.execute(
"INSERT INTO miner_attest_recent (miner_id, last_attest) VALUES (?, ?)",
("miner-a", 10),
)
conn.commit()

sync = RustChainSyncManager(str(db_path), "sync-secret")

assert isinstance(sync.state_provider, SQLiteStateProvider)
assert sync.get_available_sync_tables() == ["miner_attest_recent"]
assert sync.get_table_data("miner_attest_recent") == [
{"miner_id": "miner-a", "last_attest": 10},
]

assert sync.apply_sync_payload(
"miner_attest_recent",
[{"miner_id": "miner-a", "last_attest": 5}],
)
assert sync.get_table_data("miner_attest_recent") == [
{"miner_id": "miner-a", "last_attest": 10},
]

assert sync.apply_sync_payload(
"miner_attest_recent",
[{"miner_id": "miner-a", "last_attest": 12}],
)
assert sync.get_table_data("miner_attest_recent") == [
{"miner_id": "miner-a", "last_attest": 12},
]
Loading