diff --git a/Cargo.lock b/Cargo.lock index e89fccc..e378a11 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -869,7 +869,7 @@ dependencies = [ [[package]] name = "protobuf" version = "0.0.0" -source = "git+https://github.com/thinkparq/protobuf?rev=e2e774e7db7e3d4474d6e7232bb06bbdffc5610c#e2e774e7db7e3d4474d6e7232bb06bbdffc5610c" +source = "git+https://github.com/thinkparq/protobuf?rev=4d5e5db085065acbbaa5bb76ce4b81d6d733e446#4d5e5db085065acbbaa5bb76ce4b81d6d733e446" dependencies = [ "prost", "prost-types", diff --git a/Cargo.toml b/Cargo.toml index 22f5dd9..d15d815 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ itertools = "0" libc = "0" log = { version = "0", features = ["std"] } prost = "0.14" -protobuf = { git = "https://github.com/thinkparq/protobuf", rev = "e2e774e7db7e3d4474d6e7232bb06bbdffc5610c" } +protobuf = { git = "https://github.com/thinkparq/protobuf", rev = "4d5e5db085065acbbaa5bb76ce4b81d6d733e446" } regex = "1" ring = "0" rusqlite = { version = "0", features = ["bundled", "vtab", "array", "fallible_uint"] } diff --git a/mgmtd/assets/beegfs-mgmtd.toml b/mgmtd/assets/beegfs-mgmtd.toml index 6e43c28..8562550 100644 --- a/mgmtd/assets/beegfs-mgmtd.toml +++ b/mgmtd/assets/beegfs-mgmtd.toml @@ -90,9 +90,6 @@ # Defines after which time without contact a client is considered gone and will be removed. # client-auto-remove-timeout = "30m" -# Disables loading the license library. This disables all enterprise features. -# license-disable = false - # The BeeGFS license certificate file. # license-cert-file = "/etc/beegfs/license.pem" @@ -151,7 +148,7 @@ # quota-group-ids-range = "1000-1100" -### Capacity pools ### +### Capacity pools ### # Sets the limits / boundaries of the meta capacity pools. If changed, the whole block must # be uncommented and set. These cannot be lower than the cap-pool-dynamic-meta-limits below. diff --git a/mgmtd/src/app.rs b/mgmtd/src/app.rs index c978c81..0ed1b89 100644 --- a/mgmtd/src/app.rs +++ b/mgmtd/src/app.rs @@ -81,6 +81,7 @@ pub(crate) trait App: Debug + Clone + Send + 'static { fn load_and_verify_license_cert( &self, cert_path: &Path, + prev_trial_serial: Option, ) -> impl Future> + Send; /// Get license certificate data diff --git a/mgmtd/src/app/runtime.rs b/mgmtd/src/app/runtime.rs index 7431200..941d7eb 100644 --- a/mgmtd/src/app/runtime.rs +++ b/mgmtd/src/app/runtime.rs @@ -164,8 +164,13 @@ impl App for RuntimeApp { } } - async fn load_and_verify_license_cert(&self, cert_path: &Path) -> Result { - LicenseVerifier::load_and_verify_license_cert(&self.license, cert_path).await + async fn load_and_verify_license_cert( + &self, + cert_path: &Path, + prev_trial_serial: Option, + ) -> Result { + LicenseVerifier::load_and_verify_license_cert(&self.license, cert_path, prev_trial_serial) + .await } fn get_license_cert_data(&self) -> Result { diff --git a/mgmtd/src/app/test.rs b/mgmtd/src/app/test.rs index 7d5f424..7e5b65a 100644 --- a/mgmtd/src/app/test.rs +++ b/mgmtd/src/app/test.rs @@ -160,7 +160,11 @@ impl App for TestApp { fn notify_client_pulled_state(&self, _node_type: NodeType, _node_id: NodeId) {} - async fn load_and_verify_license_cert(&self, _cert_path: &std::path::Path) -> Result { + async fn load_and_verify_license_cert( + &self, + _cert_path: &std::path::Path, + _prev_trial_serial: Option, + ) -> Result { Ok("dummy cert".to_string()) } diff --git a/mgmtd/src/bee_msg.rs b/mgmtd/src/bee_msg.rs index c7cd2bd..4152b0f 100644 --- a/mgmtd/src/bee_msg.rs +++ b/mgmtd/src/bee_msg.rs @@ -92,14 +92,14 @@ pub(crate) async fn dispatch_request(app: &RuntimeApp, mut req: impl Request) -> macro_rules! dispatch_msg { ($({$msg_type:path => $r:tt, $ctx_str:literal})*) => { // Match on the message ID provided by the request - match req.msg_id() { + match req.header().msg_id() { $( <$msg_type>::ID => { let des: $msg_type = req.deserialize_msg().with_context(|| { format!( "{} ({}) from {:?}", stringify!($msg_type), - req.msg_id(), + req.header().msg_id(), req.addr() ) })?; @@ -186,7 +186,7 @@ async fn handle_unspecified_msg(req: impl Request) -> Result<()> { log::warn!( "Unhandled msg INCOMING from {:?} with ID {}", req.addr(), - req.msg_id() + req.header().msg_id() ); // Signal to the caller that the msg is not handled. The generic response diff --git a/mgmtd/src/bee_msg/authenticate_channel.rs b/mgmtd/src/bee_msg/authenticate_channel.rs index bfb053a..d6d5789 100644 --- a/mgmtd/src/bee_msg/authenticate_channel.rs +++ b/mgmtd/src/bee_msg/authenticate_channel.rs @@ -27,11 +27,12 @@ impl HandleNoResponse for AuthenticateChannel { mod test { use super::*; use crate::app::test::*; + use shared::bee_msg::Header; #[tokio::test] async fn authenticate_channel() { let app = TestApp::new().await; - let mut req = TestRequest::new(AuthenticateChannel::ID); + let mut req = TestRequest::new(Header::default()); AuthenticateChannel { auth_secret: AuthSecret::hash_from_bytes("secret"), diff --git a/mgmtd/src/bee_msg/change_target_consistency_states.rs b/mgmtd/src/bee_msg/change_target_consistency_states.rs index efe66a4..34c4051 100644 --- a/mgmtd/src/bee_msg/change_target_consistency_states.rs +++ b/mgmtd/src/bee_msg/change_target_consistency_states.rs @@ -95,11 +95,12 @@ doesn't match stored state {old_stored}, no consistency state changes will be ma mod test { use super::*; use crate::app::test::*; + use shared::bee_msg::Header; #[tokio::test] async fn change_target_consistency_states() { let app = TestApp::new().await; - let mut req = TestRequest::new(ChangeTargetConsistencyStates::ID); + let mut req = TestRequest::new(Header::default()); // Prepare times app.db @@ -174,7 +175,7 @@ mod test { #[tokio::test] async fn change_target_consistency_states_old_states() { let app = TestApp::new().await; - let mut req = TestRequest::new(ChangeTargetConsistencyStates::ID); + let mut req = TestRequest::new(Header::default()); // Mismatch of reported old state should not change the consistency states let msg = ChangeTargetConsistencyStates { diff --git a/mgmtd/src/bee_msg/common.rs b/mgmtd/src/bee_msg/common.rs index 6d37186..1f01675 100644 --- a/mgmtd/src/bee_msg/common.rs +++ b/mgmtd/src/bee_msg/common.rs @@ -1,7 +1,8 @@ use super::*; use crate::db::node_nic::ReplaceNic; use db::misc::MetaRoot; -use rusqlite::Transaction; +use protobuf::license::VerifyResult; +use rusqlite::{Transaction, params}; use shared::bee_msg::node::*; use shared::bee_msg::target::*; use shared::types::{NodeId, TargetId}; @@ -9,12 +10,44 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; +// Maximum number of clients that can register if license verification fails or license is invalid +const MAX_NUM_CLIENTS: u32 = 5; + /// Processes incoming node information. Registers new nodes if config allows it -pub(super) async fn update_node(msg: RegisterNode, app: &impl App) -> Result { +pub(super) async fn update_node(msg: RegisterNode, app: &impl App, reject: bool) -> Result { let nics = msg.nics.clone(); let requested_node_id = msg.node_id; let registration_disable = app.static_info().user_config.registration_disable; + let licensed_clients: Option = if msg.node_type == NodeType::Client { + match app.get_license_cert_data() { + Ok(r) => match r.result() { + // If license is valid, no limit to client count + VerifyResult::VerifyValid => None, + // no license file loaded, limit number of clients to NUM_CLIENTS + VerifyResult::VerifyError => Some(MAX_NUM_CLIENTS), + // license file was loaded and is outside validity period + VerifyResult::VerifyInvalid => None, + _ => { + log::debug!( + "Unexpected error during license verification, limiting number of clients to {MAX_NUM_CLIENTS}: {0}", + r.message + ); + Some(MAX_NUM_CLIENTS) + } + }, + Err(e) => { + log::debug!( + "Error during license verification, limiting number of clients to {MAX_NUM_CLIENTS}: {e:#}", + ); + Some(MAX_NUM_CLIENTS) + } + } + } else { + // not a client registration, so not going to be used anyway. But let's be defensive + Some(MAX_NUM_CLIENTS) + }; + let licensed_machines = match app.get_licensed_machines() { Ok(n) => n, Err(err) => { @@ -97,6 +130,22 @@ registration token ({new_alias_or_reg_token}) does not match the stored token ({ bail!("Registration of new nodes is not allowed"); } + let num_reg_clients: u32 = tx.query_row( + sql!("SELECT COUNT(DISTINCT node_uid) FROM nodes WHERE node_type = ?1"), + params![NodeType::Client.sql_variant()], + |row| row.get(0), + )?; + + if msg.node_type == NodeType::Client + && let Some(cs) = licensed_clients + && num_reg_clients >= cs { + if reject { + bail!("Number of licensed clients ({MAX_NUM_CLIENTS}) exhausted. Client registration denied."); + } else { + log::warn!("Number of licensed clients ({MAX_NUM_CLIENTS}) exhausted but client doesn't support rejection."); + } + } + let new_alias = if msg.node_type == NodeType::Client { // In versions prior to 8.0 the string node ID generated by the client // started with a number which is not allowed by the new alias schema. diff --git a/mgmtd/src/bee_msg/get_nodes.rs b/mgmtd/src/bee_msg/get_nodes.rs index acacf68..a2f9e9e 100644 --- a/mgmtd/src/bee_msg/get_nodes.rs +++ b/mgmtd/src/bee_msg/get_nodes.rs @@ -61,11 +61,12 @@ impl HandleWithResponse for GetNodes { mod test { use super::*; use crate::app::test::*; + use shared::bee_msg::Header; #[tokio::test] async fn get_nodes() { let app = TestApp::new().await; - let mut req = TestRequest::new(GetNodes::ID); + let mut req = TestRequest::new(Header::default()); let resp = GetNodes { node_type: NodeType::Meta, diff --git a/mgmtd/src/bee_msg/heartbeat.rs b/mgmtd/src/bee_msg/heartbeat.rs index f3127fd..3d75c24 100644 --- a/mgmtd/src/bee_msg/heartbeat.rs +++ b/mgmtd/src/bee_msg/heartbeat.rs @@ -24,6 +24,7 @@ impl HandleWithResponse for Heartbeat { machine_uuid: self.machine_uuid, }, app, + false, ) .await?; diff --git a/mgmtd/src/bee_msg/register_node.rs b/mgmtd/src/bee_msg/register_node.rs index 7d817ad..10d7cf5 100644 --- a/mgmtd/src/bee_msg/register_node.rs +++ b/mgmtd/src/bee_msg/register_node.rs @@ -2,13 +2,18 @@ use super::*; use common::update_node; use shared::bee_msg::node::*; +const COMPATFLAG_CLIENT_SUPPORTS_REGREJ: u8 = 1; + impl HandleWithResponse for RegisterNode { type Response = RegisterNodeResp; - async fn handle(self, app: &impl App, _req: &mut impl Request) -> Result { + async fn handle(self, app: &impl App, req: &mut impl Request) -> Result { fail_on_pre_shutdown(app)?; - let node_id = update_node(self, app).await?; + let reject = + (req.header().msg_compat_feature_flags & COMPATFLAG_CLIENT_SUPPORTS_REGREJ) != 0; + + let node_id = update_node(self, app, reject).await?; let fs_uuid: String = app .read_tx(|tx| db::config::get(tx, db::config::Config::FsUuid)) diff --git a/mgmtd/src/bee_msg/request_exceeded_quota.rs b/mgmtd/src/bee_msg/request_exceeded_quota.rs index cf9915c..17ba72f 100644 --- a/mgmtd/src/bee_msg/request_exceeded_quota.rs +++ b/mgmtd/src/bee_msg/request_exceeded_quota.rs @@ -1,4 +1,5 @@ use super::*; +use crate::license::LicensedFeature; use rusqlite::params; use shared::bee_msg::quota::*; @@ -13,6 +14,8 @@ impl HandleWithResponse for RequestExceededQuota { } async fn handle(self, app: &impl App, _req: &mut impl Request) -> Result { + app.verify_licensed_feature(LicensedFeature::Quota)?; + let inner = app .read_tx(move |tx| { // Quota is calculated per pool, so if a target ID is given, use its assigned pools @@ -66,11 +69,12 @@ mod test { use super::*; use crate::app::test::*; use crate::bee_msg::HandleWithResponse; + use shared::bee_msg::Header; #[tokio::test] async fn request_exceeded_quota() { let app = TestApp::new().await; - let mut req = TestRequest::new(RequestExceededQuota::ID); + let mut req = TestRequest::new(Header::default()); let tests: &[(_, &[u32])] = &[ ( diff --git a/mgmtd/src/config.rs b/mgmtd/src/config.rs index 7c912fe..bf21083 100644 --- a/mgmtd/src/config.rs +++ b/mgmtd/src/config.rs @@ -279,9 +279,11 @@ generate_structs! { /// Disables loading the license library. /// - /// This disables all enterprise features. + /// Deprecated. Loading a license is now mandatory. #[arg(long)] #[arg(num_args = 0..=1, default_missing_value = "true")] + #[arg(hide = true)] + #[serde(skip)] license_disable: bool = false, /// The BeeGFS license certificate file. [default: /etc/beegfs/license.pem] diff --git a/mgmtd/src/db/config.rs b/mgmtd/src/db/config.rs index 10720ac..96f8948 100644 --- a/mgmtd/src/db/config.rs +++ b/mgmtd/src/db/config.rs @@ -17,11 +17,12 @@ pub(crate) enum Config { #[allow(unused)] FsName, CounterLastClientID, + TrialSerial, } // Config entries that should not be changed after initially set. Note that this only controls the // functions below, the database entries could still be changed by manual query -const IMMUTABLE: &[Config] = &[Config::FsUuid, Config::FsInitDateSecs]; +const IMMUTABLE: &[Config] = &[Config::FsUuid, Config::FsInitDateSecs, Config::TrialSerial]; impl Config { /// The string representation of the config key as it is written to the db @@ -31,6 +32,7 @@ impl Config { Config::FsInitDateSecs => "fs_init_date_secs", Config::FsName => "fs_name", Config::CounterLastClientID => "counter_last_client_id", + Config::TrialSerial => "trial_serial", } } } diff --git a/mgmtd/src/grpc/get_license.rs b/mgmtd/src/grpc/get_license.rs index f8b4479..d41b34b 100644 --- a/mgmtd/src/grpc/get_license.rs +++ b/mgmtd/src/grpc/get_license.rs @@ -1,4 +1,6 @@ use super::*; +use crate::db::config::Config; +use protobuf::license::CertType; use protobuf::management::{self as pm, GetLicenseResponse}; pub(crate) async fn get_license( @@ -7,8 +9,26 @@ pub(crate) async fn get_license( ) -> Result { let reload: bool = required_field(req.reload)?; if reload { - app.load_and_verify_license_cert(&app.static_info().user_config.license_cert_file) + let prev_trial_serial: Option = app + .read_tx(|tx| db::config::get(tx, Config::TrialSerial)) .await?; + + let serial = app + .load_and_verify_license_cert( + &app.static_info().user_config.license_cert_file, + prev_trial_serial.clone(), + ) + .await?; + + if app + .get_license_cert_data()? + .data + .is_some_and(|d| d.r#type() == CertType::Trial) + && prev_trial_serial.is_none() + { + app.write_tx(|tx| db::config::set(tx, Config::TrialSerial, serial)) + .await?; + } } let cert_data = app.get_license_cert_data()?; Ok(GetLicenseResponse { diff --git a/mgmtd/src/lib.rs b/mgmtd/src/lib.rs index cac458a..652165b 100644 --- a/mgmtd/src/lib.rs +++ b/mgmtd/src/lib.rs @@ -16,8 +16,10 @@ use crate::app::RuntimeApp; use crate::config::Config; use anyhow::{Context, Result}; use app::App; +use db::config::Config as dbConfig; use db::node_nic::ReplaceNic; use license::LicenseVerifier; +use protobuf::license::CertType; use shared::bee_msg::target::RefreshTargetStates; use shared::conn::incoming; use shared::conn::outgoing::Pool; @@ -114,6 +116,35 @@ pub async fn start(info: StaticInfo, license: LicenseVerifier) -> Result = db + .read_tx(|tx| db::config::get(tx, db::config::Config::TrialSerial)) + .await?; + + // Load and verify license certificate + match license + .load_and_verify_license_cert( + &info.user_config.license_cert_file, + prev_trial_serial.clone(), + ) + .await + { + Ok(serial) => { + if license + .get_license_cert_data()? + .data + .is_some_and(|d| d.r#type() == CertType::Trial) + && prev_trial_serial.is_none() + { + db.write_tx(|tx| db::config::set(tx, dbConfig::TrialSerial, serial)) + .await?; + } + } + Err(err) => log::warn!( + "Loading and verifying license certificate failed. \ + Licensed features will be unavailable: {err}" + ), + }; + // Fill node addrs store from db db.read_tx(db::node_nic::get_all_addrs) .await? diff --git a/mgmtd/src/license.rs b/mgmtd/src/license.rs index 76eeb5b..4a02d73 100644 --- a/mgmtd/src/license.rs +++ b/mgmtd/src/license.rs @@ -203,6 +203,7 @@ impl LicenseVerifier { pub async fn load_and_verify_license_cert( &self, cert_path: impl AsRef, + prev_trial_serial: Option, ) -> Result { let Some(ref library) = self.0 else { bail!("License verification library not loaded."); @@ -223,14 +224,28 @@ impl LicenseVerifier { let message = res.message; match result { - VerifyResult::VerifyValid => { - log::info!("Successfully loaded license certificate: {serial}"); - Ok(serial) - } + VerifyResult::VerifyValid => match self.get_license_cert_data() { + Ok(c) => { + if c.data.is_some_and(|d| { + d.r#type() == CertType::Trial + && prev_trial_serial.clone().is_some_and(|s| serial != s) + }) { + library.init_cert_store(); + Err(anyhow!( + "Unable to apply trial license {serial}, because system was previously used with trial license {}", + prev_trial_serial.unwrap() + )) + } else { + log::info!("Successfully loaded license certificate: {serial}"); + Ok(serial) + } + } + Err(err) => Err(anyhow!("Error getting license data: {err}")), + }, VerifyResult::VerifyInvalid => Err(anyhow!(message)), - VerifyResult::VerifyError => Err(anyhow!( - "Internal error during certificate verification: {message}" - )), + VerifyResult::VerifyError => { + Err(anyhow!("Error during license verification: {message}")) + } VerifyResult::VerifyUnspecified => Err(anyhow!("Unspecified result.")), } } @@ -288,9 +303,9 @@ impl LicenseVerifier { match result { VerifyResult::VerifyValid => Ok(()), VerifyResult::VerifyInvalid => Err(anyhow!(message)), - VerifyResult::VerifyError => Err(anyhow!( - "Internal error during feature verification: {message}" - )), + VerifyResult::VerifyError => { + Err(anyhow!("Error during feature verification: {message}")) + } VerifyResult::VerifyUnspecified => Err(anyhow!("Unspecified result.")), } } diff --git a/mgmtd/src/main.rs b/mgmtd/src/main.rs index 0fa6617..622e315 100644 --- a/mgmtd/src/main.rs +++ b/mgmtd/src/main.rs @@ -113,33 +113,18 @@ If you want to initialize a new system, refer to --help or doc.beegfs.io.", .max_blocking_threads(user_config.max_blocking_threads) .build()?; + // Load the licensing library + // + // SAFETY: + // There is no way to verify that the user loaded dynamic library matches the + // requirements of LicenseVerifier. After all, users can load anything they + // want. Therefore, this is just not safe to do from the Rust compilers + // perspective and loading anything with non-matching fp signatures or not + // behaving as expected will lead to undefined behavior. + let license = unsafe { LicenseVerifier::with_lib(&user_config.license_lib_file) }; + // Run the tokio executor rt.block_on(async move { - // Load the licensing library - let license = if !user_config.license_disable { - // SAFETY: - // There is no way to verify that the user loaded dynamic library matches the - // requirements of LicenseVerifier. After all, users can load anything they - // want. Therefore, this is just not safe to do from the Rust compilers - // perspective and loading anything with non-matching fp signatures or not - // behaving as expected will lead to undefined behavior. - let license = unsafe { LicenseVerifier::with_lib(&user_config.license_lib_file) }; - - if let Err(err) = license - .load_and_verify_license_cert(&user_config.license_cert_file) - .await - { - log::warn!( - "Initializing licensing library failed. \ - Licensed features will be unavailable: {err}" - ); - } - - license - } else { - LicenseVerifier::with_no_lib() - }; - // Start the actual daemon let run = start( StaticInfo { diff --git a/mgmtd/src/quota.rs b/mgmtd/src/quota.rs index d4129e2..1b8e8bc 100644 --- a/mgmtd/src/quota.rs +++ b/mgmtd/src/quota.rs @@ -3,6 +3,7 @@ mod system_id; use crate::app::*; +use crate::license::LicensedFeature; use crate::types::SqliteEnumExt; use anyhow::{Context as AnyhowContext, Result}; use rusqlite::params; @@ -16,10 +17,14 @@ use sqlite_check::sql; use std::collections::HashSet; use std::path::Path; -/// Fetches quota information for all storage targets, calculates exceeded IDs and distributes them. -pub(crate) async fn update_and_distribute(app: &impl App) -> Result<()> { - // Fetch quota data from storage daemons +/// Fetches quota information for all storage targets and updates the quota usage database +pub(crate) async fn fetch_and_update(app: &impl App) -> Result<()> { + if app.verify_licensed_feature(LicensedFeature::Quota).is_err() { + log::warn!("Quota is enabled but feature not licensed. Skipping quota collection"); + return Ok(()); + } + // Fetch quota data from storage daemons let targets: Vec<(TargetId, PoolId, Uid)> = app .read_tx(move |tx| { tx.query_map_collect( @@ -196,19 +201,20 @@ pub(crate) async fn update_and_distribute(app: &impl App) -> Result<()> { } } - if app.static_info().user_config.quota_enforce { - exceeded_quota(app).await?; - } - Ok(()) } -/// Calculate and push exceeded quota info to the nodes -async fn exceeded_quota(app: &impl App) -> Result<()> { +/// Calculates and pushes exceeded quota info to the nodes +pub(crate) async fn distribute_exceeded(app: &impl App) -> Result<()> { + if !app.static_info().user_config.quota_enforce { + return Ok(()); + } log::info!("Calculating and pushing exceeded quota"); + let quota_licensed = app.verify_licensed_feature(LicensedFeature::Quota).is_ok(); + let (msges, nodes) = app - .read_tx(|tx| { + .read_tx(move |tx| { let pools: Vec<_> = tx.query_map_collect(sql!("SELECT pool_id FROM pools"), [], |row| row.get(0))?; @@ -229,29 +235,36 @@ async fn exceeded_quota(app: &impl App) -> Result<()> { } } - // Fill the prepared messages with matching exceeded quota ids - let mut stmt = tx.prepare_cached(sql!( - "SELECT DISTINCT e.quota_id, e.id_type, e.quota_type, st.pool_id - FROM quota_usage AS e - INNER JOIN targets AS st USING(node_type, target_id) - LEFT JOIN quota_default_limits AS d USING(id_type, quota_type, pool_id) - LEFT JOIN quota_limits AS l USING(quota_id, id_type, quota_type, pool_id) - GROUP BY e.quota_id, e.id_type, e.quota_type, st.pool_id - HAVING SUM(e.value) > COALESCE(l.value, d.value)" - ))?; - let mut rows = stmt.query([])?; - while let Some(row) = rows.next()? { - for m in &mut msges { - if row.get::<_, PoolId>(3)? == m.pool_id - && QuotaIdType::from_row(row, 1)? == m.id_type - && QuotaType::from_row(row, 2)? == m.quota_type - { - m.exceeded_quota_ids.push(row.get(0)?); - break; + if quota_licensed { + // Fill the prepared messages with matching exceeded quota ids + let mut stmt = tx.prepare_cached(sql!( + "SELECT DISTINCT e.quota_id, e.id_type, e.quota_type, st.pool_id + FROM quota_usage AS e + INNER JOIN targets AS st USING(node_type, target_id) + LEFT JOIN quota_default_limits AS d USING(id_type, quota_type, pool_id) + LEFT JOIN quota_limits AS l USING(quota_id, id_type, quota_type, pool_id) + GROUP BY e.quota_id, e.id_type, e.quota_type, st.pool_id + HAVING SUM(e.value) > COALESCE(l.value, d.value)" + ))?; + let mut rows = stmt.query([])?; + while let Some(row) = rows.next()? { + for m in &mut msges { + if row.get::<_, PoolId>(3)? == m.pool_id + && QuotaIdType::from_row(row, 1)? == m.id_type + && QuotaType::from_row(row, 2)? == m.quota_type + { + m.exceeded_quota_ids.push(row.get(0)?); + break; + } } } + } else { + log::info!( + "Quota enforcement enabled but feature not licensed. Removing quota limits from nodes" + ); } + // Get all node uids to send the messages to let nodes: Vec = tx.query_map_collect( sql!("SELECT node_uid FROM nodes WHERE node_type IN (?1,?2)"), @@ -330,7 +343,6 @@ mod test { async fn update() { let app = TestApp::with_config(Config { quota_enable: true, - quota_enforce: false, // Exceeded calculation and push is tested separately quota_user_ids_range: Some(0..=9), quota_group_ids_range: Some(0..=9), ..Default::default() @@ -369,7 +381,7 @@ mod test { })) }); - super::update_and_distribute(&app).await.unwrap(); + super::fetch_and_update(&app).await.unwrap(); // Find the amount of target 1 entries which values match the schema they have been reported // with @@ -425,7 +437,7 @@ mod test { })) }); - super::update_and_distribute(&app).await.unwrap(); + super::fetch_and_update(&app).await.unwrap(); // Now target 2 quota should be empty, target 1 quota should be completely untouched due to // the error (even if it only failed for user quota request) @@ -465,7 +477,7 @@ mod test { })) }); - super::update_and_distribute(&app).await.unwrap(); + super::fetch_and_update(&app).await.unwrap(); // Target 1 should now only have the couple of entries resulting from above app.db @@ -488,7 +500,7 @@ mod test { } #[tokio::test] - async fn exceeded_quota() { + async fn distribute_exceeded() { // This fn doesn't need special config let app = TestApp::new().await; @@ -521,6 +533,6 @@ mod test { })) }); - super::exceeded_quota(&app).await.unwrap(); + super::distribute_exceeded(&app).await.unwrap(); } } diff --git a/mgmtd/src/timer.rs b/mgmtd/src/timer.rs index 568caed..e56bab4 100644 --- a/mgmtd/src/timer.rs +++ b/mgmtd/src/timer.rs @@ -3,8 +3,7 @@ use crate::App; use crate::app::RuntimeApp; use crate::db::{self}; -use crate::license::LicensedFeature; -use crate::quota::update_and_distribute; +use crate::quota::{distribute_exceeded, fetch_and_update}; use shared::bee_msg::target::RefreshTargetStates; use shared::run_state::RunStateHandle; use shared::types::NodeType; @@ -19,13 +18,7 @@ pub(crate) fn start_tasks(app: RuntimeApp, run_state: RunStateHandle) { tokio::spawn(switchover(app.clone(), run_state.clone())); if app.info.user_config.quota_enable { - if let Err(err) = app.license.verify_licensed_feature(LicensedFeature::Quota) { - log::error!( - "Quota is enabled in the config, but the feature could not be verified. Continuing without quota support: {err}" - ); - } else { - tokio::spawn(update_quota(app, run_state)); - } + tokio::spawn(update_quota(app, run_state)); } } @@ -63,9 +56,11 @@ async fn update_quota(app: RuntimeApp, mut run_state: RunStateHandle) { loop { log::debug!("Running quota update"); - match update_and_distribute(&app).await { - Ok(_) => {} - Err(err) => log::error!("Updating quota failed: {err:#}"), + if let Err(e) = fetch_and_update(&app).await { + log::error!("Updating quota failed: {e:#}"); + } + if let Err(e) = distribute_exceeded(&app).await { + log::error!("Distributing exceeded quota failed: {e:#}"); } tokio::select! { diff --git a/shared/src/conn/msg_dispatch.rs b/shared/src/conn/msg_dispatch.rs index 2932148..4bf7824 100644 --- a/shared/src/conn/msg_dispatch.rs +++ b/shared/src/conn/msg_dispatch.rs @@ -1,7 +1,7 @@ //! Facilities for dispatching TCP and UDP messages to their message handlers use super::stream::Stream; -use crate::bee_msg::{Header, Msg, MsgId, deserialize_body, serialize}; +use crate::bee_msg::{Header, Msg, deserialize_body, serialize}; use crate::bee_serde::{Deserializable, Serializable}; use anyhow::Result; use std::fmt::Debug; @@ -25,7 +25,7 @@ pub trait Request: Send + Sync { fn respond(self, msg: &M) -> impl Future> + Send; fn authenticate_connection(&mut self); fn addr(&self) -> SocketAddr; - fn msg_id(&self) -> MsgId; + fn header(&self) -> &Header; fn deserialize_msg(&self) -> Result; } @@ -61,8 +61,8 @@ impl Request for StreamRequest<'_> { deserialize_body(self.header, &self.buf[Header::LEN..]) } - fn msg_id(&self) -> MsgId { - self.header.msg_id() + fn header(&self) -> &Header { + self.header } } @@ -96,24 +96,25 @@ impl Request for SocketRequest<'_> { deserialize_body(self.header, &self.buf[Header::LEN..]) } - fn msg_id(&self) -> MsgId { - self.header.msg_id() + fn header(&self) -> &Header { + self.header } } pub mod test { use super::*; + use crate::bee_msg::Header; use std::net::{Ipv4Addr, SocketAddrV4}; pub struct TestRequest { - pub msg_id: MsgId, + pub header: Header, pub authenticate_connection: bool, } impl TestRequest { - pub fn new(msg_id: MsgId) -> Self { + pub fn new(header: Header) -> Self { Self { - msg_id, + header, authenticate_connection: false, } } @@ -133,8 +134,8 @@ pub mod test { SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into() } - fn msg_id(&self) -> MsgId { - self.msg_id + fn header(&self) -> &Header { + &self.header } fn deserialize_msg(&self) -> Result {