diff --git a/Cargo.lock b/Cargo.lock index 5ea4912c43ef..c64810403bdc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5475,6 +5475,7 @@ dependencies = [ "aws-sdk-iam", "aws-sigv4", "base64 0.22.1", + "bb8", "bstr", "bytes", "camino", @@ -5584,6 +5585,19 @@ dependencies = [ "zeroize", ] +[[package]] +name = "proxy-cplane-api" +version = "0.1.0" +dependencies = [ + "axum", + "clap", + "humantime", + "serde", + "serde_json", + "tokio", + "workspace_hack", +] + [[package]] name = "quanta" version = "0.12.5" diff --git a/Cargo.toml b/Cargo.toml index 3f2308679771..784ec8515d11 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ members = [ "pageserver/pagebench", "pageserver/page_api", "proxy", + "proxy-cplane-api", "safekeeper", "safekeeper/client", "storage_broker", diff --git a/docs/review/00-overview.md b/docs/review/00-overview.md new file mode 100644 index 000000000000..15c18cdca870 --- /dev/null +++ b/docs/review/00-overview.md @@ -0,0 +1,87 @@ +# Review pack — `mel/pool-fixes-and-benchmarks` + +You're reviewing five commits sitting on top of Charles's session-pool +WIP commit (`6e871ba05`). Three pieces of work, in dependency order: + +1. **Fix the session-pool half-close bug** so reused connections actually + work (`6a361a996`). +2. **Multi-endpoint routing** so the mock control plane can return + different compute addresses for different endpoint IDs + (`fa654b133`). +3. **Benchmark sweep** comparing direct / proxy+pool / proxy without pool, + plus FINDINGS.md (`9ce39d144`). +4. **Transaction-mode pool multiplexing** — release on every + `ReadyForQuery 'I'`, re-acquire from the pool for the next + transaction (`29f0bc6d2`). +5. **Benchmark update** adding a `proxy_txn` configuration and section 5 + in FINDINGS.md (`ef3db262b`). + +## Reading order + +1. **`01-session-pool-fix.md`** — start here. Smallest commit, cleanest + bug story. The fix is ~30 lines but the chain of reasoning that got + there is half the value. +2. **`02-multi-endpoint.md`** — quick read; mostly CLI plumbing and a + lookup-priority decision in `mock.rs`. +3. **`03-benchmarks.md`** — methodology, findings, and the things the + numbers actually say (vs. what they superficially look like). Read + `benchmarks/FINDINGS.md` alongside. +4. **`04-transaction-mode.md`** — the bigger architectural change. Read + the `ReadyForQueryWatcher` and `proxy_pass_transaction_mode` + sections; skim the rest. Compare against `pgcat/src/client.rs`'s + two-loop structure if you want the canonical reference. +5. **`05-open-questions.md`** — the weak points and decisions I'd + challenge if I were reviewing me. Read this last — it'll change + how you read everything else. + +## High-level commit map + +``` +ef3db262b benchmarks: add proxy_txn config and update FINDINGS +29f0bc6d2 transaction-mode pool multiplexing +9ce39d144 benchmarks: pool/no-pool/direct sweep with FINDINGS +fa654b133 multi-endpoint routing via --compute-endpoint-map +6a361a996 session pooling works: don't poll_shutdown(compute) on client Terminate +6e871ba05 basic bb8 pooling <- Charles's last commit +``` + +`git log --stat --reverse 6e871ba05..HEAD` gives the per-commit diff +sizes. Five commits, ~1200 lines added, ~50 removed. + +## What's in scope vs out of scope + +| In scope | Out of scope | +| ---------------------------------------------- | ---------------------------------- | +| Session-pool correctness (the half-close bug) | Pool eviction / health checks | +| Per-endpoint routing in the mock control plane | Real cplane integration | +| Transaction-mode multiplexing | `SET` / `LISTEN` detection | +| pgbench sweep + FINDINGS | Cancellation across multiplexed conns | +| Memory measurement at peak | Auth caching | + +The five "out of scope" items are all real follow-ups; some are noted in +`05-open-questions.md`. + +## How to run things locally + +Already documented in `benchmarks/README.md`. The shortest reproducer: + +```bash +# Build +cargo build --features testing --bin proxy + +# Smoke session pool +RUST_LOG="proxy::tcp_pool=info" PGPASSWORD=testpw target/debug/proxy \ + --auth-backend=postgres \ + --auth-endpoint='postgresql://proxytest@localhost:5432/proxytest_db' \ + --compute-endpoint='postgresql://localhost:5433/proxytest_db?sslmode=disable' \ + --proxy=127.0.0.1:4432 --mgmt=127.0.0.1:7000 --http=127.0.0.1:7001 --wss=127.0.0.1:7002 \ + --tcp-pool-enabled=true --tcp-pool-max-conns-per-key=5 --tcp-pool-fallback-direct-connect=true + +# Run 5 sessions and check the pid is stable (session pooling works) +for i in 1 2 3 4 5; do + PGPASSWORD=testpw psql "postgresql://proxytest@127.0.0.1:4432/proxytest_db?sslmode=disable&options=endpoint%3Dep-test-123" \ + -c "SELECT $i, pg_backend_pid()" +done + +# Add --tcp-pool-mode=transaction to see multiplex behaviour +``` diff --git a/docs/review/01-session-pool-fix.md b/docs/review/01-session-pool-fix.md new file mode 100644 index 000000000000..e23f583fbcb5 --- /dev/null +++ b/docs/review/01-session-pool-fix.md @@ -0,0 +1,124 @@ +# 01 — Session-pool half-close bug + +Commit `6a361a996` "session pooling works: don't poll_shutdown(compute) +on client Terminate". Three files, ~190 lines net. + +## TL;DR + +When a client cleanly disconnected (Postgres `'X'` Terminate), the +existing `TerminateFilter` correctly **didn't forward** the Terminate +to compute, but the surrounding bidirectional-copy state machine +**still ran `poll_shutdown` on the compute writer** as part of its +"reader hit EOF, shut down the writer" cleanup path. That's +`shutdown(SHUT_WR)` on the live compute socket. Postgres saw the FIN, +closed its side, and the next pool reuse hit `ENOTCONN`. + +Fix: a parallel `transfer_one_direction_no_shutdown` that goes +`Running → Done` directly, used for the client→compute direction in +the pooled bidir copy. Compute's writer is never half-closed. + +## The chain of reasoning that got here + +This took several iterations to localize. The diagnostic timeline +(captured by adding probes that I later removed) was: + +``` +20.247745 tcp pool: opened new connection +20.249371 TerminateFilter: intercepted Postgres Terminate +20.249403 transfer_one_direction: entering ShuttingDown ← BUG +20.249433 poll_shutdown(writer) returned Ready(Ok) ← FIN sent +20.249507 poll_write(empty) error: Broken pipe ← write half dead +20.249559 tcp pool: returning connection ← broken conn pooled +[iter 2, 243ms later] +20.493396 peer_addr error: Invalid argument ← peer gone +20.493689 per-client task IO error: Socket is not connected +``` + +The bug is that `TerminateFilter::poll_read` returns `Ready(Ok(()))` +with **zero bytes filled** when it intercepts `'X'`. `CopyBuffer` +treats zero-fill as EOF (`me.read_done = me.cap == filled_len`), and +`poll_copy` returns `Ready(Ok(amt))`. `transfer_one_direction` then +transitions: + +``` +Running → ShuttingDown → poll_shutdown(writer) → Done +``` + +That `poll_shutdown(writer)` is on the compute side. It's the bug. + +The early `saw_terminate()` check the existing pooled bidir copy +**did** have was checked *after* `transfer_one_direction` returned — +too late, the shutdown had already happened inside the call. + +## The fix + +`transfer_one_direction_no_shutdown` is the same state machine minus +the `ShuttingDown` step: + +```rust +TransferState::Running(buf) => { + let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?; + *state = TransferState::Done(count); // skip ShuttingDown +} +TransferState::ShuttingDown(count) => { + *state = TransferState::Done(*count); // dead branch, but keep +} +TransferState::Done(count) => return Poll::Ready(Ok(*count)), +``` + +In `copy_bidirectional_client_compute_pooled`, swap the +client→compute call to the no-shutdown variant: + +```rust +let client_to_compute_result = transfer_one_direction_no_shutdown( + cx, &mut client_to_compute, &mut filtered_client, compute, +).map_err(ErrorSource::from_client)?; +``` + +The compute→client direction keeps using the regular variant — if +compute really closes, shutting down the client side is correct. + +## Files + +- `proxy/src/pglb/copy_bidirectional.rs` — the fix +- `proxy/src/pglb/passthrough.rs` — unchanged (probes removed) +- `proxy/src/proxy/mod.rs`, `proxy/src/tcp_pool.rs` — small unrelated + prep work in this commit (was-reused signal, pool key on checkout + via `Option<...>` tuple) + +## Read in this order + +1. `proxy/src/pglb/copy_bidirectional.rs:43-95` — read both + `transfer_one_direction` and `transfer_one_direction_no_shutdown` + side by side. The diff is exactly the deletion of the + `poll_shutdown` call. +2. `proxy/src/pglb/copy_bidirectional.rs:120-235` — read + `copy_bidirectional_client_compute_pooled` and find the call site + that uses the no-shutdown variant. The CRITICAL comment is a + load-bearing invariant. +3. `proxy/src/pglb/copy_bidirectional.rs::TerminateFilter` (around + line 460) — to understand what the filter actually emits when it + intercepts `'X'`. That's the input the bidir copy is responding + to. + +## Things I'd challenge in review + +- **The `ShuttingDown` arm in `_no_shutdown` is dead code.** It exists + to keep the state machine total. Pragmatic but a bit ugly. An + alternative is a separate state enum. +- **Why not just have `TerminateFilter` return `Pending` instead of + EOF on intercept?** I considered this in passing — it would avoid + the EOF-cascade entirely. The downside is the filter would need to + self-wake or we'd hang. Strictly cleaner but adds wake-up plumbing. + Worth challenging. +- **`compute→client` still uses the shutdown-variant.** That's fine + *today* because compute closing means the conn is dead anyway. But + it does mean a slow-loris compute (sends a few bytes, then EOFs) + half-closes the client. In our case the client is being torn down + anyway. Edge-case unlikely to matter in practice. +- **The probes I used to find this bug were temporary** (zero-byte + `poll_write` and `peer_addr` checks at release/reuse, plus a + `entering ShuttingDown` log line). I stripped them before the + commit. If you want to see them, look at the diff history for the + fix on `mel/pool-fixes-and-benchmarks` — it's in the conversational + history but was never committed. diff --git a/docs/review/02-multi-endpoint.md b/docs/review/02-multi-endpoint.md new file mode 100644 index 000000000000..807f821c4f09 --- /dev/null +++ b/docs/review/02-multi-endpoint.md @@ -0,0 +1,101 @@ +# 02 — Multi-endpoint routing + +Commit `fa654b133` "multi-endpoint routing via --compute-endpoint-map". +Two files, ~57 lines added. + +## TL;DR + +Charles's earlier `--compute-endpoint` flag was a global override: +every `wake_compute` call used the same compute address regardless of +the endpoint ID in the connection string. This commit adds +`--compute-endpoint-map` to the mock control plane, threads the +`EndpointId` from `ComputeUserInfo` into `do_wake_compute`, and uses +a three-tier lookup so the existing single-tenant test keeps working. + +## Lookup priority (in `do_wake_compute`) + +``` +self.compute_endpoint_map.get(endpoint) // per-endpoint map (new) + .or(self.compute_endpoint.as_ref()) // global override (existing) + .unwrap_or(&self.auth_endpoint) // legacy fallback (existing) +``` + +This means: +- Endpoint IDs in the map → routed to their mapped address +- Endpoint IDs not in the map → fall through to `--compute-endpoint` + if set +- Otherwise → `--auth-endpoint` (the original behaviour) + +So adding a map doesn't change behaviour for IDs that aren't in it. +That kept Charles's single-tenant flow untouched. + +## Flag format + +``` +--compute-endpoint-map='ep-A=postgresql://localhost:5433/db?sslmode=disable,ep-B=postgresql://localhost:5434/db?sslmode=disable' +``` + +Comma-separated entries; first `=` separates endpoint id from URL. +URLs reuse the same `parse_compute_endpoint` validator (must include +hostname, sslmode if present must be `disable` or `require`). + +## Files + +- `proxy/src/control_plane/client/mock.rs` — `compute_endpoint_map: + HashMap` field on `MockControlPlane`, + `do_wake_compute(&EndpointId)` lookup, `wake_compute` plumb-through. +- `proxy/src/binary/proxy.rs` — `--compute-endpoint-map` clap arg, + `parse_compute_endpoint_map` parser. + +## Read in this order + +1. `proxy/src/control_plane/client/mock.rs::do_wake_compute` — + the lookup logic (the chained `or` / `unwrap_or` is the whole + feature). +2. `proxy/src/control_plane/client/mock.rs::wake_compute` (in the + `ControlPlaneApi` impl) — confirms `user_info.endpoint` is what + gets passed in. +3. `proxy/src/binary/proxy.rs::parse_compute_endpoint_map` — the + parser. ~20 lines, validates duplicates and missing separators. + +## How it was verified + +Two docker Postgres instances (`compute-A` on 5433, `compute-B` on +5434), each with a `marker` table containing `'compute-A'` / +`'compute-B'`. Same SCRAM verifier copied from auth Postgres so +SCRAM-passthrough works on both. Then: + +```bash +# Through the proxy with the map: +psql ".../?options=endpoint%3Dep-A" -c "SELECT val FROM marker" +# returns: compute-A + +psql ".../?options=endpoint%3Dep-B" -c "SELECT val FROM marker" +# returns: compute-B +``` + +Plus 5x `pg_backend_pid()` per endpoint to confirm pool isolation: +each endpoint's pool is keyed independently by `(endpoint, db, role)`, +so within `ep-A` the pid is stable (pool reuse) and across A/B the pids +differ (different computes). + +## Things I'd challenge in review + +- **Why not a database-backed mapping?** The brief noted that real + Neon uses `neon_control_plane.endpoints`. I picked the CLI-flag + approach because it's faster to iterate on and doesn't require + schema changes; the brief explicitly suggested this trade-off. + A real cplane would query the db. +- **`HashMap` lookup is sync inside `do_wake_compute`'s async fn.** + Fine — the map is read-only after startup. If it ever became + mutable, would need a `RwLock`. +- **No support for SSL-mode beyond `disable`/`require`.** Inherited + from Charles's `parse_compute_endpoint`. Anything unrecognised + errors at parse time. +- **Endpoint id in the auth context is actually `EndpointId` (a + smol_str wrapper),** so the HashMap key works without allocations. + Look at `proxy/src/types.rs::smol_str_wrapper!` to confirm. +- **Cancellation correctness**: I haven't audited whether + cross-endpoint cancellation routes correctly. Probably fine because + `BackendKeyData` is per-session and the cancel path doesn't go + through `wake_compute`. But worth confirming if you're paranoid. diff --git a/docs/review/03-benchmarks.md b/docs/review/03-benchmarks.md new file mode 100644 index 000000000000..392eae27372f --- /dev/null +++ b/docs/review/03-benchmarks.md @@ -0,0 +1,157 @@ +# 03 — Benchmarks + +Commits `9ce39d144` and `ef3db262b`. The first added the harness and +the initial three configurations; the second added a fourth +(`proxy_txn`) and the transaction-mode discussion in §5 of FINDINGS. + +The headline numbers are in `benchmarks/FINDINGS.md` — that's the +write-up I'd put in the report verbatim. This file is about +methodology and the things the numbers actually mean. + +## Methodology + +3 configurations × 3 workloads × 6 concurrencies × 2 reps × 10s += 108 cells, ~22 minutes wall time on macOS / Apple Silicon, all +co-located on one host (no network). + +Configurations: + +| name | what it isolates | +| -------------- | --------------------------------------------- | +| `direct` | floor: client → docker Postgres on 5433 | +| `proxy_pool` | full proxy path with TCP pool enabled | +| `proxy_nopool` | proxy path without pool — proxy overhead only | +| `proxy_txn` | proxy path with `--tcp-pool-mode=transaction` | + +Workloads: + +| name | pgbench flags | what it stresses | +| ------------------- | -------------- | -------------------------------- | +| `tpcb_steady` | (default) | mixed read/write, persistent | +| `readonly_steady` | `-S` | SELECT-only, persistent | +| `readonly_short` | `-S -C` | SELECT-only, **reconnect per tx**| + +The `-C` workload is what matters — that's where pool reuse pays off, +because every transaction in pgbench `-C` mode goes through a fresh +client → proxy connect. + +## What the numbers say (vs. what they look like at first glance) + +### Steady-state proxy overhead is small + +`tpcb_steady` C=100: direct 4440 / pool 3778 / nopool 3960. Proxy +adds ~13 %. Pool ≈ no-pool because persistent sessions never exercise +the pool — each pgbench client holds one compute conn for the full +30 s, so the pool only saves at most one handshake per client. +**This is the right answer for this workload**, not a problem. + +### Short-session pool benefit is real + +`readonly_short` (proxy_pool TPS / proxy_nopool TPS): + +| C | pool | nopool | ratio | +| --- | ---: | -----: | ----: | +| 1 | 114 | 61 | 1.86× | +| 5 | 499 | 336 | 1.49× | +| 10 | 621 | 357 | 1.74× | +| 25 | 649 | 511 | 1.27× | +| 50 | 588 | 249 | 2.36× | +| 100 | 633 | 384 | 1.65× | + +1.3×–2.4× pool benefit on the short-session workload. + +### The remaining gap to direct is auth-Postgres, not the pool + +`proxy_pool` short-session TPS plateaus around 600–650 even at high +concurrency, well below `direct` (~800–860). The bottleneck isn't +the pool — the pool is doing exactly what it should (compute +handshake skipped). The bottleneck is per-client-connection auth +lookups: every fresh client connection through the proxy queries +`pg_authid` on the auth Postgres. With C=100 -C that's hundreds +of fresh connect-cycles per second, all hitting auth-pg. + +Direct evidence in the proxy log under `-C` C=100: +``` +19 Can't assign requested address (os error 49) +``` + +That's macOS ephemeral-port exhaustion on the proxy → auth-pg socket +(macOS recycles TIME_WAIT slowly). Real evidence for "auth-pg is on +the critical path" — fixing it requires either (a) auth-info caching, +or (b) transaction-level pooling that lets one client connection share +both compute and auth across many short sessions. + +### Memory at peak load + +`proxy_pool` 81 MB RSS, `proxy_nopool` 78 MB. The pool adds ~3 MB to +hold up to ~100 idle compute conns — about 30 KB per pooled session. +Total proxy footprint stays under 100 MB serving 100 concurrent +clients. + +### Transaction mode doesn't help on pgbench + +`readonly_short` proxy_txn ≈ proxy_pool (within noise). Two +structural reasons (covered in `04-transaction-mode.md` and +FINDINGS §5): + +1. `pgbench -C` is one transaction per client session, so each + pgbench session only goes through one txn-boundary release/ + re-acquire cycle. No more multiplex than session mode does. +2. `pgbench` without `-C` keeps every client persistently busy, no + idle gap to multiplex into. + +Transaction mode wins on **idle think time between transactions**. +pgbench doesn't produce that profile. + +A small per-boundary release/re-acquire overhead is visible at low +concurrency on persistent-session workloads (`tpcb_steady` at C=1: +920 → 765, ~17 %). Converges to ~equal at C=100. + +## Files + +- `benchmarks/run_bench.sh` — orchestrator (~250 lines) +- `benchmarks/parse_pcts.py` — percentile post-processor +- `benchmarks/summarize.py` — aggregator (mean ± std) +- `benchmarks/results.csv` — 144 raw rows +- `benchmarks/summary.csv` — 72 aggregated rows +- `benchmarks/FINDINGS.md` — the write-up +- `benchmarks/auth_*.csv` — auth-pg pg_stat_activity samples taken + concurrently with the runs +- `benchmarks/mem_*.txt` — proxy ps snapshots at peak load + +## Read in this order + +1. **`benchmarks/FINDINGS.md`** — sections 1-5. Read end-to-end. +2. **`benchmarks/summary.csv`** — open in your IDE; sort by + workload+concurrency to compare the four configs side by side. +3. **`benchmarks/run_bench.sh`** — to understand exactly how each + number was generated. The interesting bits are `start_proxy`, + `auth_pg_watch_start`, and the per-cell `run_one`. +4. **`benchmarks/auth_proxy_pool_readonly_short_c25_r1.csv`** (and + siblings) — what auth-pg looked like during a `-C` run. + +## Things I'd challenge in review + +- **`-C` runs at C=100 hit ephemeral-port exhaustion on macOS.** That's + a kernel artefact, not a real bottleneck on Linux. The numbers at + C=100 -C are noisier than the others (proxy_pool stddev 82, nopool + 186) because of this. On a Linux host you'd see different numbers. +- **REPS=2 is statistical-stress floor.** With 2 reps, stddev is + population stddev of 2 points — basically range/2. The 30s × 3-rep + full sweep would have been more rigorous; we ran the 10s × 2-rep + variant for time. The mean column is more trustworthy than the + stddev column. +- **Memory snapshot is one moment in time.** RSS during a 100-client + load. Doesn't tell you whether the proxy memory grows over time + with churn. Worth running a longer soak to confirm. +- **DURATION=10 might be short.** Below the 30s floor I originally + recommended. pgbench amortises connection-setup over the duration, + so short duration over-emphasises the connect cost. For workloads + where the connect cost is the main story (`-C`), 10s is fine. +- **`auth_*.csv` polling cadence is 2 seconds.** Coarse. It catches + steady-state contention (≥ 2 active conns concurrently) but not + microsecond bursts. The proxy log is the more reliable signal for + fast contention. +- **No comparison against PgBouncer.** Deliberate per the brief — + apples-to-apples requires transaction pooling, which we now have + but haven't benchmarked against PgBouncer yet. diff --git a/docs/review/04-transaction-mode.md b/docs/review/04-transaction-mode.md new file mode 100644 index 000000000000..b7a284cf650a --- /dev/null +++ b/docs/review/04-transaction-mode.md @@ -0,0 +1,259 @@ +# 04 — Transaction-mode pool multiplexing + +Commit `29f0bc6d2`. Seven files, ~470 lines added. + +This is the bigger change. The reference architecture is +`pgcat/src/client.rs` (in `~/mel/pgcat`) — its two-loop structure +("client idle" outer loop / "transaction in progress" inner loop) is +the model. Read `pgcat/src/client.rs:891-1380` if you want the +canonical PgBouncer-style implementation. + +## TL;DR + +Add `--tcp-pool-mode=session|transaction` (default `session`). In +transaction mode, the proxy releases the compute connection back to +the pool at every `'Z' I` ReadyForQuery boundary and re-acquires for +the next transaction, allowing one compute conn to serve many client +transactions over its lifetime — the PgBouncer transaction-mode +semantic. + +## What the wire-protocol level looks like + +Postgres sends a `'Z'` (ReadyForQuery) message to the client at every +"ready to accept new query" point. The body is one byte indicating +transaction status: + +- `'I'` — idle, not in a transaction. **Safe to release the conn.** +- `'T'` — in a transaction. Conn is mid-`BEGIN`, must be held. +- `'E'` — in a failed transaction. Conn is mid-`BEGIN` after error, + must be held until `ROLLBACK`. + +So the multiplex rule is: at every `'Z'`, look at the status byte. If +`'I'`, return to pool and acquire a (potentially different) conn for +the next transaction. Otherwise, keep the same conn held. + +## Implementation pieces + +Four new pieces, plus wiring: + +### 1. `ReadyForQueryWatcher` + +`proxy/src/pglb/copy_bidirectional.rs`. Mirror of `TerminateFilter` on +the read side. Wraps an `AsyncRead + AsyncWrite`, transparently +forwards bytes, parses Postgres BE protocol message headers as they +flow through, and tracks `last_status: Option` from each `'Z'` +message body byte. + +State machine same shape as `TerminateFilter`: + +``` +AwaitingHeader { header: [u8; 5], pos } → InBody { tag, remaining, body_seen } +``` + +Difference: doesn't drop messages, just observes. For `tag == b'Z'` +messages the first body byte is the status; we record it and let it +flow through to the client. + +```rust +if *tag == b'Z' && *body_seen == 0 { + let status = limited.filled()[0]; + this.last_status = Some(status); + this.saw_ready_for_query = true; +} +``` + +`take_ready_for_query()` returns `true` if a `'Z'` was observed since +the last call (and clears the flag). The boundary pump uses this as +the "transaction over" signal. + +### 2. `copy_bidirectional_until_boundary` + +Same file. Variant of `copy_bidirectional_client_compute_pooled` that +returns when **either** boundary is reached: + +```rust +pub enum BoundaryReason { + ReadyForQuery(u8), // status byte + ClientTerminated, + ComputeClosed, +} +``` + +Uses `TerminateFilter` on client and `ReadyForQueryWatcher` on +compute, never calls `poll_shutdown` on compute. The poll_fn body: + +```rust +if filtered_client.saw_terminate() { return ... ClientTerminated; } +let _ = transfer_one_direction_no_shutdown(client→compute, ...); +if filtered_client.saw_terminate() { return ... ClientTerminated; } +let _ = transfer_one_direction_no_shutdown(compute→client, ...); +if watched_compute.take_ready_for_query() { + return ... ReadyForQuery(watched_compute.last_status()); +} +``` + +### 3. `TcpPoolManager::try_acquire_idle` + +`proxy/src/tcp_pool.rs`. Non-blocking pool re-acquire — pops one idle +conn for the given key if any. Returns `None` if the per-key pool is +empty. No connect closure required. + +This is what the multiplex loop calls between transactions, instead +of the regular `acquire_or_connect`. We don't have a connect closure +plumbed all the way through; if the pool is empty, the simplest MVP +behaviour is to disconnect the client cleanly. + +### 4. `proxy_pass_transaction_mode` + +`proxy/src/pglb/passthrough.rs`. The outer multiplex loop: + +``` +loop { + boundary = copy_bidirectional_until_boundary(client, compute) + + match boundary { + ReadyForQuery('I'): + release compute → pool + try_acquire_idle next compute + if None: close client cleanly + ReadyForQuery('T'|'E'): + keep compute held, continue inner pump + ClientTerminated: + release if last_status was 'I' (clean idle) + discard if mid-transaction (open BEGIN block) + exit loop + ComputeClosed: + error path + } +} +``` + +The `last_known_status` variable tracks tx status across iterations +so the `ClientTerminated` branch knows whether the conn is safe to +pool. Initialised to `'I'` because compute's first `'Z'` (sent during +`forward_compute_params_to_client` in handle_client) has already +landed. + +## How dispatch works + +`ProxyPassthrough::proxy_pass` branches three ways: + +```rust +if transaction_mode { + return proxy_pass_transaction_mode(self).await; +} +// otherwise, the existing session-pooled or direct path +``` + +`transaction_mode` is `tcp_pool_checkout.is_some() && config.mode == +TcpPoolMode::Transaction`. Default mode is `Session`, so existing +deployments aren't affected. + +## Empirical multiplexing + +Verification that the loop actually multiplexes: 10 concurrent psql +sessions, each running 3 queries, in transaction mode. The pool +opened **5 distinct compute backends** for 30 queries; multiple +client sessions were served by the same backend across transactions. +`pg_stat_activity` showed 5 idle backends pooled afterward. + +``` +c3, c4, c5, c8 → pid 49787 (4 sessions sharing one backend) +c2, c7, c9 → pid 49789 (3 sessions sharing one backend) +c1 → pid 49792 +c6 → pid 49791 +c10 → pid 49790 +``` + +That's the pgbouncer-style multiplex. Working correctly. + +## Files + +- `proxy/src/pglb/copy_bidirectional.rs` — `ReadyForQueryWatcher`, + `BoundaryReason`, `copy_bidirectional_until_boundary` +- `proxy/src/pglb/passthrough.rs` — `proxy_pass_transaction_mode`, + dispatch in `ProxyPassthrough::proxy_pass` +- `proxy/src/tcp_pool.rs` — `try_acquire_idle`, `key()` accessor on + `TcpPoolCheckout` +- `proxy/src/config.rs` — `TcpPoolMode` enum, `mode` field on + `TcpPoolConfig` +- `proxy/src/binary/proxy.rs` — `--tcp-pool-mode` clap arg, + `TcpPoolModeArg` +- `proxy/src/binary/local_proxy.rs` — fill in the new field +- `proxy/src/compute/mod.rs` — `#[derive(Clone)]` on `AuthInfo` + (prep for a reusable connect closure that I ended up not using; + could be reverted) + +## Read in this order + +1. **`pgcat/src/client.rs:891-1295`** if available locally — the + reference. The two-loop structure (`loop` at line 891 outer, + `loop` at line 1172 inner with `if !server.in_transaction() + break`) is the canonical PgBouncer-style implementation. +2. **`proxy/src/pglb/copy_bidirectional.rs::ReadyForQueryWatcher`** — + the protocol-watching primitive. +3. **`proxy/src/pglb/copy_bidirectional.rs::copy_bidirectional_until_boundary`** — + the inner pump. +4. **`proxy/src/pglb/passthrough.rs::proxy_pass_transaction_mode`** — + the outer multiplex loop. +5. **`proxy/src/tcp_pool.rs::try_acquire_idle`** — non-blocking + re-acquire. + +## Things I'd challenge in review + +- **There is no `peek-then-acquire` on the client side.** This is the + big one. The current implementation always holds a compute conn + during the bidir pump's `Pending` state — including while the + client is between queries. Releasing on idle would require waiting + for the client to send the next byte *before* acquiring, which + needs a peek-without-consume on the client stream. We don't have + that primitive in our wrappers (`MeasuredStream`, `Stream`, + `PqStream`). pgcat sidesteps this because it reads message-framed + on the client side — `read_message` blocks until a complete + message arrives, *then* it acquires. + Without peek, transaction mode's only release window is the + microseconds between "compute sent `'Z' I`" and "client sends + next byte". That's what limits the throughput win on benchmarks. + +- **Pool-empty-on-reacquire just disconnects the client.** If + transaction-mode is on and `max_conns_per_key` is sized too small + for concurrent demand, clients can be cleanly closed mid-session. + Real pooler would either wait or open a new conn. MVP punt. + +- **`AuthInfo` was made `Clone`.** I did this to support a reusable + connect closure for the multiplex loop, then ended up not using it. + Could be reverted, but it's also harmless. Worth considering for a + future iteration that wants `acquire_or_connect` instead of + `try_acquire_idle`. + +- **Standard PgBouncer transaction-mode caveats apply.** They're noted + in the doc-comment on `proxy_pass_transaction_mode`. Specifically: + `SET` (without `LOCAL`), `LISTEN`, simple `PREPARE`, temp tables, + advisory locks won't survive across transactions. We don't detect + these — the user is on their own. + +- **Cancellation across multiplexed conns is broken.** `pg_cancel_backend(pid)` + takes a pid that's been synthesized by the proxy (in the + was-reused / synthesized BackendKeyData branch) and doesn't + correspond to any single backend. PgBouncer has its own cancel + routing for this. We don't. + +- **Why expose `key()` on `TcpPoolCheckout`?** The multiplex loop + needs `pool_key` to call `try_acquire_idle` after releasing. The + cleanest factoring would be `try_acquire_idle_for(checkout)` that + takes the checkout — but that's awkward because the checkout has + already been consumed by the prior `release()`. The `key()` + accessor lets us snapshot the key before consuming. Alternative: + store the key separately in the multiplex loop's state (it's + cheap, smol_strs). + +- **The `Compute is done, terminate client` info-log fires once per + pooled session even on clean session-mode exit.** Look at + `copy_bidirectional_client_compute_pooled`. Probably should be + silenced or moved to debug. + +- **No tests.** The change is 470 lines and there's no unit test + beyond the existing ones. Worth adding at least one for the + watcher's `take_ready_for_query` semantics with the various + status bytes, and one integration-style test that verifies the + multiplex pid-sharing in a smoke harness. diff --git a/docs/review/05-open-questions.md b/docs/review/05-open-questions.md new file mode 100644 index 000000000000..4019d7f144f2 --- /dev/null +++ b/docs/review/05-open-questions.md @@ -0,0 +1,170 @@ +# 05 — Open questions and weak points + +The honest list of things I'd push back on if I were reviewing me. +Read this last; it'll change how you read everything else. + +Organised by area. + +## Session-pool fix (`6a361a996`) + +- **The `_no_shutdown` variant of `transfer_one_direction` has a dead + `ShuttingDown` arm.** It exists to keep the state-machine match + exhaustive. An alternative is two separate state enums; the cost is + more types for very little code-clarity gain. I went with the dead + arm. A reasonable reviewer could push the other way. + +- **`TerminateFilter` returning EOF (zero-fill `Ready`) is the actual + shape of the protocol violation.** A cleaner design: have the + filter return `Pending` with a self-wake when it sees Terminate, + and have the pump check `saw_terminate` at the top of each + iteration before doing anything. Strictly cleaner; downside is + self-wake plumbing that has to be careful not to busy-loop. I went + with the no-shutdown variant because it's local to the bug. The + filter-returns-Pending variant would probably be worth migrating + to in a future cleanup pass. + +## Multi-endpoint (`fa654b133`) + +- **Compute-endpoint-map is parsed at startup and never refreshed.** + Real cplane would react to endpoint moves. Not a problem for our + testing use-case but worth flagging. + +- **No tests for the parser.** Should add at least three: well-formed + multi-entry input, trailing comma, missing `=`, duplicate id. + Trivially small to add. + +## Benchmarks (`9ce39d144` + `ef3db262b`) + +- **macOS ephemeral-port exhaustion contaminates `-C` runs at + C ≥ 50.** The 19 transient `EADDRNOTAVAIL` errors at C=100 -C are + a kernel artefact. A Linux re-run would give cleaner numbers. + +- **`REPS=2` is the bare minimum for a stddev.** Population stddev + on two points is just |a-b|/2. Means are trustworthy; stddev + columns are not very informative. The original `REPS=3 DURATION=30s` + sweep would have been more rigorous; we ran the short variant for + time. + +- **No PgBouncer comparison.** Per the brief, deliberately deferred + until transaction pooling. Now that we have transaction pooling, + the comparison is unblocked. + +- **Memory snapshot is one moment.** Doesn't characterise growth + under sustained churn. Soak test would be more informative. + +- **The benchmark only stresses pgbench scenarios.** No web-app-like + workload (think-time between transactions). That's the workload + where transaction-mode pool benefit would actually surface, and + we don't have data for it. + +## Transaction mode (`29f0bc6d2`) — the longest list + +- **No client-side peek means we hold compute during client idle.** + The single biggest design limitation. Without a `peek_byte()` on + the client stream, the multiplex loop always has compute held while + waiting for the client's next message. That window between "we + released compute on `'Z' I`" and "client sends next byte" is + microseconds, not milliseconds. **This is why the pgbench + benchmark doesn't show a transaction-mode TPS win.** + pgcat's design is fundamentally cleaner because it uses + message-framed reads on the client (`read_message` blocks for a + full message), which acts as the peek-then-acquire primitive + naturally. Migrating to that model is the big follow-up. + +- **`try_acquire_idle` returning `None` cleanly closes the client.** + Real pooler would either wait briefly for an idle conn (with + a tokio Notify) or fall back to opening a new one. We don't have + the connect closure plumbed through the multiplex loop, so we + can't open new. Fine for our pool-size=200 benchmark scenario + but a real deployment would want graceful contention handling. + +- **`last_known_status` initialised to `'I'` is a load-bearing + assumption.** It assumes that by the time `proxy_pass_transaction_mode` + runs, compute has emitted `'Z' I` and we've consumed it via + `forward_compute_params_to_client` in handle_client. If that + invariant ever changes (e.g., a code path that enters proxy_pass + without forwarding params), the assumption breaks silently and we + could pool a mid-tx conn. Worth either asserting or initialising + to `Some(b'I')` only after explicit ack. + +- **Standard PgBouncer caveats apply but aren't enforced.** `SET` + (without `LOCAL`), `LISTEN`, simple `PREPARE`, temp tables, + advisory locks. We could detect these by parsing the client's + Query text — pgcat does, in `QueryRouter::parse`. We don't. The + user has to know not to use them. + +- **Cancellation correctness.** `pg_cancel_backend(pid)` takes the + synthesized pid the client received in `BackendKeyData`. In + transaction mode that pid doesn't correspond to any single + backend (multiplexed). pgcat handles this with a separate + cancel-routing map keyed on `(process_id, secret_key)` → + `(real_backend_pid, real_secret_key, server)`. We don't have this + yet. Also affects session mode if pool reuse changes the actual + backend pid out from under the client. + +- **The `cancel_user_info` / cancel_token path in `proxy/mod.rs` was + designed assuming one compute conn per session.** In transaction + mode, the conn rotates. Currently the cancel_token captures the + *initial* compute's `process_id` and `secret_key`. After rotation + these are stale. Not just a "no support" issue — it's an actively + wrong reference. Worth checking whether anything reads them later. + +- **Extended-query protocol is not specially handled.** Parse / Bind / + Describe / Execute / Close / Sync messages just flow through as + bytes; we wait for `'Z'` after Sync. Should "just work" because + Postgres only emits `'Z'` after Sync (or implicit Sync). But: I + haven't tested this end-to-end with a client that uses extended + query. pgbench uses simple-Q by default, which is what we tested. + An asyncpg / pgx-style client would exercise extended query and is + worth smoke-testing. + +- **No tests.** Already noted in `04-`. The watcher's state machine + is small but has tricky edge cases (header split across two + poll_reads, body with `'Z'` followed immediately by another + message). Worth at least one unit test. + +- **The `Compute is done, terminate client` info-log fires too + often.** Inherited from `copy_bidirectional_client_compute_pooled`. + In transaction mode it can fire on every iteration. Probably + should be debug. + +- **`Compute closed connection` on `BoundaryReason::ComputeClosed` + uses `io::Error::other`.** Slightly opaque. A typed variant of + `ErrorSource` would be cleaner. + +## Cross-cutting + +- **CLAUDE.md or repo docs**. There's no top-level documentation of + what was changed and how to use it. The benchmarks/README.md is + partial. A short summary in a top-level README of the + `mel/pool-fixes-and-benchmarks` branch's contributions would help + a real reviewer get oriented faster than reading commits. + +- **Telemetry is light.** We have `info!` logs on pool open / reuse + and warn on discard, but no Prometheus metrics for "pool hit rate", + "transaction boundaries per second", "auth-pg query latency", etc. + Real production deployment would want these. + +- **Memory leak risk on long-running sessions in transaction mode.** + Each release/re-acquire creates a new `MeasuredStream`. If those + hold any state internally (a closure capture growing? the recorded + byte counts?), they accumulate. I don't think they do, but worth + spot-checking. + +## What I'd do next if I had another day + +1. Add the client-side peek primitive and restructure the multiplex + loop to release-then-wait-then-acquire. This is the + transaction-mode win we don't currently get. +2. Write a custom pgbench script (`pgbench -f` with `pg_sleep`) that + simulates think-time between transactions. Re-run the proxy_txn + sweep on it. Numbers should move. +3. Auth-info caching. The `pg_authid` lookup per fresh client connection + is the actual bottleneck on `-C`-style workloads. A + `(endpoint, role) → AuthInfo` LRU with a short TTL would cut a + huge chunk of the proxy → auth-pg traffic. +4. Cancellation routing: a `(synthesized_pid, secret_key) → + real_backend` map, keyed at session start, used by the cancel + path. PgBouncer-style. +5. Tests. Specifically: watcher state-machine, the multiplex loop's + release-on-`'Z' I` decision, the multi-endpoint parser. diff --git a/proxy-cplane-api/Cargo.toml b/proxy-cplane-api/Cargo.toml new file mode 100644 index 000000000000..856e70337c86 --- /dev/null +++ b/proxy-cplane-api/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "proxy-cplane-api" +version = "0.1.0" +edition.workspace = true +license.workspace = true + +[dependencies] +axum.workspace = true +clap = { workspace = true, features = ["derive"] } +humantime.workspace = true +serde = { workspace = true, features = ["derive"] } +serde_json.workspace = true +tokio = { workspace = true, features = ["full"] } +workspace_hack.workspace = true diff --git a/proxy-cplane-api/README.md b/proxy-cplane-api/README.md new file mode 100644 index 000000000000..811272c6512c --- /dev/null +++ b/proxy-cplane-api/README.md @@ -0,0 +1,148 @@ +# proxy-cplane-api + +A small mock control plane HTTP server that implements just enough of the +`cplane-v1` API for `proxy` to authenticate clients and route them to a +running compute. Useful for local development and benchmarking — no real +control plane required. + +It exposes three endpoints: + +| Endpoint | Behaviour | +| ----------------------------------- | -------------------------------------------------------------------------- | +| `GET /get_endpoint_access_control` | Returns the static `--scram-secret` for every role. | +| `GET /wake_compute` | Returns a compute address (static or resolved via `neon_local`). | +| `GET /endpoints/{id}/jwks` | Returns JWKS entries from the optional `--jwks-config` JSON file. | + +## Two `/wake_compute` modes + +- **Static mode** (default): every request returns `--compute-address`. Use + this when one long-running compute backs every connection (e.g. benchmarks). +- **`neon_local` mode** (`--neon-local-repo-dir `): on each request the + mock reads `/endpoints//endpoint.json`, returns the + per-endpoint `pg_port`, and shells out to `neon_local endpoint start ` + if the compute isn't already listening. Returns 404 with reason + `ENDPOINT_NOT_FOUND` if the endpoint directory is missing. + +## SCRAM secret + +The proxy must hand the client the same SCRAM secret that the compute has +stored in `pg_authid.rolpassword`, otherwise the SCRAM exchange fails. The +mock returns a single static secret to every caller and you wire matching +credentials by creating roles with the pre-hashed secret as the password — +Postgres stores literal values starting with `SCRAM-SHA-256$` verbatim +instead of re-hashing them. + +The default `--scram-secret` is the canonical hash of the literal password +`"password"`: + +``` +SCRAM-SHA-256$4096:M2ZX/kfDSd3vv5iFO/QNUA==$mookt3EiEpd/vMqGbd7df3qVwfyUfM91Ps72sNewNg4=:3nMi8eBSHggIBNSgAik6lQnE3hQcsS+myylZlYgNA1U= +``` + +So creating any role with that literal as its password lets a client connect +through the proxy with password `"password"`. + +## End-to-end walkthrough with `neon_local` + +This is the full local recipe: spin up the storage stack with `neon_local`, +create an endpoint with a SCRAM-matching role, run `proxy-cplane-api` against +the `.neon` repo, run `proxy`, and connect with `psql`. + +1. **Build** the bits we need: + + ```sh + cargo build --bin neon_local + cargo build -p proxy-cplane-api + cargo build -p proxy --features testing --bin proxy + ``` + +2. **Initialize neon_local** and start the storage stack. `neon_local` + defaults to `./.neon` for its data directory; you can override it via + `NEON_REPO_DIR`: + + ```sh + cargo neon init + cargo neon start + cargo neon tenant create --set-default --pg-version 17 + cargo neon endpoint create main --pg-version 17 --update-catalog + cargo neon endpoint start main --create-test-user + ``` + +3. **Create the SCRAM-matching role** inside the running compute. The + connection string is printed by the previous command (look for + `Starting postgres node at ...`) and listed by `cargo neon endpoint list`. + Connect as the `cloud_admin` superuser and create a role whose stored + `pg_authid.rolpassword` is the literal SCRAM secret: + + ```sh + psql "postgres://cloud_admin@127.0.0.1:55432/postgres" -c \ + "CREATE ROLE \"user\" WITH LOGIN ENCRYPTED PASSWORD \ + 'SCRAM-SHA-256\$4096:M2ZX/kfDSd3vv5iFO/QNUA==\$mookt3EiEpd/vMqGbd7df3qVwfyUfM91Ps72sNewNg4=:3nMi8eBSHggIBNSgAik6lQnE3hQcsS+myylZlYgNA1U=';" \ + -c "GRANT ALL ON DATABASE postgres TO \"user\";" + ``` + + (The `$` characters are escaped for shell quoting; the actual stored value + is the unescaped literal above.) + +4. **Stop the endpoint** so we can prove the mock starts it on demand: + + ```sh + cargo neon endpoint stop main + ``` + +5. **Generate a self-signed cert** for proxy (same as `proxy/README.md`): + + ```sh + openssl req -new -x509 -days 365 -nodes -text \ + -out server.crt -keyout server.key \ + -subj "/CN=*.local.neon.build" + ``` + +6. **Start `proxy-cplane-api`** pointed at the `.neon` directory: + + ```sh + ./target/debug/proxy-cplane-api \ + --listen 127.0.0.1:3010 \ + --neon-local-repo-dir "$PWD/.neon" \ + --neon-local-bin "$PWD/target/debug/neon_local" + ``` + +7. **Start `proxy`** pointed at the mock: + + ```sh + ./target/debug/proxy \ + --auth-backend cplane-v1 \ + --auth-endpoint http://127.0.0.1:3010 \ + -c server.crt -k server.key + ``` + +8. **Connect.** The hostname's first label is the endpoint id (`main`), so + we use `main.local.neon.build` (which resolves to `127.0.0.1`): + + ```sh + PGSSLROOTCERT=./server.crt psql \ + "postgres://user:password@main.local.neon.build:4432/postgres?sslmode=verify-full" + ``` + + The proxy hits `/wake_compute`, the mock spawns `neon_local endpoint start + main`, and `psql` lands on the freshly started compute. Look at the + `proxy-cplane-api` logs — the first connection prints `starting endpoint + main via neon_local` and reports `cold_start_info: pool_miss`; subsequent + connections find the compute already up and report `warm`. + +## CLI reference + +```text +proxy-cplane-api [OPTIONS] + +Options: + --listen Listen address [default: 0.0.0.0:3010] + --scram-secret SCRAM-SHA-256 secret returned for every role + [default: hash of "password"] + --compute-address Static compute address used when --neon-local-repo-dir + is not set [default: 127.0.0.1:5432] + --jwks-config JSON file with JWKS entries for /endpoints/{id}/jwks + --neon-local-repo-dir Enable neon_local mode against this .neon directory + --neon-local-bin Path to neon_local binary [default: neon_local] + --start-timeout How long to wait for compute to come up [default: 30s] +``` diff --git a/proxy-cplane-api/src/main.rs b/proxy-cplane-api/src/main.rs new file mode 100644 index 000000000000..7aabc8ea4c6f --- /dev/null +++ b/proxy-cplane-api/src/main.rs @@ -0,0 +1,432 @@ +//! Mock proxy control plane API. +//! +//! Two modes for `/wake_compute`: +//! +//! - **Static mode** (default): returns the fixed `--compute-address`. Useful for +//! benchmarks where one long-running compute backs every request. +//! - **neon_local mode** (`--neon-local-repo-dir`): looks the endpoint up in +//! `/endpoints//endpoint.json`, returns the per-endpoint +//! `pg_port`, and shells out to `neon_local endpoint start ` if the +//! compute isn't already listening. Returns 404 if the endpoint doesn't +//! exist. +//! +//! `/get_endpoint_access_control` returns a single static SCRAM secret +//! (`--scram-secret`) for every role; operators wire matching credentials by +//! creating compute roles whose stored `pg_authid.rolpassword` is the same +//! literal string. + +use std::net::SocketAddr; +use std::path::PathBuf; +use std::time::{Duration, Instant}; + +use axum::Json; +use axum::extract::{Path, Query, State}; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use axum::routing::get; +use clap::Parser; +use serde::{Deserialize, Serialize}; +use tokio::signal::unix::{SignalKind, signal}; + +/// Pre-hashed SCRAM-SHA-256 secret for the literal password `"password"`. +/// Mirrors the constant used by hadron's `proxy-bench/cplane-mock`. +const DEFAULT_SCRAM_SECRET: &str = "SCRAM-SHA-256$4096:M2ZX/kfDSd3vv5iFO/QNUA==$mookt3EiEpd/vMqGbd7df3qVwfyUfM91Ps72sNewNg4=:3nMi8eBSHggIBNSgAik6lQnE3hQcsS+myylZlYgNA1U="; + +#[derive(Clone)] +struct AppState { + scram_secret: String, + compute_address: String, + jwks_config: Option, + neon_local: Option, +} + +#[derive(Clone)] +struct NeonLocalConfig { + repo_dir: PathBuf, + bin: PathBuf, + start_timeout: Duration, +} + +#[derive(Clone, Deserialize)] +struct JwksConfig { + jwks: Vec, +} + +#[derive(Clone, Serialize, Deserialize)] +struct JwksEntry { + id: String, + role_names: Vec, + jwks_url: String, + provider_name: String, + jwt_audience: Option, +} + +#[derive(Parser)] +struct Args { + /// Static SCRAM-SHA-256 secret returned by `/get_endpoint_access_control` + /// for every role. The default matches the literal password "password". + #[clap(long, default_value = DEFAULT_SCRAM_SECRET)] + scram_secret: String, + + /// Compute address returned by `/wake_compute` in static mode. + /// Ignored when `--neon-local-repo-dir` is set. + #[clap(long, default_value = "127.0.0.1:5432")] + compute_address: String, + + #[clap(long, default_value = "0.0.0.0:3010")] + listen: SocketAddr, + + /// Path to a JSON file with JWKS config (same format as local_proxy.json). + /// If provided, the /endpoints/{id}/jwks endpoint will serve these JWKS. + #[clap(long)] + jwks_config: Option, + + /// Path to a `neon_local` `.neon` data directory. When set, `/wake_compute` + /// resolves the compute address by reading endpoint configs from this + /// directory and starts the endpoint via `neon_local` if it isn't running. + #[clap(long)] + neon_local_repo_dir: Option, + + /// Path to the `neon_local` binary. Defaults to looking up `neon_local` + /// on PATH. + #[clap(long, default_value = "neon_local")] + neon_local_bin: PathBuf, + + /// How long to wait for a compute to come up after spawning + /// `neon_local endpoint start`. + #[clap(long, default_value = "30s", value_parser = humantime::parse_duration)] + start_timeout: Duration, +} + +#[tokio::main] +async fn main() { + let args = Args::parse(); + + let jwks_config = if let Some(path) = &args.jwks_config { + let data = std::fs::read_to_string(path) + .unwrap_or_else(|e| panic!("failed to read jwks config file {}: {e}", path.display())); + let config: JwksConfig = serde_json::from_str(&data) + .unwrap_or_else(|e| panic!("failed to parse jwks config file {}: {e}", path.display())); + eprintln!( + "loaded {} JWKS entries from {}", + config.jwks.len(), + path.display() + ); + Some(config) + } else { + None + }; + + let neon_local = args.neon_local_repo_dir.map(|repo_dir| { + let repo_dir = repo_dir + .canonicalize() + .unwrap_or_else(|e| panic!("--neon-local-repo-dir {}: {e}", repo_dir.display())); + eprintln!("neon_local mode enabled, repo dir: {}", repo_dir.display()); + NeonLocalConfig { + repo_dir, + bin: args.neon_local_bin, + start_timeout: args.start_timeout, + } + }); + + let state = AppState { + scram_secret: args.scram_secret, + compute_address: args.compute_address, + jwks_config, + neon_local, + }; + + let app = axum::Router::new() + .route( + "/get_endpoint_access_control", + get(get_endpoint_access_control), + ) + .route("/wake_compute", get(wake_compute)) + .route("/endpoints/{id}/jwks", get(get_jwks)) + .with_state(state); + + let mut sigterm = signal(SignalKind::terminate()).unwrap(); + let listener = tokio::net::TcpListener::bind(args.listen).await.unwrap(); + eprintln!("listening on {}", args.listen); + axum::serve(listener, app) + .with_graceful_shutdown(async move { + sigterm.recv().await; + }) + .await + .unwrap(); +} + +#[derive(Deserialize)] +struct AccessControlQuery { + #[allow(dead_code)] + role: Option, + #[allow(dead_code)] + endpointish: Option, +} + +#[derive(Serialize)] +struct AccessControlResponse { + role_secret: String, + allowed_ips: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + project_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + account_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + allowed_vpc_endpoint_ids: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + block_public_connections: Option, + #[serde(skip_serializing_if = "Option::is_none")] + block_vpc_connections: Option, +} + +async fn get_endpoint_access_control( + State(state): State, + Query(_query): Query, +) -> Json { + Json(AccessControlResponse { + role_secret: state.scram_secret.clone(), + allowed_ips: vec![], + project_id: None, + account_id: None, + allowed_vpc_endpoint_ids: None, + block_public_connections: None, + block_vpc_connections: None, + }) +} + +#[derive(Deserialize)] +struct WakeComputeQuery { + endpointish: Option, +} + +#[derive(Serialize)] +struct WakeComputeResponse { + address: String, + #[serde(skip_serializing_if = "Option::is_none")] + server_name: Option, + aux: MetricsAuxInfo, +} + +#[derive(Serialize)] +struct MetricsAuxInfo { + endpoint_id: String, + project_id: String, + branch_id: String, + compute_id: String, + cold_start_info: String, +} + +async fn wake_compute( + State(state): State, + Query(query): Query, +) -> Result, ControlPlaneErrorResponse> { + let endpoint_id = query.endpointish.unwrap_or_default(); + + let (address, cold_start_info) = match &state.neon_local { + Some(cfg) => resolve_via_neon_local(cfg, &endpoint_id).await?, + None => (state.compute_address.clone(), "warm".to_owned()), + }; + + Ok(Json(WakeComputeResponse { + address, + server_name: None, + aux: MetricsAuxInfo { + endpoint_id, + project_id: "project-mock".into(), + branch_id: "branch-mock".into(), + compute_id: "compute-mock".into(), + cold_start_info, + }, + })) +} + +/// Subset of the `endpoint.json` written by `neon_local`. Only the field we +/// need is parsed; serde ignores everything else. +#[derive(Deserialize)] +struct EndpointConf { + pg_port: u16, +} + +async fn resolve_via_neon_local( + cfg: &NeonLocalConfig, + endpoint_id: &str, +) -> Result<(String, String), ControlPlaneErrorResponse> { + if endpoint_id.is_empty() { + return Err(ControlPlaneErrorResponse::endpoint_not_found( + "missing endpointish query parameter", + )); + } + + let conf_path = cfg + .repo_dir + .join("endpoints") + .join(endpoint_id) + .join("endpoint.json"); + + let conf_bytes = match std::fs::read(&conf_path) { + Ok(b) => b, + Err(e) if e.kind() == std::io::ErrorKind::NotFound => { + return Err(ControlPlaneErrorResponse::endpoint_not_found(&format!( + "endpoint {endpoint_id} not found at {}", + conf_path.display() + ))); + } + Err(e) => { + return Err(ControlPlaneErrorResponse::internal(format!( + "failed to read {}: {e}", + conf_path.display() + ))); + } + }; + + let conf: EndpointConf = serde_json::from_slice(&conf_bytes).map_err(|e| { + ControlPlaneErrorResponse::internal(format!("failed to parse {}: {e}", conf_path.display())) + })?; + + let addr = SocketAddr::from(([127, 0, 0, 1], conf.pg_port)); + let already_up = + std::net::TcpStream::connect_timeout(&addr, Duration::from_millis(200)).is_ok(); + + let cold_start_info = if already_up { + "warm".to_owned() + } else { + eprintln!("starting endpoint {endpoint_id} via neon_local"); + let status = tokio::process::Command::new(&cfg.bin) + .args(["endpoint", "start", endpoint_id]) + .env("NEON_REPO_DIR", &cfg.repo_dir) + .status() + .await + .map_err(|e| { + ControlPlaneErrorResponse::internal(format!( + "failed to spawn {}: {e}", + cfg.bin.display() + )) + })?; + if !status.success() { + return Err(ControlPlaneErrorResponse::internal(format!( + "neon_local endpoint start {endpoint_id} exited with {status}" + ))); + } + wait_for_tcp(addr, cfg.start_timeout).await?; + "pool_miss".to_owned() + }; + + Ok((addr.to_string(), cold_start_info)) +} + +async fn wait_for_tcp( + addr: SocketAddr, + timeout: Duration, +) -> Result<(), ControlPlaneErrorResponse> { + let deadline = Instant::now() + timeout; + loop { + if std::net::TcpStream::connect_timeout(&addr, Duration::from_millis(200)).is_ok() { + return Ok(()); + } + if Instant::now() >= deadline { + return Err(ControlPlaneErrorResponse::internal(format!( + "compute at {addr} did not become reachable within {timeout:?}" + ))); + } + tokio::time::sleep(Duration::from_millis(200)).await; + } +} + +/// Mirrors the proxy-side `ControlPlaneErrorMessage` shape so proxy decodes +/// the error reason and treats `ENDPOINT_NOT_FOUND` as non-retryable. +#[derive(Serialize)] +struct ControlPlaneErrorBody { + error: String, + status: ControlPlaneErrorStatus, +} + +#[derive(Serialize)] +struct ControlPlaneErrorStatus { + code: &'static str, + message: String, + details: ControlPlaneErrorDetails, +} + +#[derive(Serialize)] +struct ControlPlaneErrorDetails { + error_info: Option, + retry_info: Option<()>, + user_facing_message: Option, +} + +#[derive(Serialize)] +struct ControlPlaneErrorInfo { + reason: &'static str, +} + +#[derive(Serialize)] +struct UserFacingMessage { + message: String, +} + +struct ControlPlaneErrorResponse { + status: StatusCode, + body: ControlPlaneErrorBody, +} + +impl ControlPlaneErrorResponse { + fn endpoint_not_found(message: &str) -> Self { + Self { + status: StatusCode::NOT_FOUND, + body: ControlPlaneErrorBody { + error: message.to_owned(), + status: ControlPlaneErrorStatus { + code: "NOT_FOUND", + message: message.to_owned(), + details: ControlPlaneErrorDetails { + error_info: Some(ControlPlaneErrorInfo { + reason: "ENDPOINT_NOT_FOUND", + }), + retry_info: None, + user_facing_message: Some(UserFacingMessage { + message: message.to_owned(), + }), + }, + }, + }, + } + } + + fn internal(message: String) -> Self { + Self { + status: StatusCode::INTERNAL_SERVER_ERROR, + body: ControlPlaneErrorBody { + error: message.clone(), + status: ControlPlaneErrorStatus { + code: "INTERNAL", + message: message.clone(), + details: ControlPlaneErrorDetails { + error_info: None, + retry_info: None, + user_facing_message: Some(UserFacingMessage { message }), + }, + }, + }, + } + } +} + +impl IntoResponse for ControlPlaneErrorResponse { + fn into_response(self) -> Response { + (self.status, Json(self.body)).into_response() + } +} + +#[derive(Serialize)] +struct JwksResponse { + jwks: Vec, +} + +async fn get_jwks(State(state): State, Path(_id): Path) -> Json { + let jwks = match &state.jwks_config { + Some(config) => config.jwks.clone(), + None => vec![], + }; + Json(JwksResponse { jwks }) +} diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 0ece79c3290f..23a1bbd6f39c 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -22,6 +22,7 @@ aws-sdk-iam.workspace = true aws-sigv4.workspace = true base64.workspace = true bstr.workspace = true +bb8 = "0.8.6" bytes = { workspace = true, features = ["serde"] } camino.workspace = true chrono.workspace = true diff --git a/proxy/README.md b/proxy/README.md index ce957b90af30..62180b77834b 100644 --- a/proxy/README.md +++ b/proxy/README.md @@ -104,6 +104,10 @@ cases where it is hard to use rows represented as objects (e.g. when several fie ## Test proxy locally +For an end-to-end local setup that uses `neon_local`-managed computes +(pageserver + safekeeper + Postgres) instead of a single shared Postgres, +see [`proxy-cplane-api/README.md`](../proxy-cplane-api/README.md). + Proxy determines project name from the subdomain, request to the `round-rice-566201.somedomain.tld` will be routed to the project named `round-rice-566201`. Unfortunately, `/etc/hosts` does not support domain wildcards, so we can use *.local.neon.build` which resolves to `127.0.0.1`. We will need to have a postgres instance. Assuming that we have set up docker we can set it up as follows: diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index 2a02748a10f4..c6257e16aa59 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -205,6 +205,7 @@ async fn authenticate( endpoint: db_info.aux.endpoint_id.as_str().into(), user: user.clone(), options: NeonOptions::default(), + use_tcp_pool: false, }; ctx.set_dbname(db_info.dbname.into()); diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 491f14b1b6f6..54555b161e7e 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -82,7 +82,8 @@ pub(crate) async fn password_hack_no_authentication( ComputeUserInfo { user: info.user, options: info.options, - endpoint: payload.endpoint, + use_tcp_pool: payload.endpoint.is_pooler(), + endpoint: payload.endpoint.normalize(), }, payload.password, )) diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index a6df2a701184..0dcd3e3137bf 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -159,6 +159,7 @@ pub(crate) struct ComputeUserInfo { pub(crate) endpoint: EndpointId, pub(crate) user: RoleName, pub(crate) options: NeonOptions, + pub(crate) use_tcp_pool: bool, } impl ComputeUserInfo { @@ -184,7 +185,8 @@ impl TryFrom for ComputeUserInfo { options: user_info.options, }), Some(endpoint) => Ok(ComputeUserInfo { - endpoint, + use_tcp_pool: endpoint.is_pooler(), + endpoint: endpoint.normalize(), user: user_info.user, options: user_info.options, }), diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index 86b64c62c957..992da635ce74 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -24,7 +24,7 @@ use crate::cancellation::CancellationHandler; #[cfg(feature = "rest_broker")] use crate::config::RestConfig; use crate::config::{ - self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig, + self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig, TcpPoolConfig, refresh_config_loop, }; use crate::control_plane::locks::ApiLocks; @@ -280,6 +280,14 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig tls_config: ArcSwapOption::from(None), metric_collection: None, http_config, + tcp_pool_config: TcpPoolConfig { + enabled: false, + mode: crate::config::TcpPoolMode::Session, + max_conns_per_key: 0, + max_total_conns: 0, + idle_timeout: Duration::ZERO, + fallback_direct_connect: false, + }, authentication_config: AuthenticationConfig { jwks_cache: JwkCache::default(), scram_thread_pool: ThreadPool::new(0), diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 583cdc95bf6d..b396bfe63a8e 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -36,7 +36,7 @@ use crate::config::RestConfig; use crate::config::refresh_config_loop; use crate::config::{ self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions, - ProxyConfig, ProxyProtocolV2, remote_storage_from_toml, + ProxyConfig, ProxyProtocolV2, TcpPoolConfig, remote_storage_from_toml, }; use crate::context::parquet::ParquetUploadArgs; use crate::http::health_server::AppMetrics; @@ -111,6 +111,18 @@ struct ProxyCliArgs { default_value = "http://localhost:3000/authenticate_proxy_request/" )] auth_endpoint: String, + /// Upstream compute endpoint for `--auth-backend=postgres`. + /// Example: `postgresql://host:5432/db?sslmode=require` + #[cfg(any(test, feature = "testing"))] + #[clap(long)] + compute_endpoint: Option, + /// Per-endpoint compute routing for `--auth-backend=postgres`. + /// Comma-separated `endpoint_id=postgresql://host:port/db?sslmode=...` entries. + /// Lookup priority: this map → `--compute-endpoint` → `--auth-endpoint`. + /// Example: `ep-A=postgresql://localhost:5433/db?sslmode=disable,ep-B=postgresql://localhost:5434/db?sslmode=disable` + #[cfg(any(test, feature = "testing"))] + #[clap(long)] + compute_endpoint_map: Option, /// JWT used to connect to control plane. #[clap( long, @@ -152,6 +164,8 @@ struct ProxyCliArgs { connect_compute_lock: String, #[clap(flatten)] sql_over_http: SqlOverHttpArgs, + #[clap(flatten)] + tcp_pool: TcpPoolArgs, /// timeout for scram authentication protocol #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)] scram_protocol_timeout: tokio::time::Duration, @@ -314,6 +328,53 @@ struct SqlOverHttpArgs { sql_over_http_max_response_size_bytes: usize, } +#[derive(clap::Args, Clone, Copy, Debug)] +struct TcpPoolArgs { + /// Enable experimental TCP session pooling. + #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] + tcp_pool_enabled: bool, + + /// Maximum number of pooled backend connections per endpoint+db+user key. + #[clap(long, default_value_t = 20)] + tcp_pool_max_conns_per_key: usize, + + /// Maximum number of pooled backend connections across all keys. + #[clap(long, default_value_t = 20000)] + tcp_pool_max_total_conns: usize, + + /// How long idle pooled backend connections should be kept. + #[clap(long, default_value = "5m", value_parser = humantime::parse_duration)] + tcp_pool_idle_timeout: tokio::time::Duration, + + /// If pool acquire fails, fallback to the existing direct-connect path. + #[clap(long, default_value_t = true, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] + tcp_pool_fallback_direct_connect: bool, + + /// Pool mode. `session` (default) holds a compute connection after the + /// first frontend message that needs compute. `transaction` returns the + /// connection to the pool at every transaction boundary (compute sends + /// ReadyForQuery status `'I'`); subsequent transactions on the same client + /// may land on different compute connections. + #[clap(long, value_enum, default_value_t = TcpPoolModeArg::Session)] + tcp_pool_mode: TcpPoolModeArg, +} + +#[derive(Clone, Copy, Debug, clap::ValueEnum)] +#[clap(rename_all = "kebab-case")] +enum TcpPoolModeArg { + Session, + Transaction, +} + +impl From for crate::config::TcpPoolMode { + fn from(arg: TcpPoolModeArg) -> Self { + match arg { + TcpPoolModeArg::Session => Self::Session, + TcpPoolModeArg::Transaction => Self::Transaction, + } + } +} + #[derive(clap::Args, Clone, Debug)] struct PgSniRouterArgs { /// listen for incoming client connections on ip:port @@ -542,6 +603,9 @@ pub async fn run() -> anyhow::Result<()> { // TODO: Add gc regardles of the metric collection being enabled. maintenance_tasks.spawn(usage_metrics::task_main(metrics_config)); } + if config.tcp_pool_config.enabled { + maintenance_tasks.spawn(async move { crate::tcp_pool::manager().gc_worker().await }); + } if let Some(client) = redis_client { // Try to connect to Redis 3 times with 1 + (0..0.1) second interval. @@ -693,6 +757,14 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes, max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes, }; + let tcp_pool_config = TcpPoolConfig { + enabled: args.tcp_pool.tcp_pool_enabled, + mode: args.tcp_pool.tcp_pool_mode.into(), + max_conns_per_key: args.tcp_pool.tcp_pool_max_conns_per_key, + max_total_conns: args.tcp_pool.tcp_pool_max_total_conns, + idle_timeout: args.tcp_pool.tcp_pool_idle_timeout, + fallback_direct_connect: args.tcp_pool.tcp_pool_fallback_direct_connect, + }; let authentication_config = AuthenticationConfig { jwks_cache: JwkCache::default(), scram_thread_pool: thread_pool, @@ -755,6 +827,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { tls_config, metric_collection, http_config, + tcp_pool_config, authentication_config, proxy_protocol_v2: args.proxy_protocol_v2, handshake_timeout: args.handshake_timeout, @@ -845,8 +918,21 @@ fn build_auth_backend( url.set_password(Some(&password)) .expect("Failed to set password"); } + let compute_endpoint = args + .compute_endpoint + .as_deref() + .map(parse_compute_endpoint) + .transpose()?; + let compute_endpoint_map = args + .compute_endpoint_map + .as_deref() + .map(parse_compute_endpoint_map) + .transpose()? + .unwrap_or_default(); let api = control_plane::client::mock::MockControlPlane::new( url, + compute_endpoint, + compute_endpoint_map, !args.is_private_access_proxy, ); let api = control_plane::client::ControlPlaneClient::PostgresMock(api); @@ -929,6 +1015,52 @@ fn build_auth_backend( } } +#[cfg(any(test, feature = "testing"))] +fn parse_compute_endpoint(endpoint: &str) -> anyhow::Result { + let url: ApiUrl = endpoint.parse()?; + ensure!( + url.host_str().is_some(), + "compute-endpoint must include a hostname" + ); + + if let Some((_, sslmode)) = url.query_pairs().find(|(k, _)| k == "sslmode") { + match sslmode.as_ref() { + "disable" | "require" => {} + other => { + bail!( + "unsupported sslmode in compute-endpoint: {other}. supported values: disable,require" + ); + } + } + } + + Ok(url) +} + +#[cfg(any(test, feature = "testing"))] +fn parse_compute_endpoint_map( + input: &str, +) -> anyhow::Result> { + let mut map = std::collections::HashMap::new(); + for entry in input.split(',').map(str::trim).filter(|s| !s.is_empty()) { + let (id, url_str) = entry.split_once('=').with_context(|| { + format!("compute-endpoint-map entry missing '=' separator: {entry}") + })?; + let id = id.trim(); + ensure!( + !id.is_empty(), + "compute-endpoint-map entry has empty endpoint id: {entry}" + ); + let url = parse_compute_endpoint(url_str.trim())?; + let prev = map.insert(crate::types::EndpointId::from(id), url); + ensure!( + prev.is_none(), + "compute-endpoint-map has duplicate entry for endpoint id: {id}" + ); + } + Ok(map) +} + async fn configure_redis( args: &ProxyCliArgs, ) -> anyhow::Result> { @@ -1000,4 +1132,27 @@ mod tests { ] ); } + + #[test] + fn parse_compute_endpoint_with_require_sslmode() { + let parsed = + super::parse_compute_endpoint("postgresql://example.com:5432/db?sslmode=require") + .expect("valid endpoint should parse"); + assert_eq!(parsed.host_str(), Some("example.com")); + assert_eq!(parsed.port(), Some(5432)); + } + + #[test] + fn parse_compute_endpoint_rejects_unknown_sslmode() { + let err = super::parse_compute_endpoint("postgresql://example.com/db?sslmode=prefer") + .expect_err("unsupported sslmode must fail"); + assert!(err.to_string().contains("unsupported sslmode")); + } + + #[test] + fn parse_compute_endpoint_allows_missing_sslmode() { + let parsed = super::parse_compute_endpoint("postgresql://example.com:5432/db") + .expect("missing sslmode should be allowed"); + assert_eq!(parsed.host_str(), Some("example.com")); + } } diff --git a/proxy/src/compute/mod.rs b/proxy/src/compute/mod.rs index 43cfe70206d6..4f799f0b22f3 100644 --- a/proxy/src/compute/mod.rs +++ b/proxy/src/compute/mod.rs @@ -135,6 +135,7 @@ pub enum Auth { } /// A config for authenticating to the compute node. +#[derive(Clone)] pub(crate) struct AuthInfo { /// None for local-proxy, as we use trust-based localhost auth. /// Some for sql-over-http, ws, tcp, and in most cases for console-redirect. @@ -212,6 +213,20 @@ impl AuthInfo { config } + pub(crate) fn tcp_pool_session_reset_query(&self) -> String { + let mut query = String::from("RESET ALL"); + + for (k, v) in self.server_params.iter() { + match k { + "database" | "user" | "replication" => {} + "options" => append_startup_options_settings(&mut query, v), + _ => append_set_config(&mut query, k, v), + } + } + + query + } + /// Apply startup message params to the connection config. pub(crate) fn set_startup_params( &mut self, @@ -266,6 +281,54 @@ impl AuthInfo { } } +fn append_startup_options_settings(query: &mut String, options: &str) { + let mut expect_setting = false; + + for opt in StartupMessageParams::parse_options_raw(options) { + if expect_setting { + append_name_value_setting(query, opt); + expect_setting = false; + continue; + } + + if opt == "-c" { + expect_setting = true; + } else if let Some(setting) = opt.strip_prefix("-c") + && !setting.is_empty() + { + append_name_value_setting(query, setting); + } else if let Some(setting) = opt.strip_prefix("--command=") { + append_name_value_setting(query, setting); + } + } +} + +fn append_name_value_setting(query: &mut String, setting: &str) { + if let Some((name, value)) = setting.split_once('=') { + append_set_config(query, name, value); + } +} + +fn append_set_config(query: &mut String, name: &str, value: &str) { + query.push_str("; SELECT pg_catalog.set_config("); + append_sql_literal(query, name); + query.push_str(", "); + append_sql_literal(query, value); + query.push_str(", false)"); +} + +fn append_sql_literal(query: &mut String, value: &str) { + query.push('\''); + for c in value.chars() { + if c == '\'' { + query.push_str("''"); + } else { + query.push(c); + } + } + query.push('\''); +} + impl ConnectInfo { /// Establish a raw TCP+TLS connection to the compute node. async fn connect_raw( diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 22902dbcabe8..de574c839a7e 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -31,6 +31,7 @@ pub struct ProxyConfig { pub tls_config: ArcSwapOption, pub metric_collection: Option, pub http_config: HttpConfig, + pub tcp_pool_config: TcpPoolConfig, pub authentication_config: AuthenticationConfig, #[cfg(feature = "rest_broker")] pub rest_config: RestConfig, @@ -74,6 +75,30 @@ pub struct HttpConfig { pub max_response_size_bytes: usize, } +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub enum TcpPoolMode { + /// Hold a compute connection after the first frontend message that needs + /// compute. Cached-startup sessions can sit idle before that without + /// checking out a compute connection. + #[default] + Session, + /// Return a compute connection to the pool at every transaction + /// boundary (compute sends ReadyForQuery with status `'I'`). Subsequent + /// transactions on the same client may land on different compute + /// connections — same semantic as PgBouncer transaction mode. + Transaction, +} + +#[derive(Clone, Copy, Debug)] +pub struct TcpPoolConfig { + pub enabled: bool, + pub mode: TcpPoolMode, + pub max_conns_per_key: usize, + pub max_total_conns: usize, + pub idle_timeout: Duration, + pub fallback_direct_connect: bool, +} + pub struct AuthenticationConfig { pub scram_thread_pool: Arc, pub scram_protocol_timeout: tokio::time::Duration, diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index f947abebc0ae..9f18b26202b6 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -233,7 +233,7 @@ pub(crate) async fn handle_client( let session = cancellation_handler.get_key(); - let (process_id, secret_key) = + let (process_id, secret_key, _) = forward_compute_params_to_client(ctx, *session.key(), &mut stream, &mut node.stream) .await?; let stream = stream.flush_and_into_inner().await?; @@ -260,18 +260,17 @@ pub(crate) async fn handle_client( ) .await; }); - Ok(Some(ProxyPassthrough { client: stream, - compute: node.stream.into_framed().into_inner(), - - aux: node.aux, + compute: Some(node), private_link_id: None, + tcp_pool_checkout: None, + tcp_pool_reacquire: None, + tcp_pool_config: config.tcp_pool_config, - _cancel_on_shutdown: cancel_on_shutdown, + _cancel_on_shutdown: Some(cancel_on_shutdown), _req: request_gauge, _conn: conn_gauge, - _db_conn: node.guage, })) } diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index 9e48d913403e..b4bf935fe86a 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -1,5 +1,6 @@ //! Mock console backend which relies on a user-provided postgres instance. +use std::collections::HashMap; use std::io; use std::net::{IpAddr, Ipv4Addr}; use std::str::FromStr; @@ -50,20 +51,29 @@ impl From for ControlPlaneError { #[derive(Clone)] pub struct MockControlPlane { - endpoint: ApiUrl, + auth_endpoint: ApiUrl, + compute_endpoint: Option, + compute_endpoint_map: HashMap, ip_allowlist_check_enabled: bool, } impl MockControlPlane { - pub fn new(endpoint: ApiUrl, ip_allowlist_check_enabled: bool) -> Self { + pub fn new( + auth_endpoint: ApiUrl, + compute_endpoint: Option, + compute_endpoint_map: HashMap, + ip_allowlist_check_enabled: bool, + ) -> Self { Self { - endpoint, + auth_endpoint, + compute_endpoint, + compute_endpoint_map, ip_allowlist_check_enabled, } } pub(crate) fn url(&self) -> &str { - self.endpoint.as_str() + self.auth_endpoint.as_str() } async fn do_get_auth_info( @@ -76,7 +86,7 @@ impl MockControlPlane { // write more code for reopening it if it got closed, which doesn't // seem worth it. let (client, connection) = - tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; + tokio_postgres::connect(self.auth_endpoint.as_str(), tokio_postgres::NoTls).await?; tokio::spawn(connection); @@ -121,7 +131,7 @@ impl MockControlPlane { Ok((secret, allowed_ips)) } .inspect_err(|e: &GetAuthInfoError| tracing::error!("{e}")) - .instrument(info_span!("postgres", url = self.endpoint.as_str())) + .instrument(info_span!("postgres", url = self.auth_endpoint.as_str())) .await?; Ok(AuthInfo { secret, @@ -139,7 +149,7 @@ impl MockControlPlane { endpoint: &EndpointId, ) -> Result, GetEndpointJwksError> { let (client, connection) = - tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?; + tokio_postgres::connect(self.auth_endpoint.as_str(), tokio_postgres::NoTls).await?; let connection = tokio::spawn(connection); @@ -170,20 +180,27 @@ impl MockControlPlane { Ok(rows) } - async fn do_wake_compute(&self) -> Result { - let port = self.endpoint.port().unwrap_or(5432); - let conn_info = match self.endpoint.host_str() { + async fn do_wake_compute(&self, endpoint: &EndpointId) -> Result { + // Lookup priority: per-endpoint map → global override → auth endpoint. + let target = self + .compute_endpoint_map + .get(endpoint) + .or(self.compute_endpoint.as_ref()) + .unwrap_or(&self.auth_endpoint); + let port = target.port().unwrap_or(5432); + let ssl_mode = parse_ssl_mode(target)?; + let conn_info = match target.host_str() { None => ConnectInfo { host_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), host: "localhost".into(), port, - ssl_mode: SslMode::Disable, + ssl_mode, }, Some(host) => ConnectInfo { host_addr: IpAddr::from_str(host).ok(), host: host.into(), port, - ssl_mode: SslMode::Disable, + ssl_mode, }, }; @@ -202,6 +219,20 @@ impl MockControlPlane { } } +fn parse_ssl_mode(url: &ApiUrl) -> Result { + let Some((_, sslmode)) = url.query_pairs().find(|(k, _)| k == "sslmode") else { + return Ok(SslMode::Disable); + }; + + match sslmode.as_ref() { + "disable" => Ok(SslMode::Disable), + "require" => Ok(SslMode::Require), + other => Err(WakeComputeError::BadComputeAddress( + format!("unsupported sslmode: {other}").into(), + )), + } +} + async fn get_execute_postgres_query( client: &Client, query: &str, @@ -262,8 +293,10 @@ impl super::ControlPlaneApi for MockControlPlane { async fn wake_compute( &self, _ctx: &RequestContext, - _user_info: &ComputeUserInfo, + user_info: &ComputeUserInfo, ) -> Result { - self.do_wake_compute().map_ok(Cached::new_uncached).await + self.do_wake_compute(&user_info.endpoint) + .map_ok(Cached::new_uncached) + .await } } diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 263d784e7800..34190ad798eb 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -107,6 +107,7 @@ mod scram; mod serverless; mod signals; mod stream; +mod tcp_pool; mod tls; mod types; mod url; diff --git a/proxy/src/pglb/copy_bidirectional.rs b/proxy/src/pglb/copy_bidirectional.rs index 5e4262a3231d..fee1ad94abdf 100644 --- a/proxy/src/pglb/copy_bidirectional.rs +++ b/proxy/src/pglb/copy_bidirectional.rs @@ -13,6 +13,15 @@ enum TransferState { Done(u64), } +impl TransferState { + fn is_buffer_empty_and_flushed(&self) -> bool { + match self { + TransferState::Running(buf) => buf.is_empty_and_flushed(), + TransferState::ShuttingDown(_) | TransferState::Done(_) => true, + } + } +} + #[derive(Debug)] pub(crate) enum ErrorDirection { Read(io::Error), @@ -67,6 +76,38 @@ where } } +/// Like `transfer_one_direction`, but never calls `poll_shutdown` on the writer +/// when the reader hits EOF. Used in the pooled bidirectional copy for the +/// client→compute direction, because shutting down the compute write half +/// (`shutdown(SHUT_WR)`) breaks the connection for pool reuse: Postgres +/// reacts to the FIN by closing the backend, and subsequent reads hit +/// `ENOTCONN`. +fn transfer_one_direction_no_shutdown( + cx: &mut Context<'_>, + state: &mut TransferState, + r: &mut A, + w: &mut B, +) -> Poll> +where + A: AsyncRead + AsyncWrite + Unpin + ?Sized, + B: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + let mut r = Pin::new(r); + let mut w = Pin::new(w); + loop { + match state { + TransferState::Running(buf) => { + let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?; + *state = TransferState::Done(count); + } + TransferState::ShuttingDown(count) => { + *state = TransferState::Done(*count); + } + TransferState::Done(count) => return Poll::Ready(Ok(*count)), + } + } +} + #[tracing::instrument(skip_all)] pub async fn copy_bidirectional_client_compute( client: &mut Client, @@ -87,14 +128,11 @@ where transfer_one_direction(cx, &mut compute_to_client, compute, client) .map_err(ErrorSource::from_compute)?; - // TODO: 1 info log, with a enum label for close direction. - // Early termination checks from compute to client. if let TransferState::Done(_) = compute_to_client && let TransferState::Running(buf) = &client_to_compute { info!("Compute is done, terminate client"); - // Initiate shutdown client_to_compute = TransferState::ShuttingDown(buf.amt); client_to_compute_result = transfer_one_direction(cx, &mut client_to_compute, client, compute) @@ -106,22 +144,205 @@ where && let TransferState::Running(buf) = &compute_to_client { info!("Client is done, terminate compute"); - // Initiate shutdown compute_to_client = TransferState::ShuttingDown(buf.amt); compute_to_client_result = transfer_one_direction(cx, &mut compute_to_client, compute, client) .map_err(ErrorSource::from_compute)?; } - // It is not a problem if ready! returns early ... (comment remains the same) let client_to_compute = ready!(client_to_compute_result); let compute_to_client = ready!(compute_to_client_result); - Poll::Ready(Ok((client_to_compute, compute_to_client))) }) .await } +#[tracing::instrument(skip_all)] +pub async fn copy_bidirectional_client_compute_pooled( + client: &mut Client, + compute: &mut Compute, +) -> Result<(u64, u64), ErrorSource> +where + Client: AsyncRead + AsyncWrite + Unpin + ?Sized, + Compute: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + let mut filtered_client = TerminateFilter::new(&mut *client); + + let mut client_to_compute = TransferState::Running(CopyBuffer::new()); + let mut compute_to_client = TransferState::Running(CopyBuffer::new()); + + let result: Result<(u64, u64), ErrorSource> = poll_fn(|cx| { + // CRITICAL: check for clean Terminate BEFORE driving the state machines. + // Once TerminateFilter has seen Terminate, we want to bail without + // calling poll_shutdown on compute (which would close the pooled connection). + if filtered_client.saw_terminate() { + let c2c_count = match &client_to_compute { + TransferState::Running(buf) => buf.amt, + TransferState::ShuttingDown(n) | TransferState::Done(n) => *n, + }; + let s2c_count = match &compute_to_client { + TransferState::Running(buf) => buf.amt, + TransferState::ShuttingDown(n) | TransferState::Done(n) => *n, + }; + info!("clean client disconnect via Terminate; compute connection preserved"); + return Poll::Ready(Ok((c2c_count, s2c_count))); + } + + // CRITICAL: never poll_shutdown(compute) — the connection must be + // reusable in the pool. If the client closes (or sends Terminate), + // we just stop forwarding without half-closing the compute socket. + let client_to_compute_result = transfer_one_direction_no_shutdown( + cx, + &mut client_to_compute, + &mut filtered_client, + compute, + ) + .map_err(ErrorSource::from_client)?; + + // Re-check after the read side has been driven; the filter may have just + // observed Terminate during this poll. + if filtered_client.saw_terminate() { + let c2c_count = match &client_to_compute { + TransferState::Running(buf) => buf.amt, + TransferState::ShuttingDown(n) | TransferState::Done(n) => *n, + }; + let s2c_count = match &compute_to_client { + TransferState::Running(buf) => buf.amt, + TransferState::ShuttingDown(n) | TransferState::Done(n) => *n, + }; + info!("clean client disconnect via Terminate; compute connection preserved"); + return Poll::Ready(Ok((c2c_count, s2c_count))); + } + + let compute_to_client_result = + transfer_one_direction(cx, &mut compute_to_client, compute, &mut filtered_client) + .map_err(ErrorSource::from_compute)?; + + // If compute closes first, this path runs and IS allowed to shut down the + // client (since the compute connection is gone anyway and shouldn't be pooled). + if let TransferState::Done(_) = compute_to_client + && let TransferState::Running(buf) = &client_to_compute + { + info!("Compute is done, terminate client"); + client_to_compute = TransferState::ShuttingDown(buf.amt); + // Note: we do NOT re-drive client_to_compute here because that would + // call poll_shutdown(compute) and we've already established compute is done. + // Just let the result fall through. + } + + let c2c = ready!(client_to_compute_result); + let s2c = ready!(compute_to_client_result); + Poll::Ready(Ok((c2c, s2c))) + }) + .await; + + result +} + +/// Outcome of a single iteration of the transaction-mode pump. +#[derive(Debug, Clone, Copy)] +pub(crate) enum BoundaryReason { + /// Compute sent ReadyForQuery; transaction status is in the byte. + /// `b'I'` = idle (compute can be returned to pool), `b'T'` / `b'E'` = + /// in/failed transaction (compute must be held). + ReadyForQuery(u8), + /// Client sent Terminate ('X'). Caller decides whether to release + /// compute (last status was idle) or discard (mid-transaction). + ClientTerminated, + /// Compute closed the read side. The compute connection is gone; + /// caller should propagate this as an error to the client. + ComputeClosed, +} + +/// Pump bytes between client and compute until a transaction boundary is +/// reached. Used by the transaction-mode multiplex loop. Never calls +/// `poll_shutdown` on the compute writer — the compute connection is +/// expected to live past this call (returned to the pool or held for +/// the next iteration of the inner loop). +#[tracing::instrument(skip_all)] +pub async fn copy_bidirectional_until_boundary( + client: &mut Client, + compute: &mut Compute, +) -> Result +where + Client: AsyncRead + AsyncWrite + Unpin + ?Sized, + Compute: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + let mut filtered_client = TerminateFilter::new(&mut *client); + let mut watched_compute = ReadyForQueryWatcher::new(&mut *compute); + + let mut client_to_compute = TransferState::Running(CopyBuffer::new()); + let mut compute_to_client = TransferState::Running(CopyBuffer::new()); + + poll_fn(|cx| { + if filtered_client.saw_terminate() { + return Poll::Ready(Ok(BoundaryReason::ClientTerminated)); + } + + // If compute already produced ReadyForQuery, do not read more client + // bytes into this backend. First drain the response bytes that caused + // the boundary to the current client; only then may the caller decide + // whether this backend is safe to return to the pool. + if watched_compute.ready_for_query_pending() { + let _ = transfer_one_direction_no_shutdown( + cx, + &mut compute_to_client, + &mut watched_compute, + &mut filtered_client, + ) + .map_err(ErrorSource::from_compute)?; + + if compute_to_client.is_buffer_empty_and_flushed() { + ready!(Pin::new(&mut filtered_client).poll_flush(cx)) + .map_err(ErrorSource::Client)?; + let status = watched_compute.last_status().unwrap_or(b'?'); + watched_compute.take_ready_for_query(); + return Poll::Ready(Ok(BoundaryReason::ReadyForQuery(status))); + } + + return Poll::Pending; + } + + // Drive client → compute. Never shut down compute here. + let _ = transfer_one_direction_no_shutdown( + cx, + &mut client_to_compute, + &mut filtered_client, + &mut watched_compute, + ) + .map_err(ErrorSource::from_client)?; + + if filtered_client.saw_terminate() { + return Poll::Ready(Ok(BoundaryReason::ClientTerminated)); + } + + // Drive compute → client through the watcher. + let _ = transfer_one_direction_no_shutdown( + cx, + &mut compute_to_client, + &mut watched_compute, + &mut filtered_client, + ) + .map_err(ErrorSource::from_compute)?; + + if watched_compute.ready_for_query_pending() + && compute_to_client.is_buffer_empty_and_flushed() + { + ready!(Pin::new(&mut filtered_client).poll_flush(cx)).map_err(ErrorSource::Client)?; + let status = watched_compute.last_status().unwrap_or(b'?'); + watched_compute.take_ready_for_query(); + return Poll::Ready(Ok(BoundaryReason::ReadyForQuery(status))); + } + + if let TransferState::Done(_) = compute_to_client { + return Poll::Ready(Ok(BoundaryReason::ComputeClosed)); + } + + Poll::Pending + }) + .await +} + #[derive(Debug)] pub(super) struct CopyBuffer { read_done: bool, @@ -145,6 +366,10 @@ impl CopyBuffer { } } + fn is_empty_and_flushed(&self) -> bool { + self.pos == self.cap && !self.need_flush + } + fn poll_fill_buf( &mut self, cx: &mut Context<'_>, @@ -263,10 +488,23 @@ impl CopyBuffer { #[cfg(test)] mod tests { - use tokio::io::AsyncWriteExt; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; use super::*; + fn command_complete(tag: &[u8]) -> Vec { + let mut msg = Vec::with_capacity(5 + tag.len() + 1); + msg.push(b'C'); + msg.extend_from_slice(&((tag.len() + 1 + 4) as u32).to_be_bytes()); + msg.extend_from_slice(tag); + msg.push(0); + msg + } + + fn ready_for_query(status: u8) -> [u8; 6] { + [b'Z', 0, 0, 0, 5, status] + } + #[tokio::test] async fn test_client_to_compute() { let (mut client_client, mut client_proxy) = tokio::io::duplex(8); // Create a mock duplex stream @@ -310,4 +548,432 @@ mod tests { assert_eq!(compute_to_client_count, 5); // 'hello' was transferred assert!(client_to_compute_count <= 8); // response only partially transferred or not at all } + + #[tokio::test] + async fn transaction_boundary_waits_for_ready_for_query_to_reach_client() { + let command = command_complete(b"BEGIN"); + let ready = ready_for_query(b'T'); + let mut response = command.clone(); + response.extend_from_slice(&ready); + + let (mut client_peer, mut client_proxy) = tokio::io::duplex(1); + let (mut compute_proxy, mut compute_peer) = tokio::io::duplex(1024); + + compute_peer.write_all(&response).await.unwrap(); + + let pump = tokio::spawn(async move { + copy_bidirectional_until_boundary(&mut client_proxy, &mut compute_proxy).await + }); + + let mut got_command = vec![0; command.len()]; + client_peer.read_exact(&mut got_command).await.unwrap(); + assert_eq!(got_command, command); + + tokio::task::yield_now().await; + assert!( + !pump.is_finished(), + "boundary returned before ReadyForQuery was written to the client" + ); + + let mut got_ready = [0; 6]; + client_peer.read_exact(&mut got_ready).await.unwrap(); + assert_eq!(got_ready, ready); + + let boundary = pump.await.unwrap().unwrap(); + assert!(matches!(boundary, BoundaryReason::ReadyForQuery(b'T'))); + } +} + +pub(crate) struct TerminateFilter { + inner: R, + state: FilterState, + saw_terminate: bool, +} + +enum FilterState { + AwaitingHeader { + header: [u8; 5], + pos: usize, + }, + InBody { + remaining: usize, + header: [u8; 5], + header_emitted: usize, + }, +} + +impl TerminateFilter { + pub(crate) fn new(inner: R) -> Self { + Self { + inner, + state: FilterState::AwaitingHeader { + header: [0; 5], + pos: 0, + }, + saw_terminate: false, + } + } + + pub(crate) fn saw_terminate(&self) -> bool { + self.saw_terminate + } +} + +impl AsyncRead for TerminateFilter { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + out: &mut ReadBuf<'_>, + ) -> Poll> { + // Destructure self once so we have separate borrows of each field. + // This is the standard Rust idiom for "split borrow" through a Pin. + let this = self.get_mut(); + + if this.saw_terminate { + return Poll::Ready(Ok(())); + } + + loop { + match &mut this.state { + FilterState::InBody { .. } => { + if let FilterState::InBody { + remaining: _, + header, + header_emitted, + } = &mut this.state + { + if *header_emitted < 5 { + if out.remaining() == 0 { + return Poll::Ready(Ok(())); + } + let need = 5 - *header_emitted; + let to_emit = need.min(out.remaining()); + let start = *header_emitted; + out.put_slice(&header[start..start + to_emit]); + *header_emitted += to_emit; + if *header_emitted < 5 { + return Poll::Ready(Ok(())); + } + } + } + let FilterState::InBody { remaining, .. } = &mut this.state else { + unreachable!(); + }; + if *remaining == 0 { + this.state = FilterState::AwaitingHeader { + header: [0; 5], + pos: 0, + }; + continue; + } + let to_read = (*remaining).min(out.remaining()); + if to_read == 0 { + return Poll::Ready(Ok(())); + } + let initial_filled = out.filled().len(); + let mut limited = out.take(to_read); + match Pin::new(&mut this.inner).poll_read(cx, &mut limited) { + Poll::Ready(Ok(())) => { + let n = limited.filled().len(); + if n == 0 { + return Poll::Ready(Ok(())); + } + out.set_filled(initial_filled + n); + *remaining -= n; + return Poll::Ready(Ok(())); + } + other => return other, + } + } + FilterState::AwaitingHeader { header, pos } => { + let need = 5 - *pos; + let mut tmp = [0u8; 5]; + let mut tmp_buf = ReadBuf::new(&mut tmp[..need]); + match Pin::new(&mut this.inner).poll_read(cx, &mut tmp_buf) { + Poll::Ready(Ok(())) => { + let n = tmp_buf.filled().len(); + if n == 0 { + if *pos > 0 { + let to_emit = (*pos).min(out.remaining()); + out.put_slice(&header[..to_emit]); + } + return Poll::Ready(Ok(())); + } + header[*pos..*pos + n].copy_from_slice(&tmp[..n]); + *pos += n; + if *pos < 5 { + continue; + } + let tag = header[0]; + let len = + u32::from_be_bytes([header[1], header[2], header[3], header[4]]) + as usize; + if tag == b'X' && len == 4 { + tracing::info!( + "TerminateFilter: intercepted Postgres Terminate; \ + compute connection preserved" + ); + this.saw_terminate = true; + return Poll::Ready(Ok(())); + } + // Save header bytes locally before mutating this.state. + let header_copy = *header; + let body_len = len.saturating_sub(4); + let to_emit = 5.min(out.remaining()); + out.put_slice(&header_copy[..to_emit]); + this.state = FilterState::InBody { + remaining: body_len, + header: header_copy, + header_emitted: to_emit, + }; + return Poll::Ready(Ok(())); + } + other => return other, + } + } + } + } + } +} + +impl tokio::io::AsyncWrite for TerminateFilter { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_shutdown(cx) + } +} + +/// Transparently pass bytes through from the compute side while parsing +/// Postgres backend protocol message headers. Tracks the last-seen +/// `ReadyForQuery` (`'Z'`) transaction-status byte (`'I'` idle, `'T'` in-tx, +/// `'E'` failed-tx). The transaction-mode pump uses this to decide when a +/// transaction boundary has been reached and the compute connection can be +/// returned to the pool. +pub(crate) struct ReadyForQueryWatcher { + inner: R, + state: WatchState, + last_status: Option, + saw_ready_for_query: bool, +} + +enum WatchState { + AwaitingHeader { + header: [u8; 5], + pos: usize, + }, + InBody { + tag: u8, + remaining: usize, + body_seen: usize, + header: [u8; 5], + header_emitted: usize, + }, +} + +impl ReadyForQueryWatcher { + pub(crate) fn new(inner: R) -> Self { + Self { + inner, + state: WatchState::AwaitingHeader { + header: [0; 5], + pos: 0, + }, + last_status: None, + saw_ready_for_query: false, + } + } + + pub(crate) fn last_status(&self) -> Option { + self.last_status + } + + pub(crate) fn ready_for_query_pending(&self) -> bool { + self.saw_ready_for_query + } + + /// Returns true if a `ReadyForQuery` was observed since the last call + /// (and clears the flag). The caller uses this as the "transaction + /// boundary reached" signal in the multiplex loop. + pub(crate) fn take_ready_for_query(&mut self) -> bool { + std::mem::replace(&mut self.saw_ready_for_query, false) + } +} + +impl AsyncRead for ReadyForQueryWatcher { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + out: &mut ReadBuf<'_>, + ) -> Poll> { + let this = self.get_mut(); + // Once a ReadyForQuery boundary is observed, pause reads until the + // caller consumes it via take_ready_for_query(). This prevents the + // boundary pump from over-reading into subsequent backend messages. + if this.saw_ready_for_query { + return Poll::Pending; + } + + loop { + match &mut this.state { + WatchState::AwaitingHeader { header, pos } => { + let need = 5 - *pos; + let mut tmp = [0u8; 5]; + let mut tmp_buf = ReadBuf::new(&mut tmp[..need]); + match Pin::new(&mut this.inner).poll_read(cx, &mut tmp_buf) { + Poll::Ready(Ok(())) => { + let n = tmp_buf.filled().len(); + if n == 0 { + if *pos > 0 { + let to_emit = (*pos).min(out.remaining()); + out.put_slice(&header[..to_emit]); + } + return Poll::Ready(Ok(())); + } + header[*pos..*pos + n].copy_from_slice(&tmp[..n]); + *pos += n; + if *pos < 5 { + continue; + } + + let tag = header[0]; + let len = + u32::from_be_bytes([header[1], header[2], header[3], header[4]]) + as usize; + let body_len = len.saturating_sub(4); + + // bidir copy uses a 1024-byte read buffer, so + // out.remaining() is always >= 5 here in practice. + // We still cap the put_slice for safety. + let header_copy = *header; + let to_emit = 5.min(out.remaining()); + out.put_slice(&header_copy[..to_emit]); + + this.state = WatchState::InBody { + tag, + remaining: body_len, + body_seen: 0, + header: header_copy, + header_emitted: to_emit, + }; + + if body_len == 0 && to_emit == 5 { + this.state = WatchState::AwaitingHeader { + header: [0; 5], + pos: 0, + }; + } + + if out.remaining() == 0 || to_emit < 5 || body_len == 0 { + return Poll::Ready(Ok(())); + } + + // Fall through to read body bytes in the same poll. + continue; + } + other => return other, + } + } + WatchState::InBody { + tag, + remaining, + body_seen, + header, + header_emitted, + } => { + if *header_emitted < 5 { + if out.remaining() == 0 { + return Poll::Ready(Ok(())); + } + let need = 5 - *header_emitted; + let to_emit = need.min(out.remaining()); + let start = *header_emitted; + out.put_slice(&header[start..start + to_emit]); + *header_emitted += to_emit; + if *header_emitted < 5 { + return Poll::Ready(Ok(())); + } + if *remaining == 0 { + this.state = WatchState::AwaitingHeader { + header: [0; 5], + pos: 0, + }; + continue; + } + } + + if *remaining == 0 { + this.state = WatchState::AwaitingHeader { + header: [0; 5], + pos: 0, + }; + continue; + } + if out.remaining() == 0 { + return Poll::Ready(Ok(())); + } + let to_read = (*remaining).min(out.remaining()); + let initial_filled = out.filled().len(); + let mut limited = out.take(to_read); + match Pin::new(&mut this.inner).poll_read(cx, &mut limited) { + Poll::Ready(Ok(())) => { + let n = limited.filled().len(); + if n == 0 { + return Poll::Ready(Ok(())); + } + + // For ReadyForQuery, the body length must be 1 and + // the only byte is the transaction status. + let mut ready_for_query_status = None; + if *tag == b'Z' { + if *body_seen != 0 || *remaining != 1 || n != 1 { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "malformed ReadyForQuery message", + ))); + } + ready_for_query_status = Some(limited.filled()[0]); + } + + out.set_filled(initial_filled + n); + *remaining -= n; + *body_seen += n; + if let Some(status) = ready_for_query_status { + this.last_status = Some(status); + this.saw_ready_for_query = true; + } + return Poll::Ready(Ok(())); + } + other => return other, + } + } + } + } + } +} + +impl tokio::io::AsyncWrite for ReadyForQueryWatcher { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_write(cx, buf) + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_flush(cx) + } + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().inner).poll_shutdown(cx) + } } diff --git a/proxy/src/pglb/mod.rs b/proxy/src/pglb/mod.rs index 999fa6eb3259..c5df789b209a 100644 --- a/proxy/src/pglb/mod.rs +++ b/proxy/src/pglb/mod.rs @@ -296,7 +296,7 @@ pub(crate) async fn handle_connection( let common_names = tls.map(|tls| &tls.common_names); - let (node, cancel_on_shutdown) = handle_client( + let (node, cancel_on_shutdown, tcp_pool_checkout, tcp_pool_reacquire) = handle_client( config, auth_backend, ctx, @@ -316,18 +316,17 @@ pub(crate) async fn handle_connection( Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()), None => None, }; - Ok(Some(ProxyPassthrough { client, - compute: node.stream.into_framed().into_inner(), - - aux: node.aux, + compute: node, private_link_id, + tcp_pool_checkout, + tcp_pool_reacquire, + tcp_pool_config: config.tcp_pool_config, _cancel_on_shutdown: cancel_on_shutdown, _req: request_gauge, _conn: conn_gauge, - _db_conn: node.guage, })) } diff --git a/proxy/src/pglb/passthrough.rs b/proxy/src/pglb/passthrough.rs index d4c029f6d938..62455988fec6 100644 --- a/proxy/src/pglb/passthrough.rs +++ b/proxy/src/pglb/passthrough.rs @@ -1,18 +1,19 @@ use std::convert::Infallible; +use postgres_client::connect_raw::StartupStream; use smol_str::SmolStr; +use tokio::io::AsyncWriteExt; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::debug; use utils::measured_stream::MeasuredStream; use super::copy_bidirectional::ErrorSource; -use crate::compute::MaybeRustlsStream; +use crate::compute::{ComputeConnection, MaybeRustlsStream}; +use crate::config::{TcpPoolConfig, TcpPoolMode}; use crate::control_plane::messages::MetricsAuxInfo; -use crate::metrics::{ - Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard, - NumDbConnectionsGuard, -}; +use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard}; use crate::stream::Stream; +use crate::tcp_pool::{TcpPoolCheckout, TcpPoolReacquire}; use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS}; /// Forward bytes in both directions (client <-> compute). @@ -64,22 +65,392 @@ pub(crate) async fn proxy_pass( Ok(()) } +/// Same as `proxy_pass` but preserves compute stream on client EOF for pooled reuse. +#[tracing::instrument(skip_all)] +pub(crate) async fn proxy_pass_pooled( + client: impl AsyncRead + AsyncWrite + Unpin, + compute: impl AsyncRead + AsyncWrite + Unpin, + aux: MetricsAuxInfo, + private_link_id: Option, +) -> Result<(), ErrorSource> { + // we will report ingress at a later date + let usage_tx = USAGE_METRICS.register(Ids { + endpoint_id: aux.endpoint_id, + branch_id: aux.branch_id, + private_link_id, + }); + + let metrics = &Metrics::get().proxy.io_bytes; + let m_sent = metrics.with_labels(Direction::Tx); + let mut client = MeasuredStream::new( + client, + |_| {}, + |cnt| { + metrics.get_metric(m_sent).inc_by(cnt as u64); + usage_tx.record_egress(cnt as u64); + }, + ); + + let m_recv = metrics.with_labels(Direction::Rx); + let mut compute = MeasuredStream::new( + compute, + |_| {}, + |cnt| { + metrics.get_metric(m_recv).inc_by(cnt as u64); + usage_tx.record_ingress(cnt as u64); + }, + ); + + debug!("performing the pooled proxy pass..."); + let _ = crate::pglb::copy_bidirectional::copy_bidirectional_client_compute_pooled( + &mut client, + &mut compute, + ) + .await?; + + Ok(()) +} + +struct ComputeParts { + aux: MetricsAuxInfo, + hostname: crate::types::Host, + ssl_mode: postgres_client::config::SslMode, + socket_addr: std::net::SocketAddr, + guage: crate::metrics::NumDbConnectionsGuard<'static>, +} + +fn split_compute(conn: ComputeConnection) -> (ComputeParts, MaybeRustlsStream) { + let stream = conn.stream.into_framed().into_inner(); + let parts = ComputeParts { + aux: conn.aux, + hostname: conn.hostname, + ssl_mode: conn.ssl_mode, + socket_addr: conn.socket_addr, + guage: conn.guage, + }; + (parts, stream) +} + +fn join_compute(parts: ComputeParts, stream: MaybeRustlsStream) -> ComputeConnection { + ComputeConnection { + stream: StartupStream::new(stream), + aux: parts.aux, + hostname: parts.hostname, + ssl_mode: parts.ssl_mode, + socket_addr: parts.socket_addr, + guage: parts.guage, + } +} + +fn release_compute_checkout( + current_checkout: &mut Option, + stream: &mut Option, + compute_parts: &mut Option, + reusable: bool, +) { + if let Some(checkout) = current_checkout.take() { + let stream_inner = stream + .take() + .expect("compute stream must exist before release"); + let c = join_compute( + compute_parts + .take() + .expect("compute must exist before release"), + stream_inner, + ); + checkout.release(c, reusable); + } +} + +async fn write_frontend_message(stream: &mut S, tag: u8, body: &[u8]) -> Result<(), ErrorSource> +where + S: AsyncWrite + Unpin, +{ + stream.write_u8(tag).await.map_err(ErrorSource::Compute)?; + stream + .write_u32((body.len() + 4) as u32) + .await + .map_err(ErrorSource::Compute)?; + stream.write_all(body).await.map_err(ErrorSource::Compute)?; + stream.flush().await.map_err(ErrorSource::Compute) +} + pub(crate) struct ProxyPassthrough { pub(crate) client: Stream, - pub(crate) compute: MaybeRustlsStream, - - pub(crate) aux: MetricsAuxInfo, + pub(crate) compute: Option, pub(crate) private_link_id: Option, + pub(crate) tcp_pool_checkout: Option, + pub(crate) tcp_pool_reacquire: Option, + pub(crate) tcp_pool_config: TcpPoolConfig, - pub(crate) _cancel_on_shutdown: tokio::sync::oneshot::Sender, + pub(crate) _cancel_on_shutdown: Option>, pub(crate) _req: NumConnectionRequestsGuard<'static>, pub(crate) _conn: NumClientConnectionsGuard<'static>, - pub(crate) _db_conn: NumDbConnectionsGuard<'static>, } impl ProxyPassthrough { pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> { - proxy_pass(self.client, self.compute, self.aux, self.private_link_id).await + let mut pt = self; + + // Three paths share this entry point: + // - no pool checkout: legacy direct passthrough. + // - session pool metadata: acquire compute only after first frontend + // message, then hold it for the rest of the client session. + // - transaction pool metadata: multiplex loop that acquires compute + // only while a transaction is active. + let transaction_mode = + pt.tcp_pool_reacquire.is_some() && pt.tcp_pool_config.mode == TcpPoolMode::Transaction; + let session_mode = pt.tcp_pool_config.mode == TcpPoolMode::Session + && (pt.tcp_pool_checkout.is_some() || pt.tcp_pool_reacquire.is_some()); + + if session_mode || transaction_mode { + if let Some(compute) = pt.compute.take() { + let checkout = pt + .tcp_pool_checkout + .take() + .expect("initial pooled compute requires a pool checkout"); + if pt.tcp_pool_reacquire.is_none() { + pt.tcp_pool_reacquire = Some(checkout.reacquire_info()); + } + checkout.release(compute, true); + } + } + + if transaction_mode { + return proxy_pass_transaction_mode(pt).await; + } + + if session_mode { + return proxy_pass_session_mode(pt).await; + } + + let compute = pt + .compute + .expect("non-transaction passthrough requires an initial compute connection"); + let _keep_db_guard_live = &compute.guage; + let mut stream: MaybeRustlsStream = compute.stream.into_framed().into_inner(); + let aux = compute.aux.clone(); + + proxy_pass(pt.client, &mut stream, aux, pt.private_link_id).await } } + +/// Session-mode pool path. Startup may be cold (an initial compute was needed +/// to obtain startup params) or warm (startup params were cached). In either +/// case, do not hold compute while the client is idle after startup. +async fn proxy_pass_session_mode( + pt: ProxyPassthrough, +) -> Result<(), ErrorSource> { + let ProxyPassthrough { + mut client, + private_link_id, + tcp_pool_reacquire, + .. + } = pt; + + debug!("performing the session-mode pump..."); + + let reacquire = tcp_pool_reacquire.expect("session mode requires pool reacquire metadata"); + let pool_key = reacquire.key().clone(); + let reset_query = reacquire.reset_query(); + + let mut client_msg_buf = Vec::new(); + + let (tag, body) = + crate::pqproto::read_message(&mut client, &mut client_msg_buf, i32::MAX as u32) + .await + .map_err(ErrorSource::Client)?; + + if tag == b'X' { + return Ok(()); + } + + let (compute, checkout) = crate::tcp_pool::manager() + .reacquire(pool_key, reset_query) + .await + .map_err(|e| ErrorSource::Compute(std::io::Error::other(e.to_string())))?; + + let (parts, mut stream) = split_compute(compute); + write_frontend_message(&mut stream, tag, &body).await?; + + let aux = parts.aux.clone(); + let result = proxy_pass_pooled(client, &mut stream, aux, private_link_id).await; + let compute = join_compute(parts, stream); + checkout.release(compute, result.is_ok()); + + result +} + +/// Transaction-mode multiplex loop. Pumps one transaction at a time using +/// `copy_bidirectional_until_boundary`; at each ReadyForQuery `'I'` the +/// compute connection is released to the pool and a new one is acquired +/// for the next transaction. Mid-transaction (status `'T'` / `'E'`) the +/// same compute is held until the transaction completes or aborts. +/// +/// Caveats (standard PgBouncer transaction-mode caveats apply): +/// - `SET` (without `LOCAL`) does not survive across transactions +/// - `LISTEN` notifications may be lost +/// - prepared statements created via simple `PREPARE` are not visible +/// to the next transaction +/// - `pg_cancel_backend(pid)` does not work as expected because the pid +/// the client received is synthesized, not a real backend pid +async fn proxy_pass_transaction_mode( + pt: ProxyPassthrough, +) -> Result<(), ErrorSource> { + use crate::pglb::copy_bidirectional::{BoundaryReason, copy_bidirectional_until_boundary}; + + let ProxyPassthrough { + mut client, + private_link_id, + tcp_pool_reacquire, + tcp_pool_config: _, + .. + } = pt; + + // SAFETY: transaction mode is only entered with reacquire metadata. + let reacquire = tcp_pool_reacquire.expect("transaction mode requires pool reacquire metadata"); + let pool_key = reacquire.key().clone(); + let reset_query = reacquire.reset_query(); + let mut current_checkout: Option = None; + let mut stream: Option = None; + let mut compute_parts: Option = None; + let mut client_msg_buf = Vec::new(); + + // Mirror proxy_pass_pooled's metrics wiring so the per-transaction + // pump still records bytes-in/bytes-out for the session. + let mut usage_tx = None; + let metrics = &Metrics::get().proxy.io_bytes; + let m_sent = metrics.with_labels(Direction::Tx); + let m_recv = metrics.with_labels(Direction::Rx); + + debug!("performing the transaction-mode pump..."); + + // After startup handshake the backend is idle from transaction perspective. + let mut last_known_status: u8 = b'I'; + + let final_result: Result<(), ErrorSource> = loop { + if current_checkout.is_none() { + let (tag, body) = + crate::pqproto::read_message(&mut client, &mut client_msg_buf, i32::MAX as u32) + .await + .map_err(ErrorSource::Client)?; + + if tag == b'X' { + break Ok(()); + } + + let (next_compute, next_checkout) = crate::tcp_pool::manager() + .reacquire(pool_key.clone(), reset_query.clone()) + .await + .map_err(|e| ErrorSource::Compute(std::io::Error::other(e.to_string())))?; + + let (next_parts, mut next_stream) = split_compute(next_compute); + write_frontend_message(&mut next_stream, tag, &body).await?; + + stream = Some(next_stream); + compute_parts = Some(next_parts); + current_checkout = Some(next_checkout); + } + + let usage_tx = usage_tx.get_or_insert_with(|| { + let aux = &compute_parts + .as_ref() + .expect("compute metadata must be checked out in transaction loop") + .aux; + USAGE_METRICS.register(Ids { + endpoint_id: aux.endpoint_id, + branch_id: aux.branch_id, + private_link_id: private_link_id.clone(), + }) + }); + + let mut measured_client = MeasuredStream::new( + &mut client, + |_| {}, + |cnt| { + metrics.get_metric(m_sent).inc_by(cnt as u64); + usage_tx.record_egress(cnt as u64); + }, + ); + + let stream_ref = stream + .as_mut() + .expect("compute stream must be checked out in transaction loop"); + let mut measured_compute = MeasuredStream::new( + stream_ref, + |_| {}, + |cnt| { + metrics.get_metric(m_recv).inc_by(cnt as u64); + usage_tx.record_ingress(cnt as u64); + }, + ); + + let boundary = + copy_bidirectional_until_boundary(&mut measured_client, &mut measured_compute).await; + drop(measured_compute); + + match boundary { + Ok(BoundaryReason::ReadyForQuery(status)) => { + if !matches!(status, b'I' | b'T' | b'E') { + release_compute_checkout( + &mut current_checkout, + &mut stream, + &mut compute_parts, + false, + ); + break Err(ErrorSource::Compute(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("invalid ReadyForQuery status: {status}"), + ))); + } + last_known_status = status; + if status == b'I' { + // Transaction complete: release and wait for next frontend message. + release_compute_checkout( + &mut current_checkout, + &mut stream, + &mut compute_parts, + true, + ); + last_known_status = b'I'; + } + // 'T'/'E': stay in inner loop, keep compute held. + } + Ok(BoundaryReason::ClientTerminated) => { + // Pool the compute only if it's known to be idle. If we were + // mid-transaction (last_status was 'T' or 'E'), the conn has + // an open BEGIN block and is unsafe to share. + release_compute_checkout( + &mut current_checkout, + &mut stream, + &mut compute_parts, + last_known_status == b'I', + ); + break Ok(()); + } + Ok(BoundaryReason::ComputeClosed) => { + release_compute_checkout( + &mut current_checkout, + &mut stream, + &mut compute_parts, + false, + ); + break Err(ErrorSource::Compute(std::io::Error::other( + "compute closed connection", + ))); + } + Err(e) => { + release_compute_checkout( + &mut current_checkout, + &mut stream, + &mut compute_parts, + false, + ); + break Err(e); + } + } + }; + + final_result +} diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index b42457cd9568..ff64378ba6e6 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -26,13 +26,14 @@ use tracing::Instrument; use crate::cancellation::{CancelClosure, CancellationHandler}; use crate::compute::{ComputeConnection, PostgresError, RustlsStream}; -use crate::config::ProxyConfig; +use crate::config::{ProxyConfig, TcpPoolConfig, TcpPoolMode}; use crate::context::RequestContext; pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; use crate::pglb::{ClientMode, ClientRequestError}; use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams}; use crate::rate_limiter::EndpointRateLimiter; use crate::stream::{PqStream, Stream}; +use crate::tcp_pool::{TcpPoolCheckout, TcpPoolKey, TcpPoolReacquire}; use crate::types::EndpointCacheKey; use crate::{auth, compute}; @@ -47,7 +48,15 @@ pub(crate) async fn handle_client( endpoint_rate_limiter: Arc, common_names: Option<&HashSet>, params: &StartupMessageParams, -) -> Result<(ComputeConnection, oneshot::Sender), ClientRequestError> { +) -> Result< + ( + Option, + Option>, + Option, + Option, + ), + ClientRequestError, +> { let hostname = mode.hostname(client.get_ref()); // Extract credentials which we're going to use for auth. let result = auth_backend @@ -92,33 +101,112 @@ pub(crate) async fn handle_client( let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys); auth_info.set_startup_params(params, params_compat); - let backend = auth::Backend::ControlPlane(cplane, creds.info); + let pool_key = TcpPoolKey::new( + creds.info.endpoint_cache_key(), + params.get("database").unwrap_or(user.as_ref()).into(), + creds.info.user.clone(), + ); + let startup_cache_key = pool_key.clone(); + let cancel_user_info = creds.info.clone(); + let cplane = (*cplane).clone(); + let tcp_pool_manager = crate::tcp_pool::manager(); + let tcp_pool_config = TcpPoolConfig { + enabled: config.tcp_pool_config.enabled && creds.info.use_tcp_pool, + ..config.tcp_pool_config + }; + let transaction_pooling = + tcp_pool_config.enabled && tcp_pool_config.mode == TcpPoolMode::Transaction; + let lazy_session_pooling = + tcp_pool_config.enabled && tcp_pool_config.mode == TcpPoolMode::Session; + + if (transaction_pooling || lazy_session_pooling) + && let Some(startup_params) = tcp_pool_manager.get_startup_params(&startup_cache_key) + { + let tcp_pool_reacquire = tcp_pool_manager + .prepare_reacquire( + &tcp_pool_config, + pool_key, + ctx.clone(), + config, + cplane, + creds.info, + auth_info, + ) + .await; - // TODO: callback to pglb - let res = connect_auth::connect_to_compute_and_auth( - ctx, - config, - &backend, - auth_info, - connect_compute::TlsNegotiation::Postgres, - ) - .await; + send_client_greeting(ctx, &config.greetings, client); - let mut node = match res { - Ok(node) => node, + let session = cancellation_handler.get_key(); + for (name, value) in startup_params.iter() { + client.write_message(BeMessage::ParameterStatus { + name: name.as_bytes(), + value: value.as_bytes(), + }); + } + client.write_message(BeMessage::BackendKeyData(*session.key())); + client.write_message(BeMessage::ReadyForQuery); + + return Ok((None, None, None, Some(tcp_pool_reacquire))); + } + + let (mut node, tcp_pool_checkout, was_reused) = match crate::tcp_pool::manager() + .acquire_or_connect( + &tcp_pool_config, + pool_key, + ctx.clone(), + config, + cplane, + creds.info, + auth_info, + ) + .await + { + Ok(v) => v, Err(e) => Err(client.throw_error(e, Some(ctx)).await)?, }; + let tcp_pool_reacquire = if transaction_pooling { + tcp_pool_checkout + .as_ref() + .map(TcpPoolCheckout::reacquire_info) + } else { + None + }; send_client_greeting(ctx, &config.greetings, client); - let auth::Backend::ControlPlane(_, user_info) = backend else { - unreachable!("ensured above"); + let session = cancellation_handler.get_key(); + + let (process_id, secret_key) = if was_reused { + if let Some(startup_params) = + crate::tcp_pool::manager().get_startup_params(&startup_cache_key) + { + for (name, value) in startup_params.iter() { + client.write_message(BeMessage::ParameterStatus { + name: name.as_bytes(), + value: value.as_bytes(), + }); + } + } + client.write_message(BeMessage::BackendKeyData(*session.key())); + client.write_message(BeMessage::ReadyForQuery); + (0, 0) + } else { + let (process_id, secret_key, startup_params) = + forward_compute_params_to_client(ctx, *session.key(), client, &mut node.stream).await?; + crate::tcp_pool::manager().set_startup_params(&startup_cache_key, startup_params); + (process_id, secret_key) }; - let session = cancellation_handler.get_key(); + if let Some(checkout) = tcp_pool_checkout.as_ref() { + let reset_query = checkout.reset_query(); + node = match crate::tcp_pool::reset_session(node, &reset_query).await { + Ok(node) => node, + Err(err) => Err(client + .throw_error(PostgresError::Postgres(err), Some(ctx)) + .await)?, + }; + } - let (process_id, secret_key) = - forward_compute_params_to_client(ctx, *session.key(), client, &mut node.stream).await?; let hostname = node.hostname.to_string(); let session_id = ctx.session_id(); @@ -136,14 +224,19 @@ pub(crate) async fn handle_client( secret_key, }, hostname, - user_info, + user_info: cancel_user_info, }, &config.connect_to_compute, ) .await; }); - Ok((node, cancel_on_shutdown)) + Ok(( + Some(node), + Some(cancel_on_shutdown), + tcp_pool_checkout, + tcp_pool_reacquire, + )) } /// Greet the client with any useful information. @@ -194,9 +287,10 @@ pub(crate) async fn forward_compute_params_to_client( cancel_key_data: CancelKeyData, client: &mut PqStream, compute: &mut StartupStream, -) -> Result<(i32, i32), ClientRequestError> { +) -> Result<(i32, i32, Vec<(Box, Box)>), ClientRequestError> { let mut process_id = 0; let mut secret_key = 0; + let mut startup_params = Vec::new(); let err = loop { // if the client buffer is too large, let's write out some bytes now to save some space @@ -223,6 +317,10 @@ pub(crate) async fn forward_compute_params_to_client( name: name.as_bytes(), value: value.as_bytes(), }); + startup_params.push(( + name.to_owned().into_boxed_str(), + value.to_owned().into_boxed_str(), + )); } } // Forward all notices to the client. @@ -233,7 +331,7 @@ pub(crate) async fn forward_compute_params_to_client( } Some(Message::ReadyForQuery(_)) => { client.write_message(BeMessage::ReadyForQuery); - return Ok((process_id, secret_key)); + return Ok((process_id, secret_key, startup_params)); } Some(Message::ErrorResponse(body)) => break postgres_client::Error::db(body), Some(_) => break postgres_client::Error::unexpected_message(), diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 7e0710749e75..9c0ce57a8dd2 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -563,6 +563,7 @@ fn helper_create_connect_info( endpoint: "endpoint".into(), user: "user".into(), options: NeonOptions::parse_options_raw(""), + use_tcp_pool: false, }, ) } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 5b356c8460d4..9e22bb7b43db 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -245,6 +245,7 @@ impl PoolingBackend { conn_info.user_info.endpoint.normalize() )), options: conn_info.user_info.options.clone(), + use_tcp_pool: false, }); let node = connect_compute::connect_to_compute( diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 17305e30f16e..3621befcc128 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -241,6 +241,7 @@ mod tests { user: "user".into(), endpoint: "endpoint".into(), options: NeonOptions::default(), + use_tcp_pool: false, }, dbname: "dbname".into(), }; @@ -292,6 +293,7 @@ mod tests { user: "user".into(), endpoint: "endpoint-2".into(), options: NeonOptions::default(), + use_tcp_pool: false, }, dbname: "dbname".into(), }; diff --git a/proxy/src/serverless/http_util.rs b/proxy/src/serverless/http_util.rs index 0c91ac683579..a140fea77044 100644 --- a/proxy/src/serverless/http_util.rs +++ b/proxy/src/serverless/http_util.rs @@ -266,6 +266,7 @@ pub(crate) fn get_conn_info( endpoint, user: username, options: options.unwrap_or_default(), + use_tcp_pool: false, }; let conn_info = ConnInfo { user_info, dbname }; diff --git a/proxy/src/tcp_pool.rs b/proxy/src/tcp_pool.rs new file mode 100644 index 000000000000..3762e4e347ea --- /dev/null +++ b/proxy/src/tcp_pool.rs @@ -0,0 +1,454 @@ +use std::convert::Infallible; +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use bb8::{ManageConnection, Pool, PooledConnection}; +use futures::TryStreamExt; +use once_cell::sync::Lazy; +use postgres_client::connect_raw::StartupStream; +use postgres_protocol::message::backend::Message; +use rand::Rng; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tracing::debug; + +use crate::auth::Backend; +use crate::auth::backend::{ComputeUserInfo, MaybeOwned}; +use crate::compute::{AuthInfo, ComputeConnection, MaybeRustlsStream}; +use crate::config::{ProxyConfig, TcpPoolConfig}; +use crate::context::RequestContext; +use crate::error::{ErrorKind, ReportableError, UserFacingError}; +use crate::pqproto; +use crate::proxy::connect_auth::{self, AuthError}; +use crate::proxy::connect_compute; +use crate::types::{DbName, EndpointCacheKey, RoleName}; + +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub(crate) struct TcpPoolKey { + endpoint: EndpointCacheKey, + dbname: DbName, + role: RoleName, +} + +impl TcpPoolKey { + pub(crate) fn new(endpoint: EndpointCacheKey, dbname: DbName, role: RoleName) -> Self { + Self { + endpoint, + dbname, + role, + } + } +} + +impl std::fmt::Display for TcpPoolKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}/{}@{}", self.endpoint, self.role, self.dbname) + } +} + +struct PooledCompute { + conn: Option, + fresh: bool, +} + +#[derive(Clone)] +struct ComputeConnectionManager { + ctx: RequestContext, + config: &'static ProxyConfig, + backend: Arc>, + auth_info: AuthInfo, +} + +#[async_trait] +impl ManageConnection for ComputeConnectionManager { + type Connection = PooledCompute; + type Error = AuthError; + + async fn connect(&self) -> Result { + let conn = connect_auth::connect_to_compute_and_auth( + &self.ctx, + self.config, + &self.backend, + self.auth_info.clone(), + connect_compute::TlsNegotiation::Postgres, + ) + .await?; + + Ok(PooledCompute { + conn: Some(conn), + fresh: true, + }) + } + + async fn is_valid(&self, _conn: &mut Self::Connection) -> Result<(), Self::Error> { + Ok(()) + } + + fn has_broken(&self, conn: &mut Self::Connection) -> bool { + conn.conn.is_none() + } +} + +type KeyedPool = Pool; + +#[derive(Default)] +struct Inner { + pools: clashmap::ClashMap, + startup_params: clashmap::ClashMap, Box)>>>, +} + +pub(crate) struct TcpPoolCheckout { + key: TcpPoolKey, + reset_query: Arc, + pooled: Option>, +} + +#[derive(Clone)] +pub(crate) struct TcpPoolReacquire { + key: TcpPoolKey, + reset_query: Arc, +} + +impl TcpPoolReacquire { + pub(crate) fn key(&self) -> &TcpPoolKey { + &self.key + } + + pub(crate) fn reset_query(&self) -> Arc { + self.reset_query.clone() + } +} + +impl TcpPoolCheckout { + pub(crate) fn reset_query(&self) -> Arc { + self.reset_query.clone() + } + + pub(crate) fn reacquire_info(&self) -> TcpPoolReacquire { + TcpPoolReacquire { + key: self.key.clone(), + reset_query: self.reset_query.clone(), + } + } + + pub(crate) fn release(mut self, conn: ComputeConnection, reusable: bool) { + if !reusable { + self.pooled.take(); + return; + } + + if let Some(mut pooled) = self.pooled.take() { + pooled.conn = Some(conn); + pooled.fresh = false; + } + } +} + +#[derive(Debug, Error)] +pub(crate) enum AcquireError { + #[error("{0}")] + Connect(#[from] AuthError), + #[error("{0}")] + Startup(#[from] postgres_client::Error), +} + +impl UserFacingError for AcquireError { + fn to_string_client(&self) -> String { + match self { + AcquireError::Connect(e) => e.to_string_client(), + AcquireError::Startup(e) => e.to_string(), + } + } +} + +impl ReportableError for AcquireError { + fn get_error_kind(&self) -> ErrorKind { + match self { + AcquireError::Connect(e) => e.get_error_kind(), + AcquireError::Startup(_) => ErrorKind::Postgres, + } + } +} + +async fn drain_fresh_startup(conn: &mut ComputeConnection) -> Result<(), postgres_client::Error> { + loop { + let msg = conn + .stream + .try_next() + .await + .map_err(postgres_client::Error::io)?; + + match msg { + Some(Message::ParameterStatus(_)) + | Some(Message::BackendKeyData(_)) + | Some(Message::NoticeResponse(_)) => {} + Some(Message::ReadyForQuery(_)) => return Ok(()), + Some(Message::ErrorResponse(body)) => return Err(postgres_client::Error::db(body)), + Some(_) => return Err(postgres_client::Error::unexpected_message()), + None => return Err(postgres_client::Error::closed()), + } + } +} + +async fn write_simple_query(stream: &mut S, query: &str) -> std::io::Result<()> +where + S: AsyncWrite + Unpin, +{ + stream.write_u8(b'Q').await?; + stream.write_u32((query.len() + 5) as u32).await?; + stream.write_all(query.as_bytes()).await?; + stream.write_u8(0).await?; + stream.flush().await +} + +async fn drain_simple_query(stream: &mut S) -> Result +where + S: AsyncRead + Unpin, +{ + let mut buf = Vec::new(); + loop { + let (tag, body) = pqproto::read_message(stream, &mut buf, 65536) + .await + .map_err(postgres_client::Error::io)?; + + match tag { + b'Z' if body.len() == 1 => return Ok(body[0]), + b'Z' => return Err(postgres_client::Error::unexpected_message()), + b'C' | b'D' | b'I' | b'N' | b'S' | b'T' => {} + b'E' => return Err(postgres_client::Error::unexpected_message()), + _ => return Err(postgres_client::Error::unexpected_message()), + } + } +} + +async fn reset_raw_session( + stream: &mut MaybeRustlsStream, + reset_query: &str, +) -> Result<(), postgres_client::Error> { + write_simple_query(stream, reset_query) + .await + .map_err(postgres_client::Error::io)?; + let status = drain_simple_query(stream).await?; + if status == b'I' { + Ok(()) + } else { + Err(postgres_client::Error::unexpected_message()) + } +} + +pub(crate) async fn reset_session( + conn: ComputeConnection, + reset_query: &str, +) -> Result { + let ComputeConnection { + stream, + aux, + hostname, + ssl_mode, + socket_addr, + guage, + } = conn; + + let mut raw_stream = stream.into_framed().into_inner(); + reset_raw_session(&mut raw_stream, reset_query).await?; + + Ok(ComputeConnection { + stream: StartupStream::new(raw_stream), + aux, + hostname, + ssl_mode, + socket_addr, + guage, + }) +} + +pub(crate) struct TcpPoolManager { + inner: Arc, +} + +impl TcpPoolManager { + pub(crate) fn set_startup_params(&self, key: &TcpPoolKey, params: Vec<(Box, Box)>) { + self.inner + .startup_params + .insert(key.clone(), Arc::new(params)); + } + + pub(crate) fn get_startup_params( + &self, + key: &TcpPoolKey, + ) -> Option, Box)>>> { + self.inner.startup_params.get(key).map(|v| v.clone()) + } + + pub(crate) async fn gc_worker(&self) -> anyhow::Result { + let mut interval = tokio::time::interval(Duration::from_secs(30)); + loop { + interval.tick().await; + self.gc_one_shard(); + } + } + + fn gc_one_shard(&self) { + let shards = self.inner.pools.shards(); + if shards.is_empty() { + return; + } + + let shard_idx = rand::rng().random_range(0..shards.len()); + let mut shard = shards[shard_idx].write(); + shard.retain(|(_, pool)| { + let state = pool.state(); + let keep = state.connections > 0; + if !keep { + debug!("tcp pool: dropping empty keyed pool"); + } + keep + }); + } + + async fn get_or_create_pool( + &self, + key: TcpPoolKey, + mgr: ComputeConnectionManager, + cfg: TcpPoolConfig, + ) -> KeyedPool { + if let Some(pool) = self.inner.pools.get(&key).map(|p| p.clone()) { + return pool; + } + + let pool = Pool::builder() + .max_size(cfg.max_conns_per_key as u32) + .idle_timeout(Some(cfg.idle_timeout)) + .connection_timeout(Duration::from_secs(365 * 24 * 60 * 60)) + .build_unchecked(mgr); + + self.inner + .pools + .entry(key) + .or_insert_with(|| pool.clone()) + .clone() + } + + pub(crate) async fn prepare_reacquire( + &self, + config: &TcpPoolConfig, + key: TcpPoolKey, + ctx: RequestContext, + proxy_config: &'static ProxyConfig, + cplane: crate::control_plane::client::ControlPlaneClient, + user_info: ComputeUserInfo, + auth_info: AuthInfo, + ) -> TcpPoolReacquire { + let reset_query = Arc::::from(auth_info.tcp_pool_session_reset_query()); + let backend = Arc::new(Backend::ControlPlane(MaybeOwned::Owned(cplane), user_info)); + let mgr = ComputeConnectionManager { + ctx, + config: proxy_config, + backend, + auth_info, + }; + + self.get_or_create_pool(key.clone(), mgr, *config).await; + + TcpPoolReacquire { key, reset_query } + } + + pub(crate) async fn acquire_or_connect( + &self, + config: &TcpPoolConfig, + key: TcpPoolKey, + ctx: RequestContext, + proxy_config: &'static ProxyConfig, + cplane: crate::control_plane::client::ControlPlaneClient, + user_info: ComputeUserInfo, + auth_info: AuthInfo, + ) -> Result<(ComputeConnection, Option, bool), AcquireError> { + if !config.enabled { + let backend = Backend::ControlPlane(MaybeOwned::Owned(cplane), user_info); + let conn = connect_auth::connect_to_compute_and_auth( + &ctx, + proxy_config, + &backend, + auth_info, + connect_compute::TlsNegotiation::Postgres, + ) + .await?; + return Ok((conn, None, false)); + } + + let reset_query = Arc::::from(auth_info.tcp_pool_session_reset_query()); + let backend = Arc::new(Backend::ControlPlane(MaybeOwned::Owned(cplane), user_info)); + let mgr = ComputeConnectionManager { + ctx, + config: proxy_config, + backend, + auth_info, + }; + + let pool = self.get_or_create_pool(key.clone(), mgr, *config).await; + let mut pooled = pool.get_owned().await.map_err(|e| match e { + bb8::RunError::User(e) => AcquireError::Connect(e), + bb8::RunError::TimedOut => unreachable!("connection_timeout is effectively disabled"), + })?; + + let was_reused = !pooled.fresh; + let conn = pooled + .conn + .take() + .expect("pooled slot must hold a connection"); + Ok(( + conn, + Some(TcpPoolCheckout { + key, + reset_query, + pooled: Some(pooled), + }), + was_reused, + )) + } + + pub(crate) async fn reacquire( + &self, + key: TcpPoolKey, + reset_query: Arc, + ) -> Result<(ComputeConnection, TcpPoolCheckout), AcquireError> { + let pool = self + .inner + .pools + .get(&key) + .map(|p| p.clone()) + .expect("pool key must exist before re-acquire"); + + let mut pooled = pool.get_owned().await.map_err(|e| match e { + bb8::RunError::User(e) => AcquireError::Connect(e), + bb8::RunError::TimedOut => unreachable!("connection_timeout is effectively disabled"), + })?; + + let mut conn = pooled + .conn + .take() + .expect("pooled slot must hold a connection"); + if pooled.fresh { + drain_fresh_startup(&mut conn).await?; + pooled.fresh = false; + } + let conn = reset_session(conn, &reset_query).await?; + Ok(( + conn, + TcpPoolCheckout { + key, + reset_query, + pooled: Some(pooled), + }, + )) + } +} + +static MANAGER: Lazy = Lazy::new(|| TcpPoolManager { + inner: Arc::new(Inner::default()), +}); + +pub(crate) fn manager() -> &'static TcpPoolManager { + &MANAGER +} diff --git a/proxy/src/types.rs b/proxy/src/types.rs index 43b8dc5b29af..87e1bf5e7b72 100644 --- a/proxy/src/types.rs +++ b/proxy/src/types.rs @@ -67,6 +67,11 @@ const POOLER_SUFFIX: &str = "-pooler"; pub(crate) const LOCAL_PROXY_SUFFIX: &str = "-local-proxy"; impl EndpointId { + #[must_use] + pub(crate) fn is_pooler(&self) -> bool { + self.as_ref().ends_with(POOLER_SUFFIX) + } + #[must_use] fn normalize_str(&self) -> &str { if let Some(stripped) = self.as_ref().strip_suffix(POOLER_SUFFIX) { diff --git a/scripts/run_tcp_pool_benchmark_matrix.sh b/scripts/run_tcp_pool_benchmark_matrix.sh new file mode 100755 index 000000000000..bc46898c1c96 --- /dev/null +++ b/scripts/run_tcp_pool_benchmark_matrix.sh @@ -0,0 +1,186 @@ +#!/usr/bin/env bash +set -euo pipefail + +# External orchestration wrapper for scripts/tcp_pool_benchmark.py. +# +# The Python harness measures one or more pgbench cells. This wrapper owns +# environment reset between cells, so it can run from a different host than the +# proxy/PgCat processes. Configure RESET_*_CMD with local, ssh, or SSM commands. + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +export PGSSLROOTCERT="${PGSSLROOTCERT:-./server.crt}" +export TCP_POOL_BENCH_DOCKER_CMD="${TCP_POOL_BENCH_DOCKER_CMD:-sudo -n docker}" + +PROXY_POOL_URL="${PROXY_POOL_URL:-postgresql://postgres:password@comp1-pooler.local.neon.build:4432/postgres?sslmode=verify-full}" +PROXY_PGCAT_URL="${PROXY_PGCAT_URL:-postgresql://postgres:password@pgcat.local.neon.build:4432/postgres?sslmode=verify-full}" +PROXY_PASSTHROUGH_URL="${PROXY_PASSTHROUGH_URL:-postgresql://postgres:password@comp1.local.neon.build:4432/postgres?sslmode=verify-full}" +BACKEND_URL="${BACKEND_URL:-$PROXY_PASSTHROUGH_URL}" + +RUN_ID="${RUN_ID:-$(date +%Y%m%d-%H%M%S)}" +OUT_DIR="${OUT_DIR:-benchmark-results/tcp-pool-sanity-$RUN_ID}" +TARGETS="${TARGETS:-proxy_pool proxy_pgcat proxy_passthrough}" +WORKLOADS="${WORKLOADS:-readonly_steady readonly_connect}" +STEADY_CONCURRENCIES="${STEADY_CONCURRENCIES:-10 50 100 250}" +CONNECT_CONCURRENCIES="${CONNECT_CONCURRENCIES:-1 5 10 25 50 100}" +STEADY_DURATION="${STEADY_DURATION:-60}" +CONNECT_DURATION="${CONNECT_DURATION:-30}" +REPS="${REPS:-3}" +IDLE_WAITS="${IDLE_WAITS:-5,15,30,45,60,75}" +MAX_JOBS="${MAX_JOBS:-64}" +SAMPLE_INTERVAL="${SAMPLE_INTERVAL:-1.0}" +RESET_WAIT_SECONDS="${RESET_WAIT_SECONDS:-5}" +CLEAN_IDLE_PGBENCH="${CLEAN_IDLE_PGBENCH:-1}" +PRINT_BACKEND_BASELINE="${PRINT_BACKEND_BASELINE:-1}" +WAIT_TARGET_SELECT="${WAIT_TARGET_SELECT:-0}" + +# Optional reset commands. Examples: +# RESET_PROXY_POOL_CMD="ssh ec2-proxy sudo systemctl restart neon-proxy" +# RESET_PROXY_PGCAT_CMD="ssh ec2-proxy sudo systemctl restart neon-proxy && ssh ec2-pgcat sudo systemctl restart pgcat" +# RESET_PROXY_PASSTHROUGH_CMD="ssh ec2-proxy sudo systemctl restart neon-proxy-passthrough" +RESET_PROXY_POOL_CMD="${RESET_PROXY_POOL_CMD:-}" +RESET_PROXY_PGCAT_CMD="${RESET_PROXY_PGCAT_CMD:-}" +RESET_PROXY_PASSTHROUGH_CMD="${RESET_PROXY_PASSTHROUGH_CMD:-}" + +# Extra args are parsed as shell words. By default, sample this local benchmark +# setup: proxy and PgCat processes plus the Dockerized compute node. +EXTRA_BENCH_ARGS="${EXTRA_BENCH_ARGS:---resource 'proxy=pgrep:target/debug/proxy --auth-backend cplane-v1' --resource 'pgcat=pgrep:/home/charles/db_final_project/pgcat/target/debug/pgcat' --resource compute=docker:compute-node-1}" + +target_url() { + case "$1" in + proxy_pool) printf '%s\n' "$PROXY_POOL_URL" ;; + proxy_pgcat) printf '%s\n' "$PROXY_PGCAT_URL" ;; + proxy_passthrough) printf '%s\n' "$PROXY_PASSTHROUGH_URL" ;; + *) echo "unknown target: $1" >&2; return 1 ;; + esac +} + +target_reset_cmd() { + case "$1" in + proxy_pool) printf '%s\n' "$RESET_PROXY_POOL_CMD" ;; + proxy_pgcat) printf '%s\n' "$RESET_PROXY_PGCAT_CMD" ;; + proxy_passthrough) printf '%s\n' "$RESET_PROXY_PASSTHROUGH_CMD" ;; + *) echo "unknown target: $1" >&2; return 1 ;; + esac +} + +workload_concurrencies() { + case "$1" in + readonly_connect) printf '%s\n' "$CONNECT_CONCURRENCIES" ;; + *) printf '%s\n' "$STEADY_CONCURRENCIES" ;; + esac +} + +workload_duration() { + case "$1" in + readonly_connect) printf '%s\n' "$CONNECT_DURATION" ;; + *) printf '%s\n' "$STEADY_DURATION" ;; + esac +} + +run_reset() { + local target="$1" + local cmd + cmd="$(target_reset_cmd "$target")" + if [[ -n "$cmd" ]]; then + echo "$(date -Is) reset target=$target" + bash -lc "$cmd" + fi + if [[ "$RESET_WAIT_SECONDS" != "0" ]]; then + sleep "$RESET_WAIT_SECONDS" + fi +} + +wait_target_select() { + local url="$1" + if [[ "$WAIT_TARGET_SELECT" != "1" ]]; then + return + fi + local attempt + for attempt in {1..60}; do + if psql "$url" -X -q -At -c "select 1" >/dev/null 2>&1; then + return + fi + sleep 1 + done + echo "target did not pass select 1 readiness: $url" >&2 + return 1 +} + +clean_backend_idle_pgbench() { + if [[ "$CLEAN_IDLE_PGBENCH" != "1" ]]; then + return + fi + psql "$BACKEND_URL" -X -q -At -c " + select count(pg_terminate_backend(pid)) + from pg_stat_activity + where datname = current_database() + and pid <> pg_backend_pid() + and application_name = 'pgbench' + and state = 'idle'; + " >/dev/null +} + +print_backend_baseline() { + if [[ "$PRINT_BACKEND_BASELINE" != "1" ]]; then + return + fi + echo "$(date -Is) backend baseline" + psql "$BACKEND_URL" -X -q -P pager=off -c " + select state, application_name, client_addr, count(*) + from pg_stat_activity + where datname = current_database() + and pid <> pg_backend_pid() + group by 1,2,3 + order by count(*) desc, 1,2,3; + " +} + +run_cell() { + local target="$1" + local workload="$2" + local concurrency="$3" + local rep="$4" + local url duration + url="$(target_url "$target")" + duration="$(workload_duration "$workload")" + + run_reset "$target" + wait_target_select "$url" + clean_backend_idle_pgbench + print_backend_baseline + + echo "$(date -Is) benchmark target=$target workload=$workload concurrency=$concurrency rep=$rep" + + local -a extra_args=() + if [[ -n "$EXTRA_BENCH_ARGS" ]]; then + eval "extra_args=( $EXTRA_BENCH_ARGS )" + fi + + python3 scripts/tcp_pool_benchmark.py \ + --target "$target=$url" \ + --backend-url "$BACKEND_URL" \ + --workload "$workload" \ + --concurrency "$concurrency" \ + --duration "$duration" \ + --reps 1 \ + --rep-start "$rep" \ + --idle-waits "$IDLE_WAITS" \ + --max-jobs "$MAX_JOBS" \ + --sample-interval "$SAMPLE_INTERVAL" \ + --out-dir "$OUT_DIR" \ + "${extra_args[@]}" +} + +for target in $TARGETS; do + for workload in $WORKLOADS; do + for concurrency in $(workload_concurrencies "$workload"); do + for rep in $(seq 1 "$REPS"); do + run_cell "$target" "$workload" "$concurrency" "$rep" + done + done + done +done + +echo "wrote $OUT_DIR/summary.csv" diff --git a/scripts/setup_env.sh b/scripts/setup_env.sh new file mode 100755 index 000000000000..0784b0617fff --- /dev/null +++ b/scripts/setup_env.sh @@ -0,0 +1,8 @@ +export PGSSLROOTCERT=./server.crt +export PROXY_POOL_URL='postgresql://postgres:password@comp1-pooler.local.neon.build:4432/postgres?sslmode=verify-full' +export PROXY_PGCAT_URL='postgresql://postgres:password@pgcat.local.neon.build:4432/postgres?sslmode=verify-full' +export PROXY_PASSTHROUGH_URL='postgresql://postgres:password@comp1.local.neon.build:4432/postgres?sslmode=verify-full' +export BACKEND_URL="$PROXY_PASSTHROUGH_URL" +export RESET_PROXY_POOL_CMD="pkill -x proxy || true; sleep 1; NEON_INTERNAL_CA_FILE=/tmp/compute-ca.crt RUST_LOG=proxy LOGFMT=text cargo run -p proxy --bin proxy --features testing -- --auth-backend cplane-v1 --auth-endpoint http://127.0.0.1:3010 --tcp-pool-enabled true --tcp-pool-mode transaction --tcp-pool-max-conns-per-key 50 --tcp-pool-idle-timeout 10s --redis-auth-type plain --redis-plain 'redis://127.0.0.1:6379' -c server.crt -k server.key --endpoint-rps-limit 5000@1s --endpoint-rps-limit 5000@60s --endpoint-rps-limit 5000@600s --wake-compute-limit 5000@1s --wake-compute-limit 5000@60s --wake-compute-limit 5000@600s >/tmp/neon-proxy-bench.log 2>&1 &" +export RESET_PROXY_PASSTHROUGH_CMD="$RESET_PROXY_POOL_CMD" +export RESET_PROXY_PGCAT_CMD="pkill -x pgcat || true; sleep 1; nohup /home/charles/db_final_project/pgcat/target/debug/pgcat /home/charles/db_final_project/pgcat/pgcat.local.toml >/tmp/pgcat-bench.log 2>&1 &" \ No newline at end of file diff --git a/scripts/tcp_pool_benchmark.py b/scripts/tcp_pool_benchmark.py new file mode 100755 index 000000000000..75ec988349fc --- /dev/null +++ b/scripts/tcp_pool_benchmark.py @@ -0,0 +1,794 @@ +#!/usr/bin/env python3 +"""Benchmark proxy TCP pooling configurations with pgbench. + +The harness is intentionally deployment-agnostic: pass each PostgreSQL target +URL explicitly, and optionally pass local PIDs or Docker containers to sample +CPU/memory while pgbench runs. +""" + +from __future__ import annotations + +import argparse +import csv +import glob +import math +import os +import re +import shlex +import shutil +import statistics +import subprocess +import sys +import threading +import time +from collections import defaultdict +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path + + +WORKLOADS = { + "readonly_steady": ["-S"], + "readonly_connect": ["-S", "-C"], + "tpcb_steady": [], +} + +RUN_FIELDS = [ + "target", + "workload", + "concurrency", + "rep", + "duration_s", + "exit_code", + "transactions", + "failed_transactions", + "tps", + "client_handshakes_per_sec", + "latency_avg_ms", + "p50_ms", + "p95_ms", + "p99_ms", + "max_ms", + "initial_or_avg_connection_ms", + "backend_before_total", + "backend_during_max_total", + "backend_after_final_total", + "pgbench_stdout", + "pgbench_stderr", +] + +BACKEND_FIELDS = [ + "target", + "workload", + "concurrency", + "rep", + "phase", + "ts", + "total", + "active", + "idle", + "idle_in_transaction", + "other", + "error", +] + +RESOURCE_FIELDS = [ + "target", + "workload", + "concurrency", + "rep", + "resource", + "kind", + "ts", + "cpu_pct", + "rss_mib", + "vsz_mib", + "mem_mib", + "raw", + "error", +] + + +@dataclass(frozen=True) +class Target: + name: str + url: str + + +@dataclass(frozen=True) +class Resource: + name: str + kind: str + ident: str + + +def utc_now() -> str: + return datetime.now(timezone.utc).isoformat() + + +def parse_csv_ints(raw: str) -> list[int]: + if raw.strip() == "": + return [] + values = [] + for part in raw.split(","): + part = part.strip() + if part: + values.append(int(part)) + if not values: + raise argparse.ArgumentTypeError("expected at least one integer") + return values + + +def parse_target(raw: str) -> Target: + name, sep, url = raw.partition("=") + if not sep or not name or not url: + raise argparse.ArgumentTypeError("--target must be NAME=POSTGRES_URL") + return Target(name=name, url=url) + + +def parse_resource(raw: str) -> Resource: + name, sep, spec = raw.partition("=") + kind, sep2, ident = spec.partition(":") + if not sep or not sep2 or not name or not kind or not ident: + raise argparse.ArgumentTypeError( + "--resource must be NAME=pid:PID, NAME=pgrep:PATTERN, or NAME=docker:CONTAINER" + ) + if kind not in {"pid", "pgrep", "docker"}: + raise argparse.ArgumentTypeError("resource kind must be pid, pgrep, or docker") + return Resource(name=name, kind=kind, ident=ident) + + +def open_csv(path: Path, fields: list[str]): + exists = path.exists() and path.stat().st_size > 0 + f = path.open("a", newline="") + writer = csv.DictWriter(f, fieldnames=fields) + if not exists: + writer.writeheader() + f.flush() + return f, writer + + +def run_cmd(args: list[str], env: dict[str, str] | None = None) -> subprocess.CompletedProcess[str]: + return subprocess.run(args, text=True, capture_output=True, env=env) + + +def pgrep(pattern: str) -> list[int]: + proc = run_cmd(["pgrep", "-f", pattern]) + if proc.returncode not in {0, 1}: + return [] + this_pid = os.getpid() + parent_pid = os.getppid() + pids = [] + for line in proc.stdout.splitlines(): + try: + pid = int(line.strip()) + except ValueError: + continue + if pid not in {this_pid, parent_pid}: + pids.append(pid) + return pids + + +def backend_counts(backend_url: str) -> dict[str, str]: + sql = """ + select coalesce(state, 'other') as state, count(*) + from pg_stat_activity + where datname = current_database() + and pid <> pg_backend_pid() + group by 1 + """ + env = os.environ.copy() + env["PGAPPNAME"] = "tcp_pool_bench_backend_sampler" + proc = run_cmd(["psql", "-X", "-q", "-At", "-F", ",", "-c", sql, backend_url], env=env) + row = { + "total": "0", + "active": "0", + "idle": "0", + "idle_in_transaction": "0", + "other": "0", + "error": "", + } + if proc.returncode != 0: + row["error"] = (proc.stderr or proc.stdout).strip().replace("\n", " ")[:500] + return row + + total = 0 + other = 0 + for line in proc.stdout.splitlines(): + state, _, count_raw = line.partition(",") + try: + count = int(count_raw) + except ValueError: + continue + total += count + key = state.strip().replace(" ", "_") + if key in {"active", "idle", "idle_in_transaction"}: + row[key] = str(int(row[key]) + count) + else: + other += count + row["total"] = str(total) + row["other"] = str(other) + return row + + +def parse_mib(raw: str) -> float | None: + match = re.match(r"\s*([0-9.]+)\s*([KMGT]?i?B)", raw) + if not match: + return None + value = float(match.group(1)) + unit = match.group(2) + factor = { + "B": 1 / 1024 / 1024, + "KB": 1 / 1024, + "KiB": 1 / 1024, + "MB": 1, + "MiB": 1, + "GB": 1024, + "GiB": 1024, + "TB": 1024 * 1024, + "TiB": 1024 * 1024, + }.get(unit) + return value * factor if factor is not None else None + + +def resource_sample(resource: Resource) -> dict[str, str]: + row = { + "cpu_pct": "", + "rss_mib": "", + "vsz_mib": "", + "mem_mib": "", + "raw": "", + "error": "", + } + if resource.kind in {"pid", "pgrep"}: + ident = resource.ident + if resource.kind == "pgrep": + pids = pgrep(resource.ident) + if not pids: + row["error"] = f"no process matched {resource.ident!r}" + return row + ident = str(max(pids)) + proc = run_cmd(["ps", "-p", ident, "-o", "pcpu=", "-o", "rss=", "-o", "vsz="]) + if proc.returncode != 0 or not proc.stdout.strip(): + row["error"] = (proc.stderr or proc.stdout).strip().replace("\n", " ")[:500] + return row + parts = proc.stdout.split() + row["raw"] = " ".join(parts) + if len(parts) >= 3: + row["cpu_pct"] = parts[0] + row["rss_mib"] = f"{int(parts[1]) / 1024:.3f}" + row["vsz_mib"] = f"{int(parts[2]) / 1024:.3f}" + return row + + docker_cmd = shlex.split(os.environ.get("TCP_POOL_BENCH_DOCKER_CMD", "docker")) + proc = run_cmd([ + *docker_cmd, + "stats", + "--no-stream", + "--format", + "{{.CPUPerc}},{{.MemUsage}}", + resource.ident, + ]) + if proc.returncode != 0: + row["error"] = (proc.stderr or proc.stdout).strip().replace("\n", " ")[:500] + return row + raw = proc.stdout.strip() + row["raw"] = raw + cpu_raw, _, mem_raw = raw.partition(",") + row["cpu_pct"] = cpu_raw.strip().rstrip("%") + mem_mib = parse_mib(mem_raw.split("/")[0]) + if mem_mib is not None: + row["mem_mib"] = f"{mem_mib:.3f}" + return row + + +class Sampler: + def __init__( + self, + *, + target: str, + workload: str, + concurrency: int, + rep: int, + backend_url: str | None, + resources: list[Resource], + interval_s: float, + backend_writer: csv.DictWriter, + resource_writer: csv.DictWriter, + files: list, + ) -> None: + self.target = target + self.workload = workload + self.concurrency = concurrency + self.rep = rep + self.backend_url = backend_url + self.resources = resources + self.interval_s = interval_s + self.backend_writer = backend_writer + self.resource_writer = resource_writer + self.files = files + self.stop_event = threading.Event() + self.thread: threading.Thread | None = None + self.backend_samples: list[dict[str, str]] = [] + + def write_backend(self, phase: str) -> dict[str, str] | None: + if not self.backend_url: + return None + counts = backend_counts(self.backend_url) + row = { + "target": self.target, + "workload": self.workload, + "concurrency": self.concurrency, + "rep": self.rep, + "phase": phase, + "ts": utc_now(), + **counts, + } + self.backend_writer.writerow(row) + self.backend_samples.append(row) + for f in self.files: + f.flush() + return row + + def write_resources(self) -> None: + for resource in self.resources: + sample = resource_sample(resource) + row = { + "target": self.target, + "workload": self.workload, + "concurrency": self.concurrency, + "rep": self.rep, + "resource": resource.name, + "kind": resource.kind, + "ts": utc_now(), + **sample, + } + self.resource_writer.writerow(row) + for f in self.files: + f.flush() + + def start(self) -> None: + self.write_backend("before") + + def loop() -> None: + while not self.stop_event.wait(self.interval_s): + self.write_backend("during") + self.write_resources() + + self.thread = threading.Thread(target=loop, daemon=True) + self.thread.start() + + def stop(self) -> None: + self.stop_event.set() + if self.thread: + self.thread.join() + + def after_idle(self, waits: list[int]) -> None: + elapsed = 0 + for wait_s in waits: + sleep_s = max(0, wait_s - elapsed) + if sleep_s: + time.sleep(sleep_s) + elapsed = wait_s + self.write_backend(f"after_{wait_s}s") + self.write_resources() + + def backend_before_total(self) -> str: + for sample in self.backend_samples: + if sample["phase"] == "before": + return sample["total"] + return "" + + def backend_during_max_total(self) -> str: + values = [int(s["total"]) for s in self.backend_samples if s["phase"] == "during" and s["total"]] + return str(max(values)) if values else "" + + def backend_after_final_total(self) -> str: + after = [s for s in self.backend_samples if s["phase"].startswith("after_")] + return after[-1]["total"] if after else "" + + +def parse_pgbench_stdout(stdout: str) -> dict[str, str]: + patterns = { + "transactions": r"number of transactions actually processed:\s+([0-9]+)", + "failed_transactions": r"number of failed transactions:\s+([0-9]+)", + "latency_avg_ms": r"latency average =\s+([0-9.]+)\s+ms", + "initial_or_avg_connection_ms": ( + r"(?:initial connection time|average connection time) =\s+([0-9.]+)\s+ms" + ), + "tps": r"tps =\s+([0-9.]+)", + } + parsed = {key: "" for key in patterns} + for key, pattern in patterns.items(): + match = re.search(pattern, stdout) + if match: + parsed[key] = match.group(1) + return parsed + + +def percentile(values: list[int], pct: float) -> float: + if not values: + return math.nan + values.sort() + idx = max(0, min(len(values) - 1, math.ceil((pct / 100) * len(values)) - 1)) + return values[idx] / 1000 + + +def parse_latency_logs(prefix: Path) -> dict[str, str]: + lat_us: list[int] = [] + for path in glob.glob(str(prefix) + ".*"): + with open(path) as f: + for line in f: + parts = line.split() + if len(parts) < 3: + continue + try: + lat_us.append(int(parts[2])) + except ValueError: + continue + if not lat_us: + return {"p50_ms": "", "p95_ms": "", "p99_ms": "", "max_ms": ""} + return { + "p50_ms": f"{percentile(lat_us, 50):.3f}", + "p95_ms": f"{percentile(lat_us, 95):.3f}", + "p99_ms": f"{percentile(lat_us, 99):.3f}", + "max_ms": f"{max(lat_us) / 1000:.3f}", + } + + +def remove_latency_logs(prefix: Path) -> None: + for path in glob.glob(str(prefix) + ".*"): + try: + os.remove(path) + except OSError: + pass + + +def run_pgbench( + *, + target: Target, + workload: str, + concurrency: int, + rep: int, + duration_s: int, + max_jobs: int, + out_dir: Path, + keep_logs: bool, +) -> tuple[int, dict[str, str], str, str]: + run_name = f"{target.name}_{workload}_c{concurrency}_r{rep}" + prefix = out_dir / "pgbench_logs" / run_name + prefix.parent.mkdir(parents=True, exist_ok=True) + remove_latency_logs(prefix) + + flags = WORKLOADS[workload] + args = [ + "pgbench", + target.url, + "-n", + *flags, + "-c", + str(concurrency), + "-j", + str(max(1, min(concurrency, max_jobs))), + "-T", + str(duration_s), + "-l", + "--log-prefix", + str(prefix), + ] + proc = run_cmd(args) + parsed = parse_pgbench_stdout(proc.stdout) + parsed.update(parse_latency_logs(prefix)) + if workload.endswith("_connect"): + parsed["client_handshakes_per_sec"] = parsed.get("tps", "") + else: + parsed["client_handshakes_per_sec"] = "" + + stdout_path = out_dir / "pgbench_stdout" / f"{run_name}.out" + stderr_path = out_dir / "pgbench_stdout" / f"{run_name}.err" + stdout_path.parent.mkdir(parents=True, exist_ok=True) + stdout_path.write_text(proc.stdout) + stderr_path.write_text(proc.stderr) + + if not keep_logs: + remove_latency_logs(prefix) + + return proc.returncode, parsed, str(stdout_path), str(stderr_path) + + +def to_float(raw: str | None) -> float | None: + if raw is None or raw == "": + return None + try: + return float(raw) + except ValueError: + return None + + +def mean_or_blank(values: list[float]) -> str: + return f"{statistics.mean(values):.3f}" if values else "" + + +def pstdev_or_blank(values: list[float]) -> str: + return f"{statistics.pstdev(values):.3f}" if len(values) > 1 else ("0.000" if values else "") + + +def write_run_summary(out_dir: Path) -> None: + runs_path = out_dir / "runs.csv" + if not runs_path.exists(): + return + + backend_max: dict[tuple[str, str, int, int], int] = defaultdict(int) + backend_final: dict[tuple[str, str, int, int], int] = {} + backend_path = out_dir / "backend_connections.csv" + if backend_path.exists(): + with backend_path.open() as f: + for row in csv.DictReader(f): + key = (row["target"], row["workload"], int(row["concurrency"]), int(row["rep"])) + total = int(row["total"] or 0) + if row["phase"] == "during": + backend_max[key] = max(backend_max[key], total) + if row["phase"].startswith("after_"): + backend_final[key] = total + + grouped: dict[tuple[str, str, int], list[dict[str, str]]] = defaultdict(list) + with runs_path.open() as f: + for row in csv.DictReader(f): + grouped[(row["target"], row["workload"], int(row["concurrency"]))].append(row) + + fields = [ + "target", + "workload", + "concurrency", + "runs", + "failed_runs", + "tps_mean", + "tps_std", + "client_handshakes_per_sec_mean", + "latency_avg_ms_mean", + "p50_ms_mean", + "p95_ms_mean", + "p99_ms_mean", + "backend_during_max_total_mean", + "backend_after_final_total_mean", + ] + with (out_dir / "summary.csv").open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fields) + writer.writeheader() + for (target, workload, concurrency), rows in sorted(grouped.items()): + key_prefix = (target, workload, concurrency) + tps = [x for x in (to_float(r["tps"]) for r in rows) if x is not None] + handshakes = [ + x for x in (to_float(r["client_handshakes_per_sec"]) for r in rows) if x is not None + ] + lat_avg = [x for x in (to_float(r["latency_avg_ms"]) for r in rows) if x is not None] + p50 = [x for x in (to_float(r["p50_ms"]) for r in rows) if x is not None] + p95 = [x for x in (to_float(r["p95_ms"]) for r in rows) if x is not None] + p99 = [x for x in (to_float(r["p99_ms"]) for r in rows) if x is not None] + bmax = [ + float(backend_max[(target, workload, concurrency, int(r["rep"]))]) + for r in rows + if (target, workload, concurrency, int(r["rep"])) in backend_max + ] + bfinal = [ + float(backend_final[(target, workload, concurrency, int(r["rep"]))]) + for r in rows + if (target, workload, concurrency, int(r["rep"])) in backend_final + ] + writer.writerow({ + "target": target, + "workload": workload, + "concurrency": concurrency, + "runs": len(rows), + "failed_runs": sum(1 for r in rows if r["exit_code"] != "0"), + "tps_mean": mean_or_blank(tps), + "tps_std": pstdev_or_blank(tps), + "client_handshakes_per_sec_mean": mean_or_blank(handshakes), + "latency_avg_ms_mean": mean_or_blank(lat_avg), + "p50_ms_mean": mean_or_blank(p50), + "p95_ms_mean": mean_or_blank(p95), + "p99_ms_mean": mean_or_blank(p99), + "backend_during_max_total_mean": mean_or_blank(bmax), + "backend_after_final_total_mean": mean_or_blank(bfinal), + }) + + +def write_resource_summary(out_dir: Path) -> None: + path = out_dir / "resources.csv" + if not path.exists(): + return + grouped: dict[tuple[str, str, int, str, str, int], list[dict[str, str]]] = defaultdict(list) + with path.open() as f: + for row in csv.DictReader(f): + key = ( + row["target"], + row["workload"], + int(row["concurrency"]), + row["resource"], + row["kind"], + int(row["rep"]), + ) + grouped[key].append(row) + + per_rep = [] + for key, rows in grouped.items(): + cpus = [x for x in (to_float(r["cpu_pct"]) for r in rows) if x is not None] + rss = [x for x in (to_float(r["rss_mib"]) for r in rows) if x is not None] + mem = [x for x in (to_float(r["mem_mib"]) for r in rows) if x is not None] + per_rep.append((*key, max(cpus) if cpus else None, max(rss) if rss else None, max(mem) if mem else None)) + + final: dict[tuple[str, str, int, str, str], list[tuple[float | None, float | None, float | None]]] = defaultdict(list) + for target, workload, concurrency, resource, kind, _rep, cpu, rss, mem in per_rep: + final[(target, workload, concurrency, resource, kind)].append((cpu, rss, mem)) + + fields = [ + "target", + "workload", + "concurrency", + "resource", + "kind", + "cpu_pct_max_mean", + "rss_mib_max_mean", + "mem_mib_max_mean", + ] + with (out_dir / "resource_summary.csv").open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fields) + writer.writeheader() + for (target, workload, concurrency, resource, kind), rows in sorted(final.items()): + cpus = [x[0] for x in rows if x[0] is not None] + rss = [x[1] for x in rows if x[1] is not None] + mem = [x[2] for x in rows if x[2] is not None] + writer.writerow({ + "target": target, + "workload": workload, + "concurrency": concurrency, + "resource": resource, + "kind": kind, + "cpu_pct_max_mean": mean_or_blank(cpus), + "rss_mib_max_mean": mean_or_blank(rss), + "mem_mib_max_mean": mean_or_blank(mem), + }) + + +def validate_tools(args: argparse.Namespace) -> None: + missing = [tool for tool in ["pgbench"] if shutil.which(tool) is None] + if args.backend_url and shutil.which("psql") is None: + missing.append("psql") + if any(r.kind == "docker" for r in args.resource) and shutil.which("docker") is None: + missing.append("docker") + if missing: + sys.exit(f"missing required tool(s): {', '.join(sorted(set(missing)))}") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--target", + action="append", + required=True, + type=parse_target, + help="Target to benchmark, as NAME=POSTGRES_URL. Repeat for proxy_pool, proxy_pgcat, proxy_direct.", + ) + parser.add_argument( + "--backend-url", + help="Direct compute PostgreSQL URL used for pg_stat_activity backend connection sampling.", + ) + parser.add_argument("--out-dir", default="benchmark-results/tcp-pool", type=Path) + parser.add_argument("--duration", default=60, type=int, help="Seconds per pgbench run.") + parser.add_argument("--concurrency", default=[10, 50, 100, 250], type=parse_csv_ints) + parser.add_argument("--reps", default=3, type=int) + parser.add_argument( + "--rep-start", + default=1, + type=int, + help="First repetition number to write. Useful when an external wrapper runs one rep per invocation.", + ) + parser.add_argument( + "--workload", + action="append", + choices=sorted(WORKLOADS), + help="Workload to run. Defaults to readonly_steady and readonly_connect.", + ) + parser.add_argument("--max-jobs", default=64, type=int, help="Cap pgbench -j.") + parser.add_argument("--sample-interval", default=1.0, type=float) + parser.add_argument( + "--idle-waits", + default=[10, 30, 60], + type=parse_csv_ints, + help="Seconds after each run to sample backend idle cleanup.", + ) + parser.add_argument( + "--resource", + action="append", + default=[], + type=parse_resource, + help=( + "CPU/memory resource to sample, NAME=pid:PID, NAME=pgrep:PATTERN, " + "or NAME=docker:CONTAINER. Repeatable." + ), + ) + parser.add_argument( + "--keep-pgbench-logs", + action="store_true", + help="Keep raw pgbench latency logs. By default they are parsed then deleted.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + workloads = args.workload or ["readonly_steady", "readonly_connect"] + args.out_dir.mkdir(parents=True, exist_ok=True) + validate_tools(args) + + runs_f, runs_writer = open_csv(args.out_dir / "runs.csv", RUN_FIELDS) + backend_f, backend_writer = open_csv(args.out_dir / "backend_connections.csv", BACKEND_FIELDS) + resource_f, resource_writer = open_csv(args.out_dir / "resources.csv", RESOURCE_FIELDS) + files = [runs_f, backend_f, resource_f] + + try: + for target in args.target: + for workload in workloads: + for concurrency in args.concurrency: + for rep in range(args.rep_start, args.rep_start + args.reps): + print( + f"{utc_now()} target={target.name} workload={workload} " + f"c={concurrency} rep={rep}", + flush=True, + ) + sampler = Sampler( + target=target.name, + workload=workload, + concurrency=concurrency, + rep=rep, + backend_url=args.backend_url, + resources=args.resource, + interval_s=args.sample_interval, + backend_writer=backend_writer, + resource_writer=resource_writer, + files=files, + ) + sampler.start() + exit_code, parsed, stdout_path, stderr_path = run_pgbench( + target=target, + workload=workload, + concurrency=concurrency, + rep=rep, + duration_s=args.duration, + max_jobs=args.max_jobs, + out_dir=args.out_dir, + keep_logs=args.keep_pgbench_logs, + ) + sampler.stop() + sampler.after_idle(args.idle_waits) + + runs_writer.writerow({ + "target": target.name, + "workload": workload, + "concurrency": concurrency, + "rep": rep, + "duration_s": args.duration, + "exit_code": exit_code, + **parsed, + "backend_before_total": sampler.backend_before_total(), + "backend_during_max_total": sampler.backend_during_max_total(), + "backend_after_final_total": sampler.backend_after_final_total(), + "pgbench_stdout": stdout_path, + "pgbench_stderr": stderr_path, + }) + for f in files: + f.flush() + + write_run_summary(args.out_dir) + write_resource_summary(args.out_dir) + finally: + for f in files: + f.close() + + print(f"wrote {args.out_dir / 'summary.csv'}") + if args.resource: + print(f"wrote {args.out_dir / 'resource_summary.csv'}") + + +if __name__ == "__main__": + main()