diff --git a/sdk/src/polyfill/worker.ts b/sdk/src/polyfill/worker.ts index ef115a03d..8f979db29 100644 --- a/sdk/src/polyfill/worker.ts +++ b/sdk/src/polyfill/worker.ts @@ -1,216 +1,188 @@ -function patch($worker: typeof import("node:worker_threads"), $os: typeof import("node:os")) { - // This is technically not a part of the Worker polyfill, - // but Workers are used for multi-threading, so this is often - // needed when writing Worker code. - if (globalThis.navigator == null) { - globalThis.navigator = { - hardwareConcurrency: $os.cpus().length, - } as Navigator; - } - - globalThis.Worker = class Worker extends EventTarget { - private _worker: import("node:worker_threads").Worker; +import * as $worker from "node:worker_threads"; +import * as $os from "node:os"; + +// This is technically not a part of the Worker polyfill, +// but Workers are used for multi-threading, so this is often +// needed when writing Worker code. +if (globalThis.navigator == null) { + globalThis.navigator = { + hardwareConcurrency: $os.cpus().length, + } as Navigator; +} - constructor(url: string | URL, options?: WorkerOptions | undefined) { - super(); +globalThis.Worker = class Worker extends EventTarget { + private _worker: import("node:worker_threads").Worker; - if (url instanceof URL) { - if (url.protocol !== "file:") { - throw new Error("Worker only supports file: URLs"); - } + constructor(url: string | URL, options?: WorkerOptions | undefined) { + super(); - url = url.href; - - } else { - throw new Error("Filepaths are unreliable, use `new URL(\"...\", import.meta.url)` instead."); + if (url instanceof URL) { + if (url.protocol !== "file:") { + throw new Error("Worker only supports file: URLs"); } - if (!options || options.type !== "module") { - throw new Error("Workers must use \`type: \"module\"\`"); - } + url = url.href; - // This uses some funky stuff like `patch.toString()`. - // - // This is needed so that it can synchronously run the polyfill code - // inside of the worker. - // - // It can't use `require` because the file doesn't have a `.cjs` file extension. - // - // It can't use `import` because that's asynchronous, and the file path - // might be different if using a bundler. - const code = ` - ${patch.toString()} - - // Inject the polyfill into the worker - patch(require("node:worker_threads"), require("node:os")); - - const { workerData } = require("node:worker_threads"); - - // This actually loads and runs the worker file - import(workerData.url) - .catch((e) => { - // TODO maybe it should send a message to the parent? - console.error(e.stack); - }); - `; - - this._worker = new $worker.Worker(code, { - eval: true, - workerData: { - url, - }, - }); - - this._worker.on("message", (data) => { - this.dispatchEvent(new MessageEvent("message", { data })); - }); - - this._worker.on("messageerror", (error) => { - throw new Error("UNIMPLEMENTED"); - }); - - this._worker.on("error", (error) => { - // TODO attach the error to the event somehow - const event = new Event("error"); - this.dispatchEvent(event); - }); + } else { + throw new Error("Filepaths are unreliable, use `new URL(\"...\", import.meta.url)` instead."); } - set onmessage(f: () => void) { - throw new Error("UNIMPLEMENTED"); + if (!options || options.type !== "module") { + throw new Error("Workers must use \`type: \"module\"\`"); } - set onmessageerror(f: () => void) { - throw new Error("UNIMPLEMENTED"); - } + const code = ` + const { workerData } = require("node:worker_threads"); + + import(workerData.polyfill) + .then(() => import(workerData.url)) + .catch((e) => { + // TODO maybe it should send a message to the parent? + console.error(e.stack); + }); + `; + + this._worker = new $worker.Worker(code, { + eval: true, + workerData: { + url, + polyfill: new URL("node-polyfill.js", import.meta.url).href, + }, + }); - set onerror(f: () => void) { + this._worker.on("message", (data) => { + this.dispatchEvent(new MessageEvent("message", { data })); + }); + + this._worker.on("messageerror", (error) => { throw new Error("UNIMPLEMENTED"); - } + }); - postMessage(message: any, transfer: Array): void; - postMessage(message: any, options?: StructuredSerializeOptions | undefined): void; - postMessage(value: any, transfer: any) { - this._worker.postMessage(value, transfer); - } + this._worker.on("error", (error) => { + // TODO attach the error to the event somehow + const event = new Event("error"); + this.dispatchEvent(event); + }); + } - terminate() { - this._worker.terminate(); - } + set onmessage(f: () => void) { + throw new Error("UNIMPLEMENTED"); + } - // This is Node-specific, it allows the process to exit - // even if the Worker is still running. - unref() { - this._worker.unref(); - } - }; + set onmessageerror(f: () => void) { + throw new Error("UNIMPLEMENTED"); + } + set onerror(f: () => void) { + throw new Error("UNIMPLEMENTED"); + } - if (!$worker.isMainThread) { - const globals = globalThis as unknown as DedicatedWorkerGlobalScope; + postMessage(message: any, transfer: Array): void; + postMessage(message: any, options?: StructuredSerializeOptions | undefined): void; + postMessage(value: any, transfer: any) { + this._worker.postMessage(value, transfer); + } - // This is used to create the onmessage, onmessageerror, and onerror setters - const makeSetter = (prop: string, event: string) => { - let oldvalue: () => void; + terminate() { + this._worker.terminate(); + } - Object.defineProperty(globals, prop, { - get() { - return oldvalue; - }, - set(value) { - if (oldvalue) { - globals.removeEventListener(event, oldvalue); - } + // This is Node-specific, it allows the process to exit + // even if the Worker is still running. + unref() { + this._worker.unref(); + } +}; - oldvalue = value; - if (oldvalue) { - globals.addEventListener(event, oldvalue); - } - }, - }); - }; +if (!$worker.isMainThread) { + const globals = globalThis as unknown as DedicatedWorkerGlobalScope; - // This makes sure that `f` is only run once - const memoize = (f: () => void) => { - let run = false; + // This is used to create the onmessage, onmessageerror, and onerror setters + const makeSetter = (prop: string, event: string) => { + let oldvalue: () => void; - return () => { - if (!run) { - run = true; - f(); + Object.defineProperty(globals, prop, { + get() { + return oldvalue; + }, + set(value) { + if (oldvalue) { + globals.removeEventListener(event, oldvalue); } - }; - }; + oldvalue = value; - // We only start listening for messages / errors when the worker calls addEventListener - const startOnMessage = memoize(() => { - $worker.parentPort!.on("message", (data) => { - workerEvents.dispatchEvent(new MessageEvent("message", { data })); - }); + if (oldvalue) { + globals.addEventListener(event, oldvalue); + } + }, }); + }; - const startOnMessageError = memoize(() => { - throw new Error("UNIMPLEMENTED"); - }); + // This makes sure that `f` is only run once + const memoize = (f: () => void) => { + let run = false; - const startOnError = memoize(() => { - $worker.parentPort!.on("error", (data) => { - workerEvents.dispatchEvent(new Event("error")); - }); + return () => { + if (!run) { + run = true; + f(); + } + }; + }; + + + // We only start listening for messages / errors when the worker calls addEventListener + const startOnMessage = memoize(() => { + $worker.parentPort!.on("message", (data) => { + workerEvents.dispatchEvent(new MessageEvent("message", { data })); }); + }); + const startOnMessageError = memoize(() => { + throw new Error("UNIMPLEMENTED"); + }); - // Node workers don't have top-level events, so we have to make our own - const workerEvents = new EventTarget(); + const startOnError = memoize(() => { + $worker.parentPort!.on("error", (data) => { + workerEvents.dispatchEvent(new Event("error")); + }); + }); - globals.close = () => { - process.exit(); - }; - globals.addEventListener = (type: string, callback: EventListenerOrEventListenerObject | null, options?: boolean | EventListenerOptions | undefined) => { - workerEvents.addEventListener(type, callback, options); + // Node workers don't have top-level events, so we have to make our own + const workerEvents = new EventTarget(); - if (type === "message") { - startOnMessage(); - } else if (type === "messageerror") { - startOnMessageError(); - } else if (type === "error") { - startOnError(); - } - }; + globals.close = () => { + process.exit(); + }; - globals.removeEventListener = (type: string, callback: EventListenerOrEventListenerObject | null, options?: boolean | EventListenerOptions | undefined) => { - workerEvents.removeEventListener(type, callback, options); - }; + globals.addEventListener = (type: string, callback: EventListenerOrEventListenerObject | null, options?: boolean | EventListenerOptions | undefined) => { + workerEvents.addEventListener(type, callback, options); - function postMessage(message: any, transfer: Transferable[]): void; - function postMessage(message: any, options?: StructuredSerializeOptions | undefined): void; - function postMessage(value: any, transfer: any) { - $worker.parentPort!.postMessage(value, transfer); + if (type === "message") { + startOnMessage(); + } else if (type === "messageerror") { + startOnMessageError(); + } else if (type === "error") { + startOnError(); } + }; - globals.postMessage = postMessage; + globals.removeEventListener = (type: string, callback: EventListenerOrEventListenerObject | null, options?: boolean | EventListenerOptions | undefined) => { + workerEvents.removeEventListener(type, callback, options); + }; - makeSetter("onmessage", "message"); - makeSetter("onmessageerror", "messageerror"); - makeSetter("onerror", "error"); + function postMessage(message: any, transfer: Transferable[]): void; + function postMessage(message: any, options?: StructuredSerializeOptions | undefined): void; + function postMessage(value: any, transfer: any) { + $worker.parentPort!.postMessage(value, transfer); } -} - -async function polyfill() { - const [$worker, $os] = await Promise.all([ - import("node:worker_threads"), - import("node:os"), - ]); + globals.postMessage = postMessage; - patch($worker, $os); + makeSetter("onmessage", "message"); + makeSetter("onmessageerror", "messageerror"); + makeSetter("onerror", "error"); } - -if (globalThis.Worker == null) { - await polyfill(); -} - -export {}; diff --git a/wasm/src/lib.rs b/wasm/src/lib.rs index a2fe4c1b5..27aaa7ebd 100644 --- a/wasm/src/lib.rs +++ b/wasm/src/lib.rs @@ -167,10 +167,20 @@ pub use types::Field; #[cfg(not(test))] mod thread_pool; -use wasm_bindgen::prelude::*; +#[cfg(test)] +mod thread_pool { + use std::future::Future; + + pub fn spawn(f: F) -> impl Future + where + A: Send + 'static, + F: FnOnce() -> A + Send + 'static, + { + async move { f() } + } +} -#[cfg(not(test))] -use thread_pool::ThreadPool; +use wasm_bindgen::prelude::*; use std::str::FromStr; @@ -219,7 +229,7 @@ use types::native; pub async fn init_thread_pool(url: web_sys::Url, num_threads: usize) -> Result<(), JsValue> { console_error_panic_hook::set_once(); - ThreadPool::builder().url(url).num_threads(num_threads).build_global().await?; + thread_pool::ThreadPool::builder().url(url).num_threads(num_threads).build_global().await?; Ok(()) } diff --git a/wasm/src/thread_pool/mod.rs b/wasm/src/thread_pool/mod.rs index c2c18c4b0..98d45db45 100644 --- a/wasm/src/thread_pool/mod.rs +++ b/wasm/src/thread_pool/mod.rs @@ -14,7 +14,7 @@ // You should have received a copy of the GNU General Public License // along with the Aleo SDK library. If not, see . -use futures::future::try_join_all; +use futures::{channel::oneshot, future::try_join_all}; use rayon::ThreadBuilder; use spmc::{channel, Receiver, Sender}; use std::future::Future; @@ -29,13 +29,17 @@ use wasm_bindgen_futures::JsFuture; }); worker.addEventListener("message", (event) => { - // When running in Node, this allows the process to exit - // even though the Worker is still running. - if (worker.unref) { - worker.unref(); - } - - resolve(worker); + // This is needed in Node to wait one extra tick, so that way + // the Worker can fully initialize before we return. + setTimeout(() => { + resolve(worker); + + // When running in Node, this allows the process to exit + // even though the Worker is still running. + if (worker.unref) { + worker.unref(); + } + }, 0); }, { capture: true, once: true, @@ -48,6 +52,16 @@ use wasm_bindgen_futures::JsFuture; }); }); } + + export function startTimer() { + // Starts a super-long timer in order to keep the Node + // process alive until we manually cancel it. + return setTimeout(() => {}, Math.pow(2, 31) - 1); + } + + export function stopTimer(timer) { + clearTimeout(timer); + } "###)] extern "C" { #[wasm_bindgen(js_name = spawnWorker)] @@ -57,6 +71,51 @@ extern "C" { memory: &JsValue, address: *const Receiver, ) -> js_sys::Promise; + + #[wasm_bindgen(js_name = startTimer)] + fn start_timer() -> f64; + + #[wasm_bindgen(js_name = stopTimer)] + fn stop_timer(timer: f64); +} + +/// Runs a function on the Rayon thread-pool. +/// +/// When the function returns, the Future will resolve +/// with the return value of the function. +/// +/// # NodeJS +/// +/// This will keep the NodeJS process alive until the +/// Future is resolved. +pub fn spawn(f: F) -> impl Future +where + A: Send + 'static, + F: FnOnce() -> A + Send + 'static, +{ + // This is necessary in order to stop the Node process + // from exiting while the spawned task is running. + struct Timer(f64); + + impl Drop for Timer { + fn drop(&mut self) { + stop_timer(self.0); + } + } + + let timer = Timer(start_timer()); + + let (sender, receiver) = oneshot::channel(); + + rayon::spawn(move || { + let _ = sender.send(f()); + }); + + async move { + let output = receiver.await.unwrap_throw(); + drop(timer); + output + } } async fn spawn_workers(url: web_sys::Url, num_threads: usize) -> Result, JsValue> {