Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ itertools = "0"
libc = "0"
log = { version = "0", features = ["std"] }
prost = "0.14"
protobuf = { git = "https://github.com/thinkparq/protobuf", rev = "e2e774e7db7e3d4474d6e7232bb06bbdffc5610c" }
protobuf = { git = "https://github.com/thinkparq/protobuf", rev = "4d5e5db085065acbbaa5bb76ce4b81d6d733e446" }
regex = "1"
ring = "0"
rusqlite = { version = "0", features = ["bundled", "vtab", "array", "fallible_uint"] }
Expand Down
1 change: 1 addition & 0 deletions mgmtd/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ pub(crate) trait App: Debug + Clone + Send + 'static {
fn load_and_verify_license_cert(
&self,
cert_path: &Path,
prev_trial_serial: Option<String>,
) -> impl Future<Output = Result<String>> + Send;

/// Get license certificate data
Expand Down
9 changes: 7 additions & 2 deletions mgmtd/src/app/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,13 @@ impl App for RuntimeApp {
}
}

async fn load_and_verify_license_cert(&self, cert_path: &Path) -> Result<String> {
LicenseVerifier::load_and_verify_license_cert(&self.license, cert_path).await
async fn load_and_verify_license_cert(
&self,
cert_path: &Path,
prev_trial_serial: Option<String>,
) -> Result<String> {
LicenseVerifier::load_and_verify_license_cert(&self.license, cert_path, prev_trial_serial)
.await
}

fn get_license_cert_data(&self) -> Result<GetCertDataResult> {
Expand Down
6 changes: 5 additions & 1 deletion mgmtd/src/app/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ impl App for TestApp {

fn notify_client_pulled_state(&self, _node_type: NodeType, _node_id: NodeId) {}

async fn load_and_verify_license_cert(&self, _cert_path: &std::path::Path) -> Result<String> {
async fn load_and_verify_license_cert(
&self,
_cert_path: &std::path::Path,
_prev_trial_serial: Option<String>,
) -> Result<String> {
Ok("dummy cert".to_string())
}

Expand Down
6 changes: 3 additions & 3 deletions mgmtd/src/bee_msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ pub(crate) async fn dispatch_request(app: &RuntimeApp, mut req: impl Request) ->
macro_rules! dispatch_msg {
($({$msg_type:path => $r:tt, $ctx_str:literal})*) => {
// Match on the message ID provided by the request
match req.msg_id() {
match req.header().msg_id() {
$(
<$msg_type>::ID => {
let des: $msg_type = req.deserialize_msg().with_context(|| {
format!(
"{} ({}) from {:?}",
stringify!($msg_type),
req.msg_id(),
req.header().msg_id(),
req.addr()
)
})?;
Expand Down Expand Up @@ -186,7 +186,7 @@ async fn handle_unspecified_msg(req: impl Request) -> Result<()> {
log::warn!(
"Unhandled msg INCOMING from {:?} with ID {}",
req.addr(),
req.msg_id()
req.header().msg_id()
);

// Signal to the caller that the msg is not handled. The generic response
Expand Down
3 changes: 2 additions & 1 deletion mgmtd/src/bee_msg/authenticate_channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ impl HandleNoResponse for AuthenticateChannel {
mod test {
use super::*;
use crate::app::test::*;
use shared::bee_msg::Header;

#[tokio::test]
async fn authenticate_channel() {
let app = TestApp::new().await;
let mut req = TestRequest::new(AuthenticateChannel::ID);
let mut req = TestRequest::new(Header::default());

AuthenticateChannel {
auth_secret: AuthSecret::hash_from_bytes("secret"),
Expand Down
5 changes: 3 additions & 2 deletions mgmtd/src/bee_msg/change_target_consistency_states.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,12 @@ doesn't match stored state {old_stored}, no consistency state changes will be ma
mod test {
use super::*;
use crate::app::test::*;
use shared::bee_msg::Header;

#[tokio::test]
async fn change_target_consistency_states() {
let app = TestApp::new().await;
let mut req = TestRequest::new(ChangeTargetConsistencyStates::ID);
let mut req = TestRequest::new(Header::default());

// Prepare times
app.db
Expand Down Expand Up @@ -174,7 +175,7 @@ mod test {
#[tokio::test]
async fn change_target_consistency_states_old_states() {
let app = TestApp::new().await;
let mut req = TestRequest::new(ChangeTargetConsistencyStates::ID);
let mut req = TestRequest::new(Header::default());

// Mismatch of reported old state should not change the consistency states
let msg = ChangeTargetConsistencyStates {
Expand Down
53 changes: 51 additions & 2 deletions mgmtd/src/bee_msg/common.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,53 @@
use super::*;
use crate::db::node_nic::ReplaceNic;
use db::misc::MetaRoot;
use rusqlite::Transaction;
use protobuf::license::VerifyResult;
use rusqlite::{Transaction, params};
use shared::bee_msg::node::*;
use shared::bee_msg::target::*;
use shared::types::{NodeId, TargetId};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;

// Maximum number of clients that can register if license verification fails or license is invalid
const MAX_NUM_CLIENTS: u32 = 5;

/// Processes incoming node information. Registers new nodes if config allows it
pub(super) async fn update_node(msg: RegisterNode, app: &impl App) -> Result<NodeId> {
pub(super) async fn update_node(msg: RegisterNode, app: &impl App, reject: bool) -> Result<NodeId> {
let nics = msg.nics.clone();
let requested_node_id = msg.node_id;
let registration_disable = app.static_info().user_config.registration_disable;

let licensed_clients: Option<u32> = if msg.node_type == NodeType::Client {
match app.get_license_cert_data() {
Ok(r) => match r.result() {
// If license is valid, no limit to client count
VerifyResult::VerifyValid => None,
// no license file loaded, limit number of clients to NUM_CLIENTS
VerifyResult::VerifyError => Some(MAX_NUM_CLIENTS),
// license file was loaded and is outside validity period
VerifyResult::VerifyInvalid => None,
Comment thread
iamjoemccormick marked this conversation as resolved.
_ => {
log::debug!(
Comment thread
philippfalk marked this conversation as resolved.
Dismissed
"Unexpected error during license verification, limiting number of clients to {MAX_NUM_CLIENTS}: {0}",
r.message
);
Some(MAX_NUM_CLIENTS)
}
},
Err(e) => {
log::debug!(
"Error during license verification, limiting number of clients to {MAX_NUM_CLIENTS}: {e:#}",
);
Some(MAX_NUM_CLIENTS)
}
}
} else {
// not a client registration, so not going to be used anyway. But let's be defensive
Some(MAX_NUM_CLIENTS)
};

let licensed_machines = match app.get_licensed_machines() {
Ok(n) => n,
Err(err) => {
Expand Down Expand Up @@ -97,6 +130,22 @@ registration token ({new_alias_or_reg_token}) does not match the stored token ({
bail!("Registration of new nodes is not allowed");
}

let num_reg_clients: u32 = tx.query_row(
sql!("SELECT COUNT(DISTINCT node_uid) FROM nodes WHERE node_type = ?1"),
params![NodeType::Client.sql_variant()],
|row| row.get(0),
)?;

if msg.node_type == NodeType::Client
&& let Some(cs) = licensed_clients
&& num_reg_clients >= cs {
if reject {
bail!("Number of licensed clients ({MAX_NUM_CLIENTS}) exhausted. Client registration denied.");
} else {
log::warn!("Number of licensed clients ({MAX_NUM_CLIENTS}) exhausted but client doesn't support rejection.");
}
}

let new_alias = if msg.node_type == NodeType::Client {
// In versions prior to 8.0 the string node ID generated by the client
// started with a number which is not allowed by the new alias schema.
Expand Down
3 changes: 2 additions & 1 deletion mgmtd/src/bee_msg/get_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,12 @@ impl HandleWithResponse for GetNodes {
mod test {
use super::*;
use crate::app::test::*;
use shared::bee_msg::Header;

#[tokio::test]
async fn get_nodes() {
let app = TestApp::new().await;
let mut req = TestRequest::new(GetNodes::ID);
let mut req = TestRequest::new(Header::default());

let resp = GetNodes {
node_type: NodeType::Meta,
Expand Down
1 change: 1 addition & 0 deletions mgmtd/src/bee_msg/heartbeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ impl HandleWithResponse for Heartbeat {
machine_uuid: self.machine_uuid,
},
app,
false,
)
.await?;

Expand Down
9 changes: 7 additions & 2 deletions mgmtd/src/bee_msg/register_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@ use super::*;
use common::update_node;
use shared::bee_msg::node::*;

const COMPATFLAG_CLIENT_SUPPORTS_REGREJ: u8 = 1;

impl HandleWithResponse for RegisterNode {
type Response = RegisterNodeResp;

async fn handle(self, app: &impl App, _req: &mut impl Request) -> Result<Self::Response> {
async fn handle(self, app: &impl App, req: &mut impl Request) -> Result<Self::Response> {
fail_on_pre_shutdown(app)?;

let node_id = update_node(self, app).await?;
let reject =
(req.header().msg_compat_feature_flags & COMPATFLAG_CLIENT_SUPPORTS_REGREJ) != 0;

let node_id = update_node(self, app, reject).await?;

let fs_uuid: String = app
.read_tx(|tx| db::config::get(tx, db::config::Config::FsUuid))
Expand Down
6 changes: 5 additions & 1 deletion mgmtd/src/bee_msg/request_exceeded_quota.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use crate::license::LicensedFeature;
use rusqlite::params;
use shared::bee_msg::quota::*;

Expand All @@ -13,6 +14,8 @@ impl HandleWithResponse for RequestExceededQuota {
}

async fn handle(self, app: &impl App, _req: &mut impl Request) -> Result<Self::Response> {
app.verify_licensed_feature(LicensedFeature::Quota)?;

let inner = app
.read_tx(move |tx| {
// Quota is calculated per pool, so if a target ID is given, use its assigned pools
Expand Down Expand Up @@ -66,11 +69,12 @@ mod test {
use super::*;
use crate::app::test::*;
use crate::bee_msg::HandleWithResponse;
use shared::bee_msg::Header;

#[tokio::test]
async fn request_exceeded_quota() {
let app = TestApp::new().await;
let mut req = TestRequest::new(RequestExceededQuota::ID);
let mut req = TestRequest::new(Header::default());

let tests: &[(_, &[u32])] = &[
(
Expand Down
4 changes: 3 additions & 1 deletion mgmtd/src/db/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ pub(crate) enum Config {
#[allow(unused)]
FsName,
CounterLastClientID,
TrialSerial,
}

// Config entries that should not be changed after initially set. Note that this only controls the
// functions below, the database entries could still be changed by manual query
const IMMUTABLE: &[Config] = &[Config::FsUuid, Config::FsInitDateSecs];
const IMMUTABLE: &[Config] = &[Config::FsUuid, Config::FsInitDateSecs, Config::TrialSerial];
Comment thread
iamjoemccormick marked this conversation as resolved.

impl Config {
/// The string representation of the config key as it is written to the db
Expand All @@ -31,6 +32,7 @@ impl Config {
Config::FsInitDateSecs => "fs_init_date_secs",
Config::FsName => "fs_name",
Config::CounterLastClientID => "counter_last_client_id",
Config::TrialSerial => "trial_serial",
}
}
}
Expand Down
21 changes: 20 additions & 1 deletion mgmtd/src/grpc/get_license.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use super::*;
use crate::db::config::Config;
use protobuf::license::CertType;
use protobuf::management::{self as pm, GetLicenseResponse};

pub(crate) async fn get_license(
Expand All @@ -7,8 +9,25 @@ pub(crate) async fn get_license(
) -> Result<pm::GetLicenseResponse> {
let reload: bool = required_field(req.reload)?;
if reload {
app.load_and_verify_license_cert(&app.static_info().user_config.license_cert_file)
let prev_trial_serial: Option<String> = app
.read_tx(|tx| db::config::get(tx, Config::TrialSerial))
.await?;

let serial = app
.load_and_verify_license_cert(
&app.static_info().user_config.license_cert_file,
prev_trial_serial,
)
.await?;

if app
.get_license_cert_data()?
.data
.is_some_and(|d| d.r#type() == CertType::Trial)
{
app.write_tx(|tx| db::config::set(tx, Config::TrialSerial, serial))
.await?;
}
}
let cert_data = app.get_license_cert_data()?;
Ok(GetLicenseResponse {
Expand Down
27 changes: 27 additions & 0 deletions mgmtd/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ use crate::app::RuntimeApp;
use crate::config::Config;
use anyhow::{Context, Result};
use app::App;
use db::config::Config as dbConfig;
use db::node_nic::ReplaceNic;
use license::LicenseVerifier;
use protobuf::license::CertType;
use shared::bee_msg::target::RefreshTargetStates;
use shared::conn::incoming;
use shared::conn::outgoing::Pool;
Expand Down Expand Up @@ -114,6 +116,31 @@ pub async fn start(info: StaticInfo, license: LicenseVerifier) -> Result<RunCont
})
.await?;

let prev_trial_serial: Option<String> = db
.read_tx(|tx| db::config::get(tx, db::config::Config::TrialSerial))
.await?;

// Load and verify license certificate
match license
.load_and_verify_license_cert(&info.user_config.license_cert_file, prev_trial_serial)
.await
{
Ok(serial) => {
if license
.get_license_cert_data()?
.data
.is_some_and(|d| d.r#type == CertType::Trial.into())
{
db.write_tx(|tx| db::config::set(tx, dbConfig::TrialSerial, serial))
.await?;
}
}
Err(err) => log::warn!(
"Loading and verifying license certificate failed. \
Licensed features will be unavailable: {err}"
),
};

// Fill node addrs store from db
db.read_tx(db::node_nic::get_all_addrs)
.await?
Expand Down
Loading
Loading