diff --git a/grpc_streaming_attestation/src/server.rs b/grpc_streaming_attestation/src/server.rs index cef7176da65..4446eb53548 100644 --- a/grpc_streaming_attestation/src/server.rs +++ b/grpc_streaming_attestation/src/server.rs @@ -24,7 +24,7 @@ use anyhow::Context; use futures::{Stream, StreamExt}; use oak_remote_attestation::handshaker::{AttestationBehavior, Encryptor, ServerHandshaker}; use oak_utils::LogError; -use std::pin::Pin; +use std::{pin::Pin, sync::Arc}; use tonic::{Request, Response, Status, Streaming}; /// Handler for subsequent encrypted requests from the stream after the handshake is completed. @@ -93,7 +93,7 @@ pub struct AttestationServer { /// Processes data from client requests and creates responses. request_handler: F, /// Configuration information to provide to the client for the attestation step. - additional_info: Vec, + additional_info: Arc>, /// Error logging function that is required for logging attestation protocol errors. /// Errors are only logged on server side and are not sent to clients. error_logger: L, @@ -114,7 +114,7 @@ where Ok(Self { tee_certificate, request_handler, - additional_info, + additional_info: Arc::new(additional_info), error_logger, }) } @@ -136,8 +136,8 @@ where ) -> Result, Status> { let tee_certificate = self.tee_certificate.clone(); let request_handler = self.request_handler.clone(); - let additional_info = self.additional_info.clone(); let error_logger = self.error_logger.clone(); + let additional_info = self.additional_info.clone(); let response_stream = async_stream::try_stream! { let mut request_stream = request_stream.into_inner(); @@ -148,7 +148,7 @@ where error_logger.log_error(&format!("Couldn't create self attestation behavior: {:?}", error)); Status::internal("") })?, - additional_info, + additional_info ); while !handshaker.is_completed() { let incoming_message = request_stream.next() diff --git a/grpc_unary_attestation/src/server.rs b/grpc_unary_attestation/src/server.rs index d845a81e02c..9e6a283c79e 100644 --- a/grpc_unary_attestation/src/server.rs +++ b/grpc_unary_attestation/src/server.rs @@ -24,7 +24,10 @@ use crate::{ use lru::LruCache; use oak_remote_attestation::handshaker::{AttestationBehavior, Encryptor, ServerHandshaker}; use oak_utils::LogError; -use std::{convert::TryInto, sync::Mutex}; +use std::{ + convert::TryInto, + sync::{Arc, Mutex}, +}; use tonic; enum SessionState { @@ -38,7 +41,7 @@ struct SessionTracker { /// PEM encoded X.509 certificate that signs TEE firmware key. tee_certificate: Vec, /// Configuration information to provide to the client for the attestation step. - additional_info: Vec, + additional_info: Arc>, known_sessions: LruCache, } @@ -50,7 +53,7 @@ impl SessionTracker { let known_sessions = LruCache::new(SESSIONS_CACHE_SIZE); Self { tee_certificate, - additional_info, + additional_info: Arc::new(additional_info), known_sessions, } } diff --git a/oak_functions/client/rust/src/attestation.rs b/oak_functions/client/rust/src/attestation.rs index d1dd941dc8d..546ece6b4ad 100644 --- a/oak_functions/client/rust/src/attestation.rs +++ b/oak_functions/client/rust/src/attestation.rs @@ -30,7 +30,8 @@ pub fn into_server_identity_verifier( config_verifier: ConfigurationVerifier, ) -> ServerIdentityVerifier { let server_verifier = move |server_identity: ServerIdentity| -> anyhow::Result<()> { - let config = ConfigurationInfo::decode(server_identity.additional_info.as_ref())?; + let config = + ConfigurationInfo::decode(server_identity.additional_info.as_ref().as_slice())?; // TODO(#2347): Check that ConfigurationInfo does not have additional/unknown fields. config_verifier(config)?; // TODO(#2316): Verify proof of inclusion in Rekor. diff --git a/remote_attestation/rust/src/handshaker.rs b/remote_attestation/rust/src/handshaker.rs index 141bfe476a0..fdf6d2f1588 100644 --- a/remote_attestation/rust/src/handshaker.rs +++ b/remote_attestation/rust/src/handshaker.rs @@ -34,7 +34,7 @@ use crate::{ }, proto::{AttestationInfo, AttestationReport}, }; -use alloc::{boxed::Box, vec, vec::Vec}; +use alloc::{boxed::Box, sync::Arc, vec, vec::Vec}; use anyhow::{anyhow, Context}; use prost::Message; @@ -331,13 +331,13 @@ pub struct ServerHandshaker { transcript: Transcript, /// Additional info about the server, including configuration information and proof of /// inclusion in a verifiable log. - additional_info: Vec, + additional_info: Arc>, } impl ServerHandshaker { /// Creates [`ServerHandshaker`] with `ServerHandshakerState::ExpectingClientIdentity` /// state. - pub fn new(behavior: AttestationBehavior, additional_info: Vec) -> Self { + pub fn new(behavior: AttestationBehavior, additional_info: Arc>) -> Self { Self { behavior, state: ServerHandshakerState::ExpectingClientHello, @@ -448,9 +448,8 @@ impl ServerHandshaker { .as_ref() .context("Couldn't get TEE certificate")?; - let additional_info = self.additional_info.clone(); let attestation_info = - create_attestation_info(signer, additional_info.as_ref(), tee_certificate) + create_attestation_info(signer, self.additional_info.as_ref(), tee_certificate) .context("Couldn't get attestation info")?; let mut server_identity = ServerIdentity::new( @@ -460,7 +459,7 @@ impl ServerHandshaker { .public_key() .context("Couldn't get singing public key")?, attestation_info, - additional_info, + self.additional_info.clone(), ); // Update current transcript. @@ -487,7 +486,7 @@ impl ServerHandshaker { // Attestation info. vec![], // Additional info. - vec![], + Arc::new(vec![]), ) }; diff --git a/remote_attestation/rust/src/message.rs b/remote_attestation/rust/src/message.rs index 3186730d3c5..5938202a875 100644 --- a/remote_attestation/rust/src/message.rs +++ b/remote_attestation/rust/src/message.rs @@ -25,7 +25,7 @@ use crate::crypto::{ KEY_AGREEMENT_ALGORITHM_KEY_LENGTH, NONCE_LENGTH, SIGNATURE_LENGTH, SIGNING_ALGORITHM_KEY_LENGTH, }; -use alloc::vec::Vec; +use alloc::{sync::Arc, vec::Vec}; use anyhow::{anyhow, bail, Context}; use bytes::{Buf, BufMut}; @@ -124,7 +124,7 @@ pub struct ServerIdentity { /// /// The server and the client must be able to agree on a canonical representation of the /// content to be able to deterministically compute the hash of this field. - pub additional_info: Vec, + pub additional_info: Arc>, } /// Client identity message containing remote attestation information and a public key for @@ -222,7 +222,7 @@ impl ServerIdentity { random: [u8; REPLAY_PROTECTION_ARRAY_LENGTH], signing_public_key: [u8; SIGNING_ALGORITHM_KEY_LENGTH], attestation_info: Vec, - additional_info: Vec, + additional_info: Arc>, ) -> Self { Self { version: PROTOCOL_VERSION, @@ -302,7 +302,7 @@ impl Deserializable for ServerIdentity { let mut signing_public_key = [0u8; SIGNING_ALGORITHM_KEY_LENGTH]; input.copy_to_slice(&mut signing_public_key); let attestation_info = get_vec(&mut input)?; - let additional_info = get_vec(&mut input)?; + let additional_info = Arc::new(get_vec(&mut input)?); if input.has_remaining() { bail!( diff --git a/remote_attestation/rust/src/tests/handshaker.rs b/remote_attestation/rust/src/tests/handshaker.rs index 2a82e749473..f26984a384e 100644 --- a/remote_attestation/rust/src/tests/handshaker.rs +++ b/remote_attestation/rust/src/tests/handshaker.rs @@ -22,7 +22,7 @@ use crate::{ }, tests::message::INVALID_MESSAGE_HEADER, }; -use alloc::{boxed::Box, vec}; +use alloc::{boxed::Box, sync::Arc, vec}; use assert_matches::assert_matches; const TEE_MEASUREMENT: &str = "Test TEE measurement"; @@ -48,7 +48,8 @@ fn create_handshakers() -> (ClientHandshaker, ServerHandshaker) { .unwrap(); let additional_info = br"Additional Info".to_vec(); - let server_handshaker = ServerHandshaker::new(bidirectional_attestation, additional_info); + let server_handshaker = + ServerHandshaker::new(bidirectional_attestation, Arc::new(additional_info)); (client_handshaker, server_handshaker) } diff --git a/remote_attestation/rust/src/tests/message.rs b/remote_attestation/rust/src/tests/message.rs index c962aa833f6..5b180b87d0f 100644 --- a/remote_attestation/rust/src/tests/message.rs +++ b/remote_attestation/rust/src/tests/message.rs @@ -25,7 +25,7 @@ use crate::{ MAXIMUM_MESSAGE_SIZE, REPLAY_PROTECTION_ARRAY_LENGTH, SERVER_IDENTITY_HEADER, }, }; -use alloc::{vec, vec::Vec}; +use alloc::{sync::Arc, vec, vec::Vec}; use anyhow::{anyhow, Context}; use assert_matches::assert_matches; use quickcheck::{quickcheck, TestResult}; @@ -96,7 +96,7 @@ fn test_serialize_server_identity() { transcript_signature: Vec, signing_public_key: Vec, attestation_info: Vec, - additional_info: Vec, + additional_info: Arc>, ) -> TestResult { if ephemeral_public_key.len() > KEY_AGREEMENT_ALGORITHM_KEY_LENGTH || random.len() > REPLAY_PROTECTION_ARRAY_LENGTH @@ -131,7 +131,9 @@ fn test_serialize_server_identity() { assert!(result.is_ok()); TestResult::from_bool(result.unwrap()) } - quickcheck(property as fn(Vec, Vec, Vec, Vec, Vec, Vec) -> TestResult); + quickcheck( + property as fn(Vec, Vec, Vec, Vec, Vec, Arc>) -> TestResult, + ); } #[test] @@ -202,7 +204,7 @@ fn test_deserialize_message() { default_array(), default_array(), vec![], - vec![], + Arc::new(vec![]), ); let deserialized_server_identity = deserialize_message(&server_identity.serialize().unwrap()); assert_matches!(deserialized_server_identity, Ok(_));