Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changes/load-tauri-protocol-async.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
tauri: patch:perf
---

Load `tauri://` custom protocol handlers asynchronously to speed up load time
237 changes: 142 additions & 95 deletions crates/tauri/src/protocol/tauri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ use std::{collections::HashMap, sync::Mutex};
struct CachedResponse {
status: http::StatusCode,
headers: http::HeaderMap,
body: bytes::Bytes,
body: Vec<u8>,
}

pub fn get<R: Runtime>(
#[allow(unused_variables)] manager: Arc<AppManager<R>>,
manager: Arc<AppManager<R>>,
window_origin: &str,
web_resource_request_handler: Option<Box<WebResourceRequestHandler>>,
) -> UriSchemeProtocolHandler {
Expand All @@ -39,6 +39,7 @@ pub fn get<R: Runtime>(
url.pop();
}

#[allow(unused_mut)]
let mut client_builder = reqwest::ClientBuilder::new();
if use_https {
#[cfg(feature = "rustls-tls")]
Expand All @@ -47,6 +48,7 @@ pub fn get<R: Runtime>(
}

// we can't load env vars at runtime, gotta embed them in the lib
#[allow(unused_variables)]
if let Some(cert_pem) = option_env!("TAURI_DEV_ROOT_CERTIFICATE") {
#[cfg(any(
feature = "native-tls",
Expand Down Expand Up @@ -78,44 +80,70 @@ pub fn get<R: Runtime>(
}
let client = client_builder.build().unwrap();

let response_cache = Arc::new(Mutex::new(HashMap::new()));
let response_cache = Mutex::new(HashMap::new());

(url, client, response_cache)
};

let context = Arc::new(Context {
manager,
web_resource_request_handler,
window_origin,
#[cfg(all(dev, mobile))]
client,
#[cfg(all(dev, mobile))]
url,
#[cfg(all(dev, mobile))]
response_cache,
});

Box::new(move |_, request, responder| {
match get_response(
request,
&manager,
&window_origin,
web_resource_request_handler.as_deref(),
#[cfg(all(dev, mobile))]
(&url, &client, &response_cache),
) {
Ok(response) => responder.respond(response),
Err(e) => responder.respond(
HttpResponse::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(CONTENT_TYPE, mime::TEXT_PLAIN.essence_str())
.header("Access-Control-Allow-Origin", &window_origin)
.body(e.to_string().into_bytes())
.unwrap(),
),
}
let context = context.clone();
crate::async_runtime::spawn(async move {
match get_response(&context, request).await {
Ok(response) => responder.respond(response),
Err(e) => responder.respond(
HttpResponse::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(CONTENT_TYPE, mime::TEXT_PLAIN.essence_str())
.header("Access-Control-Allow-Origin", &context.window_origin)
.body(e.to_string().into_bytes())
.unwrap(),
),
}
});
})
}

fn get_response<R: Runtime>(
#[allow(unused_mut)] mut request: Request<Vec<u8>>,
#[allow(unused_variables)] manager: &AppManager<R>,
window_origin: &str,
web_resource_request_handler: Option<&WebResourceRequestHandler>,
#[cfg(all(dev, mobile))] (url, client, response_cache): (
&str,
&reqwest::Client,
&Arc<Mutex<HashMap<String, CachedResponse>>>,
),
struct Context<R: Runtime> {
manager: Arc<AppManager<R>>,
window_origin: String,
web_resource_request_handler: Option<Box<WebResourceRequestHandler>>,

#[cfg(all(dev, mobile))]
url: String,
#[cfg(all(dev, mobile))]
client: reqwest::Client,
#[cfg(all(dev, mobile))]
response_cache: Mutex<HashMap<String, CachedResponse>>,
}

async fn get_response<R: Runtime>(
context: &Context<R>,
request: Request<Vec<u8>>,
) -> Result<HttpResponse<Cow<'static, [u8]>>, Box<dyn std::error::Error>> {
let Context {
manager,
web_resource_request_handler,
window_origin,
#[cfg(all(dev, mobile))]
client,
#[cfg(all(dev, mobile))]
url,
#[cfg(all(dev, mobile))]
response_cache,
} = context;

// use the entire URI as we are going to proxy the request
let path = if PROXY_DEV_SERVER {
request.uri().to_string()
Expand All @@ -137,86 +165,105 @@ fn get_response<R: Runtime>(
// where `$P` is not `localhost/*`
.unwrap_or_default();

#[allow(unused_mut)]
let mut builder = HttpResponse::builder()
.add_configured_headers(manager.config.app.security.headers.as_ref())
.header("Access-Control-Allow-Origin", window_origin);

#[cfg(all(dev, mobile))]
let mut response = {
let decoded_path = percent_encoding::percent_decode(path.as_bytes())
.decode_utf8_lossy()
.to_string();
let url = format!(
"{}/{}",
url.trim_end_matches('/'),
decoded_path.trim_start_matches('/')
);

let mut proxy_builder = client.request(request.method().clone(), &url);
for (name, value) in request.headers() {
proxy_builder = proxy_builder.header(name, value);
}
proxy_builder = proxy_builder.body(request.body().clone());
match crate::async_runtime::safe_block_on(proxy_builder.send()) {
Ok(r) => {
let mut response_cache_ = response_cache.lock().unwrap();
let mut response = None;
if r.status() == http::StatusCode::NOT_MODIFIED {
response = response_cache_.get(&url);
}
let response = if let Some(r) = response {
r
} else {
let status = r.status();
let headers = r.headers().clone();
let body = crate::async_runtime::safe_block_on(r.bytes())?;
let response = CachedResponse {
status,
headers,
body,
};
response_cache_.insert(url.clone(), response);
response_cache_.get(&url).unwrap()
};
for (name, value) in &response.headers {
builder = builder.header(name, value);
}
builder
.status(response.status)
.body(response.body.to_vec().into())?
}
Err(e) => {
let error_message = format!(
"Failed to request {}: {}{}",
url.as_str(),
e,
if let Some(s) = e.status() {
format!("status code: {}", s.as_u16())
} else if cfg!(target_os = "ios") {
", did you grant local network permissions? That is required to reach the development server. Please grant the permission via the prompt or in `Settings > Privacy & Security > Local Network` and restart the app. See https://support.apple.com/en-us/102229 for more information.".to_string()
} else {
"".to_string()
}
);
log::error!("{error_message}");
return Err(error_message.into());
}
}
};
let mut response =
proxy_dev_request(client, url, response_cache, path, builder, &request).await?;

#[cfg(not(all(dev, mobile)))]
let mut response = {
let use_https_scheme = request.uri().scheme() == Some(&http::uri::Scheme::HTTPS);
let asset = manager.get_asset(path, use_https_scheme)?;
let asset = manager.get_asset(
path,
request.uri().scheme() == Some(&http::uri::Scheme::HTTPS),
)?;
builder = builder.header(CONTENT_TYPE, &asset.mime_type);
if let Some(csp) = &asset.csp_header {
builder = builder.header("Content-Security-Policy", csp);
}
builder.body(asset.bytes.into())?
};
if let Some(handler) = &web_resource_request_handler {

if let Some(handler) = web_resource_request_handler {
handler(request, &mut response);
}

Ok(response)
}

#[cfg(all(dev, mobile))]
async fn proxy_dev_request(
client: &reqwest::Client,
url: &String,
response_cache: &Mutex<HashMap<String, CachedResponse>>,
path: String,
mut builder: http::response::Builder,
request: &Request<Vec<u8>>,
) -> Result<HttpResponse<Cow<'static, [u8]>>, Box<dyn std::error::Error>> {
let decoded_path = percent_encoding::percent_decode(path.as_bytes())
.decode_utf8_lossy()
.to_string();
let url = format!(
"{}/{}",
url.trim_end_matches('/'),
decoded_path.trim_start_matches('/')
);

let mut proxy_builder = client.request(request.method().clone(), &url);
for (name, value) in request.headers() {
proxy_builder = proxy_builder.header(name, value);
}
proxy_builder = proxy_builder.body(request.body().clone());

let response = proxy_builder.send().await.map_err(|e|{
let error_message = format!(
"Failed to request {url}: {e}{}",
if let Some(s) = e.status() {
format!("status code: {}", s.as_u16())
} else if cfg!(target_os = "ios") {
", did you grant local network permissions? That is required to reach the development server. Please grant the permission via the prompt or in `Settings > Privacy & Security > Local Network` and restart the app. See https://support.apple.com/en-us/102229 for more information.".to_string()
} else {
"".to_string()
}
);
log::error!("{error_message}");
error_message
})?;

let status = response.status();

if status == http::StatusCode::NOT_MODIFIED {
if let Some(response) = response_cache.lock().unwrap().get(&url).cloned() {
for (name, value) in &response.headers {
builder = builder.header(name, value);
}

return Ok(builder.status(response.status).body(response.body.into())?);
}
}

let headers = response.headers().clone();
let body = response.bytes().await?.to_vec();
let response = CachedResponse {
status,
headers,
body,
};

response_cache
.lock()
.unwrap()
.insert(url.clone(), response.clone());

for (name, value) in &response.headers {
builder = builder.header(name, value);
}

builder
.status(response.status)
.body(response.body.into())
.map_err(Into::into)
}
Loading