diff --git a/.github/workflows/deploy-agent-api.yaml b/.github/workflows/deploy-agent-api.yaml index ad37582e977..703467c4d8b 100644 --- a/.github/workflows/deploy-agent-api.yaml +++ b/.github/workflows/deploy-agent-api.yaml @@ -50,6 +50,7 @@ jobs: PGPASSWORD=POSTGRES_PASSWORD:latest CONTROL_PLANE_DB_CA_CERT=CONTROL_PLANE_DB_CA_CERT:latest CONTROL_PLANE_JWT_SECRET=CONTROL_PLANE_JWT_SECRET:latest + STRIPE_API_KEY=STRIPE_API_KEY:latest env_vars_update_strategy: overwrite secrets_update_strategy: overwrite diff --git a/.github/workflows/platform-test.yaml b/.github/workflows/platform-test.yaml index 7c409dde729..87b322bb7f3 100644 --- a/.github/workflows/platform-test.yaml +++ b/.github/workflows/platform-test.yaml @@ -125,6 +125,13 @@ jobs: - run: mise run build:flowctl-go - run: mise run ci:nextest-build - run: mise run ci:nextest-run + + - name: Stripe integration test + env: + STRIPE_API_KEY: ${{ secrets.STRIPE_TESTMODE_API_KEY }} + if: env.STRIPE_API_KEY != '' + run: cargo nextest run --frozen --run-ignored ignored-only -E 'test(graphql_billing_live_stripe)' + - run: mise run ci:doctest - run: mise run ci:gotest - run: mise run ci:catalog-test diff --git a/.sqlx/query-9b05fe20d2fd53e8c5f4aae509794f2db9fe3414770e2820d0481e0936588d69.json b/.sqlx/query-9b05fe20d2fd53e8c5f4aae509794f2db9fe3414770e2820d0481e0936588d69.json new file mode 100644 index 00000000000..f50a5f84698 --- /dev/null +++ b/.sqlx/query-9b05fe20d2fd53e8c5f4aae509794f2db9fe3414770e2820d0481e0936588d69.json @@ -0,0 +1,67 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT\n date_start as \"date_start!\",\n date_end as \"date_end!\",\n billed_prefix as \"billed_prefix!\",\n line_items as \"line_items!: sqlx::types::Json\",\n subtotal as \"subtotal!\",\n extra as \"extra!: sqlx::types::Json\",\n invoice_type as \"invoice_type!: InvoiceType\"\n FROM invoices_ext\n WHERE billed_prefix = $1\n AND ($2::date IS NULL OR date_start > $2)\n AND ($3::date IS NULL OR date_start < $3)\n AND ($4::date IS NULL OR date_end > $4)\n AND ($5::date IS NULL OR date_end < $5)\n AND ($6::text IS NULL OR invoice_type::text = $6)\n AND (\n $7::date IS NULL\n OR date_end > $7\n OR (date_end = $7 AND date_start > $8)\n OR (date_end = $7 AND date_start = $8 AND invoice_type::text < $9)\n )\n ORDER BY date_end ASC, date_start ASC, invoice_type DESC\n LIMIT $10\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "date_start!", + "type_info": "Date" + }, + { + "ordinal": 1, + "name": "date_end!", + "type_info": "Date" + }, + { + "ordinal": 2, + "name": "billed_prefix!", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "line_items!: sqlx::types::Json", + "type_info": "Jsonb" + }, + { + "ordinal": 4, + "name": "subtotal!", + "type_info": "Int4" + }, + { + "ordinal": 5, + "name": "extra!: sqlx::types::Json", + "type_info": "Jsonb" + }, + { + "ordinal": 6, + "name": "invoice_type!: InvoiceType", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Text", + "Date", + "Date", + "Date", + "Date", + "Text", + "Date", + "Date", + "Text", + "Int8" + ] + }, + "nullable": [ + true, + true, + true, + true, + true, + true, + true + ] + }, + "hash": "9b05fe20d2fd53e8c5f4aae509794f2db9fe3414770e2820d0481e0936588d69" +} diff --git a/.sqlx/query-d96bd638750d5555c9a841b673c60b7707648867b17a0d964ce0c86b6c71761f.json b/.sqlx/query-d96bd638750d5555c9a841b673c60b7707648867b17a0d964ce0c86b6c71761f.json new file mode 100644 index 00000000000..01b0ad7bb1e --- /dev/null +++ b/.sqlx/query-d96bd638750d5555c9a841b673c60b7707648867b17a0d964ce0c86b6c71761f.json @@ -0,0 +1,67 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT\n date_start as \"date_start!\",\n date_end as \"date_end!\",\n billed_prefix as \"billed_prefix!\",\n line_items as \"line_items!: sqlx::types::Json\",\n subtotal as \"subtotal!\",\n extra as \"extra!: sqlx::types::Json\",\n invoice_type as \"invoice_type!: InvoiceType\"\n FROM invoices_ext\n WHERE billed_prefix = $1\n AND ($2::date IS NULL OR date_start > $2)\n AND ($3::date IS NULL OR date_start < $3)\n AND ($4::date IS NULL OR date_end > $4)\n AND ($5::date IS NULL OR date_end < $5)\n AND ($6::text IS NULL OR invoice_type::text = $6)\n AND (\n $7::date IS NULL\n OR date_end < $7\n OR (date_end = $7 AND date_start < $8)\n OR (date_end = $7 AND date_start = $8 AND invoice_type::text > $9)\n )\n ORDER BY date_end DESC, date_start DESC, invoice_type ASC\n LIMIT $10\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "date_start!", + "type_info": "Date" + }, + { + "ordinal": 1, + "name": "date_end!", + "type_info": "Date" + }, + { + "ordinal": 2, + "name": "billed_prefix!", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "line_items!: sqlx::types::Json", + "type_info": "Jsonb" + }, + { + "ordinal": 4, + "name": "subtotal!", + "type_info": "Int4" + }, + { + "ordinal": 5, + "name": "extra!: sqlx::types::Json", + "type_info": "Jsonb" + }, + { + "ordinal": 6, + "name": "invoice_type!: InvoiceType", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Text", + "Date", + "Date", + "Date", + "Date", + "Text", + "Date", + "Date", + "Text", + "Int8" + ] + }, + "nullable": [ + true, + true, + true, + true, + true, + true, + true + ] + }, + "hash": "d96bd638750d5555c9a841b673c60b7707648867b17a0d964ce0c86b6c71761f" +} diff --git a/Cargo.lock b/Cargo.lock index ca83e959596..c1f1a0bdca1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1347,6 +1347,7 @@ version = "0.0.0" dependencies = [ "anyhow", "async-stripe", + "billing-types", "chrono", "clap", "comfy-table", @@ -1363,6 +1364,17 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "billing-types" +version = "0.0.0" +dependencies = [ + "anyhow", + "async-graphql", + "async-stripe", + "serde", + "sqlx", +] + [[package]] name = "bindgen" version = "0.72.1" @@ -2040,10 +2052,12 @@ dependencies = [ "async-graphql", "async-graphql-axum", "async-process", + "async-stripe", "async-trait", "axum", "axum-extra", "base64 0.22.1", + "billing-types", "build", "bytes", "chrono", diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index 641c64f3cc0..4be157b1ef1 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -80,6 +80,7 @@ zeroize = { workspace = true } [dev-dependencies] assemble = { path = "../assemble" } +control-plane-api = { path = "../control-plane-api", features = ["test-support"] } insta = { workspace = true } md5 = { workspace = true } tokio = { workspace = true } diff --git a/crates/agent/src/integration_tests/harness.rs b/crates/agent/src/integration_tests/harness.rs index 7186ed83b32..f9a6e4c0a45 100644 --- a/crates/agent/src/integration_tests/harness.rs +++ b/crates/agent/src/integration_tests/harness.rs @@ -541,45 +541,14 @@ impl TestHarness { /// storage_mappings, and tenants tables should all look just like they /// would in production. pub async fn setup_tenant(&self, tenant: &str) -> sqlx::types::Uuid { - let user_id = sqlx::types::Uuid::new_v4(); let email = format!("{tenant}@{}.test", self.test_name.replace(' ', "-")); - let meta = serde_json::json!({ "picture": format!("http://{tenant}.test/avatar"), "full_name": format!("Full ({tenant}) Name"), }); - let mut txn = self.pool.begin().await.unwrap(); - sqlx::query!( - r#"insert into auth.users(id, email, raw_user_meta_data) values ($1, $2, $3)"#, - user_id, - email.as_str(), - meta - ) - .execute(&mut *txn) - .await - .expect("failed to create user"); - - control_plane_api::directives::beta_onboard::provision_tenant( - "support@estuary.dev", - Some(format!("for test: {}", self.test_name)), - tenant, - user_id, - &mut txn, - ) - .await - .expect("failed to provision tenant"); - - // Remove the estuary_support/ role grant, which gets automatically - // added by a trigger whenever we create a new tenant. Removing it here - // ensures that things still work correctly without it. - sqlx::query!(r#"delete from role_grants where subject_role = 'estuary_support/';"#) - .execute(&mut *txn) + control_plane_api::test_support::provision_test_tenant(&self.pool, tenant, &email, meta) .await - .expect("failed to remove estuary_support/ role"); - - txn.commit().await.expect("failed to commit transaction"); - user_id } pub async fn add_role_grant(&mut self, subject: &str, object: &str, capability: Capability) { @@ -1682,6 +1651,7 @@ impl TestHarness { let app = Arc::new(control_plane_api::App::new( id_gen, + None, &jwt_secret, self.pool.clone(), self.publisher.clone(), diff --git a/crates/agent/src/main.rs b/crates/agent/src/main.rs index cb651d78858..2dc371c1ac3 100644 --- a/crates/agent/src/main.rs +++ b/crates/agent/src/main.rs @@ -41,6 +41,12 @@ struct Args { /// The port to listen on for API requests. #[clap(long, default_value = "8080", env = "API_PORT")] api_port: u16, + /// Stripe secret API key. When provided, the billing GraphQL queries and + /// mutations that interact with Stripe are enabled. Without this, those + /// operations return an error indicating billing is not configured. + #[derivative(Debug = "ignore")] + #[clap(long = "stripe-api-key", env = "STRIPE_API_KEY")] + stripe_api_key: Option, /// Whether to serve job handlers within this agent instance. #[clap(long = "serve-handlers", env = "SERVE_HANDLERS")] serve_handlers: bool, @@ -312,9 +318,15 @@ async fn async_main(args: Args) -> Result<(), anyhow::Error> { args.controller_config, ); - // Wire up the agent's API Application and server. + let billing_provider: Option> = + args.stripe_api_key.map(|api_key| { + Arc::new(control_plane_api::billing::StripeBillingProvider::new( + api_key, + )) as Arc + }); let api_app = Arc::new(App::new( agent::id_generator::with_random_shard(), + billing_provider, jwt_secret.as_bytes(), pg_pool.clone(), publisher.clone(), diff --git a/crates/billing-integrations/Cargo.toml b/crates/billing-integrations/Cargo.toml index f939c7f2fea..96bf62c4d2b 100644 --- a/crates/billing-integrations/Cargo.toml +++ b/crates/billing-integrations/Cargo.toml @@ -13,6 +13,7 @@ license.workspace = true [dependencies] anyhow = { workspace = true } async-stripe = { workspace = true } +billing-types = { path = "../billing-types" } clap = { workspace = true } chrono = { workspace = true } comfy-table = { workspace = true } diff --git a/crates/billing-integrations/src/publish.rs b/crates/billing-integrations/src/publish.rs index f8c3b802c0d..ebe31955887 100644 --- a/crates/billing-integrations/src/publish.rs +++ b/crates/billing-integrations/src/publish.rs @@ -1,5 +1,8 @@ -use crate::stripe_utils::{SearchParams, stripe_search}; use anyhow::{Context, bail}; +use billing_types::{ + InvoiceMetadata, InvoiceSearch, InvoiceType, SearchParams, TENANT_METADATA_KEY, + customer_create_idempotency_key, customer_search_query, stripe_search, +}; use chrono::{Duration, ParseError, Utc}; use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use itertools::Itertools; @@ -9,11 +12,7 @@ use sqlx::{Pool, postgres::PgPoolOptions, types::chrono::NaiveDate}; use std::collections::HashMap; use stripe::InvoiceStatus; -pub const TENANT_METADATA_KEY: &str = "estuary.dev/tenant_name"; const CREATED_BY_BILLING_AUTOMATION: &str = "estuary.dev/created_by_automation"; -pub const INVOICE_TYPE_KEY: &str = "estuary.dev/invoice_type"; -pub const BILLING_PERIOD_START_KEY: &str = "estuary.dev/period_start"; -pub const BILLING_PERIOD_END_KEY: &str = "estuary.dev/period_end"; #[derive(Debug, Clone, Copy, clap::ValueEnum)] #[clap(rename_all = "kebab_case")] @@ -65,17 +64,6 @@ fn parse_date(arg: &str) -> Result { NaiveDate::parse_from_str(arg, "%Y-%m-%d") } -#[derive(Debug, Clone, PartialEq, Eq, Hash, sqlx::Type, Serialize, Deserialize)] -#[sqlx(rename_all = "snake_case")] -enum InvoiceType { - #[serde(rename = "final")] - Final, - #[serde(rename = "preview")] - Preview, - #[serde(rename = "manual")] - Manual, -} - #[derive(Serialize, Deserialize, Debug, Clone)] struct Extra { trial_start: Option, @@ -167,25 +155,18 @@ impl Invoice { let date_start_repr = self.date_start.format("%F").to_string(); let date_end_repr = self.date_end.format("%F").to_string(); - let invoice_type_val = - serde_json::to_value(self.invoice_type.clone()).expect("InvoiceType is serializable"); - let invoice_type_str = invoice_type_val - .as_str() - .expect("InvoiceType is serializable"); - let invoice_search = stripe_search::( &client, "invoices", SearchParams { - query: format!( - r#" - -status:"deleted" AND - customer:"{customer_id}" AND - metadata["{INVOICE_TYPE_KEY}"]:"{invoice_type_str}" AND - metadata["{BILLING_PERIOD_START_KEY}"]:"{date_start_repr}" AND - metadata["{BILLING_PERIOD_END_KEY}"]:"{date_end_repr}" - "# - ), + query: InvoiceSearch { + customer_id: Some(customer_id), + invoice_type: Some(self.invoice_type), + period_start: Some(&date_start_repr), + period_end: Some(&date_end_repr), + ..Default::default() + } + .to_query(), ..Default::default() }, ) @@ -306,12 +287,6 @@ impl Invoice { let date_start_repr = self.date_start.format("%F").to_string(); let date_end_repr = self.date_end.format("%F").to_string(); - let invoice_type_val = - serde_json::to_value(self.invoice_type.clone()).expect("InvoiceType is serializable"); - let invoice_type_str = invoice_type_val - .as_str() - .expect("InvoiceType is serializable"); - let customer = get_or_create_customer_for_tenant( client, db_client, @@ -404,12 +379,15 @@ impl Invoice { value: date_end_human.to_owned(), }, ]), - metadata: Some(HashMap::from([ - (TENANT_METADATA_KEY.to_string(), self.billed_prefix.to_owned()), - (INVOICE_TYPE_KEY.to_string(), invoice_type_str.to_owned()), - (BILLING_PERIOD_START_KEY.to_string(), date_start_repr), - (BILLING_PERIOD_END_KEY.to_string(), date_end_repr) - ])), + metadata: Some( + InvoiceMetadata { + tenant: self.billed_prefix.to_owned(), + invoice_type: self.invoice_type, + period_start: date_start_repr, + period_end: date_end_repr, + } + .to_metadata_map(), + ), ..Default::default() }, ) @@ -791,11 +769,11 @@ async fn get_or_create_customer_for_tenant( tenant: String, create: bool, ) -> anyhow::Result> { - let customers = stripe_search::( + let customers: Vec = stripe_search( client, "customers", SearchParams { - query: format!("metadata[\"{TENANT_METADATA_KEY}\"]:\"{tenant}\""), + query: customer_search_query(&tenant), ..Default::default() }, ) @@ -807,8 +785,16 @@ async fn get_or_create_customer_for_tenant( customer } else if create { tracing::debug!("Creating new customer"); + // Match the deterministic Idempotency-Key used by the GraphQL path so + // a setup-intent flow racing against billing automation can't produce + // a second customer for the same tenant. See `stripe_impl.rs`. + let create_client = client + .clone() + .with_strategy(stripe::RequestStrategy::Idempotent( + customer_create_idempotency_key(&tenant), + )); let new_customer = stripe::Customer::create( - client, + &create_client, stripe::CreateCustomer { name: Some(tenant.as_str()), description: Some( diff --git a/crates/billing-integrations/src/send.rs b/crates/billing-integrations/src/send.rs index 9316a44e1cc..bd626870726 100644 --- a/crates/billing-integrations/src/send.rs +++ b/crates/billing-integrations/src/send.rs @@ -1,7 +1,5 @@ -use crate::{ - publish::{BILLING_PERIOD_START_KEY, INVOICE_TYPE_KEY}, - stripe_utils::{Invoice, fetch_invoices}, -}; +use crate::stripe_utils::{Invoice, fetch_invoices}; +use billing_types::{InvoiceSearch, InvoiceType, StatusFilter}; use chrono::{Datelike, Duration, NaiveDate, Utc}; use clap::Args; use futures::stream::{self, StreamExt}; @@ -44,16 +42,34 @@ pub async fn do_send_invoices(cmd: &SendInvoices) -> anyhow::Result<()> { let month_human_repr = cmd.month.format("%B %Y"); tracing::info!("Fetching Stripe invoices to process for {month_human_repr}"); - let base_final_metadata = format!( - "metadata[\"{INVOICE_TYPE_KEY}\"]:'final' AND metadata[\"{BILLING_PERIOD_START_KEY}\"]:'{month_start}'" - ); - let draft_final_query = format!("status:'draft' AND {base_final_metadata}"); - let open_final_query = format!("status:'open' AND {base_final_metadata}"); + let draft_final_query = InvoiceSearch { + invoice_type: Some(InvoiceType::Final), + period_start: Some(&month_start), + status: StatusFilter::Only(stripe::InvoiceStatus::Draft), + ..Default::default() + } + .to_query(); + let open_final_query = InvoiceSearch { + invoice_type: Some(InvoiceType::Final), + period_start: Some(&month_start), + status: StatusFilter::Only(stripe::InvoiceStatus::Open), + ..Default::default() + } + .to_query(); // Separate queries for manual invoices (we'll filter dates client-side) - let draft_manual_query = - format!("status:'draft' AND metadata[\"{INVOICE_TYPE_KEY}\"]:'manual'"); - let open_manual_query = format!("status:'open' AND metadata[\"{INVOICE_TYPE_KEY}\"]:'manual'"); + let draft_manual_query = InvoiceSearch { + invoice_type: Some(InvoiceType::Manual), + status: StatusFilter::Only(stripe::InvoiceStatus::Draft), + ..Default::default() + } + .to_query(); + let open_manual_query = InvoiceSearch { + invoice_type: Some(InvoiceType::Manual), + status: StatusFilter::Only(stripe::InvoiceStatus::Open), + ..Default::default() + } + .to_query(); // 1. Fetch invoices: final invoices with exact date match + all manual invoices let ( diff --git a/crates/billing-integrations/src/stripe_utils.rs b/crates/billing-integrations/src/stripe_utils.rs index 36a719affbc..fa35f97f562 100644 --- a/crates/billing-integrations/src/stripe_utils.rs +++ b/crates/billing-integrations/src/stripe_utils.rs @@ -1,43 +1,6 @@ -use crate::publish::{BILLING_PERIOD_END_KEY, BILLING_PERIOD_START_KEY, TENANT_METADATA_KEY}; +use billing_types::{InvoiceMetadata, SearchParams, stripe_search}; use num_format::{Locale, ToFormattedString}; -use serde::{Serialize, de::DeserializeOwned}; use std::ops::{Deref, DerefMut}; -use stripe::SearchList; - -#[derive(Serialize, Default, Debug)] -pub struct SearchParams { - pub query: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub limit: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub page: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub expand: Option>, -} - -pub async fn stripe_search( - client: &stripe::Client, - resource: &str, - mut params: SearchParams, -) -> Result, stripe::StripeError> { - let mut all_data = Vec::new(); - let mut page = None; - loop { - if let Some(p) = page { - params.page = Some(p); - } - let resp: SearchList = client - .get_query(&format!("/{}/search", resource), ¶ms) - .await?; - let count = resp.data.len(); - all_data.extend(resp.data); - if count == 0 || !resp.has_more { - break; - } - page = resp.next_page; - } - Ok(all_data) -} pub async fn fetch_invoices( stripe_client: &stripe::Client, @@ -59,7 +22,6 @@ pub async fn fetch_invoices( .map(|inv: stripe::Invoice| Invoice::from(inv)) .collect() }) - .map_err(|e| e.into()) } #[derive(Clone, Debug)] @@ -90,8 +52,8 @@ impl Invoice { self.0 .metadata .as_ref() - .and_then(|m| m.get(TENANT_METADATA_KEY)) - .cloned() + .and_then(InvoiceMetadata::from_metadata_map) + .map(|m| m.tenant) .unwrap_or_default() } pub fn amount(&self) -> f64 { @@ -126,16 +88,16 @@ impl Invoice { self.0 .metadata .as_ref() - .and_then(|m| m.get(BILLING_PERIOD_START_KEY)) - .cloned() + .and_then(InvoiceMetadata::from_metadata_map) + .map(|m| m.period_start) } pub fn period_end(&self) -> Option { self.0 .metadata .as_ref() - .and_then(|m| m.get(BILLING_PERIOD_END_KEY)) - .cloned() + .and_then(InvoiceMetadata::from_metadata_map) + .map(|m| m.period_end) } pub fn to_table_row(&self) -> Vec { diff --git a/crates/billing-types/Cargo.toml b/crates/billing-types/Cargo.toml new file mode 100644 index 00000000000..2199adc6573 --- /dev/null +++ b/crates/billing-types/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "billing-types" +version.workspace = true +rust-version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true +repository.workspace = true +license.workspace = true + +[dependencies] +anyhow = { workspace = true } +async-graphql = { workspace = true, optional = true } +async-stripe = { workspace = true } +serde = { workspace = true } +sqlx = { workspace = true } + +[features] +default = [] +async-graphql = ["dep:async-graphql"] diff --git a/crates/billing-types/src/lib.rs b/crates/billing-types/src/lib.rs new file mode 100644 index 00000000000..3640a6f56ce --- /dev/null +++ b/crates/billing-types/src/lib.rs @@ -0,0 +1,286 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::str::FromStr; + +mod stripe_helpers; +pub use stripe_helpers::{SearchParams, stripe_search}; + +pub const TENANT_METADATA_KEY: &str = "estuary.dev/tenant_name"; +const INVOICE_TYPE_KEY: &str = "estuary.dev/invoice_type"; +const BILLING_PERIOD_START_KEY: &str = "estuary.dev/period_start"; +const BILLING_PERIOD_END_KEY: &str = "estuary.dev/period_end"; + +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Serialize, Deserialize, sqlx::Type)] +#[cfg_attr( + feature = "async-graphql", + derive(async_graphql::Enum), + graphql(rename_items = "SCREAMING_SNAKE_CASE") +)] +#[serde(rename_all = "snake_case")] +#[sqlx(type_name = "text")] +#[sqlx(rename_all = "snake_case")] +pub enum InvoiceType { + Final, + Preview, + Manual, +} + +impl InvoiceType { + pub fn as_str(self) -> &'static str { + match self { + InvoiceType::Final => "final", + InvoiceType::Preview => "preview", + InvoiceType::Manual => "manual", + } + } +} + +impl FromStr for InvoiceType { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "final" => Ok(InvoiceType::Final), + "preview" => Ok(InvoiceType::Preview), + "manual" => Ok(InvoiceType::Manual), + _ => Err(()), + } + } +} + +/// Status clause to append to a Stripe invoice search query. +/// +/// Stripe's search DSL accepts both positive (`status:"open"`) and negative +/// (`-status:"draft"`) filters +#[derive(Debug, Clone, Copy, Default)] +pub enum StatusFilter { + /// No status clause. + #[default] + Any, + /// `status:""`: match only invoices with this status. + Only(stripe::InvoiceStatus), + /// `-status:""`: exclude invoices with this status. + Exclude(stripe::InvoiceStatus), +} + +impl StatusFilter { + fn clause(self) -> Option { + match self { + StatusFilter::Any => None, + StatusFilter::Only(s) => Some(format!(r#"status:"{}""#, s.as_str())), + StatusFilter::Exclude(s) => Some(format!(r#"-status:"{}""#, s.as_str())), + } + } +} + +pub fn customer_search_query(tenant: &str) -> String { + format!(r#"metadata["{TENANT_METADATA_KEY}"]:"{tenant}""#) +} + +/// Deterministic Stripe Idempotency-Key for `Customer::create` calls. Using the +/// tenant name collapses concurrent or retried creations across processes within +/// Stripe's 24-hour idempotency window, so a search-index lag race can't produce +/// duplicate customer rows for the same tenant. +pub fn customer_create_idempotency_key(tenant: &str) -> String { + format!("flow-customer-create:{tenant}") +} + +/// These 4 pieces of metadata link an invoice in Stripe to a row in `invoices_ext`. This is +/// an area that could be improved in the future if needed, but presently `invoices_ext` does not +/// model a single "primary key", which is why we need to use this compound identity. It composes: +/// * "Final" invoices, which come from `internal.billing_historicals`, and use the natural key of +/// `(tenant, billed_month)`. `billing_historicals` does not contain a primary key +/// * "Manual" invoices, which come from `internal.manual_bills` which uses the natural key +/// `(tenant, date_start, date_end)`, again not modelling a primary key. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct InvoiceMetadata { + pub tenant: String, + pub invoice_type: InvoiceType, + pub period_start: String, + pub period_end: String, +} + +impl InvoiceMetadata { + pub fn to_metadata_map(&self) -> HashMap { + HashMap::from([ + (TENANT_METADATA_KEY.to_string(), self.tenant.clone()), + ( + INVOICE_TYPE_KEY.to_string(), + self.invoice_type.as_str().to_string(), + ), + ( + BILLING_PERIOD_START_KEY.to_string(), + self.period_start.clone(), + ), + (BILLING_PERIOD_END_KEY.to_string(), self.period_end.clone()), + ]) + } + + /// Parse an `InvoiceMetadata` from a Stripe invoice's metadata map. + /// Returns `Some` only if all four expected fields are present and the + /// invoice type parses; otherwise returns `None`. + pub fn from_metadata_map(map: &HashMap) -> Option { + Some(Self { + tenant: map.get(TENANT_METADATA_KEY)?.clone(), + invoice_type: map.get(INVOICE_TYPE_KEY)?.parse().ok()?, + period_start: map.get(BILLING_PERIOD_START_KEY)?.clone(), + period_end: map.get(BILLING_PERIOD_END_KEY)?.clone(), + }) + } +} + +/// Filter for a Stripe invoice search. Each `Some` field becomes an AND-joined +/// clause in the resulting query; `None` fields are omitted. +#[derive(Debug, Default, Clone, Copy)] +pub struct InvoiceSearch<'a> { + pub customer_id: Option<&'a str>, + pub invoice_type: Option, + pub period_start: Option<&'a str>, + pub period_end: Option<&'a str>, + pub status: StatusFilter, +} + +impl InvoiceSearch<'_> { + pub fn to_query(&self) -> String { + let mut clauses = Vec::with_capacity(5); + if let Some(id) = self.customer_id { + clauses.push(format!(r#"customer:"{id}""#)); + } + if let Some(invoice_type) = self.invoice_type { + clauses.push(format!( + r#"metadata["{INVOICE_TYPE_KEY}"]:"{}""#, + invoice_type.as_str() + )); + } + if let Some(period_start) = self.period_start { + clauses.push(format!( + r#"metadata["{BILLING_PERIOD_START_KEY}"]:"{period_start}""# + )); + } + if let Some(period_end) = self.period_end { + clauses.push(format!( + r#"metadata["{BILLING_PERIOD_END_KEY}"]:"{period_end}""# + )); + } + if let Some(status) = self.status.clause() { + clauses.push(status); + } + clauses.join(" AND ") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn customer_query_format() { + assert_eq!( + customer_search_query("acme/widgets"), + r#"metadata["estuary.dev/tenant_name"]:"acme/widgets""# + ); + } + + #[test] + fn customer_create_idempotency_key_format() { + // Same input must produce the same key across processes for cross-call + // idempotency to work; the prefix namespaces it from future deterministic + // keys for other Stripe writes. + assert_eq!( + customer_create_idempotency_key("acme/widgets"), + "flow-customer-create:acme/widgets" + ); + assert_eq!( + customer_create_idempotency_key("acme/widgets"), + customer_create_idempotency_key("acme/widgets"), + ); + } + + #[test] + fn invoice_type_parse() { + assert_eq!("final".parse(), Ok(InvoiceType::Final)); + assert_eq!("preview".parse(), Ok(InvoiceType::Preview)); + assert_eq!("manual".parse(), Ok(InvoiceType::Manual)); + assert_eq!("Final".parse::(), Err(())); + assert_eq!("".parse::(), Err(())); + } + + #[test] + fn invoice_metadata_round_trip() { + let original = InvoiceMetadata { + tenant: "acme/widgets".to_string(), + invoice_type: InvoiceType::Final, + period_start: "2026-04-01".to_string(), + period_end: "2026-04-30".to_string(), + }; + let parsed = InvoiceMetadata::from_metadata_map(&original.to_metadata_map()); + assert_eq!(parsed, Some(original)); + } + + #[test] + fn invoice_metadata_missing_field_returns_none() { + let mut map = InvoiceMetadata { + tenant: "acme/widgets".to_string(), + invoice_type: InvoiceType::Final, + period_start: "2026-04-01".to_string(), + period_end: "2026-04-30".to_string(), + } + .to_metadata_map(); + map.remove(BILLING_PERIOD_END_KEY); + assert_eq!(InvoiceMetadata::from_metadata_map(&map), None); + } + + #[test] + fn search_full_exclude_draft() { + let got = InvoiceSearch { + customer_id: Some("cus_123"), + invoice_type: Some(InvoiceType::Final), + period_start: Some("2026-04-01"), + period_end: Some("2026-04-30"), + status: StatusFilter::Exclude(stripe::InvoiceStatus::Draft), + } + .to_query(); + assert_eq!( + got, + r#"customer:"cus_123" AND metadata["estuary.dev/invoice_type"]:"final" AND metadata["estuary.dev/period_start"]:"2026-04-01" AND metadata["estuary.dev/period_end"]:"2026-04-30" AND -status:"draft""# + ); + } + + #[test] + fn search_full_exclude_void() { + let got = InvoiceSearch { + customer_id: Some("cus_123"), + invoice_type: Some(InvoiceType::Final), + period_start: Some("2026-04-01"), + period_end: Some("2026-04-30"), + status: StatusFilter::Exclude(stripe::InvoiceStatus::Void), + } + .to_query(); + assert!(got.ends_with(r#"AND -status:"void""#)); + } + + #[test] + fn search_type_and_period_start() { + let got = InvoiceSearch { + invoice_type: Some(InvoiceType::Final), + period_start: Some("2026-04-01"), + status: StatusFilter::Only(stripe::InvoiceStatus::Draft), + ..Default::default() + } + .to_query(); + assert_eq!( + got, + r#"metadata["estuary.dev/invoice_type"]:"final" AND metadata["estuary.dev/period_start"]:"2026-04-01" AND status:"draft""# + ); + } + + #[test] + fn search_type_only_status_any() { + let got = InvoiceSearch { + invoice_type: Some(InvoiceType::Manual), + ..Default::default() + } + .to_query(); + assert_eq!(got, r#"metadata["estuary.dev/invoice_type"]:"manual""#); + } +} diff --git a/crates/billing-types/src/stripe_helpers.rs b/crates/billing-types/src/stripe_helpers.rs new file mode 100644 index 00000000000..ce1a3fffc59 --- /dev/null +++ b/crates/billing-types/src/stripe_helpers.rs @@ -0,0 +1,36 @@ +use serde::de::DeserializeOwned; + +#[derive(serde::Serialize, Default, Debug)] +pub struct SearchParams { + pub query: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub limit: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub page: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub expand: Option>, +} + +pub async fn stripe_search( + client: &stripe::Client, + resource: &str, + mut params: SearchParams, +) -> anyhow::Result> { + let mut all_data = Vec::new(); + let mut page = None; + loop { + if let Some(p) = page { + params.page = Some(p); + } + let resp: stripe::SearchList = client + .get_query(&format!("/{resource}/search"), ¶ms) + .await?; + let count = resp.data.len(); + all_data.extend(resp.data); + if count == 0 || !resp.has_more { + break; + } + page = resp.next_page; + } + Ok(all_data) +} diff --git a/crates/control-plane-api/Cargo.toml b/crates/control-plane-api/Cargo.toml index ade442dd13e..7c62f1cca4c 100644 --- a/crates/control-plane-api/Cargo.toml +++ b/crates/control-plane-api/Cargo.toml @@ -12,6 +12,8 @@ license.workspace = true activate = { path = "../activate" } allocator = { path = "../allocator" } async-process = { path = "../async-process" } +async-stripe = { workspace = true } +billing-types = { path = "../billing-types", features = ["async-graphql"] } build = { path = "../build" } coroutines = { path = "../coroutines" } doc = { path = "../doc" } @@ -89,3 +91,4 @@ tracing-subscriber = { workspace = true } [features] default = ["sqlx-support"] sqlx-support = [] +test-support = [] diff --git a/crates/control-plane-api/src/billing/db.rs b/crates/control-plane-api/src/billing/db.rs new file mode 100644 index 00000000000..eb86c114863 --- /dev/null +++ b/crates/control-plane-api/src/billing/db.rs @@ -0,0 +1,169 @@ +use crate::billing::InvoiceType; +use chrono::NaiveDate; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct InvoiceCursorKey { + pub date_start: NaiveDate, + pub date_end: NaiveDate, + pub invoice_type: InvoiceType, +} + +impl InvoiceCursorKey { + pub fn from_row(row: &DbInvoiceRow) -> Self { + Self { + date_start: row.date_start, + date_end: row.date_end, + invoice_type: row.invoice_type, + } + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct InvoiceQuery { + pub date_start_gt: Option, + pub date_start_lt: Option, + pub date_end_gt: Option, + pub date_end_lt: Option, + pub invoice_type_eq: Option, +} + +#[derive(Debug, Clone)] +pub struct DbInvoiceRow { + pub date_start: NaiveDate, + pub date_end: NaiveDate, + pub billed_prefix: String, + pub line_items: sqlx::types::Json, + pub subtotal: i32, + pub extra: sqlx::types::Json, + pub invoice_type: InvoiceType, +} + +/// Forward pagination: fetch invoices older than `cursor` (or the newest +/// invoices when `cursor` is `None`). Returned rows are ordered newest-first. +pub async fn fetch_invoice_rows_forward( + pool: &sqlx::PgPool, + tenant: &str, + query: &InvoiceQuery, + cursor: Option, + limit: Option, +) -> anyhow::Result<(Vec, bool)> { + let query_limit = limit.map(|l| l as i64 + 1).unwrap_or(i64::MAX); + let invoice_type_eq = query.invoice_type_eq.map(|t| t.as_str()); + let cursor_date_end = cursor.map(|c| c.date_end); + let cursor_date_start = cursor.map(|c| c.date_start); + let cursor_invoice_type = cursor.map(|c| c.invoice_type.as_str()); + + let mut invoices = sqlx::query_as!( + DbInvoiceRow, + r#" + SELECT + date_start as "date_start!", + date_end as "date_end!", + billed_prefix as "billed_prefix!", + line_items as "line_items!: sqlx::types::Json", + subtotal as "subtotal!", + extra as "extra!: sqlx::types::Json", + invoice_type as "invoice_type!: InvoiceType" + FROM invoices_ext + WHERE billed_prefix = $1 + AND ($2::date IS NULL OR date_start > $2) + AND ($3::date IS NULL OR date_start < $3) + AND ($4::date IS NULL OR date_end > $4) + AND ($5::date IS NULL OR date_end < $5) + AND ($6::text IS NULL OR invoice_type::text = $6) + AND ( + $7::date IS NULL + OR date_end < $7 + OR (date_end = $7 AND date_start < $8) + OR (date_end = $7 AND date_start = $8 AND invoice_type::text > $9) + ) + ORDER BY date_end DESC, date_start DESC, invoice_type ASC + LIMIT $10 + "#, + tenant, + query.date_start_gt, + query.date_start_lt, + query.date_end_gt, + query.date_end_lt, + invoice_type_eq, + cursor_date_end, + cursor_date_start, + cursor_invoice_type, + query_limit, + ) + .fetch_all(pool) + .await?; + + // Query for one extra row so that its presence indicates more rows exist + // past this batch; truncate it before returning. + let has_more = limit.is_some_and(|l| invoices.len() > l); + if let Some(l) = limit { + invoices.truncate(l); + } + Ok((invoices, has_more)) +} + +/// Backward pagination: fetch invoices newer than `cursor`. Returned rows are +/// ordered newest-first (the query selects oldest-first to honor `limit`, then +/// the result is reversed). +pub async fn fetch_invoice_rows_backward( + pool: &sqlx::PgPool, + tenant: &str, + query: &InvoiceQuery, + cursor: Option, + limit: Option, +) -> anyhow::Result<(Vec, bool)> { + let query_limit = limit.map(|l| l as i64 + 1).unwrap_or(i64::MAX); + let invoice_type_eq = query.invoice_type_eq.map(|t| t.as_str()); + let cursor_date_end = cursor.map(|c| c.date_end); + let cursor_date_start = cursor.map(|c| c.date_start); + let cursor_invoice_type = cursor.map(|c| c.invoice_type.as_str()); + + let mut invoices = sqlx::query_as!( + DbInvoiceRow, + r#" + SELECT + date_start as "date_start!", + date_end as "date_end!", + billed_prefix as "billed_prefix!", + line_items as "line_items!: sqlx::types::Json", + subtotal as "subtotal!", + extra as "extra!: sqlx::types::Json", + invoice_type as "invoice_type!: InvoiceType" + FROM invoices_ext + WHERE billed_prefix = $1 + AND ($2::date IS NULL OR date_start > $2) + AND ($3::date IS NULL OR date_start < $3) + AND ($4::date IS NULL OR date_end > $4) + AND ($5::date IS NULL OR date_end < $5) + AND ($6::text IS NULL OR invoice_type::text = $6) + AND ( + $7::date IS NULL + OR date_end > $7 + OR (date_end = $7 AND date_start > $8) + OR (date_end = $7 AND date_start = $8 AND invoice_type::text < $9) + ) + ORDER BY date_end ASC, date_start ASC, invoice_type DESC + LIMIT $10 + "#, + tenant, + query.date_start_gt, + query.date_start_lt, + query.date_end_gt, + query.date_end_lt, + invoice_type_eq, + cursor_date_end, + cursor_date_start, + cursor_invoice_type, + query_limit, + ) + .fetch_all(pool) + .await?; + + let has_more = limit.is_some_and(|l| invoices.len() > l); + if let Some(l) = limit { + invoices.truncate(l); + } + invoices.reverse(); + Ok((invoices, has_more)) +} diff --git a/crates/control-plane-api/src/billing/memory.rs b/crates/control-plane-api/src/billing/memory.rs new file mode 100644 index 00000000000..e0d0a6b98c0 --- /dev/null +++ b/crates/control-plane-api/src/billing/memory.rs @@ -0,0 +1,225 @@ +use super::BillingProvider; +use billing_types::TENANT_METADATA_KEY; +use std::collections::HashMap; +use std::sync::Mutex; + +#[derive(Debug, Default)] +struct State { + customers: Vec, + payment_methods: Vec<(stripe::CustomerId, stripe::PaymentMethod)>, + invoices: Vec<(stripe::CustomerId, stripe::Invoice)>, + payment_intents: Vec, + setup_intent_counter: u64, +} + +/// In-memory `BillingProvider` used by tests and local development. +#[derive(Debug, Default)] +pub struct InMemoryBillingProvider { + state: Mutex, +} + +impl InMemoryBillingProvider { + pub fn new() -> Self { + Self::default() + } + + pub fn add_customer(&self, tenant: &str, id: &str, default_pm: Option<&str>) { + let mut state = self.state.lock().unwrap(); + state.customers.push(stripe::Customer { + id: id.parse().unwrap(), + invoice_settings: Some(stripe::InvoiceSettingCustomerSetting { + default_payment_method: default_pm + .map(|pm| stripe::Expandable::Id(pm.parse().unwrap())), + ..Default::default() + }), + metadata: Some(HashMap::from([( + TENANT_METADATA_KEY.to_string(), + tenant.to_string(), + )])), + ..Default::default() + }); + } + + pub fn add_payment_method( + &self, + customer_id: &str, + id: &str, + type_: stripe::PaymentMethodType, + billing_details: stripe::BillingDetails, + card: Option, + us_bank_account: Option, + ) { + let pm = stripe::PaymentMethod { + id: id.parse().unwrap(), + type_, + billing_details, + card, + us_bank_account, + ..Default::default() + }; + self.state + .lock() + .unwrap() + .payment_methods + .push((customer_id.parse().unwrap(), pm)); + } + + pub fn add_invoice(&self, customer_id: &str, invoice: stripe::Invoice) { + self.state + .lock() + .unwrap() + .invoices + .push((customer_id.parse().unwrap(), invoice)); + } + + pub fn add_payment_intent(&self, pi: stripe::PaymentIntent) { + self.state.lock().unwrap().payment_intents.push(pi); + } + + fn customer_search_tenant(query: &str) -> Option<&str> { + let prefix = format!(r#"metadata["{}"]:""#, TENANT_METADATA_KEY); + query + .strip_prefix(&prefix) + .and_then(|rest| rest.strip_suffix('"')) + } +} + +#[async_trait::async_trait] +impl BillingProvider for InMemoryBillingProvider { + async fn search_customers(&self, query: &str) -> anyhow::Result> { + let state = self.state.lock().unwrap(); + let Some(tenant) = Self::customer_search_tenant(query) else { + return Ok(state.customers.clone()); + }; + + Ok(state + .customers + .iter() + .filter(|customer| { + customer + .metadata + .as_ref() + .and_then(|metadata| metadata.get(TENANT_METADATA_KEY)) + .is_some_and(|value| value == tenant) + }) + .cloned() + .collect()) + } + + async fn create_customer( + &self, + tenant: &str, + _user_email: &str, + _user_name: Option<&str>, + ) -> anyhow::Result { + let mut state = self.state.lock().unwrap(); + let id = format!("cus_mock_{}", tenant.replace('/', "")); + let customer = stripe::Customer { + id: id.parse().unwrap(), + metadata: Some(HashMap::from([( + TENANT_METADATA_KEY.to_string(), + tenant.to_string(), + )])), + ..Default::default() + }; + state.customers.push(customer.clone()); + Ok(customer) + } + + async fn update_customer_default_payment_method( + &self, + customer_id: &stripe::CustomerId, + payment_method_id: Option<&str>, + ) -> anyhow::Result { + let mut state = self.state.lock().unwrap(); + let customer = state + .customers + .iter_mut() + .find(|c| &c.id == customer_id) + .ok_or_else(|| anyhow::anyhow!("customer not found: {customer_id}"))?; + let settings = customer + .invoice_settings + .get_or_insert_with(Default::default); + settings.default_payment_method = + payment_method_id.map(|id| stripe::Expandable::Id(id.parse().unwrap())); + Ok(customer.clone()) + } + + async fn list_payment_methods( + &self, + customer_id: &stripe::CustomerId, + ) -> anyhow::Result> { + let state = self.state.lock().unwrap(); + Ok(state + .payment_methods + .iter() + .filter(|(cid, _)| cid == customer_id) + .map(|(_, method)| method.clone()) + .collect()) + } + + async fn create_setup_intent( + &self, + _customer_id: &stripe::CustomerId, + ) -> anyhow::Result { + let mut state = self.state.lock().unwrap(); + state.setup_intent_counter += 1; + Ok(stripe::SetupIntent { + client_secret: Some(format!( + "seti_mock_{}_secret_test", + state.setup_intent_counter + )), + ..Default::default() + }) + } + + async fn get_payment_method( + &self, + payment_method_id: &stripe::PaymentMethodId, + ) -> anyhow::Result { + let state = self.state.lock().unwrap(); + state + .payment_methods + .iter() + .find(|(_, method)| &method.id == payment_method_id) + .map(|(_, method)| method.clone()) + .ok_or_else(|| anyhow::anyhow!("payment method not found: {payment_method_id}")) + } + + async fn detach_payment_method( + &self, + payment_method_id: &stripe::PaymentMethodId, + ) -> anyhow::Result { + let mut state = self.state.lock().unwrap(); + let idx = state + .payment_methods + .iter() + .position(|(_, method)| &method.id == payment_method_id) + .ok_or_else(|| anyhow::anyhow!("payment method not found: {payment_method_id}"))?; + let (_, method) = state.payment_methods.remove(idx); + Ok(method) + } + + async fn search_invoices(&self, query: &str) -> anyhow::Result> { + let state = self.state.lock().unwrap(); + Ok(state + .invoices + .iter() + .filter(|(customer_id, _)| query.contains(customer_id.as_str())) + .map(|(_, invoice)| invoice.clone()) + .collect()) + } + + async fn retrieve_payment_intent( + &self, + id: &stripe::PaymentIntentId, + ) -> anyhow::Result { + let state = self.state.lock().unwrap(); + state + .payment_intents + .iter() + .find(|pi| &pi.id == id) + .cloned() + .ok_or_else(|| anyhow::anyhow!("payment intent not found: {id}")) + } +} diff --git a/crates/control-plane-api/src/billing/mod.rs b/crates/control-plane-api/src/billing/mod.rs new file mode 100644 index 00000000000..166488483b2 --- /dev/null +++ b/crates/control-plane-api/src/billing/mod.rs @@ -0,0 +1,13 @@ +pub mod db; +pub mod memory; +pub mod provider; +pub mod stripe_impl; + +pub use billing_types::{InvoiceMetadata, InvoiceSearch, InvoiceType, TENANT_METADATA_KEY}; +pub use db::{ + DbInvoiceRow, InvoiceCursorKey, InvoiceQuery, fetch_invoice_rows_backward, + fetch_invoice_rows_forward, +}; +pub use memory::InMemoryBillingProvider; +pub use provider::{BillingProvider, default_payment_method_id}; +pub use stripe_impl::StripeBillingProvider; diff --git a/crates/control-plane-api/src/billing/provider.rs b/crates/control-plane-api/src/billing/provider.rs new file mode 100644 index 00000000000..332916b922a --- /dev/null +++ b/crates/control-plane-api/src/billing/provider.rs @@ -0,0 +1,85 @@ +/// `BillingProvider` is intentionally a narrow seam around outbound Stripe API +/// calls. It is not meant to be a complete billing service boundary, which is +/// why the interface still uses Stripe-native types. Database-backed billing +/// reads live separately under `billing::db`. +/// +/// This trait exists for two reasons: +/// 1. Keep the Stripe SDK wiring in one place. +/// 2. Make resolver tests deterministic without calling live Stripe. +#[async_trait::async_trait] +pub trait BillingProvider: Send + Sync + std::fmt::Debug { + async fn search_customers(&self, query: &str) -> anyhow::Result>; + + async fn create_customer( + &self, + tenant: &str, + user_email: &str, + user_name: Option<&str>, + ) -> anyhow::Result; + + async fn update_customer_default_payment_method( + &self, + customer_id: &stripe::CustomerId, + payment_method_id: Option<&str>, + ) -> anyhow::Result; + + async fn list_payment_methods( + &self, + customer_id: &stripe::CustomerId, + ) -> anyhow::Result>; + + async fn create_setup_intent( + &self, + customer_id: &stripe::CustomerId, + ) -> anyhow::Result; + + async fn get_payment_method( + &self, + payment_method_id: &stripe::PaymentMethodId, + ) -> anyhow::Result; + + async fn detach_payment_method( + &self, + payment_method_id: &stripe::PaymentMethodId, + ) -> anyhow::Result; + + async fn search_invoices(&self, query: &str) -> anyhow::Result>; + + async fn retrieve_payment_intent( + &self, + id: &stripe::PaymentIntentId, + ) -> anyhow::Result; + + async fn find_customer(&self, tenant: &str) -> anyhow::Result> { + let query = billing_types::customer_search_query(tenant); + let customers = self.search_customers(&query).await?; + Ok(customers.into_iter().next()) + } + + async fn require_customer(&self, tenant: &str) -> anyhow::Result { + self.find_customer(tenant) + .await? + .ok_or_else(|| anyhow::anyhow!("no Stripe customer exists for tenant '{tenant}'")) + } + + async fn find_or_create_customer( + &self, + tenant: &str, + email: &str, + full_name: Option<&str>, + ) -> anyhow::Result { + if let Some(existing) = self.find_customer(tenant).await? { + return Ok(existing); + } + + self.create_customer(tenant, email, full_name).await + } +} + +pub fn default_payment_method_id(customer: &stripe::Customer) -> Option { + customer + .invoice_settings + .as_ref() + .and_then(|s| s.default_payment_method.as_ref()) + .map(|e| e.id().to_string()) +} diff --git a/crates/control-plane-api/src/billing/stripe_impl.rs b/crates/control-plane-api/src/billing/stripe_impl.rs new file mode 100644 index 00000000000..7a3293ba0e1 --- /dev/null +++ b/crates/control-plane-api/src/billing/stripe_impl.rs @@ -0,0 +1,173 @@ +use super::BillingProvider; +use billing_types::{ + SearchParams, TENANT_METADATA_KEY, customer_create_idempotency_key, stripe_search, +}; +use std::collections::HashMap; + +/// Production `BillingProvider` backed by the Stripe API. +#[derive(Clone)] +pub struct StripeBillingProvider { + client: stripe::Client, +} + +// Manual impl: `stripe::Client` doesn't derive `Debug`, and we wouldn't want +// it formatted anyway since it holds the API key. `BillingProvider` requires +// `Debug` so this stub satisfies the bound without leaking the secret. +impl std::fmt::Debug for StripeBillingProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StripeBillingProvider") + .finish_non_exhaustive() + } +} + +impl StripeBillingProvider { + pub fn new(api_key: String) -> Self { + Self { + client: stripe::Client::new(api_key) + .with_strategy(stripe::RequestStrategy::ExponentialBackoff(4)), + } + } +} + +#[async_trait::async_trait] +impl BillingProvider for StripeBillingProvider { + async fn search_customers(&self, query: &str) -> anyhow::Result> { + stripe_search( + &self.client, + "customers", + SearchParams { + query: query.to_string(), + ..Default::default() + }, + ) + .await + } + + async fn create_customer( + &self, + tenant: &str, + user_email: &str, + user_name: Option<&str>, + ) -> anyhow::Result { + let mut metadata = HashMap::from([ + (TENANT_METADATA_KEY.to_string(), tenant.to_string()), + ("created_by_user_email".to_string(), user_email.to_string()), + ]); + if let Some(name) = user_name { + metadata.insert("created_by_user_name".to_string(), name.to_string()); + } + + let description = format!("Represents the billing entity for Flow tenant '{tenant}'"); + // Stripe's customer-search index lags writes by seconds, so two near- + // simultaneous `find_or_create_customer` calls can both miss in search + // and both create. Pinning a deterministic Idempotency-Key per tenant + // collapses retries inside Stripe's 24h idempotency window. + let client = self + .client + .clone() + .with_strategy(stripe::RequestStrategy::Idempotent( + customer_create_idempotency_key(tenant), + )); + let customer = stripe::Customer::create( + &client, + stripe::CreateCustomer { + email: Some(user_email), + name: Some(tenant), + description: Some(&description), + metadata: Some(metadata), + ..Default::default() + }, + ) + .await?; + Ok(customer) + } + + async fn update_customer_default_payment_method( + &self, + customer_id: &stripe::CustomerId, + payment_method_id: Option<&str>, + ) -> anyhow::Result { + let customer = stripe::Customer::update( + &self.client, + customer_id, + stripe::UpdateCustomer { + invoice_settings: Some(stripe::CustomerInvoiceSettings { + default_payment_method: payment_method_id.map(str::to_string), + ..Default::default() + }), + ..Default::default() + }, + ) + .await?; + Ok(customer) + } + + async fn list_payment_methods( + &self, + customer_id: &stripe::CustomerId, + ) -> anyhow::Result> { + let list = stripe::Customer::retrieve_payment_methods( + &self.client, + customer_id, + stripe::CustomerPaymentMethodRetrieval::default(), + ) + .await?; + Ok(list.data) + } + + async fn create_setup_intent( + &self, + customer_id: &stripe::CustomerId, + ) -> anyhow::Result { + let si = stripe::SetupIntent::create( + &self.client, + stripe::CreateSetupIntent { + customer: Some(customer_id.clone()), + description: Some("Store your payment details"), + automatic_payment_methods: Some(stripe::CreateSetupIntentAutomaticPaymentMethods { + enabled: true, + ..Default::default() + }), + ..Default::default() + }, + ) + .await?; + Ok(si) + } + + async fn get_payment_method( + &self, + payment_method_id: &stripe::PaymentMethodId, + ) -> anyhow::Result { + let pm = stripe::PaymentMethod::retrieve(&self.client, payment_method_id, &[]).await?; + Ok(pm) + } + + async fn detach_payment_method( + &self, + payment_method_id: &stripe::PaymentMethodId, + ) -> anyhow::Result { + let pm = stripe::PaymentMethod::detach(&self.client, payment_method_id).await?; + Ok(pm) + } + + async fn search_invoices(&self, query: &str) -> anyhow::Result> { + stripe_search( + &self.client, + "invoices", + SearchParams { + query: query.to_string(), + ..Default::default() + }, + ) + .await + } + + async fn retrieve_payment_intent( + &self, + id: &stripe::PaymentIntentId, + ) -> anyhow::Result { + let pi = stripe::PaymentIntent::retrieve(&self.client, id, &["latest_charge"]).await?; + Ok(pi) + } +} diff --git a/crates/control-plane-api/src/lib.rs b/crates/control-plane-api/src/lib.rs index 93d8532d118..53234fcdc81 100644 --- a/crates/control-plane-api/src/lib.rs +++ b/crates/control-plane-api/src/lib.rs @@ -4,6 +4,7 @@ use sqlx::types::Uuid; pub mod alert_subscriptions; pub mod alerts; +pub mod billing; pub mod connector_tags; pub mod controllers; pub mod data_plane; @@ -22,6 +23,9 @@ pub mod publications; pub mod server; mod text_json; +#[cfg(any(test, feature = "test-support"))] +pub mod test_support; + #[cfg(test)] pub(crate) mod test_server; diff --git a/crates/control-plane-api/src/server/mod.rs b/crates/control-plane-api/src/server/mod.rs index fd0c385f687..ff90670a68c 100644 --- a/crates/control-plane-api/src/server/mod.rs +++ b/crates/control-plane-api/src/server/mod.rs @@ -34,6 +34,7 @@ pub enum Rejection { /// App is the wired application state of the control-plane API. pub struct App { pub _id_generator: std::sync::Mutex, + pub billing_provider: Option>, pub control_plane_jwt_decode_keys: Vec, pub control_plane_jwt_encode_key: tokens::jwt::EncodingKey, pub pg_pool: sqlx::PgPool, @@ -44,6 +45,7 @@ pub struct App { impl App { pub fn new( id_generator: models::IdGenerator, + billing_provider: Option>, jwt_secret: &[u8], pg_pool: sqlx::PgPool, publisher: crate::publications::Publisher, @@ -51,6 +53,7 @@ impl App { ) -> Self { Self { _id_generator: std::sync::Mutex::new(id_generator), + billing_provider, control_plane_jwt_decode_keys: vec![tokens::jwt::DecodingKey::from_secret(jwt_secret)], control_plane_jwt_encode_key: tokens::jwt::EncodingKey::from_secret(jwt_secret), pg_pool, diff --git a/crates/control-plane-api/src/server/public/graphql/billing/invoices.rs b/crates/control-plane-api/src/server/public/graphql/billing/invoices.rs new file mode 100644 index 00000000000..fd88401a266 --- /dev/null +++ b/crates/control-plane-api/src/server/public/graphql/billing/invoices.rs @@ -0,0 +1,626 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use super::super::filters; +use super::payment_methods::{CardPaymentMethodDetails, UsBankAccountPaymentMethodDetails}; +use crate::billing::{self, BillingProvider, InvoiceCursorKey, InvoiceQuery, InvoiceType}; +use anyhow::Context as _; +use async_graphql::{ + ComplexObject, Context, InputObject, Result, SimpleObject, + connection::{self}, + dataloader::{DataLoader, Loader}, +}; +use chrono::NaiveDate; + +pub(super) type InvoiceCursor = InvoiceCursorKey; + +impl connection::CursorType for InvoiceCursorKey { + type Error = anyhow::Error; + + fn decode_cursor(s: &str) -> std::result::Result { + let mut splits = s.split(';'); + let Some(date_end) = splits.next() else { + anyhow::bail!("invalid invoice cursor, no date_end: '{s}'"); + }; + let Some(date_start) = splits.next() else { + anyhow::bail!("invalid invoice cursor, no date_start: '{s}'"); + }; + let Some(invoice_type) = splits.next() else { + anyhow::bail!("invalid invoice cursor, no invoice_type: '{s}'"); + }; + + let date_end = + NaiveDate::parse_from_str(date_end, "%Y-%m-%d").context("invalid invoice cursor")?; + let date_start = + NaiveDate::parse_from_str(date_start, "%Y-%m-%d").context("invalid invoice cursor")?; + let invoice_type = invoice_type.parse::().map_err(|()| { + anyhow::anyhow!("invalid invoice cursor, unknown invoice type: '{invoice_type}'") + })?; + + Ok(Self { + date_start, + date_end, + invoice_type, + }) + } + + fn encode_cursor(&self) -> String { + format!( + "{};{};{}", + self.date_end, + self.date_start, + self.invoice_type.as_str() + ) + } +} + +#[derive(Debug, Clone, Default, InputObject)] +pub struct InvoiceTypeFilter { + pub eq: Option, +} + +#[derive(Debug, Clone, Default, InputObject)] +pub struct InvoiceFilter { + pub date_start: Option, + pub date_end: Option, + pub invoice_type: Option, +} + +impl InvoiceFilter { + pub(super) fn into_query(self) -> InvoiceQuery { + let date_start = self.date_start.unwrap_or_default(); + let date_end = self.date_end.unwrap_or_default(); + + InvoiceQuery { + date_start_gt: date_start.gt, + date_start_lt: date_start.lt, + date_end_gt: date_end.gt, + date_end_lt: date_end.lt, + invoice_type_eq: self.invoice_type.and_then(|f| f.eq), + } + } +} + +#[derive(Debug, Clone, SimpleObject)] +#[graphql(complex)] +pub struct Invoice { + pub date_start: String, + pub date_end: String, + pub invoice_type: InvoiceType, + pub subtotal: i32, + pub line_items: async_graphql::Json, + pub extra: async_graphql::Json, + #[graphql(skip)] + tenant: String, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, async_graphql::Enum)] +pub enum ChargeStatus { + Failed, + Pending, + Succeeded, +} + +impl From<&stripe::ChargeStatus> for ChargeStatus { + fn from(s: &stripe::ChargeStatus) -> Self { + match s { + stripe::ChargeStatus::Failed => Self::Failed, + stripe::ChargeStatus::Pending => Self::Pending, + stripe::ChargeStatus::Succeeded => Self::Succeeded, + } + } +} + +#[derive(Debug, Clone, SimpleObject)] +pub struct InvoicePaymentDetails { + pub status: ChargeStatus, + pub receipt_url: Option, + pub card: Option, + pub us_bank_account: Option, +} + +impl InvoicePaymentDetails { + fn from_charge(charge: &stripe::Charge) -> Self { + let (card, us_bank_account) = match charge.payment_method_details { + Some(ref details) => ( + details.card.as_ref().map(CardPaymentMethodDetails::from), + details + .us_bank_account + .as_ref() + .map(UsBankAccountPaymentMethodDetails::from), + ), + None => (None, None), + }; + Self { + status: ChargeStatus::from(&charge.status), + receipt_url: charge.receipt_url.clone(), + card, + us_bank_account, + } + } +} + +#[ComplexObject] +impl Invoice { + async fn amount_due(&self, ctx: &Context<'_>) -> Result> { + Ok(self + .stripe_invoice(ctx) + .await? + .and_then(|inv| inv.amount_due)) + } + + async fn status(&self, ctx: &Context<'_>) -> Result> { + Ok(self + .stripe_invoice(ctx) + .await? + .and_then(|inv| inv.status.as_ref().map(|s| s.as_str().to_string()))) + } + + async fn invoice_pdf(&self, ctx: &Context<'_>) -> Result> { + Ok(self + .stripe_invoice(ctx) + .await? + .and_then(|inv| inv.invoice_pdf.clone())) + } + + async fn hosted_invoice_url(&self, ctx: &Context<'_>) -> Result> { + Ok(self + .stripe_invoice(ctx) + .await? + .and_then(|inv| inv.hosted_invoice_url.clone())) + } + + async fn payment_details(&self, ctx: &Context<'_>) -> Result> { + let Some(invoice) = self.stripe_invoice(ctx).await? else { + return Ok(None); + }; + let Some(ref pi) = invoice.payment_intent else { + return Ok(None); + }; + let loader = ctx.data::>()?; + let charge = loader.load_one(pi.id()).await?; + Ok(charge.map(|ref c| InvoicePaymentDetails::from_charge(c))) + } +} + +impl Invoice { + pub(super) fn from_row(row: billing::DbInvoiceRow) -> Self { + Self { + date_start: row.date_start.to_string(), + date_end: row.date_end.to_string(), + invoice_type: row.invoice_type, + subtotal: row.subtotal, + line_items: async_graphql::Json(row.line_items.0), + extra: async_graphql::Json(row.extra.0), + tenant: row.billed_prefix, + } + } + + async fn stripe_invoice(&self, ctx: &Context<'_>) -> Result> { + let customer_loader = ctx.data::>()?; + let Some(customer) = customer_loader.load_one(self.tenant.clone()).await? else { + return Ok(None); + }; + let loader = ctx.data::>()?; + loader + .load_one(StripeInvoiceKey { + customer_id: customer.id, + date_start: self.date_start.clone(), + date_end: self.date_end.clone(), + invoice_type: self.invoice_type, + }) + .await + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct StripeInvoiceKey { + customer_id: stripe::CustomerId, + date_start: String, + date_end: String, + invoice_type: InvoiceType, +} + +pub(in crate::server::public::graphql) struct StripeInvoiceLoader(pub Arc); + +impl Loader for StripeInvoiceLoader { + type Value = stripe::Invoice; + type Error = async_graphql::Error; + + async fn load( + &self, + keys: &[StripeInvoiceKey], + ) -> Result> { + let searches = keys.iter().map(|key| { + let query = billing_types::InvoiceSearch { + customer_id: Some(key.customer_id.as_str()), + invoice_type: Some(key.invoice_type), + period_start: Some(&key.date_start), + period_end: Some(&key.date_end), + status: billing_types::StatusFilter::Exclude(stripe::InvoiceStatus::Draft), + } + .to_query(); + let provider = self.0.clone(); + let key = key.clone(); + + async move { + let invoice = provider + .search_invoices(&query) + .await + .map_err(|err| async_graphql::Error::new(err.to_string()))? + .into_iter() + .next(); + Ok::<_, async_graphql::Error>((key, invoice)) + } + }); + + futures::future::join_all(searches) + .await + .into_iter() + .filter_map(|result| match result { + Err(err) => Some(Err(err)), + Ok((_, None)) => None, + Ok((key, Some(invoice))) => Some(Ok((key, invoice))), + }) + .collect() + } +} + +/// Request-scoped loader that resolves Stripe charges by payment intent ID. +pub(in crate::server::public::graphql) struct ChargeDataLoader(pub Arc); + +impl Loader for ChargeDataLoader { + type Value = stripe::Charge; + type Error = async_graphql::Error; + + async fn load( + &self, + keys: &[stripe::PaymentIntentId], + ) -> Result> { + let lookups = keys.iter().map(|pi_id| { + let provider = self.0.clone(); + let pi_id = pi_id.clone(); + async move { + let pi = provider + .retrieve_payment_intent(&pi_id) + .await + .map_err(|err| async_graphql::Error::new(err.to_string()))?; + let charge = match pi.latest_charge { + Some(stripe::Expandable::Object(charge)) => Some(*charge), + _ => None, + }; + Ok::<_, async_graphql::Error>((pi_id, charge)) + } + }); + + futures::future::join_all(lookups) + .await + .into_iter() + .filter_map(|result| match result { + Err(err) => Some(Err(err)), + Ok((_, None)) => None, + Ok((pi_id, Some(charge))) => Some(Ok((pi_id, charge))), + }) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::super::test_util::*; + use crate::billing; + use crate::test_server; + use serde_json::json; + use std::sync::Arc; + + #[sqlx::test( + migrations = "../../supabase/migrations", + fixtures(path = "../../../../fixtures", scripts("data_planes", "alice")) + )] + async fn graphql_billing_invoice_filter(pool: sqlx::PgPool) { + let _guard = test_server::init(); + let tenant = "aliceco"; + let user_id = provision_test_tenant(&pool, tenant).await; + + insert_billing_historical(&pool, tenant, "2024-01-01", 1234, "Usage").await; + insert_billing_historical(&pool, tenant, "2024-02-01", 900, "Usage").await; + + let (server, token) = start_server_and_token(&pool, user_id, tenant, mock_provider()).await; + + let response: serde_json::Value = server + .graphql( + &json!({ + "query": r#" + query { + tenant(name: "aliceco/") { + name + billing { + invoices( + first: 10 + filter: { + invoiceType: { eq: FINAL } + dateStart: { gt: "2023-12-31", lt: "2024-02-01" } + } + ) { + edges { + node { + dateStart + dateEnd + invoiceType + subtotal + lineItems + extra + } + } + } + } + } + } + "# + }), + Some(&token), + ) + .await; + + insta::assert_json_snapshot!("invoice_filter_by_date_start", response); + + let by_end: serde_json::Value = server + .graphql( + &json!({ + "query": r#" + query { + tenant(name: "aliceco/") { + billing { + invoices( + first: 10 + filter: { + invoiceType: { eq: FINAL } + dateEnd: { gt: "2024-01-31", lt: "2024-03-01" } + } + ) { + edges { node { dateStart dateEnd invoiceType } } + } + } + } + } + "# + }), + Some(&token), + ) + .await; + let edges = &by_end["data"]["tenant"]["billing"]["invoices"]["edges"]; + assert_eq!(edges.as_array().map(Vec::len), Some(1)); + assert_eq!(edges[0]["node"]["dateStart"].as_str(), Some("2024-02-01")); + assert_eq!(edges[0]["node"]["dateEnd"].as_str(), Some("2024-02-29")); + } + + #[sqlx::test( + migrations = "../../supabase/migrations", + fixtures(path = "../../../../fixtures", scripts("data_planes", "alice")) + )] + async fn graphql_billing_invoice_stripe_fields(pool: sqlx::PgPool) { + let _guard = test_server::init(); + let tenant = "invoicefields"; + let user_id = provision_test_tenant(&pool, tenant).await; + + insert_billing_historical(&pool, tenant, "2024-02-01", 2500, "Manual").await; + + let mock = billing::InMemoryBillingProvider::new(); + mock.add_customer("invoicefields/", "cus_invoice", None); + mock.add_invoice( + "cus_invoice", + stripe::Invoice { + amount_due: Some(2600), + status: Some(stripe::InvoiceStatus::Paid), + invoice_pdf: Some("https://example.test/invoice.pdf".to_string()), + hosted_invoice_url: Some("https://example.test/hosted".to_string()), + payment_intent: Some(stripe::Expandable::Id("pi_test_123".parse().unwrap())), + ..Default::default() + }, + ); + mock.add_payment_intent(stripe::PaymentIntent { + id: "pi_test_123".parse().unwrap(), + latest_charge: Some(stripe::Expandable::Object(Box::new(stripe::Charge { + status: stripe::ChargeStatus::Succeeded, + receipt_url: Some("https://example.test/receipt".to_string()), + payment_method_details: Some(stripe::PaymentMethodDetails { + card: Some(stripe::PaymentMethodDetailsCard { + brand: Some("visa".to_string()), + last4: Some("4242".to_string()), + exp_month: 12, + exp_year: 2025, + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }))), + ..Default::default() + }); + + let (server, token) = start_server_and_token(&pool, user_id, tenant, Arc::new(mock)).await; + + let response: serde_json::Value = server + .graphql( + &json!({ + "query": r#" + query { + tenant(name: "invoicefields/") { + billing { + invoices( + first: 1 + filter: { + invoiceType: { eq: FINAL } + dateStart: { gt: "2024-01-31", lt: "2024-02-02" } + } + ) { + edges { + node { + amountDue + status + invoicePdf + hostedInvoiceUrl + paymentDetails { + status + receiptUrl + card { brand last4 expMonth expYear } + usBankAccount { bankName last4 accountHolderType } + } + } + } + } + } + } + } + "# + }), + Some(&token), + ) + .await; + + insta::assert_json_snapshot!("invoice_stripe_fields", response); + } + + fn invoices_page(response: &serde_json::Value) -> &serde_json::Value { + &response["data"]["tenant"]["billing"]["invoices"] + } + + fn cursor(page: &serde_json::Value, field: &str) -> String { + page["pageInfo"][field] + .as_str() + .unwrap_or_else(|| panic!("page is missing {field}: {page:#?}")) + .to_string() + } + + async fn fetch_page( + server: &test_server::TestServer, + token: &str, + tenant: &str, + filter: serde_json::Value, + page_args: serde_json::Value, + ) -> serde_json::Value { + let mut variables = serde_json::Map::from_iter([ + ("tenant".to_string(), json!(tenant)), + ("filter".to_string(), filter), + ]); + variables.extend(page_args.as_object().unwrap().clone()); + server + .graphql( + &json!({ "query": INVOICES_PAGE_QUERY, "variables": variables }), + Some(token), + ) + .await + } + + #[sqlx::test( + migrations = "../../supabase/migrations", + fixtures(path = "../../../../fixtures", scripts("data_planes", "alice")) + )] + async fn graphql_billing_invoice_pagination(pool: sqlx::PgPool) { + let _guard = test_server::init(); + let tenant = "invoicepages"; + let user_id = provision_test_tenant(&pool, tenant).await; + + for month in ["2024-01-01", "2024-02-01", "2024-03-01"] { + insert_billing_historical(&pool, tenant, month, 500, "Usage").await; + } + + let (server, token) = start_server_and_token(&pool, user_id, tenant, mock_provider()).await; + let filter = json!({ + "invoiceType": { "eq": "FINAL" }, + "dateStart": { "gt": "2023-12-31", "lt": "2024-04-01" }, + }); + + let first_page = fetch_page( + &server, + &token, + "invoicepages/", + filter.clone(), + json!({"first": 1}), + ) + .await; + insta::assert_json_snapshot!("pagination_first_page", invoices_page(&first_page)); + + let after = cursor(invoices_page(&first_page), "endCursor"); + let second_page = fetch_page( + &server, + &token, + "invoicepages/", + filter.clone(), + json!({"after": after, "first": 1}), + ) + .await; + insta::assert_json_snapshot!("pagination_second_page", invoices_page(&second_page)); + + let before = cursor(invoices_page(&second_page), "startCursor"); + let previous_page = fetch_page( + &server, + &token, + "invoicepages/", + filter, + json!({"before": before, "last": 1}), + ) + .await; + insta::assert_json_snapshot!("pagination_previous_page", invoices_page(&previous_page)); + } + + #[sqlx::test( + migrations = "../../supabase/migrations", + fixtures(path = "../../../../fixtures", scripts("data_planes", "alice")) + )] + async fn graphql_billing_invoice_tie_break_pagination(pool: sqlx::PgPool) { + let _guard = test_server::init(); + let tenant = "invoicetie"; + let user_id = provision_test_tenant(&pool, tenant).await; + + for month in ["2024-02-01", "2024-03-01"] { + insert_billing_historical(&pool, tenant, month, 500, "Usage").await; + } + + sqlx::query( + r#" + insert into internal.manual_bills (tenant, usd_cents, description, date_start, date_end) + values ($1, 700, 'Manual adjustment', '2024-03-01', '2024-03-31') + "#, + ) + .bind(format!("{tenant}/")) + .execute(&pool) + .await + .expect("insert manual bill"); + + let (server, token) = start_server_and_token(&pool, user_id, tenant, mock_provider()).await; + let filter = json!({ "dateStart": { "gt": "2024-01-31", "lt": "2024-04-01" } }); + + let first_page = fetch_page( + &server, + &token, + "invoicetie/", + filter.clone(), + json!({"first": 1}), + ) + .await; + insta::assert_json_snapshot!("tie_break_first_page", invoices_page(&first_page)); + + let after = cursor(invoices_page(&first_page), "endCursor"); + let second_page = fetch_page( + &server, + &token, + "invoicetie/", + filter.clone(), + json!({"after": after, "first": 1}), + ) + .await; + insta::assert_json_snapshot!("tie_break_second_page", invoices_page(&second_page)); + + let before = cursor(invoices_page(&second_page), "startCursor"); + let previous_page = fetch_page( + &server, + &token, + "invoicetie/", + filter, + json!({"before": before, "last": 1}), + ) + .await; + insta::assert_json_snapshot!("tie_break_previous_page", invoices_page(&previous_page)); + } +} diff --git a/crates/control-plane-api/src/server/public/graphql/billing/mod.rs b/crates/control-plane-api/src/server/public/graphql/billing/mod.rs new file mode 100644 index 00000000000..cf42669ad23 --- /dev/null +++ b/crates/control-plane-api/src/server/public/graphql/billing/mod.rs @@ -0,0 +1,310 @@ +use std::sync::Arc; + +use crate::billing::BillingProvider; +use async_graphql::Context; + +mod invoices; +mod mutations; +mod payment_methods; +mod tenant; + +pub(super) use invoices::{ChargeDataLoader, StripeInvoiceLoader}; +pub use mutations::BillingMutation; +pub(super) use tenant::CustomerDataLoader; + +fn billing_provider(ctx: &Context<'_>) -> async_graphql::Result> { + ctx.data::>() + .cloned() + .map_err(|_| async_graphql::Error::new("Billing is not configured")) +} + +#[cfg(test)] +pub(super) mod test_util { + use crate::{billing, test_server}; + use serde_json::json; + use std::sync::Arc; + + pub async fn provision_test_tenant(pool: &sqlx::PgPool, tenant: &str) -> uuid::Uuid { + crate::test_support::provision_test_tenant( + pool, + tenant, + &format!("{tenant}@example.test"), + json!({"full_name": format!("{tenant} admin")}), + ) + .await + } + + pub fn mock_provider() -> Arc { + Arc::new(billing::InMemoryBillingProvider::new()) + } + + pub async fn insert_billing_historical( + pool: &sqlx::PgPool, + tenant: &str, + month: &str, + subtotal: i32, + description: &str, + ) { + let billed_at = format!("{month}T00:00:00Z"); + sqlx::query( + r#" + insert into internal.billing_historicals (tenant, billed_month, report) + values ( + $1, + $2::timestamptz, + jsonb_build_object( + 'billed_month', $2, + 'subtotal', $3::int, + 'line_items', jsonb_build_array(jsonb_build_object('description', $4, 'subtotal', $3::int)) + ) + ) + "#, + ) + .bind(format!("{tenant}/")) + .bind(&billed_at) + .bind(subtotal) + .bind(description) + .execute(pool) + .await + .expect("insert billing historical"); + } + + pub const INVOICES_PAGE_QUERY: &str = r#" + query InvoicesPage( + $tenant: String! + $filter: InvoiceFilter + $after: String + $before: String + $first: Int + $last: Int + ) { + tenant(name: $tenant) { + billing { + invoices( + after: $after + before: $before + first: $first + last: $last + filter: $filter + ) { + pageInfo { hasNextPage hasPreviousPage startCursor endCursor } + edges { cursor node { dateStart dateEnd invoiceType } } + } + } + } + } + "#; + + pub async fn start_server_and_token( + pool: &sqlx::PgPool, + user_id: uuid::Uuid, + tenant: &str, + provider: Arc, + ) -> (test_server::TestServer, String) { + let server = test_server::TestServer::start_with_config( + pool.clone(), + test_server::snapshot(pool.clone(), true).await, + Some(provider), + models::AlertConfig::default(), + ) + .await; + let token = server.make_access_token(user_id, Some(&format!("{tenant}@example.test"))); + (server, token) + } +} + +#[cfg(test)] +mod tests { + use super::test_util::provision_test_tenant; + use crate::{billing, test_server}; + use serde_json::json; + use std::sync::Arc; + + async fn attach_test_card( + client: &stripe::Client, + customer_id: &stripe::CustomerId, + test_pm_token: &str, + ) -> stripe::PaymentMethod { + let pm_id: stripe::PaymentMethodId = test_pm_token.parse().unwrap(); + stripe::PaymentMethod::attach( + client, + &pm_id, + stripe::AttachPaymentMethod { + customer: customer_id.clone(), + }, + ) + .await + .expect("attach test payment method") + } + + async fn wait_for_customer_searchable( + provider: &dyn billing::BillingProvider, + tenant: &str, + ) -> stripe::Customer { + for _ in 0..30 { + if let Ok(Some(customer)) = provider.find_customer(tenant).await { + return customer; + } + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + } + panic!("customer for tenant '{tenant}' never became searchable after 60s"); + } + + /// Exercises every Stripe API call made by the billing GraphQL mutations: + /// - Customer search, create, update + /// - SetupIntent create + /// - PaymentMethod list, detach + #[ignore = "requires STRIPE_API_KEY set to a Stripe testmode key"] + #[sqlx::test( + migrations = "../../supabase/migrations", + fixtures(path = "../../../../fixtures", scripts("data_planes", "alice")) + )] + async fn graphql_billing_live_stripe(pool: sqlx::PgPool) { + use crate::billing::StripeBillingProvider; + + let _guard = test_server::init(); + let stripe_key = + std::env::var("STRIPE_API_KEY").expect("STRIPE_API_KEY must be set to run this test"); + let stripe_client = stripe::Client::new(stripe_key.clone()); + + let tenant = format!("stripeit{}", uuid::Uuid::new_v4().simple()); + let user_id = provision_test_tenant(&pool, &tenant).await; + let provider: Arc = + Arc::new(StripeBillingProvider::new(stripe_key)); + let server = test_server::TestServer::start_with_config( + pool.clone(), + test_server::snapshot(pool, true).await, + Some(provider.clone()), + models::AlertConfig::default(), + ) + .await; + let token = server.make_access_token(user_id, Some(&format!("{tenant}@example.test"))); + + // Phase 1: createBillingSetupIntent for a new tenant. + // Exercises: Customer search (miss) → Customer create → SetupIntent create. + let response: serde_json::Value = server + .graphql( + &json!({ + "query": format!(r#" + mutation {{ + createBillingSetupIntent(tenant: "{tenant}/") {{ + clientSecret + }} + }} + "#) + }), + Some(&token), + ) + .await; + assert!( + response["data"]["createBillingSetupIntent"]["clientSecret"] + .as_str() + .is_some(), + "setup intent should return a client secret: {response:?}" + ); + + // Phase 2: Wait for the customer to become searchable. + // Stripe's /customers/search API has eventual consistency; + // all subsequent GraphQL mutations depend on search to find the customer. + let customer = wait_for_customer_searchable(provider.as_ref(), &format!("{tenant}/")).await; + + // Attach two payment methods directly via the Stripe API (simulates + // what Stripe.js does client-side after the SetupIntent completes). + let card_a = attach_test_card(&stripe_client, &customer.id, "pm_card_visa").await; + let card_b = attach_test_card(&stripe_client, &customer.id, "pm_card_mastercard").await; + + // Phase 3: setBillingPaymentMethod. + // Exercises: Customer search (hit) → Customer update → Customer search + PaymentMethod list. + let response: serde_json::Value = server + .graphql( + &json!({ + "query": format!(r#" + mutation {{ + setBillingPaymentMethod(tenant: "{tenant}/", paymentMethodId: "{}") {{ + primaryPaymentMethod {{ id }} + paymentMethods {{ id }} + }} + }} + "#, card_a.id) + }), + Some(&token), + ) + .await; + assert_eq!( + response["data"]["setBillingPaymentMethod"]["primaryPaymentMethod"]["id"], + json!(card_a.id.to_string()), + "card_a should be set as primary: {response:?}" + ); + let pm_ids: Vec<&str> = response["data"]["setBillingPaymentMethod"]["paymentMethods"] + .as_array() + .unwrap() + .iter() + .filter_map(|pm| pm["id"].as_str()) + .collect(); + assert!( + pm_ids.contains(&card_a.id.as_str()), + "card_a should be listed" + ); + assert!( + pm_ids.contains(&card_b.id.as_str()), + "card_b should be listed" + ); + + // Phase 4: deleteBillingPaymentMethod. + // Exercises: PaymentMethod detach → Customer search → PaymentMethod list → Customer update (fallback). + let response: serde_json::Value = server + .graphql( + &json!({ + "query": format!(r#" + mutation {{ + deleteBillingPaymentMethod(tenant: "{tenant}/", paymentMethodId: "{}") {{ + primaryPaymentMethod {{ id }} + paymentMethods {{ id }} + }} + }} + "#, card_a.id) + }), + Some(&token), + ) + .await; + assert_eq!( + response["data"]["deleteBillingPaymentMethod"]["primaryPaymentMethod"]["id"], + json!(card_b.id.to_string()), + "card_b should become primary after deleting card_a: {response:?}" + ); + let pm_ids: Vec<&str> = response["data"]["deleteBillingPaymentMethod"]["paymentMethods"] + .as_array() + .unwrap() + .iter() + .filter_map(|pm| pm["id"].as_str()) + .collect(); + assert!( + !pm_ids.contains(&card_a.id.as_str()), + "card_a should be gone" + ); + assert!(pm_ids.contains(&card_b.id.as_str()), "card_b should remain"); + + // Phase 5: createBillingSetupIntent again for the same tenant. + // Exercises the "find" branch of find_or_create_customer (customer already exists). + let response: serde_json::Value = server + .graphql( + &json!({ + "query": format!(r#" + mutation {{ + createBillingSetupIntent(tenant: "{tenant}/") {{ + clientSecret + }} + }} + "#) + }), + Some(&token), + ) + .await; + assert!( + response["data"]["createBillingSetupIntent"]["clientSecret"] + .as_str() + .is_some(), + "setup intent for existing customer should return a client secret: {response:?}" + ); + } +} diff --git a/crates/control-plane-api/src/server/public/graphql/billing/mutations.rs b/crates/control-plane-api/src/server/public/graphql/billing/mutations.rs new file mode 100644 index 00000000000..5dad7f1f51c --- /dev/null +++ b/crates/control-plane-api/src/server/public/graphql/billing/mutations.rs @@ -0,0 +1,356 @@ +use super::super::tenant::{validate_tenant_name, verify_tenant}; +use super::billing_provider; +use super::payment_methods::PaymentMethod; +use crate::billing::{self, BillingProvider}; +use anyhow::Context as _; +use async_graphql::{Context, Result, SimpleObject}; + +async fn require_customer_payment_methods( + provider: &dyn BillingProvider, + customer_id: &stripe::CustomerId, + payment_method_id: &str, +) -> Result> { + let methods = provider + .list_payment_methods(customer_id) + .await + .map_err(|err| async_graphql::Error::new(err.to_string()))?; + + if methods + .iter() + .all(|method| method.id.as_str() != payment_method_id) + { + return Err(async_graphql::Error::new("payment method not found")); + } + + Ok(methods) +} + +#[derive(Debug, Default)] +pub struct BillingMutation; + +#[async_graphql::Object] +impl BillingMutation { + async fn create_billing_setup_intent( + &self, + ctx: &Context<'_>, + tenant: String, + ) -> Result { + let env = ctx.data::()?; + let tenant = validate_tenant_name(&tenant)?; + verify_tenant(env, tenant.as_str(), models::Capability::Admin).await?; + + let claims = env.claims()?; + let email = claims + .email + .as_deref() + .context("authenticated user is missing an email claim")?; + let full_name: Option = sqlx::query_scalar( + "SELECT raw_user_meta_data->>'full_name' FROM auth.users WHERE id = $1", + ) + .bind(claims.sub) + .fetch_one(&env.pg_pool) + .await + .map_err(|err| async_graphql::Error::new(err.to_string()))?; + + let provider = billing_provider(ctx)?; + let customer = provider + .as_ref() + .find_or_create_customer(tenant.as_str(), email, full_name.as_deref()) + .await + .map_err(|err| async_graphql::Error::new(err.to_string()))?; + let setup_intent = provider + .create_setup_intent(&customer.id) + .await + .map_err(|err| async_graphql::Error::new(err.to_string()))?; + let client_secret = setup_intent + .client_secret + .context("stripe setup intent response was missing client_secret")?; + + Ok(CreateBillingSetupIntentPayload { client_secret }) + } + + async fn set_billing_payment_method( + &self, + ctx: &Context<'_>, + tenant: String, + payment_method_id: String, + ) -> Result { + let env = ctx.data::()?; + let tenant = validate_tenant_name(&tenant)?; + verify_tenant(env, tenant.as_str(), models::Capability::Admin).await?; + + let provider = billing_provider(ctx)?; + let customer = provider + .as_ref() + .require_customer(tenant.as_str()) + .await + .map_err(|err| async_graphql::Error::new(err.to_string()))?; + let methods = + require_customer_payment_methods(provider.as_ref(), &customer.id, &payment_method_id) + .await?; + let updated_customer = provider + .update_customer_default_payment_method(&customer.id, Some(payment_method_id.as_str())) + .await + .map_err(|err| async_graphql::Error::new(err.to_string()))?; + + let primary_payment_method = billing::default_payment_method_id(&updated_customer) + .and_then(|id| methods.iter().find(|m| m.id.as_str() == id)) + .map(PaymentMethod::from); + Ok(BillingPaymentMethodPayload { + payment_methods: methods.iter().map(PaymentMethod::from).collect(), + primary_payment_method, + }) + } + + async fn delete_billing_payment_method( + &self, + ctx: &Context<'_>, + tenant: String, + payment_method_id: String, + ) -> Result { + let env = ctx.data::()?; + let tenant = validate_tenant_name(&tenant)?; + verify_tenant(env, tenant.as_str(), models::Capability::Admin).await?; + + let provider = billing_provider(ctx)?; + let customer = provider + .as_ref() + .require_customer(tenant.as_str()) + .await + .map_err(|err| async_graphql::Error::new(err.to_string()))?; + let methods = + require_customer_payment_methods(provider.as_ref(), &customer.id, &payment_method_id) + .await?; + let deleted_payment_method_id: stripe::PaymentMethodId = payment_method_id + .parse() + .map_err(|_| async_graphql::Error::new("invalid payment method ID"))?; + let deleted_default_payment_method = billing::default_payment_method_id(&customer) + .as_deref() + == Some(payment_method_id.as_str()); + + provider + .detach_payment_method(&deleted_payment_method_id) + .await + .map_err(|err| async_graphql::Error::new(err.to_string()))?; + let remaining_methods: Vec = methods + .into_iter() + .filter(|method| method.id.as_str() != payment_method_id) + .collect(); + + let primary_id = if deleted_default_payment_method { + let fallback = remaining_methods + .first() + .map(|method| method.id.to_string()); + let updated_customer = provider + .update_customer_default_payment_method(&customer.id, fallback.as_deref()) + .await + .map_err(|err| async_graphql::Error::new(err.to_string()))?; + billing::default_payment_method_id(&updated_customer) + } else { + billing::default_payment_method_id(&customer) + }; + let primary_payment_method = primary_id + .and_then(|id| remaining_methods.iter().find(|m| m.id.as_str() == id)) + .map(PaymentMethod::from); + + Ok(BillingPaymentMethodPayload { + payment_methods: remaining_methods.iter().map(PaymentMethod::from).collect(), + primary_payment_method, + }) + } +} + +#[derive(Debug, Clone, SimpleObject)] +pub struct CreateBillingSetupIntentPayload { + client_secret: String, +} + +#[derive(Debug, Clone, SimpleObject)] +pub struct BillingPaymentMethodPayload { + payment_methods: Vec, + primary_payment_method: Option, +} + +#[cfg(test)] +mod tests { + use super::super::test_util::*; + use crate::billing; + use crate::test_server; + use serde_json::json; + use std::sync::Arc; + + #[sqlx::test( + migrations = "../../supabase/migrations", + fixtures(path = "../../../../fixtures", scripts("data_planes", "alice")) + )] + async fn graphql_billing_payment_methods_and_mutations(pool: sqlx::PgPool) { + let _guard = test_server::init(); + let tenant = "billingmock"; + let user_id = provision_test_tenant(&pool, tenant).await; + let victim_tenant = "billingvictim"; + let victim_user_id = provision_test_tenant(&pool, victim_tenant).await; + + let mock = billing::InMemoryBillingProvider::new(); + mock.add_customer("billingmock/", "cus_123", Some("pm_1")); + mock.add_payment_method( + "cus_123", + "pm_1", + stripe::PaymentMethodType::Card, + stripe::BillingDetails { + name: Some("Alice".to_string()), + ..Default::default() + }, + Some(stripe::CardDetails { + brand: "visa".to_string(), + last4: "4242".to_string(), + ..Default::default() + }), + None, + ); + mock.add_payment_method( + "cus_123", + "pm_2", + stripe::PaymentMethodType::UsBankAccount, + stripe::BillingDetails { + name: Some("Alice".to_string()), + ..Default::default() + }, + None, + Some(stripe::PaymentMethodUsBankAccount { + bank_name: Some("STRIPE TEST BANK".to_string()), + last4: Some("6789".to_string()), + ..Default::default() + }), + ); + mock.add_customer("billingvictim/", "cus_victim", Some("pm_v")); + mock.add_payment_method( + "cus_victim", + "pm_v", + stripe::PaymentMethodType::Card, + stripe::BillingDetails { + name: Some("Victim".to_string()), + ..Default::default() + }, + Some(stripe::CardDetails { + brand: "visa".to_string(), + last4: "4444".to_string(), + exp_month: 12, + exp_year: 2030, + ..Default::default() + }), + None, + ); + + let (server, token) = start_server_and_token(&pool, user_id, tenant, Arc::new(mock)).await; + let victim_token = server.make_access_token( + victim_user_id, + Some(&format!("{victim_tenant}@example.test")), + ); + + let query_response: serde_json::Value = server + .graphql( + &json!({ + "query": r#" + query { + tenant(name: "billingmock/") { + billing { + primaryPaymentMethod { id } + paymentMethods { + id + type + billingDetails { + name + } + card { + brand + last4 + expMonth + expYear + } + usBankAccount { + bankName + last4 + accountHolderType + } + } + } + } + } + "# + }), + Some(&token), + ) + .await; + insta::assert_json_snapshot!("payment_methods_query", query_response); + + let mutation_response: serde_json::Value = server + .graphql( + &json!({ + "query": r#" + mutation { + setBillingPaymentMethod(tenant: "billingmock/", paymentMethodId: "pm_2") { + primaryPaymentMethod { id } + paymentMethods { id } + } + } + "# + }), + Some(&token), + ) + .await; + insta::assert_json_snapshot!("set_payment_method", mutation_response); + + // Delete the current default (pm_2); expect fallback to promote pm_1. + let delete_default_response: serde_json::Value = server + .graphql( + &json!({ + "query": r#" + mutation { + deleteBillingPaymentMethod(tenant: "billingmock/", paymentMethodId: "pm_2") { + primaryPaymentMethod { id } + paymentMethods { id } + } + } + "# + }), + Some(&token), + ) + .await; + insta::assert_json_snapshot!("delete_default_payment_method", delete_default_response); + + let cross_tenant_delete_response: serde_json::Value = server + .graphql( + &json!({ + "query": r#" + mutation { + deleteBillingPaymentMethod(tenant: "billingmock/", paymentMethodId: "pm_v") { + primaryPaymentMethod { id } + } + } + "# + }), + Some(&token), + ) + .await; + insta::assert_json_snapshot!("cross_tenant_delete_denied", cross_tenant_delete_response); + + let victim_query_response: serde_json::Value = server + .graphql( + &json!({ + "query": r#" + query { + tenant(name: "billingvictim/") { + billing { + primaryPaymentMethod { id } + paymentMethods { id } + } + } + } + "# + }), + Some(&victim_token), + ) + .await; + insta::assert_json_snapshot!("victim_tenant_query", victim_query_response); + } +} diff --git a/crates/control-plane-api/src/server/public/graphql/billing/payment_methods.rs b/crates/control-plane-api/src/server/public/graphql/billing/payment_methods.rs new file mode 100644 index 00000000000..ac21e4568ca --- /dev/null +++ b/crates/control-plane-api/src/server/public/graphql/billing/payment_methods.rs @@ -0,0 +1,100 @@ +use async_graphql::SimpleObject; + +#[derive(Debug, Clone, SimpleObject)] +pub struct PaymentMethod { + pub id: String, + #[graphql(name = "type")] + pub type_: String, + pub billing_details: PaymentMethodBillingDetails, + pub card: Option, + pub us_bank_account: Option, +} + +#[derive(Debug, Clone, SimpleObject)] +pub struct PaymentMethodBillingDetails { + pub name: Option, +} + +impl From<&stripe::BillingDetails> for PaymentMethodBillingDetails { + fn from(details: &stripe::BillingDetails) -> Self { + Self { + name: details.name.clone(), + } + } +} + +#[derive(Debug, Clone, SimpleObject)] +pub struct CardPaymentMethodDetails { + pub brand: String, + pub last4: String, + pub exp_month: i64, + pub exp_year: i64, +} + +impl From<&stripe::CardDetails> for CardPaymentMethodDetails { + fn from(card: &stripe::CardDetails) -> Self { + Self { + brand: card.brand.clone(), + last4: card.last4.clone(), + exp_month: card.exp_month, + exp_year: card.exp_year, + } + } +} + +impl From<&stripe::PaymentMethodDetailsCard> for CardPaymentMethodDetails { + fn from(card: &stripe::PaymentMethodDetailsCard) -> Self { + Self { + brand: card.brand.clone().unwrap_or_default(), + last4: card.last4.clone().unwrap_or_default(), + exp_month: card.exp_month, + exp_year: card.exp_year, + } + } +} + +#[derive(Debug, Clone, SimpleObject)] +pub struct UsBankAccountPaymentMethodDetails { + pub bank_name: Option, + pub last4: Option, + pub account_holder_type: Option, +} + +impl From<&stripe::PaymentMethodUsBankAccount> for UsBankAccountPaymentMethodDetails { + fn from(account: &stripe::PaymentMethodUsBankAccount) -> Self { + Self { + bank_name: account.bank_name.clone(), + last4: account.last4.clone(), + account_holder_type: account + .account_holder_type + .map(|kind| kind.as_str().to_string()), + } + } +} + +impl From<&stripe::PaymentMethodDetailsUsBankAccount> for UsBankAccountPaymentMethodDetails { + fn from(account: &stripe::PaymentMethodDetailsUsBankAccount) -> Self { + Self { + bank_name: account.bank_name.clone(), + last4: account.last4.clone(), + account_holder_type: account + .account_holder_type + .map(|kind| kind.as_str().to_string()), + } + } +} + +impl From<&stripe::PaymentMethod> for PaymentMethod { + fn from(pm: &stripe::PaymentMethod) -> Self { + Self { + id: pm.id.to_string(), + type_: pm.type_.as_str().to_string(), + billing_details: PaymentMethodBillingDetails::from(&pm.billing_details), + card: pm.card.as_ref().map(CardPaymentMethodDetails::from), + us_bank_account: pm + .us_bank_account + .as_ref() + .map(UsBankAccountPaymentMethodDetails::from), + } + } +} diff --git a/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__invoice_filter_by_date_start.snap b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__invoice_filter_by_date_start.snap new file mode 100644 index 00000000000..a8125036395 --- /dev/null +++ b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__invoice_filter_by_date_start.snap @@ -0,0 +1,41 @@ +--- +source: crates/control-plane-api/src/server/public/graphql/billing/invoices.rs +expression: response +--- +{ + "data": { + "tenant": { + "billing": { + "invoices": { + "edges": [ + { + "node": { + "dateEnd": "2024-01-31", + "dateStart": "2024-01-01", + "extra": { + "billed_month": "2024-01-01T00:00:00Z", + "line_items": [ + { + "description": "Usage", + "subtotal": 1234 + } + ], + "subtotal": 1234 + }, + "invoiceType": "FINAL", + "lineItems": [ + { + "description": "Usage", + "subtotal": 1234 + } + ], + "subtotal": 1234 + } + } + ] + } + }, + "name": "aliceco/" + } + } +} diff --git a/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__invoice_stripe_fields.snap b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__invoice_stripe_fields.snap new file mode 100644 index 00000000000..4d66c3420a3 --- /dev/null +++ b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__invoice_stripe_fields.snap @@ -0,0 +1,36 @@ +--- +source: crates/control-plane-api/src/server/public/graphql/billing/invoices.rs +assertion_line: 454 +expression: response +--- +{ + "data": { + "tenant": { + "billing": { + "invoices": { + "edges": [ + { + "node": { + "amountDue": 2600, + "hostedInvoiceUrl": "https://example.test/hosted", + "invoicePdf": "https://example.test/invoice.pdf", + "paymentDetails": { + "card": { + "brand": "visa", + "expMonth": 12, + "expYear": 2025, + "last4": "4242" + }, + "receiptUrl": "https://example.test/receipt", + "status": "SUCCEEDED", + "usBankAccount": null + }, + "status": "paid" + } + } + ] + } + } + } + } +} diff --git a/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__pagination_first_page.snap b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__pagination_first_page.snap new file mode 100644 index 00000000000..357b45b53ac --- /dev/null +++ b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__pagination_first_page.snap @@ -0,0 +1,22 @@ +--- +source: crates/control-plane-api/src/server/public/graphql/billing/invoices.rs +expression: invoices_page(&first_page) +--- +{ + "edges": [ + { + "cursor": "2024-03-31;2024-03-01;final", + "node": { + "dateEnd": "2024-03-31", + "dateStart": "2024-03-01", + "invoiceType": "FINAL" + } + } + ], + "pageInfo": { + "endCursor": "2024-03-31;2024-03-01;final", + "hasNextPage": true, + "hasPreviousPage": false, + "startCursor": "2024-03-31;2024-03-01;final" + } +} diff --git a/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__pagination_previous_page.snap b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__pagination_previous_page.snap new file mode 100644 index 00000000000..6ed6adbc3bb --- /dev/null +++ b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__pagination_previous_page.snap @@ -0,0 +1,22 @@ +--- +source: crates/control-plane-api/src/server/public/graphql/billing/invoices.rs +expression: invoices_page(&previous_page) +--- +{ + "edges": [ + { + "cursor": "2024-03-31;2024-03-01;final", + "node": { + "dateEnd": "2024-03-31", + "dateStart": "2024-03-01", + "invoiceType": "FINAL" + } + } + ], + "pageInfo": { + "endCursor": "2024-03-31;2024-03-01;final", + "hasNextPage": true, + "hasPreviousPage": false, + "startCursor": "2024-03-31;2024-03-01;final" + } +} diff --git a/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__pagination_second_page.snap b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__pagination_second_page.snap new file mode 100644 index 00000000000..b5066d9e8f4 --- /dev/null +++ b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__pagination_second_page.snap @@ -0,0 +1,22 @@ +--- +source: crates/control-plane-api/src/server/public/graphql/billing/invoices.rs +expression: invoices_page(&second_page) +--- +{ + "edges": [ + { + "cursor": "2024-02-29;2024-02-01;final", + "node": { + "dateEnd": "2024-02-29", + "dateStart": "2024-02-01", + "invoiceType": "FINAL" + } + } + ], + "pageInfo": { + "endCursor": "2024-02-29;2024-02-01;final", + "hasNextPage": true, + "hasPreviousPage": true, + "startCursor": "2024-02-29;2024-02-01;final" + } +} diff --git a/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__tie_break_first_page.snap b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__tie_break_first_page.snap new file mode 100644 index 00000000000..357b45b53ac --- /dev/null +++ b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__tie_break_first_page.snap @@ -0,0 +1,22 @@ +--- +source: crates/control-plane-api/src/server/public/graphql/billing/invoices.rs +expression: invoices_page(&first_page) +--- +{ + "edges": [ + { + "cursor": "2024-03-31;2024-03-01;final", + "node": { + "dateEnd": "2024-03-31", + "dateStart": "2024-03-01", + "invoiceType": "FINAL" + } + } + ], + "pageInfo": { + "endCursor": "2024-03-31;2024-03-01;final", + "hasNextPage": true, + "hasPreviousPage": false, + "startCursor": "2024-03-31;2024-03-01;final" + } +} diff --git a/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__tie_break_previous_page.snap b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__tie_break_previous_page.snap new file mode 100644 index 00000000000..6ed6adbc3bb --- /dev/null +++ b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__tie_break_previous_page.snap @@ -0,0 +1,22 @@ +--- +source: crates/control-plane-api/src/server/public/graphql/billing/invoices.rs +expression: invoices_page(&previous_page) +--- +{ + "edges": [ + { + "cursor": "2024-03-31;2024-03-01;final", + "node": { + "dateEnd": "2024-03-31", + "dateStart": "2024-03-01", + "invoiceType": "FINAL" + } + } + ], + "pageInfo": { + "endCursor": "2024-03-31;2024-03-01;final", + "hasNextPage": true, + "hasPreviousPage": false, + "startCursor": "2024-03-31;2024-03-01;final" + } +} diff --git a/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__tie_break_second_page.snap b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__tie_break_second_page.snap new file mode 100644 index 00000000000..cba3cdb5d17 --- /dev/null +++ b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__invoices__tests__tie_break_second_page.snap @@ -0,0 +1,22 @@ +--- +source: crates/control-plane-api/src/server/public/graphql/billing/invoices.rs +expression: invoices_page(&second_page) +--- +{ + "edges": [ + { + "cursor": "2024-03-31;2024-03-01;manual", + "node": { + "dateEnd": "2024-03-31", + "dateStart": "2024-03-01", + "invoiceType": "MANUAL" + } + } + ], + "pageInfo": { + "endCursor": "2024-03-31;2024-03-01;manual", + "hasNextPage": true, + "hasPreviousPage": true, + "startCursor": "2024-03-31;2024-03-01;manual" + } +} diff --git a/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__mutations__tests__cross_tenant_delete_denied.snap b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__mutations__tests__cross_tenant_delete_denied.snap new file mode 100644 index 00000000000..842d97baeff --- /dev/null +++ b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__mutations__tests__cross_tenant_delete_denied.snap @@ -0,0 +1,22 @@ +--- +source: crates/control-plane-api/src/server/public/graphql/billing/mutations.rs +assertion_line: 337 +expression: cross_tenant_delete_response +--- +{ + "data": null, + "errors": [ + { + "locations": [ + { + "column": 27, + "line": 3 + } + ], + "message": "payment method not found", + "path": [ + "deleteBillingPaymentMethod" + ] + } + ] +} diff --git a/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__mutations__tests__delete_default_payment_method.snap b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__mutations__tests__delete_default_payment_method.snap new file mode 100644 index 00000000000..4b2ee643fb8 --- /dev/null +++ b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__mutations__tests__delete_default_payment_method.snap @@ -0,0 +1,19 @@ +--- +source: crates/control-plane-api/src/server/public/graphql/billing/mutations.rs +assertion_line: 321 +expression: delete_default_response +--- +{ + "data": { + "deleteBillingPaymentMethod": { + "paymentMethods": [ + { + "id": "pm_1" + } + ], + "primaryPaymentMethod": { + "id": "pm_1" + } + } + } +} diff --git a/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__mutations__tests__payment_methods_query.snap b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__mutations__tests__payment_methods_query.snap new file mode 100644 index 00000000000..5c17b9597eb --- /dev/null +++ b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__mutations__tests__payment_methods_query.snap @@ -0,0 +1,45 @@ +--- +source: crates/control-plane-api/src/server/public/graphql/billing/mutations.rs +assertion_line: 286 +expression: query_response +--- +{ + "data": { + "tenant": { + "billing": { + "paymentMethods": [ + { + "billingDetails": { + "name": "Alice" + }, + "card": { + "brand": "visa", + "expMonth": 0, + "expYear": 0, + "last4": "4242" + }, + "id": "pm_1", + "type": "card", + "usBankAccount": null + }, + { + "billingDetails": { + "name": "Alice" + }, + "card": null, + "id": "pm_2", + "type": "us_bank_account", + "usBankAccount": { + "accountHolderType": null, + "bankName": "STRIPE TEST BANK", + "last4": "6789" + } + } + ], + "primaryPaymentMethod": { + "id": "pm_1" + } + } + } + } +} diff --git a/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__mutations__tests__set_payment_method.snap b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__mutations__tests__set_payment_method.snap new file mode 100644 index 00000000000..205ccb2f616 --- /dev/null +++ b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__mutations__tests__set_payment_method.snap @@ -0,0 +1,22 @@ +--- +source: crates/control-plane-api/src/server/public/graphql/billing/mutations.rs +assertion_line: 303 +expression: mutation_response +--- +{ + "data": { + "setBillingPaymentMethod": { + "paymentMethods": [ + { + "id": "pm_1" + }, + { + "id": "pm_2" + } + ], + "primaryPaymentMethod": { + "id": "pm_2" + } + } + } +} diff --git a/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__mutations__tests__victim_tenant_query.snap b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__mutations__tests__victim_tenant_query.snap new file mode 100644 index 00000000000..0e06339ec75 --- /dev/null +++ b/crates/control-plane-api/src/server/public/graphql/billing/snapshots/control_plane_api__server__public__graphql__billing__mutations__tests__victim_tenant_query.snap @@ -0,0 +1,21 @@ +--- +source: crates/control-plane-api/src/server/public/graphql/billing/mutations.rs +assertion_line: 356 +expression: victim_query_response +--- +{ + "data": { + "tenant": { + "billing": { + "paymentMethods": [ + { + "id": "pm_v" + } + ], + "primaryPaymentMethod": { + "id": "pm_v" + } + } + } + } +} diff --git a/crates/control-plane-api/src/server/public/graphql/billing/tenant.rs b/crates/control-plane-api/src/server/public/graphql/billing/tenant.rs new file mode 100644 index 00000000000..cdadaa63898 --- /dev/null +++ b/crates/control-plane-api/src/server/public/graphql/billing/tenant.rs @@ -0,0 +1,196 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use super::super::tenant::{Tenant, verify_tenant}; +use super::billing_provider; +use super::invoices::{Invoice, InvoiceCursor, InvoiceFilter}; +use super::payment_methods::PaymentMethod; +use crate::billing::{self, BillingProvider}; +use async_graphql::{ + ComplexObject, Context, Result, + connection::{self, Connection}, + dataloader::{DataLoader, Loader}, +}; + +#[ComplexObject] +impl Tenant { + async fn billing(&self, ctx: &Context<'_>) -> Result { + let env = ctx.data::()?; + verify_tenant(env, &self.name, models::Capability::Admin).await?; + let provider = billing_provider(ctx)?; + Ok(TenantBilling::new(self.name.clone(), provider)) + } +} + +#[derive(Debug, Clone)] +pub struct TenantBilling { + tenant: String, + provider: Arc, +} + +impl TenantBilling { + fn new(tenant: String, provider: Arc) -> Self { + Self { tenant, provider } + } +} + +pub(in crate::server::public::graphql) struct CustomerDataLoader(pub Arc); + +impl Loader for CustomerDataLoader { + type Value = stripe::Customer; + type Error = async_graphql::Error; + + async fn load(&self, keys: &[String]) -> Result> { + let lookups = keys.iter().map(|tenant| { + let provider = self.0.clone(); + let tenant = tenant.clone(); + async move { + let customer = provider + .find_customer(&tenant) + .await + .map_err(|err| async_graphql::Error::new(err.to_string()))?; + Ok::<_, async_graphql::Error>((tenant, customer)) + } + }); + + futures::future::join_all(lookups) + .await + .into_iter() + .filter_map(|result| match result { + Err(err) => Some(Err(err)), + Ok((_, None)) => None, + Ok((tenant, Some(customer))) => Some(Ok((tenant, customer))), + }) + .collect() + } +} + +#[async_graphql::Object] +impl TenantBilling { + async fn payment_methods(&self, ctx: &Context<'_>) -> Result> { + let loader = ctx.data::>()?; + let Some(customer) = loader.load_one(self.tenant.clone()).await? else { + return Ok(Vec::new()); + }; + let methods = self + .provider + .list_payment_methods(&customer.id) + .await + .map_err(|err| async_graphql::Error::new(err.to_string()))?; + Ok(methods.iter().map(PaymentMethod::from).collect()) + } + + async fn primary_payment_method(&self, ctx: &Context<'_>) -> Result> { + let loader = ctx.data::>()?; + let Some(customer) = loader.load_one(self.tenant.clone()).await? else { + return Ok(None); + }; + let Some(primary_id) = billing::default_payment_method_id(&customer) else { + return Ok(None); + }; + let pm = self + .provider + .get_payment_method(&primary_id.parse().map_err(|_| { + async_graphql::Error::new("invalid payment method ID in customer default") + })?) + .await + .map_err(|err| async_graphql::Error::new(err.to_string()))?; + Ok(Some(PaymentMethod::from(&pm))) + } + + async fn invoices( + &self, + ctx: &Context<'_>, + filter: Option, + after: Option, + before: Option, + first: Option, + last: Option, + ) -> Result> { + let env = ctx.data::()?; + let tenant = self.tenant.clone(); + let query = filter.unwrap_or_default().into_query(); + + connection::query_with::( + after, + before, + first, + last, + |after, before, first, last| async move { + let (rows, has_prev, has_next) = if before.is_some() || last.is_some() { + let (rows, has_prev) = billing::fetch_invoice_rows_backward( + &env.pg_pool, + &tenant, + &query, + before, + last, + ) + .await + .map_err(async_graphql::Error::from)?; + (rows, has_prev, before.is_some()) + } else { + let (rows, has_next) = billing::fetch_invoice_rows_forward( + &env.pg_pool, + &tenant, + &query, + after, + first, + ) + .await + .map_err(async_graphql::Error::from)?; + (rows, after.is_some(), has_next) + }; + + let mut connection = Connection::new(has_prev, has_next); + connection.edges.extend(rows.into_iter().map(|row| { + let cursor = InvoiceCursor::from_row(&row); + let invoice = Invoice::from_row(row); + connection::Edge::new(cursor, invoice) + })); + Ok(connection) + }, + ) + .await + } +} + +#[cfg(test)] +mod tests { + use super::super::test_util::*; + use crate::test_server; + use serde_json::json; + + #[sqlx::test( + migrations = "../../supabase/migrations", + fixtures(path = "../../../../fixtures", scripts("data_planes", "alice")) + )] + async fn graphql_tenant_query_authorization(pool: sqlx::PgPool) { + let _guard = test_server::init(); + let owner_tenant = "tenantowner"; + let target_tenant = "tenanttarget"; + let owner_user_id = provision_test_tenant(&pool, owner_tenant).await; + let _target_user_id = provision_test_tenant(&pool, target_tenant).await; + + let (server, token) = + start_server_and_token(&pool, owner_user_id, owner_tenant, mock_provider()).await; + + // verify_tenant runs before tenant_exists, so querying another tenant + // (or a nonexistent one) fails identically; one assertion is enough. + let unauthorized: serde_json::Value = server + .graphql( + &json!({ + "query": format!(r#" + query {{ + tenant(name: "{target_tenant}/") {{ + name + }} + }} + "#) + }), + Some(&token), + ) + .await; + assert_eq!(unauthorized["data"]["tenant"], serde_json::Value::Null); + assert_eq!(unauthorized["errors"].as_array().map(Vec::len), Some(1)); + } +} diff --git a/crates/control-plane-api/src/server/public/graphql/filters.rs b/crates/control-plane-api/src/server/public/graphql/filters.rs index 994e27440a6..7de4758cff8 100644 --- a/crates/control-plane-api/src/server/public/graphql/filters.rs +++ b/crates/control-plane-api/src/server/public/graphql/filters.rs @@ -1,8 +1,16 @@ +use chrono::NaiveDate; + #[derive(Debug, Clone, Default, async_graphql::InputObject)] pub struct BoolFilter { pub eq: Option, } +#[derive(Debug, Clone, Default, async_graphql::InputObject)] +pub struct DateFilter { + pub gt: Option, + pub lt: Option, +} + #[derive(Debug, Clone, Default, async_graphql::InputObject)] pub struct PrefixFilter { pub starts_with: Option, diff --git a/crates/control-plane-api/src/server/public/graphql/mod.rs b/crates/control-plane-api/src/server/public/graphql/mod.rs index be94746c1fc..e3b012003da 100644 --- a/crates/control-plane-api/src/server/public/graphql/mod.rs +++ b/crates/control-plane-api/src/server/public/graphql/mod.rs @@ -26,6 +26,7 @@ mod alert_subscriptions; mod alert_types; mod alerts; mod authorized_prefixes; +mod billing; mod data_planes; mod filters; pub(crate) use data_planes::parse_data_plane_name; @@ -38,6 +39,7 @@ mod prefixes; mod publication_history; pub mod status; mod storage_mappings; +mod tenant; /// A JSON object, the shape of which is opaque to the graphql schema pub type JsonObject = async_graphql::Json>; @@ -64,11 +66,13 @@ pub struct QueryRoot( data_planes::DataPlanesQuery, invite_links::InviteLinksQuery, connectors::ConnectorsQuery, + tenant::TenantQuery, ); // Represents the portion of the GraphQL schema that deals with mutations. #[derive(Debug, Default, async_graphql::MergedObject)] pub struct MutationRoot( + billing::BillingMutation, storage_mappings::StorageMappingsMutation, alert_configs::AlertConfigsMutation, alert_subscriptions::AlertSubscriptionsMutation, @@ -94,19 +98,37 @@ pub fn schema_sdl() -> String { #[axum::debug_handler(state=std::sync::Arc)] pub(crate) async fn graphql_handler( + axum::extract::State(app): axum::extract::State>, axum::Extension(schema): axum::Extension, env: crate::Envelope, axum::extract::Json(req): axum::extract::Json, ) -> axum::response::Response { let pg_pool = env.pg_pool.clone(); - let request = req + let mut request = req .data(env) .data(async_graphql::dataloader::DataLoader::new( PgDataLoader(pg_pool), tokio::spawn, )); + if let Some(ref billing_provider) = app.billing_provider { + request = request + .data(billing_provider.clone()) + .data(async_graphql::dataloader::DataLoader::new( + billing::StripeInvoiceLoader(billing_provider.clone()), + tokio::spawn, + )) + .data(async_graphql::dataloader::DataLoader::new( + billing::ChargeDataLoader(billing_provider.clone()), + tokio::spawn, + )) + .data(async_graphql::dataloader::DataLoader::new( + billing::CustomerDataLoader(billing_provider.clone()), + tokio::spawn, + )); + } + let response = schema.execute(request).await; // Check for AuthZRetry errors - return 307 redirect for the first one found. diff --git a/crates/control-plane-api/src/server/public/graphql/tenant.rs b/crates/control-plane-api/src/server/public/graphql/tenant.rs new file mode 100644 index 00000000000..e9a45f1f6a5 --- /dev/null +++ b/crates/control-plane-api/src/server/public/graphql/tenant.rs @@ -0,0 +1,69 @@ +use async_graphql::{Context, Result, SimpleObject}; +use validator::Validate; + +#[derive(Debug, Default)] +pub struct TenantQuery; + +#[async_graphql::Object] +impl TenantQuery { + async fn tenant(&self, ctx: &Context<'_>, name: String) -> Result> { + let env = ctx.data::()?; + let tenant = validate_tenant_name(&name)?; + + verify_tenant(env, tenant.as_str(), models::Capability::Read).await?; + + if !tenant_exists(&env.pg_pool, tenant.as_str()).await? { + return Ok(None); + } + + Ok(Some(Tenant { + name: tenant.to_string(), + })) + } +} + +#[derive(Debug, Clone, SimpleObject)] +#[graphql(complex)] +pub struct Tenant { + pub name: String, +} + +pub(super) async fn tenant_exists(pool: &sqlx::PgPool, tenant: &str) -> Result { + let exists = sqlx::query_scalar::<_, bool>( + r#" + SELECT EXISTS( + SELECT 1 + FROM tenants + WHERE tenant = $1 + ) + "#, + ) + .bind(tenant) + .fetch_one(pool) + .await?; + + Ok(exists) +} + +pub(super) async fn verify_tenant( + env: &crate::Envelope, + tenant: &str, + capability: models::Capability, +) -> Result<()> { + let policy_result = crate::server::evaluate_names_authorization( + env.snapshot(), + env.claims()?, + capability, + [tenant], + ); + let (_expiry, ()) = env.authorization_outcome(policy_result).await?; + Ok(()) +} + +pub(super) fn validate_tenant_name(name: &str) -> Result { + let prefix = models::Prefix::new(name); + prefix + .validate() + .map_err(|err| async_graphql::Error::new(format!("invalid tenant name: {err}")))?; + Ok(prefix) +} diff --git a/crates/control-plane-api/src/test_server.rs b/crates/control-plane-api/src/test_server.rs index 576ddd600bd..ee3f243fb9c 100644 --- a/crates/control-plane-api/src/test_server.rs +++ b/crates/control-plane-api/src/test_server.rs @@ -70,13 +70,34 @@ pub struct TestServer { impl TestServer { pub async fn start(pg_pool: sqlx::PgPool, snapshot: Arc>) -> Self { - Self::start_with_alert_defaults(pg_pool, snapshot, models::AlertConfig::default()).await + Self::start_with_config( + pg_pool, + snapshot, + Some(Arc::new(crate::billing::InMemoryBillingProvider::new())), + models::AlertConfig::default(), + ) + .await } pub async fn start_with_alert_defaults( pg_pool: sqlx::PgPool, snapshot: Arc>, alert_config_defaults: models::AlertConfig, + ) -> Self { + Self::start_with_config( + pg_pool, + snapshot, + Some(Arc::new(crate::billing::InMemoryBillingProvider::new())), + alert_config_defaults, + ) + .await + } + + pub async fn start_with_config( + pg_pool: sqlx::PgPool, + snapshot: Arc>, + billing_provider: Option>, + alert_config_defaults: models::AlertConfig, ) -> Self { let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); // TODO(johnny): Aggregate into a sink? @@ -95,6 +116,7 @@ impl TestServer { let app = Arc::new(crate::App::new( models::IdGenerator::new(0), + billing_provider, b"test-jwt-secret-for-integration-tests", pg_pool.clone(), publisher, diff --git a/crates/control-plane-api/src/test_support.rs b/crates/control-plane-api/src/test_support.rs new file mode 100644 index 00000000000..61ac5d9a100 --- /dev/null +++ b/crates/control-plane-api/src/test_support.rs @@ -0,0 +1,35 @@ +pub async fn provision_test_tenant( + pool: &sqlx::PgPool, + tenant: &str, + email: &str, + user_meta: serde_json::Value, +) -> uuid::Uuid { + let user_id = uuid::Uuid::new_v4(); + let mut txn = pool.begin().await.expect("begin txn"); + + sqlx::query(r#"insert into auth.users (id, email, raw_user_meta_data) values ($1, $2, $3)"#) + .bind(user_id) + .bind(email) + .bind(&user_meta) + .execute(&mut *txn) + .await + .expect("insert auth user"); + + crate::directives::beta_onboard::provision_tenant( + "support@estuary.dev", + Some("test tenant".to_string()), + tenant, + user_id, + &mut txn, + ) + .await + .expect("provision tenant"); + + sqlx::query(r#"delete from role_grants where subject_role = 'estuary_support/';"#) + .execute(&mut *txn) + .await + .expect("delete support grant"); + + txn.commit().await.expect("commit tenant"); + user_id +} diff --git a/mise/tasks/local/control-plane b/mise/tasks/local/control-plane index 8ecab30abd0..e5f9eb7cb29 100755 --- a/mise/tasks/local/control-plane +++ b/mise/tasks/local/control-plane @@ -28,6 +28,13 @@ SSL_CERT_FILE=${FLOW_LOCAL}/ca.crt EOF +# Forward STRIPE_API_KEY from the invoking shell when set +# (e.g. a testmode key pasted before `mise run`). Without it, billing +# GraphQL operations return an error indicating billing is not configured. +if [ -n "${STRIPE_API_KEY:-}" ]; then + echo "STRIPE_API_KEY=${STRIPE_API_KEY}" >> "${FLOW_LOCAL}/env/agent.env" +fi + cat > "${FLOW_LOCAL}/env/config-encryption.env" <