Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions rust/crates/cli/src/commands/server/start.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use axum::response::{IntoResponse, Response};
use axum::routing::{any, get, post};
use owo_colors::OwoColorize;
use pay_core::PaymentState;
use pay_core::ReplayStore;
use pay_core::accounts::AccountsStore;
use pay_core::server::session::SessionMpp;
use pay_core::server::telemetry::FeePayerWallet;
Expand Down Expand Up @@ -91,6 +92,7 @@ struct AppState {
session_mpp: Option<Arc<SessionMpp>>,
browser_rpc_url: Option<String>,
fee_payer_wallet: Option<FeePayerWallet>,
replay_store: ReplayStore,
}

impl PaymentState for AppState {
Expand All @@ -112,6 +114,9 @@ impl PaymentState for AppState {
fn fee_payer_wallet(&self) -> Option<&FeePayerWallet> {
self.fee_payer_wallet.as_ref()
}
fn replay_store(&self) -> Option<&ReplayStore> {
Some(&self.replay_store)
}
}

fn should_use_auto_fee_payer_signer(
Expand Down Expand Up @@ -769,6 +774,7 @@ impl StartCommand {
session_mpp,
browser_rpc_url: Some(BROWSER_RPC_PROXY_PATH.to_string()),
fee_payer_wallet,
replay_store: ReplayStore::default(),
};

let pdb_state = if debugger {
Expand Down
5 changes: 5 additions & 0 deletions rust/crates/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ pub use server::{AccountingKey, AccountingStore, InMemoryStore, current_period};
#[cfg(feature = "server")]
use pay_types::metering::ApiSpec;
#[cfg(feature = "server")]
pub use server::payment::ReplayStore;
#[cfg(feature = "server")]
pub use solana_mpp;
#[cfg(feature = "server")]
use solana_mpp::server::Mpp;
Expand All @@ -55,4 +57,7 @@ pub trait PaymentState: Clone + Send + Sync + 'static {
fn fee_payer_wallet(&self) -> Option<&server::telemetry::FeePayerWallet> {
None
}
fn replay_store(&self) -> Option<&server::payment::ReplayStore> {
None
}
}
241 changes: 240 additions & 1 deletion rust/crates/core/src/server/payment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
//! - No payment header → 402 with MPP challenge (WWW-Authenticate)
//! - Payment header → verify with solana-mpp, then forward upstream

use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};

use axum::body::Body;
use axum::http::{HeaderMap, Method, Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use base64::Engine;
use serde_json::json;
use solana_mpp::{
AUTHORIZATION_HEADER, PAYMENT_RECEIPT_HEADER, WWW_AUTHENTICATE_HEADER, format_receipt,
Expand Down Expand Up @@ -37,6 +42,60 @@ const PAYMENT_PAGE_CONTENT_SECURITY_POLICY: &str = "\
img-src 'self' data: blob: https:; \
connect-src 'self' http://localhost:* http://127.0.0.1:* https:; \
worker-src 'self'";
const DEFAULT_REPLAY_TTL: Duration = Duration::from_secs(30 * 60);

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ReplayRecord {
used_at: Instant,
}

#[derive(Clone)]
pub struct ReplayStore {
inner: Arc<Mutex<HashMap<String, ReplayRecord>>>,
ttl: Duration,
}

impl Default for ReplayStore {
fn default() -> Self {
Self::new(DEFAULT_REPLAY_TTL)
}
}

impl ReplayStore {
pub fn new(ttl: Duration) -> Self {
Self {
inner: Arc::new(Mutex::new(HashMap::new())),
ttl,
}
}

pub fn contains_recent(&self, key: &str) -> bool {
let mut guard = self.inner.lock().unwrap();
cleanup_expired_locked(&mut guard, self.ttl);
guard.contains_key(key)
}

pub fn mark_used(&self, key: String) {
let mut guard = self.inner.lock().unwrap();
cleanup_expired_locked(&mut guard, self.ttl);
guard.insert(
key,
ReplayRecord {
used_at: Instant::now(),
},
);
}

pub fn cleanup(&self) {
let mut guard = self.inner.lock().unwrap();
cleanup_expired_locked(&mut guard, self.ttl);
}

#[cfg(test)]
fn len(&self) -> usize {
self.inner.lock().unwrap().len()
}
}

/// Axum middleware that gates metered endpoints behind MPP payment.
pub async fn payment_middleware<S: PaymentState>(
Expand Down Expand Up @@ -189,6 +248,8 @@ pub async fn payment_middleware<S: PaymentState>(
auth_value,
subdomain,
&path,
&method,
state.replay_store().cloned(),
state.fee_payer_wallet().cloned(),
req,
next,
Expand Down Expand Up @@ -443,14 +504,16 @@ fn resolve_charge_splits(

#[tracing::instrument(
name = "charge_authorization",
skip(mpps, auth_value, fee_payer_wallet, req, next),
skip(mpps, auth_value, replay_store, fee_payer_wallet, req, next),
fields(subdomain = %subdomain, path = %path)
)]
async fn handle_charge_authorization(
mpps: &[&solana_mpp::server::Mpp],
auth_value: &str,
subdomain: &str,
path: &str,
method: &Method,
replay_store: Option<ReplayStore>,
fee_payer_wallet: Option<telemetry::FeePayerWallet>,
req: Request<Body>,
next: Next,
Expand All @@ -471,6 +534,28 @@ async fn handle_charge_authorization(
for mpp in mpps {
match mpp.verify_credential(&credential).await {
Ok(receipt) => {
let replay_key =
build_mpp_replay_key(&credential, &receipt.reference, method, path);
if let Some(store) = replay_store.as_ref()
&& let Some(key) = replay_key.as_deref()
&& store.contains_recent(key)
{
tracing::warn!(
subdomain = %subdomain,
path = %path,
method = %method,
replay_key = %key,
"Payment proof replay detected"
);
telemetry::record_settlement_error(
"mpp",
subdomain,
path,
"payment proof already used",
false,
);
return replay_failed_response(mpps);
}
let payment = decode_payment_amount(&credential, mpp.decimals() as u8);
telemetry::record_payment_collected(
"mpp",
Expand Down Expand Up @@ -501,6 +586,12 @@ async fn handle_charge_authorization(
{
response.headers_mut().insert(PAYMENT_RECEIPT_HEADER, v);
}
if response.status().is_success()
&& let Some(store) = replay_store.as_ref()
&& let Some(key) = replay_key
{
store.mark_used(key);
}
return response;
}
Err(e) => last_error = Some(e),
Expand Down Expand Up @@ -547,6 +638,26 @@ fn verification_failed_response(
response
}

fn replay_failed_response(mpps: &[&solana_mpp::server::Mpp]) -> Response {
let mut response = (
StatusCode::PAYMENT_REQUIRED,
axum::Json(json!({
"error": "verification_failed",
"message": "payment proof already used",
"retryable": false,
})),
)
.into_response();
let challenges: Vec<_> = mpps
.iter()
.filter_map(|mpp| mpp.charge("0.01").ok())
.collect();
if let Ok(www_auths) = format_www_authenticate_many(&challenges) {
append_www_authenticate_headers(response.headers_mut(), &www_auths);
}
response
}

pub fn readable_verification_message(error: &solana_mpp::server::VerificationError) -> String {
let message = error.to_string();
if message.contains("Fee payer cannot authorize the SPL payment transfer") {
Expand All @@ -561,6 +672,55 @@ pub fn readable_verification_message(error: &solana_mpp::server::VerificationErr
message
}

fn build_mpp_replay_key(
credential: &solana_mpp::PaymentCredential,
receipt_reference: &str,
method: &Method,
path: &str,
) -> Option<String> {
let payer = extract_payer_from_credential(credential)?;
let canonical = format!(
"mpp:{}:{payer}:{receipt_reference}:{}:{path}",
credential.challenge.id,
method.as_str()
);
Some(format!(
"mpp:{}",
blake3::hash(canonical.as_bytes()).to_hex()
))
}

fn extract_payer_from_credential(credential: &solana_mpp::PaymentCredential) -> Option<String> {
if let Some(tx_b64) = credential
.payload
.get("transaction")
.and_then(|value| value.as_str())
{
let tx_bytes = base64::engine::general_purpose::STANDARD
.decode(tx_b64)
.ok()?;
let tx: solana_transaction::Transaction = bincode::deserialize(&tx_bytes).ok()?;
let zero_sig = [0u8; 64];
for (index, sig) in tx.signatures.iter().enumerate() {
if sig.as_ref() != zero_sig && index < tx.message.account_keys.len() {
return Some(tx.message.account_keys[index].to_string());
}
}
return tx.message.account_keys.first().map(ToString::to_string);
}

credential
.payload
.get("source")
.and_then(|value| value.as_str())
.map(ToString::to_string)
}

fn cleanup_expired_locked(entries: &mut HashMap<String, ReplayRecord>, ttl: Duration) {
let now = Instant::now();
entries.retain(|_, record| now.duration_since(record.used_at) < ttl);
}

fn challenge_json_response(body: serde_json::Value, www_auths: &[String]) -> Response {
let mut response = (StatusCode::PAYMENT_REQUIRED, axum::Json(body)).into_response();
append_www_authenticate_headers(response.headers_mut(), www_auths);
Expand Down Expand Up @@ -640,9 +800,13 @@ fn extract_variant_hint(path: &str) -> Option<String> {
#[cfg(test)]
mod tests {
use super::*;
use base64::Engine;
use solana_mpp::WWW_AUTHENTICATE_HEADER;
use solana_mpp::server::Mpp;
use solana_mpp::server::session::SessionConfig;
use solana_signature::Signature;
use solana_transaction::Transaction;
use std::thread;

fn test_mpp() -> Mpp {
Mpp::new(solana_mpp::server::Config {
Expand Down Expand Up @@ -773,6 +937,81 @@ mod tests {
);
}

#[test]
fn replay_store_marks_and_finds_recent_keys() {
let store = ReplayStore::new(Duration::from_secs(60));
assert!(!store.contains_recent("proof-1"));
store.mark_used("proof-1".to_string());
assert!(store.contains_recent("proof-1"));
}

#[test]
fn replay_store_expires_old_entries() {
let store = ReplayStore::new(Duration::from_millis(5));
store.mark_used("proof-1".to_string());
thread::sleep(Duration::from_millis(10));
assert!(!store.contains_recent("proof-1"));
assert_eq!(store.len(), 0);
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn replay_failed_response_is_non_retryable_verification_error() {
let mpp = test_mpp();
let response = replay_failed_response(&[&mpp]);
assert_eq!(response.status(), StatusCode::PAYMENT_REQUIRED);

let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(body["error"], "verification_failed");
assert_eq!(body["message"], "payment proof already used");
assert_eq!(body["retryable"], false);
}

#[test]
fn build_mpp_replay_key_is_stable_for_same_verified_inputs() {
let mpp = test_mpp();
let challenge = mpp.charge("0.01").expect("challenge should build");
let tx = Transaction {
signatures: vec![Signature::new_unique()],
message: solana_message::Message {
header: solana_message::MessageHeader {
num_required_signatures: 1,
num_readonly_signed_accounts: 0,
num_readonly_unsigned_accounts: 0,
},
account_keys: vec![solana_pubkey::Pubkey::new_unique()],
recent_blockhash: solana_hash::Hash::new_unique(),
instructions: vec![],
},
};
let tx_b64 =
base64::engine::general_purpose::STANDARD.encode(bincode::serialize(&tx).unwrap());
let credential = solana_mpp::PaymentCredential::new(
challenge.to_echo(),
serde_json::json!({
"type": "transaction",
"transaction": tx_b64,
}),
);

let key_a = build_mpp_replay_key(
&credential,
"receipt-ref-1",
&Method::POST,
"v1/simple/echo",
);
let key_b = build_mpp_replay_key(
&credential,
"receipt-ref-1",
&Method::POST,
"v1/simple/echo",
);

assert_eq!(key_a, key_b);
}

#[tokio::test]
async fn session_challenge_response_sets_session_header() {
let response = session_challenge_response(
Expand Down